Merge pull request #39 from LightricksResearch/bugfix/fix-attention-and-timestep-conditioning
Browse files
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
@@ -220,7 +220,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
220 |
|
221 |
def set_use_tpu_flash_attention(self):
|
222 |
for block in self.decoder.up_blocks:
|
223 |
-
if isinstance(block,
|
224 |
for attention_block in block.attention_blocks:
|
225 |
attention_block.set_use_tpu_flash_attention()
|
226 |
|
@@ -497,17 +497,18 @@ class Decoder(nn.Module):
|
|
497 |
resnet_groups=norm_num_groups,
|
498 |
norm_layer=norm_layer,
|
499 |
inject_noise=block_params.get("inject_noise", False),
|
|
|
500 |
)
|
501 |
elif block_name == "attn_res_x":
|
502 |
-
block =
|
503 |
dims=dims,
|
504 |
in_channels=input_channel,
|
505 |
num_layers=block_params["num_layers"],
|
506 |
resnet_groups=norm_num_groups,
|
507 |
norm_layer=norm_layer,
|
508 |
-
attention_head_dim=block_params["attention_head_dim"],
|
509 |
inject_noise=block_params.get("inject_noise", False),
|
510 |
timestep_conditioning=timestep_conditioning,
|
|
|
511 |
)
|
512 |
elif block_name == "res_x_y":
|
513 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
@@ -642,129 +643,6 @@ class Decoder(nn.Module):
|
|
642 |
return sample
|
643 |
|
644 |
|
645 |
-
class AttentionResBlocks(nn.Module):
|
646 |
-
"""
|
647 |
-
A 3D convolution residual block followed by self attention residual block
|
648 |
-
|
649 |
-
Args:
|
650 |
-
dims (`int` or `Tuple[int, int]`): The number of dimensions to use in convolutions.
|
651 |
-
in_channels (`int`): The number of input channels.
|
652 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
653 |
-
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
654 |
-
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
655 |
-
resnet_groups (`int`, *optional*, defaults to 32):
|
656 |
-
The number of groups to use in the group normalization layers of the resnet blocks.
|
657 |
-
norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use.
|
658 |
-
attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
|
659 |
-
inject_noise (`bool`, *optional*, defaults to `False`): Whether to inject noise or not between convolution layers.
|
660 |
-
|
661 |
-
Returns:
|
662 |
-
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
663 |
-
in_channels, height, width)`.
|
664 |
-
|
665 |
-
"""
|
666 |
-
|
667 |
-
def __init__(
|
668 |
-
self,
|
669 |
-
dims: Union[int, Tuple[int, int]],
|
670 |
-
in_channels: int,
|
671 |
-
dropout: float = 0.0,
|
672 |
-
num_layers: int = 1,
|
673 |
-
resnet_eps: float = 1e-6,
|
674 |
-
resnet_groups: int = 32,
|
675 |
-
norm_layer: str = "group_norm",
|
676 |
-
attention_head_dim: int = 64,
|
677 |
-
inject_noise: bool = False,
|
678 |
-
):
|
679 |
-
super().__init__()
|
680 |
-
|
681 |
-
if attention_head_dim > in_channels:
|
682 |
-
raise ValueError(
|
683 |
-
"attention_head_dim must be less than or equal to in_channels"
|
684 |
-
)
|
685 |
-
|
686 |
-
resnet_groups = (
|
687 |
-
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
688 |
-
)
|
689 |
-
|
690 |
-
self.res_blocks = []
|
691 |
-
self.attention_blocks = []
|
692 |
-
for i in range(num_layers):
|
693 |
-
self.res_blocks.append(
|
694 |
-
ResnetBlock3D(
|
695 |
-
dims=dims,
|
696 |
-
in_channels=in_channels,
|
697 |
-
out_channels=in_channels,
|
698 |
-
eps=resnet_eps,
|
699 |
-
groups=resnet_groups,
|
700 |
-
dropout=dropout,
|
701 |
-
norm_layer=norm_layer,
|
702 |
-
inject_noise=inject_noise,
|
703 |
-
)
|
704 |
-
)
|
705 |
-
self.attention_blocks.append(
|
706 |
-
Attention(
|
707 |
-
query_dim=in_channels,
|
708 |
-
heads=in_channels // attention_head_dim,
|
709 |
-
dim_head=attention_head_dim,
|
710 |
-
bias=True,
|
711 |
-
out_bias=True,
|
712 |
-
qk_norm="rms_norm",
|
713 |
-
residual_connection=True,
|
714 |
-
)
|
715 |
-
)
|
716 |
-
|
717 |
-
self.res_blocks = nn.ModuleList(self.res_blocks)
|
718 |
-
self.attention_blocks = nn.ModuleList(self.attention_blocks)
|
719 |
-
|
720 |
-
def forward(
|
721 |
-
self, hidden_states: torch.FloatTensor, causal: bool = True
|
722 |
-
) -> torch.FloatTensor:
|
723 |
-
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
|
724 |
-
hidden_states = resnet(hidden_states, causal=causal)
|
725 |
-
|
726 |
-
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
|
727 |
-
batch_size, channel, frames, height, width = hidden_states.shape
|
728 |
-
hidden_states = hidden_states.view(
|
729 |
-
batch_size, channel, frames * height * width
|
730 |
-
).transpose(1, 2)
|
731 |
-
|
732 |
-
if attention.use_tpu_flash_attention:
|
733 |
-
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
|
734 |
-
seq_len = hidden_states.shape[1]
|
735 |
-
block_k_major = 512
|
736 |
-
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
|
737 |
-
if pad_len > 0:
|
738 |
-
hidden_states = F.pad(
|
739 |
-
hidden_states, (0, 0, 0, pad_len), "constant", 0
|
740 |
-
)
|
741 |
-
|
742 |
-
# Create a mask with ones for the original sequence length and zeros for the padded indexes
|
743 |
-
mask = torch.ones(
|
744 |
-
(hidden_states.shape[0], seq_len),
|
745 |
-
device=hidden_states.device,
|
746 |
-
dtype=hidden_states.dtype,
|
747 |
-
)
|
748 |
-
if pad_len > 0:
|
749 |
-
mask = F.pad(mask, (0, pad_len), "constant", 0)
|
750 |
-
|
751 |
-
hidden_states = attention(
|
752 |
-
hidden_states,
|
753 |
-
attention_mask=None if not attention.use_tpu_flash_attention else mask,
|
754 |
-
)
|
755 |
-
|
756 |
-
if attention.use_tpu_flash_attention:
|
757 |
-
# Remove the padding
|
758 |
-
if pad_len > 0:
|
759 |
-
hidden_states = hidden_states[:, :-pad_len, :]
|
760 |
-
|
761 |
-
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
|
762 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
763 |
-
batch_size, channel, frames, height, width
|
764 |
-
)
|
765 |
-
return hidden_states
|
766 |
-
|
767 |
-
|
768 |
class UNetMidBlock3D(nn.Module):
|
769 |
"""
|
770 |
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
@@ -776,6 +654,14 @@ class UNetMidBlock3D(nn.Module):
|
|
776 |
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
777 |
resnet_groups (`int`, *optional*, defaults to 32):
|
778 |
The number of groups to use in the group normalization layers of the resnet blocks.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
779 |
|
780 |
Returns:
|
781 |
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
@@ -794,6 +680,7 @@ class UNetMidBlock3D(nn.Module):
|
|
794 |
norm_layer: str = "group_norm",
|
795 |
inject_noise: bool = False,
|
796 |
timestep_conditioning: bool = False,
|
|
|
797 |
):
|
798 |
super().__init__()
|
799 |
resnet_groups = (
|
@@ -823,6 +710,29 @@ class UNetMidBlock3D(nn.Module):
|
|
823 |
]
|
824 |
)
|
825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
826 |
def forward(
|
827 |
self,
|
828 |
hidden_states: torch.FloatTensor,
|
@@ -845,10 +755,60 @@ class UNetMidBlock3D(nn.Module):
|
|
845 |
timestep_embed = timestep_embed.view(
|
846 |
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
847 |
)
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
852 |
return hidden_states
|
853 |
|
854 |
|
|
|
220 |
|
221 |
def set_use_tpu_flash_attention(self):
|
222 |
for block in self.decoder.up_blocks:
|
223 |
+
if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
|
224 |
for attention_block in block.attention_blocks:
|
225 |
attention_block.set_use_tpu_flash_attention()
|
226 |
|
|
|
497 |
resnet_groups=norm_num_groups,
|
498 |
norm_layer=norm_layer,
|
499 |
inject_noise=block_params.get("inject_noise", False),
|
500 |
+
timestep_conditioning=timestep_conditioning,
|
501 |
)
|
502 |
elif block_name == "attn_res_x":
|
503 |
+
block = UNetMidBlock3D(
|
504 |
dims=dims,
|
505 |
in_channels=input_channel,
|
506 |
num_layers=block_params["num_layers"],
|
507 |
resnet_groups=norm_num_groups,
|
508 |
norm_layer=norm_layer,
|
|
|
509 |
inject_noise=block_params.get("inject_noise", False),
|
510 |
timestep_conditioning=timestep_conditioning,
|
511 |
+
attention_head_dim=block_params["attention_head_dim"],
|
512 |
)
|
513 |
elif block_name == "res_x_y":
|
514 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
|
|
643 |
return sample
|
644 |
|
645 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
class UNetMidBlock3D(nn.Module):
|
647 |
"""
|
648 |
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
|
|
654 |
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
655 |
resnet_groups (`int`, *optional*, defaults to 32):
|
656 |
The number of groups to use in the group normalization layers of the resnet blocks.
|
657 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
658 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
659 |
+
inject_noise (`bool`, *optional*, defaults to `False`):
|
660 |
+
Whether to inject noise into the hidden states.
|
661 |
+
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
662 |
+
Whether to condition the hidden states on the timestep.
|
663 |
+
attention_head_dim (`int`, *optional*, defaults to -1):
|
664 |
+
The dimension of the attention head. If -1, no attention is used.
|
665 |
|
666 |
Returns:
|
667 |
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
|
|
680 |
norm_layer: str = "group_norm",
|
681 |
inject_noise: bool = False,
|
682 |
timestep_conditioning: bool = False,
|
683 |
+
attention_head_dim: int = -1,
|
684 |
):
|
685 |
super().__init__()
|
686 |
resnet_groups = (
|
|
|
710 |
]
|
711 |
)
|
712 |
|
713 |
+
self.attention_blocks = None
|
714 |
+
|
715 |
+
if attention_head_dim > 0:
|
716 |
+
if attention_head_dim > in_channels:
|
717 |
+
raise ValueError(
|
718 |
+
"attention_head_dim must be less than or equal to in_channels"
|
719 |
+
)
|
720 |
+
|
721 |
+
self.attention_blocks = nn.ModuleList(
|
722 |
+
[
|
723 |
+
Attention(
|
724 |
+
query_dim=in_channels,
|
725 |
+
heads=in_channels // attention_head_dim,
|
726 |
+
dim_head=attention_head_dim,
|
727 |
+
bias=True,
|
728 |
+
out_bias=True,
|
729 |
+
qk_norm="rms_norm",
|
730 |
+
residual_connection=True,
|
731 |
+
)
|
732 |
+
for _ in range(num_layers)
|
733 |
+
]
|
734 |
+
)
|
735 |
+
|
736 |
def forward(
|
737 |
self,
|
738 |
hidden_states: torch.FloatTensor,
|
|
|
755 |
timestep_embed = timestep_embed.view(
|
756 |
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
757 |
)
|
758 |
+
|
759 |
+
if self.attention_blocks:
|
760 |
+
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
|
761 |
+
hidden_states = resnet(
|
762 |
+
hidden_states, causal=causal, timesteps=timestep_embed
|
763 |
+
)
|
764 |
+
|
765 |
+
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
|
766 |
+
batch_size, channel, frames, height, width = hidden_states.shape
|
767 |
+
hidden_states = hidden_states.view(
|
768 |
+
batch_size, channel, frames * height * width
|
769 |
+
).transpose(1, 2)
|
770 |
+
|
771 |
+
if attention.use_tpu_flash_attention:
|
772 |
+
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
|
773 |
+
seq_len = hidden_states.shape[1]
|
774 |
+
block_k_major = 512
|
775 |
+
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
|
776 |
+
if pad_len > 0:
|
777 |
+
hidden_states = F.pad(
|
778 |
+
hidden_states, (0, 0, 0, pad_len), "constant", 0
|
779 |
+
)
|
780 |
+
|
781 |
+
# Create a mask with ones for the original sequence length and zeros for the padded indexes
|
782 |
+
mask = torch.ones(
|
783 |
+
(hidden_states.shape[0], seq_len),
|
784 |
+
device=hidden_states.device,
|
785 |
+
dtype=hidden_states.dtype,
|
786 |
+
)
|
787 |
+
if pad_len > 0:
|
788 |
+
mask = F.pad(mask, (0, pad_len), "constant", 0)
|
789 |
+
|
790 |
+
hidden_states = attention(
|
791 |
+
hidden_states,
|
792 |
+
attention_mask=(
|
793 |
+
None if not attention.use_tpu_flash_attention else mask
|
794 |
+
),
|
795 |
+
)
|
796 |
+
|
797 |
+
if attention.use_tpu_flash_attention:
|
798 |
+
# Remove the padding
|
799 |
+
if pad_len > 0:
|
800 |
+
hidden_states = hidden_states[:, :-pad_len, :]
|
801 |
+
|
802 |
+
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
|
803 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
804 |
+
batch_size, channel, frames, height, width
|
805 |
+
)
|
806 |
+
else:
|
807 |
+
for resnet in self.res_blocks:
|
808 |
+
hidden_states = resnet(
|
809 |
+
hidden_states, causal=causal, timesteps=timestep_embed
|
810 |
+
)
|
811 |
+
|
812 |
return hidden_states
|
813 |
|
814 |
|