import os
from enum import IntEnum
from pathlib import Path
from tempfile import mktemp
from typing import IO, Dict, Type
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gradio import Interface, inputs, outputs
DEVICE = "cpu"
WEIGHTS_PATH = Path(__file__).parent / "weights"
basename: path
for basename, ext in (
os.path.splitext(filename) for filename in os.listdir(WEIGHTS_PATH)
if (path := WEIGHTS_PATH / (basename + ext)).is_file() and ext.endswith("pth")
class ScaleMode(IntEnum):
up2x = 2
up3x = 3
up4x = 4
class TileMode(IntEnum):
full = 0
half = 1
quarter = 2
ninth = 3
sixteenth = 4
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=8, bias=False):
super(SEBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, in_channels // reduction, 1, 1, 0, bias=bias
self.conv2 = nn.Conv2d(
in_channels // reduction, in_channels, 1, 1, 0, bias=bias
def forward(self, x):
if "Half" in x.type():
x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half()
x0 = torch.mean(x, dim=(2, 3), keepdim=True)
x0 = self.conv1(x0)
x0 = F.relu(x0, inplace=True)
x0 = self.conv2(x0)
x0 = torch.sigmoid(x0)
x = torch.mul(x, x0)
return x
def forward_mean(self, x, x0):
x0 = self.conv1(x0)
x0 = F.relu(x0, inplace=True)
x0 = self.conv2(x0)
x0 = torch.sigmoid(x0)
x = torch.mul(x, x0)
return x
class UNetConv(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, se):
super(UNetConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(mid_channels, out_channels, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
if se:
self.seblock = SEBlock(out_channels, reduction=8, bias=True)
self.seblock = None
def forward(self, x):
z = self.conv(x)
if self.seblock is not None:
z = self.seblock(z)
return z
class UNet1(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet1, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 128, 64, se=True)
self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
def forward_a(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1, x2
def forward_b(self, x1, x2):
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
class UNet1x3(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet1x3, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 128, 64, se=True)
self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 5, 3, 2)
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
def forward_a(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1, x2
def forward_b(self, x1, x2):
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
class UNet2(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet2, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 64, 128, se=True)
self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0)
self.conv3 = UNetConv(128, 256, 128, se=True)
self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0)
self.conv4 = UNetConv(128, 64, 64, se=True)
self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv5 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x3 = self.conv2_down(x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x3 = self.conv3(x3)
x3 = self.conv3_up(x3)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x2 = F.pad(x2, (-4, -4, -4, -4))
x4 = self.conv4(x2 + x3)
x4 = self.conv4_up(x4)
x4 = F.leaky_relu(x4, 0.1, inplace=True)
x1 = F.pad(x1, (-16, -16, -16, -16))
x5 = self.conv5(x1 + x4)
x5 = F.leaky_relu(x5, 0.1, inplace=True)
z = self.conv_bottom(x5)
return z
def forward_a(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1, x2
def forward_b(self, x2):
x3 = self.conv2_down(x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x3 = self.conv3.conv(x3)
return x3
def forward_c(self, x2, x3):
x3 = self.conv3_up(x3)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x2 = F.pad(x2, (-4, -4, -4, -4))
x4 = self.conv4.conv(x2 + x3)
return x4
def forward_d(self, x1, x4):
x4 = self.conv4_up(x4)
x4 = F.leaky_relu(x4, 0.1, inplace=True)
x1 = F.pad(x1, (-16, -16, -16, -16))
x5 = self.conv5(x1 + x4)
x5 = F.leaky_relu(x5, 0.1, inplace=True)
z = self.conv_bottom(x5)
return z
class UpCunet2x(nn.Module):
def __init__(self, in_channels=3, out_channels=3):
super(UpCunet2x, self).__init__()
self.unet1 = UNet1(in_channels, out_channels, deconv=True)
self.unet2 = UNet2(in_channels, out_channels, deconv=False)
def forward(self, x, tile_mode):
n, c, h0, w0 = x.shape
if tile_mode == 0:
ph = ((h0 - 1) // 2 + 1) * 2
pw = ((w0 - 1) // 2 + 1) * 2
x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect")
x = self.unet1.forward(x)
x0 = self.unet2.forward(x)
x1 = F.pad(x, (-20, -20, -20, -20))
x = torch.add(x0, x1)
if w0 != pw or h0 != ph:
x = x[:, :, : h0 * 2, : w0 * 2]
return x
elif tile_mode == 1:
if w0 >= h0:
crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2
crop_size_h = (h0 - 1) // 2 * 2 + 2
crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2
crop_size_w = (w0 - 1) // 2 * 2 + 2
crop_size = (crop_size_h, crop_size_w)
elif tile_mode == 2:
crop_size = (
((h0 - 1) // 4 * 4 + 4) // 2,
((w0 - 1) // 4 * 4 + 4) // 2,
elif tile_mode == 3:
crop_size = (
((h0 - 1) // 6 * 6 + 6) // 3,
((w0 - 1) // 6 * 6 + 6) // 3,
elif tile_mode == 4:
crop_size = (
((h0 - 1) // 8 * 8 + 8) // 4,
((w0 - 1) // 8 * 8 + 8) // 4,
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect")
n, c, h, w = x.shape
se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean0 = se_mean0.half()
n_patch = 0
tmp_dict = {}
opt_res_dict = {}
for i in range(0, h - 36, crop_size[0]):
tmp_dict[i] = {}
for j in range(0, w - 36, crop_size[1]):
x_crop = x[:, :, i : i + crop_size[0] + 36, j : j + crop_size[1] + 36]
n, c1, h1, w1 = x_crop.shape
tmp0, x_crop = self.unet1.forward_a(x_crop)
if "Half" in x.type():
tmp_se_mean = torch.mean(
x_crop.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
n_patch += 1
tmp_dict[i][j] = (tmp0, x_crop)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 36, crop_size[0]):
for j in range(0, w - 36, crop_size[1]):
tmp0, x_crop = tmp_dict[i][j]
x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
if "Half" in x.type():
tmp_se_mean = torch.mean(
tmp_x2.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
se_mean1 /= n_patch
se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean0 = se_mean0.half()
for i in range(0, h - 36, crop_size[0]):
for j in range(0, w - 36, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
tmp_x3 = self.unet2.forward_b(tmp_x2)
if "Half" in x.type():
tmp_se_mean = torch.mean(
tmp_x3.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 36, crop_size[0]):
for j in range(0, w - 36, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
if "Half" in x.type():
tmp_se_mean = torch.mean(
tmp_x4.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
se_mean1 /= n_patch
for i in range(0, h - 36, crop_size[0]):
opt_res_dict[i] = {}
for j in range(0, w - 36, crop_size[1]):
opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
x_crop = torch.add(x0, x1)
opt_res_dict[i][j] = x_crop
del tmp_dict
res = torch.zeros((n, c, h * 2 - 72, w * 2 - 72)).to(x.device)
if "Half" in x.type():
res = res.half()
for i in range(0, h - 36, crop_size[0]):
for j in range(0, w - 36, crop_size[1]):
:, :, i * 2 : i * 2 + h1 * 2 - 72, j * 2 : j * 2 + w1 * 2 - 72
] = opt_res_dict[i][j]
del opt_res_dict
if w0 != pw or h0 != ph:
res = res[:, :, : h0 * 2, : w0 * 2]
return res
class UpCunet3x(nn.Module):
def __init__(self, in_channels=3, out_channels=3):
super(UpCunet3x, self).__init__()
self.unet1 = UNet1x3(in_channels, out_channels, deconv=True)
self.unet2 = UNet2(in_channels, out_channels, deconv=False)
def forward(self, x, tile_mode):
n, c, h0, w0 = x.shape
if tile_mode == 0:
ph = ((h0 - 1) // 4 + 1) * 4
pw = ((w0 - 1) // 4 + 1) * 4
x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect")
x = self.unet1.forward(x)
x0 = self.unet2.forward(x)
x1 = F.pad(x, (-20, -20, -20, -20))
x = torch.add(x0, x1)
if w0 != pw or h0 != ph:
x = x[:, :, : h0 * 3, : w0 * 3]
return x
elif tile_mode == 1:
if w0 >= h0:
crop_size_w = ((w0 - 1) // 8 * 8 + 8) // 2
crop_size_h = (h0 - 1) // 4 * 4 + 4
crop_size_h = ((h0 - 1) // 8 * 8 + 8) // 2
crop_size_w = (w0 - 1) // 4 * 4 + 4
crop_size = (crop_size_h, crop_size_w)
elif tile_mode == 2:
crop_size = (
((h0 - 1) // 8 * 8 + 8) // 2,
((w0 - 1) // 8 * 8 + 8) // 2,
elif tile_mode == 3:
crop_size = (
((h0 - 1) // 12 * 12 + 12) // 3,
((w0 - 1) // 12 * 12 + 12) // 3,
elif tile_mode == 4:
crop_size = (
((h0 - 1) // 16 * 16 + 16) // 4,
((w0 - 1) // 16 * 16 + 16) // 4,
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect")
n, c, h, w = x.shape
se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean0 = se_mean0.half()
n_patch = 0
tmp_dict = {}
opt_res_dict = {}
for i in range(0, h - 28, crop_size[0]):
tmp_dict[i] = {}
for j in range(0, w - 28, crop_size[1]):
x_crop = x[:, :, i : i + crop_size[0] + 28, j : j + crop_size[1] + 28]
n, c1, h1, w1 = x_crop.shape
tmp0, x_crop = self.unet1.forward_a(x_crop)
if "Half" in x.type():
tmp_se_mean = torch.mean(
x_crop.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
n_patch += 1
tmp_dict[i][j] = (tmp0, x_crop)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 28, crop_size[0]):
for j in range(0, w - 28, crop_size[1]):
tmp0, x_crop = tmp_dict[i][j]
x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
if "Half" in x.type():
tmp_se_mean = torch.mean(
tmp_x2.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
se_mean1 /= n_patch
se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean0 = se_mean0.half()
for i in range(0, h - 28, crop_size[0]):
for j in range(0, w - 28, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
tmp_x3 = self.unet2.forward_b(tmp_x2)
if "Half" in x.type():
tmp_se_mean = torch.mean(
tmp_x3.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 28, crop_size[0]):
for j in range(0, w - 28, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
if "Half" in x.type():
tmp_se_mean = torch.mean(
tmp_x4.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
se_mean1 /= n_patch
for i in range(0, h - 28, crop_size[0]):
opt_res_dict[i] = {}
for j in range(0, w - 28, crop_size[1]):
opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
x_crop = torch.add(x0, x1)
opt_res_dict[i][j] = x_crop
del tmp_dict
res = torch.zeros((n, c, h * 3 - 84, w * 3 - 84)).to(x.device)
if "Half" in x.type():
res = res.half()
for i in range(0, h - 28, crop_size[0]):
for j in range(0, w - 28, crop_size[1]):
:, :, i * 3 : i * 3 + h1 * 3 - 84, j * 3 : j * 3 + w1 * 3 - 84
] = opt_res_dict[i][j]
del opt_res_dict
if w0 != pw or h0 != ph:
res = res[:, :, : h0 * 3, : w0 * 3]
return res
class UpCunet4x(nn.Module):
def __init__(self, in_channels=3, out_channels=3):
super(UpCunet4x, self).__init__()
self.unet1 = UNet1(in_channels, 64, deconv=True)
self.unet2 = UNet2(64, 64, deconv=False)
self.ps = nn.PixelShuffle(2)
self.conv_final = nn.Conv2d(64, 12, 3, 1, padding=0, bias=True)
def forward(self, x, tile_mode):
n, c, h0, w0 = x.shape
x00 = x
if tile_mode == 0:
ph = ((h0 - 1) // 2 + 1) * 2
pw = ((w0 - 1) // 2 + 1) * 2
x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect")
x = self.unet1.forward(x)
x0 = self.unet2.forward(x)
x1 = F.pad(x, (-20, -20, -20, -20))
x = torch.add(x0, x1)
x = self.conv_final(x)
x = F.pad(x, (-1, -1, -1, -1))
x = self.ps(x)
if w0 != pw or h0 != ph:
x = x[:, :, : h0 * 4, : w0 * 4]
x += F.interpolate(x00, scale_factor=4, mode="nearest")
return x
elif tile_mode == 1:
if w0 >= h0:
crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2
crop_size_h = (h0 - 1) // 2 * 2 + 2
crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2
crop_size_w = (w0 - 1) // 2 * 2 + 2
crop_size = (crop_size_h, crop_size_w)
elif tile_mode == 2:
crop_size = (
((h0 - 1) // 4 * 4 + 4) // 2,
((w0 - 1) // 4 * 4 + 4) // 2,
elif tile_mode == 3:
crop_size = (
((h0 - 1) // 6 * 6 + 6) // 3,
((w0 - 1) // 6 * 6 + 6) // 3,
elif tile_mode == 4:
crop_size = (
((h0 - 1) // 8 * 8 + 8) // 4,
((w0 - 1) // 8 * 8 + 8) // 4,
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect")
n, c, h, w = x.shape
se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean0 = se_mean0.half()
n_patch = 0
tmp_dict = {}
opt_res_dict = {}
for i in range(0, h - 38, crop_size[0]):
tmp_dict[i] = {}
for j in range(0, w - 38, crop_size[1]):
x_crop = x[:, :, i : i + crop_size[0] + 38, j : j + crop_size[1] + 38]
n, c1, h1, w1 = x_crop.shape
tmp0, x_crop = self.unet1.forward_a(x_crop)
if "Half" in x.type():
tmp_se_mean = torch.mean(
x_crop.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
n_patch += 1
tmp_dict[i][j] = (tmp0, x_crop)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 38, crop_size[0]):
for j in range(0, w - 38, crop_size[1]):
tmp0, x_crop = tmp_dict[i][j]
x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
if "Half" in x.type():
tmp_se_mean = torch.mean(
tmp_x2.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
se_mean1 /= n_patch
se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean0 = se_mean0.half()
for i in range(0, h - 38, crop_size[0]):
for j in range(0, w - 38, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
tmp_x3 = self.unet2.forward_b(tmp_x2)
if "Half" in x.type():
tmp_se_mean = torch.mean(
tmp_x3.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 38, crop_size[0]):
for j in range(0, w - 38, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
if "Half" in x.type():
tmp_se_mean = torch.mean(
tmp_x4.float(), dim=(2, 3), keepdim=True
tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
se_mean1 /= n_patch
for i in range(0, h - 38, crop_size[0]):
opt_res_dict[i] = {}
for j in range(0, w - 38, crop_size[1]):
opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
x_crop = torch.add(x0, x1)
x_crop = self.conv_final(x_crop)
x_crop = F.pad(x_crop, (-1, -1, -1, -1))
x_crop = self.ps(x_crop)
opt_res_dict[i][j] = x_crop
del tmp_dict
res = torch.zeros((n, c, h * 4 - 152, w * 4 - 152)).to(x.device)
if "Half" in x.type():
res = res.half()
for i in range(0, h - 38, crop_size[0]):
for j in range(0, w - 38, crop_size[1]):
:, :, i * 4 : i * 4 + h1 * 4 - 152, j * 4 : j * 4 + w1 * 4 - 152
] = opt_res_dict[i][j]
del opt_res_dict
if w0 != pw or h0 != ph:
res = res[:, :, : h0 * 4, : w0 * 4]
res += F.interpolate(x00, scale_factor=4, mode="nearest")
return res
models: Dict[str, Type[nn.Module]] = {
obj.__name__: obj
for obj in globals().values()
if isinstance(obj, type) and issubclass(obj, nn.Module)
class RealWaifuUpScaler:
def __init__(self, scale: int, weight_path: str, half: bool, device: str):
weight = torch.load(weight_path, map_location=device)
self.model = models[f"UpCunet{scale}x"]()
if half == True:
self.model = self.model.half().to(device)
self.model = self.model.to(device)
self.model.load_state_dict(weight, strict=True)
self.half = half
self.device = device
def np2tensor(self, np_frame):
if self.half == False:
return (
torch.from_numpy(np.transpose(np_frame, (2, 0, 1)))
/ 255
return (
torch.from_numpy(np.transpose(np_frame, (2, 0, 1)))
/ 255
def tensor2np(self, tensor):
if self.half == False:
return np.transpose(
(tensor.data.squeeze() * 255.0)
.clamp_(0, 255)
(1, 2, 0),
return np.transpose(
(tensor.data.squeeze().float() * 255.0)
.clamp_(0, 255)
(1, 2, 0),
def __call__(self, frame, tile_mode):
with torch.no_grad():
tensor = self.np2tensor(frame)
result = self.tensor2np(self.model(tensor, tile_mode))
return result
input_image = inputs.File(label="Input image")
half_precision = inputs.Checkbox(
label="Half precision (NOT work for CPU)", default=False
model_weight = inputs.Dropdown(sorted(AVALIABLE_WEIGHTS), label="Choice model weight")
tile_mode = inputs.Radio([mode.name for mode in TileMode], label="Output tile mode")
output_image = outputs.Image(label="Output image preview")
output_file = outputs.File(label="Output image file")
def main(file: IO[bytes], half: bool, weight: str, tile: str):
scale = next(mode.value for mode in ScaleMode if weight.startswith(mode.name))
upscaler = RealWaifuUpScaler(
scale, weight_path=str(AVALIABLE_WEIGHTS[weight]), half=half, device=DEVICE
frame = cv2.cvtColor(cv2.imread(file.name), cv2.COLOR_BGR2RGB)
result = cv2.cvtColor(upscaler(frame, TileMode[tile]), cv2.COLOR_RGB2BGR)
_, ext = os.path.splitext(file.name)
tempfile = mktemp(suffix=ext)
cv2.imwrite(tempfile, result)
return tempfile, tempfile
interface = Interface(
inputs=[input_image, half_precision, model_weight, tile_mode],
outputs=[output_image, output_file],