EasyAnimate / easyanimate /models /transformer3d.py
bubbliiiing
Update V5.1
c2a6cd2
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import math
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.embeddings import (PatchEmbed, PixArtAlphaTextProjection,
TimestepEmbedding, Timesteps,
get_2d_sincos_pos_embed)
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous
from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version,
logging)
from diffusers.utils.torch_utils import maybe_allow_in_graph
from einops import rearrange
from torch import nn
from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock,
SelfAttentionTemporalTransformerBlock,
TemporalTransformerBlock, zero_module)
from .embeddings import (HunyuanCombinedTimestepTextSizeStyleEmbedding,
TimePositionalEncoding)
from .norm import AdaLayerNormSingle, EasyAnimateRMSNorm
from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D,
TemporalUpsampler3D, UnPatch1D)
from .resampler import Resampler
try:
from diffusers.models.embeddings import PixArtAlphaTextProjection
except:
from diffusers.models.embeddings import \
CaptionProjection as PixArtAlphaTextProjection
class CLIPProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, num_tokens=120):
super().__init__()
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
self.linear_2 = zero_module(self.linear_2)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@dataclass
class Transformer3DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer3DModel(ModelMixin, ConfigMixin):
"""
A 3D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_vector_embeds (`int`, *optional*):
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
n_query=8,
# block type
basic_block_type: str = "motionmodule",
# enable_uvit
enable_uvit: bool = False,
# 3d patch params
patch_3d: bool = False,
fake_3d: bool = False,
time_patch_size: Optional[int] = None,
casual_3d: bool = False,
casual_3d_upsampler_index: Optional[list] = None,
# motion module kwargs
motion_module_type = "VanillaGrid",
motion_module_kwargs = None,
motion_module_kwargs_odd = None,
motion_module_kwargs_even = None,
# time position encoding
time_position_encoding_before_transformer = False,
qk_norm = False,
after_norm = False,
resize_inpaint_mask_directly: bool = False,
enable_clip_in_inpaint: bool = True,
position_of_clip_embedding: str = "head",
enable_zero_in_inpaint: bool = False,
enable_text_attention_mask: bool = True,
add_noise_in_inpaint_model: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.enable_uvit = enable_uvit
inner_dim = num_attention_heads * attention_head_dim
self.basic_block_type = basic_block_type
self.patch_3d = patch_3d
self.fake_3d = fake_3d
self.casual_3d = casual_3d
self.casual_3d_upsampler_index = casual_3d_upsampler_index
assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
interpolation_scale = max(interpolation_scale, 1)
self.n_query = n_query
if self.casual_3d:
self.pos_embed = CasualPatchEmbed3D(
height=sample_size,
width=sample_size,
patch_size=patch_size,
time_patch_size=self.time_patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
elif self.patch_3d:
if self.fake_3d:
self.pos_embed = PatchEmbedF3D(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
else:
self.pos_embed = PatchEmbed3D(
height=sample_size,
width=sample_size,
patch_size=patch_size,
time_patch_size=self.time_patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
else:
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
# 3. Define transformers blocks
if self.basic_block_type == "motionmodule":
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
qk_norm=qk_norm,
after_norm=after_norm,
)
for d in range(num_layers)
]
)
elif self.basic_block_type == "global_motionmodule":
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs_even if d % 2 == 0 else motion_module_kwargs_odd,
qk_norm=qk_norm,
after_norm=after_norm,
)
for d in range(num_layers)
]
)
elif self.basic_block_type == "selfattentiontemporal":
self.transformer_blocks = nn.ModuleList(
[
SelfAttentionTemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
qk_norm=qk_norm,
after_norm=after_norm,
)
for d in range(num_layers)
]
)
else:
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
for d in range(num_layers)
]
)
if self.casual_3d:
self.unpatch1d = TemporalUpsampler3D()
elif self.patch_3d and self.fake_3d:
self.unpatch1d = UnPatch1D(inner_dim, True)
if self.enable_uvit:
self.long_connect_fc = nn.ModuleList(
[
nn.Linear(inner_dim, inner_dim, True) for d in range(13)
]
)
for index in range(13):
self.long_connect_fc[index] = zero_module(self.long_connect_fc[index])
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
if self.patch_3d and not self.fake_3d:
self.proj_out_2 = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
else:
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
elif norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
if self.patch_3d and not self.fake_3d:
self.proj_out = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
else:
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
# 5. PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
if norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
self.caption_projection = None
self.clip_projection = None
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
if in_channels == 12:
self.clip_projection = CLIPProjection(in_features=768, hidden_size=inner_dim * 8)
self.gradient_checkpointing = False
self.time_position_encoding_before_transformer = time_position_encoding_before_transformer
if self.time_position_encoding_before_transformer:
self.t_pos = TimePositionalEncoding(max_len = 4096, d_model = inner_dim)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.LongTensor] = None,
timestep_cond = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
text_embedding_mask: Optional[torch.Tensor] = None,
encoder_hidden_states_t5: Optional[torch.Tensor] = None,
text_embedding_mask_t5: Optional[torch.Tensor] = None,
image_meta_size = None,
style = None,
image_rotary_emb: Optional[torch.Tensor] = None,
inpaint_latents: torch.Tensor = None,
control_latents: torch.Tensor = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
clip_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
text_embedding_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer3DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
text_embedding_mask = text_embedding_mask.squeeze(1)
if clip_attention_mask is not None:
text_embedding_mask = torch.cat([text_embedding_mask, clip_attention_mask], dim=1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if text_embedding_mask is not None and text_embedding_mask.ndim == 2:
encoder_attention_mask = (1 - text_embedding_mask.to(encoder_hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
if control_latents is not None:
hidden_states = torch.concat([hidden_states, control_latents], 1)
# 1. Input
if self.casual_3d:
video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
elif self.patch_3d:
video_length, height, width = hidden_states.shape[-3] // self.time_patch_size, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
else:
video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
batch_size = hidden_states.shape[0] // video_length
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
# hidden_states
# bs, c, f, h, w => b (f h w ) c
if self.time_position_encoding_before_transformer:
hidden_states = self.t_pos(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
if clip_encoder_hidden_states is not None and encoder_hidden_states is not None:
batch_size = hidden_states.shape[0]
clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states)
clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states = torch.cat([encoder_hidden_states, clip_encoder_hidden_states], dim = 1)
skips = []
skip_index = 0
for index, block in enumerate(self.transformer_blocks):
if self.enable_uvit:
if index >= 15:
long_connect = self.long_connect_fc[skip_index](skips.pop())
hidden_states = hidden_states + long_connect
skip_index += 1
if self.casual_3d_upsampler_index is not None and index in self.casual_3d_upsampler_index:
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=video_length, h=height, w=width)
hidden_states = self.unpatch1d(hidden_states)
video_length = (video_length - 1) * 2 + 1
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c", f=video_length, h=height, w=width)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
args = {
"basic": [],
"motionmodule": [video_length, height, width],
"global_motionmodule": [video_length, height, width],
"selfattentiontemporal": [],
}[self.basic_block_type]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
*args,
**ckpt_kwargs,
)
else:
kwargs = {
"basic": {},
"motionmodule": {"num_frames":video_length, "height":height, "width":width},
"global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
"selfattentiontemporal": {},
}[self.basic_block_type]
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
**kwargs
)
if self.enable_uvit:
if index < 13:
skips.append(hidden_states)
if self.fake_3d and self.patch_3d:
hidden_states = rearrange(hidden_states, "b (f h w) c -> (b h w) c f", f=video_length, w=width, h=height)
hidden_states = self.unpatch1d(hidden_states)
hidden_states = rearrange(hidden_states, "(b h w) c f -> b (f h w) c", w=width, h=height)
# 3. Output
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
if self.patch_3d:
if self.fake_3d:
hidden_states = hidden_states.reshape(
shape=(-1, video_length * self.patch_size, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
else:
hidden_states = hidden_states.reshape(
shape=(-1, video_length, height, width, self.time_patch_size, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nfhwopqc->ncfohpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, video_length * self.time_patch_size, height * self.patch_size, width * self.patch_size)
)
else:
hidden_states = hidden_states.reshape(
shape=(-1, video_length, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, video_length, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
@classmethod
def from_pretrained_2d(
cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={},
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if low_cpu_mem_usage:
try:
import re
from diffusers.models.modeling_utils import \
load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
import accelerate
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model = cls.from_config(config, **transformer_additional_kwargs)
param_device = "cpu"
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_path,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
print(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
return model
except Exception as e:
print(
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
)
model = cls.from_config(config, **transformer_additional_kwargs)
if os.path.exists(model_file):
state_dict = torch.load(model_file, map_location="cpu")
elif os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
from safetensors.torch import load_file, safe_open
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
state_dict = {}
for model_file_safetensors in model_files_safetensors:
_state_dict = load_file(model_file_safetensors)
for key in _state_dict:
state_dict[key] = _state_dict[key]
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
if len(new_shape) == 5:
state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
state_dict['pos_embed.proj.weight'][:, :, :-1] = 0
else:
model.state_dict()['pos_embed.proj.weight'][:, :4, :, :] = state_dict['pos_embed.proj.weight']
model.state_dict()['pos_embed.proj.weight'][:, 4:, :, :] = 0
state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
new_shape = model.state_dict()['proj_out.weight'].size()
state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
new_shape = model.state_dict()['proj_out.bias'].size()
state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
else:
print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
model = model.to(torch_dtype)
return model
class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88):
The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
patch_size (`int`, *optional*):
The size of the patch to use for the input.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward.
sample_size (`int`, *optional*):
The width of the latent images. This is fixed during training since it is used to learn a number of
position embeddings.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
cross_attention_dim (`int`, *optional*):
The number of dimension in the clip text embedding.
hidden_size (`int`, *optional*):
The size of hidden layer in the conditioning embedding layers.
num_layers (`int`, *optional*, defaults to 1):
The number of layers of Transformer blocks to use.
mlp_ratio (`float`, *optional*, defaults to 4.0):
The ratio of the hidden layer size to the input size.
learn_sigma (`bool`, *optional*, defaults to `True`):
Whether to predict variance.
cross_attention_dim_t5 (`int`, *optional*):
The number dimensions in t5 text embedding.
pooled_projection_dim (`int`, *optional*):
The size of the pooled projection.
text_len (`int`, *optional*):
The length of the clip text embedding.
text_len_t5 (`int`, *optional*):
The length of the T5 text embedding.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
patch_size: Optional[int] = None,
n_query=16,
projection_dim=768,
activation_fn: str = "gelu-approximate",
sample_size=32,
hidden_size=1152,
num_layers: int = 28,
mlp_ratio: float = 4.0,
learn_sigma: bool = True,
cross_attention_dim: int = 1024,
norm_type: str = "layer_norm",
cross_attention_dim_t5: int = 2048,
pooled_projection_dim: int = 1024,
text_len: int = 77,
text_len_t5: int = 256,
# block type
basic_block_type: str = "basic",
time_position_encoding = False,
time_position_encoding_type: str = "2d_rope",
after_norm = False,
resize_inpaint_mask_directly: bool = False,
enable_clip_in_inpaint: bool = True,
position_of_clip_embedding: str = "full",
enable_text_attention_mask: bool = True,
add_noise_in_inpaint_model: bool = False,
):
super().__init__()
# 4. Define output layers
if learn_sigma:
self.out_channels = in_channels * 2 if out_channels is None else out_channels
else:
self.out_channels = in_channels if out_channels is None else out_channels
self.enable_inpaint = in_channels * 2 != self.out_channels if learn_sigma else in_channels != self.out_channels
self.num_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.basic_block_type = basic_block_type
self.resize_inpaint_mask_directly = resize_inpaint_mask_directly
self.text_embedder = PixArtAlphaTextProjection(
in_features=cross_attention_dim_t5,
hidden_size=cross_attention_dim_t5 * 4,
out_features=cross_attention_dim,
act_fn="silu_fp32",
)
self.text_embedding_padding = nn.Parameter(
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
in_channels=in_channels,
embed_dim=hidden_size,
patch_size=patch_size,
pos_embed_type=None,
)
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
hidden_size,
pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5,
)
# 3. Define transformers blocks
if self.basic_block_type == "hybrid_attention":
self.blocks = nn.ModuleList(
[
HunyuanDiTBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=layer > num_layers // 2,
after_norm=after_norm,
time_position_encoding=time_position_encoding,
is_local_attention=False if layer % 2 == 0 else True,
local_attention_frames=2,
enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint,
)
for layer in range(num_layers)
]
)
elif self.basic_block_type == "kvcompression_basic":
self.blocks = nn.ModuleList(
[
HunyuanDiTBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=layer > num_layers // 2,
after_norm=after_norm,
time_position_encoding=time_position_encoding,
kvcompression=False if layer < num_layers // 2 else True,
enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint,
)
for layer in range(num_layers)
]
)
else:
self.blocks = nn.ModuleList(
[
HunyuanDiTBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=layer > num_layers // 2,
after_norm=after_norm,
time_position_encoding=time_position_encoding,
enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint,
)
for layer in range(num_layers)
]
)
self.n_query = n_query
if self.enable_inpaint and enable_clip_in_inpaint:
self.clip_padding = nn.Parameter(
torch.randn((self.n_query, cross_attention_dim)) * 0.02
)
self.clip_projection = Resampler(
int(math.sqrt(n_query)),
embed_dim=cross_attention_dim,
num_heads=self.config.num_attention_heads,
kv_dim=projection_dim,
norm_layer=nn.LayerNorm,
)
else:
self.clip_padding = None
self.clip_projection = None
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
inpaint_latents=None,
control_latents: torch.Tensor = None,
clip_encoder_hidden_states: Optional[torch.Tensor]=None,
clip_attention_mask: Optional[torch.Tensor]=None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
return_dict=True,
):
"""
The [`HunyuanDiT2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
The input tensor.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step.
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
text_embedding_mask: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of `BertModel`.
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
text_embedding_mask_t5: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of T5 Text Encoder.
image_meta_size (torch.Tensor):
Conditional embedding indicate the image sizes
style: torch.Tensor:
Conditional embedding indicate the style
image_rotary_emb (`torch.Tensor`):
The image rotary embeddings to apply on query and key tensors during attention calculation.
return_dict: bool
Whether to return a dictionary.
"""
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
if control_latents is not None:
hidden_states = torch.concat([hidden_states, control_latents], 1)
# unpatchify: (N, out_channels, H, W)
patch_size = self.pos_embed.patch_size
video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // patch_size, hidden_states.shape[-1] // patch_size
hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
hidden_states = self.pos_embed(hidden_states)
hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
temb = self.time_extra_emb(
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
) # [B, D]
# text projection
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
encoder_hidden_states_t5 = self.text_embedder(
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
if clip_encoder_hidden_states is not None:
batch_size = encoder_hidden_states.shape[0]
clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states)
clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
clip_attention_mask = clip_attention_mask.unsqueeze(2).bool()
clip_encoder_hidden_states = torch.where(clip_attention_mask, clip_encoder_hidden_states, self.clip_padding)
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.config.num_layers // 2:
skip = skips.pop()
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
args = {
"kvcompression_basic": [video_length, height, width, clip_encoder_hidden_states],
"basic": [video_length, height, width, clip_encoder_hidden_states],
"hybrid_attention": [video_length, height, width, clip_encoder_hidden_states],
}[self.basic_block_type]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
skip,
*args,
**ckpt_kwargs,
)
else:
kwargs = {
"kvcompression_basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
"basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
"hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
}[self.basic_block_type]
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
skip=skip,
**kwargs
) # (N, L, D)
else:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
args = {
"kvcompression_basic": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False],
"basic": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False],
"hybrid_attention": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False],
}[self.basic_block_type]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
*args,
**ckpt_kwargs,
)
else:
kwargs = {
"kvcompression_basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
"basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
"hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
}[self.basic_block_type]
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
disable_image_rotary_emb_in_attn1=True if layer==0 else False,
**kwargs
) # (N, L, D)
if layer < (self.config.num_layers // 2 - 1):
skips.append(hidden_states)
# final layer
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
hidden_states = self.proj_out(hidden_states)
# (N, L, patch_size ** 2 * out_channels)
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], video_length, height, width, patch_size, patch_size, self.out_channels)
)
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, video_length, height * patch_size, width * patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@classmethod
def from_pretrained_2d(
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if low_cpu_mem_usage:
try:
import re
from diffusers.models.modeling_utils import \
load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
import accelerate
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model = cls.from_config(config, **transformer_additional_kwargs)
param_device = "cpu"
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_path,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
print(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
return model
except Exception as e:
print(
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
)
model = cls.from_config(config, **transformer_additional_kwargs)
if os.path.exists(model_file):
state_dict = torch.load(model_file, map_location="cpu")
elif os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
from safetensors.torch import load_file, safe_open
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
state_dict = {}
for model_file_safetensors in model_files_safetensors:
_state_dict = load_file(model_file_safetensors)
for key in _state_dict:
state_dict[key] = _state_dict[key]
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
if len(new_shape) == 5:
state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
state_dict['pos_embed.proj.weight'][:, :, :-1] = 0
else:
if model.state_dict()['pos_embed.proj.weight'].size()[1] > state_dict['pos_embed.proj.weight'].size()[1]:
model.state_dict()['pos_embed.proj.weight'][:, :state_dict['pos_embed.proj.weight'].size()[1], :, :] = state_dict['pos_embed.proj.weight']
model.state_dict()['pos_embed.proj.weight'][:, state_dict['pos_embed.proj.weight'].size()[1]:, :, :] = 0
state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
else:
model.state_dict()['pos_embed.proj.weight'][:, :, :, :] = state_dict['pos_embed.proj.weight'][:, :model.state_dict()['pos_embed.proj.weight'].size()[1], :, :]
state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
if model.state_dict()['proj_out.weight'].size()[0] > state_dict['proj_out.weight'].size()[0]:
model.state_dict()['proj_out.weight'][:state_dict['proj_out.weight'].size()[0], :] = state_dict['proj_out.weight']
state_dict['proj_out.weight'] = model.state_dict()['proj_out.weight']
else:
model.state_dict()['proj_out.weight'][:, :] = state_dict['proj_out.weight'][:model.state_dict()['proj_out.weight'].size()[0], :]
state_dict['proj_out.weight'] = model.state_dict()['proj_out.weight']
if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
if model.state_dict()['proj_out.bias'].size()[0] > state_dict['proj_out.bias'].size()[0]:
model.state_dict()['proj_out.bias'][:state_dict['proj_out.bias'].size()[0]] = state_dict['proj_out.bias']
state_dict['proj_out.bias'] = model.state_dict()['proj_out.bias']
else:
model.state_dict()['proj_out.bias'][:, :] = state_dict['proj_out.bias'][:model.state_dict()['proj_out.bias'].size()[0], :]
state_dict['proj_out.bias'] = model.state_dict()['proj_out.bias']
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
else:
print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
print(m)
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
print(f"### Mamba Parameters: {sum(params) / 1e6} M")
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
model = model.to(torch_dtype)
return model
class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
patch_size: Optional[int] = None,
sample_width: int = 90,
sample_height: int = 60,
ref_channels: int = None,
clip_channels: int = None,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
freq_shift: int = 0,
num_layers: int = 30,
mmdit_layers: int = 10000,
swa_layers: list = None,
dropout: float = 0.0,
time_embed_dim: int = 512,
add_norm_text_encoder: bool = False,
text_embed_dim: int = 4096,
text_embed_dim_t5: int = 4096,
norm_eps: float = 1e-5,
norm_elementwise_affine: bool = True,
flip_sin_to_cos: bool = True,
time_position_encoding_type: str = "3d_rope",
after_norm = False,
resize_inpaint_mask_directly: bool = False,
enable_clip_in_inpaint: bool = True,
position_of_clip_embedding: str = "full",
enable_text_attention_mask: bool = True,
add_noise_in_inpaint_model: bool = False,
add_ref_latent_in_control_model: bool = False,
):
super().__init__()
self.num_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.resize_inpaint_mask_directly = resize_inpaint_mask_directly
self.patch_size = patch_size
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
self.post_patch_height = post_patch_height
self.post_patch_width = post_patch_width
self.time_proj = Timesteps(self.inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(self.inner_dim, time_embed_dim, timestep_activation_fn)
self.proj = nn.Conv2d(
in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
)
if not add_norm_text_encoder:
self.text_proj = nn.Linear(text_embed_dim, self.inner_dim)
if text_embed_dim_t5 is not None:
self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim)
else:
self.text_proj = nn.Sequential(
EasyAnimateRMSNorm(text_embed_dim),
nn.Linear(text_embed_dim, self.inner_dim)
)
if text_embed_dim_t5 is not None:
self.text_proj_t5 = nn.Sequential(
EasyAnimateRMSNorm(text_embed_dim),
nn.Linear(text_embed_dim_t5, self.inner_dim)
)
if ref_channels is not None:
self.ref_proj = nn.Conv2d(
ref_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
)
ref_pos_embedding = get_2d_sincos_pos_embed(self.inner_dim, (post_patch_height, post_patch_width))
ref_pos_embedding = torch.from_numpy(ref_pos_embedding)
self.register_buffer("ref_pos_embedding", ref_pos_embedding, persistent=False)
if clip_channels is not None:
self.clip_proj = nn.Linear(clip_channels, self.inner_dim)
self.swa_layers = swa_layers
if swa_layers is not None:
self.transformer_blocks = nn.ModuleList(
[
EasyAnimateDiTBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
after_norm=after_norm,
is_mmdit_block=True if index < mmdit_layers else False,
is_swa=True if index in swa_layers else False,
)
for index in range(num_layers)
]
)
else:
self.transformer_blocks = nn.ModuleList(
[
EasyAnimateDiTBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
after_norm=after_norm,
is_mmdit_block=True if _ < mmdit_layers else False,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * self.inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states,
timestep,
timestep_cond = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
text_embedding_mask: Optional[torch.Tensor] = None,
encoder_hidden_states_t5: Optional[torch.Tensor] = None,
text_embedding_mask_t5: Optional[torch.Tensor] = None,
image_meta_size = None,
style = None,
image_rotary_emb: Optional[torch.Tensor] = None,
inpaint_latents: Optional[torch.Tensor] = None,
control_latents: Optional[torch.Tensor] = None,
ref_latents: Optional[torch.Tensor] = None,
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
clip_attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
return_dict=True,
):
batch_size, channels, video_length, height, width = hidden_states.size()
# 1. Time embedding
temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
temb = self.time_embedding(temb, timestep_cond)
# 2. Patch embedding
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
if control_latents is not None:
hidden_states = torch.concat([hidden_states, control_latents], 1)
hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
hidden_states = self.proj(hidden_states)
hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length, h=height // self.patch_size, w=width // self.patch_size)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
encoder_hidden_states = self.text_proj(encoder_hidden_states)
if encoder_hidden_states_t5 is not None:
encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()
if ref_latents is not None:
ref_batch, ref_channels, ref_video_length, ref_height, ref_width = ref_latents.shape
ref_latents = rearrange(ref_latents, "b c f h w ->(b f) c h w")
ref_latents = self.ref_proj(ref_latents)
ref_latents = rearrange(ref_latents, "(b f) c h w -> b c f h w", f=ref_video_length, h=ref_height // self.patch_size, w=ref_width // self.patch_size)
ref_latents = ref_latents.flatten(2).transpose(1, 2)
emb_size = hidden_states.size()[-1]
ref_pos_embedding = self.ref_pos_embedding
ref_pos_embedding_interpolate = ref_pos_embedding.view(1, 1, self.post_patch_height, self.post_patch_width, emb_size).permute([0, 4, 1, 2, 3])
ref_pos_embedding_interpolate = F.interpolate(
ref_pos_embedding_interpolate,
size=[1, height // self.config.patch_size, width // self.config.patch_size],
mode='trilinear', align_corners=False
)
ref_pos_embedding_interpolate = ref_pos_embedding_interpolate.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
ref_latents = ref_latents + ref_pos_embedding_interpolate
encoder_hidden_states = ref_latents
if clip_encoder_hidden_states is not None:
clip_encoder_hidden_states = self.clip_proj(clip_encoder_hidden_states)
encoder_hidden_states = torch.concat([clip_encoder_hidden_states, ref_latents], dim=1)
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
video_length,
height // self.patch_size,
width // self.patch_size,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
num_frames=video_length,
height=height // self.patch_size,
width=width // self.patch_size
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, encoder_hidden_states.size()[1]:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, video_length, height // p, width // p, channels, p, p)
output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@classmethod
def from_pretrained_2d(
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if low_cpu_mem_usage:
try:
import re
from diffusers.models.modeling_utils import \
load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
import accelerate
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model = cls.from_config(config, **transformer_additional_kwargs)
param_device = "cpu"
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_path,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
print(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
return model
except Exception as e:
print(
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
)
model = cls.from_config(config, **transformer_additional_kwargs)
if os.path.exists(model_file):
state_dict = torch.load(model_file, map_location="cpu")
elif os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
from safetensors.torch import load_file, safe_open
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
state_dict = {}
for model_file_safetensors in model_files_safetensors:
_state_dict = load_file(model_file_safetensors)
for key in _state_dict:
state_dict[key] = _state_dict[key]
if model.state_dict()['proj.weight'].size() != state_dict['proj.weight'].size():
new_shape = model.state_dict()['proj.weight'].size()
if len(new_shape) == 5:
state_dict['proj.weight'] = state_dict['proj.weight'].unsqueeze(2).expand(new_shape).clone()
state_dict['proj.weight'][:, :, :-1] = 0
else:
if model.state_dict()['proj.weight'].size()[1] > state_dict['proj.weight'].size()[1]:
model.state_dict()['proj.weight'][:, :state_dict['proj.weight'].size()[1], :, :] = state_dict['proj.weight']
model.state_dict()['proj.weight'][:, state_dict['proj.weight'].size()[1]:, :, :] = 0
state_dict['proj.weight'] = model.state_dict()['proj.weight']
else:
model.state_dict()['proj.weight'][:, :, :, :] = state_dict['proj.weight'][:, :model.state_dict()['proj.weight'].size()[1], :, :]
state_dict['proj.weight'] = model.state_dict()['proj.weight']
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
else:
print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
print(m)
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
print(f"### All Parameters: {sum(params) / 1e6} M")
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
model = model.to(torch_dtype)
return model