Spaces:
Running
Running
from typing import Union | |
import cv2 | |
import torch | |
import numpy as np | |
from torch import nn | |
from torchvision import transforms as T | |
class SRCNN(nn.Module): | |
def __init__( | |
self, | |
input_channels=3, | |
output_channels=3, | |
input_size=33, | |
label_size=21, | |
scale=2, | |
device=None, | |
): | |
super().__init__() | |
self.input_size = input_size | |
self.label_size = label_size | |
self.pad = (self.input_size - self.label_size) // 2 | |
self.scale = scale | |
self.model = nn.Sequential( | |
nn.Conv2d(input_channels, 64, 9), | |
nn.ReLU(), | |
nn.Conv2d(64, 32, 1), | |
nn.ReLU(), | |
nn.Conv2d(32, output_channels, 5), | |
nn.ReLU(), | |
) | |
self.transform = T.Compose( | |
[T.ToTensor()] # Scale between [0, 1] | |
) | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.device = device | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.model(x) | |
def pre_process(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: | |
if torch.is_tensor(x): | |
return x / 255.0 | |
else: | |
return self.transform(x) | |
def post_process(self, x: torch.Tensor) -> torch.Tensor: | |
return x.clip(0, 1) * 255.0 | |
def enhance(self, image: np.ndarray, outscale: float = 2) -> np.ndarray: | |
(h, w) = image.shape[:2] | |
scale_w = int((w - w % self.label_size + self.input_size) * self.scale) | |
scale_h = int((h - h % self.label_size + self.input_size) * self.scale) | |
# resize the input image using bicubic interpolation | |
scaled = cv2.resize(image, (scale_w, scale_h), interpolation=cv2.INTER_CUBIC) | |
# Preprocessing | |
in_tensor = self.pre_process(scaled) # (C, H, W) | |
out_tensor = torch.zeros_like(in_tensor) # (C, H, W) | |
# slide a window from left-to-right and top-to-bottom | |
for y in range(0, scale_h - self.input_size + 1, self.label_size): | |
for x in range(0, scale_w - self.input_size + 1, self.label_size): | |
# crop ROI from our scaled image | |
crop = in_tensor[:, y : y + self.input_size, x : x + self.input_size] | |
# make a prediction on the crop and store it in our output | |
crop_inp = crop.unsqueeze(0).to(self.device) | |
pred = self.forward(crop_inp).cpu().squeeze() | |
out_tensor[ | |
:, | |
y + self.pad : y + self.pad + self.label_size, | |
x + self.pad : x + self.pad + self.label_size, | |
] = pred | |
out_tensor = self.post_process(out_tensor) | |
output = out_tensor.permute(1, 2, 0).numpy() # (C, H, W) to (H, W, C) | |
output = output[self.pad : -self.pad * 2, self.pad : -self.pad * 2] | |
output = np.clip(output, 0, 255).astype("uint8") | |
# Use openCV to upsample image if scaling factor different than 2 | |
if outscale != 2: | |
interpolation = cv2.INTER_AREA if outscale < 2 else cv2.INTER_LANCZOS4 | |
h, w = output.shape[0:2] | |
output = cv2.resize( | |
output, | |
(int(w * outscale / 2), int(h * outscale / 2)), | |
interpolation=interpolation, | |
) | |
return output, None | |