Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import cv2 | |
import os | |
import io | |
import torch | |
from torch import nn | |
import sys | |
from models.backbones.internvideo2 import pretrain_internvideo2_1b_patch14_224 | |
from models.backbones.bert.builder import build_bert | |
# from models.criterions import get_sim | |
from models.backbones.internvideo2.pos_embed import interpolate_pos_embed_internvideo2_new | |
from models.backbones.bert.tokenization_bert import BertTokenizer | |
def _frame_from_video(video): | |
while video.isOpened(): | |
success, frame = video.read() | |
if success: | |
yield frame | |
else: | |
break | |
v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3) | |
v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3) | |
def normalize(data): | |
return (data/255.0-v_mean)/v_std | |
def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')): | |
assert(len(vid_list) >= fnum) | |
step = len(vid_list) // fnum | |
vid_list = vid_list[::step][:fnum] | |
vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list] | |
vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list] | |
vid_tube = np.concatenate(vid_tube, axis=1) | |
vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3)) | |
vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float() | |
return vid_tube | |
def get_text_feat_dict(texts, clip, text_feat_d={}): | |
for t in texts: | |
feat = clip.get_txt_feat(t) | |
text_feat_d[t] = feat | |
return text_feat_d | |
def get_vid_feat(frames, vlm): | |
return vlm.get_vid_features(frames) | |
def retrieve_text(frames, | |
texts, | |
model, | |
topk:int=5, | |
config: dict={}, | |
device=torch.device('cuda')): | |
vlm = model | |
vlm = vlm.to(device) | |
fn = config.get('num_frames', 8) | |
size_t = config.get('size_t', 224) | |
frames_tensor = frames2tensor(frames, fnum=fn, target_size=(size_t, size_t), device=device) | |
vid_feat = vlm.get_vid_features(frames_tensor) | |
print('Video', vid_feat.mean(dim=-1)) | |
text_feat_d = {} | |
text_feat_d = get_text_feat_dict(texts, vlm, text_feat_d) | |
text_feats = [text_feat_d[t] for t in texts] | |
text_feats_tensor = torch.cat(text_feats, 0) | |
print('Text', text_feats_tensor.mean(dim=-1)) | |
probs, idxs = vlm.predict_label(vid_feat, text_feats_tensor, top=topk) | |
ret_texts = [texts[i] for i in idxs.long().numpy()[0].tolist()] | |
return ret_texts, probs.float().numpy()[0] | |
def setup_internvideo2(config: dict): | |
if "bert" in config.model.text_encoder.name: | |
tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained,) | |
model = InternVideo2_Stage2(config=config, tokenizer=tokenizer, is_pretrain=True) | |
else: | |
model = InternVideo2_Stage2(config=config, is_pretrain=True) | |
tokenizer = model.tokenizer | |
if config.get('compile_model', False): | |
torch.set_float32_matmul_precision('high') | |
model = torch.compile(model) | |
model = model.to(torch.device(config.device)) | |
model_without_ddp = model | |
if (config.pretrained_path.strip() and (os.path.isfile(config.pretrained_path)) or "s3://" in config.pretrained_path): | |
checkpoint = torch.load(config.pretrained_path, map_location="cpu") | |
try: | |
if "model" in checkpoint.keys(): | |
state_dict = checkpoint["model"] | |
else: | |
state_dict = checkpoint["module"] # This is a deepspeed stage 1 model | |
except: | |
state_dict = checkpoint | |
# Note: this was a temporary fix due to the bug caused by is_pretrain=False | |
# from collections import OrderedDict | |
# state_dict = OrderedDict({ k.replace('text_encoder.bert', 'text_encoder') : state_dict[k] for k in state_dict}) | |
if config.get('origin_num_frames', None) is not None: | |
a = len(state_dict) | |
interpolate_pos_embed_internvideo2_new(state_dict, model_without_ddp.vision_encoder, orig_t_size=config.origin_num_frames) | |
assert a == len(state_dict), state_dict.keys() | |
msg = model_without_ddp.load_state_dict(state_dict, strict=False) | |
print(f"load_state_dict: {msg}") | |
if config.get('use_bf16', False): | |
model_without_ddp = model_without_ddp.to(torch.bfloat16) | |
elif config.get('use_half_precision', False): | |
model_without_ddp = model_without_ddp.to(torch.float16) | |
else: | |
model_without_ddp = model_without_ddp.to(torch.float32) | |
return (model_without_ddp, tokenizer,) | |
class InternVideo2_Stage2(nn.Module): | |
"""docstring for InternVideo2_Stage2""" | |
def __init__(self, | |
config, | |
tokenizer, | |
is_pretrain: bool=True): | |
super(InternVideo2_Stage2, self).__init__() | |
self.config = config | |
self.tokenizer = tokenizer | |
self.is_pretrain = is_pretrain | |
self.vision_width = config.model.vision_encoder.clip_embed_dim | |
self.text_width = config.model.text_encoder.d_model | |
self.embed_dim = config.model.embed_dim | |
# create modules. | |
self.vision_encoder = self.build_vision_encoder() | |
self.freeze_vision() | |
self.text_encoder = self.build_text_encoder() | |
self.freeze_text() | |
self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) | |
self.text_proj = nn.Linear(self.text_width, self.embed_dim) | |
def freeze_vision(self): | |
"""freeze vision encoder""" | |
for p in self.vision_encoder.parameters(): | |
p.requires_grad = False | |
def freeze_text(self): | |
"""freeze text encoder""" | |
for p in self.text_encoder.parameters(): | |
p.requires_grad = False | |
def dtype(self): | |
return self.vision_encoder.patch_embed.proj.weight.dtype | |
def encode_vision(self, | |
image: torch.Tensor, | |
test: bool=False): | |
"""encode image / videos as features. | |
Args: | |
image (torch.Tensor): The input images. | |
test (bool): Whether testing. | |
Returns: tuple. | |
- vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. | |
- pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. | |
- student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C]. | |
- clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C]. | |
""" | |
T = image.shape[1] | |
use_image = True if T == 1 else False | |
image = image.permute(0, 2, 1, 3, 4).to(self.dtype) # [B,T,C,H,W] -> [B,C,T,H,W] | |
# whether save temporal dimension | |
# keep_temporal=self.config.model.vision_encoder.keep_temporal | |
if test: | |
vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder( | |
image, None, use_image) | |
return vision_embeds, pooled_vision_embeds | |
else: | |
mask, targets_clip_middle_vis, targets_clip_final_vis = self.encode_teacher(image) | |
# if mask is not None and (self.video_mask_type != 'tube' or self.image_mask_type != 'tube'): | |
# keep_temporal = False | |
# print(f"\033[31mmask is {type(mask)}\033[0m") | |
vision_embeds, pooled_vision_embeds, student_output, student_output_final = self.vision_encoder( | |
image, mask, use_image) | |
return vision_embeds, pooled_vision_embeds, student_output, student_output_final, targets_clip_middle_vis, targets_clip_final_vis | |
def encode_text(self, | |
text: dict): | |
"""encode text. | |
Args: | |
text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: | |
- input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. | |
- attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. | |
- other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". | |
Returns: tuple. | |
- text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C]. | |
- pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C]. | |
""" | |
text_output = self.get_text_encoder()( | |
text.input_ids, | |
attention_mask=text.attention_mask, | |
return_dict=True, | |
mode="text", | |
) | |
text_embeds = text_output.last_hidden_state | |
pooled_text_embeds = text_embeds[:, 0] | |
return text_embeds, pooled_text_embeds | |
def build_vision_encoder(self): | |
"""build vision encoder | |
Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`. | |
""" | |
encoder_name = self.config.model.vision_encoder.name | |
if encoder_name == 'pretrain_internvideo2_1b_patch14_224': | |
vision_encoder = pretrain_internvideo2_1b_patch14_224(self.config.model) | |
else: | |
raise ValueError(f"Not implemented: {encoder_name}") | |
# parameters for mask | |
img_size = self.config.model.vision_encoder.img_size | |
num_frames = self.config.model.vision_encoder.num_frames | |
tublet_size = self.config.model.vision_encoder.tubelet_size | |
patch_size = self.config.model.vision_encoder.patch_size | |
self.clip_img_size = self.config.model.vision_encoder.clip_input_resolution | |
self.video_mask_type = self.config.model.vision_encoder.video_mask_type | |
self.video_window_size = (num_frames // tublet_size, img_size // patch_size, img_size // patch_size) | |
self.video_mask_ratio = self.config.model.vision_encoder.video_mask_ratio | |
self.image_mask_type = self.config.model.vision_encoder.image_mask_type | |
self.image_window_size = (1, img_size // patch_size, img_size // patch_size) | |
self.image_mask_ratio = self.config.model.vision_encoder.image_mask_ratio | |
return vision_encoder | |
def build_text_encoder(self): | |
"""build text_encoder and possiblly video-to-text multimodal fusion encoder. | |
Returns: nn.Module. The text encoder | |
""" | |
encoder_name = self.config.model.text_encoder.name | |
if "bert" in encoder_name: | |
text_encoder = build_bert( | |
self.config.model, | |
self.is_pretrain, | |
self.config.gradient_checkpointing, | |
) | |
else: | |
raise ValueError(f"Not implemented: {encoder_name}") | |
return text_encoder | |
def get_text_encoder(self): | |
"""get text encoder, used for text and cross-modal encoding""" | |
encoder = self.text_encoder | |
return encoder.bert if hasattr(encoder, "bert") else encoder | |
def get_vid_features(self, | |
frames: torch.Tensor): | |
"""get the video features for the given frames. | |
Args: | |
frames (torch.Tensor): The input frames. Shape: [B,T,C,H,W]. | |
Returns: tuple. | |
- vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. | |
- pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. | |
""" | |
with torch.no_grad(): | |
_, vfeat = self.encode_vision(frames, test=True) | |
vfeat = self.vision_proj(vfeat) | |
vfeat /= vfeat.norm(dim=-1, keepdim=True) | |
return vfeat | |
def get_txt_feat(self, | |
text: str): | |
"""get the text features for the given text.""" | |
device = next(self.parameters()).device | |
with torch.no_grad(): | |
text = self.tokenizer( | |
text, | |
padding="max_length", | |
truncation=True, | |
max_length=self.config.max_txt_l, | |
return_tensors="pt",).to(device) | |
_, tfeat = self.encode_text(text) | |
tfeat = self.text_proj(tfeat) | |
tfeat /= tfeat.norm(dim=-1, keepdim=True) | |
return tfeat | |
def predict_label(self, | |
vid_feat: torch.Tensor, | |
txt_feat: torch.Tensor, | |
top: int=5): | |
label_probs = (100.0 * vid_feat @ txt_feat.T).softmax(dim=-1) | |
top_probs, top_labels = label_probs.float().cpu().topk(top, dim=-1) | |
return top_probs, top_labels | |