Spaces:
Runtime error
Runtime error
File size: 5,164 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import math
from typing import Optional, List
import torch
from torch import Tensor
from torch.nn import ModuleList, Module
from tha3.nn.common.poser_args import PoserArgs00
from tha3.nn.conv import create_conv3_block_from_block_args, create_downsample_block_from_block_args, \
create_upsample_block_from_block_args
from tha3.nn.nonlinearity_factory import ReLUFactory
from tha3.nn.normalization import InstanceNorm2dFactory
from tha3.nn.resnet_block import ResnetBlock
from tha3.nn.util import BlockArgs
class PoserEncoderDecoder00Args(PoserArgs00):
def __init__(self,
image_size: int,
input_image_channels: int,
output_image_channels: int,
num_pose_params: int ,
start_channels: int,
bottleneck_image_size,
num_bottleneck_blocks,
max_channels: int,
block_args: Optional[BlockArgs] = None):
super().__init__(
image_size, input_image_channels, output_image_channels, start_channels, num_pose_params, block_args)
self.max_channels = max_channels
self.num_bottleneck_blocks = num_bottleneck_blocks
self.bottleneck_image_size = bottleneck_image_size
assert bottleneck_image_size > 1
if block_args is None:
self.block_args = BlockArgs(
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=True))
else:
self.block_args = block_args
class PoserEncoderDecoder00(Module):
def __init__(self, args: PoserEncoderDecoder00Args):
super().__init__()
self.args = args
self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1
self.downsample_blocks = ModuleList()
self.downsample_blocks.append(
create_conv3_block_from_block_args(
args.input_image_channels,
args.start_channels,
args.block_args))
current_image_size = args.image_size
current_num_channels = args.start_channels
while current_image_size > args.bottleneck_image_size:
next_image_size = current_image_size // 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.downsample_blocks.append(create_downsample_block_from_block_args(
in_channels=current_num_channels,
out_channels=next_num_channels,
is_output_1x1=False,
block_args=args.block_args))
current_image_size = next_image_size
current_num_channels = next_num_channels
assert len(self.downsample_blocks) == self.num_levels
self.bottleneck_blocks = ModuleList()
self.bottleneck_blocks.append(create_conv3_block_from_block_args(
in_channels=current_num_channels + args.num_pose_params,
out_channels=current_num_channels,
block_args=args.block_args))
for i in range(1, args.num_bottleneck_blocks):
self.bottleneck_blocks.append(
ResnetBlock.create(
num_channels=current_num_channels,
is1x1=False,
block_args=args.block_args))
self.upsample_blocks = ModuleList()
while current_image_size < args.image_size:
next_image_size = current_image_size * 2
next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
self.upsample_blocks.append(create_upsample_block_from_block_args(
in_channels=current_num_channels,
out_channels=next_num_channels,
block_args=args.block_args))
current_image_size = next_image_size
current_num_channels = next_num_channels
def get_num_output_channels_from_level(self, level: int):
return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
def get_num_output_channels_from_image_size(self, image_size: int):
return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]:
if self.args.num_pose_params != 0:
assert pose is not None
else:
assert pose is None
outputs = []
feature = image
outputs.append(feature)
for block in self.downsample_blocks:
feature = block(feature)
outputs.append(feature)
if pose is not None:
n, c = pose.shape
pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)
feature = torch.cat([feature, pose], dim=1)
for block in self.bottleneck_blocks:
feature = block(feature)
outputs.append(feature)
for block in self.upsample_blocks:
feature = block(feature)
outputs.append(feature)
outputs.reverse()
return outputs
|