from typing import Tuple, Union import torch from xora.models.autoencoders.dual_conv3d import DualConv3d from xora.models.autoencoders.causal_conv3d import CausalConv3d def make_conv_nd( dims: Union[int, Tuple[int, int]], in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, dilation=1, groups=1, bias=True, causal=False, ): if dims == 2: return torch.nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) elif dims == 3: if causal: return CausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) return torch.nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) elif dims == (2, 1): return DualConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, ) else: raise ValueError(f"unsupported dimensions: {dims}") def make_linear_nd( dims: int, in_channels: int, out_channels: int, bias=True, ): if dims == 2: return torch.nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) elif dims == 3 or dims == (2, 1): return torch.nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) else: raise ValueError(f"unsupported dimensions: {dims}")