utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import re
  2. import json
  3. import os
  4. import torch
  5. import torch.distributed as dist
  6. import utils
  7. def pre_caption(caption,max_words=50):
  8. caption = re.sub(
  9. r"([.!\"()*#:;~])",
  10. ' ',
  11. caption.lower(),
  12. )
  13. caption = re.sub(
  14. r"\s{2,}",
  15. ' ',
  16. caption,
  17. )
  18. caption = caption.rstrip('\n')
  19. caption = caption.strip(' ')
  20. #truncate caption
  21. caption_words = caption.split(' ')
  22. if len(caption_words)>max_words:
  23. caption = ' '.join(caption_words[:max_words])
  24. return caption
  25. def pre_question(question,max_ques_words=50):
  26. question = re.sub(
  27. r"([.!\"()*#:;~])",
  28. '',
  29. question.lower(),
  30. )
  31. question = question.rstrip(' ')
  32. #truncate question
  33. question_words = question.split(' ')
  34. if len(question_words)>max_ques_words:
  35. question = ' '.join(question_words[:max_ques_words])
  36. return question
  37. def save_result(result, result_dir, filename, remove_duplicate=''):
  38. result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
  39. final_result_file = os.path.join(result_dir, '%s.json'%filename)
  40. json.dump(result,open(result_file,'w'))
  41. dist.barrier()
  42. if utils.is_main_process():
  43. # combine results from all processes
  44. result = []
  45. for rank in range(utils.get_world_size()):
  46. result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
  47. res = json.load(open(result_file,'r'))
  48. result += res
  49. if remove_duplicate:
  50. result_new = []
  51. id_list = []
  52. for res in result:
  53. if res[remove_duplicate] not in id_list:
  54. id_list.append(res[remove_duplicate])
  55. result_new.append(res)
  56. result = result_new
  57. json.dump(result,open(final_result_file,'w'))
  58. print('result file saved to %s'%final_result_file)
  59. return final_result_file
  60. from pycocotools.coco import COCO
  61. from pycocoevalcap.eval import COCOEvalCap
  62. from torchvision.datasets.utils import download_url
  63. def coco_caption_eval(coco_gt_root, results_file, split):
  64. urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
  65. 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
  66. filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
  67. download_url(urls[split],coco_gt_root)
  68. annotation_file = os.path.join(coco_gt_root,filenames[split])
  69. # create coco object and coco_result object
  70. coco = COCO(annotation_file)
  71. coco_result = coco.loadRes(results_file)
  72. # create coco_eval object by taking coco and coco_result
  73. coco_eval = COCOEvalCap(coco, coco_result)
  74. # evaluate on a subset of images by setting
  75. # coco_eval.params['image_id'] = coco_result.getImgIds()
  76. # please remove this line when evaluating the full validation set
  77. # coco_eval.params['image_id'] = coco_result.getImgIds()
  78. # evaluate results
  79. # SPICE will take a few minutes the first time, but speeds up due to caching
  80. coco_eval.evaluate()
  81. # print output evaluation scores
  82. for metric, score in coco_eval.eval.items():
  83. print(f'{metric}: {score:.3f}')
  84. return coco_eval