### demo.py # Define model classes for inference. ### import json import torch import torch.nn as nn import torch.backends.cudnn as cudnn from einops import rearrange from transformers import BertTokenizer from torchvision import transforms from torchvision.transforms._transforms_video import ( NormalizeVideo, ) from svitt.model import SViTT from svitt.config import load_cfg, setup_config from svitt.base_dataset import read_frames_cv2_egoclip class VideoModel(nn.Module): """ Base model for video understanding based on SViTT architecture. """ def __init__(self, config): """ Initializes the model. Parameters: config: config file """ super(VideoModel, self).__init__() self.cfg = load_cfg(config) self.model = self.build_model() self.templates = ['{}'] self.dataset = self.cfg['data']['dataset'] self.eval() def build_model(self): cfg = self.cfg if cfg['model'].get('pretrain', False): ckpt_path = cfg['model']['pretrain'] else: raise Exception('no checkpoint found') if cfg['model'].get('config', False): config_path = cfg['model']['config'] else: raise Exception('no model config found') self.model_cfg = setup_config(config_path) self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder) model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer) print(f"Loading checkpoint from {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint["model"] # fix for zero-shot evaluation for key in list(state_dict.keys()): if "bert" in key: encoder_key = key.replace("bert.", "") state_dict[encoder_key] = state_dict[key] if torch.cuda.is_available(): model.cuda() model.load_state_dict(state_dict, strict=False) return model def eval(self): cudnn.benchmark = True for p in self.model.parameters(): p.requires_grad = False self.model.eval() class VideoCLSModel(VideoModel): """ Video model for video classification tasks (Charades-Ego, EGTEA). """ def __init__(self, config, sample_videos): super(VideoCLSModel, self).__init__(config) self.sample_videos = sample_videos self.video_transform = self.init_video_transform() #def load_data(self, idx=None): # filename = f"{self.cfg['data']['root']}/{idx}/tensors.pt" # return torch.load(filename) def init_video_transform(self, input_res=224, center_crop=256, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), ): print('Video Transform is used!') normalize = NormalizeVideo(mean=norm_mean, std=norm_std) return transforms.Compose( [ transforms.Resize(center_crop), transforms.CenterCrop(center_crop), transforms.Resize(input_res), normalize, ] ) def load_data(self, idx): num_frames = self.model_cfg.video_input.num_frames video_paths = self.sample_videos[idx] clips = [None] * len(video_paths) for i, path in enumerate(video_paths): imgs = read_frames_cv2_egoclip(path, num_frames, 'uniform') imgs = imgs.transpose(0, 1) imgs = self.video_transform(imgs) imgs = imgs.transpose(0, 1) clips[i] = imgs return torch.stack(clips) def load_meta(self, idx=None): filename = f"{self.cfg['data']['root']}/{idx}/meta.json" with open(filename, "r") as f: meta = json.load(f) return meta @torch.no_grad() def get_text_features(self, text): print('=> Extracting text features') embeddings = self.tokenizer( text, padding="max_length", truncation=True, max_length=self.model_cfg.max_txt_l.video, return_tensors="pt", ) _, class_embeddings = self.model.encode_text(embeddings) return class_embeddings @torch.no_grad() def forward(self, idx, text=None): print('=> Start forwarding') meta = self.load_meta(idx) clips = self.load_data(idx) if text is None: text = meta["text"][4:] text_features = self.get_text_features(text) target = meta["correct"] # encode images pooled_image_feat_all = [] for i in range(clips.shape[0]): images = clips[i,:].unsqueeze(0) bsz = images.shape[0] _, pooled_image_feat, *outputs = self.model.encode_image(images) if pooled_image_feat.ndim == 3: pooled_image_feat = rearrange(pooled_image_feat, '(b k) n d -> b (k n) d', b=bsz) else: pooled_image_feat = rearrange(pooled_image_feat, '(b k) d -> b k d', b=bsz) pooled_image_feat_all.append(pooled_image_feat) pooled_image_feat_all = torch.cat(pooled_image_feat_all, dim=0) similarity = self.model.get_sim(pooled_image_feat_all, text_features)[0] return similarity.argmax(), target @torch.no_grad() def predict(self, idx, text=None): output, target = self.forward(idx, text) return output.numpy(), target