|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Conv2d Module with Valid Padding""" |
|
|
|
import torch.nn.functional as F |
|
from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional |
|
|
|
|
|
class Conv2dValid(_ConvNd): |
|
""" |
|
Conv2d operator for VALID mode padding. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: _size_2_t, |
|
stride: _size_2_t = 1, |
|
padding: Union[str, _size_2_t] = 0, |
|
dilation: _size_2_t = 1, |
|
groups: int = 1, |
|
bias: bool = True, |
|
padding_mode: str = "zeros", |
|
device=None, |
|
dtype=None, |
|
valid_trigx: bool = False, |
|
valid_trigy: bool = False, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
kernel_size_ = _pair(kernel_size) |
|
stride_ = _pair(stride) |
|
padding_ = padding if isinstance(padding, str) else _pair(padding) |
|
dilation_ = _pair(dilation) |
|
super(Conv2dValid, self).__init__( |
|
in_channels, |
|
out_channels, |
|
kernel_size_, |
|
stride_, |
|
padding_, |
|
dilation_, |
|
False, |
|
_pair(0), |
|
groups, |
|
bias, |
|
padding_mode, |
|
**factory_kwargs, |
|
) |
|
self.valid_trigx = valid_trigx |
|
self.valid_trigy = valid_trigy |
|
|
|
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): |
|
validx, validy = 0, 0 |
|
if self.valid_trigx: |
|
validx = ( |
|
input.size(-2) * (self.stride[-2] - 1) - 1 + self.kernel_size[-2] |
|
) // 2 |
|
if self.valid_trigy: |
|
validy = ( |
|
input.size(-1) * (self.stride[-1] - 1) - 1 + self.kernel_size[-1] |
|
) // 2 |
|
return F.conv2d( |
|
input, |
|
weight, |
|
bias, |
|
self.stride, |
|
(validx, validy), |
|
self.dilation, |
|
self.groups, |
|
) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return self._conv_forward(input, self.weight, self.bias) |
|
|