SAFMN / app.py
Meloo's picture
Create app.py
46255b1 verified
raw
history blame
15.8 kB
import sys
sys.path.append('SAFMN')
import os
import cv2
import argparse
import glob
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
########################################## Wavelet colorfix ###################################
from PIL import Image
from torch import Tensor
from torchvision.transforms import ToTensor, ToPILImage
def adain_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply adaptive instance normalization
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def wavelet_color_fix(target: Image, source: Image):
if target.size() != source.size():
source = source.resize((target.size()[-2], target.size()[-1]), Image.LANCZOS)
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply wavelet reconstruction
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
########################################## URL Load ###################################
from torch.hub import download_url_to_file, get_dir
from urllib.parse import urlparse
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
########################################## Model Define ###################################
# Layer Norm
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# CCM
class CCM(nn.Module):
def __init__(self, dim, growth_rate=2.0):
super().__init__()
hidden_dim = int(dim * growth_rate)
self.ccm = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 3, 1, 1),
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1, 1, 0)
)
def forward(self, x):
return self.ccm(x)
# SAFM
class SAFM(nn.Module):
def __init__(self, dim, n_levels=4):
super().__init__()
self.n_levels = n_levels
chunk_dim = dim // n_levels
# Spatial Weighting
self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
# # Feature Aggregation
self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
# Activation
self.act = nn.GELU()
def forward(self, x):
h, w = x.size()[-2:]
xc = x.chunk(self.n_levels, dim=1)
out = []
for i in range(self.n_levels):
if i > 0:
p_size = (h//2**i, w//2**i)
s = F.adaptive_max_pool2d(xc[i], p_size)
s = self.mfr[i](s)
s = F.interpolate(s, size=(h, w), mode='nearest')
else:
s = self.mfr[i](xc[i])
out.append(s)
out = self.aggr(torch.cat(out, dim=1))
out = self.act(out) * x
return out
class AttBlock(nn.Module):
def __init__(self, dim, ffn_scale=2.0):
super().__init__()
self.norm1 = LayerNorm(dim)
self.norm2 = LayerNorm(dim)
# Multiscale Block
self.safm = SAFM(dim)
# Feedforward layer
self.ccm = CCM(dim, ffn_scale)
def forward(self, x):
x = self.safm(self.norm1(x)) + x
x = self.ccm(self.norm2(x)) + x
return x
class SAFMN(nn.Module):
def __init__(self, dim, n_blocks=8, ffn_scale=2.0, upscaling_factor=4):
super().__init__()
self.to_feat = nn.Conv2d(3, dim, 3, 1, 1)
self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
self.to_img = nn.Sequential(
nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1),
nn.PixelShuffle(upscaling_factor)
)
def forward(self, x):
x = self.to_feat(x)
x = self.feats(x) + x
x = self.to_img(x)
return x
########################################## Gradio inference ###################################
pretrain_model_url = {
'safmn_x2': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x2-v2.pth',
'safmn_x4': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x4-v2.pth',
}
# download weights
if not os.path.exists('./experiments/pretrained_models/SAFMN_L_Real_LSDIR_x2-v2.pth'):
load_file_from_url(url=pretrain_model_url['safmn_x2'], model_dir='./experiments/pretrained_models/', progress=True, file_name=None)
if not os.path.exists('./experiments/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'):
load_file_from_url(url=pretrain_model_url['safmn_x4'], model_dir='./experiments/pretrained_models/', progress=True, file_name=None)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def set_safmn(upscale):
model = SAFMN(dim=128, n_blocks=16, ffn_scale=2.0, upscaling_factor=upscale)
if upscale == 2:
model_path = './experiments/pretrained_models/SAFMN_L_Real_LSDIR_x2.pth'
elif upscale == 4:
model_path = './experiments/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'
else:
raise NotImplementedError('Only support x2/x4 upscaling!')
model.load_state_dict(torch.load(model_path)['params'], strict=True)
model.eval()
return model.to(device)
def img2patch(lq, scale=4, crop_size=512):
b, c, hl, wl = lq.size()
h, w = hl*scale, wl*scale
sr_size = (b, c, h, w)
assert b == 1
crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale
#adaptive step_i, step_j
num_row = (h - 1) // crop_size_h + 1
num_col = (w - 1) // crop_size_w + 1
import math
step_j = crop_size_w if num_col == 1 else math.ceil((w - crop_size_w) / (num_col - 1) - 1e-8)
step_i = crop_size_h if num_row == 1 else math.ceil((h - crop_size_h) / (num_row - 1) - 1e-8)
step_i = step_i // scale * scale
step_j = step_j // scale * scale
parts = []
idxes = []
i = 0 # 0~h-1
last_i = False
while i < h and not last_i:
j = 0
if i + crop_size_h >= h:
i = h - crop_size_h
last_i = True
last_j = False
while j < w and not last_j:
if j + crop_size_w >= w:
j = w - crop_size_w
last_j = True
parts.append(lq[:, :, i // scale :(i + crop_size_h) // scale, j // scale:(j + crop_size_w) // scale])
idxes.append({'i': i, 'j': j})
j = j + step_j
i = i + step_i
return torch.cat(parts, dim=0), idxes, sr_size
def patch2img(outs, idxes, sr_size, scale=4, crop_size=512):
preds = torch.zeros(sr_size).to(outs.device)
b, c, h, w = sr_size
count_mt = torch.zeros((b, 1, h, w)).to(outs.device)
crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale
for cnt, each_idx in enumerate(idxes):
i = each_idx['i']
j = each_idx['j']
preds[0, :, i: i + crop_size_h, j: j + crop_size_w] += outs[cnt]
count_mt[0, 0, i: i + crop_size_h, j: j + crop_size_w] += 1.
return (preds / count_mt).to(outs.device)
os.makedirs('./results', exist_ok=True)
def inference(image, upscale, large_input_flag, color_fix):
upscale = int(upscale) # convert type to int
if upscale > 4:
upscale = 4
if 0 < upscale < 3:
upscale = 2
model = set_safmn(upscale)
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
print(f'input size: {img.shape}')
# img2tensor
img = img.astype(np.float32) / 255.
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img = img.unsqueeze(0).to(device)
# inference
if large_input_flag:
patches, idx, size = img2patch(img, scale=upscale)
with torch.no_grad():
n = len(patches)
outs = []
m = 1
i = 0
while i < n:
j = i + m
if j >= n:
j = n
pred = output = model(patches[i:j])
if isinstance(pred, list):
pred = pred[-1]
outs.append(pred.detach())
i = j
output = torch.cat(outs, dim=0)
output = patch2img(output, idx, size, scale=upscale)
else:
with torch.no_grad():
output = model(img)
# color fix
if color_fix:
img = F.interpolate(img, scale_factor=upscale, mode='bilinear')
output = wavelet_reconstruction(output, img)
# tensor2img
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
# save restored img
save_path = f'results/out.png'
cv2.imwrite(save_path, output)
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
return output, save_path
title = "Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution"
description = r"""
<b>Official Gradio demo</b> for <a href='https://github.com/sunny2109/SAFMN' target='_blank'><b>Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution (ICCV 2023)</b></a>.<br>
"""
article = r"""
If SAFMN is helpful, please help to ⭐ the <a href='https://github.com/sunny2109/SAFMN' target='_blank'>Github Repo</a>. Thanks!
[![GitHub Stars](https://img.shields.io/github/stars/sunny2109/SAFMN?style=social)](https://github.com/sunny2109/SAFMN)
---
πŸ“ **Citation**
If our work is useful for your research, please consider citing:
```bibtex
@inproceedings{sun2023safmn,
title={Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution},
author={Sun, Long and Dong, Jiangxin and Tang, Jinhui and Pan, Jinshan},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
year={2023}
}
```
<center><img src='https://visitor-badge.laobi.icu/badge?page_id=sunny2109/SAFMN' alt='visitors'></center>
"""
demo = gr.Interface(
inference, [
gr.inputs.Image(type="filepath", label="Input"),
gr.inputs.Number(default=2, label="Upscaling factor (up to 4)"),
gr.inputs.Checkbox(default=False, label="Memory-efficient inference"),
gr.inputs.Checkbox(default=False, label="Color correction"),
], [
gr.outputs.Image(type="numpy", label="Output"),
gr.outputs.File(label="Download the output")
],
title=title,
description=description,
article=article,
)
demo.queue(concurrency_count=2)
demo.launch()