mazpie's picture
Initial commit
2d9a728
raw
history blame
10.9 kB
import os
import logging
from collections import OrderedDict
from pkg_resources import packaging
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import torch.utils.checkpoint as checkpoint
import functools
logger = logging.getLogger(__name__)
# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
MODEL_PATH = 'https://huggingface.co/laion'
_MODELS = {
"ViT-L/14": os.path.join(MODEL_PATH, "CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "vit_l14_text.pth"),
"ViT-B/16": os.path.join(MODEL_PATH, "CLIP-ViT-B-16-DataComp.XL-s13B-b90K", "vit_b16_text.pth"),
}
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None,
checkpoint_num: int = 0):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
self.checkpoint_num = checkpoint_num
def forward(self, x: torch.Tensor):
if self.checkpoint_num > 0:
segments = min(self.checkpoint_num, len(self.resblocks))
return checkpoint.checkpoint_sequential(self.resblocks, segments, x)
else:
return self.resblocks(x)
class CLIP_TEXT(nn.Module):
def __init__(
self,
embed_dim: int,
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
checkpoint_num: int,
):
super().__init__()
self.context_length = context_length
self._tokenizer = _Tokenizer()
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
checkpoint_num=checkpoint_num,
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
def no_weight_decay(self):
return {'token_embedding', 'positional_embedding'}
@functools.lru_cache(maxsize=None)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def tokenize(self, texts, context_length=77, truncate=True):
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]
sot_token = self._tokenizer.encoder["<|startoftext|>"]
eot_token = self._tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def forward(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def clip_text_b16(
embed_dim=512,
context_length=77,
vocab_size=49408,
transformer_width=512,
transformer_heads=8,
transformer_layers=12,
checkpoint_num=0,
pretrained=True,
):
# raise NotImplementedError
model = CLIP_TEXT(
embed_dim,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
checkpoint_num,
)
# pretrained = _MODELS["ViT-B/16"]
# logger.info(f"Load pretrained weights from {pretrained}")
# state_dict = torch.load(pretrained, map_location='cpu')
# model.load_state_dict(state_dict, strict=False)
# return model.eval()
if pretrained:
if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
pretrained = _MODELS[pretrained]
else:
pretrained = _MODELS["ViT-B/16"]
logger.info(f"Load pretrained weights from {pretrained}")
state_dict = torch.load(pretrained, map_location='cpu')
if context_length != state_dict["positional_embedding"].size(0):
# assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
if context_length < state_dict["positional_embedding"].size(0):
state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
else:
state_dict["positional_embedding"] = F.pad(
state_dict["positional_embedding"],
(0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
value=0,
)
message = model.load_state_dict(state_dict, strict=False)
print(f"Load pretrained weights from {pretrained}: {message}")
return model.eval()
def clip_text_l14(
embed_dim=768,
context_length=77,
vocab_size=49408,
transformer_width=768,
transformer_heads=12,
transformer_layers=12,
checkpoint_num=0,
pretrained=True,
):
model = CLIP_TEXT(
embed_dim,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
checkpoint_num,
)
if pretrained:
if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
pretrained = _MODELS[pretrained]
else:
pretrained = _MODELS["ViT-L/14"]
logger.info(f"Load pretrained weights from {pretrained}")
state_dict = torch.load(pretrained, map_location='cpu')
if context_length != state_dict["positional_embedding"].size(0):
# assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
if context_length < state_dict["positional_embedding"].size(0):
state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
else:
state_dict["positional_embedding"] = F.pad(
state_dict["positional_embedding"],
(0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
value=0,
)
message = model.load_state_dict(state_dict, strict=False)
print(f"Load pretrained weights from {pretrained}: {message}")
return model.eval()
def clip_text_l14_336(
embed_dim=768,
context_length=77,
vocab_size=49408,
transformer_width=768,
transformer_heads=12,
transformer_layers=12,
):
raise NotImplementedError
model = CLIP_TEXT(
embed_dim,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers
)
pretrained = _MODELS["ViT-L/14_336"]
logger.info(f"Load pretrained weights from {pretrained}")
state_dict = torch.load(pretrained, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
return model.eval()
def build_clip(config):
model_cls = config.text_encoder.clip_teacher
model = eval(model_cls)()
return model