Sapir commited on
Commit
65dad79
1 Parent(s): f63ea56

VAE: Add timestep conditioning

Browse files
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -10,6 +10,8 @@ from einops import rearrange
10
  from torch import nn
11
  from diffusers.utils import logging
12
  import torch.nn.functional as F
 
 
13
 
14
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
15
  from xora.models.autoencoders.pixel_norm import PixelNorm
@@ -94,6 +96,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
94
  patch_size=config.get("patch_size", 1),
95
  norm_layer=config.get("norm_layer", "group_norm"),
96
  causal=config.get("causal_decoder", False),
 
97
  )
98
 
99
  dims = config["dims"]
@@ -122,6 +125,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
122
  latent_log_var=self.encoder.latent_log_var,
123
  use_quant_conv=self.use_quant_conv,
124
  causal_decoder=self.decoder.causal,
 
125
  )
126
 
127
  @property
@@ -449,6 +453,7 @@ class Decoder(nn.Module):
449
  patch_size: int = 1,
450
  norm_layer: str = "group_norm",
451
  causal: bool = True,
 
452
  ):
453
  super().__init__()
454
  self.patch_size = patch_size
@@ -502,6 +507,7 @@ class Decoder(nn.Module):
502
  norm_layer=norm_layer,
503
  attention_head_dim=block_params["attention_head_dim"],
504
  inject_noise=block_params.get("inject_noise", False),
 
505
  )
506
  elif block_name == "res_x_y":
507
  output_channel = output_channel // block_params.get("multiplier", 2)
@@ -513,6 +519,7 @@ class Decoder(nn.Module):
513
  groups=norm_num_groups,
514
  norm_layer=norm_layer,
515
  inject_noise=block_params.get("inject_noise", False),
 
516
  )
517
  elif block_name == "compress_time":
518
  block = DepthToSpaceUpsample(
@@ -552,9 +559,28 @@ class Decoder(nn.Module):
552
 
553
  self.gradient_checkpointing = False
554
 
555
- def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  r"""The forward method of the `Decoder` class."""
557
  assert target_shape is not None, "target_shape must be provided"
 
558
 
559
  sample = self.conv_in(sample, causal=self.causal)
560
 
@@ -568,10 +594,46 @@ class Decoder(nn.Module):
568
 
569
  sample = sample.to(upscale_dtype)
570
 
 
 
 
 
 
 
571
  for up_block in self.up_blocks:
572
- sample = checkpoint_fn(up_block)(sample, causal=self.causal)
 
 
 
 
 
573
 
574
  sample = self.conv_norm_out(sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  sample = self.conv_act(sample)
576
  sample = self.conv_out(sample, causal=self.causal)
577
 
@@ -731,11 +793,18 @@ class UNetMidBlock3D(nn.Module):
731
  resnet_groups: int = 32,
732
  norm_layer: str = "group_norm",
733
  inject_noise: bool = False,
 
734
  ):
735
  super().__init__()
736
  resnet_groups = (
737
  resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
738
  )
 
 
 
 
 
 
739
 
740
  self.res_blocks = nn.ModuleList(
741
  [
@@ -748,17 +817,38 @@ class UNetMidBlock3D(nn.Module):
748
  dropout=dropout,
749
  norm_layer=norm_layer,
750
  inject_noise=inject_noise,
 
751
  )
752
  for _ in range(num_layers)
753
  ]
754
  )
755
 
756
  def forward(
757
- self, hidden_states: torch.FloatTensor, causal: bool = True
 
 
 
758
  ) -> torch.FloatTensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
  for resnet in self.res_blocks:
760
- hidden_states = resnet(hidden_states, causal=causal)
761
-
 
762
  return hidden_states
763
 
764
 
@@ -846,6 +936,7 @@ class ResnetBlock3D(nn.Module):
846
  eps: float = 1e-6,
847
  norm_layer: str = "group_norm",
848
  inject_noise: bool = False,
 
849
  ):
850
  super().__init__()
851
  self.in_channels = in_channels
@@ -915,6 +1006,13 @@ class ResnetBlock3D(nn.Module):
915
  else nn.Identity()
916
  )
917
 
 
 
 
 
 
 
 
918
  def _feed_spatial_noise(
919
  self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
920
  ) -> torch.FloatTensor:
@@ -933,10 +1031,29 @@ class ResnetBlock3D(nn.Module):
933
  self,
