Update app.py
Browse files
app.py
CHANGED
@@ -8,271 +8,14 @@ import glob
|
|
8 |
import numpy as np
|
9 |
import os
|
10 |
import torch
|
11 |
-
import torch.nn as nn
|
12 |
import torch.nn.functional as F
|
13 |
import gradio as gr
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
from PIL import Image
|
18 |
-
from torch import Tensor
|
19 |
-
from torchvision.transforms import ToTensor, ToPILImage
|
20 |
-
def adain_color_fix(target: Image, source: Image):
|
21 |
-
# Convert images to tensors
|
22 |
-
to_tensor = ToTensor()
|
23 |
-
target_tensor = to_tensor(target).unsqueeze(0)
|
24 |
-
source_tensor = to_tensor(source).unsqueeze(0)
|
25 |
-
|
26 |
-
# Apply adaptive instance normalization
|
27 |
-
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
28 |
-
|
29 |
-
# Convert tensor back to image
|
30 |
-
to_image = ToPILImage()
|
31 |
-
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
32 |
-
|
33 |
-
return result_image
|
34 |
-
|
35 |
-
def wavelet_color_fix(target: Image, source: Image):
|
36 |
-
if target.size() != source.size():
|
37 |
-
source = source.resize((target.size()[-2], target.size()[-1]), Image.LANCZOS)
|
38 |
-
# Convert images to tensors
|
39 |
-
to_tensor = ToTensor()
|
40 |
-
target_tensor = to_tensor(target).unsqueeze(0)
|
41 |
-
source_tensor = to_tensor(source).unsqueeze(0)
|
42 |
-
|
43 |
-
# Apply wavelet reconstruction
|
44 |
-
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
45 |
-
|
46 |
-
# Convert tensor back to image
|
47 |
-
to_image = ToPILImage()
|
48 |
-
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
49 |
-
|
50 |
-
return result_image
|
51 |
-
|
52 |
-
def calc_mean_std(feat: Tensor, eps=1e-5):
|
53 |
-
"""Calculate mean and std for adaptive_instance_normalization.
|
54 |
-
Args:
|
55 |
-
feat (Tensor): 4D tensor.
|
56 |
-
eps (float): A small value added to the variance to avoid
|
57 |
-
divide-by-zero. Default: 1e-5.
|
58 |
-
"""
|
59 |
-
size = feat.size()
|
60 |
-
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
61 |
-
b, c = size[:2]
|
62 |
-
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
63 |
-
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
64 |
-
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
65 |
-
return feat_mean, feat_std
|
66 |
-
|
67 |
-
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
68 |
-
"""Adaptive instance normalization.
|
69 |
-
Adjust the reference features to have the similar color and illuminations
|
70 |
-
as those in the degradate features.
|
71 |
-
Args:
|
72 |
-
content_feat (Tensor): The reference feature.
|
73 |
-
style_feat (Tensor): The degradate features.
|
74 |
-
"""
|
75 |
-
size = content_feat.size()
|
76 |
-
style_mean, style_std = calc_mean_std(style_feat)
|
77 |
-
content_mean, content_std = calc_mean_std(content_feat)
|
78 |
-
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
79 |
-
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
80 |
-
|
81 |
-
def wavelet_blur(image: Tensor, radius: int):
|
82 |
-
"""
|
83 |
-
Apply wavelet blur to the input tensor.
|
84 |
-
"""
|
85 |
-
# input shape: (1, 3, H, W)
|
86 |
-
# convolution kernel
|
87 |
-
kernel_vals = [
|
88 |
-
[0.0625, 0.125, 0.0625],
|
89 |
-
[0.125, 0.25, 0.125],
|
90 |
-
[0.0625, 0.125, 0.0625],
|
91 |
-
]
|
92 |
-
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
93 |
-
# add channel dimensions to the kernel to make it a 4D tensor
|
94 |
-
kernel = kernel[None, None]
|
95 |
-
# repeat the kernel across all input channels
|
96 |
-
kernel = kernel.repeat(3, 1, 1, 1)
|
97 |
-
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
98 |
-
# apply convolution
|
99 |
-
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
100 |
-
return output
|
101 |
-
|
102 |
-
def wavelet_decomposition(image: Tensor, levels=5):
|
103 |
-
"""
|
104 |
-
Apply wavelet decomposition to the input tensor.
|
105 |
-
This function only returns the low frequency & the high frequency.
|
106 |
-
"""
|
107 |
-
high_freq = torch.zeros_like(image)
|
108 |
-
for i in range(levels):
|
109 |
-
radius = 2 ** i
|
110 |
-
low_freq = wavelet_blur(image, radius)
|
111 |
-
high_freq += (image - low_freq)
|
112 |
-
image = low_freq
|
113 |
-
|
114 |
-
return high_freq, low_freq
|
115 |
-
|
116 |
-
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
117 |
-
"""
|
118 |
-
Apply wavelet decomposition, so that the content will have the same color as the style.
|
119 |
-
"""
|
120 |
-
# calculate the wavelet decomposition of the content feature
|
121 |
-
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
122 |
-
del content_low_freq
|
123 |
-
# calculate the wavelet decomposition of the style feature
|
124 |
-
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
125 |
-
del style_high_freq
|
126 |
-
# reconstruct the content feature with the style's high frequency
|
127 |
-
return content_high_freq + style_low_freq
|
128 |
-
|
129 |
-
|
130 |
-
########################################## URL Load ###################################
|
131 |
-
from torch.hub import download_url_to_file, get_dir
|
132 |
-
from urllib.parse import urlparse
|
133 |
-
|
134 |
-
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
135 |
-
"""Load file form http url, will download models if necessary.
|
136 |
-
|
137 |
-
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
138 |
-
|
139 |
-
Args:
|
140 |
-
url (str): URL to be downloaded.
|
141 |
-
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
142 |
-
Default: None.
|
143 |
-
progress (bool): Whether to show the download progress. Default: True.
|
144 |
-
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
145 |
-
|
146 |
-
Returns:
|
147 |
-
str: The path to the downloaded file.
|
148 |
-
"""
|
149 |
-
if model_dir is None: # use the pytorch hub_dir
|
150 |
-
hub_dir = get_dir()
|
151 |
-
model_dir = os.path.join(hub_dir, 'checkpoints')
|
152 |
-
|
153 |
-
os.makedirs(model_dir, exist_ok=True)
|
154 |
-
|
155 |
-
parts = urlparse(url)
|
156 |
-
filename = os.path.basename(parts.path)
|
157 |
-
if file_name is not None:
|
158 |
-
filename = file_name
|
159 |
-
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
160 |
-
if not os.path.exists(cached_file):
|
161 |
-
print(f'Downloading: "{url}" to {cached_file}\n')
|
162 |
-
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
163 |
-
return cached_file
|
164 |
-
|
165 |
-
|
166 |
-
########################################## Model Define ###################################
|
167 |
-
# Layer Norm
|
168 |
-
class LayerNorm(nn.Module):
|
169 |
-
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
|
170 |
-
super().__init__()
|
171 |
-
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
172 |
-
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
173 |
-
self.eps = eps
|
174 |
-
self.data_format = data_format
|
175 |
-
if self.data_format not in ["channels_last", "channels_first"]:
|
176 |
-
raise NotImplementedError
|
177 |
-
self.normalized_shape = (normalized_shape, )
|
178 |
-
|
179 |
-
def forward(self, x):
|
180 |
-
if self.data_format == "channels_last":
|
181 |
-
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
182 |
-
elif self.data_format == "channels_first":
|
183 |
-
u = x.mean(1, keepdim=True)
|
184 |
-
s = (x - u).pow(2).mean(1, keepdim=True)
|
185 |
-
x = (x - u) / torch.sqrt(s + self.eps)
|
186 |
-
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
187 |
-
return x
|
188 |
-
|
189 |
-
# CCM
|
190 |
-
class CCM(nn.Module):
|
191 |
-
def __init__(self, dim, growth_rate=2.0):
|
192 |
-
super().__init__()
|
193 |
-
hidden_dim = int(dim * growth_rate)
|
194 |
-
|
195 |
-
self.ccm = nn.Sequential(
|
196 |
-
nn.Conv2d(dim, hidden_dim, 3, 1, 1),
|
197 |
-
nn.GELU(),
|
198 |
-
nn.Conv2d(hidden_dim, dim, 1, 1, 0)
|
199 |
-
)
|
200 |
-
|
201 |
-
def forward(self, x):
|
202 |
-
return self.ccm(x)
|
203 |
-
|
204 |
-
|
205 |
-
# SAFM
|
206 |
-
class SAFM(nn.Module):
|
207 |
-
def __init__(self, dim, n_levels=4):
|
208 |
-
super().__init__()
|
209 |
-
self.n_levels = n_levels
|
210 |
-
chunk_dim = dim // n_levels
|
211 |
-
|
212 |
-
# Spatial Weighting
|
213 |
-
self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
|
214 |
-
|
215 |
-
# # Feature Aggregation
|
216 |
-
self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
|
217 |
-
|
218 |
-
# Activation
|
219 |
-
self.act = nn.GELU()
|
220 |
-
|
221 |
-
def forward(self, x):
|
222 |
-
h, w = x.size()[-2:]
|
223 |
-
|
224 |
-
xc = x.chunk(self.n_levels, dim=1)
|
225 |
-
out = []
|
226 |
-
for i in range(self.n_levels):
|
227 |
-
if i > 0:
|
228 |
-
p_size = (h//2**i, w//2**i)
|
229 |
-
s = F.adaptive_max_pool2d(xc[i], p_size)
|
230 |
-
s = self.mfr[i](s)
|
231 |
-
s = F.interpolate(s, size=(h, w), mode='nearest')
|
232 |
-
else:
|
233 |
-
s = self.mfr[i](xc[i])
|
234 |
-
out.append(s)
|
235 |
-
|
236 |
-
out = self.aggr(torch.cat(out, dim=1))
|
237 |
-
out = self.act(out) * x
|
238 |
-
return out
|
239 |
-
|
240 |
-
class AttBlock(nn.Module):
|
241 |
-
def __init__(self, dim, ffn_scale=2.0):
|
242 |
-
super().__init__()
|
243 |
-
|
244 |
-
self.norm1 = LayerNorm(dim)
|
245 |
-
self.norm2 = LayerNorm(dim)
|
246 |
-
|
247 |
-
# Multiscale Block
|
248 |
-
self.safm = SAFM(dim)
|
249 |
-
# Feedforward layer
|
250 |
-
self.ccm = CCM(dim, ffn_scale)
|
251 |
-
|
252 |
-
def forward(self, x):
|
253 |
-
x = self.safm(self.norm1(x)) + x
|
254 |
-
x = self.ccm(self.norm2(x)) + x
|
255 |
-
return x
|
256 |
-
|
257 |
-
|
258 |
-
class SAFMN(nn.Module):
|
259 |
-
def __init__(self, dim, n_blocks=8, ffn_scale=2.0, upscaling_factor=4):
|
260 |
-
super().__init__()
|
261 |
-
self.to_feat = nn.Conv2d(3, dim, 3, 1, 1)
|
262 |
-
|
263 |
-
self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
|
264 |
-
|
265 |
-
self.to_img = nn.Sequential(
|
266 |
-
nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1),
|
267 |
-
nn.PixelShuffle(upscaling_factor)
|
268 |
-
)
|
269 |
-
|
270 |
-
def forward(self, x):
|
271 |
-
x = self.to_feat(x)
|
272 |
-
x = self.feats(x) + x
|
273 |
-
x = self.to_img(x)
|
274 |
-
return x
|
275 |
-
|
276 |
########################################## Gradio inference ###################################
|
277 |
pretrain_model_url = {
|
278 |
'safmn_x2': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x2-v2.pth',
|
@@ -425,44 +168,75 @@ def inference(image, upscale, large_input_flag, color_fix):
|
|
425 |
|
426 |
|
427 |
|
428 |
-
title = "
|
429 |
-
description =
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
If SAFMN is helpful, please help to ⭐ the <a href='https://github.com/sunny2109/SAFMN' target='_blank'>Github Repo</a>. Thanks!
|
434 |
-
[![GitHub Stars](https://img.shields.io/github/stars/sunny2109/SAFMN?style=social)](https://github.com/sunny2109/SAFMN)
|
435 |
|
436 |
-
|
437 |
-
|
438 |
|
439 |
-
If our work is useful for your research, please consider citing:
|
440 |
-
|
441 |
@inproceedings{sun2023safmn,
|
442 |
title={Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution},
|
443 |
author={Sun, Long and Dong, Jiangxin and Tang, Jinhui and Pan, Jinshan},
|
444 |
-
booktitle={
|
445 |
year={2023}
|
446 |
}
|
447 |
-
|
448 |
-
|
449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
450 |
"""
|
451 |
|
452 |
demo = gr.Interface(
|
453 |
-
|
454 |
-
|
455 |
-
gr.
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
gr.outputs.Image(type="numpy", label="Output"),
|
460 |
-
gr.outputs.File(label="Download the output")
|
461 |
],
|
|
|
|
|
|
|
|
|
462 |
title=title,
|
463 |
description=description,
|
464 |
-
article=article,
|
|
|
|
|
465 |
)
|
466 |
|
467 |
-
|
468 |
-
demo.launch()
|
|
|
8 |
import numpy as np
|
9 |
import os
|
10 |
import torch
|
|
|
11 |
import torch.nn.functional as F
|
12 |
import gradio as gr
|
13 |
|
14 |
+
from utils.download_url import load_file_from_url
|
15 |
+
from utils.color_fix import wavelet_reconstruction
|
16 |
+
from models.safmn_arch import SAFMN
|
17 |
|
18 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
########################################## Gradio inference ###################################
|
20 |
pretrain_model_url = {
|
21 |
'safmn_x2': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x2-v2.pth',
|
|
|
168 |
|
169 |
|
170 |
|
171 |
+
title = "SAFMN for Real-world SR"
|
172 |
+
description = ''' ### Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution - ICCV 2023
|
173 |
+
#### Long Sun](https://github.com/sunny2109), [Jiangxin Dong](https://scholar.google.com/citations?user=ruebFVEAAAAJ&hl=zh-CN&oi=ao), [Jinhui Tang](https://scholar.google.com/citations?user=ByBLlEwAAAAJ&hl=zh-CN), and [Jinshan Pan](https://jspan.github.io/)
|
174 |
+
#### [IMAG Lab](https://imag-njust.net/), Nanjing University of Science and Technology
|
175 |
+
|
|
|
|
|
176 |
|
177 |
+
#### Drag the slider on the super-resolution image left and right to see the changes in the image details. SAFMN performs x2/x4 upscaling on the input image.
|
178 |
+
<br>
|
179 |
|
180 |
+
### If our work is useful for your research, please consider citing:
|
181 |
+
<code>
|
182 |
@inproceedings{sun2023safmn,
|
183 |
title={Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution},
|
184 |
author={Sun, Long and Dong, Jiangxin and Tang, Jinhui and Pan, Jinshan},
|
185 |
+
booktitle={ICCV},
|
186 |
year={2023}
|
187 |
}
|
188 |
+
</code>
|
189 |
+
<br>
|
190 |
+
'''
|
191 |
+
|
192 |
+
|
193 |
+
article = "<p style='text-align: center'><a href='https://eduardzamfir.github.io/seemore' target='_blank'>See More Details: Efficient Image Super-Resolution by Experts Mining</a></p>"
|
194 |
+
|
195 |
+
#### Image,Prompts examples
|
196 |
+
examples = [
|
197 |
+
['images/0801x4.png'],
|
198 |
+
['images/0840x4.png'],
|
199 |
+
['images/0841x4.png'],
|
200 |
+
['images/0870x4.png'],
|
201 |
+
['images/0878x4.png'],
|
202 |
+
['images/0884x4.png'],
|
203 |
+
['images/0900x4.png'],
|
204 |
+
['images/img002x4.png'],
|
205 |
+
['images/img003x4.png'],
|
206 |
+
['images/img004x4.png'],
|
207 |
+
['images/img035x4.png'],
|
208 |
+
['images/img053x4.png'],
|
209 |
+
['images/img064x4.png'],
|
210 |
+
['images/img083x4.png'],
|
211 |
+
['images/img092x4.png'],
|
212 |
+
]
|
213 |
+
|
214 |
+
css = """
|
215 |
+
.image-frame img, .image-container img {
|
216 |
+
width: auto;
|
217 |
+
height: auto;
|
218 |
+
max-width: none;
|
219 |
+
}
|
220 |
"""
|
221 |
|
222 |
demo = gr.Interface(
|
223 |
+
fn=process_img,
|
224 |
+
inputs=[
|
225 |
+
gr.Image(type="pil", label="Input", value="real_testdata/004.png"),
|
226 |
+
gr.Number(default=2, label="Upscaling factor (up to 4)"),
|
227 |
+
gr.Checkbox(default=False, label="Memory-efficient inference"),
|
228 |
+
gr.Checkbox(default=False, label="Color correction"),
|
|
|
|
|
229 |
],
|
230 |
+
outputs=ImageSlider(label="Super-Resolved Image",
|
231 |
+
type="pil",
|
232 |
+
show_download_button=True,
|
233 |
+
),
|
234 |
title=title,
|
235 |
description=description,
|
236 |
+
article=article,
|
237 |
+
examples=examples,
|
238 |
+
css=css,
|
239 |
)
|
240 |
|
241 |
+
if __name__ == "__main__":
|
242 |
+
demo.launch()
|