causal_video_autoencoder: add option to half channels in depth to space upsample block
Browse files
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
@@ -455,6 +455,8 @@ class Decoder(nn.Module):
|
|
455 |
block_params = block_params if isinstance(block_params, dict) else {}
|
456 |
if block_name == "res_x_y":
|
457 |
output_channel = output_channel * block_params.get("multiplier", 2)
|
|
|
|
|
458 |
|
459 |
self.conv_in = make_conv_nd(
|
460 |
dims,
|
@@ -501,11 +503,13 @@ class Decoder(nn.Module):
|
|
501 |
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
502 |
)
|
503 |
elif block_name == "compress_all":
|
|
|
504 |
block = DepthToSpaceUpsample(
|
505 |
dims=dims,
|
506 |
in_channels=input_channel,
|
507 |
stride=(2, 2, 2),
|
508 |
residual=block_params.get("residual", False),
|
|
|
509 |
)
|
510 |
else:
|
511 |
raise ValueError(f"unknown layer: {block_name}")
|
@@ -614,10 +618,14 @@ class UNetMidBlock3D(nn.Module):
|
|
614 |
|
615 |
|
616 |
class DepthToSpaceUpsample(nn.Module):
|
617 |
-
def __init__(
|
|
|
|
|
618 |
super().__init__()
|
619 |
self.stride = stride
|
620 |
-
self.out_channels =
|
|
|
|
|
621 |
self.conv = make_conv_nd(
|
622 |
dims=dims,
|
623 |
in_channels=in_channels,
|
@@ -627,6 +635,7 @@ class DepthToSpaceUpsample(nn.Module):
|
|
627 |
causal=True,
|
628 |
)
|
629 |
self.residual = residual
|
|
|
630 |
|
631 |
def forward(self, x, causal: bool = True):
|
632 |
if self.residual:
|
@@ -638,7 +647,8 @@ class DepthToSpaceUpsample(nn.Module):
|
|
638 |
p2=self.stride[1],
|
639 |
p3=self.stride[2],
|
640 |
)
|
641 |
-
|
|
|
642 |
if self.stride[0] == 2:
|
643 |
x_in = x_in[:, :, 1:, :, :]
|
644 |
x = self.conv(x, causal=causal)
|
|
|
455 |
block_params = block_params if isinstance(block_params, dict) else {}
|
456 |
if block_name == "res_x_y":
|
457 |
output_channel = output_channel * block_params.get("multiplier", 2)
|
458 |
+
if block_name == "compress_all":
|
459 |
+
output_channel = output_channel * block_params.get("multiplier", 1)
|
460 |
|
461 |
self.conv_in = make_conv_nd(
|
462 |
dims,
|
|
|
503 |
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
504 |
)
|
505 |
elif block_name == "compress_all":
|
506 |
+
output_channel = output_channel // block_params.get("multiplier", 1)
|
507 |
block = DepthToSpaceUpsample(
|
508 |
dims=dims,
|
509 |
in_channels=input_channel,
|
510 |
stride=(2, 2, 2),
|
511 |
residual=block_params.get("residual", False),
|
512 |
+
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
513 |
)
|
514 |
else:
|
515 |
raise ValueError(f"unknown layer: {block_name}")
|
|
|
618 |
|
619 |
|
620 |
class DepthToSpaceUpsample(nn.Module):
|
621 |
+
def __init__(
|
622 |
+
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
623 |
+
):
|
624 |
super().__init__()
|
625 |
self.stride = stride
|
626 |
+
self.out_channels = (
|
627 |
+
np.prod(stride) * in_channels // out_channels_reduction_factor
|
628 |
+
)
|
629 |
self.conv = make_conv_nd(
|
630 |
dims=dims,
|
631 |
in_channels=in_channels,
|
|
|
635 |
causal=True,
|
636 |
)
|
637 |
self.residual = residual
|
638 |
+
self.out_channels_reduction_factor = out_channels_reduction_factor
|
639 |
|
640 |
def forward(self, x, causal: bool = True):
|
641 |
if self.residual:
|
|
|
647 |
p2=self.stride[1],
|
648 |
p3=self.stride[2],
|
649 |
)
|
650 |
+
num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
|
651 |
+
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
652 |
if self.stride[0] == 2:
|
653 |
x_in = x_in[:, :, 1:, :, :]
|
654 |
x = self.conv(x, causal=causal)
|