934
  input_tensor: torch.FloatTensor,
935
  causal: bool = True,
 
936
  ) -> torch.FloatTensor:
937
  hidden_states = input_tensor
 
938
 
939
  hidden_states = self.norm1(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
940
 
941
  hidden_states = self.non_linearity(hidden_states)
942
 
@@ -949,6 +1066,9 @@ class ResnetBlock3D(nn.Module):
949
 
950
  hidden_states = self.norm2(hidden_states)
951
 
 
 
 
952
  hidden_states = self.non_linearity(hidden_states)
953
 
954
  hidden_states = self.dropout(hidden_states)
@@ -962,6 +1082,8 @@ class ResnetBlock3D(nn.Module):
962
 
963
  input_tensor = self.norm3(input_tensor)
964
 
 
 
965
  input_tensor = self.conv_shortcut(input_tensor)
966
 
967
  output_tensor = input_tensor + hidden_states
@@ -1013,35 +1135,42 @@ def unpatchify(x, patch_size_hw, patch_size_t=1):
1013
  def create_video_autoencoder_config(
1014
  latent_channels: int = 64,
1015
  ):
1016
- config = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1017
  "_class_name": "CausalVideoAutoencoder",
1018
- "dims": 3, # (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
1019
- "in_channels": 3, # Number of input color channels (e.g., RGB)
1020
- "out_channels": 3, # Number of output color channels
1021
- "latent_channels": latent_channels, # Number of channels in the latent space representation
1022
- "blocks": [
1023
- ("res_x", 4),
1024
- ("compress_space", 1),
1025
- ("res_x_y", 1),
1026
- ("res_x", 2),
1027
- ("compress_all", 1),
1028
- ("res_x", 3),
1029
- ("compress_all", 1),
1030
- ("res_x_y", 1),
1031
- ("res_x", 2),
1032
- ("compress_time", 1),
1033
- ("res_x", 3),
1034
- ("res_x", 3),
1035
- ],
1036
  "patch_size": 4,
1037
  "latent_log_var": "uniform",
1038
  "use_quant_conv": False,
1039
- "norm_layer": "layer_norm",
1040
- "causal_decoder": True,
1041
  }
1042
 
1043
- return config
1044
-
1045
 
1046
  def test_vae_patchify_unpatchify():
1047
  import torch
@@ -1075,8 +1204,9 @@ def demo_video_autoencoder_forward_backward():
1075
  print(f"input shape={input_videos.shape}")
1076
  print(f"latent shape={latent.shape}")
1077
 
 
1078
  reconstructed_videos = video_autoencoder.decode(
1079
- latent, target_shape=input_videos.shape
1080
  ).sample
1081
 
1082
  print(f"reconstructed shape={reconstructed_videos.shape}")
@@ -1084,16 +1214,16 @@ def demo_video_autoencoder_forward_backward():
1084
  # Validate that single image gets treated the same way as first frame
1085
  input_image = input_videos[:, :, :1, :, :]
1086
  image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
1087
- reconstructed_image = video_autoencoder.decode(
1088
- image_latent, target_shape=image_latent.shape
1089
  ).sample
1090
 
1091
- first_frame_latent = latent[:, :, :1, :, :]
1092
 
1093
  # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
1094
  # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
1095
- assert (image_latent == first_frame_latent).all()
1096
- assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
1097
 
1098
  # Calculate the loss (e.g., mean squared error)
1099
  loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
 
10
  from torch import nn
11
  from diffusers.utils import logging
12
  import torch.nn.functional as F
13
+ from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
14
+
15
 
16
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
17
  from xora.models.autoencoders.pixel_norm import PixelNorm
 
96
  patch_size=config.get("patch_size", 1),
97
  norm_layer=config.get("norm_layer", "group_norm"),
98
  causal=config.get("causal_decoder", False),
99
+ timestep_conditioning=config.get("timestep_conditioning", False),
100
  )
101
 
102
  dims = config["dims"]
 
125
  latent_log_var=self.encoder.latent_log_var,
126
  use_quant_conv=self.use_quant_conv,
127
  causal_decoder=self.decoder.causal,
128
+ timestep_conditioning=self.decoder.timestep_conditioning,
129
  )
130
 
131
  @property
 
453
  patch_size: int = 1,
454
  norm_layer: str = "group_norm",
455
  causal: bool = True,
