Spaces:
Runtime error
Runtime error
import math | |
from typing import Optional, List | |
import torch | |
from torch import Tensor | |
from torch.nn import ModuleList, Module | |
from tha3.nn.common.poser_args import PoserArgs00 | |
from tha3.nn.conv import create_conv3_block_from_block_args, create_downsample_block_from_block_args, \ | |
create_upsample_block_from_block_args | |
from tha3.nn.nonlinearity_factory import ReLUFactory | |
from tha3.nn.normalization import InstanceNorm2dFactory | |
from tha3.nn.resnet_block import ResnetBlock | |
from tha3.nn.util import BlockArgs | |
class PoserEncoderDecoder00Args(PoserArgs00): | |
def __init__(self, | |
image_size: int, | |
input_image_channels: int, | |
output_image_channels: int, | |
num_pose_params: int , | |
start_channels: int, | |
bottleneck_image_size, | |
num_bottleneck_blocks, | |
max_channels: int, | |
block_args: Optional[BlockArgs] = None): | |
super().__init__( | |
image_size, input_image_channels, output_image_channels, start_channels, num_pose_params, block_args) | |
self.max_channels = max_channels | |
self.num_bottleneck_blocks = num_bottleneck_blocks | |
self.bottleneck_image_size = bottleneck_image_size | |
assert bottleneck_image_size > 1 | |
if block_args is None: | |
self.block_args = BlockArgs( | |
normalization_layer_factory=InstanceNorm2dFactory(), | |
nonlinearity_factory=ReLUFactory(inplace=True)) | |
else: | |
self.block_args = block_args | |
class PoserEncoderDecoder00(Module): | |
def __init__(self, args: PoserEncoderDecoder00Args): | |
super().__init__() | |
self.args = args | |
self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1 | |
self.downsample_blocks = ModuleList() | |
self.downsample_blocks.append( | |
create_conv3_block_from_block_args( | |
args.input_image_channels, | |
args.start_channels, | |
args.block_args)) | |
current_image_size = args.image_size | |
current_num_channels = args.start_channels | |
while current_image_size > args.bottleneck_image_size: | |
next_image_size = current_image_size // 2 | |
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) | |
self.downsample_blocks.append(create_downsample_block_from_block_args( | |
in_channels=current_num_channels, | |
out_channels=next_num_channels, | |
is_output_1x1=False, | |
block_args=args.block_args)) | |
current_image_size = next_image_size | |
current_num_channels = next_num_channels | |
assert len(self.downsample_blocks) == self.num_levels | |
self.bottleneck_blocks = ModuleList() | |
self.bottleneck_blocks.append(create_conv3_block_from_block_args( | |
in_channels=current_num_channels + args.num_pose_params, | |
out_channels=current_num_channels, | |
block_args=args.block_args)) | |
for i in range(1, args.num_bottleneck_blocks): | |
self.bottleneck_blocks.append( | |
ResnetBlock.create( | |
num_channels=current_num_channels, | |
is1x1=False, | |
block_args=args.block_args)) | |
self.upsample_blocks = ModuleList() | |
while current_image_size < args.image_size: | |
next_image_size = current_image_size * 2 | |
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) | |
self.upsample_blocks.append(create_upsample_block_from_block_args( | |
in_channels=current_num_channels, | |
out_channels=next_num_channels, | |
block_args=args.block_args)) | |
current_image_size = next_image_size | |
current_num_channels = next_num_channels | |
def get_num_output_channels_from_level(self, level: int): | |
return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) | |
def get_num_output_channels_from_image_size(self, image_size: int): | |
return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) | |
def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]: | |
if self.args.num_pose_params != 0: | |
assert pose is not None | |
else: | |
assert pose is None | |
outputs = [] | |
feature = image | |
outputs.append(feature) | |
for block in self.downsample_blocks: | |
feature = block(feature) | |
outputs.append(feature) | |
if pose is not None: | |
n, c = pose.shape | |
pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size) | |
feature = torch.cat([feature, pose], dim=1) | |
for block in self.bottleneck_blocks: | |
feature = block(feature) | |
outputs.append(feature) | |
for block in self.upsample_blocks: | |
feature = block(feature) | |
outputs.append(feature) | |
outputs.reverse() | |
return outputs | |