from typing import Optional from tha3.nn.conv import create_conv7_block_from_block_args, create_conv3_block_from_block_args, \ create_downsample_block_from_block_args, create_conv3 from tha3.nn.resnet_block import ResnetBlock from tha3.nn.resnet_block_seperable import ResnetBlockSeparable from tha3.nn.separable_conv import create_separable_conv7_block, create_separable_conv3_block, \ create_separable_downsample_block, create_separable_conv3 from tha3.nn.util import BlockArgs class ConvBlockFactory: def __init__(self, block_args: BlockArgs, use_separable_convolution: bool = False): self.use_separable_convolution = use_separable_convolution self.block_args = block_args def create_conv3(self, in_channels: int, out_channels: int, bias: bool, initialization_method: Optional[str] = None): if initialization_method is None: initialization_method = self.block_args.initialization_method if self.use_separable_convolution: return create_separable_conv3( in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm) else: return create_conv3( in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm) def create_conv7_block(self, in_channels: int, out_channels: int): if self.use_separable_convolution: return create_separable_conv7_block(in_channels, out_channels, self.block_args) else: return create_conv7_block_from_block_args(in_channels, out_channels, self.block_args) def create_conv3_block(self, in_channels: int, out_channels: int): if self.use_separable_convolution: return create_separable_conv3_block(in_channels, out_channels, self.block_args) else: return create_conv3_block_from_block_args(in_channels, out_channels, self.block_args) def create_downsample_block(self, in_channels: int, out_channels: int, is_output_1x1: bool): if self.use_separable_convolution: return create_separable_downsample_block(in_channels, out_channels, is_output_1x1, self.block_args) else: return create_downsample_block_from_block_args(in_channels, out_channels, is_output_1x1) def create_resnet_block(self, num_channels: int, is_1x1: bool): if self.use_separable_convolution: return ResnetBlockSeparable.create(num_channels, is_1x1, block_args=self.block_args) else: return ResnetBlock.create(num_channels, is_1x1, block_args=self.block_args)