Meloo commited on
Commit
d966562
·
verified ·
1 Parent(s): 333eb80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -287
app.py CHANGED
@@ -8,271 +8,14 @@ import glob
8
  import numpy as np
9
  import os
10
  import torch
11
- import torch.nn as nn
12
  import torch.nn.functional as F
13
  import gradio as gr
14
 
 
 
 
15
 
16
- ########################################## Wavelet colorfix ###################################
17
- from PIL import Image
18
- from torch import Tensor
19
- from torchvision.transforms import ToTensor, ToPILImage
20
- def adain_color_fix(target: Image, source: Image):
21
- # Convert images to tensors
22
- to_tensor = ToTensor()
23
- target_tensor = to_tensor(target).unsqueeze(0)
24
- source_tensor = to_tensor(source).unsqueeze(0)
25
-
26
- # Apply adaptive instance normalization
27
- result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
28
-
29
- # Convert tensor back to image
30
- to_image = ToPILImage()
31
- result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
32
-
33
- return result_image
34
-
35
- def wavelet_color_fix(target: Image, source: Image):
36
- if target.size() != source.size():
37
- source = source.resize((target.size()[-2], target.size()[-1]), Image.LANCZOS)
38
- # Convert images to tensors
39
- to_tensor = ToTensor()
40
- target_tensor = to_tensor(target).unsqueeze(0)
41
- source_tensor = to_tensor(source).unsqueeze(0)
42
-
43
- # Apply wavelet reconstruction
44
- result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
45
-
46
- # Convert tensor back to image
47
- to_image = ToPILImage()
48
- result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
49
-
50
- return result_image
51
-
52
- def calc_mean_std(feat: Tensor, eps=1e-5):
53
- """Calculate mean and std for adaptive_instance_normalization.
54
- Args:
55
- feat (Tensor): 4D tensor.
56
- eps (float): A small value added to the variance to avoid
57
- divide-by-zero. Default: 1e-5.
58
- """
59
- size = feat.size()
60
- assert len(size) == 4, 'The input feature should be 4D tensor.'
61
- b, c = size[:2]
62
- feat_var = feat.view(b, c, -1).var(dim=2) + eps
63
- feat_std = feat_var.sqrt().view(b, c, 1, 1)
64
- feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
65
- return feat_mean, feat_std
66
-
67
- def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
68
- """Adaptive instance normalization.
69
- Adjust the reference features to have the similar color and illuminations
70
- as those in the degradate features.
71
- Args:
72
- content_feat (Tensor): The reference feature.
73
- style_feat (Tensor): The degradate features.
74
- """
75
- size = content_feat.size()
76
- style_mean, style_std = calc_mean_std(style_feat)
77
- content_mean, content_std = calc_mean_std(content_feat)
78
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
79
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
80
-
81
- def wavelet_blur(image: Tensor, radius: int):
82
- """
83
- Apply wavelet blur to the input tensor.
84
- """
85
- # input shape: (1, 3, H, W)
86
- # convolution kernel
87
- kernel_vals = [
88
- [0.0625, 0.125, 0.0625],
89
- [0.125, 0.25, 0.125],
90
- [0.0625, 0.125, 0.0625],
91
- ]
92
- kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
93
- # add channel dimensions to the kernel to make it a 4D tensor
94
- kernel = kernel[None, None]
95
- # repeat the kernel across all input channels
96
- kernel = kernel.repeat(3, 1, 1, 1)
97
- image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
98
- # apply convolution
99
- output = F.conv2d(image, kernel, groups=3, dilation=radius)
100
- return output
101
-
102
- def wavelet_decomposition(image: Tensor, levels=5):
103
- """
104
- Apply wavelet decomposition to the input tensor.
105
- This function only returns the low frequency & the high frequency.
106
- """
107
- high_freq = torch.zeros_like(image)
108
- for i in range(levels):
109
- radius = 2 ** i
110
- low_freq = wavelet_blur(image, radius)
111
- high_freq += (image - low_freq)
112
- image = low_freq
113
-
114
- return high_freq, low_freq
115
-
116
- def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
117
- """
118
- Apply wavelet decomposition, so that the content will have the same color as the style.
119
- """
120
- # calculate the wavelet decomposition of the content feature
121
- content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
122
- del content_low_freq
123
- # calculate the wavelet decomposition of the style feature
124
- style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
125
- del style_high_freq
126
- # reconstruct the content feature with the style's high frequency
127
- return content_high_freq + style_low_freq
128
-
129
-
130
- ########################################## URL Load ###################################
131
- from torch.hub import download_url_to_file, get_dir
132
- from urllib.parse import urlparse
133
-
134
- def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
135
- """Load file form http url, will download models if necessary.
136
-
137
- Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
138
-
139
- Args:
140
- url (str): URL to be downloaded.
141
- model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
142
- Default: None.
143
- progress (bool): Whether to show the download progress. Default: True.
144
- file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
145
-
146
- Returns:
147
- str: The path to the downloaded file.
148
- """
149
- if model_dir is None: # use the pytorch hub_dir
150
- hub_dir = get_dir()
151
- model_dir = os.path.join(hub_dir, 'checkpoints')
152
-
153
- os.makedirs(model_dir, exist_ok=True)
154
-
155
- parts = urlparse(url)
156
- filename = os.path.basename(parts.path)
157
- if file_name is not None:
158
- filename = file_name
159
- cached_file = os.path.abspath(os.path.join(model_dir, filename))
160
- if not os.path.exists(cached_file):
161
- print(f'Downloading: "{url}" to {cached_file}\n')
162
- download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
163
- return cached_file
164
-
165
-
166
- ########################################## Model Define ###################################
167
- # Layer Norm
168
- class LayerNorm(nn.Module):
169
- def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
170
- super().__init__()
171
- self.weight = nn.Parameter(torch.ones(normalized_shape))
172
- self.bias = nn.Parameter(torch.zeros(normalized_shape))
173
- self.eps = eps
174
- self.data_format = data_format
175
- if self.data_format not in ["channels_last", "channels_first"]:
176
- raise NotImplementedError
177
- self.normalized_shape = (normalized_shape, )
178
-
179
- def forward(self, x):
180
- if self.data_format == "channels_last":
181
- return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
182
- elif self.data_format == "channels_first":
183
- u = x.mean(1, keepdim=True)
184
- s = (x - u).pow(2).mean(1, keepdim=True)
185
- x = (x - u) / torch.sqrt(s + self.eps)
186
- x = self.weight[:, None, None] * x + self.bias[:, None, None]
187
- return x
188
-
189
- # CCM
190
- class CCM(nn.Module):
191
- def __init__(self, dim, growth_rate=2.0):
192
- super().__init__()
193
- hidden_dim = int(dim * growth_rate)
194
-
195
- self.ccm = nn.Sequential(
196
- nn.Conv2d(dim, hidden_dim, 3, 1, 1),
197
- nn.GELU(),
198
- nn.Conv2d(hidden_dim, dim, 1, 1, 0)
199
- )
200
-
201
- def forward(self, x):
202
- return self.ccm(x)
203
-
204
-
205
- # SAFM
206
- class SAFM(nn.Module):
207
- def __init__(self, dim, n_levels=4):
208
- super().__init__()
209
- self.n_levels = n_levels
210
- chunk_dim = dim // n_levels
211
-
212
- # Spatial Weighting
213
- self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
214
-
215
- # # Feature Aggregation
216
- self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
217
-
218
- # Activation
219
- self.act = nn.GELU()
220
-
221
- def forward(self, x):
222
- h, w = x.size()[-2:]
223
-
224
- xc = x.chunk(self.n_levels, dim=1)
225
- out = []
226
- for i in range(self.n_levels):
227
- if i > 0:
228
- p_size = (h//2**i, w//2**i)
229
- s = F.adaptive_max_pool2d(xc[i], p_size)
230
- s = self.mfr[i](s)
231
- s = F.interpolate(s, size=(h, w), mode='nearest')
232
- else:
233
- s = self.mfr[i](xc[i])
234
- out.append(s)
235
-
236
- out = self.aggr(torch.cat(out, dim=1))
237
- out = self.act(out) * x
238
- return out
239
-
240
- class AttBlock(nn.Module):
241
- def __init__(self, dim, ffn_scale=2.0):
242
- super().__init__()
243
-
244
- self.norm1 = LayerNorm(dim)
245
- self.norm2 = LayerNorm(dim)
246
-
247
- # Multiscale Block
248
- self.safm = SAFM(dim)
249
- # Feedforward layer
250
- self.ccm = CCM(dim, ffn_scale)
251
-
252
- def forward(self, x):
253
- x = self.safm(self.norm1(x)) + x
254
- x = self.ccm(self.norm2(x)) + x
255
- return x
256
-
257
-
258
- class SAFMN(nn.Module):
259
- def __init__(self, dim, n_blocks=8, ffn_scale=2.0, upscaling_factor=4):
260
- super().__init__()
261
- self.to_feat = nn.Conv2d(3, dim, 3, 1, 1)
262
-
263
- self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
264
-
265
- self.to_img = nn.Sequential(
266
- nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1),
267
- nn.PixelShuffle(upscaling_factor)
268
- )
269
-
270
- def forward(self, x):
271
- x = self.to_feat(x)
272
- x = self.feats(x) + x
273
- x = self.to_img(x)
274
- return x
275
-
276
  ########################################## Gradio inference ###################################
277
  pretrain_model_url = {
278
  'safmn_x2': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x2-v2.pth',
@@ -425,44 +168,75 @@ def inference(image, upscale, large_input_flag, color_fix):
425
 
426
 
427
 
428
- title = "Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution"
429
- description = r"""
430
- <b>Official Gradio demo</b> for <a href='https://github.com/sunny2109/SAFMN' target='_blank'><b>Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution (ICCV 2023)</b></a>.<br>
431
- """
432
- article = r"""
433
- If SAFMN is helpful, please help to ⭐ the <a href='https://github.com/sunny2109/SAFMN' target='_blank'>Github Repo</a>. Thanks!
434
- [![GitHub Stars](https://img.shields.io/github/stars/sunny2109/SAFMN?style=social)](https://github.com/sunny2109/SAFMN)
435
 
436
- ---
437
- 📝 **Citation**
438
 
439
- If our work is useful for your research, please consider citing:
440
- ```bibtex
441
  @inproceedings{sun2023safmn,
442
  title={Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution},
443
  author={Sun, Long and Dong, Jiangxin and Tang, Jinhui and Pan, Jinshan},
444
- booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
445
  year={2023}
446
  }
447
- ```
448
-
449
- <center><img src='https://visitor-badge.laobi.icu/badge?page_id=sunny2109/SAFMN' alt='visitors'></center>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  """
451
 
452
  demo = gr.Interface(
453
- inference, [
454
- gr.inputs.Image(type="filepath", label="Input"),
455
- gr.inputs.Number(default=2, label="Upscaling factor (up to 4)"),
456
- gr.inputs.Checkbox(default=False, label="Memory-efficient inference"),
457
- gr.inputs.Checkbox(default=False, label="Color correction"),
458
- ], [
459
- gr.outputs.Image(type="numpy", label="Output"),
460
- gr.outputs.File(label="Download the output")
461
  ],
 
 
 
 
462
  title=title,
463
  description=description,
464
- article=article,
 
 
465
  )
466
 
467
- demo.queue(concurrency_count=2)
468
- demo.launch()
 
8
  import numpy as np
9
  import os
10
  import torch
 
11
  import torch.nn.functional as F
12
  import gradio as gr
13
 
14
+ from utils.download_url import load_file_from_url
15
+ from utils.color_fix import wavelet_reconstruction
16
+ from models.safmn_arch import SAFMN
17
 
18
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  ########################################## Gradio inference ###################################
20
  pretrain_model_url = {
21
  'safmn_x2': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x2-v2.pth',
 
168
 
169
 
170
 
171
+ title = "SAFMN for Real-world SR"
172
+ description = ''' ### Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution - ICCV 2023
173
+ #### Long Sun](https://github.com/sunny2109), [Jiangxin Dong](https://scholar.google.com/citations?user=ruebFVEAAAAJ&hl=zh-CN&oi=ao), [Jinhui Tang](https://scholar.google.com/citations?user=ByBLlEwAAAAJ&hl=zh-CN), and [Jinshan Pan](https://jspan.github.io/)
174
+ #### [IMAG Lab](https://imag-njust.net/), Nanjing University of Science and Technology
175
+
 
 
176
 
177
+ #### Drag the slider on the super-resolution image left and right to see the changes in the image details. SAFMN performs x2/x4 upscaling on the input image.
178
+ <br>
179
 
180
+ ### If our work is useful for your research, please consider citing:
181
+ <code>
182
  @inproceedings{sun2023safmn,
183
  title={Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution},
184
  author={Sun, Long and Dong, Jiangxin and Tang, Jinhui and Pan, Jinshan},
185
+ booktitle={ICCV},
186
  year={2023}
187
  }
188
+ </code>
189
+ <br>
190
+ '''
191
+
192
+
193
+ article = "<p style='text-align: center'><a href='https://eduardzamfir.github.io/seemore' target='_blank'>See More Details: Efficient Image Super-Resolution by Experts Mining</a></p>"
194
+
195
+ #### Image,Prompts examples
196
+ examples = [
197
+ ['images/0801x4.png'],
198
+ ['images/0840x4.png'],
199
+ ['images/0841x4.png'],
200
+ ['images/0870x4.png'],
201
+ ['images/0878x4.png'],
202
+ ['images/0884x4.png'],
203
+ ['images/0900x4.png'],
204
+ ['images/img002x4.png'],
205
+ ['images/img003x4.png'],
206
+ ['images/img004x4.png'],
207
+ ['images/img035x4.png'],
208
+ ['images/img053x4.png'],
209
+ ['images/img064x4.png'],
210
+ ['images/img083x4.png'],
211
+ ['images/img092x4.png'],
212
+ ]
213
+
214
+ css = """
215
+ .image-frame img, .image-container img {
216
+ width: auto;
217
+ height: auto;
218
+ max-width: none;
219
+ }
220
  """
221
 
222
  demo = gr.Interface(
223
+ fn=process_img,
224
+ inputs=[
225
+ gr.Image(type="pil", label="Input", value="real_testdata/004.png"),
226
+ gr.Number(default=2, label="Upscaling factor (up to 4)"),
227
+ gr.Checkbox(default=False, label="Memory-efficient inference"),
228
+ gr.Checkbox(default=False, label="Color correction"),
 
 
229
  ],
230
+ outputs=ImageSlider(label="Super-Resolved Image",
231
+ type="pil",
232
+ show_download_button=True,
233
+ ),
234
  title=title,
235
  description=description,
236
+ article=article,
237
+ examples=examples,
238
+ css=css,
239
  )
240
 
241
+ if __name__ == "__main__":
242
+ demo.launch()