|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin |
|
from diffusers.models.attention import FeedForward |
|
from diffusers.models.attention_processor import ( |
|
Attention, |
|
apply_rope, |
|
) |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from diffusers.models.normalization import ( |
|
AdaLayerNormContinuous, |
|
AdaLayerNormZero, |
|
AdaLayerNormZeroSingle, |
|
) |
|
from diffusers.utils import ( |
|
USE_PEFT_BACKEND, |
|
is_torch_version, |
|
logging, |
|
scale_lora_layers, |
|
unscale_lora_layers, |
|
) |
|
from diffusers.utils.torch_utils import maybe_allow_in_graph |
|
from diffusers.models.embeddings import ( |
|
CombinedTimestepGuidanceTextProjEmbeddings, |
|
CombinedTimestepTextProjEmbeddings, |
|
) |
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class FluxSingleAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError( |
|
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." |
|
) |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view( |
|
batch_size, channel, height * width |
|
).transpose(1, 2) |
|
|
|
batch_size, _, _ = hidden_states.shape |
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
if image_rotary_emb is not None: |
|
|
|
|
|
|
|
|
|
query, key = apply_rope(query, key, image_rotary_emb) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
attention_mask = (attention_mask > 0).bool() |
|
attention_mask = attention_mask.to( |
|
device=hidden_states.device, dtype=hidden_states.dtype |
|
) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, |
|
key, |
|
value, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
attn_mask=attention_mask, |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape( |
|
batch_size, -1, attn.heads * head_dim |
|
) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape( |
|
batch_size, channel, height, width |
|
) |
|
|
|
return hidden_states |
|
|
|
|
|
class FluxAttnProcessor2_0: |
|
"""Attention processor used typically in processing the SD3-like self-attention projections.""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError( |
|
"FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." |
|
) |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
) -> torch.FloatTensor: |
|
input_ndim = hidden_states.ndim |
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view( |
|
batch_size, channel, height * width |
|
).transpose(1, 2) |
|
context_input_ndim = encoder_hidden_states.ndim |
|
if context_input_ndim == 4: |
|
batch_size, channel, height, width = encoder_hidden_states.shape |
|
encoder_hidden_states = encoder_hidden_states.view( |
|
batch_size, channel, height * width |
|
).transpose(1, 2) |
|
|
|
batch_size = encoder_hidden_states.shape[0] |
|
|
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
|
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( |
|
batch_size, -1, attn.heads, head_dim |
|
).transpose(1, 2) |
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( |
|
batch_size, -1, attn.heads, head_dim |
|
).transpose(1, 2) |
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( |
|
batch_size, -1, attn.heads, head_dim |
|
).transpose(1, 2) |
|
|
|
if attn.norm_added_q is not None: |
|
encoder_hidden_states_query_proj = attn.norm_added_q( |
|
encoder_hidden_states_query_proj |
|
) |
|
if attn.norm_added_k is not None: |
|
encoder_hidden_states_key_proj = attn.norm_added_k( |
|
encoder_hidden_states_key_proj |
|
) |
|
|
|
|
|
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) |
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) |
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) |
|
|
|
if image_rotary_emb is not None: |
|
|
|
|
|
|
|
|
|
query, key = apply_rope(query, key, image_rotary_emb) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
attention_mask = (attention_mask > 0).bool() |
|
attention_mask = attention_mask.to( |
|
device=hidden_states.device, dtype=hidden_states.dtype |
|
) |
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, |
|
key, |
|
value, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
attn_mask=attention_mask, |
|
) |
|
hidden_states = hidden_states.transpose(1, 2).reshape( |
|
batch_size, -1, attn.heads * head_dim |
|
) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
encoder_hidden_states, hidden_states = ( |
|
hidden_states[:, : encoder_hidden_states.shape[1]], |
|
hidden_states[:, encoder_hidden_states.shape[1] :], |
|
) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape( |
|
batch_size, channel, height, width |
|
) |
|
if context_input_ndim == 4: |
|
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( |
|
batch_size, channel, height, width |
|
) |
|
|
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
|
|
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: |
|
assert dim % 2 == 0, "The dimension must be even." |
|
|
|
scale = ( |
|
torch.arange( |
|
0, |
|
dim, |
|
2, |
|
dtype=torch.float64, |
|
device=pos.device, |
|
) |
|
/ dim |
|
) |
|
omega = 1.0 / (theta**scale) |
|
|
|
batch_size, seq_length = pos.shape |
|
out = torch.einsum("...n,d->...nd", pos, omega) |
|
cos_out = torch.cos(out) |
|
sin_out = torch.sin(out) |
|
|
|
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) |
|
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) |
|
return out.float() |
|
|
|
|
|
|
|
class EmbedND(nn.Module): |
|
def __init__(self, dim: int, theta: int, axes_dim: List[int]): |
|
super().__init__() |
|
self.dim = dim |
|
self.theta = theta |
|
self.axes_dim = axes_dim |
|
|
|
def forward(self, ids: torch.Tensor) -> torch.Tensor: |
|
n_axes = ids.shape[-1] |
|
emb = torch.cat( |
|
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], |
|
dim=-3, |
|
) |
|
|
|
return emb.unsqueeze(1) |
|
|
|
|
|
def expand_flux_attention_mask( |
|
hidden_states: torch.Tensor, |
|
attn_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Expand a mask so that the image is included. |
|
""" |
|
bsz = attn_mask.shape[0] |
|
assert bsz == hidden_states.shape[0] |
|
residual_seq_len = hidden_states.shape[1] |
|
mask_seq_len = attn_mask.shape[1] |
|
|
|
expanded_mask = torch.ones(bsz, residual_seq_len) |
|
expanded_mask[:, :mask_seq_len] = attn_mask |
|
|
|
return expanded_mask |
|
|
|
|
|
@maybe_allow_in_graph |
|
class FluxSingleTransformerBlock(nn.Module): |
|
r""" |
|
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. |
|
|
|
Reference: https://arxiv.org/abs/2403.03206 |
|
|
|
Parameters: |
|
dim (`int`): The number of channels in the input and output. |
|
num_attention_heads (`int`): The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`): The number of channels in each head. |
|
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the |
|
processing of `context` conditions. |
|
""" |
|
|
|
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): |
|
super().__init__() |
|
self.mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
|
self.norm = AdaLayerNormZeroSingle(dim) |
|
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) |
|
self.act_mlp = nn.GELU(approximate="tanh") |
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) |
|
|
|
processor = FluxSingleAttnProcessor2_0() |
|
self.attn = Attention( |
|
query_dim=dim, |
|
cross_attention_dim=None, |
|
dim_head=attention_head_dim, |
|
heads=num_attention_heads, |
|
out_dim=dim, |
|
bias=True, |
|
processor=processor, |
|
qk_norm="rms_norm", |
|
eps=1e-6, |
|
pre_only=True, |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
temb: torch.FloatTensor, |
|
image_rotary_emb=None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
): |
|
residual = hidden_states |
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb) |
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) |
|
|
|
if attention_mask is not None: |
|
attention_mask = expand_flux_attention_mask( |
|
hidden_states, |
|
attention_mask, |
|
) |
|
|
|
attn_output = self.attn( |
|
hidden_states=norm_hidden_states, |
|
image_rotary_emb=image_rotary_emb, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) |
|
gate = gate.unsqueeze(1) |
|
hidden_states = gate * self.proj_out(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
@maybe_allow_in_graph |
|
class FluxTransformerBlock(nn.Module): |
|
r""" |
|
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. |
|
|
|
Reference: https://arxiv.org/abs/2403.03206 |
|
|
|
Parameters: |
|
dim (`int`): The number of channels in the input and output. |
|
num_attention_heads (`int`): The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`): The number of channels in each head. |
|
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the |
|
processing of `context` conditions. |
|
""" |
|
|
|
def __init__( |
|
self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6 |
|
): |
|
super().__init__() |
|
|
|
self.norm1 = AdaLayerNormZero(dim) |
|
|
|
self.norm1_context = AdaLayerNormZero(dim) |
|
|
|
if hasattr(F, "scaled_dot_product_attention"): |
|
processor = FluxAttnProcessor2_0() |
|
else: |
|
raise ValueError( |
|
"The current PyTorch version does not support the `scaled_dot_product_attention` function." |
|
) |
|
self.attn = Attention( |
|
query_dim=dim, |
|
cross_attention_dim=None, |
|
added_kv_proj_dim=dim, |
|
dim_head=attention_head_dim, |
|
heads=num_attention_heads, |
|
out_dim=dim, |
|
context_pre_only=False, |
|
bias=True, |
|
processor=processor, |
|
qk_norm=qk_norm, |
|
eps=eps, |
|
) |
|
|
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
|
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") |
|
|
|
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
|
self.ff_context = FeedForward( |
|
dim=dim, dim_out=dim, activation_fn="gelu-approximate" |
|
) |
|
|
|
|
|
self._chunk_size = None |
|
self._chunk_dim = 0 |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor, |
|
temb: torch.FloatTensor, |
|
image_rotary_emb=None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
): |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( |
|
hidden_states, emb=temb |
|
) |
|
|
|
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( |
|
self.norm1_context(encoder_hidden_states, emb=temb) |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = expand_flux_attention_mask( |
|
torch.cat([encoder_hidden_states, hidden_states], dim=1), |
|
attention_mask, |
|
) |
|
|
|
|
|
attn_output, context_attn_output = self.attn( |
|
hidden_states=norm_hidden_states, |
|
encoder_hidden_states=norm_encoder_hidden_states, |
|
image_rotary_emb=image_rotary_emb, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
|
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
hidden_states = hidden_states + attn_output |
|
|
|
norm_hidden_states = self.norm2(hidden_states) |
|
norm_hidden_states = ( |
|
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
) |
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
|
hidden_states = hidden_states + ff_output |
|
|
|
|
|
|
|
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output |
|
encoder_hidden_states = encoder_hidden_states + context_attn_output |
|
|
|
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
|
norm_encoder_hidden_states = ( |
|
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) |
|
+ c_shift_mlp[:, None] |
|
) |
|
|
|
context_ff_output = self.ff_context(norm_encoder_hidden_states) |
|
encoder_hidden_states = ( |
|
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output |
|
) |
|
|
|
return encoder_hidden_states, hidden_states |
|
|
|
|
|
class FluxTransformer2DModelWithMasking( |
|
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin |
|
): |
|
""" |
|
The Transformer model introduced in Flux. |
|
|
|
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ |
|
|
|
Parameters: |
|
patch_size (`int`): Patch size to turn the input data into small patches. |
|
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. |
|
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. |
|
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. |
|
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. |
|
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. |
|
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. |
|
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. |
|
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
patch_size: int = 1, |
|
in_channels: int = 64, |
|
num_layers: int = 19, |
|
num_single_layers: int = 38, |
|
attention_head_dim: int = 128, |
|
num_attention_heads: int = 24, |
|
joint_attention_dim: int = 4096, |
|
pooled_projection_dim: int = 768, |
|
guidance_embeds: bool = False, |
|
axes_dims_rope: List[int] = [16, 56, 56], |
|
): |
|
super().__init__() |
|
self.out_channels = in_channels |
|
self.inner_dim = ( |
|
self.config.num_attention_heads * self.config.attention_head_dim |
|
) |
|
|
|
self.pos_embed = EmbedND( |
|
dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope |
|
) |
|
text_time_guidance_cls = ( |
|
CombinedTimestepGuidanceTextProjEmbeddings |
|
if guidance_embeds |
|
else CombinedTimestepTextProjEmbeddings |
|
) |
|
self.time_text_embed = text_time_guidance_cls( |
|
embedding_dim=self.inner_dim, |
|
pooled_projection_dim=self.config.pooled_projection_dim, |
|
) |
|
|
|
self.context_embedder = nn.Linear( |
|
self.config.joint_attention_dim, self.inner_dim |
|
) |
|
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) |
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
FluxTransformerBlock( |
|
dim=self.inner_dim, |
|
num_attention_heads=self.config.num_attention_heads, |
|
attention_head_dim=self.config.attention_head_dim, |
|
) |
|
for i in range(self.config.num_layers) |
|
] |
|
) |
|
|
|
self.single_transformer_blocks = nn.ModuleList( |
|
[ |
|
FluxSingleTransformerBlock( |
|
dim=self.inner_dim, |
|
num_attention_heads=self.config.num_attention_heads, |
|
attention_head_dim=self.config.attention_head_dim, |
|
) |
|
for i in range(self.config.num_single_layers) |
|
] |
|
) |
|
|
|
self.norm_out = AdaLayerNormContinuous( |
|
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 |
|
) |
|
self.proj_out = nn.Linear( |
|
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True |
|
) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = value |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor = None, |
|
pooled_projections: torch.Tensor = None, |
|
timestep: torch.LongTensor = None, |
|
img_ids: torch.Tensor = None, |
|
txt_ids: torch.Tensor = None, |
|
guidance: torch.Tensor = None, |
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
return_dict: bool = True, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: |
|
""" |
|
The [`FluxTransformer2DModelWithMasking`] forward method. |
|
|
|
Args: |
|
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): |
|
Input `hidden_states`. |
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): |
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. |
|
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected |
|
from the embeddings of input conditions. |
|
timestep ( `torch.LongTensor`): |
|
Used to indicate denoising step. |
|
block_controlnet_hidden_states: (`list` of `torch.Tensor`): |
|
A list of tensors that if specified are added to the residuals of transformer blocks. |
|
joint_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
|
`tuple` where the first element is the sample tensor. |
|
""" |
|
if joint_attention_kwargs is not None: |
|
joint_attention_kwargs = joint_attention_kwargs.copy() |
|
lora_scale = joint_attention_kwargs.pop("scale", 1.0) |
|
else: |
|
lora_scale = 1.0 |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
scale_lora_layers(self, lora_scale) |
|
else: |
|
if ( |
|
joint_attention_kwargs is not None |
|
and joint_attention_kwargs.get("scale", None) is not None |
|
): |
|
logger.warning( |
|
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." |
|
) |
|
hidden_states = self.x_embedder(hidden_states) |
|
|
|
timestep = timestep.to(hidden_states.dtype) * 1000 |
|
if guidance is not None: |
|
guidance = guidance.to(hidden_states.dtype) * 1000 |
|
else: |
|
guidance = None |
|
temb = ( |
|
self.time_text_embed(timestep, pooled_projections) |
|
if guidance is None |
|
else self.time_text_embed(timestep, guidance, pooled_projections) |
|
) |
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states) |
|
|
|
ids = torch.cat((txt_ids, img_ids), dim=1) |
|
image_rotary_emb = self.pos_embed(ids) |
|
|
|
for index_block, block in enumerate(self.transformer_blocks): |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = ( |
|
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
) |
|
encoder_hidden_states, hidden_states = ( |
|
torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
temb, |
|
image_rotary_emb, |
|
attention_mask, |
|
**ckpt_kwargs, |
|
) |
|
) |
|
|
|
else: |
|
encoder_hidden_states, hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
temb=temb, |
|
image_rotary_emb=image_rotary_emb, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
|
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
|
for index_block, block in enumerate(self.single_transformer_blocks): |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = ( |
|
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
) |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
temb, |
|
image_rotary_emb, |
|
attention_mask, |
|
**ckpt_kwargs, |
|
) |
|
|
|
else: |
|
hidden_states = block( |
|
hidden_states=hidden_states, |
|
temb=temb, |
|
image_rotary_emb=image_rotary_emb, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] |
|
|
|
hidden_states = self.norm_out(hidden_states, temb) |
|
output = self.proj_out(hidden_states) |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self, lora_scale) |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return Transformer2DModelOutput(sample=output) |
|
|