train_nlvr.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. '''
  2. * Copyright (c) 2022, salesforce.com, inc.
  3. * All rights reserved.
  4. * SPDX-License-Identifier: BSD-3-Clause
  5. * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
  6. * By Junnan Li
  7. '''
  8. import argparse
  9. import os
  10. import ruamel_yaml as yaml
  11. import numpy as np
  12. import random
  13. import time
  14. import datetime
  15. import json
  16. from pathlib import Path
  17. import json
  18. import pickle
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. from torch.utils.data import DataLoader
  23. import torch.backends.cudnn as cudnn
  24. import torch.distributed as dist
  25. from models.blip_nlvr import blip_nlvr
  26. import utils
  27. from utils import cosine_lr_schedule, warmup_lr_schedule
  28. from data import create_dataset, create_sampler, create_loader
  29. def train(model, data_loader, optimizer, epoch, device, config):
  30. # train
  31. model.train()
  32. metric_logger = utils.MetricLogger(delimiter=" ")
  33. metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
  34. metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
  35. header = 'Train Epoch: [{}]'.format(epoch)
  36. print_freq = 50
  37. step_size = 10
  38. for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  39. images = torch.cat([image0, image1], dim=0)
  40. images, targets = images.to(device), targets.to(device)
  41. loss = model(images, text, targets=targets, train=True)
  42. optimizer.zero_grad()
  43. loss.backward()
  44. optimizer.step()
  45. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  46. metric_logger.update(loss=loss.item())
  47. # gather the stats from all processes
  48. metric_logger.synchronize_between_processes()
  49. print("Averaged stats:", metric_logger.global_avg())
  50. return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
  51. @torch.no_grad()
  52. def evaluate(model, data_loader, device, config):
  53. # test
  54. model.eval()
  55. metric_logger = utils.MetricLogger(delimiter=" ")
  56. header = 'Evaluation:'
  57. print_freq = 50
  58. for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
  59. images = torch.cat([image0, image1], dim=0)
  60. images, targets = images.to(device), targets.to(device)
  61. prediction = model(images, text, targets=targets, train=False)
  62. _, pred_class = prediction.max(1)
  63. accuracy = (targets==pred_class).sum() / targets.size(0)
  64. metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
  65. # gather the stats from all processes
  66. metric_logger.synchronize_between_processes()
  67. print("Averaged stats:", metric_logger.global_avg())
  68. return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
  69. def main(args, config):
  70. utils.init_distributed_mode(args)
  71. device = torch.device(args.device)
  72. # fix the seed for reproducibility
  73. seed = args.seed + utils.get_rank()
  74. torch.manual_seed(seed)
  75. np.random.seed(seed)
  76. random.seed(seed)
  77. cudnn.benchmark = True
  78. #### Dataset ####
  79. print("Creating dataset")
  80. datasets = create_dataset('nlvr', config)
  81. if args.distributed:
  82. num_tasks = utils.get_world_size()
  83. global_rank = utils.get_rank()
  84. samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank)
  85. else:
  86. samplers = [None, None, None]
  87. batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']]
  88. train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size,
  89. num_workers=[4,4,4],is_trains=[True,False,False],
  90. collate_fns=[None,None,None])
  91. #### Model ####
  92. print("Creating model")
  93. model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'],
  94. vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
  95. model = model.to(device)
  96. model_without_ddp = model
  97. if args.distributed:
  98. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  99. model_without_ddp = model.module
  100. optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
  101. print("Start training")
  102. start_time = time.time()
  103. best = 0
  104. best_epoch = 0
  105. for epoch in range(0, config['max_epoch']):
  106. if not args.evaluate:
  107. if args.distributed:
  108. train_loader.sampler.set_epoch(epoch)
  109. cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
  110. train_stats = train(model, train_loader, optimizer, epoch, device, config)
  111. val_stats = evaluate(model, val_loader, device, config)
  112. test_stats = evaluate(model, test_loader, device, config)
  113. if utils.is_main_process():
  114. if args.evaluate:
  115. log_stats = {**{f'val_{k}': v for k, v in val_stats.items()},
  116. **{f'test_{k}': v for k, v in test_stats.items()},
  117. }
  118. with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
  119. f.write(json.dumps(log_stats) + "\n")
  120. else:
  121. log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
  122. **{f'val_{k}': v for k, v in val_stats.items()},
  123. **{f'test_{k}': v for k, v in test_stats.items()},
  124. 'epoch': epoch,
  125. }
  126. if float(val_stats['acc'])>best:
  127. save_obj = {
  128. 'model': model_without_ddp.state_dict(),
  129. 'optimizer': optimizer.state_dict(),
  130. 'config': config,
  131. 'epoch': epoch,
  132. }
  133. torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
  134. best = float(val_stats['acc'])
  135. best_epoch = epoch
  136. with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
  137. f.write(json.dumps(log_stats) + "\n")
  138. if args.evaluate:
  139. break
  140. dist.barrier()
  141. if utils.is_main_process():
  142. with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
  143. f.write("best epoch: %d"%best_epoch)
  144. total_time = time.time() - start_time
  145. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  146. print('Training time {}'.format(total_time_str))
  147. if __name__ == '__main__':
  148. parser = argparse.ArgumentParser()
  149. parser.add_argument('--config', default='./configs/nlvr.yaml')
  150. parser.add_argument('--output_dir', default='output/NLVR')
  151. parser.add_argument('--evaluate', action='store_true')
  152. parser.add_argument('--device', default='cuda')
  153. parser.add_argument('--seed', default=42, type=int)
  154. parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
  155. parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
  156. parser.add_argument('--distributed', default=True, type=bool)
  157. args = parser.parse_args()
  158. config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
  159. Path(args.output_dir).mkdir(parents=True, exist_ok=True)
  160. yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
  161. main(args, config)