nlvr_dataset.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import json
  3. import random
  4. from torch.utils.data import Dataset
  5. from torchvision.datasets.utils import download_url
  6. from PIL import Image
  7. from data.utils import pre_caption
  8. class nlvr_dataset(Dataset):
  9. def __init__(self, transform, image_root, ann_root, split):
  10. '''
  11. image_root (string): Root directory of images
  12. ann_root (string): directory to store the annotation file
  13. split (string): train, val or test
  14. '''
  15. urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
  16. 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
  17. 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
  18. filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
  19. download_url(urls[split],ann_root)
  20. self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
  21. self.transform = transform
  22. self.image_root = image_root
  23. def __len__(self):
  24. return len(self.annotation)
  25. def __getitem__(self, index):
  26. ann = self.annotation[index]
  27. image0_path = os.path.join(self.image_root,ann['images'][0])
  28. image0 = Image.open(image0_path).convert('RGB')
  29. image0 = self.transform(image0)
  30. image1_path = os.path.join(self.image_root,ann['images'][1])
  31. image1 = Image.open(image1_path).convert('RGB')
  32. image1 = self.transform(image1)
  33. sentence = pre_caption(ann['sentence'], 40)
  34. if ann['label']=='True':
  35. label = 1
  36. else:
  37. label = 0
  38. words = sentence.split(' ')
  39. if 'left' not in words and 'right' not in words:
  40. if random.random()<0.5:
  41. return image0, image1, sentence, label
  42. else:
  43. return image1, image0, sentence, label
  44. else:
  45. if random.random()<0.5:
  46. return image0, image1, sentence, label
  47. else:
  48. new_words = []
  49. for word in words:
  50. if word=='left':
  51. new_words.append('right')
  52. elif word=='right':
  53. new_words.append('left')
  54. else:
  55. new_words.append(word)
  56. sentence = ' '.join(new_words)
  57. return image1, image0, sentence, label