diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..f3a1426b661451e292ecef2a42ec695b8565df8b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+examples/charades.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/ego4d.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/youtube.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..87099eaefc08f7b967b350f6a2120385e29f742c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,236 @@
+import os
+import pdb
+import time
+import torch
+import gradio as gr
+import numpy as np
+import argparse
+import subprocess
+from run_on_video import clip, vid2clip, txt2clip
+
+parser = argparse.ArgumentParser(description='')
+parser.add_argument('--save_dir', type=str, default='./tmp')
+parser.add_argument('--resume', type=str, default='./results/omni/model_best.ckpt')
+parser.add_argument("--gpu_id", type=int, default=2)
+args = parser.parse_args()
+os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
+
+#################################
+model_version = "ViT-B/32"
+output_feat_size = 512
+clip_len = 2
+overwrite = True
+num_decoding_thread = 4
+half_precision = False
+
+clip_model, _ = clip.load(model_version, device=args.gpu_id, jit=False)
+
+import logging
+import torch.backends.cudnn as cudnn
+from main.config import TestOptions, setup_model
+from utils.basic_utils import l2_normalize_np_array
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def load_model():
+ logger.info("Setup config, data and model...")
+ opt = TestOptions().parse(args)
+ # pdb.set_trace()
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ if opt.lr_warmup > 0:
+ total_steps = opt.n_epoch
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
+ opt.lr_warmup = [warmup_steps, total_steps]
+
+ model, criterion, _, _ = setup_model(opt)
+ return model
+
+vtg_model = load_model()
+
+def convert_to_hms(seconds):
+ return time.strftime('%H:%M:%S', time.gmtime(seconds))
+
+def load_data(save_dir):
+ vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32)
+ txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32)
+
+ vid = torch.from_numpy(l2_normalize_np_array(vid))
+ txt = torch.from_numpy(l2_normalize_np_array(txt))
+ clip_len = 2
+ ctx_l = vid.shape[0]
+
+ timestamp = ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
+
+ if True:
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ tef_ed = tef_st + 1.0 / ctx_l
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
+ vid = torch.cat([vid, tef], dim=1) # (Lv, Dv+2)
+
+ src_vid = vid.unsqueeze(0).cuda()
+ src_txt = txt.unsqueeze(0).cuda()
+ src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda()
+ src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda()
+
+ return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l
+
+def forward(model, save_dir, query):
+ src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir)
+ src_vid = src_vid.cuda(args.gpu_id)
+ src_txt = src_txt.cuda(args.gpu_id)
+ src_vid_mask = src_vid_mask.cuda(args.gpu_id)
+ src_txt_mask = src_txt_mask.cuda(args.gpu_id)
+
+ with torch.no_grad():
+ output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask)
+
+ # prepare the model prediction
+ pred_logits = output['pred_logits'][0].cpu()
+ pred_spans = output['pred_spans'][0].cpu()
+ pred_saliency = output['saliency_scores'].cpu()
+
+ # prepare the model prediction
+ pred_windows = (pred_spans + timestamp) * ctx_l * clip_len
+ pred_confidence = pred_logits
+
+ # grounding
+ top1_window = pred_windows[torch.argmax(pred_confidence)].tolist()
+ top5_values, top5_indices = torch.topk(pred_confidence.flatten(), k=5)
+ top5_windows = pred_windows[top5_indices].tolist()
+
+ # print(f"The video duration is {convert_to_hms(src_vid.shape[1]*clip_len)}.")
+ q_response = f"For query: {query}"
+
+ mr_res = " - ".join([convert_to_hms(int(i)) for i in top1_window])
+ mr_response = f"The Top-1 interval is: {mr_res}"
+
+ hl_res = convert_to_hms(torch.argmax(pred_saliency) * clip_len)
+ hl_response = f"The Top-1 highlight is: {hl_res}"
+ return '\n'.join([q_response, mr_response, hl_response])
+
+def extract_vid(vid_path, state):
+ history = state['messages']
+ vid_features = vid2clip(clip_model, vid_path, args.save_dir)
+ history.append({"role": "user", "content": "Finish extracting video features."})
+ history.append({"role": "system", "content": "Please Enter the text query."})
+ chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history),2)]
+ return '', chat_messages, state
+
+def extract_txt(txt):
+ txt_features = txt2clip(clip_model, txt, args.save_dir)
+ return
+
+def download_video(url, save_dir='./examples', size=768):
+ save_path = f'{save_dir}/{url}.mp4'
+ cmd = f'yt-dlp -S ext:mp4:m4a --throttled-rate 5M -f "best[width<={size}][height<={size}]" --output {save_path} --merge-output-format mp4 https://www.youtube.com/embed/{url}'
+ if not os.path.exists(save_path):
+ try:
+ subprocess.call(cmd, shell=True)
+ except:
+ return None
+ return save_path
+
+def get_empty_state():
+ return {"total_tokens": 0, "messages": []}
+
+def submit_message(prompt, state):
+ history = state['messages']
+
+ if not prompt:
+ return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state
+
+ prompt_msg = { "role": "user", "content": prompt }
+
+ try:
+ history.append(prompt_msg)
+ # answer = vlogger.chat2video(prompt)
+ # answer = prompt
+ extract_txt(prompt)
+ answer = forward(vtg_model, args.save_dir, prompt)
+ history.append({"role": "system", "content": answer})
+
+ except Exception as e:
+ history.append(prompt_msg)
+ history.append({
+ "role": "system",
+ "content": f"Error: {e}"
+ })
+
+ chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)]
+ return '', chat_messages, state
+
+
+def clear_conversation():
+ return gr.update(value=None, visible=True), gr.update(value=None, interactive=True), None, gr.update(value=None, visible=True), get_empty_state()
+
+
+def subvid_fn(vid):
+ save_path = download_video(vid)
+ return gr.update(value=save_path)
+
+
+css = """
+ #col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
+ #video_inp {min-height: 100px}
+ #chatbox {min-height: 100px;}
+ #header {text-align: center;}
+ #hint {font-size: 1.0em; padding: 0.5em; margin: 0;}
+ .message { font-size: 1.2em; }
+ """
+
+with gr.Blocks(css=css) as demo:
+
+ state = gr.State(get_empty_state())
+
+
+ with gr.Column(elem_id="col-container"):
+ gr.Markdown("""## 🤖️ UniVTG: Towards Unified Video-Language Temporal Grounding
+ Given a video and text query, return relevant window and highlight.""",
+ elem_id="header")
+
+ with gr.Row():
+ with gr.Column():
+ video_inp = gr.Video(label="video_input")
+ gr.Markdown("👋 **Step1**: Select a video in Examples (bottom) or input youtube video_id in this textbox, *e.g.* *G7zJK6lcbyU* for https://www.youtube.com/watch?v=G7zJK6lcbyU", elem_id="hint")
+ with gr.Row():
+ video_id = gr.Textbox(value="", placeholder="Youtube video url", show_label=False)
+ vidsub_btn = gr.Button("(Optional) Submit Youtube id")
+
+ with gr.Column():
+ vid_ext = gr.Button("Step2: Extract video feature, may takes a while")
+ # vlog_outp = gr.Textbox(label="Document output", lines=40)
+ total_tokens_str = gr.Markdown(elem_id="total_tokens_str")
+
+ chatbot = gr.Chatbot(elem_id="chatbox")
+ input_message = gr.Textbox(show_label=False, placeholder="Enter text query and press enter", visible=True).style(container=False)
+ btn_submit = gr.Button("Step3: Enter your text query")
+ btn_clear_conversation = gr.Button("🔃 Clear")
+
+ examples = gr.Examples(
+ examples=[
+ ["./examples/youtube.mp4"],
+ ["./examples/charades.mp4"],
+ ["./examples/ego4d.mp4"],
+ ],
+ inputs=[video_inp],
+ )
+
+ gr.HTML('''
You can duplicate this Space to skip the queue:
''')
+
+ btn_submit.click(submit_message, [input_message, state], [input_message, chatbot])
+ input_message.submit(submit_message, [input_message, state], [input_message, chatbot])
+ # btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, vlog_outp, state])
+ btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, state])
+ vid_ext.click(extract_vid, [video_inp, state], [input_message, chatbot])
+ vidsub_btn.click(subvid_fn, [video_id], [video_inp])
+
+ demo.load(queur=False)
+
+
+demo.queue(concurrency_count=10)
+demo.launch(height='800px', server_port=2253, debug=True, share=True)
diff --git a/examples/charades.mp4 b/examples/charades.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0d77e221ed025160d916a817c96e0016dbd6cea7
--- /dev/null
+++ b/examples/charades.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fa3d1ba99bf28103844e1313cc5543b7c626d87c42a1c18108c2a69479a6d679
+size 1301669
diff --git a/examples/ego4d.mp4 b/examples/ego4d.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7b2617ddde3a942056ba39fc71ced4a5041f8898
--- /dev/null
+++ b/examples/ego4d.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf1271d42415c793e659bebbd48394326cc50e970d44e6fdd0af5dfb4cb4ede4
+size 28306388
diff --git a/examples/youtube.mp4 b/examples/youtube.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3f214974cd37b2cb8d41df55431dda54fd731d06
--- /dev/null
+++ b/examples/youtube.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1dd6b483e5346a777b5d6448460c5e30b8fe46aa1133cf6bba94c84dd7262b49
+size 47353846
diff --git a/main/__init__.py b/main/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/main/_train_qfvs.py b/main/_train_qfvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0f5522cb0a721e19970e4cb64409db0900a0691
--- /dev/null
+++ b/main/_train_qfvs.py
@@ -0,0 +1,293 @@
+import os
+import pdb
+import time
+import json
+import pprint
+import random
+import importlib
+import numpy as np
+from tqdm import tqdm, trange
+from collections import defaultdict
+
+import h5py
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+import sys
+sys.path.append('/data/home/qinghonglin/univtg')
+from main.config import BaseOptions, setup_model
+from main.dataset import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs
+from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle
+from utils.model_utils import count_parameters
+from eval.qfvs import calculate_semantic_matching, load_videos_tag
+
+import logging
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def eval_epoch(model, config, opt):
+ model.eval()
+ f1_sum = 0; p_sum = 0; r_sum = 0
+
+ assert len(config['test_videos']) == 1
+ video_id = config['test_videos'][0]
+ embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl")
+
+ feat_type = config['vid_feature']
+ feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r')
+ features = torch.tensor(feat['feature'][()]).unsqueeze(0).cuda()
+ # pdb.set_trace()
+ # seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda()
+
+ # dim = features.shape[-1]
+ # ctx_l = seg_len.sum().cpu()
+
+ dim = features.shape[-1]
+ ctx_l = features.shape[1]
+ seg_len = torch.ones(ctx_l)
+ features = features.reshape(-1, dim)[:ctx_l]
+
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ tef_ed = tef_st + 1.0 / ctx_l
+ tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2)
+ features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
+
+ transfer = {"Cupglass": "Glass",
+ "Musicalinstrument": "Instrument",
+ "Petsanimal": "Animal"}
+
+ for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
+ evaluation_num=len(files)
+ for file in files:
+ summaries_GT=[]
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f:
+ for line in f.readlines():
+ summaries_GT.append(int(line.strip()))
+
+ concept1, concept2 = file.split('_')[0:2]
+
+ ##############
+ if concept1 in transfer:
+ concept1 = transfer[concept1]
+ if concept2 in transfer:
+ concept2 = transfer[concept2]
+ concept1 = embedding[concept1]
+ concept2 = embedding[concept2]
+
+ data = {
+ 'features':features,
+ 'seg_len': seg_len,
+ 'tokens_pad1':torch.from_numpy(concept1),
+ 'tokens_pad2':torch.from_numpy(concept2),
+ }
+
+ input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True)
+
+ summaries_GT = [x - 1 for x in summaries_GT]
+ video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat")
+
+
+ output_type = 'pred_logits' # only saliency.
+ # if opt.f_loss_coef == 0:
+ # output_type = 'saliency_scores' # only saliency.
+ # elif opt.s_loss_intra_coef == 0:
+ # output_type = 'pred_logits' # cls is default.
+ # else:
+ # output_type = ['pred_logits', 'saliency_scores']
+
+ # if opt.qfvs_score_multiple > 0:
+ # output_type = ['pred_logits', 'saliency_scores']
+
+ with torch.no_grad():
+ if not isinstance(output_type, list):
+ score1 = model(**input1)[output_type].squeeze()
+ # score1 = score1.masked_select(mask)
+ score2 = model(**input2)[output_type].squeeze()
+ # score2 = score2.masked_select(mask)
+
+ score = model(**input_oracle)[output_type].squeeze()
+ # score = score.masked_select(mask)
+ else:
+ score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda()
+ for output_t in output_type:
+ # score1 *= model(**input1)[output_t].squeeze() #.masked_select(mask)
+ # score2 *= model(**input2)[output_t].squeeze() #.masked_select(mask)
+ # score *= model(**input_oracle)[output_t].squeeze() #.masked_select(mask)
+ score1 += model(**input1)[output_t].squeeze() #.masked_select(mask)
+ score2 += model(**input2)[output_t].squeeze() #.masked_select(mask)
+ score += model(**input_oracle)[output_t].squeeze() #.masked_select(mask)
+
+ score = score
+ # score = score + score1 + score2
+
+ # since video4 features dim is greater than video_shots_tag.
+ score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])]
+ _, top_index = score.topk(int(score.shape[0] * config["top_percent"]))
+ p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1)
+ f1_sum+=f1; r_sum+=r; p_sum+=p
+
+ return {'F': round(100* f1_sum/evaluation_num,2) ,
+ 'R': round(100* r_sum/evaluation_num,2) ,
+ 'P': round(100* p_sum/evaluation_num,2) }
+
+def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer):
+ model.train()
+ criterion.train()
+
+ # init meters
+ time_meters = defaultdict(AverageMeter)
+ loss_meters = defaultdict(AverageMeter)
+
+ timer_dataloading = time.time()
+ loss_total = 0
+
+ # optimizer.zero_grad()
+ for batch_idx, batch in enumerate(tqdm(train_loader)):
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
+ timer_start = time.time()
+ model_input1, model_input2, model_input_oracle, \
+ model_gt1, model_gt2, model_gt_oracle, \
+ mask_GT = prepare_batch_inputs_qfvs(batch, config)
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ output1 = model(**model_input1)
+ output2 = model(**model_input2)
+ output_oracle = model(**model_input_oracle)
+
+ loss_dict = {}
+ loss_dict1 = criterion(output1, model_gt1)
+ loss_dict2 = criterion(output2, model_gt2)
+ loss_dict3 = criterion(output_oracle, model_gt_oracle)
+
+ weight_dict = criterion.weight_dict
+ for k in loss_dict1.keys():
+ loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k]
+
+ # print(loss_dict)
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+ loss_total += losses.item()
+
+ time_meters["model_forward_time"].update(time.time() - timer_start)
+ timer_start = time.time()
+ # optimizer.zero_grad()
+ optimizer.zero_grad()
+ losses.backward()
+ if opt.grad_clip > 0:
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
+ # if ((batch_idx + 1) % opt.bsz==0) or (batch_idx == len(train_loader)-1):
+ # pdb.set_trace()
+ # optimizer.step()
+ # optimizer.zero_grad()
+ optimizer.step()
+ time_meters["model_backward_time"].update(time.time() - timer_start)
+
+ timer_dataloading = time.time()
+ return round(loss_total / len(train_loader), 2)
+
+# train in single domain.
+def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config):
+ if opt.device.type == "cuda":
+ logger.info("CUDA enabled.")
+ model.to(opt.device)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
+
+ prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0}
+ if opt.start_epoch is None:
+ start_epoch = -1 if opt.eval_init else 0
+ else:
+ start_epoch = opt.start_epoch
+
+ val_score = eval_epoch(model, config, opt)
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0)
+ logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
+ if epoch_i > -1:
+ loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer)
+ lr_scheduler.step()
+ eval_epoch_interval = opt.eval_epoch
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
+ with torch.no_grad():
+ val_score = eval_epoch(model, config, opt)
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1)
+ logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
+
+ if prev_best_score['Fscore'] < val_score['F']:
+ prev_best_score['Fscore'] = val_score['F']
+ prev_best_score['Precision'] = val_score['P']
+ prev_best_score['Recall'] = val_score['R']
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt"))
+ tb_writer.close()
+ return prev_best_score
+
+def start_training():
+ logger.info("Setup config, data and model...")
+ opt = BaseOptions().parse()
+ set_seed(opt.seed)
+
+ config = load_json("./main/config_qfvs.json")
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+
+ # key -> test video; value -> training videos.
+ qfvs_split = {1: [2, 3, 4],
+ 2: [1, 3, 4],
+ 3: [1, 2, 4],
+ 4: [1, 2, 3]}
+ # qfvs_split = {
+ # 2: [1, 3, 4],
+ # 3: [1, 2, 4],
+ # }
+
+ scores_videos = {}
+ for test_id, splits in qfvs_split.items():
+ logger.info(f"Start Training {opt.dset_name}: {test_id}")
+ config['train_videos'] = qfvs_split[test_id]
+ config['test_videos'] = [test_id]
+ train_dataset = DatasetQFVS(config)
+ train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers)
+
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
+ count_parameters(model)
+ best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config)
+ scores_videos['V'+str(test_id)] = best_score
+
+ # save the final results.
+ avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos)
+ avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos)
+ avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos)
+ scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall}
+
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
+ save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False)
+
+ tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1)
+ tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None))
+ tb_writer.close()
+
+ print(scores_videos)
+ return
+
+if __name__ == '__main__':
+ start_training()
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
\ No newline at end of file
diff --git a/main/config.py b/main/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..40eab1902681354b755102bbfbccd15976a6e9b6
--- /dev/null
+++ b/main/config.py
@@ -0,0 +1,378 @@
+import os
+import pdb
+import time
+import torch
+import logging
+import argparse
+import importlib
+from utils.basic_utils import mkdirp, remkdirp, \
+ load_json, save_json, make_zipfile, dict_to_markdown
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+class BaseOptions(object):
+ saved_option_filename = "opt.json"
+ ckpt_filename = "model.ckpt"
+ tensorboard_log_dir = "tensorboard_log"
+ train_log_filename = "train.log.txt"
+ eval_log_filename = "eval.log.txt"
+
+ def __init__(self):
+ self.parser = None
+ self.initialized = False
+ self.opt = None
+
+ def initialize(self):
+ self.initialized = True
+ parser = argparse.ArgumentParser()
+ # * Running configs
+ parser.add_argument("--dset_type", type=str, choices=["mr", "hl", "vs", "vlp"]) # moment retrieval, highlight detection, and video summarization
+ parser.add_argument("--dset_name", type=str, choices=["qvhighlights", "charades", "anet", "tvsum", "youtube", "summe", "ego4d", "qfvs", "video2gif", "coin", "hacs", "vlp", "videocc", "tacos"])
+ parser.add_argument("--domain_name", type=str, default=None)
+ parser.add_argument("--model_id", type=str, default="moment_detr")
+ parser.add_argument("--exp_id", type=str, default="debug", help="id of this run, required at training")
+ parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
+ parser.add_argument("--gpu_id", type=int, default=0)
+ parser.add_argument("--debug", action="store_true",
+ help="debug (fast) mode, break all loops, do not load all data into memory.")
+ parser.add_argument("--seed", type=int, default=2018, help="random seed")
+
+ # * DDP
+ parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
+
+
+ parser.add_argument("--eval_split_name", type=str, default="val",
+ help="should match keys in video_duration_idx_path, must set for VCMR")
+ parser.add_argument("--data_ratio", type=float, default=1.0,
+ help="how many training and eval data to use. 1.0: use all, 0.1: use 10%."
+ "Use small portion for debug purposes. Note this is different from --debug, "
+ "which works by breaking the loops, typically they are not used together.")
+ parser.add_argument("--results_root", type=str, default="results")
+ parser.add_argument("--num_workers", type=int, default=0,
+ help="num subprocesses used to load the data, 0: use main process")
+ parser.add_argument("--no_pin_memory", action="store_true",
+ help="Don't use pin_memory=True for dataloader. "
+ "ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4")
+
+ # * Training configs
+ parser.add_argument("--bsz", type=int, default=32, help="mini-batch size")
+ parser.add_argument("--n_epoch", type=int, default=200, help="number of epochs to run")
+ parser.add_argument("--max_es_cnt", type=int, default=200,
+ help="number of epochs to early stop, use -1 to disable early stop")
+ parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
+ parser.add_argument("--lr_drop", type=int, default=400, help="drop learning rate to 1/10 every lr_drop epochs")
+ parser.add_argument("--lr_gamma", type=float, default=0.1, help="lr reduces the gamma times after the `drop' epoch")
+ parser.add_argument("--lr_warmup", type=float, default=-1, help="linear warmup scheme")
+ parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
+ parser.add_argument("--grad_clip", type=float, default=0.1, help="perform gradient clip, -1: disable")
+
+ # ** Loss coefficients
+ # *** boundary branch
+ parser.add_argument("--span_loss_type", default="l1", type=str, choices=['l1', 'ce'],
+ help="l1: (center-x, width) regression. ce: (st_idx, ed_idx) classification.")
+ parser.add_argument('--b_loss_coef', default=10, type=float) # boundary regression e.g., l1
+ parser.add_argument('--g_loss_coef', default=1, type=float) # giou loss
+ # *** foreground branch
+ parser.add_argument('--eos_coef', default=0.1, type=float, help="relative classification weight of the no-object class")
+ parser.add_argument('--f_loss_coef', default=4, type=float) # cls loss for foreground
+ # *** saliency branch
+ parser.add_argument("--s_loss_intra_coef", type=float, default=1., help="inter-video (frame-level) saliency loss e.g. momentdetr saliency loss")
+ parser.add_argument("--s_loss_inter_coef", type=float, default=0., help="intra-video (sample-level) saliency loss,")
+
+ # * Eval configs
+ parser.add_argument("--main_metric", type=str, default="MR-full-mAP")
+ parser.add_argument('--eval_mode', default=None, type=str,
+ help="how to integrate foreground and saliency for better prediction")
+ parser.add_argument("--eval_bsz", type=int, default=100,
+ help="mini-batch size at inference, for query")
+ parser.add_argument("--eval_epoch", type=int, default=5,
+ help="number of epochs for once inference")
+ parser.add_argument("--eval_init", action="store_true", help="evaluate model before training i.e. `epoch=-1'")
+ parser.add_argument("--save_interval", type=int, default=50)
+
+ parser.add_argument("--resume", type=str, default=None,
+ help="checkpoint path to resume or evaluate, without --resume_all this only load weights")
+ parser.add_argument("--resume_dir", type=str, default=None,
+ help="checkpoint path to resume or evaluate, without --resume_all this only load weights")
+ parser.add_argument("--resume_all", action="store_true",
+ help="if --resume_all, load optimizer/scheduler/epoch as well")
+ parser.add_argument("--start_epoch", type=int, default=None,
+ help="if None, will be set automatically when using --resume_all")
+
+ # ** NMS configs
+ parser.add_argument("--no_sort_results", action="store_true",
+ help="do not sort results, use this for moment query visualization")
+ parser.add_argument("--max_before_nms", type=int, default=10)
+ parser.add_argument("--max_after_nms", type=int, default=10)
+ parser.add_argument("--conf_thd", type=float, default=0.0, help="only keep windows with conf >= conf_thd")
+ parser.add_argument("--nms_thd", type=float, default=-1,
+ help="additionally use non-maximum suppression "
+ "(or non-minimum suppression for distance)"
+ "to post-processing the predictions. "
+ "-1: do not use nms. [0, 1]")
+
+ # * Dataset configs
+ parser.add_argument("--use_cache", type=int, default=-1, help="Preload features into cache for fast IO")
+ parser.add_argument("--max_q_l", type=int, default=75)
+ parser.add_argument("--max_v_l", type=int, default=75)
+ parser.add_argument("--clip_length", type=float, default=1.0)
+ parser.add_argument("--clip_len_list", type=int, nargs='+')
+ parser.add_argument("--max_windows", type=int, default=5)
+
+ parser.add_argument("--add_easy_negative", type=int, default=1)
+ parser.add_argument("--easy_negative_only", type=int, default=1)
+ parser.add_argument("--round_multiple", type=int, default=1)
+
+ parser.add_argument("--train_path", type=str, default=None, nargs='+')
+ parser.add_argument("--eval_path", type=str, default=None,
+ help="Evaluating during training, for Dev set. If None, will only do training, ")
+ parser.add_argument("--train_path_list", type=str, nargs='+')
+ parser.add_argument("--eval_path_list", type=str, nargs='+')
+ parser.add_argument("--feat_root_list", type=str, nargs='+')
+
+ parser.add_argument("--no_norm_vfeat", action="store_true", help="Do not do normalize video feat")
+ parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalize text feat")
+ parser.add_argument("--v_feat_dirs", type=str, nargs="+",
+ help="video feature dirs. If more than one, will concat their features. "
+ "Note that sub ctx features are also accepted here.")
+ parser.add_argument("--t_feat_dir", type=str, help="text/query feature dir")
+ parser.add_argument("--v_feat_dim", type=int, help="video feature dim")
+ parser.add_argument("--t_feat_dim", type=int, help="text/query feature dim")
+ parser.add_argument("--ctx_mode", type=str, default="video_tef")
+ parser.add_argument("--v_feat_types", type=str)
+ parser.add_argument("--t_feat_type", type=str)
+
+ # * Model configs
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
+ help="Type of positional embedding to use on top of the image features")
+ parser.add_argument("--n_input_proj", type=int, default=2, help="#layers to vid/txt projector")
+ parser.add_argument("--temperature", type=float, default=0.07, help="temperature nce contrastive_align_loss")
+
+ # ** Transformer
+ parser.add_argument('--enc_layers', default=4, type=int,
+ help="Number of encoding layers in the transformer")
+ parser.add_argument('--sub_enc_layers', default=2, type=int,
+ help="Number of encoding layers in the video / text transformer in albef-style.")
+ parser.add_argument('--dec_layers', default=2, type=int,
+ help="Number of decoding layers in the transformer, N/A for UniVTG")
+ parser.add_argument('--dim_feedforward', default=1024, type=int,
+ help="Intermediate size of the feedforward layers in the transformer blocks")
+ parser.add_argument('--hidden_dim', default=256, type=int,
+ help="Size of the embeddings (dimension of the transformer)")
+ parser.add_argument('--input_dropout', default=0.5, type=float,
+ help="Dropout applied in input")
+ parser.add_argument('--dropout', default=0.1, type=float,
+ help="Dropout applied in the transformer")
+ parser.add_argument('--droppath', default=0.1, type=float,
+ help="Droppath applied in the transformer")
+ parser.add_argument("--txt_drop_ratio", default=0, type=float,
+ help="drop txt_drop_ratio tokens from text input. 0.1=10%")
+ parser.add_argument("--use_txt_pos", action="store_true", help="use position_embedding for text as well.")
+ parser.add_argument('--nheads', default=8, type=int,
+ help="Number of attention heads inside the transformer's attentions")
+ parser.add_argument('--num_queries', default=10, type=int,
+ help="Number of query slots")
+ parser.add_argument('--pre_norm', action='store_true')
+
+ # ** momentdetr configs e.g. Matcher, saliency margin
+ parser.add_argument('--set_cost_span', default=10, type=float,
+ help="L1 span coefficient in the matching cost")
+ parser.add_argument('--set_cost_giou', default=1, type=float,
+ help="giou span coefficient in the matching cost")
+ parser.add_argument('--set_cost_class', default=4, type=float,
+ help="Class coefficient in the matching cost")
+ parser.add_argument("--saliency_margin", type=float, default=0.2)
+ parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_true',
+ help="Disables auxiliary decoding losses (loss at each layer)")
+
+ # * Query-Force Video Summarization
+ parser.add_argument("--max_segment_num", type=int, default=20)
+ parser.add_argument("--max_frame_num", type=int, default=200)
+ parser.add_argument("--top_percent", type=float, default=0.02)
+
+ parser.add_argument("--qfvs_vid_feature", type=str, default='fps1')
+ parser.add_argument("--qfvs_txt_feature", type=str, default='query')
+ parser.add_argument("--qfvs_split", type=int, default=-1)
+
+ parser.add_argument("--qfvs_dense_shot", type=int, default=-1)
+ parser.add_argument("--qfvs_score_ensemble", type=int, default=-1)
+ parser.add_argument("--qfvs_score_gather", type=int, default=-1)
+ parser.add_argument("--qfvs_loss_gather", type=int, default=-1)
+ self.parser = parser
+
+ def display_save(self, opt):
+ args = vars(opt)
+ # Display settings
+ print(dict_to_markdown(vars(opt), max_str_len=120))
+ # Save settings
+ if not isinstance(self, TestOptions):
+ option_file_path = os.path.join(opt.results_dir, self.saved_option_filename) # not yaml file indeed
+ save_json(args, option_file_path, save_pretty=True)
+
+ def parse(self, args=None):
+ if not self.initialized:
+ self.initialize()
+ opt = self.parser.parse_args()
+
+ if args is not None:
+ args_dict = vars(args)
+ opt_dict = vars(opt)
+ for key, value in args_dict.items():
+ opt_dict[key] = value
+ opt = argparse.Namespace(**opt_dict)
+ opt.model_dir = os.path.dirname(opt.resume)
+ torch.cuda.set_device(opt.gpu_id)
+
+ if opt.debug:
+ opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ])
+ opt.num_workers = 0
+
+ if isinstance(self, TestOptions):
+ # modify model_dir to absolute path
+ # opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir)
+ opt.model_dir = os.path.dirname(opt.resume)
+ saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename))
+ for arg in saved_options: # use saved options to overwrite all BaseOptions args.
+ if arg not in ["results_root", "num_workers", "nms_thd", "debug", "max_before_nms", "max_after_nms"
+ "max_pred_l", "min_pred_l", "gpu_id",
+ "resume", "resume_all", "no_sort_results",
+ "eval_path", "eval_split_name"]:
+ # "dset_name", "v_feat_dirs", "t_feat_dir"]:
+ setattr(opt, arg, saved_options[arg])
+ # opt.no_core_driver = True
+ if opt.eval_results_dir is not None:
+ opt.results_dir = opt.eval_results_dir
+ else:
+ if opt.exp_id is None:
+ raise ValueError("--exp_id is required for at a training option!")
+
+ # ctx_str = opt.ctx_mode + "_sub" if any(["sub_ctx" in p for p in opt.v_feat_dirs]) else opt.ctx_mode
+
+ if 'debug' not in opt.exp_id:
+ opt.results_dir = os.path.join(opt.results_root, "-".join([opt.dset_type, opt.dset_name]), "-".join([opt.exp_id, opt.v_feat_types, opt.t_feat_type, time.strftime("%Y_%m_%d_%H")]))
+ else:
+ opt.results_dir = os.path.join(opt.results_root, "-".join([opt.dset_type, opt.dset_name]), opt.exp_id) # debug mode.
+
+ if int(opt.local_rank) in [0, -1]:
+ # mkdirp(opt.results_dir)
+ remkdirp(opt.results_dir) # remove dir and remkdir it.
+
+ # save a copy of current code
+ code_dir = os.path.dirname(os.path.realpath(__file__))
+ code_zip_filename = os.path.join(opt.results_dir, "code.zip")
+ make_zipfile(code_dir, code_zip_filename,
+ enclosing_dir="code",
+ exclude_dirs_substring="results",
+ exclude_dirs=["results", "debug_results", "__pycache__"],
+ exclude_extensions=[".pyc", ".ipynb", ".swap"], )
+
+ if int(opt.local_rank) in [0, -1]:
+ self.display_save(opt)
+ opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename)
+ opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename)
+ opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename)
+ opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir)
+ # opt.device = torch.device("cuda" if opt.device >= 0 else "cpu")
+
+ if int(opt.local_rank) in [-1]:
+ torch.cuda.set_device(opt.gpu_id)
+ opt.pin_memory = not opt.no_pin_memory
+
+ if opt.local_rank == -1:
+ torch.cuda.set_device(opt.gpu_id)
+
+ opt.use_tef = "tef" in opt.ctx_mode
+ opt.use_video = "video" in opt.ctx_mode
+ if not opt.use_video:
+ opt.v_feat_dim = 0
+ if opt.use_tef:
+ opt.v_feat_dim += 2
+
+ self.opt = opt
+ return opt
+
+class TestOptions(BaseOptions):
+ """add additional options for evaluating"""
+
+ def initialize(self):
+ BaseOptions.initialize(self)
+ # also need to specify --eval_split_name
+ self.parser.add_argument("--eval_id", type=str, help="evaluation id")
+ self.parser.add_argument("--eval_results_dir", type=str, default=None,
+ help="dir to save results, if not set, fall back to training results_dir")
+ self.parser.add_argument("--model_dir", type=str,
+ help="dir contains the model file, will be converted to absolute path afterwards")
+
+class WarmupStepLR(torch.optim.lr_scheduler.StepLR):
+ def __init__(self, optimizer, warmup_steps, step_size, gamma=0.1, last_epoch=-1):
+ self.warmup_steps = warmup_steps
+ self.step_size = step_size
+ self.gamma = gamma
+ super(WarmupStepLR, self).__init__(optimizer, step_size, gamma=self.gamma, last_epoch=last_epoch)
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ import warnings
+ warnings.warn("To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`.", DeprecationWarning)
+ # e.g. warmup_steps = 10, case: 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21...
+ if self.last_epoch == self.warmup_steps or(self.last_epoch % self.step_size != 0 and self.last_epoch > self.warmup_steps):
+ return [group['lr'] for group in self.optimizer.param_groups]
+ # e.g. warmup_steps = 10, case: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
+ elif self.last_epoch < self.warmup_steps:
+ return [group['initial_lr'] * float(self.last_epoch + 1) / float(self.warmup_steps) for group in self.optimizer.param_groups]
+
+
+ # e.g. warmup_steps = 10, case: 10, 20, 30, 40...
+ return [group['lr'] * self.gamma
+ for group in self.optimizer.param_groups]
+ def _get_closed_form_lr(self):
+ if self.last_epoch <= self.warmup_steps:
+ return [base_lr * float(self.last_epoch) / (self.warmup_steps) for base_lr in self.base_lrs]
+ else:
+ return [base_lr * self.gamma ** ((self.last_epoch - self.warmup_steps)// self.step_size) for base_lr in self.base_lrs]
+
+def setup_model(opt):
+ """setup model/optimizer/scheduler and load checkpoints when needed"""
+ logger.info("setup model/optimizer/scheduler")
+
+ importer = importlib.import_module('.'.join(['model', opt.model_id]))
+ model, criterion = importer.build_model(opt)
+
+ if int(opt.device) >= 0:
+ logger.info("CUDA enabled.")
+ model.to(opt.gpu_id)
+ criterion.to(opt.gpu_id)
+
+ param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}]
+ optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd)
+
+ if opt.lr_warmup != -1 and opt.lr_drop > 0:
+ lr_scheduler = WarmupStepLR(optimizer, warmup_steps=opt.lr_warmup[0], step_size=opt.lr_drop, gamma=opt.lr_gamma)
+
+ elif opt.lr_warmup != -1:
+ from transformers import get_constant_schedule_with_warmup
+ lr_scheduler = get_constant_schedule_with_warmup(optimizer, opt.lr_warmup[0])
+
+ elif opt.lr_drop > 0:
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop, gamma=opt.lr_gamma)
+
+ if opt.resume is not None:
+ logger.info(f"Load checkpoint from {opt.resume}")
+ checkpoint = torch.load(opt.resume, map_location="cpu")
+
+ for key in list(checkpoint["model"].keys()):
+ checkpoint["model"][key.replace('module.', '')] = checkpoint["model"].pop(key)
+ model.load_state_dict(checkpoint["model"])
+
+ if opt.resume_all:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ opt.start_epoch = checkpoint['epoch'] + 1
+ logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}")
+ else:
+ logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path")
+
+ return model, criterion, optimizer, lr_scheduler
diff --git a/main/config_hl.py b/main/config_hl.py
new file mode 100644
index 0000000000000000000000000000000000000000..853049a62973aa1b5774257b21d97d2f2a9fdef2
--- /dev/null
+++ b/main/config_hl.py
@@ -0,0 +1,190 @@
+# Copyright (c) THL A29 Limited, a Tencent company. All rights reserved.
+
+YOUTUBE_SPLITS = {
+ 'dog': {
+ 'train': [
+ 'BsjTtq337mM', 'eGCD1F74iy8', 'x2Za-t1yHtI', 'iyYiqa0QZXM',
+ 'azy9ijU6f9I', 'NNtSZ6cPiwA', 'U9CBalvFfbM', 'AZDkqJaOgJU',
+ '-olTgMPAyMI', 'i35F1Ec3Ats', '6bS6-GVLBeM', 'ZGszTEn28v8',
+ 'EEb8iSMqwj4', 'p2hYGNkRMCw', '3kbptPDIz4U', 'iLHRqR-M9HQ',
+ 'zyooMDuAgCA', 'dOVsQ63N0gg', '7H_qqQvPUzY', 'Z5BEFsaYIS4',
+ 'iWO6io44-Fs', 'vVmGisWK0QI', 'L10kN7Btk90', '2yql1mvWbDs',
+ 'Iu2nbtr_Uuk', 'NSmOKAauZpM', 'PAhQGoURAro', 'uJ81Us4mBOc',
+ '1krGVyfIaOw', 'p9yW6FxsrJ4', 'DLGRJfpGmCQ', '0XTXKe2TOAg',
+ 'qpc4OSqeV7I', 'q_PJFuBOk7k', '0Uu53hCnKQ4', '-szRD9kyNug',
+ 'rUPxwWmJYpg', 'hseONiKKx_8', 'BLaQcOcDfjo', 'nW5JulWYEc8',
+ 'rMvH1SMGwwI', 'l6KlvTJkTgk', 'O8j4U3NjNvs', '8AJTZeEeStk'
+ ],
+ 'val': [
+ 'a2nj7XCo2Rk', '9rP5yF9EC3Y', 'OxSsRZqPfyk', 'bZzP2MieC1c',
+ 'PcvdX5OVgfQ', 'p0oxRJD1GUk', 'msjK8nHZHZ0', 'hSRyclcZyGM',
+ 'dlH2K9N_jSM', 'OCVXhRG2fEA', 'MkBdHvXPocc', 'yN7h90Y-04g',
+ 'PWqLJKZeBC8', '9D_Q8l_ruQk', 'Mp8Pz86J660', '1gjntnYm8NA',
+ 'O3XxuutEvoo', 'wf_qlAizlSM', 'fXx44D1sqUw', 'P0MnXh6bnKk',
+ 'sTd06idFa0E', 'ppNjl3I3iJs', 'Om5mczkpcVg', 'xZIN_s-qhbU'
+ ]
+ },
+ 'gymnastics': {
+ 'train': [
+ 'Wfv90YJ2YtA', 'MbD5OIR9yWc', 'fZwCJWkC_Qw', 'AyRI1CioQfY',
+ 'xV_5YCdVqSM', '19UO7T32DJI', 'o2gAP2Clg_s', 'ewyfAOrBzjQ',
+ 'CMTKpA683Ig', 'aNjphhjTgqs', 'dmJ0Nq4DF2w', '57IQ6EudvGU',
+ 'BAlUYtPUsVI', '_UU4XqYVDqE', 'Kq4OhBiQk_E', 'D6nyvx9kEac',
+ 'g-m4-zeCisU', '_45vTFtcduE', '9L-Pocc_u70', '0636XaURL-A',
+ 'GCabQyaHSMg', 'vUi1Scb35fQ', 'eK-Yuoou_1I', 'kkS7TgNZwJI',
+ '2EFkINKg3nA', 'eKvALYDh7RU', 'Hyp3Hpk6dyA', '9rpzf3sgQkw',
+ 'kHNAnpewyeo', 'ydQij10qrZM', '41u2V_ZAKto', '6NSWsMKAgEU',
+ 'kUs_yUR-C2k', 'bs3ZBcfhvKA'
+ ],
+ 'val': [
+ '2AuigNFEsTM', 'rPsKpHKzUso', 'tzq5cJQ9NQA', 'DyZ0gZ5xmxI',
+ 'PEKRfJYYEgU', 'affAIVH9uRA', 'FT7yIi3-tG0', 'T_zWyrVzyvw',
+ 'RoiLzMA_ilA', 'nBZiGSccsTg', 'z3cNtOMKK7A', 'EwQ-aMK2sKg',
+ 'Rq0BpciuvBM', 's6LNwTThBgs', '-hE9v3izo4c', 'KldEfRhv7H0',
+ 'eUyuw2J5FaE', 'E0aRE1_ea8E', 'BU7YlQAOBkM', 'iDJM9j11U-c',
+ 'zr5LSPMBpiI', 'NAfBa7lqg2Q', 'eB4Toq9dUWs', 'YPd7RDN5CkE',
+ '86YLsw7efDM', 'iQRMMFiYAUw', 'lzEhLAPxZyQ', 'PAjJbT1DRnY'
+ ]
+ },
+ 'parkour': {
+ 'train': [
+ 'qz1UnnxlWhI', 'MzODICzycHs', '0swXWs9yWA4', 'Nnv22OW_PaI',
+ 'LUhZJLY2uKc', 'yZz8z1l3XJU', '3dvjtdMC2ls', 'e27ppPer9XY',
+ 'HJNn2WlKFhM', 'j4OxlxnapNI', 'rhABvn7VjSQ', '3PCwXpwYqLs',
+ 'LECL1bIpi5w', 'w0ouP79iZWc', 'z6aKQPMJUC0', 'kATlFTwxBVY',
+ '3SM6a8eyuVA', 'v-Sfc4COqRQ', '64eu8pwuIUE', '7WKm0XDk3og',
+ '2F5Sc0Jgk4g'
+ ],
+ 'val': [
+ 'TFdbCRkVeIA', 'uGLs9atTvNc', 'qlGPuopK3CI', 'ucTkpjZO_o4',
+ '4-4BgyGphLQ', '08k4ysX_XJE', '6sMNnWqa_as', 'oT6g0I2Ok9o',
+ 'Be4IlnKeBOo', 'yUjJq0kvxcw', 'fLek7GRIxjE'
+ ]
+ },
+ 'skating': {
+ 'train': [
+ '7owXLUkpoNY', '1OLM0_Jzt5M', 'b1LXb0Sbiy0', '3fGux6-ttlA',
+ 'HQvRun80GyA', 'a8M-5nTrll8', 'bA3CxZllhsI', 'AUAsfZtcB4E',
+ 'FG57uCJvQLw', 'jXIuv5uFPTI', 'eG-hdYLoS98', '2SdJBl251PU',
+ '2PHJqqrGC80', 'EtZkkFhniRw', 'jUiwyguxzIw', 'FL6mXlaF78Q',
+ 'BdemklZtYWI', 'ATk_ncI1-BA', '4wiKDfq3X8U', 'BN7GBjVlFTo',
+ 'JiMZvMkkbRo', '2DIXYkSnRf4', 'dZ3i-HuhQXM', '7jZydh62m8M'
+ ],
+ 'val': [
+ '2oOe2_Ew6Ao', 'DGcO0QgcXtw', 'ixsKaNplm6o', '7TQbqKWjLcI',
+ 'CQZNrEstSag', 'g1WbAIzkw80', '4cyx1VpDjc4', 'BGZaaqFjoRY',
+ 'AJ98A2y1dVw', '1n7Afe5AZCM', '8x8ESK5MnR0'
+ ]
+ },
+ 'skiing': {
+ 'train': [
+ '6Usy87KaF-A', 'DtjKkp_4KDQ', '4Wt7TM2wDxI', 'iKnzSGFwdbc',
+ 'nALCc6HPQNs', 'WL4TA--CVcA', 'dFrfsgW1M98', 'x6qmrVojcYc',
+ 'pvcmQ9J_BYw', 'S3VEYFAP_pk', 'pU57a3jYMEk', '33TrLdo3ook',
+ 'xLhHU8uo2aY', 'fAHBmka6Psc', '9HYzZk5kiJA', 'T0gjqYbeU1g',
+ '7o628W-bFy0', 'YKDm_PCa-HM', 'R3DV2zDnNqg', 'NCe9YeXTvHo',
+ '5tXxvscmZ-Y', 'thNiPQLbi5w', '1TtJy8cSzqA', 'zDRzOsmwa08',
+ 'gCI4gArPjNA', 'uw0i26NHucs', '1giAsZC_ywQ', 'OvgaPTfEnqo',
+ 'bFD_p5znoq4', 'uKmqaAvjKgw', '5ivw_sdCTCU', 'iwCSAYGwPq4',
+ 'HmmOPntPlRA', 'FHCEyiM-NoY', 'EUSFMmoE_jI', 'igvSxtdsT8w',
+ 'zEgMYFiEaX4', '0K2FKccDp9A', 'tdyz6h4ZtYs', 'PO7GEbi2z3c',
+ 'mmiu7rRmSAU', 'qL6Kic-CdTo', '0fNCsOY1WGk', 'V3J26hr1ZSE',
+ 'GS-qBunN3B4', 'ZLNvg8025Nw', 'puAxGH6aWMY', 'h-SlvHubhs8',
+ 'AdovZ4OAS8I', 'UDvA1XMa1m4', 'qdo3d7mR_9s', 'qAinbyORWIw',
+ 'v1JpJueAElY', 'TjH29fdjcqI', 'f76B1uucoyo', 'DNPPDcOd5eQ',
+ '-GX95udKKm8', 'YRO_RQ3aBgg', '1ptV2E7lm9U', 'qa7dtf1Qcew',
+ '_UJTkqYNrpA', 'md14DNKq2_o', 'tpewrb9dDyo', 'yGoWYi_dHLY',
+ 'DZ3NRjDHwy8', 'aMFcEuJUqpk', '6fT9KLuE7no', 'lPdQMMAuOZo'
+ ],
+ 'val': [
+ 'SSlv7qJK5zA', '_BYqZjuKpKA', 'ZueaKXReGjU', 'mGST8ZekCZc',
+ 'JJSu7Lh9rvs', 'IyoD3G5igY0', 'MXyv-Ut9HRg', 'Z8X9WIojH1U',
+ 'vT33-8KUb2Q', 'HW6_sPym938', '9wtXO2lF6hM', 'mRdthCqe6Nk',
+ 'RGxiOb9hlS0', 'ruySf5zL7Kw', 'I7wFmP6P7p0', '0AHkDElk3ws',
+ 'zqXd4EgUFhE', '91lDbBHUx0w', 'iaHbK6ogafc', 'jRbst8kjWW8',
+ 'drHPy6wSZGs', '5VaY6LgIqDs', 'bXq9rRSbI3c', 'hjZLa2DTuqs',
+ 'Ka2qcp3jmWo', 'ZnA4-ggkFu8', 'iXdt4v42mbs', '8aWN-0NZErI',
+ '09v0HNf81J0', 'YJCR2q-WRhQ', 'RjagI4pAUpw', '_10CbYdTG5M',
+ 'lhgmIgzBQxs', '2pstGBM4p0w', 'b53-VPsWom4', 'x-G4r153n6o',
+ 'qBbqK5qlVSM', 'XamrS9XyHuQ', 'u_n7jMS1vlw', 'AO6p0jlOd6U',
+ 'm-W-lcTkBQ0', 'bMuyPVIlXW8', 'kAAvTAKkIy4', 'U6vnbCurZQA',
+ 'dHE8q7sZ70U', 'w7fzLVRPSUc', 'FLYkD7zHuHQ', 'nhOhI24P7dM',
+ 'n5q2KhfoiWw', '7Hcyse0h9HE', '6_BPy_VaPSY'
+ ]
+ },
+ 'surfing': {
+ 'train': [
+ 'Ai9FwQGn5ds', 'hBl0Sm3_auw', 'LMxMeg407Vg', 'D3fk8doVui4',
+ 'Y9pxmLg6ti8', 'p_JsivYdbgQ', 'UokX-hcXQeo', 'VYe5QfM5ecE',
+ 'I48VJ92ouTQ', 'Tn-ebtUnq6E', 'eWae-nWocPU', '-Yamat_0tbw',
+ 'c2Fy-rdXJy4', 'xQ4NAp4vWbI', 'g9kXCIjIjoE', 'A96Jx6gv6_4',
+ 'e427qElqqN0', 'tTcA5hiViPo', 'wMdXzj_3aA0', 'fqNzMz1n6uA',
+ 'jKVOA7RFCUo', 'TJBJrk9iPPA', '_C8EjMxrS2s', 'yj7abHfZTQQ',
+ 'NDcqgpsyWaU', 'UJjwoivaGNo', 'GZ_XS8EnnWo', 'kJUBIcBjUZ0',
+ 'lWoLyR7lDAU', 'FilbyF_PGjI', 'fapRkcOe4vE', 't05r50PQqww',
+ 'QgStLppe610', '2TY8Q2WXUyk', '9y_ED3DyNhE', 'CGwtinVGkVU',
+ 'nOuRhrAMaIw', 'UN4TwjDajtQ', '-FHmVZWWgcE', 'ksx0_BfpsLg',
+ 'agOBPDsQrTM', 'XqggBwFOmFU', 'orNzj1J8i-4', '6ZbTCHwt1gk',
+ '0un3wh_pQAc', '4u6OURBLZDs', 'us0agAKuvEM', 'mVQYl7Q-TQs',
+ 'cB2SdlGHLMQ', 'WK5t4To0zlA', 'NNEuH_juUHI', 'KTU7xfVOat0',
+ 'Y1nhbNaY1ZY', 'YlXJnZe575s', 'SH7Ns0ANzJU', '3TbZfeokCkE'
+ ],
+ 'val': [
+ 'o0on6yIXJQE', '4RsZz_8d8Ro', 'p8VUjcZyK70', '0P2PZXUa0Bg',
+ 'p2eU5z647Mw', 'mSVxaAJcNJQ', 'bcmXVyFbsRg', 'Eiq8GHi4kEo',
+ 'H5FEdJYokO4', 'Mkyp0z_Cgig', 'NB5Ez5kJfMU', 'Xa0y6b6Vm6U',
+ 'gVcCGUtpA90', '0-fstXuo_Pw', '-d72e4v9skA', 'lbp6_wCXqvw',
+ '9GpZHq1n8ps', 'CefGXyYu_zU', 'SI2JbS48Upg', 'hdklRTNrq0I',
+ 'J-P-t6g19SM', 'K0f_DpVOjfA', 'lw_1fEY9QTo', 'uUuYnKLETLw',
+ 'HwKv3Xc5MAE', 'wvQ0h5Nwsxc', 'l8ME6z_EWKE', 's9dTu2fcbNg',
+ 'GS09SevPYT4', 'YbwdDCzVczU', 'jaCOI_VwIjc', '3Y1Jp1_fFLQ',
+ '82OzgxT2tH8', 'IjQhHPlTfdE', 'KzQcJrT91jU', 't05AD0c08zE',
+ 'rGxWxX6nYO4', 'QGp0kRzKiAc', 'pK9gDWoOyko', 'Srjd4pe6vck',
+ 'twGcxuhCXoU', 'AshLUHPEb8M', '8En3M5CUc2E', '8sTJfTUk1d0',
+ 'o-bubyWTw60', 'NctbssxGCtU', 'L09Qo1ql0nM'
+ ]
+ }
+}
+
+TVSUM_SPLITS = {
+ 'BK': {
+ 'train': ['WxtbjNsCQ8A', 'EE-bNr36nyA', 'oDXZc0tZe04', 'uGu_10sucQo'],
+ 'val': ['Se3oxnaPsz0']
+ },
+ 'BT': {
+ 'train': ['eQu1rNs0an0', 'qqR6AEXwxoQ', 'EYqVtI9YWJA', 'iVt07TCkFM0'],
+ 'val': ['JgHubY5Vw3Y']
+ },
+ 'DS': {
+ 'train': ['kLxoNp-UchI', 'NyBmCxDoHJU', 'jcoYJXDG9sw', '-esJrBWj2d8'],
+ 'val': ['E11zDS9XGzg']
+ },
+ 'FM': {
+ 'train': ['_xMr-HKMfVA', 'byxOvuiIJV0', 'VuWGsYPqAX8', 'xmEERLqJ2kU'],
+ 'val': ['JKpqYvAdIsw']
+ },
+ 'GA': {
+ 'train': ['xxdtq8mxegs', 'i3wAGJaaktw', '0tmA_C6XwfM', '3eYKfiOEJNs'],
+ 'val': ['Bhxk-O1Y7Ho']
+ },
+ 'MS': {
+ 'train': ['Hl-__g2gn_A', 'WG0MBPpPC6I', 'LRw_obCPUt0', '37rzWOQsNIw'],
+ 'val': ['Yi4Ij2NM7U4']
+ },
+ 'PK': {
+ 'train': ['GsAD1KT1xo8', 'XkqCExn6_Us', 'b626MiF1ew4', 'PJrm840pAUI'],
+ 'val': ['cjibtmSLxQ4']
+ },
+ 'PR': {
+ 'train': ['RBCABdttQmI', 'z_6gVvQb2d0', '4wU_LUjG5Ic', '91IHQYk1IQM'],
+ 'val': ['fWutDQy1nnY']
+ },
+ 'VT': {
+ 'train': ['gzDbaEs1Rlg', 'XzYM3PfTM4w', '98MoyGZKHXc', 'AwmHb44_ouw'],
+ 'val': ['J0nA4VgnoCo']
+ },
+ 'VU': {
+ 'train': ['akI8YFjEmUw', 'HT5vyqe0Xaw', 'vdmoEJ5YbrQ', 'xwqBXPGE9pQ'],
+ 'val': ['sTEELN-vY30']
+ }
+}
\ No newline at end of file
diff --git a/main/config_qfvs.json b/main/config_qfvs.json
new file mode 100644
index 0000000000000000000000000000000000000000..1948047231c7911893e31576f5b1dec34e68f148
--- /dev/null
+++ b/main/config_qfvs.json
@@ -0,0 +1,14 @@
+{
+ // "max_segment_num": 20,
+ // "max_frame_num": 200,
+
+ // "train_videos": null,
+ // "test_videos": null,
+ // "top_percent": 0.02,
+
+ // "vid_feature": "fps1",
+ // "txt_feature": "query",
+ // "txt_max_len": 5,
+
+ // "factor": null
+}
\ No newline at end of file
diff --git a/main/dataset.py b/main/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e534a0c17c64a53cfd29cf1a9188299c0088e42c
--- /dev/null
+++ b/main/dataset.py
@@ -0,0 +1,1261 @@
+import os
+import pdb
+import h5py
+import nncore
+import torch
+from torch.utils.data import Dataset
+import numpy as np
+from tqdm import tqdm
+import random
+import logging
+from os.path import join, exists
+from nncore.dataset import DATASETS
+from nncore.parallel import DataContainer
+from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
+from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
+from utils.tensor_utils import pad_sequences_1d
+from utils.span_utils import span_xx_to_cxw
+from random import shuffle
+
+logger = logging.getLogger(__name__)
+
+class DatasetVLP(Dataset):
+ Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"]
+ """One line in data loaded from data_path."
+ {
+ "qid": 7803,
+ "query": "Man in gray top walks from outside to inside.",
+ "duration": 150,
+ "vid": "RoripwjYFp8_360.0_510.0",
+ "relevant_clip_ids": [13, 14, 15, 16, 17],
+ "relevant_windows": [[26, 36]]
+ }
+ """
+ def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim,
+ q_feat_type="last_hidden_state",
+ max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video",
+ normalize_v=True, normalize_t=True, load_labels=True,
+ clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0,
+ use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1):
+ self.dset_name = dset_name
+ self.data_path = data_path
+ self.data_ratio = data_ratio
+ self.v_feat_dirs = v_feat_dirs \
+ if isinstance(v_feat_dirs, list) else [v_feat_dirs]
+ self.q_feat_dir = q_feat_dir
+ self.q_feat_type = q_feat_type
+ self.v_feat_dim = v_feat_dim
+ self.q_feat_dim = q_feat_dim
+ self.max_q_l = max_q_l
+ self.max_v_l = max_v_l
+ self.ctx_mode = ctx_mode
+ self.use_tef = "tef" in ctx_mode
+ self.use_video = "video" in ctx_mode
+ self.normalize_t = normalize_t
+ self.normalize_v = normalize_v
+ self.load_labels = load_labels
+ self.clip_len = clip_len
+ self.fix_len = fix_len
+ self.max_windows = max_windows # maximum number of windows to use as labels
+ self.span_loss_type = span_loss_type
+ self.txt_drop_ratio = txt_drop_ratio
+ self.use_cache = use_cache
+ self.add_easy_negative = add_easy_negative
+ self.easy_negative_only = easy_negative_only
+
+ self.vlp_mapping = {
+ # 'data/qvhighlights/metadata/qvhighlights_asr.jsonl': {
+ # 'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '_asr', 'type': 'interval',
+ # },
+ # 'data/ego4d/metadata/point_train_1m.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
+ # },
+ # 'data/ego4d/metadata/point_train_1m_0.1p.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
+ # },
+ # 'data/ego4d/metadata/point_train_1m_0.2p.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
+ # },
+ # 'data/ego4d/metadata/point_train_1m_0.5p.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
+ # },
+ # 'data/ego4d/metadata/point_train_1m_0.75p.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
+ # },
+ # 'data/ego4d/metadata/point_train_2m.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
+ # },
+ # 'data/ego4d/metadata/point_train_1m_egoclip.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
+ # },
+ # 'data/hacs/metadata/hacs_train_cs.jsonl': {
+ # 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '_cs', 'type': 'curve',
+ # },
+ # 'data/hacs/metadata/hacs_train.jsonl': {
+ # 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve',
+ # },
+ # 'data/videocc/metadata/train_300k.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ # },
+ # 'data/videocc/metadata/train_600k.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ # },
+ # 'data/videocc/metadata/train_600k_0.1p.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ # },
+ # 'data/videocc/metadata/train_600k_0.2p.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ # },
+ # 'data/videocc/metadata/train_600k_0.5p.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ # },
+ # 'data/videocc/metadata/train_600k_0.75p.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ # },
+ # 'data/videocc/metadata/train_900k.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ # },
+ # 'data/ego4d/metadata/concept_train_top10_window.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/ego4d/metadata/concept_train_top5_window.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/ego4d/metadata/concept_train_top5_window_0.1p.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/ego4d/metadata/concept_train_top5_window_0.2p.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/ego4d/metadata/concept_train_top5_window_0.5p.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/ego4d/metadata/concept_train_top5_window_0.75p.jsonl': {
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/videocc/metadata/concept_train_top10_window.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/videocc/metadata/concept_train_top5_window.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/videocc/metadata/concept_train_top5_window_0.1p.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/videocc/metadata/concept_train_top5_window_0.2p.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/videocc/metadata/concept_train_top5_window_0.5p.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ # 'data/videocc/metadata/concept_train_top5_window_0.75p.jsonl': {
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
+ # },
+ #
+ # pre-training
+ 'data/ego4d/metadata/point_egoclip_wo_val.jsonl': {
+ 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
+ },
+ 'data/videocc/metadata/interval_900k.jsonl': {
+ 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ },
+ 'data/videocc/metadata/curve_5_window.jsonl': {
+ 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
+ },
+ # downstream
+ 'data/qvhighlights/metadata/qvhighlights_train.jsonl': {
+ 'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve',
+ },
+ 'data/charades/metadata/charades_train.jsonl': {
+ 'dset_name': 'charades', 'v_feat_suffix': '_2', 'q_feat_suffix': '', 'type': 'interval',
+ },
+ 'data/ego4d/metadata/nlq_train.jsonl': {
+ 'dset_name': 'ego4d', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ },
+ 'data/tacos/metadata/train.jsonl': {
+ 'dset_name': 'tacos', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ },
+ 'data/anet/metadata/train.jsonl': {
+ 'dset_name': 'anet', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ },
+ 'data/didemo/metadata/train.jsonl': {
+ 'dset_name': 'didemo', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
+ },
+ }
+
+ if "val" in data_path or "test" in data_path:
+ assert txt_drop_ratio == 0
+
+ # checks
+ assert q_feat_type in self.Q_FEAT_TYPES
+
+ # data
+ self.data = self.load_data()
+
+ self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs]
+ t_feat_type = q_feat_dir.split('/')[-1]
+
+ if self.use_cache > 0:
+ print('Loading the off-line features...')
+ dset_dir = os.path.join('data', self.dset_name)
+ vid_keys = [meta['vid'] for meta in self.data]
+ qid_keys = [meta['qid'] for meta in self.data]
+
+ self.vid_cache = {}
+ for v_feat_type in self.v_feat_types:
+ assert 'vid' in v_feat_type
+ with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f:
+ self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)}
+
+ assert 'txt' in t_feat_type
+ self.txt_cache = {}
+ with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f:
+ for key in tqdm(qid_keys):
+ try:
+ self.txt_cache[key] = f[str(key)][:]
+ except:
+ logger.info(f"text {key} is not in the cache.")
+
+ def load_data(self):
+ # datalist = load_jsonl(self.data_path[0])
+ datalist = []
+ for dset_path in self.data_path:
+ dset_info = self.vlp_mapping[dset_path]
+ dset_list = load_jsonl(dset_path)
+ for x in dset_list: x.update(dset_info)
+ datalist += dset_list
+ n_examples = int(len(datalist))
+ if self.data_ratio != 1:
+ n_examples = int(len(datalist) * self.data_ratio)
+ shuffle(datalist)
+ datalist = datalist[:n_examples]
+ logger.info("Using {}% of the data: {} examples"
+ .format(self.data_ratio * 100, n_examples))
+ return datalist
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ meta = self.data[index]
+
+ model_inputs = dict()
+ model_inputs["query_feat"] = self._get_query_feat_by_qid(meta) # (Dq, ) or (Lq, Dq)
+
+ if self.use_video:
+ model_inputs["video_feat"] = self._get_video_feat_by_vid(meta) # (Lv, Dv)
+ ctx_l = len(model_inputs["video_feat"])
+ else:
+ ctx_l = self.max_v_l
+
+ if meta['dset_name'] in ['hacs', 'ego4d', 'activitynet']:
+ for i, window_i in enumerate(meta["relevant_windows"]):
+ if window_i[1] - window_i[0] < self.clip_len:
+ center = (window_i[1] + window_i[0]) / 2
+ window_i[0] = max(0, center - 0.5 * self.clip_len)
+ window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len)
+ window_i[1] = max(self.clip_len, window_i[1])
+
+ model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
+
+ if 'test' in self.data_path and 'qvhighlights' in self.dset_name:
+ meta["relevant_windows"] = [[0, 150]]
+ relevant_windows = torch.Tensor(meta["relevant_windows"])
+
+ # assign the nearest window for each timestamp i.e., qvhighlights.
+ num_vid_seq = model_inputs["timestamp"].shape[0]
+ num_windows = relevant_windows.shape[0]
+
+ relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len)
+ relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1)
+ model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1)
+
+ if meta['qid'] is not None:
+ nn_window_ts = torch.zeros_like(model_inputs["timestamp"])
+ diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0]
+ diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1]
+ assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0))
+ if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet.
+ nn_window_ts = relevant_windows_ts.squeeze(1)
+ else:
+ nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]]
+
+ model_inputs["span_labels_nn"] = nn_window_ts
+ model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1])
+
+ # for activitynet.
+ if model_inputs["timestamp_window"].sum() < 1:
+ idx = int(meta['relevant_windows'][0][0] / self.clip_len)
+ idx = max(0, min(idx, ctx_l-1))
+ model_inputs["timestamp_window"][idx] = 1
+
+ if self.use_tef:
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ tef_ed = tef_st + 1.0 / ctx_l
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
+ if self.use_video:
+ model_inputs["video_feat"] = torch.cat(
+ [model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2)
+ else:
+ model_inputs["video_feat"] = tef
+
+ if self.load_labels:
+ model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2)
+ if 'saliency_scores' in meta.keys():
+ # this is for highlight-only task
+ model_inputs["saliency_scores"] = torch.zeros(ctx_l).double()
+ limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None
+ model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1))
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
+ self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
+ # pdb.set_trace()
+ else:
+ model_inputs["saliency_scores"] = model_inputs["timestamp_window"]
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
+ self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt
+ model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ]
+
+ if 'type' in meta.keys():
+ if meta['type'] == 'point':
+ model_inputs['weight_ablation'] = torch.tensor([0, 0, 1, 0, 0])
+ if meta['type'] == 'interval':
+ model_inputs['weight_ablation'] = torch.tensor([1, 1, 0, 0, 0])
+ if meta['type'] == 'curve':
+ model_inputs['weight_ablation'] = torch.tensor([0, 0, 0, 1, 1])
+
+ return dict(meta=meta, model_inputs=model_inputs)
+
+ def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1):
+ gt_st = int(gt_window[0] / self.clip_len)
+ gt_st = min(gt_st, ctx_l-1)
+ gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1)
+ if gt_st > gt_ed:
+ # gt_st = gt_ed
+ gt_ed = gt_st
+
+ if gt_st != gt_ed:
+ pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n)
+ else:
+ pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st]
+
+ neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l))
+ # neg_clip_indices = random.sample(neg_pool, k=max_n)
+
+ try:
+ neg_clip_indices = random.sample(neg_pool, k=max_n)
+ except:
+ neg_clip_indices = pos_clip_indices
+
+ return pos_clip_indices, neg_clip_indices
+
+ def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1):
+ """Sum the scores from the three annotations, then take the two clips with the
+ maximum scores as positive, and two with the minimum scores as negative.
+ Args:
+ rel_clip_ids: list(int), list of relevant clip ids
+ scores: list([anno1_score, anno2_score, anno3_score]),
+ ctx_l: int
+ max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
+ add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
+ """
+ # indices inside rel_clip_ids
+ scores = np.array(scores) # (#rel_clips, 3)
+ agg_scores = np.sum(scores, 1) # (#rel_clips, )
+ sort_indices = np.argsort(agg_scores) # increasing
+
+ # indices in the whole video
+ # the min(_, ctx_l-1) here is incorrect, but should not cause
+ # much troubles since this should be rarely used.
+ hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
+ hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
+
+ if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]:
+ hard_neg_clip_indices = hard_pos_clip_indices
+
+ easy_pos_clip_indices = []
+ easy_neg_clip_indices = []
+ # pdb.set_trace()
+ if self.add_easy_negative > 0:
+ easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
+ if len(easy_neg_pool) >= max_n:
+ easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
+ easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
+ else: # copy the hard ones
+ easy_pos_clip_indices = hard_pos_clip_indices
+ easy_neg_clip_indices = hard_neg_clip_indices
+
+ if self.easy_negative_only > 0:
+ return easy_pos_clip_indices, easy_neg_clip_indices
+
+ pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
+ neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
+
+ return pos_clip_indices, neg_clip_indices
+
+ def get_span_labels(self, windows, ctx_l):
+ """
+ windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive)
+ Note a maximum of `self.max_windows` windows are used.
+ returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length
+ """
+ if len(windows) > self.max_windows:
+ random.shuffle(windows)
+ windows = windows[:self.max_windows]
+ if self.span_loss_type == "l1":
+ windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx
+ windows = span_xx_to_cxw(windows) # normalized windows in cxw
+ elif self.span_loss_type == "ce":
+ windows = torch.Tensor([
+ [int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1]
+ for w in windows]).long() # inclusive
+ else:
+ raise NotImplementedError
+ return windows
+
+ def _get_query_feat_by_qid(self, meta):
+ qid = meta['qid']
+ dset_name = meta['dset_name']
+ q_feat_suffix = meta['q_feat_suffix']
+ q_feat_dir = self.q_feat_dir + q_feat_suffix
+
+ if self.use_cache > 0:
+ try:
+ q_feat = self.txt_cache[qid]
+ except:
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
+ return torch.from_numpy(q_feat)
+
+ q_feat_path = os.path.join('data', dset_name, q_feat_dir, f"{qid}.npz")
+ try:
+ q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
+ except:
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
+ logger.info(f"Something wrong when loading the query feature {q_feat_path}.")
+
+ if self.q_feat_type == "last_hidden_state":
+ # q_feat = q_feat[:self.max_q_l]
+ q_feat = q_feat
+ if self.normalize_t:
+ q_feat = l2_normalize_np_array(q_feat)
+ if self.txt_drop_ratio > 0:
+ q_feat = self.random_drop_rows(q_feat)
+ return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
+
+ def random_drop_rows(self, embeddings):
+ """randomly mask num_drop rows in embeddings to be zero.
+ Args:
+ embeddings: np.ndarray (L, D)
+ """
+ num_drop_rows = round(len(embeddings) * self.txt_drop_ratio)
+ if num_drop_rows > 0:
+ row_indices = np.random.choice(
+ len(embeddings), size=num_drop_rows, replace=False)
+ embeddings[row_indices] = 0
+ return embeddings
+
+ def _get_video_feat_by_vid(self, meta):
+ dset_name = meta['dset_name']
+ v_feat_suffix = meta['v_feat_suffix']
+ vid = meta['vid']
+
+ v_feat_list = []
+ for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs):
+ v_feat_dir = _feat_dir + v_feat_suffix
+ if self.use_cache > 0:
+ _feat = self.vid_cache[feat_type][vid]
+ else:
+ _feat_path = os.path.join('data', dset_name, v_feat_dir, f"{vid}.npz")
+ _feat = np.load(_feat_path)["features"].astype(np.float32)
+ if self.normalize_v:
+ _feat = l2_normalize_np_array(_feat)
+ v_feat_list.append(_feat)
+ # some features are slightly longer than the others
+ min_len = min([len(e) for e in v_feat_list])
+ v_feat_list = [e[:min_len] for e in v_feat_list]
+ v_feat = np.concatenate(v_feat_list, axis=1)
+ return torch.from_numpy(v_feat) # (Lv, D)
+
+class DatasetMR(Dataset):
+ Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"]
+ """One line in data loaded from data_path."
+ {
+ "qid": 7803,
+ "query": "Man in gray top walks from outside to inside.",
+ "duration": 150,
+ "vid": "RoripwjYFp8_360.0_510.0",
+ "relevant_clip_ids": [13, 14, 15, 16, 17],
+ "relevant_windows": [[26, 36]]
+ }
+ """
+ def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim,
+ q_feat_type="last_hidden_state",
+ max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video",
+ normalize_v=True, normalize_t=True, load_labels=True,
+ clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0,
+ use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1):
+ self.dset_name = dset_name
+ self.data_path = data_path[0] if isinstance(data_path, list) else data_path
+ self.data_ratio = data_ratio
+ self.v_feat_dirs = v_feat_dirs \
+ if isinstance(v_feat_dirs, list) else [v_feat_dirs]
+ self.q_feat_dir = q_feat_dir
+ self.q_feat_type = q_feat_type
+ self.v_feat_dim = v_feat_dim
+ self.q_feat_dim = q_feat_dim
+ self.max_q_l = max_q_l
+ self.max_v_l = max_v_l
+ self.ctx_mode = ctx_mode
+ self.use_tef = "tef" in ctx_mode
+ self.use_video = "video" in ctx_mode
+ self.normalize_t = normalize_t
+ self.normalize_v = normalize_v
+ self.load_labels = load_labels
+ self.clip_len = clip_len
+ self.fix_len = fix_len
+ self.max_windows = max_windows # maximum number of windows to use as labels
+ self.span_loss_type = span_loss_type
+ self.txt_drop_ratio = txt_drop_ratio
+ self.use_cache = use_cache
+ self.add_easy_negative = add_easy_negative
+ self.easy_negative_only = easy_negative_only
+
+ if "val" in data_path or "test" in data_path:
+ assert txt_drop_ratio == 0
+
+ # checks
+ assert q_feat_type in self.Q_FEAT_TYPES
+
+ # data
+ self.data = self.load_data()
+
+ self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs]
+ t_feat_type = q_feat_dir.split('/')[-1]
+
+ if self.use_cache > 0:
+ print('Loading the off-line features...')
+ dset_dir = os.path.join('data', self.dset_name)
+ vid_keys = [meta['vid'] for meta in self.data]
+ qid_keys = [meta['qid'] for meta in self.data]
+
+ self.vid_cache = {}
+ for v_feat_type in self.v_feat_types:
+ assert 'vid' in v_feat_type
+ with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f:
+ self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)}
+
+ assert 'txt' in t_feat_type
+ self.txt_cache = {}
+ with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f:
+ for key in tqdm(qid_keys):
+ try:
+ self.txt_cache[key] = f[str(key)][:]
+ except:
+ logger.info(f"text {key} is not in the cache.")
+
+ def load_data(self):
+ datalist = load_jsonl(self.data_path)
+ if self.data_ratio != 1:
+ n_examples = int(len(datalist) * self.data_ratio)
+ datalist = datalist[:n_examples]
+ logger.info("Using {}% of the data: {} examples"
+ .format(self.data_ratio * 100, n_examples))
+ return datalist
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ meta = self.data[index]
+
+ model_inputs = dict()
+ model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq)
+
+ if self.use_video:
+ model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv)
+ ctx_l = len(model_inputs["video_feat"])
+ else:
+ ctx_l = self.max_v_l
+
+ if self.dset_name in ['hacs', 'ego4d', 'videocc', 'activitynet']:
+ for i, window_i in enumerate(meta["relevant_windows"]):
+ if window_i[1] - window_i[0] < self.clip_len:
+ center = (window_i[1] + window_i[0]) / 2
+ window_i[0] = max(0, center - 0.5 * self.clip_len)
+ window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len)
+ window_i[1] = max(self.clip_len, window_i[1])
+
+ model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
+
+ if 'test' in self.data_path and 'qvhighlights' in self.dset_name:
+ meta["relevant_windows"] = [[0, 150]]
+ relevant_windows = torch.Tensor(meta["relevant_windows"])
+
+ # assign the nearest window for each timestamp i.e., qvhighlights.
+ num_vid_seq = model_inputs["timestamp"].shape[0]
+ num_windows = relevant_windows.shape[0]
+
+ relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len)
+ relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1)
+ model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1)
+
+ if meta['qid'] is not None:
+ nn_window_ts = torch.zeros_like(model_inputs["timestamp"])
+ diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0]
+ diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1]
+ assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0))
+ if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet.
+ nn_window_ts = relevant_windows_ts.squeeze(1)
+ else:
+ nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]]
+
+ model_inputs["span_labels_nn"] = nn_window_ts
+ model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1])
+
+ # for activitynet.
+ if model_inputs["timestamp_window"].sum() < 1:
+ idx = int(meta['relevant_windows'][0][0] / self.clip_len)
+ idx = max(0, min(idx, ctx_l-1))
+ model_inputs["timestamp_window"][idx] = 1
+
+ if self.use_tef:
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ tef_ed = tef_st + 1.0 / ctx_l
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
+ if self.use_video:
+ model_inputs["video_feat"] = torch.cat(
+ [model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2)
+ else:
+ model_inputs["video_feat"] = tef
+
+ if self.load_labels:
+ model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2)
+ if 'saliency_scores' in meta.keys():
+ model_inputs["saliency_scores"] = torch.zeros(ctx_l).double()
+ limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None
+ model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1))
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
+ self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
+ else:
+ model_inputs["saliency_scores"] = model_inputs["timestamp_window"]
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
+ self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt
+ model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ]
+
+ return dict(meta=meta, model_inputs=model_inputs)
+
+ def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1):
+ gt_st = int(gt_window[0] / self.clip_len)
+ gt_st = min(gt_st, ctx_l-1)
+ gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1)
+ if gt_st > gt_ed:
+ gt_ed = gt_st
+
+ if gt_st != gt_ed:
+ pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n)
+ else:
+ pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st]
+
+ neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l))
+
+ try:
+ neg_clip_indices = random.sample(neg_pool, k=max_n)
+ except:
+ neg_clip_indices = pos_clip_indices
+
+ return pos_clip_indices, neg_clip_indices
+
+ def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1):
+ """Sum the scores from the three annotations, then take the two clips with the
+ maximum scores as positive, and two with the minimum scores as negative.
+ Args:
+ rel_clip_ids: list(int), list of relevant clip ids
+ scores: list([anno1_score, anno2_score, anno3_score]),
+ ctx_l: int
+ max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
+ add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
+ """
+ # indices inside rel_clip_ids
+ scores = np.array(scores) # (#rel_clips, 3)
+ agg_scores = np.sum(scores, 1) # (#rel_clips, )
+ sort_indices = np.argsort(agg_scores) # increasing
+
+ # indices in the whole video
+ # the min(_, ctx_l-1) here is incorrect, but should not cause
+ # much troubles since this should be rarely used.
+ hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
+ hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
+
+ if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]:
+ hard_neg_clip_indices = hard_pos_clip_indices
+
+ easy_pos_clip_indices = []
+ easy_neg_clip_indices = []
+
+ if self.add_easy_negative > 0:
+ easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
+ if len(easy_neg_pool) >= max_n:
+ easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
+ easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
+ else: # copy the hard ones
+ easy_pos_clip_indices = hard_pos_clip_indices
+ easy_neg_clip_indices = hard_neg_clip_indices
+
+ if self.easy_negative_only > 0:
+ return easy_pos_clip_indices, easy_neg_clip_indices
+
+ pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
+ neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
+ return pos_clip_indices, neg_clip_indices
+
+ def get_span_labels(self, windows, ctx_l):
+ """
+ windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive)
+ Note a maximum of `self.max_windows` windows are used.
+ returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length
+ """
+ if len(windows) > self.max_windows:
+ random.shuffle(windows)
+ windows = windows[:self.max_windows]
+ if self.span_loss_type == "l1":
+ windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx
+ windows = span_xx_to_cxw(windows) # normalized windows in cxw
+ elif self.span_loss_type == "ce":
+ windows = torch.Tensor([
+ [int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1]
+ for w in windows]).long() # inclusive
+ else:
+ raise NotImplementedError
+ return windows
+
+ def _get_query_feat_by_qid(self, qid):
+ if self.use_cache > 0:
+ try:
+ q_feat = self.txt_cache[qid]
+ except:
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
+ return torch.from_numpy(q_feat)
+
+ q_feat_path = join(self.q_feat_dir, f"{qid}.npz")
+ try:
+ q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
+ except:
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
+ logger.info(f"Something wrong when loading the query feature {q_feat_path}.")
+
+ if self.q_feat_type == "last_hidden_state":
+ # q_feat = q_feat[:self.max_q_l]
+ q_feat = q_feat
+ if self.normalize_t:
+ q_feat = l2_normalize_np_array(q_feat)
+ if self.txt_drop_ratio > 0:
+ q_feat = self.random_drop_rows(q_feat)
+ return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
+
+ def random_drop_rows(self, embeddings):
+ """randomly mask num_drop rows in embeddings to be zero.
+ Args:
+ embeddings: np.ndarray (L, D)
+ """
+ num_drop_rows = round(len(embeddings) * self.txt_drop_ratio)
+ if num_drop_rows > 0:
+ row_indices = np.random.choice(
+ len(embeddings), size=num_drop_rows, replace=False)
+ embeddings[row_indices] = 0
+ return embeddings
+
+ def _get_video_feat_by_vid(self, vid):
+ v_feat_list = []
+ for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs):
+ if self.use_cache > 0:
+ _feat = self.vid_cache[feat_type][vid]
+ else:
+ _feat_path = join(_feat_dir, f"{vid}.npz")
+ _feat = np.load(_feat_path)["features"].astype(np.float32)
+ # _feat = np.load(_feat_path)["features"][:self.max_v_l].astype(np.float32)
+ if self.normalize_v:
+ _feat = l2_normalize_np_array(_feat)
+ v_feat_list.append(_feat)
+ # some features are slightly longer than the others
+ min_len = min([len(e) for e in v_feat_list])
+ v_feat_list = [e[:min_len] for e in v_feat_list]
+ v_feat = np.concatenate(v_feat_list, axis=1)
+ return torch.from_numpy(v_feat) # (Lv, D)
+
+class DatasetHL(Dataset):
+ def __init__(self,
+ dset_name,
+ domain,
+ data_path,
+ v_feat_types,
+ v_feat_dirs,
+ t_feat_dir,
+ use_tef=False
+ ):
+ assert dset_name in ['tvsum', 'youtube']
+ self.dset_name = dset_name
+ dset_domain = {'tvsum': TVSUM_SPLITS,
+ 'youtube': YOUTUBE_SPLITS}
+ self.splits = dset_domain[dset_name]
+ assert domain in self.splits.keys()
+
+ self.domain = domain
+ assert len(data_path) == 1
+ self.data_path = data_path[0] if isinstance(data_path, list) else data_path
+ self.v_feat_types = v_feat_types.split('_')
+ self.v_feat_dirs = v_feat_dirs
+ self.q_feat_type = "last_hidden_state"
+ self.q_feat_dir = t_feat_dir
+
+ self.txt_drop_ratio = 0
+ self.normalize_t = True
+ self.normalize_v = True
+
+ self.label = nncore.load(self.data_path)
+ self.use_tef = use_tef
+
+ self.video_id = {
+ k: [s for s in self.splits[domain][k] if s in self.label]
+ for k in ('train', 'val')
+ }
+ self.set_state('train')
+
+ def __len__(self):
+ return len(self.video_id[self.state])
+
+ def __getitem__(self, idx):
+ vid = self.get_video_id(idx)
+ video = self._get_video_feat_by_vid(vid)
+ saliency = self.get_saliency(idx)
+
+ if self.dset_name == 'youtube':
+ saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())])
+ elif self.dset_name == 'tvsum':
+ saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())])
+ # saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency != min(saliency))[0].tolist())])
+ else:
+ raise NotImplementedError
+
+ num_clips = min(c.size(0) for c in (video, saliency))
+
+ video = video[:num_clips]
+ saliency = saliency[:num_clips]
+
+ if self.use_tef:
+ ctx_l = video.shape[0]
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ tef_ed = tef_st + 1.0 / ctx_l
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
+ video = torch.cat([video, tef], dim=1) # (Lv, Dv+2)
+
+ data = dict(
+ video=DataContainer(video),
+ saliency=DataContainer(saliency, pad_value=-1),
+ saliency_pos_labels=saliency_pos_labels)
+
+ if self.q_feat_dir is not None:
+ query = self._get_query_feat_by_qid(vid)
+ data['query'] = DataContainer(query, pad_value=float('inf'))
+ return data
+
+ def set_state(self, state):
+ self.state = 'train' if state == 'train' else 'val'
+
+ def get_video_id(self, idx):
+ return self.video_id[self.state][idx]
+
+ def get_video(self, idx):
+ video_id = self.get_video_id(idx)
+ video = torch.from_numpy(self.video[video_id]).float()
+ optic = torch.from_numpy(self.optic[video_id]).float()
+ return torch.cat((video, optic), dim=1)
+
+ def _get_video_feat_by_vid(self, vid):
+ v_feat_list = []
+ for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs):
+ # if self.use_cache > 0:
+ # _feat = self.vid_cache[feat_type][vid]
+ # else:
+ if True:
+ _feat_path = join(_feat_dir, f"{vid}.npz")
+ _feat = np.load(_feat_path)["features"].astype(np.float32)
+ if self.normalize_v:
+ _feat = l2_normalize_np_array(_feat)
+ v_feat_list.append(_feat)
+ # some features are slightly longer than the others
+ min_len = min([len(e) for e in v_feat_list])
+ v_feat_list = [e[:min_len] for e in v_feat_list]
+ v_feat = np.concatenate(v_feat_list, axis=1)
+ return torch.from_numpy(v_feat) # (Lv, D)
+
+ def _get_query_feat_by_qid(self, qid):
+ # if self.use_cache > 0:
+ # try:
+ # q_feat = self.txt_cache[qid]
+ # except:
+ # q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
+ # return torch.from_numpy(q_feat)
+
+ q_feat_path = join(self.q_feat_dir, f"{qid}.npz")
+ try:
+ q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
+ except:
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
+ logger.info(f"Something wrong when loading the query feature {q_feat_path}.")
+
+ if self.q_feat_type == "last_hidden_state":
+ # q_feat = q_feat[:self.max_q_l]
+ q_feat = q_feat
+ if self.normalize_t:
+ q_feat = l2_normalize_np_array(q_feat)
+ if self.txt_drop_ratio > 0:
+ q_feat = self.random_drop_rows(q_feat)
+ return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
+
+ def get_saliency(self, idx):
+ if self.dset_name == 'tvsum':
+ video_id = self.get_video_id(idx)
+ saliency = torch.Tensor(self.label[video_id]['anno'])
+
+ # top-5 saliency scores as a threshold.
+ # saliency_tmp = saliency.mean(1)
+ # topk = int(saliency_tmp.shape[0] * 0.1)
+ # th = saliency_tmp[torch.sort(saliency_tmp)[1][-topk]] # v4
+ # saliency = saliency_tmp - th
+
+ # saliency_tmp = saliency.mean(1) # med
+ # th = saliency_tmp.median()
+ # saliency = saliency_tmp - th
+
+ saliency = (saliency - saliency.mean()).mean(dim=1)
+ # saliency = (saliency.sum(dim=1) - 20) / 80 # v2
+
+ elif self.dset_name == 'youtube':
+ video_id = self.get_video_id(idx)
+ saliency = [1 if s > 0 else 0 for s in self.label[video_id]['match']]
+ else:
+ raise NotImplementedError
+ return torch.Tensor(saliency)
+
+ def evaluate(self, blob, k=5, save_dir=None, **kwargs):
+ # blob = nncore.to_dict_of_list(blob)
+ collected = []
+
+ if save_dir is not None:
+ import json
+ with open(os.path.join(save_dir, self.dset_name, self.domain +'.jsonl'), 'w') as f:
+ for idx, score in enumerate(blob):
+ video_id = self.get_video_id(idx)
+ entry = {'vid':video_id, 'pred': score[0].tolist(), 'gt': self.get_saliency(idx).tolist(),
+ 'duration': int(self.label[video_id]['frames']) / int(self.label[video_id]['fps']),
+ 'domain': self.label[video_id]['domain'], 'fps': self.label[video_id]['fps']}
+ if self.dset_name == 'tvsum':
+ entry.update({'title':self.label[video_id]['title']})
+ if self.dset_name == 'youtube':
+ entry.update({'clip':self.label[video_id]['clip']})
+ f.write(json.dumps(entry) + '\n')
+
+ if self.dset_name == 'tvsum':
+ for i in range(20):
+ video_ap = []
+ for idx, score in enumerate(blob):
+ inds = torch.argsort(score[0], descending=True)
+ video_id = self.get_video_id(idx)
+ label = torch.Tensor(self.label[video_id]['anno'])[:, i]
+ label = torch.where(label > label.median(), 1.0, .0)
+ label = label[inds].tolist()[:k]
+
+ if (num_gt := sum(label)) == 0:
+ video_ap.append(0)
+ continue
+
+ hits = ap = rec = 0
+ prc = 1
+
+ for j, gt in enumerate(label):
+ hits += gt
+ _rec = hits / num_gt
+ _prc = hits / (j + 1)
+ ap += (_rec - rec) * (prc + _prc) / 2
+ rec, prc = _rec, _prc
+ video_ap.append(ap)
+ collected.append(sum(video_ap) / len(video_ap))
+
+ elif self.dset_name == 'youtube':
+ for idx, score in enumerate(blob):
+ inds = torch.argsort(score[0], descending=True)
+ label = self.get_saliency(idx)[inds].tolist()
+
+ if (num_gt := sum(label)) == 0:
+ collected.append(0)
+ continue
+
+ hits = ap = rec = 0
+ prc = 1
+
+ for i, gt in enumerate(label):
+ hits += gt
+ _rec = hits / num_gt
+ _prc = hits / (i + 1)
+ ap += (_rec - rec) * (prc + _prc) / 2
+ rec, prc = _rec, _prc
+ collected.append(ap)
+ else:
+ raise NotImplementedError
+
+ mean_ap = sum(collected) / len(collected)
+ results = dict(mAP=round(mean_ap, 5))
+ return results
+
+class DatasetQFVS(Dataset):
+ def __init__(self,config, use_tef=True):
+ # pdb.set_trace()
+ self.config=config
+ self.dataset=[]
+ self.use_tef=use_tef
+
+ self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl")
+
+ for video_id in self.config["train_videos"]:
+ for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
+ for file in files:
+ self.dataset.append(file[:file.find("_oracle.txt")]+"_"+str(video_id))
+
+ def __getitem__(self,index):
+ video_id=self.dataset[index].split('_')[2]
+ feat_type = self.config['vid_feature']
+ # pdb.set_trace()
+ feat_type = self.config['vid_feature']
+ f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
+ features=f['feature'][()]
+ # dim=features.shape[-1]
+ # features=features.reshape(-1, dim)
+ # seg_len=f['seg_len'][()]
+ dim = features.shape[-1]
+ ctx_l = features.shape[0]
+ seg_len = np.ones(ctx_l)
+
+ # mask = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
+ # for j in range(len(seg_len)):
+ # for k in range(seg_len[j]):
+ # mask[j][k] = 1
+
+ # ctx_l = seg_len.sum()
+ features = torch.from_numpy(features)
+ # features = features[mask, :]
+
+ if self.use_tef:
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ tef_ed = tef_st + 1.0 / ctx_l
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
+ features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
+
+ transfer={"Cupglass":"Glass",
+ "Musicalinstrument":"Instrument",
+ "Petsanimal":"Animal"}
+
+ concept1,concept2=self.dataset[index].split('_')[0:2]
+
+ concept1_GT=torch.zeros(ctx_l)
+ concept2_GT=torch.zeros(ctx_l)
+ with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f:
+ lines=f.readlines()
+ for index,line in enumerate(lines):
+ concepts=line.strip().split(',')
+ if concept1 in concepts:
+ concept1_GT[index]=1
+ if concept2 in concepts:
+ concept2_GT[index]=1
+
+ # shot_num=seg_len.sum()
+ # mask_GT=torch.zeros(ctx_l)
+ # for i in range(shot_num):
+ # mask_GT[i]=1
+ mask_GT=torch.ones(ctx_l)
+
+ oracle_summary = torch.zeros(ctx_l)
+ GT_summary_shots = []
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f:
+ for line in f.readlines():
+ GT_summary_shots.append(int(line.strip()))
+ GT_summary_shots = [x - 1 for x in GT_summary_shots]
+ for element in GT_summary_shots:
+ oracle_summary[element] = 1
+
+ if concept1 in transfer:
+ concept1=transfer[concept1]
+ if concept2 in transfer:
+ concept2=transfer[concept2]
+ concept1=self.embedding[concept1]
+ concept2=self.embedding[concept2]
+
+ try:
+ saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
+ except:
+ saliency_pos_labels_1 = torch.Tensor(0)
+
+ try:
+ saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
+ except:
+ saliency_pos_labels_2 = torch.Tensor(0)
+
+ try:
+ saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
+ except:
+ saliency_pos_labels_oracle = torch.Tensor(0)
+
+ return {
+ 'features':features,
+ 'seg_len':torch.from_numpy(seg_len),
+ 'concept1_GT':concept1_GT,
+ 'concept2_GT':concept2_GT,
+ 'mask_GT':mask_GT,
+ 'oracle_summary':oracle_summary,
+ 'tokens_pad1':torch.from_numpy(concept1),
+ 'tokens_pad2':torch.from_numpy(concept2),
+ 'saliency_pos_labels_1': saliency_pos_labels_1,
+ 'saliency_pos_labels_2': saliency_pos_labels_2,
+ 'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
+ }
+
+ def __len__(self):
+ return len(self.dataset)
+
+def start_end_collate_mr(batch):
+ batch_meta = [e["meta"] for e in batch] # seems no need to collate ?
+
+ model_inputs_keys = batch[0]["model_inputs"].keys()
+ batched_data = dict()
+ for k in model_inputs_keys:
+ if k == "span_labels":
+ batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch]
+ continue
+ if k in ["saliency_pos_labels", "saliency_neg_labels"]:
+ batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch])
+ continue
+
+ batched_data[k] = pad_sequences_1d(
+ [e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None)
+ return batch_meta, batched_data
+
+def start_end_collate_hl(batch):
+ model_inputs_keys = batch[0].keys()
+
+ batched_data = dict()
+ for k in model_inputs_keys:
+ batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
+ return batched_data
+
+def start_end_collate_qfvs(batch):
+ model_inputs_keys = batch[0].keys()
+
+ batched_data = dict()
+ for k in model_inputs_keys:
+ batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
+
+ return batched_data
+
+def prepare_batch_inputs_mr(batched_model_inputs, device, non_blocking=False):
+ model_inputs = dict(
+ src_txt=batched_model_inputs["query_feat"][0].to(device, non_blocking=non_blocking),
+ src_txt_mask=batched_model_inputs["query_feat"][1].to(device, non_blocking=non_blocking),
+ src_vid=batched_model_inputs["video_feat"][0].to(device, non_blocking=non_blocking),
+ src_vid_mask=batched_model_inputs["video_feat"][1].to(device, non_blocking=non_blocking),
+ )
+ targets = {}
+ targets['timestamp'] = batched_model_inputs["timestamp"][0].to(device, non_blocking=non_blocking)
+ targets['timestamp_mask'] = batched_model_inputs["timestamp"][1].to(device, non_blocking=non_blocking)
+ targets['timestamp_window'] = batched_model_inputs["timestamp_window"][0].to(device, non_blocking=non_blocking)
+ targets['span_labels_nn'] = batched_model_inputs["span_labels_nn"][0].to(device, non_blocking=non_blocking)
+
+ if 'saliency_scores' in batched_model_inputs.keys():
+ targets['saliency_scores'] = batched_model_inputs["saliency_scores"][0].to(device, non_blocking=non_blocking)
+
+ if "span_labels" in batched_model_inputs:
+ targets["span_labels"] = [
+ dict(spans=e["spans"].to(device, non_blocking=non_blocking))
+ for e in batched_model_inputs["span_labels"]
+ ]
+ if "saliency_pos_labels" in batched_model_inputs:
+ for name in ["saliency_pos_labels", "saliency_neg_labels"]:
+ targets[name] = batched_model_inputs[name].to(device, non_blocking=non_blocking)
+
+ if "weight_ablation" in batched_model_inputs:
+ targets["weight_ablation"] = batched_model_inputs["weight_ablation"][0].to(device, non_blocking=non_blocking)
+
+ targets = None if len(targets) == 0 else targets
+ return model_inputs, targets
+
+def prepare_batch_inputs_hl(batched_model_inputs, device='cuda', non_blocking=False):
+ src_vid = batched_model_inputs['video'][0].to(device, non_blocking=non_blocking)
+ src_vid_mask = batched_model_inputs['video'][1].bool().to(device, non_blocking=non_blocking)
+ src_txt = batched_model_inputs['query'][0].to(device, non_blocking=non_blocking) \
+ if 'query' in batched_model_inputs.keys() else None
+ src_txt_mask = batched_model_inputs['query'][1].bool().to(device, non_blocking=non_blocking) \
+ if 'query' in batched_model_inputs.keys() else None
+
+ model_inputs = dict(
+ src_vid=src_vid, src_vid_mask=src_vid_mask,
+ src_txt=src_txt, src_txt_mask=src_txt_mask)
+
+ # if 'audio' in batched_model_inputs.keys():
+ # src_aud = batched_model_inputs['audio'][0].bool().to(device, non_blocking=non_blocking)
+ # src_aud_mask = batched_model_inputs['audio'][1].bool().to(device, non_blocking=non_blocking)
+ # model_inputs['src_aud']=src_aud; model_inputs['src_aud_mask']=src_aud_mask;
+
+ targets = {}
+ saliency = batched_model_inputs['saliency'][0].to(device, non_blocking=non_blocking)
+ saliency_pos_labels = batched_model_inputs['saliency_pos_labels'][0].to(device, non_blocking=non_blocking)
+
+ targets['saliency_scores'] = saliency
+ targets['saliency_pos_labels'] = saliency_pos_labels.long()
+ targets['timestamp_mask'] = batched_model_inputs["video"][1].to(device, non_blocking=non_blocking)
+ targets['timestamp_window'] = 1 * (saliency > 0)
+
+ return model_inputs, targets
+
+def prepare_batch_inputs_qfvs(data, config, eval=False):
+ if not eval:
+ features, mask, seg_len, \
+ concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \
+ src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\
+ saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \
+ data['features'][0], data['features'][1], data['seg_len'][0],\
+ data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\
+ data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \
+ data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0],
+ else:
+ features, mask, seg_len, \
+ src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \
+ data['features'][0], data['features'][1], data['seg_len'][0],\
+ data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1]
+
+ # preprocess for vid input.
+ seq = features.to('cuda')
+ mask = mask.to('cuda')
+
+ # for txt input.
+ src_txt_1 = src_txt_1.to(torch.float32).to('cuda')
+ src_txt_2 = src_txt_2.to(torch.float32).to('cuda')
+ src_txt_mask_1 = src_txt_mask_1.to('cuda')
+ src_txt_mask_2 = src_txt_mask_2.to('cuda')
+
+ src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda')
+ src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda')
+
+ model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1)
+ model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2)
+ model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle)
+
+ if not eval:
+ targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda'))
+ targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda'))
+ targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda'))
+
+ targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda')
+ targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda')
+ targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda')
+
+ return model_inputs_1, model_inputs_2, model_inputs_oracle, \
+ targets_1, targets_2, targets_oracle, mask_GT
+ else:
+ return model_inputs_1, model_inputs_2, model_inputs_oracle, mask
diff --git a/main/dataset_qfvs.py b/main/dataset_qfvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ec00e2df1ff33503d6d145e27bad72293cd73d0
--- /dev/null
+++ b/main/dataset_qfvs.py
@@ -0,0 +1,284 @@
+import os
+import pdb
+import h5py
+import nncore
+import torch
+from torch.utils.data import Dataset
+import numpy as np
+from tqdm import tqdm
+import random
+import logging
+from os.path import join, exists
+from nncore.dataset import DATASETS
+from nncore.parallel import DataContainer
+from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
+from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
+from utils.tensor_utils import pad_sequences_1d
+from utils.span_utils import span_xx_to_cxw
+
+logger = logging.getLogger(__name__)
+
+class DatasetQFVS(Dataset):
+ def __init__(self,config, use_tef=True):
+ # pdb.set_trace()
+ self.config=config
+ self.dataset=[]
+ self.use_tef=use_tef
+
+ self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl")
+
+ self.transfer={"Cupglass":"Glass",
+ "Musicalinstrument":"Instrument",
+ "Petsanimal":"Animal"}
+
+ self.f_dict = {}
+ feat_type = self.config['vid_feature']
+
+ for video_id in self.config["train_videos"]:
+ self.f_dict[str(video_id)] = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
+ for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
+ for file in files:
+ self.dataset.append(['Oracle', file[:file.find("_oracle.txt")]+"_"+str(video_id)])
+
+ if self.config['qfvs_dense_shot'] > 0:
+ dense_concept = {}
+ feat_type = self.config['vid_feature']
+ feat=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
+ features=feat['features'][()]
+ seg_len=feat['seg_len'][()]
+ with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+str(video_id)+"/P0"+str(video_id)+".txt","r") as f:
+ lines=f.readlines()
+ for index,line in enumerate(lines):
+ concepts=line.strip().split(',')
+ for concept in concepts:
+ if concept in self.transfer:
+ concept= self.transfer[concept]
+ if concept not in dense_concept:
+ # dense_concept[concept] = torch.zeros(seg_len.sum())
+ dense_concept[concept] = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
+ else:
+ dense_concept[concept][index] = 1
+
+ for key, value in dense_concept.items():
+ if value.sum().item() > 0:
+ self.dataset.append([video_id, key, value])
+
+ def __getitem__(self, index):
+ if self.dataset[index][0] == 'Oracle':
+ return self.get_oracle(index)
+ else:
+ return self.get_dense(index)
+
+ def get_dense(self,index):
+ video_id=str(self.dataset[index][0])
+ f = self.f_dict[video_id]
+ # feat_type = self.config['vid_feature']
+ # f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
+ features=f['features'][()]
+ seg_len=f['seg_len'][()]
+
+ dim = features.shape[-1]
+
+ mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
+ for j in range(len(seg_len)):
+ for k in range(seg_len[j]):
+ mask_GT[j][k] = 1
+
+ features = torch.from_numpy(features)
+
+ concept1 = concept2 = self.dataset[index][1]
+ concept1_GT = concept2_GT = oracle_summary = self.dataset[index][2]
+
+ if concept1 in self.transfer:
+ concept1=self.transfer[concept1]
+ if concept2 in self.transfer:
+ concept2=self.transfer[concept2]
+ concept1=self.embedding[concept1]
+ concept2=self.embedding[concept2]
+
+ concept1 = l2_normalize_np_array(concept1)
+ concept2 = l2_normalize_np_array(concept2)
+
+ try:
+ saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
+ except:
+ saliency_pos_labels_1 = torch.Tensor(0)
+
+ try:
+ saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
+ except:
+ saliency_pos_labels_2 = torch.Tensor(0)
+
+ try:
+ saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
+ except:
+ saliency_pos_labels_oracle = torch.Tensor(0)
+
+ return {
+ 'features':features,
+ 'seg_len':torch.from_numpy(seg_len),
+ 'concept1_GT':concept1_GT,
+ 'concept2_GT':concept2_GT,
+ 'mask_GT':mask_GT,
+ 'oracle_summary':oracle_summary,
+ 'tokens_pad1':torch.from_numpy(concept1),
+ 'tokens_pad2':torch.from_numpy(concept2),
+ 'saliency_pos_labels_1': saliency_pos_labels_1,
+ 'saliency_pos_labels_2': saliency_pos_labels_2,
+ 'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
+ }
+
+ def get_oracle(self,index):
+ video_id=self.dataset[index][1].split('_')[2]
+ f = self.f_dict[video_id]
+ # video_id=self.dataset[index][1].split('_')[2]
+ # feat_type = self.config['vid_feature']
+ # f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
+ features=f['features'][()]
+ seg_len=f['seg_len'][()]
+
+ dim = features.shape[-1]
+
+ mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
+ for j in range(len(seg_len)):
+ for k in range(seg_len[j]):
+ mask_GT[j][k] = 1
+
+ features = torch.from_numpy(features)
+
+ concept1,concept2=self.dataset[index][1].split('_')[0:2]
+
+ concept1_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
+ concept2_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
+ # concept1_GT=torch.zeros(seg_len.sum())
+ # concept2_GT= torch.zeros(seg_len.sum())
+ with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f:
+ lines=f.readlines()
+ for index,line in enumerate(lines):
+ concepts=line.strip().split(',')
+ if concept1 in concepts:
+ concept1_GT[index]=1
+ if concept2 in concepts:
+ concept2_GT[index]=1
+
+ # oracle_summary =torch.zeros(seg_len.sum())
+ oracle_summary = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
+ GT_summary_shots = []
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f:
+ for line in f.readlines():
+ GT_summary_shots.append(int(line.strip()))
+ GT_summary_shots = [x - 1 for x in GT_summary_shots]
+ for element in GT_summary_shots:
+ oracle_summary[element] = 1
+
+ if concept1 in self.transfer:
+ concept1=self.transfer[concept1]
+ if concept2 in self.transfer:
+ concept2=self.transfer[concept2]
+ concept1=self.embedding[concept1]
+ concept2=self.embedding[concept2]
+
+ concept1 = l2_normalize_np_array(concept1)
+ concept2 = l2_normalize_np_array(concept2)
+
+ try:
+ saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
+ except:
+ saliency_pos_labels_1 = torch.Tensor(0)
+
+ try:
+ saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
+ except:
+ saliency_pos_labels_2 = torch.Tensor(0)
+
+ try:
+ saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
+ except:
+ saliency_pos_labels_oracle = torch.Tensor(0)
+
+ return {
+ 'features':features,
+ 'seg_len':torch.from_numpy(seg_len),
+ 'concept1_GT':concept1_GT,
+ 'concept2_GT':concept2_GT,
+ 'mask_GT':mask_GT,
+ 'oracle_summary':oracle_summary,
+ 'tokens_pad1':torch.from_numpy(concept1),
+ 'tokens_pad2':torch.from_numpy(concept2),
+ 'saliency_pos_labels_1': saliency_pos_labels_1,
+ 'saliency_pos_labels_2': saliency_pos_labels_2,
+ 'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
+ }
+
+ def __len__(self):
+ return len(self.dataset)
+
+def start_end_collate_qfvs(batch):
+ model_inputs_keys = batch[0].keys()
+
+ batched_data = dict()
+ for k in model_inputs_keys:
+ batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
+
+ return batched_data
+
+def prepare_batch_inputs_qfvs(data, config, eval=False):
+ if not eval:
+ features, mask, seg_len, \
+ concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \
+ src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\
+ saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \
+ data['features'][0], data['mask_GT'][0], data['seg_len'][0],\
+ data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\
+ data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \
+ data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0],
+ else:
+ features, mask, seg_len, \
+ src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \
+ data['features'][0], data['mask_GT'][0], data['seg_len'][0],\
+ data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1]
+
+ # preprocess for vid input.
+ mask_GT = mask.to('cuda').reshape(1, -1).bool()
+ seq = features.to('cuda').squeeze(0)
+ mask = mask.to('cuda').squeeze(0)
+ num_seg = seq.shape[0]
+
+ ctx_l = seq.shape[1]
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ tef_ed = tef_st + 1.0 / ctx_l
+ tef = torch.stack([tef_st, tef_ed], dim=1).to('cuda') # (Lv, 2)
+
+ tef = tef.squeeze(0).repeat(seq.shape[0], 1, 1)
+ seq = torch.cat([seq, tef], dim=-1)
+
+ # for txt input.
+ src_txt_1 = src_txt_1.to(torch.float32).to('cuda').repeat(num_seg, 1, 1)
+ src_txt_2 = src_txt_2.to(torch.float32).to('cuda').repeat(num_seg, 1, 1)
+ src_txt_mask_1 = src_txt_mask_1.to('cuda').repeat(num_seg, 1)
+ src_txt_mask_2 = src_txt_mask_2.to('cuda').repeat(num_seg, 1)
+
+ src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda')
+ src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda')
+
+ model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1)
+ model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2)
+ model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle)
+
+ # concept1_GT = concept1_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
+ # concept2_GT = concept2_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
+ # oracle_summary_GT = oracle_summary_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
+
+ if not eval:
+ targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda'))
+ targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda'))
+ targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda'))
+
+ targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda')
+ targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda')
+ targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda')
+
+ return model_inputs_1, model_inputs_2, model_inputs_oracle, \
+ targets_1, targets_2, targets_oracle, mask_GT
+ else:
+ return model_inputs_1, model_inputs_2, model_inputs_oracle, mask_GT
\ No newline at end of file
diff --git a/main/inference_demo.py b/main/inference_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..7659564d0598b249807620cec917374c2fa193f0
--- /dev/null
+++ b/main/inference_demo.py
@@ -0,0 +1,81 @@
+import pdb
+import pprint
+from tqdm import tqdm, trange
+import numpy as np
+import os
+from collections import OrderedDict, defaultdict
+from utils.basic_utils import AverageMeter
+
+import torch
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+
+from main.config import TestOptions, setup_model
+from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
+from eval.eval import eval_submission
+from eval.postprocessing import PostProcessorDETR
+from utils.basic_utils import save_jsonl, save_json
+from utils.temporal_nms import temporal_nms
+from utils.span_utils import span_cxw_to_xx
+from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
+
+import logging
+import importlib
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def load_model():
+ logger.info("Setup config, data and model...")
+ opt = TestOptions().parse()
+ # pdb.set_trace()
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ model, criterion, _, _ = setup_model(opt)
+ return model
+
+def load_data(save_dir):
+ vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32)
+ txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32)
+
+ vid = torch.from_numpy(l2_normalize_np_array(vid))
+ txt = torch.from_numpy(l2_normalize_np_array(txt))
+ clip_len = 2
+ ctx_l = vid.shape[0]
+
+ timestamp = ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
+
+ if True:
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ tef_ed = tef_st + 1.0 / ctx_l
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
+ vid = torch.cat([vid, tef], dim=1) # (Lv, Dv+2)
+
+ src_vid = vid.unsqueeze(0).cuda()
+ src_txt = txt.unsqueeze(0).cuda()
+ src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda()
+ src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda()
+
+ return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l
+
+if __name__ == '__main__':
+ clip_len = 2
+ save_dir = '/data/home/qinghonglin/univtg/demo/tmp'
+
+ model = load_model()
+ src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir)
+ with torch.no_grad():
+ output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask)
+
+ pred_logits = output['pred_logits'][0].cpu()
+ pred_spans = output['pred_spans'][0].cpu()
+ pred_saliency = output['saliency_scores'].cpu()
+
+ pdb.set_trace()
+ top1 = (pred_spans + timestamp)[torch.argmax(pred_logits)] * ctx_l * clip_len
+ print(top1)
+ print(pred_saliency.argmax()*clip_len)
\ No newline at end of file
diff --git a/main/inference_hl.py b/main/inference_hl.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f4e1eb206aefe89e741ba86825002e23b368ce0
--- /dev/null
+++ b/main/inference_hl.py
@@ -0,0 +1,229 @@
+import os
+import pdb
+import time
+import json
+import pprint
+import random
+import importlib
+import numpy as np
+from tqdm import tqdm, trange
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+import sys
+sys.path.append('/Users/kevin/univtg')
+from main.config import BaseOptions, setup_model
+from main.dataset import DatasetHL, prepare_batch_inputs_hl, start_end_collate_hl
+from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl
+from utils.model_utils import count_parameters
+
+import logging
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def eval_epoch(model, train_val_dataset, opt): #, nms_thresh, device):
+ model.eval()
+
+ scores = []
+ train_val_dataset.set_state('val')
+ val_loader = DataLoader(
+ train_val_dataset,
+ collate_fn=start_end_collate_hl,
+ batch_size=opt.eval_bsz,
+ num_workers=opt.num_workers,
+ shuffle=False,
+ pin_memory=opt.pin_memory
+ )
+
+ with torch.no_grad():
+ for data in val_loader:
+ model_inputs, targets = prepare_batch_inputs_hl(data)
+ outputs = model(**model_inputs)
+ # pred_cls = outputs['pred_logits'].squeeze(-1)
+ # pred_cls = outputs['saliency_scores']
+ # pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
+
+ # pdb.set_trace()
+ if opt.f_loss_coef == 0:
+ pred_cls = outputs['saliency_scores']
+ elif opt.s_loss_intra_coef == 0:
+ pred_cls = outputs['pred_logits'].squeeze(-1)
+ else:
+ if opt.eval_mode == 'add':
+ pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
+ else:
+ pred_cls = outputs['pred_logits'].squeeze(-1)
+
+ pred_cls = pred_cls.detach().cpu()
+ scores.append(pred_cls)
+ map = round(train_val_dataset.evaluate(scores, save_dir='./plot')['mAP'] * 100, 4)
+ return map
+
+def train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer):
+ logger.info(f"[Epoch {epoch_i+1}]")
+ model.train()
+ criterion.train()
+
+ train_val_dataset.set_state('train')
+ train_loader = DataLoader(
+ train_val_dataset,
+ collate_fn=start_end_collate_hl,
+ batch_size=opt.bsz,
+ num_workers=opt.num_workers,
+ shuffle=True,
+ pin_memory=opt.pin_memory
+ )
+
+ # init meters
+ time_meters = defaultdict(AverageMeter)
+ loss_meters = defaultdict(AverageMeter)
+
+ num_training_examples = len(train_loader)
+ timer_dataloading = time.time()
+ for batch_idx, batch in enumerate(train_loader):
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
+ timer_start = time.time()
+ model_inputs, targets = prepare_batch_inputs_hl(batch)
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ outputs = model(**model_inputs)
+ loss_dict = criterion(outputs, targets)
+ weight_dict = criterion.weight_dict
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+ time_meters["model_forward_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ optimizer.zero_grad()
+ losses.backward()
+ if opt.grad_clip > 0:
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
+ optimizer.step()
+ time_meters["model_backward_time"].update(time.time() - timer_start)
+
+ loss_dict["loss_overall"] = float(losses)
+ for k, v in loss_dict.items():
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
+
+ timer_dataloading = time.time()
+ if opt.debug and batch_idx == 3:
+ break
+
+ # print/add logs
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
+ for k, v in loss_meters.items():
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
+
+ to_write = opt.train_log_txt_formatter.format(
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
+ epoch=epoch_i+1,
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
+ with open(opt.train_log_filepath, "a") as f:
+ f.write(to_write)
+
+ logger.info("Epoch time stats:")
+ for name, meter in time_meters.items():
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
+ logger.info(f"{name} ==> {d}")
+
+# train in single domain.
+def train(model, criterion, optimizer, lr_scheduler, train_val_dataset, opt):
+ # if opt.device.type == "cuda":
+ # logger.info("CUDA enabled.")
+ # model.to(opt.device)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
+
+ prev_best_score = 0.
+ if opt.start_epoch is None:
+ start_epoch = -1 if opt.eval_init else 0
+ else:
+ start_epoch = opt.start_epoch
+
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
+ if epoch_i > -1:
+ train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer)
+ lr_scheduler.step()
+ eval_epoch_interval = opt.eval_epoch
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
+ with torch.no_grad():
+ scores = eval_epoch(model, train_val_dataset, opt)
+ tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-{train_val_dataset.domain}-mAP", float(scores), epoch_i+1)
+ if prev_best_score < scores:
+ prev_best_score = scores
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_{train_val_dataset.domain}_best.ckpt"))
+ tb_writer.close()
+ return prev_best_score
+
+def start_training():
+ logger.info("Setup config, data and model...")
+ opt = BaseOptions().parse()
+ set_seed(opt.seed)
+
+ from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
+ if opt.dset_name == "tvsum":
+ domain_splits = TVSUM_SPLITS.keys()
+ if opt.dset_name == "youtube":
+ domain_splits = YOUTUBE_SPLITS.keys()
+
+ scores = {}
+ if opt.lr_warmup > 0:
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
+ total_steps = opt.n_epoch
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
+ opt.lr_warmup = [warmup_steps, total_steps]
+
+ domain_splits = domain_splits if not opt.domain_name else [opt.domain_name]
+
+ for domain in domain_splits:
+ dataset_config = dict(
+ dset_name=opt.dset_name,
+ domain=domain,
+ data_path=opt.train_path,
+ v_feat_types=opt.v_feat_types,
+ v_feat_dirs=opt.v_feat_dirs,
+ t_feat_dir=opt.t_feat_dir,
+ use_tef=True
+ )
+ dataloader = DatasetHL(**dataset_config)
+
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
+ count_parameters(model)
+ logger.info(f"Start Training {domain}")
+ best_score = train(model, criterion, optimizer, lr_scheduler, dataloader, opt)
+ scores[domain] = best_score
+ scores['AVG'] = sum(scores.values()) / len(scores)
+
+ # save the final results.
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
+ save_json(scores, save_metrics_path, save_pretty=True, sort_keys=False)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text(f"HL-{opt.dset_name}", dict_to_markdown(scores, max_str_len=None))
+ tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-avg-mAP-key", float(scores['AVG']), 1)
+ tb_writer.close()
+ # return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
+
+ print(opt.dset_name)
+ print(scores)
+ return
+
+if __name__ == '__main__':
+ start_training()
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
diff --git a/main/inference_mr.py b/main/inference_mr.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aea2de137ac46fa91f737d49998a00165423bce
--- /dev/null
+++ b/main/inference_mr.py
@@ -0,0 +1,273 @@
+import pdb
+import pprint
+from tqdm import tqdm, trange
+import numpy as np
+import os
+from collections import OrderedDict, defaultdict
+from utils.basic_utils import AverageMeter
+
+import torch
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+
+from main.config import TestOptions, setup_model
+from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
+from eval.eval import eval_submission
+from eval.postprocessing import PostProcessorDETR
+from utils.basic_utils import save_jsonl, save_json
+from utils.temporal_nms import temporal_nms
+from utils.span_utils import span_cxw_to_xx
+
+import logging
+import importlib
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+
+def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms):
+ mr_res_after_nms = []
+ for e in mr_res:
+ e["pred_relevant_windows"] = temporal_nms(
+ e["pred_relevant_windows"][:max_before_nms],
+ nms_thd=nms_thd,
+ max_after_nms=max_after_nms
+ )
+ mr_res_after_nms.append(e)
+ return mr_res_after_nms
+
+
+def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename):
+ # IOU_THDS = (0.5, 0.7)
+ logger.info("Saving/Evaluating before nms results")
+ submission_path = os.path.join(opt.results_dir, save_submission_filename)
+ save_jsonl(submission, submission_path)
+
+ if opt.eval_split_name in ["val", "test"]: # since test_public has no GT
+ metrics = eval_submission(
+ submission, gt_data,
+ verbose=opt.debug, match_number=not opt.debug,
+ )
+ save_metrics_path = submission_path.replace(".jsonl", "_metrics.json")
+ save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False)
+ latest_file_paths = [submission_path, save_metrics_path]
+ else:
+ metrics = None
+ latest_file_paths = [submission_path, ]
+
+ if opt.nms_thd != -1:
+ logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd))
+ submission_after_nms = post_processing_mr_nms(
+ submission, nms_thd=opt.nms_thd,
+ max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms
+ )
+
+ logger.info("Saving/Evaluating nms results")
+ submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd))
+ save_jsonl(submission_after_nms, submission_nms_path)
+ if opt.eval_split_name == "val":
+ metrics_nms = eval_submission(
+ submission_after_nms, gt_data,
+ verbose=opt.debug, match_number=not opt.debug
+ )
+ save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json")
+ save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False)
+ latest_file_paths += [submission_nms_path, save_metrics_nms_path]
+ else:
+ metrics_nms = None
+ latest_file_paths = [submission_nms_path, ]
+ else:
+ metrics_nms = None
+ return metrics, metrics_nms, latest_file_paths
+
+
+@torch.no_grad()
+def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None):
+ model.eval()
+ if criterion:
+ assert eval_loader.dataset.load_labels
+ criterion.eval()
+
+ loss_meters = defaultdict(AverageMeter)
+ write_tb = tb_writer is not None and epoch_i is not None
+
+ mr_res = []
+ for batch in tqdm(eval_loader, desc="compute st ed scores"):
+ query_meta = batch[0]
+ model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory)
+ outputs = model(**model_inputs)
+ prob = outputs["pred_logits"] # the last channel may be 1 or 2.
+ # if opt.eval_mode == 'v1':
+ # prob = prob * outputs["saliency_scores"].unsqueeze(-1) # v1
+ # if opt.eval_mode == 'v2':
+ # prob = F.softmax(prob, dim=1) * outputs["saliency_scores"].unsqueeze(-1) # v2
+ # if opt.eval_mode == 'v3':
+ # prob = outputs["saliency_scores"].unsqueeze(-1)
+ if outputs["pred_logits"].shape[-1] > 1:
+ prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #queries, #classes=2)
+ if opt.span_loss_type == "l1":
+ scores = prob[..., 0] # * (batch_size, #queries) foreground label is 0, we directly take it
+ pred_spans = outputs["pred_spans"] # (bsz, #queries, 2)
+
+ if opt.model_id not in ['moment_detr']: # dense regression.
+ start_spans = targets['timestamp']
+ pred_spans = start_spans + pred_spans
+ mask = targets['timestamp_mask'].bool()
+ scores[~mask] = 0
+ # if opt.eval_mode == 'v4':
+ # _mask = targets['timestamp_window'].bool()
+ # scores[~_mask] = 0
+
+ if opt.eval_mode == 'add':
+ # pdb.set_trace()
+ _saliency_scores = outputs["saliency_scores"].half() + prob.squeeze(-1)
+ else:
+ _saliency_scores = outputs["saliency_scores"].half() # (bsz, L)
+
+ if opt.eval_mode == 'add_mr':
+ prob = outputs["saliency_scores"].half().unsqueeze(-1) + prob
+
+ saliency_scores = []
+ valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist()
+ for j in range(len(valid_vid_lengths)):
+ saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist())
+ else:
+ bsz, n_queries = outputs["pred_spans"].shape[:2] # # (bsz, #queries, max_v_l *2)
+ pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l)
+ # TODO use more advanced decoding method with st_ed product
+ pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) # 2 * (bsz, #queries, 2)
+ scores = torch.prod(pred_span_scores, 2) # (bsz, #queries)
+ pred_spans[:, 1] += 1
+ pred_spans *= opt.clip_length
+
+ # compose predictions
+ for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())):
+ if opt.span_loss_type == "l1":
+ if opt.model_id in ['moment_detr']:
+ spans = span_cxw_to_xx(spans) * meta["duration"]
+ else:
+ spans = spans * meta["duration"]
+ spans = torch.clamp(spans, 0, meta["duration"]) # added by Kevin, since window cannot be longer than video duration.
+
+ # (#queries, 3), [st(float), ed(float), score(float)]
+ cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist()
+ if not opt.no_sort_results:
+ cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
+ cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds]
+ cur_query_pred = dict(
+ qid=meta["qid"],
+ query=meta["query"],
+ vid=meta["vid"],
+ pred_relevant_windows=cur_ranked_preds,
+ pred_saliency_scores=saliency_scores[idx]
+ )
+ mr_res.append(cur_query_pred)
+
+ if criterion:
+ loss_dict = criterion(outputs, targets)
+ weight_dict = criterion.weight_dict
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+ loss_dict["loss_overall"] = float(losses) # for logging only
+ for k, v in loss_dict.items():
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
+
+ if opt.debug:
+ break
+
+ if write_tb and criterion:
+ for k, v in loss_meters.items():
+ tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1)
+
+ post_processor = PostProcessorDETR(
+ clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150,
+ min_w_l=2, max_w_l=150, move_window_method="left",
+ # process_func_names=("clip_ts", "round_multiple")
+ process_func_names=["round_multiple"] # have added `clamp' op on line 147, thus we do not need `clip_ts' again;
+ )
+ # todo: are we need round_multiple?
+ if opt.round_multiple > 0:
+ mr_res = post_processor(mr_res)
+ return mr_res, loss_meters
+
+def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer):
+ """compute and save query and video proposal embeddings"""
+ eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict)
+ return eval_res, eval_loss_meters
+
+def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None):
+ logger.info("Generate submissions")
+ model.eval()
+ if criterion is not None and eval_dataset.load_labels:
+ criterion.eval()
+ else:
+ criterion = None
+
+ eval_loader = DataLoader(
+ eval_dataset,
+ collate_fn=start_end_collate_mr,
+ batch_size=opt.eval_bsz,
+ num_workers=opt.num_workers,
+ shuffle=False,
+ pin_memory=opt.pin_memory
+ )
+
+ submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer)
+ if opt.no_sort_results:
+ save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl")
+ metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing(
+ submission, opt, eval_dataset.data, save_submission_filename)
+ return metrics, metrics_nms, eval_loss_meters, latest_file_paths
+
+def start_inference():
+ logger.info("Setup config, data and model...")
+ opt = TestOptions().parse()
+ # pdb.set_trace()
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ assert opt.eval_path is not None
+ eval_dataset = DatasetMR(
+ dset_name=opt.dset_name,
+ data_path=opt.eval_path,
+ v_feat_dirs=opt.v_feat_dirs,
+ q_feat_dir=opt.t_feat_dir,
+ v_feat_dim=opt.v_feat_dim,
+ q_feat_dim=opt.t_feat_dim,
+ q_feat_type="last_hidden_state",
+ max_q_l=opt.max_q_l,
+ max_v_l=opt.max_v_l,
+ ctx_mode=opt.ctx_mode,
+ data_ratio=opt.data_ratio,
+ normalize_v=not opt.no_norm_vfeat,
+ normalize_t=not opt.no_norm_tfeat,
+ clip_len=opt.clip_length,
+ max_windows=opt.max_windows,
+ load_labels=True, # opt.eval_split_name == "val",
+ span_loss_type=opt.span_loss_type,
+ txt_drop_ratio=0,
+ use_cache=opt.use_cache,
+ )
+
+ if opt.lr_warmup > 0:
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
+ total_steps = opt.n_epoch
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
+ opt.lr_warmup = [warmup_steps, total_steps]
+
+ model, criterion, _, _ = setup_model(opt)
+ save_submission_filename = "inference_{}_{}_{}_preds.jsonl".format(
+ opt.dset_name, opt.eval_split_name, opt.eval_id)
+ logger.info("Starting inference...")
+ with torch.no_grad():
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
+ eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion)
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
+ if metrics_nms is not None:
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
+
+
+if __name__ == '__main__':
+ start_inference()
diff --git a/main/inference_qfvs.py b/main/inference_qfvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..4feed8b8399a2dc1fb081e14e9acc4ece64650ed
--- /dev/null
+++ b/main/inference_qfvs.py
@@ -0,0 +1,342 @@
+import os
+import pdb
+import time
+import json
+import pprint
+import random
+import importlib
+import numpy as np
+from tqdm import tqdm, trange
+from collections import defaultdict
+
+import h5py
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+import sys
+sys.path.append('/Users/kevin/univtg')
+from main.config import BaseOptions, setup_model
+from main.dataset_qfvs import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs
+from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle, l2_normalize_np_array
+from utils.model_utils import count_parameters
+from eval.qfvs import calculate_semantic_matching, load_videos_tag
+
+import logging
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def eval_epoch(model, config, opt):
+ model.eval()
+ f1_sum = 0; p_sum = 0; r_sum = 0
+
+ assert len(config['test_videos']) == 1
+ video_id = config['test_videos'][0]
+ embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl")
+
+ feat_type = config['vid_feature']
+ feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r')
+ features = torch.from_numpy(feat['features'][()])
+ seg_len = torch.from_numpy(feat['seg_len'][()])
+ # seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda()
+
+ # dim = features.shape[-1]
+ # ctx_l = seg_len.sum().cpu()
+
+ # dim = features.shape[-1]
+ # ctx_l = features.shape[1]
+ # seg_len = torch.ones(ctx_l)
+ # features = features.reshape(-1, dim)[:ctx_l]
+
+ # tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ # tef_ed = tef_st + 1.0 / ctx_l
+ # tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2)
+ # features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
+
+ transfer = {"Cupglass": "Glass",
+ "Musicalinstrument": "Instrument",
+ "Petsanimal": "Animal"}
+
+ with open(os.path.join('./plot', opt.dset_name, str(opt.qfvs_split) +'.jsonl'), 'w') as f_write:
+ for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
+ evaluation_num=len(files)
+
+ mask_GT = torch.zeros(config["max_segment_num"], config["max_frame_num"], dtype=torch.bool).cuda()
+ for j in range(len(seg_len)):
+ for k in range(seg_len[j]):
+ mask_GT[j][k] = 1
+
+ for file in files:
+ summaries_GT=[]
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f:
+ for line in f.readlines():
+ summaries_GT.append(int(line.strip()))
+
+ concept1, concept2 = file.split('_')[0:2]
+
+ ##############
+ if concept1 in transfer:
+ concept1 = transfer[concept1]
+ if concept2 in transfer:
+ concept2 = transfer[concept2]
+ concept1 = embedding[concept1]
+ concept2 = embedding[concept2]
+
+ concept1 = l2_normalize_np_array(concept1)
+ concept2 = l2_normalize_np_array(concept2)
+
+ data = {
+ 'features':features,
+ 'seg_len': seg_len,
+ 'tokens_pad1':torch.from_numpy(concept1),
+ 'tokens_pad2':torch.from_numpy(concept2),
+ 'mask_GT': mask_GT
+ }
+
+ input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True)
+
+ summaries_GT = [x - 1 for x in summaries_GT]
+ video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat")
+
+ if opt.f_loss_coef == 0:
+ output_type = 'saliency_scores'
+ elif opt.s_loss_intra_coef == 0:
+ output_type = 'pred_logits'
+ else:
+ if config['qfvs_score_ensemble'] > 0:
+ output_type = ['pred_logits', 'saliency_scores']
+ else:
+ output_type = 'pred_logits'
+
+ with torch.no_grad():
+ if not isinstance(output_type, list):
+ score1 = model(**input1)[output_type].squeeze()
+ score1 = score1.masked_select(mask_GT)
+
+ score2 = model(**input2)[output_type].squeeze()
+ score2 = score2.masked_select(mask_GT)
+
+ score = model(**input_oracle)[output_type].squeeze()
+ score = score.masked_select(mask_GT)
+ else:
+ score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda()
+ for output_t in output_type:
+ score1 += model(**input1)[output_t].squeeze().masked_select(mask_GT)
+ score2 += model(**input2)[output_t].squeeze().masked_select(mask_GT)
+ score += model(**input_oracle)[output_t].squeeze().masked_select(mask_GT)
+
+ if config['qfvs_score_gather'] > 0:
+ score = score + score1 + score2
+ else:
+ score = score
+
+ # since video4 features dim is greater than video_shots_tag.
+ score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])]
+ _, top_index = score.topk(int(score.shape[0] * config["top_percent"]))
+
+ c1, c2 = file.split('_')[0:2]
+ if c1 in transfer:
+ c1 = transfer[c1]
+ if c2 in transfer:
+ c2 = transfer[c2]
+
+ p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1)
+ entry = {'concept1': c1, 'concept2': c2,
+ 'score':score.tolist(),
+ 'top_percent': config["top_percent"],
+ 'top_pred':top_index.tolist(),
+ 'gt':summaries_GT,
+ 'p': p, 'r': r, 'f1': f1,
+ 'shots': video_shots_tag[video_id-1].shape[0]}
+ f_write.write(json.dumps(entry) + '\n')
+ f1_sum+=f1; r_sum+=r; p_sum+=p
+ return {'F': round(100* f1_sum/evaluation_num,2) ,
+ 'R': round(100* r_sum/evaluation_num,2) ,
+ 'P': round(100* p_sum/evaluation_num,2) }
+
+def idx2time(idx):
+ sec1, sec2 = idx*5, (idx+1)*5
+
+ h1 = sec1 // 3600
+ m1 = (sec1 - h1*3600) // 60
+ s1 = sec1 % 60
+
+ h2 = sec2 // 3600
+ m2 = (sec2 - h2*3600) // 60
+ s2 = sec2 % 60
+ print(h1,m1,s1,'\t', h2,m2,s2)
+
+def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer):
+ model.train()
+ criterion.train()
+
+ # init meters
+ time_meters = defaultdict(AverageMeter)
+ loss_meters = defaultdict(AverageMeter)
+
+ timer_dataloading = time.time()
+ loss_total = 0
+
+ for batch_idx, batch in enumerate(tqdm(train_loader)):
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
+ timer_start = time.time()
+ model_input1, model_input2, model_input_oracle, \
+ model_gt1, model_gt2, model_gt_oracle, \
+ mask_GT = prepare_batch_inputs_qfvs(batch, config)
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ output1 = model(**model_input1)
+ output2 = model(**model_input2)
+ output_oracle = model(**model_input_oracle)
+
+ loss_dict = {}
+ loss_dict1 = criterion(output1, model_gt1, mask_GT)
+ loss_dict2 = criterion(output2, model_gt2, mask_GT)
+ loss_dict3 = criterion(output_oracle, model_gt_oracle, mask_GT)
+
+ weight_dict = criterion.weight_dict
+ if config['qfvs_loss_gather'] > 0:
+ for k in loss_dict1.keys():
+ loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k]
+ else:
+ loss_dict = loss_dict3
+
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+ loss_total += losses.item()
+
+ time_meters["model_forward_time"].update(time.time() - timer_start)
+ timer_start = time.time()
+ optimizer.zero_grad()
+ losses.backward()
+ if opt.grad_clip > 0:
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
+ optimizer.step()
+ time_meters["model_backward_time"].update(time.time() - timer_start)
+
+ timer_dataloading = time.time()
+ return round(loss_total / len(train_loader), 2)
+
+# train in single domain.
+def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config):
+ # if opt.device.type == "cuda":
+ # logger.info("CUDA enabled.")
+ # model.to(opt.device)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
+
+ prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0}
+ if opt.start_epoch is None:
+ start_epoch = -1 if opt.eval_init else 0
+ else:
+ start_epoch = opt.start_epoch
+
+ val_score = eval_epoch(model, config, opt)
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0)
+ logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
+ if epoch_i > -1:
+ loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer)
+ lr_scheduler.step()
+ eval_epoch_interval = opt.eval_epoch
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
+ with torch.no_grad():
+ val_score = eval_epoch(model, config, opt)
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1)
+ logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
+
+ if prev_best_score['Fscore'] < val_score['F']:
+ prev_best_score['Fscore'] = val_score['F']
+ prev_best_score['Precision'] = val_score['P']
+ prev_best_score['Recall'] = val_score['R']
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt"))
+ tb_writer.close()
+ return prev_best_score
+
+def update_config(opt, config):
+ # for key in ["max_segment_num", "max_frame_num", "top_percent",
+ # "qfvs_vid_feature", "qfvs_txt_feature", "qfvs_dense_shot",
+ # "qfvs_score_ensemble", "qfvs_score_gather", "qfvs_loss_gather"]:
+ config["max_segment_num"] = opt.max_segment_num
+ config["max_frame_num"] = opt.max_frame_num
+ config["top_percent"] = opt.top_percent
+ config["vid_feature"] = opt.qfvs_vid_feature
+ config["txt_feature"] = opt.qfvs_txt_feature
+ config["qfvs_dense_shot"] = opt.qfvs_dense_shot
+ config["qfvs_score_ensemble"] = opt.qfvs_score_ensemble
+ config["qfvs_score_gather"] = opt.qfvs_score_gather
+ config["qfvs_loss_gather"] = opt.qfvs_loss_gather
+ return config
+
+def start_training():
+ logger.info("Setup config, data and model...")
+ opt = BaseOptions().parse()
+ set_seed(opt.seed)
+
+ # config = load_json("./main/config_qfvs.json")
+ config = {}
+ config = update_config(opt, config)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+
+ # key -> test video; value -> training videos.
+ qfvs_split = {
+ 1: [2, 3, 4],
+ 2: [1, 3, 4],
+ 3: [1, 2, 4],
+ 4: [1, 2, 3]
+ }
+
+ scores_videos = {}
+ for test_id, splits in qfvs_split.items():
+ if opt.qfvs_split != -1:
+ if test_id != opt.qfvs_split:
+ continue
+ logger.info(f"Start Training {opt.dset_name}: {test_id}")
+ config['train_videos'] = qfvs_split[test_id]
+ config['test_videos'] = [test_id]
+ train_dataset = DatasetQFVS(config)
+ train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers)
+
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
+ count_parameters(model)
+ best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config)
+ scores_videos['V'+str(test_id)] = best_score
+
+ # save the final results.
+ avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos)
+ avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos)
+ avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos)
+ scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall}
+
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
+ save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False)
+
+ tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1)
+ tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None))
+ tb_writer.close()
+
+ print(scores_videos)
+ return
+
+if __name__ == '__main__':
+ start_training()
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
diff --git a/main/train_hl.py b/main/train_hl.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceec407bc1f4ff92077bda46e90cfd7b566ca56b
--- /dev/null
+++ b/main/train_hl.py
@@ -0,0 +1,229 @@
+import os
+import pdb
+import time
+import json
+import pprint
+import random
+import importlib
+import numpy as np
+from tqdm import tqdm, trange
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+import sys
+sys.path.append('/data/home/qinghonglin/univtg')
+from main.config import BaseOptions, setup_model
+from main.dataset import DatasetHL, prepare_batch_inputs_hl, start_end_collate_hl
+from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl
+from utils.model_utils import count_parameters
+
+import logging
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def eval_epoch(model, train_val_dataset, opt): #, nms_thresh, device):
+ model.eval()
+
+ scores = []
+ train_val_dataset.set_state('val')
+ val_loader = DataLoader(
+ train_val_dataset,
+ collate_fn=start_end_collate_hl,
+ batch_size=opt.eval_bsz,
+ num_workers=opt.num_workers,
+ shuffle=False,
+ pin_memory=opt.pin_memory
+ )
+
+ with torch.no_grad():
+ for data in val_loader:
+ model_inputs, targets = prepare_batch_inputs_hl(data)
+ outputs = model(**model_inputs)
+ # pred_cls = outputs['pred_logits'].squeeze(-1)
+ # pred_cls = outputs['saliency_scores']
+ # pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
+
+ # pdb.set_trace()
+ if opt.f_loss_coef == 0:
+ pred_cls = outputs['saliency_scores']
+ elif opt.s_loss_intra_coef == 0:
+ pred_cls = outputs['pred_logits'].squeeze(-1)
+ else:
+ if opt.eval_mode == 'add':
+ pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
+ else:
+ pred_cls = outputs['pred_logits'].squeeze(-1)
+
+ pred_cls = pred_cls.detach().cpu()
+ scores.append(pred_cls)
+ map = round(train_val_dataset.evaluate(scores)['mAP'] * 100, 4)
+ return map
+
+def train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer):
+ logger.info(f"[Epoch {epoch_i+1}]")
+ model.train()
+ criterion.train()
+
+ train_val_dataset.set_state('train')
+ train_loader = DataLoader(
+ train_val_dataset,
+ collate_fn=start_end_collate_hl,
+ batch_size=opt.bsz,
+ num_workers=opt.num_workers,
+ shuffle=True,
+ pin_memory=opt.pin_memory
+ )
+
+ # init meters
+ time_meters = defaultdict(AverageMeter)
+ loss_meters = defaultdict(AverageMeter)
+
+ num_training_examples = len(train_loader)
+ timer_dataloading = time.time()
+ for batch_idx, batch in enumerate(train_loader):
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
+ timer_start = time.time()
+ model_inputs, targets = prepare_batch_inputs_hl(batch)
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ outputs = model(**model_inputs)
+ loss_dict = criterion(outputs, targets)
+ weight_dict = criterion.weight_dict
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+ time_meters["model_forward_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ optimizer.zero_grad()
+ losses.backward()
+ if opt.grad_clip > 0:
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
+ optimizer.step()
+ time_meters["model_backward_time"].update(time.time() - timer_start)
+
+ loss_dict["loss_overall"] = float(losses)
+ for k, v in loss_dict.items():
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
+
+ timer_dataloading = time.time()
+ if opt.debug and batch_idx == 3:
+ break
+
+ # print/add logs
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
+ for k, v in loss_meters.items():
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
+
+ to_write = opt.train_log_txt_formatter.format(
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
+ epoch=epoch_i+1,
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
+ with open(opt.train_log_filepath, "a") as f:
+ f.write(to_write)
+
+ logger.info("Epoch time stats:")
+ for name, meter in time_meters.items():
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
+ logger.info(f"{name} ==> {d}")
+
+# train in single domain.
+def train(model, criterion, optimizer, lr_scheduler, train_val_dataset, opt):
+ # if opt.device.type == "cuda":
+ # logger.info("CUDA enabled.")
+ # model.to(opt.device)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
+
+ prev_best_score = 0.
+ if opt.start_epoch is None:
+ start_epoch = -1 if opt.eval_init else 0
+ else:
+ start_epoch = opt.start_epoch
+
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
+ if epoch_i > -1:
+ train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer)
+ lr_scheduler.step()
+ eval_epoch_interval = opt.eval_epoch
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
+ with torch.no_grad():
+ scores = eval_epoch(model, train_val_dataset, opt)
+ tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-{train_val_dataset.domain}-mAP", float(scores), epoch_i+1)
+ if prev_best_score < scores:
+ prev_best_score = scores
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_{train_val_dataset.domain}_best.ckpt"))
+ tb_writer.close()
+ return prev_best_score
+
+def start_training():
+ logger.info("Setup config, data and model...")
+ opt = BaseOptions().parse()
+ set_seed(opt.seed)
+
+ from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
+ if opt.dset_name == "tvsum":
+ domain_splits = TVSUM_SPLITS.keys()
+ if opt.dset_name == "youtube":
+ domain_splits = YOUTUBE_SPLITS.keys()
+
+ scores = {}
+ if opt.lr_warmup > 0:
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
+ total_steps = opt.n_epoch
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
+ opt.lr_warmup = [warmup_steps, total_steps]
+
+ domain_splits = domain_splits if not opt.domain_name else [opt.domain_name]
+
+ for domain in domain_splits:
+ dataset_config = dict(
+ dset_name=opt.dset_name,
+ domain=domain,
+ data_path=opt.train_path,
+ v_feat_types=opt.v_feat_types,
+ v_feat_dirs=opt.v_feat_dirs,
+ t_feat_dir=opt.t_feat_dir,
+ use_tef=True
+ )
+ dataloader = DatasetHL(**dataset_config)
+
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
+ count_parameters(model)
+ logger.info(f"Start Training {domain}")
+ best_score = train(model, criterion, optimizer, lr_scheduler, dataloader, opt)
+ scores[domain] = best_score
+ scores['AVG'] = sum(scores.values()) / len(scores)
+
+ # save the final results.
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
+ save_json(scores, save_metrics_path, save_pretty=True, sort_keys=False)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text(f"HL-{opt.dset_name}", dict_to_markdown(scores, max_str_len=None))
+ tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-avg-mAP-key", float(scores['AVG']), 1)
+ tb_writer.close()
+ # return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
+
+ print(opt.dset_name)
+ print(scores)
+ return
+
+if __name__ == '__main__':
+ start_training()
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
\ No newline at end of file
diff --git a/main/train_mr.py b/main/train_mr.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a10d029f81f86733d6dab71a3aee575917b092b
--- /dev/null
+++ b/main/train_mr.py
@@ -0,0 +1,266 @@
+import os
+import pdb
+import sys
+import time
+import json
+import pprint
+import random
+import numpy as np
+from tqdm import tqdm, trange
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+sys.path.append('/data/home/qinghonglin/univtg')
+from main.config import BaseOptions, setup_model
+from main.dataset import \
+ DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
+from main.inference_mr import eval_epoch, start_inference
+from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown
+from utils.model_utils import count_parameters
+
+import logging
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer):
+ logger.info(f"[Epoch {epoch_i+1}]")
+ model.train()
+ criterion.train()
+
+ # init meters
+ time_meters = defaultdict(AverageMeter)
+ loss_meters = defaultdict(AverageMeter)
+
+ num_training_examples = len(train_loader)
+ timer_dataloading = time.time()
+ for batch_idx, batch in tqdm(enumerate(train_loader),
+ desc="Training Iteration",
+ total=num_training_examples):
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
+
+ timer_start = time.time()
+ model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory)
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+
+ # try:
+ outputs = model(**model_inputs)
+ loss_dict = criterion(outputs, targets)
+ weight_dict = criterion.weight_dict
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+ time_meters["model_forward_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ optimizer.zero_grad()
+ losses.backward()
+
+ if opt.grad_clip > 0:
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
+ optimizer.step()
+ time_meters["model_backward_time"].update(time.time() - timer_start)
+
+ loss_dict["loss_overall"] = float(losses) # for logging only
+ for k, v in loss_dict.items():
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
+
+ timer_dataloading = time.time()
+
+ # print/add logs
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
+ for k, v in loss_meters.items():
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
+
+ to_write = opt.train_log_txt_formatter.format(
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
+ epoch=epoch_i+1,
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
+ with open(opt.train_log_filepath, "a") as f:
+ f.write(to_write)
+
+ logger.info("Epoch time stats:")
+ for name, meter in time_meters.items():
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
+ logger.info(f"{name} ==> {d}")
+
+
+def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
+
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=start_end_collate_mr,
+ batch_size=opt.bsz,
+ num_workers=opt.num_workers,
+ shuffle=True,
+ pin_memory=opt.pin_memory
+ )
+
+ prev_best_score = 0.
+ es_cnt = 0
+ if opt.start_epoch is None:
+ start_epoch = -1 if opt.eval_init else 0
+ else:
+ start_epoch = opt.start_epoch
+ save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
+ if epoch_i > -1:
+ train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer)
+ lr_scheduler.step()
+ eval_epoch_interval = opt.eval_epoch
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
+ with torch.no_grad():
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
+ eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
+
+ # log
+ to_write = opt.eval_log_txt_formatter.format(
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
+ epoch=epoch_i,
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
+ eval_metrics_str=json.dumps(metrics_no_nms))
+
+ with open(opt.eval_log_filepath, "a") as f:
+ f.write(to_write)
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
+ if metrics_nms is not None:
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
+
+ metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
+ for k, v in metrics["brief"].items():
+ tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
+
+ # stop_score = metrics["brief"]["MR-full-mAP"]
+ # pdb.set_trace()
+ stop_score = metrics["brief"][opt.main_metric]
+ if stop_score > prev_best_score:
+ es_cnt = 0
+ prev_best_score = stop_score
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
+
+ best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
+ for src, tgt in zip(latest_file_paths, best_file_paths):
+ os.renames(src, tgt)
+ logger.info("The checkpoint file has been updated.")
+ else:
+ es_cnt += 1
+ if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
+ with open(opt.train_log_filepath, "a") as f:
+ f.write(f"Early Stop at epoch {epoch_i}")
+ logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
+ break
+
+ # save ckpt
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
+
+ if (epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
+
+ if opt.debug:
+ break
+
+ tb_writer.close()
+
+
+def start_training():
+ logger.info("Setup config, data and model...")
+ opt = BaseOptions().parse()
+ set_seed(opt.seed)
+ if opt.debug: # keep the model run deterministically
+ # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
+ # Enable this only when input size is fixed.
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+
+ dataset_config = dict(
+ dset_name=opt.dset_name,
+ data_path=opt.train_path,
+ v_feat_dirs=opt.v_feat_dirs,
+ q_feat_dir=opt.t_feat_dir,
+ v_feat_dim=opt.v_feat_dim,
+ q_feat_dim=opt.t_feat_dim,
+ q_feat_type="last_hidden_state",
+ max_q_l=opt.max_q_l,
+ max_v_l=opt.max_v_l,
+ ctx_mode=opt.ctx_mode,
+ data_ratio=opt.data_ratio,
+ normalize_v=not opt.no_norm_vfeat,
+ normalize_t=not opt.no_norm_tfeat,
+ clip_len=opt.clip_length,
+ max_windows=opt.max_windows,
+ span_loss_type=opt.span_loss_type,
+ txt_drop_ratio=opt.txt_drop_ratio,
+ use_cache=opt.use_cache,
+ add_easy_negative=opt.add_easy_negative,
+ easy_negative_only=opt.easy_negative_only
+ )
+
+ dataset_config["data_path"] = opt.train_path
+ train_dataset = DatasetMR(**dataset_config)
+
+ if opt.eval_path is not None:
+ dataset_config["data_path"] = opt.eval_path
+ dataset_config["txt_drop_ratio"] = 0
+ dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining
+ # dataset_config["load_labels"] = False # uncomment to calculate eval loss
+ eval_dataset = DatasetMR(**dataset_config)
+ else:
+ eval_dataset = None
+
+ if opt.lr_warmup > 0:
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
+ total_steps = opt.n_epoch
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
+ opt.lr_warmup = [warmup_steps, total_steps]
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
+ logger.info(f"Model {model}")
+ count_parameters(model)
+ logger.info("Start Training...")
+ train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
+ return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
+
+
+if __name__ == '__main__':
+ best_ckpt_path, eval_split_name, eval_path, debug = start_training()
+ if not debug:
+ input_args = ["--resume", best_ckpt_path,
+ "--eval_split_name", eval_split_name,
+ "--eval_path", eval_path]
+
+ import sys
+ sys.argv[1:] = input_args
+ logger.info("\n\n\nFINISHED TRAINING!!!")
+ logger.info("Evaluating model at {}".format(best_ckpt_path))
+ logger.info("Input args {}".format(sys.argv[1:]))
+ start_inference()
\ No newline at end of file
diff --git a/main/train_qfvs.py b/main/train_qfvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..65a3b155ed6d432da7eb0072f87e1bae18d8a994
--- /dev/null
+++ b/main/train_qfvs.py
@@ -0,0 +1,325 @@
+import os
+import pdb
+import time
+import json
+import pprint
+import random
+import importlib
+import numpy as np
+from tqdm import tqdm, trange
+from collections import defaultdict
+
+import h5py
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+import sys
+sys.path.append('/Users/kevin/univtg')
+from main.config import BaseOptions, setup_model
+from main.dataset_qfvs import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs
+from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle, l2_normalize_np_array
+from utils.model_utils import count_parameters
+from eval.qfvs import calculate_semantic_matching, load_videos_tag
+
+import logging
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def eval_epoch(model, config, opt):
+ model.eval()
+ f1_sum = 0; p_sum = 0; r_sum = 0
+
+ assert len(config['test_videos']) == 1
+ video_id = config['test_videos'][0]
+ embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl")
+
+ feat_type = config['vid_feature']
+ feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r')
+ features = torch.from_numpy(feat['features'][()])
+ seg_len = torch.from_numpy(feat['seg_len'][()])
+ # seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda()
+
+ # dim = features.shape[-1]
+ # ctx_l = seg_len.sum().cpu()
+
+ # dim = features.shape[-1]
+ # ctx_l = features.shape[1]
+ # seg_len = torch.ones(ctx_l)
+ # features = features.reshape(-1, dim)[:ctx_l]
+
+ # tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
+ # tef_ed = tef_st + 1.0 / ctx_l
+ # tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2)
+ # features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
+
+ transfer = {"Cupglass": "Glass",
+ "Musicalinstrument": "Instrument",
+ "Petsanimal": "Animal"}
+
+ for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
+ evaluation_num=len(files)
+
+ mask_GT = torch.zeros(config["max_segment_num"], config["max_frame_num"], dtype=torch.bool).cuda()
+ for j in range(len(seg_len)):
+ for k in range(seg_len[j]):
+ mask_GT[j][k] = 1
+
+ for file in files:
+ summaries_GT=[]
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f:
+ for line in f.readlines():
+ summaries_GT.append(int(line.strip()))
+
+ concept1, concept2 = file.split('_')[0:2]
+
+ ##############
+ if concept1 in transfer:
+ concept1 = transfer[concept1]
+ if concept2 in transfer:
+ concept2 = transfer[concept2]
+ concept1 = embedding[concept1]
+ concept2 = embedding[concept2]
+
+ concept1 = l2_normalize_np_array(concept1)
+ concept2 = l2_normalize_np_array(concept2)
+
+ data = {
+ 'features':features,
+ 'seg_len': seg_len,
+ 'tokens_pad1':torch.from_numpy(concept1),
+ 'tokens_pad2':torch.from_numpy(concept2),
+ 'mask_GT': mask_GT
+ }
+
+ input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True)
+
+ summaries_GT = [x - 1 for x in summaries_GT]
+ video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat")
+
+ if opt.f_loss_coef == 0:
+ output_type = 'saliency_scores'
+ elif opt.s_loss_intra_coef == 0:
+ output_type = 'pred_logits'
+ else:
+ if config['qfvs_score_ensemble'] > 0:
+ output_type = ['pred_logits', 'saliency_scores']
+ else:
+ output_type = 'pred_logits'
+
+ with torch.no_grad():
+ if not isinstance(output_type, list):
+ score1 = model(**input1)[output_type].squeeze()
+ score1 = score1.masked_select(mask_GT)
+
+ score2 = model(**input2)[output_type].squeeze()
+ score2 = score2.masked_select(mask_GT)
+
+ score = model(**input_oracle)[output_type].squeeze()
+ score = score.masked_select(mask_GT)
+ else:
+ score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda()
+ for output_t in output_type:
+ score1 += model(**input1)[output_t].squeeze().masked_select(mask_GT)
+ score2 += model(**input2)[output_t].squeeze().masked_select(mask_GT)
+ score += model(**input_oracle)[output_t].squeeze().masked_select(mask_GT)
+
+ if config['qfvs_score_gather'] > 0:
+ score = score + score1 + score2
+ else:
+ score = score
+
+ # since video4 features dim is greater than video_shots_tag.
+ score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])]
+ _, top_index = score.topk(int(score.shape[0] * config["top_percent"]))
+
+ p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1)
+ f1_sum+=f1; r_sum+=r; p_sum+=p
+
+ return {'F': round(100* f1_sum/evaluation_num,2) ,
+ 'R': round(100* r_sum/evaluation_num,2) ,
+ 'P': round(100* p_sum/evaluation_num,2) }
+
+def idx2time(idx):
+ sec1, sec2 = idx*5, (idx+1)*5
+
+ h1 = sec1 // 3600
+ m1 = (sec1 - h1*3600) // 60
+ s1 = sec1 % 60
+
+ h2 = sec2 // 3600
+ m2 = (sec2 - h2*3600) // 60
+ s2 = sec2 % 60
+ print(h1,m1,s1,'\t', h2,m2,s2)
+
+def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer):
+ model.train()
+ criterion.train()
+
+ # init meters
+ time_meters = defaultdict(AverageMeter)
+ loss_meters = defaultdict(AverageMeter)
+
+ timer_dataloading = time.time()
+ loss_total = 0
+
+ for batch_idx, batch in enumerate(tqdm(train_loader)):
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
+ timer_start = time.time()
+ model_input1, model_input2, model_input_oracle, \
+ model_gt1, model_gt2, model_gt_oracle, \
+ mask_GT = prepare_batch_inputs_qfvs(batch, config)
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ output1 = model(**model_input1)
+ output2 = model(**model_input2)
+ output_oracle = model(**model_input_oracle)
+
+ loss_dict = {}
+ loss_dict1 = criterion(output1, model_gt1, mask_GT)
+ loss_dict2 = criterion(output2, model_gt2, mask_GT)
+ loss_dict3 = criterion(output_oracle, model_gt_oracle, mask_GT)
+
+ weight_dict = criterion.weight_dict
+ if config['qfvs_loss_gather'] > 0:
+ for k in loss_dict1.keys():
+ loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k]
+ else:
+ loss_dict = loss_dict3
+
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+ loss_total += losses.item()
+
+ time_meters["model_forward_time"].update(time.time() - timer_start)
+ timer_start = time.time()
+ optimizer.zero_grad()
+ losses.backward()
+ if opt.grad_clip > 0:
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
+ optimizer.step()
+ time_meters["model_backward_time"].update(time.time() - timer_start)
+
+ timer_dataloading = time.time()
+ return round(loss_total / len(train_loader), 2)
+
+# train in single domain.
+def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config):
+ # if opt.device.type == "cuda":
+ # logger.info("CUDA enabled.")
+ # model.to(opt.device)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
+
+ prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0}
+ if opt.start_epoch is None:
+ start_epoch = -1 if opt.eval_init else 0
+ else:
+ start_epoch = opt.start_epoch
+
+ val_score = eval_epoch(model, config, opt)
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0)
+ logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
+ if epoch_i > -1:
+ loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer)
+ lr_scheduler.step()
+ eval_epoch_interval = opt.eval_epoch
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
+ with torch.no_grad():
+ val_score = eval_epoch(model, config, opt)
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1)
+ logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
+
+ if prev_best_score['Fscore'] < val_score['F']:
+ prev_best_score['Fscore'] = val_score['F']
+ prev_best_score['Precision'] = val_score['P']
+ prev_best_score['Recall'] = val_score['R']
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt"))
+ tb_writer.close()
+ return prev_best_score
+
+def update_config(opt, config):
+ # for key in ["max_segment_num", "max_frame_num", "top_percent",
+ # "qfvs_vid_feature", "qfvs_txt_feature", "qfvs_dense_shot",
+ # "qfvs_score_ensemble", "qfvs_score_gather", "qfvs_loss_gather"]:
+ config["max_segment_num"] = opt.max_segment_num
+ config["max_frame_num"] = opt.max_frame_num
+ config["top_percent"] = opt.top_percent
+ config["vid_feature"] = opt.qfvs_vid_feature
+ config["txt_feature"] = opt.qfvs_txt_feature
+ config["qfvs_dense_shot"] = opt.qfvs_dense_shot
+ config["qfvs_score_ensemble"] = opt.qfvs_score_ensemble
+ config["qfvs_score_gather"] = opt.qfvs_score_gather
+ config["qfvs_loss_gather"] = opt.qfvs_loss_gather
+ return config
+
+def start_training():
+ logger.info("Setup config, data and model...")
+ opt = BaseOptions().parse()
+ set_seed(opt.seed)
+
+ # config = load_json("./main/config_qfvs.json")
+ config = {}
+ config = update_config(opt, config)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+
+ # key -> test video; value -> training videos.
+ qfvs_split = {
+ 1: [2, 3, 4],
+ 2: [1, 3, 4],
+ 3: [1, 2, 4],
+ 4: [1, 2, 3]
+ }
+
+ scores_videos = {}
+ for test_id, splits in qfvs_split.items():
+ logger.info(f"Start Training {opt.dset_name}: {test_id}")
+ config['train_videos'] = qfvs_split[test_id]
+ config['test_videos'] = [test_id]
+ train_dataset = DatasetQFVS(config)
+ train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers)
+
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
+ count_parameters(model)
+ best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config)
+ scores_videos['V'+str(test_id)] = best_score
+
+ # save the final results.
+ avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos)
+ avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos)
+ avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos)
+ scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall}
+
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
+ save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False)
+
+ tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1)
+ tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None))
+ tb_writer.close()
+
+ print(scores_videos)
+ return
+
+if __name__ == '__main__':
+ start_training()
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
diff --git a/main/train_vlp.py b/main/train_vlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..feed89496947c621c1a20f4f9326bcd13ec1fa52
--- /dev/null
+++ b/main/train_vlp.py
@@ -0,0 +1,278 @@
+import os
+import pdb
+import sys
+import time
+import json
+import pprint
+import random
+import numpy as np
+from tqdm import tqdm, trange
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+sys.path.append('/data/home/qinghonglin/univtg')
+from main.config import BaseOptions, setup_model
+from main.dataset import \
+ DatasetVLP, start_end_collate_mr, prepare_batch_inputs_mr
+from main.inference_mr import eval_epoch, start_inference
+from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown
+from utils.model_utils import count_parameters
+
+import logging
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer, cls=None):
+ logger.info(f"[Epoch {epoch_i+1}]")
+ model.train()
+ criterion.train()
+
+ # init meters
+ time_meters = defaultdict(AverageMeter)
+ loss_meters = defaultdict(AverageMeter)
+
+ num_training_examples = len(train_loader)
+ timer_dataloading = time.time()
+ for batch_idx, batch in tqdm(enumerate(train_loader),
+ desc="Training Iteration",
+ total=num_training_examples):
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
+
+ timer_start = time.time()
+ model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory)
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+
+ if cls is not None:
+ model_inputs.update(cls)
+
+ # try:
+ outputs = model(**model_inputs)
+ loss_dict = criterion(outputs, targets)
+ weight_dict = criterion.weight_dict
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+ time_meters["model_forward_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ optimizer.zero_grad()
+ losses.backward()
+ # except:
+ # pdb.set_trace()
+ if opt.grad_clip > 0:
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
+ optimizer.step()
+ time_meters["model_backward_time"].update(time.time() - timer_start)
+
+ loss_dict["loss_overall"] = float(losses) # for logging only
+ for k, v in loss_dict.items():
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
+
+ timer_dataloading = time.time()
+
+ # print/add logs
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
+ for k, v in loss_meters.items():
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
+
+ to_write = opt.train_log_txt_formatter.format(
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
+ epoch=epoch_i+1,
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
+ with open(opt.train_log_filepath, "a") as f:
+ f.write(to_write)
+
+ logger.info("Epoch time stats:")
+ for name, meter in time_meters.items():
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
+ logger.info(f"{name} ==> {d}")
+
+
+def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
+ if opt.device.type == "cuda":
+ logger.info("CUDA enabled.")
+ model.to(opt.device)
+
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
+
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=start_end_collate_mr,
+ batch_size=opt.bsz,
+ num_workers=opt.num_workers,
+ shuffle=True,
+ pin_memory=opt.pin_memory
+ )
+
+ if ('tal' in opt.train_path) or ('mq' in opt.train_path):
+ cls = {
+ 'src_cls': train_dataset.src_cls.cuda(),
+ 'src_cls_mask': train_dataset.src_cls_mask.cuda(),}
+ else:
+ cls = None
+
+ prev_best_score = 0.
+ es_cnt = 0
+ if opt.start_epoch is None:
+ start_epoch = -1 if opt.eval_init else 0
+ else:
+ start_epoch = opt.start_epoch
+ save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
+ if epoch_i > -1:
+ train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer, cls)
+ lr_scheduler.step()
+ eval_epoch_interval = opt.eval_epoch
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
+ with torch.no_grad():
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
+ eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
+
+ # log
+ to_write = opt.eval_log_txt_formatter.format(
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
+ epoch=epoch_i,
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
+ eval_metrics_str=json.dumps(metrics_no_nms))
+
+ with open(opt.eval_log_filepath, "a") as f:
+ f.write(to_write)
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
+ if metrics_nms is not None:
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
+
+ metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
+ for k, v in metrics["brief"].items():
+ tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
+
+ # stop_score = metrics["brief"]["MR-full-mAP"]
+ # pdb.set_trace()
+ stop_score = metrics["brief"][opt.main_metric]
+ if stop_score > prev_best_score:
+ es_cnt = 0
+ prev_best_score = stop_score
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
+
+ best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
+ for src, tgt in zip(latest_file_paths, best_file_paths):
+ os.renames(src, tgt)
+ logger.info("The checkpoint file has been updated.")
+ else:
+ es_cnt += 1
+ if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
+ with open(opt.train_log_filepath, "a") as f:
+ f.write(f"Early Stop at epoch {epoch_i}")
+ logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
+ break
+
+ # save ckpt
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
+
+ if (epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
+
+ if opt.debug:
+ break
+
+ tb_writer.close()
+
+
+def start_training():
+ logger.info("Setup config, data and model...")
+ opt = BaseOptions().parse()
+ set_seed(opt.seed)
+ if opt.debug: # keep the model run deterministically
+ # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
+ # Enable this only when input size is fixed.
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+
+ dataset_config = dict(
+ dset_name=opt.dset_name,
+ data_path=opt.train_path,
+ v_feat_dirs=opt.v_feat_dirs,
+ q_feat_dir=opt.t_feat_dir,
+ v_feat_dim=opt.v_feat_dim,
+ q_feat_dim=opt.t_feat_dim,
+ q_feat_type="last_hidden_state",
+ max_q_l=opt.max_q_l,
+ max_v_l=opt.max_v_l,
+ ctx_mode=opt.ctx_mode,
+ data_ratio=opt.data_ratio,
+ normalize_v=not opt.no_norm_vfeat,
+ normalize_t=not opt.no_norm_tfeat,
+ clip_len=opt.clip_length,
+ max_windows=opt.max_windows,
+ span_loss_type=opt.span_loss_type,
+ txt_drop_ratio=opt.txt_drop_ratio,
+ use_cache=opt.use_cache,
+ add_easy_negative=opt.add_easy_negative,
+ easy_negative_only=opt.easy_negative_only
+ )
+
+ dataset_config["data_path"] = opt.train_path
+ train_dataset = DatasetVLP(**dataset_config)
+
+ if opt.eval_path is not None:
+ dataset_config["data_path"] = opt.eval_path
+ dataset_config["txt_drop_ratio"] = 0
+ dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining
+ # dataset_config["load_labels"] = False # uncomment to calculate eval loss
+ eval_dataset = DatasetVLP(**dataset_config)
+ else:
+ eval_dataset = None
+
+ if opt.lr_warmup > 0:
+ opt.lr_warmup = opt.n_epoch
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
+ logger.info(f"Model {model}")
+ count_parameters(model)
+ logger.info("Start Training...")
+ train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
+ return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
+
+
+if __name__ == '__main__':
+ best_ckpt_path, eval_split_name, eval_path, debug = start_training()
+ if not debug:
+ input_args = ["--resume", best_ckpt_path,
+ "--eval_split_name", eval_split_name,
+ "--eval_path", eval_path]
+
+ import sys
+ sys.argv[1:] = input_args
+ logger.info("\n\n\nFINISHED TRAINING!!!")
+ logger.info("Evaluating model at {}".format(best_ckpt_path))
+ logger.info("Input args {}".format(sys.argv[1:]))
+ start_inference()
\ No newline at end of file
diff --git a/main/train_vlp_ddp.py b/main/train_vlp_ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c5d735a1676dc53828ee9c97c048fe7641ba16
--- /dev/null
+++ b/main/train_vlp_ddp.py
@@ -0,0 +1,288 @@
+import os
+import pdb
+import sys
+import time
+import json
+import pprint
+import random
+import numpy as np
+from tqdm import tqdm, trange
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from torch.utils.data.distributed import DistributedSampler
+
+sys.path.append('/data/home/qinghonglin/univtg')
+from main.config import BaseOptions, setup_model
+from main.dataset import \
+ DatasetMR, DatasetVLP, start_end_collate_mr, prepare_batch_inputs_mr
+from main.inference_mr import eval_epoch, start_inference
+from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown
+from utils.model_utils import count_parameters
+
+import logging
+logger = logging.getLogger(__name__)
+logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+
+def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer):
+ logger.info(f"[Epoch {epoch_i+1}]")
+ model.train()
+ criterion.train()
+
+ # init meters
+ time_meters = defaultdict(AverageMeter)
+ loss_meters = defaultdict(AverageMeter)
+
+ num_training_examples = len(train_loader)
+ timer_dataloading = time.time()
+ for batch_idx, batch in tqdm(enumerate(train_loader),
+ desc="Training Iteration",
+ total=num_training_examples):
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
+
+ timer_start = time.time()
+ model_inputs, targets = prepare_batch_inputs_mr(batch[1], torch.device("cuda", int(opt.local_rank)), non_blocking=opt.pin_memory)
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+
+ # try:
+ outputs = model(**model_inputs)
+ loss_dict = criterion(outputs, targets)
+ weight_dict = criterion.weight_dict
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+ time_meters["model_forward_time"].update(time.time() - timer_start)
+
+ timer_start = time.time()
+ optimizer.zero_grad()
+ losses.backward()
+
+ if opt.grad_clip > 0:
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
+ optimizer.step()
+ time_meters["model_backward_time"].update(time.time() - timer_start)
+
+ loss_dict["loss_overall"] = float(losses) # for logging only
+ for k, v in loss_dict.items():
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
+
+ timer_dataloading = time.time()
+
+ # print/add logs
+ if int(opt.local_rank) in [0, -1]:
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
+ for k, v in loss_meters.items():
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
+
+ to_write = opt.train_log_txt_formatter.format(
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
+ epoch=epoch_i+1,
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
+ with open(opt.train_log_filepath, "a") as f:
+ f.write(to_write)
+
+ logger.info("Epoch time stats:")
+ for name, meter in time_meters.items():
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
+ logger.info(f"{name} ==> {d}")
+
+
+def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
+ if int(opt.local_rank) in [0, -1]:
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
+ else:
+ tb_writer = None
+
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=start_end_collate_mr,
+ batch_size=opt.bsz,
+ num_workers=opt.num_workers,
+ # shuffle=True,
+ pin_memory=opt.pin_memory,
+ sampler=DistributedSampler(train_dataset)
+ )
+
+ prev_best_score = 0.
+ es_cnt = 0
+ if opt.start_epoch is None:
+ start_epoch = -1 if opt.eval_init else 0
+ else:
+ start_epoch = opt.start_epoch
+ save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
+ if epoch_i > -1:
+ train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer)
+ lr_scheduler.step()
+ eval_epoch_interval = opt.eval_epoch
+ if int(opt.local_rank) in [0, -1] and opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
+ with torch.no_grad():
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
+ eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
+
+ # log
+ to_write = opt.eval_log_txt_formatter.format(
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
+ epoch=epoch_i,
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
+ eval_metrics_str=json.dumps(metrics_no_nms))
+
+ if int(opt.local_rank) in [0, -1]:
+ with open(opt.eval_log_filepath, "a") as f:
+ f.write(to_write)
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
+ if metrics_nms is not None:
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
+
+ metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
+ for k, v in metrics["brief"].items():
+ tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
+
+ # stop_score = metrics["brief"]["MR-full-mAP"]
+ # pdb.set_trace()
+ stop_score = metrics["brief"][opt.main_metric]
+ if stop_score > prev_best_score:
+ es_cnt = 0
+ prev_best_score = stop_score
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
+
+ best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
+ for src, tgt in zip(latest_file_paths, best_file_paths):
+ os.renames(src, tgt)
+ logger.info("The checkpoint file has been updated.")
+ else:
+ es_cnt += 1
+ if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
+ with open(opt.train_log_filepath, "a") as f:
+ f.write(f"Early Stop at epoch {epoch_i}")
+ logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
+ break
+
+ # save ckpt
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
+
+ if int(opt.local_rank) in [0, -1] and ((epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0): # additional copies
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "epoch": epoch_i,
+ "opt": opt
+ }
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
+
+ if opt.debug:
+ break
+
+ if int(opt.local_rank) in [0, -1]:
+ tb_writer.close()
+
+
+def start_training():
+ # logger.info("Setup config, data and model...")
+ opt = BaseOptions().parse()
+ set_seed(opt.seed)
+ if opt.debug: # keep the model run deterministically
+ # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
+ # Enable this only when input size is fixed.
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+
+ local_rank = int(opt.local_rank)
+ dist.init_process_group(backend='nccl')
+
+ torch.cuda.set_device(local_rank)
+ device = torch.device("cuda", local_rank)
+
+ dataset_config = dict(
+ dset_name=opt.dset_name,
+ data_path=opt.train_path,
+ v_feat_dirs=opt.v_feat_dirs,
+ q_feat_dir=opt.t_feat_dir,
+ v_feat_dim=opt.v_feat_dim,
+ q_feat_dim=opt.t_feat_dim,
+ q_feat_type="last_hidden_state",
+ max_q_l=opt.max_q_l,
+ max_v_l=opt.max_v_l,
+ ctx_mode=opt.ctx_mode,
+ data_ratio=opt.data_ratio,
+ normalize_v=not opt.no_norm_vfeat,
+ normalize_t=not opt.no_norm_tfeat,
+ clip_len=opt.clip_length,
+ max_windows=opt.max_windows,
+ span_loss_type=opt.span_loss_type,
+ txt_drop_ratio=opt.txt_drop_ratio,
+ use_cache=opt.use_cache,
+ add_easy_negative=opt.add_easy_negative,
+ easy_negative_only=opt.easy_negative_only
+ )
+
+ dataset_config["data_path"] = opt.train_path
+ train_dataset = DatasetVLP(**dataset_config)
+
+ if opt.eval_path is not None:
+ # perform zero-shot on qvhl.
+ dataset_config["data_path"] = opt.eval_path
+ dataset_config["txt_drop_ratio"] = 0
+ if len(dataset_config["v_feat_dirs"]) == 1:
+ dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_clip"]
+ elif len(dataset_config["v_feat_dirs"]) == 2:
+ dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_slowfast", "data/qvhighlights/vid_clip"]
+ else:
+ raise NotImplementedError
+ dataset_config["q_feat_dir"] = "data/qvhighlights/txt_clip"
+ dataset_config["data_ratio"] = 1
+ # dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining
+ eval_dataset = DatasetMR(**dataset_config)
+ else:
+ eval_dataset = None
+
+ if opt.lr_warmup > 0:
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
+ total_steps = opt.n_epoch
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
+ opt.lr_warmup = [warmup_steps, total_steps]
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
+
+ model.to(device)
+ logger.info(f"Using {torch.cuda.device_count()} GPUs.")
+ model = torch.nn.parallel.DistributedDataParallel(model,
+ device_ids=[local_rank],
+ output_device=local_rank,
+ find_unused_parameters=True)
+
+ if int(opt.local_rank) in [0, -1]:
+ logger.info(f"Model {model}")
+ count_parameters(model)
+ logger.info("Start Training...")
+ train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
+ # return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
+ return
+
+if __name__ == '__main__':
+ # best_ckpt_path, eval_split_name, eval_path, debug = start_training()
+ start_training()
diff --git a/model/base.py b/model/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30e473b5490bd74ca180db4b5e46945a8c4fcaa
--- /dev/null
+++ b/model/base.py
@@ -0,0 +1,449 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ mask = targets['timestamp_mask'].bool()
+ mask_valid = targets['timestamp_window'].bool()
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ target_classes[mask_valid] = 1
+ # target_classes = targets['timestamp_window'] # soft cls.
+ target_classes.float()
+ # pdb.set_trace()
+
+ weights = torch.zeros_like(target_classes).float()
+ weights[mask] = self.empty_weight[1]
+ weights[mask_valid] = self.empty_weight[0]
+
+ # pdb.set_trace()
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+ return {"loss_f": loss_ce.sum() / mask.sum()}
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ mask = targets['timestamp_mask']
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ neg_indices_in = (saliency_scores < selected_scores)
+ neg_indices_in[batch_indices, pos_indices] = True
+ mask_invalid = neg_indices_in * mask.bool()
+
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
+
+ loss_saliency_intra = - loss_in_i - loss_in_j
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, hl_only=False):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr', 'vlp']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
\ No newline at end of file
diff --git a/model/base_albef.py b/model/base_albef.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f067b66558cfb3e380126ea8e9740cb062790b4
--- /dev/null
+++ b/model/base_albef.py
@@ -0,0 +1,478 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder import build_transformer, Transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer_mm, transformer_v, transformer_t, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer_mm
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer_mm.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ self.transformer_v = transformer_v
+ self.transformer_t = transformer_t
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+
+ # pos embed.
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+
+ src_vid = self.transformer_v(src_vid, ~src_vid_mask.bool(), pos_vid)
+ src_txt = self.transformer_t(src_txt, ~src_txt_mask.bool(), pos_txt)
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ mask = targets['timestamp_mask'].bool()
+ mask_valid = targets['timestamp_window'].bool()
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ target_classes[mask_valid] = 1
+ # target_classes = targets['timestamp_window'] # soft cls.
+ target_classes.float()
+ # pdb.set_trace()
+
+ weights = torch.zeros_like(target_classes).float()
+ weights[mask] = self.empty_weight[1]
+ weights[mask_valid] = self.empty_weight[0]
+
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+ return {"loss_f": loss_ce.sum() / mask.sum()}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ mask = targets['timestamp_mask']
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ neg_indices_in = (saliency_scores < selected_scores)
+ neg_indices_in[batch_indices, pos_indices] = True
+ mask_invalid = neg_indices_in * mask.bool()
+
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
+
+ loss_saliency_intra = - loss_in_i - loss_in_j
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, hl_only=False):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer_mm = build_transformer(args)
+ transformer_v = Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.sub_enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ )
+ transformer_t = Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.sub_enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ )
+ # pdb.set_trace()
+
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer_mm,
+ transformer_v,
+ transformer_t,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr', 'vlp']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
\ No newline at end of file
diff --git a/model/base_droppath.py b/model/base_droppath.py
new file mode 100644
index 0000000000000000000000000000000000000000..1be7420925310f7f6e29b6a46df788b93227b4d4
--- /dev/null
+++ b/model/base_droppath.py
@@ -0,0 +1,449 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder_droppath import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ mask = targets['timestamp_mask'].bool()
+ mask_valid = targets['timestamp_window'].bool()
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ target_classes[mask_valid] = 1
+ # target_classes = targets['timestamp_window'] # soft cls.
+ target_classes.float()
+ # pdb.set_trace()
+
+ weights = torch.zeros_like(target_classes).float()
+ weights[mask] = self.empty_weight[1]
+ weights[mask_valid] = self.empty_weight[0]
+
+ # pdb.set_trace()
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+ return {"loss_f": loss_ce.sum() / mask.sum()}
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ mask = targets['timestamp_mask']
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ neg_indices_in = (saliency_scores < selected_scores)
+ neg_indices_in[batch_indices, pos_indices] = True
+ mask_invalid = neg_indices_in * mask.bool()
+
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
+
+ loss_saliency_intra = - loss_in_i - loss_in_j
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, hl_only=False):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr', 'vlp']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
diff --git a/model/base_droppath_ablation.py b/model/base_droppath_ablation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3fb06a293c0a7130715d53cecc9b98406d70fdf
--- /dev/null
+++ b/model/base_droppath_ablation.py
@@ -0,0 +1,474 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder_droppath import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ weight_abalation_b = targets['weight_ablation'][:,0].unsqueeze(-1)
+ if weight_abalation_b.sum() == 0:
+ return {"loss_f": torch.tensor(0).cuda(), "loss_g": torch.tensor(0).cuda()}
+
+ mask_valid = (mask_valid * weight_abalation_b).bool()
+ mask_valid_full = (mask_valid_full * weight_abalation_b.unsqueeze(-1)).bool()
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ mask = targets['timestamp_mask'].bool()
+ mask_valid = targets['timestamp_window'].bool()
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ target_classes[mask_valid] = 1
+ # target_classes = targets['timestamp_window'] # soft cls.
+ target_classes.float()
+ # pdb.set_trace()
+
+ weights = torch.zeros_like(target_classes).float()
+ weights[mask] = self.empty_weight[1]
+ weights[mask_valid] = self.empty_weight[0]
+
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+
+ weight_abalation_f = targets['weight_ablation'][:,2].unsqueeze(-1)
+ if weight_abalation_f.sum() == 0:
+ return {"loss_f": torch.tensor(0).cuda()}
+
+ mask = mask * weight_abalation_f
+ loss_ce = loss_ce * weight_abalation_f
+ return {"loss_f": loss_ce.sum() / mask.sum()}
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+
+ weight_abalation_s = targets['weight_ablation'][:,3].bool()
+ if weight_abalation_s.sum() == 0:
+ return {"loss_s_inter": torch.tensor(0).cuda(),
+ "loss_s_intra": torch.tensor(0).cuda()}
+
+ _idiag = idiag[weight_abalation_s]
+ _jdiag = jdiag[weight_abalation_s]
+
+ loss_i = _idiag.sum() / len(_idiag)
+ loss_j = _jdiag.sum() / len(_jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ mask = targets['timestamp_mask']
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ neg_indices_in = (saliency_scores < selected_scores)
+ neg_indices_in[batch_indices, pos_indices] = True
+ mask_invalid = neg_indices_in * mask.bool()
+
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ _pos_logsm_in_i = pos_logsm_in_i[weight_abalation_s]
+ _pos_logsm_in_j = pos_logsm_in_j[weight_abalation_s]
+
+ loss_in_i = _pos_logsm_in_i.sum() / len(_pos_logsm_in_i)
+ loss_in_j = _pos_logsm_in_j.sum() / len(_pos_logsm_in_j)
+
+ loss_saliency_intra = - loss_in_i - loss_in_j
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, hl_only=False):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr', 'vlp']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
\ No newline at end of file
diff --git a/model/base_droppath_qfvs.py b/model/base_droppath_qfvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..7236540c8ae0ada454a66cfc8d1a0018e00287e8
--- /dev/null
+++ b/model/base_droppath_qfvs.py
@@ -0,0 +1,476 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder_droppath import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_f": 0.}
+
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ target_classes = targets["saliency_scores"].squeeze()
+
+ weights = torch.ones_like(target_classes).float() * self.empty_weight[1]
+ weights[target_classes.bool()] = self.empty_weight[0]
+
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none")
+ return {"loss_f": loss_ce.sum() / target_classes.sum()}
+ # return {"loss_f": loss_ce.sum() / len(target_classes)}
+
+ # mask = targets['timestamp_mask'].bool()
+ # mask_valid = targets['timestamp_window'].bool()
+ # target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ # target_classes[mask_valid] = 1
+ # # target_classes = targets['timestamp_window'] # soft cls.
+ # target_classes.float()
+ # # pdb.set_trace()
+
+ # weights = torch.zeros_like(target_classes).float()
+ # weights[mask] = self.empty_weight[1]
+ # weights[mask_valid] = self.empty_weight[0]
+
+ # loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+ # # return {"loss_f": loss_ce.sum() / mask.sum()}
+ # return {"loss_f": loss_ce.sum() / mask_valid.sum()}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * qfvs mil-nce mode
+ pos_indices = saliency_scores.squeeze() > 0
+
+ sim = outputs['saliency_scores']
+ sim_soft = F.softmax(sim / self.temperature, dim=0)
+ sim_log = torch.log(sim_soft[pos_indices])
+ loss_saliency_intra = -sim_log.sum() / len(sim_log)
+ return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra}
+
+ # * inter-vid mode
+ # vid_mem_proj = outputs["vid_mem_proj"]
+ # pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ # batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ # vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ # txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ # sim = sim_matrix(vid_feats, txt_feats)
+
+ # i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ # j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # # sum over positives
+ # idiag = torch.diag(i_logsm)
+ # jdiag = torch.diag(j_logsm)
+ # loss_i = idiag.sum() / len(idiag)
+ # loss_j = jdiag.sum() / len(jdiag)
+
+ # loss_saliency_inter = - loss_i - loss_j
+
+ # # * intra-vid mode
+ # mask = targets['timestamp_mask']
+ # selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ # neg_indices_in = (saliency_scores < selected_scores)
+ # neg_indices_in[batch_indices, pos_indices] = True
+ # mask_invalid = neg_indices_in * mask.bool()
+
+ # sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ # sim_in = sim_in + (mask_invalid + 1e-45).log()
+ # logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ # logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ # pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ # pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ # loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
+ # loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
+
+ # loss_saliency_intra = - loss_in_i - loss_in_j
+
+ # return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, mask_GT=None):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0])
+ count = mask_GT.sum()
+ outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0])
+ # targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0])
+ targets['saliency_scores'] = targets['saliency_scores'][0,:count]
+
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr', 'vlp']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
\ No newline at end of file
diff --git a/model/base_prompt.py b/model/base_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..5816b7429f3c8be69ca8c3f4322a11ade60b8217
--- /dev/null
+++ b/model/base_prompt.py
@@ -0,0 +1,460 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.prompt_learner = nn.Embedding(10, hidden_dim)
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+
+ src_prompt = self.prompt_learner.weight.unsqueeze(0).repeat(bs, 1, 1)
+ src_prompt_mask = torch.ones((bs, src_prompt.shape[1])).cuda()
+
+ if self.training:
+ # src_txt = src_prompt
+ # src_txt_mask = torch.ones_like(src_prompt).cuda()
+ src_txt = torch.cat([src_prompt, src_txt], dim=1)
+ src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1)
+ else:
+ src_txt = torch.cat([src_prompt, src_txt], dim=1)
+ src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1)
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ mask = targets['timestamp_mask'].bool()
+ mask_valid = targets['timestamp_window'].bool()
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ target_classes[mask_valid] = 1
+ # target_classes = targets['timestamp_window'] # soft cls.
+ target_classes.float()
+ # pdb.set_trace()
+
+ weights = torch.zeros_like(target_classes).float()
+ weights[mask] = self.empty_weight[1]
+ weights[mask_valid] = self.empty_weight[0]
+
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+ return {"loss_f": loss_ce.sum() / mask.sum()}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ mask = targets['timestamp_mask']
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ neg_indices_in = (saliency_scores < selected_scores)
+ neg_indices_in[batch_indices, pos_indices] = True
+ mask_invalid = neg_indices_in * mask.bool()
+
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
+
+ loss_saliency_intra = - loss_in_i - loss_in_j
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, hl_only=False):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
\ No newline at end of file
diff --git a/model/base_qfvs.py b/model/base_qfvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..19e9553ef35763eebd97125964c7c4864b9dd7db
--- /dev/null
+++ b/model/base_qfvs.py
@@ -0,0 +1,476 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder_droppath import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_f": 0.}
+
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ target_classes = targets["saliency_scores"].squeeze()
+
+ weights = torch.ones_like(target_classes).float() * self.empty_weight[1]
+ weights[target_classes.bool()] = self.empty_weight[0]
+
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none")
+ # pdb.set_trace()
+ return {"loss_f": loss_ce.sum() / target_classes.sum()}
+ # return {"loss_f": loss_ce.sum() / len(target_classes)}
+
+ # mask = targets['timestamp_mask'].bool()
+ # mask_valid = targets['timestamp_window'].bool()
+ # target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ # target_classes[mask_valid] = 1
+ # # target_classes = targets['timestamp_window'] # soft cls.
+ # target_classes.float()
+ # # pdb.set_trace()
+
+ # weights = torch.zeros_like(target_classes).float()
+ # weights[mask] = self.empty_weight[1]
+ # weights[mask_valid] = self.empty_weight[0]
+
+ # loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+ # # return {"loss_f": loss_ce.sum() / mask.sum()}
+ # return {"loss_f": loss_ce.sum() / mask_valid.sum()}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * qfvs mil-nce mode
+ pos_indices = saliency_scores.squeeze() > 0
+
+ sim = outputs['saliency_scores']
+ sim_soft = F.softmax(sim / self.temperature, dim=0)
+ sim_log = torch.log(sim_soft[pos_indices])
+ loss_saliency_intra = -sim_log.sum() / len(sim_log)
+ return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra}
+
+ # * inter-vid mode
+ # vid_mem_proj = outputs["vid_mem_proj"]
+ # pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ # batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ # vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ # txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ # sim = sim_matrix(vid_feats, txt_feats)
+
+ # i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ # j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # # sum over positives
+ # idiag = torch.diag(i_logsm)
+ # jdiag = torch.diag(j_logsm)
+ # loss_i = idiag.sum() / len(idiag)
+ # loss_j = jdiag.sum() / len(jdiag)
+
+ # loss_saliency_inter = - loss_i - loss_j
+
+ # # * intra-vid mode
+ # mask = targets['timestamp_mask']
+ # selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ # neg_indices_in = (saliency_scores < selected_scores)
+ # neg_indices_in[batch_indices, pos_indices] = True
+ # mask_invalid = neg_indices_in * mask.bool()
+
+ # sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ # sim_in = sim_in + (mask_invalid + 1e-45).log()
+ # logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ # logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ # pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ # pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ # loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
+ # loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
+
+ # loss_saliency_intra = - loss_in_i - loss_in_j
+
+ # return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, mask_GT=None):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ # pdb.set_trace()
+ outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0])
+ outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0])
+ targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0])
+
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr', 'vlp']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
\ No newline at end of file
diff --git a/model/matcher.py b/model/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7a27af5fb40c43bfdc54a4277c022f0b040fe0c
--- /dev/null
+++ b/model/matcher.py
@@ -0,0 +1,107 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Modules to compute the matching cost and solve the corresponding LSAP.
+"""
+import torch
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+import torch.nn.functional as F
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+
+class HungarianMatcher(nn.Module):
+ """This class computes an assignment between the targets and the predictions of the network
+
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+ while the others are un-matched (and thus treated as non-objects).
+ """
+ def __init__(self, cost_class: float = 1, cost_span: float = 1, cost_giou: float = 1,
+ span_loss_type: str = "l1", max_v_l: int = 75):
+ """Creates the matcher
+
+ Params:
+ cost_span: This is the relative weight of the L1 error of the span coordinates in the matching cost
+ cost_giou: This is the relative weight of the giou loss of the spans in the matching cost
+ """
+ super().__init__()
+ self.cost_class = cost_class
+ self.cost_span = cost_span
+ self.cost_giou = cost_giou
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.foreground_label = 0
+ assert cost_class != 0 or cost_span != 0 or cost_giou != 0, "all costs cant be 0"
+
+ @torch.no_grad()
+ def forward(self, outputs, targets):
+ """ Performs the matching
+
+ Params:
+ outputs: This is a dict that contains at least these entries:
+ "pred_spans": Tensor of dim [batch_size, num_queries, 2] with the predicted span coordinates,
+ in normalized (cx, w) format
+ ""pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+ "spans": Tensor of dim [num_target_spans, 2] containing the target span coordinates. The spans are
+ in normalized (cx, w) format
+
+ Returns:
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_spans)
+ """
+ bs, num_queries = outputs["pred_spans"].shape[:2]
+ targets = targets["span_labels"]
+
+ # Also concat the target labels and spans
+ out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
+ tgt_spans = torch.cat([v["spans"] for v in targets]) # [num_target_spans in batch, 2]
+ tgt_ids = torch.full([len(tgt_spans)], self.foreground_label) # [total #spans in the batch]
+
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - prob[target class].
+ # The 1 is a constant that doesn't change the matching, it can be omitted.
+ cost_class = -out_prob[:, tgt_ids] # [batch_size * num_queries, total #spans in the batch]
+
+ if self.span_loss_type == "l1":
+ # We flatten to compute the cost matrices in a batch
+ out_spans = outputs["pred_spans"].flatten(0, 1) # [batch_size * num_queries, 2]
+
+ # Compute the L1 cost between spans
+ cost_span = torch.cdist(out_spans, tgt_spans, p=1) # [batch_size * num_queries, total #spans in the batch]
+
+ # Compute the giou cost between spans
+ # [batch_size * num_queries, total #spans in the batch]
+ cost_giou = - generalized_temporal_iou(span_cxw_to_xx(out_spans), span_cxw_to_xx(tgt_spans))
+ else:
+ pred_spans = outputs["pred_spans"] # (bsz, #queries, max_v_l * 2)
+ pred_spans = pred_spans.view(bs * num_queries, 2, self.max_v_l).softmax(-1) # (bsz * #queries, 2, max_v_l)
+ cost_span = - pred_spans[:, 0][:, tgt_spans[:, 0]] - \
+ pred_spans[:, 1][:, tgt_spans[:, 1]] # (bsz * #queries, #spans)
+ # pred_spans = pred_spans.repeat(1, n_spans, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, max_v_l, 2)
+ # tgt_spans = tgt_spans.view(1, n_spans, 2).repeat(bs * num_queries, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, 2)
+ # cost_span = pred_spans[tgt_spans]
+ # cost_span = cost_span.view(bs * num_queries, n_spans)
+
+ # giou
+ cost_giou = 0
+
+ # Final cost matrix
+ # import ipdb; ipdb.set_trace()
+ C = self.cost_span * cost_span + self.cost_giou * cost_giou + self.cost_class * cost_class
+ C = C.view(bs, num_queries, -1).cpu()
+
+ sizes = [len(v["spans"]) for v in targets]
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+
+
+def build_matcher(args):
+ return HungarianMatcher(
+ cost_span=args.set_cost_span, cost_giou=args.set_cost_giou,
+ cost_class=args.set_cost_class, span_loss_type=args.span_loss_type, max_v_l=args.max_v_l
+ )
diff --git a/model/moment_detr.py b/model/moment_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..29fcfcf8ca078a59f8078e176c1c361de6a2d460
--- /dev/null
+++ b/model/moment_detr.py
@@ -0,0 +1,462 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+DETR model and criterion classes.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+from model.transformer import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k
+ output: (#items, #classes)
+ target: int,
+ """
+ maxk = max(topk)
+ num_items = output.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target)
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / num_items))
+ return res
+
+class Model(nn.Module):
+ """ This is the Moment-DETR module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ num_queries, input_dropout, aux_loss=False,
+ contrastive_align_loss=False, contrastive_hdim=64,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ Moment-DETR can detect in a single video.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ contrastive_align_loss: If true, perform span - tokens contrastive learning
+ contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+ self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3)
+ self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ # self.foreground_thd = foreground_thd
+ # self.background_thd = background_thd
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.contrastive_align_loss = contrastive_align_loss
+ if contrastive_align_loss:
+ self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim)
+ self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim)
+ self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim)
+
+ self.saliency_proj = nn.Linear(hidden_dim, 1)
+ self.aux_loss = aux_loss
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask):
+ """The forward expects two tensors:
+ - src_txt: [batch_size, L_txt, D_txt]
+ - src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels,
+ will convert to 1 as padding later for transformer
+ - src_vid: [batch_size, L_vid, D_vid]
+ - src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels,
+ will convert to 1 as padding later for transformer
+
+ It returns a dict with the following elements:
+ - "pred_spans": The normalized boxes coordinates for all queries, represented as
+ (center_x, width). These values are normalized in [0, 1],
+ relative to the size of each individual image (disregarding possible padding).
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
+ dictionnaries containing the two above keys for each decoder layer.
+ """
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+ # TODO should we remove or use different positional embeddings to the src_txt?
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ # pos_txt = torch.zeros_like(src_txt)
+ # pad zeros for txt positions
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+ # (#layers, bsz, #queries, d), (bsz, L_vid+L_txt, d)
+ hs, memory = self.transformer(src, ~mask, self.query_embed.weight, pos)
+ outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(hs) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]}
+
+ txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d)
+ vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d)
+ if self.contrastive_align_loss:
+ proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1)
+ proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1)
+ proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1)
+ out.update(dict(
+ proj_queries=proj_queries[-1],
+ proj_txt_mem=proj_txt_mem,
+ proj_vid_mem=proj_vid_mem
+ ))
+
+ out["saliency_scores"] = self.saliency_proj(vid_mem).squeeze(-1) # (bsz, L_vid)
+
+ if self.aux_loss:
+ # assert proj_queries and proj_txt_mem
+ out['aux_outputs'] = [
+ {'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+ if self.contrastive_align_loss:
+ assert proj_queries is not None
+ for idx, d in enumerate(proj_queries[:-1]):
+ out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem))
+ return out
+
+ # @torch.jit.unused
+ # def _set_aux_loss(self, outputs_class, outputs_coord):
+ # # this is a workaround to make torchscript happy, as torchscript
+ # # doesn't support dictionary with non-homogeneous values, such
+ # # as a dict having both a Tensor and a list.
+ # return [{'pred_logits': a, 'pred_spans': b}
+ # for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+ targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2]
+ The target spans are expected in format (center_x, w), normalized by the image size.
+ """
+ assert 'pred_spans' in outputs
+ targets = targets["span_labels"]
+ idx = self._get_src_permutation_idx(indices)
+ src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2)
+ tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2)
+ if self.span_loss_type == "l1":
+ loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none')
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans)))
+ else: # ce
+ n_spans = src_spans.shape[0]
+ src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2)
+ loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none')
+
+ # giou
+ # src_span_indices = src_spans.max(1)[1] # (#spans, 2)
+ # src_span_indices[:, 1] += 1 # ed non-inclusive [st, ed)
+ #
+ # tgt_span_indices = tgt_spans
+ # tgt_span_indices[:, 1] += 1
+ # loss_giou = 1 - torch.diag(generalized_temporal_iou(src_span_indices, tgt_span_indices))
+ loss_giou = loss_span.new_zeros([1])
+
+ losses = {}
+ losses['loss_b'] = loss_span.mean()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ """Classification loss (NLL)
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ # TODO add foreground and background classifier. use all non-matched as background.
+ assert 'pred_logits' in outputs
+ src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2)
+ # idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch
+ idx = self._get_src_permutation_idx(indices)
+ target_classes = torch.full(src_logits.shape[:2], self.background_label,
+ dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ target_classes[idx] = self.foreground_label
+
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none")
+ losses = {'loss_f': loss_ce.mean()}
+
+ if log:
+ # TODO this should probably be a separate loss, not hacked in this one here
+ losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0]
+ return losses
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_intra": 0}
+ saliency_scores = outputs["saliency_scores"] # (N, L)
+ pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
+ neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
+ num_pairs = pos_indices.shape[1] # typically 2 or 4
+ batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
+ pos_scores = torch.stack(
+ [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
+ neg_scores = torch.stack(
+ [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
+ loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
+ / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
+ return {"loss_s_intra": loss_saliency}
+
+ def loss_contrastive_align(self, outputs, targets, indices, log=True):
+ """encourage higher scores between matched query span and input text"""
+ normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
+ normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
+ logits = torch.einsum(
+ "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
+ logits = logits.sum(2) / self.temperature # (bsz, #queries)
+ idx = self._get_src_permutation_idx(indices)
+ positive_map = torch.zeros_like(logits, dtype=torch.bool)
+ positive_map[idx] = True
+ positive_logits = logits.masked_fill(~positive_map, 0)
+
+ pos_term = positive_logits.sum(1) # (bsz, )
+ num_pos = positive_map.sum(1) # (bsz, )
+ neg_term = logits.logsumexp(1) # (bsz, )
+ loss_nce = - pos_term / num_pos + neg_term # (bsz, )
+ losses = {"loss_contrastive_align": loss_nce.mean()}
+ return losses
+
+ def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True):
+ """encourage higher scores between matched query span and input text"""
+ # TODO (1) align vid_mem and txt_mem;
+ # TODO (2) change L1 loss as CE loss on 75 labels, similar to soft token prediction in MDETR
+ normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
+ normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
+ logits = torch.einsum(
+ "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
+ logits = logits.sum(2) / self.temperature # (bsz, #queries)
+ idx = self._get_src_permutation_idx(indices)
+ positive_map = torch.zeros_like(logits, dtype=torch.bool)
+ positive_map[idx] = True
+ positive_logits = logits.masked_fill(~positive_map, 0)
+
+ pos_term = positive_logits.sum(1) # (bsz, )
+ num_pos = positive_map.sum(1) # (bsz, )
+ neg_term = logits.logsumexp(1) # (bsz, )
+ loss_nce = - pos_term / num_pos + neg_term # (bsz, )
+ losses = {"loss_contrastive_align": loss_nce.mean()}
+ return losses
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx # two 1D tensors of the same length
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "contrastive_align": self.loss_contrastive_align,
+ "saliency": self.loss_saliency,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
+
+ # Retrieve the matching between the outputs of the last layer and the targets
+ # list(tuples), each tuple is (pred_span_indices, tgt_span_indices)
+ indices = self.matcher(outputs_without_aux, targets)
+
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if 'aux_outputs' in outputs:
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
+ indices = self.matcher(aux_outputs, targets)
+ for loss in self.losses:
+ if "saliency" == loss: # skip as it is only in the top layer
+ continue
+ kwargs = {}
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs)
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ return losses
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ # the `num_classes` naming here is somewhat misleading.
+ # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
+ # is the maximum id for a class in your dataset. For example,
+ # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
+ # As another example, for a dataset that has a single class with id 1,
+ # you should pass `num_classes` to be 2 (max_obj_id + 1).
+ # For more details on this, check the following discussion
+ # https://github.com/facebookresearch/moment_bert/issues/108#issuecomment-650269223
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ num_queries=args.num_queries,
+ input_dropout=args.input_dropout,
+ aux_loss=args.aux_loss,
+ # contrastive_align_loss=args.contrastive_align_loss,
+ # contrastive_hdim=args.contrastive_hdim,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+ # if args.contrastive_align_loss:
+ # weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef
+ # TODO this is a hack
+ if args.aux_loss:
+ aux_weight_dict = {}
+ for i in range(args.dec_layers - 1):
+ aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"})
+ weight_dict.update(aux_weight_dict)
+
+ losses = ['spans', 'labels', 'saliency']
+ # if args.contrastive_align_loss:
+ # losses += ["contrastive_align"]
+ criterion = SetCriterion(
+ matcher=matcher, weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin
+ )
+ criterion.to(device)
+ return model, criterion
diff --git a/model/position_encoding.py b/model/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b9bad0b7867faede6179cd27e0a7c859137dcb8
--- /dev/null
+++ b/model/position_encoding.py
@@ -0,0 +1,126 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+import numpy as np
+
+def PositionalEncoding(n_position, d_hid):
+ def get_position_angle_vec(position, d_hid):
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+
+ sinusoid_table = np.array([get_position_angle_vec(pos_i, d_hid) for pos_i in range(n_position)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+ return torch.FloatTensor(sinusoid_table) # shape:(1, maxLen(n_position), d_hid)
+
+class TrainablePositionalEncoding(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings.
+ """
+ def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
+ super(TrainablePositionalEncoding, self).__init__()
+ self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
+ self.LayerNorm = nn.LayerNorm(hidden_size)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, input_feat):
+ """
+ Args:
+ input_feat: (N, L, D)
+ """
+ bsz, seq_length = input_feat.shape[:2]
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
+ position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
+
+ position_embeddings = self.position_embeddings(position_ids)
+
+ embeddings = self.LayerNorm(input_feat + position_embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images. (To 1D sequences)
+ """
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask):
+ """
+ Args:
+ x: torch.tensor, (batch_size, L, d)
+ mask: torch.tensor, (batch_size, L), with 1 as valid
+
+ Returns:
+
+ """
+ assert mask is not None
+ x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L)
+ if self.normalize:
+ eps = 1e-6
+ x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ # import pdb; pdb.set_trace()
+ # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2).int() / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats)
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2)
+ # import ipdb; ipdb.set_trace()
+ return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L)
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, x, mask):
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = torch.cat([
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
+ return pos
+
+
+def build_position_encoding(args):
+ N_steps = args.hidden_dim
+ if args.position_embedding in ('v2', 'sine'):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
+ # elif args.position_embedding in ('v3', 'learned'):
+ # position_embedding = PositionEmbeddingLearned(N_steps)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ txt_pos_embed = TrainablePositionalEncoding(
+ max_position_embeddings=args.max_q_l,
+ hidden_size=args.hidden_dim, dropout=args.input_dropout)
+ return position_embedding, txt_pos_embed
diff --git a/model/transformer.py b/model/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a4958ec06d8c1fd1f091362d2c5a22714ed0714
--- /dev/null
+++ b/model/transformer.py
@@ -0,0 +1,471 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+DETR Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MHattention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+
+class Transformer(nn.Module):
+
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False,
+ return_intermediate_dec=False):
+ super().__init__()
+
+ # TransformerEncoderLayerThin
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
+ dropout, activation, normalize_before)
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ # TransformerDecoderLayerThin
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
+ dropout, activation, normalize_before)
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
+ return_intermediate=return_intermediate_dec)
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, query_embed, pos_embed):
+ """
+ Args:
+ src: (batch_size, L, d)
+ mask: (batch_size, L)
+ query_embed: (#queries, d)
+ pos_embed: (batch_size, L, d) the same as src
+
+ Returns:
+
+ """
+ # flatten NxCxHxW to HWxNxC
+ bs, l, d = src.shape
+ src = src.permute(1, 0, 2) # (L, batch_size, d)
+ pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (#queries, batch_size, d)
+
+ tgt = torch.zeros_like(query_embed)
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # (L, batch_size, d)
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
+ pos=pos_embed, query_pos=query_embed) # (#layers, #queries, batch_size, d)
+ hs = hs.transpose(1, 2) # (#layers, batch_size, #qeries, d)
+ # memory = memory.permute(1, 2, 0) # (batch_size, d, L)
+ memory = memory.transpose(0, 1) # (batch_size, L, d)
+ return hs, memory
+
+
+class TransformerEncoder(nn.Module):
+
+ def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(self, src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ output = src
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(output, src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
+ if self.return_intermediate:
+ intermediate.append(output)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ output = tgt
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(output, memory, tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos, query_pos=query_pos)
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+class TransformerEncoderLayerThin(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ # self.linear1 = nn.Linear(d_model, dim_feedforward)
+ # self.dropout = nn.Dropout(dropout)
+ # self.linear2 = nn.Linear(dim_feedforward, d_model)
+ self.linear = nn.Linear(d_model, d_model)
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ # self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src2 = self.linear(src2)
+ src = src + self.dropout(src2)
+ src = self.norm(src)
+ # src = src + self.dropout1(src2)
+ # src = self.norm1(src)
+ # src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ # src = src + self.dropout2(src2)
+ # src = self.norm2(src)
+ return src
+
+ def forward_pre(self, src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ """not used"""
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(self, src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerEncoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(self, src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(self, src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
+
+
+class TransformerDecoderLayerThin(nn.Module):
+ """removed intermediate layer"""
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, d_model)
+ # self.linear1 = nn.Linear(d_model, dim_feedforward)
+ # self.dropout = nn.Dropout(dropout)
+ # self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ # self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ # self.dropout3 = nn.Dropout(dropout)
+
+ # self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt2 = self.linear1(tgt2)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ # tgt = tgt + self.dropout2(tgt2)
+ # tgt = self.norm2(tgt)
+ # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ # tgt = tgt + self.dropout3(tgt2)
+ # tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
+
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def build_transformer(args):
+ return Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ )
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
diff --git a/model/transformer_encoder.py b/model/transformer_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb912bfe43022280e9e8ad3c1bcbaf926bf2ebec
--- /dev/null
+++ b/model/transformer_encoder.py
@@ -0,0 +1,159 @@
+import copy
+import pdb
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+
+class Transformer(nn.Module):
+
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=4,
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False, # False as default
+ return_intermediate_dec=False):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
+ dropout, activation, normalize_before)
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, pos_embed):
+ """
+ Args:
+ src: (batch_size, L, d)
+ mask: (batch_size, L)
+ query_embed: (#queries, d) -> my imple (batch_size, d) and #queries=1
+ pos_embed: (batch_size, L, d) the same as src
+
+ Returns:
+
+ """
+ # flatten NxCxHxW to HWxNxC
+ src = src.permute(1, 0, 2) # (L, batch_size, d)
+ pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
+
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ memory = memory.transpose(0, 1)
+
+ return memory
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(self, src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ output = src
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(output, src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
+ if self.return_intermediate:
+ intermediate.append(output)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output
+
+
+class TransformerEncoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward(self, src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def build_transformer(args):
+ return Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ )
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
diff --git a/model/transformer_encoder_droppath.py b/model/transformer_encoder_droppath.py
new file mode 100644
index 0000000000000000000000000000000000000000..536d46529a4722c1bf787d85fbace0afe1e3a33b
--- /dev/null
+++ b/model/transformer_encoder_droppath.py
@@ -0,0 +1,194 @@
+import copy
+import pdb
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+
+class Transformer(nn.Module):
+
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=4,
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, droppath=0.1,
+ activation="gelu", normalize_before=False, # False as default
+ return_intermediate_dec=False):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
+ dropout, droppath, activation, normalize_before)
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, pos_embed):
+ """
+ Args:
+ src: (batch_size, L, d)
+ mask: (batch_size, L)
+ query_embed: (#queries, d) -> my imple (batch_size, d) and #queries=1
+ pos_embed: (batch_size, L, d) the same as src
+
+ Returns:
+
+ """
+ # flatten NxCxHxW to HWxNxC
+ src = src.permute(1, 0, 2) # (L, batch_size, d)
+ pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
+
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ memory = memory.transpose(0, 1)
+
+ return memory
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(self, src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ output = src
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(output, src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
+ if self.return_intermediate:
+ intermediate.append(output)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output
+
+class TransformerEncoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, droppath=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ # self.dropout1 = nn.Dropout(dropout)
+ # self.dropout2 = nn.Dropout(dropout)
+ self.droppath1 = DropPath(droppath)
+ self.droppath2 = DropPath(droppath)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
+ # src2 = self.self_attn_eff(q=q, k=k, v=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.droppath1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.activation(self.linear1(src)))
+ # src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.droppath2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward(self, src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def build_transformer(args):
+ return Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ droppath=args.droppath,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ )
+
+def drop_path(x, drop_prob=0.0, training=False):
+ """
+ Stochastic Depth per sample.
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ mask.floor_()
+ x = x.div(keep_prob) * mask
+
+ return x
+
+
+class DropPath(nn.Module):
+ """
+ Drop paths per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ x = x.permute(1, 0, 2)
+ res = drop_path(x, self.drop_prob, self.training)
+ return res.permute(1, 0, 2)
+ # return drop_path(x, self.drop_prob, self.training)
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
\ No newline at end of file
diff --git a/model/univtg.py b/model/univtg.py
new file mode 100644
index 0000000000000000000000000000000000000000..607f8ad325ce6697ca3e49d911447489fa407f7f
--- /dev/null
+++ b/model/univtg.py
@@ -0,0 +1,450 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder_droppath import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+ device_id = src_vid.device
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).to(device_id)
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ mask = targets['timestamp_mask'].bool()
+ mask_valid = targets['timestamp_window'].bool()
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ target_classes[mask_valid] = 1
+ # target_classes = targets['timestamp_window'] # soft cls.
+ target_classes.float()
+ # pdb.set_trace()
+
+ weights = torch.zeros_like(target_classes).float()
+ weights[mask] = self.empty_weight[1]
+ weights[mask_valid] = self.empty_weight[0]
+
+ # pdb.set_trace()
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+ return {"loss_f": loss_ce.sum() / mask.sum()}
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ mask = targets['timestamp_mask']
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ neg_indices_in = (saliency_scores < selected_scores)
+ neg_indices_in[batch_indices, pos_indices] = True
+ mask_invalid = neg_indices_in * mask.bool()
+
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
+
+ loss_saliency_intra = - loss_in_i - loss_in_j
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, hl_only=False):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr', 'vlp']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
diff --git a/model/univtg_ablation.py b/model/univtg_ablation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3fb06a293c0a7130715d53cecc9b98406d70fdf
--- /dev/null
+++ b/model/univtg_ablation.py
@@ -0,0 +1,474 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder_droppath import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ weight_abalation_b = targets['weight_ablation'][:,0].unsqueeze(-1)
+ if weight_abalation_b.sum() == 0:
+ return {"loss_f": torch.tensor(0).cuda(), "loss_g": torch.tensor(0).cuda()}
+
+ mask_valid = (mask_valid * weight_abalation_b).bool()
+ mask_valid_full = (mask_valid_full * weight_abalation_b.unsqueeze(-1)).bool()
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ mask = targets['timestamp_mask'].bool()
+ mask_valid = targets['timestamp_window'].bool()
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ target_classes[mask_valid] = 1
+ # target_classes = targets['timestamp_window'] # soft cls.
+ target_classes.float()
+ # pdb.set_trace()
+
+ weights = torch.zeros_like(target_classes).float()
+ weights[mask] = self.empty_weight[1]
+ weights[mask_valid] = self.empty_weight[0]
+
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+
+ weight_abalation_f = targets['weight_ablation'][:,2].unsqueeze(-1)
+ if weight_abalation_f.sum() == 0:
+ return {"loss_f": torch.tensor(0).cuda()}
+
+ mask = mask * weight_abalation_f
+ loss_ce = loss_ce * weight_abalation_f
+ return {"loss_f": loss_ce.sum() / mask.sum()}
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+
+ weight_abalation_s = targets['weight_ablation'][:,3].bool()
+ if weight_abalation_s.sum() == 0:
+ return {"loss_s_inter": torch.tensor(0).cuda(),
+ "loss_s_intra": torch.tensor(0).cuda()}
+
+ _idiag = idiag[weight_abalation_s]
+ _jdiag = jdiag[weight_abalation_s]
+
+ loss_i = _idiag.sum() / len(_idiag)
+ loss_j = _jdiag.sum() / len(_jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ mask = targets['timestamp_mask']
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ neg_indices_in = (saliency_scores < selected_scores)
+ neg_indices_in[batch_indices, pos_indices] = True
+ mask_invalid = neg_indices_in * mask.bool()
+
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ _pos_logsm_in_i = pos_logsm_in_i[weight_abalation_s]
+ _pos_logsm_in_j = pos_logsm_in_j[weight_abalation_s]
+
+ loss_in_i = _pos_logsm_in_i.sum() / len(_pos_logsm_in_i)
+ loss_in_j = _pos_logsm_in_j.sum() / len(_pos_logsm_in_j)
+
+ loss_saliency_intra = - loss_in_i - loss_in_j
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, hl_only=False):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr', 'vlp']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
\ No newline at end of file
diff --git a/model/univtg_qfvs.py b/model/univtg_qfvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4b133857affa8f34e4c77e753539291cb96a75
--- /dev/null
+++ b/model/univtg_qfvs.py
@@ -0,0 +1,476 @@
+import pdb
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+
+from model.transformer_encoder_droppath import build_transformer
+from model.matcher import build_matcher
+from model.position_encoding import build_position_encoding
+from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
+
+def init_weights(module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+class WeightedPool(nn.Module):
+ def __init__(self, dim):
+ super(WeightedPool, self).__init__()
+ weight = torch.empty(dim, 1)
+ nn.init.xavier_uniform_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def forward(self, x, mask):
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
+ alphas = nn.Softmax(dim=1)(alpha)
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
+ pooled_x = pooled_x.squeeze(2)
+ return pooled_x
+
+class Model(nn.Module):
+ """ This is the UniVTG module that performs moment localization. """
+
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
+ input_dropout, aux_loss=False,
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture. See transformer.py
+ position_embed: torch module of the position_embedding, See position_encoding.py
+ txt_position_embed: position_embedding for text
+ txt_dim: int, text query input dimension
+ vid_dim: int, video feature input dimension
+ max_v_l: int, maximum #clips in videos
+ span_loss_type: str, one of [l1, ce]
+ l1: (center-x, width) regression.
+ ce: (st_idx, ed_idx) classification.
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
+ """
+ super().__init__()
+ self.transformer = transformer
+ self.position_embed = position_embed
+ self.txt_position_embed = txt_position_embed
+ hidden_dim = transformer.d_model
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
+
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
+ self.token_type_embeddings.apply(init_weights)
+
+ # Conv projector
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
+
+ self.use_txt_pos = use_txt_pos
+ self.n_input_proj = n_input_proj
+ relu_args = [True] * 3
+ relu_args[n_input_proj-1] = False
+ self.input_txt_proj = nn.Sequential(*[
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+ self.input_vid_proj = nn.Sequential(*[
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
+ ][:n_input_proj])
+
+ # MLP Projector
+ self.weightedpool = WeightedPool(hidden_dim)
+
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
+ bs = src_vid.shape[0]
+ src_vid = self.input_vid_proj(src_vid)
+ src_txt = self.input_txt_proj(src_txt)
+ if src_cls is not None:
+ src_cls = self.input_txt_proj(src_cls)
+
+ # type token.
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
+ if src_cls is not None:
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
+
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
+
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
+
+ memory = self.transformer(src, ~mask, pos)
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
+
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
+
+ if self.span_loss_type == "l1":
+ outputs_coord = outputs_coord.sigmoid()
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
+ outputs_coord = outputs_coord * idx_mask
+ else:
+ raise NotImplementedError
+
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
+ 'src_vid_mask': src_vid_mask}
+
+ vid_mem_proj = src_vid
+
+ # word-level -> sentence-level
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
+
+ out["vid_mem_proj"] = vid_mem_proj
+ out["txt_mem_proj"] = txt_mem_proj
+ if src_cls is not None:
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
+ out["cls_mem_proj"] = cls_mem_proj
+ out["saliency_scores"] = sim
+ return out
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
+ saliency_margin=1):
+ """ Create the criterion.
+ Parameters:
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ temperature: float, temperature for NCE loss
+ span_loss_type: str, [l1, ce]
+ max_v_l: int,
+ saliency_margin: float
+ """
+ super().__init__()
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.temperature = temperature
+ self.span_loss_type = span_loss_type
+ self.max_v_l = max_v_l
+ self.saliency_margin = saliency_margin
+ self.temperature = 0.07
+
+ # foreground and background classification
+ self.foreground_label = 0
+ self.background_label = 1
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(2)
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
+ self.register_buffer('empty_weight', empty_weight)
+
+ def loss_spans(self, outputs, targets, indices):
+ assert 'pred_spans' in outputs
+
+ start_spans = targets['timestamp']
+ pred_spans = outputs['pred_spans']
+ src_spans = start_spans + pred_spans
+ gt_spans = targets['span_labels_nn']
+
+ mask = targets['timestamp_mask'].bool()
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
+ mask_valid = targets['timestamp_window'].bool()
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
+
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
+
+ losses = {}
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
+ losses['loss_g'] = loss_giou.mean()
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, log=True):
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_f": 0.}
+
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
+ target_classes = targets["saliency_scores"].squeeze()
+
+ weights = torch.ones_like(target_classes).float() * self.empty_weight[1]
+ weights[target_classes.bool()] = self.empty_weight[0]
+
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none")
+ return {"loss_f": loss_ce.sum() / target_classes.sum()}
+ # return {"loss_f": loss_ce.sum() / len(target_classes)}
+
+ # mask = targets['timestamp_mask'].bool()
+ # mask_valid = targets['timestamp_window'].bool()
+ # target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
+ # target_classes[mask_valid] = 1
+ # # target_classes = targets['timestamp_window'] # soft cls.
+ # target_classes.float()
+ # # pdb.set_trace()
+
+ # weights = torch.zeros_like(target_classes).float()
+ # weights[mask] = self.empty_weight[1]
+ # weights[mask_valid] = self.empty_weight[0]
+
+ # loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
+ # # return {"loss_f": loss_ce.sum() / mask.sum()}
+ # return {"loss_f": loss_ce.sum() / mask_valid.sum()}
+
+ def loss_saliency(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * qfvs mil-nce mode
+ pos_indices = saliency_scores.squeeze() > 0
+
+ sim = outputs['saliency_scores']
+ sim_soft = F.softmax(sim / self.temperature, dim=0)
+ sim_log = torch.log(sim_soft[pos_indices])
+ loss_saliency_intra = -sim_log.sum() / len(sim_log)
+ return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra}
+
+ # * inter-vid mode
+ # vid_mem_proj = outputs["vid_mem_proj"]
+ # pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ # batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ # vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ # txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ # sim = sim_matrix(vid_feats, txt_feats)
+
+ # i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ # j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # # sum over positives
+ # idiag = torch.diag(i_logsm)
+ # jdiag = torch.diag(j_logsm)
+ # loss_i = idiag.sum() / len(idiag)
+ # loss_j = jdiag.sum() / len(jdiag)
+
+ # loss_saliency_inter = - loss_i - loss_j
+
+ # # * intra-vid mode
+ # mask = targets['timestamp_mask']
+ # selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
+ # neg_indices_in = (saliency_scores < selected_scores)
+ # neg_indices_in[batch_indices, pos_indices] = True
+ # mask_invalid = neg_indices_in * mask.bool()
+
+ # sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
+ # sim_in = sim_in + (mask_invalid + 1e-45).log()
+ # logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
+ # logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
+
+ # pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
+ # pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
+ # loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
+ # loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
+
+ # loss_saliency_intra = - loss_in_i - loss_in_j
+
+ # return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
+ """higher scores for positive clips"""
+ if "saliency_pos_labels" not in targets:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+ saliency_scores = targets["saliency_scores"]
+ if saliency_scores.sum() == 0:
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
+
+ # * inter-vid mode
+ vid_mem_proj = outputs["vid_mem_proj"]
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
+
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
+ sim = sim_matrix(vid_feats, txt_feats)
+
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
+
+ # sum over positives
+ idiag = torch.diag(i_logsm)
+ jdiag = torch.diag(j_logsm)
+ loss_i = idiag.sum() / len(idiag)
+ loss_j = jdiag.sum() / len(jdiag)
+
+ loss_saliency_inter = - loss_i - loss_j
+
+ # * intra-vid mode
+ if 'cls_idx' not in targets.keys(): # eval
+ return {"loss_s_inter": loss_saliency_inter}
+
+ cls_indices = targets['cls_idx'].bool()
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
+ sim_cls = sim_matrix(vid_feats, cls_feats)
+
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
+ idiag_cls = i_logsm_cls[cls_indices]
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
+
+ loss_saliency_intra = - loss_cls_i
+
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
+
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
+ loss_map = {
+ "spans": self.loss_spans,
+ "labels": self.loss_labels,
+ "saliency": self.loss_saliency,
+ "saliency_cls": self.loss_saliency_cls,
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, **kwargs)
+
+ def forward(self, outputs, targets, mask_GT=None):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ indices = None
+ # Compute all the requested losses
+ losses = {}
+ outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0])
+ count = mask_GT.sum()
+ outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0])
+ # targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0])
+ targets['saliency_scores'] = targets['saliency_scores'][0,:count]
+
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices))
+
+ return losses
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class Conv(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+ def forward(self, x):
+ x = x.permute(0,2,1)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.permute(0, 2, 1)
+
+class LinearLayer(nn.Module):
+ """linear layer configurable with layer normalization, dropout, ReLU."""
+
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
+ super(LinearLayer, self).__init__()
+ self.relu = relu
+ self.layer_norm = layer_norm
+ if layer_norm:
+ self.LayerNorm = nn.LayerNorm(in_hsz)
+ layers = [
+ nn.Dropout(dropout),
+ nn.Linear(in_hsz, out_hsz)
+ ]
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """(N, L, D)"""
+ if self.layer_norm:
+ x = self.LayerNorm(x)
+ x = self.net(x)
+ if self.relu:
+ x = F.relu(x, inplace=True)
+ return x # (N, L, D)
+
+
+def build_model(args):
+ device = torch.device(args.device)
+
+ transformer = build_transformer(args)
+ position_embedding, txt_position_embedding = build_position_encoding(args)
+
+ model = Model(
+ transformer,
+ position_embedding,
+ txt_position_embedding,
+ txt_dim=args.t_feat_dim,
+ vid_dim=args.v_feat_dim,
+ input_dropout=args.input_dropout,
+ span_loss_type=args.span_loss_type,
+ use_txt_pos=args.use_txt_pos,
+ n_input_proj=args.n_input_proj,
+ )
+
+ matcher = build_matcher(args)
+ weight_dict = {"loss_b": args.b_loss_coef,
+ "loss_g": args.g_loss_coef,
+ "loss_f": args.f_loss_coef,
+ "loss_s_intra": args.s_loss_intra_coef,
+ "loss_s_inter": args.s_loss_inter_coef}
+
+ if args.dset_type in ['mr', 'vlp']:
+ if 'tal' not in args.train_path:
+ losses = ['spans', 'labels', 'saliency']
+ else:
+ losses = ['spans', 'labels', 'saliency_cls']
+ elif args.dset_type in ['hl', 'vs']:
+ losses = ['labels', 'saliency']
+
+ criterion = SetCriterion(
+ matcher=matcher,
+ weight_dict=weight_dict, losses=losses,
+ eos_coef=args.eos_coef, temperature=args.temperature,
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
+ saliency_margin=args.saliency_margin,
+ )
+ criterion.to(device)
+ return model, criterion
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f550f0f568552976c3867b559c32dedfa79e852a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,355 @@
+absl-py==1.2.0
+accelerate==0.19.0
+aiodns==3.0.0
+aiofiles==23.1.0
+aiohttp==3.8.3
+aiohttp-socks==0.7.1
+aiosignal==1.3.1
+altair==5.0.1
+antiorm==1.2.1
+antlr4-python3-runtime==4.9.3
+anyio==3.7.0
+appdirs==1.4.4
+argilla==1.8.0
+argon2-cffi==21.3.0
+argon2-cffi-bindings==21.2.0
+asttokens==2.0.7
+async-timeout==4.0.2
+attrs==22.1.0
+Babel==2.12.1
+backcall==0.2.0
+backoff==2.2.1
+beautifulsoup4==4.11.1
+bert-score==0.3.13
+black==22.3.0
+bleach==5.0.1
+blis==0.7.9
+boto3==1.24.84
+botocore==1.27.84
+Brotli==1.0.9
+brotlipy==0.7.0
+cachetools==5.2.0
+catalogue==2.0.8
+cchardet==2.1.7
+certifi==2023.5.7
+cffi==1.15.1
+chardet==5.1.0
+charset-normalizer==2.1.1
+cinemagoer==2023.5.1
+click==8.1.3
+cloudpickle==2.2.0
+cmake==3.26.3
+coloredlogs==15.0.1
+colorlog==6.7.0
+commonmark==0.9.1
+confection==0.0.4
+contourpy==1.0.6
+cryptography==37.0.1
+cycler==0.11.0
+cymem==2.0.7
+dataclasses==0.6
+dataclasses-json==0.5.7
+dataflow==0.9.5
+db==0.1.1
+db-sqlite3==0.0.1
+debugpy==1.6.3
+decoder==0.5
+decorator==4.4.2
+decord==0.6.0
+defusedxml==0.7.1
+Deprecated==1.2.14
+detectron2==0.6
+docker==6.0.0
+docker-pycreds==0.4.0
+easydict==1.9
+ego4d==1.2.5
+einops==0.6.0
+elastic-transport==8.4.0
+elasticsearch==8.5.0
+entrypoints==0.4
+et-xmlfile==1.1.0
+exceptiongroup==1.1.1
+executing==0.10.0
+fairscale==0.4.12
+fake-useragent==0.1.14
+fastapi==0.98.0
+fastjsonschema==2.16.1
+ffmpeg==1.4
+ffmpeg-python==0.2.0
+ffmpy==0.3.0
+ffprobe==0.5
+filelock==3.7.1
+fonttools==4.38.0
+frozenlist==1.3.3
+fsspec==2023.5.0
+ftfy==6.1.1
+future==0.18.2
+fvcore==0.1.5.post20220512
+gdown==4.7.1
+gensim==4.2.0
+geographiclib==2.0
+geopy==2.3.0
+gitdb==4.0.10
+GitPython==3.1.31
+glide-text2im==0.0.0
+google-api-core==2.11.1
+google-api-python-client==2.95.0
+google-auth==2.22.0
+google-auth-httplib2==0.1.0
+google-auth-oauthlib==0.4.6
+google-cloud==0.34.0
+google-cloud-vision==3.4.4
+google-measurement-protocol==1.1.0
+googleapis-common-protos==1.59.1
+googletransx==2.4.2
+gradio==3.23.0
+greenlet==2.0.2
+grpcio==1.56.2
+grpcio-status==1.56.2
+h11==0.14.0
+h5py==3.7.0
+httpcore==0.16.3
+httplib2==0.22.0
+httpx==0.23.3
+huggingface-hub==0.15.1
+humanfriendly==10.0
+hydra-core==1.2.0
+idna==3.3
+imageio==2.31.0
+imageio-ffmpeg==0.4.7
+importlib-metadata==4.12.0
+importlib-resources==5.9.0
+iopath==0.1.9
+ipdb==0.13.11
+ipykernel==6.15.3
+ipython==8.4.0
+ipython-genutils==0.2.0
+ipywidgets==8.0.2
+jedi==0.18.1
+Jinja2==3.1.2
+jmespath==1.0.1
+joblib==1.1.0
+jsonlines==3.1.0
+jsonschema==4.16.0
+jupyter==1.0.0
+jupyter_client==7.3.5
+jupyter-console==6.4.4
+jupyter-core==4.11.1
+jupyterlab-pygments==0.2.2
+jupyterlab-widgets==3.0.3
+kiwisolver==1.4.4
+langchain==0.0.191
+langcodes==3.3.0
+language-evaluation==0.1.0
+lazy_loader==0.2
+linkify-it-py==2.0.2
+lit==16.0.5.post0
+lxml==4.9.1
+Markdown==3.4.1
+markdown-it-py==2.2.0
+markdown2==2.4.9
+MarkupSafe==2.1.1
+marshmallow==3.19.0
+marshmallow-enum==1.5.1
+matplotlib==3.6.2
+matplotlib-inline==0.1.3
+mdit-py-plugins==0.3.3
+mdurl==0.1.2
+mistune==2.0.4
+mkl-fft==1.3.1
+mkl-random==1.2.2
+mkl-service==2.4.0
+monotonic==1.6
+more-itertools==9.1.0
+moviepy==1.0.3
+mpmath==1.3.0
+msg-parser==1.2.0
+msgpack==1.0.4
+msgpack-numpy==0.4.8
+multidict==6.0.4
+murmurhash==1.0.9
+mutagen==1.46.0
+mypy-extensions==0.4.3
+nbclient==0.6.8
+nbconvert==7.0.0
+nbformat==5.5.0
+nest-asyncio==1.5.5
+networkx==2.8.7
+nh3==0.2.13
+nltk==3.7
+nms-1d-cpu==0.0.0
+nncore==0.3.6
+notebook==6.4.12
+numexpr==2.8.4
+numpy==1.23.1
+nvidia-cublas-cu11==11.10.3.66
+nvidia-cuda-cupti-cu11==11.7.101
+nvidia-cuda-nvrtc-cu11==11.7.99
+nvidia-cuda-runtime-cu11==11.7.99
+nvidia-cudnn-cu11==8.5.0.96
+nvidia-cufft-cu11==10.9.0.58
+nvidia-curand-cu11==10.2.10.91
+nvidia-cusolver-cu11==11.4.0.1
+nvidia-cusparse-cu11==11.7.4.91
+nvidia-nccl-cu11==2.14.3
+nvidia-nvtx-cu11==11.7.91
+oauthlib==3.2.0
+olefile==0.46
+omegaconf==2.2.3
+openai==0.27.7
+openapi-schema-pydantic==1.2.4
+opencv-python==4.5.4.58
+openpyxl==3.1.2
+orjson==3.9.1
+ortools==9.4.1874
+packaging==21.3
+pandas==1.5.2
+pandocfilters==1.5.0
+parso==0.8.3
+pathspec==0.10.1
+pathtools==0.1.2
+pathy==0.10.1
+pdfminer.six==20221105
+peft==0.3.0
+pexpect==4.8.0
+pickleshare==0.7.5
+Pillow==9.3.0
+pip==22.2.2
+pkgutil_resolve_name==1.3.10
+platformdirs==2.5.2
+portalocker==2.5.1
+preshed==3.0.8
+prices==1.1.1
+proglog==0.1.10
+prometheus-client==0.14.1
+prompt-toolkit==3.0.30
+proto-plus==1.22.3
+protobuf==3.20.1
+psutil==5.9.2
+ptyprocess==0.7.0
+pure-eval==0.2.2
+pyasn1==0.4.8
+pyasn1-modules==0.2.8
+pycares==4.2.2
+pycipher==0.5.2
+pycocoevalcap==1.2
+pycocotools==2.0.5
+pycparser==2.21
+pycryptodomex==3.18.0
+pydantic==1.10.8
+pydot==1.4.2
+pydub==0.25.1
+pyfiglet==0.8.post1
+Pygments==2.12.0
+pynvml==11.5.0
+pyOpenSSL==22.0.0
+pypandoc==1.11
+pyparsing==3.0.9
+pyrsistent==0.18.1
+PySocks==1.7.1
+python-dateutil==2.8.2
+python-docx==0.8.11
+python-hostlist==1.21
+python-magic==0.4.27
+python-multipart==0.0.6
+python-pptx==0.6.21
+python-socks==2.0.3
+pytz==2022.7
+PyWavelets==1.4.1
+PyYAML==6.0
+pyzmq==23.2.1
+qtconsole==5.3.2
+QtPy==2.2.0
+regex==2022.7.25
+requests==2.28.1
+requests-oauthlib==1.3.1
+rfc3986==1.5.0
+rich==13.0.1
+rouge-score==0.1.2
+rsa==4.9
+ruamel.yaml==0.17.21
+ruamel.yaml.clib==0.2.7
+s3transfer==0.6.0
+sacremoses==0.0.53
+safetensors==0.3.1
+schedule==1.1.0
+scikit-image==0.21.0
+scikit-learn==1.1.2
+scipy==1.9.3
+seaborn==0.12.0
+semantic-version==2.10.0
+Send2Trash==1.8.0
+sentencepiece==0.1.99
+sentry-sdk==1.26.0
+setproctitle==1.3.2
+setuptools==59.5.0
+shortuuid==1.0.11
+simplejson==3.17.6
+six==1.16.0
+smart-open==6.2.0
+smmap==5.0.0
+sniffio==1.3.0
+soupsieve==2.3.2.post1
+spacy==3.5.3
+spacy-legacy==3.0.12
+spacy-loggers==1.0.4
+SQLAlchemy==2.0.15
+srsly==2.4.6
+stack-data==0.4.0
+starlette==0.27.0
+svgwrite==1.4.3
+sympy==1.12
+tabulate==0.8.10
+tenacity==8.2.2
+tensorboard==2.9.1
+tensorboard-data-server==0.6.1
+tensorboard-plugin-wit==1.8.1
+termcolor==1.1.0
+terminado==0.15.0
+terminaltables==3.1.10
+thinc==8.1.10
+threadpoolctl==3.1.0
+tifffile==2023.4.12
+timm==0.4.12
+tinycss2==1.1.1
+tokenizers==0.13.2
+tomli==2.0.1
+toolz==0.12.0
+torch==2.0.1
+torchaudio==0.9.0a0+33b2469
+torchdata==0.6.1
+torchtext==0.15.2
+torchvision==0.10.0a0
+tornado==6.2
+tqdm==4.64.1
+traitlets==5.3.0
+transformers==4.28.1
+triton==2.0.0
+twint==2.1.21
+typer==0.7.0
+typing_extensions==4.3.0
+typing-inspect==0.9.0
+uc-micro-py==1.0.2
+unstructured==0.7.1
+uritemplate==4.1.1
+urllib3==1.26.12
+uvicorn==0.22.0
+wandb==0.15.4
+warmup-scheduler==0.3
+wasabi==1.1.2
+wavedrom==2.0.3.post3
+wcwidth==0.2.5
+webencodings==0.5.1
+websocket-client==1.4.1
+websockets==11.0.3
+Werkzeug==2.2.1
+wheel==0.37.1
+widgetsnbextension==4.0.3
+wrapt==1.14.1
+xlrd==2.0.1
+XlsxWriter==3.1.2
+yacs==0.1.8
+yarl==1.9.2
+youtube-dl==2021.12.17
+yt-dlp==2023.3.4
+zipp==3.8.1
diff --git a/results/omni/opt.json b/results/omni/opt.json
new file mode 100644
index 0000000000000000000000000000000000000000..7fbfd399c087577661ae58fdf6be0f8990b2a789
--- /dev/null
+++ b/results/omni/opt.json
@@ -0,0 +1,111 @@
+{
+ "dset_type": "vlp",
+ "dset_name": "vlp",
+ "domain_name": null,
+ "model_id": "univtg",
+ "exp_id": "omni_mini_aio_unified__epo3_f10_b10g1_s0.1_0.1",
+ "device": 0,
+ "gpu_id": 0,
+ "debug": false,
+ "seed": 2018,
+ "local_rank": 0,
+ "eval_split_name": "val",
+ "data_ratio": 1.0,
+ "results_root": "results",
+ "num_workers": 8,
+ "no_pin_memory": false,
+ "bsz": 64,
+ "n_epoch": 100,
+ "max_es_cnt": 200,
+ "lr": 0.0001,
+ "lr_drop": 200,
+ "lr_gamma": 0.1,
+ "lr_warmup": 10.0,
+ "wd": 0.0001,
+ "grad_clip": 0.1,
+ "span_loss_type": "l1",
+ "b_loss_coef": 10.0,
+ "g_loss_coef": 1.0,
+ "eos_coef": 0.1,
+ "f_loss_coef": 10.0,
+ "s_loss_intra_coef": 0.1,
+ "s_loss_inter_coef": 0.1,
+ "main_metric": "MR-full-R1@0.3-key",
+ "eval_mode": null,
+ "eval_bsz": 32,
+ "eval_epoch": 5,
+ "eval_init": true,
+ "save_interval": 5,
+ "resume": "/data/home/qinghonglin/univtg/results/vlp-vlp/aio_unified_mini-clip-clip-2023_05_27_00/model_e0003.ckpt",
+ "resume_dir": null,
+ "resume_all": false,
+ "start_epoch": null,
+ "no_sort_results": false,
+ "max_before_nms": 1000,
+ "max_after_nms": 10,
+ "conf_thd": 0.0,
+ "nms_thd": 0.7,
+ "use_cache": -1,
+ "max_q_l": 75,
+ "max_v_l": 75,
+ "clip_length": 2.0,
+ "clip_len_list": null,
+ "max_windows": 5,
+ "add_easy_negative": 1,
+ "easy_negative_only": 1,
+ "round_multiple": 1,
+ "train_path": [
+ "data/qvhighlights/metadata/qvhighlights_train.jsonl",
+ "data/charades/metadata/charades_train.jsonl",
+ "data/ego4d/metadata/nlq_train.jsonl",
+ "data/tacos/metadata/train.jsonl",
+ "data/anet/metadata/train.jsonl",
+ "data/didemo/metadata/train.jsonl"
+ ],
+ "eval_path": "data/qvhighlights/metadata/qvhighlights_val.jsonl",
+ "train_path_list": null,
+ "eval_path_list": null,
+ "feat_root_list": null,
+ "no_norm_vfeat": false,
+ "no_norm_tfeat": false,
+ "v_feat_dirs": [
+ "vid_clip"
+ ],
+ "t_feat_dir": "txt_clip",
+ "v_feat_dim": 512,
+ "t_feat_dim": 512,
+ "ctx_mode": "video_tef",
+ "v_feat_types": "clip",
+ "t_feat_type": "clip",
+ "position_embedding": "sine",
+ "n_input_proj": 2,
+ "temperature": 0.07,
+ "enc_layers": 4,
+ "sub_enc_layers": 2,
+ "dec_layers": 2,
+ "dim_feedforward": 1024,
+ "hidden_dim": 512,
+ "input_dropout": 0.5,
+ "dropout": 0.0,
+ "droppath": 0.1,
+ "txt_drop_ratio": 0,
+ "use_txt_pos": false,
+ "nheads": 8,
+ "num_queries": 10,
+ "pre_norm": false,
+ "set_cost_span": 10,
+ "set_cost_giou": 1,
+ "set_cost_class": 4,
+ "saliency_margin": 0.2,
+ "aux_loss": false,
+ "max_segment_num": 20,
+ "max_frame_num": 200,
+ "top_percent": 0.02,
+ "qfvs_vid_feature": "fps1",
+ "qfvs_txt_feature": "query",
+ "qfvs_dense_shot": -1,
+ "qfvs_score_ensemble": -1,
+ "qfvs_score_gather": -1,
+ "qfvs_loss_gather": -1,
+ "results_dir": "results/vlp-vlp/omni_mini_aio_unified__epo3_f10_b10g1_s0.1_0.1-clip-clip-2023_05_31_06"
+}
\ No newline at end of file
diff --git a/run_on_video/__init__.py b/run_on_video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ed2adb16c75bea2cdd119bc3c19b67bf68bcc37
--- /dev/null
+++ b/run_on_video/__init__.py
@@ -0,0 +1 @@
+from run_on_video.video_extractor import vid2clip, txt2clip
diff --git a/run_on_video/clip/__init__.py b/run_on_video/clip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9
--- /dev/null
+++ b/run_on_video/clip/__init__.py
@@ -0,0 +1 @@
+from .clip import *
diff --git a/run_on_video/clip/bpe_simple_vocab_16e6.txt.gz b/run_on_video/clip/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/run_on_video/clip/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/run_on_video/clip/clip.py b/run_on_video/clip/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..9000dd80de5171b359dc336d0c955cc2332d12d4
--- /dev/null
+++ b/run_on_video/clip/clip.py
@@ -0,0 +1,195 @@
+import hashlib
+import os
+import urllib
+import warnings
+from typing import Union, List
+
+import torch
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+from tqdm import tqdm
+
+from .model import build_model
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+__all__ = ["available_models", "load", "tokenize"]
+_tokenizer = _Tokenizer()
+
+_MODELS = {
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+}
+
+
+def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def _transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert("RGB"),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+
+def available_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list(_MODELS.keys())
+
+
+def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+
+ device : Union[str, torch.device]
+ The device to put the loaded model
+
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLIP model
+
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if name in _MODELS:
+ model_path = _download(_MODELS[name])
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ model = build_model(state_dict or model.state_dict()).to(device)
+ if str(device) == "cpu":
+ model.float()
+ return model, _transform(model.visual.input_resolution)
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if str(device) == "cpu":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, _transform(model.input_resolution.item())
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77, max_valid_length: int = 32) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ max_valid_length:
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + _tokenizer.encode(text)[:max_valid_length-2] + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/run_on_video/clip/model.py b/run_on_video/clip/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..658fbbf5ed1379f9c0179fa456635a9ed6d4d4de
--- /dev/null
+++ b/run_on_video/clip/model.py
@@ -0,0 +1,432 @@
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.avgpool = nn.AvgPool2d(2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ def stem(x):
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+ def forward(self, x: torch.Tensor):
+ return self.resblocks(x)
+
+
+class VisualTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def forward(self, x: torch.Tensor):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+class CLIP(nn.Module):
+ def __init__(self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ if isinstance(vision_layers, (tuple, list)):
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width
+ )
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisualTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask()
+ )
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ self.initialize_parameters()
+
+ def initialize_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ if isinstance(self.visual, ModifiedResNet):
+ if self.visual.attnpool is not None:
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+ def encode_image(self, image):
+ return self.visual(image.type(self.dtype))
+
+ def encode_text(self, text):
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ eos_x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return dict(last_hidden_state=x, pooler_output=eos_x)
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+
+ # normalized features
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logit_scale * text_features @ image_features.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ model = CLIP(
+ embed_dim,
+ image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ if key in state_dict:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict)
+ return model.eval()
diff --git a/run_on_video/clip/simple_tokenizer.py b/run_on_video/clip/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91
--- /dev/null
+++ b/run_on_video/clip/simple_tokenizer.py
@@ -0,0 +1,132 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
diff --git a/run_on_video/clip_feature_extractor.py b/run_on_video/clip_feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2852db020d8bcc74fd52cc2e0cd97d59c61b6b94
--- /dev/null
+++ b/run_on_video/clip_feature_extractor.py
@@ -0,0 +1,101 @@
+import pdb
+import torch as th
+import math
+import numpy as np
+import torch
+from video_loader import VideoLoader
+from torch.utils.data import DataLoader
+import argparse
+from preprocessing import Preprocessing
+import torch.nn.functional as F
+from tqdm import tqdm
+import os
+import sys
+from feature_extractor import clip
+import argparse
+
+#################################
+model_version = "ViT-B/32"
+output_feat_size = 512
+clip_len = 2
+overwrite = True
+num_decoding_thread = 4
+half_precision = False
+
+@torch.no_grad()
+def extractor(vid_path, text, output_file):
+ dataset = VideoLoader(
+ vid_path,
+ framerate=1/clip_len,
+ size=224,
+ centercrop=True,
+ overwrite=overwrite,
+ model_version=model_version
+ )
+ n_dataset = len(dataset)
+ loader = DataLoader(
+ dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=num_decoding_thread,
+ sampler=sampler if n_dataset > 10 else None,
+ )
+ preprocess = Preprocessing()
+ model, _ = clip.load(model_version, device="cuda", jit=False)
+
+ encoded_texts = clip.tokenize(text).to('cuda')
+ text_feature = model.encode_text(encoded_texts)['last_hidden_state']
+ valid_lengths = (encoded_texts != 0).sum(1).tolist()[0]
+ text_feature = text_feature[0, :valid_lengths].cpu().numpy()
+ np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature)
+
+ totatl_num_frames = 0
+ with th.no_grad():
+ for k, data in enumerate(tqdm(loader)):
+ input_file = data['input'][0]
+ if os.path.isfile(output_file):
+ # print(f'Video {input_file} already processed.')
+ continue
+ elif not os.path.isfile(input_file):
+ print(f'{input_file}, does not exist.\n')
+ elif len(data['video'].shape) > 4:
+ video = data['video'].squeeze(0)
+ if len(video.shape) == 4:
+ video = preprocess(video)
+ n_chunk = len(video)
+ vid_features = th.cuda.FloatTensor(
+ n_chunk, output_feat_size).fill_(0)
+ n_iter = int(math.ceil(n_chunk))
+ for i in range(n_iter):
+ min_ind = i
+ max_ind = (i + 1)
+ video_batch = video[min_ind:max_ind].cuda()
+ batch_features = model.encode_image(video_batch)
+ vid_features[min_ind:max_ind] = batch_features
+ vid_features = vid_features.cpu().numpy()
+ if half_precision:
+ vid_features = vid_features.astype('float16')
+ totatl_num_frames += vid_features.shape[0]
+ # safeguard output path before saving
+ dirname = os.path.dirname(output_file)
+ if not os.path.exists(dirname):
+ print(f"Output directory {dirname} does not exists, creating...")
+ os.makedirs(dirname)
+ np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features)
+ else:
+ print(f'{input_file}, failed at ffprobe.\n')
+
+ print(f"Total number of frames: {totatl_num_frames}")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='')
+ parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4')
+ parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.')
+ parser.add_argument('--save_dir', type=str, default='./tmp')
+ args = parser.parse_args()
+
+ query = ' '.join(args.text)
+
+ print(args.vid_path)
+ print(query)
+ extractor(args.vid_path, [query], args.save_dir)
diff --git a/run_on_video/data_utils.py b/run_on_video/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..44c073f4689b20bd8f7adb2aac77c0b7523ed8dd
--- /dev/null
+++ b/run_on_video/data_utils.py
@@ -0,0 +1,170 @@
+import torch
+import os
+import numpy as np
+import ffmpeg
+import math
+from run_on_video import clip
+
+
+class ClipFeatureExtractor:
+ def __init__(self, framerate=1/2, size=224, centercrop=True, model_name_or_path="ViT-B/32", device="cuda"):
+ self.video_loader = VideoLoader(framerate=framerate, size=size, centercrop=centercrop)
+ print("Loading CLIP models")
+ self.clip_extractor, _ = clip.load(model_name_or_path, device=device, jit=False)
+ self.tokenizer = clip.tokenize
+ self.video_preprocessor = Preprocessing()
+ self.device = device
+
+ @torch.no_grad()
+ def encode_video(self, video_path: str, bsz=60):
+ video_frames = self.video_loader.read_video_from_file(video_path) # (T, H, W, 3)
+ video_frames = self.video_preprocessor(video_frames)
+ n_frames = len(video_frames)
+ n_batch = int(math.ceil(n_frames / bsz))
+ video_features = []
+ for i in range(n_batch):
+ st_idx = i * bsz
+ ed_idx = (i+1) * bsz
+ _video_frames = video_frames[st_idx:ed_idx].to(self.device)
+ _video_features = self.clip_extractor.encode_image(_video_frames)
+ video_features.append(_video_features)
+ video_features = torch.cat(video_features, dim=0)
+ return video_features # (T=#frames, d) torch tensor
+
+ @torch.no_grad()
+ def encode_text(self, text_list, bsz=60):
+ n_text = len(text_list)
+ n_batch = int(math.ceil(n_text / bsz))
+ text_features = []
+ for i in range(n_batch):
+ st_idx = i * bsz
+ ed_idx = (i+1) * bsz
+ encoded_texts = self.tokenizer(text_list[st_idx:ed_idx], context_length=77).to(self.device)
+ output = self.clip_extractor.encode_text(encoded_texts)
+ valid_lengths = (encoded_texts != 0).sum(1).tolist()
+ batch_last_hidden_states = output["last_hidden_state"]
+ for j, valid_len in enumerate(valid_lengths):
+ text_features.append(batch_last_hidden_states[j, :valid_len])
+ return text_features # List([L_j, d]) torch tensor
+
+
+def convert_to_float(frac_str):
+ try:
+ return float(frac_str)
+ except ValueError:
+ try:
+ num, denom = frac_str.split('/')
+ except ValueError:
+ return None
+ try:
+ leading, num = num.split(' ')
+ except ValueError:
+ return float(num) / float(denom)
+ if float(leading) < 0:
+ sign_mult = -1
+ else:
+ sign_mult = 1
+ return float(leading) + sign_mult * (float(num) / float(denom))
+
+
+class Normalize(object):
+
+ def __init__(self, mean, std):
+ self.mean = torch.FloatTensor(mean).view(1, 3, 1, 1)
+ self.std = torch.FloatTensor(std).view(1, 3, 1, 1)
+
+ def __call__(self, tensor):
+ tensor = (tensor - self.mean) / (self.std + 1e-8)
+ return tensor
+
+
+class Preprocessing(object):
+
+ def __init__(self):
+ self.norm = Normalize(
+ mean=[0.48145466, 0.4578275, 0.40821073],
+ std=[0.26862954, 0.26130258, 0.27577711])
+
+ def __call__(self, tensor):
+ tensor = tensor / 255.0
+ tensor = self.norm(tensor)
+ return tensor
+
+
+class VideoLoader:
+ """Pytorch video loader.
+ Copied and modified from:
+ https://github.com/linjieli222/HERO_Video_Feature_Extractor/blob/main/clip/video_loader.py
+ """
+ def __init__(
+ self,
+ framerate=1/2,
+ size=224,
+ centercrop=True,
+ ):
+ self.centercrop = centercrop
+ self.size = size
+ self.framerate = framerate
+
+ def _get_video_info(self, video_path):
+ probe = ffmpeg.probe(video_path)
+ video_stream = next((stream for stream in probe['streams']
+ if stream['codec_type'] == 'video'), None)
+ width = int(video_stream['width'])
+ height = int(video_stream['height'])
+ fps = math.floor(convert_to_float(video_stream['avg_frame_rate']))
+ try:
+ frames_length = int(video_stream['nb_frames'])
+ duration = float(video_stream['duration'])
+ except Exception:
+ frames_length, duration = -1, -1
+ info = {"duration": duration, "frames_length": frames_length,
+ "fps": fps, "height": height, "width": width}
+ return info
+
+ def _get_output_dim(self, h, w):
+ if isinstance(self.size, tuple) and len(self.size) == 2:
+ return self.size
+ elif h >= w:
+ return int(h * self.size / w), self.size
+ else:
+ return self.size, int(w * self.size / h)
+
+ def read_video_from_file(self, video_path):
+ try:
+ info = self._get_video_info(video_path)
+ h, w = info["height"], info["width"]
+ except Exception:
+ print('ffprobe failed at: {}'.format(video_path))
+ return {'video': torch.zeros(1), 'input': video_path,
+ 'info': {}}
+ height, width = self._get_output_dim(h, w)
+ try:
+ duration = info["duration"]
+ fps = self.framerate
+ if duration > 0 and duration < 1/fps+0.1:
+ fps = 2/max(int(duration), 1)
+ print(duration, fps)
+ except Exception:
+ fps = self.framerate
+ cmd = (
+ ffmpeg
+ .input(video_path)
+ .filter('fps', fps=fps)
+ .filter('scale', width, height)
+ )
+ if self.centercrop:
+ x = int((width - self.size) / 2.0)
+ y = int((height - self.size) / 2.0)
+ cmd = cmd.crop(x, y, self.size, self.size)
+ out, _ = (
+ cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
+ .run(capture_stdout=True, quiet=True)
+ )
+ if self.centercrop and isinstance(self.size, int):
+ height, width = self.size, self.size
+ video = np.frombuffer(out, np.uint8).reshape(
+ [-1, height, width, 3])
+ video = torch.from_numpy(video.astype('float32'))
+ video = video.permute(0, 3, 1, 2)
+ return video
diff --git a/run_on_video/preprocessing.py b/run_on_video/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..93b3e8112b299c93667fb433ac683fa7f46e0fda
--- /dev/null
+++ b/run_on_video/preprocessing.py
@@ -0,0 +1,25 @@
+import torch as th
+
+
+class Normalize(object):
+
+ def __init__(self, mean, std):
+ self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
+ self.std = th.FloatTensor(std).view(1, 3, 1, 1)
+
+ def __call__(self, tensor):
+ tensor = (tensor - self.mean) / (self.std + 1e-8)
+ return tensor
+
+
+class Preprocessing(object):
+
+ def __init__(self):
+ self.norm = Normalize(
+ mean=[0.48145466, 0.4578275, 0.40821073],
+ std=[0.26862954, 0.26130258, 0.27577711])
+
+ def __call__(self, tensor):
+ tensor = tensor / 255.0
+ tensor = self.norm(tensor)
+ return tensor
diff --git a/run_on_video/text_extractor.py b/run_on_video/text_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e26cf4d7968ab718675c49e85907be38fcd5e2f3
--- /dev/null
+++ b/run_on_video/text_extractor.py
@@ -0,0 +1,36 @@
+import pdb
+import sys
+import json
+import torch
+import numpy as np
+from run_on_video.data_utils import ClipFeatureExtractor
+import torch.nn.functional as F
+import tqdm
+import os
+
+query_list = []
+qid_list = []
+dataset = 'charades'
+split = 'test'
+
+save_dir = f''
+
+with open(f"data/{dataset}/metadata/{dataset}_{split}.jsonl", 'r') as f:
+ while True:
+ line = f.readline()
+ if not line:
+ break
+ js = json.loads(line)
+ query_list.append(js['query'])
+ qid_list.append(str(js['qid']))
+
+# clip
+feature_extractor = ClipFeatureExtractor(
+ framerate=1 / 2, size=224, centercrop=True,
+ model_name_or_path="ViT-B/32", device='cuda'
+)
+# pdb.set_trace()
+query_feats = feature_extractor.encode_text(query_list)
+
+for i in tqdm.tqdm(range(len(query_feats))):
+ np.savez(save_dir + '/' + qid_list[i], last_hidden_state=query_feats[i].cpu().numpy())
diff --git a/run_on_video/video_extractor.py b/run_on_video/video_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8804063be0b6de26c8ea63c5d035dfd4a27e4884
--- /dev/null
+++ b/run_on_video/video_extractor.py
@@ -0,0 +1,94 @@
+import pdb
+import torch as th
+import math
+import numpy as np
+import torch
+from run_on_video.video_loader import VideoLoader
+from torch.utils.data import DataLoader
+import argparse
+from run_on_video.preprocessing import Preprocessing
+import torch.nn.functional as F
+from tqdm import tqdm
+import os
+import sys
+from run_on_video import clip
+import argparse
+
+#################################
+@torch.no_grad()
+def vid2clip(model, vid_path, output_file,
+ model_version="ViT-B/32", output_feat_size=512,
+ clip_len=2, overwrite=True, num_decoding_thread=4, half_precision=False):
+ dataset = VideoLoader(
+ vid_path,
+ framerate=1/clip_len,
+ size=224,
+ centercrop=True,
+ overwrite=overwrite,
+ model_version=model_version
+ )
+ n_dataset = len(dataset)
+ loader = DataLoader(
+ dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=num_decoding_thread,
+ sampler=None,
+ )
+ preprocess = Preprocessing()
+ device_id = next(model.parameters()).device
+
+ totatl_num_frames = 0
+ with th.no_grad():
+ for k, data in enumerate(tqdm(loader)):
+ input_file = data['input'][0]
+ if os.path.isfile(output_file):
+ # print(f'Video {input_file} already processed.')
+ continue
+ elif not os.path.isfile(input_file):
+ print(f'{input_file}, does not exist.\n')
+ elif len(data['video'].shape) > 4:
+ video = data['video'].squeeze(0)
+ if len(video.shape) == 4:
+ video = preprocess(video)
+ n_chunk = len(video)
+ vid_features = th.cuda.FloatTensor(
+ n_chunk, output_feat_size).fill_(0)
+ n_iter = int(math.ceil(n_chunk))
+ for i in range(n_iter):
+ min_ind = i
+ max_ind = (i + 1)
+ video_batch = video[min_ind:max_ind].to(device_id)
+ batch_features = model.encode_image(video_batch)
+ vid_features[min_ind:max_ind] = batch_features
+ vid_features = vid_features.cpu().numpy()
+ if half_precision:
+ vid_features = vid_features.astype('float16')
+ totatl_num_frames += vid_features.shape[0]
+ # safeguard output path before saving
+ dirname = os.path.dirname(output_file)
+ if not os.path.exists(dirname):
+ print(f"Output directory {dirname} does not exists, creating...")
+ os.makedirs(dirname)
+ np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features)
+ else:
+ print(f'{input_file}, failed at ffprobe.\n')
+ print(f"Total number of frames: {totatl_num_frames}")
+ return vid_features
+
+def txt2clip(model, text, output_file):
+ device_id = next(model.parameters()).device
+ encoded_texts = clip.tokenize(text).to(device_id)
+ text_feature = model.encode_text(encoded_texts)['last_hidden_state']
+ valid_lengths = (encoded_texts != 0).sum(1).tolist()[0]
+ text_feature = text_feature[0, :valid_lengths].detach().cpu().numpy()
+
+ np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature)
+ return text_feature
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='')
+ parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4')
+ parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.')
+ parser.add_argument('--save_dir', type=str, default='./tmp')
+ args = parser.parse_args()
diff --git a/run_on_video/video_loader.py b/run_on_video/video_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dddbf5df1b22187226fc53ec260663c708d8452
--- /dev/null
+++ b/run_on_video/video_loader.py
@@ -0,0 +1,125 @@
+import torch as th
+from torch.utils.data import Dataset
+import pandas as pd
+import os
+import numpy as np
+import ffmpeg
+import math
+
+
+def convert_to_float(frac_str):
+ try:
+ return float(frac_str)
+ except ValueError:
+ try:
+ num, denom = frac_str.split('/')
+ except ValueError:
+ return None
+ try:
+ leading, num = num.split(' ')
+ except ValueError:
+ return float(num) / float(denom)
+ if float(leading) < 0:
+ sign_mult = -1
+ else:
+ sign_mult = 1
+ return float(leading) + sign_mult * (float(num) / float(denom))
+
+
+class VideoLoader(Dataset):
+ """Pytorch video loader."""
+
+ def __init__(
+ self,
+ vid_path,
+ framerate=1,
+ size=112,
+ centercrop=False,
+ overwrite=False,
+ model_version="ViT-B/32",
+ ):
+ """
+ Args:
+ """
+ self.vid_path = vid_path
+
+ self.centercrop = centercrop
+ self.size = size
+ self.framerate = framerate
+ self.overwrite = overwrite
+ self.model_version = model_version
+
+ def __len__(self):
+ return 1
+
+ def _get_video_info(self, video_path):
+ probe = ffmpeg.probe(video_path)
+ video_stream = next((stream for stream in probe['streams']
+ if stream['codec_type'] == 'video'), None)
+ width = int(video_stream['width'])
+ height = int(video_stream['height'])
+ fps = math.floor(convert_to_float(video_stream['avg_frame_rate']))
+ try:
+ frames_length = int(video_stream['nb_frames'])
+ duration = float(video_stream['duration'])
+ except Exception:
+ frames_length, duration = -1, -1
+ info = {"duration": duration, "frames_length": frames_length,
+ "fps": fps, "height": height, "width": width}
+ return info
+
+ def _get_output_dim(self, h, w):
+ if isinstance(self.size, tuple) and len(self.size) == 2:
+ return self.size
+ elif h >= w:
+ return int(h * self.size / w), self.size
+ else:
+ return self.size, int(w * self.size / h)
+
+ def __getitem__(self, id):
+ video_path = self.vid_path
+
+ load_flag = os.path.isfile(video_path)
+ if load_flag:
+ try:
+ info = self._get_video_info(video_path)
+ h, w = info["height"], info["width"]
+ except Exception:
+ print('ffprobe failed at: {}'.format(video_path))
+ return {'video': th.zeros(1), 'input': video_path,'info': {}}
+ try:
+ height, width = self._get_output_dim(h, w)
+ try:
+ duration = info["duration"]
+ fps = self.framerate
+ if duration > 0 and duration < 1/fps+0.1:
+ fps = 2/max(int(duration), 1)
+ print(duration, fps)
+ except Exception:
+ fps = self.framerate
+ cmd = (
+ ffmpeg
+ .input(video_path)
+ .filter('fps', fps=fps)
+ .filter('scale', width, height)
+ # .filter('scale', self.size, self.size)
+ )
+ if self.centercrop:
+ x = int((width - self.size) / 2.0)
+ y = int((height - self.size) / 2.0)
+ cmd = cmd.crop(x, y, self.size, self.size)
+ out, _ = (
+ cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
+ .run(capture_stdout=True, quiet=True)
+ )
+ if self.centercrop and isinstance(self.size, int):
+ height, width = self.size, self.size
+ video = np.frombuffer(out, np.uint8).reshape(
+ [-1, height, width, 3])
+ video = th.from_numpy(video.astype('float32'))
+ video = video.permute(0, 3, 1, 2)
+ except:
+ return {'video': th.zeros(1), 'input': video_path,'info': {}}
+ else:
+ video = th.zeros(1)
+ return {'video': video, 'input': video_path}
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/basic_utils.py b/utils/basic_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab03b808017927d67665a37d606de7b6538253a6
--- /dev/null
+++ b/utils/basic_utils.py
@@ -0,0 +1,234 @@
+import os
+import json
+import torch
+import random
+import zipfile
+import numpy as np
+import pickle
+from collections import OrderedDict, Counter
+import pandas as pd
+import shutil
+
+def set_seed(seed, use_cuda=True):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if use_cuda:
+ torch.cuda.manual_seed_all(seed)
+
+def load_pickle(filename):
+ with open(filename, "rb") as f:
+ return pickle.load(f)
+
+
+def save_pickle(data, filename):
+ with open(filename, "wb") as f:
+ pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+def load_json(filename):
+ with open(filename, "r") as f:
+ return json.load(f)
+
+
+def save_json(data, filename, save_pretty=False, sort_keys=False):
+ with open(filename, "w") as f:
+ if save_pretty:
+ f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
+ else:
+ json.dump(data, f)
+
+
+def load_jsonl(filename):
+ with open(filename, "r") as f:
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
+
+
+def save_jsonl(data, filename):
+ """data is a list"""
+ with open(filename, "w") as f:
+ f.write("\n".join([json.dumps(e) for e in data]))
+
+
+def save_lines(list_of_str, filepath):
+ with open(filepath, "w") as f:
+ f.write("\n".join(list_of_str))
+
+
+def read_lines(filepath):
+ with open(filepath, "r") as f:
+ return [e.strip("\n") for e in f.readlines()]
+
+
+def mkdirp(p):
+ if not os.path.exists(p):
+ os.makedirs(p)
+
+def remkdirp(p):
+ if os.path.exists(p):
+ shutil.rmtree(p)
+ os.makedirs(p)
+
+def flat_list_of_lists(l):
+ """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
+ return [item for sublist in l for item in sublist]
+
+
+def convert_to_seconds(hms_time):
+ """ convert '00:01:12' to 72 seconds.
+ :hms_time (str): time in comma separated string, e.g. '00:01:12'
+ :return (int): time in seconds, e.g. 72
+ """
+ times = [float(t) for t in hms_time.split(":")]
+ return times[0] * 3600 + times[1] * 60 + times[2]
+
+
+def get_video_name_from_url(url):
+ return url.split("/")[-1][:-4]
+
+
+def merge_dicts(list_dicts):
+ merged_dict = list_dicts[0].copy()
+ for i in range(1, len(list_dicts)):
+ merged_dict.update(list_dicts[i])
+ return merged_dict
+
+
+def l2_normalize_np_array(np_array, eps=1e-5):
+ """np_array: np.ndarray, (*, D), where the last dim will be normalized"""
+ return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps)
+
+
+def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, exclude_extensions=None,
+ exclude_dirs_substring=None):
+ """make a zip file of root_dir, save it to save_path.
+ exclude_paths will be excluded if it is a subdir of root_dir.
+ An enclosing_dir is added is specified.
+ """
+ abs_src = os.path.abspath(src_dir)
+ with zipfile.ZipFile(save_path, "w") as zf:
+ for dirname, subdirs, files in os.walk(src_dir):
+ if exclude_dirs is not None:
+ for e_p in exclude_dirs:
+ if e_p in subdirs:
+ subdirs.remove(e_p)
+ if exclude_dirs_substring is not None:
+ to_rm = []
+ for d in subdirs:
+ if exclude_dirs_substring in d:
+ to_rm.append(d)
+ for e in to_rm:
+ subdirs.remove(e)
+ arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:])
+ zf.write(dirname, arcname)
+ for filename in files:
+ if exclude_extensions is not None:
+ if os.path.splitext(filename)[1] in exclude_extensions:
+ continue # do not zip it
+ absname = os.path.join(dirname, filename)
+ arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:])
+ zf.write(absname, arcname)
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current/max/min value"""
+ def __init__(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+ self.max = -1e10
+ self.min = 1e10
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+ self.max = -1e10
+ self.min = 1e10
+
+ def update(self, val, n=1):
+ self.max = max(val, self.max)
+ self.min = min(val, self.min)
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True):
+ """Dissect an array (N, D) into a list a sub-array,
+ np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept"""
+ if assert_equal:
+ assert len(np_array) == sum(lengths)
+ length_indices = [0, ]
+ for i in range(len(lengths)):
+ length_indices.append(length_indices[i] + lengths[i])
+ if dim == 0:
+ array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))]
+ elif dim == 1:
+ array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
+ elif dim == 2:
+ array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
+ else:
+ raise NotImplementedError
+ return array_list
+
+
+def get_ratio_from_counter(counter_obj, threshold=200):
+ keys = counter_obj.keys()
+ values = counter_obj.values()
+ filtered_values = [counter_obj[k] for k in keys if k > threshold]
+ return float(sum(filtered_values)) / sum(values)
+
+
+def get_counter_dist(counter_object, sort_type="none"):
+ _sum = sum(counter_object.values())
+ dist = {k: float(f"{100 * v / _sum:.2f}") for k, v in counter_object.items()}
+ if sort_type == "value":
+ dist = OrderedDict(sorted(dist.items(), reverse=True))
+ return dist
+
+
+def get_show_name(vid_name):
+ """
+ get tvshow name from vid_name
+ :param vid_name: video clip name
+ :return: tvshow name
+ """
+ show_list = ["friends", "met", "castle", "house", "grey"]
+ vid_name_prefix = vid_name.split("_")[0]
+ show_name = vid_name_prefix if vid_name_prefix in show_list else "bbt"
+ return show_name
+
+
+def get_abspaths_by_ext(dir_path, ext=(".jpg",)):
+ """Get absolute paths to files in dir_path with extensions specified by ext.
+ Note this function does work recursively.
+ """
+ if isinstance(ext, list):
+ ext = tuple(ext)
+ if isinstance(ext, str):
+ ext = tuple([ext, ])
+ filepaths = [os.path.join(root, name)
+ for root, dirs, files in os.walk(dir_path)
+ for name in files
+ if name.endswith(tuple(ext))]
+ return filepaths
+
+
+def get_basename_no_ext(path):
+ """ '/data/movienet/240p_keyframe_feats/tt7672188.npz' --> 'tt7672188' """
+ return os.path.splitext(os.path.split(path)[1])[0]
+
+
+def dict_to_markdown(d, max_str_len=120):
+ # convert list into its str representation
+ d = {k: v.__repr__() if isinstance(v, list) else v for k, v in d.items()}
+ # truncate string that is longer than max_str_len
+ if max_str_len is not None:
+ d = {k: v[-max_str_len:] if isinstance(v, str) else v for k, v in d.items()}
+ return pd.DataFrame(d, index=[0]).transpose().to_markdown()
+
diff --git a/utils/cpd_auto.py b/utils/cpd_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..b561b08c4764350843955b510402e3c0a28d62a7
--- /dev/null
+++ b/utils/cpd_auto.py
@@ -0,0 +1,89 @@
+import numpy as np
+from .cpd_nonlin import cpd_nonlin
+
+def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs):
+ """Main interface
+
+ Detect change points automatically selecting their number
+ K - kernel between each pair of frames in video
+ ncp - maximum ncp
+ vmax - special parameter
+ Optional arguments:
+ lmin - minimum segment length
+ lmax - maximum segment length
+ desc_rate - rate of descriptor sampling (vmax always corresponds to 1x)
+
+ Note:
+ - cps are always calculated in subsampled coordinates irrespective to
+ desc_rate
+ - lmin and m should be in agreement
+ ---
+ Returns: (cps, costs)
+ cps - best selected change-points
+ costs - costs for 0,1,2,...,m change-points
+
+ Memory requirement: ~ (3*N*N + N*ncp)*4 bytes ~= 16 * N^2 bytes
+ That is 1,6 Gb for the N=10000.
+ """
+ m = ncp
+ (_, scores) = cpd_nonlin(K, m, backtrack=False, **kwargs)
+ # print("scores ",scores)
+
+ N = K.shape[0]
+ N2 = N*desc_rate # length of the video before subsampling
+
+ penalties = np.zeros(m+1)
+ # Prevent division by zero (in case of 0 changes)
+ ncp = np.arange(1, m+1)
+ penalties[1:] = (vmax*ncp/(2.0*N2))*(np.log(float(N2)/ncp)+1)
+
+ costs = scores/float(N) + penalties
+ m_best = np.argmin(costs)
+ # print("cost ",costs)
+ # print("m_best ",m_best)
+ (cps, scores2) = cpd_nonlin(K, m_best, **kwargs)
+
+ return (cps, costs)
+
+
+# ------------------------------------------------------------------------------
+# Extra functions (currently not used)
+
+def estimate_vmax(K_stable):
+ """K_stable - kernel between all frames of a stable segment"""
+ n = K_stable.shape[0]
+ vmax = np.trace(centering(K_stable)/n)
+ return vmax
+
+
+def centering(K):
+ """Apply kernel centering"""
+ mean_rows = np.mean(K, 1)[:, np.newaxis]
+ return K - mean_rows - mean_rows.T + np.mean(mean_rows)
+
+
+def eval_score(K, cps):
+ """ Evaluate unnormalized empirical score
+ (sum of kernelized scatters) for the given change-points """
+ N = K.shape[0]
+ cps = [0] + list(cps) + [N]
+ V1 = 0
+ V2 = 0
+ for i in range(len(cps)-1):
+ K_sub = K[cps[i]:cps[i+1], :][:, cps[i]:cps[i+1]]
+ V1 += np.sum(np.diag(K_sub))
+ V2 += np.sum(K_sub) / float(cps[i+1] - cps[i])
+ return (V1 - V2)
+
+
+def eval_cost(K, cps, score, vmax):
+ """ Evaluate cost function for automatic number of change points selection
+ K - kernel between all frames
+ cps - selected change-points
+ score - unnormalized empirical score (sum of kernelized scatters)
+ vmax - vmax parameter"""
+
+ N = K.shape[0]
+ penalty = (vmax*len(cps)/(2.0*N))*(np.log(float(N)/len(cps))+1)
+ return score/float(N) + penalty
+
diff --git a/utils/cpd_nonlin.py b/utils/cpd_nonlin.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee577b95e36b2208bde52fa80e73c88048f5536c
--- /dev/null
+++ b/utils/cpd_nonlin.py
@@ -0,0 +1,115 @@
+import numpy as np
+# from scipy import weave
+
+def calc_scatters(K):
+ n = K.shape[0]
+ K1 = np.cumsum([0] + list(np.diag(K)))
+ K2 = np.zeros((n+1, n+1))
+ K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1) # TODO: use the fact that K - symmetric
+
+ scatters = np.zeros((n, n))
+
+# code = r"""
+# for (int i = 0; i < n; i++) {
+# for (int j = i; j < n; j++) {
+# scatters(i,j) = K1(j+1)-K1(i) - (K2(j+1,j+1)+K2(i,i)-K2(j+1,i)-K2(i,j+1))/(j-i+1);
+# }
+# }
+# """
+# weave.inline(code, ['K1','K2','scatters','n'], global_dict = \
+# {'K1':K1, 'K2':K2, 'scatters':scatters, 'n':n}, type_converters=weave.converters.blitz)
+
+ for i in range(n):
+ for j in range(i, n):
+ scatters[i,j] = K1[j+1] - K1[i] - (K2[j+1,j+1]+K2[i,i]-K2[j+1,i]-K2[i,j+1])/(j-i+1)
+ return scatters
+
+def cpd_nonlin(K, ncp, lmin=1, lmax=100000, backtrack=True, verbose=True,
+ out_scatters=None):
+ """ Change point detection with dynamic programming
+ K - square kernel matrix
+ ncp - number of change points to detect (ncp >= 0)
+ lmin - minimal length of a segment
+ lmax - maximal length of a segment
+ backtrack - when False - only evaluate objective scores (to save memory)
+
+ Returns: (cps, obj)
+ cps - detected array of change points: mean is thought to be constant on [ cps[i], cps[i+1] )
+ obj_vals - values of the objective function for 0..m changepoints
+
+ """
+ m = int(ncp) # prevent numpy.int64
+
+ (n, n1) = K.shape
+ assert(n == n1), "Kernel matrix awaited."
+
+ assert(n >= (m + 1)*lmin)
+ assert(n <= (m + 1)*lmax)
+ assert(lmax >= lmin >= 1)
+
+ if verbose:
+ #print "n =", n
+ print("Precomputing scatters...")
+ J = calc_scatters(K)
+
+ if out_scatters != None:
+ out_scatters[0] = J
+
+ if verbose:
+ print("Inferring best change points...")
+ I = 1e101*np.ones((m+1, n+1))
+ I[0, lmin:lmax] = J[0, lmin-1:lmax-1]
+
+ if backtrack:
+ p = np.zeros((m+1, n+1), dtype=int)
+ else:
+ p = np.zeros((1,1), dtype=int)
+
+# code = r"""
+# #define max(x,y) ((x)>(y)?(x):(y))
+# for (int k=1; k 1e99] = np.inf
+ return cps, scores
+
+
diff --git a/utils/kts_utils.py b/utils/kts_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3265f9a0d3317578cb0fa5eb1a8033f333bec937
--- /dev/null
+++ b/utils/kts_utils.py
@@ -0,0 +1,204 @@
+import numpy as np
+from .cpd_nonlin import cpd_nonlin
+
+def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs):
+ """Main interface
+
+ Detect change points automatically selecting their number
+ K - kernel between each pair of frames in video
+ ncp - maximum ncp
+ vmax - special parameter
+ Optional arguments:
+ lmin - minimum segment length
+ lmax - maximum segment length
+ desc_rate - rate of descriptor sampling (vmax always corresponds to 1x)
+
+ Note:
+ - cps are always calculated in subsampled coordinates irrespective to
+ desc_rate
+ - lmin and m should be in agreement
+ ---
+ Returns: (cps, costs)
+ cps - best selected change-points
+ costs - costs for 0,1,2,...,m change-points
+
+ Memory requirement: ~ (3*N*N + N*ncp)*4 bytes ~= 16 * N^2 bytes
+ That is 1,6 Gb for the N=10000.
+ """
+ m = ncp
+ (_, scores) = cpd_nonlin(K, m, backtrack=False, **kwargs)
+ # print("scores ",scores)
+
+ N = K.shape[0]
+ N2 = N * desc_rate # length of the video before subsampling
+
+ penalties = np.zeros(m + 1)
+ # Prevent division by zero (in case of 0 changes)
+ ncp = np.arange(1, m + 1)
+ penalties[1:] = (vmax * ncp / (2.0 * N2)) * (np.log(float(N2) / ncp) + 1)
+
+ costs = scores / float(N) + penalties
+ m_best = np.argmin(costs)
+ # print("cost ",costs)
+ # print("m_best ",m_best)
+ (cps, scores2) = cpd_nonlin(K, m_best, **kwargs)
+
+ return (cps, costs)
+
+
+# ------------------------------------------------------------------------------
+# Extra functions (currently not used)
+
+def estimate_vmax(K_stable):
+ """K_stable - kernel between all frames of a stable segment"""
+ n = K_stable.shape[0]
+ vmax = np.trace(centering(K_stable) / n)
+ return vmax
+
+
+def centering(K):
+ """Apply kernel centering"""
+ mean_rows = np.mean(K, 1)[:, np.newaxis]
+ return K - mean_rows - mean_rows.T + np.mean(mean_rows)
+
+
+def eval_score(K, cps):
+ """ Evaluate unnormalized empirical score
+ (sum of kernelized scatters) for the given change-points """
+ N = K.shape[0]
+ cps = [0] + list(cps) + [N]
+ V1 = 0
+ V2 = 0
+ for i in range(len(cps) - 1):
+ K_sub = K[cps[i]:cps[i + 1], :][:, cps[i]:cps[i + 1]]
+ V1 += np.sum(np.diag(K_sub))
+ V2 += np.sum(K_sub) / float(cps[i + 1] - cps[i])
+ return (V1 - V2)
+
+
+def eval_cost(K, cps, score, vmax):
+ """ Evaluate cost function for automatic number of change points selection
+ K - kernel between all frames
+ cps - selected change-points
+ score - unnormalized empirical score (sum of kernelized scatters)
+ vmax - vmax parameter"""
+
+ N = K.shape[0]
+ penalty = (vmax * len(cps) / (2.0 * N)) * (np.log(float(N) / len(cps)) + 1)
+ return score / float(N) + penalty
+
+
+def calc_scatters(K):
+ n = K.shape[0]
+ K1 = np.cumsum([0] + list(np.diag(K)))
+ K2 = np.zeros((n + 1, n + 1)).astype(np.double())
+ K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1) # TODO: use the fact that K - symmetric
+ # KK = np.cumsum(K, 0).astype(np.double())
+ # K2[1:, 1:] = np.cumsum(KK, 1) # TODO: use the fact that K - symmetric
+
+ scatters = np.zeros((n, n))
+
+ # code = r"""
+ # for (int i = 0; i < n; i++) {
+ # for (int j = i; j < n; j++) {
+ # scatters(i,j) = K1(j+1)-K1(i) - (K2(j+1,j+1)+K2(i,i)-K2(j+1,i)-K2(i,j+1))/(j-i+1);
+ # }
+ # }
+ # """
+ # weave.inline(code, ['K1','K2','scatters','n'], global_dict = \
+ # {'K1':K1, 'K2':K2, 'scatters':scatters, 'n':n}, type_converters=weave.converters.blitz)
+
+ for i in range(n):
+ for j in range(i, n):
+ scatters[i, j] = K1[j + 1] - K1[i] - (K2[j + 1, j + 1] + K2[i, i] - K2[j + 1, i] - K2[i, j + 1]) / (
+ j - i + 1)
+ return scatters
+
+def cpd_nonlin(K, ncp, lmin=1, lmax=100000, backtrack=True, verbose=True,
+ out_scatters=None):
+ """ Change point detection with dynamic programming
+ K - square kernel matrix
+ ncp - number of change points to detect (ncp >= 0)
+ lmin - minimal length of a segment
+ lmax - maximal length of a segment
+ backtrack - when False - only evaluate objective scores (to save memory)
+
+ Returns: (cps, obj)
+ cps - detected array of change points: mean is thought to be constant on [ cps[i], cps[i+1] )
+ obj_vals - values of the objective function for 0..m changepoints
+
+ """
+ m = int(ncp) # prevent numpy.int64
+
+ (n, n1) = K.shape
+ assert (n == n1), "Kernel matrix awaited."
+
+ assert (n >= (m + 1) * lmin)
+ assert (n <= (m + 1) * lmax)
+ assert (lmax >= lmin >= 1)
+
+ if verbose:
+ # print "n =", n
+ print("Precomputing scatters...")
+ J = calc_scatters(K)
+
+ if out_scatters != None:
+ out_scatters[0] = J
+
+ if verbose:
+ print("Inferring best change points...")
+ I = 1e101 * np.ones((m + 1, n + 1))
+ I[0, lmin:lmax] = J[0, lmin - 1:lmax - 1]
+
+ if backtrack:
+ p = np.zeros((m + 1, n + 1), dtype=int)
+ else:
+ p = np.zeros((1, 1), dtype=int)
+
+ # code = r"""
+ # #define max(x,y) ((x)>(y)?(x):(y))
+ # for (int k=1; k 1e99] = np.inf
+ return cps, scores
+
+
diff --git a/utils/model_utils.py b/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..06eed4751ad15e78692e64926dfd2741664949ce
--- /dev/null
+++ b/utils/model_utils.py
@@ -0,0 +1,15 @@
+def count_parameters(model, verbose=True):
+ """Count number of parameters in PyTorch model,
+ References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7.
+
+ from utils.utils import count_parameters
+ count_parameters(model)
+ import sys
+ sys.exit(1)
+ """
+ n_all = sum(p.numel() for p in model.parameters())
+ n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ if verbose:
+ print("Parameter Count: all {:,d}; trainable {:,d}".format(n_all, n_trainable))
+ return n_all, n_trainable
+
diff --git a/utils/span_utils.py b/utils/span_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b3b8a23d6d73820b3bc03b1dd1ce498214eca10
--- /dev/null
+++ b/utils/span_utils.py
@@ -0,0 +1,124 @@
+import pdb
+
+import torch
+
+
+def span_xx_to_cxw(xx_spans):
+ """
+ Args:
+ xx_spans: tensor, (#windows, 2) or (..., 2), each row is a window of format (st, ed)
+
+ Returns:
+ cxw_spans: tensor, (#windows, 2), each row is a window of format (center=(st+ed)/2, width=(ed-st))
+ >>> spans = torch.Tensor([[0, 1], [0.2, 0.4]])
+ >>> span_xx_to_cxw(spans)
+ tensor([[0.5000, 1.0000],
+ [0.3000, 0.2000]])
+ >>> spans = torch.Tensor([[[0, 1], [0.2, 0.4]]])
+ >>> span_xx_to_cxw(spans)
+ tensor([[[0.5000, 1.0000],
+ [0.3000, 0.2000]]])
+ """
+ center = xx_spans.sum(-1) * 0.5
+ width = xx_spans[..., 1] - xx_spans[..., 0]
+ return torch.stack([center, width], dim=-1)
+
+
+def span_cxw_to_xx(cxw_spans):
+ """
+ Args:
+ cxw_spans: tensor, (#windows, 2) or (..., 2), the last dim is a row denoting a window of format (center, width)
+
+ >>> spans = torch.Tensor([[0.5000, 1.0000], [0.3000, 0.2000]])
+ >>> span_cxw_to_xx(spans)
+ tensor([[0.0000, 1.0000],
+ [0.2000, 0.4000]])
+ >>> spans = torch.Tensor([[[0.5000, 1.0000], [0.3000, 0.2000]]])
+ >>> span_cxw_to_xx(spans)
+ tensor([[[0.0000, 1.0000],
+ [0.2000, 0.4000]]])
+ """
+ x1 = cxw_spans[..., 0] - 0.5 * cxw_spans[..., 1]
+ x2 = cxw_spans[..., 0] + 0.5 * cxw_spans[..., 1]
+ return torch.stack([x1, x2], dim=-1)
+
+
+def temporal_iou(spans1, spans2):
+ """
+ Args:
+ spans1: (N, 2) torch.Tensor, each row defines a span [st, ed]
+ spans2: (M, 2) torch.Tensor, ...
+
+ Returns:
+ iou: (N, M) torch.Tensor
+ union: (N, M) torch.Tensor
+ >>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]])
+ >>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]])
+ >>> temporal_iou(test_spans1, test_spans2)
+ (tensor([[0.6667, 0.2000],
+ [0.0000, 0.5000]]),
+ tensor([[0.3000, 1.0000],
+ [0.8000, 1.0000]]))
+ """
+ areas1 = spans1[:, 1] - spans1[:, 0] # (N, )
+ areas2 = spans2[:, 1] - spans2[:, 0] # (M, )
+
+ left = torch.max(spans1[:, None, 0], spans2[:, 0]) # (N, M)
+ right = torch.min(spans1[:, None, 1], spans2[:, 1]) # (N, M
+
+ inter = (right - left).clamp(min=0) # (N, M)
+ union = areas1[:, None] + areas2 - inter # (N, M)
+
+ iou = inter / union
+ return iou, union
+
+
+def temporal_intersection_over_pred(gt_spans, pred_spans):
+ """ intersection over the second input spans
+ Args:
+ gt_spans: (N, 2),
+ pred_spans: (M, 2)
+
+ Returns:
+
+ """
+ left = torch.max(gt_spans[:, None, 0], pred_spans[:, 0])
+ right = torch.min(gt_spans[:, None, 1], pred_spans[:, 1])
+
+ inter = (right - left).clamp(min=0) # (N, M)
+ inter_over_pred = inter / (pred_spans[:, 1] - pred_spans[:, 0])
+ return inter_over_pred
+
+
+def generalized_temporal_iou(spans1, spans2):
+ """
+ Generalized IoU from https://giou.stanford.edu/
+ Also reference to DETR implementation of generalized_box_iou
+ https://github.com/facebookresearch/detr/blob/master/util/box_ops.py#L40
+
+ Args:
+ spans1: (N, 2) torch.Tensor, each row defines a span in xx format [st, ed]
+ spans2: (M, 2) torch.Tensor, ...
+
+ Returns:
+ giou: (N, M) torch.Tensor
+
+ >>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]])
+ >>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]])
+ >>> generalized_temporal_iou(test_spans1, test_spans2)
+ tensor([[ 0.6667, 0.2000],
+ [-0.2000, 0.5000]])
+ """
+ spans1 = spans1.float()
+ spans2 = spans2.float()
+ assert (spans1[:, 1] >= spans1[:, 0]).all()
+ assert (spans2[:, 1] >= spans2[:, 0]).all()
+ iou, union = temporal_iou(spans1, spans2)
+
+ left = torch.min(spans1[:, None, 0], spans2[:, 0]) # (N, M)
+ right = torch.max(spans1[:, None, 1], spans2[:, 1]) # (N, M)
+ enclosing_area = (right - left).clamp(min=0) # (N, M)
+
+ return iou - (enclosing_area - union) / enclosing_area
+
+
diff --git a/utils/temporal_nms.py b/utils/temporal_nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..2844f5d4c1ac71760cd82c7aaf82c6b2daa9a207
--- /dev/null
+++ b/utils/temporal_nms.py
@@ -0,0 +1,74 @@
+"""
+Non-Maximum Suppression for video proposals.
+"""
+
+
+def compute_temporal_iou(pred, gt):
+ """ deprecated due to performance concerns
+ compute intersection-over-union along temporal axis
+ Args:
+ pred: [st (float), ed (float)]
+ gt: [st (float), ed (float)]
+ Returns:
+ iou (float):
+
+ Ref: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
+ """
+ intersection = max(0, min(pred[1], gt[1]) - max(pred[0], gt[0]))
+ union = max(pred[1], gt[1]) - min(pred[0], gt[0]) # not the correct union though
+ if union == 0:
+ return 0
+ else:
+ return 1.0 * intersection / union
+
+
+def temporal_nms(predictions, nms_thd, max_after_nms=100):
+ """
+ Args:
+ predictions: list(sublist), each sublist is [st (float), ed(float), score (float)],
+ note larger scores are better and are preserved. For metrics that are better when smaller,
+ please convert to its negative, e.g., convert distance to negative distance.
+ nms_thd: float in [0, 1]
+ max_after_nms:
+ Returns:
+ predictions_after_nms: list(sublist), each sublist is [st (float), ed(float), score (float)]
+ References:
+ https://github.com/wzmsltw/BSN-boundary-sensitive-network/blob/7b101fc5978802aa3c95ba5779eb54151c6173c6/Post_processing.py#L42
+ """
+ if len(predictions) == 1: # only has one prediction, no need for nms
+ return predictions
+
+ predictions = sorted(predictions, key=lambda x: x[2], reverse=True) # descending order
+
+ tstart = [e[0] for e in predictions]
+ tend = [e[1] for e in predictions]
+ tscore = [e[2] for e in predictions]
+ rstart = []
+ rend = []
+ rscore = []
+ while len(tstart) > 1 and len(rscore) < max_after_nms: # max 100 after nms
+ idx = 1
+ while idx < len(tstart): # compare with every prediction in the list.
+ if compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]]) > nms_thd:
+ # rm highly overlapped lower score entries.
+ tstart.pop(idx)
+ tend.pop(idx)
+ tscore.pop(idx)
+ # print("--------------------------------")
+ # print(compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]]))
+ # print([tstart[0], tend[0]], [tstart[idx], tend[idx]])
+ # print(tstart.pop(idx), tend.pop(idx), tscore.pop(idx))
+ else:
+ # move to next
+ idx += 1
+ rstart.append(tstart.pop(0))
+ rend.append(tend.pop(0))
+ rscore.append(tscore.pop(0))
+
+ if len(rscore) < max_after_nms and len(tstart) >= 1: # add the last, possibly empty.
+ rstart.append(tstart.pop(0))
+ rend.append(tend.pop(0))
+ rscore.append(tscore.pop(0))
+
+ predictions_after_nms = [[st, ed, s] for s, st, ed in zip(rscore, rstart, rend)]
+ return predictions_after_nms
diff --git a/utils/tensor_utils.py b/utils/tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c2c25a83b66092b1ce8731b4d9fae1523438b29
--- /dev/null
+++ b/utils/tensor_utils.py
@@ -0,0 +1,93 @@
+import numpy as np
+import torch
+
+
+def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None):
+ """ Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray)
+ into a (n+1)-d array, only allow the first dim has variable lengths.
+ Args:
+ sequences: list(n-d tensor or list)
+ dtype: np.dtype or torch.dtype
+ device:
+ fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length.
+ return will be of shape [len(sequences), fixed_length, ...]
+ Returns:
+ padded_seqs: ((n+1)-d tensor) padded with zeros
+ mask: (2d tensor) of the same shape as the first two dims of padded_seqs,
+ 1 indicate valid, 0 otherwise
+ Examples:
+ >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
+ >>> pad_sequences_1d(test_data_list, dtype=torch.long)
+ >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)]
+ >>> pad_sequences_1d(test_data_3d, dtype=torch.float)
+ >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
+ >>> pad_sequences_1d(test_data_list, dtype=np.float32)
+ >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)]
+ >>> pad_sequences_1d(test_data_3d, dtype=np.float32)
+ """
+ if isinstance(sequences[0], list):
+ if "torch" in str(dtype):
+ sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences]
+ else:
+ sequences = [np.asarray(s, dtype=dtype) for s in sequences]
+
+ extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements
+ lengths = [len(seq) for seq in sequences]
+ if fixed_length is not None:
+ max_length = fixed_length
+ else:
+ max_length = max(lengths)
+ if isinstance(sequences[0], torch.Tensor):
+ assert "torch" in str(dtype), "dtype and input type does not match"
+ padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device)
+ mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device)
+ else: # np
+ assert "numpy" in str(dtype), "dtype and input type does not match"
+ padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype)
+ mask = np.zeros((len(sequences), max_length), dtype=np.float32)
+
+ for idx, seq in enumerate(sequences):
+ end = lengths[idx]
+ padded_seqs[idx, :end] = seq
+ mask[idx, :end] = 1
+ return padded_seqs, mask # , lengths
+
+
+def pad_sequences_2d(sequences, dtype=torch.long):
+ """ Pad a double-nested list or a sequence of n-d torch tensor into a (n+1)-d tensor,
+ only allow the first two dims has variable lengths
+ Args:
+ sequences: list(n-d tensor or list)
+ dtype: torch.long for word indices / torch.float (float32) for other cases
+ Returns:
+ Examples:
+ >>> test_data_list = [[[1, 3, 5], [3, 7, 4, 1]], [[98, 34, 11, 89, 90], [22], [34, 56]],]
+ >>> pad_sequences_2d(test_data_list, dtype=torch.long) # torch.Size([2, 3, 5])
+ >>> test_data_3d = [torch.randn(2,2,4), torch.randn(4,3,4), torch.randn(1,5,4)]
+ >>> pad_sequences_2d(test_data_3d, dtype=torch.float) # torch.Size([2, 3, 5])
+ >>> test_data_3d2 = [[torch.randn(2,4), ], [torch.randn(3,4), torch.randn(5,4)]]
+ >>> pad_sequences_2d(test_data_3d2, dtype=torch.float) # torch.Size([2, 3, 5])
+ # TODO add support for numpy array
+ """
+ bsz = len(sequences)
+ para_lengths = [len(seq) for seq in sequences]
+ max_para_len = max(para_lengths)
+ sen_lengths = [[len(word_seq) for word_seq in seq] for seq in sequences]
+ max_sen_len = max([max(e) for e in sen_lengths])
+
+ if isinstance(sequences[0], torch.Tensor):
+ extra_dims = sequences[0].shape[2:]
+ elif isinstance(sequences[0][0], torch.Tensor):
+ extra_dims = sequences[0][0].shape[1:]
+ else:
+ sequences = [[torch.Tensor(word_seq, dtype=dtype) for word_seq in seq] for seq in sequences]
+ extra_dims = ()
+
+ padded_seqs = torch.zeros((bsz, max_para_len, max_sen_len) + extra_dims, dtype=dtype)
+ mask = torch.zeros(bsz, max_para_len, max_sen_len).float()
+
+ for b_i in range(bsz):
+ for sen_i, sen_l in enumerate(sen_lengths[b_i]):
+ padded_seqs[b_i, sen_i, :sen_l] = sequences[b_i][sen_i]
+ mask[b_i, sen_i, :sen_l] = 1
+ return padded_seqs, mask # , sen_lengths
diff --git a/utils/windows_utils.py b/utils/windows_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3527cdfd7107db5d7eb57afe47f3e8b3bbbc15d
--- /dev/null
+++ b/utils/windows_utils.py
@@ -0,0 +1,59 @@
+"""
+Find windows from a video with clip_ids.
+
+A window is defined by a [start_clip_idx, end_clip_idx] pair:
+For example, assuming clip_len = 2 seconds
+[0, 0] meaning a single clip window [0, 2] (seconds)
+[10, 19] meaning a 9 clip window [20, 40] (seconds)
+
+"""
+
+
+def convert_clip_ids_to_windows(clip_ids):
+ """ Inverse function of convert_windows_to_clip_ids
+ Args:
+ clip_ids: list(int), each is a index of a clip, starting from 0
+
+ Returns:
+ list(list(int)), each sublist contains two integers which are clip indices.
+ [10, 19] meaning a 9 clip window [20, 40] (seconds), if each clip is 2 seconds.
+
+ >>> test_clip_ids = [56, 57, 58, 59, 60, 61, 62] + [64, ] + [67, 68, 69, 70, 71]
+ >>> convert_clip_ids_to_windows(test_clip_ids)
+ [[56, 62], [64, 64], [67, 71]]
+ """
+ windows = []
+ _window = [clip_ids[0], None]
+ last_clip_id = clip_ids[0]
+ for clip_id in clip_ids:
+ if clip_id - last_clip_id > 1: # find gap
+ _window[1] = last_clip_id
+ windows.append(_window)
+ _window = [clip_id, None]
+ last_clip_id = clip_id
+ _window[1] = last_clip_id
+ windows.append(_window)
+ return windows
+
+
+def convert_windows_to_clip_ids(windows):
+ """ Inverse function of convert_clip_ids_to_windows
+ Args:
+ windows: list(list(int)), each sublist contains two integers which are clip indices.
+ [10, 11] meaning a 9 clip window [20, 40] (seconds), if each clip is 2 seconds.
+
+ Returns:
+ clip_ids: list(int)
+
+ >>> test_windows =[[56, 62], [64, 64], [67, 71]]
+ >>> convert_windows_to_clip_ids(test_windows)
+ [56, 57, 58, 59, 60, 61, 62] + [64, ] + [67, 68, 69, 70, 71]
+ """
+ clip_ids = []
+ for w in windows:
+ clip_ids += list(range(w[0], w[1]+1))
+ return clip_ids
+
+
+def convert_clip_window_to_seconds(window, clip_len=2):
+ return [window[0] * clip_len, (window[1] + 1) * clip_len]