456
+ timestep_conditioning: bool = False,
457
  ):
458
  super().__init__()
459
  self.patch_size = patch_size
 
507
  norm_layer=norm_layer,
508
  attention_head_dim=block_params["attention_head_dim"],
509
  inject_noise=block_params.get("inject_noise", False),
510
+ timestep_conditioning=timestep_conditioning,
511
  )
512
  elif block_name == "res_x_y":
513
  output_channel = output_channel // block_params.get("multiplier", 2)
 
519
  groups=norm_num_groups,
520
  norm_layer=norm_layer,
521
  inject_noise=block_params.get("inject_noise", False),
522
+ timestep_conditioning=False,
523
  )
524
  elif block_name == "compress_time":
525
  block = DepthToSpaceUpsample(
 
559
 
560
  self.gradient_checkpointing = False
561
 
562
+ self.timestep_conditioning = timestep_conditioning
563
+
564
+ if timestep_conditioning:
565
+ self.timestep_scale_multiplier = nn.Parameter(
566
+ torch.tensor(1000.0, dtype=torch.float32)
567
+ )
568
+ self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
569
+ output_channel * 2, 0
570
+ )
571
+ self.last_scale_shift_table = nn.Parameter(
572
+ torch.randn(2, output_channel) / output_channel**0.5
573
+ )
574
+
575
+ def forward(
576
+ self,
577
+ sample: torch.FloatTensor,
578
+ target_shape,
579
+ timesteps: Optional[torch.Tensor] = None,
580
+ ) -> torch.FloatTensor:
581
  r"""The forward method of the `Decoder` class."""
582
  assert target_shape is not None, "target_shape must be provided"
583
+ batch_size = sample.shape[0]
584
 
585
  sample = self.conv_in(sample, causal=self.causal)
586
 
 
594
 
595
  sample = sample.to(upscale_dtype)
596
 
597
+ if self.timestep_conditioning:
598
+ assert (
599
+ timesteps is not None
600
+ ), "should pass timesteps with timestep_conditioning=True"
601
+ scaled_timesteps = timesteps * self.timestep_scale_multiplier
602
+
603
  for up_block in self.up_blocks:
604
+ if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
605
+ sample = checkpoint_fn(up_block)(
606
+ sample, causal=self.causal, timesteps=scaled_timesteps
607
+ )
608
+ else:
609
+ sample = checkpoint_fn(up_block)(sample, causal=self.causal)
610
 
611
  sample = self.conv_norm_out(sample)
612
+
613
+ if self.timestep_conditioning:
614
+ embedded_timesteps = self.last_time_embedder(
615
+ timestep=scaled_timesteps.flatten(),
616
+ resolution=None,
617
+ aspect_ratio=None,
618
+ batch_size=sample.shape[0],
619
+ hidden_dtype=sample.dtype,
620
+ )
621
+ embedded_timesteps = embedded_timesteps.view(
622
+ batch_size, embedded_timesteps.shape[-1], 1, 1, 1
623
+ )
624
+ ada_values = self.last_scale_shift_table[
625
+ None, ..., None, None, None
626
+ ] + embedded_timesteps.reshape(
627
+ batch_size,
628
+ 2,
629
+ -1,
630
+ embedded_timesteps.shape[-3],
631
+ embedded_timesteps.shape[-2],
632
+ embedded_timesteps.shape[-1],
633
+ )
634
+ shift, scale = ada_values.unbind(dim=1)
635
+ sample = sample * (1 + scale) + shift
636
+
637
  sample = self.conv_act(sample)
638
  sample = self.conv_out(sample, causal=self.causal)
639
 
 
793
  resnet_groups: int = 32,
794
  norm_layer: str = "group_norm",
795
  inject_noise: bool = False,
796
+ timestep_conditioning: bool = False,
797
  ):
798
  super().__init__()
799
  resnet_groups = (
800
  resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
801
  )
802
+ self.timestep_conditioning = timestep_conditioning
803
+
804
+ if timestep_conditioning:
805
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
806
+ in_channels * 4, 0
807
+ )
808
 
809
  self.res_blocks = nn.ModuleList(
810
  [
 
817
  dropout=dropout,
818
  norm_layer=norm_layer,
819
  inject_noise=inject_noise,
820
+ timestep_conditioning=timestep_conditioning,
821
  )
822
  for _ in range(num_layers)
823
  ]
