Spaces:
Runtime error
Runtime error
from typing import Optional | |
from tha3.nn.conv import create_conv7_block_from_block_args, create_conv3_block_from_block_args, \ | |
create_downsample_block_from_block_args, create_conv3 | |
from tha3.nn.resnet_block import ResnetBlock | |
from tha3.nn.resnet_block_seperable import ResnetBlockSeparable | |
from tha3.nn.separable_conv import create_separable_conv7_block, create_separable_conv3_block, \ | |
create_separable_downsample_block, create_separable_conv3 | |
from tha3.nn.util import BlockArgs | |
class ConvBlockFactory: | |
def __init__(self, | |
block_args: BlockArgs, | |
use_separable_convolution: bool = False): | |
self.use_separable_convolution = use_separable_convolution | |
self.block_args = block_args | |
def create_conv3(self, | |
in_channels: int, | |
out_channels: int, | |
bias: bool, | |
initialization_method: Optional[str] = None): | |
if initialization_method is None: | |
initialization_method = self.block_args.initialization_method | |
if self.use_separable_convolution: | |
return create_separable_conv3( | |
in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm) | |
else: | |
return create_conv3( | |
in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm) | |
def create_conv7_block(self, in_channels: int, out_channels: int): | |
if self.use_separable_convolution: | |
return create_separable_conv7_block(in_channels, out_channels, self.block_args) | |
else: | |
return create_conv7_block_from_block_args(in_channels, out_channels, self.block_args) | |
def create_conv3_block(self, in_channels: int, out_channels: int): | |
if self.use_separable_convolution: | |
return create_separable_conv3_block(in_channels, out_channels, self.block_args) | |
else: | |
return create_conv3_block_from_block_args(in_channels, out_channels, self.block_args) | |
def create_downsample_block(self, in_channels: int, out_channels: int, is_output_1x1: bool): | |
if self.use_separable_convolution: | |
return create_separable_downsample_block(in_channels, out_channels, is_output_1x1, self.block_args) | |
else: | |
return create_downsample_block_from_block_args(in_channels, out_channels, is_output_1x1) | |
def create_resnet_block(self, num_channels: int, is_1x1: bool): | |
if self.use_separable_convolution: | |
return ResnetBlockSeparable.create(num_channels, is_1x1, block_args=self.block_args) | |
else: | |
return ResnetBlock.create(num_channels, is_1x1, block_args=self.block_args) |