Yoav HaCohen commited on
Commit
f63ea56
2 Parent(s): 427926d 07ddecf

Merge pull request #34 from LightricksResearch/add_atten_to_decoder

Browse files
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -9,10 +9,12 @@ import numpy as np
9
  from einops import rearrange
10
  from torch import nn
11
  from diffusers.utils import logging
 
12
 
13
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
14
  from xora.models.autoencoders.pixel_norm import PixelNorm
15
  from xora.models.autoencoders.vae import AutoencoderKLWrapper
 
16
 
17
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
 
@@ -212,6 +214,12 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
212
  last_layer = self.decoder.layers[-1]
213
  return last_layer
214
 
 
 
 
 
 
 
215
 
216
  class Encoder(nn.Module):
217
  r"""
@@ -485,6 +493,16 @@ class Decoder(nn.Module):
485
  norm_layer=norm_layer,
486
  inject_noise=block_params.get("inject_noise", False),
487
  )
 
 
 
 
 
 
 
 
 
 
488
  elif block_name == "res_x_y":
489
  output_channel = output_channel // block_params.get("multiplier", 2)
490
  block = ResnetBlock3D(
@@ -562,6 +580,129 @@ class Decoder(nn.Module):
562
  return sample
563
 
564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  class UNetMidBlock3D(nn.Module):
566
  """
567
  A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
 
9
  from einops import rearrange
10
  from torch import nn
11
  from diffusers.utils import logging
12
+ import torch.nn.functional as F
13
 
14
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
15
  from xora.models.autoencoders.pixel_norm import PixelNorm
16
  from xora.models.autoencoders.vae import AutoencoderKLWrapper
17
+ from xora.models.transformers.attention import Attention
18
 
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
 
 
214
  last_layer = self.decoder.layers[-1]
215
  return last_layer
216
 
217
+ def set_use_tpu_flash_attention(self):
218
+ for block in self.decoder.up_blocks:
219
+ if isinstance(block, AttentionResBlocks):
220
+ for attention_block in block.attention_blocks:
221
+ attention_block.set_use_tpu_flash_attention()
222
+
223
 
224
  class Encoder(nn.Module):
225
  r"""
 
493
  norm_layer=norm_layer,
494
  inject_noise=block_params.get("inject_noise", False),
495
  )
496
+ elif block_name == "attn_res_x":
497
+ block = AttentionResBlocks(
498
+ dims=dims,
499
+ in_channels=input_channel,
500
+ num_layers=block_params["num_layers"],
501
+ resnet_groups=norm_num_groups,
502
+ norm_layer=norm_layer,
503
+ attention_head_dim=block_params["attention_head_dim"],
504
+ inject_noise=block_params.get("inject_noise", False),
505
+ )
506
  elif block_name == "res_x_y":
507
  output_channel = output_channel // block_params.get("multiplier", 2)
508
  block = ResnetBlock3D(
 
580
  return sample
581
 
582
 
583
+ class AttentionResBlocks(nn.Module):
584
+ """
585
+ A 3D convolution residual block followed by self attention residual block
586
+
587
+ Args:
588
+ dims (`int` or `Tuple[int, int]`): The number of dimensions to use in convolutions.
589
+ in_channels (`int`): The number of input channels.
590
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
591
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
592
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
593
+ resnet_groups (`int`, *optional*, defaults to 32):
594
+ The number of groups to use in the group normalization layers of the resnet blocks.
595
+ norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use.
596
+ attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
597
+ inject_noise (`bool`, *optional*, defaults to `False`): Whether to inject noise or not between convolution layers.
598
+
599
+ Returns:
600
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
601
+ in_channels, height, width)`.
602
+
603
+ """
604
+
605
+ def __init__(
606
+ self,
607
+ dims: Union[int, Tuple[int, int]],
608
+ in_channels: int,
609
+ dropout: float = 0.0,
610
+ num_layers: int = 1,
611
+ resnet_eps: float = 1e-6,
612
+ resnet_groups: int = 32,
613
+ norm_layer: str = "group_norm",
614
+ attention_head_dim: int = 64,
615
+ inject_noise: bool = False,
616
+ ):
617
+ super().__init__()
618
+
619
+ if attention_head_dim > in_channels:
620
+ raise ValueError(
621
+ "attention_head_dim must be less than or equal to in_channels"
622
+ )
623
+
624
+ resnet_groups = (
625
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
626
+ )
627
+
628
+ self.res_blocks = []
629
+ self.attention_blocks = []
630
+ for i in range(num_layers):
631
+ self.res_blocks.append(
632
+ ResnetBlock3D(
633
+ dims=dims,
634
+ in_channels=in_channels,
635
+ out_channels=in_channels,
636
+ eps=resnet_eps,
637
+ groups=resnet_groups,
638
+ dropout=dropout,
639
+ norm_layer=norm_layer,
640
+ inject_noise=inject_noise,
641
+ )
642
+ )
643
+ self.attention_blocks.append(
644
+ Attention(
645
+ query_dim=in_channels,
646
+ heads=in_channels // attention_head_dim,
647
+ dim_head=attention_head_dim,
648
+ bias=True,
649
+ out_bias=True,
650
+ qk_norm="rms_norm",
651
+ residual_connection=True,
652
+ )
653
+ )
654
+
655
+ self.res_blocks = nn.ModuleList(self.res_blocks)
656
+ self.attention_blocks = nn.ModuleList(self.attention_blocks)
657
+
658
+ def forward(
659
+ self, hidden_states: torch.FloatTensor, causal: bool = True
660
+ ) -> torch.FloatTensor:
661
+ for resnet, attention in zip(self.res_blocks, self.attention_blocks):
662
+ hidden_states = resnet(hidden_states, causal=causal)
663
+
664
+ # Reshape the hidden states to be (batch_size, frames * height * width, channel)
665
+ batch_size, channel, frames, height, width = hidden_states.shape
666
+ hidden_states = hidden_states.view(
667
+ batch_size, channel, frames * height * width
668
+ ).transpose(1, 2)
669
+
670
+ if attention.use_tpu_flash_attention:
671
+ # Pad the second dimension to be divisible by block_k_major (block in flash attention)
672
+ seq_len = hidden_states.shape[1]
673
+ block_k_major = 512
674
+ pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
675
+ if pad_len > 0:
676
+ hidden_states = F.pad(
677
+ hidden_states, (0, 0, 0, pad_len), "constant", 0
678
+ )
679
+
680
+ # Create a mask with ones for the original sequence length and zeros for the padded indexes
681
+ mask = torch.ones(
682
+ (hidden_states.shape[0], seq_len),
683
+ device=hidden_states.device,
684
+ dtype=hidden_states.dtype,
685
+ )
686
+ if pad_len > 0:
687
+ mask = F.pad(mask, (0, pad_len), "constant", 0)
688
+
689
+ hidden_states = attention(
690
+ hidden_states,
691
+ attention_mask=None if not attention.use_tpu_flash_attention else mask,
692
+ )
693
+
694
+ if attention.use_tpu_flash_attention:
695
+ # Remove the padding
696
+ if pad_len > 0:
697
+ hidden_states = hidden_states[:, :-pad_len, :]
698
+
699
+ # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
700
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
701
+ batch_size, channel, frames, height, width
702
+ )
703
+ return hidden_states
704
+
705
+
706
  class UNetMidBlock3D(nn.Module):
707
  """
708
  A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.