train_vqa.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. '''
  2. * Copyright (c) 2022, salesforce.com, inc.
  3. * All rights reserved.
  4. * SPDX-License-Identifier: BSD-3-Clause
  5. * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
  6. * By Junnan Li
  7. '''
  8. import argparse
  9. import os
  10. import ruamel_yaml as yaml
  11. import numpy as np
  12. import random
  13. import time
  14. import datetime
  15. import json
  16. from pathlib import Path
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. from torch.utils.data import DataLoader
  21. import torch.backends.cudnn as cudnn
  22. import torch.distributed as dist
  23. from models.blip_vqa import blip_vqa
  24. import utils
  25. from utils import cosine_lr_schedule
  26. from data import create_dataset, create_sampler, create_loader
  27. from data.vqa_dataset import vqa_collate_fn
  28. from data.utils import save_result
  29. def train(model, data_loader, optimizer, epoch, device):
  30. # train
  31. model.train()
  32. metric_logger = utils.MetricLogger(delimiter=" ")
  33. metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
  34. metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
  35. header = 'Train Epoch: [{}]'.format(epoch)
  36. print_freq = 50
  37. for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  38. image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
  39. loss = model(image, question, answer, train=True, n=n, weights=weights)
  40. optimizer.zero_grad()
  41. loss.backward()
  42. optimizer.step()
  43. metric_logger.update(loss=loss.item())
  44. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  45. # gather the stats from all processes
  46. metric_logger.synchronize_between_processes()
  47. print("Averaged stats:", metric_logger.global_avg())
  48. return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
  49. @torch.no_grad()
  50. def evaluation(model, data_loader, device, config) :
  51. # test
  52. model.eval()
  53. metric_logger = utils.MetricLogger(delimiter=" ")
  54. header = 'Generate VQA test result:'
  55. print_freq = 50
  56. result = []
  57. if config['inference']=='rank':
  58. answer_list = data_loader.dataset.answer_list
  59. answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)
  60. answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id
  61. for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  62. image = image.to(device,non_blocking=True)
  63. if config['inference']=='generate':
  64. answers = model(image, question, train=False, inference='generate')
  65. for answer, ques_id in zip(answers, question_id):
  66. ques_id = int(ques_id.item())
  67. result.append({"question_id":ques_id, "answer":answer})
  68. elif config['inference']=='rank':
  69. answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test'])
  70. for ques_id, answer_id in zip(question_id, answer_ids):
  71. result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]})
  72. return result
  73. def main(args, config):
  74. utils.init_distributed_mode(args)
  75. device = torch.device(args.device)
  76. # fix the seed for reproducibility
  77. seed = args.seed + utils.get_rank()
  78. torch.manual_seed(seed)
  79. np.random.seed(seed)
  80. random.seed(seed)
  81. cudnn.benchmark = True
  82. #### Dataset ####
  83. print("Creating vqa datasets")
  84. datasets = create_dataset('vqa', config)
  85. if args.distributed:
  86. num_tasks = utils.get_world_size()
  87. global_rank = utils.get_rank()
  88. samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
  89. else:
  90. samplers = [None, None]
  91. train_loader, test_loader = create_loader(datasets,samplers,
  92. batch_size=[config['batch_size_train'],config['batch_size_test']],
  93. num_workers=[4,4],is_trains=[True, False],
  94. collate_fns=[vqa_collate_fn,None])
  95. #### Model ####
  96. print("Creating model")
  97. model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'],
  98. vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
  99. model = model.to(device)
  100. model_without_ddp = model
  101. if args.distributed:
  102. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  103. model_without_ddp = model.module
  104. optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
  105. best = 0
  106. best_epoch = 0
  107. print("Start training")
  108. start_time = time.time()
  109. for epoch in range(0, config['max_epoch']):
  110. if not args.evaluate:
  111. if args.distributed:
  112. train_loader.sampler.set_epoch(epoch)
  113. cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
  114. train_stats = train(model, train_loader, optimizer, epoch, device)
  115. else:
  116. break
  117. if utils.is_main_process():
  118. log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
  119. 'epoch': epoch,
  120. }
  121. with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
  122. f.write(json.dumps(log_stats) + "\n")
  123. save_obj = {
  124. 'model': model_without_ddp.state_dict(),
  125. 'optimizer': optimizer.state_dict(),
  126. 'config': config,
  127. 'epoch': epoch,
  128. }
  129. torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
  130. dist.barrier()
  131. vqa_result = evaluation(model_without_ddp, test_loader, device, config)
  132. result_file = save_result(vqa_result, args.result_dir, 'vqa_result')
  133. total_time = time.time() - start_time
  134. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  135. print('Training time {}'.format(total_time_str))
  136. if __name__ == '__main__':
  137. parser = argparse.ArgumentParser()
  138. parser.add_argument('--config', default='./configs/vqa.yaml')
  139. parser.add_argument('--output_dir', default='output/VQA')
  140. parser.add_argument('--evaluate', action='store_true')
  141. parser.add_argument('--device', default='cuda')
  142. parser.add_argument('--seed', default=42, type=int)
  143. parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
  144. parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
  145. parser.add_argument('--distributed', default=True, type=bool)
  146. args = parser.parse_args()
  147. config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
  148. args.result_dir = os.path.join(args.output_dir, 'result')
  149. Path(args.output_dir).mkdir(parents=True, exist_ok=True)
  150. Path(args.result_dir).mkdir(parents=True, exist_ok=True)
  151. yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
  152. main(args, config)