Meloo commited on
Commit
46255b1
β€’
1 Parent(s): f0b7d84

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +468 -0
app.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('SAFMN')
3
+
4
+ import os
5
+ import cv2
6
+ import argparse
7
+ import glob
8
+ import numpy as np
9
+ import os
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import gradio as gr
14
+
15
+
16
+ ########################################## Wavelet colorfix ###################################
17
+ from PIL import Image
18
+ from torch import Tensor
19
+ from torchvision.transforms import ToTensor, ToPILImage
20
+ def adain_color_fix(target: Image, source: Image):
21
+ # Convert images to tensors
22
+ to_tensor = ToTensor()
23
+ target_tensor = to_tensor(target).unsqueeze(0)
24
+ source_tensor = to_tensor(source).unsqueeze(0)
25
+
26
+ # Apply adaptive instance normalization
27
+ result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
28
+
29
+ # Convert tensor back to image
30
+ to_image = ToPILImage()
31
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
32
+
33
+ return result_image
34
+
35
+ def wavelet_color_fix(target: Image, source: Image):
36
+ if target.size() != source.size():
37
+ source = source.resize((target.size()[-2], target.size()[-1]), Image.LANCZOS)
38
+ # Convert images to tensors
39
+ to_tensor = ToTensor()
40
+ target_tensor = to_tensor(target).unsqueeze(0)
41
+ source_tensor = to_tensor(source).unsqueeze(0)
42
+
43
+ # Apply wavelet reconstruction
44
+ result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
45
+
46
+ # Convert tensor back to image
47
+ to_image = ToPILImage()
48
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
49
+
50
+ return result_image
51
+
52
+ def calc_mean_std(feat: Tensor, eps=1e-5):
53
+ """Calculate mean and std for adaptive_instance_normalization.
54
+ Args:
55
+ feat (Tensor): 4D tensor.
56
+ eps (float): A small value added to the variance to avoid
57
+ divide-by-zero. Default: 1e-5.
58
+ """
59
+ size = feat.size()
60
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
61
+ b, c = size[:2]
62
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
63
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
64
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
65
+ return feat_mean, feat_std
66
+
67
+ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
68
+ """Adaptive instance normalization.
69
+ Adjust the reference features to have the similar color and illuminations
70
+ as those in the degradate features.
71
+ Args:
72
+ content_feat (Tensor): The reference feature.
73
+ style_feat (Tensor): The degradate features.
74
+ """
75
+ size = content_feat.size()
76
+ style_mean, style_std = calc_mean_std(style_feat)
77
+ content_mean, content_std = calc_mean_std(content_feat)
78
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
79
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
80
+
81
+ def wavelet_blur(image: Tensor, radius: int):
82
+ """
83
+ Apply wavelet blur to the input tensor.
84
+ """
85
+ # input shape: (1, 3, H, W)
86
+ # convolution kernel
87
+ kernel_vals = [
88
+ [0.0625, 0.125, 0.0625],
89
+ [0.125, 0.25, 0.125],
90
+ [0.0625, 0.125, 0.0625],
91
+ ]
92
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
93
+ # add channel dimensions to the kernel to make it a 4D tensor
94
+ kernel = kernel[None, None]
95
+ # repeat the kernel across all input channels
96
+ kernel = kernel.repeat(3, 1, 1, 1)
97
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
98
+ # apply convolution
99
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
100
+ return output
101
+
102
+ def wavelet_decomposition(image: Tensor, levels=5):
103
+ """
104
+ Apply wavelet decomposition to the input tensor.
105
+ This function only returns the low frequency & the high frequency.
106
+ """
107
+ high_freq = torch.zeros_like(image)
108
+ for i in range(levels):
109
+ radius = 2 ** i
110
+ low_freq = wavelet_blur(image, radius)
111
+ high_freq += (image - low_freq)
112
+ image = low_freq
113
+
114
+ return high_freq, low_freq
115
+
116
+ def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
117
+ """
118
+ Apply wavelet decomposition, so that the content will have the same color as the style.
119
+ """
120
+ # calculate the wavelet decomposition of the content feature
121
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
122
+ del content_low_freq
123
+ # calculate the wavelet decomposition of the style feature
124
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
125
+ del style_high_freq
126
+ # reconstruct the content feature with the style's high frequency
127
+ return content_high_freq + style_low_freq
128
+
129
+
130
+ ########################################## URL Load ###################################
131
+ from torch.hub import download_url_to_file, get_dir
132
+ from urllib.parse import urlparse
133
+
134
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
135
+ """Load file form http url, will download models if necessary.
136
+
137
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
138
+
139
+ Args:
140
+ url (str): URL to be downloaded.
141
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
142
+ Default: None.
143
+ progress (bool): Whether to show the download progress. Default: True.
144
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
145
+
146
+ Returns:
147
+ str: The path to the downloaded file.
148
+ """
149
+ if model_dir is None: # use the pytorch hub_dir
150
+ hub_dir = get_dir()
151
+ model_dir = os.path.join(hub_dir, 'checkpoints')
152
+
153
+ os.makedirs(model_dir, exist_ok=True)
154
+
155
+ parts = urlparse(url)
156
+ filename = os.path.basename(parts.path)
157
+ if file_name is not None:
158
+ filename = file_name
159
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
160
+ if not os.path.exists(cached_file):
161
+ print(f'Downloading: "{url}" to {cached_file}\n')
162
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
163
+ return cached_file
164
+
165
+
166
+ ########################################## Model Define ###################################
167
+ # Layer Norm
168
+ class LayerNorm(nn.Module):
169
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
170
+ super().__init__()
171
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
172
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
173
+ self.eps = eps
174
+ self.data_format = data_format
175
+ if self.data_format not in ["channels_last", "channels_first"]:
176
+ raise NotImplementedError
177
+ self.normalized_shape = (normalized_shape, )
178
+
179
+ def forward(self, x):
180
+ if self.data_format == "channels_last":
181
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
182
+ elif self.data_format == "channels_first":
183
+ u = x.mean(1, keepdim=True)
184
+ s = (x - u).pow(2).mean(1, keepdim=True)
185
+ x = (x - u) / torch.sqrt(s + self.eps)
186
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
187
+ return x
188
+
189
+ # CCM
190
+ class CCM(nn.Module):
191
+ def __init__(self, dim, growth_rate=2.0):
192
+ super().__init__()
193
+ hidden_dim = int(dim * growth_rate)
194
+
195
+ self.ccm = nn.Sequential(
196
+ nn.Conv2d(dim, hidden_dim, 3, 1, 1),
197
+ nn.GELU(),
198
+ nn.Conv2d(hidden_dim, dim, 1, 1, 0)
199
+ )
200
+
201
+ def forward(self, x):
202
+ return self.ccm(x)
203
+
204
+
205
+ # SAFM
206
+ class SAFM(nn.Module):
207
+ def __init__(self, dim, n_levels=4):
208
+ super().__init__()
209
+ self.n_levels = n_levels
210
+ chunk_dim = dim // n_levels
211
+
212
+ # Spatial Weighting
213
+ self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
214
+
215
+ # # Feature Aggregation
216
+ self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
217
+
218
+ # Activation
219
+ self.act = nn.GELU()
220
+
221
+ def forward(self, x):
222
+ h, w = x.size()[-2:]
223
+
224
+ xc = x.chunk(self.n_levels, dim=1)
225
+ out = []
226
+ for i in range(self.n_levels):
227
+ if i > 0:
228
+ p_size = (h//2**i, w//2**i)
229
+ s = F.adaptive_max_pool2d(xc[i], p_size)
230
+ s = self.mfr[i](s)
231
+ s = F.interpolate(s, size=(h, w), mode='nearest')
232
+ else:
233
+ s = self.mfr[i](xc[i])
234
+ out.append(s)
235
+
236
+ out = self.aggr(torch.cat(out, dim=1))
237
+ out = self.act(out) * x
238
+ return out
239
+
240
+ class AttBlock(nn.Module):
241
+ def __init__(self, dim, ffn_scale=2.0):
242
+ super().__init__()
243
+
244
+ self.norm1 = LayerNorm(dim)
245
+ self.norm2 = LayerNorm(dim)
246
+
247
+ # Multiscale Block
248
+ self.safm = SAFM(dim)
249
+ # Feedforward layer
250
+ self.ccm = CCM(dim, ffn_scale)
251
+
252
+ def forward(self, x):
253
+ x = self.safm(self.norm1(x)) + x
254
+ x = self.ccm(self.norm2(x)) + x
255
+ return x
256
+
257
+
258
+ class SAFMN(nn.Module):
259
+ def __init__(self, dim, n_blocks=8, ffn_scale=2.0, upscaling_factor=4):
260
+ super().__init__()
261
+ self.to_feat = nn.Conv2d(3, dim, 3, 1, 1)
262
+
263
+ self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
264
+
265
+ self.to_img = nn.Sequential(
266
+ nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1),
267
+ nn.PixelShuffle(upscaling_factor)
268
+ )
269
+
270
+ def forward(self, x):
271
+ x = self.to_feat(x)
272
+ x = self.feats(x) + x
273
+ x = self.to_img(x)
274
+ return x
275
+
276
+ ########################################## Gradio inference ###################################
277
+ pretrain_model_url = {
278
+ 'safmn_x2': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x2-v2.pth',
279
+ 'safmn_x4': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x4-v2.pth',
280
+ }
281
+
282
+
283
+ # download weights
284
+ if not os.path.exists('./experiments/pretrained_models/SAFMN_L_Real_LSDIR_x2-v2.pth'):
285
+ load_file_from_url(url=pretrain_model_url['safmn_x2'], model_dir='./experiments/pretrained_models/', progress=True, file_name=None)
286
+
287
+ if not os.path.exists('./experiments/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'):
288
+ load_file_from_url(url=pretrain_model_url['safmn_x4'], model_dir='./experiments/pretrained_models/', progress=True, file_name=None)
289
+
290
+
291
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
292
+
293
+ def set_safmn(upscale):
294
+ model = SAFMN(dim=128, n_blocks=16, ffn_scale=2.0, upscaling_factor=upscale)
295
+ if upscale == 2:
296
+ model_path = './experiments/pretrained_models/SAFMN_L_Real_LSDIR_x2.pth'
297
+ elif upscale == 4:
298
+ model_path = './experiments/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'
299
+ else:
300
+ raise NotImplementedError('Only support x2/x4 upscaling!')
301
+
302
+ model.load_state_dict(torch.load(model_path)['params'], strict=True)
303
+ model.eval()
304
+ return model.to(device)
305
+
306
+
307
+ def img2patch(lq, scale=4, crop_size=512):
308
+ b, c, hl, wl = lq.size()
309
+ h, w = hl*scale, wl*scale
310
+ sr_size = (b, c, h, w)
311
+ assert b == 1
312
+
313
+ crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale
314
+
315
+ #adaptive step_i, step_j
316
+ num_row = (h - 1) // crop_size_h + 1
317
+ num_col = (w - 1) // crop_size_w + 1
318
+
319
+ import math
320
+ step_j = crop_size_w if num_col == 1 else math.ceil((w - crop_size_w) / (num_col - 1) - 1e-8)
321
+ step_i = crop_size_h if num_row == 1 else math.ceil((h - crop_size_h) / (num_row - 1) - 1e-8)
322
+
323
+ step_i = step_i // scale * scale
324
+ step_j = step_j // scale * scale
325
+
326
+ parts = []
327
+ idxes = []
328
+
329
+ i = 0 # 0~h-1
330
+ last_i = False
331
+ while i < h and not last_i:
332
+ j = 0
333
+ if i + crop_size_h >= h:
334
+ i = h - crop_size_h
335
+ last_i = True
336
+
337
+ last_j = False
338
+ while j < w and not last_j:
339
+ if j + crop_size_w >= w:
340
+ j = w - crop_size_w
341
+ last_j = True
342
+ parts.append(lq[:, :, i // scale :(i + crop_size_h) // scale, j // scale:(j + crop_size_w) // scale])
343
+ idxes.append({'i': i, 'j': j})
344
+ j = j + step_j
345
+ i = i + step_i
346
+
347
+ return torch.cat(parts, dim=0), idxes, sr_size
348
+
349
+
350
+ def patch2img(outs, idxes, sr_size, scale=4, crop_size=512):
351
+ preds = torch.zeros(sr_size).to(outs.device)
352
+ b, c, h, w = sr_size
353
+
354
+ count_mt = torch.zeros((b, 1, h, w)).to(outs.device)
355
+ crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale
356
+
357
+ for cnt, each_idx in enumerate(idxes):
358
+ i = each_idx['i']
359
+ j = each_idx['j']
360
+ preds[0, :, i: i + crop_size_h, j: j + crop_size_w] += outs[cnt]
361
+ count_mt[0, 0, i: i + crop_size_h, j: j + crop_size_w] += 1.
362
+
363
+ return (preds / count_mt).to(outs.device)
364
+
365
+
366
+ os.makedirs('./results', exist_ok=True)
367
+
368
+ def inference(image, upscale, large_input_flag, color_fix):
369
+ upscale = int(upscale) # convert type to int
370
+ if upscale > 4:
371
+ upscale = 4
372
+ if 0 < upscale < 3:
373
+ upscale = 2
374
+
375
+ model = set_safmn(upscale)
376
+
377
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
378
+ print(f'input size: {img.shape}')
379
+
380
+ # img2tensor
381
+ img = img.astype(np.float32) / 255.
382
+ img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
383
+ img = img.unsqueeze(0).to(device)
384
+
385
+ # inference
386
+ if large_input_flag:
387
+ patches, idx, size = img2patch(img, scale=upscale)
388
+ with torch.no_grad():
389
+ n = len(patches)
390
+ outs = []
391
+ m = 1
392
+ i = 0
393
+ while i < n:
394
+ j = i + m
395
+ if j >= n:
396
+ j = n
397
+ pred = output = model(patches[i:j])
398
+ if isinstance(pred, list):
399
+ pred = pred[-1]
400
+ outs.append(pred.detach())
401
+ i = j
402
+ output = torch.cat(outs, dim=0)
403
+
404
+ output = patch2img(output, idx, size, scale=upscale)
405
+ else:
406
+ with torch.no_grad():
407
+ output = model(img)
408
+
409
+ # color fix
410
+ if color_fix:
411
+ img = F.interpolate(img, scale_factor=upscale, mode='bilinear')
412
+ output = wavelet_reconstruction(output, img)
413
+ # tensor2img
414
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
415
+ if output.ndim == 3:
416
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
417
+ output = (output * 255.0).round().astype(np.uint8)
418
+
419
+ # save restored img
420
+ save_path = f'results/out.png'
421
+ cv2.imwrite(save_path, output)
422
+
423
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
424
+ return output, save_path
425
+
426
+
427
+
428
+ title = "Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution"
429
+ description = r"""
430
+ <b>Official Gradio demo</b> for <a href='https://github.com/sunny2109/SAFMN' target='_blank'><b>Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution (ICCV 2023)</b></a>.<br>
431
+ """
432
+ article = r"""
433
+ If SAFMN is helpful, please help to ⭐ the <a href='https://github.com/sunny2109/SAFMN' target='_blank'>Github Repo</a>. Thanks!
434
+ [![GitHub Stars](https://img.shields.io/github/stars/sunny2109/SAFMN?style=social)](https://github.com/sunny2109/SAFMN)
435
+
436
+ ---
437
+ πŸ“ **Citation**
438
+
439
+ If our work is useful for your research, please consider citing:
440
+ ```bibtex
441
+ @inproceedings{sun2023safmn,
442
+ title={Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution},
443
+ author={Sun, Long and Dong, Jiangxin and Tang, Jinhui and Pan, Jinshan},
444
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
445
+ year={2023}
446
+ }
447
+ ```
448
+
449
+ <center><img src='https://visitor-badge.laobi.icu/badge?page_id=sunny2109/SAFMN' alt='visitors'></center>
450
+ """
451
+
452
+ demo = gr.Interface(
453
+ inference, [
454
+ gr.inputs.Image(type="filepath", label="Input"),
455
+ gr.inputs.Number(default=2, label="Upscaling factor (up to 4)"),
456
+ gr.inputs.Checkbox(default=False, label="Memory-efficient inference"),
457
+ gr.inputs.Checkbox(default=False, label="Color correction"),
458
+ ], [
459
+ gr.outputs.Image(type="numpy", label="Output"),
460
+ gr.outputs.File(label="Download the output")
461
+ ],
462
+ title=title,
463
+ description=description,
464
+ article=article,
465
+ )
466
+
467
+ demo.queue(concurrency_count=2)
468
+ demo.launch()