diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f3a1426b661451e292ecef2a42ec695b8565df8b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/charades.mp4 filter=lfs diff=lfs merge=lfs -text +examples/ego4d.mp4 filter=lfs diff=lfs merge=lfs -text +examples/youtube.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..87099eaefc08f7b967b350f6a2120385e29f742c --- /dev/null +++ b/app.py @@ -0,0 +1,236 @@ +import os +import pdb +import time +import torch +import gradio as gr +import numpy as np +import argparse +import subprocess +from run_on_video import clip, vid2clip, txt2clip + +parser = argparse.ArgumentParser(description='') +parser.add_argument('--save_dir', type=str, default='./tmp') +parser.add_argument('--resume', type=str, default='./results/omni/model_best.ckpt') +parser.add_argument("--gpu_id", type=int, default=2) +args = parser.parse_args() +os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) + +################################# +model_version = "ViT-B/32" +output_feat_size = 512 +clip_len = 2 +overwrite = True +num_decoding_thread = 4 +half_precision = False + +clip_model, _ = clip.load(model_version, device=args.gpu_id, jit=False) + +import logging +import torch.backends.cudnn as cudnn +from main.config import TestOptions, setup_model +from utils.basic_utils import l2_normalize_np_array + +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def load_model(): + logger.info("Setup config, data and model...") + opt = TestOptions().parse(args) + # pdb.set_trace() + cudnn.benchmark = True + cudnn.deterministic = False + + if opt.lr_warmup > 0: + total_steps = opt.n_epoch + warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps) + opt.lr_warmup = [warmup_steps, total_steps] + + model, criterion, _, _ = setup_model(opt) + return model + +vtg_model = load_model() + +def convert_to_hms(seconds): + return time.strftime('%H:%M:%S', time.gmtime(seconds)) + +def load_data(save_dir): + vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32) + txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32) + + vid = torch.from_numpy(l2_normalize_np_array(vid)) + txt = torch.from_numpy(l2_normalize_np_array(txt)) + clip_len = 2 + ctx_l = vid.shape[0] + + timestamp = ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2) + + if True: + tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + tef_ed = tef_st + 1.0 / ctx_l + tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) + vid = torch.cat([vid, tef], dim=1) # (Lv, Dv+2) + + src_vid = vid.unsqueeze(0).cuda() + src_txt = txt.unsqueeze(0).cuda() + src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda() + src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda() + + return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l + +def forward(model, save_dir, query): + src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir) + src_vid = src_vid.cuda(args.gpu_id) + src_txt = src_txt.cuda(args.gpu_id) + src_vid_mask = src_vid_mask.cuda(args.gpu_id) + src_txt_mask = src_txt_mask.cuda(args.gpu_id) + + with torch.no_grad(): + output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask) + + # prepare the model prediction + pred_logits = output['pred_logits'][0].cpu() + pred_spans = output['pred_spans'][0].cpu() + pred_saliency = output['saliency_scores'].cpu() + + # prepare the model prediction + pred_windows = (pred_spans + timestamp) * ctx_l * clip_len + pred_confidence = pred_logits + + # grounding + top1_window = pred_windows[torch.argmax(pred_confidence)].tolist() + top5_values, top5_indices = torch.topk(pred_confidence.flatten(), k=5) + top5_windows = pred_windows[top5_indices].tolist() + + # print(f"The video duration is {convert_to_hms(src_vid.shape[1]*clip_len)}.") + q_response = f"For query: {query}" + + mr_res = " - ".join([convert_to_hms(int(i)) for i in top1_window]) + mr_response = f"The Top-1 interval is: {mr_res}" + + hl_res = convert_to_hms(torch.argmax(pred_saliency) * clip_len) + hl_response = f"The Top-1 highlight is: {hl_res}" + return '\n'.join([q_response, mr_response, hl_response]) + +def extract_vid(vid_path, state): + history = state['messages'] + vid_features = vid2clip(clip_model, vid_path, args.save_dir) + history.append({"role": "user", "content": "Finish extracting video features."}) + history.append({"role": "system", "content": "Please Enter the text query."}) + chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history),2)] + return '', chat_messages, state + +def extract_txt(txt): + txt_features = txt2clip(clip_model, txt, args.save_dir) + return + +def download_video(url, save_dir='./examples', size=768): + save_path = f'{save_dir}/{url}.mp4' + cmd = f'yt-dlp -S ext:mp4:m4a --throttled-rate 5M -f "best[width<={size}][height<={size}]" --output {save_path} --merge-output-format mp4 https://www.youtube.com/embed/{url}' + if not os.path.exists(save_path): + try: + subprocess.call(cmd, shell=True) + except: + return None + return save_path + +def get_empty_state(): + return {"total_tokens": 0, "messages": []} + +def submit_message(prompt, state): + history = state['messages'] + + if not prompt: + return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state + + prompt_msg = { "role": "user", "content": prompt } + + try: + history.append(prompt_msg) + # answer = vlogger.chat2video(prompt) + # answer = prompt + extract_txt(prompt) + answer = forward(vtg_model, args.save_dir, prompt) + history.append({"role": "system", "content": answer}) + + except Exception as e: + history.append(prompt_msg) + history.append({ + "role": "system", + "content": f"Error: {e}" + }) + + chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)] + return '', chat_messages, state + + +def clear_conversation(): + return gr.update(value=None, visible=True), gr.update(value=None, interactive=True), None, gr.update(value=None, visible=True), get_empty_state() + + +def subvid_fn(vid): + save_path = download_video(vid) + return gr.update(value=save_path) + + +css = """ + #col-container {max-width: 80%; margin-left: auto; margin-right: auto;} + #video_inp {min-height: 100px} + #chatbox {min-height: 100px;} + #header {text-align: center;} + #hint {font-size: 1.0em; padding: 0.5em; margin: 0;} + .message { font-size: 1.2em; } + """ + +with gr.Blocks(css=css) as demo: + + state = gr.State(get_empty_state()) + + + with gr.Column(elem_id="col-container"): + gr.Markdown("""## 🤖️ UniVTG: Towards Unified Video-Language Temporal Grounding + Given a video and text query, return relevant window and highlight.""", + elem_id="header") + + with gr.Row(): + with gr.Column(): + video_inp = gr.Video(label="video_input") + gr.Markdown("👋 **Step1**: Select a video in Examples (bottom) or input youtube video_id in this textbox, *e.g.* *G7zJK6lcbyU* for https://www.youtube.com/watch?v=G7zJK6lcbyU", elem_id="hint") + with gr.Row(): + video_id = gr.Textbox(value="", placeholder="Youtube video url", show_label=False) + vidsub_btn = gr.Button("(Optional) Submit Youtube id") + + with gr.Column(): + vid_ext = gr.Button("Step2: Extract video feature, may takes a while") + # vlog_outp = gr.Textbox(label="Document output", lines=40) + total_tokens_str = gr.Markdown(elem_id="total_tokens_str") + + chatbot = gr.Chatbot(elem_id="chatbox") + input_message = gr.Textbox(show_label=False, placeholder="Enter text query and press enter", visible=True).style(container=False) + btn_submit = gr.Button("Step3: Enter your text query") + btn_clear_conversation = gr.Button("🔃 Clear") + + examples = gr.Examples( + examples=[ + ["./examples/youtube.mp4"], + ["./examples/charades.mp4"], + ["./examples/ego4d.mp4"], + ], + inputs=[video_inp], + ) + + gr.HTML('''


