Spaces:
Running
on
A10G
Running
on
A10G
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from .kernels import normalize_kernel2d | |
def _compute_padding(kernel_size: List[int]) -> List[int]: | |
"""Compute padding tuple.""" | |
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) | |
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad | |
if len(kernel_size) < 2: | |
raise AssertionError(kernel_size) | |
computed = [k - 1 for k in kernel_size] | |
# for even kernels we need to do asymmetric padding :( | |
out_padding = 2 * len(kernel_size) * [0] | |
for i in range(len(kernel_size)): | |
computed_tmp = computed[-(i + 1)] | |
pad_front = computed_tmp // 2 | |
pad_rear = computed_tmp - pad_front | |
out_padding[2 * i + 0] = pad_front | |
out_padding[2 * i + 1] = pad_rear | |
return out_padding | |
def filter2d( | |
input: torch.Tensor, | |
kernel: torch.Tensor, | |
border_type: str = 'reflect', | |
normalized: bool = False, | |
padding: str = 'same', | |
) -> torch.Tensor: | |
r"""Convolve a tensor with a 2d kernel. | |
The function applies a given kernel to a tensor. The kernel is applied | |
independently at each depth channel of the tensor. Before applying the | |
kernel, the function applies padding according to the specified mode so | |
that the output remains in the same shape. | |
Args: | |
input: the input tensor with shape of | |
:math:`(B, C, H, W)`. | |
kernel: the kernel to be convolved with the input | |
tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`. | |
border_type: the padding mode to be applied before convolving. | |
The expected modes are: ``'constant'``, ``'reflect'``, | |
``'replicate'`` or ``'circular'``. | |
normalized: If True, kernel will be L1 normalized. | |
padding: This defines the type of padding. | |
2 modes available ``'same'`` or ``'valid'``. | |
Return: | |
torch.Tensor: the convolved tensor of same size and numbers of channels | |
as the input with shape :math:`(B, C, H, W)`. | |
Example: | |
>>> input = torch.tensor([[[ | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 5., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.],]]]) | |
>>> kernel = torch.ones(1, 3, 3) | |
>>> filter2d(input, kernel, padding='same') | |
tensor([[[[0., 0., 0., 0., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 0., 0., 0., 0.]]]]) | |
""" | |
if not isinstance(input, torch.Tensor): | |
raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}") | |
if not isinstance(kernel, torch.Tensor): | |
raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}") | |
if not isinstance(border_type, str): | |
raise TypeError(f"Input border_type is not string. Got {type(border_type)}") | |
if border_type not in ['constant', 'reflect', 'replicate', 'circular']: | |
raise ValueError( | |
f"Invalid border type, we expect 'constant', \ | |
'reflect', 'replicate', 'circular'. Got:{border_type}" | |
) | |
if not isinstance(padding, str): | |
raise TypeError(f"Input padding is not string. Got {type(padding)}") | |
if padding not in ['valid', 'same']: | |
raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}") | |
if not len(input.shape) == 4: | |
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") | |
if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])): | |
raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}") | |
# prepare kernel | |
b, c, h, w = input.shape | |
tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) | |
if normalized: | |
tmp_kernel = normalize_kernel2d(tmp_kernel) | |
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) | |
height, width = tmp_kernel.shape[-2:] | |
# pad the input tensor | |
if padding == 'same': | |
padding_shape: List[int] = _compute_padding([height, width]) | |
input = F.pad(input, padding_shape, mode=border_type) | |
# kernel and input tensor reshape to align element-wise or batch-wise params | |
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) | |
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) | |
# convolve the tensor with the kernel. | |
output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) | |
if padding == 'same': | |
out = output.view(b, c, h, w) | |
else: | |
out = output.view(b, c, h - height + 1, w - width + 1) | |
return out | |
def filter2d_separable( | |
input: torch.Tensor, | |
kernel_x: torch.Tensor, | |
kernel_y: torch.Tensor, | |
border_type: str = 'reflect', | |
normalized: bool = False, | |
padding: str = 'same', | |
) -> torch.Tensor: | |
r"""Convolve a tensor with two 1d kernels, in x and y directions. | |
The function applies a given kernel to a tensor. The kernel is applied | |
independently at each depth channel of the tensor. Before applying the | |
kernel, the function applies padding according to the specified mode so | |
that the output remains in the same shape. | |
Args: | |
input: the input tensor with shape of | |
:math:`(B, C, H, W)`. | |
kernel_x: the kernel to be convolved with the input | |
tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`. | |
kernel_y: the kernel to be convolved with the input | |
tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`. | |
border_type: the padding mode to be applied before convolving. | |
The expected modes are: ``'constant'``, ``'reflect'``, | |
``'replicate'`` or ``'circular'``. | |
normalized: If True, kernel will be L1 normalized. | |
padding: This defines the type of padding. | |
2 modes available ``'same'`` or ``'valid'``. | |
Return: | |
torch.Tensor: the convolved tensor of same size and numbers of channels | |
as the input with shape :math:`(B, C, H, W)`. | |
Example: | |
>>> input = torch.tensor([[[ | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 5., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.],]]]) | |
>>> kernel = torch.ones(1, 3) | |
>>> filter2d_separable(input, kernel, kernel, padding='same') | |
tensor([[[[0., 0., 0., 0., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 0., 0., 0., 0.]]]]) | |
""" | |
out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding) | |
out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding) | |
return out | |
def filter3d( | |
input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False | |
) -> torch.Tensor: | |
r"""Convolve a tensor with a 3d kernel. | |
The function applies a given kernel to a tensor. The kernel is applied | |
independently at each depth channel of the tensor. Before applying the | |
kernel, the function applies padding according to the specified mode so | |
that the output remains in the same shape. | |
Args: | |
input: the input tensor with shape of | |
:math:`(B, C, D, H, W)`. | |
kernel: the kernel to be convolved with the input | |
tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`. | |
border_type: the padding mode to be applied before convolving. | |
The expected modes are: ``'constant'``, | |
``'replicate'`` or ``'circular'``. | |
normalized: If True, kernel will be L1 normalized. | |
Return: | |
the convolved tensor of same size and numbers of channels | |
as the input with shape :math:`(B, C, D, H, W)`. | |
Example: | |
>>> input = torch.tensor([[[ | |
... [[0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.]], | |
... [[0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 5., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.]], | |
... [[0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.], | |
... [0., 0., 0., 0., 0.]] | |
... ]]]) | |
>>> kernel = torch.ones(1, 3, 3, 3) | |
>>> filter3d(input, kernel) | |
tensor([[[[[0., 0., 0., 0., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 0., 0., 0., 0.]], | |
<BLANKLINE> | |
[[0., 0., 0., 0., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 0., 0., 0., 0.]], | |
<BLANKLINE> | |
[[0., 0., 0., 0., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 5., 5., 5., 0.], | |
[0., 0., 0., 0., 0.]]]]]) | |
""" | |
if not isinstance(input, torch.Tensor): | |
raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}") | |
if not isinstance(kernel, torch.Tensor): | |
raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}") | |
if not isinstance(border_type, str): | |
raise TypeError(f"Input border_type is not string. Got {type(kernel)}") | |
if not len(input.shape) == 5: | |
raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") | |
if not len(kernel.shape) == 4 and kernel.shape[0] != 1: | |
raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}") | |
# prepare kernel | |
b, c, d, h, w = input.shape | |
tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) | |
if normalized: | |
bk, dk, hk, wk = kernel.shape | |
tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel) | |
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1) | |
# pad the input tensor | |
depth, height, width = tmp_kernel.shape[-3:] | |
padding_shape: List[int] = _compute_padding([depth, height, width]) | |
input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) | |
# kernel and input tensor reshape to align element-wise or batch-wise params | |
tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width) | |
input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1)) | |
# convolve the tensor with the kernel. | |
output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) | |
return output.view(b, c, d, h, w) |