Spaces:
Runtime error
Runtime error
import os | |
import pdb | |
import h5py | |
import nncore | |
import torch | |
from torch.utils.data import Dataset | |
import numpy as np | |
from tqdm import tqdm | |
import random | |
import logging | |
from os.path import join, exists | |
from nncore.dataset import DATASETS | |
from nncore.parallel import DataContainer | |
from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS | |
from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array | |
from utils.tensor_utils import pad_sequences_1d | |
from utils.span_utils import span_xx_to_cxw | |
from random import shuffle | |
logger = logging.getLogger(__name__) | |
class DatasetVLP(Dataset): | |
Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"] | |
"""One line in data loaded from data_path." | |
{ | |
"qid": 7803, | |
"query": "Man in gray top walks from outside to inside.", | |
"duration": 150, | |
"vid": "RoripwjYFp8_360.0_510.0", | |
"relevant_clip_ids": [13, 14, 15, 16, 17], | |
"relevant_windows": [[26, 36]] | |
} | |
""" | |
def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim, | |
q_feat_type="last_hidden_state", | |
max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video", | |
normalize_v=True, normalize_t=True, load_labels=True, | |
clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0, | |
use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1): | |
self.dset_name = dset_name | |
self.data_path = data_path | |
self.data_ratio = data_ratio | |
self.v_feat_dirs = v_feat_dirs \ | |
if isinstance(v_feat_dirs, list) else [v_feat_dirs] | |
self.q_feat_dir = q_feat_dir | |
self.q_feat_type = q_feat_type | |
self.v_feat_dim = v_feat_dim | |
self.q_feat_dim = q_feat_dim | |
self.max_q_l = max_q_l | |
self.max_v_l = max_v_l | |
self.ctx_mode = ctx_mode | |
self.use_tef = "tef" in ctx_mode | |
self.use_video = "video" in ctx_mode | |
self.normalize_t = normalize_t | |
self.normalize_v = normalize_v | |
self.load_labels = load_labels | |
self.clip_len = clip_len | |
self.fix_len = fix_len | |
self.max_windows = max_windows # maximum number of windows to use as labels | |
self.span_loss_type = span_loss_type | |
self.txt_drop_ratio = txt_drop_ratio | |
self.use_cache = use_cache | |
self.add_easy_negative = add_easy_negative | |
self.easy_negative_only = easy_negative_only | |
self.vlp_mapping = { | |
# 'data/qvhighlights/metadata/qvhighlights_asr.jsonl': { | |
# 'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '_asr', 'type': 'interval', | |
# }, | |
# 'data/ego4d/metadata/point_train_1m.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', | |
# }, | |
# 'data/ego4d/metadata/point_train_1m_0.1p.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', | |
# }, | |
# 'data/ego4d/metadata/point_train_1m_0.2p.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', | |
# }, | |
# 'data/ego4d/metadata/point_train_1m_0.5p.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', | |
# }, | |
# 'data/ego4d/metadata/point_train_1m_0.75p.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', | |
# }, | |
# 'data/ego4d/metadata/point_train_2m.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', | |
# }, | |
# 'data/ego4d/metadata/point_train_1m_egoclip.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', | |
# }, | |
# 'data/hacs/metadata/hacs_train_cs.jsonl': { | |
# 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '_cs', 'type': 'curve', | |
# }, | |
# 'data/hacs/metadata/hacs_train.jsonl': { | |
# 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve', | |
# }, | |
# 'data/videocc/metadata/train_300k.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
# }, | |
# 'data/videocc/metadata/train_600k.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
# }, | |
# 'data/videocc/metadata/train_600k_0.1p.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
# }, | |
# 'data/videocc/metadata/train_600k_0.2p.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
# }, | |
# 'data/videocc/metadata/train_600k_0.5p.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
# }, | |
# 'data/videocc/metadata/train_600k_0.75p.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
# }, | |
# 'data/videocc/metadata/train_900k.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
# }, | |
# 'data/ego4d/metadata/concept_train_top10_window.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/ego4d/metadata/concept_train_top5_window.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/ego4d/metadata/concept_train_top5_window_0.1p.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/ego4d/metadata/concept_train_top5_window_0.2p.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/ego4d/metadata/concept_train_top5_window_0.5p.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/ego4d/metadata/concept_train_top5_window_0.75p.jsonl': { | |
# 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/videocc/metadata/concept_train_top10_window.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/videocc/metadata/concept_train_top5_window.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/videocc/metadata/concept_train_top5_window_0.1p.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/videocc/metadata/concept_train_top5_window_0.2p.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/videocc/metadata/concept_train_top5_window_0.5p.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# 'data/videocc/metadata/concept_train_top5_window_0.75p.jsonl': { | |
# 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', | |
# }, | |
# | |
# pre-training | |
'data/ego4d/metadata/point_egoclip_wo_val.jsonl': { | |
'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', | |
}, | |
'data/videocc/metadata/interval_900k.jsonl': { | |
'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
}, | |
'data/videocc/metadata/curve_5_window.jsonl': { | |
'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', | |
}, | |
# downstream | |
'data/qvhighlights/metadata/qvhighlights_train.jsonl': { | |
'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve', | |
}, | |
'data/charades/metadata/charades_train.jsonl': { | |
'dset_name': 'charades', 'v_feat_suffix': '_2', 'q_feat_suffix': '', 'type': 'interval', | |
}, | |
'data/ego4d/metadata/nlq_train.jsonl': { | |
'dset_name': 'ego4d', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
}, | |
'data/tacos/metadata/train.jsonl': { | |
'dset_name': 'tacos', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
}, | |
'data/anet/metadata/train.jsonl': { | |
'dset_name': 'anet', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
}, | |
'data/didemo/metadata/train.jsonl': { | |
'dset_name': 'didemo', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', | |
}, | |
} | |
if "val" in data_path or "test" in data_path: | |
assert txt_drop_ratio == 0 | |
# checks | |
assert q_feat_type in self.Q_FEAT_TYPES | |
# data | |
self.data = self.load_data() | |
self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs] | |
t_feat_type = q_feat_dir.split('/')[-1] | |
if self.use_cache > 0: | |
print('Loading the off-line features...') | |
dset_dir = os.path.join('data', self.dset_name) | |
vid_keys = [meta['vid'] for meta in self.data] | |
qid_keys = [meta['qid'] for meta in self.data] | |
self.vid_cache = {} | |
for v_feat_type in self.v_feat_types: | |
assert 'vid' in v_feat_type | |
with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f: | |
self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)} | |
assert 'txt' in t_feat_type | |
self.txt_cache = {} | |
with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f: | |
for key in tqdm(qid_keys): | |
try: | |
self.txt_cache[key] = f[str(key)][:] | |
except: | |
logger.info(f"text {key} is not in the cache.") | |
def load_data(self): | |
# datalist = load_jsonl(self.data_path[0]) | |
datalist = [] | |
for dset_path in self.data_path: | |
dset_info = self.vlp_mapping[dset_path] | |
dset_list = load_jsonl(dset_path) | |
for x in dset_list: x.update(dset_info) | |
datalist += dset_list | |
n_examples = int(len(datalist)) | |
if self.data_ratio != 1: | |
n_examples = int(len(datalist) * self.data_ratio) | |
shuffle(datalist) | |
datalist = datalist[:n_examples] | |
logger.info("Using {}% of the data: {} examples" | |
.format(self.data_ratio * 100, n_examples)) | |
return datalist | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
meta = self.data[index] | |
model_inputs = dict() | |
model_inputs["query_feat"] = self._get_query_feat_by_qid(meta) # (Dq, ) or (Lq, Dq) | |
if self.use_video: | |
model_inputs["video_feat"] = self._get_video_feat_by_vid(meta) # (Lv, Dv) | |
ctx_l = len(model_inputs["video_feat"]) | |
else: | |
ctx_l = self.max_v_l | |
if meta['dset_name'] in ['hacs', 'ego4d', 'activitynet']: | |
for i, window_i in enumerate(meta["relevant_windows"]): | |
if window_i[1] - window_i[0] < self.clip_len: | |
center = (window_i[1] + window_i[0]) / 2 | |
window_i[0] = max(0, center - 0.5 * self.clip_len) | |
window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len) | |
window_i[1] = max(self.clip_len, window_i[1]) | |
model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2) | |
if 'test' in self.data_path and 'qvhighlights' in self.dset_name: | |
meta["relevant_windows"] = [[0, 150]] | |
relevant_windows = torch.Tensor(meta["relevant_windows"]) | |
# assign the nearest window for each timestamp i.e., qvhighlights. | |
num_vid_seq = model_inputs["timestamp"].shape[0] | |
num_windows = relevant_windows.shape[0] | |
relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len) | |
relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1) | |
model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1) | |
if meta['qid'] is not None: | |
nn_window_ts = torch.zeros_like(model_inputs["timestamp"]) | |
diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0] | |
diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1] | |
assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0)) | |
if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet. | |
nn_window_ts = relevant_windows_ts.squeeze(1) | |
else: | |
nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]] | |
model_inputs["span_labels_nn"] = nn_window_ts | |
model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1]) | |
# for activitynet. | |
if model_inputs["timestamp_window"].sum() < 1: | |
idx = int(meta['relevant_windows'][0][0] / self.clip_len) | |
idx = max(0, min(idx, ctx_l-1)) | |
model_inputs["timestamp_window"][idx] = 1 | |
if self.use_tef: | |
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l | |
tef_ed = tef_st + 1.0 / ctx_l | |
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) | |
if self.use_video: | |
model_inputs["video_feat"] = torch.cat( | |
[model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2) | |
else: | |
model_inputs["video_feat"] = tef | |
if self.load_labels: | |
model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2) | |
if 'saliency_scores' in meta.keys(): | |
# this is for highlight-only task | |
model_inputs["saliency_scores"] = torch.zeros(ctx_l).double() | |
limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None | |
model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1)) | |
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ | |
self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l) | |
# pdb.set_trace() | |
else: | |
model_inputs["saliency_scores"] = model_inputs["timestamp_window"] | |
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ | |
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt | |
model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ] | |
if 'type' in meta.keys(): | |
if meta['type'] == 'point': | |
model_inputs['weight_ablation'] = torch.tensor([0, 0, 1, 0, 0]) | |
if meta['type'] == 'interval': | |
model_inputs['weight_ablation'] = torch.tensor([1, 1, 0, 0, 0]) | |
if meta['type'] == 'curve': | |
model_inputs['weight_ablation'] = torch.tensor([0, 0, 0, 1, 1]) | |
return dict(meta=meta, model_inputs=model_inputs) | |
def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1): | |
gt_st = int(gt_window[0] / self.clip_len) | |
gt_st = min(gt_st, ctx_l-1) | |
gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1) | |
if gt_st > gt_ed: | |
# gt_st = gt_ed | |
gt_ed = gt_st | |
if gt_st != gt_ed: | |
pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n) | |
else: | |
pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st] | |
neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) | |
# neg_clip_indices = random.sample(neg_pool, k=max_n) | |
try: | |
neg_clip_indices = random.sample(neg_pool, k=max_n) | |
except: | |
neg_clip_indices = pos_clip_indices | |
return pos_clip_indices, neg_clip_indices | |
def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1): | |
"""Sum the scores from the three annotations, then take the two clips with the | |
maximum scores as positive, and two with the minimum scores as negative. | |
Args: | |
rel_clip_ids: list(int), list of relevant clip ids | |
scores: list([anno1_score, anno2_score, anno3_score]), | |
ctx_l: int | |
max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. | |
add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. | |
""" | |
# indices inside rel_clip_ids | |
scores = np.array(scores) # (#rel_clips, 3) | |
agg_scores = np.sum(scores, 1) # (#rel_clips, ) | |
sort_indices = np.argsort(agg_scores) # increasing | |
# indices in the whole video | |
# the min(_, ctx_l-1) here is incorrect, but should not cause | |
# much troubles since this should be rarely used. | |
hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] | |
hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] | |
if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]: | |
hard_neg_clip_indices = hard_pos_clip_indices | |
easy_pos_clip_indices = [] | |
easy_neg_clip_indices = [] | |
# pdb.set_trace() | |
if self.add_easy_negative > 0: | |
easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) | |
if len(easy_neg_pool) >= max_n: | |
easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) | |
easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) | |
else: # copy the hard ones | |
easy_pos_clip_indices = hard_pos_clip_indices | |
easy_neg_clip_indices = hard_neg_clip_indices | |
if self.easy_negative_only > 0: | |
return easy_pos_clip_indices, easy_neg_clip_indices | |
pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices | |
neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices | |
return pos_clip_indices, neg_clip_indices | |
def get_span_labels(self, windows, ctx_l): | |
""" | |
windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive) | |
Note a maximum of `self.max_windows` windows are used. | |
returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length | |
""" | |
if len(windows) > self.max_windows: | |
random.shuffle(windows) | |
windows = windows[:self.max_windows] | |
if self.span_loss_type == "l1": | |
windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx | |
windows = span_xx_to_cxw(windows) # normalized windows in cxw | |
elif self.span_loss_type == "ce": | |
windows = torch.Tensor([ | |
[int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1] | |
for w in windows]).long() # inclusive | |
else: | |
raise NotImplementedError | |
return windows | |
def _get_query_feat_by_qid(self, meta): | |
qid = meta['qid'] | |
dset_name = meta['dset_name'] | |
q_feat_suffix = meta['q_feat_suffix'] | |
q_feat_dir = self.q_feat_dir + q_feat_suffix | |
if self.use_cache > 0: | |
try: | |
q_feat = self.txt_cache[qid] | |
except: | |
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) | |
return torch.from_numpy(q_feat) | |
q_feat_path = os.path.join('data', dset_name, q_feat_dir, f"{qid}.npz") | |
try: | |
q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32) | |
except: | |
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) | |
logger.info(f"Something wrong when loading the query feature {q_feat_path}.") | |
if self.q_feat_type == "last_hidden_state": | |
# q_feat = q_feat[:self.max_q_l] | |
q_feat = q_feat | |
if self.normalize_t: | |
q_feat = l2_normalize_np_array(q_feat) | |
if self.txt_drop_ratio > 0: | |
q_feat = self.random_drop_rows(q_feat) | |
return torch.from_numpy(q_feat) # (D, ) or (Lq, D) | |
def random_drop_rows(self, embeddings): | |
"""randomly mask num_drop rows in embeddings to be zero. | |
Args: | |
embeddings: np.ndarray (L, D) | |
""" | |
num_drop_rows = round(len(embeddings) * self.txt_drop_ratio) | |
if num_drop_rows > 0: | |
row_indices = np.random.choice( | |
len(embeddings), size=num_drop_rows, replace=False) | |
embeddings[row_indices] = 0 | |
return embeddings | |
def _get_video_feat_by_vid(self, meta): | |
dset_name = meta['dset_name'] | |
v_feat_suffix = meta['v_feat_suffix'] | |
vid = meta['vid'] | |
v_feat_list = [] | |
for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs): | |
v_feat_dir = _feat_dir + v_feat_suffix | |
if self.use_cache > 0: | |
_feat = self.vid_cache[feat_type][vid] | |
else: | |
_feat_path = os.path.join('data', dset_name, v_feat_dir, f"{vid}.npz") | |
_feat = np.load(_feat_path)["features"].astype(np.float32) | |
if self.normalize_v: | |
_feat = l2_normalize_np_array(_feat) | |
v_feat_list.append(_feat) | |
# some features are slightly longer than the others | |
min_len = min([len(e) for e in v_feat_list]) | |
v_feat_list = [e[:min_len] for e in v_feat_list] | |
v_feat = np.concatenate(v_feat_list, axis=1) | |
return torch.from_numpy(v_feat) # (Lv, D) | |
class DatasetMR(Dataset): | |
Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"] | |
"""One line in data loaded from data_path." | |
{ | |
"qid": 7803, | |
"query": "Man in gray top walks from outside to inside.", | |
"duration": 150, | |
"vid": "RoripwjYFp8_360.0_510.0", | |
"relevant_clip_ids": [13, 14, 15, 16, 17], | |
"relevant_windows": [[26, 36]] | |
} | |
""" | |
def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim, | |
q_feat_type="last_hidden_state", | |
max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video", | |
normalize_v=True, normalize_t=True, load_labels=True, | |
clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0, | |
use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1): | |
self.dset_name = dset_name | |
self.data_path = data_path[0] if isinstance(data_path, list) else data_path | |
self.data_ratio = data_ratio | |
self.v_feat_dirs = v_feat_dirs \ | |
if isinstance(v_feat_dirs, list) else [v_feat_dirs] | |
self.q_feat_dir = q_feat_dir | |
self.q_feat_type = q_feat_type | |
self.v_feat_dim = v_feat_dim | |
self.q_feat_dim = q_feat_dim | |
self.max_q_l = max_q_l | |
self.max_v_l = max_v_l | |
self.ctx_mode = ctx_mode | |
self.use_tef = "tef" in ctx_mode | |
self.use_video = "video" in ctx_mode | |
self.normalize_t = normalize_t | |
self.normalize_v = normalize_v | |
self.load_labels = load_labels | |
self.clip_len = clip_len | |
self.fix_len = fix_len | |
self.max_windows = max_windows # maximum number of windows to use as labels | |
self.span_loss_type = span_loss_type | |
self.txt_drop_ratio = txt_drop_ratio | |
self.use_cache = use_cache | |
self.add_easy_negative = add_easy_negative | |
self.easy_negative_only = easy_negative_only | |
if "val" in data_path or "test" in data_path: | |
assert txt_drop_ratio == 0 | |
# checks | |
assert q_feat_type in self.Q_FEAT_TYPES | |
# data | |
self.data = self.load_data() | |
self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs] | |
t_feat_type = q_feat_dir.split('/')[-1] | |
if self.use_cache > 0: | |
print('Loading the off-line features...') | |
dset_dir = os.path.join('data', self.dset_name) | |
vid_keys = [meta['vid'] for meta in self.data] | |
qid_keys = [meta['qid'] for meta in self.data] | |
self.vid_cache = {} | |
for v_feat_type in self.v_feat_types: | |
assert 'vid' in v_feat_type | |
with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f: | |
self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)} | |
assert 'txt' in t_feat_type | |
self.txt_cache = {} | |
with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f: | |
for key in tqdm(qid_keys): | |
try: | |
self.txt_cache[key] = f[str(key)][:] | |
except: | |
logger.info(f"text {key} is not in the cache.") | |
def load_data(self): | |
datalist = load_jsonl(self.data_path) | |
if self.data_ratio != 1: | |
n_examples = int(len(datalist) * self.data_ratio) | |
datalist = datalist[:n_examples] | |
logger.info("Using {}% of the data: {} examples" | |
.format(self.data_ratio * 100, n_examples)) | |
return datalist | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
meta = self.data[index] | |
model_inputs = dict() | |
model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq) | |
if self.use_video: | |
model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv) | |
ctx_l = len(model_inputs["video_feat"]) | |
else: | |
ctx_l = self.max_v_l | |
if self.dset_name in ['hacs', 'ego4d', 'videocc', 'activitynet']: | |
for i, window_i in enumerate(meta["relevant_windows"]): | |
if window_i[1] - window_i[0] < self.clip_len: | |
center = (window_i[1] + window_i[0]) / 2 | |
window_i[0] = max(0, center - 0.5 * self.clip_len) | |
window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len) | |
window_i[1] = max(self.clip_len, window_i[1]) | |
model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2) | |
if 'test' in self.data_path and 'qvhighlights' in self.dset_name: | |
meta["relevant_windows"] = [[0, 150]] | |
relevant_windows = torch.Tensor(meta["relevant_windows"]) | |
# assign the nearest window for each timestamp i.e., qvhighlights. | |
num_vid_seq = model_inputs["timestamp"].shape[0] | |
num_windows = relevant_windows.shape[0] | |
relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len) | |
relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1) | |
model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1) | |
if meta['qid'] is not None: | |
nn_window_ts = torch.zeros_like(model_inputs["timestamp"]) | |
diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0] | |
diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1] | |
assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0)) | |
if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet. | |
nn_window_ts = relevant_windows_ts.squeeze(1) | |
else: | |
nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]] | |
model_inputs["span_labels_nn"] = nn_window_ts | |
model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1]) | |
# for activitynet. | |
if model_inputs["timestamp_window"].sum() < 1: | |
idx = int(meta['relevant_windows'][0][0] / self.clip_len) | |
idx = max(0, min(idx, ctx_l-1)) | |
model_inputs["timestamp_window"][idx] = 1 | |
if self.use_tef: | |
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l | |
tef_ed = tef_st + 1.0 / ctx_l | |
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) | |
if self.use_video: | |
model_inputs["video_feat"] = torch.cat( | |
[model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2) | |
else: | |
model_inputs["video_feat"] = tef | |
if self.load_labels: | |
model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2) | |
if 'saliency_scores' in meta.keys(): | |
model_inputs["saliency_scores"] = torch.zeros(ctx_l).double() | |
limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None | |
model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1)) | |
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ | |
self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l) | |
else: | |
model_inputs["saliency_scores"] = model_inputs["timestamp_window"] | |
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ | |
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt | |
model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ] | |
return dict(meta=meta, model_inputs=model_inputs) | |
def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1): | |
gt_st = int(gt_window[0] / self.clip_len) | |
gt_st = min(gt_st, ctx_l-1) | |
gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1) | |
if gt_st > gt_ed: | |
gt_ed = gt_st | |
if gt_st != gt_ed: | |
pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n) | |
else: | |
pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st] | |
neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) | |
try: | |
neg_clip_indices = random.sample(neg_pool, k=max_n) | |
except: | |
neg_clip_indices = pos_clip_indices | |
return pos_clip_indices, neg_clip_indices | |
def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1): | |
"""Sum the scores from the three annotations, then take the two clips with the | |
maximum scores as positive, and two with the minimum scores as negative. | |
Args: | |
rel_clip_ids: list(int), list of relevant clip ids | |
scores: list([anno1_score, anno2_score, anno3_score]), | |
ctx_l: int | |
max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. | |
add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. | |
""" | |
# indices inside rel_clip_ids | |
scores = np.array(scores) # (#rel_clips, 3) | |
agg_scores = np.sum(scores, 1) # (#rel_clips, ) | |
sort_indices = np.argsort(agg_scores) # increasing | |
# indices in the whole video | |
# the min(_, ctx_l-1) here is incorrect, but should not cause | |
# much troubles since this should be rarely used. | |
hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] | |
hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] | |
if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]: | |
hard_neg_clip_indices = hard_pos_clip_indices | |
easy_pos_clip_indices = [] | |
easy_neg_clip_indices = [] | |
if self.add_easy_negative > 0: | |
easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) | |
if len(easy_neg_pool) >= max_n: | |
easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) | |
easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) | |
else: # copy the hard ones | |
easy_pos_clip_indices = hard_pos_clip_indices | |
easy_neg_clip_indices = hard_neg_clip_indices | |
if self.easy_negative_only > 0: | |
return easy_pos_clip_indices, easy_neg_clip_indices | |
pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices | |
neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices | |
return pos_clip_indices, neg_clip_indices | |
def get_span_labels(self, windows, ctx_l): | |
""" | |
windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive) | |
Note a maximum of `self.max_windows` windows are used. | |
returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length | |
""" | |
if len(windows) > self.max_windows: | |
random.shuffle(windows) | |
windows = windows[:self.max_windows] | |
if self.span_loss_type == "l1": | |
windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx | |
windows = span_xx_to_cxw(windows) # normalized windows in cxw | |
elif self.span_loss_type == "ce": | |
windows = torch.Tensor([ | |
[int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1] | |
for w in windows]).long() # inclusive | |
else: | |
raise NotImplementedError | |
return windows | |
def _get_query_feat_by_qid(self, qid): | |
if self.use_cache > 0: | |
try: | |
q_feat = self.txt_cache[qid] | |
except: | |
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) | |
return torch.from_numpy(q_feat) | |
q_feat_path = join(self.q_feat_dir, f"{qid}.npz") | |
try: | |
q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32) | |
except: | |
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) | |
logger.info(f"Something wrong when loading the query feature {q_feat_path}.") | |
if self.q_feat_type == "last_hidden_state": | |
# q_feat = q_feat[:self.max_q_l] | |
q_feat = q_feat | |
if self.normalize_t: | |
q_feat = l2_normalize_np_array(q_feat) | |
if self.txt_drop_ratio > 0: | |
q_feat = self.random_drop_rows(q_feat) | |
return torch.from_numpy(q_feat) # (D, ) or (Lq, D) | |
def random_drop_rows(self, embeddings): | |
"""randomly mask num_drop rows in embeddings to be zero. | |
Args: | |
embeddings: np.ndarray (L, D) | |
""" | |
num_drop_rows = round(len(embeddings) * self.txt_drop_ratio) | |
if num_drop_rows > 0: | |
row_indices = np.random.choice( | |
len(embeddings), size=num_drop_rows, replace=False) | |
embeddings[row_indices] = 0 | |
return embeddings | |
def _get_video_feat_by_vid(self, vid): | |
v_feat_list = [] | |
for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs): | |
if self.use_cache > 0: | |
_feat = self.vid_cache[feat_type][vid] | |
else: | |
_feat_path = join(_feat_dir, f"{vid}.npz") | |
_feat = np.load(_feat_path)["features"].astype(np.float32) | |
# _feat = np.load(_feat_path)["features"][:self.max_v_l].astype(np.float32) | |
if self.normalize_v: | |
_feat = l2_normalize_np_array(_feat) | |
v_feat_list.append(_feat) | |
# some features are slightly longer than the others | |
min_len = min([len(e) for e in v_feat_list]) | |
v_feat_list = [e[:min_len] for e in v_feat_list] | |
v_feat = np.concatenate(v_feat_list, axis=1) | |
return torch.from_numpy(v_feat) # (Lv, D) | |
class DatasetHL(Dataset): | |
def __init__(self, | |
dset_name, | |
domain, | |
data_path, | |
v_feat_types, | |
v_feat_dirs, | |
t_feat_dir, | |
use_tef=False | |
): | |
assert dset_name in ['tvsum', 'youtube'] | |
self.dset_name = dset_name | |
dset_domain = {'tvsum': TVSUM_SPLITS, | |
'youtube': YOUTUBE_SPLITS} | |
self.splits = dset_domain[dset_name] | |
assert domain in self.splits.keys() | |
self.domain = domain | |
assert len(data_path) == 1 | |
self.data_path = data_path[0] if isinstance(data_path, list) else data_path | |
self.v_feat_types = v_feat_types.split('_') | |
self.v_feat_dirs = v_feat_dirs | |
self.q_feat_type = "last_hidden_state" | |
self.q_feat_dir = t_feat_dir | |
self.txt_drop_ratio = 0 | |
self.normalize_t = True | |
self.normalize_v = True | |
self.label = nncore.load(self.data_path) | |
self.use_tef = use_tef | |
self.video_id = { | |
k: [s for s in self.splits[domain][k] if s in self.label] | |
for k in ('train', 'val') | |
} | |
self.set_state('train') | |
def __len__(self): | |
return len(self.video_id[self.state]) | |
def __getitem__(self, idx): | |
vid = self.get_video_id(idx) | |
video = self._get_video_feat_by_vid(vid) | |
saliency = self.get_saliency(idx) | |
if self.dset_name == 'youtube': | |
saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())]) | |
elif self.dset_name == 'tvsum': | |
saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())]) | |
# saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency != min(saliency))[0].tolist())]) | |
else: | |
raise NotImplementedError | |
num_clips = min(c.size(0) for c in (video, saliency)) | |
video = video[:num_clips] | |
saliency = saliency[:num_clips] | |
if self.use_tef: | |
ctx_l = video.shape[0] | |
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l | |
tef_ed = tef_st + 1.0 / ctx_l | |
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) | |
video = torch.cat([video, tef], dim=1) # (Lv, Dv+2) | |
data = dict( | |
video=DataContainer(video), | |
saliency=DataContainer(saliency, pad_value=-1), | |
saliency_pos_labels=saliency_pos_labels) | |
if self.q_feat_dir is not None: | |
query = self._get_query_feat_by_qid(vid) | |
data['query'] = DataContainer(query, pad_value=float('inf')) | |
return data | |
def set_state(self, state): | |
self.state = 'train' if state == 'train' else 'val' | |
def get_video_id(self, idx): | |
return self.video_id[self.state][idx] | |
def get_video(self, idx): | |
video_id = self.get_video_id(idx) | |
video = torch.from_numpy(self.video[video_id]).float() | |
optic = torch.from_numpy(self.optic[video_id]).float() | |
return torch.cat((video, optic), dim=1) | |
def _get_video_feat_by_vid(self, vid): | |
v_feat_list = [] | |
for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs): | |
# if self.use_cache > 0: | |
# _feat = self.vid_cache[feat_type][vid] | |
# else: | |
if True: | |
_feat_path = join(_feat_dir, f"{vid}.npz") | |
_feat = np.load(_feat_path)["features"].astype(np.float32) | |
if self.normalize_v: | |
_feat = l2_normalize_np_array(_feat) | |
v_feat_list.append(_feat) | |
# some features are slightly longer than the others | |
min_len = min([len(e) for e in v_feat_list]) | |
v_feat_list = [e[:min_len] for e in v_feat_list] | |
v_feat = np.concatenate(v_feat_list, axis=1) | |
return torch.from_numpy(v_feat) # (Lv, D) | |
def _get_query_feat_by_qid(self, qid): | |
# if self.use_cache > 0: | |
# try: | |
# q_feat = self.txt_cache[qid] | |
# except: | |
# q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) | |
# return torch.from_numpy(q_feat) | |
q_feat_path = join(self.q_feat_dir, f"{qid}.npz") | |
try: | |
q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32) | |
except: | |
q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) | |
logger.info(f"Something wrong when loading the query feature {q_feat_path}.") | |
if self.q_feat_type == "last_hidden_state": | |
# q_feat = q_feat[:self.max_q_l] | |
q_feat = q_feat | |
if self.normalize_t: | |
q_feat = l2_normalize_np_array(q_feat) | |
if self.txt_drop_ratio > 0: | |
q_feat = self.random_drop_rows(q_feat) | |
return torch.from_numpy(q_feat) # (D, ) or (Lq, D) | |
def get_saliency(self, idx): | |
if self.dset_name == 'tvsum': | |
video_id = self.get_video_id(idx) | |
saliency = torch.Tensor(self.label[video_id]['anno']) | |
# top-5 saliency scores as a threshold. | |
# saliency_tmp = saliency.mean(1) | |
# topk = int(saliency_tmp.shape[0] * 0.1) | |
# th = saliency_tmp[torch.sort(saliency_tmp)[1][-topk]] # v4 | |
# saliency = saliency_tmp - th | |
# saliency_tmp = saliency.mean(1) # med | |
# th = saliency_tmp.median() | |
# saliency = saliency_tmp - th | |
saliency = (saliency - saliency.mean()).mean(dim=1) | |
# saliency = (saliency.sum(dim=1) - 20) / 80 # v2 | |
elif self.dset_name == 'youtube': | |
video_id = self.get_video_id(idx) | |
saliency = [1 if s > 0 else 0 for s in self.label[video_id]['match']] | |
else: | |
raise NotImplementedError | |
return torch.Tensor(saliency) | |
def evaluate(self, blob, k=5, save_dir=None, **kwargs): | |
# blob = nncore.to_dict_of_list(blob) | |
collected = [] | |
if save_dir is not None: | |
import json | |
with open(os.path.join(save_dir, self.dset_name, self.domain +'.jsonl'), 'w') as f: | |
for idx, score in enumerate(blob): | |
video_id = self.get_video_id(idx) | |
entry = {'vid':video_id, 'pred': score[0].tolist(), 'gt': self.get_saliency(idx).tolist(), | |
'duration': int(self.label[video_id]['frames']) / int(self.label[video_id]['fps']), | |
'domain': self.label[video_id]['domain'], 'fps': self.label[video_id]['fps']} | |
if self.dset_name == 'tvsum': | |
entry.update({'title':self.label[video_id]['title']}) | |
if self.dset_name == 'youtube': | |
entry.update({'clip':self.label[video_id]['clip']}) | |
f.write(json.dumps(entry) + '\n') | |
if self.dset_name == 'tvsum': | |
for i in range(20): | |
video_ap = [] | |
for idx, score in enumerate(blob): | |
inds = torch.argsort(score[0], descending=True) | |
video_id = self.get_video_id(idx) | |
label = torch.Tensor(self.label[video_id]['anno'])[:, i] | |
label = torch.where(label > label.median(), 1.0, .0) | |
label = label[inds].tolist()[:k] | |
if (num_gt := sum(label)) == 0: | |
video_ap.append(0) | |
continue | |
hits = ap = rec = 0 | |
prc = 1 | |
for j, gt in enumerate(label): | |
hits += gt | |
_rec = hits / num_gt | |
_prc = hits / (j + 1) | |
ap += (_rec - rec) * (prc + _prc) / 2 | |
rec, prc = _rec, _prc | |
video_ap.append(ap) | |
collected.append(sum(video_ap) / len(video_ap)) | |
elif self.dset_name == 'youtube': | |
for idx, score in enumerate(blob): | |
inds = torch.argsort(score[0], descending=True) | |
label = self.get_saliency(idx)[inds].tolist() | |
if (num_gt := sum(label)) == 0: | |
collected.append(0) | |
continue | |
hits = ap = rec = 0 | |
prc = 1 | |
for i, gt in enumerate(label): | |
hits += gt | |
_rec = hits / num_gt | |
_prc = hits / (i + 1) | |
ap += (_rec - rec) * (prc + _prc) / 2 | |
rec, prc = _rec, _prc | |
collected.append(ap) | |
else: | |
raise NotImplementedError | |
mean_ap = sum(collected) / len(collected) | |
results = dict(mAP=round(mean_ap, 5)) | |
return results | |
class DatasetQFVS(Dataset): | |
def __init__(self,config, use_tef=True): | |
# pdb.set_trace() | |
self.config=config | |
self.dataset=[] | |
self.use_tef=use_tef | |
self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl") | |
for video_id in self.config["train_videos"]: | |
for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)): | |
for file in files: | |
self.dataset.append(file[:file.find("_oracle.txt")]+"_"+str(video_id)) | |
def __getitem__(self,index): | |
video_id=self.dataset[index].split('_')[2] | |
feat_type = self.config['vid_feature'] | |
# pdb.set_trace() | |
feat_type = self.config['vid_feature'] | |
f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') | |
features=f['feature'][()] | |
# dim=features.shape[-1] | |
# features=features.reshape(-1, dim) | |
# seg_len=f['seg_len'][()] | |
dim = features.shape[-1] | |
ctx_l = features.shape[0] | |
seg_len = np.ones(ctx_l) | |
# mask = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool) | |
# for j in range(len(seg_len)): | |
# for k in range(seg_len[j]): | |
# mask[j][k] = 1 | |
# ctx_l = seg_len.sum() | |
features = torch.from_numpy(features) | |
# features = features[mask, :] | |
if self.use_tef: | |
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l | |
tef_ed = tef_st + 1.0 / ctx_l | |
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) | |
features = torch.cat([features, tef], dim=1) # (Lv, Dv+2) | |
transfer={"Cupglass":"Glass", | |
"Musicalinstrument":"Instrument", | |
"Petsanimal":"Animal"} | |
concept1,concept2=self.dataset[index].split('_')[0:2] | |
concept1_GT=torch.zeros(ctx_l) | |
concept2_GT=torch.zeros(ctx_l) | |
with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f: | |
lines=f.readlines() | |
for index,line in enumerate(lines): | |
concepts=line.strip().split(',') | |
if concept1 in concepts: | |
concept1_GT[index]=1 | |
if concept2 in concepts: | |
concept2_GT[index]=1 | |
# shot_num=seg_len.sum() | |
# mask_GT=torch.zeros(ctx_l) | |
# for i in range(shot_num): | |
# mask_GT[i]=1 | |
mask_GT=torch.ones(ctx_l) | |
oracle_summary = torch.zeros(ctx_l) | |
GT_summary_shots = [] | |
with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f: | |
for line in f.readlines(): | |
GT_summary_shots.append(int(line.strip())) | |
GT_summary_shots = [x - 1 for x in GT_summary_shots] | |
for element in GT_summary_shots: | |
oracle_summary[element] = 1 | |
if concept1 in transfer: | |
concept1=transfer[concept1] | |
if concept2 in transfer: | |
concept2=transfer[concept2] | |
concept1=self.embedding[concept1] | |
concept2=self.embedding[concept2] | |
try: | |
saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())]) | |
except: | |
saliency_pos_labels_1 = torch.Tensor(0) | |
try: | |
saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())]) | |
except: | |
saliency_pos_labels_2 = torch.Tensor(0) | |
try: | |
saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())]) | |
except: | |
saliency_pos_labels_oracle = torch.Tensor(0) | |
return { | |
'features':features, | |
'seg_len':torch.from_numpy(seg_len), | |
'concept1_GT':concept1_GT, | |
'concept2_GT':concept2_GT, | |
'mask_GT':mask_GT, | |
'oracle_summary':oracle_summary, | |
'tokens_pad1':torch.from_numpy(concept1), | |
'tokens_pad2':torch.from_numpy(concept2), | |
'saliency_pos_labels_1': saliency_pos_labels_1, | |
'saliency_pos_labels_2': saliency_pos_labels_2, | |
'saliency_pos_labels_oracle': saliency_pos_labels_oracle, | |
} | |
def __len__(self): | |
return len(self.dataset) | |
def start_end_collate_mr(batch): | |
batch_meta = [e["meta"] for e in batch] # seems no need to collate ? | |
model_inputs_keys = batch[0]["model_inputs"].keys() | |
batched_data = dict() | |
for k in model_inputs_keys: | |
if k == "span_labels": | |
batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch] | |
continue | |
if k in ["saliency_pos_labels", "saliency_neg_labels"]: | |
batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch]) | |
continue | |
batched_data[k] = pad_sequences_1d( | |
[e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None) | |
return batch_meta, batched_data | |
def start_end_collate_hl(batch): | |
model_inputs_keys = batch[0].keys() | |
batched_data = dict() | |
for k in model_inputs_keys: | |
batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None) | |
return batched_data | |
def start_end_collate_qfvs(batch): | |
model_inputs_keys = batch[0].keys() | |
batched_data = dict() | |
for k in model_inputs_keys: | |
batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None) | |
return batched_data | |
def prepare_batch_inputs_mr(batched_model_inputs, device, non_blocking=False): | |
model_inputs = dict( | |
src_txt=batched_model_inputs["query_feat"][0].to(device, non_blocking=non_blocking), | |
src_txt_mask=batched_model_inputs["query_feat"][1].to(device, non_blocking=non_blocking), | |
src_vid=batched_model_inputs["video_feat"][0].to(device, non_blocking=non_blocking), | |
src_vid_mask=batched_model_inputs["video_feat"][1].to(device, non_blocking=non_blocking), | |
) | |
targets = {} | |
targets['timestamp'] = batched_model_inputs["timestamp"][0].to(device, non_blocking=non_blocking) | |
targets['timestamp_mask'] = batched_model_inputs["timestamp"][1].to(device, non_blocking=non_blocking) | |
targets['timestamp_window'] = batched_model_inputs["timestamp_window"][0].to(device, non_blocking=non_blocking) | |
targets['span_labels_nn'] = batched_model_inputs["span_labels_nn"][0].to(device, non_blocking=non_blocking) | |
if 'saliency_scores' in batched_model_inputs.keys(): | |
targets['saliency_scores'] = batched_model_inputs["saliency_scores"][0].to(device, non_blocking=non_blocking) | |
if "span_labels" in batched_model_inputs: | |
targets["span_labels"] = [ | |
dict(spans=e["spans"].to(device, non_blocking=non_blocking)) | |
for e in batched_model_inputs["span_labels"] | |
] | |
if "saliency_pos_labels" in batched_model_inputs: | |
for name in ["saliency_pos_labels", "saliency_neg_labels"]: | |
targets[name] = batched_model_inputs[name].to(device, non_blocking=non_blocking) | |
if "weight_ablation" in batched_model_inputs: | |
targets["weight_ablation"] = batched_model_inputs["weight_ablation"][0].to(device, non_blocking=non_blocking) | |
targets = None if len(targets) == 0 else targets | |
return model_inputs, targets | |
def prepare_batch_inputs_hl(batched_model_inputs, device='cuda', non_blocking=False): | |
src_vid = batched_model_inputs['video'][0].to(device, non_blocking=non_blocking) | |
src_vid_mask = batched_model_inputs['video'][1].bool().to(device, non_blocking=non_blocking) | |
src_txt = batched_model_inputs['query'][0].to(device, non_blocking=non_blocking) \ | |
if 'query' in batched_model_inputs.keys() else None | |
src_txt_mask = batched_model_inputs['query'][1].bool().to(device, non_blocking=non_blocking) \ | |
if 'query' in batched_model_inputs.keys() else None | |
model_inputs = dict( | |
src_vid=src_vid, src_vid_mask=src_vid_mask, | |
src_txt=src_txt, src_txt_mask=src_txt_mask) | |
# if 'audio' in batched_model_inputs.keys(): | |
# src_aud = batched_model_inputs['audio'][0].bool().to(device, non_blocking=non_blocking) | |
# src_aud_mask = batched_model_inputs['audio'][1].bool().to(device, non_blocking=non_blocking) | |
# model_inputs['src_aud']=src_aud; model_inputs['src_aud_mask']=src_aud_mask; | |
targets = {} | |
saliency = batched_model_inputs['saliency'][0].to(device, non_blocking=non_blocking) | |
saliency_pos_labels = batched_model_inputs['saliency_pos_labels'][0].to(device, non_blocking=non_blocking) | |
targets['saliency_scores'] = saliency | |
targets['saliency_pos_labels'] = saliency_pos_labels.long() | |
targets['timestamp_mask'] = batched_model_inputs["video"][1].to(device, non_blocking=non_blocking) | |
targets['timestamp_window'] = 1 * (saliency > 0) | |
return model_inputs, targets | |
def prepare_batch_inputs_qfvs(data, config, eval=False): | |
if not eval: | |
features, mask, seg_len, \ | |
concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \ | |
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\ | |
saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \ | |
data['features'][0], data['features'][1], data['seg_len'][0],\ | |
data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\ | |
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \ | |
data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0], | |
else: | |
features, mask, seg_len, \ | |
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \ | |
data['features'][0], data['features'][1], data['seg_len'][0],\ | |
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1] | |
# preprocess for vid input. | |
seq = features.to('cuda') | |
mask = mask.to('cuda') | |
# for txt input. | |
src_txt_1 = src_txt_1.to(torch.float32).to('cuda') | |
src_txt_2 = src_txt_2.to(torch.float32).to('cuda') | |
src_txt_mask_1 = src_txt_mask_1.to('cuda') | |
src_txt_mask_2 = src_txt_mask_2.to('cuda') | |
src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda') | |
src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda') | |
model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1) | |
model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2) | |
model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle) | |
if not eval: | |
targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda')) | |
targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda')) | |
targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda')) | |
targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda') | |
targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda') | |
targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda') | |
return model_inputs_1, model_inputs_2, model_inputs_oracle, \ | |
targets_1, targets_2, targets_oracle, mask_GT | |
else: | |
return model_inputs_1, model_inputs_2, model_inputs_oracle, mask | |