from tqdm import tqdm, trange import torch import torch.nn.functional as F import numpy as np from utils.run_utils import topk_3d, generate_min_max_length_mask, extract_topk_elements from modules.ndcg_iou import calculate_ndcg_iou def grab_corpus_feature(model, corpus_loader, device): model.eval() all_video_feat, all_video_mask = [], [] all_sub_feat, all_sub_mask = [], [] # all_video_name = [] with torch.no_grad(): for batch_input in tqdm(corpus_loader, desc="Compute Corpus Feature: ", total=len(corpus_loader)): batch_input = {k: v.to(device) for k, v in batch_input.items()} _video_feat, _sub_feat = model.encode_context(batch_input["video_feat"], batch_input["video_mask"], batch_input["sub_feat"], batch_input["sub_mask"]) all_video_feat.append(_video_feat.detach().cpu()) all_video_mask.append(batch_input["video_mask"].detach().cpu()) all_sub_feat.append(_sub_feat.detach().cpu()) all_sub_mask.append(batch_input["sub_mask"].detach().cpu()) all_video_feat = torch.cat(all_video_feat, dim=0) all_video_mask = torch.cat(all_video_mask, dim=0) all_sub_feat = torch.cat(all_sub_feat, dim=0) all_sub_mask = torch.cat(all_sub_mask, dim=0) return { "all_video_feat": all_video_feat, "all_video_mask": all_video_mask, "all_sub_feat": all_sub_feat, "all_sub_mask": all_sub_mask} def eval_epoch(model, corpus_feature, eval_loader, eval_gt, opt, corpus_video_list): topn_video = 100 device = opt.device model.eval() all_query_id = [] all_video_feat = corpus_feature["all_video_feat"].to(device) all_video_mask = corpus_feature["all_video_mask"].to(device) all_sub_feat = corpus_feature["all_sub_feat"].to(device) all_sub_mask = corpus_feature["all_sub_mask"].to(device) all_query_score, all_end_prob, all_start_prob, all_top_video_name = [], [], [], [] for batch_input in tqdm(eval_loader, desc="Compute Query Scores: ", total=len(eval_loader)): batch_input = {k: v.to(device) for k, v in batch_input.items()} query_scores, start_probs, end_probs = model.get_pred_from_raw_query( query_feat = batch_input["query_feat"], query_mask = batch_input["query_mask"], video_feat = all_video_feat, video_mask = all_video_mask, sub_feat = all_sub_feat, sub_mask = all_sub_mask, cross=True) query_scores = torch.exp(opt.q2c_alpha * query_scores) start_probs = F.softmax(start_probs, dim=-1) end_probs = F.softmax(end_probs, dim=-1) query_scores, start_probs, end_probs, video_name_top = extract_topk_elements(query_scores, start_probs, end_probs, corpus_video_list, topn_video) all_query_id.append(batch_input["query_id"].detach().cpu()) all_query_score.append(query_scores.detach().cpu()) all_start_prob.append(start_probs.detach().cpu()) all_end_prob.append(end_probs.detach().cpu()) all_top_video_name.extend(video_name_top) all_query_id = torch.cat(all_query_id, dim=0) all_query_id = all_query_id.tolist() all_query_score = torch.cat(all_query_score, dim=0) all_start_prob = torch.cat(all_start_prob, dim=0) all_end_prob = torch.cat(all_end_prob, dim=0) average_ndcg = calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob, all_top_video_name, eval_gt, opt) return average_ndcg def calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob, all_top_video_name, eval_gt, opt): topn_moment = max(opt.ndcg_topk) all_2D_map = torch.einsum("qvm,qv,qvn->qvmn", all_start_prob, all_query_score, all_end_prob) map_mask = generate_min_max_length_mask(all_2D_map.shape, min_l=opt.min_pred_l, max_l=opt.max_pred_l) all_2D_map = all_2D_map * map_mask all_pred = {} for idx in trange(len(all_2D_map), desc="Collect Predictions: "): query_id = all_query_id[idx] score_map = all_2D_map[idx] top_score, top_idx = topk_3d(score_map, topn_moment) top_video_name = all_top_video_name[idx] pred_videos = [top_video_name[i[0]] for i in top_idx] pre_start_time = [i[1].item() * opt.clip_length for i in top_idx] pre_end_time = [i[2].item() * opt.clip_length for i in top_idx] pred_result = [] for video_name, s, e, score, in zip(pred_videos, pre_start_time, pre_end_time, top_score): pred_result.append({ "video_name": video_name, "timestamp": [s, e], "model_scores": score }) # print(pred_result) all_pred[query_id] = pred_result average_ndcg = calculate_ndcg_iou(eval_gt, all_pred, opt.iou_threshold, opt.ndcg_topk) return average_ndcg