train_retrieval.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. '''
  2. * Copyright (c) 2022, salesforce.com, inc.
  3. * All rights reserved.
  4. * SPDX-License-Identifier: BSD-3-Clause
  5. * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
  6. * By Junnan Li
  7. '''
  8. import argparse
  9. import os
  10. import ruamel_yaml as yaml
  11. import numpy as np
  12. import random
  13. import time
  14. import datetime
  15. import json
  16. from pathlib import Path
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. import torch.backends.cudnn as cudnn
  21. import torch.distributed as dist
  22. from torch.utils.data import DataLoader
  23. from models.blip_retrieval import blip_retrieval
  24. import utils
  25. from utils import cosine_lr_schedule
  26. from data import create_dataset, create_sampler, create_loader
  27. def train(model, data_loader, optimizer, epoch, device, config):
  28. # train
  29. model.train()
  30. metric_logger = utils.MetricLogger(delimiter=" ")
  31. metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
  32. metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
  33. metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
  34. header = 'Train Epoch: [{}]'.format(epoch)
  35. print_freq = 50
  36. for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  37. image = image.to(device,non_blocking=True)
  38. idx = idx.to(device,non_blocking=True)
  39. if epoch>0:
  40. alpha = config['alpha']
  41. else:
  42. alpha = config['alpha']*min(1,i/len(data_loader))
  43. loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
  44. loss = loss_ita + loss_itm
  45. optimizer.zero_grad()
  46. loss.backward()
  47. optimizer.step()
  48. metric_logger.update(loss_itm=loss_itm.item())
  49. metric_logger.update(loss_ita=loss_ita.item())
  50. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  51. # gather the stats from all processes
  52. metric_logger.synchronize_between_processes()
  53. print("Averaged stats:", metric_logger.global_avg())
  54. return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
  55. @torch.no_grad()
  56. def evaluation(model, data_loader, device, config):
  57. # test
  58. model.eval()
  59. metric_logger = utils.MetricLogger(delimiter=" ")
  60. header = 'Evaluation:'
  61. print('Computing features for evaluation...')
  62. start_time = time.time()
  63. texts = data_loader.dataset.text
  64. num_text = len(texts)
  65. text_bs = 256
  66. text_ids = []
  67. text_embeds = []
  68. text_atts = []
  69. for i in range(0, num_text, text_bs):
  70. text = texts[i: min(num_text, i+text_bs)]
  71. text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
  72. text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
  73. text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
  74. text_embeds.append(text_embed)
  75. text_ids.append(text_input.input_ids)
  76. text_atts.append(text_input.attention_mask)
  77. text_embeds = torch.cat(text_embeds,dim=0)
  78. text_ids = torch.cat(text_ids,dim=0)
  79. text_atts = torch.cat(text_atts,dim=0)
  80. text_ids[:,0] = model.tokenizer.enc_token_id
  81. image_feats = []
  82. image_embeds = []
  83. for image, img_id in data_loader:
  84. image = image.to(device)
  85. image_feat = model.visual_encoder(image)
  86. image_embed = model.vision_proj(image_feat[:,0,:])
  87. image_embed = F.normalize(image_embed,dim=-1)
  88. image_feats.append(image_feat.cpu())
  89. image_embeds.append(image_embed)
  90. image_feats = torch.cat(image_feats,dim=0)
  91. image_embeds = torch.cat(image_embeds,dim=0)
  92. sims_matrix = image_embeds @ text_embeds.t()
  93. score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
  94. num_tasks = utils.get_world_size()
  95. rank = utils.get_rank()
  96. step = sims_matrix.size(0)//num_tasks + 1
  97. start = rank*step
  98. end = min(sims_matrix.size(0),start+step)
  99. for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
  100. topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
  101. encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
  102. encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
  103. output = model.text_encoder(text_ids[topk_idx],
  104. attention_mask = text_atts[topk_idx],
  105. encoder_hidden_states = encoder_output,
  106. encoder_attention_mask = encoder_att,
  107. return_dict = True,
  108. )
  109. score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
  110. score_matrix_i2t[start+i,topk_idx] = score + topk_sim
  111. sims_matrix = sims_matrix.t()
  112. score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
  113. step = sims_matrix.size(0)//num_tasks + 1
  114. start = rank*step
  115. end = min(sims_matrix.size(0),start+step)
  116. for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
  117. topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
  118. encoder_output = image_feats[topk_idx].to(device)
  119. encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
  120. output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
  121. attention_mask = text_atts[start+i].repeat(config['k_test'],1),
  122. encoder_hidden_states = encoder_output,
  123. encoder_attention_mask = encoder_att,
  124. return_dict = True,
  125. )
  126. score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
  127. score_matrix_t2i[start+i,topk_idx] = score + topk_sim
  128. if args.distributed:
  129. dist.barrier()
  130. torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
  131. torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
  132. total_time = time.time() - start_time
  133. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  134. print('Evaluation time {}'.format(total_time_str))
  135. return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
  136. @torch.no_grad()
  137. def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
  138. #Images->Text
  139. ranks = np.zeros(scores_i2t.shape[0])
  140. for index,score in enumerate(scores_i2t):
  141. inds = np.argsort(score)[::-1]
  142. # Score
  143. rank = 1e20
  144. for i in img2txt[index]:
  145. tmp = np.where(inds == i)[0][0]
  146. if tmp < rank:
  147. rank = tmp
  148. ranks[index] = rank
  149. # Compute metrics
  150. tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
  151. tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
  152. tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
  153. #Text->Images
  154. ranks = np.zeros(scores_t2i.shape[0])
  155. for index,score in enumerate(scores_t2i):
  156. inds = np.argsort(score)[::-1]
  157. ranks[index] = np.where(inds == txt2img[index])[0][0]
  158. # Compute metrics
  159. ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
  160. ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
  161. ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
  162. tr_mean = (tr1 + tr5 + tr10) / 3
  163. ir_mean = (ir1 + ir5 + ir10) / 3
  164. r_mean = (tr_mean + ir_mean) / 2
  165. eval_result = {'txt_r1': tr1,
  166. 'txt_r5': tr5,
  167. 'txt_r10': tr10,
  168. 'txt_r_mean': tr_mean,
  169. 'img_r1': ir1,
  170. 'img_r5': ir5,
  171. 'img_r10': ir10,
  172. 'img_r_mean': ir_mean,
  173. 'r_mean': r_mean}
  174. return eval_result
  175. def main(args, config):
  176. utils.init_distributed_mode(args)
  177. device = torch.device(args.device)
  178. # fix the seed for reproducibility
  179. seed = args.seed + utils.get_rank()
  180. torch.manual_seed(seed)
  181. np.random.seed(seed)
  182. random.seed(seed)
  183. cudnn.benchmark = True
  184. #### Dataset ####
  185. print("Creating retrieval dataset")
  186. train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
  187. if args.distributed:
  188. num_tasks = utils.get_world_size()
  189. global_rank = utils.get_rank()
  190. samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
  191. else:
  192. samplers = [None, None, None]
  193. train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
  194. batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
  195. num_workers=[4,4,4],
  196. is_trains=[True, False, False],
  197. collate_fns=[None,None,None])
  198. #### Model ####
  199. print("Creating model")
  200. model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
  201. vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
  202. queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
  203. model = model.to(device)
  204. model_without_ddp = model
  205. if args.distributed:
  206. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  207. model_without_ddp = model.module
  208. optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
  209. best = 0
  210. best_epoch = 0
  211. print("Start training")
  212. start_time = time.time()
  213. for epoch in range(0, config['max_epoch']):
  214. if not args.evaluate:
  215. if args.distributed:
  216. train_loader.sampler.set_epoch(epoch)
  217. cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
  218. train_stats = train(model, train_loader, optimizer, epoch, device, config)
  219. score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
  220. score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
  221. if utils.is_main_process():
  222. val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
  223. print(val_result)
  224. if val_result['r_mean']>best:
  225. save_obj = {
  226. 'model': model_without_ddp.state_dict(),
  227. 'optimizer': optimizer.state_dict(),
  228. 'config': config,
  229. 'epoch': epoch,
  230. }
  231. torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
  232. best = val_result['r_mean']
  233. best_epoch = epoch
  234. test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
  235. print(test_result)
  236. if args.evaluate:
  237. log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
  238. **{f'test_{k}': v for k, v in test_result.items()},
  239. }
  240. with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
  241. f.write(json.dumps(log_stats) + "\n")
  242. else:
  243. log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
  244. **{f'val_{k}': v for k, v in val_result.items()},
  245. **{f'test_{k}': v for k, v in test_result.items()},
  246. 'epoch': epoch,
  247. 'best_epoch': best_epoch,
  248. }
  249. with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
  250. f.write(json.dumps(log_stats) + "\n")
  251. if args.evaluate:
  252. break
  253. dist.barrier()
  254. torch.cuda.empty_cache()
  255. total_time = time.time() - start_time
  256. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  257. print('Training time {}'.format(total_time_str))
  258. if __name__ == '__main__':
  259. parser = argparse.ArgumentParser()
  260. parser.add_argument('--config', default='./configs/retrieval_flickr.yaml')
  261. parser.add_argument('--output_dir', default='output/Retrieval_flickr')
  262. parser.add_argument('--evaluate', action='store_true')
  263. parser.add_argument('--device', default='cuda')
  264. parser.add_argument('--seed', default=42, type=int)
  265. parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
  266. parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
  267. parser.add_argument('--distributed', default=True, type=bool)
  268. args = parser.parse_args()
  269. config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
  270. Path(args.output_dir).mkdir(parents=True, exist_ok=True)
  271. yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
  272. main(args, config)