jhj0517 commited on
Commit
1e3934e
1 Parent(s): f73fef2

Move file location

Browse files
modules/image_restoration/real_esrgan/__init__.py ADDED
File without changes
modules/image_restoration/{model_downloader.py → real_esrgan/model_downloader.py} RENAMED
@@ -1,8 +1,13 @@
1
  from modules.live_portrait.model_downloader import download_model
2
 
3
  MODELS_REALESRGAN_URL = {
4
- "RealESRGAN_x2": "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth",
5
  "realesr-general-x4v3": "https://huggingface.co/jhj0517/realesr-general-x4v3/resolve/main/realesr-general-x4v3.pth",
 
 
 
 
 
 
6
  }
7
 
8
 
 
1
  from modules.live_portrait.model_downloader import download_model
2
 
3
  MODELS_REALESRGAN_URL = {
 
4
  "realesr-general-x4v3": "https://huggingface.co/jhj0517/realesr-general-x4v3/resolve/main/realesr-general-x4v3.pth",
5
+ "RealESRGAN_x2": "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth",
6
+ }
7
+
8
+ MODELS_REALESRGAN_SCALABILITY = {
9
+ "realesr-general-x4v3": [1, 2, 4],
10
+ "RealESRGAN_x2": [2]
11
  }
12
 
13
 
modules/image_restoration/{real_esrgan_inferencer.py → real_esrgan/real_esrgan_inferencer.py} RENAMED
@@ -3,10 +3,14 @@ import gradio as gr
3
  import torch
4
  from PIL import Image
5
  import numpy as np
6
- from typing import Optional
 
 
7
 
8
  from modules.utils.paths import *
9
- from .model_downloader import download_resrgan_model, MODELS_REALESRGAN_URL
 
 
10
 
11
 
12
  class RealESRGANInferencer:
@@ -16,46 +20,88 @@ class RealESRGANInferencer:
16
  self.model_dir = model_dir
17
  self.output_dir = output_dir
18
  self.device = self.get_device()
 
19
  self.model = None
20
- self.up_sampler = None
21
  self.face_enhancer = None
22
 
23
  self.available_models = list(MODELS_REALESRGAN_URL.keys())
24
  self.default_model = self.available_models[0]
 
 
 
 
 
25
 
26
  def load_model(self,
27
  model_name: Optional[str] = None,
28
- scale: int = 1,
 
29
  progress: gr.Progress = gr.Progress()):
 
 
 
 
 
 
 
 
 
 
30
  if model_name is None:
31
  model_name = self.default_model
32
- if not model_name.endswith(".pth"):
33
- model_name += ".pth"
34
  model_path = os.path.join(self.model_dir, model_name)
 
 
35
 
36
  if not os.path.exists(model_path):
37
  progress(0, f"Downloading RealESRGAN model to : {model_path}")
38
- name, ext = os.path.splitext(model_name)
39
- download_resrgan_model(model_path, MODELS_REALESRGAN_URL[name])
 
 
 
 
 
 
 
 
40
 
41
- if self.model is None:
42
- self.model = RealESRGAN(self.device, scale=scale)
43
- self.model.load_weights(model_path=model_path, download=False)
 
 
 
 
44
 
45
  def restore_image(self,
46
  img_path: str,
 
 
 
47
  overwrite: bool = True):
48
- if self.model is None:
49
- self.load_model()
 
 
 
 
 
 
 
 
 
50
 
51
  try:
52
- img = Image.open(img_path).convert('RGB')
53
- sr_img = self.model.predict(img)
54
  if overwrite:
55
  output_path = img_path
56
  else:
57
  output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
58
- sr_img.save(output_path)
 
59
  return output_path
60
  except Exception as e:
61
  raise
 
3
  import torch
4
  from PIL import Image
5
  import numpy as np
6
+ from typing import Optional, Literal, List, Dict, Tuple, Union
7
+ from realesrgan.utils import RealESRGANer
8
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
9
 
10
  from modules.utils.paths import *
11
+ from .model_downloader import download_resrgan_model, MODELS_REALESRGAN_URL, MODELS_REALESRGAN_SCALABILITY
12
+ from modules.utils.image_helper import save_image
13
+ from .rrdb_net import RRDBNet
14
 
