__init__.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import torch
  2. from torch.utils.data import DataLoader
  3. from torchvision import transforms
  4. from torchvision.transforms.functional import InterpolationMode
  5. from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
  6. from data.nocaps_dataset import nocaps_eval
  7. from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
  8. from data.vqa_dataset import vqa_dataset
  9. from data.nlvr_dataset import nlvr_dataset
  10. from data.pretrain_dataset import pretrain_dataset
  11. from transform.randaugment import RandomAugment
  12. def create_dataset(dataset, config, min_scale=0.5):
  13. normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
  14. transform_train = transforms.Compose([
  15. transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
  16. transforms.RandomHorizontalFlip(),
  17. RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
  18. 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
  19. transforms.ToTensor(),
  20. normalize,
  21. ])
  22. transform_test = transforms.Compose([
  23. transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
  24. transforms.ToTensor(),
  25. normalize,
  26. ])
  27. if dataset=='pretrain':
  28. dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
  29. return dataset
  30. elif dataset=='caption_coco':
  31. train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
  32. val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
  33. test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
  34. return train_dataset, val_dataset, test_dataset
  35. elif dataset=='nocaps':
  36. val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
  37. test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
  38. return val_dataset, test_dataset
  39. elif dataset=='retrieval_coco':
  40. train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
  41. val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
  42. test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
  43. return train_dataset, val_dataset, test_dataset
  44. elif dataset=='retrieval_flickr':
  45. train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
  46. val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
  47. test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
  48. return train_dataset, val_dataset, test_dataset
  49. elif dataset=='vqa':
  50. train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
  51. train_files = config['train_files'], split='train')
  52. test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
  53. return train_dataset, test_dataset
  54. elif dataset=='nlvr':
  55. train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
  56. val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
  57. test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
  58. return train_dataset, val_dataset, test_dataset
  59. def create_sampler(datasets, shuffles, num_tasks, global_rank):
  60. samplers = []
  61. for dataset,shuffle in zip(datasets,shuffles):
  62. sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
  63. samplers.append(sampler)
  64. return samplers
  65. def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
  66. loaders = []
  67. for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
  68. if is_train:
  69. shuffle = (sampler is None)
  70. drop_last = True
  71. else:
  72. shuffle = False
  73. drop_last = False
  74. loader = DataLoader(
  75. dataset,
  76. batch_size=bs,
  77. num_workers=n_worker,
  78. pin_memory=True,
  79. sampler=sampler,
  80. shuffle=shuffle,
  81. collate_fn=collate_fn,
  82. drop_last=drop_last,
  83. )
  84. loaders.append(loader)
  85. return loaders