Spaces:
Runtime error
Runtime error
File size: 5,628 Bytes
a17aefb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
### demo.py
# Define model classes for inference.
###
import json
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from einops import rearrange
from transformers import BertTokenizer
from torchvision import transforms
from torchvision.transforms._transforms_video import (
NormalizeVideo,
)
from svitt.model import SViTT
from svitt.config import load_cfg, setup_config
from svitt.base_dataset import read_frames_cv2_egoclip
class VideoModel(nn.Module):
""" Base model for video understanding based on SViTT architecture. """
def __init__(self, config):
""" Initializes the model.
Parameters:
config: config file
"""
super(VideoModel, self).__init__()
self.cfg = load_cfg(config)
self.model = self.build_model()
self.templates = ['{}']
self.dataset = self.cfg['data']['dataset']
self.eval()
def build_model(self):
cfg = self.cfg
if cfg['model'].get('pretrain', False):
ckpt_path = cfg['model']['pretrain']
else:
raise Exception('no checkpoint found')
if cfg['model'].get('config', False):
config_path = cfg['model']['config']
else:
raise Exception('no model config found')
self.model_cfg = setup_config(config_path)
self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder)
model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer)
print(f"Loading checkpoint from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint["model"]
# fix for zero-shot evaluation
for key in list(state_dict.keys()):
if "bert" in key:
encoder_key = key.replace("bert.", "")
state_dict[encoder_key] = state_dict[key]
if torch.cuda.is_available():
model.cuda()
model.load_state_dict(state_dict, strict=False)
return model
def eval(self):
cudnn.benchmark = True
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
class VideoCLSModel(VideoModel):
""" Video model for video classification tasks (Charades-Ego, EGTEA). """
def __init__(self, config, sample_videos):
super(VideoCLSModel, self).__init__(config)
self.sample_videos = sample_videos
self.video_transform = self.init_video_transform()
#def load_data(self, idx=None):
# filename = f"{self.cfg['data']['root']}/{idx}/tensors.pt"
# return torch.load(filename)
def init_video_transform(self,
input_res=224,
center_crop=256,
norm_mean=(0.485, 0.456, 0.406),
norm_std=(0.229, 0.224, 0.225),
):
print('Video Transform is used!')
normalize = NormalizeVideo(mean=norm_mean, std=norm_std)
return transforms.Compose(
[
transforms.Resize(center_crop),
transforms.CenterCrop(center_crop),
transforms.Resize(input_res),
normalize,
]
)
def load_data(self, idx):
num_frames = self.model_cfg.video_input.num_frames
video_paths = self.sample_videos[idx]
clips = [None] * len(video_paths)
for i, path in enumerate(video_paths):
imgs = read_frames_cv2_egoclip(path, num_frames, 'uniform')
imgs = imgs.transpose(0, 1)
imgs = self.video_transform(imgs)
imgs = imgs.transpose(0, 1)
clips[i] = imgs
return torch.stack(clips)
def load_meta(self, idx=None):
filename = f"{self.cfg['data']['root']}/{idx}/meta.json"
with open(filename, "r") as f:
meta = json.load(f)
return meta
@torch.no_grad()
def get_text_features(self, text):
print('=> Extracting text features')
embeddings = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.model_cfg.max_txt_l.video,
return_tensors="pt",
)
_, class_embeddings = self.model.encode_text(embeddings)
return class_embeddings
@torch.no_grad()
def forward(self, idx, text=None):
print('=> Start forwarding')
meta = self.load_meta(idx)
clips = self.load_data(idx)
if text is None:
text = meta["text"][4:]
text_features = self.get_text_features(text)
target = meta["correct"]
# encode images
pooled_image_feat_all = []
for i in range(clips.shape[0]):
images = clips[i,:].unsqueeze(0)
bsz = images.shape[0]
_, pooled_image_feat, *outputs = self.model.encode_image(images)
if pooled_image_feat.ndim == 3:
pooled_image_feat = rearrange(pooled_image_feat, '(b k) n d -> b (k n) d', b=bsz)
else:
pooled_image_feat = rearrange(pooled_image_feat, '(b k) d -> b k d', b=bsz)
pooled_image_feat_all.append(pooled_image_feat)
pooled_image_feat_all = torch.cat(pooled_image_feat_all, dim=0)
similarity = self.model.get_sim(pooled_image_feat_all, text_features)[0]
return similarity.argmax(), target
@torch.no_grad()
def predict(self, idx, text=None):
output, target = self.forward(idx, text)
return output.numpy(), target
|