from modules.dataset_tvrr import TrainDataset, QueryEvalDataset, CorpusEvalDataset import torch from torch.utils.data import DataLoader from utils.tensor_utils import pad_sequences_1d import numpy as np def collate_fn(batch, task): fixed_length = 128 batch_data = dict() if task == "train": simis = [e["simi"] for e in batch] batch_data["simi"] = torch.tensor(simis) query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None) batch_data["query_feat"] = query_feat_mask[0] batch_data["query_mask"] = query_feat_mask[1] video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length) batch_data["video_feat"] = video_feat_mask[0] batch_data["video_mask"] = video_feat_mask[1] sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length) batch_data["sub_feat"] = sub_feat_mask[0] batch_data["sub_mask"] = sub_feat_mask[1] st_ed_indices = [e["st_ed_indices"] for e in batch] batch_data["st_ed_indices"] = torch.stack(st_ed_indices, dim=0) match_labels = np.zeros(shape=(len(st_ed_indices), fixed_length), dtype=np.int32) for idx, st_ed_index in enumerate(st_ed_indices): st_ed = st_ed_index.cpu().numpy() st, ed = st_ed[0], st_ed[1] match_labels[idx][st:(ed + 1)] = 1 batch_data['match_labels'] = torch.tensor(match_labels, dtype=torch.long) if task == "corpus": video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length) batch_data["video_feat"] = video_feat_mask[0] batch_data["video_mask"] = video_feat_mask[1] sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length) batch_data["sub_feat"] = sub_feat_mask[0] batch_data["sub_mask"] = sub_feat_mask[1] # batch_data["video_name"] = [e["video_name"] for e in batch] if task == "eval": query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None) batch_data["query_feat"] = query_feat_mask[0] batch_data["query_mask"] = query_feat_mask[1] query_id = [e["query_id"] for e in batch] batch_data["query_id"] = torch.tensor(query_id) return batch_data def prepare_dataset(opt): train_set = TrainDataset( data_path=opt.train_path, desc_bert_path=opt.desc_bert_path, sub_bert_path=opt.sub_bert_path, max_desc_len=opt.max_desc_l, max_ctx_len=opt.max_ctx_l, video_feat_path=opt.video_feat_path, clip_length=opt.clip_length, ctx_mode=opt.ctx_mode, normalize_vfeat=not opt.no_norm_vfeat, normalize_tfeat=not opt.no_norm_tfeat) train_loader = DataLoader(train_set, collate_fn=lambda batch: collate_fn(batch, task='train'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=True, pin_memory=True) corpus_set = CorpusEvalDataset(corpus_path=opt.corpus_path, max_ctx_len=opt.max_ctx_l, sub_bert_path=opt.sub_bert_path, video_feat_path=opt.video_feat_path, ctx_mode=opt.ctx_mode) corpus_loader = DataLoader(corpus_set, collate_fn=lambda batch: collate_fn(batch, task='corpus'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=False, pin_memory=True) val_set = QueryEvalDataset(data_path=opt.val_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l) val_loader = DataLoader(val_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True) test_set = QueryEvalDataset(data_path=opt.test_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l) test_loader = DataLoader(test_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True) val_gt = val_set.ground_truth test_gt = test_set.ground_truth corpus_video_list = corpus_set.corpus_video_list return train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt