mazpie's picture
Initial commit
2d9a728
raw
history blame
20.7 kB
import logging
import os
import json
import random
import io
import torch
import numpy as np
from dataset.base_dataset import BaseDataset
from dataset.text_prompt import kinetics_templates, imagenet_templates
from dataset.utils import pre_text
from dataset.video_utils import VIDEO_READER_FUNCS
from dataset.serialize import get_local_rank, TorchShmSerializedList
logger = logging.getLogger(__name__)
class ImgTxtPtTrainDataset(BaseDataset):
media_type = "image"
def __init__(self, ann_file, transform, num_epochs=1):
super().__init__()
logger.info(f"ann_file: {ann_file}")
self.media_type = ann_file.media_type
self.label_file = ann_file.anno_path
self.data_root = ann_file.data_root
self.data_root_prefix = ann_file.get("data_root_prefix", "")
self.min_caption_length = ann_file.get("min_caption_length", 2)
self.caption_augmentation = ann_file.get("caption_augmentation", None)
self.transform = transform
# each caption has multiple image as ground_truth, e.g., ssv2
self.has_multi_vision_gt = ann_file.get("has_multi_vision_gt", False)
assert not self.has_multi_vision_gt
self.crop_img = ann_file.get("crop_img", False)
self.use_prompt = ann_file.get("prompt", "") != ""
if self.use_prompt:
if ann_file.prompt == "imagenet":
self.prompt = imagenet_templates
logger.info(f"Use prompt for ImageNet")
elif ann_file.prompt == "kinetics":
self.prompt = kinetics_templates
logger.info(f"Use prompt for Kinetics")
else:
raise NotImplementedError(ann_file.prompt)
logger.info(self.prompt)
if self.use_prompt and self.caption_augmentation is not None:
raise NotImplementedError("You can't use prompt because of multiple captions!")
if '.json' in self.label_file:
logger.info(f"Loading json file {self.label_file}")
if get_local_rank() == 0: # Only one rank need to read the file
with io.BytesIO(self.client.get(self.label_file)) as f:
# with open(self.label_file, 'r') as f:
annos = json.load(f)
if ann_file.get("jump_filter", False):
logger.info("Jump filter!")
else:
if self.caption_augmentation is not None:
# filter out the caption with length less than min_caption_length
new_annos = []
if self.media_type == "audio_video" and self.caption_augmentation.caption_sample_type == 'avs_all':
for anno in annos:
ok = True
if not anno['video'].endswith('.mp4'):
ok = False
for k in anno.keys():
if "caption" in k and 'asr' not in k:
tmp_c = pre_text(anno[k])
if len(tmp_c.split()) < self.min_caption_length:
ok = False
break
if ok:
new_annos.append(anno)
elif self.caption_augmentation.caption_sample_type == 'uniform':
for anno in annos:
if "captions" in anno.keys():
caption_key = "captions"
else:
caption_key = "caption"
assert type(anno[caption_key]) is list, type(anno[caption_key])
caption_list = []
for c in anno[caption_key]:
tmp_c = pre_text(c)
if len(tmp_c.split()) >= self.min_caption_length:
caption_list.append(tmp_c)
if len(caption_list) > 0:
new_annos.append(anno)
else:
raise NotImplementedError(ann_file)
logger.info(f"Num samples: {len(annos)}")
logger.info(f"Num samples not too short: {len(new_annos)} min_caption_length={self.min_caption_length}")
annos = new_annos
else:
# filter out the caption with length less than min_caption_length
captions = [pre_text(anno["caption"]) for anno in annos]
captions_len = [len(caption.split()) for caption in captions]
logger.info("Num samples: {}".format(len(captions)))
logger.info("Num samples too short: {}".format(sum([l < self.min_caption_length for l in captions_len])))
annos = [anno for anno, l in zip(annos, captions_len) if l >= self.min_caption_length]
if num_epochs < 1:
raise NotImplementedError
else:
annos = []
self.anno = TorchShmSerializedList(annos)
self.num_examples = len(self.anno)
logger.info(f"num_examples: {self.num_examples}")
else:
raise NotImplementedError("We need json file!!!")
def __len__(self):
return self.num_examples
def get_caption(self, index):
if '.json' in self.label_file:
if self.caption_augmentation is not None:
if self.caption_augmentation.caption_sample_type == 'avs_all':
caption_dict = {}
for k in self.anno[index].keys():
if 'caption' in k:
caption_dict[k] = self.anno[index][k]
else:
if "captions" in self.anno[index].keys():
captions = self.anno[index]["captions"]
else:
captions = self.anno[index]["caption"]
else:
caption = self.anno[index]["caption"]
else:
raise NotImplementedError
if self.caption_augmentation is not None:
if self.caption_augmentation.caption_sample_type == 'uniform':
caption = random.choice(captions)
elif self.caption_augmentation.caption_sample_type == 'avs_all':
caption = caption_dict
else:
raise NotImplementedError
return caption
def get_anno(self, index):
assert self.media_type == 'image', self.media_type
anno = {"caption": self.get_caption(index)}
anno["image"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["image"])
if self.use_prompt:
anno["caption"] = random.choice(self.prompt).format(anno["caption"])
if self.crop_img:
anno["crop_bbox"] = self.anno[index]["crop_bbox"]
return anno
def pre_caption(self, caption):
if type(caption) is str:
return pre_text(caption)
elif type(caption) is dict:
assert self.caption_augmentation.caption_sample_type == 'avs_all'
caption_dict = {}
for k in caption.keys():
caption_dict[k] = pre_text(caption[k])
return caption_dict
else:
raise NotImplementedError(caption)
def __getitem__(self, index):
try:
ann = self.get_anno(index)
caption = self.pre_caption(ann["caption"])
# key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"])
if self.crop_img:
data_path = {"image":ann["image"], "crop_bbox":ann["crop_bbox"]}
image, index = self.load_and_transform_media_data(index, data_path)
else:
image, index = self.load_and_transform_media_data(index, ann["image"])
return image, caption, index
except Exception as e:
logger.warning(f"Caught exception {e} when loading image {ann}")
# raise e
print(e)
index = np.random.randint(0, len(self))
return self.__getitem__(index)
class VidTxtPtTrainDataset(ImgTxtPtTrainDataset):
media_type = "video"
def __init__(
self,
ann_file,
transform,
num_frames=4,
video_reader_type="decord",
sample_type="rand",
num_tries=3,
num_epochs=1
):
super().__init__(ann_file, transform, num_epochs)
self.num_frames = num_frames
self.video_reader_type = video_reader_type
self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
self.sample_type = sample_type
self.num_tries = num_tries
self.is_paragraph_retrieval = ann_file.get("is_paragraph_retrieval", False)
self.read_clip_from_video = ann_file.get("read_clip_from_video", False)
if self.is_paragraph_retrieval:
raise NotImplementedError
def get_anno(self, index):
assert self.media_type == "video", self.media_type
anno = {"caption": self.get_caption(index)}
anno["video"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["video"])
if self.read_clip_from_video:
anno["video_start_frame"] = self.anno[index]["video_start_frame"]
anno["video_end_frame"] = self.anno[index]["video_end_frame"]
if self.use_prompt:
anno["caption"] = random.choice(self.prompt).format(anno["caption"])
return anno
def __getitem__(self, index):
try:
ann = self.get_anno(index)
caption = self.pre_caption(ann["caption"])
if self.read_clip_from_video:
data_path = {
"video": ann["video"],
"video_start_frame": ann["video_start_frame"],
"video_end_frame": ann["video_end_frame"],
"read_clip_from_video": True
}
else:
data_path = ann["video"]
video, index = self.load_and_transform_media_data(index, data_path)
return video, caption, index
except Exception as e:
logger.warning(f"Caught exception {e} when loading video {ann}")
# raise e
print(e)
index = np.random.randint(0, len(self))
return self.__getitem__(index)
class AudioVidTxtPtTrainDataset(VidTxtPtTrainDataset):
media_type = "audio_video"
def __init__(
self,
ann_file,
transform,
audio_sample_rate=16000,
audio_reader_type='torchaudio',
max_audio_length=10,
num_frames=4,
video_reader_type="decord",
sample_type="rand",
num_tries=3,
num_epochs=1
):
super().__init__(ann_file, transform, num_epochs=num_epochs, num_frames=num_frames, video_reader_type=video_reader_type, sample_type=sample_type, num_tries=num_tries)
assert self.media_type == 'audio_video', self.media_type
self.audio_sample_rate = audio_sample_rate
self.audio_reader_type = audio_reader_type
self.max_audio_length = max_audio_length
self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False)
self.read_audio_from_video = ann_file.get("read_audio_from_video", False)
self.zero_audio_padding_for_video = ann_file.get("zero_audio_padding_for_video", False)
self.now_tries = 0
def get_anno(self, index):
anno = {"caption": self.get_caption(index)}
anno["video"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["video"])
if self.read_clip_from_video:
anno["video_start_frame"] = self.anno[index]["video_start_frame"]
anno["video_end_frame"] = self.anno[index]["video_end_frame"]
if "audio" in self.anno[index].keys():
anno["audio"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["audio"])
if self.use_prompt:
anno["caption"] = random.choice(self.prompt).format(anno["caption"])
return anno
def __getitem__(self, index):
try:
ann = self.get_anno(index)
caption = self.pre_caption(ann["caption"])
data_path = {'video': ann["video"]}
if self.read_clip_from_video:
data_path["video_start_frame"] = ann["video_start_frame"]
data_path["video_end_frame"] = ann["video_end_frame"]
if "audio" in ann.keys():
data_path["read_audio_from_video"] = False
data_path["audio"] = ann["audio"]
else:
data_path["read_audio_from_video"] = self.read_audio_from_video
data_path["read_clip_from_video"] = self.read_clip_from_video
media, index = self.load_and_transform_media_data(index, data_path)
self.now_tries = 0
audio = media[0]
if audio is None and self.zero_audio_padding_for_video:
logger.warning(f"No audio in {data_path}")
media[0] = torch.zeros((998, 64), dtype=torch.float32)
return media, caption, index
except Exception as e:
# print(e)
if self.num_tries < self.now_tries:
raise e
else:
self.now_tries += 1
logger.warning(f"Caught exception {e} when loading audio-video {ann}")
# logger.warning(f"Caught exception when loading audio-video {ann['video']}")
index = np.random.randint(0, len(self))
return self.__getitem__(index)
class AudioTxtPtTrainDataset(BaseDataset):
media_type = "audio"
def __init__(self, ann_file, transform,
audio_sample_rate=16000,
audio_reader_type='torchaudio',
max_audio_length=10,
num_tries=3,
num_epochs=1):
super().__init__()
logger.info(f"ann_file: {ann_file}")
self.media_type = ann_file.media_type
self.label_file = ann_file.anno_path
self.data_root = ann_file.data_root
self.data_root_prefix = ann_file.get("data_root_prefix", "")
self.min_caption_length = ann_file.get("min_caption_length", 2)
self.caption_augmentation = ann_file.get("caption_augmentation", None)
self.transform = transform
self.audio_sample_rate = audio_sample_rate
self.max_audio_length = max_audio_length
self.audio_reader_type = audio_reader_type
self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False)
assert not self.has_multi_audio_gt
self.use_prompt = ann_file.get("prompt", "") != ""
if self.use_prompt:
if ann_file.prompt == "imagenet":
self.prompt = imagenet_templates
logger.info(f"Use prompt for ImageNet")
elif ann_file.prompt == "kinetics":
self.prompt = kinetics_templates
logger.info(f"Use prompt for Kinetics")
else:
raise NotImplementedError(ann_file.prompt)
logger.info(self.prompt)
if self.use_prompt and self.caption_augmentation is not None:
raise NotImplementedError("You can't use prompt because of multiple captions!")
if '.json' in self.label_file:
logger.info(f"Loading json file {self.label_file}")
if get_local_rank() == 0: # Only one rank need to read the file
with io.BytesIO(self.client.get(self.label_file)) as f:
# with open(self.label_file, 'r') as f:
annos = json.load(f)
if ann_file.get("jump_filter", False):
logger.info("Jump filter!")
else:
if self.caption_augmentation is not None:
# filter out the caption with length less than min_caption_length
new_annos = []
if self.caption_augmentation.caption_sample_type == 'uniform':
for anno in annos:
if "captions" in anno.keys():
caption_key = "captions"
else:
caption_key = "caption"
assert type(anno[caption_key]) is list, type(anno[caption_key])
caption_list = []
for c in anno[caption_key]:
tmp_c = pre_text(c)
if len(tmp_c.split()) >= self.min_caption_length:
caption_list.append(tmp_c)
if len(caption_list) > 0:
new_annos.append(anno)
else:
raise NotImplementedError(ann_file)
logger.info(f"Num samples: {len(annos)}")
logger.info(f"Num samples not too short: {len(new_annos)} min_caption_length={self.min_caption_length}")
annos = new_annos
else:
# filter out the caption with length less than min_caption_length
captions = [pre_text(anno["caption"]) for anno in annos]
captions_len = [len(caption.split()) for caption in captions]
logger.info("Num samples: {}".format(len(captions)))
logger.info("Num samples too short: {}".format(sum([l < self.min_caption_length for l in captions_len])))
annos = [anno for anno, l in zip(annos, captions_len) if l >= self.min_caption_length]
if num_epochs < 1:
raise NotImplementedError
else:
annos = []
self.anno = TorchShmSerializedList(annos)
self.num_examples = len(self.anno)
logger.info(f"num_examples: {self.num_examples}")
else:
raise NotImplementedError("We need json file!!!")
def __len__(self):
return self.num_examples
def get_caption(self, index):
if '.json' in self.label_file:
if self.caption_augmentation is not None:
if "captions" in self.anno[index].keys():
captions = self.anno[index]["captions"]
else:
captions = self.anno[index]["caption"]
else:
caption = self.anno[index]["caption"]
else:
raise NotImplementedError
if self.caption_augmentation is not None:
if self.caption_augmentation.caption_sample_type == 'uniform':
caption = random.choice(captions)
else:
raise NotImplementedError
return caption
def get_anno(self, index):
assert self.media_type == 'audio', self.media_type
anno = {"caption": self.get_caption(index)}
anno["audio"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["audio"])
if self.use_prompt:
anno["caption"] = random.choice(self.prompt).format(anno["caption"])
return anno
def pre_caption(self, caption):
if type(caption) is str:
return pre_text(caption)
else:
raise NotImplementedError(caption)
def __getitem__(self, index):
try:
ann = self.get_anno(index)
caption = self.pre_caption(ann["caption"])
audio, index = self.load_and_transform_media_data(index, ann["audio"])
return audio, caption, index
except Exception as e:
logger.warning(f"Caught exception {e} when loading audio {ann}")
print(e)
index = np.random.randint(0, len(self))
return self.__getitem__(index)