Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import functools | |
import math | |
import re | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from . import block as B | |
# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py | |
# Which enhanced stuff that was already here | |
class RRDBNet(nn.Module): | |
def __init__( | |
self, | |
state_dict, | |
norm=None, | |
act: str = "leakyrelu", | |
upsampler: str = "upconv", | |
mode: B.ConvMode = "CNA", | |
) -> None: | |
""" | |
ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks. | |
By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, | |
and Chen Change Loy. | |
This is old-arch Residual in Residual Dense Block Network and is not | |
the newest revision that's available at github.com/xinntao/ESRGAN. | |
This is on purpose, the newest Network has severely limited the | |
potential use of the Network with no benefits. | |
This network supports model files from both new and old-arch. | |
Args: | |
norm: Normalization layer | |
act: Activation layer | |
upsampler: Upsample layer. upconv, pixel_shuffle | |
mode: Convolution mode | |
""" | |
super(RRDBNet, self).__init__() | |
self.model_arch = "ESRGAN" | |
self.sub_type = "SR" | |
self.state = state_dict | |
self.norm = norm | |
self.act = act | |
self.upsampler = upsampler | |
self.mode = mode | |
self.state_map = { | |
# currently supports old, new, and newer RRDBNet arch models | |
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN | |
"model.0.weight": ("conv_first.weight",), | |
"model.0.bias": ("conv_first.bias",), | |
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), | |
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), | |
r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( | |
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", | |
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", | |
), | |
} | |
if "params_ema" in self.state: | |
self.state = self.state["params_ema"] | |
# self.model_arch = "RealESRGAN" | |
self.num_blocks = self.get_num_blocks() | |
self.plus = any("conv1x1" in k for k in self.state.keys()) | |
if self.plus: | |
self.model_arch = "ESRGAN+" | |
self.state = self.new_to_old_arch(self.state) | |
self.key_arr = list(self.state.keys()) | |
self.in_nc: int = self.state[self.key_arr[0]].shape[1] | |
self.out_nc: int = self.state[self.key_arr[-1]].shape[0] | |
self.scale: int = self.get_scale() | |
self.num_filters: int = self.state[self.key_arr[0]].shape[0] | |
c2x2 = False | |
if self.state["model.0.weight"].shape[-2] == 2: | |
c2x2 = True | |
self.scale = round(math.sqrt(self.scale / 4)) | |
self.model_arch = "ESRGAN-2c2" | |
self.supports_fp16 = True | |
self.supports_bfp16 = True | |
self.min_size_restriction = None | |
# Detect if pixelunshuffle was used (Real-ESRGAN) | |
if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in ( | |
self.in_nc / 4, | |
self.in_nc / 16, | |
): | |
self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc)) | |
else: | |
self.shuffle_factor = None | |
upsample_block = { | |
"upconv": B.upconv_block, | |
"pixel_shuffle": B.pixelshuffle_block, | |
}.get(self.upsampler) | |
if upsample_block is None: | |
raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found") | |
if self.scale == 3: | |
upsample_blocks = upsample_block( | |
in_nc=self.num_filters, | |
out_nc=self.num_filters, | |
upscale_factor=3, | |
act_type=self.act, | |
c2x2=c2x2, | |
) | |
else: | |
upsample_blocks = [ | |
upsample_block( | |
in_nc=self.num_filters, | |
out_nc=self.num_filters, | |
act_type=self.act, | |
c2x2=c2x2, | |
) | |
for _ in range(int(math.log(self.scale, 2))) | |
] | |
self.model = B.sequential( | |
# fea conv | |
B.conv_block( | |
in_nc=self.in_nc, | |
out_nc=self.num_filters, | |
kernel_size=3, | |
norm_type=None, | |
act_type=None, | |
c2x2=c2x2, | |
), | |
B.ShortcutBlock( | |
B.sequential( | |
# rrdb blocks | |
*[ | |
B.RRDB( | |
nf=self.num_filters, | |
kernel_size=3, | |
gc=32, | |
stride=1, | |
bias=True, | |
pad_type="zero", | |
norm_type=self.norm, | |
act_type=self.act, | |
mode="CNA", | |
plus=self.plus, | |
c2x2=c2x2, | |
) | |
for _ in range(self.num_blocks) | |
], | |
# lr conv | |
B.conv_block( | |
in_nc=self.num_filters, | |
out_nc=self.num_filters, | |
kernel_size=3, | |
norm_type=self.norm, | |
act_type=None, | |
mode=self.mode, | |
c2x2=c2x2, | |
), | |
) | |
), | |
*upsample_blocks, | |
# hr_conv0 | |
B.conv_block( | |
in_nc=self.num_filters, | |
out_nc=self.num_filters, | |
kernel_size=3, | |
norm_type=None, | |
act_type=self.act, | |
c2x2=c2x2, | |
), | |
# hr_conv1 | |
B.conv_block( | |
in_nc=self.num_filters, | |
out_nc=self.out_nc, | |
kernel_size=3, | |
norm_type=None, | |
act_type=None, | |
c2x2=c2x2, | |
), | |
) | |
# Adjust these properties for calculations outside of the model | |
if self.shuffle_factor: | |
self.in_nc //= self.shuffle_factor**2 | |
self.scale //= self.shuffle_factor | |
self.load_state_dict(self.state, strict=False) | |
def new_to_old_arch(self, state): | |
"""Convert a new-arch model state dictionary to an old-arch dictionary.""" | |
if "params_ema" in state: | |
state = state["params_ema"] | |
if "conv_first.weight" not in state: | |
# model is already old arch, this is a loose check, but should be sufficient | |
return state | |
# add nb to state keys | |
for kind in ("weight", "bias"): | |
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ | |
f"model.1.sub./NB/.{kind}" | |
] | |
del self.state_map[f"model.1.sub./NB/.{kind}"] | |
old_state = OrderedDict() | |
for old_key, new_keys in self.state_map.items(): | |
for new_key in new_keys: | |
if r"\1" in old_key: | |
for k, v in state.items(): | |
sub = re.sub(new_key, old_key, k) | |
if sub != k: | |
old_state[sub] = v | |
else: | |
if new_key in state: | |
old_state[old_key] = state[new_key] | |
# upconv layers | |
max_upconv = 0 | |
for key in state.keys(): | |
match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key) | |
if match is not None: | |
_, key_num, key_type = match.groups() | |
old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key] | |
max_upconv = max(max_upconv, int(key_num) * 3) | |
# final layers | |
for key in state.keys(): | |
if key in ("HRconv.weight", "conv_hr.weight"): | |
old_state[f"model.{max_upconv + 2}.weight"] = state[key] | |
elif key in ("HRconv.bias", "conv_hr.bias"): | |
old_state[f"model.{max_upconv + 2}.bias"] = state[key] | |
elif key in ("conv_last.weight",): | |
old_state[f"model.{max_upconv + 4}.weight"] = state[key] | |
elif key in ("conv_last.bias",): | |
old_state[f"model.{max_upconv + 4}.bias"] = state[key] | |
# Sort by first numeric value of each layer | |
def compare(item1, item2): | |
parts1 = item1.split(".") | |
parts2 = item2.split(".") | |
int1 = int(parts1[1]) | |
int2 = int(parts2[1]) | |
return int1 - int2 | |
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) | |
# Rebuild the output dict in the right order | |
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) | |
return out_dict | |
def get_scale(self, min_part: int = 6) -> int: | |
n = 0 | |
for part in list(self.state): | |
parts = part.split(".")[1:] | |
if len(parts) == 2: | |
part_num = int(parts[0]) | |
if part_num > min_part and parts[1] == "weight": | |
n += 1 | |
return 2**n | |
def get_num_blocks(self) -> int: | |
nbs = [] | |
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( | |
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", | |
) | |
for state_key in state_keys: | |
for k in self.state: | |
m = re.search(state_key, k) | |
if m: | |
nbs.append(int(m.group(1))) | |
if nbs: | |
break | |
return max(*nbs) + 1 | |
def forward(self, x): | |
if self.shuffle_factor: | |
_, _, h, w = x.size() | |
mod_pad_h = ( | |
self.shuffle_factor - h % self.shuffle_factor | |
) % self.shuffle_factor | |
mod_pad_w = ( | |
self.shuffle_factor - w % self.shuffle_factor | |
) % self.shuffle_factor | |
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") | |
x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor) | |
x = self.model(x) | |
return x[:, :, : h * self.scale, : w * self.scale] | |
return self.model(x) | |