pretrain_dataset.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import json
  2. import os
  3. import random
  4. from torch.utils.data import Dataset
  5. from PIL import Image
  6. from PIL import ImageFile
  7. ImageFile.LOAD_TRUNCATED_IMAGES = True
  8. Image.MAX_IMAGE_PIXELS = None
  9. from data.utils import pre_caption
  10. import os,glob
  11. class pretrain_dataset(Dataset):
  12. def __init__(self, ann_file, laion_path, transform):
  13. self.ann_pretrain = []
  14. for f in ann_file:
  15. print('loading '+f)
  16. ann = json.load(open(f,'r'))
  17. self.ann_pretrain += ann
  18. self.laion_path = laion_path
  19. if self.laion_path:
  20. self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
  21. print('loading '+self.laion_files[0])
  22. with open(self.laion_files[0],'r') as f:
  23. self.ann_laion = json.load(f)
  24. self.annotation = self.ann_pretrain + self.ann_laion
  25. else:
  26. self.annotation = self.ann_pretrain
  27. self.transform = transform
  28. def reload_laion(self, epoch):
  29. n = epoch%len(self.laion_files)
  30. print('loading '+self.laion_files[n])
  31. with open(self.laion_files[n],'r') as f:
  32. self.ann_laion = json.load(f)
  33. self.annotation = self.ann_pretrain + self.ann_laion
  34. def __len__(self):
  35. return len(self.annotation)
  36. def __getitem__(self, index):
  37. ann = self.annotation[index]
  38. image = Image.open(ann['image']).convert('RGB')
  39. image = self.transform(image)
  40. caption = pre_caption(ann['caption'],30)
  41. return image, caption