Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import random | |
import kiui | |
from kiui.op import recenter | |
import torchvision | |
import torchvision.transforms.v2 | |
from contextlib import nullcontext | |
from functools import partial | |
from typing import Dict, List, Optional, Tuple, Union | |
from pdb import set_trace as st | |
import kornia | |
import numpy as np | |
import open_clip | |
import torch | |
import torch.nn as nn | |
from einops import rearrange, repeat | |
from omegaconf import ListConfig | |
from torch.utils.checkpoint import checkpoint | |
from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer, | |
T5EncoderModel, T5Tokenizer) | |
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer | |
from ...modules.diffusionmodules.model import Encoder | |
from ...modules.diffusionmodules.openaimodel import Timestep | |
from ...modules.diffusionmodules.util import (extract_into_tensor, | |
make_beta_schedule) | |
from ...modules.distributions.distributions import DiagonalGaussianDistribution | |
from ...util import (append_dims, autocast, count_params, default, | |
disabled_train, expand_dims_like, instantiate_from_config) | |
from dit.dit_models_xformers import CaptionEmbedder, approx_gelu, t2i_modulate | |
class AbstractEmbModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self._is_trainable = None | |
self._ucg_rate = None | |
self._input_key = None | |
def is_trainable(self) -> bool: | |
return self._is_trainable | |
def ucg_rate(self) -> Union[float, torch.Tensor]: | |
return self._ucg_rate | |
def input_key(self) -> str: | |
return self._input_key | |
def is_trainable(self, value: bool): | |
self._is_trainable = value | |
def ucg_rate(self, value: Union[float, torch.Tensor]): | |
self._ucg_rate = value | |
def input_key(self, value: str): | |
self._input_key = value | |
def is_trainable(self): | |
del self._is_trainable | |
def ucg_rate(self): | |
del self._ucg_rate | |
def input_key(self): | |
del self._input_key | |
class GeneralConditioner(nn.Module): | |
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} | |
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} | |
def __init__(self, emb_models: Union[List, ListConfig]): | |
super().__init__() | |
embedders = [] | |
for n, embconfig in enumerate(emb_models): | |
embedder = instantiate_from_config(embconfig) | |
assert isinstance( | |
embedder, AbstractEmbModel | |
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" | |
embedder.is_trainable = embconfig.get("is_trainable", False) | |
embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) | |
if not embedder.is_trainable: | |
embedder.train = disabled_train | |
for param in embedder.parameters(): | |
param.requires_grad = False | |
embedder.eval() | |
print( | |
f"Initialized embedder #{n}: {embedder.__class__.__name__} " | |
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" | |
) | |
if "input_key" in embconfig: | |
embedder.input_key = embconfig["input_key"] | |
elif "input_keys" in embconfig: | |
embedder.input_keys = embconfig["input_keys"] | |
else: | |
raise KeyError( | |
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" | |
) | |
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) | |
if embedder.legacy_ucg_val is not None: | |
embedder.ucg_prng = np.random.RandomState() | |
embedders.append(embedder) | |
self.embedders = nn.ModuleList(embedders) | |
def possibly_get_ucg_val(self, embedder: AbstractEmbModel, | |
batch: Dict) -> Dict: | |
assert embedder.legacy_ucg_val is not None | |
p = embedder.ucg_rate | |
val = embedder.legacy_ucg_val | |
for i in range(len(batch[embedder.input_key])): | |
if embedder.ucg_prng.choice(2, p=[1 - p, p]): | |
batch[embedder.input_key][i] = val | |
return batch | |
def forward(self, | |
batch: Dict, | |
force_zero_embeddings: Optional[List] = None) -> Dict: | |
output = dict() | |
if force_zero_embeddings is None: | |
force_zero_embeddings = [] | |
for embedder in self.embedders: | |
embedding_context = nullcontext if embedder.is_trainable else torch.no_grad | |
with embedding_context(): | |
if hasattr(embedder, "input_key") and (embedder.input_key | |
is not None): | |
if embedder.legacy_ucg_val is not None: | |
batch = self.possibly_get_ucg_val(embedder, batch) | |
emb_out = embedder(batch[embedder.input_key]) | |
elif hasattr(embedder, "input_keys"): | |
emb_out = embedder( | |
*[batch[k] for k in embedder.input_keys]) | |
assert isinstance( | |
emb_out, (torch.Tensor, list, tuple) | |
), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" | |
if not isinstance(emb_out, (list, tuple)): | |
emb_out = [emb_out] | |
for emb in emb_out: | |
out_key = self.OUTPUT_DIM2KEYS[emb.dim()] | |
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: | |
emb = (expand_dims_like( | |
torch.bernoulli( | |
(1.0 - embedder.ucg_rate) * | |
torch.ones(emb.shape[0], device=emb.device)), | |
emb, | |
) * emb) | |
if (hasattr(embedder, "input_key") | |
and embedder.input_key in force_zero_embeddings): | |
emb = torch.zeros_like(emb) | |
if out_key in output: | |
output[out_key] = torch.cat((output[out_key], emb), | |
self.KEY2CATDIM[out_key]) | |
else: | |
output[out_key] = emb | |
return output | |
def get_unconditional_conditioning( | |
self, | |
batch_c: Dict, | |
batch_uc: Optional[Dict] = None, | |
force_uc_zero_embeddings: Optional[List[str]] = None, | |
force_cond_zero_embeddings: Optional[List[str]] = None, | |
): | |
if force_uc_zero_embeddings is None: | |
force_uc_zero_embeddings = [] | |
ucg_rates = list() | |
for embedder in self.embedders: | |
ucg_rates.append(embedder.ucg_rate) | |
embedder.ucg_rate = 0.0 # ! force no drop during inference | |
c = self(batch_c, force_cond_zero_embeddings) | |
uc = self(batch_c if batch_uc is None else batch_uc, | |
force_uc_zero_embeddings) | |
for embedder, rate in zip(self.embedders, ucg_rates): | |
embedder.ucg_rate = rate | |
return c, uc | |
class InceptionV3(nn.Module): | |
"""Wrapper around the https://github.com/mseitzer/pytorch-fid inception | |
port with an additional squeeze at the end""" | |
def __init__(self, normalize_input=False, **kwargs): | |
super().__init__() | |
from pytorch_fid import inception | |
kwargs["resize_input"] = True | |
self.model = inception.InceptionV3(normalize_input=normalize_input, | |
**kwargs) | |
def forward(self, inp): | |
outp = self.model(inp) | |
if len(outp) == 1: | |
return outp[0].squeeze() | |
return outp | |
class IdentityEncoder(AbstractEmbModel): | |
def encode(self, x): | |
return x | |
def forward(self, x): | |
return x | |
class ClassEmbedder(AbstractEmbModel): | |
def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): | |
super().__init__() | |
self.embedding = nn.Embedding(n_classes, embed_dim) | |
self.n_classes = n_classes | |
self.add_sequence_dim = add_sequence_dim | |
def forward(self, c): | |
c = self.embedding(c) | |
if self.add_sequence_dim: | |
c = c[:, None, :] | |
return c | |
def get_unconditional_conditioning(self, bs, device="cuda"): | |
uc_class = ( | |
self.n_classes - 1 | |
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) | |
uc = torch.ones((bs, ), device=device) * uc_class | |
uc = {self.key: uc.long()} | |
return uc | |
class ClassEmbedderForMultiCond(ClassEmbedder): | |
def forward(self, batch, key=None, disable_dropout=False): | |
out = batch | |
key = default(key, self.key) | |
islist = isinstance(batch[key], list) | |
if islist: | |
batch[key] = batch[key][0] | |
c_out = super().forward(batch, key, disable_dropout) | |
out[key] = [c_out] if islist else c_out | |
return out | |
class FrozenT5Embedder(AbstractEmbModel): | |
"""Uses the T5 transformer encoder for text""" | |
def __init__(self, | |
version="google/t5-v1_1-xxl", | |
device="cuda", | |
max_length=77, | |
freeze=True | |
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl | |
super().__init__() | |
self.tokenizer = T5Tokenizer.from_pretrained(version) | |
self.transformer = T5EncoderModel.from_pretrained(version) | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
def freeze(self): | |
self.transformer = self.transformer.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
batch_encoding = self.tokenizer( | |
text, | |
truncation=True, | |
max_length=self.max_length, | |
return_length=True, | |
return_overflowing_tokens=False, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
tokens = batch_encoding["input_ids"].to(self.device) | |
with torch.autocast("cuda", enabled=False): | |
outputs = self.transformer(input_ids=tokens) | |
z = outputs.last_hidden_state | |
return z | |
def encode(self, text): | |
return self(text) | |
class FrozenByT5Embedder(AbstractEmbModel): | |
""" | |
Uses the ByT5 transformer encoder for text. Is character-aware. | |
""" | |
def __init__(self, | |
version="google/byt5-base", | |
device="cuda", | |
max_length=77, | |
freeze=True | |
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl | |
super().__init__() | |
self.tokenizer = ByT5Tokenizer.from_pretrained(version) | |
self.transformer = T5EncoderModel.from_pretrained(version) | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
def freeze(self): | |
self.transformer = self.transformer.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
batch_encoding = self.tokenizer( | |
text, | |
truncation=True, | |
max_length=self.max_length, | |
return_length=True, | |
return_overflowing_tokens=False, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
tokens = batch_encoding["input_ids"].to(self.device) | |
with torch.autocast("cuda", enabled=False): | |
outputs = self.transformer(input_ids=tokens) | |
z = outputs.last_hidden_state | |
return z | |
def encode(self, text): | |
return self(text) | |
class FrozenCLIPEmbedder(AbstractEmbModel): | |
"""Uses the CLIP transformer encoder for text (from huggingface)""" | |
LAYERS = ["last", "pooled", "hidden"] | |
def __init__( | |
self, | |
version="openai/clip-vit-large-patch14", | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
layer="last", | |
layer_idx=None, | |
always_return_pooled=False, | |
): # clip-vit-base-patch32 | |
super().__init__() | |
assert layer in self.LAYERS | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.transformer = CLIPTextModel.from_pretrained(version) | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.layer = layer | |
self.layer_idx = layer_idx | |
self.return_pooled = always_return_pooled | |
if layer == "hidden": | |
assert layer_idx is not None | |
assert 0 <= abs(layer_idx) <= 12 | |
def freeze(self): | |
self.transformer = self.transformer.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
batch_encoding = self.tokenizer( | |
text, | |
truncation=True, | |
max_length=self.max_length, | |
return_length=True, | |
return_overflowing_tokens=False, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
tokens = batch_encoding["input_ids"].to(self.device) | |
outputs = self.transformer(input_ids=tokens, | |
output_hidden_states=self.layer == "hidden") | |
if self.layer == "last": | |
z = outputs.last_hidden_state | |
elif self.layer == "pooled": | |
z = outputs.pooler_output[:, None, :] | |
else: | |
z = outputs.hidden_states[self.layer_idx] | |
if self.return_pooled: | |
return z, outputs.pooler_output | |
return z | |
def encode(self, text): | |
return self(text) | |
class FrozenOpenCLIPEmbedder2(AbstractEmbModel): | |
""" | |
Uses the OpenCLIP transformer encoder for text | |
""" | |
LAYERS = ["pooled", "last", "penultimate"] | |
def __init__( | |
self, | |
arch="ViT-H-14", | |
version="laion2b_s32b_b79k", | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
layer="last", | |
always_return_pooled=False, | |
legacy=True, | |
): | |
super().__init__() | |
assert layer in self.LAYERS | |
model, _, _ = open_clip.create_model_and_transforms( | |
arch, | |
device=torch.device("cpu"), | |
pretrained=version, | |
) | |
del model.visual | |
self.model = model | |
self.device = device | |
self.max_length = max_length | |
self.return_pooled = always_return_pooled | |
if freeze: | |
self.freeze() | |
self.layer = layer | |
if self.layer == "last": | |
self.layer_idx = 0 | |
elif self.layer == "penultimate": | |
self.layer_idx = 1 | |
else: | |
raise NotImplementedError() | |
self.legacy = legacy | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
tokens = open_clip.tokenize(text) | |
z = self.encode_with_transformer(tokens.to(self.device)) | |
if not self.return_pooled and self.legacy: | |
return z | |
if self.return_pooled: | |
assert not self.legacy | |
return z[self.layer], z["pooled"] | |
return z[self.layer] | |
def encode_with_transformer(self, text): | |
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] | |
x = x + self.model.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) | |
if self.legacy: | |
x = x[self.layer] | |
x = self.model.ln_final(x) | |
return x | |
else: | |
# x is a dict and will stay a dict | |
o = x["last"] | |
o = self.model.ln_final(o) | |
pooled = self.pool(o, text) | |
x["pooled"] = pooled | |
return x | |
def pool(self, x, text): | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = (x[torch.arange(x.shape[0]), | |
text.argmax(dim=-1)] @ self.model.text_projection) | |
return x | |
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): | |
outputs = {} | |
for i, r in enumerate(self.model.transformer.resblocks): | |
if i == len(self.model.transformer.resblocks) - 1: | |
outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD | |
if (self.model.transformer.grad_checkpointing | |
and not torch.jit.is_scripting()): | |
x = checkpoint(r, x, attn_mask) | |
else: | |
x = r(x, attn_mask=attn_mask) | |
outputs["last"] = x.permute(1, 0, 2) # LND -> NLD | |
return outputs | |
def encode(self, text): | |
return self(text) | |
class FrozenOpenCLIPEmbedder(AbstractEmbModel): | |
LAYERS = [ | |
# "pooled", | |
"last", | |
"penultimate", | |
] | |
def __init__( | |
self, | |
arch="ViT-H-14", | |
version="laion2b_s32b_b79k", | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
layer="last", | |
): | |
super().__init__() | |
assert layer in self.LAYERS | |
model, _, _ = open_clip.create_model_and_transforms( | |
arch, device=torch.device("cpu"), pretrained=version) | |
del model.visual | |
self.model = model | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.layer = layer | |
if self.layer == "last": | |
self.layer_idx = 0 | |
elif self.layer == "penultimate": | |
self.layer_idx = 1 | |
else: | |
raise NotImplementedError() | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
tokens = open_clip.tokenize(text) | |
z = self.encode_with_transformer(tokens.to(self.device)) | |
return z | |
def encode_with_transformer(self, text): | |
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] | |
x = x + self.model.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.model.ln_final(x) | |
return x | |
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): | |
for i, r in enumerate(self.model.transformer.resblocks): | |
if i == len(self.model.transformer.resblocks) - self.layer_idx: | |
break | |
if (self.model.transformer.grad_checkpointing | |
and not torch.jit.is_scripting()): | |
x = checkpoint(r, x, attn_mask) | |
else: | |
x = r(x, attn_mask=attn_mask) | |
return x | |
def encode(self, text): | |
return self(text) | |
class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): | |
""" | |
Uses the OpenCLIP vision transformer encoder for images | |
""" | |
def __init__( | |
self, | |
# arch="ViT-H-14", | |
# version="laion2b_s32b_b79k", | |
arch="ViT-L-14", | |
# version="laion2b_s32b_b82k", | |
version="openai", | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
antialias=True, | |
ucg_rate=0.0, | |
unsqueeze_dim=False, | |
repeat_to_max_len=False, | |
num_image_crops=0, | |
output_tokens=False, | |
init_device=None, | |
): | |
super().__init__() | |
model, _, _ = open_clip.create_model_and_transforms( | |
arch, | |
device=torch.device(default(init_device, "cpu")), | |
pretrained=version, | |
) | |
del model.transformer | |
self.model = model | |
self.max_crops = num_image_crops | |
self.pad_to_max_len = self.max_crops > 0 | |
self.repeat_to_max_len = repeat_to_max_len and ( | |
not self.pad_to_max_len) | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.antialias = antialias | |
self.register_buffer("mean", | |
torch.Tensor([0.48145466, 0.4578275, 0.40821073]), | |
persistent=False) | |
self.register_buffer("std", | |
torch.Tensor([0.26862954, 0.26130258, | |
0.27577711]), | |
persistent=False) | |
self.ucg_rate = ucg_rate | |
self.unsqueeze_dim = unsqueeze_dim | |
self.stored_batch = None | |
self.model.visual.output_tokens = output_tokens | |
self.output_tokens = output_tokens | |
def preprocess(self, x): | |
# normalize to [0,1] | |
x = kornia.geometry.resize( | |
x, | |
(224, 224), | |
interpolation="bicubic", | |
align_corners=True, | |
antialias=self.antialias, | |
) | |
x = (x + 1.0) / 2.0 | |
# renormalize according to clip | |
x = kornia.enhance.normalize(x, self.mean, self.std) | |
return x | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, image, no_dropout=False): | |
z = self.encode_with_vision_transformer(image) | |
tokens = None | |
if self.output_tokens: | |
z, tokens = z[0], z[1] | |
z = z.to(image.dtype) | |
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): | |
z = (torch.bernoulli( | |
(1.0 - self.ucg_rate) * | |
torch.ones(z.shape[0], device=z.device))[:, None] * z) | |
if tokens is not None: | |
tokens = (expand_dims_like( | |
torch.bernoulli( | |
(1.0 - self.ucg_rate) * | |
torch.ones(tokens.shape[0], device=tokens.device)), | |
tokens, | |
) * tokens) | |
if self.unsqueeze_dim: | |
z = z[:, None, :] | |
if self.output_tokens: | |
assert not self.repeat_to_max_len | |
assert not self.pad_to_max_len | |
return tokens, z | |
if self.repeat_to_max_len: | |
if z.dim() == 2: | |
z_ = z[:, None, :] | |
else: | |
z_ = z | |
return repeat(z_, "b 1 d -> b n d", n=self.max_length), z | |
elif self.pad_to_max_len: | |
assert z.dim() == 3 | |
z_pad = torch.cat( | |
( | |
z, | |
torch.zeros( | |
z.shape[0], | |
self.max_length - z.shape[1], | |
z.shape[2], | |
device=z.device, | |
), | |
), | |
1, | |
) | |
return z_pad, z_pad[:, 0, ...] | |
return z | |
def encode_with_vision_transformer(self, img): | |
# if self.max_crops > 0: | |
# img = self.preprocess_by_cropping(img) | |
if img.dim() == 5: | |
assert self.max_crops == img.shape[1] | |
img = rearrange(img, "b n c h w -> (b n) c h w") | |
img = self.preprocess(img) | |
if not self.output_tokens: | |
assert not self.model.visual.output_tokens | |
x = self.model.visual(img) | |
tokens = None | |
else: | |
assert self.model.visual.output_tokens | |
x, tokens = self.model.visual(img) | |
if self.max_crops > 0: | |
x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) | |
# drop out between 0 and all along the sequence axis | |
x = (torch.bernoulli( | |
(1.0 - self.ucg_rate) * | |
torch.ones(x.shape[0], x.shape[1], 1, device=x.device)) * x) | |
if tokens is not None: | |
tokens = rearrange(tokens, | |
"(b n) t d -> b t (n d)", | |
n=self.max_crops) | |
print( | |
f"You are running very experimental token-concat in {self.__class__.__name__}. " | |
f"Check what you are doing, and then remove this message.") | |
if self.output_tokens: | |
return x, tokens | |
return x | |
def encode(self, text): | |
return self(text) | |
# dino-v2 embedder | |
class FrozenDinov2ImageEmbedder(AbstractEmbModel): | |
""" | |
Uses the Dino-v2 for low-level image embedding | |
""" | |
def __init__( | |
self, | |
arch="vitl", | |
version="dinov2", # by default | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
antialias=True, | |
ucg_rate=0.0, | |
unsqueeze_dim=False, | |
repeat_to_max_len=False, | |
num_image_crops=0, | |
output_tokens=False, | |
output_cls=False, | |
init_device=None, | |
): | |
super().__init__() | |
self.model = torch.hub.load( | |
f'facebookresearch/{version}', | |
'{}_{}{}_reg'.format( | |
version, f'{arch}', '14' | |
), # with registers better performance. vitl and vitg similar. Since fixed, load the best one. | |
pretrained=True).to(torch.device(default(init_device, "cpu"))) | |
# ! frozen | |
# self.tokenizer.requires_grad_(False) | |
# self.tokenizer.eval() | |
# assert freeze # add adaLN here | |
if freeze: | |
self.freeze() | |
# self.model = model | |
self.max_crops = num_image_crops | |
self.pad_to_max_len = self.max_crops > 0 | |
self.repeat_to_max_len = repeat_to_max_len and ( | |
not self.pad_to_max_len) | |
self.device = device | |
self.max_length = max_length | |
self.antialias = antialias | |
# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/data/transforms.py#L41 | |
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
self.register_buffer("mean", | |
torch.Tensor(IMAGENET_DEFAULT_MEAN), | |
persistent=False) | |
self.register_buffer("std", | |
torch.Tensor(IMAGENET_DEFAULT_STD), | |
persistent=False) | |
self.ucg_rate = ucg_rate | |
self.unsqueeze_dim = unsqueeze_dim | |
self.stored_batch = None | |
# self.model.visual.output_tokens = output_tokens | |
self.output_tokens = output_tokens # output | |
self.output_cls = output_cls | |
# self.output_tokens = False | |
def preprocess(self, x): | |
# normalize to [0,1] | |
x = kornia.geometry.resize( | |
x, | |
(224, 224), | |
interpolation="bicubic", | |
align_corners=True, | |
antialias=self.antialias, | |
) | |
x = (x + 1.0) / 2.0 | |
# renormalize according to clip | |
x = kornia.enhance.normalize(x, self.mean, self.std) | |
return x | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def _model_forward(self, *args, **kwargs): | |
return self.model(*args, **kwargs) | |
def encode_with_vision_transformer(self, img, **kwargs): | |
# if self.max_crops > 0: | |
# img = self.preprocess_by_cropping(img) | |
if img.dim() == 5: | |
# assert self.max_crops == img.shape[1] | |
img = rearrange(img, "b n c h w -> (b n) c h w") | |
img = self.preprocess(img) | |
# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L326 | |
if not self.output_cls: | |
return self._model_forward( | |
img, is_training=True, | |
**kwargs)['x_norm_patchtokens'] # to return spatial tokens | |
else: | |
dino_ret_dict = self._model_forward( | |
img, is_training=True) # to return spatial tokens | |
x_patchtokens, x_norm_clstoken = dino_ret_dict[ | |
'x_norm_patchtokens'], dino_ret_dict['x_norm_clstoken'] | |
return x_norm_clstoken, x_patchtokens | |
def forward(self, image, no_dropout=False, **kwargs): | |
tokens = self.encode_with_vision_transformer(image, **kwargs) | |
z = None | |
if self.output_cls: | |
z, tokens = z[0], z[1] | |
z = z.to(image.dtype) | |
tokens = tokens.to(image.dtype) # ! return spatial tokens only | |
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): | |
if z is not None: | |
z = (torch.bernoulli( | |
(1.0 - self.ucg_rate) * | |
torch.ones(z.shape[0], device=z.device))[:, None] * z) | |
tokens = (expand_dims_like( | |
torch.bernoulli( | |
(1.0 - self.ucg_rate) * | |
torch.ones(tokens.shape[0], device=tokens.device)), | |
tokens, | |
) * tokens) | |
if self.output_cls: | |
return tokens, z | |
else: | |
return tokens | |
class FrozenDinov2ImageEmbedderMVPlucker(FrozenDinov2ImageEmbedder): | |
def __init__( | |
self, | |
arch="vitl", | |
version="dinov2", # by default | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
antialias=True, | |
ucg_rate=0.0, | |
unsqueeze_dim=False, | |
repeat_to_max_len=False, | |
num_image_crops=0, | |
output_tokens=False, | |
output_cls=False, | |
init_device=None, | |
# mv cond settings | |
n_cond_frames=4, # numebr of condition views | |
enable_bf16=False, | |
modLN=False, | |
aug_c=False, | |
): | |
super().__init__( | |
arch, | |
version, | |
device, | |
max_length, | |
freeze, | |
antialias, | |
ucg_rate, | |
unsqueeze_dim, | |
repeat_to_max_len, | |
num_image_crops, | |
output_tokens, | |
output_cls, | |
init_device, | |
) | |
self.n_cond_frames = n_cond_frames | |
self.dtype = torch.bfloat16 if enable_bf16 else torch.float32 | |
self.enable_bf16 = enable_bf16 | |
self.aug_c = aug_c | |
# ! proj c_cond to features | |
self.reso_encoder = 224 | |
orig_patch_embed_weight = self.model.patch_embed.state_dict() | |
# ! 9-d input | |
with torch.no_grad(): | |
new_patch_embed = PatchEmbed(img_size=224, | |
patch_size=14, | |
in_chans=9, | |
embed_dim=self.model.embed_dim) | |
# zero init first | |
nn.init.constant_(new_patch_embed.proj.weight, 0) | |
nn.init.constant_(new_patch_embed.proj.bias, 0) | |
# load pre-trained first 3 layers weights, bias into the new patch_embed | |
new_patch_embed.proj.weight[:, :3].copy_(orig_patch_embed_weight['proj.weight']) | |
new_patch_embed.proj.bias[:].copy_(orig_patch_embed_weight['proj.bias']) | |
self.model.patch_embed = new_patch_embed # xyz in the front | |
# self.scale_jitter_aug = torchvision.transforms.v2.ScaleJitter(target_size=(self.reso_encoder, self.reso_encoder), scale_range=(0.5, 1.5)) | |
def scale_jitter_aug(self, x): | |
inp_size = x.shape[2] | |
# aug_size = torch.randint(low=50, high=100, size=(1,)) / 100 * inp_size | |
aug_size = int(max(0.5, random.random()) * inp_size) | |
# st() | |
x = torch.nn.functional.interpolate(x, | |
size=aug_size, | |
mode='bilinear', | |
antialias=True) | |
x = torch.nn.functional.interpolate(x,size=inp_size, | |
mode='bilinear', antialias=True) | |
return x | |
def gen_rays(self, c): | |
# Generate rays | |
intrinsics, c2w = c[16:], c[:16].reshape(4, 4) | |
self.h = self.reso_encoder | |
self.w = self.reso_encoder | |
yy, xx = torch.meshgrid( | |
torch.arange(self.h, dtype=torch.float32, device=c.device) + 0.5, | |
torch.arange(self.w, dtype=torch.float32, device=c.device) + 0.5, | |
indexing='ij') | |
# normalize to 0-1 pixel range | |
yy = yy / self.h | |
xx = xx / self.w | |
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[ | |
0], intrinsics[4] | |
# cx *= self.w | |
# cy *= self.h | |
# f_x = f_y = fx * h / res_raw | |
# c2w = torch.from_numpy(c2w).float() | |
c2w = c2w.float() | |
xx = (xx - cx) / fx | |
yy = (yy - cy) / fy | |
zz = torch.ones_like(xx) | |
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | |
dirs /= torch.norm(dirs, dim=-1, keepdim=True) | |
dirs = dirs.reshape(-1, 3, 1) | |
del xx, yy, zz | |
# st() | |
dirs = (c2w[None, :3, :3] @ dirs)[..., 0] | |
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() | |
origins = origins.view(self.h, self.w, 3) | |
dirs = dirs.view(self.h, self.w, 3) | |
return origins, dirs | |
def get_plucker_ray(self, c): | |
rays_plucker = [] | |
for idx in range(c.shape[0]): | |
rays_o, rays_d = self.gen_rays(c[idx]) | |
rays_plucker.append( | |
torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w | |
rays_plucker = torch.stack(rays_plucker, 0) | |
return rays_plucker | |
def _model_forward(self, x, plucker_c, *args, **kwargs): | |
with torch.cuda.amp.autocast(dtype=self.dtype, enabled=True): | |
x = torch.cat([x, plucker_c], dim=1).to(self.dtype) | |
return self.model(x, **kwargs) | |
def preprocess(self, x): | |
# add gaussian noise and rescale augmentation | |
if self.ucg_rate > 0.0: | |
# 1 means maintain the input x | |
enable_drop_flag = torch.bernoulli( | |
(1.0 - self.ucg_rate) * | |
torch.ones(x.shape[0], device=x.device))[:, None, None, None] # broadcast to B,1,1,1 | |
# * add random downsample & upsample | |
# rescaled_x = self.downsample_upsample(x) | |
# torchvision.utils.save_image(x, 'tmp/x.png', normalize=True, value_range=(-1,1)) | |
x_aug = self.scale_jitter_aug(x) | |
# torchvision.utils.save_image(x_aug, 'tmp/rescale-x.png', normalize=True, value_range=(-1,1)) | |
# x_aug = x * enable_drop_flag + (1-enable_drop_flag) * x_aug | |
# * guassian noise jitter | |
# force linear_weight > 0.24 | |
# linear_weight = torch.max(enable_drop_flag, torch.max(torch.rand_like(enable_drop_flag), 0.25 * torch.ones_like(enable_drop_flag), dim=0, keepdim=True), dim=0, keepdim=True) | |
gaussian_jitter_scale, jitter_lb = torch.rand_like(enable_drop_flag), 0.5 * torch.ones_like(enable_drop_flag) | |
gaussian_jitter_scale = torch.where(gaussian_jitter_scale>jitter_lb, gaussian_jitter_scale, jitter_lb) | |
# torchvision.utils.save_image(x, 'tmp/aug-x.png', normalize=True, value_range=(-1,1)) | |
x_aug = gaussian_jitter_scale * x_aug + (1 - gaussian_jitter_scale) * torch.randn_like(x).clamp(-1,1) | |
x_aug = x * enable_drop_flag + (1-enable_drop_flag) * x_aug | |
# torchvision.utils.save_image(x_aug, 'tmp/final-x.png', normalize=True, value_range=(-1,1)) | |
# st() | |
return super().preprocess(x) | |
def random_rotate_c(self, c): | |
intrinsics, c2ws = c[16:], c[:16].reshape(4, 4) | |
# https://github.com/TencentARC/InstantMesh/blob/34c193cc96eebd46deb7c48a76613753ad777122/src/data/objaverse.py#L195 | |
degree = np.random.uniform(0, math.pi * 2) | |
# random rotation along z axis | |
if random.random() > 0.5: | |
rot = torch.tensor([ | |
[np.cos(degree), -np.sin(degree), 0, 0], | |
[np.sin(degree), np.cos(degree), 0, 0], | |
[0, 0, 1, 0], | |
[0, 0, 0, 1], | |
]).to(c2ws) | |
else: | |
# random rotation along y axis | |
rot = torch.tensor([ | |
[np.cos(degree), 0, np.sin(degree), 0], | |
[0, 1, 0, 0], | |
[-np.sin(degree), 0, np.cos(degree), 0], | |
[0, 0, 0, 1], | |
]).to(c2ws) | |
c2ws = torch.matmul(rot, c2ws) | |
return torch.cat([c2ws.reshape(-1), intrinsics]) | |
def forward(self, img_c, no_dropout=False): | |
mv_image, c = img_c['img'], img_c['c'] | |
if self.aug_c: | |
for idx_b in range(c.shape[0]): | |
for idx_v in range(c.shape[1]): | |
if random.random() > 0.6: | |
c[idx_b, idx_v] = self.random_rotate_c(c[idx_b, idx_v]) | |
# plucker_c = self.get_plucker_ray( | |
# rearrange(c[:, 1:1 + self.n_cond_frames], "b t ... -> (b t) ...")) | |
plucker_c = self.get_plucker_ray( | |
rearrange(c[:, :self.n_cond_frames], "b t ... -> (b t) ...")) | |
# mv_image_tokens = super().forward(mv_image[:, 1:1 + self.n_cond_frames], | |
mv_image_tokens = super().forward(mv_image[:, :self.n_cond_frames], | |
plucker_c=plucker_c, | |
no_dropout=no_dropout) | |
mv_image_tokens = rearrange(mv_image_tokens, | |
"(b t) ... -> b t ...", | |
t=self.n_cond_frames) | |
return mv_image_tokens | |
def make_2tuple(x): | |
if isinstance(x, tuple): | |
assert len(x) == 2 | |
return x | |
assert isinstance(x, int) | |
return (x, x) | |
class PatchEmbed(nn.Module): | |
""" | |
2D image to patch embedding: (B,C,H,W) -> (B,N,D) | |
Args: | |
img_size: Image size. | |
patch_size: Patch token size. | |
in_chans: Number of input image channels. | |
embed_dim: Number of linear projection output channels. | |
norm_layer: Normalization layer. | |
""" | |
def __init__( | |
self, | |
img_size: Union[int, Tuple[int, int]] = 224, | |
patch_size: Union[int, Tuple[int, int]] = 16, | |
in_chans: int = 3, | |
embed_dim: int = 768, | |
norm_layer = None, | |
flatten_embedding: bool = True, | |
) -> None: | |
super().__init__() | |
image_HW = make_2tuple(img_size) | |
patch_HW = make_2tuple(patch_size) | |
patch_grid_size = ( | |
image_HW[0] // patch_HW[0], | |
image_HW[1] // patch_HW[1], | |
) | |
self.img_size = image_HW | |
self.patch_size = patch_HW | |
self.patches_resolution = patch_grid_size | |
self.num_patches = patch_grid_size[0] * patch_grid_size[1] | |
self.in_chans = in_chans | |
self.embed_dim = embed_dim | |
self.flatten_embedding = flatten_embedding | |
self.proj = nn.Conv2d(in_chans, | |
embed_dim, | |
kernel_size=patch_HW, | |
stride=patch_HW) | |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
def forward(self, x): | |
_, _, H, W = x.shape | |
patch_H, patch_W = self.patch_size | |
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" | |
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" | |
x = self.proj(x) # B C H W | |
H, W = x.size(2), x.size(3) | |
x = x.flatten(2).transpose(1, 2) # B HW C | |
x = self.norm(x) | |
if not self.flatten_embedding: | |
x = x.reshape(-1, H, W, self.embed_dim) # B H W C | |
return x | |
def flops(self) -> float: | |
Ho, Wo = self.patches_resolution | |
flops = Ho * Wo * self.embed_dim * self.in_chans * ( | |
self.patch_size[0] * self.patch_size[1]) | |
if self.norm is not None: | |
flops += Ho * Wo * self.embed_dim | |
return flops | |
class FrozenDinov2ImageEmbedderMV(FrozenDinov2ImageEmbedder): | |
def __init__( | |
self, | |
arch="vitl", | |
version="dinov2", # by default | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
antialias=True, | |
ucg_rate=0.0, | |
unsqueeze_dim=False, | |
repeat_to_max_len=False, | |
num_image_crops=0, | |
output_tokens=False, | |
output_cls=False, | |
init_device=None, | |
# mv cond settings | |
n_cond_frames=4, # numebr of condition views | |
enable_bf16=False, | |
modLN=False, | |
): | |
super().__init__( | |
arch, | |
version, | |
device, | |
max_length, | |
freeze, | |
antialias, | |
ucg_rate, | |
unsqueeze_dim, | |
repeat_to_max_len, | |
num_image_crops, | |
output_tokens, | |
output_cls, | |
init_device, | |
) | |
self.n_cond_frames = n_cond_frames | |
self.dtype = torch.bfloat16 if enable_bf16 else torch.float32 | |
self.enable_bf16 = enable_bf16 | |
# ! proj c_cond to features | |
hidden_size = self.model.embed_dim # 768 for vit-b | |
# self.cam_proj = CaptionEmbedder(16, hidden_size, | |
self.cam_proj = CaptionEmbedder(25, hidden_size, act_layer=approx_gelu) | |
# ! single-modLN | |
self.model.modLN_modulation = nn.Sequential( | |
nn.SiLU(), nn.Linear(hidden_size, 4 * hidden_size, bias=True)) | |
# zero-init modLN | |
nn.init.constant_(self.model.modLN_modulation[-1].weight, 0) | |
nn.init.constant_(self.model.modLN_modulation[-1].bias, 0) | |
# inject modLN to dino block | |
for block in self.model.blocks: | |
block.scale_shift_table = nn.Parameter(torch.zeros( | |
4, hidden_size)) # zero init also | |
# torch.randn(4, hidden_size) / hidden_size**0.5) | |
def _model_forward(self, x, *args, **kwargs): | |
# re-define model forward, finetune dino-v2. | |
assert self.training | |
# ? how to send in camera | |
# c = 0 # placeholder | |
# ret = self.model.forward_features(*args, **kwargs) | |
with torch.cuda.amp.autocast(dtype=self.dtype, enabled=True): | |
x = self.model.prepare_tokens_with_masks(x, masks=None) | |
B, N, C = x.shape | |
# TODO how to send in c | |
# c = torch.ones(B, 25).to(x) # placeholder | |
c = kwargs.get('c') | |
c = self.cam_proj(c) | |
cond = self.model.modLN_modulation(c) | |
# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/block.py#L89 | |
for blk in self.model.blocks: # inject modLN | |
shift_msa, scale_msa, shift_mlp, scale_mlp = ( | |
blk.scale_shift_table[None] + | |
cond.reshape(B, 4, -1)).chunk(4, dim=1) | |
def attn_residual_func(x: torch.Tensor) -> torch.Tensor: | |
# return blk.ls1(blk.attn(blk.norm1(x), attn_bias=attn_bias)) | |
return blk.ls1( | |
blk.attn( | |
t2i_modulate(blk.norm1(x), shift_msa, scale_msa))) | |
def ffn_residual_func(x: torch.Tensor) -> torch.Tensor: | |
# return blk.ls2(blk.mlp(blk.norm2(x))) | |
return blk.ls2( | |
t2i_modulate(blk.mlp(blk.norm2(x)), shift_mlp, | |
scale_mlp)) | |
x = x + blk.drop_path1( | |
attn_residual_func(x)) # all drop_path identity() here. | |
x = x + blk.drop_path2(ffn_residual_func(x)) | |
x_norm = self.model.norm(x) | |
return { | |
"x_norm_clstoken": x_norm[:, 0], | |
# "x_norm_regtokens": x_norm[:, 1 : self.model.num_register_tokens + 1], | |
"x_norm_patchtokens": x_norm[:, | |
self.model.num_register_tokens + 1:], | |
# "x_prenorm": x, | |
# "masks": masks, | |
} | |
def forward(self, img_c, no_dropout=False): | |
# if self.enable_bf16: | |
# with th.cuda.amp.autocast(dtype=self.dtype, | |
# enabled=True): | |
# mv_image = super().forward(mv_image[:, 1:1+self.n_cond_frames].to(torch.bf16)) | |
# else: | |
mv_image, c = img_c['img'], img_c['c'] | |
# ! use zero c here, ablation. current verison wrong. | |
# c = torch.zeros_like(c) | |
# ! frame-0 as canonical here. | |
mv_image = super().forward(mv_image[:, 1:1 + self.n_cond_frames], | |
c=rearrange(c[:, 1:1 + self.n_cond_frames], | |
"b t ... -> (b t) ...", | |
t=self.n_cond_frames), | |
no_dropout=no_dropout) | |
mv_image = rearrange(mv_image, | |
"(b t) ... -> b t ...", | |
t=self.n_cond_frames) | |
return mv_image | |
class FrozenCLIPT5Encoder(AbstractEmbModel): | |
def __init__( | |
self, | |
clip_version="openai/clip-vit-large-patch14", | |
t5_version="google/t5-v1_1-xl", | |
device="cuda", | |
clip_max_length=77, | |
t5_max_length=77, | |
): | |
super().__init__() | |
self.clip_encoder = FrozenCLIPEmbedder(clip_version, | |
device, | |
max_length=clip_max_length) | |
self.t5_encoder = FrozenT5Embedder(t5_version, | |
device, | |
max_length=t5_max_length) | |
print( | |
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " | |
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." | |
) | |
def encode(self, text): | |
return self(text) | |
def forward(self, text): | |
clip_z = self.clip_encoder.encode(text) | |
t5_z = self.t5_encoder.encode(text) | |
return [clip_z, t5_z] | |
class SpatialRescaler(nn.Module): | |
def __init__( | |
self, | |
n_stages=1, | |
method="bilinear", | |
multiplier=0.5, | |
in_channels=3, | |
out_channels=None, | |
bias=False, | |
wrap_video=False, | |
kernel_size=1, | |
remap_output=False, | |
): | |
super().__init__() | |
self.n_stages = n_stages | |
assert self.n_stages >= 0 | |
assert method in [ | |
"nearest", | |
"linear", | |
"bilinear", | |
"trilinear", | |
"bicubic", | |
"area", | |
] | |
self.multiplier = multiplier | |
self.interpolator = partial(torch.nn.functional.interpolate, | |
mode=method) | |
self.remap_output = out_channels is not None or remap_output | |
if self.remap_output: | |
print( | |
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." | |
) | |
self.channel_mapper = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
bias=bias, | |
padding=kernel_size // 2, | |
) | |
self.wrap_video = wrap_video | |
def forward(self, x): | |
if self.wrap_video and x.ndim == 5: | |
B, C, T, H, W = x.shape | |
x = rearrange(x, "b c t h w -> b t c h w") | |
x = rearrange(x, "b t c h w -> (b t) c h w") | |
for stage in range(self.n_stages): | |
x = self.interpolator(x, scale_factor=self.multiplier) | |
if self.wrap_video: | |
x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) | |
x = rearrange(x, "b t c h w -> b c t h w") | |
if self.remap_output: | |
x = self.channel_mapper(x) | |
return x | |
def encode(self, x): | |
return self(x) | |
class LowScaleEncoder(nn.Module): | |
def __init__( | |
self, | |
model_config, | |
linear_start, | |
linear_end, | |
timesteps=1000, | |
max_noise_level=250, | |
output_size=64, | |
scale_factor=1.0, | |
): | |
super().__init__() | |
self.max_noise_level = max_noise_level | |
self.model = instantiate_from_config(model_config) | |
self.augmentation_schedule = self.register_schedule( | |
timesteps=timesteps, | |
linear_start=linear_start, | |
linear_end=linear_end) | |
self.out_size = output_size | |
self.scale_factor = scale_factor | |
def register_schedule( | |
self, | |
beta_schedule="linear", | |
timesteps=1000, | |
linear_start=1e-4, | |
linear_end=2e-2, | |
cosine_s=8e-3, | |
): | |
betas = make_beta_schedule( | |
beta_schedule, | |
timesteps, | |
linear_start=linear_start, | |
linear_end=linear_end, | |
cosine_s=cosine_s, | |
) | |
alphas = 1.0 - betas | |
alphas_cumprod = np.cumprod(alphas, axis=0) | |
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) | |
(timesteps, ) = betas.shape | |
self.num_timesteps = int(timesteps) | |
self.linear_start = linear_start | |
self.linear_end = linear_end | |
assert (alphas_cumprod.shape[0] == self.num_timesteps | |
), "alphas have to be defined for each timestep" | |
to_torch = partial(torch.tensor, dtype=torch.float32) | |
self.register_buffer("betas", to_torch(betas)) | |
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) | |
self.register_buffer("alphas_cumprod_prev", | |
to_torch(alphas_cumprod_prev)) | |
# calculations for diffusion q(x_t | x_{t-1}) and others | |
self.register_buffer("sqrt_alphas_cumprod", | |
to_torch(np.sqrt(alphas_cumprod))) | |
self.register_buffer("sqrt_one_minus_alphas_cumprod", | |
to_torch(np.sqrt(1.0 - alphas_cumprod))) | |
self.register_buffer("log_one_minus_alphas_cumprod", | |
to_torch(np.log(1.0 - alphas_cumprod))) | |
self.register_buffer("sqrt_recip_alphas_cumprod", | |
to_torch(np.sqrt(1.0 / alphas_cumprod))) | |
self.register_buffer("sqrt_recipm1_alphas_cumprod", | |
to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) | |
def q_sample(self, x_start, t, noise=None): | |
noise = default(noise, lambda: torch.randn_like(x_start)) | |
return ( | |
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * | |
x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, | |
t, x_start.shape) * noise) | |
def forward(self, x): | |
z = self.model.encode(x) | |
if isinstance(z, DiagonalGaussianDistribution): | |
z = z.sample() | |
z = z * self.scale_factor | |
noise_level = torch.randint(0, | |
self.max_noise_level, (x.shape[0], ), | |
device=x.device).long() | |
z = self.q_sample(z, noise_level) | |
if self.out_size is not None: | |
z = torch.nn.functional.interpolate(z, | |
size=self.out_size, | |
mode="nearest") | |
return z, noise_level | |
def decode(self, z): | |
z = z / self.scale_factor | |
return self.model.decode(z) | |
class ConcatTimestepEmbedderND(AbstractEmbModel): | |
"""embeds each dimension independently and concatenates them""" | |
def __init__(self, outdim): | |
super().__init__() | |
self.timestep = Timestep(outdim) | |
self.outdim = outdim | |
def forward(self, x): | |
if x.ndim == 1: | |
x = x[:, None] | |
assert len(x.shape) == 2 | |
b, dims = x.shape[0], x.shape[1] | |
x = rearrange(x, "b d -> (b d)") | |
emb = self.timestep(x) | |
emb = rearrange(emb, | |
"(b d) d2 -> b (d d2)", | |
b=b, | |
d=dims, | |
d2=self.outdim) | |
return emb | |
class GaussianEncoder(Encoder, AbstractEmbModel): | |
def __init__(self, | |
weight: float = 1.0, | |
flatten_output: bool = True, | |
*args, | |
**kwargs): | |
super().__init__(*args, **kwargs) | |
self.posterior = DiagonalGaussianRegularizer() | |
self.weight = weight | |
self.flatten_output = flatten_output | |
def forward(self, x) -> Tuple[Dict, torch.Tensor]: | |
z = super().forward(x) | |
z, log = self.posterior(z) | |
log["loss"] = log["kl_loss"] | |
log["weight"] = self.weight | |
if self.flatten_output: | |
z = rearrange(z, "b c h w -> b (h w ) c") | |
return log, z | |
class VideoPredictionEmbedderWithEncoder(AbstractEmbModel): | |
def __init__( | |
self, | |
n_cond_frames: int, | |
n_copies: int, | |
encoder_config: dict, | |
sigma_sampler_config: Optional[dict] = None, | |
sigma_cond_config: Optional[dict] = None, | |
is_ae: bool = False, | |
scale_factor: float = 1.0, | |
disable_encoder_autocast: bool = False, | |
en_and_decode_n_samples_a_time: Optional[int] = None, | |
): | |
super().__init__() | |
self.n_cond_frames = n_cond_frames | |
self.n_copies = n_copies | |
self.encoder = instantiate_from_config(encoder_config) | |
self.sigma_sampler = (instantiate_from_config(sigma_sampler_config) | |
if sigma_sampler_config is not None else None) | |
self.sigma_cond = (instantiate_from_config(sigma_cond_config) | |
if sigma_cond_config is not None else None) | |
self.is_ae = is_ae | |
self.scale_factor = scale_factor | |
self.disable_encoder_autocast = disable_encoder_autocast | |
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time | |
def forward( | |
self, vid: torch.Tensor | |
) -> Union[ | |
torch.Tensor, | |
Tuple[torch.Tensor, torch.Tensor], | |
Tuple[torch.Tensor, dict], | |
Tuple[Tuple[torch.Tensor, torch.Tensor], dict], | |
]: | |
if self.sigma_sampler is not None: | |
b = vid.shape[0] // self.n_cond_frames | |
sigmas = self.sigma_sampler(b).to(vid.device) | |
if self.sigma_cond is not None: | |
sigma_cond = self.sigma_cond(sigmas) | |
sigma_cond = repeat(sigma_cond, | |
"b d -> (b t) d", | |
t=self.n_copies) | |
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) | |
noise = torch.randn_like(vid) | |
vid = vid + noise * append_dims(sigmas, vid.ndim) | |
with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): | |
n_samples = (self.en_and_decode_n_samples_a_time | |
if self.en_and_decode_n_samples_a_time is not None | |
else vid.shape[0]) | |
n_rounds = math.ceil(vid.shape[0] / n_samples) | |
all_out = [] | |
for n in range(n_rounds): | |
if self.is_ae: | |
out = self.encoder.encode(vid[n * n_samples:(n + 1) * | |
n_samples]) | |
else: | |
out = self.encoder(vid[n * n_samples:(n + 1) * n_samples]) | |
all_out.append(out) | |
vid = torch.cat(all_out, dim=0) | |
vid *= self.scale_factor | |
vid = rearrange(vid, | |
"(b t) c h w -> b () (t c) h w", | |
t=self.n_cond_frames) | |
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) | |
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid | |
return return_val | |
class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): | |
def __init__( | |
self, | |
open_clip_embedding_config: Dict, | |
n_cond_frames: int, | |
n_copies: int, | |
): | |
super().__init__() | |
self.n_cond_frames = n_cond_frames | |
self.n_copies = n_copies | |
self.open_clip = instantiate_from_config(open_clip_embedding_config) | |
def forward(self, vid): | |
vid = self.open_clip(vid) | |
vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) | |
vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) | |
return vid | |
class FrozenOpenCLIPImageMVEmbedder(AbstractEmbModel): | |
# for multi-view 3D diffusion condition. Only extract the first frame | |
def __init__( | |
self, | |
open_clip_embedding_config: Dict, | |
# n_cond_frames: int, | |
# n_copies: int, | |
): | |
super().__init__() | |
# self.n_cond_frames = n_cond_frames | |
# self.n_copies = n_copies | |
self.open_clip = instantiate_from_config(open_clip_embedding_config) | |
def forward(self, vid, no_dropout=False): | |
# st() | |
vid = self.open_clip(vid[:, 0, ...], no_dropout=no_dropout) | |
# vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) | |
# vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) | |
return vid | |