Spaces:
Running
on
Zero
Running
on
Zero
#original code from https://github.com/genmoai/models under apache 2.0 license | |
#adapted to ComfyUI | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
# from flash_attn import flash_attn_varlen_qkvpacked_func | |
from comfy.ldm.modules.attention import optimized_attention | |
from .layers import ( | |
FeedForward, | |
PatchEmbed, | |
RMSNorm, | |
TimestepEmbedder, | |
) | |
from .rope_mixed import ( | |
compute_mixed_rotation, | |
create_position_matrix, | |
) | |
from .temporal_rope import apply_rotary_emb_qk_real | |
from .utils import ( | |
AttentionPool, | |
modulate, | |
) | |
import comfy.ldm.common_dit | |
import comfy.ops | |
def modulated_rmsnorm(x, scale, eps=1e-6): | |
# Normalize and modulate | |
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps) | |
x_modulated = x_normed * (1 + scale.unsqueeze(1)) | |
return x_modulated | |
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6): | |
# Apply tanh to gate | |
tanh_gate = torch.tanh(gate).unsqueeze(1) | |
# Normalize and apply gated scaling | |
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate | |
# Apply residual connection | |
output = x + x_normed | |
return output | |
class AsymmetricAttention(nn.Module): | |
def __init__( | |
self, | |
dim_x: int, | |
dim_y: int, | |
num_heads: int = 8, | |
qkv_bias: bool = True, | |
qk_norm: bool = False, | |
attn_drop: float = 0.0, | |
update_y: bool = True, | |
out_bias: bool = True, | |
attend_to_padding: bool = False, | |
softmax_scale: Optional[float] = None, | |
device: Optional[torch.device] = None, | |
dtype=None, | |
operations=None, | |
): | |
super().__init__() | |
self.dim_x = dim_x | |
self.dim_y = dim_y | |
self.num_heads = num_heads | |
self.head_dim = dim_x // num_heads | |
self.attn_drop = attn_drop | |
self.update_y = update_y | |
self.attend_to_padding = attend_to_padding | |
self.softmax_scale = softmax_scale | |
if dim_x % num_heads != 0: | |
raise ValueError( | |
f"dim_x={dim_x} should be divisible by num_heads={num_heads}" | |
) | |
# Input layers. | |
self.qkv_bias = qkv_bias | |
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype) | |
# Project text features to match visual features (dim_y -> dim_x) | |
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype) | |
# Query and key normalization for stability. | |
assert qk_norm | |
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype) | |
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype) | |
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype) | |
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype) | |
# Output layers. y features go back down from dim_x -> dim_y. | |
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype) | |
self.proj_y = ( | |
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype) | |
if update_y | |
else nn.Identity() | |
) | |
def forward( | |
self, | |
x: torch.Tensor, # (B, N, dim_x) | |
y: torch.Tensor, # (B, L, dim_y) | |
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. | |
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. | |
crop_y, | |
**rope_rotation, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
rope_cos = rope_rotation.get("rope_cos") | |
rope_sin = rope_rotation.get("rope_sin") | |
# Pre-norm for visual features | |
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size | |
# Process visual features | |
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) | |
# assert qkv_x.dtype == torch.bfloat16 | |
# qkv_x = all_to_all_collect_tokens( | |
# qkv_x, self.num_heads | |
# ) # (3, B, N, local_h, head_dim) | |
# Process text features | |
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) | |
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim) | |
q_y = self.q_norm_y(q_y) | |
k_y = self.k_norm_y(k_y) | |
# Split qkv_x into q, k, v | |
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim) | |
q_x = self.q_norm_x(q_x) | |
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) | |
k_x = self.k_norm_x(k_x) | |
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) | |
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2) | |
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2) | |
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2) | |
xy = optimized_attention(q, | |
k, | |
v, self.num_heads, skip_reshape=True) | |
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) | |
x = self.proj_x(x) | |
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype) | |
o[:, :y.shape[1]] = y | |
y = self.proj_y(o) | |
# print("ox", x) | |
# print("oy", y) | |
return x, y | |
class AsymmetricJointBlock(nn.Module): | |
def __init__( | |
self, | |
hidden_size_x: int, | |
hidden_size_y: int, | |
num_heads: int, | |
*, | |
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens. | |
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. | |
update_y: bool = True, # Whether to update text tokens in this block. | |
device: Optional[torch.device] = None, | |
dtype=None, | |
operations=None, | |
**block_kwargs, | |
): | |
super().__init__() | |
self.update_y = update_y | |
self.hidden_size_x = hidden_size_x | |
self.hidden_size_y = hidden_size_y | |
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype) | |
if self.update_y: | |
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype) | |
else: | |
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype) | |
# Self-attention: | |
self.attn = AsymmetricAttention( | |
hidden_size_x, | |
hidden_size_y, | |
num_heads=num_heads, | |
update_y=update_y, | |
device=device, | |
dtype=dtype, | |
operations=operations, | |
**block_kwargs, | |
) | |
# MLP. | |
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x) | |
assert mlp_hidden_dim_x == int(1536 * 8) | |
self.mlp_x = FeedForward( | |
in_features=hidden_size_x, | |
hidden_size=mlp_hidden_dim_x, | |
multiple_of=256, | |
ffn_dim_multiplier=None, | |
device=device, | |
dtype=dtype, | |
operations=operations, | |
) | |
# MLP for text not needed in last block. | |
if self.update_y: | |
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y) | |
self.mlp_y = FeedForward( | |
in_features=hidden_size_y, | |
hidden_size=mlp_hidden_dim_y, | |
multiple_of=256, | |
ffn_dim_multiplier=None, | |
device=device, | |
dtype=dtype, | |
operations=operations, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
c: torch.Tensor, | |
y: torch.Tensor, | |
**attn_kwargs, | |
): | |
"""Forward pass of a block. | |
Args: | |
x: (B, N, dim) tensor of visual tokens | |
c: (B, dim) tensor of conditioned features | |
y: (B, L, dim) tensor of text tokens | |
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens | |
Returns: | |
x: (B, N, dim) tensor of visual tokens after block | |
y: (B, L, dim) tensor of text tokens after block | |
""" | |
N = x.size(1) | |
c = F.silu(c) | |
mod_x = self.mod_x(c) | |
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1) | |
mod_y = self.mod_y(c) | |
if self.update_y: | |
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1) | |
else: | |
scale_msa_y = mod_y | |
# Self-attention block. | |
x_attn, y_attn = self.attn( | |
x, | |
y, | |
scale_x=scale_msa_x, | |
scale_y=scale_msa_y, | |
**attn_kwargs, | |
) | |
assert x_attn.size(1) == N | |
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) | |
if self.update_y: | |
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y) | |
# MLP block. | |
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x) | |
if self.update_y: | |
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y) | |
return x, y | |
def ff_block_x(self, x, scale_x, gate_x): | |
x_mod = modulated_rmsnorm(x, scale_x) | |
x_res = self.mlp_x(x_mod) | |
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm | |
return x | |
def ff_block_y(self, y, scale_y, gate_y): | |
y_mod = modulated_rmsnorm(y, scale_y) | |
y_res = self.mlp_y(y_mod) | |
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm | |
return y | |
class FinalLayer(nn.Module): | |
""" | |
The final layer of DiT. | |
""" | |
def __init__( | |
self, | |
hidden_size, | |
patch_size, | |
out_channels, | |
device: Optional[torch.device] = None, | |
dtype=None, | |
operations=None, | |
): | |
super().__init__() | |
self.norm_final = operations.LayerNorm( | |
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype | |
) | |
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype) | |
self.linear = operations.Linear( | |
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype | |
) | |
def forward(self, x, c): | |
c = F.silu(c) | |
shift, scale = self.mod(c).chunk(2, dim=1) | |
x = modulate(self.norm_final(x), shift, scale) | |
x = self.linear(x) | |
return x | |
class AsymmDiTJoint(nn.Module): | |
""" | |
Diffusion model with a Transformer backbone. | |
Ingests text embeddings instead of a label. | |
""" | |
def __init__( | |
self, | |
*, | |
patch_size=2, | |
in_channels=4, | |
hidden_size_x=1152, | |
hidden_size_y=1152, | |
depth=48, | |
num_heads=16, | |
mlp_ratio_x=8.0, | |
mlp_ratio_y=4.0, | |
use_t5: bool = False, | |
t5_feat_dim: int = 4096, | |
t5_token_length: int = 256, | |
learn_sigma=True, | |
patch_embed_bias: bool = True, | |
timestep_mlp_bias: bool = True, | |
attend_to_padding: bool = False, | |
timestep_scale: Optional[float] = None, | |
use_extended_posenc: bool = False, | |
posenc_preserve_area: bool = False, | |
rope_theta: float = 10000.0, | |
image_model=None, | |
device: Optional[torch.device] = None, | |
dtype=None, | |
operations=None, | |
**block_kwargs, | |
): | |
super().__init__() | |
self.dtype = dtype | |
self.learn_sigma = learn_sigma | |
self.in_channels = in_channels | |
self.out_channels = in_channels * 2 if learn_sigma else in_channels | |
self.patch_size = patch_size | |
self.num_heads = num_heads | |
self.hidden_size_x = hidden_size_x | |
self.hidden_size_y = hidden_size_y | |
self.head_dim = ( | |
hidden_size_x // num_heads | |
) # Head dimension and count is determined by visual. | |
self.attend_to_padding = attend_to_padding | |
self.use_extended_posenc = use_extended_posenc | |
self.posenc_preserve_area = posenc_preserve_area | |
self.use_t5 = use_t5 | |
self.t5_token_length = t5_token_length | |
self.t5_feat_dim = t5_feat_dim | |
self.rope_theta = ( | |
rope_theta # Scaling factor for frequency computation for temporal RoPE. | |
) | |
self.x_embedder = PatchEmbed( | |
patch_size=patch_size, | |
in_chans=in_channels, | |
embed_dim=hidden_size_x, | |
bias=patch_embed_bias, | |
dtype=dtype, | |
device=device, | |
operations=operations | |
) | |
# Conditionings | |
# Timestep | |
self.t_embedder = TimestepEmbedder( | |
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations | |
) | |
if self.use_t5: | |
# Caption Pooling (T5) | |
self.t5_y_embedder = AttentionPool( | |
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations | |
) | |
# Dense Embedding Projection (T5) | |
self.t5_yproj = operations.Linear( | |
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device | |
) | |
# Initialize pos_frequencies as an empty parameter. | |
self.pos_frequencies = nn.Parameter( | |
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device) | |
) | |
assert not self.attend_to_padding | |
# for depth 48: | |
# b = 0: AsymmetricJointBlock, update_y=True | |
# b = 1: AsymmetricJointBlock, update_y=True | |
# ... | |
# b = 46: AsymmetricJointBlock, update_y=True | |
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features. | |
blocks = [] | |
for b in range(depth): | |
# Joint multi-modal block | |
update_y = b < depth - 1 | |
block = AsymmetricJointBlock( | |
hidden_size_x, | |
hidden_size_y, | |
num_heads, | |
mlp_ratio_x=mlp_ratio_x, | |
mlp_ratio_y=mlp_ratio_y, | |
update_y=update_y, | |
attend_to_padding=attend_to_padding, | |
device=device, | |
dtype=dtype, | |
operations=operations, | |
**block_kwargs, | |
) | |
blocks.append(block) | |
self.blocks = nn.ModuleList(blocks) | |
self.final_layer = FinalLayer( | |
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations | |
) | |
def embed_x(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
x: (B, C=12, T, H, W) tensor of visual tokens | |
Returns: | |
x: (B, C=3072, N) tensor of visual tokens with positional embedding. | |
""" | |
return self.x_embedder(x) # Convert BcTHW to BCN | |
def prepare( | |
self, | |
x: torch.Tensor, | |
sigma: torch.Tensor, | |
t5_feat: torch.Tensor, | |
t5_mask: torch.Tensor, | |
): | |
"""Prepare input and conditioning embeddings.""" | |
# Visual patch embeddings with positional encoding. | |
T, H, W = x.shape[-3:] | |
pH, pW = H // self.patch_size, W // self.patch_size | |
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 | |
assert x.ndim == 3 | |
B = x.size(0) | |
pH, pW = H // self.patch_size, W // self.patch_size | |
N = T * pH * pW | |
assert x.size(1) == N | |
pos = create_position_matrix( | |
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 | |
) # (N, 3) | |
rope_cos, rope_sin = compute_mixed_rotation( | |
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos | |
) # Each are (N, num_heads, dim // 2) | |
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D) | |
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D) | |
c = c_t + t5_y_pool | |
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D) | |
return x, c, y_feat, rope_cos, rope_sin | |
def forward( | |
self, | |
x: torch.Tensor, | |
timestep: torch.Tensor, | |
context: List[torch.Tensor], | |
attention_mask: List[torch.Tensor], | |
num_tokens=256, | |
packed_indices: Dict[str, torch.Tensor] = None, | |
rope_cos: torch.Tensor = None, | |
rope_sin: torch.Tensor = None, | |
control=None, transformer_options={}, **kwargs | |
): | |
patches_replace = transformer_options.get("patches_replace", {}) | |
y_feat = context | |
y_mask = attention_mask | |
sigma = timestep | |
"""Forward pass of DiT. | |
Args: | |
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) | |
sigma: (B,) tensor of noise standard deviations | |
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) | |
y_mask: List((B, L) boolean tensor indicating which tokens are not padding) | |
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. | |
""" | |
B, _, T, H, W = x.shape | |
x, c, y_feat, rope_cos, rope_sin = self.prepare( | |
x, sigma, y_feat, y_mask | |
) | |
del y_mask | |
blocks_replace = patches_replace.get("dit", {}) | |
for i, block in enumerate(self.blocks): | |
if ("double_block", i) in blocks_replace: | |
def block_wrap(args): | |
out = {} | |
out["img"], out["txt"] = block( | |
args["img"], | |
args["vec"], | |
args["txt"], | |
rope_cos=args["rope_cos"], | |
rope_sin=args["rope_sin"], | |
crop_y=args["num_tokens"] | |
) | |
return out | |
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) | |
y_feat = out["txt"] | |
x = out["img"] | |
else: | |
x, y_feat = block( | |
x, | |
c, | |
y_feat, | |
rope_cos=rope_cos, | |
rope_sin=rope_sin, | |
crop_y=num_tokens, | |
) # (B, M, D), (B, L, D) | |
del y_feat # Final layers don't use dense text features. | |
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) | |
x = rearrange( | |
x, | |
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", | |
T=T, | |
hp=H // self.patch_size, | |
wp=W // self.patch_size, | |
p1=self.patch_size, | |
p2=self.patch_size, | |
c=self.out_channels, | |
) | |
return -x | |