TomatoCocotree
上传
6a62ffb
raw
history blame
2.44 kB
from typing import Optional
from torch.nn import Sigmoid, Sequential, Tanh
from tha3.nn.conv import create_conv3, create_conv3_from_block_args
from tha3.nn.nonlinearity_factory import ReLUFactory
from tha3.nn.normalization import InstanceNorm2dFactory
from tha3.nn.util import BlockArgs
class PoserArgs00:
def __init__(self,
image_size: int,
input_image_channels: int,
output_image_channels: int,
start_channels: int,
num_pose_params: int,
block_args: Optional[BlockArgs] = None):
self.num_pose_params = num_pose_params
self.start_channels = start_channels
self.output_image_channels = output_image_channels
self.input_image_channels = input_image_channels
self.image_size = image_size
if block_args is None:
self.block_args = BlockArgs(
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=True))
else:
self.block_args = block_args
def create_alpha_block(self):
from torch.nn import Sequential
return Sequential(
create_conv3(
in_channels=self.start_channels,
out_channels=1,
bias=True,
initialization_method=self.block_args.initialization_method,
use_spectral_norm=False),
Sigmoid())
def create_all_channel_alpha_block(self):
from torch.nn import Sequential
return Sequential(
create_conv3(
in_channels=self.start_channels,
out_channels=self.output_image_channels,
bias=True,
initialization_method=self.block_args.initialization_method,
use_spectral_norm=False),
Sigmoid())
def create_color_change_block(self):
return Sequential(
create_conv3_from_block_args(
in_channels=self.start_channels,
out_channels=self.output_image_channels,
bias=True,
block_args=self.block_args),
Tanh())
def create_grid_change_block(self):
return create_conv3(
in_channels=self.start_channels,
out_channels=2,
bias=False,
initialization_method='zero',
use_spectral_norm=False)