824
  )
825
 
826
  def forward(
827
+ self,
828
+ hidden_states: torch.FloatTensor,
829
+ causal: bool = True,
830
+ timesteps: Optional[torch.Tensor] = None,
831
  ) -> torch.FloatTensor:
832
+ timestep_embed = None
833
+ if self.timestep_conditioning:
834
+ assert (
835
+ timesteps is not None
836
+ ), "should pass timesteps with timestep_conditioning=True"
837
+ batch_size = hidden_states.shape[0]
838
+ timestep_embed = self.time_embedder(
839
+ timestep=timesteps.flatten(),
840
+ resolution=None,
841
+ aspect_ratio=None,
842
+ batch_size=batch_size,
843
+ hidden_dtype=hidden_states.dtype,
844
+ )
845
+ timestep_embed = timestep_embed.view(
846
+ batch_size, timestep_embed.shape[-1], 1, 1, 1
847
+ )
848
  for resnet in self.res_blocks:
849
+ hidden_states = resnet(
850
+ hidden_states, causal=causal, timesteps=timestep_embed
851
+ )
852
  return hidden_states
853
 
854
 
 
936
  eps: float = 1e-6,
937
  norm_layer: str = "group_norm",
938
  inject_noise: bool = False,
939
+ timestep_conditioning: bool = False,
940
  ):
941
  super().__init__()
942
  self.in_channels = in_channels
 
1006
  else nn.Identity()
1007
  )
1008
 
1009
+ self.timestep_conditioning = timestep_conditioning
1010
+
1011
+ if timestep_conditioning:
1012
+ self.scale_shift_table = nn.Parameter(
1013
+ torch.randn(4, in_channels) / in_channels**0.5
1014
+ )
1015
+
1016
  def _feed_spatial_noise(
1017
  self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
1018
  ) -> torch.FloatTensor:
 
1031
  self,
1032
  input_tensor: torch.FloatTensor,
1033
  causal: bool = True,
1034
+ timesteps: Optional[torch.Tensor] = None,
1035
  ) -> torch.FloatTensor:
1036
  hidden_states = input_tensor
1037
+ batch_size = hidden_states.shape[0]
1038
 
1039
  hidden_states = self.norm1(hidden_states)
1040
+ if self.timestep_conditioning:
1041
+ assert (
1042
+ timesteps is not None
1043
+ ), "should pass timesteps with timestep_conditioning=True"
1044
+ ada_values = self.scale_shift_table[
1045
+ None, ..., None, None, None
1046
+ ] + timesteps.reshape(
1047
+ batch_size,
1048
+ 4,
1049
+ -1,
1050
+ timesteps.shape[-3],
1051
+ timesteps.shape[-2],
1052
+ timesteps.shape[-1],
1053
+ )
1054
+ shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
1055
+
1056
+ hidden_states = hidden_states * (1 + scale1) + shift1
1057
 
1058
  hidden_states = self.non_linearity(hidden_states)
1059
 
 
1066
 
1067
  hidden_states = self.norm2(hidden_states)
1068
 
1069
+ if self.timestep_conditioning:
1070
+ hidden_states = hidden_states * (1 + scale2) + shift2
1071
+
1072
  hidden_states = self.non_linearity(hidden_states)
1073
 
1074
  hidden_states = self.dropout(hidden_states)
 
1082
 
1083
  input_tensor = self.norm3(input_tensor)
1084
 
1085
+ batch_size = input_tensor.shape[0]
1086
+
1087
  input_tensor = self.conv_shortcut(input_tensor)
1088
 
1089
  output_tensor = input_tensor + hidden_states
 
1135
  def create_video_autoencoder_config(
1136
  latent_channels: int = 64,
1137
  ):
