import logging import os import json import random import io import torch import numpy as np from dataset.base_dataset import BaseDataset from dataset.text_prompt import kinetics_templates, imagenet_templates from dataset.utils import pre_text from dataset.video_utils import VIDEO_READER_FUNCS from dataset.serialize import get_local_rank, TorchShmSerializedList logger = logging.getLogger(__name__) class ImgTxtPtTrainDataset(BaseDataset): media_type = "image" def __init__(self, ann_file, transform, num_epochs=1): super().__init__() logger.info(f"ann_file: {ann_file}") self.media_type = ann_file.media_type self.label_file = ann_file.anno_path self.data_root = ann_file.data_root self.data_root_prefix = ann_file.get("data_root_prefix", "") self.min_caption_length = ann_file.get("min_caption_length", 2) self.caption_augmentation = ann_file.get("caption_augmentation", None) self.transform = transform # each caption has multiple image as ground_truth, e.g., ssv2 self.has_multi_vision_gt = ann_file.get("has_multi_vision_gt", False) assert not self.has_multi_vision_gt self.crop_img = ann_file.get("crop_img", False) self.use_prompt = ann_file.get("prompt", "") != "" if self.use_prompt: if ann_file.prompt == "imagenet": self.prompt = imagenet_templates logger.info(f"Use prompt for ImageNet") elif ann_file.prompt == "kinetics": self.prompt = kinetics_templates logger.info(f"Use prompt for Kinetics") else: raise NotImplementedError(ann_file.prompt) logger.info(self.prompt) if self.use_prompt and self.caption_augmentation is not None: raise NotImplementedError("You can't use prompt because of multiple captions!") if '.json' in self.label_file: logger.info(f"Loading json file {self.label_file}") if get_local_rank() == 0: # Only one rank need to read the file with io.BytesIO(self.client.get(self.label_file)) as f: # with open(self.label_file, 'r') as f: annos = json.load(f) if ann_file.get("jump_filter", False): logger.info("Jump filter!") else: if self.caption_augmentation is not None: # filter out the caption with length less than min_caption_length new_annos = [] if self.media_type == "audio_video" and self.caption_augmentation.caption_sample_type == 'avs_all': for anno in annos: ok = True if not anno['video'].endswith('.mp4'): ok = False for k in anno.keys(): if "caption" in k and 'asr' not in k: tmp_c = pre_text(anno[k]) if len(tmp_c.split()) < self.min_caption_length: ok = False break if ok: new_annos.append(anno) elif self.caption_augmentation.caption_sample_type == 'uniform': for anno in annos: if "captions" in anno.keys(): caption_key = "captions" else: caption_key = "caption" assert type(anno[caption_key]) is list, type(anno[caption_key]) caption_list = [] for c in anno[caption_key]: tmp_c = pre_text(c) if len(tmp_c.split()) >= self.min_caption_length: caption_list.append(tmp_c) if len(caption_list) > 0: new_annos.append(anno) else: raise NotImplementedError(ann_file) logger.info(f"Num samples: {len(annos)}") logger.info(f"Num samples not too short: {len(new_annos)} min_caption_length={self.min_caption_length}") annos = new_annos else: # filter out the caption with length less than min_caption_length captions = [pre_text(anno["caption"]) for anno in annos] captions_len = [len(caption.split()) for caption in captions] logger.info("Num samples: {}".format(len(captions))) logger.info("Num samples too short: {}".format(sum([l < self.min_caption_length for l in captions_len]))) annos = [anno for anno, l in zip(annos, captions_len) if l >= self.min_caption_length] if num_epochs < 1: raise NotImplementedError else: annos = [] self.anno = TorchShmSerializedList(annos) self.num_examples = len(self.anno) logger.info(f"num_examples: {self.num_examples}") else: raise NotImplementedError("We need json file!!!") def __len__(self): return self.num_examples def get_caption(self, index): if '.json' in self.label_file: if self.caption_augmentation is not None: if self.caption_augmentation.caption_sample_type == 'avs_all': caption_dict = {} for k in self.anno[index].keys(): if 'caption' in k: caption_dict[k] = self.anno[index][k] else: if "captions" in self.anno[index].keys(): captions = self.anno[index]["captions"] else: captions = self.anno[index]["caption"] else: caption = self.anno[index]["caption"] else: raise NotImplementedError if self.caption_augmentation is not None: if self.caption_augmentation.caption_sample_type == 'uniform': caption = random.choice(captions) elif self.caption_augmentation.caption_sample_type == 'avs_all': caption = caption_dict else: raise NotImplementedError return caption def get_anno(self, index): assert self.media_type == 'image', self.media_type anno = {"caption": self.get_caption(index)} anno["image"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["image"]) if self.use_prompt: anno["caption"] = random.choice(self.prompt).format(anno["caption"]) if self.crop_img: anno["crop_bbox"] = self.anno[index]["crop_bbox"] return anno def pre_caption(self, caption): if type(caption) is str: return pre_text(caption) elif type(caption) is dict: assert self.caption_augmentation.caption_sample_type == 'avs_all' caption_dict = {} for k in caption.keys(): caption_dict[k] = pre_text(caption[k]) return caption_dict else: raise NotImplementedError(caption) def __getitem__(self, index): try: ann = self.get_anno(index) caption = self.pre_caption(ann["caption"]) # key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"]) if self.crop_img: data_path = {"image":ann["image"], "crop_bbox":ann["crop_bbox"]} image, index = self.load_and_transform_media_data(index, data_path) else: image, index = self.load_and_transform_media_data(index, ann["image"]) return image, caption, index except Exception as e: logger.warning(f"Caught exception {e} when loading image {ann}") # raise e print(e) index = np.random.randint(0, len(self)) return self.__getitem__(index) class VidTxtPtTrainDataset(ImgTxtPtTrainDataset): media_type = "video" def __init__( self, ann_file, transform, num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3, num_epochs=1 ): super().__init__(ann_file, transform, num_epochs) self.num_frames = num_frames self.video_reader_type = video_reader_type self.video_reader = VIDEO_READER_FUNCS[video_reader_type] self.sample_type = sample_type self.num_tries = num_tries self.is_paragraph_retrieval = ann_file.get("is_paragraph_retrieval", False) self.read_clip_from_video = ann_file.get("read_clip_from_video", False) if self.is_paragraph_retrieval: raise NotImplementedError def get_anno(self, index): assert self.media_type == "video", self.media_type anno = {"caption": self.get_caption(index)} anno["video"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["video"]) if self.read_clip_from_video: anno["video_start_frame"] = self.anno[index]["video_start_frame"] anno["video_end_frame"] = self.anno[index]["video_end_frame"] if self.use_prompt: anno["caption"] = random.choice(self.prompt).format(anno["caption"]) return anno def __getitem__(self, index): try: ann = self.get_anno(index) caption = self.pre_caption(ann["caption"]) if self.read_clip_from_video: data_path = { "video": ann["video"], "video_start_frame": ann["video_start_frame"], "video_end_frame": ann["video_end_frame"], "read_clip_from_video": True } else: data_path = ann["video"] video, index = self.load_and_transform_media_data(index, data_path) return video, caption, index except Exception as e: logger.warning(f"Caught exception {e} when loading video {ann}") # raise e print(e) index = np.random.randint(0, len(self)) return self.__getitem__(index) class AudioVidTxtPtTrainDataset(VidTxtPtTrainDataset): media_type = "audio_video" def __init__( self, ann_file, transform, audio_sample_rate=16000, audio_reader_type='torchaudio', max_audio_length=10, num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3, num_epochs=1 ): super().__init__(ann_file, transform, num_epochs=num_epochs, num_frames=num_frames, video_reader_type=video_reader_type, sample_type=sample_type, num_tries=num_tries) assert self.media_type == 'audio_video', self.media_type self.audio_sample_rate = audio_sample_rate self.audio_reader_type = audio_reader_type self.max_audio_length = max_audio_length self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False) self.read_audio_from_video = ann_file.get("read_audio_from_video", False) self.zero_audio_padding_for_video = ann_file.get("zero_audio_padding_for_video", False) self.now_tries = 0 def get_anno(self, index): anno = {"caption": self.get_caption(index)} anno["video"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["video"]) if self.read_clip_from_video: anno["video_start_frame"] = self.anno[index]["video_start_frame"] anno["video_end_frame"] = self.anno[index]["video_end_frame"] if "audio" in self.anno[index].keys(): anno["audio"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["audio"]) if self.use_prompt: anno["caption"] = random.choice(self.prompt).format(anno["caption"]) return anno def __getitem__(self, index): try: ann = self.get_anno(index) caption = self.pre_caption(ann["caption"]) data_path = {'video': ann["video"]} if self.read_clip_from_video: data_path["video_start_frame"] = ann["video_start_frame"] data_path["video_end_frame"] = ann["video_end_frame"] if "audio" in ann.keys(): data_path["read_audio_from_video"] = False data_path["audio"] = ann["audio"] else: data_path["read_audio_from_video"] = self.read_audio_from_video data_path["read_clip_from_video"] = self.read_clip_from_video media, index = self.load_and_transform_media_data(index, data_path) self.now_tries = 0 audio = media[0] if audio is None and self.zero_audio_padding_for_video: logger.warning(f"No audio in {data_path}") media[0] = torch.zeros((998, 64), dtype=torch.float32) return media, caption, index except Exception as e: # print(e) if self.num_tries < self.now_tries: raise e else: self.now_tries += 1 logger.warning(f"Caught exception {e} when loading audio-video {ann}") # logger.warning(f"Caught exception when loading audio-video {ann['video']}") index = np.random.randint(0, len(self)) return self.__getitem__(index) class AudioTxtPtTrainDataset(BaseDataset): media_type = "audio" def __init__(self, ann_file, transform, audio_sample_rate=16000, audio_reader_type='torchaudio', max_audio_length=10, num_tries=3, num_epochs=1): super().__init__() logger.info(f"ann_file: {ann_file}") self.media_type = ann_file.media_type self.label_file = ann_file.anno_path self.data_root = ann_file.data_root self.data_root_prefix = ann_file.get("data_root_prefix", "") self.min_caption_length = ann_file.get("min_caption_length", 2) self.caption_augmentation = ann_file.get("caption_augmentation", None) self.transform = transform self.audio_sample_rate = audio_sample_rate self.max_audio_length = max_audio_length self.audio_reader_type = audio_reader_type self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False) assert not self.has_multi_audio_gt self.use_prompt = ann_file.get("prompt", "") != "" if self.use_prompt: if ann_file.prompt == "imagenet": self.prompt = imagenet_templates logger.info(f"Use prompt for ImageNet") elif ann_file.prompt == "kinetics": self.prompt = kinetics_templates logger.info(f"Use prompt for Kinetics") else: raise NotImplementedError(ann_file.prompt) logger.info(self.prompt) if self.use_prompt and self.caption_augmentation is not None: raise NotImplementedError("You can't use prompt because of multiple captions!") if '.json' in self.label_file: logger.info(f"Loading json file {self.label_file}") if get_local_rank() == 0: # Only one rank need to read the file with io.BytesIO(self.client.get(self.label_file)) as f: # with open(self.label_file, 'r') as f: annos = json.load(f) if ann_file.get("jump_filter", False): logger.info("Jump filter!") else: if self.caption_augmentation is not None: # filter out the caption with length less than min_caption_length new_annos = [] if self.caption_augmentation.caption_sample_type == 'uniform': for anno in annos: if "captions" in anno.keys(): caption_key = "captions" else: caption_key = "caption" assert type(anno[caption_key]) is list, type(anno[caption_key]) caption_list = [] for c in anno[caption_key]: tmp_c = pre_text(c) if len(tmp_c.split()) >= self.min_caption_length: caption_list.append(tmp_c) if len(caption_list) > 0: new_annos.append(anno) else: raise NotImplementedError(ann_file) logger.info(f"Num samples: {len(annos)}") logger.info(f"Num samples not too short: {len(new_annos)} min_caption_length={self.min_caption_length}") annos = new_annos else: # filter out the caption with length less than min_caption_length captions = [pre_text(anno["caption"]) for anno in annos] captions_len = [len(caption.split()) for caption in captions] logger.info("Num samples: {}".format(len(captions))) logger.info("Num samples too short: {}".format(sum([l < self.min_caption_length for l in captions_len]))) annos = [anno for anno, l in zip(annos, captions_len) if l >= self.min_caption_length] if num_epochs < 1: raise NotImplementedError else: annos = [] self.anno = TorchShmSerializedList(annos) self.num_examples = len(self.anno) logger.info(f"num_examples: {self.num_examples}") else: raise NotImplementedError("We need json file!!!") def __len__(self): return self.num_examples def get_caption(self, index): if '.json' in self.label_file: if self.caption_augmentation is not None: if "captions" in self.anno[index].keys(): captions = self.anno[index]["captions"] else: captions = self.anno[index]["caption"] else: caption = self.anno[index]["caption"] else: raise NotImplementedError if self.caption_augmentation is not None: if self.caption_augmentation.caption_sample_type == 'uniform': caption = random.choice(captions) else: raise NotImplementedError return caption def get_anno(self, index): assert self.media_type == 'audio', self.media_type anno = {"caption": self.get_caption(index)} anno["audio"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["audio"]) if self.use_prompt: anno["caption"] = random.choice(self.prompt).format(anno["caption"]) return anno def pre_caption(self, caption): if type(caption) is str: return pre_text(caption) else: raise NotImplementedError(caption) def __getitem__(self, index): try: ann = self.get_anno(index) caption = self.pre_caption(ann["caption"]) audio, index = self.load_and_transform_media_data(index, ann["audio"]) return audio, caption, index except Exception as e: logger.warning(f"Caught exception {e} when loading audio {ann}") print(e) index = np.random.randint(0, len(self)) return self.__getitem__(index)