Osterkarten / modules /filters.py
lllyasviel's picture
SAG implemented (#88)
59ddae4
raw
history blame
1.19 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def gaussian_kernel(kernel_size, sigma):
kernel = np.fromfunction(
lambda x, y: (1 / (2 * np.pi * sigma ** 2)) *
np.exp(-((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2) / (2 * sigma ** 2)),
(kernel_size, kernel_size)
)
return kernel / np.sum(kernel)
class GaussianBlur(nn.Module):
def __init__(self, channels, kernel_size, sigma):
super(GaussianBlur, self).__init__()
self.channels = channels
self.kernel_size = kernel_size
self.sigma = sigma
self.padding = kernel_size // 2 # Ensure output size matches input size
self.register_buffer('kernel', torch.tensor(gaussian_kernel(kernel_size, sigma), dtype=torch.float32))
self.kernel = self.kernel.view(1, 1, kernel_size, kernel_size)
self.kernel = self.kernel.expand(self.channels, -1, -1, -1) # Repeat the kernel for each input channel
def forward(self, x):
x = F.conv2d(x, self.kernel.to(x), padding=self.padding, groups=self.channels)
return x
gaussian_filter_2d = GaussianBlur(4, 7, 0.8)