|
import sys |
|
sys.path.append('SAFMN') |
|
|
|
import os |
|
import cv2 |
|
import argparse |
|
import glob |
|
import numpy as np |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import gradio as gr |
|
|
|
|
|
|
|
from PIL import Image |
|
from torch import Tensor |
|
from torchvision.transforms import ToTensor, ToPILImage |
|
def adain_color_fix(target: Image, source: Image): |
|
|
|
to_tensor = ToTensor() |
|
target_tensor = to_tensor(target).unsqueeze(0) |
|
source_tensor = to_tensor(source).unsqueeze(0) |
|
|
|
|
|
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) |
|
|
|
|
|
to_image = ToPILImage() |
|
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) |
|
|
|
return result_image |
|
|
|
def wavelet_color_fix(target: Image, source: Image): |
|
if target.size() != source.size(): |
|
source = source.resize((target.size()[-2], target.size()[-1]), Image.LANCZOS) |
|
|
|
to_tensor = ToTensor() |
|
target_tensor = to_tensor(target).unsqueeze(0) |
|
source_tensor = to_tensor(source).unsqueeze(0) |
|
|
|
|
|
result_tensor = wavelet_reconstruction(target_tensor, source_tensor) |
|
|
|
|
|
to_image = ToPILImage() |
|
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) |
|
|
|
return result_image |
|
|
|
def calc_mean_std(feat: Tensor, eps=1e-5): |
|
"""Calculate mean and std for adaptive_instance_normalization. |
|
Args: |
|
feat (Tensor): 4D tensor. |
|
eps (float): A small value added to the variance to avoid |
|
divide-by-zero. Default: 1e-5. |
|
""" |
|
size = feat.size() |
|
assert len(size) == 4, 'The input feature should be 4D tensor.' |
|
b, c = size[:2] |
|
feat_var = feat.view(b, c, -1).var(dim=2) + eps |
|
feat_std = feat_var.sqrt().view(b, c, 1, 1) |
|
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) |
|
return feat_mean, feat_std |
|
|
|
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): |
|
"""Adaptive instance normalization. |
|
Adjust the reference features to have the similar color and illuminations |
|
as those in the degradate features. |
|
Args: |
|
content_feat (Tensor): The reference feature. |
|
style_feat (Tensor): The degradate features. |
|
""" |
|
size = content_feat.size() |
|
style_mean, style_std = calc_mean_std(style_feat) |
|
content_mean, content_std = calc_mean_std(content_feat) |
|
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) |
|
return normalized_feat * style_std.expand(size) + style_mean.expand(size) |
|
|
|
def wavelet_blur(image: Tensor, radius: int): |
|
""" |
|
Apply wavelet blur to the input tensor. |
|
""" |
|
|
|
|
|
kernel_vals = [ |
|
[0.0625, 0.125, 0.0625], |
|
[0.125, 0.25, 0.125], |
|
[0.0625, 0.125, 0.0625], |
|
] |
|
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) |
|
|
|
kernel = kernel[None, None] |
|
|
|
kernel = kernel.repeat(3, 1, 1, 1) |
|
image = F.pad(image, (radius, radius, radius, radius), mode='replicate') |
|
|
|
output = F.conv2d(image, kernel, groups=3, dilation=radius) |
|
return output |
|
|
|
def wavelet_decomposition(image: Tensor, levels=5): |
|
""" |
|
Apply wavelet decomposition to the input tensor. |
|
This function only returns the low frequency & the high frequency. |
|
""" |
|
high_freq = torch.zeros_like(image) |
|
for i in range(levels): |
|
radius = 2 ** i |
|
low_freq = wavelet_blur(image, radius) |
|
high_freq += (image - low_freq) |
|
image = low_freq |
|
|
|
return high_freq, low_freq |
|
|
|
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): |
|
""" |
|
Apply wavelet decomposition, so that the content will have the same color as the style. |
|
""" |
|
|
|
content_high_freq, content_low_freq = wavelet_decomposition(content_feat) |
|
del content_low_freq |
|
|
|
style_high_freq, style_low_freq = wavelet_decomposition(style_feat) |
|
del style_high_freq |
|
|
|
return content_high_freq + style_low_freq |
|
|
|
|
|
|
|
from torch.hub import download_url_to_file, get_dir |
|
from urllib.parse import urlparse |
|
|
|
def load_file_from_url(url, model_dir=None, progress=True, file_name=None): |
|
"""Load file form http url, will download models if necessary. |
|
|
|
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py |
|
|
|
Args: |
|
url (str): URL to be downloaded. |
|
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. |
|
Default: None. |
|
progress (bool): Whether to show the download progress. Default: True. |
|
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. |
|
|
|
Returns: |
|
str: The path to the downloaded file. |
|
""" |
|
if model_dir is None: |
|
hub_dir = get_dir() |
|
model_dir = os.path.join(hub_dir, 'checkpoints') |
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
parts = urlparse(url) |
|
filename = os.path.basename(parts.path) |
|
if file_name is not None: |
|
filename = file_name |
|
cached_file = os.path.abspath(os.path.join(model_dir, filename)) |
|
if not os.path.exists(cached_file): |
|
print(f'Downloading: "{url}" to {cached_file}\n') |
|
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) |
|
return cached_file |
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
|
self.eps = eps |
|
self.data_format = data_format |
|
if self.data_format not in ["channels_last", "channels_first"]: |
|
raise NotImplementedError |
|
self.normalized_shape = (normalized_shape, ) |
|
|
|
def forward(self, x): |
|
if self.data_format == "channels_last": |
|
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
elif self.data_format == "channels_first": |
|
u = x.mean(1, keepdim=True) |
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
return x |
|
|
|
|
|
class CCM(nn.Module): |
|
def __init__(self, dim, growth_rate=2.0): |
|
super().__init__() |
|
hidden_dim = int(dim * growth_rate) |
|
|
|
self.ccm = nn.Sequential( |
|
nn.Conv2d(dim, hidden_dim, 3, 1, 1), |
|
nn.GELU(), |
|
nn.Conv2d(hidden_dim, dim, 1, 1, 0) |
|
) |
|
|
|
def forward(self, x): |
|
return self.ccm(x) |
|
|
|
|
|
|
|
class SAFM(nn.Module): |
|
def __init__(self, dim, n_levels=4): |
|
super().__init__() |
|
self.n_levels = n_levels |
|
chunk_dim = dim // n_levels |
|
|
|
|
|
self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)]) |
|
|
|
|
|
self.aggr = nn.Conv2d(dim, dim, 1, 1, 0) |
|
|
|
|
|
self.act = nn.GELU() |
|
|
|
def forward(self, x): |
|
h, w = x.size()[-2:] |
|
|
|
xc = x.chunk(self.n_levels, dim=1) |
|
out = [] |
|
for i in range(self.n_levels): |
|
if i > 0: |
|
p_size = (h//2**i, w//2**i) |
|
s = F.adaptive_max_pool2d(xc[i], p_size) |
|
s = self.mfr[i](s) |
|
s = F.interpolate(s, size=(h, w), mode='nearest') |
|
else: |
|
s = self.mfr[i](xc[i]) |
|
out.append(s) |
|
|
|
out = self.aggr(torch.cat(out, dim=1)) |
|
out = self.act(out) * x |
|
return out |
|
|
|
class AttBlock(nn.Module): |
|
def __init__(self, dim, ffn_scale=2.0): |
|
super().__init__() |
|
|
|
self.norm1 = LayerNorm(dim) |
|
self.norm2 = LayerNorm(dim) |
|
|
|
|
|
self.safm = SAFM(dim) |
|
|
|
self.ccm = CCM(dim, ffn_scale) |
|
|
|
def forward(self, x): |
|
x = self.safm(self.norm1(x)) + x |
|
x = self.ccm(self.norm2(x)) + x |
|
return x |
|
|
|
|
|
class SAFMN(nn.Module): |
|
def __init__(self, dim, n_blocks=8, ffn_scale=2.0, upscaling_factor=4): |
|
super().__init__() |
|
self.to_feat = nn.Conv2d(3, dim, 3, 1, 1) |
|
|
|
self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)]) |
|
|
|
self.to_img = nn.Sequential( |
|
nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1), |
|
nn.PixelShuffle(upscaling_factor) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.to_feat(x) |
|
x = self.feats(x) + x |
|
x = self.to_img(x) |
|
return x |
|
|
|
|
|
pretrain_model_url = { |
|
'safmn_x2': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x2-v2.pth', |
|
'safmn_x4': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x4-v2.pth', |
|
} |
|
|
|
|
|
|
|
if not os.path.exists('./experiments/pretrained_models/SAFMN_L_Real_LSDIR_x2-v2.pth'): |
|
load_file_from_url(url=pretrain_model_url['safmn_x2'], model_dir='./experiments/pretrained_models/', progress=True, file_name=None) |
|
|
|
if not os.path.exists('./experiments/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'): |
|
load_file_from_url(url=pretrain_model_url['safmn_x4'], model_dir='./experiments/pretrained_models/', progress=True, file_name=None) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def set_safmn(upscale): |
|
model = SAFMN(dim=128, n_blocks=16, ffn_scale=2.0, upscaling_factor=upscale) |
|
if upscale == 2: |
|
model_path = './experiments/pretrained_models/SAFMN_L_Real_LSDIR_x2.pth' |
|
elif upscale == 4: |
|
model_path = './experiments/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth' |
|
else: |
|
raise NotImplementedError('Only support x2/x4 upscaling!') |
|
|
|
model.load_state_dict(torch.load(model_path)['params'], strict=True) |
|
model.eval() |
|
return model.to(device) |
|
|
|
|
|
def img2patch(lq, scale=4, crop_size=512): |
|
b, c, hl, wl = lq.size() |
|
h, w = hl*scale, wl*scale |
|
sr_size = (b, c, h, w) |
|
assert b == 1 |
|
|
|
crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale |
|
|
|
|
|
num_row = (h - 1) // crop_size_h + 1 |
|
num_col = (w - 1) // crop_size_w + 1 |
|
|
|
import math |
|
step_j = crop_size_w if num_col == 1 else math.ceil((w - crop_size_w) / (num_col - 1) - 1e-8) |
|
step_i = crop_size_h if num_row == 1 else math.ceil((h - crop_size_h) / (num_row - 1) - 1e-8) |
|
|
|
step_i = step_i // scale * scale |
|
step_j = step_j // scale * scale |
|
|
|
parts = [] |
|
idxes = [] |
|
|
|
i = 0 |
|
last_i = False |
|
while i < h and not last_i: |
|
j = 0 |
|
if i + crop_size_h >= h: |
|
i = h - crop_size_h |
|
last_i = True |
|
|
|
last_j = False |
|
while j < w and not last_j: |
|
if j + crop_size_w >= w: |
|
j = w - crop_size_w |
|
last_j = True |
|
parts.append(lq[:, :, i // scale :(i + crop_size_h) // scale, j // scale:(j + crop_size_w) // scale]) |
|
idxes.append({'i': i, 'j': j}) |
|
j = j + step_j |
|
i = i + step_i |
|
|
|
return torch.cat(parts, dim=0), idxes, sr_size |
|
|
|
|
|
def patch2img(outs, idxes, sr_size, scale=4, crop_size=512): |
|
preds = torch.zeros(sr_size).to(outs.device) |
|
b, c, h, w = sr_size |
|
|
|
count_mt = torch.zeros((b, 1, h, w)).to(outs.device) |
|
crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale |
|
|
|
for cnt, each_idx in enumerate(idxes): |
|
i = each_idx['i'] |
|
j = each_idx['j'] |
|
preds[0, :, i: i + crop_size_h, j: j + crop_size_w] += outs[cnt] |
|
count_mt[0, 0, i: i + crop_size_h, j: j + crop_size_w] += 1. |
|
|
|
return (preds / count_mt).to(outs.device) |
|
|
|
|
|
os.makedirs('./results', exist_ok=True) |
|
|
|
def inference(image, upscale, large_input_flag, color_fix): |
|
upscale = int(upscale) |
|
if upscale > 4: |
|
upscale = 4 |
|
if 0 < upscale < 3: |
|
upscale = 2 |
|
|
|
model = set_safmn(upscale) |
|
|
|
img = cv2.imread(str(image), cv2.IMREAD_COLOR) |
|
print(f'input size: {img.shape}') |
|
|
|
|
|
img = img.astype(np.float32) / 255. |
|
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() |
|
img = img.unsqueeze(0).to(device) |
|
|
|
|
|
if large_input_flag: |
|
patches, idx, size = img2patch(img, scale=upscale) |
|
with torch.no_grad(): |
|
n = len(patches) |
|
outs = [] |
|
m = 1 |
|
i = 0 |
|
while i < n: |
|
j = i + m |
|
if j >= n: |
|
j = n |
|
pred = output = model(patches[i:j]) |
|
if isinstance(pred, list): |
|
pred = pred[-1] |
|
outs.append(pred.detach()) |
|
i = j |
|
output = torch.cat(outs, dim=0) |
|
|
|
output = patch2img(output, idx, size, scale=upscale) |
|
else: |
|
with torch.no_grad(): |
|
output = model(img) |
|
|
|
|
|
if color_fix: |
|
img = F.interpolate(img, scale_factor=upscale, mode='bilinear') |
|
output = wavelet_reconstruction(output, img) |
|
|
|
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() |
|
if output.ndim == 3: |
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) |
|
output = (output * 255.0).round().astype(np.uint8) |
|
|
|
|
|
save_path = f'results/out.png' |
|
cv2.imwrite(save_path, output) |
|
|
|
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) |
|
return output, save_path |
|
|
|
|
|
|
|
title = "Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution" |
|
description = r""" |
|
<b>Official Gradio demo</b> for <a href='https://github.com/sunny2109/SAFMN' target='_blank'><b>Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution (ICCV 2023)</b></a>.<br> |
|
""" |
|
article = r""" |
|
If SAFMN is helpful, please help to β the <a href='https://github.com/sunny2109/SAFMN' target='_blank'>Github Repo</a>. Thanks! |
|
[![GitHub Stars](https://img.shields.io/github/stars/sunny2109/SAFMN?style=social)](https://github.com/sunny2109/SAFMN) |
|
|
|
--- |
|
π **Citation** |
|
|
|
If our work is useful for your research, please consider citing: |
|
```bibtex |
|
@inproceedings{sun2023safmn, |
|
title={Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution}, |
|
author={Sun, Long and Dong, Jiangxin and Tang, Jinhui and Pan, Jinshan}, |
|
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, |
|
year={2023} |
|
} |
|
``` |
|
|
|
<center><img src='https://visitor-badge.laobi.icu/badge?page_id=sunny2109/SAFMN' alt='visitors'></center> |
|
""" |
|
|
|
demo = gr.Interface( |
|
inference, [ |
|
gr.inputs.Image(type="filepath", label="Input"), |
|
gr.inputs.Number(default=2, label="Upscaling factor (up to 4)"), |
|
gr.inputs.Checkbox(default=False, label="Memory-efficient inference"), |
|
gr.inputs.Checkbox(default=False, label="Color correction"), |
|
], [ |
|
gr.outputs.Image(type="numpy", label="Output"), |
|
gr.outputs.File(label="Download the output") |
|
], |
|
title=title, |
|
description=description, |
|
article=article, |
|
) |
|
|
|
demo.queue(concurrency_count=2) |
|
demo.launch() |