Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,410 Bytes
2d9a728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
import os
import logging
import torch
from einops import rearrange
from torch import nn
import math
# from .criterions import VTC_VTM_Loss
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
from .viclip_vision import clip_joint_l14, clip_joint_b16
from .viclip_text import clip_text_l14, clip_text_b16
logger = logging.getLogger(__name__)
class ViCLIP(nn.Module):
"""docstring for ViCLIP"""
def __init__(self,
tokenizer=None,
size='l',
pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth"),
freeze_text=True):
super(ViCLIP, self).__init__()
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = _Tokenizer()
self.max_txt_l = 32
if size.lower() == 'l':
self.vision_encoder_name = 'vit_l14'
elif size.lower() == 'b':
self.vision_encoder_name = 'vit_b16'
else:
raise NotImplementedError(f"Size {size} not implemented")
self.vision_encoder_pretrained = False
self.inputs_image_res = 224
self.vision_encoder_kernel_size = 1
self.vision_encoder_center = True
self.video_input_num_frames = 8
self.vision_encoder_drop_path_rate = 0.1
self.vision_encoder_checkpoint_num = 24
self.is_pretrain = pretrain
self.vision_width = 1024
self.text_width = 768
self.embed_dim = 768
self.masking_prob = 0.9
if size.lower() == 'l':
self.text_encoder_name = 'vit_l14'
elif size.lower() == 'b':
self.text_encoder_name = 'vit_b16'
else:
raise NotImplementedError(f"Size {size} not implemented")
self.text_encoder_pretrained = False#'bert-base-uncased'
self.text_encoder_d_model = 768
self.text_encoder_vocab_size = 49408
# create modules.
self.vision_encoder = self.build_vision_encoder()
self.text_encoder = self.build_text_encoder()
self.temp = nn.parameter.Parameter(torch.ones([]) * 1 / 100.0)
self.temp_min = 1 / 100.0
if pretrain:
logger.info(f"Load pretrained weights from {pretrain}")
state_dict = torch.load(pretrain, map_location='cpu')['model']
self.load_state_dict(state_dict)
# Freeze weights
if freeze_text:
self.freeze_text()
def freeze_text(self):
"""freeze text encoder"""
for p in self.text_encoder.parameters():
p.requires_grad = False
def no_weight_decay(self):
ret = {"temp"}
ret.update(
{"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
)
ret.update(
{"text_encoder." + k for k in self.text_encoder.no_weight_decay()}
)
return ret
def forward(self, image, text, raw_text, idx, log_generation=None, return_sims=False):
"""forward and calculate loss.
Args:
image (torch.Tensor): The input images. Shape: [B,T,C,H,W].
text (dict): TODO
idx (torch.Tensor): TODO
Returns: TODO
"""
self.clip_contrastive_temperature()
vision_embeds = self.encode_vision(image)
text_embeds = self.encode_text(raw_text)
if return_sims:
sims = torch.nn.functional.normalize(vision_embeds, dim=-1) @ \
torch.nn.functional.normalize(text_embeds, dim=-1).transpose(0, 1)
return sims
# calculate loss
## VTC loss
loss_vtc = self.clip_loss.vtc_loss(
vision_embeds, text_embeds, idx, self.temp, all_gather=True
)
return dict(
loss_vtc=loss_vtc,
)
def encode_vision(self, image, test=False):
"""encode image / videos as features.
Args:
image (torch.Tensor): The input images.
test (bool): Whether testing.
Returns: tuple.
- vision_embeds (torch.Tensor): The features of all patches. Shape: [B,T,L,C].
- pooled_vision_embeds (torch.Tensor): The pooled features. Shape: [B,T,C].
"""
if image.ndim == 5:
image = image.permute(0, 2, 1, 3, 4).contiguous()
else:
image = image.unsqueeze(2)
if not test and self.masking_prob > 0.0:
return self.vision_encoder(
image, masking_prob=self.masking_prob
)
return self.vision_encoder(image)
def encode_text(self, text):
"""encode text.
Args:
text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
- input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
- attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
- other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
Returns: tuple.
- text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].
- pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].
"""
device = next(self.text_encoder.parameters()).device
text = self.text_encoder.tokenize(
text, context_length=self.max_txt_l
).to(device)
text_embeds = self.text_encoder(text)
return text_embeds
@torch.no_grad()
def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
"""Seems only used during pre-training"""
self.temp.clamp_(min=self.temp_min)
def build_vision_encoder(self):
"""build vision encoder
Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
"""
encoder_name = self.vision_encoder_name
if encoder_name == "vit_l14":
vision_encoder = clip_joint_l14(
pretrained=self.vision_encoder_pretrained,
input_resolution=self.inputs_image_res,
kernel_size=self.vision_encoder_kernel_size,
center=self.vision_encoder_center,
num_frames=self.video_input_num_frames,
drop_path=self.vision_encoder_drop_path_rate,
checkpoint_num=self.vision_encoder_checkpoint_num,
)
elif encoder_name == "vit_b16":
vision_encoder = clip_joint_b16(
pretrained=self.vision_encoder_pretrained,
input_resolution=self.inputs_image_res,
kernel_size=self.vision_encoder_kernel_size,
center=self.vision_encoder_center,
num_frames=self.video_input_num_frames,
drop_path=self.vision_encoder_drop_path_rate,
checkpoint_num=self.vision_encoder_checkpoint_num,
)
else:
raise NotImplementedError(f"Not implemented: {encoder_name}")
return vision_encoder
def build_text_encoder(self):
"""build text_encoder and possiblly video-to-text multimodal fusion encoder.
Returns: nn.Module. The text encoder
"""
encoder_name = self.text_encoder_name
if encoder_name == "vit_l14":
text_encoder = clip_text_l14(
pretrained=self.text_encoder_pretrained,
context_length=self.max_txt_l,
vocab_size=self.text_encoder_vocab_size,
checkpoint_num=0,
)
elif encoder_name == "vit_b16":
text_encoder = clip_text_b16(
pretrained=self.text_encoder_pretrained,
context_length=self.max_txt_l,
vocab_size=self.text_encoder_vocab_size,
checkpoint_num=0,
)
else:
raise NotImplementedError(f"Not implemented: {encoder_name}")
return text_encoder
def get_text_encoder(self):
"""get text encoder, used for text and cross-modal encoding"""
encoder = self.text_encoder
return encoder.bert if hasattr(encoder, "bert") else encoder
def get_text_features(self, input_text, tokenizer, text_feature_dict={}):
if input_text in text_feature_dict:
return text_feature_dict[input_text]
text_template= f"{input_text}"
with torch.no_grad():
# text_token = tokenizer.encode(text_template).cuda()
text_features = self.encode_text(text_template).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_feature_dict[input_text] = text_features
return text_features
def get_vid_features(self, input_frames):
with torch.no_grad():
clip_feat = self.encode_vision(input_frames,test=True).float()
clip_feat /= clip_feat.norm(dim=-1, keepdim=True)
return clip_feat
def get_predict_label(self, clip_feature, text_feats_tensor, top=5):
label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1)
top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
return top_probs, top_labels
if __name__ =="__main__":
tokenizer = _Tokenizer()
|