Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from collections import OrderedDict | |
from pkg_resources import packaging | |
from .simple_tokenizer import SimpleTokenizer as _Tokenizer | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
import torch.utils.checkpoint as checkpoint | |
import functools | |
logger = logging.getLogger(__name__) | |
# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K | |
MODEL_PATH = 'https://huggingface.co/laion' | |
_MODELS = { | |
"ViT-L/14": os.path.join(MODEL_PATH, "CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "vit_l14_text.pth"), | |
"ViT-B/16": os.path.join(MODEL_PATH, "CLIP-ViT-B-16-DataComp.XL-s13B-b90K", "vit_b16_text.pth"), | |
} | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16.""" | |
def forward(self, x: torch.Tensor): | |
orig_type = x.dtype | |
ret = super().forward(x.type(torch.float32)) | |
return ret.type(orig_type) | |
class QuickGELU(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class ResidualAttentionBlock(nn.Module): | |
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential(OrderedDict([ | |
("c_fc", nn.Linear(d_model, d_model * 4)), | |
("gelu", QuickGELU()), | |
("c_proj", nn.Linear(d_model * 4, d_model)) | |
])) | |
self.ln_2 = LayerNorm(d_model) | |
self.attn_mask = attn_mask | |
def attention(self, x: torch.Tensor): | |
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None | |
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |
def forward(self, x: torch.Tensor): | |
x = x + self.attention(self.ln_1(x)) | |
x = x + self.mlp(self.ln_2(x)) | |
return x | |
class Transformer(nn.Module): | |
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, | |
checkpoint_num: int = 0): | |
super().__init__() | |
self.width = width | |
self.layers = layers | |
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) | |
self.checkpoint_num = checkpoint_num | |
def forward(self, x: torch.Tensor): | |
if self.checkpoint_num > 0: | |
segments = min(self.checkpoint_num, len(self.resblocks)) | |
return checkpoint.checkpoint_sequential(self.resblocks, segments, x) | |
else: | |
return self.resblocks(x) | |
class CLIP_TEXT(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int, | |
context_length: int, | |
vocab_size: int, | |
transformer_width: int, | |
transformer_heads: int, | |
transformer_layers: int, | |
checkpoint_num: int, | |
): | |
super().__init__() | |
self.context_length = context_length | |
self._tokenizer = _Tokenizer() | |
self.transformer = Transformer( | |
width=transformer_width, | |
layers=transformer_layers, | |
heads=transformer_heads, | |
attn_mask=self.build_attention_mask(), | |
checkpoint_num=checkpoint_num, | |
) | |
self.vocab_size = vocab_size | |
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) | |
self.ln_final = LayerNorm(transformer_width) | |
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) | |
def no_weight_decay(self): | |
return {'token_embedding', 'positional_embedding'} | |
def build_attention_mask(self): | |
# lazily create causal attention mask, with full attention between the vision tokens | |
# pytorch uses additive attention mask; fill with -inf | |
mask = torch.empty(self.context_length, self.context_length) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
def tokenize(self, texts, context_length=77, truncate=True): | |
""" | |
Returns the tokenized representation of given input string(s) | |
Parameters | |
---------- | |
texts : Union[str, List[str]] | |
An input string or a list of input strings to tokenize | |
context_length : int | |
The context length to use; all CLIP models use 77 as the context length | |
truncate: bool | |
Whether to truncate the text in case its encoding is longer than the context length | |
Returns | |
------- | |
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. | |
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. | |
""" | |
if isinstance(texts, str): | |
texts = [texts] | |
sot_token = self._tokenizer.encoder["<|startoftext|>"] | |
eot_token = self._tokenizer.encoder["<|endoftext|>"] | |
all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] | |
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): | |
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |
else: | |
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) | |
for i, tokens in enumerate(all_tokens): | |
if len(tokens) > context_length: | |
if truncate: | |
tokens = tokens[:context_length] | |
tokens[-1] = eot_token | |
else: | |
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") | |
result[i, :len(tokens)] = torch.tensor(tokens) | |
return result | |
def forward(self, text): | |
x = self.token_embedding(text) # [batch_size, n_ctx, d_model] | |
x = x + self.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_final(x) | |
# x.shape = [batch_size, n_ctx, transformer.width] | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection | |
return x | |
def clip_text_b16( | |
embed_dim=512, | |
context_length=77, | |
vocab_size=49408, | |
transformer_width=512, | |
transformer_heads=8, | |
transformer_layers=12, | |
checkpoint_num=0, | |
pretrained=True, | |
): | |
# raise NotImplementedError | |
model = CLIP_TEXT( | |
embed_dim, | |
context_length, | |
vocab_size, | |
transformer_width, | |
transformer_heads, | |
transformer_layers, | |
checkpoint_num, | |
) | |
# pretrained = _MODELS["ViT-B/16"] | |
# logger.info(f"Load pretrained weights from {pretrained}") | |
# state_dict = torch.load(pretrained, map_location='cpu') | |
# model.load_state_dict(state_dict, strict=False) | |
# return model.eval() | |
if pretrained: | |
if isinstance(pretrained, str) and pretrained != "bert-base-uncased": | |
pretrained = _MODELS[pretrained] | |
else: | |
pretrained = _MODELS["ViT-B/16"] | |
logger.info(f"Load pretrained weights from {pretrained}") | |
state_dict = torch.load(pretrained, map_location='cpu') | |
if context_length != state_dict["positional_embedding"].size(0): | |
# assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length." | |
print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}") | |
if context_length < state_dict["positional_embedding"].size(0): | |
state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length] | |
else: | |
state_dict["positional_embedding"] = F.pad( | |
state_dict["positional_embedding"], | |
(0, 0, 0, context_length - state_dict["positional_embedding"].size(0)), | |
value=0, | |
) | |
message = model.load_state_dict(state_dict, strict=False) | |
print(f"Load pretrained weights from {pretrained}: {message}") | |
return model.eval() | |
def clip_text_l14( | |
embed_dim=768, | |
context_length=77, | |
vocab_size=49408, | |
transformer_width=768, | |
transformer_heads=12, | |
transformer_layers=12, | |
checkpoint_num=0, | |
pretrained=True, | |
): | |
model = CLIP_TEXT( | |
embed_dim, | |
context_length, | |
vocab_size, | |
transformer_width, | |
transformer_heads, | |
transformer_layers, | |
checkpoint_num, | |
) | |
if pretrained: | |
if isinstance(pretrained, str) and pretrained != "bert-base-uncased": | |
pretrained = _MODELS[pretrained] | |
else: | |
pretrained = _MODELS["ViT-L/14"] | |
logger.info(f"Load pretrained weights from {pretrained}") | |
state_dict = torch.load(pretrained, map_location='cpu') | |
if context_length != state_dict["positional_embedding"].size(0): | |
# assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length." | |
print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}") | |
if context_length < state_dict["positional_embedding"].size(0): | |
state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length] | |
else: | |
state_dict["positional_embedding"] = F.pad( | |
state_dict["positional_embedding"], | |
(0, 0, 0, context_length - state_dict["positional_embedding"].size(0)), | |
value=0, | |
) | |
message = model.load_state_dict(state_dict, strict=False) | |
print(f"Load pretrained weights from {pretrained}: {message}") | |
return model.eval() | |
def clip_text_l14_336( | |
embed_dim=768, | |
context_length=77, | |
vocab_size=49408, | |
transformer_width=768, | |
transformer_heads=12, | |
transformer_layers=12, | |
): | |
raise NotImplementedError | |
model = CLIP_TEXT( | |
embed_dim, | |
context_length, | |
vocab_size, | |
transformer_width, | |
transformer_heads, | |
transformer_layers | |
) | |
pretrained = _MODELS["ViT-L/14_336"] | |
logger.info(f"Load pretrained weights from {pretrained}") | |
state_dict = torch.load(pretrained, map_location='cpu') | |
model.load_state_dict(state_dict, strict=False) | |
return model.eval() | |
def build_clip(config): | |
model_cls = config.text_encoder.clip_teacher | |
model = eval(model_cls)() | |
return model | |