|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from ...configuration_utils import ConfigMixin, register_to_config |
|
from ...utils import is_torch_version, logging |
|
from ..attention import BasicTransformerBlock |
|
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection |
|
from ..modeling_outputs import Transformer2DModelOutput |
|
from ..modeling_utils import ModelMixin |
|
from ..normalization import AdaLayerNormSingle |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class PixArtTransformer2DModel(ModelMixin, ConfigMixin): |
|
r""" |
|
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, |
|
https://arxiv.org/abs/2403.04692). |
|
|
|
Parameters: |
|
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. |
|
attention_head_dim (int, optional, defaults to 72): The number of channels in each head. |
|
in_channels (int, defaults to 4): The number of channels in the input. |
|
out_channels (int, optional): |
|
The number of channels in the output. Specify this parameter if the output channel number differs from the |
|
input. |
|
num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. |
|
dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. |
|
norm_num_groups (int, optional, defaults to 32): |
|
Number of groups for group normalization within Transformer blocks. |
|
cross_attention_dim (int, optional): |
|
The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension. |
|
attention_bias (bool, optional, defaults to True): |
|
Configure if the Transformer blocks' attention should contain a bias parameter. |
|
sample_size (int, defaults to 128): |
|
The width of the latent images. This parameter is fixed during training. |
|
patch_size (int, defaults to 2): |
|
Size of the patches the model processes, relevant for architectures working on non-sequential data. |
|
activation_fn (str, optional, defaults to "gelu-approximate"): |
|
Activation function to use in feed-forward networks within Transformer blocks. |
|
num_embeds_ada_norm (int, optional, defaults to 1000): |
|
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during |
|
inference. |
|
upcast_attention (bool, optional, defaults to False): |
|
If true, upcasts the attention mechanism dimensions for potentially improved performance. |
|
norm_type (str, optional, defaults to "ada_norm_zero"): |
|
Specifies the type of normalization used, can be 'ada_norm_zero'. |
|
norm_elementwise_affine (bool, optional, defaults to False): |
|
If true, enables element-wise affine parameters in the normalization layers. |
|
norm_eps (float, optional, defaults to 1e-6): |
|
A small constant added to the denominator in normalization layers to prevent division by zero. |
|
interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings. |
|
use_additional_conditions (bool, optional): If we're using additional conditions as inputs. |
|
attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used. |
|
caption_channels (int, optional, defaults to None): |
|
Number of channels to use for projecting the caption embeddings. |
|
use_linear_projection (bool, optional, defaults to False): |
|
Deprecated argument. Will be removed in a future version. |
|
num_vector_embeds (bool, optional, defaults to False): |
|
Deprecated argument. Will be removed in a future version. |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_attention_heads: int = 16, |
|
attention_head_dim: int = 72, |
|
in_channels: int = 4, |
|
out_channels: Optional[int] = 8, |
|
num_layers: int = 28, |
|
dropout: float = 0.0, |
|
norm_num_groups: int = 32, |
|
cross_attention_dim: Optional[int] = 1152, |
|
attention_bias: bool = True, |
|
sample_size: int = 128, |
|
patch_size: int = 2, |
|
activation_fn: str = "gelu-approximate", |
|
num_embeds_ada_norm: Optional[int] = 1000, |
|
upcast_attention: bool = False, |
|
norm_type: str = "ada_norm_single", |
|
norm_elementwise_affine: bool = False, |
|
norm_eps: float = 1e-6, |
|
interpolation_scale: Optional[int] = None, |
|
use_additional_conditions: Optional[bool] = None, |
|
caption_channels: Optional[int] = None, |
|
attention_type: Optional[str] = "default", |
|
): |
|
super().__init__() |
|
|
|
|
|
if norm_type != "ada_norm_single": |
|
raise NotImplementedError( |
|
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." |
|
) |
|
elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None: |
|
raise ValueError( |
|
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." |
|
) |
|
|
|
|
|
self.attention_head_dim = attention_head_dim |
|
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim |
|
self.out_channels = in_channels if out_channels is None else out_channels |
|
if use_additional_conditions is None: |
|
if sample_size == 128: |
|
use_additional_conditions = True |
|
else: |
|
use_additional_conditions = False |
|
self.use_additional_conditions = use_additional_conditions |
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.height = self.config.sample_size |
|
self.width = self.config.sample_size |
|
|
|
interpolation_scale = ( |
|
self.config.interpolation_scale |
|
if self.config.interpolation_scale is not None |
|
else max(self.config.sample_size // 64, 1) |
|
) |
|
self.pos_embed = PatchEmbed( |
|
height=self.config.sample_size, |
|
width=self.config.sample_size, |
|
patch_size=self.config.patch_size, |
|
in_channels=self.config.in_channels, |
|
embed_dim=self.inner_dim, |
|
interpolation_scale=interpolation_scale, |
|
) |
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
BasicTransformerBlock( |
|
self.inner_dim, |
|
self.config.num_attention_heads, |
|
self.config.attention_head_dim, |
|
dropout=self.config.dropout, |
|
cross_attention_dim=self.config.cross_attention_dim, |
|
activation_fn=self.config.activation_fn, |
|
num_embeds_ada_norm=self.config.num_embeds_ada_norm, |
|
attention_bias=self.config.attention_bias, |
|
upcast_attention=self.config.upcast_attention, |
|
norm_type=norm_type, |
|
norm_elementwise_affine=self.config.norm_elementwise_affine, |
|
norm_eps=self.config.norm_eps, |
|
attention_type=self.config.attention_type, |
|
) |
|
for _ in range(self.config.num_layers) |
|
] |
|
) |
|
|
|
|
|
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) |
|
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) |
|
self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) |
|
|
|
self.adaln_single = AdaLayerNormSingle( |
|
self.inner_dim, use_additional_conditions=self.use_additional_conditions |
|
) |
|
self.caption_projection = None |
|
if self.config.caption_channels is not None: |
|
self.caption_projection = PixArtAlphaTextProjection( |
|
in_features=self.config.caption_channels, hidden_size=self.inner_dim |
|
) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = value |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
added_cond_kwargs: Dict[str, torch.Tensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
): |
|
""" |
|
The [`PixArtTransformer2DModel`] forward method. |
|
|
|
Args: |
|
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): |
|
Input `hidden_states`. |
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): |
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to |
|
self-attention. |
|
timestep (`torch.LongTensor`, *optional*): |
|
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. |
|
added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs. |
|
cross_attention_kwargs ( `Dict[str, Any]`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
attention_mask ( `torch.Tensor`, *optional*): |
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask |
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large |
|
negative values to the attention scores corresponding to "discard" tokens. |
|
encoder_attention_mask ( `torch.Tensor`, *optional*): |
|
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: |
|
|
|
* Mask `(batch, sequence_length)` True = keep, False = discard. |
|
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. |
|
|
|
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format |
|
above. This bias will be added to the cross-attention scores. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
|
`tuple` where the first element is the sample tensor. |
|
""" |
|
if self.use_additional_conditions and added_cond_kwargs is None: |
|
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.ndim == 2: |
|
|
|
|
|
|
|
|
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
height, width = ( |
|
hidden_states.shape[-2] // self.config.patch_size, |
|
hidden_states.shape[-1] // self.config.patch_size, |
|
) |
|
hidden_states = self.pos_embed(hidden_states) |
|
|
|
timestep, embedded_timestep = self.adaln_single( |
|
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype |
|
) |
|
|
|
if self.caption_projection is not None: |
|
encoder_hidden_states = self.caption_projection(encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) |
|
|
|
|
|
for block in self.transformer_blocks: |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
timestep, |
|
cross_attention_kwargs, |
|
None, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states = block( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
timestep=timestep, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
class_labels=None, |
|
) |
|
|
|
|
|
shift, scale = ( |
|
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) |
|
).chunk(2, dim=1) |
|
hidden_states = self.norm_out(hidden_states) |
|
|
|
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) |
|
hidden_states = self.proj_out(hidden_states) |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
|
|
hidden_states = hidden_states.reshape( |
|
shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) |
|
) |
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) |
|
output = hidden_states.reshape( |
|
shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) |
|
) |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return Transformer2DModelOutput(sample=output) |
|
|