Merge pull request #33 from LightricksResearch/compress-all-half-channels
Browse files
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
@@ -455,6 +455,8 @@ class Decoder(nn.Module):
|
|
455 |
block_params = block_params if isinstance(block_params, dict) else {}
|
456 |
if block_name == "res_x_y":
|
457 |
output_channel = output_channel * block_params.get("multiplier", 2)
|
|
|
|
|
458 |
|
459 |
self.conv_in = make_conv_nd(
|
460 |
dims,
|
@@ -503,11 +505,13 @@ class Decoder(nn.Module):
|
|
503 |
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
504 |
)
|
505 |
elif block_name == "compress_all":
|
|
|
506 |
block = DepthToSpaceUpsample(
|
507 |
dims=dims,
|
508 |
in_channels=input_channel,
|
509 |
stride=(2, 2, 2),
|
510 |
residual=block_params.get("residual", False),
|
|
|
511 |
)
|
512 |
else:
|
513 |
raise ValueError(f"unknown layer: {block_name}")
|
@@ -618,10 +622,14 @@ class UNetMidBlock3D(nn.Module):
|
|
618 |
|
619 |
|
620 |
class DepthToSpaceUpsample(nn.Module):
|
621 |
-
def __init__(
|
|
|
|
|
622 |
super().__init__()
|
623 |
self.stride = stride
|
624 |
-
self.out_channels =
|
|
|
|
|
625 |
self.conv = make_conv_nd(
|
626 |
dims=dims,
|
627 |
in_channels=in_channels,
|
@@ -631,6 +639,7 @@ class DepthToSpaceUpsample(nn.Module):
|
|
631 |
causal=True,
|
632 |
)
|
633 |
self.residual = residual
|
|
|
634 |
|
635 |
def forward(self, x, causal: bool = True):
|
636 |
if self.residual:
|
@@ -642,7 +651,8 @@ class DepthToSpaceUpsample(nn.Module):
|
|
642 |
p2=self.stride[1],
|
643 |
p3=self.stride[2],
|
644 |
)
|
645 |
-
|
|
|
646 |
if self.stride[0] == 2:
|
647 |
x_in = x_in[:, :, 1:, :, :]
|
648 |
x = self.conv(x, causal=causal)
|
|
|
455 |
block_params = block_params if isinstance(block_params, dict) else {}
|
456 |
if block_name == "res_x_y":
|
457 |
output_channel = output_channel * block_params.get("multiplier", 2)
|
458 |
+
if block_name == "compress_all":
|
459 |
+
output_channel = output_channel * block_params.get("multiplier", 1)
|
460 |
|
461 |
self.conv_in = make_conv_nd(
|
462 |
dims,
|
|
|
505 |
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
506 |
)
|
507 |
elif block_name == "compress_all":
|
508 |
+
output_channel = output_channel // block_params.get("multiplier", 1)
|
509 |
block = DepthToSpaceUpsample(
|
510 |
dims=dims,
|
511 |
in_channels=input_channel,
|
512 |
stride=(2, 2, 2),
|
513 |
residual=block_params.get("residual", False),
|
514 |
+
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
515 |
)
|
516 |
else:
|
517 |
raise ValueError(f"unknown layer: {block_name}")
|
|
|
622 |
|
623 |
|
624 |
class DepthToSpaceUpsample(nn.Module):
|
625 |
+
def __init__(
|
626 |
+
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
627 |
+
):
|
628 |
super().__init__()
|
629 |
self.stride = stride
|
630 |
+
self.out_channels = (
|
631 |
+
np.prod(stride) * in_channels // out_channels_reduction_factor
|
632 |
+
)
|
633 |
self.conv = make_conv_nd(
|
634 |
dims=dims,
|
635 |
in_channels=in_channels,
|
|
|
639 |
causal=True,
|
640 |
)
|
641 |
self.residual = residual
|
642 |
+
self.out_channels_reduction_factor = out_channels_reduction_factor
|
643 |
|
644 |
def forward(self, x, causal: bool = True):
|
645 |
if self.residual:
|
|
|
651 |
p2=self.stride[1],
|
652 |
p3=self.stride[2],
|
653 |
)
|
654 |
+
num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
|
655 |
+
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
656 |
if self.stride[0] == 2:
|
657 |
x_in = x_in[:, :, 1:, :, :]
|
658 |
x = self.conv(x, causal=causal)
|