from typing import Tuple, Union import torch import torch.nn as nn class CausalConv3d(nn.Module): def __init__( self, in_channels, out_channels, kernel_size: int = 3, stride: Union[int, Tuple[int]] = 1, dilation: int = 1, groups: int = 1, **kwargs, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels kernel_size = (kernel_size, kernel_size, kernel_size) self.time_kernel_size = kernel_size[0] dilation = (dilation, 1, 1) height_pad = kernel_size[1] // 2 width_pad = kernel_size[2] // 2 padding = (0, height_pad, width_pad) self.conv = nn.Conv3d( in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, padding_mode="zeros", groups=groups, ) def forward(self, x, causal: bool = True): if causal: first_frame_pad = x[:, :, :1, :, :].repeat( (1, 1, self.time_kernel_size - 1, 1, 1) ) x = torch.concatenate((first_frame_pad, x), dim=2) else: first_frame_pad = x[:, :, :1, :, :].repeat( (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) ) last_frame_pad = x[:, :, -1:, :, :].repeat( (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) ) x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) x = self.conv(x) return x @property def weight(self): return self.conv.weight