| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import json
- import os
- import random
- from torch.utils.data import Dataset
- from PIL import Image
- from PIL import ImageFile
- ImageFile.LOAD_TRUNCATED_IMAGES = True
- Image.MAX_IMAGE_PIXELS = None
- from data.utils import pre_caption
- import os,glob
- class pretrain_dataset(Dataset):
- def __init__(self, ann_file, laion_path, transform):
- self.ann_pretrain = []
- for f in ann_file:
- print('loading '+f)
- ann = json.load(open(f,'r'))
- self.ann_pretrain += ann
-
- self.laion_path = laion_path
- if self.laion_path:
- self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
- print('loading '+self.laion_files[0])
- with open(self.laion_files[0],'r') as f:
- self.ann_laion = json.load(f)
- self.annotation = self.ann_pretrain + self.ann_laion
- else:
- self.annotation = self.ann_pretrain
-
- self.transform = transform
- def reload_laion(self, epoch):
- n = epoch%len(self.laion_files)
- print('loading '+self.laion_files[n])
- with open(self.laion_files[n],'r') as f:
- self.ann_laion = json.load(f)
-
- self.annotation = self.ann_pretrain + self.ann_laion
-
-
- def __len__(self):
- return len(self.annotation)
-
- def __getitem__(self, index):
-
- ann = self.annotation[index]
-
- image = Image.open(ann['image']).convert('RGB')
- image = self.transform(image)
- caption = pre_caption(ann['caption'],30)
-
- return image, caption
|