15
 
16
  class RealESRGANInferencer:
 
20
  self.model_dir = model_dir
21
  self.output_dir = output_dir
22
  self.device = self.get_device()
23
+ self.arc = None
24
  self.model = None
 
25
  self.face_enhancer = None
26
 
27
  self.available_models = list(MODELS_REALESRGAN_URL.keys())
28
  self.default_model = self.available_models[0]
29
+ self.model_config = {
30
+ "model_name": self.default_model,
31
+ "scale": 1,
32
+ "half_precision": True
33
+ }
34
 
35
  def load_model(self,
36
  model_name: Optional[str] = None,
37
+ scale: Literal[1, 2, 4] = 1,
38
+ half_precision: bool = True,
39
  progress: gr.Progress = gr.Progress()):
40
+ model_config = {
41
+ "model_name": model_name,
42
+ "scale": scale,
43
+ "half_precision": half_precision
44
+ }
45
+ if model_config == self.model_config and self.model is not None:
46
+ return
47
+ else:
48
+ self.model_config = model_config
49
+
50
  if model_name is None:
51
  model_name = self.default_model
52
+
 
53
  model_path = os.path.join(self.model_dir, model_name)
54
+ if not model_name.endswith(".pth"):
55
+ model_path += ".pth"
56
 
57
  if not os.path.exists(model_path):
58
  progress(0, f"Downloading RealESRGAN model to : {model_path}")
59
+ download_resrgan_model(model_path, MODELS_REALESRGAN_URL[model_name])
60
+
61
+ name, ext = os.path.splitext(model_name)
62
+ assert scale in MODELS_REALESRGAN_SCALABILITY[name]
63
+ if name == 'RealESRGAN_x2': # x4 RRDBNet model
64
+ arc = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
65
+ netscale = 4
66
+ else: # x4 VGG-style model (S size) : "realesr-general-x4v3"
67
+ arc = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
68
+ netscale = 4
69
 
70
+ self.model = RealESRGANer(
71
+ scale=netscale,
72
+ model_path=model_path,
73
+ model=arc,
74
+ half=half_precision,
75
+ )
76
+ self.model.device = torch.device(self.get_device())
77
 
78
  def restore_image(self,
79
  img_path: str,
80
+ model_name: Optional[str] = None,
81
+ scale: int = 1,
82
+ half_precision: bool = True,
83
  overwrite: bool = True):
84
+ model_config = {
85
+ "model_name": self.model_config["model_name"],
86
+ "scale": scale,
87
+ "half_precision": half_precision
88
+ }
89
+ if self.model is None or self.model_config != model_config:
90
+ self.load_model(
91
+ model_name=self.default_model if model_name is None else model_name,
92
+ scale=scale,
93
+ half_precision=half_precision
94
+ )
95
 
96
  try:
97
+ output, img_mode = self.model.enhance(img_path, outscale=scale)
98
+
99
  if overwrite:
100
  output_path = img_path
101
  else:
102
  output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
103
+
104
+ output_path = save_image(output, output_path=output_path)
105
  return output_path
106
  except Exception as e:
107
  raise
modules/image_restoration/real_esrgan/rrdb_net.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ import torch
3
+ from torch.nn import init as init
4
+ from torch.nn import functional as F
5
+ from torch.nn.modules.batchnorm import _BatchNorm
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+
11
+ Used in RRDB block in ESRGAN.
12
+
13
+ Args:
14
+ num_feat (int): Channel number of intermediate features.
15
+ num_grow_ch (int): Channels for each growth.
16
+ """
17
+
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+
26
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+
28
+ # initialization
29
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
+
31
+ def forward(self, x):
32
+ x1 = self.lrelu(self.conv1(x))
33
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
+ # Empirically, we use 0.2 to scale the residual for better performance
38
+ return x5 * 0.2 + x
39
+
40
+
41
+ class RRDB(nn.Module):
42
+ """Residual in Residual Dense Block.
43
+
44
+ Used in RRDB-Net in ESRGAN.
45
+
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ num_grow_ch (int): Channels for each growth.
49
+ """
50
+
51
+ def __init__(self, num_feat, num_grow_ch=32):
52
+ super(RRDB, self).__init__()
53
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+
57
+ def forward(self, x):
58
+ out = self.rdb1(x)
59
+ out = self.rdb2(out)
60
+ out = self.rdb3(out)
61
+ # Empirically, we use 0.2 to scale the residual for better performance
62
+ return out * 0.2 + x
63
+
64
+
65
+ class RRDBNet(nn.Module):
66
+ """Networks consisting of Residual in Residual Dense Block, which is used
67
+ in ESRGAN.
68
+
69
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
70
+
71
+ We extend ESRGAN for scale x2 and scale x1.
72
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
73
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
74
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
75
+
76
+ Args:
77
+ num_in_ch (int): Channel number of inputs.
78
+ num_out_ch (int): Channel number of outputs.
79
+ num_feat (int): Channel number of intermediate features.
80
+ Default: 64
81
+ num_block (int): Block number in the trunk network. Defaults: 23
82
+ num_grow_ch (int): Channels for each growth. Default: 32.
83
+ """
84
+
85
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
86
+ super(RRDBNet, self).__init__()
87
+ self.scale = scale
88
+ if scale == 2:
89
+ num_in_ch = num_in_ch * 4
90
+ elif scale == 1:
91
+ num_in_ch = num_in_ch * 16
92
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
93
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
94
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95
+ # upsample
96
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
100
+
101
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
102
+
103
+ def forward(self, x):
104
+ if self.scale == 2:
105
+ feat = pixel_unshuffle(x, scale=2)
106
+ elif self.scale == 1:
107
+ feat = pixel_unshuffle(x, scale=4)
108
+ else:
109
+ feat = x
110
+ feat = self.conv_first(feat)
111
+ body_feat = self.conv_body(self.body(feat))
112
+ feat = feat + body_feat
113
+ # upsample
114
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
115
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
116
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
117
+ return out
118
+
119
+
120
+ def make_layer(basic_block, num_basic_block, **kwarg):
121
+ """Make layers by stacking the same blocks.
122
+
123
+ Args:
124
+ basic_block (nn.module): nn.module class for basic block.
125
+ num_basic_block (int): number of blocks.
126
+
127
+ Returns:
128
+ nn.Sequential: Stacked blocks in nn.Sequential.
129
+ """
130
+ layers = []
131
+ for _ in range(num_basic_block):
132
+ layers.append(basic_block(**kwarg))
133
+ return nn.Sequential(*layers)
134
+
135
+
136
+ def pixel_unshuffle(x, scale):
137
+ """ Pixel unshuffle.
138
+
139
+ Args:
140
+ x (Tensor): Input feature with shape (b, c, hh, hw).
141
+ scale (int): Downsample ratio.
142
+
143
+ Returns:
144
+ Tensor: the pixel unshuffled feature.
145
+ """
146
+ b, c, hh, hw = x.size()
147
+ out_channel = c * (scale**2)
148
+ assert hh % scale == 0 and hw % scale == 0
149
+ h = hh // scale
150
+ w = hw // scale
151
+ x_view = x.view(b, c, h, scale, w, scale)
152
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
153
+
154
+ @torch.no_grad()
155
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
156
+ """Initialize network weights.
157
+
158
+ Args:
159
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
160
+ scale (float): Scale initialized weights, especially for residual
161
+ blocks. Default: 1.
162
+ bias_fill (float): The value to fill bias. Default: 0
163
+ kwargs (dict): Other arguments for initialization function.
164
+ """
165
+ if not isinstance(module_list, list):
166
+ module_list = [module_list]
167
+ for module in module_list:
168
+ for m in module.modules():
169
+ if isinstance(m, nn.Conv2d):
170
+ init.kaiming_normal_(m.weight, **kwargs)
171
+ m.weight.data *= scale
172
+ if m.bias is not None:
173
+ m.bias.data.fill_(bias_fill)
174
+ elif isinstance(m, nn.Linear):
175
+ init.kaiming_normal_(m.weight, **kwargs)
176
+ m.weight.data *= scale
177
+ if m.bias is not None:
178
+ m.bias.data.fill_(bias_fill)
179
+ elif isinstance(m, _BatchNorm):
180
+ init.constant_(m.weight, 1)
181
+ if m.bias is not None:
182
+ m.bias.data.fill_(bias_fill)