Spaces:
Runtime error
Runtime error
import os | |
import pdb | |
import h5py | |
import nncore | |
import torch | |
from torch.utils.data import Dataset | |
import numpy as np | |
from tqdm import tqdm | |
import random | |
import logging | |
from os.path import join, exists | |
from nncore.dataset import DATASETS | |
from nncore.parallel import DataContainer | |
from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS | |
from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array | |
from utils.tensor_utils import pad_sequences_1d | |
from utils.span_utils import span_xx_to_cxw | |
logger = logging.getLogger(__name__) | |
class DatasetQFVS(Dataset): | |
def __init__(self,config, use_tef=True): | |
# pdb.set_trace() | |
self.config=config | |
self.dataset=[] | |
self.use_tef=use_tef | |
self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl") | |
self.transfer={"Cupglass":"Glass", | |
"Musicalinstrument":"Instrument", | |
"Petsanimal":"Animal"} | |
self.f_dict = {} | |
feat_type = self.config['vid_feature'] | |
for video_id in self.config["train_videos"]: | |
self.f_dict[str(video_id)] = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') | |
for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)): | |
for file in files: | |
self.dataset.append(['Oracle', file[:file.find("_oracle.txt")]+"_"+str(video_id)]) | |
if self.config['qfvs_dense_shot'] > 0: | |
dense_concept = {} | |
feat_type = self.config['vid_feature'] | |
feat=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') | |
features=feat['features'][()] | |
seg_len=feat['seg_len'][()] | |
with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+str(video_id)+"/P0"+str(video_id)+".txt","r") as f: | |
lines=f.readlines() | |
for index,line in enumerate(lines): | |
concepts=line.strip().split(',') | |
for concept in concepts: | |
if concept in self.transfer: | |
concept= self.transfer[concept] | |
if concept not in dense_concept: | |
# dense_concept[concept] = torch.zeros(seg_len.sum()) | |
dense_concept[concept] = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"]) | |
else: | |
dense_concept[concept][index] = 1 | |
for key, value in dense_concept.items(): | |
if value.sum().item() > 0: | |
self.dataset.append([video_id, key, value]) | |
def __getitem__(self, index): | |
if self.dataset[index][0] == 'Oracle': | |
return self.get_oracle(index) | |
else: | |
return self.get_dense(index) | |
def get_dense(self,index): | |
video_id=str(self.dataset[index][0]) | |
f = self.f_dict[video_id] | |
# feat_type = self.config['vid_feature'] | |
# f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') | |
features=f['features'][()] | |
seg_len=f['seg_len'][()] | |
dim = features.shape[-1] | |
mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool) | |
for j in range(len(seg_len)): | |
for k in range(seg_len[j]): | |
mask_GT[j][k] = 1 | |
features = torch.from_numpy(features) | |
concept1 = concept2 = self.dataset[index][1] | |
concept1_GT = concept2_GT = oracle_summary = self.dataset[index][2] | |
if concept1 in self.transfer: | |
concept1=self.transfer[concept1] | |
if concept2 in self.transfer: | |
concept2=self.transfer[concept2] | |
concept1=self.embedding[concept1] | |
concept2=self.embedding[concept2] | |
concept1 = l2_normalize_np_array(concept1) | |
concept2 = l2_normalize_np_array(concept2) | |
try: | |
saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())]) | |
except: | |
saliency_pos_labels_1 = torch.Tensor(0) | |
try: | |
saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())]) | |
except: | |
saliency_pos_labels_2 = torch.Tensor(0) | |
try: | |
saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())]) | |
except: | |
saliency_pos_labels_oracle = torch.Tensor(0) | |
return { | |
'features':features, | |
'seg_len':torch.from_numpy(seg_len), | |
'concept1_GT':concept1_GT, | |
'concept2_GT':concept2_GT, | |
'mask_GT':mask_GT, | |
'oracle_summary':oracle_summary, | |
'tokens_pad1':torch.from_numpy(concept1), | |
'tokens_pad2':torch.from_numpy(concept2), | |
'saliency_pos_labels_1': saliency_pos_labels_1, | |
'saliency_pos_labels_2': saliency_pos_labels_2, | |
'saliency_pos_labels_oracle': saliency_pos_labels_oracle, | |
} | |
def get_oracle(self,index): | |
video_id=self.dataset[index][1].split('_')[2] | |
f = self.f_dict[video_id] | |
# video_id=self.dataset[index][1].split('_')[2] | |
# feat_type = self.config['vid_feature'] | |
# f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') | |
features=f['features'][()] | |
seg_len=f['seg_len'][()] | |
dim = features.shape[-1] | |
mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool) | |
for j in range(len(seg_len)): | |
for k in range(seg_len[j]): | |
mask_GT[j][k] = 1 | |
features = torch.from_numpy(features) | |
concept1,concept2=self.dataset[index][1].split('_')[0:2] | |
concept1_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"]) | |
concept2_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"]) | |
# concept1_GT=torch.zeros(seg_len.sum()) | |
# concept2_GT= torch.zeros(seg_len.sum()) | |
with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f: | |
lines=f.readlines() | |
for index,line in enumerate(lines): | |
concepts=line.strip().split(',') | |
if concept1 in concepts: | |
concept1_GT[index]=1 | |
if concept2 in concepts: | |
concept2_GT[index]=1 | |
# oracle_summary =torch.zeros(seg_len.sum()) | |
oracle_summary = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"]) | |
GT_summary_shots = [] | |
with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f: | |
for line in f.readlines(): | |
GT_summary_shots.append(int(line.strip())) | |
GT_summary_shots = [x - 1 for x in GT_summary_shots] | |
for element in GT_summary_shots: | |
oracle_summary[element] = 1 | |
if concept1 in self.transfer: | |
concept1=self.transfer[concept1] | |
if concept2 in self.transfer: | |
concept2=self.transfer[concept2] | |
concept1=self.embedding[concept1] | |
concept2=self.embedding[concept2] | |
concept1 = l2_normalize_np_array(concept1) | |
concept2 = l2_normalize_np_array(concept2) | |
try: | |
saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())]) | |
except: | |
saliency_pos_labels_1 = torch.Tensor(0) | |
try: | |
saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())]) | |
except: | |
saliency_pos_labels_2 = torch.Tensor(0) | |
try: | |
saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())]) | |
except: | |
saliency_pos_labels_oracle = torch.Tensor(0) | |
return { | |
'features':features, | |
'seg_len':torch.from_numpy(seg_len), | |
'concept1_GT':concept1_GT, | |
'concept2_GT':concept2_GT, | |
'mask_GT':mask_GT, | |
'oracle_summary':oracle_summary, | |
'tokens_pad1':torch.from_numpy(concept1), | |
'tokens_pad2':torch.from_numpy(concept2), | |
'saliency_pos_labels_1': saliency_pos_labels_1, | |
'saliency_pos_labels_2': saliency_pos_labels_2, | |
'saliency_pos_labels_oracle': saliency_pos_labels_oracle, | |
} | |
def __len__(self): | |
return len(self.dataset) | |
def start_end_collate_qfvs(batch): | |
model_inputs_keys = batch[0].keys() | |
batched_data = dict() | |
for k in model_inputs_keys: | |
batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None) | |
return batched_data | |
def prepare_batch_inputs_qfvs(data, config, eval=False): | |
if not eval: | |
features, mask, seg_len, \ | |
concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \ | |
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\ | |
saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \ | |
data['features'][0], data['mask_GT'][0], data['seg_len'][0],\ | |
data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\ | |
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \ | |
data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0], | |
else: | |
features, mask, seg_len, \ | |
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \ | |
data['features'][0], data['mask_GT'][0], data['seg_len'][0],\ | |
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1] | |
# preprocess for vid input. | |
mask_GT = mask.to('cuda').reshape(1, -1).bool() | |
seq = features.to('cuda').squeeze(0) | |
mask = mask.to('cuda').squeeze(0) | |
num_seg = seq.shape[0] | |
ctx_l = seq.shape[1] | |
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l | |
tef_ed = tef_st + 1.0 / ctx_l | |
tef = torch.stack([tef_st, tef_ed], dim=1).to('cuda') # (Lv, 2) | |
tef = tef.squeeze(0).repeat(seq.shape[0], 1, 1) | |
seq = torch.cat([seq, tef], dim=-1) | |
# for txt input. | |
src_txt_1 = src_txt_1.to(torch.float32).to('cuda').repeat(num_seg, 1, 1) | |
src_txt_2 = src_txt_2.to(torch.float32).to('cuda').repeat(num_seg, 1, 1) | |
src_txt_mask_1 = src_txt_mask_1.to('cuda').repeat(num_seg, 1) | |
src_txt_mask_2 = src_txt_mask_2.to('cuda').repeat(num_seg, 1) | |
src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda') | |
src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda') | |
model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1) | |
model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2) | |
model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle) | |
# concept1_GT = concept1_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num']) | |
# concept2_GT = concept2_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num']) | |
# oracle_summary_GT = oracle_summary_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num']) | |
if not eval: | |
targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda')) | |
targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda')) | |
targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda')) | |
targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda') | |
targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda') | |
targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda') | |
return model_inputs_1, model_inputs_2, model_inputs_oracle, \ | |
targets_1, targets_2, targets_oracle, mask_GT | |
else: | |
return model_inputs_1, model_inputs_2, model_inputs_oracle, mask_GT |