APISR / architecture /cunet.py
HikariDawn's picture
feat: initial push
561c629
# Github Repository: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/README_EN.md
# Code snippet (with certain modificaiton) from: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/VapourSynth/upcunet_v3_vs.py
import torch
from torch import nn as nn
from torch.nn import functional as F
import os, sys
import numpy as np
from time import time as ttime, sleep
class UNet_Full(nn.Module):
def __init__(self):
super(UNet_Full, self).__init__()
self.unet1 = UNet1(3, 3, deconv=True)
self.unet2 = UNet2(3, 3, deconv=False)
def forward(self, x):
n, c, h0, w0 = x.shape
ph = ((h0 - 1) // 2 + 1) * 2
pw = ((w0 - 1) // 2 + 1) * 2
x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), 'reflect') # In order to ensure that it can be divided by 2
x1 = self.unet1(x)
x2 = self.unet2(x1)
x1 = F.pad(x1, (-20, -20, -20, -20))
output = torch.add(x2, x1)
if (w0 != pw or h0 != ph):
output = output[:, :, :h0 * 2, :w0 * 2]
return output
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=8, bias=False):
super(SEBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, 1, 0, bias=bias)
self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, 1, 0, bias=bias)
def forward(self, x):
if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half()
else:
x0 = torch.mean(x, dim=(2, 3), keepdim=True)
x0 = self.conv1(x0)
x0 = F.relu(x0, inplace=True)
x0 = self.conv2(x0)
x0 = torch.sigmoid(x0)
x = torch.mul(x, x0)
return x
class UNetConv(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, se):
super(UNetConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(mid_channels, out_channels, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
)
if se:
self.seblock = SEBlock(out_channels, reduction=8, bias=True)
else:
self.seblock = None
def forward(self, x):
z = self.conv(x)
if self.seblock is not None:
z = self.seblock(z)
return z
class UNet1(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet1, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 128, 64, se=True)
self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
else:
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
class UNet2(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet2, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 64, 128, se=True)
self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0)
self.conv3 = UNetConv(128, 256, 128, se=True)
self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0)
self.conv4 = UNetConv(128, 64, 64, se=True)
self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv5 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
else:
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x3 = self.conv2_down(x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x3 = self.conv3(x3)
x3 = self.conv3_up(x3)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x2 = F.pad(x2, (-4, -4, -4, -4))
x4 = self.conv4(x2 + x3)
x4 = self.conv4_up(x4)
x4 = F.leaky_relu(x4, 0.1, inplace=True)
x1 = F.pad(x1, (-16, -16, -16, -16))
x5 = self.conv5(x1 + x4)
x5 = F.leaky_relu(x5, 0.1, inplace=True)
z = self.conv_bottom(x5)
return z
def main():
root_path = os.path.abspath('.')
sys.path.append(root_path)
from opt import opt # Manage GPU to choose
import time
model = UNet_Full().cuda()
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(f"CuNet has param {pytorch_total_params//1000} K params")
# Count the number of FLOPs to double check
x = torch.randn((1, 3, 180, 180)).cuda()
start = time.time()
x = model(x)
print("output size is ", x.shape)
total = time.time() - start
print(total)
if __name__ == "__main__":
main()