pretrain.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. '''
  2. * Copyright (c) 2022, salesforce.com, inc.
  3. * All rights reserved.
  4. * SPDX-License-Identifier: BSD-3-Clause
  5. * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
  6. * By Junnan Li
  7. '''
  8. import argparse
  9. import os
  10. import ruamel_yaml as yaml
  11. import numpy as np
  12. import random
  13. import time
  14. import datetime
  15. import json
  16. from pathlib import Path
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. import torch.backends.cudnn as cudnn
  21. import torch.distributed as dist
  22. from torch.utils.data import DataLoader
  23. from models.blip_pretrain import blip_pretrain
  24. import utils
  25. from utils import warmup_lr_schedule, step_lr_schedule
  26. from data import create_dataset, create_sampler, create_loader
  27. def train(model, data_loader, optimizer, epoch, device, config):
  28. # train
  29. model.train()
  30. metric_logger = utils.MetricLogger(delimiter=" ")
  31. metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
  32. metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
  33. metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
  34. metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
  35. header = 'Train Epoch: [{}]'.format(epoch)
  36. print_freq = 50
  37. if config['laion_path']:
  38. data_loader.dataset.reload_laion(epoch)
  39. data_loader.sampler.set_epoch(epoch)
  40. for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  41. if epoch==0:
  42. warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
  43. optimizer.zero_grad()
  44. image = image.to(device,non_blocking=True)
  45. # ramp up alpha in the first 2 epochs
  46. alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader)))
  47. loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha)
  48. loss = loss_ita + loss_itm + loss_lm
  49. loss.backward()
  50. optimizer.step()
  51. metric_logger.update(loss_ita=loss_ita.item())
  52. metric_logger.update(loss_itm=loss_itm.item())
  53. metric_logger.update(loss_lm=loss_lm.item())
  54. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  55. # gather the stats from all processes
  56. metric_logger.synchronize_between_processes()
  57. print("Averaged stats:", metric_logger.global_avg())
  58. return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
  59. def main(args, config):
  60. utils.init_distributed_mode(args)
  61. device = torch.device(args.device)
  62. # fix the seed for reproducibility
  63. seed = args.seed + utils.get_rank()
  64. torch.manual_seed(seed)
  65. np.random.seed(seed)
  66. random.seed(seed)
  67. cudnn.benchmark = True
  68. #### Dataset ####
  69. print("Creating dataset")
  70. datasets = [create_dataset('pretrain', config, min_scale=0.2)]
  71. print('number of training samples: %d'%len(datasets[0]))
  72. num_tasks = utils.get_world_size()
  73. global_rank = utils.get_rank()
  74. samplers = create_sampler(datasets, [True], num_tasks, global_rank)
  75. data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
  76. #### Model ####
  77. print("Creating model")
  78. model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
  79. vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size'])
  80. model = model.to(device)
  81. optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
  82. start_epoch = 0
  83. if args.checkpoint:
  84. checkpoint = torch.load(args.checkpoint, map_location='cpu')
  85. state_dict = checkpoint['model']
  86. model.load_state_dict(state_dict)
  87. optimizer.load_state_dict(checkpoint['optimizer'])
  88. start_epoch = checkpoint['epoch']+1
  89. print('resume checkpoint from %s'%args.checkpoint)
  90. model_without_ddp = model
  91. if args.distributed:
  92. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  93. model_without_ddp = model.module
  94. print("Start training")
  95. start_time = time.time()
  96. for epoch in range(start_epoch, config['max_epoch']):
  97. step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate'])
  98. train_stats = train(model, data_loader, optimizer, epoch, device, config)
  99. if utils.is_main_process():
  100. log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
  101. 'epoch': epoch,
  102. }
  103. save_obj = {
  104. 'model': model_without_ddp.state_dict(),
  105. 'optimizer': optimizer.state_dict(),
  106. 'config': config,
  107. 'epoch': epoch,
  108. }
  109. torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
  110. with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
  111. f.write(json.dumps(log_stats) + "\n")
  112. dist.barrier()
  113. total_time = time.time() - start_time
  114. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  115. print('Training time {}'.format(total_time_str))
  116. if __name__ == '__main__':
  117. parser = argparse.ArgumentParser()
  118. parser.add_argument('--config', default='./configs/pretrain.yaml')
  119. parser.add_argument('--output_dir', default='output/Pretrain')
  120. parser.add_argument('--checkpoint', default='')
  121. parser.add_argument('--evaluate', action='store_true')
  122. parser.add_argument('--device', default='cuda')
  123. parser.add_argument('--seed', default=42, type=int)
  124. parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
  125. parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
  126. parser.add_argument('--distributed', default=True, type=bool)
  127. args = parser.parse_args()
  128. config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
  129. Path(args.output_dir).mkdir(parents=True, exist_ok=True)
  130. yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
  131. main(args, config)