train_caption.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. '''
  2. * Copyright (c) 2022, salesforce.com, inc.
  3. * All rights reserved.
  4. * SPDX-License-Identifier: BSD-3-Clause
  5. * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
  6. * By Junnan Li
  7. '''
  8. import argparse
  9. import os
  10. import ruamel_yaml as yaml
  11. import numpy as np
  12. import random
  13. import time
  14. import datetime
  15. import json
  16. from pathlib import Path
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. import torch.backends.cudnn as cudnn
  21. import torch.distributed as dist
  22. from torch.utils.data import DataLoader
  23. from models.blip import blip_decoder
  24. import utils
  25. from utils import cosine_lr_schedule
  26. from data import create_dataset, create_sampler, create_loader
  27. from data.utils import save_result, coco_caption_eval
  28. def train(model, data_loader, optimizer, epoch, device):
  29. # train
  30. model.train()
  31. metric_logger = utils.MetricLogger(delimiter=" ")
  32. metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
  33. metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
  34. header = 'Train Caption Epoch: [{}]'.format(epoch)
  35. print_freq = 50
  36. for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  37. image = image.to(device)
  38. loss = model(image, caption)
  39. optimizer.zero_grad()
  40. loss.backward()
  41. optimizer.step()
  42. metric_logger.update(loss=loss.item())
  43. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  44. # gather the stats from all processes
  45. metric_logger.synchronize_between_processes()
  46. print("Averaged stats:", metric_logger.global_avg())
  47. return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
  48. @torch.no_grad()
  49. def evaluate(model, data_loader, device, config):
  50. # evaluate
  51. model.eval()
  52. metric_logger = utils.MetricLogger(delimiter=" ")
  53. header = 'Caption generation:'
  54. print_freq = 10
  55. result = []
  56. for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
  57. image = image.to(device)
  58. captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
  59. min_length=config['min_length'])
  60. for caption, img_id in zip(captions, image_id):
  61. result.append({"image_id": img_id.item(), "caption": caption})
  62. return result
  63. def main(args, config):
  64. utils.init_distributed_mode(args)
  65. device = torch.device(args.device)
  66. # fix the seed for reproducibility
  67. seed = args.seed + utils.get_rank()
  68. torch.manual_seed(seed)
  69. np.random.seed(seed)
  70. random.seed(seed)
  71. cudnn.benchmark = True
  72. #### Dataset ####
  73. print("Creating captioning dataset")
  74. train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
  75. if args.distributed:
  76. num_tasks = utils.get_world_size()
  77. global_rank = utils.get_rank()
  78. samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
  79. else:
  80. samplers = [None, None, None]
  81. train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
  82. batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
  83. is_trains=[True, False, False], collate_fns=[None,None,None])
  84. #### Model ####
  85. print("Creating model")
  86. model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
  87. vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
  88. prompt=config['prompt'])
  89. model = model.to(device)
  90. model_without_ddp = model
  91. if args.distributed:
  92. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  93. model_without_ddp = model.module
  94. optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
  95. best = 0
  96. best_epoch = 0
  97. print("Start training")
  98. start_time = time.time()
  99. for epoch in range(0, config['max_epoch']):
  100. if not args.evaluate:
  101. if args.distributed:
  102. train_loader.sampler.set_epoch(epoch)
  103. cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
  104. train_stats = train(model, train_loader, optimizer, epoch, device)
  105. val_result = evaluate(model_without_ddp, val_loader, device, config)
  106. val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
  107. test_result = evaluate(model_without_ddp, test_loader, device, config)
  108. test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
  109. if utils.is_main_process():
  110. coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
  111. coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
  112. if args.evaluate:
  113. log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
  114. **{f'test_{k}': v for k, v in coco_test.eval.items()},
  115. }
  116. with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
  117. f.write(json.dumps(log_stats) + "\n")
  118. else:
  119. save_obj = {
  120. 'model': model_without_ddp.state_dict(),
  121. 'optimizer': optimizer.state_dict(),
  122. 'config': config,
  123. 'epoch': epoch,
  124. }
  125. if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
  126. best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
  127. best_epoch = epoch
  128. torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
  129. log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
  130. **{f'val_{k}': v for k, v in coco_val.eval.items()},
  131. **{f'test_{k}': v for k, v in coco_test.eval.items()},
  132. 'epoch': epoch,
  133. 'best_epoch': best_epoch,
  134. }
  135. with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
  136. f.write(json.dumps(log_stats) + "\n")
  137. if args.evaluate:
  138. break
  139. dist.barrier()
  140. total_time = time.time() - start_time
  141. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  142. print('Training time {}'.format(total_time_str))
  143. if __name__ == '__main__':
  144. parser = argparse.ArgumentParser()
  145. parser.add_argument('--config', default='./configs/caption_coco.yaml')
  146. parser.add_argument('--output_dir', default='output/Caption_coco')
  147. parser.add_argument('--evaluate', action='store_true')
  148. parser.add_argument('--device', default='cuda')
  149. parser.add_argument('--seed', default=42, type=int)
  150. parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
  151. parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
  152. parser.add_argument('--distributed', default=True, type=bool)
  153. args = parser.parse_args()
  154. config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
  155. args.result_dir = os.path.join(args.output_dir, 'result')
  156. Path(args.output_dir).mkdir(parents=True, exist_ok=True)
  157. Path(args.result_dir).mkdir(parents=True, exist_ok=True)
  158. yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
  159. main(args, config)