mazpie's picture
Initial commit
2d9a728
raw
history blame
2.39 kB
import json
from dataset.base_dataset import BaseDataset
from dataset.utils import pre_text, load_anno
from dataset.video_utils import VIDEO_READER_FUNCS
import logging
logger = logging.getLogger(__name__)
class ImageQADataset(BaseDataset):
media_type = "image"
def __init__(self, ann_file, transform, eos="[SEP]", mode="train", answer_list=None):
super(ImageQADataset, self).__init__()
assert mode in ["train", "eval"]
self.mode = mode
self.transform = transform
self.eos = eos
self.anno_list = load_anno(ann_file)
if mode == "eval":
self.answer_list = json.load(open(answer_list, "r"))
def __len__(self):
return len(self.anno_list)
def get_answers_with_weights(self, raw_answers):
if isinstance(raw_answers, str):
raw_answers = [raw_answers]
answer_weight = {}
for answer in raw_answers:
if answer in answer_weight.keys():
answer_weight[answer] += 1/len(raw_answers)
else:
answer_weight[answer] = 1/len(raw_answers)
answers = list(answer_weight.keys())
weights = [answer_weight[a] for a in answers]
answers = [answer + " " + self.eos for answer in answers]
return answers, weights
def __getitem__(self, index):
ann = self.anno_list[index]
image, index = self.load_and_transform_media_data(index, ann["media"])
question = pre_text(ann["question"])
if self.mode == "train":
answers, weights = self.get_answers_with_weights(ann["answer"])
return image, question, answers, weights
else: # self.mode == "eval":
question_id = ann["question_id"]
return image, question, question_id
class VideoQADataset(ImageQADataset):
media_type = "video"
def __init__(
self, ann_file, transform, eos="[SEP]", mode="train", answer_list=None,
num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=1
):
super(VideoQADataset, self).__init__(
ann_file, transform, eos, mode, answer_list)
self.num_frames = num_frames
self.video_reader_type = video_reader_type
self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
self.sample_type = sample_type
self.num_tries = num_tries