Spanicin commited on
Commit
780dcf9
1 Parent(s): 4b96359

Update videoretalking/models/base_blocks.py

Browse files
Files changed (1) hide show
  1. videoretalking/models/base_blocks.py +553 -553
videoretalking/models/base_blocks.py CHANGED
@@ -1,554 +1,554 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from torch.nn.modules.batchnorm import BatchNorm2d
6
- from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
7
-
8
- from models.ffc import FFC
9
- from basicsr.archs.arch_util import default_init_weights
10
-
11
-
12
- class Conv2d(nn.Module):
13
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
14
- super().__init__(*args, **kwargs)
15
- self.conv_block = nn.Sequential(
16
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
17
- nn.BatchNorm2d(cout)
18
- )
19
- self.act = nn.ReLU()
20
- self.residual = residual
21
-
22
- def forward(self, x):
23
- out = self.conv_block(x)
24
- if self.residual:
25
- out += x
26
- return self.act(out)
27
-
28
-
29
- class ResBlock(nn.Module):
30
- def __init__(self, in_channels, out_channels, mode='down'):
31
- super(ResBlock, self).__init__()
32
- self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
33
- self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
34
- self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
35
- if mode == 'down':
36
- self.scale_factor = 0.5
37
- elif mode == 'up':
38
- self.scale_factor = 2
39
-
40
- def forward(self, x):
41
- out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
42
- # upsample/downsample
43
- out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
44
- out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
45
- # skip
46
- x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
47
- skip = self.skip(x)
48
- out = out + skip
49
- return out
50
-
51
-
52
- class LayerNorm2d(nn.Module):
53
- def __init__(self, n_out, affine=True):
54
- super(LayerNorm2d, self).__init__()
55
- self.n_out = n_out
56
- self.affine = affine
57
-
58
- if self.affine:
59
- self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
60
- self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
61
-
62
- def forward(self, x):
63
- normalized_shape = x.size()[1:]
64
- if self.affine:
65
- return F.layer_norm(x, normalized_shape, \
66
- self.weight.expand(normalized_shape),
67
- self.bias.expand(normalized_shape))
68
- else:
69
- return F.layer_norm(x, normalized_shape)
70
-
71
-
72
- def spectral_norm(module, use_spect=True):
73
- if use_spect:
74
- return SpectralNorm(module)
75
- else:
76
- return module
77
-
78
-
79
- class FirstBlock2d(nn.Module):
80
- def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
81
- super(FirstBlock2d, self).__init__()
82
- kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
83
- conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
84
-
85
- if type(norm_layer) == type(None):
86
- self.model = nn.Sequential(conv, nonlinearity)
87
- else:
88
- self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
89
-
90
- def forward(self, x):
91
- out = self.model(x)
92
- return out
93
-
94
-
95
- class DownBlock2d(nn.Module):
96
- def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
97
- super(DownBlock2d, self).__init__()
98
- kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
99
- conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
100
- pool = nn.AvgPool2d(kernel_size=(2, 2))
101
-
102
- if type(norm_layer) == type(None):
103
- self.model = nn.Sequential(conv, nonlinearity, pool)
104
- else:
105
- self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
106
-
107
- def forward(self, x):
108
- out = self.model(x)
109
- return out
110
-
111
-
112
- class UpBlock2d(nn.Module):
113
- def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
114
- super(UpBlock2d, self).__init__()
115
- kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
116
- conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
117
- if type(norm_layer) == type(None):
118
- self.model = nn.Sequential(conv, nonlinearity)
119
- else:
120
- self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
121
-
122
- def forward(self, x):
123
- out = self.model(F.interpolate(x, scale_factor=2))
124
- return out
125
-
126
-
127
- class ADAIN(nn.Module):
128
- def __init__(self, norm_nc, feature_nc):
129
- super().__init__()
130
-
131
- self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
132
-
133
- nhidden = 128
134
- use_bias=True
135
-
136
- self.mlp_shared = nn.Sequential(
137
- nn.Linear(feature_nc, nhidden, bias=use_bias),
138
- nn.ReLU()
139
- )
140
- self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
141
- self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
142
-
143
- def forward(self, x, feature):
144
-
145
- # Part 1. generate parameter-free normalized activations
146
- normalized = self.param_free_norm(x)
147
- # Part 2. produce scaling and bias conditioned on feature
148
- feature = feature.view(feature.size(0), -1)
149
- actv = self.mlp_shared(feature)
150
- gamma = self.mlp_gamma(actv)
151
- beta = self.mlp_beta(actv)
152
-
153
- # apply scale and bias
154
- gamma = gamma.view(*gamma.size()[:2], 1,1)
155
- beta = beta.view(*beta.size()[:2], 1,1)
156
- out = normalized * (1 + gamma) + beta
157
- return out
158
-
159
-
160
- class FineADAINResBlock2d(nn.Module):
161
- """
162
- Define an Residual block for different types
163
- """
164
- def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
165
- super(FineADAINResBlock2d, self).__init__()
166
- kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
167
- self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
168
- self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
169
- self.norm1 = ADAIN(input_nc, feature_nc)
170
- self.norm2 = ADAIN(input_nc, feature_nc)
171
- self.actvn = nonlinearity
172
-
173
- def forward(self, x, z):
174
- dx = self.actvn(self.norm1(self.conv1(x), z))
175
- dx = self.norm2(self.conv2(x), z)
176
- out = dx + x
177
- return out
178
-
179
-
180
- class FineADAINResBlocks(nn.Module):
181
- def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
182
- super(FineADAINResBlocks, self).__init__()
183
- self.num_block = num_block
184
- for i in range(num_block):
185
- model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
186
- setattr(self, 'res'+str(i), model)
187
-
188
- def forward(self, x, z):
189
- for i in range(self.num_block):
190
- model = getattr(self, 'res'+str(i))
191
- x = model(x, z)
192
- return x
193
-
194
-
195
- class ADAINEncoderBlock(nn.Module):
196
- def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
197
- super(ADAINEncoderBlock, self).__init__()
198
- kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
199
- kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
200
-
201
- self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
202
- self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
203
-
204
-
205
- self.norm_0 = ADAIN(input_nc, feature_nc)
206
- self.norm_1 = ADAIN(output_nc, feature_nc)
207
- self.actvn = nonlinearity
208
-
209
- def forward(self, x, z):
210
- x = self.conv_0(self.actvn(self.norm_0(x, z)))
211
- x = self.conv_1(self.actvn(self.norm_1(x, z)))
212
- return x
213
-
214
-
215
- class ADAINDecoderBlock(nn.Module):
216
- def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
217
- super(ADAINDecoderBlock, self).__init__()
218
- # Attributes
219
- self.actvn = nonlinearity
220
- hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
221
-
222
- kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
223
- if use_transpose:
224
- kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
225
- else:
226
- kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
227
-
228
- # create conv layers
229
- self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
230
- if use_transpose:
231
- self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
232
- self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
233
- else:
234
- self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
235
- nn.Upsample(scale_factor=2))
236
- self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
237
- nn.Upsample(scale_factor=2))
238
- # define normalization layers
239
- self.norm_0 = ADAIN(input_nc, feature_nc)
240
- self.norm_1 = ADAIN(hidden_nc, feature_nc)
241
- self.norm_s = ADAIN(input_nc, feature_nc)
242
-
243
- def forward(self, x, z):
244
- x_s = self.shortcut(x, z)
245
- dx = self.conv_0(self.actvn(self.norm_0(x, z)))
246
- dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
247
- out = x_s + dx
248
- return out
249
-
250
- def shortcut(self, x, z):
251
- x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
252
- return x_s
253
-
254
-
255
- class FineEncoder(nn.Module):
256
- """docstring for Encoder"""
257
- def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
258
- super(FineEncoder, self).__init__()
259
- self.layers = layers
260
- self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
261
- for i in range(layers):
262
- in_channels = min(ngf*(2**i), img_f)
263
- out_channels = min(ngf*(2**(i+1)), img_f)
264
- model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
265
- setattr(self, 'down' + str(i), model)
266
- self.output_nc = out_channels
267
-
268
- def forward(self, x):
269
- x = self.first(x)
270
- out=[x]
271
- for i in range(self.layers):
272
- model = getattr(self, 'down'+str(i))
273
- x = model(x)
274
- out.append(x)
275
- return out
276
-
277
-
278
- class FineDecoder(nn.Module):
279
- """docstring for FineDecoder"""
280
- def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
281
- super(FineDecoder, self).__init__()
282
- self.layers = layers
283
- for i in range(layers)[::-1]:
284
- in_channels = min(ngf*(2**(i+1)), img_f)
285
- out_channels = min(ngf*(2**i), img_f)
286
- up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
287
- res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
288
- jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
289
- setattr(self, 'up' + str(i), up)
290
- setattr(self, 'res' + str(i), res)
291
- setattr(self, 'jump' + str(i), jump)
292
- self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
293
- self.output_nc = out_channels
294
-
295
- def forward(self, x, z):
296
- out = x.pop()
297
- for i in range(self.layers)[::-1]:
298
- res_model = getattr(self, 'res' + str(i))
299
- up_model = getattr(self, 'up' + str(i))
300
- jump_model = getattr(self, 'jump' + str(i))
301
- out = res_model(out, z)
302
- out = up_model(out)
303
- out = jump_model(x.pop()) + out
304
- out_image = self.final(out)
305
- return out_image
306
-
307
-
308
- class ADAINEncoder(nn.Module):
309
- def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
310
- super(ADAINEncoder, self).__init__()
311
- self.layers = layers
312
- self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
313
- for i in range(layers):
314
- in_channels = min(ngf * (2**i), img_f)
315
- out_channels = min(ngf *(2**(i+1)), img_f)
316
- model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
317
- setattr(self, 'encoder' + str(i), model)
318
- self.output_nc = out_channels
319
-
320
- def forward(self, x, z):
321
- out = self.input_layer(x)
322
- out_list = [out]
323
- for i in range(self.layers):
324
- model = getattr(self, 'encoder' + str(i))
325
- out = model(out, z)
326
- out_list.append(out)
327
- return out_list
328
-
329
-
330
- class ADAINDecoder(nn.Module):
331
- """docstring for ADAINDecoder"""
332
- def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
333
- nonlinearity=nn.LeakyReLU(), use_spect=False):
334
-
335
- super(ADAINDecoder, self).__init__()
336
- self.encoder_layers = encoder_layers
337
- self.decoder_layers = decoder_layers
338
- self.skip_connect = skip_connect
339
- use_transpose = True
340
- for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
341
- in_channels = min(ngf * (2**(i+1)), img_f)
342
- in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
343
- out_channels = min(ngf * (2**i), img_f)
344
- model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
345
- setattr(self, 'decoder' + str(i), model)
346
- self.output_nc = out_channels*2 if self.skip_connect else out_channels
347
-
348
- def forward(self, x, z):
349
- out = x.pop() if self.skip_connect else x
350
- for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
351
- model = getattr(self, 'decoder' + str(i))
352
- out = model(out, z)
353
- out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
354
- return out
355
-
356
-
357
- class ADAINHourglass(nn.Module):
358
- def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
359
- super(ADAINHourglass, self).__init__()
360
- self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
361
- self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
362
- self.output_nc = self.decoder.output_nc
363
-
364
- def forward(self, x, z):
365
- return self.decoder(self.encoder(x, z), z)
366
-
367
-
368
- class FineADAINLama(nn.Module):
369
- def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
370
- super(FineADAINLama, self).__init__()
371
- kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
372
- self.actvn = nonlinearity
373
- ratio_gin = 0.75
374
- ratio_gout = 0.75
375
- self.ffc = FFC(input_nc, input_nc, 3,
376
- ratio_gin, ratio_gout, 1, 1, 1,
377
- 1, False, False, padding_type='reflect')
378
- global_channels = int(input_nc * ratio_gout)
379
- self.bn_l = ADAIN(input_nc - global_channels, feature_nc)
380
- self.bn_g = ADAIN(global_channels, feature_nc)
381
-
382
- def forward(self, x, z):
383
- x_l, x_g = self.ffc(x)
384
- x_l = self.actvn(self.bn_l(x_l,z))
385
- x_g = self.actvn(self.bn_g(x_g,z))
386
- return x_l, x_g
387
-
388
-
389
- class FFCResnetBlock(nn.Module):
390
- def __init__(self, dim, feature_dim, padding_type='reflect', norm_layer=BatchNorm2d, activation_layer=nn.ReLU, dilation=1,
391
- spatial_transform_kwargs=None, inline=False, **conv_kwargs):
392
- super().__init__()
393
- self.conv1 = FineADAINLama(dim, feature_dim, **conv_kwargs)
394
- self.conv2 = FineADAINLama(dim, feature_dim, **conv_kwargs)
395
- self.inline = True
396
-
397
- def forward(self, x, z):
398
- if self.inline:
399
- x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
400
- else:
401
- x_l, x_g = x if type(x) is tuple else (x, 0)
402
-
403
- id_l, id_g = x_l, x_g
404
- x_l, x_g = self.conv1((x_l, x_g), z)
405
- x_l, x_g = self.conv2((x_l, x_g), z)
406
-
407
- x_l, x_g = id_l + x_l, id_g + x_g
408
- out = x_l, x_g
409
- if self.inline:
410
- out = torch.cat(out, dim=1)
411
- return out
412
-
413
-
414
- class FFCADAINResBlocks(nn.Module):
415
- def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
416
- super(FFCADAINResBlocks, self).__init__()
417
- self.num_block = num_block
418
- for i in range(num_block):
419
- model = FFCResnetBlock(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
420
- setattr(self, 'res'+str(i), model)
421
-
422
- def forward(self, x, z):
423
- for i in range(self.num_block):
424
- model = getattr(self, 'res'+str(i))
425
- x = model(x, z)
426
- return x
427
-
428
-
429
- class Jump(nn.Module):
430
- def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
431
- super(Jump, self).__init__()
432
- kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
433
- conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
434
- if type(norm_layer) == type(None):
435
- self.model = nn.Sequential(conv, nonlinearity)
436
- else:
437
- self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
438
-
439
- def forward(self, x):
440
- out = self.model(x)
441
- return out
442
-
443
-
444
- class FinalBlock2d(nn.Module):
445
- def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
446
- super(FinalBlock2d, self).__init__()
447
- kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
448
- conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
449
- if tanh_or_sigmoid == 'sigmoid':
450
- out_nonlinearity = nn.Sigmoid()
451
- else:
452
- out_nonlinearity = nn.Tanh()
453
- self.model = nn.Sequential(conv, out_nonlinearity)
454
-
455
- def forward(self, x):
456
- out = self.model(x)
457
- return out
458
-
459
-
460
- class ModulatedConv2d(nn.Module):
461
- def __init__(self,
462
- in_channels,
463
- out_channels,
464
- kernel_size,
465
- num_style_feat,
466
- demodulate=True,
467
- sample_mode=None,
468
- eps=1e-8):
469
- super(ModulatedConv2d, self).__init__()
470
- self.in_channels = in_channels
471
- self.out_channels = out_channels
472
- self.kernel_size = kernel_size
473
- self.demodulate = demodulate
474
- self.sample_mode = sample_mode
475
- self.eps = eps
476
-
477
- # modulation inside each modulated conv
478
- self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
479
- # initialization
480
- default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
481
-
482
- self.weight = nn.Parameter(
483
- torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
484
- math.sqrt(in_channels * kernel_size**2))
485
- self.padding = kernel_size // 2
486
-
487
- def forward(self, x, style):
488
- b, c, h, w = x.shape
489
- style = self.modulation(style).view(b, 1, c, 1, 1)
490
- weight = self.weight * style
491
-
492
- if self.demodulate:
493
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
494
- weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
495
-
496
- weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
497
-
498
- # upsample or downsample if necessary
499
- if self.sample_mode == 'upsample':
500
- x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
501
- elif self.sample_mode == 'downsample':
502
- x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
503
-
504
- b, c, h, w = x.shape
505
- x = x.view(1, b * c, h, w)
506
- out = F.conv2d(x, weight, padding=self.padding, groups=b)
507
- out = out.view(b, self.out_channels, *out.shape[2:4])
508
- return out
509
-
510
- def __repr__(self):
511
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
512
- f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
513
-
514
-
515
- class StyleConv(nn.Module):
516
- def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
517
- super(StyleConv, self).__init__()
518
- self.modulated_conv = ModulatedConv2d(
519
- in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
520
- self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
521
- self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
522
- self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
523
-
524
- def forward(self, x, style, noise=None):
525
- # modulate
526
- out = self.modulated_conv(x, style) * 2**0.5 # for conversion
527
- # noise injection
528
- if noise is None:
529
- b, _, h, w = out.shape
530
- noise = out.new_empty(b, 1, h, w).normal_()
531
- out = out + self.weight * noise
532
- # add bias
533
- out = out + self.bias
534
- # activation
535
- out = self.activate(out)
536
- return out
537
-
538
-
539
- class ToRGB(nn.Module):
540
- def __init__(self, in_channels, num_style_feat, upsample=True):
541
- super(ToRGB, self).__init__()
542
- self.upsample = upsample
543
- self.modulated_conv = ModulatedConv2d(
544
- in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
545
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
546
-
547
- def forward(self, x, style, skip=None):
548
- out = self.modulated_conv(x, style)
549
- out = out + self.bias
550
- if skip is not None:
551
- if self.upsample:
552
- skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
553
- out = out + skip
554
  return out
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.modules.batchnorm import BatchNorm2d
6
+ from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
7
+
8
+ from videoretalking.models.ffc import FFC
9
+ from basicsr.archs.arch_util import default_init_weights
10
+
11
+
12
+ class Conv2d(nn.Module):
13
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ self.conv_block = nn.Sequential(
16
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
17
+ nn.BatchNorm2d(cout)
18
+ )
19
+ self.act = nn.ReLU()
20
+ self.residual = residual
21
+
22
+ def forward(self, x):
23
+ out = self.conv_block(x)
24
+ if self.residual:
25
+ out += x
26
+ return self.act(out)
27
+
28
+
29
+ class ResBlock(nn.Module):
30
+ def __init__(self, in_channels, out_channels, mode='down'):
31
+ super(ResBlock, self).__init__()
32
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
33
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
34
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
35
+ if mode == 'down':
36
+ self.scale_factor = 0.5
37
+ elif mode == 'up':
38
+ self.scale_factor = 2
39
+
40
+ def forward(self, x):
41
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
42
+ # upsample/downsample
43
+ out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
44
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
45
+ # skip
46
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
47
+ skip = self.skip(x)
48
+ out = out + skip
49
+ return out
50
+
51
+
52
+ class LayerNorm2d(nn.Module):
53
+ def __init__(self, n_out, affine=True):
54
+ super(LayerNorm2d, self).__init__()
55
+ self.n_out = n_out
56
+ self.affine = affine
57
+
58
+ if self.affine:
59
+ self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
60
+ self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
61
+
62
+ def forward(self, x):
63
+ normalized_shape = x.size()[1:]
64
+ if self.affine:
65
+ return F.layer_norm(x, normalized_shape, \
66
+ self.weight.expand(normalized_shape),
67
+ self.bias.expand(normalized_shape))
68
+ else:
69
+ return F.layer_norm(x, normalized_shape)
70
+
71
+
72
+ def spectral_norm(module, use_spect=True):
73
+ if use_spect:
74
+ return SpectralNorm(module)
75
+ else:
76
+ return module
77
+
78
+
79
+ class FirstBlock2d(nn.Module):
80
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
81
+ super(FirstBlock2d, self).__init__()
82
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
83
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
84
+
85
+ if type(norm_layer) == type(None):
86
+ self.model = nn.Sequential(conv, nonlinearity)
87
+ else:
88
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
89
+
90
+ def forward(self, x):
91
+ out = self.model(x)
92
+ return out
93
+
94
+
95
+ class DownBlock2d(nn.Module):
96
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
97
+ super(DownBlock2d, self).__init__()
98
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
99
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
100
+ pool = nn.AvgPool2d(kernel_size=(2, 2))
101
+
102
+ if type(norm_layer) == type(None):
103
+ self.model = nn.Sequential(conv, nonlinearity, pool)
104
+ else:
105
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
106
+
107
+ def forward(self, x):
108
+ out = self.model(x)
109
+ return out
110
+
111
+
112
+ class UpBlock2d(nn.Module):
113
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
114
+ super(UpBlock2d, self).__init__()
115
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
116
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
117
+ if type(norm_layer) == type(None):
118
+ self.model = nn.Sequential(conv, nonlinearity)
119
+ else:
120
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
121
+
122
+ def forward(self, x):
123
+ out = self.model(F.interpolate(x, scale_factor=2))
124
+ return out
125
+
126
+
127
+ class ADAIN(nn.Module):
128
+ def __init__(self, norm_nc, feature_nc):
129
+ super().__init__()
130
+
131
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
132
+
133
+ nhidden = 128
134
+ use_bias=True
135
+
136
+ self.mlp_shared = nn.Sequential(
137
+ nn.Linear(feature_nc, nhidden, bias=use_bias),
138
+ nn.ReLU()
139
+ )
140
+ self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
141
+ self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
142
+
143
+ def forward(self, x, feature):
144
+
145
+ # Part 1. generate parameter-free normalized activations
146
+ normalized = self.param_free_norm(x)
147
+ # Part 2. produce scaling and bias conditioned on feature
148
+ feature = feature.view(feature.size(0), -1)
149
+ actv = self.mlp_shared(feature)
150
+ gamma = self.mlp_gamma(actv)
151
+ beta = self.mlp_beta(actv)
152
+
153
+ # apply scale and bias
154
+ gamma = gamma.view(*gamma.size()[:2], 1,1)
155
+ beta = beta.view(*beta.size()[:2], 1,1)
156
+ out = normalized * (1 + gamma) + beta
157
+ return out
158
+
159
+
160
+ class FineADAINResBlock2d(nn.Module):
161
+ """
162
+ Define an Residual block for different types
163
+ """
164
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
165
+ super(FineADAINResBlock2d, self).__init__()
166
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
167
+ self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
168
+ self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
169
+ self.norm1 = ADAIN(input_nc, feature_nc)
170
+ self.norm2 = ADAIN(input_nc, feature_nc)
171
+ self.actvn = nonlinearity
172
+
173
+ def forward(self, x, z):
174
+ dx = self.actvn(self.norm1(self.conv1(x), z))
175
+ dx = self.norm2(self.conv2(x), z)
176
+ out = dx + x
177
+ return out
178
+
179
+
180
+ class FineADAINResBlocks(nn.Module):
181
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
182
+ super(FineADAINResBlocks, self).__init__()
183
+ self.num_block = num_block
184
+ for i in range(num_block):
185
+ model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
186
+ setattr(self, 'res'+str(i), model)
187
+
188
+ def forward(self, x, z):
189
+ for i in range(self.num_block):
190
+ model = getattr(self, 'res'+str(i))
191
+ x = model(x, z)
192
+ return x
193
+
194
+
195
+ class ADAINEncoderBlock(nn.Module):
196
+ def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
197
+ super(ADAINEncoderBlock, self).__init__()
198
+ kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
199
+ kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
200
+
201
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
202
+ self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
203
+
204
+
205
+ self.norm_0 = ADAIN(input_nc, feature_nc)
206
+ self.norm_1 = ADAIN(output_nc, feature_nc)
207
+ self.actvn = nonlinearity
208
+
209
+ def forward(self, x, z):
210
+ x = self.conv_0(self.actvn(self.norm_0(x, z)))
211
+ x = self.conv_1(self.actvn(self.norm_1(x, z)))
212
+ return x
213
+
214
+
215
+ class ADAINDecoderBlock(nn.Module):
216
+ def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
217
+ super(ADAINDecoderBlock, self).__init__()
218
+ # Attributes
219
+ self.actvn = nonlinearity
220
+ hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
221
+
222
+ kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
223
+ if use_transpose:
224
+ kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
225
+ else:
226
+ kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
227
+
228
+ # create conv layers
229
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
230
+ if use_transpose:
231
+ self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
232
+ self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
233
+ else:
234
+ self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
235
+ nn.Upsample(scale_factor=2))
236
+ self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
237
+ nn.Upsample(scale_factor=2))
238
+ # define normalization layers
239
+ self.norm_0 = ADAIN(input_nc, feature_nc)
240
+ self.norm_1 = ADAIN(hidden_nc, feature_nc)
241
+ self.norm_s = ADAIN(input_nc, feature_nc)
242
+
243
+ def forward(self, x, z):
244
+ x_s = self.shortcut(x, z)
245
+ dx = self.conv_0(self.actvn(self.norm_0(x, z)))
246
+ dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
247
+ out = x_s + dx
248
+ return out
249
+
250
+ def shortcut(self, x, z):
251
+ x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
252
+ return x_s
253
+
254
+
255
+ class FineEncoder(nn.Module):
256
+ """docstring for Encoder"""
257
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
258
+ super(FineEncoder, self).__init__()
259
+ self.layers = layers
260
+ self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
261
+ for i in range(layers):
262
+ in_channels = min(ngf*(2**i), img_f)
263
+ out_channels = min(ngf*(2**(i+1)), img_f)
264
+ model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
265
+ setattr(self, 'down' + str(i), model)
266
+ self.output_nc = out_channels
267
+
268
+ def forward(self, x):
269
+ x = self.first(x)
270
+ out=[x]
271
+ for i in range(self.layers):
272
+ model = getattr(self, 'down'+str(i))
273
+ x = model(x)
274
+ out.append(x)
275
+ return out
276
+
277
+
278
+ class FineDecoder(nn.Module):
279
+ """docstring for FineDecoder"""
280
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
281
+ super(FineDecoder, self).__init__()
282
+ self.layers = layers
283
+ for i in range(layers)[::-1]:
284
+ in_channels = min(ngf*(2**(i+1)), img_f)
285
+ out_channels = min(ngf*(2**i), img_f)
286
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
287
+ res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
288
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
289
+ setattr(self, 'up' + str(i), up)
290
+ setattr(self, 'res' + str(i), res)
291
+ setattr(self, 'jump' + str(i), jump)
292
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
293
+ self.output_nc = out_channels
294
+
295
+ def forward(self, x, z):
296
+ out = x.pop()
297
+ for i in range(self.layers)[::-1]:
298
+ res_model = getattr(self, 'res' + str(i))
299
+ up_model = getattr(self, 'up' + str(i))
300
+ jump_model = getattr(self, 'jump' + str(i))
301
+ out = res_model(out, z)
302
+ out = up_model(out)
303
+ out = jump_model(x.pop()) + out
304
+ out_image = self.final(out)
305
+ return out_image
306
+
307
+
308
+ class ADAINEncoder(nn.Module):
309
+ def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
310
+ super(ADAINEncoder, self).__init__()
311
+ self.layers = layers
312
+ self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
313
+ for i in range(layers):
314
+ in_channels = min(ngf * (2**i), img_f)
315
+ out_channels = min(ngf *(2**(i+1)), img_f)
316
+ model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
317
+ setattr(self, 'encoder' + str(i), model)
318
+ self.output_nc = out_channels
319
+
320
+ def forward(self, x, z):
321
+ out = self.input_layer(x)
322
+ out_list = [out]
323
+ for i in range(self.layers):
324
+ model = getattr(self, 'encoder' + str(i))
325
+ out = model(out, z)
326
+ out_list.append(out)
327
+ return out_list
328
+
329
+
330
+ class ADAINDecoder(nn.Module):
331
+ """docstring for ADAINDecoder"""
332
+ def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
333
+ nonlinearity=nn.LeakyReLU(), use_spect=False):
334
+
335
+ super(ADAINDecoder, self).__init__()
336
+ self.encoder_layers = encoder_layers
337
+ self.decoder_layers = decoder_layers
338
+ self.skip_connect = skip_connect
339
+ use_transpose = True
340
+ for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
341
+ in_channels = min(ngf * (2**(i+1)), img_f)
342
+ in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
343
+ out_channels = min(ngf * (2**i), img_f)
344
+ model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
345
+ setattr(self, 'decoder' + str(i), model)
346
+ self.output_nc = out_channels*2 if self.skip_connect else out_channels
347
+
348
+ def forward(self, x, z):
349
+ out = x.pop() if self.skip_connect else x
350
+ for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
351
+ model = getattr(self, 'decoder' + str(i))
352
+ out = model(out, z)
353
+ out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
354
+ return out
355
+
356
+
357
+ class ADAINHourglass(nn.Module):
358
+ def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
359
+ super(ADAINHourglass, self).__init__()
360
+ self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
361
+ self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
362
+ self.output_nc = self.decoder.output_nc
363
+
364
+ def forward(self, x, z):
365
+ return self.decoder(self.encoder(x, z), z)
366
+
367
+
368
+ class FineADAINLama(nn.Module):
369
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
370
+ super(FineADAINLama, self).__init__()
371
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
372
+ self.actvn = nonlinearity
373
+ ratio_gin = 0.75
374
+ ratio_gout = 0.75
375
+ self.ffc = FFC(input_nc, input_nc, 3,
376
+ ratio_gin, ratio_gout, 1, 1, 1,
377
+ 1, False, False, padding_type='reflect')
378
+ global_channels = int(input_nc * ratio_gout)
379
+ self.bn_l = ADAIN(input_nc - global_channels, feature_nc)
380
+ self.bn_g = ADAIN(global_channels, feature_nc)
381
+
382
+ def forward(self, x, z):
383
+ x_l, x_g = self.ffc(x)
384
+ x_l = self.actvn(self.bn_l(x_l,z))
385
+ x_g = self.actvn(self.bn_g(x_g,z))
386
+ return x_l, x_g
387
+
388
+
389
+ class FFCResnetBlock(nn.Module):
390
+ def __init__(self, dim, feature_dim, padding_type='reflect', norm_layer=BatchNorm2d, activation_layer=nn.ReLU, dilation=1,
391
+ spatial_transform_kwargs=None, inline=False, **conv_kwargs):
392
+ super().__init__()
393
+ self.conv1 = FineADAINLama(dim, feature_dim, **conv_kwargs)
394
+ self.conv2 = FineADAINLama(dim, feature_dim, **conv_kwargs)
395
+ self.inline = True
396
+
397
+ def forward(self, x, z):
398
+ if self.inline:
399
+ x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
400
+ else:
401
+ x_l, x_g = x if type(x) is tuple else (x, 0)
402
+
403
+ id_l, id_g = x_l, x_g
404
+ x_l, x_g = self.conv1((x_l, x_g), z)
405
+ x_l, x_g = self.conv2((x_l, x_g), z)
406
+
407
+ x_l, x_g = id_l + x_l, id_g + x_g
408
+ out = x_l, x_g
409
+ if self.inline:
410
+ out = torch.cat(out, dim=1)
411
+ return out
412
+
413
+
414
+ class FFCADAINResBlocks(nn.Module):
415
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
416
+ super(FFCADAINResBlocks, self).__init__()
417
+ self.num_block = num_block
418
+ for i in range(num_block):
419
+ model = FFCResnetBlock(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
420
+ setattr(self, 'res'+str(i), model)
421
+
422
+ def forward(self, x, z):
423
+ for i in range(self.num_block):
424
+ model = getattr(self, 'res'+str(i))
425
+ x = model(x, z)
426
+ return x
427
+
428
+
429
+ class Jump(nn.Module):
430
+ def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
431
+ super(Jump, self).__init__()
432
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
433
+ conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
434
+ if type(norm_layer) == type(None):
435
+ self.model = nn.Sequential(conv, nonlinearity)
436
+ else:
437
+ self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
438
+
439
+ def forward(self, x):
440
+ out = self.model(x)
441
+ return out
442
+
443
+
444
+ class FinalBlock2d(nn.Module):
445
+ def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
446
+ super(FinalBlock2d, self).__init__()
447
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
448
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
449
+ if tanh_or_sigmoid == 'sigmoid':
450
+ out_nonlinearity = nn.Sigmoid()
451
+ else:
452
+ out_nonlinearity = nn.Tanh()
453
+ self.model = nn.Sequential(conv, out_nonlinearity)
454
+
455
+ def forward(self, x):
456
+ out = self.model(x)
457
+ return out
458
+
459
+
460
+ class ModulatedConv2d(nn.Module):
461
+ def __init__(self,
462
+ in_channels,
463
+ out_channels,
464
+ kernel_size,
465
+ num_style_feat,
466
+ demodulate=True,
467
+ sample_mode=None,
468
+ eps=1e-8):
469
+ super(ModulatedConv2d, self).__init__()
470
+ self.in_channels = in_channels
471
+ self.out_channels = out_channels
472
+ self.kernel_size = kernel_size
473
+ self.demodulate = demodulate
474
+ self.sample_mode = sample_mode
475
+ self.eps = eps
476
+
477
+ # modulation inside each modulated conv
478
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
479
+ # initialization
480
+ default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
481
+
482
+ self.weight = nn.Parameter(
483
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
484
+ math.sqrt(in_channels * kernel_size**2))
485
+ self.padding = kernel_size // 2
486
+
487
+ def forward(self, x, style):
488
+ b, c, h, w = x.shape
489
+ style = self.modulation(style).view(b, 1, c, 1, 1)
490
+ weight = self.weight * style
491
+
492
+ if self.demodulate:
493
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
494
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
495
+
496
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
497
+
498
+ # upsample or downsample if necessary
499
+ if self.sample_mode == 'upsample':
500
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
501
+ elif self.sample_mode == 'downsample':
502
+ x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
503
+
504
+ b, c, h, w = x.shape
505
+ x = x.view(1, b * c, h, w)
506
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
507
+ out = out.view(b, self.out_channels, *out.shape[2:4])
508
+ return out
509
+
510
+ def __repr__(self):
511
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
512
+ f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
513
+
514
+
515
+ class StyleConv(nn.Module):
516
+ def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
517
+ super(StyleConv, self).__init__()
518
+ self.modulated_conv = ModulatedConv2d(
519
+ in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
520
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
521
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
522
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
523
+
524
+ def forward(self, x, style, noise=None):
525
+ # modulate
526
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
527
+ # noise injection
528
+ if noise is None:
529
+ b, _, h, w = out.shape
530
+ noise = out.new_empty(b, 1, h, w).normal_()
531
+ out = out + self.weight * noise
532
+ # add bias
533
+ out = out + self.bias
534
+ # activation
535
+ out = self.activate(out)
536
+ return out
537
+
538
+
539
+ class ToRGB(nn.Module):
540
+ def __init__(self, in_channels, num_style_feat, upsample=True):
541
+ super(ToRGB, self).__init__()
542
+ self.upsample = upsample
543
+ self.modulated_conv = ModulatedConv2d(
544
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
545
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
546
+
547
+ def forward(self, x, style, skip=None):
548
+ out = self.modulated_conv(x, style)
549
+ out = out + self.bias
550
+ if skip is not None:
551
+ if self.upsample:
552
+ skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
553
+ out = out + skip
554
  return out