NimaBoscarino
commited on
Commit
•
e048e19
1
Parent(s):
60dd3db
Remove model code, only keep model checkpoint
Browse files- README.md +2 -0
- archs/__init__.py +0 -12
- archs/arcface_arch.py +0 -198
- archs/gfpganv1_arch.py +0 -418
- data/__init__.py +0 -11
- data/ffhq_degradation_dataset.py +0 -213
- experiments/pretrained_models/GFPGANv1.pth +0 -3
- experiments/pretrained_models/README.md +0 -7
- inference_gfpgan_full.py +0 -130
- models/__init__.py +0 -12
- models/gfpgan_model.py +0 -562
- requirements.txt +0 -10
- setup.cfg +0 -22
- train.py +0 -10
- train_gfpgan_v1.yml +0 -210
README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
|
3 |
[**Paper**](https://arxiv.org/abs/2101.04061) **|** [**Project Page**](https://xinntao.github.io/projects/gfpgan)    [English](README.md) **|** [简体中文](README_CN.md)
|
4 |
|
|
|
|
|
5 |
GFPGAN is a blind face restoration algorithm towards real-world face images.
|
6 |
|
7 |
<a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
|
|
|
2 |
|
3 |
[**Paper**](https://arxiv.org/abs/2101.04061) **|** [**Project Page**](https://xinntao.github.io/projects/gfpgan)    [English](README.md) **|** [简体中文](README_CN.md)
|
4 |
|
5 |
+
GitHub: https://github.com/TencentARC/GFPGAN
|
6 |
+
|
7 |
GFPGAN is a blind face restoration algorithm towards real-world face images.
|
8 |
|
9 |
<a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
|
archs/__init__.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
import importlib
|
2 |
-
from os import path as osp
|
3 |
-
|
4 |
-
from basicsr.utils import scandir
|
5 |
-
|
6 |
-
# automatically scan and import arch modules for registry
|
7 |
-
# scan all the files under the 'archs' folder and collect files ending with
|
8 |
-
# '_arch.py'
|
9 |
-
arch_folder = osp.dirname(osp.abspath(__file__))
|
10 |
-
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
11 |
-
# import all the arch modules
|
12 |
-
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
archs/arcface_arch.py
DELETED
@@ -1,198 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
|
3 |
-
from basicsr.utils.registry import ARCH_REGISTRY
|
4 |
-
|
5 |
-
|
6 |
-
def conv3x3(in_planes, out_planes, stride=1):
|
7 |
-
"""3x3 convolution with padding"""
|
8 |
-
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
9 |
-
|
10 |
-
|
11 |
-
class BasicBlock(nn.Module):
|
12 |
-
expansion = 1
|
13 |
-
|
14 |
-
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
15 |
-
super(BasicBlock, self).__init__()
|
16 |
-
self.conv1 = conv3x3(inplanes, planes, stride)
|
17 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
18 |
-
self.relu = nn.ReLU(inplace=True)
|
19 |
-
self.conv2 = conv3x3(planes, planes)
|
20 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
21 |
-
self.downsample = downsample
|
22 |
-
self.stride = stride
|
23 |
-
|
24 |
-
def forward(self, x):
|
25 |
-
residual = x
|
26 |
-
|
27 |
-
out = self.conv1(x)
|
28 |
-
out = self.bn1(out)
|
29 |
-
out = self.relu(out)
|
30 |
-
|
31 |
-
out = self.conv2(out)
|
32 |
-
out = self.bn2(out)
|
33 |
-
|
34 |
-
if self.downsample is not None:
|
35 |
-
residual = self.downsample(x)
|
36 |
-
|
37 |
-
out += residual
|
38 |
-
out = self.relu(out)
|
39 |
-
|
40 |
-
return out
|
41 |
-
|
42 |
-
|
43 |
-
class IRBlock(nn.Module):
|
44 |
-
expansion = 1
|
45 |
-
|
46 |
-
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
47 |
-
super(IRBlock, self).__init__()
|
48 |
-
self.bn0 = nn.BatchNorm2d(inplanes)
|
49 |
-
self.conv1 = conv3x3(inplanes, inplanes)
|
50 |
-
self.bn1 = nn.BatchNorm2d(inplanes)
|
51 |
-
self.prelu = nn.PReLU()
|
52 |
-
self.conv2 = conv3x3(inplanes, planes, stride)
|
53 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
54 |
-
self.downsample = downsample
|
55 |
-
self.stride = stride
|
56 |
-
self.use_se = use_se
|
57 |
-
if self.use_se:
|
58 |
-
self.se = SEBlock(planes)
|
59 |
-
|
60 |
-
def forward(self, x):
|
61 |
-
residual = x
|
62 |
-
out = self.bn0(x)
|
63 |
-
out = self.conv1(out)
|
64 |
-
out = self.bn1(out)
|
65 |
-
out = self.prelu(out)
|
66 |
-
|
67 |
-
out = self.conv2(out)
|
68 |
-
out = self.bn2(out)
|
69 |
-
if self.use_se:
|
70 |
-
out = self.se(out)
|
71 |
-
|
72 |
-
if self.downsample is not None:
|
73 |
-
residual = self.downsample(x)
|
74 |
-
|
75 |
-
out += residual
|
76 |
-
out = self.prelu(out)
|
77 |
-
|
78 |
-
return out
|
79 |
-
|
80 |
-
|
81 |
-
class Bottleneck(nn.Module):
|
82 |
-
expansion = 4
|
83 |
-
|
84 |
-
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
85 |
-
super(Bottleneck, self).__init__()
|
86 |
-
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
87 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
88 |
-
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
89 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
90 |
-
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
91 |
-
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
92 |
-
self.relu = nn.ReLU(inplace=True)
|
93 |
-
self.downsample = downsample
|
94 |
-
self.stride = stride
|
95 |
-
|
96 |
-
def forward(self, x):
|
97 |
-
residual = x
|
98 |
-
|
99 |
-
out = self.conv1(x)
|
100 |
-
out = self.bn1(out)
|
101 |
-
out = self.relu(out)
|
102 |
-
|
103 |
-
out = self.conv2(out)
|
104 |
-
out = self.bn2(out)
|
105 |
-
out = self.relu(out)
|
106 |
-
|
107 |
-
out = self.conv3(out)
|
108 |
-
out = self.bn3(out)
|
109 |
-
|
110 |
-
if self.downsample is not None:
|
111 |
-
residual = self.downsample(x)
|
112 |
-
|
113 |
-
out += residual
|
114 |
-
out = self.relu(out)
|
115 |
-
|
116 |
-
return out
|
117 |
-
|
118 |
-
|
119 |
-
class SEBlock(nn.Module):
|
120 |
-
|
121 |
-
def __init__(self, channel, reduction=16):
|
122 |
-
super(SEBlock, self).__init__()
|
123 |
-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
124 |
-
self.fc = nn.Sequential(
|
125 |
-
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
126 |
-
nn.Sigmoid())
|
127 |
-
|
128 |
-
def forward(self, x):
|
129 |
-
b, c, _, _ = x.size()
|
130 |
-
y = self.avg_pool(x).view(b, c)
|
131 |
-
y = self.fc(y).view(b, c, 1, 1)
|
132 |
-
return x * y
|
133 |
-
|
134 |
-
|
135 |
-
@ARCH_REGISTRY.register()
|
136 |
-
class ResNetArcFace(nn.Module):
|
137 |
-
|
138 |
-
def __init__(self, block, layers, use_se=True):
|
139 |
-
if block == 'IRBlock':
|
140 |
-
block = IRBlock
|
141 |
-
self.inplanes = 64
|
142 |
-
self.use_se = use_se
|
143 |
-
super(ResNetArcFace, self).__init__()
|
144 |
-
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
145 |
-
self.bn1 = nn.BatchNorm2d(64)
|
146 |
-
self.prelu = nn.PReLU()
|
147 |
-
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
148 |
-
self.layer1 = self._make_layer(block, 64, layers[0])
|
149 |
-
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
150 |
-
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
151 |
-
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
152 |
-
self.bn4 = nn.BatchNorm2d(512)
|
153 |
-
self.dropout = nn.Dropout()
|
154 |
-
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
155 |
-
self.bn5 = nn.BatchNorm1d(512)
|
156 |
-
|
157 |
-
for m in self.modules():
|
158 |
-
if isinstance(m, nn.Conv2d):
|
159 |
-
nn.init.xavier_normal_(m.weight)
|
160 |
-
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
161 |
-
nn.init.constant_(m.weight, 1)
|
162 |
-
nn.init.constant_(m.bias, 0)
|
163 |
-
elif isinstance(m, nn.Linear):
|
164 |
-
nn.init.xavier_normal_(m.weight)
|
165 |
-
nn.init.constant_(m.bias, 0)
|
166 |
-
|
167 |
-
def _make_layer(self, block, planes, blocks, stride=1):
|
168 |
-
downsample = None
|
169 |
-
if stride != 1 or self.inplanes != planes * block.expansion:
|
170 |
-
downsample = nn.Sequential(
|
171 |
-
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
172 |
-
nn.BatchNorm2d(planes * block.expansion),
|
173 |
-
)
|
174 |
-
layers = []
|
175 |
-
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
176 |
-
self.inplanes = planes
|
177 |
-
for _ in range(1, blocks):
|
178 |
-
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
179 |
-
|
180 |
-
return nn.Sequential(*layers)
|
181 |
-
|
182 |
-
def forward(self, x):
|
183 |
-
x = self.conv1(x)
|
184 |
-
x = self.bn1(x)
|
185 |
-
x = self.prelu(x)
|
186 |
-
x = self.maxpool(x)
|
187 |
-
|
188 |
-
x = self.layer1(x)
|
189 |
-
x = self.layer2(x)
|
190 |
-
x = self.layer3(x)
|
191 |
-
x = self.layer4(x)
|
192 |
-
x = self.bn4(x)
|
193 |
-
x = self.dropout(x)
|
194 |
-
x = x.view(x.size(0), -1)
|
195 |
-
x = self.fc5(x)
|
196 |
-
x = self.bn5(x)
|
197 |
-
|
198 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
archs/gfpganv1_arch.py
DELETED
@@ -1,418 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import random
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
from torch.nn import functional as F
|
6 |
-
|
7 |
-
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
8 |
-
StyleGAN2Generator)
|
9 |
-
from basicsr.ops.fused_act import FusedLeakyReLU
|
10 |
-
from basicsr.utils.registry import ARCH_REGISTRY
|
11 |
-
|
12 |
-
|
13 |
-
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
14 |
-
"""StyleGAN2 Generator.
|
15 |
-
|
16 |
-
Args:
|
17 |
-
out_size (int): The spatial size of outputs.
|
18 |
-
num_style_feat (int): Channel number of style features. Default: 512.
|
19 |
-
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
20 |
-
channel_multiplier (int): Channel multiplier for large networks of
|
21 |
-
StyleGAN2. Default: 2.
|
22 |
-
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
23 |
-
magnitude. A cross production will be applied to extent 1D resample
|
24 |
-
kenrel to 2D resample kernel. Default: [1, 3, 3, 1].
|
25 |
-
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
26 |
-
"""
|
27 |
-
|
28 |
-
def __init__(self,
|
29 |
-
out_size,
|
30 |
-
num_style_feat=512,
|
31 |
-
num_mlp=8,
|
32 |
-
channel_multiplier=2,
|
33 |
-
resample_kernel=(1, 3, 3, 1),
|
34 |
-
lr_mlp=0.01,
|
35 |
-
narrow=1,
|
36 |
-
sft_half=False):
|
37 |
-
super(StyleGAN2GeneratorSFT, self).__init__(
|
38 |
-
out_size,
|
39 |
-
num_style_feat=num_style_feat,
|
40 |
-
num_mlp=num_mlp,
|
41 |
-
channel_multiplier=channel_multiplier,
|
42 |
-
resample_kernel=resample_kernel,
|
43 |
-
lr_mlp=lr_mlp,
|
44 |
-
narrow=narrow)
|
45 |
-
self.sft_half = sft_half
|
46 |
-
|
47 |
-
def forward(self,
|
48 |
-
styles,
|
49 |
-
conditions,
|
50 |
-
input_is_latent=False,
|
51 |
-
noise=None,
|
52 |
-
randomize_noise=True,
|
53 |
-
truncation=1,
|
54 |
-
truncation_latent=None,
|
55 |
-
inject_index=None,
|
56 |
-
return_latents=False):
|
57 |
-
"""Forward function for StyleGAN2Generator.
|
58 |
-
|
59 |
-
Args:
|
60 |
-
styles (list[Tensor]): Sample codes of styles.
|
61 |
-
input_is_latent (bool): Whether input is latent style.
|
62 |
-
Default: False.
|
63 |
-
noise (Tensor | None): Input noise or None. Default: None.
|
64 |
-
randomize_noise (bool): Randomize noise, used when 'noise' is
|
65 |
-
False. Default: True.
|
66 |
-
truncation (float): TODO. Default: 1.
|
67 |
-
truncation_latent (Tensor | None): TODO. Default: None.
|
68 |
-
inject_index (int | None): The injection index for mixing noise.
|
69 |
-
Default: None.
|
70 |
-
return_latents (bool): Whether to return style latents.
|
71 |
-
Default: False.
|
72 |
-
"""
|
73 |
-
# style codes -> latents with Style MLP layer
|
74 |
-
if not input_is_latent:
|
75 |
-
styles = [self.style_mlp(s) for s in styles]
|
76 |
-
# noises
|
77 |
-
if noise is None:
|
78 |
-
if randomize_noise:
|
79 |
-
noise = [None] * self.num_layers # for each style conv layer
|
80 |
-
else: # use the stored noise
|
81 |
-
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
82 |
-
# style truncation
|
83 |
-
if truncation < 1:
|
84 |
-
style_truncation = []
|
85 |
-
for style in styles:
|
86 |
-
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
87 |
-
styles = style_truncation
|
88 |
-
# get style latent with injection
|
89 |
-
if len(styles) == 1:
|
90 |
-
inject_index = self.num_latent
|
91 |
-
|
92 |
-
if styles[0].ndim < 3:
|
93 |
-
# repeat latent code for all the layers
|
94 |
-
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
95 |
-
else: # used for encoder with different latent code for each layer
|
96 |
-
latent = styles[0]
|
97 |
-
elif len(styles) == 2: # mixing noises
|
98 |
-
if inject_index is None:
|
99 |
-
inject_index = random.randint(1, self.num_latent - 1)
|
100 |
-
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
101 |
-
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
102 |
-
latent = torch.cat([latent1, latent2], 1)
|
103 |
-
|
104 |
-
# main generation
|
105 |
-
out = self.constant_input(latent.shape[0])
|
106 |
-
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
107 |
-
skip = self.to_rgb1(out, latent[:, 1])
|
108 |
-
|
109 |
-
i = 1
|
110 |
-
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
111 |
-
noise[2::2], self.to_rgbs):
|
112 |
-
out = conv1(out, latent[:, i], noise=noise1)
|
113 |
-
|
114 |
-
# the conditions may have fewer levels
|
115 |
-
if i < len(conditions):
|
116 |
-
# SFT part to combine the conditions
|
117 |
-
if self.sft_half:
|
118 |
-
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
119 |
-
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
120 |
-
out = torch.cat([out_same, out_sft], dim=1)
|
121 |
-
else:
|
122 |
-
out = out * conditions[i - 1] + conditions[i]
|
123 |
-
|
124 |
-
out = conv2(out, latent[:, i + 1], noise=noise2)
|
125 |
-
skip = to_rgb(out, latent[:, i + 2], skip)
|
126 |
-
i += 2
|
127 |
-
|
128 |
-
image = skip
|
129 |
-
|
130 |
-
if return_latents:
|
131 |
-
return image, latent
|
132 |
-
else:
|
133 |
-
return image, None
|
134 |
-
|
135 |
-
|
136 |
-
class ConvUpLayer(nn.Module):
|
137 |
-
"""Conv Up Layer. Bilinear upsample + Conv.
|
138 |
-
|
139 |
-
Args:
|
140 |
-
in_channels (int): Channel number of the input.
|
141 |
-
out_channels (int): Channel number of the output.
|
142 |
-
kernel_size (int): Size of the convolving kernel.
|
143 |
-
stride (int): Stride of the convolution. Default: 1
|
144 |
-
padding (int): Zero-padding added to both sides of the input.
|
145 |
-
Default: 0.
|
146 |
-
bias (bool): If ``True``, adds a learnable bias to the output.
|
147 |
-
Default: ``True``.
|
148 |
-
bias_init_val (float): Bias initialized value. Default: 0.
|
149 |
-
activate (bool): Whether use activateion. Default: True.
|
150 |
-
"""
|
151 |
-
|
152 |
-
def __init__(self,
|
153 |
-
in_channels,
|
154 |
-
out_channels,
|
155 |
-
kernel_size,
|
156 |
-
stride=1,
|
157 |
-
padding=0,
|
158 |
-
bias=True,
|
159 |
-
bias_init_val=0,
|
160 |
-
activate=True):
|
161 |
-
super(ConvUpLayer, self).__init__()
|
162 |
-
self.in_channels = in_channels
|
163 |
-
self.out_channels = out_channels
|
164 |
-
self.kernel_size = kernel_size
|
165 |
-
self.stride = stride
|
166 |
-
self.padding = padding
|
167 |
-
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
168 |
-
|
169 |
-
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
170 |
-
|
171 |
-
if bias and not activate:
|
172 |
-
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
173 |
-
else:
|
174 |
-
self.register_parameter('bias', None)
|
175 |
-
|
176 |
-
# activation
|
177 |
-
if activate:
|
178 |
-
if bias:
|
179 |
-
self.activation = FusedLeakyReLU(out_channels)
|
180 |
-
else:
|
181 |
-
self.activation = ScaledLeakyReLU(0.2)
|
182 |
-
else:
|
183 |
-
self.activation = None
|
184 |
-
|
185 |
-
def forward(self, x):
|
186 |
-
# bilinear upsample
|
187 |
-
out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
188 |
-
# conv
|
189 |
-
out = F.conv2d(
|
190 |
-
out,
|
191 |
-
self.weight * self.scale,
|
192 |
-
bias=self.bias,
|
193 |
-
stride=self.stride,
|
194 |
-
padding=self.padding,
|
195 |
-
)
|
196 |
-
# activation
|
197 |
-
if self.activation is not None:
|
198 |
-
out = self.activation(out)
|
199 |
-
return out
|
200 |
-
|
201 |
-
|
202 |
-
class ResUpBlock(nn.Module):
|
203 |
-
"""Residual block with upsampling.
|
204 |
-
|
205 |
-
Args:
|
206 |
-
in_channels (int): Channel number of the input.
|
207 |
-
out_channels (int): Channel number of the output.
|
208 |
-
"""
|
209 |
-
|
210 |
-
def __init__(self, in_channels, out_channels):
|
211 |
-
super(ResUpBlock, self).__init__()
|
212 |
-
|
213 |
-
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
214 |
-
self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
|
215 |
-
self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
|
216 |
-
|
217 |
-
def forward(self, x):
|
218 |
-
out = self.conv1(x)
|
219 |
-
out = self.conv2(out)
|
220 |
-
skip = self.skip(x)
|
221 |
-
out = (out + skip) / math.sqrt(2)
|
222 |
-
return out
|
223 |
-
|
224 |
-
|
225 |
-
@ARCH_REGISTRY.register()
|
226 |
-
class GFPGANv1(nn.Module):
|
227 |
-
"""Unet + StyleGAN2 decoder with SFT."""
|
228 |
-
|
229 |
-
def __init__(
|
230 |
-
self,
|
231 |
-
out_size,
|
232 |
-
num_style_feat=512,
|
233 |
-
channel_multiplier=1,
|
234 |
-
resample_kernel=(1, 3, 3, 1),
|
235 |
-
decoder_load_path=None,
|
236 |
-
fix_decoder=True,
|
237 |
-
# for stylegan decoder
|
238 |
-
num_mlp=8,
|
239 |
-
lr_mlp=0.01,
|
240 |
-
input_is_latent=False,
|
241 |
-
different_w=False,
|
242 |
-
narrow=1,
|
243 |
-
sft_half=False):
|
244 |
-
|
245 |
-
super(GFPGANv1, self).__init__()
|
246 |
-
self.input_is_latent = input_is_latent
|
247 |
-
self.different_w = different_w
|
248 |
-
self.num_style_feat = num_style_feat
|
249 |
-
|
250 |
-
unet_narrow = narrow * 0.5
|
251 |
-
channels = {
|
252 |
-
'4': int(512 * unet_narrow),
|
253 |
-
'8': int(512 * unet_narrow),
|
254 |
-
'16': int(512 * unet_narrow),
|
255 |
-
'32': int(512 * unet_narrow),
|
256 |
-
'64': int(256 * channel_multiplier * unet_narrow),
|
257 |
-
'128': int(128 * channel_multiplier * unet_narrow),
|
258 |
-
'256': int(64 * channel_multiplier * unet_narrow),
|
259 |
-
'512': int(32 * channel_multiplier * unet_narrow),
|
260 |
-
'1024': int(16 * channel_multiplier * unet_narrow)
|
261 |
-
}
|
262 |
-
|
263 |
-
self.log_size = int(math.log(out_size, 2))
|
264 |
-
first_out_size = 2**(int(math.log(out_size, 2)))
|
265 |
-
|
266 |
-
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
267 |
-
|
268 |
-
# downsample
|
269 |
-
in_channels = channels[f'{first_out_size}']
|
270 |
-
self.conv_body_down = nn.ModuleList()
|
271 |
-
for i in range(self.log_size, 2, -1):
|
272 |
-
out_channels = channels[f'{2**(i - 1)}']
|
273 |
-
self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
|
274 |
-
in_channels = out_channels
|
275 |
-
|
276 |
-
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
277 |
-
|
278 |
-
# upsample
|
279 |
-
in_channels = channels['4']
|
280 |
-
self.conv_body_up = nn.ModuleList()
|
281 |
-
for i in range(3, self.log_size + 1):
|
282 |
-
out_channels = channels[f'{2**i}']
|
283 |
-
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
284 |
-
in_channels = out_channels
|
285 |
-
|
286 |
-
# to RGB
|
287 |
-
self.toRGB = nn.ModuleList()
|
288 |
-
for i in range(3, self.log_size + 1):
|
289 |
-
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
290 |
-
|
291 |
-
if different_w:
|
292 |
-
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
293 |
-
else:
|
294 |
-
linear_out_channel = num_style_feat
|
295 |
-
|
296 |
-
self.final_linear = EqualLinear(
|
297 |
-
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
298 |
-
|
299 |
-
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
300 |
-
out_size=out_size,
|
301 |
-
num_style_feat=num_style_feat,
|
302 |
-
num_mlp=num_mlp,
|
303 |
-
channel_multiplier=channel_multiplier,
|
304 |
-
resample_kernel=resample_kernel,
|
305 |
-
lr_mlp=lr_mlp,
|
306 |
-
narrow=narrow,
|
307 |
-
sft_half=sft_half)
|
308 |
-
|
309 |
-
if decoder_load_path:
|
310 |
-
self.stylegan_decoder.load_state_dict(
|
311 |
-
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
312 |
-
if fix_decoder:
|
313 |
-
for _, param in self.stylegan_decoder.named_parameters():
|
314 |
-
param.requires_grad = False
|
315 |
-
|
316 |
-
# for SFT
|
317 |
-
self.condition_scale = nn.ModuleList()
|
318 |
-
self.condition_shift = nn.ModuleList()
|
319 |
-
for i in range(3, self.log_size + 1):
|
320 |
-
out_channels = channels[f'{2**i}']
|
321 |
-
if sft_half:
|
322 |
-
sft_out_channels = out_channels
|
323 |
-
else:
|
324 |
-
sft_out_channels = out_channels * 2
|
325 |
-
self.condition_scale.append(
|
326 |
-
nn.Sequential(
|
327 |
-
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
328 |
-
ScaledLeakyReLU(0.2),
|
329 |
-
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
330 |
-
self.condition_shift.append(
|
331 |
-
nn.Sequential(
|
332 |
-
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
333 |
-
ScaledLeakyReLU(0.2),
|
334 |
-
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
335 |
-
|
336 |
-
def forward(self,
|
337 |
-
x,
|
338 |
-
return_latents=False,
|
339 |
-
save_feat_path=None,
|
340 |
-
load_feat_path=None,
|
341 |
-
return_rgb=True,
|
342 |
-
randomize_noise=True):
|
343 |
-
conditions = []
|
344 |
-
unet_skips = []
|
345 |
-
out_rgbs = []
|
346 |
-
|
347 |
-
# encoder
|
348 |
-
feat = self.conv_body_first(x)
|
349 |
-
for i in range(self.log_size - 2):
|
350 |
-
feat = self.conv_body_down[i](feat)
|
351 |
-
unet_skips.insert(0, feat)
|
352 |
-
|
353 |
-
feat = self.final_conv(feat)
|
354 |
-
|
355 |
-
# style code
|
356 |
-
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
357 |
-
if self.different_w:
|
358 |
-
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
359 |
-
|
360 |
-
# decode
|
361 |
-
for i in range(self.log_size - 2):
|
362 |
-
# add unet skip
|
363 |
-
feat = feat + unet_skips[i]
|
364 |
-
# ResUpLayer
|
365 |
-
feat = self.conv_body_up[i](feat)
|
366 |
-
# generate scale and shift for SFT layer
|
367 |
-
scale = self.condition_scale[i](feat)
|
368 |
-
conditions.append(scale.clone())
|
369 |
-
shift = self.condition_shift[i](feat)
|
370 |
-
conditions.append(shift.clone())
|
371 |
-
# generate rgb images
|
372 |
-
if return_rgb:
|
373 |
-
out_rgbs.append(self.toRGB[i](feat))
|
374 |
-
|
375 |
-
if save_feat_path is not None:
|
376 |
-
torch.save(conditions, save_feat_path)
|
377 |
-
if load_feat_path is not None:
|
378 |
-
conditions = torch.load(load_feat_path)
|
379 |
-
conditions = [v.cuda() for v in conditions]
|
380 |
-
|
381 |
-
# decoder
|
382 |
-
image, _ = self.stylegan_decoder([style_code],
|
383 |
-
conditions,
|
384 |
-
return_latents=return_latents,
|
385 |
-
input_is_latent=self.input_is_latent,
|
386 |
-
randomize_noise=randomize_noise)
|
387 |
-
|
388 |
-
return image, out_rgbs
|
389 |
-
|
390 |
-
|
391 |
-
@ARCH_REGISTRY.register()
|
392 |
-
class FacialComponentDiscriminator(nn.Module):
|
393 |
-
|
394 |
-
def __init__(self):
|
395 |
-
super(FacialComponentDiscriminator, self).__init__()
|
396 |
-
|
397 |
-
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
398 |
-
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
399 |
-
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
400 |
-
self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
401 |
-
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
402 |
-
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
403 |
-
|
404 |
-
def forward(self, x, return_feats=False):
|
405 |
-
feat = self.conv1(x)
|
406 |
-
feat = self.conv3(self.conv2(feat))
|
407 |
-
rlt_feats = []
|
408 |
-
if return_feats:
|
409 |
-
rlt_feats.append(feat.clone())
|
410 |
-
feat = self.conv5(self.conv4(feat))
|
411 |
-
if return_feats:
|
412 |
-
rlt_feats.append(feat.clone())
|
413 |
-
out = self.final_conv(feat)
|
414 |
-
|
415 |
-
if return_feats:
|
416 |
-
return out, rlt_feats
|
417 |
-
else:
|
418 |
-
return out, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/__init__.py
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
import importlib
|
2 |
-
from os import path as osp
|
3 |
-
|
4 |
-
from basicsr.utils import scandir
|
5 |
-
|
6 |
-
# automatically scan and import dataset modules for registry
|
7 |
-
# scan all the files under the data folder with '_dataset' in file names
|
8 |
-
data_folder = osp.dirname(osp.abspath(__file__))
|
9 |
-
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
10 |
-
# import all the dataset modules
|
11 |
-
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/ffhq_degradation_dataset.py
DELETED
@@ -1,213 +0,0 @@
|
|
1 |
-
import cv2
|
2 |
-
import math
|
3 |
-
import numpy as np
|
4 |
-
import os.path as osp
|
5 |
-
import torch
|
6 |
-
import torch.utils.data as data
|
7 |
-
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
8 |
-
normalize)
|
9 |
-
|
10 |
-
from basicsr.data import degradations as degradations
|
11 |
-
from basicsr.data.data_util import paths_from_folder
|
12 |
-
from basicsr.data.transforms import augment
|
13 |
-
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
14 |
-
from basicsr.utils.registry import DATASET_REGISTRY
|
15 |
-
|
16 |
-
|
17 |
-
@DATASET_REGISTRY.register()
|
18 |
-
class FFHQDegradationDataset(data.Dataset):
|
19 |
-
|
20 |
-
def __init__(self, opt):
|
21 |
-
super(FFHQDegradationDataset, self).__init__()
|
22 |
-
self.opt = opt
|
23 |
-
# file client (io backend)
|
24 |
-
self.file_client = None
|
25 |
-
self.io_backend_opt = opt['io_backend']
|
26 |
-
|
27 |
-
self.gt_folder = opt['dataroot_gt']
|
28 |
-
self.mean = opt['mean']
|
29 |
-
self.std = opt['std']
|
30 |
-
self.out_size = opt['out_size']
|
31 |
-
|
32 |
-
self.crop_components = opt.get('crop_components', False) # facial components
|
33 |
-
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
|
34 |
-
|
35 |
-
if self.crop_components:
|
36 |
-
self.components_list = torch.load(opt.get('component_path'))
|
37 |
-
|
38 |
-
if self.io_backend_opt['type'] == 'lmdb':
|
39 |
-
self.io_backend_opt['db_paths'] = self.gt_folder
|
40 |
-
if not self.gt_folder.endswith('.lmdb'):
|
41 |
-
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
42 |
-
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
43 |
-
self.paths = [line.split('.')[0] for line in fin]
|
44 |
-
else:
|
45 |
-
self.paths = paths_from_folder(self.gt_folder)
|
46 |
-
|
47 |
-
# degradations
|
48 |
-
self.blur_kernel_size = opt['blur_kernel_size']
|
49 |
-
self.kernel_list = opt['kernel_list']
|
50 |
-
self.kernel_prob = opt['kernel_prob']
|
51 |
-
self.blur_sigma = opt['blur_sigma']
|
52 |
-
self.downsample_range = opt['downsample_range']
|
53 |
-
self.noise_range = opt['noise_range']
|
54 |
-
self.jpeg_range = opt['jpeg_range']
|
55 |
-
|
56 |
-
# color jitter
|
57 |
-
self.color_jitter_prob = opt.get('color_jitter_prob')
|
58 |
-
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
|
59 |
-
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
60 |
-
# to gray
|
61 |
-
self.gray_prob = opt.get('gray_prob')
|
62 |
-
|
63 |
-
logger = get_root_logger()
|
64 |
-
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
|
65 |
-
f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
66 |
-
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
67 |
-
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
68 |
-
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
69 |
-
|
70 |
-
if self.color_jitter_prob is not None:
|
71 |
-
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
|
72 |
-
f'shift: {self.color_jitter_shift}')
|
73 |
-
if self.gray_prob is not None:
|
74 |
-
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
75 |
-
|
76 |
-
self.color_jitter_shift /= 255.
|
77 |
-
|
78 |
-
@staticmethod
|
79 |
-
def color_jitter(img, shift):
|
80 |
-
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
81 |
-
img = img + jitter_val
|
82 |
-
img = np.clip(img, 0, 1)
|
83 |
-
return img
|
84 |
-
|
85 |
-
@staticmethod
|
86 |
-
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
87 |
-
fn_idx = torch.randperm(4)
|
88 |
-
for fn_id in fn_idx:
|
89 |
-
if fn_id == 0 and brightness is not None:
|
90 |
-
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
91 |
-
img = adjust_brightness(img, brightness_factor)
|
92 |
-
|
93 |
-
if fn_id == 1 and contrast is not None:
|
94 |
-
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
95 |
-
img = adjust_contrast(img, contrast_factor)
|
96 |
-
|
97 |
-
if fn_id == 2 and saturation is not None:
|
98 |
-
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
99 |
-
img = adjust_saturation(img, saturation_factor)
|
100 |
-
|
101 |
-
if fn_id == 3 and hue is not None:
|
102 |
-
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
103 |
-
img = adjust_hue(img, hue_factor)
|
104 |
-
return img
|
105 |
-
|
106 |
-
def get_component_coordinates(self, index, status):
|
107 |
-
components_bbox = self.components_list[f'{index:08d}']
|
108 |
-
if status[0]: # hflip
|
109 |
-
# exchange right and left eye
|
110 |
-
tmp = components_bbox['left_eye']
|
111 |
-
components_bbox['left_eye'] = components_bbox['right_eye']
|
112 |
-
components_bbox['right_eye'] = tmp
|
113 |
-
# modify the width coordinate
|
114 |
-
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
|
115 |
-
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
|
116 |
-
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
|
117 |
-
|
118 |
-
# get coordinates
|
119 |
-
locations = []
|
120 |
-
for part in ['left_eye', 'right_eye', 'mouth']:
|
121 |
-
mean = components_bbox[part][0:2]
|
122 |
-
half_len = components_bbox[part][2]
|
123 |
-
if 'eye' in part:
|
124 |
-
half_len *= self.eye_enlarge_ratio
|
125 |
-
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
126 |
-
loc = torch.from_numpy(loc).float()
|
127 |
-
locations.append(loc)
|
128 |
-
return locations
|
129 |
-
|
130 |
-
def __getitem__(self, index):
|
131 |
-
if self.file_client is None:
|
132 |
-
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
133 |
-
|
134 |
-
# load gt image
|
135 |
-
gt_path = self.paths[index]
|
136 |
-
img_bytes = self.file_client.get(gt_path)
|
137 |
-
img_gt = imfrombytes(img_bytes, float32=True)
|
138 |
-
|
139 |
-
# random horizontal flip
|
140 |
-
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
141 |
-
h, w, _ = img_gt.shape
|
142 |
-
|
143 |
-
if self.crop_components:
|
144 |
-
locations = self.get_component_coordinates(index, status)
|
145 |
-
loc_left_eye, loc_right_eye, loc_mouth = locations
|
146 |
-
|
147 |
-
# ------------------------ generate lq image ------------------------ #
|
148 |
-
# blur
|
149 |
-
kernel = degradations.random_mixed_kernels(
|
150 |
-
self.kernel_list,
|
151 |
-
self.kernel_prob,
|
152 |
-
self.blur_kernel_size,
|
153 |
-
self.blur_sigma,
|
154 |
-
self.blur_sigma, [-math.pi, math.pi],
|
155 |
-
noise_range=None)
|
156 |
-
img_lq = cv2.filter2D(img_gt, -1, kernel)
|
157 |
-
# downsample
|
158 |
-
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
159 |
-
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
|
160 |
-
# noise
|
161 |
-
if self.noise_range is not None:
|
162 |
-
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
|
163 |
-
# jpeg compression
|
164 |
-
if self.jpeg_range is not None:
|
165 |
-
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
|
166 |
-
|
167 |
-
# resize to original size
|
168 |
-
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
|
169 |
-
|
170 |
-
# random color jitter (only for lq)
|
171 |
-
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
172 |
-
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
|
173 |
-
# random to gray (only for lq)
|
174 |
-
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
175 |
-
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
176 |
-
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
177 |
-
if self.opt.get('gt_gray'):
|
178 |
-
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
179 |
-
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
|
180 |
-
|
181 |
-
# BGR to RGB, HWC to CHW, numpy to tensor
|
182 |
-
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
183 |
-
|
184 |
-
# random color jitter (pytorch version) (only for lq)
|
185 |
-
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
186 |
-
brightness = self.opt.get('brightness', (0.5, 1.5))
|
187 |
-
contrast = self.opt.get('contrast', (0.5, 1.5))
|
188 |
-
saturation = self.opt.get('saturation', (0, 1.5))
|
189 |
-
hue = self.opt.get('hue', (-0.1, 0.1))
|
190 |
-
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
|
191 |
-
|
192 |
-
# round and clip
|
193 |
-
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
|
194 |
-
|
195 |
-
# normalize
|
196 |
-
normalize(img_gt, self.mean, self.std, inplace=True)
|
197 |
-
normalize(img_lq, self.mean, self.std, inplace=True)
|
198 |
-
|
199 |
-
if self.crop_components:
|
200 |
-
return_dict = {
|
201 |
-
'lq': img_lq,
|
202 |
-
'gt': img_gt,
|
203 |
-
'gt_path': gt_path,
|
204 |
-
'loc_left_eye': loc_left_eye,
|
205 |
-
'loc_right_eye': loc_right_eye,
|
206 |
-
'loc_mouth': loc_mouth
|
207 |
-
}
|
208 |
-
return return_dict
|
209 |
-
else:
|
210 |
-
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
|
211 |
-
|
212 |
-
def __len__(self):
|
213 |
-
return len(self.paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/pretrained_models/GFPGANv1.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:6db3a33dd00dd427b8a70a7e3c6244a5bcccb818736f4861ce1d609024a991de
|
3 |
-
size 615378983
|
|
|
|
|
|
|
|
experiments/pretrained_models/README.md
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
# Pre-trained Models and Other Data
|
2 |
-
|
3 |
-
Download pre-trained models and other data. Put them in this folder.
|
4 |
-
|
5 |
-
1. [Pretrained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
|
6 |
-
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
|
7 |
-
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_gfpgan_full.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import cv2
|
3 |
-
import glob
|
4 |
-
import numpy as np
|
5 |
-
import os
|
6 |
-
import torch
|
7 |
-
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
8 |
-
from torchvision.transforms.functional import normalize
|
9 |
-
|
10 |
-
from archs.gfpganv1_arch import GFPGANv1
|
11 |
-
from basicsr.utils import img2tensor, imwrite, tensor2img
|
12 |
-
|
13 |
-
|
14 |
-
def restoration(gfpgan,
|
15 |
-
face_helper,
|
16 |
-
img_path,
|
17 |
-
save_root,
|
18 |
-
has_aligned=False,
|
19 |
-
only_center_face=True,
|
20 |
-
suffix=None,
|
21 |
-
paste_back=False):
|
22 |
-
# read image
|
23 |
-
img_name = os.path.basename(img_path)
|
24 |
-
print(f'Processing {img_name} ...')
|
25 |
-
basename, _ = os.path.splitext(img_name)
|
26 |
-
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
27 |
-
face_helper.clean_all()
|
28 |
-
|
29 |
-
if has_aligned:
|
30 |
-
input_img = cv2.resize(input_img, (512, 512))
|
31 |
-
face_helper.cropped_faces = [input_img]
|
32 |
-
else:
|
33 |
-
face_helper.read_image(input_img)
|
34 |
-
# get face landmarks for each face
|
35 |
-
face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False)
|
36 |
-
# align and warp each face
|
37 |
-
save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
|
38 |
-
face_helper.align_warp_face(save_crop_path)
|
39 |
-
|
40 |
-
# face restoration
|
41 |
-
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
42 |
-
# prepare data
|
43 |
-
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
44 |
-
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
45 |
-
cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda')
|
46 |
-
|
47 |
-
try:
|
48 |
-
with torch.no_grad():
|
49 |
-
output = gfpgan(cropped_face_t, return_rgb=False)[0]
|
50 |
-
# convert to image
|
51 |
-
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
52 |
-
except RuntimeError as error:
|
53 |
-
print(f'\tFailed inference for GFPGAN: {error}.')
|
54 |
-
restored_face = cropped_face
|
55 |
-
|
56 |
-
restored_face = restored_face.astype('uint8')
|
57 |
-
face_helper.add_restored_face(restored_face)
|
58 |
-
|
59 |
-
if suffix is not None:
|
60 |
-
save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
|
61 |
-
else:
|
62 |
-
save_face_name = f'{basename}_{idx:02d}.png'
|
63 |
-
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
|
64 |
-
imwrite(restored_face, save_restore_path)
|
65 |
-
|
66 |
-
# save cmp image
|
67 |
-
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
68 |
-
imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png'))
|
69 |
-
|
70 |
-
if not has_aligned and paste_back:
|
71 |
-
face_helper.get_inverse_affine(None)
|
72 |
-
save_restore_path = os.path.join(save_root, 'restored_imgs', img_name)
|
73 |
-
# paste each restored face to the input image
|
74 |
-
face_helper.paste_faces_to_input_image(save_restore_path)
|
75 |
-
|
76 |
-
|
77 |
-
if __name__ == '__main__':
|
78 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
79 |
-
parser = argparse.ArgumentParser()
|
80 |
-
|
81 |
-
parser.add_argument('--upscale_factor', type=int, default=1)
|
82 |
-
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANv1.pth')
|
83 |
-
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
|
84 |
-
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
85 |
-
parser.add_argument('--only_center_face', action='store_true')
|
86 |
-
parser.add_argument('--aligned', action='store_true')
|
87 |
-
parser.add_argument('--paste_back', action='store_true')
|
88 |
-
|
89 |
-
args = parser.parse_args()
|
90 |
-
if args.test_path.endswith('/'):
|
91 |
-
args.test_path = args.test_path[:-1]
|
92 |
-
save_root = 'results/'
|
93 |
-
os.makedirs(save_root, exist_ok=True)
|
94 |
-
|
95 |
-
# initialize the GFP-GAN
|
96 |
-
gfpgan = GFPGANv1(
|
97 |
-
out_size=512,
|
98 |
-
num_style_feat=512,
|
99 |
-
channel_multiplier=1,
|
100 |
-
decoder_load_path=None,
|
101 |
-
fix_decoder=True,
|
102 |
-
# for stylegan decoder
|
103 |
-
num_mlp=8,
|
104 |
-
input_is_latent=True,
|
105 |
-
different_w=True,
|
106 |
-
narrow=1,
|
107 |
-
sft_half=True)
|
108 |
-
|
109 |
-
gfpgan.to(device)
|
110 |
-
checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
|
111 |
-
gfpgan.load_state_dict(checkpoint['params_ema'])
|
112 |
-
gfpgan.eval()
|
113 |
-
|
114 |
-
# initialize face helper
|
115 |
-
face_helper = FaceRestoreHelper(
|
116 |
-
args.upscale_factor, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')
|
117 |
-
|
118 |
-
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
|
119 |
-
for img_path in img_list:
|
120 |
-
restoration(
|
121 |
-
gfpgan,
|
122 |
-
face_helper,
|
123 |
-
img_path,
|
124 |
-
save_root,
|
125 |
-
has_aligned=args.aligned,
|
126 |
-
only_center_face=args.only_center_face,
|
127 |
-
suffix=args.suffix,
|
128 |
-
paste_back=args.paste_back)
|
129 |
-
|
130 |
-
print('Results are in the <results> folder.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__init__.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
import importlib
|
2 |
-
from os import path as osp
|
3 |
-
|
4 |
-
from basicsr.utils import scandir
|
5 |
-
|
6 |
-
# automatically scan and import model modules for registry
|
7 |
-
# scan all the files under the 'models' folder and collect files ending with
|
8 |
-
# '_model.py'
|
9 |
-
model_folder = osp.dirname(osp.abspath(__file__))
|
10 |
-
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
11 |
-
# import all the model modules
|
12 |
-
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/gfpgan_model.py
DELETED
@@ -1,562 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import os.path as osp
|
3 |
-
import torch
|
4 |
-
from collections import OrderedDict
|
5 |
-
from torch.nn import functional as F
|
6 |
-
from torchvision.ops import roi_align
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
from basicsr.archs import build_network
|
10 |
-
from basicsr.losses import build_loss
|
11 |
-
from basicsr.losses.losses import r1_penalty
|
12 |
-
from basicsr.metrics import calculate_metric
|
13 |
-
from basicsr.models.base_model import BaseModel
|
14 |
-
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
15 |
-
from basicsr.utils.registry import MODEL_REGISTRY
|
16 |
-
|
17 |
-
|
18 |
-
@MODEL_REGISTRY.register()
|
19 |
-
class GFPGANModel(BaseModel):
|
20 |
-
"""GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
|
21 |
-
|
22 |
-
def __init__(self, opt):
|
23 |
-
super(GFPGANModel, self).__init__(opt)
|
24 |
-
self.idx = 0
|
25 |
-
|
26 |
-
# define network
|
27 |
-
self.net_g = build_network(opt['network_g'])
|
28 |
-
self.net_g = self.model_to_device(self.net_g)
|
29 |
-
self.print_network(self.net_g)
|
30 |
-
|
31 |
-
# load pretrained model
|
32 |
-
load_path = self.opt['path'].get('pretrain_network_g', None)
|
33 |
-
if load_path is not None:
|
34 |
-
param_key = self.opt['path'].get('param_key_g', 'params')
|
35 |
-
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
|
36 |
-
|
37 |
-
self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
|
38 |
-
|
39 |
-
if self.is_train:
|
40 |
-
self.init_training_settings()
|
41 |
-
|
42 |
-
def init_training_settings(self):
|
43 |
-
train_opt = self.opt['train']
|
44 |
-
|
45 |
-
# ----------- define net_d ----------- #
|
46 |
-
self.net_d = build_network(self.opt['network_d'])
|
47 |
-
self.net_d = self.model_to_device(self.net_d)
|
48 |
-
self.print_network(self.net_d)
|
49 |
-
# load pretrained model
|
50 |
-
load_path = self.opt['path'].get('pretrain_network_d', None)
|
51 |
-
if load_path is not None:
|
52 |
-
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
53 |
-
|
54 |
-
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
55 |
-
# net_g_ema only used for testing on one GPU and saving
|
56 |
-
# There is no need to wrap with DistributedDataParallel
|
57 |
-
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
58 |
-
# load pretrained model
|
59 |
-
load_path = self.opt['path'].get('pretrain_network_g', None)
|
60 |
-
if load_path is not None:
|
61 |
-
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
62 |
-
else:
|
63 |
-
self.model_ema(0) # copy net_g weight
|
64 |
-
|
65 |
-
self.net_g.train()
|
66 |
-
self.net_d.train()
|
67 |
-
self.net_g_ema.eval()
|
68 |
-
|
69 |
-
# ----------- facial components networks ----------- #
|
70 |
-
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
71 |
-
self.use_facial_disc = True
|
72 |
-
else:
|
73 |
-
self.use_facial_disc = False
|
74 |
-
|
75 |
-
if self.use_facial_disc:
|
76 |
-
# left eye
|
77 |
-
self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
|
78 |
-
self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
|
79 |
-
self.print_network(self.net_d_left_eye)
|
80 |
-
load_path = self.opt['path'].get('pretrain_network_d_left_eye')
|
81 |
-
if load_path is not None:
|
82 |
-
self.load_network(self.net_d_left_eye, load_path, True, 'params')
|
83 |
-
# right eye
|
84 |
-
self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
|
85 |
-
self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
|
86 |
-
self.print_network(self.net_d_right_eye)
|
87 |
-
load_path = self.opt['path'].get('pretrain_network_d_right_eye')
|
88 |
-
if load_path is not None:
|
89 |
-
self.load_network(self.net_d_right_eye, load_path, True, 'params')
|
90 |
-
# mouth
|
91 |
-
self.net_d_mouth = build_network(self.opt['network_d_mouth'])
|
92 |
-
self.net_d_mouth = self.model_to_device(self.net_d_mouth)
|
93 |
-
self.print_network(self.net_d_mouth)
|
94 |
-
load_path = self.opt['path'].get('pretrain_network_d_mouth')
|
95 |
-
if load_path is not None:
|
96 |
-
self.load_network(self.net_d_mouth, load_path, True, 'params')
|
97 |
-
|
98 |
-
self.net_d_left_eye.train()
|
99 |
-
self.net_d_right_eye.train()
|
100 |
-
self.net_d_mouth.train()
|
101 |
-
|
102 |
-
# ----------- define facial component gan loss ----------- #
|
103 |
-
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
104 |
-
|
105 |
-
# ----------- define losses ----------- #
|
106 |
-
if train_opt.get('pixel_opt'):
|
107 |
-
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
108 |
-
else:
|
109 |
-
self.cri_pix = None
|
110 |
-
|
111 |
-
if train_opt.get('perceptual_opt'):
|
112 |
-
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
113 |
-
else:
|
114 |
-
self.cri_perceptual = None
|
115 |
-
|
116 |
-
# L1 loss used in pyramid loss, component style loss and identity loss
|
117 |
-
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
|
118 |
-
|
119 |
-
# gan loss (wgan)
|
120 |
-
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
121 |
-
|
122 |
-
# ----------- define identity loss ----------- #
|
123 |
-
if 'network_identity' in self.opt:
|
124 |
-
self.use_identity = True
|
125 |
-
else:
|
126 |
-
self.use_identity = False
|
127 |
-
|
128 |
-
if self.use_identity:
|
129 |
-
# define identity network
|
130 |
-
self.network_identity = build_network(self.opt['network_identity'])
|
131 |
-
self.network_identity = self.model_to_device(self.network_identity)
|
132 |
-
self.print_network(self.network_identity)
|
133 |
-
load_path = self.opt['path'].get('pretrain_network_identity')
|
134 |
-
if load_path is not None:
|
135 |
-
self.load_network(self.network_identity, load_path, True, None)
|
136 |
-
self.network_identity.eval()
|
137 |
-
for param in self.network_identity.parameters():
|
138 |
-
param.requires_grad = False
|
139 |
-
|
140 |
-
# regularization weights
|
141 |
-
self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
|
142 |
-
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
143 |
-
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
|
144 |
-
self.net_d_reg_every = train_opt['net_d_reg_every']
|
145 |
-
|
146 |
-
# set up optimizers and schedulers
|
147 |
-
self.setup_optimizers()
|
148 |
-
self.setup_schedulers()
|
149 |
-
|
150 |
-
def setup_optimizers(self):
|
151 |
-
train_opt = self.opt['train']
|
152 |
-
|
153 |
-
# ----------- optimizer g ----------- #
|
154 |
-
net_g_reg_ratio = 1
|
155 |
-
normal_params = []
|
156 |
-
for _, param in self.net_g.named_parameters():
|
157 |
-
normal_params.append(param)
|
158 |
-
optim_params_g = [{ # add normal params first
|
159 |
-
'params': normal_params,
|
160 |
-
'lr': train_opt['optim_g']['lr']
|
161 |
-
}]
|
162 |
-
optim_type = train_opt['optim_g'].pop('type')
|
163 |
-
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
|
164 |
-
betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
|
165 |
-
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
|
166 |
-
self.optimizers.append(self.optimizer_g)
|
167 |
-
|
168 |
-
# ----------- optimizer d ----------- #
|
169 |
-
net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
|
170 |
-
normal_params = []
|
171 |
-
for _, param in self.net_d.named_parameters():
|
172 |
-
normal_params.append(param)
|
173 |
-
optim_params_d = [{ # add normal params first
|
174 |
-
'params': normal_params,
|
175 |
-
'lr': train_opt['optim_d']['lr']
|
176 |
-
}]
|
177 |
-
optim_type = train_opt['optim_d'].pop('type')
|
178 |
-
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
|
179 |
-
betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
|
180 |
-
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
181 |
-
self.optimizers.append(self.optimizer_d)
|
182 |
-
|
183 |
-
if self.use_facial_disc:
|
184 |
-
# setup optimizers for facial component discriminators
|
185 |
-
optim_type = train_opt['optim_component'].pop('type')
|
186 |
-
lr = train_opt['optim_component']['lr']
|
187 |
-
# left eye
|
188 |
-
self.optimizer_d_left_eye = self.get_optimizer(
|
189 |
-
optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
|
190 |
-
self.optimizers.append(self.optimizer_d_left_eye)
|
191 |
-
# right eye
|
192 |
-
self.optimizer_d_right_eye = self.get_optimizer(
|
193 |
-
optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
|
194 |
-
self.optimizers.append(self.optimizer_d_right_eye)
|
195 |
-
# mouth
|
196 |
-
self.optimizer_d_mouth = self.get_optimizer(
|
197 |
-
optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
|
198 |
-
self.optimizers.append(self.optimizer_d_mouth)
|
199 |
-
|
200 |
-
def feed_data(self, data):
|
201 |
-
self.lq = data['lq'].to(self.device)
|
202 |
-
if 'gt' in data:
|
203 |
-
self.gt = data['gt'].to(self.device)
|
204 |
-
|
205 |
-
if 'loc_left_eye' in data:
|
206 |
-
# get facial component locations, shape (batch, 4)
|
207 |
-
self.loc_left_eyes = data['loc_left_eye']
|
208 |
-
self.loc_right_eyes = data['loc_right_eye']
|
209 |
-
self.loc_mouths = data['loc_mouth']
|
210 |
-
|
211 |
-
# uncomment to check data
|
212 |
-
# import torchvision
|
213 |
-
# if self.opt['rank'] == 0:
|
214 |
-
# import os
|
215 |
-
# os.makedirs('tmp/gt', exist_ok=True)
|
216 |
-
# os.makedirs('tmp/lq', exist_ok=True)
|
217 |
-
# print(self.idx)
|
218 |
-
# torchvision.utils.save_image(
|
219 |
-
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
220 |
-
# torchvision.utils.save_image(
|
221 |
-
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
222 |
-
# self.idx = self.idx + 1
|
223 |
-
|
224 |
-
def construct_img_pyramid(self):
|
225 |
-
pyramid_gt = [self.gt]
|
226 |
-
down_img = self.gt
|
227 |
-
for _ in range(0, self.log_size - 3):
|
228 |
-
down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
|
229 |
-
pyramid_gt.insert(0, down_img)
|
230 |
-
return pyramid_gt
|
231 |
-
|
232 |
-
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
233 |
-
# hard code
|
234 |
-
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
235 |
-
eye_out_size *= face_ratio
|
236 |
-
mouth_out_size *= face_ratio
|
237 |
-
|
238 |
-
rois_eyes = []
|
239 |
-
rois_mouths = []
|
240 |
-
for b in range(self.loc_left_eyes.size(0)): # loop for batch size
|
241 |
-
# left eye and right eye
|
242 |
-
img_inds = self.loc_left_eyes.new_full((2, 1), b)
|
243 |
-
bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
|
244 |
-
rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
|
245 |
-
rois_eyes.append(rois)
|
246 |
-
# mouse
|
247 |
-
img_inds = self.loc_left_eyes.new_full((1, 1), b)
|
248 |
-
rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
|
249 |
-
rois_mouths.append(rois)
|
250 |
-
|
251 |
-
rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
|
252 |
-
rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
|
253 |
-
|
254 |
-
# real images
|
255 |
-
all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
256 |
-
self.left_eyes_gt = all_eyes[0::2, :, :, :]
|
257 |
-
self.right_eyes_gt = all_eyes[1::2, :, :, :]
|
258 |
-
self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
259 |
-
# output
|
260 |
-
all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
261 |
-
self.left_eyes = all_eyes[0::2, :, :, :]
|
262 |
-
self.right_eyes = all_eyes[1::2, :, :, :]
|
263 |
-
self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
264 |
-
|
265 |
-
def _gram_mat(self, x):
|
266 |
-
"""Calculate Gram matrix.
|
267 |
-
|
268 |
-
Args:
|
269 |
-
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
270 |
-
|
271 |
-
Returns:
|
272 |
-
torch.Tensor: Gram matrix.
|
273 |
-
"""
|
274 |
-
n, c, h, w = x.size()
|
275 |
-
features = x.view(n, c, w * h)
|
276 |
-
features_t = features.transpose(1, 2)
|
277 |
-
gram = features.bmm(features_t) / (c * h * w)
|
278 |
-
return gram
|
279 |
-
|
280 |
-
def gray_resize_for_identity(self, out, size=128):
|
281 |
-
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
282 |
-
out_gray = out_gray.unsqueeze(1)
|
283 |
-
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
284 |
-
return out_gray
|
285 |
-
|
286 |
-
def optimize_parameters(self, current_iter):
|
287 |
-
# optimize net_g
|
288 |
-
for p in self.net_d.parameters():
|
289 |
-
p.requires_grad = False
|
290 |
-
self.optimizer_g.zero_grad()
|
291 |
-
|
292 |
-
if self.use_facial_disc:
|
293 |
-
for p in self.net_d_left_eye.parameters():
|
294 |
-
p.requires_grad = False
|
295 |
-
for p in self.net_d_right_eye.parameters():
|
296 |
-
p.requires_grad = False
|
297 |
-
for p in self.net_d_mouth.parameters():
|
298 |
-
p.requires_grad = False
|
299 |
-
|
300 |
-
# image pyramid loss weight
|
301 |
-
if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')):
|
302 |
-
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1)
|
303 |
-
else:
|
304 |
-
pyramid_loss_weight = 1e-12 # very small loss
|
305 |
-
if pyramid_loss_weight > 0:
|
306 |
-
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
|
307 |
-
pyramid_gt = self.construct_img_pyramid()
|
308 |
-
else:
|
309 |
-
self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
|
310 |
-
|
311 |
-
# get roi-align regions
|
312 |
-
if self.use_facial_disc:
|
313 |
-
self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
|
314 |
-
|
315 |
-
l_g_total = 0
|
316 |
-
loss_dict = OrderedDict()
|
317 |
-
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
318 |
-
# pixel loss
|
319 |
-
if self.cri_pix:
|
320 |
-
l_g_pix = self.cri_pix(self.output, self.gt)
|
321 |
-
l_g_total += l_g_pix
|
322 |
-
loss_dict['l_g_pix'] = l_g_pix
|
323 |
-
|
324 |
-
# image pyramid loss
|
325 |
-
if pyramid_loss_weight > 0:
|
326 |
-
for i in range(0, self.log_size - 2):
|
327 |
-
l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
|
328 |
-
l_g_total += l_pyramid
|
329 |
-
loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
|
330 |
-
|
331 |
-
# perceptual loss
|
332 |
-
if self.cri_perceptual:
|
333 |
-
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
|
334 |
-
if l_g_percep is not None:
|
335 |
-
l_g_total += l_g_percep
|
336 |
-
loss_dict['l_g_percep'] = l_g_percep
|
337 |
-
if l_g_style is not None:
|
338 |
-
l_g_total += l_g_style
|
339 |
-
loss_dict['l_g_style'] = l_g_style
|
340 |
-
|
341 |
-
# gan loss
|
342 |
-
fake_g_pred = self.net_d(self.output)
|
343 |
-
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
344 |
-
l_g_total += l_g_gan
|
345 |
-
loss_dict['l_g_gan'] = l_g_gan
|
346 |
-
|
347 |
-
# facial component loss
|
348 |
-
if self.use_facial_disc:
|
349 |
-
# left eye
|
350 |
-
fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
|
351 |
-
l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
|
352 |
-
l_g_total += l_g_gan
|
353 |
-
loss_dict['l_g_gan_left_eye'] = l_g_gan
|
354 |
-
# right eye
|
355 |
-
fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
|
356 |
-
l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
|
357 |
-
l_g_total += l_g_gan
|
358 |
-
loss_dict['l_g_gan_right_eye'] = l_g_gan
|
359 |
-
# mouth
|
360 |
-
fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
|
361 |
-
l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
|
362 |
-
l_g_total += l_g_gan
|
363 |
-
loss_dict['l_g_gan_mouth'] = l_g_gan
|
364 |
-
|
365 |
-
if self.opt['train'].get('comp_style_weight', 0) > 0:
|
366 |
-
# get gt feat
|
367 |
-
_, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
|
368 |
-
_, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
|
369 |
-
_, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
|
370 |
-
|
371 |
-
def _comp_style(feat, feat_gt, criterion):
|
372 |
-
return criterion(self._gram_mat(feat[0]), self._gram_mat(
|
373 |
-
feat_gt[0].detach())) * 0.5 + criterion(
|
374 |
-
self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
|
375 |
-
|
376 |
-
# facial component style loss
|
377 |
-
comp_style_loss = 0
|
378 |
-
comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
|
379 |
-
comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
|
380 |
-
comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
|
381 |
-
comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
|
382 |
-
l_g_total += comp_style_loss
|
383 |
-
loss_dict['l_g_comp_style_loss'] = comp_style_loss
|
384 |
-
|
385 |
-
# identity loss
|
386 |
-
if self.use_identity:
|
387 |
-
identity_weight = self.opt['train']['identity_weight']
|
388 |
-
# get gray images and resize
|
389 |
-
out_gray = self.gray_resize_for_identity(self.output)
|
390 |
-
gt_gray = self.gray_resize_for_identity(self.gt)
|
391 |
-
|
392 |
-
identity_gt = self.network_identity(gt_gray).detach()
|
393 |
-
identity_out = self.network_identity(out_gray)
|
394 |
-
l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
|
395 |
-
l_g_total += l_identity
|
396 |
-
loss_dict['l_identity'] = l_identity
|
397 |
-
|
398 |
-
l_g_total.backward()
|
399 |
-
self.optimizer_g.step()
|
400 |
-
|
401 |
-
# EMA
|
402 |
-
self.model_ema(decay=0.5**(32 / (10 * 1000)))
|
403 |
-
|
404 |
-
# ----------- optimize net_d ----------- #
|
405 |
-
for p in self.net_d.parameters():
|
406 |
-
p.requires_grad = True
|
407 |
-
self.optimizer_d.zero_grad()
|
408 |
-
if self.use_facial_disc:
|
409 |
-
for p in self.net_d_left_eye.parameters():
|
410 |
-
p.requires_grad = True
|
411 |
-
for p in self.net_d_right_eye.parameters():
|
412 |
-
p.requires_grad = True
|
413 |
-
for p in self.net_d_mouth.parameters():
|
414 |
-
p.requires_grad = True
|
415 |
-
self.optimizer_d_left_eye.zero_grad()
|
416 |
-
self.optimizer_d_right_eye.zero_grad()
|
417 |
-
self.optimizer_d_mouth.zero_grad()
|
418 |
-
|
419 |
-
fake_d_pred = self.net_d(self.output.detach())
|
420 |
-
real_d_pred = self.net_d(self.gt)
|
421 |
-
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
422 |
-
loss_dict['l_d'] = l_d
|
423 |
-
# In wgan, real_score should be positive and fake_score should benegative
|
424 |
-
loss_dict['real_score'] = real_d_pred.detach().mean()
|
425 |
-
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
426 |
-
l_d.backward()
|
427 |
-
|
428 |
-
if current_iter % self.net_d_reg_every == 0:
|
429 |
-
self.gt.requires_grad = True
|
430 |
-
real_pred = self.net_d(self.gt)
|
431 |
-
l_d_r1 = r1_penalty(real_pred, self.gt)
|
432 |
-
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
|
433 |
-
loss_dict['l_d_r1'] = l_d_r1.detach().mean()
|
434 |
-
l_d_r1.backward()
|
435 |
-
|
436 |
-
self.optimizer_d.step()
|
437 |
-
|
438 |
-
if self.use_facial_disc:
|
439 |
-
# lefe eye
|
440 |
-
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
441 |
-
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
442 |
-
l_d_left_eye = self.cri_component(
|
443 |
-
real_d_pred, True, is_disc=True) + self.cri_gan(
|
444 |
-
fake_d_pred, False, is_disc=True)
|
445 |
-
loss_dict['l_d_left_eye'] = l_d_left_eye
|
446 |
-
l_d_left_eye.backward()
|
447 |
-
# right eye
|
448 |
-
fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
|
449 |
-
real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
|
450 |
-
l_d_right_eye = self.cri_component(
|
451 |
-
real_d_pred, True, is_disc=True) + self.cri_gan(
|
452 |
-
fake_d_pred, False, is_disc=True)
|
453 |
-
loss_dict['l_d_right_eye'] = l_d_right_eye
|
454 |
-
l_d_right_eye.backward()
|
455 |
-
# mouth
|
456 |
-
fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
|
457 |
-
real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
|
458 |
-
l_d_mouth = self.cri_component(
|
459 |
-
real_d_pred, True, is_disc=True) + self.cri_gan(
|
460 |
-
fake_d_pred, False, is_disc=True)
|
461 |
-
loss_dict['l_d_mouth'] = l_d_mouth
|
462 |
-
l_d_mouth.backward()
|
463 |
-
|
464 |
-
self.optimizer_d_left_eye.step()
|
465 |
-
self.optimizer_d_right_eye.step()
|
466 |
-
self.optimizer_d_mouth.step()
|
467 |
-
|
468 |
-
self.log_dict = self.reduce_loss_dict(loss_dict)
|
469 |
-
|
470 |
-
def test(self):
|
471 |
-
with torch.no_grad():
|
472 |
-
if hasattr(self, 'net_g_ema'):
|
473 |
-
self.net_g_ema.eval()
|
474 |
-
self.output, _ = self.net_g_ema(self.lq)
|
475 |
-
else:
|
476 |
-
logger = get_root_logger()
|
477 |
-
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
478 |
-
self.net_g.eval()
|
479 |
-
self.output, _ = self.net_g(self.lq)
|
480 |
-
self.net_g.train()
|
481 |
-
|
482 |
-
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
483 |
-
if self.opt['rank'] == 0:
|
484 |
-
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
485 |
-
|
486 |
-
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
487 |
-
dataset_name = dataloader.dataset.opt['name']
|
488 |
-
with_metrics = self.opt['val'].get('metrics') is not None
|
489 |
-
if with_metrics:
|
490 |
-
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
491 |
-
pbar = tqdm(total=len(dataloader), unit='image')
|
492 |
-
|
493 |
-
for idx, val_data in enumerate(dataloader):
|
494 |
-
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
495 |
-
self.feed_data(val_data)
|
496 |
-
self.test()
|
497 |
-
|
498 |
-
visuals = self.get_current_visuals()
|
499 |
-
sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
|
500 |
-
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
|
501 |
-
|
502 |
-
if 'gt' in visuals:
|
503 |
-
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
|
504 |
-
del self.gt
|
505 |
-
# tentative for out of GPU memory
|
506 |
-
del self.lq
|
507 |
-
del self.output
|
508 |
-
torch.cuda.empty_cache()
|
509 |
-
|
510 |
-
if save_img:
|
511 |
-
if self.opt['is_train']:
|
512 |
-
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
513 |
-
f'{img_name}_{current_iter}.png')
|
514 |
-
else:
|
515 |
-
if self.opt['val']['suffix']:
|
516 |
-
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
517 |
-
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
518 |
-
else:
|
519 |
-
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
520 |
-
f'{img_name}_{self.opt["name"]}.png')
|
521 |
-
imwrite(sr_img, save_img_path)
|
522 |
-
|
523 |
-
if with_metrics:
|
524 |
-
# calculate metrics
|
525 |
-
for name, opt_ in self.opt['val']['metrics'].items():
|
526 |
-
metric_data = dict(img1=sr_img, img2=gt_img)
|
527 |
-
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
528 |
-
pbar.update(1)
|
529 |
-
pbar.set_description(f'Test {img_name}')
|
530 |
-
pbar.close()
|
531 |
-
|
532 |
-
if with_metrics:
|
533 |
-
for metric in self.metric_results.keys():
|
534 |
-
self.metric_results[metric] /= (idx + 1)
|
535 |
-
|
536 |
-
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
537 |
-
|
538 |
-
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
539 |
-
log_str = f'Validation {dataset_name}\n'
|
540 |
-
for metric, value in self.metric_results.items():
|
541 |
-
log_str += f'\t # {metric}: {value:.4f}\n'
|
542 |
-
logger = get_root_logger()
|
543 |
-
logger.info(log_str)
|
544 |
-
if tb_logger:
|
545 |
-
for metric, value in self.metric_results.items():
|
546 |
-
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
547 |
-
|
548 |
-
def get_current_visuals(self):
|
549 |
-
out_dict = OrderedDict()
|
550 |
-
out_dict['gt'] = self.gt.detach().cpu()
|
551 |
-
out_dict['sr'] = self.output.detach().cpu()
|
552 |
-
return out_dict
|
553 |
-
|
554 |
-
def save(self, epoch, current_iter):
|
555 |
-
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
556 |
-
self.save_network(self.net_d, 'net_d', current_iter)
|
557 |
-
# save component discriminators
|
558 |
-
if self.use_facial_disc:
|
559 |
-
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
560 |
-
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
561 |
-
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
562 |
-
self.save_training_state(epoch, current_iter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
facexlib
|
2 |
-
lmdb
|
3 |
-
numpy
|
4 |
-
opencv-python
|
5 |
-
pyyaml
|
6 |
-
tb-nightly
|
7 |
-
torch>=1.7
|
8 |
-
torchvision
|
9 |
-
tqdm
|
10 |
-
yapf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.cfg
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
[flake8]
|
2 |
-
ignore =
|
3 |
-
# line break before binary operator (W503)
|
4 |
-
W503,
|
5 |
-
# line break after binary operator (W504)
|
6 |
-
W504,
|
7 |
-
max-line-length=120
|
8 |
-
|
9 |
-
[yapf]
|
10 |
-
based_on_style = pep8
|
11 |
-
column_limit = 120
|
12 |
-
blank_line_before_nested_class_or_def = true
|
13 |
-
split_before_expression_after_opening_paren = true
|
14 |
-
|
15 |
-
[isort]
|
16 |
-
line_length = 120
|
17 |
-
multi_line_output = 0
|
18 |
-
known_standard_library = pkg_resources,setuptools
|
19 |
-
known_first_party = basicsr
|
20 |
-
known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm
|
21 |
-
no_lines_before = STDLIB,LOCALFOLDER
|
22 |
-
default_section = THIRDPARTY
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
import os.path as osp
|
2 |
-
|
3 |
-
import archs # noqa: F401
|
4 |
-
import data # noqa: F401
|
5 |
-
import models # noqa: F401
|
6 |
-
from basicsr.train import train_pipeline
|
7 |
-
|
8 |
-
if __name__ == '__main__':
|
9 |
-
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
10 |
-
train_pipeline(root_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_gfpgan_v1.yml
DELETED
@@ -1,210 +0,0 @@
|
|
1 |
-
# general settings
|
2 |
-
name: train_GFPGANv1_512
|
3 |
-
model_type: GFPGANModel
|
4 |
-
num_gpu: 4
|
5 |
-
manual_seed: 0
|
6 |
-
|
7 |
-
# dataset and data loader settings
|
8 |
-
datasets:
|
9 |
-
train:
|
10 |
-
name: FFHQ
|
11 |
-
type: FFHQDegradationDataset
|
12 |
-
# dataroot_gt: datasets/ffhq/ffhq_512.lmdb
|
13 |
-
dataroot_gt: datasets/ffhq/ffhq_512
|
14 |
-
io_backend:
|
15 |
-
# type: lmdb
|
16 |
-
type: disk
|
17 |
-
|
18 |
-
use_hflip: true
|
19 |
-
mean: [0.5, 0.5, 0.5]
|
20 |
-
std: [0.5, 0.5, 0.5]
|
21 |
-
out_size: 512
|
22 |
-
|
23 |
-
blur_kernel_size: 41
|
24 |
-
kernel_list: ['iso', 'aniso']
|
25 |
-
kernel_prob: [0.5, 0.5]
|
26 |
-
blur_sigma: [0.1, 10]
|
27 |
-
downsample_range: [0.8, 8]
|
28 |
-
noise_range: [0, 20]
|
29 |
-
jpeg_range: [60, 100]
|
30 |
-
|
31 |
-
# color jitter and gray
|
32 |
-
color_jitter_prob: 0.3
|
33 |
-
color_jitter_shift: 20
|
34 |
-
color_jitter_pt_prob: 0.3
|
35 |
-
gray_prob: 0.01
|
36 |
-
|
37 |
-
crop_components: true
|
38 |
-
component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
|
39 |
-
eye_enlarge_ratio: 1.4
|
40 |
-
|
41 |
-
# data loader
|
42 |
-
use_shuffle: true
|
43 |
-
num_worker_per_gpu: 6
|
44 |
-
batch_size_per_gpu: 3
|
45 |
-
dataset_enlarge_ratio: 100
|
46 |
-
prefetch_mode: ~
|
47 |
-
|
48 |
-
val:
|
49 |
-
# Please modify accordingly to use your own validation
|
50 |
-
# Or comment the val block if do not need validation during training
|
51 |
-
name: validation
|
52 |
-
type: PairedImageDataset
|
53 |
-
dataroot_lq: datasets/faces/validation/input
|
54 |
-
dataroot_gt: datasets/faces/validation/reference
|
55 |
-
io_backend:
|
56 |
-
type: disk
|
57 |
-
mean: [0.5, 0.5, 0.5]
|
58 |
-
std: [0.5, 0.5, 0.5]
|
59 |
-
scale: 1
|
60 |
-
|
61 |
-
# network structures
|
62 |
-
network_g:
|
63 |
-
type: GFPGANv1
|
64 |
-
out_size: 512
|
65 |
-
num_style_feat: 512
|
66 |
-
channel_multiplier: 1
|
67 |
-
resample_kernel: [1, 3, 3, 1]
|
68 |
-
decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
|
69 |
-
fix_decoder: true
|
70 |
-
num_mlp: 8
|
71 |
-
lr_mlp: 0.01
|
72 |
-
input_is_latent: true
|
73 |
-
different_w: true
|
74 |
-
narrow: 1
|
75 |
-
sft_half: true
|
76 |
-
|
77 |
-
network_d:
|
78 |
-
type: StyleGAN2Discriminator
|
79 |
-
out_size: 512
|
80 |
-
channel_multiplier: 1
|
81 |
-
resample_kernel: [1, 3, 3, 1]
|
82 |
-
|
83 |
-
network_d_left_eye:
|
84 |
-
type: FacialComponentDiscriminator
|
85 |
-
|
86 |
-
network_d_right_eye:
|
87 |
-
type: FacialComponentDiscriminator
|
88 |
-
|
89 |
-
network_d_mouth:
|
90 |
-
type: FacialComponentDiscriminator
|
91 |
-
|
92 |
-
network_identity:
|
93 |
-
type: ResNetArcFace
|
94 |
-
block: IRBlock
|
95 |
-
layers: [2, 2, 2, 2]
|
96 |
-
use_se: False
|
97 |
-
|
98 |
-
# path
|
99 |
-
path:
|
100 |
-
pretrain_network_g: ~
|
101 |
-
param_key_g: params_ema
|
102 |
-
strict_load_g: ~
|
103 |
-
pretrain_network_d: ~
|
104 |
-
pretrain_network_d_left_eye: ~
|
105 |
-
pretrain_network_d_right_eye: ~
|
106 |
-
pretrain_network_d_mouth: ~
|
107 |
-
pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
|
108 |
-
# resume
|
109 |
-
resume_state: ~
|
110 |
-
ignore_resume_networks: ['network_identity']
|
111 |
-
|
112 |
-
# training settings
|
113 |
-
train:
|
114 |
-
optim_g:
|
115 |
-
type: Adam
|
116 |
-
lr: !!float 2e-3
|
117 |
-
optim_d:
|
118 |
-
type: Adam
|
119 |
-
lr: !!float 2e-3
|
120 |
-
optim_component:
|
121 |
-
type: Adam
|
122 |
-
lr: !!float 2e-3
|
123 |
-
|
124 |
-
scheduler:
|
125 |
-
type: MultiStepLR
|
126 |
-
milestones: [600000, 700000]
|
127 |
-
gamma: 0.5
|
128 |
-
|
129 |
-
total_iter: 800000
|
130 |
-
warmup_iter: -1 # no warm up
|
131 |
-
|
132 |
-
# losses
|
133 |
-
# pixel loss
|
134 |
-
pixel_opt:
|
135 |
-
type: L1Loss
|
136 |
-
loss_weight: !!float 1e-1
|
137 |
-
reduction: mean
|
138 |
-
# L1 loss used in pyramid loss, component style loss and identity loss
|
139 |
-
L1_opt:
|
140 |
-
type: L1Loss
|
141 |
-
loss_weight: 1
|
142 |
-
reduction: mean
|
143 |
-
|
144 |
-
# image pyramid loss
|
145 |
-
pyramid_loss_weight: 1
|
146 |
-
remove_pyramid_loss: 50000
|
147 |
-
# perceptual loss (content and style losses)
|
148 |
-
perceptual_opt:
|
149 |
-
type: PerceptualLoss
|
150 |
-
layer_weights:
|
151 |
-
# before relu
|
152 |
-
'conv1_2': 0.1
|
153 |
-
'conv2_2': 0.1
|
154 |
-
'conv3_4': 1
|
155 |
-
'conv4_4': 1
|
156 |
-
'conv5_4': 1
|
157 |
-
vgg_type: vgg19
|
158 |
-
use_input_norm: true
|
159 |
-
perceptual_weight: !!float 1
|
160 |
-
style_weight: 50
|
161 |
-
range_norm: true
|
162 |
-
criterion: l1
|
163 |
-
# gan loss
|
164 |
-
gan_opt:
|
165 |
-
type: GANLoss
|
166 |
-
gan_type: wgan_softplus
|
167 |
-
loss_weight: !!float 1e-1
|
168 |
-
# r1 regularization for discriminator
|
169 |
-
r1_reg_weight: 10
|
170 |
-
# facial component loss
|
171 |
-
gan_component_opt:
|
172 |
-
type: GANLoss
|
173 |
-
gan_type: vanilla
|
174 |
-
real_label_val: 1.0
|
175 |
-
fake_label_val: 0.0
|
176 |
-
loss_weight: !!float 1
|
177 |
-
comp_style_weight: 200
|
178 |
-
# identity loss
|
179 |
-
identity_weight: 10
|
180 |
-
|
181 |
-
net_d_iters: 1
|
182 |
-
net_d_init_iters: 0
|
183 |
-
net_d_reg_every: 16
|
184 |
-
|
185 |
-
# validation settings
|
186 |
-
val:
|
187 |
-
val_freq: !!float 5e3
|
188 |
-
save_img: true
|
189 |
-
|
190 |
-
metrics:
|
191 |
-
psnr: # metric name, can be arbitrary
|
192 |
-
type: calculate_psnr
|
193 |
-
crop_border: 0
|
194 |
-
test_y_channel: false
|
195 |
-
|
196 |
-
# logging settings
|
197 |
-
logger:
|
198 |
-
print_freq: 100
|
199 |
-
save_checkpoint_freq: !!float 5e3
|
200 |
-
use_tb_logger: true
|
201 |
-
wandb:
|
202 |
-
project: ~
|
203 |
-
resume_id: ~
|
204 |
-
|
205 |
-
# dist training settings
|
206 |
-
dist_params:
|
207 |
-
backend: nccl
|
208 |
-
port: 29500
|
209 |
-
|
210 |
-
find_unused_parameters: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|