import os import torch import torch.nn as nn import torch.nn.functional as F import math from tqdm import tqdm import argparse from collections import OrderedDict import json from collections import defaultdict from model.deberta_moe import DebertaV2ForMaskedLM from transformers import DebertaV2Tokenizer import clip import ffmpeg from VideoLoader import VideoLoader def get_mask(lengths, max_length): """ Computes a batch of padding masks given batched lengths """ mask = 1 * ( torch.arange(max_length).unsqueeze(1) < lengths ).transpose(0, 1) return mask class Infer: def __init__(self, device): pretrained_ckpt = torch.load("ckpts/model.pth", map_location="cpu") args = pretrained_ckpt['args'] args.n_ans = 2 args.max_tokens = 256 self.args = args self.clip_model = clip.load("ViT-L/14", device = device)[0] self.tokenizer = DebertaV2Tokenizer.from_pretrained( "ckpts/deberta-v2-xlarge", local_files_only=True ) self.model = DebertaV2ForMaskedLM.from_pretrained( features_dim=args.features_dim if args.use_video else 0, max_feats=args.max_feats, freeze_lm=args.freeze_lm, freeze_mlm=args.freeze_mlm, ft_ln=args.ft_ln, ds_factor_attn=args.ds_factor_attn, ds_factor_ff=args.ds_factor_ff, dropout=args.dropout, n_ans=args.n_ans, freeze_last=args.freeze_last, pretrained_model_name_or_path="ckpts/deberta-v2-xlarge", local_files_only=False, add_video_feat=args.add_video_feat, freeze_ad=args.freeze_ad, ) new_state_dict = OrderedDict() for k, v in pretrained_ckpt['model'].items(): new_state_dict[k.replace("module.","")] = v self.model.load_state_dict(pretrained_ckpt, strict=False) self.model.eval() self.model.to(device) self.device = device self.video_loader = VideoLoader() self.set_answer() def _get_clip_feature(self, video): feat = self.clip_model.encode_image(video.to(self.device)) #feat = F.normalize(feat, dim=1) return feat def set_answer(self): tok_yes = torch.tensor( self.tokenizer( "Yes", add_special_tokens=False, max_length=1, truncation=True, padding="max_length", )["input_ids"], dtype=torch.long, ) tok_no = torch.tensor( self.tokenizer( "No", add_special_tokens=False, max_length=1, truncation=True, padding="max_length", )["input_ids"], dtype=torch.long, ) a2tok = torch.stack([tok_yes, tok_no]) self.model.set_answer_embeddings( a2tok.to(self.model.device), freeze_last=self.args.freeze_last ) def generate(self, text, candidates, video_path): video, video_len = self.video_loader(video_path) video = self._get_clip_feature(video).unsqueeze(0).float() video_mask = get_mask(video_len, 10) video_mask = torch.cat([torch.ones((1,1)),video_mask], dim=1) logits_list = [] question = text.capitalize().strip() if question[-1] != "?": question = str(question) + "?" for aid in range(len(candidates)): prompt = ( f" Question: {question} Is it '{candidates[aid]}'? {self.tokenizer.mask_token}. Subtitles: " ) prompt = prompt.strip() encoded = self.tokenizer( prompt, add_special_tokens=True, max_length=self.args.max_tokens, padding="longest", truncation=True, return_tensors="pt", ) # forward output = self.model( video=video.to(self.device), video_mask=video_mask.to(self.device), input_ids=encoded["input_ids"].to(self.device), attention_mask=encoded["attention_mask"].to(self.device), ) # += output['loads'].detach().cpu() logits = output["logits"] # get logits for the mask token delay = 11 logits = logits[:, delay : encoded["input_ids"].size(1) + delay][ encoded["input_ids"] == self.tokenizer.mask_token_id ] logits_list.append(logits.softmax(-1)[:, 0]) logits = torch.stack(logits_list, 1) if logits.shape[1] == 1: preds = logits.round().long().squeeze(1) else: preds = logits.max(1).indices return candidates[preds]