1138
+ encoder_blocks = [
1139
+ ("res_x", {"num_layers": 4}),
1140
+ ("compress_all_x_y", {"multiplier": 3}),
1141
+ ("res_x", {"num_layers": 4}),
1142
+ ("compress_all_x_y", {"multiplier": 2}),
1143
+ ("res_x", {"num_layers": 4}),
1144
+ ("compress_all", {}),
1145
+ ("res_x", {"num_layers": 3}),
1146
+ ("res_x", {"num_layers": 4}),
1147
+ ]
1148
+ decoder_blocks = [
1149
+ ("res_x", {"num_layers": 4}),
1150
+ ("compress_all", {"residual": True}),
1151
+ ("res_x_y", {"multiplier": 3}),
1152
+ ("res_x", {"num_layers": 3}),
1153
+ ("compress_all", {"residual": True}),
1154
+ ("res_x_y", {"multiplier": 2}),
1155
+ ("res_x", {"num_layers": 3}),
1156
+ ("compress_all", {"residual": True}),
1157
+ ("res_x", {"num_layers": 3}),
1158
+ ("res_x", {"num_layers": 4}),
1159
+ ]
1160
+ return {
1161
  "_class_name": "CausalVideoAutoencoder",
1162
+ "dims": 3,
1163
+ "encoder_blocks": encoder_blocks,
1164
+ "decoder_blocks": decoder_blocks,
1165
+ "latent_channels": latent_channels,
1166
+ "norm_layer": "pixel_norm",
 
 
 
 
 
 
 
 
 
 
 
 
 
1167
  "patch_size": 4,
1168
  "latent_log_var": "uniform",
1169
  "use_quant_conv": False,
1170
+ "causal_decoder": False,
1171
+ "timestep_conditioning": True,
1172
  }
1173
 
 
 
1174
 
1175
  def test_vae_patchify_unpatchify():
1176
  import torch
 
1204
  print(f"input shape={input_videos.shape}")
1205
  print(f"latent shape={latent.shape}")
1206
 
1207
+ timesteps = torch.ones(input_videos.shape[0]) * 0.1
1208
  reconstructed_videos = video_autoencoder.decode(
1209
+ latent, target_shape=input_videos.shape, timesteps=timesteps
1210
  ).sample
1211
 
1212
  print(f"reconstructed shape={reconstructed_videos.shape}")
 
1214
  # Validate that single image gets treated the same way as first frame
1215
  input_image = input_videos[:, :, :1, :, :]
1216
  image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
1217
+ _ = video_autoencoder.decode(
1218
+ image_latent, target_shape=image_latent.shape, timesteps=timesteps
1219
  ).sample
1220
 
1221
+ # first_frame_latent = latent[:, :, :1, :, :]
1222
 
1223
  # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
1224
  # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
1225
+ # assert (image_latent == first_frame_latent).all()
1226
+ # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
1227
 
1228
  # Calculate the loss (e.g., mean squared error)
1229
  loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
xora/models/autoencoders/vae.py CHANGED
@@ -251,14 +251,21 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
251
  return moments
252
 
253
  def _decode(
254
- self, z: torch.FloatTensor, target_shape=None
 
 
 
255
  ) -> Union[DecoderOutput, torch.FloatTensor]:
256
  z = self.post_quant_conv(z)
257
- dec = self.decoder(z, target_shape=target_shape)
258
  return dec
259
 
260
  def decode(
261
- self, z: torch.FloatTensor, return_dict: bool = True, target_shape=None
 
 
 
 
262
  ) -> Union[DecoderOutput, torch.FloatTensor]:
263
  assert target_shape is not None, "target_shape must be provided for decoding"
264
  if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
@@ -291,7 +298,7 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
291
  decoded = (
292
  self._hw_tiled_decode(z, target_shape)
293
  if self.use_hw_tiling
294
- else self._decode(z, target_shape=target_shape)
295
  )
296
 
297
  if not return_dict:
 
251
  return moments
252
 
253
  def _decode(
254
+ self,
255
+ z: torch.FloatTensor,
256
+ target_shape=None,
257
+ timesteps: Optional[torch.Tensor] = None,
258
  ) -> Union[DecoderOutput, torch.FloatTensor]:
259
  z = self.post_quant_conv(z)
260
+ dec = self.decoder(z, target_shape=target_shape, timesteps=timesteps)
261
  return dec
262
 
263
  def decode(
264
+ self,
265
+ z: torch.FloatTensor,
266
+ return_dict: bool = True,
267
+ target_shape=None,
268
+ timesteps: Optional[torch.Tensor] = None,
269
  ) -> Union[DecoderOutput, torch.FloatTensor]:
270
  assert target_shape is not None, "target_shape must be provided for decoding"
271
  if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
 
298
  decoded = (
299
  self._hw_tiled_decode(z, target_shape)
300
  if self.use_hw_tiling
301
+ else self._decode(z, target_shape=target_shape, timesteps=timesteps)
302
  )
303
 
304
  if not return_dict: