hvaldez's picture
first commit
a17aefb verified
raw
history blame
5.63 kB
### demo.py
# Define model classes for inference.
###
import json
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from einops import rearrange
from transformers import BertTokenizer
from torchvision import transforms
from torchvision.transforms._transforms_video import (
NormalizeVideo,
)
from svitt.model import SViTT
from svitt.config import load_cfg, setup_config
from svitt.base_dataset import read_frames_cv2_egoclip
class VideoModel(nn.Module):
""" Base model for video understanding based on SViTT architecture. """
def __init__(self, config):
""" Initializes the model.
Parameters:
config: config file
"""
super(VideoModel, self).__init__()
self.cfg = load_cfg(config)
self.model = self.build_model()
self.templates = ['{}']
self.dataset = self.cfg['data']['dataset']
self.eval()
def build_model(self):
cfg = self.cfg
if cfg['model'].get('pretrain', False):
ckpt_path = cfg['model']['pretrain']
else:
raise Exception('no checkpoint found')
if cfg['model'].get('config', False):
config_path = cfg['model']['config']
else:
raise Exception('no model config found')
self.model_cfg = setup_config(config_path)
self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder)
model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer)
print(f"Loading checkpoint from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint["model"]
# fix for zero-shot evaluation
for key in list(state_dict.keys()):
if "bert" in key:
encoder_key = key.replace("bert.", "")
state_dict[encoder_key] = state_dict[key]
if torch.cuda.is_available():
model.cuda()
model.load_state_dict(state_dict, strict=False)
return model
def eval(self):
cudnn.benchmark = True
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
class VideoCLSModel(VideoModel):
""" Video model for video classification tasks (Charades-Ego, EGTEA). """
def __init__(self, config, sample_videos):
super(VideoCLSModel, self).__init__(config)
self.sample_videos = sample_videos
self.video_transform = self.init_video_transform()
#def load_data(self, idx=None):
# filename = f"{self.cfg['data']['root']}/{idx}/tensors.pt"
# return torch.load(filename)
def init_video_transform(self,
input_res=224,
center_crop=256,
norm_mean=(0.485, 0.456, 0.406),
norm_std=(0.229, 0.224, 0.225),
):
print('Video Transform is used!')
normalize = NormalizeVideo(mean=norm_mean, std=norm_std)
return transforms.Compose(
[
transforms.Resize(center_crop),
transforms.CenterCrop(center_crop),
transforms.Resize(input_res),
normalize,
]
)
def load_data(self, idx):
num_frames = self.model_cfg.video_input.num_frames
video_paths = self.sample_videos[idx]
clips = [None] * len(video_paths)
for i, path in enumerate(video_paths):
imgs = read_frames_cv2_egoclip(path, num_frames, 'uniform')
imgs = imgs.transpose(0, 1)
imgs = self.video_transform(imgs)
imgs = imgs.transpose(0, 1)
clips[i] = imgs
return torch.stack(clips)
def load_meta(self, idx=None):
filename = f"{self.cfg['data']['root']}/{idx}/meta.json"
with open(filename, "r") as f:
meta = json.load(f)
return meta
@torch.no_grad()
def get_text_features(self, text):
print('=> Extracting text features')
embeddings = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.model_cfg.max_txt_l.video,
return_tensors="pt",
)
_, class_embeddings = self.model.encode_text(embeddings)
return class_embeddings
@torch.no_grad()
def forward(self, idx, text=None):
print('=> Start forwarding')
meta = self.load_meta(idx)
clips = self.load_data(idx)
if text is None:
text = meta["text"][4:]
text_features = self.get_text_features(text)
target = meta["correct"]
# encode images
pooled_image_feat_all = []
for i in range(clips.shape[0]):
images = clips[i,:].unsqueeze(0)
bsz = images.shape[0]
_, pooled_image_feat, *outputs = self.model.encode_image(images)
if pooled_image_feat.ndim == 3:
pooled_image_feat = rearrange(pooled_image_feat, '(b k) n d -> b (k n) d', b=bsz)
else:
pooled_image_feat = rearrange(pooled_image_feat, '(b k) d -> b k d', b=bsz)
pooled_image_feat_all.append(pooled_image_feat)
pooled_image_feat_all = torch.cat(pooled_image_feat_all, dim=0)
similarity = self.model.get_sim(pooled_image_feat_all, text_features)[0]
return similarity.argmax(), target
@torch.no_grad()
def predict(self, idx, text=None):
output, target = self.forward(idx, text)
return output.numpy(), target