Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('SAFMN')
|
3 |
+
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
import argparse
|
7 |
+
import glob
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
|
16 |
+
########################################## Wavelet colorfix ###################################
|
17 |
+
from PIL import Image
|
18 |
+
from torch import Tensor
|
19 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
20 |
+
def adain_color_fix(target: Image, source: Image):
|
21 |
+
# Convert images to tensors
|
22 |
+
to_tensor = ToTensor()
|
23 |
+
target_tensor = to_tensor(target).unsqueeze(0)
|
24 |
+
source_tensor = to_tensor(source).unsqueeze(0)
|
25 |
+
|
26 |
+
# Apply adaptive instance normalization
|
27 |
+
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
28 |
+
|
29 |
+
# Convert tensor back to image
|
30 |
+
to_image = ToPILImage()
|
31 |
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
32 |
+
|
33 |
+
return result_image
|
34 |
+
|
35 |
+
def wavelet_color_fix(target: Image, source: Image):
|
36 |
+
if target.size() != source.size():
|
37 |
+
source = source.resize((target.size()[-2], target.size()[-1]), Image.LANCZOS)
|
38 |
+
# Convert images to tensors
|
39 |
+
to_tensor = ToTensor()
|
40 |
+
target_tensor = to_tensor(target).unsqueeze(0)
|
41 |
+
source_tensor = to_tensor(source).unsqueeze(0)
|
42 |
+
|
43 |
+
# Apply wavelet reconstruction
|
44 |
+
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
45 |
+
|
46 |
+
# Convert tensor back to image
|
47 |
+
to_image = ToPILImage()
|
48 |
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
49 |
+
|
50 |
+
return result_image
|
51 |
+
|
52 |
+
def calc_mean_std(feat: Tensor, eps=1e-5):
|
53 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
54 |
+
Args:
|
55 |
+
feat (Tensor): 4D tensor.
|
56 |
+
eps (float): A small value added to the variance to avoid
|
57 |
+
divide-by-zero. Default: 1e-5.
|
58 |
+
"""
|
59 |
+
size = feat.size()
|
60 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
61 |
+
b, c = size[:2]
|
62 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
63 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
64 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
65 |
+
return feat_mean, feat_std
|
66 |
+
|
67 |
+
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
68 |
+
"""Adaptive instance normalization.
|
69 |
+
Adjust the reference features to have the similar color and illuminations
|
70 |
+
as those in the degradate features.
|
71 |
+
Args:
|
72 |
+
content_feat (Tensor): The reference feature.
|
73 |
+
style_feat (Tensor): The degradate features.
|
74 |
+
"""
|
75 |
+
size = content_feat.size()
|
76 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
77 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
78 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
79 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
80 |
+
|
81 |
+
def wavelet_blur(image: Tensor, radius: int):
|
82 |
+
"""
|
83 |
+
Apply wavelet blur to the input tensor.
|
84 |
+
"""
|
85 |
+
# input shape: (1, 3, H, W)
|
86 |
+
# convolution kernel
|
87 |
+
kernel_vals = [
|
88 |
+
[0.0625, 0.125, 0.0625],
|
89 |
+
[0.125, 0.25, 0.125],
|
90 |
+
[0.0625, 0.125, 0.0625],
|
91 |
+
]
|
92 |
+
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
93 |
+
# add channel dimensions to the kernel to make it a 4D tensor
|
94 |
+
kernel = kernel[None, None]
|
95 |
+
# repeat the kernel across all input channels
|
96 |
+
kernel = kernel.repeat(3, 1, 1, 1)
|
97 |
+
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
98 |
+
# apply convolution
|
99 |
+
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
100 |
+
return output
|
101 |
+
|
102 |
+
def wavelet_decomposition(image: Tensor, levels=5):
|
103 |
+
"""
|
104 |
+
Apply wavelet decomposition to the input tensor.
|
105 |
+
This function only returns the low frequency & the high frequency.
|
106 |
+
"""
|
107 |
+
high_freq = torch.zeros_like(image)
|
108 |
+
for i in range(levels):
|
109 |
+
radius = 2 ** i
|
110 |
+
low_freq = wavelet_blur(image, radius)
|
111 |
+
high_freq += (image - low_freq)
|
112 |
+
image = low_freq
|
113 |
+
|
114 |
+
return high_freq, low_freq
|
115 |
+
|
116 |
+
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
117 |
+
"""
|
118 |
+
Apply wavelet decomposition, so that the content will have the same color as the style.
|
119 |
+
"""
|
120 |
+
# calculate the wavelet decomposition of the content feature
|
121 |
+
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
122 |
+
del content_low_freq
|
123 |
+
# calculate the wavelet decomposition of the style feature
|
124 |
+
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
125 |
+
del style_high_freq
|
126 |
+
# reconstruct the content feature with the style's high frequency
|
127 |
+
return content_high_freq + style_low_freq
|
128 |
+
|
129 |
+
|
130 |
+
########################################## URL Load ###################################
|
131 |
+
from torch.hub import download_url_to_file, get_dir
|
132 |
+
from urllib.parse import urlparse
|
133 |
+
|
134 |
+
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
135 |
+
"""Load file form http url, will download models if necessary.
|
136 |
+
|
137 |
+
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
138 |
+
|
139 |
+
Args:
|
140 |
+
url (str): URL to be downloaded.
|
141 |
+
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
142 |
+
Default: None.
|
143 |
+
progress (bool): Whether to show the download progress. Default: True.
|
144 |
+
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
str: The path to the downloaded file.
|
148 |
+
"""
|
149 |
+
if model_dir is None: # use the pytorch hub_dir
|
150 |
+
hub_dir = get_dir()
|
151 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
152 |
+
|
153 |
+
os.makedirs(model_dir, exist_ok=True)
|
154 |
+
|
155 |
+
parts = urlparse(url)
|
156 |
+
filename = os.path.basename(parts.path)
|
157 |
+
if file_name is not None:
|
158 |
+
filename = file_name
|
159 |
+
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
160 |
+
if not os.path.exists(cached_file):
|
161 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
162 |
+
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
163 |
+
return cached_file
|
164 |
+
|
165 |
+
|
166 |
+
########################################## Model Define ###################################
|
167 |
+
# Layer Norm
|
168 |
+
class LayerNorm(nn.Module):
|
169 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
|
170 |
+
super().__init__()
|
171 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
172 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
173 |
+
self.eps = eps
|
174 |
+
self.data_format = data_format
|
175 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
176 |
+
raise NotImplementedError
|
177 |
+
self.normalized_shape = (normalized_shape, )
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
if self.data_format == "channels_last":
|
181 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
182 |
+
elif self.data_format == "channels_first":
|
183 |
+
u = x.mean(1, keepdim=True)
|
184 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
185 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
186 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
187 |
+
return x
|
188 |
+
|
189 |
+
# CCM
|
190 |
+
class CCM(nn.Module):
|
191 |
+
def __init__(self, dim, growth_rate=2.0):
|
192 |
+
super().__init__()
|
193 |
+
hidden_dim = int(dim * growth_rate)
|
194 |
+
|
195 |
+
self.ccm = nn.Sequential(
|
196 |
+
nn.Conv2d(dim, hidden_dim, 3, 1, 1),
|
197 |
+
nn.GELU(),
|
198 |
+
nn.Conv2d(hidden_dim, dim, 1, 1, 0)
|
199 |
+
)
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
return self.ccm(x)
|
203 |
+
|
204 |
+
|
205 |
+
# SAFM
|
206 |
+
class SAFM(nn.Module):
|
207 |
+
def __init__(self, dim, n_levels=4):
|
208 |
+
super().__init__()
|
209 |
+
self.n_levels = n_levels
|
210 |
+
chunk_dim = dim // n_levels
|
211 |
+
|
212 |
+
# Spatial Weighting
|
213 |
+
self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
|
214 |
+
|
215 |
+
# # Feature Aggregation
|
216 |
+
self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
|
217 |
+
|
218 |
+
# Activation
|
219 |
+
self.act = nn.GELU()
|
220 |
+
|
221 |
+
def forward(self, x):
|
222 |
+
h, w = x.size()[-2:]
|
223 |
+
|
224 |
+
xc = x.chunk(self.n_levels, dim=1)
|
225 |
+
out = []
|
226 |
+
for i in range(self.n_levels):
|
227 |
+
if i > 0:
|
228 |
+
p_size = (h//2**i, w//2**i)
|
229 |
+
s = F.adaptive_max_pool2d(xc[i], p_size)
|
230 |
+
s = self.mfr[i](s)
|
231 |
+
s = F.interpolate(s, size=(h, w), mode='nearest')
|
232 |
+
else:
|
233 |
+
s = self.mfr[i](xc[i])
|
234 |
+
out.append(s)
|
235 |
+
|
236 |
+
out = self.aggr(torch.cat(out, dim=1))
|
237 |
+
out = self.act(out) * x
|
238 |
+
return out
|
239 |
+
|
240 |
+
class AttBlock(nn.Module):
|
241 |
+
def __init__(self, dim, ffn_scale=2.0):
|
242 |
+
super().__init__()
|
243 |
+
|
244 |
+
self.norm1 = LayerNorm(dim)
|
245 |
+
self.norm2 = LayerNorm(dim)
|
246 |
+
|
247 |
+
# Multiscale Block
|
248 |
+
self.safm = SAFM(dim)
|
249 |
+
# Feedforward layer
|
250 |
+
self.ccm = CCM(dim, ffn_scale)
|
251 |
+
|
252 |
+
def forward(self, x):
|
253 |
+
x = self.safm(self.norm1(x)) + x
|
254 |
+
x = self.ccm(self.norm2(x)) + x
|
255 |
+
return x
|
256 |
+
|
257 |
+
|
258 |
+
class SAFMN(nn.Module):
|
259 |
+
def __init__(self, dim, n_blocks=8, ffn_scale=2.0, upscaling_factor=4):
|
260 |
+
super().__init__()
|
261 |
+
self.to_feat = nn.Conv2d(3, dim, 3, 1, 1)
|
262 |
+
|
263 |
+
self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
|
264 |
+
|
265 |
+
self.to_img = nn.Sequential(
|
266 |
+
nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1),
|
267 |
+
nn.PixelShuffle(upscaling_factor)
|
268 |
+
)
|
269 |
+
|
270 |
+
def forward(self, x):
|
271 |
+
x = self.to_feat(x)
|
272 |
+
x = self.feats(x) + x
|
273 |
+
x = self.to_img(x)
|
274 |
+
return x
|
275 |
+
|
276 |
+
########################################## Gradio inference ###################################
|
277 |
+
pretrain_model_url = {
|
278 |
+
'safmn_x2': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x2-v2.pth',
|
279 |
+
'safmn_x4': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x4-v2.pth',
|
280 |
+
}
|
281 |
+
|
282 |
+
|
283 |
+
# download weights
|
284 |
+
if not os.path.exists('./experiments/pretrained_models/SAFMN_L_Real_LSDIR_x2-v2.pth'):
|
285 |
+
load_file_from_url(url=pretrain_model_url['safmn_x2'], model_dir='./experiments/pretrained_models/', progress=True, file_name=None)
|
286 |
+
|
287 |
+
if not os.path.exists('./experiments/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'):
|
288 |
+
load_file_from_url(url=pretrain_model_url['safmn_x4'], model_dir='./experiments/pretrained_models/', progress=True, file_name=None)
|
289 |
+
|
290 |
+
|
291 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
292 |
+
|
293 |
+
def set_safmn(upscale):
|
294 |
+
model = SAFMN(dim=128, n_blocks=16, ffn_scale=2.0, upscaling_factor=upscale)
|
295 |
+
if upscale == 2:
|
296 |
+
model_path = './experiments/pretrained_models/SAFMN_L_Real_LSDIR_x2.pth'
|
297 |
+
elif upscale == 4:
|
298 |
+
model_path = './experiments/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'
|
299 |
+
else:
|
300 |
+
raise NotImplementedError('Only support x2/x4 upscaling!')
|
301 |
+
|
302 |
+
model.load_state_dict(torch.load(model_path)['params'], strict=True)
|
303 |
+
model.eval()
|
304 |
+
return model.to(device)
|
305 |
+
|
306 |
+
|
307 |
+
def img2patch(lq, scale=4, crop_size=512):
|
308 |
+
b, c, hl, wl = lq.size()
|
309 |
+
h, w = hl*scale, wl*scale
|
310 |
+
sr_size = (b, c, h, w)
|
311 |
+
assert b == 1
|
312 |
+
|
313 |
+
crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale
|
314 |
+
|
315 |
+
#adaptive step_i, step_j
|
316 |
+
num_row = (h - 1) // crop_size_h + 1
|
317 |
+
num_col = (w - 1) // crop_size_w + 1
|
318 |
+
|
319 |
+
import math
|
320 |
+
step_j = crop_size_w if num_col == 1 else math.ceil((w - crop_size_w) / (num_col - 1) - 1e-8)
|
321 |
+
step_i = crop_size_h if num_row == 1 else math.ceil((h - crop_size_h) / (num_row - 1) - 1e-8)
|
322 |
+
|
323 |
+
step_i = step_i // scale * scale
|
324 |
+
step_j = step_j // scale * scale
|
325 |
+
|
326 |
+
parts = []
|
327 |
+
idxes = []
|
328 |
+
|
329 |
+
i = 0 # 0~h-1
|
330 |
+
last_i = False
|
331 |
+
while i < h and not last_i:
|
332 |
+
j = 0
|
333 |
+
if i + crop_size_h >= h:
|
334 |
+
i = h - crop_size_h
|
335 |
+
last_i = True
|
336 |
+
|
337 |
+
last_j = False
|
338 |
+
while j < w and not last_j:
|
339 |
+
if j + crop_size_w >= w:
|
340 |
+
j = w - crop_size_w
|
341 |
+
last_j = True
|
342 |
+
parts.append(lq[:, :, i // scale :(i + crop_size_h) // scale, j // scale:(j + crop_size_w) // scale])
|
343 |
+
idxes.append({'i': i, 'j': j})
|
344 |
+
j = j + step_j
|
345 |
+
i = i + step_i
|
346 |
+
|
347 |
+
return torch.cat(parts, dim=0), idxes, sr_size
|
348 |
+
|
349 |
+
|
350 |
+
def patch2img(outs, idxes, sr_size, scale=4, crop_size=512):
|
351 |
+
preds = torch.zeros(sr_size).to(outs.device)
|
352 |
+
b, c, h, w = sr_size
|
353 |
+
|
354 |
+
count_mt = torch.zeros((b, 1, h, w)).to(outs.device)
|
355 |
+
crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale
|
356 |
+
|
357 |
+
for cnt, each_idx in enumerate(idxes):
|
358 |
+
i = each_idx['i']
|
359 |
+
j = each_idx['j']
|
360 |
+
preds[0, :, i: i + crop_size_h, j: j + crop_size_w] += outs[cnt]
|
361 |
+
count_mt[0, 0, i: i + crop_size_h, j: j + crop_size_w] += 1.
|
362 |
+
|
363 |
+
return (preds / count_mt).to(outs.device)
|
364 |
+
|
365 |
+
|
366 |
+
os.makedirs('./results', exist_ok=True)
|
367 |
+
|
368 |
+
def inference(image, upscale, large_input_flag, color_fix):
|
369 |
+
upscale = int(upscale) # convert type to int
|
370 |
+
if upscale > 4:
|
371 |
+
upscale = 4
|
372 |
+
if 0 < upscale < 3:
|
373 |
+
upscale = 2
|
374 |
+
|
375 |
+
model = set_safmn(upscale)
|
376 |
+
|
377 |
+
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
|
378 |
+
print(f'input size: {img.shape}')
|
379 |
+
|
380 |
+
# img2tensor
|
381 |
+
img = img.astype(np.float32) / 255.
|
382 |
+
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
383 |
+
img = img.unsqueeze(0).to(device)
|
384 |
+
|
385 |
+
# inference
|
386 |
+
if large_input_flag:
|
387 |
+
patches, idx, size = img2patch(img, scale=upscale)
|
388 |
+
with torch.no_grad():
|
389 |
+
n = len(patches)
|
390 |
+
outs = []
|
391 |
+
m = 1
|
392 |
+
i = 0
|
393 |
+
while i < n:
|
394 |
+
j = i + m
|
395 |
+
if j >= n:
|
396 |
+
j = n
|
397 |
+
pred = output = model(patches[i:j])
|
398 |
+
if isinstance(pred, list):
|
399 |
+
pred = pred[-1]
|
400 |
+
outs.append(pred.detach())
|
401 |
+
i = j
|
402 |
+
output = torch.cat(outs, dim=0)
|
403 |
+
|
404 |
+
output = patch2img(output, idx, size, scale=upscale)
|
405 |
+
else:
|
406 |
+
with torch.no_grad():
|
407 |
+
output = model(img)
|
408 |
+
|
409 |
+
# color fix
|
410 |
+
if color_fix:
|
411 |
+
img = F.interpolate(img, scale_factor=upscale, mode='bilinear')
|
412 |
+
output = wavelet_reconstruction(output, img)
|
413 |
+
# tensor2img
|
414 |
+
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
415 |
+
if output.ndim == 3:
|
416 |
+
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
417 |
+
output = (output * 255.0).round().astype(np.uint8)
|
418 |
+
|
419 |
+
# save restored img
|
420 |
+
save_path = f'results/out.png'
|
421 |
+
cv2.imwrite(save_path, output)
|
422 |
+
|
423 |
+
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
424 |
+
return output, save_path
|
425 |
+
|
426 |
+
|
427 |
+
|
428 |
+
title = "Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution"
|
429 |
+
description = r"""
|
430 |
+
<b>Official Gradio demo</b> for <a href='https://github.com/sunny2109/SAFMN' target='_blank'><b>Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution (ICCV 2023)</b></a>.<br>
|
431 |
+
"""
|
432 |
+
article = r"""
|
433 |
+
If SAFMN is helpful, please help to β the <a href='https://github.com/sunny2109/SAFMN' target='_blank'>Github Repo</a>. Thanks!
|
434 |
+
[![GitHub Stars](https://img.shields.io/github/stars/sunny2109/SAFMN?style=social)](https://github.com/sunny2109/SAFMN)
|
435 |
+
|
436 |
+
---
|
437 |
+
π **Citation**
|
438 |
+
|
439 |
+
If our work is useful for your research, please consider citing:
|
440 |
+
```bibtex
|
441 |
+
@inproceedings{sun2023safmn,
|
442 |
+
title={Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution},
|
443 |
+
author={Sun, Long and Dong, Jiangxin and Tang, Jinhui and Pan, Jinshan},
|
444 |
+
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
445 |
+
year={2023}
|
446 |
+
}
|
447 |
+
```
|
448 |
+
|
449 |
+
<center><img src='https://visitor-badge.laobi.icu/badge?page_id=sunny2109/SAFMN' alt='visitors'></center>
|
450 |
+
"""
|
451 |
+
|
452 |
+
demo = gr.Interface(
|
453 |
+
inference, [
|
454 |
+
gr.inputs.Image(type="filepath", label="Input"),
|
455 |
+
gr.inputs.Number(default=2, label="Upscaling factor (up to 4)"),
|
456 |
+
gr.inputs.Checkbox(default=False, label="Memory-efficient inference"),
|
457 |
+
gr.inputs.Checkbox(default=False, label="Color correction"),
|
458 |
+
], [
|
459 |
+
gr.outputs.Image(type="numpy", label="Output"),
|
460 |
+
gr.outputs.File(label="Download the output")
|
461 |
+
],
|
462 |
+
title=title,
|
463 |
+
description=description,
|
464 |
+
article=article,
|
465 |
+
)
|
466 |
+
|
467 |
+
demo.queue(concurrency_count=2)
|
468 |
+
demo.launch()
|