video_dataset.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from torch.utils.data import Dataset
  2. from torchvision.datasets.utils import download_url
  3. from PIL import Image
  4. import torch
  5. import numpy as np
  6. import random
  7. import decord
  8. from decord import VideoReader
  9. import json
  10. import os
  11. from data.utils import pre_caption
  12. decord.bridge.set_bridge("torch")
  13. class ImageNorm(object):
  14. """Apply Normalization to Image Pixels on GPU
  15. """
  16. def __init__(self, mean, std):
  17. self.mean = torch.tensor(mean).view(1, 3, 1, 1)
  18. self.std = torch.tensor(std).view(1, 3, 1, 1)
  19. def __call__(self, img):
  20. if torch.max(img) > 1 and self.mean.max() <= 1:
  21. img.div_(255.)
  22. return img.sub_(self.mean).div_(self.std)
  23. def load_jsonl(filename):
  24. with open(filename, "r") as f:
  25. return [json.loads(l.strip("\n")) for l in f.readlines()]
  26. class VideoDataset(Dataset):
  27. def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
  28. '''
  29. image_root (string): Root directory of video
  30. ann_root (string): directory to store the annotation file
  31. '''
  32. url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
  33. filename = 'msrvtt_test.jsonl'
  34. download_url(url,ann_root)
  35. self.annotation = load_jsonl(os.path.join(ann_root,filename))
  36. self.num_frm = num_frm
  37. self.frm_sampling_strategy = frm_sampling_strategy
  38. self.max_img_size = max_img_size
  39. self.video_root = video_root
  40. self.video_fmt = video_fmt
  41. self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
  42. self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
  43. self.txt2video = [i for i in range(len(self.annotation))]
  44. self.video2txt = self.txt2video
  45. def __len__(self):
  46. return len(self.annotation)
  47. def __getitem__(self, index):
  48. ann = self.annotation[index]
  49. video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
  50. vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
  51. video = self.img_norm(vid_frm_array.float())
  52. return video, ann['clip_name']
  53. def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
  54. try:
  55. if not height or not width:
  56. vr = VideoReader(video_path)
  57. else:
  58. vr = VideoReader(video_path, width=width, height=height)
  59. vlen = len(vr)
  60. if start_time or end_time:
  61. assert fps > 0, 'must provide video fps if specifying start and end time.'
  62. start_idx = min(int(start_time * fps), vlen)
  63. end_idx = min(int(end_time * fps), vlen)
  64. else:
  65. start_idx, end_idx = 0, vlen
  66. if self.frm_sampling_strategy == 'uniform':
  67. frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
  68. elif self.frm_sampling_strategy == 'rand':
  69. frame_indices = sorted(random.sample(range(vlen), self.num_frm))
  70. elif self.frm_sampling_strategy == 'headtail':
  71. frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
  72. frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
  73. frame_indices = frame_indices_head + frame_indices_tail
  74. else:
  75. raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
  76. raw_sample_frms = vr.get_batch(frame_indices)
  77. except Exception as e:
  78. return None
  79. raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
  80. return raw_sample_frms