English
File size: 4,404 Bytes
5019d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dae63ab
 
 
5019d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from modules.dataset_tvrr import TrainDataset, QueryEvalDataset, CorpusEvalDataset
import torch
from torch.utils.data import DataLoader
from utils.tensor_utils import pad_sequences_1d
import numpy as np

def collate_fn(batch, task):
    fixed_length = 128
    batch_data = dict()

    if task == "train":
        simis = [e["simi"] for e in batch]
        batch_data["simi"] =  torch.tensor(simis)
        

        
        query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
        batch_data["query_feat"] = query_feat_mask[0]
        batch_data["query_mask"] = query_feat_mask[1]    
        video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
        batch_data["video_feat"] = video_feat_mask[0]
        batch_data["video_mask"] = video_feat_mask[1]
        sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
        batch_data["sub_feat"] = sub_feat_mask[0]
        batch_data["sub_mask"] = sub_feat_mask[1]

        st_ed_indices = [e["st_ed_indices"] for e in batch]
        batch_data["st_ed_indices"] = torch.stack(st_ed_indices, dim=0)
        match_labels = np.zeros(shape=(len(st_ed_indices), fixed_length), dtype=np.int32)
        for idx, st_ed_index in enumerate(st_ed_indices):
            st_ed = st_ed_index.cpu().numpy()
            st, ed = st_ed[0], st_ed[1]
            match_labels[idx][st:(ed + 1)] = 1
        batch_data['match_labels'] = torch.tensor(match_labels, dtype=torch.long)
        
    if task == "corpus":
        video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
        batch_data["video_feat"] = video_feat_mask[0]
        batch_data["video_mask"] = video_feat_mask[1]
        sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
        batch_data["sub_feat"] = sub_feat_mask[0]
        batch_data["sub_mask"] = sub_feat_mask[1]
        
        # batch_data["video_name"] = [e["video_name"] for e in batch]
        
        
    if task == "eval":
        query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
        batch_data["query_feat"] = query_feat_mask[0]
        batch_data["query_mask"] = query_feat_mask[1]    

        query_id = [e["query_id"] for e in batch]
        batch_data["query_id"] =  torch.tensor(query_id)

    return  batch_data




def prepare_dataset(opt):
    train_set = TrainDataset(
        data_path=opt.train_path,
        desc_bert_path=opt.desc_bert_path,
        sub_bert_path=opt.sub_bert_path,
        max_desc_len=opt.max_desc_l,
        max_ctx_len=opt.max_ctx_l,
        video_feat_path=opt.video_feat_path,
        clip_length=opt.clip_length,
        ctx_mode=opt.ctx_mode,
        normalize_vfeat=not opt.no_norm_vfeat,
        normalize_tfeat=not opt.no_norm_tfeat)
    train_loader = DataLoader(train_set, collate_fn=lambda batch: collate_fn(batch, task='train'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=True, pin_memory=True)
    
    corpus_set = CorpusEvalDataset(corpus_path=opt.corpus_path, max_ctx_len=opt.max_ctx_l, sub_bert_path=opt.sub_bert_path, video_feat_path=opt.video_feat_path, ctx_mode=opt.ctx_mode)
    corpus_loader = DataLoader(corpus_set, collate_fn=lambda batch: collate_fn(batch, task='corpus'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=False, pin_memory=True)

    val_set = QueryEvalDataset(data_path=opt.val_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l)
    val_loader = DataLoader(val_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
    test_set = QueryEvalDataset(data_path=opt.test_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l)
    test_loader = DataLoader(test_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
    
    val_gt = val_set.ground_truth
    test_gt = test_set.ground_truth
    corpus_video_list = corpus_set.corpus_video_list
    return train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt