Maikou's picture
all files first commit
9c3a994
# -*- coding: utf-8 -*-
import os
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import CLIPModel, CLIPTokenizer
from collections import OrderedDict
from michelangelo.data.transforms import RandomResize
class AbstractEncoder(nn.Module):
embedding_dim: int
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key="class"):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
def forward(self, batch, key=None):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
c = self.embedding(c)
return c
class FrozenCLIPTextEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(
self,
version="openai/clip-vit-large-patch14",
tokenizer_version=None,
device="cuda",
max_length=77,
zero_embedding_radio: float = 0.1,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
self.device = device
self.max_length = max_length
self.zero_embedding_radio = zero_embedding_radio
self.clip_dict = OrderedDict()
self.clip_name = os.path.split(version)[-1]
transformer = CLIPModel.from_pretrained(version).text_model
for param in transformer.parameters():
param.requires_grad = False
self.clip_dict[self.clip_name] = transformer
self._move_flag = False
@property
def clip(self):
return self.clip_dict[self.clip_name]
def move(self):
if self._move_flag:
return
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
self._move_flag = True
def unconditional_embedding(self, batch_size):
empty_text = [""] * batch_size
empty_z = self.forward(empty_text)
return empty_z
def forward(self, text):
self.move()
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.clip(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
batch_size = len(text)
batch_mask = torch.rand((batch_size,))
for i in range(batch_size):
if batch_mask[i] < self.zero_embedding_radio:
text[i] = ""
return self(text)
class FrozenAlignedCLIPTextEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(
self,
version="openai/clip-vit-large-patch14",
tokenizer_version=None,
device="cuda",
max_length=77,
zero_embedding_radio: float = 0.1,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
self.device = device
self.max_length = max_length
self.zero_embedding_radio = zero_embedding_radio
self.clip_dict = OrderedDict()
self.clip_name = os.path.split(version)[-1]
transformer = CLIPModel.from_pretrained(version).text_model
for param in transformer.parameters():
param.requires_grad = False
self.clip_dict[self.clip_name] = transformer
self._move_flag = False
@property
def clip(self):
return self.clip_dict[self.clip_name]
def move(self):
if self._move_flag:
return
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
self._move_flag = True
def unconditional_embedding(self, batch_size):
empty_text = [""] * batch_size
empty_z = self.forward(empty_text)
return empty_z
def forward(self, text):
self.move()
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.clip(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
batch_size = len(text)
batch_mask = torch.rand((batch_size,))
for i in range(batch_size):
if batch_mask[i] < self.zero_embedding_radio:
text[i] = ""
return self(text)
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
zero_embedding_radio=0.1,
normalize_embedding=True,
num_projection_vector=0,
linear_mapping_bias=True,
reverse_visual_projection=False,
):
super().__init__()
self.device = device
self.clip_dict = OrderedDict()
self.clip_name = os.path.split(version)[-1]
clip_model = CLIPModel.from_pretrained(version)
clip_model.text_model = None
clip_model.text_projection = None
clip_model = clip_model.eval()
for param in self.parameters():
param.requires_grad = False
self.clip_dict[self.clip_name] = clip_model
self.transform = transforms.Compose(
[
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(224), # crop a (224, 224) square
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
self.zero_embedding_radio = zero_embedding_radio
self.num_projection_vector = num_projection_vector
self.reverse_visual_projection = reverse_visual_projection
self.normalize_embedding = normalize_embedding
embedding_dim = (
clip_model.visual_projection.in_features
if reverse_visual_projection
else clip_model.visual_projection.out_features
)
self.embedding_dim = embedding_dim
if self.num_projection_vector > 0:
self.projection = nn.Linear(
embedding_dim,
clip_model.visual_projection.out_features * num_projection_vector,
bias=linear_mapping_bias,
)
nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5)
self._move_flag = False
@property
def clip(self):
return self.clip_dict[self.clip_name]
def unconditional_embedding(self, batch_size):
zero = torch.zeros(
batch_size,
1,
self.embedding_dim,
device=self.device,
dtype=self.clip.visual_projection.weight.dtype,
)
if self.num_projection_vector > 0:
zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
return zero
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
if self.reverse_visual_projection:
z = self.clip.vision_model(self.transform(image))[1]
else:
z = self.clip.get_image_features(self.transform(image))
if self.normalize_embedding:
z = z / z.norm(dim=-1, keepdim=True)
if z.ndim == 2:
z = z.unsqueeze(dim=-2)
if zero_embedding_radio > 0:
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio
z = z * mask.to(z)
if self.num_projection_vector > 0:
z = self.projection(z).view(len(image), self.num_projection_vector, -1)
return z
def move(self):
if self._move_flag:
return
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
self._move_flag = True
def encode(self, image):
self.move()
return self(image, zero_embedding_radio=self.zero_embedding_radio)
class FrozenCLIPImageGridEmbedder(AbstractEncoder):
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
zero_embedding_radio=0.1,
):
super().__init__()
self.device = device
self.clip_dict = OrderedDict()
self.clip_name = os.path.split(version)[-1]
clip_model: CLIPModel = CLIPModel.from_pretrained(version)
clip_model.text_model = None
clip_model.text_projection = None
clip_model = clip_model.eval()
for param in self.parameters():
param.requires_grad = False
self.clip_dict[self.clip_name] = clip_model
self.transform = transforms.Compose(
[
transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True),
transforms.CenterCrop(224), # crop a (224, 224) square
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
self.zero_embedding_radio = zero_embedding_radio
self.embedding_dim = clip_model.vision_embed_dim
self._move_flag = False
@property
def clip(self):
return self.clip_dict[self.clip_name]
def move(self):
if self._move_flag:
return
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
self._move_flag = True
def unconditional_embedding(self, batch_size):
zero = torch.zeros(
batch_size,
self.clip.vision_model.embeddings.num_positions,
self.embedding_dim,
device=self.device,
dtype=self.clip.visual_projection.weight.dtype,
)
return zero
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
self.move()
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
z = self.clip.vision_model(self.transform(image)).last_hidden_state
if zero_embedding_radio > 0:
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
z = z * mask.to(z)
return z
def encode(self, image):
return self(image, zero_embedding_radio=self.zero_embedding_radio)
class MoECLIPImageEncoder(nn.Module):
def __init__(
self,
versions,
hidden_state_dim,
num_projection_vector=8,
zero_embedding_radio=0.1,
device="cuda",
precision="fp16",
normalize=False,
clip_max=0,
transform_type="base",
argument_p=0.2,
):
super().__init__()
self.device = torch.device(device)
self.hidden_state_dim = hidden_state_dim
self.zero_embedding_radio = zero_embedding_radio
self.num_projection_vector = num_projection_vector
self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision]
self.normalize = normalize
self.clip_max = clip_max
if transform_type == "base":
self.transform = transforms.Compose(
[
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(224), # crop a (224, 224) square
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
elif transform_type == "crop_blur_resize":
self.transform = transforms.Compose(
[
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(224), # crop a (224, 224) square
transforms.RandomApply(
transforms=[
transforms.RandomResizedCrop(
size=224,
scale=(0.8, 1.0),
ratio=(0.99, 1.01),
interpolation=transforms.InterpolationMode.BICUBIC,
),
],
p=argument_p,
),
transforms.RandomApply(
transforms=[
transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)),
],
p=argument_p,
),
transforms.RandomApply(
transforms=[
RandomResize(size=224, resize_radio=(0.2, 1)),
],
p=argument_p,
),
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
else:
raise ValueError(f"invalid {transform_type=}")
if isinstance(versions, str):
versions = (versions,)
# 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16
clips = OrderedDict()
for v in versions:
# 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。
clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None)
delattr(clips[v], "transformer")
clips[v].eval()
clips[v].requires_grad_(False)
self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips)
if self.num_projection_vector == 0:
self.projection = nn.Identity()
else:
self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True)
self.projection.to(dtype=self.dtype)
nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5)
self.clips = clips
self._move_flag = False
def move(self):
if self._move_flag:
return
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.type(self.dtype)
if l.bias is not None:
l.bias.data = l.bias.data.type(self.dtype)
if isinstance(l, nn.MultiheadAttention):
for attr in [
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
"in_proj_bias",
"bias_k",
"bias_v",
]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.type(self.dtype)
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.type(self.dtype)
model.apply(_convert_weights_to_fp16)
for k in self.clips:
self.clips[k].to(self.device)
convert_weights(self.clips[k]) # fp32 -> self.dtype
self._move_flag = True
def unconditional_embedding(self, batch_size=None):
zero = torch.zeros(
batch_size,
self.clips_hidden_dim,
device=self.device,
dtype=self.dtype,
)
if self.num_projection_vector > 0:
zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
return zero
def convert_embedding(self, z):
if self.num_projection_vector > 0:
z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1)
return z
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
image = self.transform(image)
with torch.no_grad():
embs = []
for v in self.clips:
x = self.clips[v].encode_image(image)
if self.normalize:
x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5)
# clip_max only works with normalization
if self.clip_max > 0:
x = x.clamp(-self.clip_max, self.clip_max)
embs.append(x)
z = torch.cat(embs, dim=-1)
if self.normalize:
z /= z.size(-1) ** 0.5
if zero_embedding_radio > 0:
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
z = z + mask.to(z)
if self.num_projection_vector > 0:
z = self.projection(z).view(len(image), self.num_projection_vector, -1)
return z
def encode(self, image):
self.move()
return self(image, zero_embedding_radio=self.zero_embedding_radio)