You can duplicate this Space to skip the queue:Duplicate Space
''') + + btn_submit.click(submit_message, [input_message, state], [input_message, chatbot]) + input_message.submit(submit_message, [input_message, state], [input_message, chatbot]) + # btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, vlog_outp, state]) + btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, state]) + vid_ext.click(extract_vid, [video_inp, state], [input_message, chatbot]) + vidsub_btn.click(subvid_fn, [video_id], [video_inp]) + + demo.load(queur=False) + + +demo.queue(concurrency_count=10) +demo.launch(height='800px', server_port=2253, debug=True, share=True) diff --git a/examples/charades.mp4 b/examples/charades.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0d77e221ed025160d916a817c96e0016dbd6cea7 --- /dev/null +++ b/examples/charades.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa3d1ba99bf28103844e1313cc5543b7c626d87c42a1c18108c2a69479a6d679 +size 1301669 diff --git a/examples/ego4d.mp4 b/examples/ego4d.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7b2617ddde3a942056ba39fc71ced4a5041f8898 --- /dev/null +++ b/examples/ego4d.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf1271d42415c793e659bebbd48394326cc50e970d44e6fdd0af5dfb4cb4ede4 +size 28306388 diff --git a/examples/youtube.mp4 b/examples/youtube.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3f214974cd37b2cb8d41df55431dda54fd731d06 --- /dev/null +++ b/examples/youtube.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1dd6b483e5346a777b5d6448460c5e30b8fe46aa1133cf6bba94c84dd7262b49 +size 47353846 diff --git a/main/__init__.py b/main/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/main/_train_qfvs.py b/main/_train_qfvs.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f5522cb0a721e19970e4cb64409db0900a0691 --- /dev/null +++ b/main/_train_qfvs.py @@ -0,0 +1,293 @@ +import os +import pdb +import time +import json +import pprint +import random +import importlib +import numpy as np +from tqdm import tqdm, trange +from collections import defaultdict + +import h5py +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import sys +sys.path.append('/data/home/qinghonglin/univtg') +from main.config import BaseOptions, setup_model +from main.dataset import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs +from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle +from utils.model_utils import count_parameters +from eval.qfvs import calculate_semantic_matching, load_videos_tag + +import logging +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def eval_epoch(model, config, opt): + model.eval() + f1_sum = 0; p_sum = 0; r_sum = 0 + + assert len(config['test_videos']) == 1 + video_id = config['test_videos'][0] + embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl") + + feat_type = config['vid_feature'] + feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r') + features = torch.tensor(feat['feature'][()]).unsqueeze(0).cuda() + # pdb.set_trace() + # seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda() + + # dim = features.shape[-1] + # ctx_l = seg_len.sum().cpu() + + dim = features.shape[-1] + ctx_l = features.shape[1] + seg_len = torch.ones(ctx_l) + features = features.reshape(-1, dim)[:ctx_l] + + tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + tef_ed = tef_st + 1.0 / ctx_l + tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2) + features = torch.cat([features, tef], dim=1) # (Lv, Dv+2) + + transfer = {"Cupglass": "Glass", + "Musicalinstrument": "Instrument", + "Petsanimal": "Animal"} + + for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)): + evaluation_num=len(files) + for file in files: + summaries_GT=[] + with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f: + for line in f.readlines(): + summaries_GT.append(int(line.strip())) + + concept1, concept2 = file.split('_')[0:2] + + ############## + if concept1 in transfer: + concept1 = transfer[concept1] + if concept2 in transfer: + concept2 = transfer[concept2] + concept1 = embedding[concept1] + concept2 = embedding[concept2] + + data = { + 'features':features, + 'seg_len': seg_len, + 'tokens_pad1':torch.from_numpy(concept1), + 'tokens_pad2':torch.from_numpy(concept2), + } + + input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True) + + summaries_GT = [x - 1 for x in summaries_GT] + video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat") + + + output_type = 'pred_logits' # only saliency. + # if opt.f_loss_coef == 0: + # output_type = 'saliency_scores' # only saliency. + # elif opt.s_loss_intra_coef == 0: + # output_type = 'pred_logits' # cls is default. + # else: + # output_type = ['pred_logits', 'saliency_scores'] + + # if opt.qfvs_score_multiple > 0: + # output_type = ['pred_logits', 'saliency_scores'] + + with torch.no_grad(): + if not isinstance(output_type, list): + score1 = model(**input1)[output_type].squeeze() + # score1 = score1.masked_select(mask) + score2 = model(**input2)[output_type].squeeze() + # score2 = score2.masked_select(mask) + + score = model(**input_oracle)[output_type].squeeze() + # score = score.masked_select(mask) + else: + score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda() + for output_t in output_type: + # score1 *= model(**input1)[output_t].squeeze() #.masked_select(mask) + # score2 *= model(**input2)[output_t].squeeze() #.masked_select(mask) + # score *= model(**input_oracle)[output_t].squeeze() #.masked_select(mask) + score1 += model(**input1)[output_t].squeeze() #.masked_select(mask) + score2 += model(**input2)[output_t].squeeze() #.masked_select(mask) + score += model(**input_oracle)[output_t].squeeze() #.masked_select(mask) + + score = score + # score = score + score1 + score2 + + # since video4 features dim is greater than video_shots_tag. + score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])] + _, top_index = score.topk(int(score.shape[0] * config["top_percent"])) + p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1) + f1_sum+=f1; r_sum+=r; p_sum+=p + + return {'F': round(100* f1_sum/evaluation_num,2) , + 'R': round(100* r_sum/evaluation_num,2) , + 'P': round(100* p_sum/evaluation_num,2) } + +def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer): + model.train() + criterion.train() + + # init meters + time_meters = defaultdict(AverageMeter) + loss_meters = defaultdict(AverageMeter) + + timer_dataloading = time.time() + loss_total = 0 + + # optimizer.zero_grad() + for batch_idx, batch in enumerate(tqdm(train_loader)): + time_meters["dataloading_time"].update(time.time() - timer_dataloading) + timer_start = time.time() + model_input1, model_input2, model_input_oracle, \ + model_gt1, model_gt2, model_gt_oracle, \ + mask_GT = prepare_batch_inputs_qfvs(batch, config) + time_meters["prepare_inputs_time"].update(time.time() - timer_start) + + timer_start = time.time() + output1 = model(**model_input1) + output2 = model(**model_input2) + output_oracle = model(**model_input_oracle) + + loss_dict = {} + loss_dict1 = criterion(output1, model_gt1) + loss_dict2 = criterion(output2, model_gt2) + loss_dict3 = criterion(output_oracle, model_gt_oracle) + + weight_dict = criterion.weight_dict + for k in loss_dict1.keys(): + loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k] + + # print(loss_dict) + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + loss_total += losses.item() + + time_meters["model_forward_time"].update(time.time() - timer_start) + timer_start = time.time() + # optimizer.zero_grad() + optimizer.zero_grad() + losses.backward() + if opt.grad_clip > 0: + nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + # if ((batch_idx + 1) % opt.bsz==0) or (batch_idx == len(train_loader)-1): + # pdb.set_trace() + # optimizer.step() + # optimizer.zero_grad() + optimizer.step() + time_meters["model_backward_time"].update(time.time() - timer_start) + + timer_dataloading = time.time() + return round(loss_total / len(train_loader), 2) + +# train in single domain. +def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config): + if opt.device.type == "cuda": + logger.info("CUDA enabled.") + model.to(opt.device) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) + opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" + opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" + + prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0} + if opt.start_epoch is None: + start_epoch = -1 if opt.eval_init else 0 + else: + start_epoch = opt.start_epoch + + val_score = eval_epoch(model, config, opt) + tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0) + logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]" + f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]" + f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]") + for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): + if epoch_i > -1: + loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer) + lr_scheduler.step() + eval_epoch_interval = opt.eval_epoch + if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: + with torch.no_grad(): + val_score = eval_epoch(model, config, opt) + tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1) + logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]" + f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]" + f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]") + + if prev_best_score['Fscore'] < val_score['F']: + prev_best_score['Fscore'] = val_score['F'] + prev_best_score['Precision'] = val_score['P'] + prev_best_score['Recall'] = val_score['R'] + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt")) + tb_writer.close() + return prev_best_score + +def start_training(): + logger.info("Setup config, data and model...") + opt = BaseOptions().parse() + set_seed(opt.seed) + + config = load_json("./main/config_qfvs.json") + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + + # key -> test video; value -> training videos. + qfvs_split = {1: [2, 3, 4], + 2: [1, 3, 4], + 3: [1, 2, 4], + 4: [1, 2, 3]} + # qfvs_split = { + # 2: [1, 3, 4], + # 3: [1, 2, 4], + # } + + scores_videos = {} + for test_id, splits in qfvs_split.items(): + logger.info(f"Start Training {opt.dset_name}: {test_id}") + config['train_videos'] = qfvs_split[test_id] + config['test_videos'] = [test_id] + train_dataset = DatasetQFVS(config) + train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers) + + model, criterion, optimizer, lr_scheduler = setup_model(opt) + count_parameters(model) + best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config) + scores_videos['V'+str(test_id)] = best_score + + # save the final results. + avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos) + avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos) + avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos) + scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall} + + save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json") + save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False) + + tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1) + tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None)) + tb_writer.close() + + print(scores_videos) + return + +if __name__ == '__main__': + start_training() + results = logger.info("\n\n\nFINISHED TRAINING!!!") \ No newline at end of file diff --git a/main/config.py b/main/config.py new file mode 100644 index 0000000000000000000000000000000000000000..40eab1902681354b755102bbfbccd15976a6e9b6 --- /dev/null +++ b/main/config.py @@ -0,0 +1,378 @@ +import os +import pdb +import time +import torch +import logging +import argparse +import importlib +from utils.basic_utils import mkdirp, remkdirp, \ + load_json, save_json, make_zipfile, dict_to_markdown + +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +class BaseOptions(object): + saved_option_filename = "opt.json" + ckpt_filename = "model.ckpt" + tensorboard_log_dir = "tensorboard_log" + train_log_filename = "train.log.txt" + eval_log_filename = "eval.log.txt" + + def __init__(self): + self.parser = None + self.initialized = False + self.opt = None + + def initialize(self): + self.initialized = True + parser = argparse.ArgumentParser() + # * Running configs + parser.add_argument("--dset_type", type=str, choices=["mr", "hl", "vs", "vlp"]) # moment retrieval, highlight detection, and video summarization + parser.add_argument("--dset_name", type=str, choices=["qvhighlights", "charades", "anet", "tvsum", "youtube", "summe", "ego4d", "qfvs", "video2gif", "coin", "hacs", "vlp", "videocc", "tacos"]) + parser.add_argument("--domain_name", type=str, default=None) + parser.add_argument("--model_id", type=str, default="moment_detr") + parser.add_argument("--exp_id", type=str, default="debug", help="id of this run, required at training") + parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu") + parser.add_argument("--gpu_id", type=int, default=0) + parser.add_argument("--debug", action="store_true", + help="debug (fast) mode, break all loops, do not load all data into memory.") + parser.add_argument("--seed", type=int, default=2018, help="random seed") + + # * DDP + parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') + + + parser.add_argument("--eval_split_name", type=str, default="val", + help="should match keys in video_duration_idx_path, must set for VCMR") + parser.add_argument("--data_ratio", type=float, default=1.0, + help="how many training and eval data to use. 1.0: use all, 0.1: use 10%." + "Use small portion for debug purposes. Note this is different from --debug, " + "which works by breaking the loops, typically they are not used together.") + parser.add_argument("--results_root", type=str, default="results") + parser.add_argument("--num_workers", type=int, default=0, + help="num subprocesses used to load the data, 0: use main process") + parser.add_argument("--no_pin_memory", action="store_true", + help="Don't use pin_memory=True for dataloader. " + "ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4") + + # * Training configs + parser.add_argument("--bsz", type=int, default=32, help="mini-batch size") + parser.add_argument("--n_epoch", type=int, default=200, help="number of epochs to run") + parser.add_argument("--max_es_cnt", type=int, default=200, + help="number of epochs to early stop, use -1 to disable early stop") + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") + parser.add_argument("--lr_drop", type=int, default=400, help="drop learning rate to 1/10 every lr_drop epochs") + parser.add_argument("--lr_gamma", type=float, default=0.1, help="lr reduces the gamma times after the `drop' epoch") + parser.add_argument("--lr_warmup", type=float, default=-1, help="linear warmup scheme") + parser.add_argument("--wd", type=float, default=1e-4, help="weight decay") + parser.add_argument("--grad_clip", type=float, default=0.1, help="perform gradient clip, -1: disable") + + # ** Loss coefficients + # *** boundary branch + parser.add_argument("--span_loss_type", default="l1", type=str, choices=['l1', 'ce'], + help="l1: (center-x, width) regression. ce: (st_idx, ed_idx) classification.") + parser.add_argument('--b_loss_coef', default=10, type=float) # boundary regression e.g., l1 + parser.add_argument('--g_loss_coef', default=1, type=float) # giou loss + # *** foreground branch + parser.add_argument('--eos_coef', default=0.1, type=float, help="relative classification weight of the no-object class") + parser.add_argument('--f_loss_coef', default=4, type=float) # cls loss for foreground + # *** saliency branch + parser.add_argument("--s_loss_intra_coef", type=float, default=1., help="inter-video (frame-level) saliency loss e.g. momentdetr saliency loss") + parser.add_argument("--s_loss_inter_coef", type=float, default=0., help="intra-video (sample-level) saliency loss,") + + # * Eval configs + parser.add_argument("--main_metric", type=str, default="MR-full-mAP") + parser.add_argument('--eval_mode', default=None, type=str, + help="how to integrate foreground and saliency for better prediction") + parser.add_argument("--eval_bsz", type=int, default=100, + help="mini-batch size at inference, for query") + parser.add_argument("--eval_epoch", type=int, default=5, + help="number of epochs for once inference") + parser.add_argument("--eval_init", action="store_true", help="evaluate model before training i.e. `epoch=-1'") + parser.add_argument("--save_interval", type=int, default=50) + + parser.add_argument("--resume", type=str, default=None, + help="checkpoint path to resume or evaluate, without --resume_all this only load weights") + parser.add_argument("--resume_dir", type=str, default=None, + help="checkpoint path to resume or evaluate, without --resume_all this only load weights") + parser.add_argument("--resume_all", action="store_true", + help="if --resume_all, load optimizer/scheduler/epoch as well") + parser.add_argument("--start_epoch", type=int, default=None, + help="if None, will be set automatically when using --resume_all") + + # ** NMS configs + parser.add_argument("--no_sort_results", action="store_true", + help="do not sort results, use this for moment query visualization") + parser.add_argument("--max_before_nms", type=int, default=10) + parser.add_argument("--max_after_nms", type=int, default=10) + parser.add_argument("--conf_thd", type=float, default=0.0, help="only keep windows with conf >= conf_thd") + parser.add_argument("--nms_thd", type=float, default=-1, + help="additionally use non-maximum suppression " + "(or non-minimum suppression for distance)" + "to post-processing the predictions. " + "-1: do not use nms. [0, 1]") + + # * Dataset configs + parser.add_argument("--use_cache", type=int, default=-1, help="Preload features into cache for fast IO") + parser.add_argument("--max_q_l", type=int, default=75) + parser.add_argument("--max_v_l", type=int, default=75) + parser.add_argument("--clip_length", type=float, default=1.0) + parser.add_argument("--clip_len_list", type=int, nargs='+') + parser.add_argument("--max_windows", type=int, default=5) + + parser.add_argument("--add_easy_negative", type=int, default=1) + parser.add_argument("--easy_negative_only", type=int, default=1) + parser.add_argument("--round_multiple", type=int, default=1) + + parser.add_argument("--train_path", type=str, default=None, nargs='+') + parser.add_argument("--eval_path", type=str, default=None, + help="Evaluating during training, for Dev set. If None, will only do training, ") + parser.add_argument("--train_path_list", type=str, nargs='+') + parser.add_argument("--eval_path_list", type=str, nargs='+') + parser.add_argument("--feat_root_list", type=str, nargs='+') + + parser.add_argument("--no_norm_vfeat", action="store_true", help="Do not do normalize video feat") + parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalize text feat") + parser.add_argument("--v_feat_dirs", type=str, nargs="+", + help="video feature dirs. If more than one, will concat their features. " + "Note that sub ctx features are also accepted here.") + parser.add_argument("--t_feat_dir", type=str, help="text/query feature dir") + parser.add_argument("--v_feat_dim", type=int, help="video feature dim") + parser.add_argument("--t_feat_dim", type=int, help="text/query feature dim") + parser.add_argument("--ctx_mode", type=str, default="video_tef") + parser.add_argument("--v_feat_types", type=str) + parser.add_argument("--t_feat_type", type=str) + + # * Model configs + parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features") + parser.add_argument("--n_input_proj", type=int, default=2, help="#layers to vid/txt projector") + parser.add_argument("--temperature", type=float, default=0.07, help="temperature nce contrastive_align_loss") + + # ** Transformer + parser.add_argument('--enc_layers', default=4, type=int, + help="Number of encoding layers in the transformer") + parser.add_argument('--sub_enc_layers', default=2, type=int, + help="Number of encoding layers in the video / text transformer in albef-style.") + parser.add_argument('--dec_layers', default=2, type=int, + help="Number of decoding layers in the transformer, N/A for UniVTG") + parser.add_argument('--dim_feedforward', default=1024, type=int, + help="Intermediate size of the feedforward layers in the transformer blocks") + parser.add_argument('--hidden_dim', default=256, type=int, + help="Size of the embeddings (dimension of the transformer)") + parser.add_argument('--input_dropout', default=0.5, type=float, + help="Dropout applied in input") + parser.add_argument('--dropout', default=0.1, type=float, + help="Dropout applied in the transformer") + parser.add_argument('--droppath', default=0.1, type=float, + help="Droppath applied in the transformer") + parser.add_argument("--txt_drop_ratio", default=0, type=float, + help="drop txt_drop_ratio tokens from text input. 0.1=10%") + parser.add_argument("--use_txt_pos", action="store_true", help="use position_embedding for text as well.") + parser.add_argument('--nheads', default=8, type=int, + help="Number of attention heads inside the transformer's attentions") + parser.add_argument('--num_queries', default=10, type=int, + help="Number of query slots") + parser.add_argument('--pre_norm', action='store_true') + + # ** momentdetr configs e.g. Matcher, saliency margin + parser.add_argument('--set_cost_span', default=10, type=float, + help="L1 span coefficient in the matching cost") + parser.add_argument('--set_cost_giou', default=1, type=float, + help="giou span coefficient in the matching cost") + parser.add_argument('--set_cost_class', default=4, type=float, + help="Class coefficient in the matching cost") + parser.add_argument("--saliency_margin", type=float, default=0.2) + parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_true', + help="Disables auxiliary decoding losses (loss at each layer)") + + # * Query-Force Video Summarization + parser.add_argument("--max_segment_num", type=int, default=20) + parser.add_argument("--max_frame_num", type=int, default=200) + parser.add_argument("--top_percent", type=float, default=0.02) + + parser.add_argument("--qfvs_vid_feature", type=str, default='fps1') + parser.add_argument("--qfvs_txt_feature", type=str, default='query') + parser.add_argument("--qfvs_split", type=int, default=-1) + + parser.add_argument("--qfvs_dense_shot", type=int, default=-1) + parser.add_argument("--qfvs_score_ensemble", type=int, default=-1) + parser.add_argument("--qfvs_score_gather", type=int, default=-1) + parser.add_argument("--qfvs_loss_gather", type=int, default=-1) + self.parser = parser + + def display_save(self, opt): + args = vars(opt) + # Display settings + print(dict_to_markdown(vars(opt), max_str_len=120)) + # Save settings + if not isinstance(self, TestOptions): + option_file_path = os.path.join(opt.results_dir, self.saved_option_filename) # not yaml file indeed + save_json(args, option_file_path, save_pretty=True) + + def parse(self, args=None): + if not self.initialized: + self.initialize() + opt = self.parser.parse_args() + + if args is not None: + args_dict = vars(args) + opt_dict = vars(opt) + for key, value in args_dict.items(): + opt_dict[key] = value + opt = argparse.Namespace(**opt_dict) + opt.model_dir = os.path.dirname(opt.resume) + torch.cuda.set_device(opt.gpu_id) + + if opt.debug: + opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ]) + opt.num_workers = 0 + + if isinstance(self, TestOptions): + # modify model_dir to absolute path + # opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir) + opt.model_dir = os.path.dirname(opt.resume) + saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename)) + for arg in saved_options: # use saved options to overwrite all BaseOptions args. + if arg not in ["results_root", "num_workers", "nms_thd", "debug", "max_before_nms", "max_after_nms" + "max_pred_l", "min_pred_l", "gpu_id", + "resume", "resume_all", "no_sort_results", + "eval_path", "eval_split_name"]: + # "dset_name", "v_feat_dirs", "t_feat_dir"]: + setattr(opt, arg, saved_options[arg]) + # opt.no_core_driver = True + if opt.eval_results_dir is not None: + opt.results_dir = opt.eval_results_dir + else: + if opt.exp_id is None: + raise ValueError("--exp_id is required for at a training option!") + + # ctx_str = opt.ctx_mode + "_sub" if any(["sub_ctx" in p for p in opt.v_feat_dirs]) else opt.ctx_mode + + if 'debug' not in opt.exp_id: + opt.results_dir = os.path.join(opt.results_root, "-".join([opt.dset_type, opt.dset_name]), "-".join([opt.exp_id, opt.v_feat_types, opt.t_feat_type, time.strftime("%Y_%m_%d_%H")])) + else: + opt.results_dir = os.path.join(opt.results_root, "-".join([opt.dset_type, opt.dset_name]), opt.exp_id) # debug mode. + + if int(opt.local_rank) in [0, -1]: + # mkdirp(opt.results_dir) + remkdirp(opt.results_dir) # remove dir and remkdir it. + + # save a copy of current code + code_dir = os.path.dirname(os.path.realpath(__file__)) + code_zip_filename = os.path.join(opt.results_dir, "code.zip") + make_zipfile(code_dir, code_zip_filename, + enclosing_dir="code", + exclude_dirs_substring="results", + exclude_dirs=["results", "debug_results", "__pycache__"], + exclude_extensions=[".pyc", ".ipynb", ".swap"], ) + + if int(opt.local_rank) in [0, -1]: + self.display_save(opt) + opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename) + opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename) + opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename) + opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir) + # opt.device = torch.device("cuda" if opt.device >= 0 else "cpu") + + if int(opt.local_rank) in [-1]: + torch.cuda.set_device(opt.gpu_id) + opt.pin_memory = not opt.no_pin_memory + + if opt.local_rank == -1: + torch.cuda.set_device(opt.gpu_id) + + opt.use_tef = "tef" in opt.ctx_mode + opt.use_video = "video" in opt.ctx_mode + if not opt.use_video: + opt.v_feat_dim = 0 + if opt.use_tef: + opt.v_feat_dim += 2 + + self.opt = opt + return opt + +class TestOptions(BaseOptions): + """add additional options for evaluating""" + + def initialize(self): + BaseOptions.initialize(self) + # also need to specify --eval_split_name + self.parser.add_argument("--eval_id", type=str, help="evaluation id") + self.parser.add_argument("--eval_results_dir", type=str, default=None, + help="dir to save results, if not set, fall back to training results_dir") + self.parser.add_argument("--model_dir", type=str, + help="dir contains the model file, will be converted to absolute path afterwards") + +class WarmupStepLR(torch.optim.lr_scheduler.StepLR): + def __init__(self, optimizer, warmup_steps, step_size, gamma=0.1, last_epoch=-1): + self.warmup_steps = warmup_steps + self.step_size = step_size + self.gamma = gamma + super(WarmupStepLR, self).__init__(optimizer, step_size, gamma=self.gamma, last_epoch=last_epoch) + def get_lr(self): + if not self._get_lr_called_within_step: + import warnings + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", DeprecationWarning) + # e.g. warmup_steps = 10, case: 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21... + if self.last_epoch == self.warmup_steps or(self.last_epoch % self.step_size != 0 and self.last_epoch > self.warmup_steps): + return [group['lr'] for group in self.optimizer.param_groups] + # e.g. warmup_steps = 10, case: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + elif self.last_epoch < self.warmup_steps: + return [group['initial_lr'] * float(self.last_epoch + 1) / float(self.warmup_steps) for group in self.optimizer.param_groups] + + + # e.g. warmup_steps = 10, case: 10, 20, 30, 40... + return [group['lr'] * self.gamma + for group in self.optimizer.param_groups] + def _get_closed_form_lr(self): + if self.last_epoch <= self.warmup_steps: + return [base_lr * float(self.last_epoch) / (self.warmup_steps) for base_lr in self.base_lrs] + else: + return [base_lr * self.gamma ** ((self.last_epoch - self.warmup_steps)// self.step_size) for base_lr in self.base_lrs] + +def setup_model(opt): + """setup model/optimizer/scheduler and load checkpoints when needed""" + logger.info("setup model/optimizer/scheduler") + + importer = importlib.import_module('.'.join(['model', opt.model_id])) + model, criterion = importer.build_model(opt) + + if int(opt.device) >= 0: + logger.info("CUDA enabled.") + model.to(opt.gpu_id) + criterion.to(opt.gpu_id) + + param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}] + optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd) + + if opt.lr_warmup != -1 and opt.lr_drop > 0: + lr_scheduler = WarmupStepLR(optimizer, warmup_steps=opt.lr_warmup[0], step_size=opt.lr_drop, gamma=opt.lr_gamma) + + elif opt.lr_warmup != -1: + from transformers import get_constant_schedule_with_warmup + lr_scheduler = get_constant_schedule_with_warmup(optimizer, opt.lr_warmup[0]) + + elif opt.lr_drop > 0: + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop, gamma=opt.lr_gamma) + + if opt.resume is not None: + logger.info(f"Load checkpoint from {opt.resume}") + checkpoint = torch.load(opt.resume, map_location="cpu") + + for key in list(checkpoint["model"].keys()): + checkpoint["model"][key.replace('module.', '')] = checkpoint["model"].pop(key) + model.load_state_dict(checkpoint["model"]) + + if opt.resume_all: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + opt.start_epoch = checkpoint['epoch'] + 1 + logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}") + else: + logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path") + + return model, criterion, optimizer, lr_scheduler diff --git a/main/config_hl.py b/main/config_hl.py new file mode 100644 index 0000000000000000000000000000000000000000..853049a62973aa1b5774257b21d97d2f2a9fdef2 --- /dev/null +++ b/main/config_hl.py @@ -0,0 +1,190 @@ +# Copyright (c) THL A29 Limited, a Tencent company. All rights reserved. + +YOUTUBE_SPLITS = { + 'dog': { + 'train': [ + 'BsjTtq337mM', 'eGCD1F74iy8', 'x2Za-t1yHtI', 'iyYiqa0QZXM', + 'azy9ijU6f9I', 'NNtSZ6cPiwA', 'U9CBalvFfbM', 'AZDkqJaOgJU', + '-olTgMPAyMI', 'i35F1Ec3Ats', '6bS6-GVLBeM', 'ZGszTEn28v8', + 'EEb8iSMqwj4', 'p2hYGNkRMCw', '3kbptPDIz4U', 'iLHRqR-M9HQ', + 'zyooMDuAgCA', 'dOVsQ63N0gg', '7H_qqQvPUzY', 'Z5BEFsaYIS4', + 'iWO6io44-Fs', 'vVmGisWK0QI', 'L10kN7Btk90', '2yql1mvWbDs', + 'Iu2nbtr_Uuk', 'NSmOKAauZpM', 'PAhQGoURAro', 'uJ81Us4mBOc', + '1krGVyfIaOw', 'p9yW6FxsrJ4', 'DLGRJfpGmCQ', '0XTXKe2TOAg', + 'qpc4OSqeV7I', 'q_PJFuBOk7k', '0Uu53hCnKQ4', '-szRD9kyNug', + 'rUPxwWmJYpg', 'hseONiKKx_8', 'BLaQcOcDfjo', 'nW5JulWYEc8', + 'rMvH1SMGwwI', 'l6KlvTJkTgk', 'O8j4U3NjNvs', '8AJTZeEeStk' + ], + 'val': [ + 'a2nj7XCo2Rk', '9rP5yF9EC3Y', 'OxSsRZqPfyk', 'bZzP2MieC1c', + 'PcvdX5OVgfQ', 'p0oxRJD1GUk', 'msjK8nHZHZ0', 'hSRyclcZyGM', + 'dlH2K9N_jSM', 'OCVXhRG2fEA', 'MkBdHvXPocc', 'yN7h90Y-04g', + 'PWqLJKZeBC8', '9D_Q8l_ruQk', 'Mp8Pz86J660', '1gjntnYm8NA', + 'O3XxuutEvoo', 'wf_qlAizlSM', 'fXx44D1sqUw', 'P0MnXh6bnKk', + 'sTd06idFa0E', 'ppNjl3I3iJs', 'Om5mczkpcVg', 'xZIN_s-qhbU' + ] + }, + 'gymnastics': { + 'train': [ + 'Wfv90YJ2YtA', 'MbD5OIR9yWc', 'fZwCJWkC_Qw', 'AyRI1CioQfY', + 'xV_5YCdVqSM', '19UO7T32DJI', 'o2gAP2Clg_s', 'ewyfAOrBzjQ', + 'CMTKpA683Ig', 'aNjphhjTgqs', 'dmJ0Nq4DF2w', '57IQ6EudvGU', + 'BAlUYtPUsVI', '_UU4XqYVDqE', 'Kq4OhBiQk_E', 'D6nyvx9kEac', + 'g-m4-zeCisU', '_45vTFtcduE', '9L-Pocc_u70', '0636XaURL-A', + 'GCabQyaHSMg', 'vUi1Scb35fQ', 'eK-Yuoou_1I', 'kkS7TgNZwJI', + '2EFkINKg3nA', 'eKvALYDh7RU', 'Hyp3Hpk6dyA', '9rpzf3sgQkw', + 'kHNAnpewyeo', 'ydQij10qrZM', '41u2V_ZAKto', '6NSWsMKAgEU', + 'kUs_yUR-C2k', 'bs3ZBcfhvKA' + ], + 'val': [ + '2AuigNFEsTM', 'rPsKpHKzUso', 'tzq5cJQ9NQA', 'DyZ0gZ5xmxI', + 'PEKRfJYYEgU', 'affAIVH9uRA', 'FT7yIi3-tG0', 'T_zWyrVzyvw', + 'RoiLzMA_ilA', 'nBZiGSccsTg', 'z3cNtOMKK7A', 'EwQ-aMK2sKg', + 'Rq0BpciuvBM', 's6LNwTThBgs', '-hE9v3izo4c', 'KldEfRhv7H0', + 'eUyuw2J5FaE', 'E0aRE1_ea8E', 'BU7YlQAOBkM', 'iDJM9j11U-c', + 'zr5LSPMBpiI', 'NAfBa7lqg2Q', 'eB4Toq9dUWs', 'YPd7RDN5CkE', + '86YLsw7efDM', 'iQRMMFiYAUw', 'lzEhLAPxZyQ', 'PAjJbT1DRnY' + ] + }, + 'parkour': { + 'train': [ + 'qz1UnnxlWhI', 'MzODICzycHs', '0swXWs9yWA4', 'Nnv22OW_PaI', + 'LUhZJLY2uKc', 'yZz8z1l3XJU', '3dvjtdMC2ls', 'e27ppPer9XY', + 'HJNn2WlKFhM', 'j4OxlxnapNI', 'rhABvn7VjSQ', '3PCwXpwYqLs', + 'LECL1bIpi5w', 'w0ouP79iZWc', 'z6aKQPMJUC0', 'kATlFTwxBVY', + '3SM6a8eyuVA', 'v-Sfc4COqRQ', '64eu8pwuIUE', '7WKm0XDk3og', + '2F5Sc0Jgk4g' + ], + 'val': [ + 'TFdbCRkVeIA', 'uGLs9atTvNc', 'qlGPuopK3CI', 'ucTkpjZO_o4', + '4-4BgyGphLQ', '08k4ysX_XJE', '6sMNnWqa_as', 'oT6g0I2Ok9o', + 'Be4IlnKeBOo', 'yUjJq0kvxcw', 'fLek7GRIxjE' + ] + }, + 'skating': { + 'train': [ + '7owXLUkpoNY', '1OLM0_Jzt5M', 'b1LXb0Sbiy0', '3fGux6-ttlA', + 'HQvRun80GyA', 'a8M-5nTrll8', 'bA3CxZllhsI', 'AUAsfZtcB4E', + 'FG57uCJvQLw', 'jXIuv5uFPTI', 'eG-hdYLoS98', '2SdJBl251PU', + '2PHJqqrGC80', 'EtZkkFhniRw', 'jUiwyguxzIw', 'FL6mXlaF78Q', + 'BdemklZtYWI', 'ATk_ncI1-BA', '4wiKDfq3X8U', 'BN7GBjVlFTo', + 'JiMZvMkkbRo', '2DIXYkSnRf4', 'dZ3i-HuhQXM', '7jZydh62m8M' + ], + 'val': [ + '2oOe2_Ew6Ao', 'DGcO0QgcXtw', 'ixsKaNplm6o', '7TQbqKWjLcI', + 'CQZNrEstSag', 'g1WbAIzkw80', '4cyx1VpDjc4', 'BGZaaqFjoRY', + 'AJ98A2y1dVw', '1n7Afe5AZCM', '8x8ESK5MnR0' + ] + }, + 'skiing': { + 'train': [ + '6Usy87KaF-A', 'DtjKkp_4KDQ', '4Wt7TM2wDxI', 'iKnzSGFwdbc', + 'nALCc6HPQNs', 'WL4TA--CVcA', 'dFrfsgW1M98', 'x6qmrVojcYc', + 'pvcmQ9J_BYw', 'S3VEYFAP_pk', 'pU57a3jYMEk', '33TrLdo3ook', + 'xLhHU8uo2aY', 'fAHBmka6Psc', '9HYzZk5kiJA', 'T0gjqYbeU1g', + '7o628W-bFy0', 'YKDm_PCa-HM', 'R3DV2zDnNqg', 'NCe9YeXTvHo', + '5tXxvscmZ-Y', 'thNiPQLbi5w', '1TtJy8cSzqA', 'zDRzOsmwa08', + 'gCI4gArPjNA', 'uw0i26NHucs', '1giAsZC_ywQ', 'OvgaPTfEnqo', + 'bFD_p5znoq4', 'uKmqaAvjKgw', '5ivw_sdCTCU', 'iwCSAYGwPq4', + 'HmmOPntPlRA', 'FHCEyiM-NoY', 'EUSFMmoE_jI', 'igvSxtdsT8w', + 'zEgMYFiEaX4', '0K2FKccDp9A', 'tdyz6h4ZtYs', 'PO7GEbi2z3c', + 'mmiu7rRmSAU', 'qL6Kic-CdTo', '0fNCsOY1WGk', 'V3J26hr1ZSE', + 'GS-qBunN3B4', 'ZLNvg8025Nw', 'puAxGH6aWMY', 'h-SlvHubhs8', + 'AdovZ4OAS8I', 'UDvA1XMa1m4', 'qdo3d7mR_9s', 'qAinbyORWIw', + 'v1JpJueAElY', 'TjH29fdjcqI', 'f76B1uucoyo', 'DNPPDcOd5eQ', + '-GX95udKKm8', 'YRO_RQ3aBgg', '1ptV2E7lm9U', 'qa7dtf1Qcew', + '_UJTkqYNrpA', 'md14DNKq2_o', 'tpewrb9dDyo', 'yGoWYi_dHLY', + 'DZ3NRjDHwy8', 'aMFcEuJUqpk', '6fT9KLuE7no', 'lPdQMMAuOZo' + ], + 'val': [ + 'SSlv7qJK5zA', '_BYqZjuKpKA', 'ZueaKXReGjU', 'mGST8ZekCZc', + 'JJSu7Lh9rvs', 'IyoD3G5igY0', 'MXyv-Ut9HRg', 'Z8X9WIojH1U', + 'vT33-8KUb2Q', 'HW6_sPym938', '9wtXO2lF6hM', 'mRdthCqe6Nk', + 'RGxiOb9hlS0', 'ruySf5zL7Kw', 'I7wFmP6P7p0', '0AHkDElk3ws', + 'zqXd4EgUFhE', '91lDbBHUx0w', 'iaHbK6ogafc', 'jRbst8kjWW8', + 'drHPy6wSZGs', '5VaY6LgIqDs', 'bXq9rRSbI3c', 'hjZLa2DTuqs', + 'Ka2qcp3jmWo', 'ZnA4-ggkFu8', 'iXdt4v42mbs', '8aWN-0NZErI', + '09v0HNf81J0', 'YJCR2q-WRhQ', 'RjagI4pAUpw', '_10CbYdTG5M', + 'lhgmIgzBQxs', '2pstGBM4p0w', 'b53-VPsWom4', 'x-G4r153n6o', + 'qBbqK5qlVSM', 'XamrS9XyHuQ', 'u_n7jMS1vlw', 'AO6p0jlOd6U', + 'm-W-lcTkBQ0', 'bMuyPVIlXW8', 'kAAvTAKkIy4', 'U6vnbCurZQA', + 'dHE8q7sZ70U', 'w7fzLVRPSUc', 'FLYkD7zHuHQ', 'nhOhI24P7dM', + 'n5q2KhfoiWw', '7Hcyse0h9HE', '6_BPy_VaPSY' + ] + }, + 'surfing': { + 'train': [ + 'Ai9FwQGn5ds', 'hBl0Sm3_auw', 'LMxMeg407Vg', 'D3fk8doVui4', + 'Y9pxmLg6ti8', 'p_JsivYdbgQ', 'UokX-hcXQeo', 'VYe5QfM5ecE', + 'I48VJ92ouTQ', 'Tn-ebtUnq6E', 'eWae-nWocPU', '-Yamat_0tbw', + 'c2Fy-rdXJy4', 'xQ4NAp4vWbI', 'g9kXCIjIjoE', 'A96Jx6gv6_4', + 'e427qElqqN0', 'tTcA5hiViPo', 'wMdXzj_3aA0', 'fqNzMz1n6uA', + 'jKVOA7RFCUo', 'TJBJrk9iPPA', '_C8EjMxrS2s', 'yj7abHfZTQQ', + 'NDcqgpsyWaU', 'UJjwoivaGNo', 'GZ_XS8EnnWo', 'kJUBIcBjUZ0', + 'lWoLyR7lDAU', 'FilbyF_PGjI', 'fapRkcOe4vE', 't05r50PQqww', + 'QgStLppe610', '2TY8Q2WXUyk', '9y_ED3DyNhE', 'CGwtinVGkVU', + 'nOuRhrAMaIw', 'UN4TwjDajtQ', '-FHmVZWWgcE', 'ksx0_BfpsLg', + 'agOBPDsQrTM', 'XqggBwFOmFU', 'orNzj1J8i-4', '6ZbTCHwt1gk', + '0un3wh_pQAc', '4u6OURBLZDs', 'us0agAKuvEM', 'mVQYl7Q-TQs', + 'cB2SdlGHLMQ', 'WK5t4To0zlA', 'NNEuH_juUHI', 'KTU7xfVOat0', + 'Y1nhbNaY1ZY', 'YlXJnZe575s', 'SH7Ns0ANzJU', '3TbZfeokCkE' + ], + 'val': [ + 'o0on6yIXJQE', '4RsZz_8d8Ro', 'p8VUjcZyK70', '0P2PZXUa0Bg', + 'p2eU5z647Mw', 'mSVxaAJcNJQ', 'bcmXVyFbsRg', 'Eiq8GHi4kEo', + 'H5FEdJYokO4', 'Mkyp0z_Cgig', 'NB5Ez5kJfMU', 'Xa0y6b6Vm6U', + 'gVcCGUtpA90', '0-fstXuo_Pw', '-d72e4v9skA', 'lbp6_wCXqvw', + '9GpZHq1n8ps', 'CefGXyYu_zU', 'SI2JbS48Upg', 'hdklRTNrq0I', + 'J-P-t6g19SM', 'K0f_DpVOjfA', 'lw_1fEY9QTo', 'uUuYnKLETLw', + 'HwKv3Xc5MAE', 'wvQ0h5Nwsxc', 'l8ME6z_EWKE', 's9dTu2fcbNg', + 'GS09SevPYT4', 'YbwdDCzVczU', 'jaCOI_VwIjc', '3Y1Jp1_fFLQ', + '82OzgxT2tH8', 'IjQhHPlTfdE', 'KzQcJrT91jU', 't05AD0c08zE', + 'rGxWxX6nYO4', 'QGp0kRzKiAc', 'pK9gDWoOyko', 'Srjd4pe6vck', + 'twGcxuhCXoU', 'AshLUHPEb8M', '8En3M5CUc2E', '8sTJfTUk1d0', + 'o-bubyWTw60', 'NctbssxGCtU', 'L09Qo1ql0nM' + ] + } +} + +TVSUM_SPLITS = { + 'BK': { + 'train': ['WxtbjNsCQ8A', 'EE-bNr36nyA', 'oDXZc0tZe04', 'uGu_10sucQo'], + 'val': ['Se3oxnaPsz0'] + }, + 'BT': { + 'train': ['eQu1rNs0an0', 'qqR6AEXwxoQ', 'EYqVtI9YWJA', 'iVt07TCkFM0'], + 'val': ['JgHubY5Vw3Y'] + }, + 'DS': { + 'train': ['kLxoNp-UchI', 'NyBmCxDoHJU', 'jcoYJXDG9sw', '-esJrBWj2d8'], + 'val': ['E11zDS9XGzg'] + }, + 'FM': { + 'train': ['_xMr-HKMfVA', 'byxOvuiIJV0', 'VuWGsYPqAX8', 'xmEERLqJ2kU'], + 'val': ['JKpqYvAdIsw'] + }, + 'GA': { + 'train': ['xxdtq8mxegs', 'i3wAGJaaktw', '0tmA_C6XwfM', '3eYKfiOEJNs'], + 'val': ['Bhxk-O1Y7Ho'] + }, + 'MS': { + 'train': ['Hl-__g2gn_A', 'WG0MBPpPC6I', 'LRw_obCPUt0', '37rzWOQsNIw'], + 'val': ['Yi4Ij2NM7U4'] + }, + 'PK': { + 'train': ['GsAD1KT1xo8', 'XkqCExn6_Us', 'b626MiF1ew4', 'PJrm840pAUI'], + 'val': ['cjibtmSLxQ4'] + }, + 'PR': { + 'train': ['RBCABdttQmI', 'z_6gVvQb2d0', '4wU_LUjG5Ic', '91IHQYk1IQM'], + 'val': ['fWutDQy1nnY'] + }, + 'VT': { + 'train': ['gzDbaEs1Rlg', 'XzYM3PfTM4w', '98MoyGZKHXc', 'AwmHb44_ouw'], + 'val': ['J0nA4VgnoCo'] + }, + 'VU': { + 'train': ['akI8YFjEmUw', 'HT5vyqe0Xaw', 'vdmoEJ5YbrQ', 'xwqBXPGE9pQ'], + 'val': ['sTEELN-vY30'] + } +} \ No newline at end of file diff --git a/main/config_qfvs.json b/main/config_qfvs.json new file mode 100644 index 0000000000000000000000000000000000000000..1948047231c7911893e31576f5b1dec34e68f148 --- /dev/null +++ b/main/config_qfvs.json @@ -0,0 +1,14 @@ +{ + // "max_segment_num": 20, + // "max_frame_num": 200, + + // "train_videos": null, + // "test_videos": null, + // "top_percent": 0.02, + + // "vid_feature": "fps1", + // "txt_feature": "query", + // "txt_max_len": 5, + + // "factor": null +} \ No newline at end of file diff --git a/main/dataset.py b/main/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e534a0c17c64a53cfd29cf1a9188299c0088e42c --- /dev/null +++ b/main/dataset.py @@ -0,0 +1,1261 @@ +import os +import pdb +import h5py +import nncore +import torch +from torch.utils.data import Dataset +import numpy as np +from tqdm import tqdm +import random +import logging +from os.path import join, exists +from nncore.dataset import DATASETS +from nncore.parallel import DataContainer +from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS +from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array +from utils.tensor_utils import pad_sequences_1d +from utils.span_utils import span_xx_to_cxw +from random import shuffle + +logger = logging.getLogger(__name__) + +class DatasetVLP(Dataset): + Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"] + """One line in data loaded from data_path." + { + "qid": 7803, + "query": "Man in gray top walks from outside to inside.", + "duration": 150, + "vid": "RoripwjYFp8_360.0_510.0", + "relevant_clip_ids": [13, 14, 15, 16, 17], + "relevant_windows": [[26, 36]] + } + """ + def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim, + q_feat_type="last_hidden_state", + max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video", + normalize_v=True, normalize_t=True, load_labels=True, + clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0, + use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1): + self.dset_name = dset_name + self.data_path = data_path + self.data_ratio = data_ratio + self.v_feat_dirs = v_feat_dirs \ + if isinstance(v_feat_dirs, list) else [v_feat_dirs] + self.q_feat_dir = q_feat_dir + self.q_feat_type = q_feat_type + self.v_feat_dim = v_feat_dim + self.q_feat_dim = q_feat_dim + self.max_q_l = max_q_l + self.max_v_l = max_v_l + self.ctx_mode = ctx_mode + self.use_tef = "tef" in ctx_mode + self.use_video = "video" in ctx_mode + self.normalize_t = normalize_t + self.normalize_v = normalize_v + self.load_labels = load_labels + self.clip_len = clip_len + self.fix_len = fix_len + self.max_windows = max_windows # maximum number of windows to use as labels + self.span_loss_type = span_loss_type + self.txt_drop_ratio = txt_drop_ratio + self.use_cache = use_cache + self.add_easy_negative = add_easy_negative + self.easy_negative_only = easy_negative_only + + self.vlp_mapping = { + # 'data/qvhighlights/metadata/qvhighlights_asr.jsonl': { + # 'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '_asr', 'type': 'interval', + # }, + # 'data/ego4d/metadata/point_train_1m.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', + # }, + # 'data/ego4d/metadata/point_train_1m_0.1p.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', + # }, + # 'data/ego4d/metadata/point_train_1m_0.2p.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', + # }, + # 'data/ego4d/metadata/point_train_1m_0.5p.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', + # }, + # 'data/ego4d/metadata/point_train_1m_0.75p.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', + # }, + # 'data/ego4d/metadata/point_train_2m.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', + # }, + # 'data/ego4d/metadata/point_train_1m_egoclip.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', + # }, + # 'data/hacs/metadata/hacs_train_cs.jsonl': { + # 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '_cs', 'type': 'curve', + # }, + # 'data/hacs/metadata/hacs_train.jsonl': { + # 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve', + # }, + # 'data/videocc/metadata/train_300k.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + # }, + # 'data/videocc/metadata/train_600k.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + # }, + # 'data/videocc/metadata/train_600k_0.1p.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + # }, + # 'data/videocc/metadata/train_600k_0.2p.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + # }, + # 'data/videocc/metadata/train_600k_0.5p.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + # }, + # 'data/videocc/metadata/train_600k_0.75p.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + # }, + # 'data/videocc/metadata/train_900k.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + # }, + # 'data/ego4d/metadata/concept_train_top10_window.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/ego4d/metadata/concept_train_top5_window.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/ego4d/metadata/concept_train_top5_window_0.1p.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/ego4d/metadata/concept_train_top5_window_0.2p.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/ego4d/metadata/concept_train_top5_window_0.5p.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/ego4d/metadata/concept_train_top5_window_0.75p.jsonl': { + # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/videocc/metadata/concept_train_top10_window.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/videocc/metadata/concept_train_top5_window.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/videocc/metadata/concept_train_top5_window_0.1p.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/videocc/metadata/concept_train_top5_window_0.2p.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/videocc/metadata/concept_train_top5_window_0.5p.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # 'data/videocc/metadata/concept_train_top5_window_0.75p.jsonl': { + # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', + # }, + # + # pre-training + 'data/ego4d/metadata/point_egoclip_wo_val.jsonl': { + 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point', + }, + 'data/videocc/metadata/interval_900k.jsonl': { + 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + }, + 'data/videocc/metadata/curve_5_window.jsonl': { + 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve', + }, + # downstream + 'data/qvhighlights/metadata/qvhighlights_train.jsonl': { + 'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve', + }, + 'data/charades/metadata/charades_train.jsonl': { + 'dset_name': 'charades', 'v_feat_suffix': '_2', 'q_feat_suffix': '', 'type': 'interval', + }, + 'data/ego4d/metadata/nlq_train.jsonl': { + 'dset_name': 'ego4d', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + }, + 'data/tacos/metadata/train.jsonl': { + 'dset_name': 'tacos', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + }, + 'data/anet/metadata/train.jsonl': { + 'dset_name': 'anet', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + }, + 'data/didemo/metadata/train.jsonl': { + 'dset_name': 'didemo', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval', + }, + } + + if "val" in data_path or "test" in data_path: + assert txt_drop_ratio == 0 + + # checks + assert q_feat_type in self.Q_FEAT_TYPES + + # data + self.data = self.load_data() + + self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs] + t_feat_type = q_feat_dir.split('/')[-1] + + if self.use_cache > 0: + print('Loading the off-line features...') + dset_dir = os.path.join('data', self.dset_name) + vid_keys = [meta['vid'] for meta in self.data] + qid_keys = [meta['qid'] for meta in self.data] + + self.vid_cache = {} + for v_feat_type in self.v_feat_types: + assert 'vid' in v_feat_type + with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f: + self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)} + + assert 'txt' in t_feat_type + self.txt_cache = {} + with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f: + for key in tqdm(qid_keys): + try: + self.txt_cache[key] = f[str(key)][:] + except: + logger.info(f"text {key} is not in the cache.") + + def load_data(self): + # datalist = load_jsonl(self.data_path[0]) + datalist = [] + for dset_path in self.data_path: + dset_info = self.vlp_mapping[dset_path] + dset_list = load_jsonl(dset_path) + for x in dset_list: x.update(dset_info) + datalist += dset_list + n_examples = int(len(datalist)) + if self.data_ratio != 1: + n_examples = int(len(datalist) * self.data_ratio) + shuffle(datalist) + datalist = datalist[:n_examples] + logger.info("Using {}% of the data: {} examples" + .format(self.data_ratio * 100, n_examples)) + return datalist + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + meta = self.data[index] + + model_inputs = dict() + model_inputs["query_feat"] = self._get_query_feat_by_qid(meta) # (Dq, ) or (Lq, Dq) + + if self.use_video: + model_inputs["video_feat"] = self._get_video_feat_by_vid(meta) # (Lv, Dv) + ctx_l = len(model_inputs["video_feat"]) + else: + ctx_l = self.max_v_l + + if meta['dset_name'] in ['hacs', 'ego4d', 'activitynet']: + for i, window_i in enumerate(meta["relevant_windows"]): + if window_i[1] - window_i[0] < self.clip_len: + center = (window_i[1] + window_i[0]) / 2 + window_i[0] = max(0, center - 0.5 * self.clip_len) + window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len) + window_i[1] = max(self.clip_len, window_i[1]) + + model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2) + + if 'test' in self.data_path and 'qvhighlights' in self.dset_name: + meta["relevant_windows"] = [[0, 150]] + relevant_windows = torch.Tensor(meta["relevant_windows"]) + + # assign the nearest window for each timestamp i.e., qvhighlights. + num_vid_seq = model_inputs["timestamp"].shape[0] + num_windows = relevant_windows.shape[0] + + relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len) + relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1) + model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1) + + if meta['qid'] is not None: + nn_window_ts = torch.zeros_like(model_inputs["timestamp"]) + diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0] + diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1] + assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0)) + if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet. + nn_window_ts = relevant_windows_ts.squeeze(1) + else: + nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]] + + model_inputs["span_labels_nn"] = nn_window_ts + model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1]) + + # for activitynet. + if model_inputs["timestamp_window"].sum() < 1: + idx = int(meta['relevant_windows'][0][0] / self.clip_len) + idx = max(0, min(idx, ctx_l-1)) + model_inputs["timestamp_window"][idx] = 1 + + if self.use_tef: + tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + tef_ed = tef_st + 1.0 / ctx_l + tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) + if self.use_video: + model_inputs["video_feat"] = torch.cat( + [model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2) + else: + model_inputs["video_feat"] = tef + + if self.load_labels: + model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2) + if 'saliency_scores' in meta.keys(): + # this is for highlight-only task + model_inputs["saliency_scores"] = torch.zeros(ctx_l).double() + limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None + model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1)) + model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ + self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l) + # pdb.set_trace() + else: + model_inputs["saliency_scores"] = model_inputs["timestamp_window"] + model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ + self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt + model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ] + + if 'type' in meta.keys(): + if meta['type'] == 'point': + model_inputs['weight_ablation'] = torch.tensor([0, 0, 1, 0, 0]) + if meta['type'] == 'interval': + model_inputs['weight_ablation'] = torch.tensor([1, 1, 0, 0, 0]) + if meta['type'] == 'curve': + model_inputs['weight_ablation'] = torch.tensor([0, 0, 0, 1, 1]) + + return dict(meta=meta, model_inputs=model_inputs) + + def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1): + gt_st = int(gt_window[0] / self.clip_len) + gt_st = min(gt_st, ctx_l-1) + gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1) + if gt_st > gt_ed: + # gt_st = gt_ed + gt_ed = gt_st + + if gt_st != gt_ed: + pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n) + else: + pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st] + + neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) + # neg_clip_indices = random.sample(neg_pool, k=max_n) + + try: + neg_clip_indices = random.sample(neg_pool, k=max_n) + except: + neg_clip_indices = pos_clip_indices + + return pos_clip_indices, neg_clip_indices + + def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1): + """Sum the scores from the three annotations, then take the two clips with the + maximum scores as positive, and two with the minimum scores as negative. + Args: + rel_clip_ids: list(int), list of relevant clip ids + scores: list([anno1_score, anno2_score, anno3_score]), + ctx_l: int + max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. + add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. + """ + # indices inside rel_clip_ids + scores = np.array(scores) # (#rel_clips, 3) + agg_scores = np.sum(scores, 1) # (#rel_clips, ) + sort_indices = np.argsort(agg_scores) # increasing + + # indices in the whole video + # the min(_, ctx_l-1) here is incorrect, but should not cause + # much troubles since this should be rarely used. + hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] + hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] + + if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]: + hard_neg_clip_indices = hard_pos_clip_indices + + easy_pos_clip_indices = [] + easy_neg_clip_indices = [] + # pdb.set_trace() + if self.add_easy_negative > 0: + easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) + if len(easy_neg_pool) >= max_n: + easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) + easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) + else: # copy the hard ones + easy_pos_clip_indices = hard_pos_clip_indices + easy_neg_clip_indices = hard_neg_clip_indices + + if self.easy_negative_only > 0: + return easy_pos_clip_indices, easy_neg_clip_indices + + pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices + neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices + + return pos_clip_indices, neg_clip_indices + + def get_span_labels(self, windows, ctx_l): + """ + windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive) + Note a maximum of `self.max_windows` windows are used. + returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length + """ + if len(windows) > self.max_windows: + random.shuffle(windows) + windows = windows[:self.max_windows] + if self.span_loss_type == "l1": + windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx + windows = span_xx_to_cxw(windows) # normalized windows in cxw + elif self.span_loss_type == "ce": + windows = torch.Tensor([ + [int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1] + for w in windows]).long() # inclusive + else: + raise NotImplementedError + return windows + + def _get_query_feat_by_qid(self, meta): + qid = meta['qid'] + dset_name = meta['dset_name'] + q_feat_suffix = meta['q_feat_suffix'] + q_feat_dir = self.q_feat_dir + q_feat_suffix + + if self.use_cache > 0: + try: + q_feat = self.txt_cache[qid] + except: + q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) + return torch.from_numpy(q_feat) + + q_feat_path = os.path.join('data', dset_name, q_feat_dir, f"{qid}.npz") + try: + q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32) + except: + q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) + logger.info(f"Something wrong when loading the query feature {q_feat_path}.") + + if self.q_feat_type == "last_hidden_state": + # q_feat = q_feat[:self.max_q_l] + q_feat = q_feat + if self.normalize_t: + q_feat = l2_normalize_np_array(q_feat) + if self.txt_drop_ratio > 0: + q_feat = self.random_drop_rows(q_feat) + return torch.from_numpy(q_feat) # (D, ) or (Lq, D) + + def random_drop_rows(self, embeddings): + """randomly mask num_drop rows in embeddings to be zero. + Args: + embeddings: np.ndarray (L, D) + """ + num_drop_rows = round(len(embeddings) * self.txt_drop_ratio) + if num_drop_rows > 0: + row_indices = np.random.choice( + len(embeddings), size=num_drop_rows, replace=False) + embeddings[row_indices] = 0 + return embeddings + + def _get_video_feat_by_vid(self, meta): + dset_name = meta['dset_name'] + v_feat_suffix = meta['v_feat_suffix'] + vid = meta['vid'] + + v_feat_list = [] + for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs): + v_feat_dir = _feat_dir + v_feat_suffix + if self.use_cache > 0: + _feat = self.vid_cache[feat_type][vid] + else: + _feat_path = os.path.join('data', dset_name, v_feat_dir, f"{vid}.npz") + _feat = np.load(_feat_path)["features"].astype(np.float32) + if self.normalize_v: + _feat = l2_normalize_np_array(_feat) + v_feat_list.append(_feat) + # some features are slightly longer than the others + min_len = min([len(e) for e in v_feat_list]) + v_feat_list = [e[:min_len] for e in v_feat_list] + v_feat = np.concatenate(v_feat_list, axis=1) + return torch.from_numpy(v_feat) # (Lv, D) + +class DatasetMR(Dataset): + Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"] + """One line in data loaded from data_path." + { + "qid": 7803, + "query": "Man in gray top walks from outside to inside.", + "duration": 150, + "vid": "RoripwjYFp8_360.0_510.0", + "relevant_clip_ids": [13, 14, 15, 16, 17], + "relevant_windows": [[26, 36]] + } + """ + def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim, + q_feat_type="last_hidden_state", + max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video", + normalize_v=True, normalize_t=True, load_labels=True, + clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0, + use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1): + self.dset_name = dset_name + self.data_path = data_path[0] if isinstance(data_path, list) else data_path + self.data_ratio = data_ratio + self.v_feat_dirs = v_feat_dirs \ + if isinstance(v_feat_dirs, list) else [v_feat_dirs] + self.q_feat_dir = q_feat_dir + self.q_feat_type = q_feat_type + self.v_feat_dim = v_feat_dim + self.q_feat_dim = q_feat_dim + self.max_q_l = max_q_l + self.max_v_l = max_v_l + self.ctx_mode = ctx_mode + self.use_tef = "tef" in ctx_mode + self.use_video = "video" in ctx_mode + self.normalize_t = normalize_t + self.normalize_v = normalize_v + self.load_labels = load_labels + self.clip_len = clip_len + self.fix_len = fix_len + self.max_windows = max_windows # maximum number of windows to use as labels + self.span_loss_type = span_loss_type + self.txt_drop_ratio = txt_drop_ratio + self.use_cache = use_cache + self.add_easy_negative = add_easy_negative + self.easy_negative_only = easy_negative_only + + if "val" in data_path or "test" in data_path: + assert txt_drop_ratio == 0 + + # checks + assert q_feat_type in self.Q_FEAT_TYPES + + # data + self.data = self.load_data() + + self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs] + t_feat_type = q_feat_dir.split('/')[-1] + + if self.use_cache > 0: + print('Loading the off-line features...') + dset_dir = os.path.join('data', self.dset_name) + vid_keys = [meta['vid'] for meta in self.data] + qid_keys = [meta['qid'] for meta in self.data] + + self.vid_cache = {} + for v_feat_type in self.v_feat_types: + assert 'vid' in v_feat_type + with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f: + self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)} + + assert 'txt' in t_feat_type + self.txt_cache = {} + with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f: + for key in tqdm(qid_keys): + try: + self.txt_cache[key] = f[str(key)][:] + except: + logger.info(f"text {key} is not in the cache.") + + def load_data(self): + datalist = load_jsonl(self.data_path) + if self.data_ratio != 1: + n_examples = int(len(datalist) * self.data_ratio) + datalist = datalist[:n_examples] + logger.info("Using {}% of the data: {} examples" + .format(self.data_ratio * 100, n_examples)) + return datalist + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + meta = self.data[index] + + model_inputs = dict() + model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq) + + if self.use_video: + model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv) + ctx_l = len(model_inputs["video_feat"]) + else: + ctx_l = self.max_v_l + + if self.dset_name in ['hacs', 'ego4d', 'videocc', 'activitynet']: + for i, window_i in enumerate(meta["relevant_windows"]): + if window_i[1] - window_i[0] < self.clip_len: + center = (window_i[1] + window_i[0]) / 2 + window_i[0] = max(0, center - 0.5 * self.clip_len) + window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len) + window_i[1] = max(self.clip_len, window_i[1]) + + model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2) + + if 'test' in self.data_path and 'qvhighlights' in self.dset_name: + meta["relevant_windows"] = [[0, 150]] + relevant_windows = torch.Tensor(meta["relevant_windows"]) + + # assign the nearest window for each timestamp i.e., qvhighlights. + num_vid_seq = model_inputs["timestamp"].shape[0] + num_windows = relevant_windows.shape[0] + + relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len) + relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1) + model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1) + + if meta['qid'] is not None: + nn_window_ts = torch.zeros_like(model_inputs["timestamp"]) + diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0] + diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1] + assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0)) + if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet. + nn_window_ts = relevant_windows_ts.squeeze(1) + else: + nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]] + + model_inputs["span_labels_nn"] = nn_window_ts + model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1]) + + # for activitynet. + if model_inputs["timestamp_window"].sum() < 1: + idx = int(meta['relevant_windows'][0][0] / self.clip_len) + idx = max(0, min(idx, ctx_l-1)) + model_inputs["timestamp_window"][idx] = 1 + + if self.use_tef: + tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + tef_ed = tef_st + 1.0 / ctx_l + tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) + if self.use_video: + model_inputs["video_feat"] = torch.cat( + [model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2) + else: + model_inputs["video_feat"] = tef + + if self.load_labels: + model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2) + if 'saliency_scores' in meta.keys(): + model_inputs["saliency_scores"] = torch.zeros(ctx_l).double() + limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None + model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1)) + model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ + self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l) + else: + model_inputs["saliency_scores"] = model_inputs["timestamp_window"] + model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ + self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt + model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ] + + return dict(meta=meta, model_inputs=model_inputs) + + def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1): + gt_st = int(gt_window[0] / self.clip_len) + gt_st = min(gt_st, ctx_l-1) + gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1) + if gt_st > gt_ed: + gt_ed = gt_st + + if gt_st != gt_ed: + pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n) + else: + pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st] + + neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) + + try: + neg_clip_indices = random.sample(neg_pool, k=max_n) + except: + neg_clip_indices = pos_clip_indices + + return pos_clip_indices, neg_clip_indices + + def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1): + """Sum the scores from the three annotations, then take the two clips with the + maximum scores as positive, and two with the minimum scores as negative. + Args: + rel_clip_ids: list(int), list of relevant clip ids + scores: list([anno1_score, anno2_score, anno3_score]), + ctx_l: int + max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. + add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. + """ + # indices inside rel_clip_ids + scores = np.array(scores) # (#rel_clips, 3) + agg_scores = np.sum(scores, 1) # (#rel_clips, ) + sort_indices = np.argsort(agg_scores) # increasing + + # indices in the whole video + # the min(_, ctx_l-1) here is incorrect, but should not cause + # much troubles since this should be rarely used. + hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] + hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] + + if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]: + hard_neg_clip_indices = hard_pos_clip_indices + + easy_pos_clip_indices = [] + easy_neg_clip_indices = [] + + if self.add_easy_negative > 0: + easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) + if len(easy_neg_pool) >= max_n: + easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) + easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) + else: # copy the hard ones + easy_pos_clip_indices = hard_pos_clip_indices + easy_neg_clip_indices = hard_neg_clip_indices + + if self.easy_negative_only > 0: + return easy_pos_clip_indices, easy_neg_clip_indices + + pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices + neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices + return pos_clip_indices, neg_clip_indices + + def get_span_labels(self, windows, ctx_l): + """ + windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive) + Note a maximum of `self.max_windows` windows are used. + returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length + """ + if len(windows) > self.max_windows: + random.shuffle(windows) + windows = windows[:self.max_windows] + if self.span_loss_type == "l1": + windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx + windows = span_xx_to_cxw(windows) # normalized windows in cxw + elif self.span_loss_type == "ce": + windows = torch.Tensor([ + [int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1] + for w in windows]).long() # inclusive + else: + raise NotImplementedError + return windows + + def _get_query_feat_by_qid(self, qid): + if self.use_cache > 0: + try: + q_feat = self.txt_cache[qid] + except: + q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) + return torch.from_numpy(q_feat) + + q_feat_path = join(self.q_feat_dir, f"{qid}.npz") + try: + q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32) + except: + q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) + logger.info(f"Something wrong when loading the query feature {q_feat_path}.") + + if self.q_feat_type == "last_hidden_state": + # q_feat = q_feat[:self.max_q_l] + q_feat = q_feat + if self.normalize_t: + q_feat = l2_normalize_np_array(q_feat) + if self.txt_drop_ratio > 0: + q_feat = self.random_drop_rows(q_feat) + return torch.from_numpy(q_feat) # (D, ) or (Lq, D) + + def random_drop_rows(self, embeddings): + """randomly mask num_drop rows in embeddings to be zero. + Args: + embeddings: np.ndarray (L, D) + """ + num_drop_rows = round(len(embeddings) * self.txt_drop_ratio) + if num_drop_rows > 0: + row_indices = np.random.choice( + len(embeddings), size=num_drop_rows, replace=False) + embeddings[row_indices] = 0 + return embeddings + + def _get_video_feat_by_vid(self, vid): + v_feat_list = [] + for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs): + if self.use_cache > 0: + _feat = self.vid_cache[feat_type][vid] + else: + _feat_path = join(_feat_dir, f"{vid}.npz") + _feat = np.load(_feat_path)["features"].astype(np.float32) + # _feat = np.load(_feat_path)["features"][:self.max_v_l].astype(np.float32) + if self.normalize_v: + _feat = l2_normalize_np_array(_feat) + v_feat_list.append(_feat) + # some features are slightly longer than the others + min_len = min([len(e) for e in v_feat_list]) + v_feat_list = [e[:min_len] for e in v_feat_list] + v_feat = np.concatenate(v_feat_list, axis=1) + return torch.from_numpy(v_feat) # (Lv, D) + +class DatasetHL(Dataset): + def __init__(self, + dset_name, + domain, + data_path, + v_feat_types, + v_feat_dirs, + t_feat_dir, + use_tef=False + ): + assert dset_name in ['tvsum', 'youtube'] + self.dset_name = dset_name + dset_domain = {'tvsum': TVSUM_SPLITS, + 'youtube': YOUTUBE_SPLITS} + self.splits = dset_domain[dset_name] + assert domain in self.splits.keys() + + self.domain = domain + assert len(data_path) == 1 + self.data_path = data_path[0] if isinstance(data_path, list) else data_path + self.v_feat_types = v_feat_types.split('_') + self.v_feat_dirs = v_feat_dirs + self.q_feat_type = "last_hidden_state" + self.q_feat_dir = t_feat_dir + + self.txt_drop_ratio = 0 + self.normalize_t = True + self.normalize_v = True + + self.label = nncore.load(self.data_path) + self.use_tef = use_tef + + self.video_id = { + k: [s for s in self.splits[domain][k] if s in self.label] + for k in ('train', 'val') + } + self.set_state('train') + + def __len__(self): + return len(self.video_id[self.state]) + + def __getitem__(self, idx): + vid = self.get_video_id(idx) + video = self._get_video_feat_by_vid(vid) + saliency = self.get_saliency(idx) + + if self.dset_name == 'youtube': + saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())]) + elif self.dset_name == 'tvsum': + saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())]) + # saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency != min(saliency))[0].tolist())]) + else: + raise NotImplementedError + + num_clips = min(c.size(0) for c in (video, saliency)) + + video = video[:num_clips] + saliency = saliency[:num_clips] + + if self.use_tef: + ctx_l = video.shape[0] + tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + tef_ed = tef_st + 1.0 / ctx_l + tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) + video = torch.cat([video, tef], dim=1) # (Lv, Dv+2) + + data = dict( + video=DataContainer(video), + saliency=DataContainer(saliency, pad_value=-1), + saliency_pos_labels=saliency_pos_labels) + + if self.q_feat_dir is not None: + query = self._get_query_feat_by_qid(vid) + data['query'] = DataContainer(query, pad_value=float('inf')) + return data + + def set_state(self, state): + self.state = 'train' if state == 'train' else 'val' + + def get_video_id(self, idx): + return self.video_id[self.state][idx] + + def get_video(self, idx): + video_id = self.get_video_id(idx) + video = torch.from_numpy(self.video[video_id]).float() + optic = torch.from_numpy(self.optic[video_id]).float() + return torch.cat((video, optic), dim=1) + + def _get_video_feat_by_vid(self, vid): + v_feat_list = [] + for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs): + # if self.use_cache > 0: + # _feat = self.vid_cache[feat_type][vid] + # else: + if True: + _feat_path = join(_feat_dir, f"{vid}.npz") + _feat = np.load(_feat_path)["features"].astype(np.float32) + if self.normalize_v: + _feat = l2_normalize_np_array(_feat) + v_feat_list.append(_feat) + # some features are slightly longer than the others + min_len = min([len(e) for e in v_feat_list]) + v_feat_list = [e[:min_len] for e in v_feat_list] + v_feat = np.concatenate(v_feat_list, axis=1) + return torch.from_numpy(v_feat) # (Lv, D) + + def _get_query_feat_by_qid(self, qid): + # if self.use_cache > 0: + # try: + # q_feat = self.txt_cache[qid] + # except: + # q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) + # return torch.from_numpy(q_feat) + + q_feat_path = join(self.q_feat_dir, f"{qid}.npz") + try: + q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32) + except: + q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32) + logger.info(f"Something wrong when loading the query feature {q_feat_path}.") + + if self.q_feat_type == "last_hidden_state": + # q_feat = q_feat[:self.max_q_l] + q_feat = q_feat + if self.normalize_t: + q_feat = l2_normalize_np_array(q_feat) + if self.txt_drop_ratio > 0: + q_feat = self.random_drop_rows(q_feat) + return torch.from_numpy(q_feat) # (D, ) or (Lq, D) + + def get_saliency(self, idx): + if self.dset_name == 'tvsum': + video_id = self.get_video_id(idx) + saliency = torch.Tensor(self.label[video_id]['anno']) + + # top-5 saliency scores as a threshold. + # saliency_tmp = saliency.mean(1) + # topk = int(saliency_tmp.shape[0] * 0.1) + # th = saliency_tmp[torch.sort(saliency_tmp)[1][-topk]] # v4 + # saliency = saliency_tmp - th + + # saliency_tmp = saliency.mean(1) # med + # th = saliency_tmp.median() + # saliency = saliency_tmp - th + + saliency = (saliency - saliency.mean()).mean(dim=1) + # saliency = (saliency.sum(dim=1) - 20) / 80 # v2 + + elif self.dset_name == 'youtube': + video_id = self.get_video_id(idx) + saliency = [1 if s > 0 else 0 for s in self.label[video_id]['match']] + else: + raise NotImplementedError + return torch.Tensor(saliency) + + def evaluate(self, blob, k=5, save_dir=None, **kwargs): + # blob = nncore.to_dict_of_list(blob) + collected = [] + + if save_dir is not None: + import json + with open(os.path.join(save_dir, self.dset_name, self.domain +'.jsonl'), 'w') as f: + for idx, score in enumerate(blob): + video_id = self.get_video_id(idx) + entry = {'vid':video_id, 'pred': score[0].tolist(), 'gt': self.get_saliency(idx).tolist(), + 'duration': int(self.label[video_id]['frames']) / int(self.label[video_id]['fps']), + 'domain': self.label[video_id]['domain'], 'fps': self.label[video_id]['fps']} + if self.dset_name == 'tvsum': + entry.update({'title':self.label[video_id]['title']}) + if self.dset_name == 'youtube': + entry.update({'clip':self.label[video_id]['clip']}) + f.write(json.dumps(entry) + '\n') + + if self.dset_name == 'tvsum': + for i in range(20): + video_ap = [] + for idx, score in enumerate(blob): + inds = torch.argsort(score[0], descending=True) + video_id = self.get_video_id(idx) + label = torch.Tensor(self.label[video_id]['anno'])[:, i] + label = torch.where(label > label.median(), 1.0, .0) + label = label[inds].tolist()[:k] + + if (num_gt := sum(label)) == 0: + video_ap.append(0) + continue + + hits = ap = rec = 0 + prc = 1 + + for j, gt in enumerate(label): + hits += gt + _rec = hits / num_gt + _prc = hits / (j + 1) + ap += (_rec - rec) * (prc + _prc) / 2 + rec, prc = _rec, _prc + video_ap.append(ap) + collected.append(sum(video_ap) / len(video_ap)) + + elif self.dset_name == 'youtube': + for idx, score in enumerate(blob): + inds = torch.argsort(score[0], descending=True) + label = self.get_saliency(idx)[inds].tolist() + + if (num_gt := sum(label)) == 0: + collected.append(0) + continue + + hits = ap = rec = 0 + prc = 1 + + for i, gt in enumerate(label): + hits += gt + _rec = hits / num_gt + _prc = hits / (i + 1) + ap += (_rec - rec) * (prc + _prc) / 2 + rec, prc = _rec, _prc + collected.append(ap) + else: + raise NotImplementedError + + mean_ap = sum(collected) / len(collected) + results = dict(mAP=round(mean_ap, 5)) + return results + +class DatasetQFVS(Dataset): + def __init__(self,config, use_tef=True): + # pdb.set_trace() + self.config=config + self.dataset=[] + self.use_tef=use_tef + + self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl") + + for video_id in self.config["train_videos"]: + for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)): + for file in files: + self.dataset.append(file[:file.find("_oracle.txt")]+"_"+str(video_id)) + + def __getitem__(self,index): + video_id=self.dataset[index].split('_')[2] + feat_type = self.config['vid_feature'] + # pdb.set_trace() + feat_type = self.config['vid_feature'] + f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') + features=f['feature'][()] + # dim=features.shape[-1] + # features=features.reshape(-1, dim) + # seg_len=f['seg_len'][()] + dim = features.shape[-1] + ctx_l = features.shape[0] + seg_len = np.ones(ctx_l) + + # mask = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool) + # for j in range(len(seg_len)): + # for k in range(seg_len[j]): + # mask[j][k] = 1 + + # ctx_l = seg_len.sum() + features = torch.from_numpy(features) + # features = features[mask, :] + + if self.use_tef: + tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + tef_ed = tef_st + 1.0 / ctx_l + tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) + features = torch.cat([features, tef], dim=1) # (Lv, Dv+2) + + transfer={"Cupglass":"Glass", + "Musicalinstrument":"Instrument", + "Petsanimal":"Animal"} + + concept1,concept2=self.dataset[index].split('_')[0:2] + + concept1_GT=torch.zeros(ctx_l) + concept2_GT=torch.zeros(ctx_l) + with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f: + lines=f.readlines() + for index,line in enumerate(lines): + concepts=line.strip().split(',') + if concept1 in concepts: + concept1_GT[index]=1 + if concept2 in concepts: + concept2_GT[index]=1 + + # shot_num=seg_len.sum() + # mask_GT=torch.zeros(ctx_l) + # for i in range(shot_num): + # mask_GT[i]=1 + mask_GT=torch.ones(ctx_l) + + oracle_summary = torch.zeros(ctx_l) + GT_summary_shots = [] + with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f: + for line in f.readlines(): + GT_summary_shots.append(int(line.strip())) + GT_summary_shots = [x - 1 for x in GT_summary_shots] + for element in GT_summary_shots: + oracle_summary[element] = 1 + + if concept1 in transfer: + concept1=transfer[concept1] + if concept2 in transfer: + concept2=transfer[concept2] + concept1=self.embedding[concept1] + concept2=self.embedding[concept2] + + try: + saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())]) + except: + saliency_pos_labels_1 = torch.Tensor(0) + + try: + saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())]) + except: + saliency_pos_labels_2 = torch.Tensor(0) + + try: + saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())]) + except: + saliency_pos_labels_oracle = torch.Tensor(0) + + return { + 'features':features, + 'seg_len':torch.from_numpy(seg_len), + 'concept1_GT':concept1_GT, + 'concept2_GT':concept2_GT, + 'mask_GT':mask_GT, + 'oracle_summary':oracle_summary, + 'tokens_pad1':torch.from_numpy(concept1), + 'tokens_pad2':torch.from_numpy(concept2), + 'saliency_pos_labels_1': saliency_pos_labels_1, + 'saliency_pos_labels_2': saliency_pos_labels_2, + 'saliency_pos_labels_oracle': saliency_pos_labels_oracle, + } + + def __len__(self): + return len(self.dataset) + +def start_end_collate_mr(batch): + batch_meta = [e["meta"] for e in batch] # seems no need to collate ? + + model_inputs_keys = batch[0]["model_inputs"].keys() + batched_data = dict() + for k in model_inputs_keys: + if k == "span_labels": + batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch] + continue + if k in ["saliency_pos_labels", "saliency_neg_labels"]: + batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch]) + continue + + batched_data[k] = pad_sequences_1d( + [e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None) + return batch_meta, batched_data + +def start_end_collate_hl(batch): + model_inputs_keys = batch[0].keys() + + batched_data = dict() + for k in model_inputs_keys: + batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None) + return batched_data + +def start_end_collate_qfvs(batch): + model_inputs_keys = batch[0].keys() + + batched_data = dict() + for k in model_inputs_keys: + batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None) + + return batched_data + +def prepare_batch_inputs_mr(batched_model_inputs, device, non_blocking=False): + model_inputs = dict( + src_txt=batched_model_inputs["query_feat"][0].to(device, non_blocking=non_blocking), + src_txt_mask=batched_model_inputs["query_feat"][1].to(device, non_blocking=non_blocking), + src_vid=batched_model_inputs["video_feat"][0].to(device, non_blocking=non_blocking), + src_vid_mask=batched_model_inputs["video_feat"][1].to(device, non_blocking=non_blocking), + ) + targets = {} + targets['timestamp'] = batched_model_inputs["timestamp"][0].to(device, non_blocking=non_blocking) + targets['timestamp_mask'] = batched_model_inputs["timestamp"][1].to(device, non_blocking=non_blocking) + targets['timestamp_window'] = batched_model_inputs["timestamp_window"][0].to(device, non_blocking=non_blocking) + targets['span_labels_nn'] = batched_model_inputs["span_labels_nn"][0].to(device, non_blocking=non_blocking) + + if 'saliency_scores' in batched_model_inputs.keys(): + targets['saliency_scores'] = batched_model_inputs["saliency_scores"][0].to(device, non_blocking=non_blocking) + + if "span_labels" in batched_model_inputs: + targets["span_labels"] = [ + dict(spans=e["spans"].to(device, non_blocking=non_blocking)) + for e in batched_model_inputs["span_labels"] + ] + if "saliency_pos_labels" in batched_model_inputs: + for name in ["saliency_pos_labels", "saliency_neg_labels"]: + targets[name] = batched_model_inputs[name].to(device, non_blocking=non_blocking) + + if "weight_ablation" in batched_model_inputs: + targets["weight_ablation"] = batched_model_inputs["weight_ablation"][0].to(device, non_blocking=non_blocking) + + targets = None if len(targets) == 0 else targets + return model_inputs, targets + +def prepare_batch_inputs_hl(batched_model_inputs, device='cuda', non_blocking=False): + src_vid = batched_model_inputs['video'][0].to(device, non_blocking=non_blocking) + src_vid_mask = batched_model_inputs['video'][1].bool().to(device, non_blocking=non_blocking) + src_txt = batched_model_inputs['query'][0].to(device, non_blocking=non_blocking) \ + if 'query' in batched_model_inputs.keys() else None + src_txt_mask = batched_model_inputs['query'][1].bool().to(device, non_blocking=non_blocking) \ + if 'query' in batched_model_inputs.keys() else None + + model_inputs = dict( + src_vid=src_vid, src_vid_mask=src_vid_mask, + src_txt=src_txt, src_txt_mask=src_txt_mask) + + # if 'audio' in batched_model_inputs.keys(): + # src_aud = batched_model_inputs['audio'][0].bool().to(device, non_blocking=non_blocking) + # src_aud_mask = batched_model_inputs['audio'][1].bool().to(device, non_blocking=non_blocking) + # model_inputs['src_aud']=src_aud; model_inputs['src_aud_mask']=src_aud_mask; + + targets = {} + saliency = batched_model_inputs['saliency'][0].to(device, non_blocking=non_blocking) + saliency_pos_labels = batched_model_inputs['saliency_pos_labels'][0].to(device, non_blocking=non_blocking) + + targets['saliency_scores'] = saliency + targets['saliency_pos_labels'] = saliency_pos_labels.long() + targets['timestamp_mask'] = batched_model_inputs["video"][1].to(device, non_blocking=non_blocking) + targets['timestamp_window'] = 1 * (saliency > 0) + + return model_inputs, targets + +def prepare_batch_inputs_qfvs(data, config, eval=False): + if not eval: + features, mask, seg_len, \ + concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \ + src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\ + saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \ + data['features'][0], data['features'][1], data['seg_len'][0],\ + data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\ + data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \ + data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0], + else: + features, mask, seg_len, \ + src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \ + data['features'][0], data['features'][1], data['seg_len'][0],\ + data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1] + + # preprocess for vid input. + seq = features.to('cuda') + mask = mask.to('cuda') + + # for txt input. + src_txt_1 = src_txt_1.to(torch.float32).to('cuda') + src_txt_2 = src_txt_2.to(torch.float32).to('cuda') + src_txt_mask_1 = src_txt_mask_1.to('cuda') + src_txt_mask_2 = src_txt_mask_2.to('cuda') + + src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda') + src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda') + + model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1) + model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2) + model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle) + + if not eval: + targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda')) + targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda')) + targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda')) + + targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda') + targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda') + targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda') + + return model_inputs_1, model_inputs_2, model_inputs_oracle, \ + targets_1, targets_2, targets_oracle, mask_GT + else: + return model_inputs_1, model_inputs_2, model_inputs_oracle, mask diff --git a/main/dataset_qfvs.py b/main/dataset_qfvs.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec00e2df1ff33503d6d145e27bad72293cd73d0 --- /dev/null +++ b/main/dataset_qfvs.py @@ -0,0 +1,284 @@ +import os +import pdb +import h5py +import nncore +import torch +from torch.utils.data import Dataset +import numpy as np +from tqdm import tqdm +import random +import logging +from os.path import join, exists +from nncore.dataset import DATASETS +from nncore.parallel import DataContainer +from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS +from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array +from utils.tensor_utils import pad_sequences_1d +from utils.span_utils import span_xx_to_cxw + +logger = logging.getLogger(__name__) + +class DatasetQFVS(Dataset): + def __init__(self,config, use_tef=True): + # pdb.set_trace() + self.config=config + self.dataset=[] + self.use_tef=use_tef + + self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl") + + self.transfer={"Cupglass":"Glass", + "Musicalinstrument":"Instrument", + "Petsanimal":"Animal"} + + self.f_dict = {} + feat_type = self.config['vid_feature'] + + for video_id in self.config["train_videos"]: + self.f_dict[str(video_id)] = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') + for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)): + for file in files: + self.dataset.append(['Oracle', file[:file.find("_oracle.txt")]+"_"+str(video_id)]) + + if self.config['qfvs_dense_shot'] > 0: + dense_concept = {} + feat_type = self.config['vid_feature'] + feat=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') + features=feat['features'][()] + seg_len=feat['seg_len'][()] + with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+str(video_id)+"/P0"+str(video_id)+".txt","r") as f: + lines=f.readlines() + for index,line in enumerate(lines): + concepts=line.strip().split(',') + for concept in concepts: + if concept in self.transfer: + concept= self.transfer[concept] + if concept not in dense_concept: + # dense_concept[concept] = torch.zeros(seg_len.sum()) + dense_concept[concept] = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"]) + else: + dense_concept[concept][index] = 1 + + for key, value in dense_concept.items(): + if value.sum().item() > 0: + self.dataset.append([video_id, key, value]) + + def __getitem__(self, index): + if self.dataset[index][0] == 'Oracle': + return self.get_oracle(index) + else: + return self.get_dense(index) + + def get_dense(self,index): + video_id=str(self.dataset[index][0]) + f = self.f_dict[video_id] + # feat_type = self.config['vid_feature'] + # f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') + features=f['features'][()] + seg_len=f['seg_len'][()] + + dim = features.shape[-1] + + mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool) + for j in range(len(seg_len)): + for k in range(seg_len[j]): + mask_GT[j][k] = 1 + + features = torch.from_numpy(features) + + concept1 = concept2 = self.dataset[index][1] + concept1_GT = concept2_GT = oracle_summary = self.dataset[index][2] + + if concept1 in self.transfer: + concept1=self.transfer[concept1] + if concept2 in self.transfer: + concept2=self.transfer[concept2] + concept1=self.embedding[concept1] + concept2=self.embedding[concept2] + + concept1 = l2_normalize_np_array(concept1) + concept2 = l2_normalize_np_array(concept2) + + try: + saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())]) + except: + saliency_pos_labels_1 = torch.Tensor(0) + + try: + saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())]) + except: + saliency_pos_labels_2 = torch.Tensor(0) + + try: + saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())]) + except: + saliency_pos_labels_oracle = torch.Tensor(0) + + return { + 'features':features, + 'seg_len':torch.from_numpy(seg_len), + 'concept1_GT':concept1_GT, + 'concept2_GT':concept2_GT, + 'mask_GT':mask_GT, + 'oracle_summary':oracle_summary, + 'tokens_pad1':torch.from_numpy(concept1), + 'tokens_pad2':torch.from_numpy(concept2), + 'saliency_pos_labels_1': saliency_pos_labels_1, + 'saliency_pos_labels_2': saliency_pos_labels_2, + 'saliency_pos_labels_oracle': saliency_pos_labels_oracle, + } + + def get_oracle(self,index): + video_id=self.dataset[index][1].split('_')[2] + f = self.f_dict[video_id] + # video_id=self.dataset[index][1].split('_')[2] + # feat_type = self.config['vid_feature'] + # f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r') + features=f['features'][()] + seg_len=f['seg_len'][()] + + dim = features.shape[-1] + + mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool) + for j in range(len(seg_len)): + for k in range(seg_len[j]): + mask_GT[j][k] = 1 + + features = torch.from_numpy(features) + + concept1,concept2=self.dataset[index][1].split('_')[0:2] + + concept1_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"]) + concept2_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"]) + # concept1_GT=torch.zeros(seg_len.sum()) + # concept2_GT= torch.zeros(seg_len.sum()) + with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f: + lines=f.readlines() + for index,line in enumerate(lines): + concepts=line.strip().split(',') + if concept1 in concepts: + concept1_GT[index]=1 + if concept2 in concepts: + concept2_GT[index]=1 + + # oracle_summary =torch.zeros(seg_len.sum()) + oracle_summary = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"]) + GT_summary_shots = [] + with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f: + for line in f.readlines(): + GT_summary_shots.append(int(line.strip())) + GT_summary_shots = [x - 1 for x in GT_summary_shots] + for element in GT_summary_shots: + oracle_summary[element] = 1 + + if concept1 in self.transfer: + concept1=self.transfer[concept1] + if concept2 in self.transfer: + concept2=self.transfer[concept2] + concept1=self.embedding[concept1] + concept2=self.embedding[concept2] + + concept1 = l2_normalize_np_array(concept1) + concept2 = l2_normalize_np_array(concept2) + + try: + saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())]) + except: + saliency_pos_labels_1 = torch.Tensor(0) + + try: + saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())]) + except: + saliency_pos_labels_2 = torch.Tensor(0) + + try: + saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())]) + except: + saliency_pos_labels_oracle = torch.Tensor(0) + + return { + 'features':features, + 'seg_len':torch.from_numpy(seg_len), + 'concept1_GT':concept1_GT, + 'concept2_GT':concept2_GT, + 'mask_GT':mask_GT, + 'oracle_summary':oracle_summary, + 'tokens_pad1':torch.from_numpy(concept1), + 'tokens_pad2':torch.from_numpy(concept2), + 'saliency_pos_labels_1': saliency_pos_labels_1, + 'saliency_pos_labels_2': saliency_pos_labels_2, + 'saliency_pos_labels_oracle': saliency_pos_labels_oracle, + } + + def __len__(self): + return len(self.dataset) + +def start_end_collate_qfvs(batch): + model_inputs_keys = batch[0].keys() + + batched_data = dict() + for k in model_inputs_keys: + batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None) + + return batched_data + +def prepare_batch_inputs_qfvs(data, config, eval=False): + if not eval: + features, mask, seg_len, \ + concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \ + src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\ + saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \ + data['features'][0], data['mask_GT'][0], data['seg_len'][0],\ + data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\ + data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \ + data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0], + else: + features, mask, seg_len, \ + src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \ + data['features'][0], data['mask_GT'][0], data['seg_len'][0],\ + data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1] + + # preprocess for vid input. + mask_GT = mask.to('cuda').reshape(1, -1).bool() + seq = features.to('cuda').squeeze(0) + mask = mask.to('cuda').squeeze(0) + num_seg = seq.shape[0] + + ctx_l = seq.shape[1] + tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + tef_ed = tef_st + 1.0 / ctx_l + tef = torch.stack([tef_st, tef_ed], dim=1).to('cuda') # (Lv, 2) + + tef = tef.squeeze(0).repeat(seq.shape[0], 1, 1) + seq = torch.cat([seq, tef], dim=-1) + + # for txt input. + src_txt_1 = src_txt_1.to(torch.float32).to('cuda').repeat(num_seg, 1, 1) + src_txt_2 = src_txt_2.to(torch.float32).to('cuda').repeat(num_seg, 1, 1) + src_txt_mask_1 = src_txt_mask_1.to('cuda').repeat(num_seg, 1) + src_txt_mask_2 = src_txt_mask_2.to('cuda').repeat(num_seg, 1) + + src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda') + src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda') + + model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1) + model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2) + model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle) + + # concept1_GT = concept1_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num']) + # concept2_GT = concept2_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num']) + # oracle_summary_GT = oracle_summary_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num']) + + if not eval: + targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda')) + targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda')) + targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda')) + + targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda') + targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda') + targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda') + + return model_inputs_1, model_inputs_2, model_inputs_oracle, \ + targets_1, targets_2, targets_oracle, mask_GT + else: + return model_inputs_1, model_inputs_2, model_inputs_oracle, mask_GT \ No newline at end of file diff --git a/main/inference_demo.py b/main/inference_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..7659564d0598b249807620cec917374c2fa193f0 --- /dev/null +++ b/main/inference_demo.py @@ -0,0 +1,81 @@ +import pdb +import pprint +from tqdm import tqdm, trange +import numpy as np +import os +from collections import OrderedDict, defaultdict +from utils.basic_utils import AverageMeter + +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader + +from main.config import TestOptions, setup_model +from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr +from eval.eval import eval_submission +from eval.postprocessing import PostProcessorDETR +from utils.basic_utils import save_jsonl, save_json +from utils.temporal_nms import temporal_nms +from utils.span_utils import span_cxw_to_xx +from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array + +import logging +import importlib + +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def load_model(): + logger.info("Setup config, data and model...") + opt = TestOptions().parse() + # pdb.set_trace() + cudnn.benchmark = True + cudnn.deterministic = False + + model, criterion, _, _ = setup_model(opt) + return model + +def load_data(save_dir): + vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32) + txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32) + + vid = torch.from_numpy(l2_normalize_np_array(vid)) + txt = torch.from_numpy(l2_normalize_np_array(txt)) + clip_len = 2 + ctx_l = vid.shape[0] + + timestamp = ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2) + + if True: + tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + tef_ed = tef_st + 1.0 / ctx_l + tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) + vid = torch.cat([vid, tef], dim=1) # (Lv, Dv+2) + + src_vid = vid.unsqueeze(0).cuda() + src_txt = txt.unsqueeze(0).cuda() + src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda() + src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda() + + return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l + +if __name__ == '__main__': + clip_len = 2 + save_dir = '/data/home/qinghonglin/univtg/demo/tmp' + + model = load_model() + src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir) + with torch.no_grad(): + output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask) + + pred_logits = output['pred_logits'][0].cpu() + pred_spans = output['pred_spans'][0].cpu() + pred_saliency = output['saliency_scores'].cpu() + + pdb.set_trace() + top1 = (pred_spans + timestamp)[torch.argmax(pred_logits)] * ctx_l * clip_len + print(top1) + print(pred_saliency.argmax()*clip_len) \ No newline at end of file diff --git a/main/inference_hl.py b/main/inference_hl.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4e1eb206aefe89e741ba86825002e23b368ce0 --- /dev/null +++ b/main/inference_hl.py @@ -0,0 +1,229 @@ +import os +import pdb +import time +import json +import pprint +import random +import importlib +import numpy as np +from tqdm import tqdm, trange +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import sys +sys.path.append('/Users/kevin/univtg') +from main.config import BaseOptions, setup_model +from main.dataset import DatasetHL, prepare_batch_inputs_hl, start_end_collate_hl +from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl +from utils.model_utils import count_parameters + +import logging +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def eval_epoch(model, train_val_dataset, opt): #, nms_thresh, device): + model.eval() + + scores = [] + train_val_dataset.set_state('val') + val_loader = DataLoader( + train_val_dataset, + collate_fn=start_end_collate_hl, + batch_size=opt.eval_bsz, + num_workers=opt.num_workers, + shuffle=False, + pin_memory=opt.pin_memory + ) + + with torch.no_grad(): + for data in val_loader: + model_inputs, targets = prepare_batch_inputs_hl(data) + outputs = model(**model_inputs) + # pred_cls = outputs['pred_logits'].squeeze(-1) + # pred_cls = outputs['saliency_scores'] + # pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1) + + # pdb.set_trace() + if opt.f_loss_coef == 0: + pred_cls = outputs['saliency_scores'] + elif opt.s_loss_intra_coef == 0: + pred_cls = outputs['pred_logits'].squeeze(-1) + else: + if opt.eval_mode == 'add': + pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1) + else: + pred_cls = outputs['pred_logits'].squeeze(-1) + + pred_cls = pred_cls.detach().cpu() + scores.append(pred_cls) + map = round(train_val_dataset.evaluate(scores, save_dir='./plot')['mAP'] * 100, 4) + return map + +def train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer): + logger.info(f"[Epoch {epoch_i+1}]") + model.train() + criterion.train() + + train_val_dataset.set_state('train') + train_loader = DataLoader( + train_val_dataset, + collate_fn=start_end_collate_hl, + batch_size=opt.bsz, + num_workers=opt.num_workers, + shuffle=True, + pin_memory=opt.pin_memory + ) + + # init meters + time_meters = defaultdict(AverageMeter) + loss_meters = defaultdict(AverageMeter) + + num_training_examples = len(train_loader) + timer_dataloading = time.time() + for batch_idx, batch in enumerate(train_loader): + time_meters["dataloading_time"].update(time.time() - timer_dataloading) + timer_start = time.time() + model_inputs, targets = prepare_batch_inputs_hl(batch) + time_meters["prepare_inputs_time"].update(time.time() - timer_start) + + timer_start = time.time() + outputs = model(**model_inputs) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + time_meters["model_forward_time"].update(time.time() - timer_start) + + timer_start = time.time() + optimizer.zero_grad() + losses.backward() + if opt.grad_clip > 0: + nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + optimizer.step() + time_meters["model_backward_time"].update(time.time() - timer_start) + + loss_dict["loss_overall"] = float(losses) + for k, v in loss_dict.items(): + loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) + + timer_dataloading = time.time() + if opt.debug and batch_idx == 3: + break + + # print/add logs + tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1) + for k, v in loss_meters.items(): + tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1) + + to_write = opt.train_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i+1, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) + with open(opt.train_log_filepath, "a") as f: + f.write(to_write) + + logger.info("Epoch time stats:") + for name, meter in time_meters.items(): + d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]} + logger.info(f"{name} ==> {d}") + +# train in single domain. +def train(model, criterion, optimizer, lr_scheduler, train_val_dataset, opt): + # if opt.device.type == "cuda": + # logger.info("CUDA enabled.") + # model.to(opt.device) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) + opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" + opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" + + prev_best_score = 0. + if opt.start_epoch is None: + start_epoch = -1 if opt.eval_init else 0 + else: + start_epoch = opt.start_epoch + + for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): + if epoch_i > -1: + train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer) + lr_scheduler.step() + eval_epoch_interval = opt.eval_epoch + if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: + with torch.no_grad(): + scores = eval_epoch(model, train_val_dataset, opt) + tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-{train_val_dataset.domain}-mAP", float(scores), epoch_i+1) + if prev_best_score < scores: + prev_best_score = scores + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_{train_val_dataset.domain}_best.ckpt")) + tb_writer.close() + return prev_best_score + +def start_training(): + logger.info("Setup config, data and model...") + opt = BaseOptions().parse() + set_seed(opt.seed) + + from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS + if opt.dset_name == "tvsum": + domain_splits = TVSUM_SPLITS.keys() + if opt.dset_name == "youtube": + domain_splits = YOUTUBE_SPLITS.keys() + + scores = {} + if opt.lr_warmup > 0: + # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz + total_steps = opt.n_epoch + warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps) + opt.lr_warmup = [warmup_steps, total_steps] + + domain_splits = domain_splits if not opt.domain_name else [opt.domain_name] + + for domain in domain_splits: + dataset_config = dict( + dset_name=opt.dset_name, + domain=domain, + data_path=opt.train_path, + v_feat_types=opt.v_feat_types, + v_feat_dirs=opt.v_feat_dirs, + t_feat_dir=opt.t_feat_dir, + use_tef=True + ) + dataloader = DatasetHL(**dataset_config) + + model, criterion, optimizer, lr_scheduler = setup_model(opt) + count_parameters(model) + logger.info(f"Start Training {domain}") + best_score = train(model, criterion, optimizer, lr_scheduler, dataloader, opt) + scores[domain] = best_score + scores['AVG'] = sum(scores.values()) / len(scores) + + # save the final results. + save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json") + save_json(scores, save_metrics_path, save_pretty=True, sort_keys=False) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text(f"HL-{opt.dset_name}", dict_to_markdown(scores, max_str_len=None)) + tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-avg-mAP-key", float(scores['AVG']), 1) + tb_writer.close() + # return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug + + print(opt.dset_name) + print(scores) + return + +if __name__ == '__main__': + start_training() + results = logger.info("\n\n\nFINISHED TRAINING!!!") diff --git a/main/inference_mr.py b/main/inference_mr.py new file mode 100644 index 0000000000000000000000000000000000000000..4aea2de137ac46fa91f737d49998a00165423bce --- /dev/null +++ b/main/inference_mr.py @@ -0,0 +1,273 @@ +import pdb +import pprint +from tqdm import tqdm, trange +import numpy as np +import os +from collections import OrderedDict, defaultdict +from utils.basic_utils import AverageMeter + +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader + +from main.config import TestOptions, setup_model +from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr +from eval.eval import eval_submission +from eval.postprocessing import PostProcessorDETR +from utils.basic_utils import save_jsonl, save_json +from utils.temporal_nms import temporal_nms +from utils.span_utils import span_cxw_to_xx + +import logging +import importlib + +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + + +def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms): + mr_res_after_nms = [] + for e in mr_res: + e["pred_relevant_windows"] = temporal_nms( + e["pred_relevant_windows"][:max_before_nms], + nms_thd=nms_thd, + max_after_nms=max_after_nms + ) + mr_res_after_nms.append(e) + return mr_res_after_nms + + +def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename): + # IOU_THDS = (0.5, 0.7) + logger.info("Saving/Evaluating before nms results") + submission_path = os.path.join(opt.results_dir, save_submission_filename) + save_jsonl(submission, submission_path) + + if opt.eval_split_name in ["val", "test"]: # since test_public has no GT + metrics = eval_submission( + submission, gt_data, + verbose=opt.debug, match_number=not opt.debug, + ) + save_metrics_path = submission_path.replace(".jsonl", "_metrics.json") + save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False) + latest_file_paths = [submission_path, save_metrics_path] + else: + metrics = None + latest_file_paths = [submission_path, ] + + if opt.nms_thd != -1: + logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd)) + submission_after_nms = post_processing_mr_nms( + submission, nms_thd=opt.nms_thd, + max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms + ) + + logger.info("Saving/Evaluating nms results") + submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd)) + save_jsonl(submission_after_nms, submission_nms_path) + if opt.eval_split_name == "val": + metrics_nms = eval_submission( + submission_after_nms, gt_data, + verbose=opt.debug, match_number=not opt.debug + ) + save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json") + save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False) + latest_file_paths += [submission_nms_path, save_metrics_nms_path] + else: + metrics_nms = None + latest_file_paths = [submission_nms_path, ] + else: + metrics_nms = None + return metrics, metrics_nms, latest_file_paths + + +@torch.no_grad() +def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): + model.eval() + if criterion: + assert eval_loader.dataset.load_labels + criterion.eval() + + loss_meters = defaultdict(AverageMeter) + write_tb = tb_writer is not None and epoch_i is not None + + mr_res = [] + for batch in tqdm(eval_loader, desc="compute st ed scores"): + query_meta = batch[0] + model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory) + outputs = model(**model_inputs) + prob = outputs["pred_logits"] # the last channel may be 1 or 2. + # if opt.eval_mode == 'v1': + # prob = prob * outputs["saliency_scores"].unsqueeze(-1) # v1 + # if opt.eval_mode == 'v2': + # prob = F.softmax(prob, dim=1) * outputs["saliency_scores"].unsqueeze(-1) # v2 + # if opt.eval_mode == 'v3': + # prob = outputs["saliency_scores"].unsqueeze(-1) + if outputs["pred_logits"].shape[-1] > 1: + prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #queries, #classes=2) + if opt.span_loss_type == "l1": + scores = prob[..., 0] # * (batch_size, #queries) foreground label is 0, we directly take it + pred_spans = outputs["pred_spans"] # (bsz, #queries, 2) + + if opt.model_id not in ['moment_detr']: # dense regression. + start_spans = targets['timestamp'] + pred_spans = start_spans + pred_spans + mask = targets['timestamp_mask'].bool() + scores[~mask] = 0 + # if opt.eval_mode == 'v4': + # _mask = targets['timestamp_window'].bool() + # scores[~_mask] = 0 + + if opt.eval_mode == 'add': + # pdb.set_trace() + _saliency_scores = outputs["saliency_scores"].half() + prob.squeeze(-1) + else: + _saliency_scores = outputs["saliency_scores"].half() # (bsz, L) + + if opt.eval_mode == 'add_mr': + prob = outputs["saliency_scores"].half().unsqueeze(-1) + prob + + saliency_scores = [] + valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist() + for j in range(len(valid_vid_lengths)): + saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist()) + else: + bsz, n_queries = outputs["pred_spans"].shape[:2] # # (bsz, #queries, max_v_l *2) + pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l) + # TODO use more advanced decoding method with st_ed product + pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) # 2 * (bsz, #queries, 2) + scores = torch.prod(pred_span_scores, 2) # (bsz, #queries) + pred_spans[:, 1] += 1 + pred_spans *= opt.clip_length + + # compose predictions + for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())): + if opt.span_loss_type == "l1": + if opt.model_id in ['moment_detr']: + spans = span_cxw_to_xx(spans) * meta["duration"] + else: + spans = spans * meta["duration"] + spans = torch.clamp(spans, 0, meta["duration"]) # added by Kevin, since window cannot be longer than video duration. + + # (#queries, 3), [st(float), ed(float), score(float)] + cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist() + if not opt.no_sort_results: + cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True) + cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds] + cur_query_pred = dict( + qid=meta["qid"], + query=meta["query"], + vid=meta["vid"], + pred_relevant_windows=cur_ranked_preds, + pred_saliency_scores=saliency_scores[idx] + ) + mr_res.append(cur_query_pred) + + if criterion: + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + loss_dict["loss_overall"] = float(losses) # for logging only + for k, v in loss_dict.items(): + loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) + + if opt.debug: + break + + if write_tb and criterion: + for k, v in loss_meters.items(): + tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) + + post_processor = PostProcessorDETR( + clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, + min_w_l=2, max_w_l=150, move_window_method="left", + # process_func_names=("clip_ts", "round_multiple") + process_func_names=["round_multiple"] # have added `clamp' op on line 147, thus we do not need `clip_ts' again; + ) + # todo: are we need round_multiple? + if opt.round_multiple > 0: + mr_res = post_processor(mr_res) + return mr_res, loss_meters + +def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer): + """compute and save query and video proposal embeddings""" + eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict) + return eval_res, eval_loss_meters + +def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None): + logger.info("Generate submissions") + model.eval() + if criterion is not None and eval_dataset.load_labels: + criterion.eval() + else: + criterion = None + + eval_loader = DataLoader( + eval_dataset, + collate_fn=start_end_collate_mr, + batch_size=opt.eval_bsz, + num_workers=opt.num_workers, + shuffle=False, + pin_memory=opt.pin_memory + ) + + submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer) + if opt.no_sort_results: + save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl") + metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing( + submission, opt, eval_dataset.data, save_submission_filename) + return metrics, metrics_nms, eval_loss_meters, latest_file_paths + +def start_inference(): + logger.info("Setup config, data and model...") + opt = TestOptions().parse() + # pdb.set_trace() + cudnn.benchmark = True + cudnn.deterministic = False + + assert opt.eval_path is not None + eval_dataset = DatasetMR( + dset_name=opt.dset_name, + data_path=opt.eval_path, + v_feat_dirs=opt.v_feat_dirs, + q_feat_dir=opt.t_feat_dir, + v_feat_dim=opt.v_feat_dim, + q_feat_dim=opt.t_feat_dim, + q_feat_type="last_hidden_state", + max_q_l=opt.max_q_l, + max_v_l=opt.max_v_l, + ctx_mode=opt.ctx_mode, + data_ratio=opt.data_ratio, + normalize_v=not opt.no_norm_vfeat, + normalize_t=not opt.no_norm_tfeat, + clip_len=opt.clip_length, + max_windows=opt.max_windows, + load_labels=True, # opt.eval_split_name == "val", + span_loss_type=opt.span_loss_type, + txt_drop_ratio=0, + use_cache=opt.use_cache, + ) + + if opt.lr_warmup > 0: + # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz + total_steps = opt.n_epoch + warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps) + opt.lr_warmup = [warmup_steps, total_steps] + + model, criterion, _, _ = setup_model(opt) + save_submission_filename = "inference_{}_{}_{}_preds.jsonl".format( + opt.dset_name, opt.eval_split_name, opt.eval_id) + logger.info("Starting inference...") + with torch.no_grad(): + metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ + eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion) + logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) + if metrics_nms is not None: + logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) + + +if __name__ == '__main__': + start_inference() diff --git a/main/inference_qfvs.py b/main/inference_qfvs.py new file mode 100644 index 0000000000000000000000000000000000000000..4feed8b8399a2dc1fb081e14e9acc4ece64650ed --- /dev/null +++ b/main/inference_qfvs.py @@ -0,0 +1,342 @@ +import os +import pdb +import time +import json +import pprint +import random +import importlib +import numpy as np +from tqdm import tqdm, trange +from collections import defaultdict + +import h5py +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import sys +sys.path.append('/Users/kevin/univtg') +from main.config import BaseOptions, setup_model +from main.dataset_qfvs import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs +from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle, l2_normalize_np_array +from utils.model_utils import count_parameters +from eval.qfvs import calculate_semantic_matching, load_videos_tag + +import logging +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def eval_epoch(model, config, opt): + model.eval() + f1_sum = 0; p_sum = 0; r_sum = 0 + + assert len(config['test_videos']) == 1 + video_id = config['test_videos'][0] + embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl") + + feat_type = config['vid_feature'] + feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r') + features = torch.from_numpy(feat['features'][()]) + seg_len = torch.from_numpy(feat['seg_len'][()]) + # seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda() + + # dim = features.shape[-1] + # ctx_l = seg_len.sum().cpu() + + # dim = features.shape[-1] + # ctx_l = features.shape[1] + # seg_len = torch.ones(ctx_l) + # features = features.reshape(-1, dim)[:ctx_l] + + # tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + # tef_ed = tef_st + 1.0 / ctx_l + # tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2) + # features = torch.cat([features, tef], dim=1) # (Lv, Dv+2) + + transfer = {"Cupglass": "Glass", + "Musicalinstrument": "Instrument", + "Petsanimal": "Animal"} + + with open(os.path.join('./plot', opt.dset_name, str(opt.qfvs_split) +'.jsonl'), 'w') as f_write: + for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)): + evaluation_num=len(files) + + mask_GT = torch.zeros(config["max_segment_num"], config["max_frame_num"], dtype=torch.bool).cuda() + for j in range(len(seg_len)): + for k in range(seg_len[j]): + mask_GT[j][k] = 1 + + for file in files: + summaries_GT=[] + with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f: + for line in f.readlines(): + summaries_GT.append(int(line.strip())) + + concept1, concept2 = file.split('_')[0:2] + + ############## + if concept1 in transfer: + concept1 = transfer[concept1] + if concept2 in transfer: + concept2 = transfer[concept2] + concept1 = embedding[concept1] + concept2 = embedding[concept2] + + concept1 = l2_normalize_np_array(concept1) + concept2 = l2_normalize_np_array(concept2) + + data = { + 'features':features, + 'seg_len': seg_len, + 'tokens_pad1':torch.from_numpy(concept1), + 'tokens_pad2':torch.from_numpy(concept2), + 'mask_GT': mask_GT + } + + input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True) + + summaries_GT = [x - 1 for x in summaries_GT] + video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat") + + if opt.f_loss_coef == 0: + output_type = 'saliency_scores' + elif opt.s_loss_intra_coef == 0: + output_type = 'pred_logits' + else: + if config['qfvs_score_ensemble'] > 0: + output_type = ['pred_logits', 'saliency_scores'] + else: + output_type = 'pred_logits' + + with torch.no_grad(): + if not isinstance(output_type, list): + score1 = model(**input1)[output_type].squeeze() + score1 = score1.masked_select(mask_GT) + + score2 = model(**input2)[output_type].squeeze() + score2 = score2.masked_select(mask_GT) + + score = model(**input_oracle)[output_type].squeeze() + score = score.masked_select(mask_GT) + else: + score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda() + for output_t in output_type: + score1 += model(**input1)[output_t].squeeze().masked_select(mask_GT) + score2 += model(**input2)[output_t].squeeze().masked_select(mask_GT) + score += model(**input_oracle)[output_t].squeeze().masked_select(mask_GT) + + if config['qfvs_score_gather'] > 0: + score = score + score1 + score2 + else: + score = score + + # since video4 features dim is greater than video_shots_tag. + score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])] + _, top_index = score.topk(int(score.shape[0] * config["top_percent"])) + + c1, c2 = file.split('_')[0:2] + if c1 in transfer: + c1 = transfer[c1] + if c2 in transfer: + c2 = transfer[c2] + + p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1) + entry = {'concept1': c1, 'concept2': c2, + 'score':score.tolist(), + 'top_percent': config["top_percent"], + 'top_pred':top_index.tolist(), + 'gt':summaries_GT, + 'p': p, 'r': r, 'f1': f1, + 'shots': video_shots_tag[video_id-1].shape[0]} + f_write.write(json.dumps(entry) + '\n') + f1_sum+=f1; r_sum+=r; p_sum+=p + return {'F': round(100* f1_sum/evaluation_num,2) , + 'R': round(100* r_sum/evaluation_num,2) , + 'P': round(100* p_sum/evaluation_num,2) } + +def idx2time(idx): + sec1, sec2 = idx*5, (idx+1)*5 + + h1 = sec1 // 3600 + m1 = (sec1 - h1*3600) // 60 + s1 = sec1 % 60 + + h2 = sec2 // 3600 + m2 = (sec2 - h2*3600) // 60 + s2 = sec2 % 60 + print(h1,m1,s1,'\t', h2,m2,s2) + +def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer): + model.train() + criterion.train() + + # init meters + time_meters = defaultdict(AverageMeter) + loss_meters = defaultdict(AverageMeter) + + timer_dataloading = time.time() + loss_total = 0 + + for batch_idx, batch in enumerate(tqdm(train_loader)): + time_meters["dataloading_time"].update(time.time() - timer_dataloading) + timer_start = time.time() + model_input1, model_input2, model_input_oracle, \ + model_gt1, model_gt2, model_gt_oracle, \ + mask_GT = prepare_batch_inputs_qfvs(batch, config) + time_meters["prepare_inputs_time"].update(time.time() - timer_start) + + timer_start = time.time() + output1 = model(**model_input1) + output2 = model(**model_input2) + output_oracle = model(**model_input_oracle) + + loss_dict = {} + loss_dict1 = criterion(output1, model_gt1, mask_GT) + loss_dict2 = criterion(output2, model_gt2, mask_GT) + loss_dict3 = criterion(output_oracle, model_gt_oracle, mask_GT) + + weight_dict = criterion.weight_dict + if config['qfvs_loss_gather'] > 0: + for k in loss_dict1.keys(): + loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k] + else: + loss_dict = loss_dict3 + + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + loss_total += losses.item() + + time_meters["model_forward_time"].update(time.time() - timer_start) + timer_start = time.time() + optimizer.zero_grad() + losses.backward() + if opt.grad_clip > 0: + nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + optimizer.step() + time_meters["model_backward_time"].update(time.time() - timer_start) + + timer_dataloading = time.time() + return round(loss_total / len(train_loader), 2) + +# train in single domain. +def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config): + # if opt.device.type == "cuda": + # logger.info("CUDA enabled.") + # model.to(opt.device) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) + opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" + opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" + + prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0} + if opt.start_epoch is None: + start_epoch = -1 if opt.eval_init else 0 + else: + start_epoch = opt.start_epoch + + val_score = eval_epoch(model, config, opt) + tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0) + logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]" + f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]" + f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]") + for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): + if epoch_i > -1: + loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer) + lr_scheduler.step() + eval_epoch_interval = opt.eval_epoch + if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: + with torch.no_grad(): + val_score = eval_epoch(model, config, opt) + tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1) + logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]" + f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]" + f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]") + + if prev_best_score['Fscore'] < val_score['F']: + prev_best_score['Fscore'] = val_score['F'] + prev_best_score['Precision'] = val_score['P'] + prev_best_score['Recall'] = val_score['R'] + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt")) + tb_writer.close() + return prev_best_score + +def update_config(opt, config): + # for key in ["max_segment_num", "max_frame_num", "top_percent", + # "qfvs_vid_feature", "qfvs_txt_feature", "qfvs_dense_shot", + # "qfvs_score_ensemble", "qfvs_score_gather", "qfvs_loss_gather"]: + config["max_segment_num"] = opt.max_segment_num + config["max_frame_num"] = opt.max_frame_num + config["top_percent"] = opt.top_percent + config["vid_feature"] = opt.qfvs_vid_feature + config["txt_feature"] = opt.qfvs_txt_feature + config["qfvs_dense_shot"] = opt.qfvs_dense_shot + config["qfvs_score_ensemble"] = opt.qfvs_score_ensemble + config["qfvs_score_gather"] = opt.qfvs_score_gather + config["qfvs_loss_gather"] = opt.qfvs_loss_gather + return config + +def start_training(): + logger.info("Setup config, data and model...") + opt = BaseOptions().parse() + set_seed(opt.seed) + + # config = load_json("./main/config_qfvs.json") + config = {} + config = update_config(opt, config) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + + # key -> test video; value -> training videos. + qfvs_split = { + 1: [2, 3, 4], + 2: [1, 3, 4], + 3: [1, 2, 4], + 4: [1, 2, 3] + } + + scores_videos = {} + for test_id, splits in qfvs_split.items(): + if opt.qfvs_split != -1: + if test_id != opt.qfvs_split: + continue + logger.info(f"Start Training {opt.dset_name}: {test_id}") + config['train_videos'] = qfvs_split[test_id] + config['test_videos'] = [test_id] + train_dataset = DatasetQFVS(config) + train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers) + + model, criterion, optimizer, lr_scheduler = setup_model(opt) + count_parameters(model) + best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config) + scores_videos['V'+str(test_id)] = best_score + + # save the final results. + avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos) + avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos) + avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos) + scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall} + + save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json") + save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False) + + tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1) + tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None)) + tb_writer.close() + + print(scores_videos) + return + +if __name__ == '__main__': + start_training() + results = logger.info("\n\n\nFINISHED TRAINING!!!") diff --git a/main/train_hl.py b/main/train_hl.py new file mode 100644 index 0000000000000000000000000000000000000000..ceec407bc1f4ff92077bda46e90cfd7b566ca56b --- /dev/null +++ b/main/train_hl.py @@ -0,0 +1,229 @@ +import os +import pdb +import time +import json +import pprint +import random +import importlib +import numpy as np +from tqdm import tqdm, trange +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import sys +sys.path.append('/data/home/qinghonglin/univtg') +from main.config import BaseOptions, setup_model +from main.dataset import DatasetHL, prepare_batch_inputs_hl, start_end_collate_hl +from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl +from utils.model_utils import count_parameters + +import logging +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def eval_epoch(model, train_val_dataset, opt): #, nms_thresh, device): + model.eval() + + scores = [] + train_val_dataset.set_state('val') + val_loader = DataLoader( + train_val_dataset, + collate_fn=start_end_collate_hl, + batch_size=opt.eval_bsz, + num_workers=opt.num_workers, + shuffle=False, + pin_memory=opt.pin_memory + ) + + with torch.no_grad(): + for data in val_loader: + model_inputs, targets = prepare_batch_inputs_hl(data) + outputs = model(**model_inputs) + # pred_cls = outputs['pred_logits'].squeeze(-1) + # pred_cls = outputs['saliency_scores'] + # pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1) + + # pdb.set_trace() + if opt.f_loss_coef == 0: + pred_cls = outputs['saliency_scores'] + elif opt.s_loss_intra_coef == 0: + pred_cls = outputs['pred_logits'].squeeze(-1) + else: + if opt.eval_mode == 'add': + pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1) + else: + pred_cls = outputs['pred_logits'].squeeze(-1) + + pred_cls = pred_cls.detach().cpu() + scores.append(pred_cls) + map = round(train_val_dataset.evaluate(scores)['mAP'] * 100, 4) + return map + +def train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer): + logger.info(f"[Epoch {epoch_i+1}]") + model.train() + criterion.train() + + train_val_dataset.set_state('train') + train_loader = DataLoader( + train_val_dataset, + collate_fn=start_end_collate_hl, + batch_size=opt.bsz, + num_workers=opt.num_workers, + shuffle=True, + pin_memory=opt.pin_memory + ) + + # init meters + time_meters = defaultdict(AverageMeter) + loss_meters = defaultdict(AverageMeter) + + num_training_examples = len(train_loader) + timer_dataloading = time.time() + for batch_idx, batch in enumerate(train_loader): + time_meters["dataloading_time"].update(time.time() - timer_dataloading) + timer_start = time.time() + model_inputs, targets = prepare_batch_inputs_hl(batch) + time_meters["prepare_inputs_time"].update(time.time() - timer_start) + + timer_start = time.time() + outputs = model(**model_inputs) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + time_meters["model_forward_time"].update(time.time() - timer_start) + + timer_start = time.time() + optimizer.zero_grad() + losses.backward() + if opt.grad_clip > 0: + nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + optimizer.step() + time_meters["model_backward_time"].update(time.time() - timer_start) + + loss_dict["loss_overall"] = float(losses) + for k, v in loss_dict.items(): + loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) + + timer_dataloading = time.time() + if opt.debug and batch_idx == 3: + break + + # print/add logs + tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1) + for k, v in loss_meters.items(): + tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1) + + to_write = opt.train_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i+1, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) + with open(opt.train_log_filepath, "a") as f: + f.write(to_write) + + logger.info("Epoch time stats:") + for name, meter in time_meters.items(): + d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]} + logger.info(f"{name} ==> {d}") + +# train in single domain. +def train(model, criterion, optimizer, lr_scheduler, train_val_dataset, opt): + # if opt.device.type == "cuda": + # logger.info("CUDA enabled.") + # model.to(opt.device) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) + opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" + opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" + + prev_best_score = 0. + if opt.start_epoch is None: + start_epoch = -1 if opt.eval_init else 0 + else: + start_epoch = opt.start_epoch + + for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): + if epoch_i > -1: + train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer) + lr_scheduler.step() + eval_epoch_interval = opt.eval_epoch + if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: + with torch.no_grad(): + scores = eval_epoch(model, train_val_dataset, opt) + tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-{train_val_dataset.domain}-mAP", float(scores), epoch_i+1) + if prev_best_score < scores: + prev_best_score = scores + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_{train_val_dataset.domain}_best.ckpt")) + tb_writer.close() + return prev_best_score + +def start_training(): + logger.info("Setup config, data and model...") + opt = BaseOptions().parse() + set_seed(opt.seed) + + from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS + if opt.dset_name == "tvsum": + domain_splits = TVSUM_SPLITS.keys() + if opt.dset_name == "youtube": + domain_splits = YOUTUBE_SPLITS.keys() + + scores = {} + if opt.lr_warmup > 0: + # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz + total_steps = opt.n_epoch + warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps) + opt.lr_warmup = [warmup_steps, total_steps] + + domain_splits = domain_splits if not opt.domain_name else [opt.domain_name] + + for domain in domain_splits: + dataset_config = dict( + dset_name=opt.dset_name, + domain=domain, + data_path=opt.train_path, + v_feat_types=opt.v_feat_types, + v_feat_dirs=opt.v_feat_dirs, + t_feat_dir=opt.t_feat_dir, + use_tef=True + ) + dataloader = DatasetHL(**dataset_config) + + model, criterion, optimizer, lr_scheduler = setup_model(opt) + count_parameters(model) + logger.info(f"Start Training {domain}") + best_score = train(model, criterion, optimizer, lr_scheduler, dataloader, opt) + scores[domain] = best_score + scores['AVG'] = sum(scores.values()) / len(scores) + + # save the final results. + save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json") + save_json(scores, save_metrics_path, save_pretty=True, sort_keys=False) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text(f"HL-{opt.dset_name}", dict_to_markdown(scores, max_str_len=None)) + tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-avg-mAP-key", float(scores['AVG']), 1) + tb_writer.close() + # return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug + + print(opt.dset_name) + print(scores) + return + +if __name__ == '__main__': + start_training() + results = logger.info("\n\n\nFINISHED TRAINING!!!") \ No newline at end of file diff --git a/main/train_mr.py b/main/train_mr.py new file mode 100644 index 0000000000000000000000000000000000000000..1a10d029f81f86733d6dab71a3aee575917b092b --- /dev/null +++ b/main/train_mr.py @@ -0,0 +1,266 @@ +import os +import pdb +import sys +import time +import json +import pprint +import random +import numpy as np +from tqdm import tqdm, trange +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +sys.path.append('/data/home/qinghonglin/univtg') +from main.config import BaseOptions, setup_model +from main.dataset import \ + DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr +from main.inference_mr import eval_epoch, start_inference +from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown +from utils.model_utils import count_parameters + +import logging +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer): + logger.info(f"[Epoch {epoch_i+1}]") + model.train() + criterion.train() + + # init meters + time_meters = defaultdict(AverageMeter) + loss_meters = defaultdict(AverageMeter) + + num_training_examples = len(train_loader) + timer_dataloading = time.time() + for batch_idx, batch in tqdm(enumerate(train_loader), + desc="Training Iteration", + total=num_training_examples): + time_meters["dataloading_time"].update(time.time() - timer_dataloading) + + timer_start = time.time() + model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory) + time_meters["prepare_inputs_time"].update(time.time() - timer_start) + + timer_start = time.time() + + # try: + outputs = model(**model_inputs) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + time_meters["model_forward_time"].update(time.time() - timer_start) + + timer_start = time.time() + optimizer.zero_grad() + losses.backward() + + if opt.grad_clip > 0: + nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + optimizer.step() + time_meters["model_backward_time"].update(time.time() - timer_start) + + loss_dict["loss_overall"] = float(losses) # for logging only + for k, v in loss_dict.items(): + loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) + + timer_dataloading = time.time() + + # print/add logs + tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1) + for k, v in loss_meters.items(): + tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1) + + to_write = opt.train_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i+1, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) + with open(opt.train_log_filepath, "a") as f: + f.write(to_write) + + logger.info("Epoch time stats:") + for name, meter in time_meters.items(): + d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]} + logger.info(f"{name} ==> {d}") + + +def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt): + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) + opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" + opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" + + train_loader = DataLoader( + train_dataset, + collate_fn=start_end_collate_mr, + batch_size=opt.bsz, + num_workers=opt.num_workers, + shuffle=True, + pin_memory=opt.pin_memory + ) + + prev_best_score = 0. + es_cnt = 0 + if opt.start_epoch is None: + start_epoch = -1 if opt.eval_init else 0 + else: + start_epoch = opt.start_epoch + save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name) + for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): + if epoch_i > -1: + train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer) + lr_scheduler.step() + eval_epoch_interval = opt.eval_epoch + if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: + with torch.no_grad(): + metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ + eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer) + + # log + to_write = opt.eval_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]), + eval_metrics_str=json.dumps(metrics_no_nms)) + + with open(opt.eval_log_filepath, "a") as f: + f.write(to_write) + logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) + if metrics_nms is not None: + logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) + + metrics = metrics_nms if metrics_nms is not None else metrics_no_nms + for k, v in metrics["brief"].items(): + tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1) + + # stop_score = metrics["brief"]["MR-full-mAP"] + # pdb.set_trace() + stop_score = metrics["brief"][opt.main_metric] + if stop_score > prev_best_score: + es_cnt = 0 + prev_best_score = stop_score + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt")) + + best_file_paths = [e.replace("latest", "best") for e in latest_file_paths] + for src, tgt in zip(latest_file_paths, best_file_paths): + os.renames(src, tgt) + logger.info("The checkpoint file has been updated.") + else: + es_cnt += 1 + if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop + with open(opt.train_log_filepath, "a") as f: + f.write(f"Early Stop at epoch {epoch_i}") + logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n") + break + + # save ckpt + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt")) + + if (epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt")) + + if opt.debug: + break + + tb_writer.close() + + +def start_training(): + logger.info("Setup config, data and model...") + opt = BaseOptions().parse() + set_seed(opt.seed) + if opt.debug: # keep the model run deterministically + # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config. + # Enable this only when input size is fixed. + cudnn.benchmark = False + cudnn.deterministic = True + + dataset_config = dict( + dset_name=opt.dset_name, + data_path=opt.train_path, + v_feat_dirs=opt.v_feat_dirs, + q_feat_dir=opt.t_feat_dir, + v_feat_dim=opt.v_feat_dim, + q_feat_dim=opt.t_feat_dim, + q_feat_type="last_hidden_state", + max_q_l=opt.max_q_l, + max_v_l=opt.max_v_l, + ctx_mode=opt.ctx_mode, + data_ratio=opt.data_ratio, + normalize_v=not opt.no_norm_vfeat, + normalize_t=not opt.no_norm_tfeat, + clip_len=opt.clip_length, + max_windows=opt.max_windows, + span_loss_type=opt.span_loss_type, + txt_drop_ratio=opt.txt_drop_ratio, + use_cache=opt.use_cache, + add_easy_negative=opt.add_easy_negative, + easy_negative_only=opt.easy_negative_only + ) + + dataset_config["data_path"] = opt.train_path + train_dataset = DatasetMR(**dataset_config) + + if opt.eval_path is not None: + dataset_config["data_path"] = opt.eval_path + dataset_config["txt_drop_ratio"] = 0 + dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining + # dataset_config["load_labels"] = False # uncomment to calculate eval loss + eval_dataset = DatasetMR(**dataset_config) + else: + eval_dataset = None + + if opt.lr_warmup > 0: + # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz + total_steps = opt.n_epoch + warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps) + opt.lr_warmup = [warmup_steps, total_steps] + model, criterion, optimizer, lr_scheduler = setup_model(opt) + logger.info(f"Model {model}") + count_parameters(model) + logger.info("Start Training...") + train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt) + return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug + + +if __name__ == '__main__': + best_ckpt_path, eval_split_name, eval_path, debug = start_training() + if not debug: + input_args = ["--resume", best_ckpt_path, + "--eval_split_name", eval_split_name, + "--eval_path", eval_path] + + import sys + sys.argv[1:] = input_args + logger.info("\n\n\nFINISHED TRAINING!!!") + logger.info("Evaluating model at {}".format(best_ckpt_path)) + logger.info("Input args {}".format(sys.argv[1:])) + start_inference() \ No newline at end of file diff --git a/main/train_qfvs.py b/main/train_qfvs.py new file mode 100644 index 0000000000000000000000000000000000000000..65a3b155ed6d432da7eb0072f87e1bae18d8a994 --- /dev/null +++ b/main/train_qfvs.py @@ -0,0 +1,325 @@ +import os +import pdb +import time +import json +import pprint +import random +import importlib +import numpy as np +from tqdm import tqdm, trange +from collections import defaultdict + +import h5py +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import sys +sys.path.append('/Users/kevin/univtg') +from main.config import BaseOptions, setup_model +from main.dataset_qfvs import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs +from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle, l2_normalize_np_array +from utils.model_utils import count_parameters +from eval.qfvs import calculate_semantic_matching, load_videos_tag + +import logging +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def eval_epoch(model, config, opt): + model.eval() + f1_sum = 0; p_sum = 0; r_sum = 0 + + assert len(config['test_videos']) == 1 + video_id = config['test_videos'][0] + embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl") + + feat_type = config['vid_feature'] + feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r') + features = torch.from_numpy(feat['features'][()]) + seg_len = torch.from_numpy(feat['seg_len'][()]) + # seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda() + + # dim = features.shape[-1] + # ctx_l = seg_len.sum().cpu() + + # dim = features.shape[-1] + # ctx_l = features.shape[1] + # seg_len = torch.ones(ctx_l) + # features = features.reshape(-1, dim)[:ctx_l] + + # tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + # tef_ed = tef_st + 1.0 / ctx_l + # tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2) + # features = torch.cat([features, tef], dim=1) # (Lv, Dv+2) + + transfer = {"Cupglass": "Glass", + "Musicalinstrument": "Instrument", + "Petsanimal": "Animal"} + + for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)): + evaluation_num=len(files) + + mask_GT = torch.zeros(config["max_segment_num"], config["max_frame_num"], dtype=torch.bool).cuda() + for j in range(len(seg_len)): + for k in range(seg_len[j]): + mask_GT[j][k] = 1 + + for file in files: + summaries_GT=[] + with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f: + for line in f.readlines(): + summaries_GT.append(int(line.strip())) + + concept1, concept2 = file.split('_')[0:2] + + ############## + if concept1 in transfer: + concept1 = transfer[concept1] + if concept2 in transfer: + concept2 = transfer[concept2] + concept1 = embedding[concept1] + concept2 = embedding[concept2] + + concept1 = l2_normalize_np_array(concept1) + concept2 = l2_normalize_np_array(concept2) + + data = { + 'features':features, + 'seg_len': seg_len, + 'tokens_pad1':torch.from_numpy(concept1), + 'tokens_pad2':torch.from_numpy(concept2), + 'mask_GT': mask_GT + } + + input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True) + + summaries_GT = [x - 1 for x in summaries_GT] + video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat") + + if opt.f_loss_coef == 0: + output_type = 'saliency_scores' + elif opt.s_loss_intra_coef == 0: + output_type = 'pred_logits' + else: + if config['qfvs_score_ensemble'] > 0: + output_type = ['pred_logits', 'saliency_scores'] + else: + output_type = 'pred_logits' + + with torch.no_grad(): + if not isinstance(output_type, list): + score1 = model(**input1)[output_type].squeeze() + score1 = score1.masked_select(mask_GT) + + score2 = model(**input2)[output_type].squeeze() + score2 = score2.masked_select(mask_GT) + + score = model(**input_oracle)[output_type].squeeze() + score = score.masked_select(mask_GT) + else: + score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda() + for output_t in output_type: + score1 += model(**input1)[output_t].squeeze().masked_select(mask_GT) + score2 += model(**input2)[output_t].squeeze().masked_select(mask_GT) + score += model(**input_oracle)[output_t].squeeze().masked_select(mask_GT) + + if config['qfvs_score_gather'] > 0: + score = score + score1 + score2 + else: + score = score + + # since video4 features dim is greater than video_shots_tag. + score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])] + _, top_index = score.topk(int(score.shape[0] * config["top_percent"])) + + p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1) + f1_sum+=f1; r_sum+=r; p_sum+=p + + return {'F': round(100* f1_sum/evaluation_num,2) , + 'R': round(100* r_sum/evaluation_num,2) , + 'P': round(100* p_sum/evaluation_num,2) } + +def idx2time(idx): + sec1, sec2 = idx*5, (idx+1)*5 + + h1 = sec1 // 3600 + m1 = (sec1 - h1*3600) // 60 + s1 = sec1 % 60 + + h2 = sec2 // 3600 + m2 = (sec2 - h2*3600) // 60 + s2 = sec2 % 60 + print(h1,m1,s1,'\t', h2,m2,s2) + +def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer): + model.train() + criterion.train() + + # init meters + time_meters = defaultdict(AverageMeter) + loss_meters = defaultdict(AverageMeter) + + timer_dataloading = time.time() + loss_total = 0 + + for batch_idx, batch in enumerate(tqdm(train_loader)): + time_meters["dataloading_time"].update(time.time() - timer_dataloading) + timer_start = time.time() + model_input1, model_input2, model_input_oracle, \ + model_gt1, model_gt2, model_gt_oracle, \ + mask_GT = prepare_batch_inputs_qfvs(batch, config) + time_meters["prepare_inputs_time"].update(time.time() - timer_start) + + timer_start = time.time() + output1 = model(**model_input1) + output2 = model(**model_input2) + output_oracle = model(**model_input_oracle) + + loss_dict = {} + loss_dict1 = criterion(output1, model_gt1, mask_GT) + loss_dict2 = criterion(output2, model_gt2, mask_GT) + loss_dict3 = criterion(output_oracle, model_gt_oracle, mask_GT) + + weight_dict = criterion.weight_dict + if config['qfvs_loss_gather'] > 0: + for k in loss_dict1.keys(): + loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k] + else: + loss_dict = loss_dict3 + + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + loss_total += losses.item() + + time_meters["model_forward_time"].update(time.time() - timer_start) + timer_start = time.time() + optimizer.zero_grad() + losses.backward() + if opt.grad_clip > 0: + nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + optimizer.step() + time_meters["model_backward_time"].update(time.time() - timer_start) + + timer_dataloading = time.time() + return round(loss_total / len(train_loader), 2) + +# train in single domain. +def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config): + # if opt.device.type == "cuda": + # logger.info("CUDA enabled.") + # model.to(opt.device) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) + opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" + opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" + + prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0} + if opt.start_epoch is None: + start_epoch = -1 if opt.eval_init else 0 + else: + start_epoch = opt.start_epoch + + val_score = eval_epoch(model, config, opt) + tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0) + logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]" + f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]" + f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]") + for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): + if epoch_i > -1: + loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer) + lr_scheduler.step() + eval_epoch_interval = opt.eval_epoch + if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: + with torch.no_grad(): + val_score = eval_epoch(model, config, opt) + tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1) + logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]" + f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]" + f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]") + + if prev_best_score['Fscore'] < val_score['F']: + prev_best_score['Fscore'] = val_score['F'] + prev_best_score['Precision'] = val_score['P'] + prev_best_score['Recall'] = val_score['R'] + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt")) + tb_writer.close() + return prev_best_score + +def update_config(opt, config): + # for key in ["max_segment_num", "max_frame_num", "top_percent", + # "qfvs_vid_feature", "qfvs_txt_feature", "qfvs_dense_shot", + # "qfvs_score_ensemble", "qfvs_score_gather", "qfvs_loss_gather"]: + config["max_segment_num"] = opt.max_segment_num + config["max_frame_num"] = opt.max_frame_num + config["top_percent"] = opt.top_percent + config["vid_feature"] = opt.qfvs_vid_feature + config["txt_feature"] = opt.qfvs_txt_feature + config["qfvs_dense_shot"] = opt.qfvs_dense_shot + config["qfvs_score_ensemble"] = opt.qfvs_score_ensemble + config["qfvs_score_gather"] = opt.qfvs_score_gather + config["qfvs_loss_gather"] = opt.qfvs_loss_gather + return config + +def start_training(): + logger.info("Setup config, data and model...") + opt = BaseOptions().parse() + set_seed(opt.seed) + + # config = load_json("./main/config_qfvs.json") + config = {} + config = update_config(opt, config) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + + # key -> test video; value -> training videos. + qfvs_split = { + 1: [2, 3, 4], + 2: [1, 3, 4], + 3: [1, 2, 4], + 4: [1, 2, 3] + } + + scores_videos = {} + for test_id, splits in qfvs_split.items(): + logger.info(f"Start Training {opt.dset_name}: {test_id}") + config['train_videos'] = qfvs_split[test_id] + config['test_videos'] = [test_id] + train_dataset = DatasetQFVS(config) + train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers) + + model, criterion, optimizer, lr_scheduler = setup_model(opt) + count_parameters(model) + best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config) + scores_videos['V'+str(test_id)] = best_score + + # save the final results. + avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos) + avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos) + avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos) + scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall} + + save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json") + save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False) + + tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1) + tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None)) + tb_writer.close() + + print(scores_videos) + return + +if __name__ == '__main__': + start_training() + results = logger.info("\n\n\nFINISHED TRAINING!!!") diff --git a/main/train_vlp.py b/main/train_vlp.py new file mode 100644 index 0000000000000000000000000000000000000000..feed89496947c621c1a20f4f9326bcd13ec1fa52 --- /dev/null +++ b/main/train_vlp.py @@ -0,0 +1,278 @@ +import os +import pdb +import sys +import time +import json +import pprint +import random +import numpy as np +from tqdm import tqdm, trange +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +sys.path.append('/data/home/qinghonglin/univtg') +from main.config import BaseOptions, setup_model +from main.dataset import \ + DatasetVLP, start_end_collate_mr, prepare_batch_inputs_mr +from main.inference_mr import eval_epoch, start_inference +from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown +from utils.model_utils import count_parameters + +import logging +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer, cls=None): + logger.info(f"[Epoch {epoch_i+1}]") + model.train() + criterion.train() + + # init meters + time_meters = defaultdict(AverageMeter) + loss_meters = defaultdict(AverageMeter) + + num_training_examples = len(train_loader) + timer_dataloading = time.time() + for batch_idx, batch in tqdm(enumerate(train_loader), + desc="Training Iteration", + total=num_training_examples): + time_meters["dataloading_time"].update(time.time() - timer_dataloading) + + timer_start = time.time() + model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory) + time_meters["prepare_inputs_time"].update(time.time() - timer_start) + + timer_start = time.time() + + if cls is not None: + model_inputs.update(cls) + + # try: + outputs = model(**model_inputs) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + time_meters["model_forward_time"].update(time.time() - timer_start) + + timer_start = time.time() + optimizer.zero_grad() + losses.backward() + # except: + # pdb.set_trace() + if opt.grad_clip > 0: + nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + optimizer.step() + time_meters["model_backward_time"].update(time.time() - timer_start) + + loss_dict["loss_overall"] = float(losses) # for logging only + for k, v in loss_dict.items(): + loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) + + timer_dataloading = time.time() + + # print/add logs + tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1) + for k, v in loss_meters.items(): + tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1) + + to_write = opt.train_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i+1, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) + with open(opt.train_log_filepath, "a") as f: + f.write(to_write) + + logger.info("Epoch time stats:") + for name, meter in time_meters.items(): + d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]} + logger.info(f"{name} ==> {d}") + + +def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt): + if opt.device.type == "cuda": + logger.info("CUDA enabled.") + model.to(opt.device) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) + opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" + opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" + + train_loader = DataLoader( + train_dataset, + collate_fn=start_end_collate_mr, + batch_size=opt.bsz, + num_workers=opt.num_workers, + shuffle=True, + pin_memory=opt.pin_memory + ) + + if ('tal' in opt.train_path) or ('mq' in opt.train_path): + cls = { + 'src_cls': train_dataset.src_cls.cuda(), + 'src_cls_mask': train_dataset.src_cls_mask.cuda(),} + else: + cls = None + + prev_best_score = 0. + es_cnt = 0 + if opt.start_epoch is None: + start_epoch = -1 if opt.eval_init else 0 + else: + start_epoch = opt.start_epoch + save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name) + for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): + if epoch_i > -1: + train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer, cls) + lr_scheduler.step() + eval_epoch_interval = opt.eval_epoch + if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: + with torch.no_grad(): + metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ + eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer) + + # log + to_write = opt.eval_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]), + eval_metrics_str=json.dumps(metrics_no_nms)) + + with open(opt.eval_log_filepath, "a") as f: + f.write(to_write) + logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) + if metrics_nms is not None: + logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) + + metrics = metrics_nms if metrics_nms is not None else metrics_no_nms + for k, v in metrics["brief"].items(): + tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1) + + # stop_score = metrics["brief"]["MR-full-mAP"] + # pdb.set_trace() + stop_score = metrics["brief"][opt.main_metric] + if stop_score > prev_best_score: + es_cnt = 0 + prev_best_score = stop_score + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt")) + + best_file_paths = [e.replace("latest", "best") for e in latest_file_paths] + for src, tgt in zip(latest_file_paths, best_file_paths): + os.renames(src, tgt) + logger.info("The checkpoint file has been updated.") + else: + es_cnt += 1 + if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop + with open(opt.train_log_filepath, "a") as f: + f.write(f"Early Stop at epoch {epoch_i}") + logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n") + break + + # save ckpt + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt")) + + if (epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt")) + + if opt.debug: + break + + tb_writer.close() + + +def start_training(): + logger.info("Setup config, data and model...") + opt = BaseOptions().parse() + set_seed(opt.seed) + if opt.debug: # keep the model run deterministically + # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config. + # Enable this only when input size is fixed. + cudnn.benchmark = False + cudnn.deterministic = True + + dataset_config = dict( + dset_name=opt.dset_name, + data_path=opt.train_path, + v_feat_dirs=opt.v_feat_dirs, + q_feat_dir=opt.t_feat_dir, + v_feat_dim=opt.v_feat_dim, + q_feat_dim=opt.t_feat_dim, + q_feat_type="last_hidden_state", + max_q_l=opt.max_q_l, + max_v_l=opt.max_v_l, + ctx_mode=opt.ctx_mode, + data_ratio=opt.data_ratio, + normalize_v=not opt.no_norm_vfeat, + normalize_t=not opt.no_norm_tfeat, + clip_len=opt.clip_length, + max_windows=opt.max_windows, + span_loss_type=opt.span_loss_type, + txt_drop_ratio=opt.txt_drop_ratio, + use_cache=opt.use_cache, + add_easy_negative=opt.add_easy_negative, + easy_negative_only=opt.easy_negative_only + ) + + dataset_config["data_path"] = opt.train_path + train_dataset = DatasetVLP(**dataset_config) + + if opt.eval_path is not None: + dataset_config["data_path"] = opt.eval_path + dataset_config["txt_drop_ratio"] = 0 + dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining + # dataset_config["load_labels"] = False # uncomment to calculate eval loss + eval_dataset = DatasetVLP(**dataset_config) + else: + eval_dataset = None + + if opt.lr_warmup > 0: + opt.lr_warmup = opt.n_epoch + model, criterion, optimizer, lr_scheduler = setup_model(opt) + logger.info(f"Model {model}") + count_parameters(model) + logger.info("Start Training...") + train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt) + return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug + + +if __name__ == '__main__': + best_ckpt_path, eval_split_name, eval_path, debug = start_training() + if not debug: + input_args = ["--resume", best_ckpt_path, + "--eval_split_name", eval_split_name, + "--eval_path", eval_path] + + import sys + sys.argv[1:] = input_args + logger.info("\n\n\nFINISHED TRAINING!!!") + logger.info("Evaluating model at {}".format(best_ckpt_path)) + logger.info("Input args {}".format(sys.argv[1:])) + start_inference() \ No newline at end of file diff --git a/main/train_vlp_ddp.py b/main/train_vlp_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c5d735a1676dc53828ee9c97c048fe7641ba16 --- /dev/null +++ b/main/train_vlp_ddp.py @@ -0,0 +1,288 @@ +import os +import pdb +import sys +import time +import json +import pprint +import random +import numpy as np +from tqdm import tqdm, trange +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data.distributed import DistributedSampler + +sys.path.append('/data/home/qinghonglin/univtg') +from main.config import BaseOptions, setup_model +from main.dataset import \ + DatasetMR, DatasetVLP, start_end_collate_mr, prepare_batch_inputs_mr +from main.inference_mr import eval_epoch, start_inference +from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown +from utils.model_utils import count_parameters + +import logging +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + +def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer): + logger.info(f"[Epoch {epoch_i+1}]") + model.train() + criterion.train() + + # init meters + time_meters = defaultdict(AverageMeter) + loss_meters = defaultdict(AverageMeter) + + num_training_examples = len(train_loader) + timer_dataloading = time.time() + for batch_idx, batch in tqdm(enumerate(train_loader), + desc="Training Iteration", + total=num_training_examples): + time_meters["dataloading_time"].update(time.time() - timer_dataloading) + + timer_start = time.time() + model_inputs, targets = prepare_batch_inputs_mr(batch[1], torch.device("cuda", int(opt.local_rank)), non_blocking=opt.pin_memory) + time_meters["prepare_inputs_time"].update(time.time() - timer_start) + + timer_start = time.time() + + # try: + outputs = model(**model_inputs) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + time_meters["model_forward_time"].update(time.time() - timer_start) + + timer_start = time.time() + optimizer.zero_grad() + losses.backward() + + if opt.grad_clip > 0: + nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + optimizer.step() + time_meters["model_backward_time"].update(time.time() - timer_start) + + loss_dict["loss_overall"] = float(losses) # for logging only + for k, v in loss_dict.items(): + loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) + + timer_dataloading = time.time() + + # print/add logs + if int(opt.local_rank) in [0, -1]: + tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1) + for k, v in loss_meters.items(): + tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1) + + to_write = opt.train_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i+1, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) + with open(opt.train_log_filepath, "a") as f: + f.write(to_write) + + logger.info("Epoch time stats:") + for name, meter in time_meters.items(): + d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]} + logger.info(f"{name} ==> {d}") + + +def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt): + if int(opt.local_rank) in [0, -1]: + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) + opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" + opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" + else: + tb_writer = None + + train_loader = DataLoader( + train_dataset, + collate_fn=start_end_collate_mr, + batch_size=opt.bsz, + num_workers=opt.num_workers, + # shuffle=True, + pin_memory=opt.pin_memory, + sampler=DistributedSampler(train_dataset) + ) + + prev_best_score = 0. + es_cnt = 0 + if opt.start_epoch is None: + start_epoch = -1 if opt.eval_init else 0 + else: + start_epoch = opt.start_epoch + save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name) + for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): + if epoch_i > -1: + train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer) + lr_scheduler.step() + eval_epoch_interval = opt.eval_epoch + if int(opt.local_rank) in [0, -1] and opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: + with torch.no_grad(): + metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ + eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer) + + # log + to_write = opt.eval_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]), + eval_metrics_str=json.dumps(metrics_no_nms)) + + if int(opt.local_rank) in [0, -1]: + with open(opt.eval_log_filepath, "a") as f: + f.write(to_write) + logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) + if metrics_nms is not None: + logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) + + metrics = metrics_nms if metrics_nms is not None else metrics_no_nms + for k, v in metrics["brief"].items(): + tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1) + + # stop_score = metrics["brief"]["MR-full-mAP"] + # pdb.set_trace() + stop_score = metrics["brief"][opt.main_metric] + if stop_score > prev_best_score: + es_cnt = 0 + prev_best_score = stop_score + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt")) + + best_file_paths = [e.replace("latest", "best") for e in latest_file_paths] + for src, tgt in zip(latest_file_paths, best_file_paths): + os.renames(src, tgt) + logger.info("The checkpoint file has been updated.") + else: + es_cnt += 1 + if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop + with open(opt.train_log_filepath, "a") as f: + f.write(f"Early Stop at epoch {epoch_i}") + logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n") + break + + # save ckpt + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt")) + + if int(opt.local_rank) in [0, -1] and ((epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0): # additional copies + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt")) + + if opt.debug: + break + + if int(opt.local_rank) in [0, -1]: + tb_writer.close() + + +def start_training(): + # logger.info("Setup config, data and model...") + opt = BaseOptions().parse() + set_seed(opt.seed) + if opt.debug: # keep the model run deterministically + # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config. + # Enable this only when input size is fixed. + cudnn.benchmark = False + cudnn.deterministic = True + + local_rank = int(opt.local_rank) + dist.init_process_group(backend='nccl') + + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + + dataset_config = dict( + dset_name=opt.dset_name, + data_path=opt.train_path, + v_feat_dirs=opt.v_feat_dirs, + q_feat_dir=opt.t_feat_dir, + v_feat_dim=opt.v_feat_dim, + q_feat_dim=opt.t_feat_dim, + q_feat_type="last_hidden_state", + max_q_l=opt.max_q_l, + max_v_l=opt.max_v_l, + ctx_mode=opt.ctx_mode, + data_ratio=opt.data_ratio, + normalize_v=not opt.no_norm_vfeat, + normalize_t=not opt.no_norm_tfeat, + clip_len=opt.clip_length, + max_windows=opt.max_windows, + span_loss_type=opt.span_loss_type, + txt_drop_ratio=opt.txt_drop_ratio, + use_cache=opt.use_cache, + add_easy_negative=opt.add_easy_negative, + easy_negative_only=opt.easy_negative_only + ) + + dataset_config["data_path"] = opt.train_path + train_dataset = DatasetVLP(**dataset_config) + + if opt.eval_path is not None: + # perform zero-shot on qvhl. + dataset_config["data_path"] = opt.eval_path + dataset_config["txt_drop_ratio"] = 0 + if len(dataset_config["v_feat_dirs"]) == 1: + dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_clip"] + elif len(dataset_config["v_feat_dirs"]) == 2: + dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_slowfast", "data/qvhighlights/vid_clip"] + else: + raise NotImplementedError + dataset_config["q_feat_dir"] = "data/qvhighlights/txt_clip" + dataset_config["data_ratio"] = 1 + # dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining + eval_dataset = DatasetMR(**dataset_config) + else: + eval_dataset = None + + if opt.lr_warmup > 0: + # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz + total_steps = opt.n_epoch + warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps) + opt.lr_warmup = [warmup_steps, total_steps] + model, criterion, optimizer, lr_scheduler = setup_model(opt) + + model.to(device) + logger.info(f"Using {torch.cuda.device_count()} GPUs.") + model = torch.nn.parallel.DistributedDataParallel(model, + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=True) + + if int(opt.local_rank) in [0, -1]: + logger.info(f"Model {model}") + count_parameters(model) + logger.info("Start Training...") + train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt) + # return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug + return + +if __name__ == '__main__': + # best_ckpt_path, eval_split_name, eval_path, debug = start_training() + start_training() diff --git a/model/base.py b/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e30e473b5490bd74ca180db4b5e46945a8c4fcaa --- /dev/null +++ b/model/base.py @@ -0,0 +1,449 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + mask = targets['timestamp_mask'].bool() + mask_valid = targets['timestamp_window'].bool() + target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + target_classes[mask_valid] = 1 + # target_classes = targets['timestamp_window'] # soft cls. + target_classes.float() + # pdb.set_trace() + + weights = torch.zeros_like(target_classes).float() + weights[mask] = self.empty_weight[1] + weights[mask_valid] = self.empty_weight[0] + + # pdb.set_trace() + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + return {"loss_f": loss_ce.sum() / mask.sum()} + # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + mask = targets['timestamp_mask'] + selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + neg_indices_in = (saliency_scores < selected_scores) + neg_indices_in[batch_indices, pos_indices] = True + mask_invalid = neg_indices_in * mask.bool() + + sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + sim_in = sim_in + (mask_invalid + 1e-45).log() + logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i) + loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j) + + loss_saliency_intra = - loss_in_i - loss_in_j + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, hl_only=False): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr', 'vlp']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion \ No newline at end of file diff --git a/model/base_albef.py b/model/base_albef.py new file mode 100644 index 0000000000000000000000000000000000000000..3f067b66558cfb3e380126ea8e9740cb062790b4 --- /dev/null +++ b/model/base_albef.py @@ -0,0 +1,478 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder import build_transformer, Transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer_mm, transformer_v, transformer_t, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer_mm + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer_mm.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + self.transformer_v = transformer_v + self.transformer_t = transformer_t + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + + # pos embed. + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + + src_vid = self.transformer_v(src_vid, ~src_vid_mask.bool(), pos_vid) + src_txt = self.transformer_t(src_txt, ~src_txt_mask.bool(), pos_txt) + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + mask = targets['timestamp_mask'].bool() + mask_valid = targets['timestamp_window'].bool() + target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + target_classes[mask_valid] = 1 + # target_classes = targets['timestamp_window'] # soft cls. + target_classes.float() + # pdb.set_trace() + + weights = torch.zeros_like(target_classes).float() + weights[mask] = self.empty_weight[1] + weights[mask_valid] = self.empty_weight[0] + + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + return {"loss_f": loss_ce.sum() / mask.sum()} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + mask = targets['timestamp_mask'] + selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + neg_indices_in = (saliency_scores < selected_scores) + neg_indices_in[batch_indices, pos_indices] = True + mask_invalid = neg_indices_in * mask.bool() + + sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + sim_in = sim_in + (mask_invalid + 1e-45).log() + logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i) + loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j) + + loss_saliency_intra = - loss_in_i - loss_in_j + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, hl_only=False): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer_mm = build_transformer(args) + transformer_v = Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.sub_enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + transformer_t = Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.sub_enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + # pdb.set_trace() + + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer_mm, + transformer_v, + transformer_t, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr', 'vlp']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion \ No newline at end of file diff --git a/model/base_droppath.py b/model/base_droppath.py new file mode 100644 index 0000000000000000000000000000000000000000..1be7420925310f7f6e29b6a46df788b93227b4d4 --- /dev/null +++ b/model/base_droppath.py @@ -0,0 +1,449 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder_droppath import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + mask = targets['timestamp_mask'].bool() + mask_valid = targets['timestamp_window'].bool() + target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + target_classes[mask_valid] = 1 + # target_classes = targets['timestamp_window'] # soft cls. + target_classes.float() + # pdb.set_trace() + + weights = torch.zeros_like(target_classes).float() + weights[mask] = self.empty_weight[1] + weights[mask_valid] = self.empty_weight[0] + + # pdb.set_trace() + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + return {"loss_f": loss_ce.sum() / mask.sum()} + # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + mask = targets['timestamp_mask'] + selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + neg_indices_in = (saliency_scores < selected_scores) + neg_indices_in[batch_indices, pos_indices] = True + mask_invalid = neg_indices_in * mask.bool() + + sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + sim_in = sim_in + (mask_invalid + 1e-45).log() + logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i) + loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j) + + loss_saliency_intra = - loss_in_i - loss_in_j + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, hl_only=False): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr', 'vlp']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion diff --git a/model/base_droppath_ablation.py b/model/base_droppath_ablation.py new file mode 100644 index 0000000000000000000000000000000000000000..a3fb06a293c0a7130715d53cecc9b98406d70fdf --- /dev/null +++ b/model/base_droppath_ablation.py @@ -0,0 +1,474 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder_droppath import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + weight_abalation_b = targets['weight_ablation'][:,0].unsqueeze(-1) + if weight_abalation_b.sum() == 0: + return {"loss_f": torch.tensor(0).cuda(), "loss_g": torch.tensor(0).cuda()} + + mask_valid = (mask_valid * weight_abalation_b).bool() + mask_valid_full = (mask_valid_full * weight_abalation_b.unsqueeze(-1)).bool() + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + mask = targets['timestamp_mask'].bool() + mask_valid = targets['timestamp_window'].bool() + target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + target_classes[mask_valid] = 1 + # target_classes = targets['timestamp_window'] # soft cls. + target_classes.float() + # pdb.set_trace() + + weights = torch.zeros_like(target_classes).float() + weights[mask] = self.empty_weight[1] + weights[mask_valid] = self.empty_weight[0] + + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + + weight_abalation_f = targets['weight_ablation'][:,2].unsqueeze(-1) + if weight_abalation_f.sum() == 0: + return {"loss_f": torch.tensor(0).cuda()} + + mask = mask * weight_abalation_f + loss_ce = loss_ce * weight_abalation_f + return {"loss_f": loss_ce.sum() / mask.sum()} + # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + + weight_abalation_s = targets['weight_ablation'][:,3].bool() + if weight_abalation_s.sum() == 0: + return {"loss_s_inter": torch.tensor(0).cuda(), + "loss_s_intra": torch.tensor(0).cuda()} + + _idiag = idiag[weight_abalation_s] + _jdiag = jdiag[weight_abalation_s] + + loss_i = _idiag.sum() / len(_idiag) + loss_j = _jdiag.sum() / len(_jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + mask = targets['timestamp_mask'] + selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + neg_indices_in = (saliency_scores < selected_scores) + neg_indices_in[batch_indices, pos_indices] = True + mask_invalid = neg_indices_in * mask.bool() + + sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + sim_in = sim_in + (mask_invalid + 1e-45).log() + logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + _pos_logsm_in_i = pos_logsm_in_i[weight_abalation_s] + _pos_logsm_in_j = pos_logsm_in_j[weight_abalation_s] + + loss_in_i = _pos_logsm_in_i.sum() / len(_pos_logsm_in_i) + loss_in_j = _pos_logsm_in_j.sum() / len(_pos_logsm_in_j) + + loss_saliency_intra = - loss_in_i - loss_in_j + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, hl_only=False): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr', 'vlp']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion \ No newline at end of file diff --git a/model/base_droppath_qfvs.py b/model/base_droppath_qfvs.py new file mode 100644 index 0000000000000000000000000000000000000000..7236540c8ae0ada454a66cfc8d1a0018e00287e8 --- /dev/null +++ b/model/base_droppath_qfvs.py @@ -0,0 +1,476 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder_droppath import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_f": 0.} + + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + target_classes = targets["saliency_scores"].squeeze() + + weights = torch.ones_like(target_classes).float() * self.empty_weight[1] + weights[target_classes.bool()] = self.empty_weight[0] + + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none") + return {"loss_f": loss_ce.sum() / target_classes.sum()} + # return {"loss_f": loss_ce.sum() / len(target_classes)} + + # mask = targets['timestamp_mask'].bool() + # mask_valid = targets['timestamp_window'].bool() + # target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + # target_classes[mask_valid] = 1 + # # target_classes = targets['timestamp_window'] # soft cls. + # target_classes.float() + # # pdb.set_trace() + + # weights = torch.zeros_like(target_classes).float() + # weights[mask] = self.empty_weight[1] + # weights[mask_valid] = self.empty_weight[0] + + # loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + # # return {"loss_f": loss_ce.sum() / mask.sum()} + # return {"loss_f": loss_ce.sum() / mask_valid.sum()} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * qfvs mil-nce mode + pos_indices = saliency_scores.squeeze() > 0 + + sim = outputs['saliency_scores'] + sim_soft = F.softmax(sim / self.temperature, dim=0) + sim_log = torch.log(sim_soft[pos_indices]) + loss_saliency_intra = -sim_log.sum() / len(sim_log) + return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra} + + # * inter-vid mode + # vid_mem_proj = outputs["vid_mem_proj"] + # pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + # batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + # vid_feats = vid_mem_proj[batch_indices, pos_indices] + # txt_feats = outputs["txt_mem_proj"].squeeze(1) + # sim = sim_matrix(vid_feats, txt_feats) + + # i_logsm = F.log_softmax(sim / self.temperature, dim=1) + # j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # # sum over positives + # idiag = torch.diag(i_logsm) + # jdiag = torch.diag(j_logsm) + # loss_i = idiag.sum() / len(idiag) + # loss_j = jdiag.sum() / len(jdiag) + + # loss_saliency_inter = - loss_i - loss_j + + # # * intra-vid mode + # mask = targets['timestamp_mask'] + # selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + # neg_indices_in = (saliency_scores < selected_scores) + # neg_indices_in[batch_indices, pos_indices] = True + # mask_invalid = neg_indices_in * mask.bool() + + # sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + # sim_in = sim_in + (mask_invalid + 1e-45).log() + # logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + # logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + # pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + # pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + # loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i) + # loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j) + + # loss_saliency_intra = - loss_in_i - loss_in_j + + # return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, mask_GT=None): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0]) + count = mask_GT.sum() + outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0]) + # targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0]) + targets['saliency_scores'] = targets['saliency_scores'][0,:count] + + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr', 'vlp']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion \ No newline at end of file diff --git a/model/base_prompt.py b/model/base_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..5816b7429f3c8be69ca8c3f4322a11ade60b8217 --- /dev/null +++ b/model/base_prompt.py @@ -0,0 +1,460 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.prompt_learner = nn.Embedding(10, hidden_dim) + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + + src_prompt = self.prompt_learner.weight.unsqueeze(0).repeat(bs, 1, 1) + src_prompt_mask = torch.ones((bs, src_prompt.shape[1])).cuda() + + if self.training: + # src_txt = src_prompt + # src_txt_mask = torch.ones_like(src_prompt).cuda() + src_txt = torch.cat([src_prompt, src_txt], dim=1) + src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1) + else: + src_txt = torch.cat([src_prompt, src_txt], dim=1) + src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1) + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + mask = targets['timestamp_mask'].bool() + mask_valid = targets['timestamp_window'].bool() + target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + target_classes[mask_valid] = 1 + # target_classes = targets['timestamp_window'] # soft cls. + target_classes.float() + # pdb.set_trace() + + weights = torch.zeros_like(target_classes).float() + weights[mask] = self.empty_weight[1] + weights[mask_valid] = self.empty_weight[0] + + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + return {"loss_f": loss_ce.sum() / mask.sum()} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + mask = targets['timestamp_mask'] + selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + neg_indices_in = (saliency_scores < selected_scores) + neg_indices_in[batch_indices, pos_indices] = True + mask_invalid = neg_indices_in * mask.bool() + + sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + sim_in = sim_in + (mask_invalid + 1e-45).log() + logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i) + loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j) + + loss_saliency_intra = - loss_in_i - loss_in_j + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, hl_only=False): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion \ No newline at end of file diff --git a/model/base_qfvs.py b/model/base_qfvs.py new file mode 100644 index 0000000000000000000000000000000000000000..19e9553ef35763eebd97125964c7c4864b9dd7db --- /dev/null +++ b/model/base_qfvs.py @@ -0,0 +1,476 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder_droppath import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_f": 0.} + + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + target_classes = targets["saliency_scores"].squeeze() + + weights = torch.ones_like(target_classes).float() * self.empty_weight[1] + weights[target_classes.bool()] = self.empty_weight[0] + + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none") + # pdb.set_trace() + return {"loss_f": loss_ce.sum() / target_classes.sum()} + # return {"loss_f": loss_ce.sum() / len(target_classes)} + + # mask = targets['timestamp_mask'].bool() + # mask_valid = targets['timestamp_window'].bool() + # target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + # target_classes[mask_valid] = 1 + # # target_classes = targets['timestamp_window'] # soft cls. + # target_classes.float() + # # pdb.set_trace() + + # weights = torch.zeros_like(target_classes).float() + # weights[mask] = self.empty_weight[1] + # weights[mask_valid] = self.empty_weight[0] + + # loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + # # return {"loss_f": loss_ce.sum() / mask.sum()} + # return {"loss_f": loss_ce.sum() / mask_valid.sum()} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * qfvs mil-nce mode + pos_indices = saliency_scores.squeeze() > 0 + + sim = outputs['saliency_scores'] + sim_soft = F.softmax(sim / self.temperature, dim=0) + sim_log = torch.log(sim_soft[pos_indices]) + loss_saliency_intra = -sim_log.sum() / len(sim_log) + return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra} + + # * inter-vid mode + # vid_mem_proj = outputs["vid_mem_proj"] + # pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + # batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + # vid_feats = vid_mem_proj[batch_indices, pos_indices] + # txt_feats = outputs["txt_mem_proj"].squeeze(1) + # sim = sim_matrix(vid_feats, txt_feats) + + # i_logsm = F.log_softmax(sim / self.temperature, dim=1) + # j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # # sum over positives + # idiag = torch.diag(i_logsm) + # jdiag = torch.diag(j_logsm) + # loss_i = idiag.sum() / len(idiag) + # loss_j = jdiag.sum() / len(jdiag) + + # loss_saliency_inter = - loss_i - loss_j + + # # * intra-vid mode + # mask = targets['timestamp_mask'] + # selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + # neg_indices_in = (saliency_scores < selected_scores) + # neg_indices_in[batch_indices, pos_indices] = True + # mask_invalid = neg_indices_in * mask.bool() + + # sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + # sim_in = sim_in + (mask_invalid + 1e-45).log() + # logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + # logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + # pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + # pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + # loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i) + # loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j) + + # loss_saliency_intra = - loss_in_i - loss_in_j + + # return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, mask_GT=None): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + # pdb.set_trace() + outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0]) + outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0]) + targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0]) + + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr', 'vlp']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion \ No newline at end of file diff --git a/model/matcher.py b/model/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..c7a27af5fb40c43bfdc54a4277c022f0b040fe0c --- /dev/null +++ b/model/matcher.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn +import torch.nn.functional as F +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + def __init__(self, cost_class: float = 1, cost_span: float = 1, cost_giou: float = 1, + span_loss_type: str = "l1", max_v_l: int = 75): + """Creates the matcher + + Params: + cost_span: This is the relative weight of the L1 error of the span coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the spans in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_span = cost_span + self.cost_giou = cost_giou + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.foreground_label = 0 + assert cost_class != 0 or cost_span != 0 or cost_giou != 0, "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_spans": Tensor of dim [batch_size, num_queries, 2] with the predicted span coordinates, + in normalized (cx, w) format + ""pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "spans": Tensor of dim [num_target_spans, 2] containing the target span coordinates. The spans are + in normalized (cx, w) format + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_spans) + """ + bs, num_queries = outputs["pred_spans"].shape[:2] + targets = targets["span_labels"] + + # Also concat the target labels and spans + out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + tgt_spans = torch.cat([v["spans"] for v in targets]) # [num_target_spans in batch, 2] + tgt_ids = torch.full([len(tgt_spans)], self.foreground_label) # [total #spans in the batch] + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - prob[target class]. + # The 1 is a constant that doesn't change the matching, it can be omitted. + cost_class = -out_prob[:, tgt_ids] # [batch_size * num_queries, total #spans in the batch] + + if self.span_loss_type == "l1": + # We flatten to compute the cost matrices in a batch + out_spans = outputs["pred_spans"].flatten(0, 1) # [batch_size * num_queries, 2] + + # Compute the L1 cost between spans + cost_span = torch.cdist(out_spans, tgt_spans, p=1) # [batch_size * num_queries, total #spans in the batch] + + # Compute the giou cost between spans + # [batch_size * num_queries, total #spans in the batch] + cost_giou = - generalized_temporal_iou(span_cxw_to_xx(out_spans), span_cxw_to_xx(tgt_spans)) + else: + pred_spans = outputs["pred_spans"] # (bsz, #queries, max_v_l * 2) + pred_spans = pred_spans.view(bs * num_queries, 2, self.max_v_l).softmax(-1) # (bsz * #queries, 2, max_v_l) + cost_span = - pred_spans[:, 0][:, tgt_spans[:, 0]] - \ + pred_spans[:, 1][:, tgt_spans[:, 1]] # (bsz * #queries, #spans) + # pred_spans = pred_spans.repeat(1, n_spans, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, max_v_l, 2) + # tgt_spans = tgt_spans.view(1, n_spans, 2).repeat(bs * num_queries, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, 2) + # cost_span = pred_spans[tgt_spans] + # cost_span = cost_span.view(bs * num_queries, n_spans) + + # giou + cost_giou = 0 + + # Final cost matrix + # import ipdb; ipdb.set_trace() + C = self.cost_span * cost_span + self.cost_giou * cost_giou + self.cost_class * cost_class + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["spans"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +def build_matcher(args): + return HungarianMatcher( + cost_span=args.set_cost_span, cost_giou=args.set_cost_giou, + cost_class=args.set_cost_class, span_loss_type=args.span_loss_type, max_v_l=args.max_v_l + ) diff --git a/model/moment_detr.py b/model/moment_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..29fcfcf8ca078a59f8078e176c1c361de6a2d460 --- /dev/null +++ b/model/moment_detr.py @@ -0,0 +1,462 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +import torch.nn.functional as F +from torch import nn + +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +from model.transformer import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k + output: (#items, #classes) + target: int, + """ + maxk = max(topk) + num_items = output.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / num_items)) + return res + +class Model(nn.Module): + """ This is the Moment-DETR module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + num_queries, input_dropout, aux_loss=False, + contrastive_align_loss=False, contrastive_hdim=64, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + Moment-DETR can detect in a single video. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + contrastive_align_loss: If true, perform span - tokens contrastive learning + contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3) + self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + # self.foreground_thd = foreground_thd + # self.background_thd = background_thd + self.query_embed = nn.Embedding(num_queries, hidden_dim) + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.contrastive_align_loss = contrastive_align_loss + if contrastive_align_loss: + self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim) + self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim) + self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim) + + self.saliency_proj = nn.Linear(hidden_dim, 1) + self.aux_loss = aux_loss + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask): + """The forward expects two tensors: + - src_txt: [batch_size, L_txt, D_txt] + - src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels, + will convert to 1 as padding later for transformer + - src_vid: [batch_size, L_vid, D_vid] + - src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels, + will convert to 1 as padding later for transformer + + It returns a dict with the following elements: + - "pred_spans": The normalized boxes coordinates for all queries, represented as + (center_x, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + # TODO should we remove or use different positional embeddings to the src_txt? + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + # pos_txt = torch.zeros_like(src_txt) + # pad zeros for txt positions + pos = torch.cat([pos_vid, pos_txt], dim=1) + # (#layers, bsz, #queries, d), (bsz, L_vid+L_txt, d) + hs, memory = self.transformer(src, ~mask, self.query_embed.weight, pos) + outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(hs) # (#layers, bsz, #queries, 2 or max_v_l * 2) + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]} + + txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d) + vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d) + if self.contrastive_align_loss: + proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1) + proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1) + proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1) + out.update(dict( + proj_queries=proj_queries[-1], + proj_txt_mem=proj_txt_mem, + proj_vid_mem=proj_vid_mem + )) + + out["saliency_scores"] = self.saliency_proj(vid_mem).squeeze(-1) # (bsz, L_vid) + + if self.aux_loss: + # assert proj_queries and proj_txt_mem + out['aux_outputs'] = [ + {'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + if self.contrastive_align_loss: + assert proj_queries is not None + for idx, d in enumerate(proj_queries[:-1]): + out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem)) + return out + + # @torch.jit.unused + # def _set_aux_loss(self, outputs_class, outputs_coord): + # # this is a workaround to make torchscript happy, as torchscript + # # doesn't support dictionary with non-homogeneous values, such + # # as a dict having both a Tensor and a list. + # return [{'pred_logits': a, 'pred_spans': b} + # for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2] + The target spans are expected in format (center_x, w), normalized by the image size. + """ + assert 'pred_spans' in outputs + targets = targets["span_labels"] + idx = self._get_src_permutation_idx(indices) + src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2) + tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2) + if self.span_loss_type == "l1": + loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none') + loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans))) + else: # ce + n_spans = src_spans.shape[0] + src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2) + loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none') + + # giou + # src_span_indices = src_spans.max(1)[1] # (#spans, 2) + # src_span_indices[:, 1] += 1 # ed non-inclusive [st, ed) + # + # tgt_span_indices = tgt_spans + # tgt_span_indices[:, 1] += 1 + # loss_giou = 1 - torch.diag(generalized_temporal_iou(src_span_indices, tgt_span_indices)) + loss_giou = loss_span.new_zeros([1]) + + losses = {} + losses['loss_b'] = loss_span.mean() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + # TODO add foreground and background classifier. use all non-matched as background. + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2) + # idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch + idx = self._get_src_permutation_idx(indices) + target_classes = torch.full(src_logits.shape[:2], self.background_label, + dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + target_classes[idx] = self.foreground_label + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none") + losses = {'loss_f': loss_ce.mean()} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0] + return losses + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_intra": 0} + saliency_scores = outputs["saliency_scores"] # (N, L) + pos_indices = targets["saliency_pos_labels"] # (N, #pairs) + neg_indices = targets["saliency_neg_labels"] # (N, #pairs) + num_pairs = pos_indices.shape[1] # typically 2 or 4 + batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device) + pos_scores = torch.stack( + [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + neg_scores = torch.stack( + [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \ + / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale + return {"loss_s_intra": loss_saliency} + + def loss_contrastive_align(self, outputs, targets, indices, log=True): + """encourage higher scores between matched query span and input text""" + normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens + normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d) + logits = torch.einsum( + "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens) + logits = logits.sum(2) / self.temperature # (bsz, #queries) + idx = self._get_src_permutation_idx(indices) + positive_map = torch.zeros_like(logits, dtype=torch.bool) + positive_map[idx] = True + positive_logits = logits.masked_fill(~positive_map, 0) + + pos_term = positive_logits.sum(1) # (bsz, ) + num_pos = positive_map.sum(1) # (bsz, ) + neg_term = logits.logsumexp(1) # (bsz, ) + loss_nce = - pos_term / num_pos + neg_term # (bsz, ) + losses = {"loss_contrastive_align": loss_nce.mean()} + return losses + + def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True): + """encourage higher scores between matched query span and input text""" + # TODO (1) align vid_mem and txt_mem; + # TODO (2) change L1 loss as CE loss on 75 labels, similar to soft token prediction in MDETR + normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens + normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d) + logits = torch.einsum( + "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens) + logits = logits.sum(2) / self.temperature # (bsz, #queries) + idx = self._get_src_permutation_idx(indices) + positive_map = torch.zeros_like(logits, dtype=torch.bool) + positive_map[idx] = True + positive_logits = logits.masked_fill(~positive_map, 0) + + pos_term = positive_logits.sum(1) # (bsz, ) + num_pos = positive_map.sum(1) # (bsz, ) + neg_term = logits.logsumexp(1) # (bsz, ) + loss_nce = - pos_term / num_pos + neg_term # (bsz, ) + losses = {"loss_contrastive_align": loss_nce.mean()} + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx # two 1D tensors of the same length + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "contrastive_align": self.loss_contrastive_align, + "saliency": self.loss_saliency, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + # list(tuples), each tuple is (pred_span_indices, tgt_span_indices) + indices = self.matcher(outputs_without_aux, targets) + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if "saliency" == loss: # skip as it is only in the top layer + continue + kwargs = {} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + # the `num_classes` naming here is somewhat misleading. + # it indeed corresponds to `max_obj_id + 1`, where max_obj_id + # is the maximum id for a class in your dataset. For example, + # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. + # As another example, for a dataset that has a single class with id 1, + # you should pass `num_classes` to be 2 (max_obj_id + 1). + # For more details on this, check the following discussion + # https://github.com/facebookresearch/moment_bert/issues/108#issuecomment-650269223 + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + num_queries=args.num_queries, + input_dropout=args.input_dropout, + aux_loss=args.aux_loss, + # contrastive_align_loss=args.contrastive_align_loss, + # contrastive_hdim=args.contrastive_hdim, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + # if args.contrastive_align_loss: + # weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"}) + weight_dict.update(aux_weight_dict) + + losses = ['spans', 'labels', 'saliency'] + # if args.contrastive_align_loss: + # losses += ["contrastive_align"] + criterion = SetCriterion( + matcher=matcher, weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin + ) + criterion.to(device) + return model, criterion diff --git a/model/position_encoding.py b/model/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9bad0b7867faede6179cd27e0a7c859137dcb8 --- /dev/null +++ b/model/position_encoding.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn +import numpy as np + +def PositionalEncoding(n_position, d_hid): + def get_position_angle_vec(position, d_hid): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i, d_hid) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + return torch.FloatTensor(sinusoid_table) # shape:(1, maxLen(n_position), d_hid) + +class TrainablePositionalEncoding(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, max_position_embeddings, hidden_size, dropout=0.1): + super(TrainablePositionalEncoding, self).__init__() + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, input_feat): + """ + Args: + input_feat: (N, L, D) + """ + bsz, seq_length = input_feat.shape[:2] + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device) + position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L) + + position_embeddings = self.position_embeddings(position_ids) + + embeddings = self.LayerNorm(input_feat + position_embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. (To 1D sequences) + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask): + """ + Args: + x: torch.tensor, (batch_size, L, d) + mask: torch.tensor, (batch_size, L), with 1 as valid + + Returns: + + """ + assert mask is not None + x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L) + if self.normalize: + eps = 1e-6 + x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + # import pdb; pdb.set_trace() + # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2).int() / self.num_pos_feats) + + pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats) + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2) + # import ipdb; ipdb.set_trace() + return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L) + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, x, mask): + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + # elif args.position_embedding in ('v3', 'learned'): + # position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + txt_pos_embed = TrainablePositionalEncoding( + max_position_embeddings=args.max_q_l, + hidden_size=args.hidden_dim, dropout=args.input_dropout) + return position_embedding, txt_pos_embed diff --git a/model/transformer.py b/model/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1a4958ec06d8c1fd1f091362d2c5a22714ed0714 --- /dev/null +++ b/model/transformer.py @@ -0,0 +1,471 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + # TransformerEncoderLayerThin + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + # TransformerDecoderLayerThin + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + """ + Args: + src: (batch_size, L, d) + mask: (batch_size, L) + query_embed: (#queries, d) + pos_embed: (batch_size, L, d) the same as src + + Returns: + + """ + # flatten NxCxHxW to HWxNxC + bs, l, d = src.shape + src = src.permute(1, 0, 2) # (L, batch_size, d) + pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (#queries, batch_size, d) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # (L, batch_size, d) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) # (#layers, #queries, batch_size, d) + hs = hs.transpose(1, 2) # (#layers, batch_size, #qeries, d) + # memory = memory.permute(1, 2, 0) # (batch_size, d, L) + memory = memory.transpose(0, 1) # (batch_size, L, d) + return hs, memory + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + intermediate = [] + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + if self.return_intermediate: + intermediate.append(output) + + if self.norm is not None: + output = self.norm(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayerThin(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + # self.linear1 = nn.Linear(d_model, dim_feedforward) + # self.dropout = nn.Dropout(dropout) + # self.linear2 = nn.Linear(dim_feedforward, d_model) + self.linear = nn.Linear(d_model, d_model) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + # self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src2 = self.linear(src2) + src = src + self.dropout(src2) + src = self.norm(src) + # src = src + self.dropout1(src2) + # src = self.norm1(src) + # src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + # src = src + self.dropout2(src2) + # src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + """not used""" + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +class TransformerDecoderLayerThin(nn.Module): + """removed intermediate layer""" + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, d_model) + # self.linear1 = nn.Linear(d_model, dim_feedforward) + # self.dropout = nn.Dropout(dropout) + # self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + # self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + # self.dropout3 = nn.Dropout(dropout) + + # self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt2 = self.linear1(tgt2) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + # tgt = tgt + self.dropout2(tgt2) + # tgt = self.norm2(tgt) + # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + # tgt = tgt + self.dropout3(tgt2) + # tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/model/transformer_encoder.py b/model/transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..cb912bfe43022280e9e8ad3c1bcbaf926bf2ebec --- /dev/null +++ b/model/transformer_encoder.py @@ -0,0 +1,159 @@ +import copy +import pdb +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=4, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, # False as default + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, pos_embed): + """ + Args: + src: (batch_size, L, d) + mask: (batch_size, L) + query_embed: (#queries, d) -> my imple (batch_size, d) and #queries=1 + pos_embed: (batch_size, L, d) the same as src + + Returns: + + """ + # flatten NxCxHxW to HWxNxC + src = src.permute(1, 0, 2) # (L, batch_size, d) + pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + memory = memory.transpose(0, 1) + + return memory + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + intermediate = [] + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + if self.return_intermediate: + intermediate.append(output) + + if self.norm is not None: + output = self.norm(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/model/transformer_encoder_droppath.py b/model/transformer_encoder_droppath.py new file mode 100644 index 0000000000000000000000000000000000000000..536d46529a4722c1bf787d85fbace0afe1e3a33b --- /dev/null +++ b/model/transformer_encoder_droppath.py @@ -0,0 +1,194 @@ +import copy +import pdb +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=4, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, droppath=0.1, + activation="gelu", normalize_before=False, # False as default + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, droppath, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, pos_embed): + """ + Args: + src: (batch_size, L, d) + mask: (batch_size, L) + query_embed: (#queries, d) -> my imple (batch_size, d) and #queries=1 + pos_embed: (batch_size, L, d) the same as src + + Returns: + + """ + # flatten NxCxHxW to HWxNxC + src = src.permute(1, 0, 2) # (L, batch_size, d) + pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + memory = memory.transpose(0, 1) + + return memory + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + intermediate = [] + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + if self.return_intermediate: + intermediate.append(output) + + if self.norm is not None: + output = self.norm(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, droppath=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + # self.dropout1 = nn.Dropout(dropout) + # self.dropout2 = nn.Dropout(dropout) + self.droppath1 = DropPath(droppath) + self.droppath2 = DropPath(droppath) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + # src2 = self.self_attn_eff(q=q, k=k, v=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.droppath1(src2) + src = self.norm1(src) + src2 = self.linear2(self.activation(self.linear1(src))) + # src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.droppath2(src2) + src = self.norm2(src) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + droppath=args.droppath, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + +def drop_path(x, drop_prob=0.0, training=False): + """ + Stochastic Depth per sample. + """ + if drop_prob == 0.0 or not training: + return x + + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + mask.floor_() + x = x.div(keep_prob) * mask + + return x + + +class DropPath(nn.Module): + """ + Drop paths per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + + self.drop_prob = drop_prob + + def forward(self, x): + x = x.permute(1, 0, 2) + res = drop_path(x, self.drop_prob, self.training) + return res.permute(1, 0, 2) + # return drop_path(x, self.drop_prob, self.training) + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") \ No newline at end of file diff --git a/model/univtg.py b/model/univtg.py new file mode 100644 index 0000000000000000000000000000000000000000..607f8ad325ce6697ca3e49d911447489fa407f7f --- /dev/null +++ b/model/univtg.py @@ -0,0 +1,450 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder_droppath import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + device_id = src_vid.device + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).to(device_id) + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + mask = targets['timestamp_mask'].bool() + mask_valid = targets['timestamp_window'].bool() + target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + target_classes[mask_valid] = 1 + # target_classes = targets['timestamp_window'] # soft cls. + target_classes.float() + # pdb.set_trace() + + weights = torch.zeros_like(target_classes).float() + weights[mask] = self.empty_weight[1] + weights[mask_valid] = self.empty_weight[0] + + # pdb.set_trace() + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + return {"loss_f": loss_ce.sum() / mask.sum()} + # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + mask = targets['timestamp_mask'] + selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + neg_indices_in = (saliency_scores < selected_scores) + neg_indices_in[batch_indices, pos_indices] = True + mask_invalid = neg_indices_in * mask.bool() + + sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + sim_in = sim_in + (mask_invalid + 1e-45).log() + logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i) + loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j) + + loss_saliency_intra = - loss_in_i - loss_in_j + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, hl_only=False): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr', 'vlp']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion diff --git a/model/univtg_ablation.py b/model/univtg_ablation.py new file mode 100644 index 0000000000000000000000000000000000000000..a3fb06a293c0a7130715d53cecc9b98406d70fdf --- /dev/null +++ b/model/univtg_ablation.py @@ -0,0 +1,474 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder_droppath import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + weight_abalation_b = targets['weight_ablation'][:,0].unsqueeze(-1) + if weight_abalation_b.sum() == 0: + return {"loss_f": torch.tensor(0).cuda(), "loss_g": torch.tensor(0).cuda()} + + mask_valid = (mask_valid * weight_abalation_b).bool() + mask_valid_full = (mask_valid_full * weight_abalation_b.unsqueeze(-1)).bool() + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + mask = targets['timestamp_mask'].bool() + mask_valid = targets['timestamp_window'].bool() + target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + target_classes[mask_valid] = 1 + # target_classes = targets['timestamp_window'] # soft cls. + target_classes.float() + # pdb.set_trace() + + weights = torch.zeros_like(target_classes).float() + weights[mask] = self.empty_weight[1] + weights[mask_valid] = self.empty_weight[0] + + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + + weight_abalation_f = targets['weight_ablation'][:,2].unsqueeze(-1) + if weight_abalation_f.sum() == 0: + return {"loss_f": torch.tensor(0).cuda()} + + mask = mask * weight_abalation_f + loss_ce = loss_ce * weight_abalation_f + return {"loss_f": loss_ce.sum() / mask.sum()} + # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + + weight_abalation_s = targets['weight_ablation'][:,3].bool() + if weight_abalation_s.sum() == 0: + return {"loss_s_inter": torch.tensor(0).cuda(), + "loss_s_intra": torch.tensor(0).cuda()} + + _idiag = idiag[weight_abalation_s] + _jdiag = jdiag[weight_abalation_s] + + loss_i = _idiag.sum() / len(_idiag) + loss_j = _jdiag.sum() / len(_jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + mask = targets['timestamp_mask'] + selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + neg_indices_in = (saliency_scores < selected_scores) + neg_indices_in[batch_indices, pos_indices] = True + mask_invalid = neg_indices_in * mask.bool() + + sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + sim_in = sim_in + (mask_invalid + 1e-45).log() + logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + _pos_logsm_in_i = pos_logsm_in_i[weight_abalation_s] + _pos_logsm_in_j = pos_logsm_in_j[weight_abalation_s] + + loss_in_i = _pos_logsm_in_i.sum() / len(_pos_logsm_in_i) + loss_in_j = _pos_logsm_in_j.sum() / len(_pos_logsm_in_j) + + loss_saliency_intra = - loss_in_i - loss_in_j + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, hl_only=False): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr', 'vlp']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion \ No newline at end of file diff --git a/model/univtg_qfvs.py b/model/univtg_qfvs.py new file mode 100644 index 0000000000000000000000000000000000000000..0c4b133857affa8f34e4c77e753539291cb96a75 --- /dev/null +++ b/model/univtg_qfvs.py @@ -0,0 +1,476 @@ +import pdb +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from model.transformer_encoder_droppath import build_transformer +from model.matcher import build_matcher +from model.position_encoding import build_position_encoding +from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +class WeightedPool(nn.Module): + def __init__(self, dim): + super(WeightedPool, self).__init__() + weight = torch.empty(dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def forward(self, x, mask): + alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) + alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) + alphas = nn.Softmax(dim=1)(alpha) + pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) + pooled_x = pooled_x.squeeze(2) + return pooled_x + +class Model(nn.Module): + """ This is the UniVTG module that performs moment localization. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + input_dropout, aux_loss=False, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + + # Conv projector + self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) + self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground + + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + + # MLP Projector + self.weightedpool = WeightedPool(hidden_dim) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): + bs = src_vid.shape[0] + src_vid = self.input_vid_proj(src_vid) + src_txt = self.input_txt_proj(src_txt) + if src_cls is not None: + src_cls = self.input_txt_proj(src_cls) + + # type token. + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + if src_cls is not None: + src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) + + src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) + + pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) + pos = torch.cat([pos_vid, pos_txt], dim=1) + + memory = self.transformer(src, ~mask, pos) + vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) + + outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) + outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) + + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() + idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) + outputs_coord = outputs_coord * idx_mask + else: + raise NotImplementedError + + out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, + 'src_vid_mask': src_vid_mask} + + vid_mem_proj = src_vid + + # word-level -> sentence-level + txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) + sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() + + out["vid_mem_proj"] = vid_mem_proj + out["txt_mem_proj"] = txt_mem_proj + if src_cls is not None: + cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) + out["cls_mem_proj"] = cls_mem_proj + out["saliency_scores"] = sim + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + self.temperature = 0.07 + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + def loss_spans(self, outputs, targets, indices): + assert 'pred_spans' in outputs + + start_spans = targets['timestamp'] + pred_spans = outputs['pred_spans'] + src_spans = start_spans + pred_spans + gt_spans = targets['span_labels_nn'] + + mask = targets['timestamp_mask'].bool() + mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) + mask_valid = targets['timestamp_window'].bool() + mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) + + loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full + loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) + + losses = {} + losses['loss_b'] = loss_span.sum() / mask_valid.sum() + losses['loss_g'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_f": 0.} + + src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) + target_classes = targets["saliency_scores"].squeeze() + + weights = torch.ones_like(target_classes).float() * self.empty_weight[1] + weights[target_classes.bool()] = self.empty_weight[0] + + loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none") + return {"loss_f": loss_ce.sum() / target_classes.sum()} + # return {"loss_f": loss_ce.sum() / len(target_classes)} + + # mask = targets['timestamp_mask'].bool() + # mask_valid = targets['timestamp_window'].bool() + # target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + # target_classes[mask_valid] = 1 + # # target_classes = targets['timestamp_window'] # soft cls. + # target_classes.float() + # # pdb.set_trace() + + # weights = torch.zeros_like(target_classes).float() + # weights[mask] = self.empty_weight[1] + # weights[mask_valid] = self.empty_weight[0] + + # loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask + # # return {"loss_f": loss_ce.sum() / mask.sum()} + # return {"loss_f": loss_ce.sum() / mask_valid.sum()} + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * qfvs mil-nce mode + pos_indices = saliency_scores.squeeze() > 0 + + sim = outputs['saliency_scores'] + sim_soft = F.softmax(sim / self.temperature, dim=0) + sim_log = torch.log(sim_soft[pos_indices]) + loss_saliency_intra = -sim_log.sum() / len(sim_log) + return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra} + + # * inter-vid mode + # vid_mem_proj = outputs["vid_mem_proj"] + # pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + # batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + # vid_feats = vid_mem_proj[batch_indices, pos_indices] + # txt_feats = outputs["txt_mem_proj"].squeeze(1) + # sim = sim_matrix(vid_feats, txt_feats) + + # i_logsm = F.log_softmax(sim / self.temperature, dim=1) + # j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # # sum over positives + # idiag = torch.diag(i_logsm) + # jdiag = torch.diag(j_logsm) + # loss_i = idiag.sum() / len(idiag) + # loss_j = jdiag.sum() / len(jdiag) + + # loss_saliency_inter = - loss_i - loss_j + + # # * intra-vid mode + # mask = targets['timestamp_mask'] + # selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) + # neg_indices_in = (saliency_scores < selected_scores) + # neg_indices_in[batch_indices, pos_indices] = True + # mask_invalid = neg_indices_in * mask.bool() + + # sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) + # sim_in = sim_in + (mask_invalid + 1e-45).log() + # logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) + # logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) + + # pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] + # pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] + # loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i) + # loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j) + + # loss_saliency_intra = - loss_in_i - loss_in_j + + # return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def loss_saliency_cls(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + saliency_scores = targets["saliency_scores"] + if saliency_scores.sum() == 0: + return {"loss_s_inter": 0., "loss_s_intra": 0.} + + # * inter-vid mode + vid_mem_proj = outputs["vid_mem_proj"] + pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) + batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) + + vid_feats = vid_mem_proj[batch_indices, pos_indices] + txt_feats = outputs["txt_mem_proj"].squeeze(1) + sim = sim_matrix(vid_feats, txt_feats) + + i_logsm = F.log_softmax(sim / self.temperature, dim=1) + j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) + + # sum over positives + idiag = torch.diag(i_logsm) + jdiag = torch.diag(j_logsm) + loss_i = idiag.sum() / len(idiag) + loss_j = jdiag.sum() / len(jdiag) + + loss_saliency_inter = - loss_i - loss_j + + # * intra-vid mode + if 'cls_idx' not in targets.keys(): # eval + return {"loss_s_inter": loss_saliency_inter} + + cls_indices = targets['cls_idx'].bool() + cls_feats = outputs["cls_mem_proj"].squeeze(1) + sim_cls = sim_matrix(vid_feats, cls_feats) + + i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) + idiag_cls = i_logsm_cls[cls_indices] + loss_cls_i = idiag_cls.sum() / len(idiag_cls) + + loss_saliency_intra = - loss_cls_i + + return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "saliency": self.loss_saliency, + "saliency_cls": self.loss_saliency_cls, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets, mask_GT=None): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + indices = None + # Compute all the requested losses + losses = {} + outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0]) + count = mask_GT.sum() + outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0]) + # targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0]) + targets['saliency_scores'] = targets['saliency_scores'][0,:count] + + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + return losses + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class Conv(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') + for n, k in zip([input_dim] + h, h + [output_dim])) + def forward(self, x): + x = x.permute(0,2,1) + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.permute(0, 2, 1) + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(in_hsz) + layers = [ + nn.Dropout(dropout), + nn.Linear(in_hsz, out_hsz) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + + +def build_model(args): + device = torch.device(args.device) + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + model = Model( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + input_dropout=args.input_dropout, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + ) + + matcher = build_matcher(args) + weight_dict = {"loss_b": args.b_loss_coef, + "loss_g": args.g_loss_coef, + "loss_f": args.f_loss_coef, + "loss_s_intra": args.s_loss_intra_coef, + "loss_s_inter": args.s_loss_inter_coef} + + if args.dset_type in ['mr', 'vlp']: + if 'tal' not in args.train_path: + losses = ['spans', 'labels', 'saliency'] + else: + losses = ['spans', 'labels', 'saliency_cls'] + elif args.dset_type in ['hl', 'vs']: + losses = ['labels', 'saliency'] + + criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, + ) + criterion.to(device) + return model, criterion diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f550f0f568552976c3867b559c32dedfa79e852a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,355 @@ +absl-py==1.2.0 +accelerate==0.19.0 +aiodns==3.0.0 +aiofiles==23.1.0 +aiohttp==3.8.3 +aiohttp-socks==0.7.1 +aiosignal==1.3.1 +altair==5.0.1 +antiorm==1.2.1 +antlr4-python3-runtime==4.9.3 +anyio==3.7.0 +appdirs==1.4.4 +argilla==1.8.0 +argon2-cffi==21.3.0 +argon2-cffi-bindings==21.2.0 +asttokens==2.0.7 +async-timeout==4.0.2 +attrs==22.1.0 +Babel==2.12.1 +backcall==0.2.0 +backoff==2.2.1 +beautifulsoup4==4.11.1 +bert-score==0.3.13 +black==22.3.0 +bleach==5.0.1 +blis==0.7.9 +boto3==1.24.84 +botocore==1.27.84 +Brotli==1.0.9 +brotlipy==0.7.0 +cachetools==5.2.0 +catalogue==2.0.8 +cchardet==2.1.7 +certifi==2023.5.7 +cffi==1.15.1 +chardet==5.1.0 +charset-normalizer==2.1.1 +cinemagoer==2023.5.1 +click==8.1.3 +cloudpickle==2.2.0 +cmake==3.26.3 +coloredlogs==15.0.1 +colorlog==6.7.0 +commonmark==0.9.1 +confection==0.0.4 +contourpy==1.0.6 +cryptography==37.0.1 +cycler==0.11.0 +cymem==2.0.7 +dataclasses==0.6 +dataclasses-json==0.5.7 +dataflow==0.9.5 +db==0.1.1 +db-sqlite3==0.0.1 +debugpy==1.6.3 +decoder==0.5 +decorator==4.4.2 +decord==0.6.0 +defusedxml==0.7.1 +Deprecated==1.2.14 +detectron2==0.6 +docker==6.0.0 +docker-pycreds==0.4.0 +easydict==1.9 +ego4d==1.2.5 +einops==0.6.0 +elastic-transport==8.4.0 +elasticsearch==8.5.0 +entrypoints==0.4 +et-xmlfile==1.1.0 +exceptiongroup==1.1.1 +executing==0.10.0 +fairscale==0.4.12 +fake-useragent==0.1.14 +fastapi==0.98.0 +fastjsonschema==2.16.1 +ffmpeg==1.4 +ffmpeg-python==0.2.0 +ffmpy==0.3.0 +ffprobe==0.5 +filelock==3.7.1 +fonttools==4.38.0 +frozenlist==1.3.3 +fsspec==2023.5.0 +ftfy==6.1.1 +future==0.18.2 +fvcore==0.1.5.post20220512 +gdown==4.7.1 +gensim==4.2.0 +geographiclib==2.0 +geopy==2.3.0 +gitdb==4.0.10 +GitPython==3.1.31 +glide-text2im==0.0.0 +google-api-core==2.11.1 +google-api-python-client==2.95.0 +google-auth==2.22.0 +google-auth-httplib2==0.1.0 +google-auth-oauthlib==0.4.6 +google-cloud==0.34.0 +google-cloud-vision==3.4.4 +google-measurement-protocol==1.1.0 +googleapis-common-protos==1.59.1 +googletransx==2.4.2 +gradio==3.23.0 +greenlet==2.0.2 +grpcio==1.56.2 +grpcio-status==1.56.2 +h11==0.14.0 +h5py==3.7.0 +httpcore==0.16.3 +httplib2==0.22.0 +httpx==0.23.3 +huggingface-hub==0.15.1 +humanfriendly==10.0 +hydra-core==1.2.0 +idna==3.3 +imageio==2.31.0 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +iopath==0.1.9 +ipdb==0.13.11 +ipykernel==6.15.3 +ipython==8.4.0 +ipython-genutils==0.2.0 +ipywidgets==8.0.2 +jedi==0.18.1 +Jinja2==3.1.2 +jmespath==1.0.1 +joblib==1.1.0 +jsonlines==3.1.0 +jsonschema==4.16.0 +jupyter==1.0.0 +jupyter_client==7.3.5 +jupyter-console==6.4.4 +jupyter-core==4.11.1 +jupyterlab-pygments==0.2.2 +jupyterlab-widgets==3.0.3 +kiwisolver==1.4.4 +langchain==0.0.191 +langcodes==3.3.0 +language-evaluation==0.1.0 +lazy_loader==0.2 +linkify-it-py==2.0.2 +lit==16.0.5.post0 +lxml==4.9.1 +Markdown==3.4.1 +markdown-it-py==2.2.0 +markdown2==2.4.9 +MarkupSafe==2.1.1 +marshmallow==3.19.0 +marshmallow-enum==1.5.1 +matplotlib==3.6.2 +matplotlib-inline==0.1.3 +mdit-py-plugins==0.3.3 +mdurl==0.1.2 +mistune==2.0.4 +mkl-fft==1.3.1 +mkl-random==1.2.2 +mkl-service==2.4.0 +monotonic==1.6 +more-itertools==9.1.0 +moviepy==1.0.3 +mpmath==1.3.0 +msg-parser==1.2.0 +msgpack==1.0.4 +msgpack-numpy==0.4.8 +multidict==6.0.4 +murmurhash==1.0.9 +mutagen==1.46.0 +mypy-extensions==0.4.3 +nbclient==0.6.8 +nbconvert==7.0.0 +nbformat==5.5.0 +nest-asyncio==1.5.5 +networkx==2.8.7 +nh3==0.2.13 +nltk==3.7 +nms-1d-cpu==0.0.0 +nncore==0.3.6 +notebook==6.4.12 +numexpr==2.8.4 +numpy==1.23.1 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.2.10.91 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-nccl-cu11==2.14.3 +nvidia-nvtx-cu11==11.7.91 +oauthlib==3.2.0 +olefile==0.46 +omegaconf==2.2.3 +openai==0.27.7 +openapi-schema-pydantic==1.2.4 +opencv-python==4.5.4.58 +openpyxl==3.1.2 +orjson==3.9.1 +ortools==9.4.1874 +packaging==21.3 +pandas==1.5.2 +pandocfilters==1.5.0 +parso==0.8.3 +pathspec==0.10.1 +pathtools==0.1.2 +pathy==0.10.1 +pdfminer.six==20221105 +peft==0.3.0 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==9.3.0 +pip==22.2.2 +pkgutil_resolve_name==1.3.10 +platformdirs==2.5.2 +portalocker==2.5.1 +preshed==3.0.8 +prices==1.1.1 +proglog==0.1.10 +prometheus-client==0.14.1 +prompt-toolkit==3.0.30 +proto-plus==1.22.3 +protobuf==3.20.1 +psutil==5.9.2 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycares==4.2.2 +pycipher==0.5.2 +pycocoevalcap==1.2 +pycocotools==2.0.5 +pycparser==2.21 +pycryptodomex==3.18.0 +pydantic==1.10.8 +pydot==1.4.2 +pydub==0.25.1 +pyfiglet==0.8.post1 +Pygments==2.12.0 +pynvml==11.5.0 +pyOpenSSL==22.0.0 +pypandoc==1.11 +pyparsing==3.0.9 +pyrsistent==0.18.1 +PySocks==1.7.1 +python-dateutil==2.8.2 +python-docx==0.8.11 +python-hostlist==1.21 +python-magic==0.4.27 +python-multipart==0.0.6 +python-pptx==0.6.21 +python-socks==2.0.3 +pytz==2022.7 +PyWavelets==1.4.1 +PyYAML==6.0 +pyzmq==23.2.1 +qtconsole==5.3.2 +QtPy==2.2.0 +regex==2022.7.25 +requests==2.28.1 +requests-oauthlib==1.3.1 +rfc3986==1.5.0 +rich==13.0.1 +rouge-score==0.1.2 +rsa==4.9 +ruamel.yaml==0.17.21 +ruamel.yaml.clib==0.2.7 +s3transfer==0.6.0 +sacremoses==0.0.53 +safetensors==0.3.1 +schedule==1.1.0 +scikit-image==0.21.0 +scikit-learn==1.1.2 +scipy==1.9.3 +seaborn==0.12.0 +semantic-version==2.10.0 +Send2Trash==1.8.0 +sentencepiece==0.1.99 +sentry-sdk==1.26.0 +setproctitle==1.3.2 +setuptools==59.5.0 +shortuuid==1.0.11 +simplejson==3.17.6 +six==1.16.0 +smart-open==6.2.0 +smmap==5.0.0 +sniffio==1.3.0 +soupsieve==2.3.2.post1 +spacy==3.5.3 +spacy-legacy==3.0.12 +spacy-loggers==1.0.4 +SQLAlchemy==2.0.15 +srsly==2.4.6 +stack-data==0.4.0 +starlette==0.27.0 +svgwrite==1.4.3 +sympy==1.12 +tabulate==0.8.10 +tenacity==8.2.2 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +termcolor==1.1.0 +terminado==0.15.0 +terminaltables==3.1.10 +thinc==8.1.10 +threadpoolctl==3.1.0 +tifffile==2023.4.12 +timm==0.4.12 +tinycss2==1.1.1 +tokenizers==0.13.2 +tomli==2.0.1 +toolz==0.12.0 +torch==2.0.1 +torchaudio==0.9.0a0+33b2469 +torchdata==0.6.1 +torchtext==0.15.2 +torchvision==0.10.0a0 +tornado==6.2 +tqdm==4.64.1 +traitlets==5.3.0 +transformers==4.28.1 +triton==2.0.0 +twint==2.1.21 +typer==0.7.0 +typing_extensions==4.3.0 +typing-inspect==0.9.0 +uc-micro-py==1.0.2 +unstructured==0.7.1 +uritemplate==4.1.1 +urllib3==1.26.12 +uvicorn==0.22.0 +wandb==0.15.4 +warmup-scheduler==0.3 +wasabi==1.1.2 +wavedrom==2.0.3.post3 +wcwidth==0.2.5 +webencodings==0.5.1 +websocket-client==1.4.1 +websockets==11.0.3 +Werkzeug==2.2.1 +wheel==0.37.1 +widgetsnbextension==4.0.3 +wrapt==1.14.1 +xlrd==2.0.1 +XlsxWriter==3.1.2 +yacs==0.1.8 +yarl==1.9.2 +youtube-dl==2021.12.17 +yt-dlp==2023.3.4 +zipp==3.8.1 diff --git a/results/omni/opt.json b/results/omni/opt.json new file mode 100644 index 0000000000000000000000000000000000000000..7fbfd399c087577661ae58fdf6be0f8990b2a789 --- /dev/null +++ b/results/omni/opt.json @@ -0,0 +1,111 @@ +{ + "dset_type": "vlp", + "dset_name": "vlp", + "domain_name": null, + "model_id": "univtg", + "exp_id": "omni_mini_aio_unified__epo3_f10_b10g1_s0.1_0.1", + "device": 0, + "gpu_id": 0, + "debug": false, + "seed": 2018, + "local_rank": 0, + "eval_split_name": "val", + "data_ratio": 1.0, + "results_root": "results", + "num_workers": 8, + "no_pin_memory": false, + "bsz": 64, + "n_epoch": 100, + "max_es_cnt": 200, + "lr": 0.0001, + "lr_drop": 200, + "lr_gamma": 0.1, + "lr_warmup": 10.0, + "wd": 0.0001, + "grad_clip": 0.1, + "span_loss_type": "l1", + "b_loss_coef": 10.0, + "g_loss_coef": 1.0, + "eos_coef": 0.1, + "f_loss_coef": 10.0, + "s_loss_intra_coef": 0.1, + "s_loss_inter_coef": 0.1, + "main_metric": "MR-full-R1@0.3-key", + "eval_mode": null, + "eval_bsz": 32, + "eval_epoch": 5, + "eval_init": true, + "save_interval": 5, + "resume": "/data/home/qinghonglin/univtg/results/vlp-vlp/aio_unified_mini-clip-clip-2023_05_27_00/model_e0003.ckpt", + "resume_dir": null, + "resume_all": false, + "start_epoch": null, + "no_sort_results": false, + "max_before_nms": 1000, + "max_after_nms": 10, + "conf_thd": 0.0, + "nms_thd": 0.7, + "use_cache": -1, + "max_q_l": 75, + "max_v_l": 75, + "clip_length": 2.0, + "clip_len_list": null, + "max_windows": 5, + "add_easy_negative": 1, + "easy_negative_only": 1, + "round_multiple": 1, + "train_path": [ + "data/qvhighlights/metadata/qvhighlights_train.jsonl", + "data/charades/metadata/charades_train.jsonl", + "data/ego4d/metadata/nlq_train.jsonl", + "data/tacos/metadata/train.jsonl", + "data/anet/metadata/train.jsonl", + "data/didemo/metadata/train.jsonl" + ], + "eval_path": "data/qvhighlights/metadata/qvhighlights_val.jsonl", + "train_path_list": null, + "eval_path_list": null, + "feat_root_list": null, + "no_norm_vfeat": false, + "no_norm_tfeat": false, + "v_feat_dirs": [ + "vid_clip" + ], + "t_feat_dir": "txt_clip", + "v_feat_dim": 512, + "t_feat_dim": 512, + "ctx_mode": "video_tef", + "v_feat_types": "clip", + "t_feat_type": "clip", + "position_embedding": "sine", + "n_input_proj": 2, + "temperature": 0.07, + "enc_layers": 4, + "sub_enc_layers": 2, + "dec_layers": 2, + "dim_feedforward": 1024, + "hidden_dim": 512, + "input_dropout": 0.5, + "dropout": 0.0, + "droppath": 0.1, + "txt_drop_ratio": 0, + "use_txt_pos": false, + "nheads": 8, + "num_queries": 10, + "pre_norm": false, + "set_cost_span": 10, + "set_cost_giou": 1, + "set_cost_class": 4, + "saliency_margin": 0.2, + "aux_loss": false, + "max_segment_num": 20, + "max_frame_num": 200, + "top_percent": 0.02, + "qfvs_vid_feature": "fps1", + "qfvs_txt_feature": "query", + "qfvs_dense_shot": -1, + "qfvs_score_ensemble": -1, + "qfvs_score_gather": -1, + "qfvs_loss_gather": -1, + "results_dir": "results/vlp-vlp/omni_mini_aio_unified__epo3_f10_b10g1_s0.1_0.1-clip-clip-2023_05_31_06" +} \ No newline at end of file diff --git a/run_on_video/__init__.py b/run_on_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed2adb16c75bea2cdd119bc3c19b67bf68bcc37 --- /dev/null +++ b/run_on_video/__init__.py @@ -0,0 +1 @@ +from run_on_video.video_extractor import vid2clip, txt2clip diff --git a/run_on_video/clip/__init__.py b/run_on_video/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/run_on_video/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/run_on_video/clip/bpe_simple_vocab_16e6.txt.gz b/run_on_video/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/run_on_video/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/run_on_video/clip/clip.py b/run_on_video/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..9000dd80de5171b359dc336d0c955cc2332d12d4 --- /dev/null +++ b/run_on_video/clip/clip.py @@ -0,0 +1,195 @@ +import hashlib +import os +import urllib +import warnings +from typing import Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", +} + + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=Image.BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name]) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, max_valid_length: int = 32) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + max_valid_length: + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text)[:max_valid_length-2] + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/run_on_video/clip/model.py b/run_on_video/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..658fbbf5ed1379f9c0179fa456635a9ed6d4d4de --- /dev/null +++ b/run_on_video/clip/model.py @@ -0,0 +1,432 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisualTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + eos_x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return dict(last_hidden_state=x, pooler_output=eos_x) + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/run_on_video/clip/simple_tokenizer.py b/run_on_video/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/run_on_video/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/run_on_video/clip_feature_extractor.py b/run_on_video/clip_feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..2852db020d8bcc74fd52cc2e0cd97d59c61b6b94 --- /dev/null +++ b/run_on_video/clip_feature_extractor.py @@ -0,0 +1,101 @@ +import pdb +import torch as th +import math +import numpy as np +import torch +from video_loader import VideoLoader +from torch.utils.data import DataLoader +import argparse +from preprocessing import Preprocessing +import torch.nn.functional as F +from tqdm import tqdm +import os +import sys +from feature_extractor import clip +import argparse + +################################# +model_version = "ViT-B/32" +output_feat_size = 512 +clip_len = 2 +overwrite = True +num_decoding_thread = 4 +half_precision = False + +@torch.no_grad() +def extractor(vid_path, text, output_file): + dataset = VideoLoader( + vid_path, + framerate=1/clip_len, + size=224, + centercrop=True, + overwrite=overwrite, + model_version=model_version + ) + n_dataset = len(dataset) + loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=num_decoding_thread, + sampler=sampler if n_dataset > 10 else None, + ) + preprocess = Preprocessing() + model, _ = clip.load(model_version, device="cuda", jit=False) + + encoded_texts = clip.tokenize(text).to('cuda') + text_feature = model.encode_text(encoded_texts)['last_hidden_state'] + valid_lengths = (encoded_texts != 0).sum(1).tolist()[0] + text_feature = text_feature[0, :valid_lengths].cpu().numpy() + np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature) + + totatl_num_frames = 0 + with th.no_grad(): + for k, data in enumerate(tqdm(loader)): + input_file = data['input'][0] + if os.path.isfile(output_file): + # print(f'Video {input_file} already processed.') + continue + elif not os.path.isfile(input_file): + print(f'{input_file}, does not exist.\n') + elif len(data['video'].shape) > 4: + video = data['video'].squeeze(0) + if len(video.shape) == 4: + video = preprocess(video) + n_chunk = len(video) + vid_features = th.cuda.FloatTensor( + n_chunk, output_feat_size).fill_(0) + n_iter = int(math.ceil(n_chunk)) + for i in range(n_iter): + min_ind = i + max_ind = (i + 1) + video_batch = video[min_ind:max_ind].cuda() + batch_features = model.encode_image(video_batch) + vid_features[min_ind:max_ind] = batch_features + vid_features = vid_features.cpu().numpy() + if half_precision: + vid_features = vid_features.astype('float16') + totatl_num_frames += vid_features.shape[0] + # safeguard output path before saving + dirname = os.path.dirname(output_file) + if not os.path.exists(dirname): + print(f"Output directory {dirname} does not exists, creating...") + os.makedirs(dirname) + np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features) + else: + print(f'{input_file}, failed at ffprobe.\n') + + print(f"Total number of frames: {totatl_num_frames}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='') + parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4') + parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.') + parser.add_argument('--save_dir', type=str, default='./tmp') + args = parser.parse_args() + + query = ' '.join(args.text) + + print(args.vid_path) + print(query) + extractor(args.vid_path, [query], args.save_dir) diff --git a/run_on_video/data_utils.py b/run_on_video/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44c073f4689b20bd8f7adb2aac77c0b7523ed8dd --- /dev/null +++ b/run_on_video/data_utils.py @@ -0,0 +1,170 @@ +import torch +import os +import numpy as np +import ffmpeg +import math +from run_on_video import clip + + +class ClipFeatureExtractor: + def __init__(self, framerate=1/2, size=224, centercrop=True, model_name_or_path="ViT-B/32", device="cuda"): + self.video_loader = VideoLoader(framerate=framerate, size=size, centercrop=centercrop) + print("Loading CLIP models") + self.clip_extractor, _ = clip.load(model_name_or_path, device=device, jit=False) + self.tokenizer = clip.tokenize + self.video_preprocessor = Preprocessing() + self.device = device + + @torch.no_grad() + def encode_video(self, video_path: str, bsz=60): + video_frames = self.video_loader.read_video_from_file(video_path) # (T, H, W, 3) + video_frames = self.video_preprocessor(video_frames) + n_frames = len(video_frames) + n_batch = int(math.ceil(n_frames / bsz)) + video_features = [] + for i in range(n_batch): + st_idx = i * bsz + ed_idx = (i+1) * bsz + _video_frames = video_frames[st_idx:ed_idx].to(self.device) + _video_features = self.clip_extractor.encode_image(_video_frames) + video_features.append(_video_features) + video_features = torch.cat(video_features, dim=0) + return video_features # (T=#frames, d) torch tensor + + @torch.no_grad() + def encode_text(self, text_list, bsz=60): + n_text = len(text_list) + n_batch = int(math.ceil(n_text / bsz)) + text_features = [] + for i in range(n_batch): + st_idx = i * bsz + ed_idx = (i+1) * bsz + encoded_texts = self.tokenizer(text_list[st_idx:ed_idx], context_length=77).to(self.device) + output = self.clip_extractor.encode_text(encoded_texts) + valid_lengths = (encoded_texts != 0).sum(1).tolist() + batch_last_hidden_states = output["last_hidden_state"] + for j, valid_len in enumerate(valid_lengths): + text_features.append(batch_last_hidden_states[j, :valid_len]) + return text_features # List([L_j, d]) torch tensor + + +def convert_to_float(frac_str): + try: + return float(frac_str) + except ValueError: + try: + num, denom = frac_str.split('/') + except ValueError: + return None + try: + leading, num = num.split(' ') + except ValueError: + return float(num) / float(denom) + if float(leading) < 0: + sign_mult = -1 + else: + sign_mult = 1 + return float(leading) + sign_mult * (float(num) / float(denom)) + + +class Normalize(object): + + def __init__(self, mean, std): + self.mean = torch.FloatTensor(mean).view(1, 3, 1, 1) + self.std = torch.FloatTensor(std).view(1, 3, 1, 1) + + def __call__(self, tensor): + tensor = (tensor - self.mean) / (self.std + 1e-8) + return tensor + + +class Preprocessing(object): + + def __init__(self): + self.norm = Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + + def __call__(self, tensor): + tensor = tensor / 255.0 + tensor = self.norm(tensor) + return tensor + + +class VideoLoader: + """Pytorch video loader. + Copied and modified from: + https://github.com/linjieli222/HERO_Video_Feature_Extractor/blob/main/clip/video_loader.py + """ + def __init__( + self, + framerate=1/2, + size=224, + centercrop=True, + ): + self.centercrop = centercrop + self.size = size + self.framerate = framerate + + def _get_video_info(self, video_path): + probe = ffmpeg.probe(video_path) + video_stream = next((stream for stream in probe['streams'] + if stream['codec_type'] == 'video'), None) + width = int(video_stream['width']) + height = int(video_stream['height']) + fps = math.floor(convert_to_float(video_stream['avg_frame_rate'])) + try: + frames_length = int(video_stream['nb_frames']) + duration = float(video_stream['duration']) + except Exception: + frames_length, duration = -1, -1 + info = {"duration": duration, "frames_length": frames_length, + "fps": fps, "height": height, "width": width} + return info + + def _get_output_dim(self, h, w): + if isinstance(self.size, tuple) and len(self.size) == 2: + return self.size + elif h >= w: + return int(h * self.size / w), self.size + else: + return self.size, int(w * self.size / h) + + def read_video_from_file(self, video_path): + try: + info = self._get_video_info(video_path) + h, w = info["height"], info["width"] + except Exception: + print('ffprobe failed at: {}'.format(video_path)) + return {'video': torch.zeros(1), 'input': video_path, + 'info': {}} + height, width = self._get_output_dim(h, w) + try: + duration = info["duration"] + fps = self.framerate + if duration > 0 and duration < 1/fps+0.1: + fps = 2/max(int(duration), 1) + print(duration, fps) + except Exception: + fps = self.framerate + cmd = ( + ffmpeg + .input(video_path) + .filter('fps', fps=fps) + .filter('scale', width, height) + ) + if self.centercrop: + x = int((width - self.size) / 2.0) + y = int((height - self.size) / 2.0) + cmd = cmd.crop(x, y, self.size, self.size) + out, _ = ( + cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True) + ) + if self.centercrop and isinstance(self.size, int): + height, width = self.size, self.size + video = np.frombuffer(out, np.uint8).reshape( + [-1, height, width, 3]) + video = torch.from_numpy(video.astype('float32')) + video = video.permute(0, 3, 1, 2) + return video diff --git a/run_on_video/preprocessing.py b/run_on_video/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..93b3e8112b299c93667fb433ac683fa7f46e0fda --- /dev/null +++ b/run_on_video/preprocessing.py @@ -0,0 +1,25 @@ +import torch as th + + +class Normalize(object): + + def __init__(self, mean, std): + self.mean = th.FloatTensor(mean).view(1, 3, 1, 1) + self.std = th.FloatTensor(std).view(1, 3, 1, 1) + + def __call__(self, tensor): + tensor = (tensor - self.mean) / (self.std + 1e-8) + return tensor + + +class Preprocessing(object): + + def __init__(self): + self.norm = Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + + def __call__(self, tensor): + tensor = tensor / 255.0 + tensor = self.norm(tensor) + return tensor diff --git a/run_on_video/text_extractor.py b/run_on_video/text_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..e26cf4d7968ab718675c49e85907be38fcd5e2f3 --- /dev/null +++ b/run_on_video/text_extractor.py @@ -0,0 +1,36 @@ +import pdb +import sys +import json +import torch +import numpy as np +from run_on_video.data_utils import ClipFeatureExtractor +import torch.nn.functional as F +import tqdm +import os + +query_list = [] +qid_list = [] +dataset = 'charades' +split = 'test' + +save_dir = f'' + +with open(f"data/{dataset}/metadata/{dataset}_{split}.jsonl", 'r') as f: + while True: + line = f.readline() + if not line: + break + js = json.loads(line) + query_list.append(js['query']) + qid_list.append(str(js['qid'])) + +# clip +feature_extractor = ClipFeatureExtractor( + framerate=1 / 2, size=224, centercrop=True, + model_name_or_path="ViT-B/32", device='cuda' +) +# pdb.set_trace() +query_feats = feature_extractor.encode_text(query_list) + +for i in tqdm.tqdm(range(len(query_feats))): + np.savez(save_dir + '/' + qid_list[i], last_hidden_state=query_feats[i].cpu().numpy()) diff --git a/run_on_video/video_extractor.py b/run_on_video/video_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..8804063be0b6de26c8ea63c5d035dfd4a27e4884 --- /dev/null +++ b/run_on_video/video_extractor.py @@ -0,0 +1,94 @@ +import pdb +import torch as th +import math +import numpy as np +import torch +from run_on_video.video_loader import VideoLoader +from torch.utils.data import DataLoader +import argparse +from run_on_video.preprocessing import Preprocessing +import torch.nn.functional as F +from tqdm import tqdm +import os +import sys +from run_on_video import clip +import argparse + +################################# +@torch.no_grad() +def vid2clip(model, vid_path, output_file, + model_version="ViT-B/32", output_feat_size=512, + clip_len=2, overwrite=True, num_decoding_thread=4, half_precision=False): + dataset = VideoLoader( + vid_path, + framerate=1/clip_len, + size=224, + centercrop=True, + overwrite=overwrite, + model_version=model_version + ) + n_dataset = len(dataset) + loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=num_decoding_thread, + sampler=None, + ) + preprocess = Preprocessing() + device_id = next(model.parameters()).device + + totatl_num_frames = 0 + with th.no_grad(): + for k, data in enumerate(tqdm(loader)): + input_file = data['input'][0] + if os.path.isfile(output_file): + # print(f'Video {input_file} already processed.') + continue + elif not os.path.isfile(input_file): + print(f'{input_file}, does not exist.\n') + elif len(data['video'].shape) > 4: + video = data['video'].squeeze(0) + if len(video.shape) == 4: + video = preprocess(video) + n_chunk = len(video) + vid_features = th.cuda.FloatTensor( + n_chunk, output_feat_size).fill_(0) + n_iter = int(math.ceil(n_chunk)) + for i in range(n_iter): + min_ind = i + max_ind = (i + 1) + video_batch = video[min_ind:max_ind].to(device_id) + batch_features = model.encode_image(video_batch) + vid_features[min_ind:max_ind] = batch_features + vid_features = vid_features.cpu().numpy() + if half_precision: + vid_features = vid_features.astype('float16') + totatl_num_frames += vid_features.shape[0] + # safeguard output path before saving + dirname = os.path.dirname(output_file) + if not os.path.exists(dirname): + print(f"Output directory {dirname} does not exists, creating...") + os.makedirs(dirname) + np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features) + else: + print(f'{input_file}, failed at ffprobe.\n') + print(f"Total number of frames: {totatl_num_frames}") + return vid_features + +def txt2clip(model, text, output_file): + device_id = next(model.parameters()).device + encoded_texts = clip.tokenize(text).to(device_id) + text_feature = model.encode_text(encoded_texts)['last_hidden_state'] + valid_lengths = (encoded_texts != 0).sum(1).tolist()[0] + text_feature = text_feature[0, :valid_lengths].detach().cpu().numpy() + + np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature) + return text_feature + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='') + parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4') + parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.') + parser.add_argument('--save_dir', type=str, default='./tmp') + args = parser.parse_args() diff --git a/run_on_video/video_loader.py b/run_on_video/video_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..7dddbf5df1b22187226fc53ec260663c708d8452 --- /dev/null +++ b/run_on_video/video_loader.py @@ -0,0 +1,125 @@ +import torch as th +from torch.utils.data import Dataset +import pandas as pd +import os +import numpy as np +import ffmpeg +import math + + +def convert_to_float(frac_str): + try: + return float(frac_str) + except ValueError: + try: + num, denom = frac_str.split('/') + except ValueError: + return None + try: + leading, num = num.split(' ') + except ValueError: + return float(num) / float(denom) + if float(leading) < 0: + sign_mult = -1 + else: + sign_mult = 1 + return float(leading) + sign_mult * (float(num) / float(denom)) + + +class VideoLoader(Dataset): + """Pytorch video loader.""" + + def __init__( + self, + vid_path, + framerate=1, + size=112, + centercrop=False, + overwrite=False, + model_version="ViT-B/32", + ): + """ + Args: + """ + self.vid_path = vid_path + + self.centercrop = centercrop + self.size = size + self.framerate = framerate + self.overwrite = overwrite + self.model_version = model_version + + def __len__(self): + return 1 + + def _get_video_info(self, video_path): + probe = ffmpeg.probe(video_path) + video_stream = next((stream for stream in probe['streams'] + if stream['codec_type'] == 'video'), None) + width = int(video_stream['width']) + height = int(video_stream['height']) + fps = math.floor(convert_to_float(video_stream['avg_frame_rate'])) + try: + frames_length = int(video_stream['nb_frames']) + duration = float(video_stream['duration']) + except Exception: + frames_length, duration = -1, -1 + info = {"duration": duration, "frames_length": frames_length, + "fps": fps, "height": height, "width": width} + return info + + def _get_output_dim(self, h, w): + if isinstance(self.size, tuple) and len(self.size) == 2: + return self.size + elif h >= w: + return int(h * self.size / w), self.size + else: + return self.size, int(w * self.size / h) + + def __getitem__(self, id): + video_path = self.vid_path + + load_flag = os.path.isfile(video_path) + if load_flag: + try: + info = self._get_video_info(video_path) + h, w = info["height"], info["width"] + except Exception: + print('ffprobe failed at: {}'.format(video_path)) + return {'video': th.zeros(1), 'input': video_path,'info': {}} + try: + height, width = self._get_output_dim(h, w) + try: + duration = info["duration"] + fps = self.framerate + if duration > 0 and duration < 1/fps+0.1: + fps = 2/max(int(duration), 1) + print(duration, fps) + except Exception: + fps = self.framerate + cmd = ( + ffmpeg + .input(video_path) + .filter('fps', fps=fps) + .filter('scale', width, height) + # .filter('scale', self.size, self.size) + ) + if self.centercrop: + x = int((width - self.size) / 2.0) + y = int((height - self.size) / 2.0) + cmd = cmd.crop(x, y, self.size, self.size) + out, _ = ( + cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True) + ) + if self.centercrop and isinstance(self.size, int): + height, width = self.size, self.size + video = np.frombuffer(out, np.uint8).reshape( + [-1, height, width, 3]) + video = th.from_numpy(video.astype('float32')) + video = video.permute(0, 3, 1, 2) + except: + return {'video': th.zeros(1), 'input': video_path,'info': {}} + else: + video = th.zeros(1) + return {'video': video, 'input': video_path} diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/basic_utils.py b/utils/basic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ab03b808017927d67665a37d606de7b6538253a6 --- /dev/null +++ b/utils/basic_utils.py @@ -0,0 +1,234 @@ +import os +import json +import torch +import random +import zipfile +import numpy as np +import pickle +from collections import OrderedDict, Counter +import pandas as pd +import shutil + +def set_seed(seed, use_cuda=True): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda: + torch.cuda.manual_seed_all(seed) + +def load_pickle(filename): + with open(filename, "rb") as f: + return pickle.load(f) + + +def save_pickle(data, filename): + with open(filename, "wb") as f: + pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +def save_json(data, filename, save_pretty=False, sort_keys=False): + with open(filename, "w") as f: + if save_pretty: + f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) + else: + json.dump(data, f) + + +def load_jsonl(filename): + with open(filename, "r") as f: + return [json.loads(l.strip("\n")) for l in f.readlines()] + + +def save_jsonl(data, filename): + """data is a list""" + with open(filename, "w") as f: + f.write("\n".join([json.dumps(e) for e in data])) + + +def save_lines(list_of_str, filepath): + with open(filepath, "w") as f: + f.write("\n".join(list_of_str)) + + +def read_lines(filepath): + with open(filepath, "r") as f: + return [e.strip("\n") for e in f.readlines()] + + +def mkdirp(p): + if not os.path.exists(p): + os.makedirs(p) + +def remkdirp(p): + if os.path.exists(p): + shutil.rmtree(p) + os.makedirs(p) + +def flat_list_of_lists(l): + """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" + return [item for sublist in l for item in sublist] + + +def convert_to_seconds(hms_time): + """ convert '00:01:12' to 72 seconds. + :hms_time (str): time in comma separated string, e.g. '00:01:12' + :return (int): time in seconds, e.g. 72 + """ + times = [float(t) for t in hms_time.split(":")] + return times[0] * 3600 + times[1] * 60 + times[2] + + +def get_video_name_from_url(url): + return url.split("/")[-1][:-4] + + +def merge_dicts(list_dicts): + merged_dict = list_dicts[0].copy() + for i in range(1, len(list_dicts)): + merged_dict.update(list_dicts[i]) + return merged_dict + + +def l2_normalize_np_array(np_array, eps=1e-5): + """np_array: np.ndarray, (*, D), where the last dim will be normalized""" + return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps) + + +def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, exclude_extensions=None, + exclude_dirs_substring=None): + """make a zip file of root_dir, save it to save_path. + exclude_paths will be excluded if it is a subdir of root_dir. + An enclosing_dir is added is specified. + """ + abs_src = os.path.abspath(src_dir) + with zipfile.ZipFile(save_path, "w") as zf: + for dirname, subdirs, files in os.walk(src_dir): + if exclude_dirs is not None: + for e_p in exclude_dirs: + if e_p in subdirs: + subdirs.remove(e_p) + if exclude_dirs_substring is not None: + to_rm = [] + for d in subdirs: + if exclude_dirs_substring in d: + to_rm.append(d) + for e in to_rm: + subdirs.remove(e) + arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:]) + zf.write(dirname, arcname) + for filename in files: + if exclude_extensions is not None: + if os.path.splitext(filename)[1] in exclude_extensions: + continue # do not zip it + absname = os.path.join(dirname, filename) + arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:]) + zf.write(absname, arcname) + + +class AverageMeter(object): + """Computes and stores the average and current/max/min value""" + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.max = -1e10 + self.min = 1e10 + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.max = -1e10 + self.min = 1e10 + + def update(self, val, n=1): + self.max = max(val, self.max) + self.min = min(val, self.min) + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True): + """Dissect an array (N, D) into a list a sub-array, + np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept""" + if assert_equal: + assert len(np_array) == sum(lengths) + length_indices = [0, ] + for i in range(len(lengths)): + length_indices.append(length_indices[i] + lengths[i]) + if dim == 0: + array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))] + elif dim == 1: + array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))] + elif dim == 2: + array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))] + else: + raise NotImplementedError + return array_list + + +def get_ratio_from_counter(counter_obj, threshold=200): + keys = counter_obj.keys() + values = counter_obj.values() + filtered_values = [counter_obj[k] for k in keys if k > threshold] + return float(sum(filtered_values)) / sum(values) + + +def get_counter_dist(counter_object, sort_type="none"): + _sum = sum(counter_object.values()) + dist = {k: float(f"{100 * v / _sum:.2f}") for k, v in counter_object.items()} + if sort_type == "value": + dist = OrderedDict(sorted(dist.items(), reverse=True)) + return dist + + +def get_show_name(vid_name): + """ + get tvshow name from vid_name + :param vid_name: video clip name + :return: tvshow name + """ + show_list = ["friends", "met", "castle", "house", "grey"] + vid_name_prefix = vid_name.split("_")[0] + show_name = vid_name_prefix if vid_name_prefix in show_list else "bbt" + return show_name + + +def get_abspaths_by_ext(dir_path, ext=(".jpg",)): + """Get absolute paths to files in dir_path with extensions specified by ext. + Note this function does work recursively. + """ + if isinstance(ext, list): + ext = tuple(ext) + if isinstance(ext, str): + ext = tuple([ext, ]) + filepaths = [os.path.join(root, name) + for root, dirs, files in os.walk(dir_path) + for name in files + if name.endswith(tuple(ext))] + return filepaths + + +def get_basename_no_ext(path): + """ '/data/movienet/240p_keyframe_feats/tt7672188.npz' --> 'tt7672188' """ + return os.path.splitext(os.path.split(path)[1])[0] + + +def dict_to_markdown(d, max_str_len=120): + # convert list into its str representation + d = {k: v.__repr__() if isinstance(v, list) else v for k, v in d.items()} + # truncate string that is longer than max_str_len + if max_str_len is not None: + d = {k: v[-max_str_len:] if isinstance(v, str) else v for k, v in d.items()} + return pd.DataFrame(d, index=[0]).transpose().to_markdown() + diff --git a/utils/cpd_auto.py b/utils/cpd_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..b561b08c4764350843955b510402e3c0a28d62a7 --- /dev/null +++ b/utils/cpd_auto.py @@ -0,0 +1,89 @@ +import numpy as np +from .cpd_nonlin import cpd_nonlin + +def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs): + """Main interface + + Detect change points automatically selecting their number + K - kernel between each pair of frames in video + ncp - maximum ncp + vmax - special parameter + Optional arguments: + lmin - minimum segment length + lmax - maximum segment length + desc_rate - rate of descriptor sampling (vmax always corresponds to 1x) + + Note: + - cps are always calculated in subsampled coordinates irrespective to + desc_rate + - lmin and m should be in agreement + --- + Returns: (cps, costs) + cps - best selected change-points + costs - costs for 0,1,2,...,m change-points + + Memory requirement: ~ (3*N*N + N*ncp)*4 bytes ~= 16 * N^2 bytes + That is 1,6 Gb for the N=10000. + """ + m = ncp + (_, scores) = cpd_nonlin(K, m, backtrack=False, **kwargs) + # print("scores ",scores) + + N = K.shape[0] + N2 = N*desc_rate # length of the video before subsampling + + penalties = np.zeros(m+1) + # Prevent division by zero (in case of 0 changes) + ncp = np.arange(1, m+1) + penalties[1:] = (vmax*ncp/(2.0*N2))*(np.log(float(N2)/ncp)+1) + + costs = scores/float(N) + penalties + m_best = np.argmin(costs) + # print("cost ",costs) + # print("m_best ",m_best) + (cps, scores2) = cpd_nonlin(K, m_best, **kwargs) + + return (cps, costs) + + +# ------------------------------------------------------------------------------ +# Extra functions (currently not used) + +def estimate_vmax(K_stable): + """K_stable - kernel between all frames of a stable segment""" + n = K_stable.shape[0] + vmax = np.trace(centering(K_stable)/n) + return vmax + + +def centering(K): + """Apply kernel centering""" + mean_rows = np.mean(K, 1)[:, np.newaxis] + return K - mean_rows - mean_rows.T + np.mean(mean_rows) + + +def eval_score(K, cps): + """ Evaluate unnormalized empirical score + (sum of kernelized scatters) for the given change-points """ + N = K.shape[0] + cps = [0] + list(cps) + [N] + V1 = 0 + V2 = 0 + for i in range(len(cps)-1): + K_sub = K[cps[i]:cps[i+1], :][:, cps[i]:cps[i+1]] + V1 += np.sum(np.diag(K_sub)) + V2 += np.sum(K_sub) / float(cps[i+1] - cps[i]) + return (V1 - V2) + + +def eval_cost(K, cps, score, vmax): + """ Evaluate cost function for automatic number of change points selection + K - kernel between all frames + cps - selected change-points + score - unnormalized empirical score (sum of kernelized scatters) + vmax - vmax parameter""" + + N = K.shape[0] + penalty = (vmax*len(cps)/(2.0*N))*(np.log(float(N)/len(cps))+1) + return score/float(N) + penalty + diff --git a/utils/cpd_nonlin.py b/utils/cpd_nonlin.py new file mode 100644 index 0000000000000000000000000000000000000000..ee577b95e36b2208bde52fa80e73c88048f5536c --- /dev/null +++ b/utils/cpd_nonlin.py @@ -0,0 +1,115 @@ +import numpy as np +# from scipy import weave + +def calc_scatters(K): + n = K.shape[0] + K1 = np.cumsum([0] + list(np.diag(K))) + K2 = np.zeros((n+1, n+1)) + K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1) # TODO: use the fact that K - symmetric + + scatters = np.zeros((n, n)) + +# code = r""" +# for (int i = 0; i < n; i++) { +# for (int j = i; j < n; j++) { +# scatters(i,j) = K1(j+1)-K1(i) - (K2(j+1,j+1)+K2(i,i)-K2(j+1,i)-K2(i,j+1))/(j-i+1); +# } +# } +# """ +# weave.inline(code, ['K1','K2','scatters','n'], global_dict = \ +# {'K1':K1, 'K2':K2, 'scatters':scatters, 'n':n}, type_converters=weave.converters.blitz) + + for i in range(n): + for j in range(i, n): + scatters[i,j] = K1[j+1] - K1[i] - (K2[j+1,j+1]+K2[i,i]-K2[j+1,i]-K2[i,j+1])/(j-i+1) + return scatters + +def cpd_nonlin(K, ncp, lmin=1, lmax=100000, backtrack=True, verbose=True, + out_scatters=None): + """ Change point detection with dynamic programming + K - square kernel matrix + ncp - number of change points to detect (ncp >= 0) + lmin - minimal length of a segment + lmax - maximal length of a segment + backtrack - when False - only evaluate objective scores (to save memory) + + Returns: (cps, obj) + cps - detected array of change points: mean is thought to be constant on [ cps[i], cps[i+1] ) + obj_vals - values of the objective function for 0..m changepoints + + """ + m = int(ncp) # prevent numpy.int64 + + (n, n1) = K.shape + assert(n == n1), "Kernel matrix awaited." + + assert(n >= (m + 1)*lmin) + assert(n <= (m + 1)*lmax) + assert(lmax >= lmin >= 1) + + if verbose: + #print "n =", n + print("Precomputing scatters...") + J = calc_scatters(K) + + if out_scatters != None: + out_scatters[0] = J + + if verbose: + print("Inferring best change points...") + I = 1e101*np.ones((m+1, n+1)) + I[0, lmin:lmax] = J[0, lmin-1:lmax-1] + + if backtrack: + p = np.zeros((m+1, n+1), dtype=int) + else: + p = np.zeros((1,1), dtype=int) + +# code = r""" +# #define max(x,y) ((x)>(y)?(x):(y)) +# for (int k=1; k 1e99] = np.inf + return cps, scores + + diff --git a/utils/kts_utils.py b/utils/kts_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3265f9a0d3317578cb0fa5eb1a8033f333bec937 --- /dev/null +++ b/utils/kts_utils.py @@ -0,0 +1,204 @@ +import numpy as np +from .cpd_nonlin import cpd_nonlin + +def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs): + """Main interface + + Detect change points automatically selecting their number + K - kernel between each pair of frames in video + ncp - maximum ncp + vmax - special parameter + Optional arguments: + lmin - minimum segment length + lmax - maximum segment length + desc_rate - rate of descriptor sampling (vmax always corresponds to 1x) + + Note: + - cps are always calculated in subsampled coordinates irrespective to + desc_rate + - lmin and m should be in agreement + --- + Returns: (cps, costs) + cps - best selected change-points + costs - costs for 0,1,2,...,m change-points + + Memory requirement: ~ (3*N*N + N*ncp)*4 bytes ~= 16 * N^2 bytes + That is 1,6 Gb for the N=10000. + """ + m = ncp + (_, scores) = cpd_nonlin(K, m, backtrack=False, **kwargs) + # print("scores ",scores) + + N = K.shape[0] + N2 = N * desc_rate # length of the video before subsampling + + penalties = np.zeros(m + 1) + # Prevent division by zero (in case of 0 changes) + ncp = np.arange(1, m + 1) + penalties[1:] = (vmax * ncp / (2.0 * N2)) * (np.log(float(N2) / ncp) + 1) + + costs = scores / float(N) + penalties + m_best = np.argmin(costs) + # print("cost ",costs) + # print("m_best ",m_best) + (cps, scores2) = cpd_nonlin(K, m_best, **kwargs) + + return (cps, costs) + + +# ------------------------------------------------------------------------------ +# Extra functions (currently not used) + +def estimate_vmax(K_stable): + """K_stable - kernel between all frames of a stable segment""" + n = K_stable.shape[0] + vmax = np.trace(centering(K_stable) / n) + return vmax + + +def centering(K): + """Apply kernel centering""" + mean_rows = np.mean(K, 1)[:, np.newaxis] + return K - mean_rows - mean_rows.T + np.mean(mean_rows) + + +def eval_score(K, cps): + """ Evaluate unnormalized empirical score + (sum of kernelized scatters) for the given change-points """ + N = K.shape[0] + cps = [0] + list(cps) + [N] + V1 = 0 + V2 = 0 + for i in range(len(cps) - 1): + K_sub = K[cps[i]:cps[i + 1], :][:, cps[i]:cps[i + 1]] + V1 += np.sum(np.diag(K_sub)) + V2 += np.sum(K_sub) / float(cps[i + 1] - cps[i]) + return (V1 - V2) + + +def eval_cost(K, cps, score, vmax): + """ Evaluate cost function for automatic number of change points selection + K - kernel between all frames + cps - selected change-points + score - unnormalized empirical score (sum of kernelized scatters) + vmax - vmax parameter""" + + N = K.shape[0] + penalty = (vmax * len(cps) / (2.0 * N)) * (np.log(float(N) / len(cps)) + 1) + return score / float(N) + penalty + + +def calc_scatters(K): + n = K.shape[0] + K1 = np.cumsum([0] + list(np.diag(K))) + K2 = np.zeros((n + 1, n + 1)).astype(np.double()) + K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1) # TODO: use the fact that K - symmetric + # KK = np.cumsum(K, 0).astype(np.double()) + # K2[1:, 1:] = np.cumsum(KK, 1) # TODO: use the fact that K - symmetric + + scatters = np.zeros((n, n)) + + # code = r""" + # for (int i = 0; i < n; i++) { + # for (int j = i; j < n; j++) { + # scatters(i,j) = K1(j+1)-K1(i) - (K2(j+1,j+1)+K2(i,i)-K2(j+1,i)-K2(i,j+1))/(j-i+1); + # } + # } + # """ + # weave.inline(code, ['K1','K2','scatters','n'], global_dict = \ + # {'K1':K1, 'K2':K2, 'scatters':scatters, 'n':n}, type_converters=weave.converters.blitz) + + for i in range(n): + for j in range(i, n): + scatters[i, j] = K1[j + 1] - K1[i] - (K2[j + 1, j + 1] + K2[i, i] - K2[j + 1, i] - K2[i, j + 1]) / ( + j - i + 1) + return scatters + +def cpd_nonlin(K, ncp, lmin=1, lmax=100000, backtrack=True, verbose=True, + out_scatters=None): + """ Change point detection with dynamic programming + K - square kernel matrix + ncp - number of change points to detect (ncp >= 0) + lmin - minimal length of a segment + lmax - maximal length of a segment + backtrack - when False - only evaluate objective scores (to save memory) + + Returns: (cps, obj) + cps - detected array of change points: mean is thought to be constant on [ cps[i], cps[i+1] ) + obj_vals - values of the objective function for 0..m changepoints + + """ + m = int(ncp) # prevent numpy.int64 + + (n, n1) = K.shape + assert (n == n1), "Kernel matrix awaited." + + assert (n >= (m + 1) * lmin) + assert (n <= (m + 1) * lmax) + assert (lmax >= lmin >= 1) + + if verbose: + # print "n =", n + print("Precomputing scatters...") + J = calc_scatters(K) + + if out_scatters != None: + out_scatters[0] = J + + if verbose: + print("Inferring best change points...") + I = 1e101 * np.ones((m + 1, n + 1)) + I[0, lmin:lmax] = J[0, lmin - 1:lmax - 1] + + if backtrack: + p = np.zeros((m + 1, n + 1), dtype=int) + else: + p = np.zeros((1, 1), dtype=int) + + # code = r""" + # #define max(x,y) ((x)>(y)?(x):(y)) + # for (int k=1; k 1e99] = np.inf + return cps, scores + + diff --git a/utils/model_utils.py b/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..06eed4751ad15e78692e64926dfd2741664949ce --- /dev/null +++ b/utils/model_utils.py @@ -0,0 +1,15 @@ +def count_parameters(model, verbose=True): + """Count number of parameters in PyTorch model, + References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7. + + from utils.utils import count_parameters + count_parameters(model) + import sys + sys.exit(1) + """ + n_all = sum(p.numel() for p in model.parameters()) + n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + if verbose: + print("Parameter Count: all {:,d}; trainable {:,d}".format(n_all, n_trainable)) + return n_all, n_trainable + diff --git a/utils/span_utils.py b/utils/span_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3b8a23d6d73820b3bc03b1dd1ce498214eca10 --- /dev/null +++ b/utils/span_utils.py @@ -0,0 +1,124 @@ +import pdb + +import torch + + +def span_xx_to_cxw(xx_spans): + """ + Args: + xx_spans: tensor, (#windows, 2) or (..., 2), each row is a window of format (st, ed) + + Returns: + cxw_spans: tensor, (#windows, 2), each row is a window of format (center=(st+ed)/2, width=(ed-st)) + >>> spans = torch.Tensor([[0, 1], [0.2, 0.4]]) + >>> span_xx_to_cxw(spans) + tensor([[0.5000, 1.0000], + [0.3000, 0.2000]]) + >>> spans = torch.Tensor([[[0, 1], [0.2, 0.4]]]) + >>> span_xx_to_cxw(spans) + tensor([[[0.5000, 1.0000], + [0.3000, 0.2000]]]) + """ + center = xx_spans.sum(-1) * 0.5 + width = xx_spans[..., 1] - xx_spans[..., 0] + return torch.stack([center, width], dim=-1) + + +def span_cxw_to_xx(cxw_spans): + """ + Args: + cxw_spans: tensor, (#windows, 2) or (..., 2), the last dim is a row denoting a window of format (center, width) + + >>> spans = torch.Tensor([[0.5000, 1.0000], [0.3000, 0.2000]]) + >>> span_cxw_to_xx(spans) + tensor([[0.0000, 1.0000], + [0.2000, 0.4000]]) + >>> spans = torch.Tensor([[[0.5000, 1.0000], [0.3000, 0.2000]]]) + >>> span_cxw_to_xx(spans) + tensor([[[0.0000, 1.0000], + [0.2000, 0.4000]]]) + """ + x1 = cxw_spans[..., 0] - 0.5 * cxw_spans[..., 1] + x2 = cxw_spans[..., 0] + 0.5 * cxw_spans[..., 1] + return torch.stack([x1, x2], dim=-1) + + +def temporal_iou(spans1, spans2): + """ + Args: + spans1: (N, 2) torch.Tensor, each row defines a span [st, ed] + spans2: (M, 2) torch.Tensor, ... + + Returns: + iou: (N, M) torch.Tensor + union: (N, M) torch.Tensor + >>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]]) + >>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]]) + >>> temporal_iou(test_spans1, test_spans2) + (tensor([[0.6667, 0.2000], + [0.0000, 0.5000]]), + tensor([[0.3000, 1.0000], + [0.8000, 1.0000]])) + """ + areas1 = spans1[:, 1] - spans1[:, 0] # (N, ) + areas2 = spans2[:, 1] - spans2[:, 0] # (M, ) + + left = torch.max(spans1[:, None, 0], spans2[:, 0]) # (N, M) + right = torch.min(spans1[:, None, 1], spans2[:, 1]) # (N, M + + inter = (right - left).clamp(min=0) # (N, M) + union = areas1[:, None] + areas2 - inter # (N, M) + + iou = inter / union + return iou, union + + +def temporal_intersection_over_pred(gt_spans, pred_spans): + """ intersection over the second input spans + Args: + gt_spans: (N, 2), + pred_spans: (M, 2) + + Returns: + + """ + left = torch.max(gt_spans[:, None, 0], pred_spans[:, 0]) + right = torch.min(gt_spans[:, None, 1], pred_spans[:, 1]) + + inter = (right - left).clamp(min=0) # (N, M) + inter_over_pred = inter / (pred_spans[:, 1] - pred_spans[:, 0]) + return inter_over_pred + + +def generalized_temporal_iou(spans1, spans2): + """ + Generalized IoU from https://giou.stanford.edu/ + Also reference to DETR implementation of generalized_box_iou + https://github.com/facebookresearch/detr/blob/master/util/box_ops.py#L40 + + Args: + spans1: (N, 2) torch.Tensor, each row defines a span in xx format [st, ed] + spans2: (M, 2) torch.Tensor, ... + + Returns: + giou: (N, M) torch.Tensor + + >>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]]) + >>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]]) + >>> generalized_temporal_iou(test_spans1, test_spans2) + tensor([[ 0.6667, 0.2000], + [-0.2000, 0.5000]]) + """ + spans1 = spans1.float() + spans2 = spans2.float() + assert (spans1[:, 1] >= spans1[:, 0]).all() + assert (spans2[:, 1] >= spans2[:, 0]).all() + iou, union = temporal_iou(spans1, spans2) + + left = torch.min(spans1[:, None, 0], spans2[:, 0]) # (N, M) + right = torch.max(spans1[:, None, 1], spans2[:, 1]) # (N, M) + enclosing_area = (right - left).clamp(min=0) # (N, M) + + return iou - (enclosing_area - union) / enclosing_area + + diff --git a/utils/temporal_nms.py b/utils/temporal_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..2844f5d4c1ac71760cd82c7aaf82c6b2daa9a207 --- /dev/null +++ b/utils/temporal_nms.py @@ -0,0 +1,74 @@ +""" +Non-Maximum Suppression for video proposals. +""" + + +def compute_temporal_iou(pred, gt): + """ deprecated due to performance concerns + compute intersection-over-union along temporal axis + Args: + pred: [st (float), ed (float)] + gt: [st (float), ed (float)] + Returns: + iou (float): + + Ref: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py + """ + intersection = max(0, min(pred[1], gt[1]) - max(pred[0], gt[0])) + union = max(pred[1], gt[1]) - min(pred[0], gt[0]) # not the correct union though + if union == 0: + return 0 + else: + return 1.0 * intersection / union + + +def temporal_nms(predictions, nms_thd, max_after_nms=100): + """ + Args: + predictions: list(sublist), each sublist is [st (float), ed(float), score (float)], + note larger scores are better and are preserved. For metrics that are better when smaller, + please convert to its negative, e.g., convert distance to negative distance. + nms_thd: float in [0, 1] + max_after_nms: + Returns: + predictions_after_nms: list(sublist), each sublist is [st (float), ed(float), score (float)] + References: + https://github.com/wzmsltw/BSN-boundary-sensitive-network/blob/7b101fc5978802aa3c95ba5779eb54151c6173c6/Post_processing.py#L42 + """ + if len(predictions) == 1: # only has one prediction, no need for nms + return predictions + + predictions = sorted(predictions, key=lambda x: x[2], reverse=True) # descending order + + tstart = [e[0] for e in predictions] + tend = [e[1] for e in predictions] + tscore = [e[2] for e in predictions] + rstart = [] + rend = [] + rscore = [] + while len(tstart) > 1 and len(rscore) < max_after_nms: # max 100 after nms + idx = 1 + while idx < len(tstart): # compare with every prediction in the list. + if compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]]) > nms_thd: + # rm highly overlapped lower score entries. + tstart.pop(idx) + tend.pop(idx) + tscore.pop(idx) + # print("--------------------------------") + # print(compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]])) + # print([tstart[0], tend[0]], [tstart[idx], tend[idx]]) + # print(tstart.pop(idx), tend.pop(idx), tscore.pop(idx)) + else: + # move to next + idx += 1 + rstart.append(tstart.pop(0)) + rend.append(tend.pop(0)) + rscore.append(tscore.pop(0)) + + if len(rscore) < max_after_nms and len(tstart) >= 1: # add the last, possibly empty. + rstart.append(tstart.pop(0)) + rend.append(tend.pop(0)) + rscore.append(tscore.pop(0)) + + predictions_after_nms = [[st, ed, s] for s, st, ed in zip(rscore, rstart, rend)] + return predictions_after_nms diff --git a/utils/tensor_utils.py b/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2c25a83b66092b1ce8731b4d9fae1523438b29 --- /dev/null +++ b/utils/tensor_utils.py @@ -0,0 +1,93 @@ +import numpy as np +import torch + + +def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None): + """ Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray) + into a (n+1)-d array, only allow the first dim has variable lengths. + Args: + sequences: list(n-d tensor or list) + dtype: np.dtype or torch.dtype + device: + fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length. + return will be of shape [len(sequences), fixed_length, ...] + Returns: + padded_seqs: ((n+1)-d tensor) padded with zeros + mask: (2d tensor) of the same shape as the first two dims of padded_seqs, + 1 indicate valid, 0 otherwise + Examples: + >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] + >>> pad_sequences_1d(test_data_list, dtype=torch.long) + >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)] + >>> pad_sequences_1d(test_data_3d, dtype=torch.float) + >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] + >>> pad_sequences_1d(test_data_list, dtype=np.float32) + >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)] + >>> pad_sequences_1d(test_data_3d, dtype=np.float32) + """ + if isinstance(sequences[0], list): + if "torch" in str(dtype): + sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences] + else: + sequences = [np.asarray(s, dtype=dtype) for s in sequences] + + extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements + lengths = [len(seq) for seq in sequences] + if fixed_length is not None: + max_length = fixed_length + else: + max_length = max(lengths) + if isinstance(sequences[0], torch.Tensor): + assert "torch" in str(dtype), "dtype and input type does not match" + padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device) + mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device) + else: # np + assert "numpy" in str(dtype), "dtype and input type does not match" + padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype) + mask = np.zeros((len(sequences), max_length), dtype=np.float32) + + for idx, seq in enumerate(sequences): + end = lengths[idx] + padded_seqs[idx, :end] = seq + mask[idx, :end] = 1 + return padded_seqs, mask # , lengths + + +def pad_sequences_2d(sequences, dtype=torch.long): + """ Pad a double-nested list or a sequence of n-d torch tensor into a (n+1)-d tensor, + only allow the first two dims has variable lengths + Args: + sequences: list(n-d tensor or list) + dtype: torch.long for word indices / torch.float (float32) for other cases + Returns: + Examples: + >>> test_data_list = [[[1, 3, 5], [3, 7, 4, 1]], [[98, 34, 11, 89, 90], [22], [34, 56]],] + >>> pad_sequences_2d(test_data_list, dtype=torch.long) # torch.Size([2, 3, 5]) + >>> test_data_3d = [torch.randn(2,2,4), torch.randn(4,3,4), torch.randn(1,5,4)] + >>> pad_sequences_2d(test_data_3d, dtype=torch.float) # torch.Size([2, 3, 5]) + >>> test_data_3d2 = [[torch.randn(2,4), ], [torch.randn(3,4), torch.randn(5,4)]] + >>> pad_sequences_2d(test_data_3d2, dtype=torch.float) # torch.Size([2, 3, 5]) + # TODO add support for numpy array + """ + bsz = len(sequences) + para_lengths = [len(seq) for seq in sequences] + max_para_len = max(para_lengths) + sen_lengths = [[len(word_seq) for word_seq in seq] for seq in sequences] + max_sen_len = max([max(e) for e in sen_lengths]) + + if isinstance(sequences[0], torch.Tensor): + extra_dims = sequences[0].shape[2:] + elif isinstance(sequences[0][0], torch.Tensor): + extra_dims = sequences[0][0].shape[1:] + else: + sequences = [[torch.Tensor(word_seq, dtype=dtype) for word_seq in seq] for seq in sequences] + extra_dims = () + + padded_seqs = torch.zeros((bsz, max_para_len, max_sen_len) + extra_dims, dtype=dtype) + mask = torch.zeros(bsz, max_para_len, max_sen_len).float() + + for b_i in range(bsz): + for sen_i, sen_l in enumerate(sen_lengths[b_i]): + padded_seqs[b_i, sen_i, :sen_l] = sequences[b_i][sen_i] + mask[b_i, sen_i, :sen_l] = 1 + return padded_seqs, mask # , sen_lengths diff --git a/utils/windows_utils.py b/utils/windows_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f3527cdfd7107db5d7eb57afe47f3e8b3bbbc15d --- /dev/null +++ b/utils/windows_utils.py @@ -0,0 +1,59 @@ +""" +Find windows from a video with clip_ids. + +A window is defined by a [start_clip_idx, end_clip_idx] pair: +For example, assuming clip_len = 2 seconds +[0, 0] meaning a single clip window [0, 2] (seconds) +[10, 19] meaning a 9 clip window [20, 40] (seconds) + +""" + + +def convert_clip_ids_to_windows(clip_ids): + """ Inverse function of convert_windows_to_clip_ids + Args: + clip_ids: list(int), each is a index of a clip, starting from 0 + + Returns: + list(list(int)), each sublist contains two integers which are clip indices. + [10, 19] meaning a 9 clip window [20, 40] (seconds), if each clip is 2 seconds. + + >>> test_clip_ids = [56, 57, 58, 59, 60, 61, 62] + [64, ] + [67, 68, 69, 70, 71] + >>> convert_clip_ids_to_windows(test_clip_ids) + [[56, 62], [64, 64], [67, 71]] + """ + windows = [] + _window = [clip_ids[0], None] + last_clip_id = clip_ids[0] + for clip_id in clip_ids: + if clip_id - last_clip_id > 1: # find gap + _window[1] = last_clip_id + windows.append(_window) + _window = [clip_id, None] + last_clip_id = clip_id + _window[1] = last_clip_id + windows.append(_window) + return windows + + +def convert_windows_to_clip_ids(windows): + """ Inverse function of convert_clip_ids_to_windows + Args: + windows: list(list(int)), each sublist contains two integers which are clip indices. + [10, 11] meaning a 9 clip window [20, 40] (seconds), if each clip is 2 seconds. + + Returns: + clip_ids: list(int) + + >>> test_windows =[[56, 62], [64, 64], [67, 71]] + >>> convert_windows_to_clip_ids(test_windows) + [56, 57, 58, 59, 60, 61, 62] + [64, ] + [67, 68, 69, 70, 71] + """ + clip_ids = [] + for w in windows: + clip_ids += list(range(w[0], w[1]+1)) + return clip_ids + + +def convert_clip_window_to_seconds(window, clip_len=2): + return [window[0] * clip_len, (window[1] + 1) * clip_len]