eval_nocaps.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. '''
  2. * Copyright (c) 2022, salesforce.com, inc.
  3. * All rights reserved.
  4. * SPDX-License-Identifier: BSD-3-Clause
  5. * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
  6. * By Junnan Li
  7. '''
  8. import argparse
  9. import os
  10. import ruamel_yaml as yaml
  11. import numpy as np
  12. import random
  13. import time
  14. import datetime
  15. import json
  16. from pathlib import Path
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. import torch.backends.cudnn as cudnn
  21. import torch.distributed as dist
  22. from torch.utils.data import DataLoader
  23. from models.blip import blip_decoder
  24. import utils
  25. from data import create_dataset, create_sampler, create_loader
  26. from data.utils import save_result
  27. @torch.no_grad()
  28. def evaluate(model, data_loader, device, config):
  29. # evaluate
  30. model.eval()
  31. metric_logger = utils.MetricLogger(delimiter=" ")
  32. header = 'Evaluation:'
  33. print_freq = 10
  34. result = []
  35. for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
  36. image = image.to(device)
  37. captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
  38. min_length=config['min_length'], repetition_penalty=1.1)
  39. for caption, img_id in zip(captions, image_id):
  40. result.append({"image_id": img_id.item(), "caption": caption})
  41. return result
  42. def main(args, config):
  43. utils.init_distributed_mode(args)
  44. device = torch.device(args.device)
  45. # fix the seed for reproducibility
  46. seed = args.seed + utils.get_rank()
  47. torch.manual_seed(seed)
  48. np.random.seed(seed)
  49. random.seed(seed)
  50. cudnn.benchmark = True
  51. #### Dataset ####
  52. print("Creating captioning dataset")
  53. val_dataset, test_dataset = create_dataset('nocaps', config)
  54. if args.distributed:
  55. num_tasks = utils.get_world_size()
  56. global_rank = utils.get_rank()
  57. samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
  58. else:
  59. samplers = [None,None]
  60. val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
  61. batch_size=[config['batch_size']]*2,num_workers=[4,4],
  62. is_trains=[False, False], collate_fns=[None,None])
  63. #### Model ####
  64. print("Creating model")
  65. model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
  66. prompt=config['prompt'])
  67. model = model.to(device)
  68. model_without_ddp = model
  69. if args.distributed:
  70. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  71. model_without_ddp = model.module
  72. val_result = evaluate(model_without_ddp, val_loader, device, config)
  73. val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
  74. test_result = evaluate(model_without_ddp, test_loader, device, config)
  75. test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
  76. if __name__ == '__main__':
  77. parser = argparse.ArgumentParser()
  78. parser.add_argument('--config', default='./configs/nocaps.yaml')
  79. parser.add_argument('--output_dir', default='output/NoCaps')
  80. parser.add_argument('--device', default='cuda')
  81. parser.add_argument('--seed', default=42, type=int)
  82. parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
  83. parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
  84. parser.add_argument('--distributed', default=True, type=bool)
  85. args = parser.parse_args()
  86. config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
  87. args.result_dir = os.path.join(args.output_dir, 'result')
  88. Path(args.output_dir).mkdir(parents=True, exist_ok=True)
  89. Path(args.result_dir).mkdir(parents=True, exist_ok=True)
  90. yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
  91. main(args, config)