English
TVR-Ranking / modules /dataset_init.py
Liangrj5
init
5019d3f
raw
history blame
4.32 kB
from modules.dataset_tvrr import TrainDataset, QueryEvalDataset, CorpusEvalDataset
import torch
from torch.utils.data import DataLoader
from utils.tensor_utils import pad_sequences_1d
import numpy as np
def collate_fn(batch, task):
fixed_length = 128
batch_data = dict()
if task == "train":
simis = [e["simi"] for e in batch]
batch_data["simi"] = torch.tensor(simis)
query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
batch_data["query_feat"] = query_feat_mask[0]
batch_data["query_mask"] = query_feat_mask[1]
video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
batch_data["video_feat"] = video_feat_mask[0]
batch_data["video_mask"] = video_feat_mask[1]
sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
batch_data["sub_feat"] = sub_feat_mask[0]
batch_data["sub_mask"] = sub_feat_mask[1]
st_ed_indices = [e["st_ed_indices"] for e in batch]
batch_data["st_ed_indices"] = torch.stack(st_ed_indices, dim=0)
match_labels = np.zeros(shape=(len(st_ed_indices), fixed_length), dtype=np.int32)
for idx, st_ed_index in enumerate(st_ed_indices):
st_ed = st_ed_index.cpu().numpy()
st, ed = st_ed[0], st_ed[1]
match_labels[idx][st:(ed + 1)] = 1
batch_data['match_labels'] = torch.tensor(match_labels, dtype=torch.long)
if task == "corpus":
video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
batch_data["video_feat"] = video_feat_mask[0]
batch_data["video_mask"] = video_feat_mask[1]
sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
batch_data["sub_feat"] = sub_feat_mask[0]
batch_data["sub_mask"] = sub_feat_mask[1]
if task == "eval":
query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
batch_data["query_feat"] = query_feat_mask[0]
batch_data["query_mask"] = query_feat_mask[1]
query_id = [e["query_id"] for e in batch]
batch_data["query_id"] = torch.tensor(query_id)
return batch_data
def prepare_dataset(opt):
train_set = TrainDataset(
data_path=opt.train_path,
desc_bert_path=opt.desc_bert_path,
sub_bert_path=opt.sub_bert_path,
max_desc_len=opt.max_desc_l,
max_ctx_len=opt.max_ctx_l,
video_feat_path=opt.video_feat_path,
clip_length=opt.clip_length,
ctx_mode=opt.ctx_mode,
normalize_vfeat=not opt.no_norm_vfeat,
normalize_tfeat=not opt.no_norm_tfeat)
train_loader = DataLoader(train_set, collate_fn=lambda batch: collate_fn(batch, task='train'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=True, pin_memory=True)
corpus_set = CorpusEvalDataset(corpus_path=opt.corpus_path, max_ctx_len=opt.max_ctx_l, sub_bert_path=opt.sub_bert_path, video_feat_path=opt.video_feat_path, ctx_mode=opt.ctx_mode)
corpus_loader = DataLoader(corpus_set, collate_fn=lambda batch: collate_fn(batch, task='corpus'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
val_set = QueryEvalDataset(data_path=opt.val_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l)
val_loader = DataLoader(val_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
test_set = QueryEvalDataset(data_path=opt.test_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l)
test_loader = DataLoader(test_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
val_gt = val_set.ground_truth
test_gt = test_set.ground_truth
corpus_video_list = corpus_set.corpus_video_list
return train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt