eval_retrieval_video.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. '''
  2. * Copyright (c) 2022, salesforce.com, inc.
  3. * All rights reserved.
  4. * SPDX-License-Identifier: BSD-3-Clause
  5. * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
  6. * By Junnan Li
  7. '''
  8. import argparse
  9. import os
  10. import ruamel_yaml as yaml
  11. import numpy as np
  12. import random
  13. import time
  14. import datetime
  15. import json
  16. from pathlib import Path
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. import torch.backends.cudnn as cudnn
  21. import torch.distributed as dist
  22. from torch.utils.data import DataLoader
  23. from models.blip_retrieval import blip_retrieval
  24. import utils
  25. from data.video_dataset import VideoDataset
  26. @torch.no_grad()
  27. def evaluation(model, data_loader, tokenizer, device, config):
  28. # test
  29. model.eval()
  30. metric_logger = utils.MetricLogger(delimiter=" ")
  31. header = 'Evaluation:'
  32. print('Computing features for evaluation...')
  33. start_time = time.time()
  34. texts = data_loader.dataset.text
  35. num_text = len(texts)
  36. text_bs = 256
  37. text_ids = []
  38. text_embeds = []
  39. text_atts = []
  40. for i in range(0, num_text, text_bs):
  41. text = texts[i: min(num_text, i+text_bs)]
  42. text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
  43. text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
  44. text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
  45. text_embeds.append(text_embed)
  46. text_ids.append(text_input.input_ids)
  47. text_atts.append(text_input.attention_mask)
  48. text_embeds = torch.cat(text_embeds,dim=0)
  49. text_ids = torch.cat(text_ids,dim=0)
  50. text_atts = torch.cat(text_atts,dim=0)
  51. text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
  52. video_feats = []
  53. video_embeds = []
  54. for video, video_id in data_loader:
  55. B,N,C,W,H = video.size()
  56. video = video.view(-1,C,W,H)
  57. video = video.to(device,non_blocking=True)
  58. video_feat = model.visual_encoder(video)
  59. video_embed = model.vision_proj(video_feat[:,0,:])
  60. video_embed = video_embed.view(B,N,-1).mean(dim=1)
  61. video_embed = F.normalize(video_embed,dim=-1)
  62. video_feat = video_feat.view(B,-1,video_feat.shape[-1])
  63. video_feats.append(video_feat.cpu())
  64. video_embeds.append(video_embed)
  65. video_feats = torch.cat(video_feats,dim=0)
  66. video_embeds = torch.cat(video_embeds,dim=0)
  67. sims_matrix = video_embeds @ text_embeds.t()
  68. score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
  69. num_tasks = utils.get_world_size()
  70. rank = utils.get_rank()
  71. step = sims_matrix.size(0)//num_tasks + 1
  72. start = rank*step
  73. end = min(sims_matrix.size(0),start+step)
  74. for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
  75. topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
  76. encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
  77. encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
  78. output = model.text_encoder(text_ids[topk_idx],
  79. attention_mask = text_atts[topk_idx],
  80. encoder_hidden_states = encoder_output,
  81. encoder_attention_mask = encoder_att,
  82. return_dict = True,
  83. )
  84. score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
  85. score_matrix_v2t[start+i,topk_idx] = score + topk_sim
  86. sims_matrix = sims_matrix.t()
  87. score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
  88. step = sims_matrix.size(0)//num_tasks + 1
  89. start = rank*step
  90. end = min(sims_matrix.size(0),start+step)
  91. for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
  92. topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
  93. encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
  94. encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
  95. output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
  96. attention_mask = text_atts[start+i].repeat(config['k_test'],1),
  97. encoder_hidden_states = encoder_output,
  98. encoder_attention_mask = encoder_att,
  99. return_dict = True,
  100. )
  101. score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
  102. score_matrix_t2v[start+i,topk_idx] = score + topk_sim
  103. if args.distributed:
  104. dist.barrier()
  105. torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
  106. torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
  107. total_time = time.time() - start_time
  108. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  109. print('Evaluation time {}'.format(total_time_str))
  110. return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
  111. @torch.no_grad()
  112. def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
  113. #Video->Text
  114. ranks = np.zeros(scores_v2t.shape[0])
  115. for index,score in enumerate(scores_v2t):
  116. inds = np.argsort(score)[::-1]
  117. ranks[index] = np.where(inds == vid2txt[index])[0][0]
  118. # Compute metrics
  119. tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
  120. tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
  121. tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
  122. #Text->Video
  123. ranks = np.zeros(scores_t2v.shape[0])
  124. for index,score in enumerate(scores_t2v):
  125. inds = np.argsort(score)[::-1]
  126. ranks[index] = np.where(inds == txt2vmg[index])[0][0]
  127. mdR = np.median(ranks+1)
  128. # Compute metrics
  129. vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
  130. vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
  131. vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
  132. tr_mean = (tr1 + tr5 + tr10) / 3
  133. vr_mean = (vr1 + vr5 + vr10) / 3
  134. r_mean = (tr_mean + vr_mean) / 2
  135. eval_result = {'txt_r1': tr1,
  136. 'txt_r5': tr5,
  137. 'txt_r10': tr10,
  138. 'txt_r_mean': tr_mean,
  139. 'vid_r1': vr1,
  140. 'vid_r5': vr5,
  141. 'vid_r10': vr10,
  142. 'vid_r_mean': vr_mean,
  143. 'vid_mdR': mdR,
  144. 'r_mean': r_mean}
  145. return eval_result
  146. def main(args, config):
  147. utils.init_distributed_mode(args)
  148. device = torch.device(args.device)
  149. # fix the seed for reproducibility
  150. seed = args.seed + utils.get_rank()
  151. torch.manual_seed(seed)
  152. np.random.seed(seed)
  153. random.seed(seed)
  154. cudnn.benchmark = True
  155. #### Dataset ####
  156. print("Creating retrieval dataset")
  157. test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
  158. max_img_size=config['image_size'], frm_sampling_strategy='uniform')
  159. test_loader = DataLoader(
  160. test_dataset,
  161. batch_size=config['batch_size'],
  162. num_workers=4,
  163. pin_memory=True,
  164. drop_last=False,
  165. shuffle=False,
  166. )
  167. #### Model ####
  168. print("Creating model")
  169. model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
  170. model = model.to(device)
  171. model_without_ddp = model
  172. if args.distributed:
  173. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  174. model_without_ddp = model.module
  175. score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
  176. if utils.is_main_process():
  177. test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
  178. print(test_result)
  179. log_stats = {**{f'{k}': v for k, v in test_result.items()},}
  180. with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
  181. f.write(json.dumps(log_stats) + "\n")
  182. if __name__ == '__main__':
  183. parser = argparse.ArgumentParser()
  184. parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
  185. parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
  186. parser.add_argument('--device', default='cuda')
  187. parser.add_argument('--seed', default=42, type=int)
  188. parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
  189. parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
  190. parser.add_argument('--distributed', default=True, type=bool)
  191. args = parser.parse_args()
  192. config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
  193. Path(args.output_dir).mkdir(parents=True, exist_ok=True)
  194. yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
  195. main(args, config)