from typing import Optional from torch.nn import Sigmoid, Sequential, Tanh from tha3.nn.conv import create_conv3, create_conv3_from_block_args from tha3.nn.nonlinearity_factory import ReLUFactory from tha3.nn.normalization import InstanceNorm2dFactory from tha3.nn.util import BlockArgs class PoserArgs00: def __init__(self, image_size: int, input_image_channels: int, output_image_channels: int, start_channels: int, num_pose_params: int, block_args: Optional[BlockArgs] = None): self.num_pose_params = num_pose_params self.start_channels = start_channels self.output_image_channels = output_image_channels self.input_image_channels = input_image_channels self.image_size = image_size if block_args is None: self.block_args = BlockArgs( normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=True)) else: self.block_args = block_args def create_alpha_block(self): from torch.nn import Sequential return Sequential( create_conv3( in_channels=self.start_channels, out_channels=1, bias=True, initialization_method=self.block_args.initialization_method, use_spectral_norm=False), Sigmoid()) def create_all_channel_alpha_block(self): from torch.nn import Sequential return Sequential( create_conv3( in_channels=self.start_channels, out_channels=self.output_image_channels, bias=True, initialization_method=self.block_args.initialization_method, use_spectral_norm=False), Sigmoid()) def create_color_change_block(self): return Sequential( create_conv3_from_block_args( in_channels=self.start_channels, out_channels=self.output_image_channels, bias=True, block_args=self.block_args), Tanh()) def create_grid_change_block(self): return create_conv3( in_channels=self.start_channels, out_channels=2, bias=False, initialization_method='zero', use_spectral_norm=False)