Spaces:
Runtime error
Runtime error
File size: 2,443 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from typing import Optional
from torch.nn import Sigmoid, Sequential, Tanh
from tha3.nn.conv import create_conv3, create_conv3_from_block_args
from tha3.nn.nonlinearity_factory import ReLUFactory
from tha3.nn.normalization import InstanceNorm2dFactory
from tha3.nn.util import BlockArgs
class PoserArgs00:
def __init__(self,
image_size: int,
input_image_channels: int,
output_image_channels: int,
start_channels: int,
num_pose_params: int,
block_args: Optional[BlockArgs] = None):
self.num_pose_params = num_pose_params
self.start_channels = start_channels
self.output_image_channels = output_image_channels
self.input_image_channels = input_image_channels
self.image_size = image_size
if block_args is None:
self.block_args = BlockArgs(
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=True))
else:
self.block_args = block_args
def create_alpha_block(self):
from torch.nn import Sequential
return Sequential(
create_conv3(
in_channels=self.start_channels,
out_channels=1,
bias=True,
initialization_method=self.block_args.initialization_method,
use_spectral_norm=False),
Sigmoid())
def create_all_channel_alpha_block(self):
from torch.nn import Sequential
return Sequential(
create_conv3(
in_channels=self.start_channels,
out_channels=self.output_image_channels,
bias=True,
initialization_method=self.block_args.initialization_method,
use_spectral_norm=False),
Sigmoid())
def create_color_change_block(self):
return Sequential(
create_conv3_from_block_args(
in_channels=self.start_channels,
out_channels=self.output_image_channels,
bias=True,
block_args=self.block_args),
Tanh())
def create_grid_change_block(self):
return create_conv3(
in_channels=self.start_channels,
out_channels=2,
bias=False,
initialization_method='zero',
use_spectral_norm=False) |