File size: 5,628 Bytes
a17aefb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
### demo.py
# Define model classes for inference.
###
import json
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from einops import rearrange
from transformers import BertTokenizer
from torchvision import transforms
from torchvision.transforms._transforms_video import (
    NormalizeVideo, 
)

from svitt.model import SViTT
from svitt.config import load_cfg, setup_config
from svitt.base_dataset import read_frames_cv2_egoclip


class VideoModel(nn.Module):
    """ Base model for video understanding based on SViTT architecture. """
    def __init__(self, config):
        """ Initializes the model.
        Parameters:
            config: config file
        """
        super(VideoModel, self).__init__()
        self.cfg = load_cfg(config)
        self.model = self.build_model()
        self.templates = ['{}']
        self.dataset = self.cfg['data']['dataset']
        self.eval()

    def build_model(self):
        cfg = self.cfg
        if cfg['model'].get('pretrain', False):
            ckpt_path = cfg['model']['pretrain']
        else:
            raise Exception('no checkpoint found')
        
        if cfg['model'].get('config', False):
            config_path = cfg['model']['config']
        else:
            raise Exception('no model config found')
        
        self.model_cfg = setup_config(config_path)
        self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder)
        model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer)

        print(f"Loading checkpoint from {ckpt_path}")
        checkpoint = torch.load(ckpt_path, map_location="cpu")
        state_dict = checkpoint["model"]

        # fix for zero-shot evaluation
        for key in list(state_dict.keys()):
            if "bert" in key:
                encoder_key = key.replace("bert.", "")
                state_dict[encoder_key] = state_dict[key]
                    
        if torch.cuda.is_available():
            model.cuda()
                        
        model.load_state_dict(state_dict, strict=False)

        return model

    def eval(self):
        cudnn.benchmark = True
        for p in self.model.parameters():
            p.requires_grad = False
        self.model.eval()


class VideoCLSModel(VideoModel):
    """ Video model for video classification tasks (Charades-Ego, EGTEA). """
    def __init__(self, config, sample_videos):
        super(VideoCLSModel, self).__init__(config)
        self.sample_videos = sample_videos
        self.video_transform = self.init_video_transform()
        
    #def load_data(self, idx=None):
    #    filename = f"{self.cfg['data']['root']}/{idx}/tensors.pt"
    #    return torch.load(filename)
    def init_video_transform(self,
            input_res=224,
            center_crop=256,
            norm_mean=(0.485, 0.456, 0.406),
            norm_std=(0.229, 0.224, 0.225),
        ):
        print('Video Transform is used!')
        normalize = NormalizeVideo(mean=norm_mean, std=norm_std)
        return transforms.Compose(
            [
                transforms.Resize(center_crop),
                transforms.CenterCrop(center_crop),
                transforms.Resize(input_res),
                normalize,
            ]
        )
    
    def load_data(self, idx):
        num_frames = self.model_cfg.video_input.num_frames
        video_paths = self.sample_videos[idx]
        clips = [None] * len(video_paths)
        for i, path in enumerate(video_paths):
            imgs = read_frames_cv2_egoclip(path, num_frames, 'uniform')
            imgs = imgs.transpose(0, 1) 
            imgs = self.video_transform(imgs)
            imgs = imgs.transpose(0, 1)  
            clips[i] = imgs
        return torch.stack(clips)
    
    def load_meta(self, idx=None):
        filename = f"{self.cfg['data']['root']}/{idx}/meta.json"
        with open(filename, "r") as f:
            meta = json.load(f)
        return meta
        
    @torch.no_grad()
    def get_text_features(self, text):
        print('=> Extracting text features')
        embeddings = self.tokenizer(
            text, 
            padding="max_length", 
            truncation=True,
            max_length=self.model_cfg.max_txt_l.video, 
            return_tensors="pt",
        )
        _, class_embeddings = self.model.encode_text(embeddings)
        return class_embeddings

    @torch.no_grad()
    def forward(self, idx, text=None):
        print('=> Start forwarding')
        meta = self.load_meta(idx)
        clips = self.load_data(idx)
        if text is None:
            text = meta["text"][4:]
        text_features = self.get_text_features(text)
        target = meta["correct"]

        # encode images
        pooled_image_feat_all = []
        for i in range(clips.shape[0]):
        
            images = clips[i,:].unsqueeze(0)
            bsz = images.shape[0]

            _, pooled_image_feat, *outputs = self.model.encode_image(images) 
            if pooled_image_feat.ndim == 3:
                pooled_image_feat = rearrange(pooled_image_feat, '(b k) n d -> b (k n) d', b=bsz)
            else:
                pooled_image_feat = rearrange(pooled_image_feat, '(b k) d -> b k d', b=bsz)
            
            pooled_image_feat_all.append(pooled_image_feat)
        
        pooled_image_feat_all = torch.cat(pooled_image_feat_all, dim=0)
        similarity = self.model.get_sim(pooled_image_feat_all, text_features)[0]
        return similarity.argmax(), target

    @torch.no_grad()
    def predict(self, idx, text=None):
        output, target = self.forward(idx, text)
        return output.numpy(), target