Spanicin commited on
Commit
35f2e7d
1 Parent(s): af60c81

Update videoretalking/third_part/GPEN/face_model/gpen_model.py

Browse files
videoretalking/third_part/GPEN/face_model/gpen_model.py CHANGED
@@ -1,746 +1,746 @@
1
- '''
2
- @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3
- @author: yangxy (yangtao9009@gmail.com)
4
- '''
5
- import math
6
- import random
7
- import functools
8
- import operator
9
- import itertools
10
-
11
- import torch
12
- from torch import nn
13
- from torch.nn import functional as F
14
- from torch.autograd import Function
15
-
16
- from face_model.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
17
-
18
- class PixelNorm(nn.Module):
19
- def __init__(self):
20
- super().__init__()
21
-
22
- def forward(self, input):
23
- return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
24
-
25
-
26
- def make_kernel(k):
27
- k = torch.tensor(k, dtype=torch.float32)
28
-
29
- if k.ndim == 1:
30
- k = k[None, :] * k[:, None]
31
-
32
- k /= k.sum()
33
-
34
- return k
35
-
36
-
37
- class Upsample(nn.Module):
38
- def __init__(self, kernel, factor=2, device='cpu'):
39
- super().__init__()
40
-
41
- self.factor = factor
42
- kernel = make_kernel(kernel) * (factor ** 2)
43
- self.register_buffer('kernel', kernel)
44
-
45
- p = kernel.shape[0] - factor
46
-
47
- pad0 = (p + 1) // 2 + factor - 1
48
- pad1 = p // 2
49
-
50
- self.pad = (pad0, pad1)
51
- self.device = device
52
-
53
- def forward(self, input):
54
- out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad, device=self.device)
55
-
56
- return out
57
-
58
-
59
- class Downsample(nn.Module):
60
- def __init__(self, kernel, factor=2, device='cpu'):
61
- super().__init__()
62
-
63
- self.factor = factor
64
- kernel = make_kernel(kernel)
65
- self.register_buffer('kernel', kernel)
66
-
67
- p = kernel.shape[0] - factor
68
-
69
- pad0 = (p + 1) // 2
70
- pad1 = p // 2
71
-
72
- self.pad = (pad0, pad1)
73
- self.device = device
74
-
75
- def forward(self, input):
76
- out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad, device=self.device)
77
-
78
- return out
79
-
80
-
81
- class Blur(nn.Module):
82
- def __init__(self, kernel, pad, upsample_factor=1, device='cpu'):
83
- super().__init__()
84
-
85
- kernel = make_kernel(kernel)
86
-
87
- if upsample_factor > 1:
88
- kernel = kernel * (upsample_factor ** 2)
89
-
90
- self.register_buffer('kernel', kernel)
91
-
92
- self.pad = pad
93
- self.device = device
94
-
95
- def forward(self, input):
96
- out = upfirdn2d(input, self.kernel, pad=self.pad, device=self.device)
97
-
98
- return out
99
-
100
-
101
- class EqualConv2d(nn.Module):
102
- def __init__(
103
- self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
104
- ):
105
- super().__init__()
106
-
107
- self.weight = nn.Parameter(
108
- torch.randn(out_channel, in_channel, kernel_size, kernel_size)
109
- )
110
- self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
111
-
112
- self.stride = stride
113
- self.padding = padding
114
-
115
- if bias:
116
- self.bias = nn.Parameter(torch.zeros(out_channel))
117
-
118
- else:
119
- self.bias = None
120
-
121
- def forward(self, input):
122
- out = F.conv2d(
123
- input,
124
- self.weight * self.scale,
125
- bias=self.bias,
126
- stride=self.stride,
127
- padding=self.padding,
128
- )
129
-
130
- return out
131
-
132
- def __repr__(self):
133
- return (
134
- f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
135
- f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
136
- )
137
-
138
-
139
- class EqualLinear(nn.Module):
140
- def __init__(
141
- self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, device='cpu'
142
- ):
143
- super().__init__()
144
-
145
- self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
146
-
147
- if bias:
148
- self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
149
-
150
- else:
151
- self.bias = None
152
-
153
- self.activation = activation
154
- self.device = device
155
-
156
- self.scale = (1 / math.sqrt(in_dim)) * lr_mul
157
- self.lr_mul = lr_mul
158
-
159
- def forward(self, input):
160
- if self.activation:
161
- out = F.linear(input, self.weight * self.scale)
162
- out = fused_leaky_relu(out, self.bias * self.lr_mul, device=self.device)
163
-
164
- else:
165
- out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
166
-
167
- return out
168
-
169
- def __repr__(self):
170
- return (
171
- f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
172
- )
173
-
174
-
175
- class ScaledLeakyReLU(nn.Module):
176
- def __init__(self, negative_slope=0.2):
177
- super().__init__()
178
-
179
- self.negative_slope = negative_slope
180
-
181
- def forward(self, input):
182
- out = F.leaky_relu(input, negative_slope=self.negative_slope)
183
-
184
- return out * math.sqrt(2)
185
-
186
-
187
- class ModulatedConv2d(nn.Module):
188
- def __init__(
189
- self,
190
- in_channel,
191
- out_channel,
192
- kernel_size,
193
- style_dim,
194
- demodulate=True,
195
- upsample=False,
196
- downsample=False,
197
- blur_kernel=[1, 3, 3, 1],
198
- device='cpu'
199
- ):
200
- super().__init__()
201
-
202
- self.eps = 1e-8
203
- self.kernel_size = kernel_size
204
- self.in_channel = in_channel
205
- self.out_channel = out_channel
206
- self.upsample = upsample
207
- self.downsample = downsample
208
-
209
- if upsample:
210
- factor = 2
211
- p = (len(blur_kernel) - factor) - (kernel_size - 1)
212
- pad0 = (p + 1) // 2 + factor - 1
213
- pad1 = p // 2 + 1
214
-
215
- self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor, device=device)
216
-
217
- if downsample:
218
- factor = 2
219
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
220
- pad0 = (p + 1) // 2
221
- pad1 = p // 2
222
-
223
- self.blur = Blur(blur_kernel, pad=(pad0, pad1), device=device)
224
-
225
- fan_in = in_channel * kernel_size ** 2
226
- self.scale = 1 / math.sqrt(fan_in)
227
- self.padding = kernel_size // 2
228
-
229
- self.weight = nn.Parameter(
230
- torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
231
- )
232
-
233
- self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
234
-
235
- self.demodulate = demodulate
236
-
237
- def __repr__(self):
238
- return (
239
- f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
240
- f'upsample={self.upsample}, downsample={self.downsample})'
241
- )
242
-
243
- def forward(self, input, style):
244
- batch, in_channel, height, width = input.shape
245
-
246
- style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
247
- weight = self.scale * self.weight * style
248
-
249
- if self.demodulate:
250
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
251
- weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
252
-
253
- weight = weight.view(
254
- batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
255
- )
256
-
257
- if self.upsample:
258
- input = input.view(1, batch * in_channel, height, width)
259
- weight = weight.view(
260
- batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
261
- )
262
- weight = weight.transpose(1, 2).reshape(
263
- batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
264
- )
265
- out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
266
- _, _, height, width = out.shape
267
- out = out.view(batch, self.out_channel, height, width)
268
- out = self.blur(out)
269
-
270
- elif self.downsample:
271
- input = self.blur(input)
272
- _, _, height, width = input.shape
273
- input = input.view(1, batch * in_channel, height, width)
274
- out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
275
- _, _, height, width = out.shape
276
- out = out.view(batch, self.out_channel, height, width)
277
-
278
- else:
279
- input = input.view(1, batch * in_channel, height, width)
280
- out = F.conv2d(input, weight, padding=self.padding, groups=batch)
281
- _, _, height, width = out.shape
282
- out = out.view(batch, self.out_channel, height, width)
283
-
284
- return out
285
-
286
-
287
- class NoiseInjection(nn.Module):
288
- def __init__(self, isconcat=True):
289
- super().__init__()
290
-
291
- self.isconcat = isconcat
292
- self.weight = nn.Parameter(torch.zeros(1))
293
-
294
- def forward(self, image, noise=None):
295
- if noise is None:
296
- batch, _, height, width = image.shape
297
- noise = image.new_empty(batch, 1, height, width).normal_()
298
-
299
- if self.isconcat:
300
- return torch.cat((image, self.weight * noise), dim=1)
301
- else:
302
- return image + self.weight * noise
303
-
304
-
305
- class ConstantInput(nn.Module):
306
- def __init__(self, channel, size=4):
307
- super().__init__()
308
-
309
- self.input = nn.Parameter(torch.randn(1, channel, size, size))
310
-
311
- def forward(self, input):
312
- batch = input.shape[0]
313
- out = self.input.repeat(batch, 1, 1, 1)
314
-
315
- return out
316
-
317
-
318
- class StyledConv(nn.Module):
319
- def __init__(
320
- self,
321
- in_channel,
322
- out_channel,
323
- kernel_size,
324
- style_dim,
325
- upsample=False,
326
- blur_kernel=[1, 3, 3, 1],
327
- demodulate=True,
328
- isconcat=True,
329
- device='cpu'
330
- ):
331
- super().__init__()
332
-
333
- self.conv = ModulatedConv2d(
334
- in_channel,
335
- out_channel,
336
- kernel_size,
337
- style_dim,
338
- upsample=upsample,
339
- blur_kernel=blur_kernel,
340
- demodulate=demodulate,
341
- device=device
342
- )
343
-
344
- self.noise = NoiseInjection(isconcat)
345
- #self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
346
- #self.activate = ScaledLeakyReLU(0.2)
347
- feat_multiplier = 2 if isconcat else 1
348
- self.activate = FusedLeakyReLU(out_channel*feat_multiplier, device=device)
349
-
350
- def forward(self, input, style, noise=None):
351
- out = self.conv(input, style)
352
- out = self.noise(out, noise=noise)
353
- # out = out + self.bias
354
- out = self.activate(out)
355
-
356
- return out
357
-
358
-
359
- class ToRGB(nn.Module):
360
- def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], device='cpu'):
361
- super().__init__()
362
-
363
- if upsample:
364
- self.upsample = Upsample(blur_kernel, device=device)
365
-
366
- self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False, device=device)
367
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
368
-
369
- def forward(self, input, style, skip=None):
370
- out = self.conv(input, style)
371
- out = out + self.bias
372
-
373
- if skip is not None:
374
- skip = self.upsample(skip)
375
-
376
- out = out + skip
377
-
378
- return out
379
-
380
- class Generator(nn.Module):
381
- def __init__(
382
- self,
383
- size,
384
- style_dim,
385
- n_mlp,
386
- channel_multiplier=2,
387
- blur_kernel=[1, 3, 3, 1],
388
- lr_mlp=0.01,
389
- isconcat=True,
390
- narrow=1,
391
- device='cpu'
392
- ):
393
- super().__init__()
394
-
395
- self.size = size
396
- self.n_mlp = n_mlp
397
- self.style_dim = style_dim
398
- self.feat_multiplier = 2 if isconcat else 1
399
-
400
- layers = [PixelNorm()]
401
-
402
- for i in range(n_mlp):
403
- layers.append(
404
- EqualLinear(
405
- style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu', device=device
406
- )
407
- )
408
-
409
- self.style = nn.Sequential(*layers)
410
-
411
- self.channels = {
412
- 4: int(512 * narrow),
413
- 8: int(512 * narrow),
414
- 16: int(512 * narrow),
415
- 32: int(512 * narrow),
416
- 64: int(256 * channel_multiplier * narrow),
417
- 128: int(128 * channel_multiplier * narrow),
418
- 256: int(64 * channel_multiplier * narrow),
419
- 512: int(32 * channel_multiplier * narrow),
420
- 1024: int(16 * channel_multiplier * narrow)
421
- }
422
-
423
- self.input = ConstantInput(self.channels[4])
424
- self.conv1 = StyledConv(
425
- self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel, isconcat=isconcat, device=device
426
- )
427
- self.to_rgb1 = ToRGB(self.channels[4]*self.feat_multiplier, style_dim, upsample=False, device=device)
428
-
429
- self.log_size = int(math.log(size, 2))
430
-
431
- self.convs = nn.ModuleList()
432
- self.upsamples = nn.ModuleList()
433
- self.to_rgbs = nn.ModuleList()
434
-
435
- in_channel = self.channels[4]
436
-
437
- for i in range(3, self.log_size + 1):
438
- out_channel = self.channels[2 ** i]
439
-
440
- self.convs.append(
441
- StyledConv(
442
- in_channel*self.feat_multiplier,
443
- out_channel,
444
- 3,
445
- style_dim,
446
- upsample=True,
447
- blur_kernel=blur_kernel,
448
- isconcat=isconcat,
449
- device=device
450
- )
451
- )
452
-
453
- self.convs.append(
454
- StyledConv(
455
- out_channel*self.feat_multiplier, out_channel, 3, style_dim, blur_kernel=blur_kernel, isconcat=isconcat, device=device
456
- )
457
- )
458
-
459
- self.to_rgbs.append(ToRGB(out_channel*self.feat_multiplier, style_dim, device=device))
460
-
461
- in_channel = out_channel
462
-
463
- self.n_latent = self.log_size * 2 - 2
464
-
465
- def make_noise(self):
466
- device = self.input.input.device
467
-
468
- noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
469
-
470
- for i in range(3, self.log_size + 1):
471
- for _ in range(2):
472
- noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
473
-
474
- return noises
475
-
476
- def mean_latent(self, n_latent):
477
- latent_in = torch.randn(
478
- n_latent, self.style_dim, device=self.input.input.device
479
- )
480
- latent = self.style(latent_in).mean(0, keepdim=True)
481
-
482
- return latent
483
-
484
- def get_latent(self, input):
485
- return self.style(input)
486
-
487
- def forward(
488
- self,
489
- styles,
490
- return_latents=False,
491
- inject_index=None,
492
- truncation=1,
493
- truncation_latent=None,
494
- input_is_latent=False,
495
- noise=None,
496
- ):
497
- if not input_is_latent:
498
- styles = [self.style(s) for s in styles]
499
-
500
- if noise is None:
501
- '''
502
- noise = [None] * (2 * (self.log_size - 2) + 1)
503
- '''
504
- noise = []
505
- batch = styles[0].shape[0]
506
- for i in range(self.n_mlp + 1):
507
- size = 2 ** (i+2)
508
- noise.append(torch.randn(batch, self.channels[size], size, size, device=styles[0].device))
509
-
510
- if truncation < 1:
511
- style_t = []
512
-
513
- for style in styles:
514
- style_t.append(
515
- truncation_latent + truncation * (style - truncation_latent)
516
- )
517
-
518
- styles = style_t
519
-
520
- if len(styles) < 2:
521
- inject_index = self.n_latent
522
-
523
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
524
-
525
- else:
526
- if inject_index is None:
527
- inject_index = random.randint(1, self.n_latent - 1)
528
-
529
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
530
- latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
531
-
532
- latent = torch.cat([latent, latent2], 1)
533
-
534
- out = self.input(latent)
535
- out = self.conv1(out, latent[:, 0], noise=noise[0])
536
-
537
- skip = self.to_rgb1(out, latent[:, 1])
538
-
539
- i = 1
540
- for conv1, conv2, noise1, noise2, to_rgb in zip(
541
- self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
542
- ):
543
- out = conv1(out, latent[:, i], noise=noise1)
544
- out = conv2(out, latent[:, i + 1], noise=noise2)
545
- skip = to_rgb(out, latent[:, i + 2], skip)
546
-
547
- i += 2
548
-
549
- image = skip
550
-
551
- if return_latents:
552
- return image, latent
553
-
554
- else:
555
- return image, None
556
-
557
- class ConvLayer(nn.Sequential):
558
- def __init__(
559
- self,
560
- in_channel,
561
- out_channel,
562
- kernel_size,
563
- downsample=False,
564
- blur_kernel=[1, 3, 3, 1],
565
- bias=True,
566
- activate=True,
567
- device='cpu'
568
- ):
569
- layers = []
570
-
571
- if downsample:
572
- factor = 2
573
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
574
- pad0 = (p + 1) // 2
575
- pad1 = p // 2
576
-
577
- layers.append(Blur(blur_kernel, pad=(pad0, pad1), device=device))
578
-
579
- stride = 2
580
- self.padding = 0
581
-
582
- else:
583
- stride = 1
584
- self.padding = kernel_size // 2
585
-
586
- layers.append(
587
- EqualConv2d(
588
- in_channel,
589
- out_channel,
590
- kernel_size,
591
- padding=self.padding,
592
- stride=stride,
593
- bias=bias and not activate,
594
- )
595
- )
596
-
597
- if activate:
598
- if bias:
599
- layers.append(FusedLeakyReLU(out_channel, device=device))
600
-
601
- else:
602
- layers.append(ScaledLeakyReLU(0.2))
603
-
604
- super().__init__(*layers)
605
-
606
-
607
- class ResBlock(nn.Module):
608
- def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], device='cpu'):
609
- super().__init__()
610
-
611
- self.conv1 = ConvLayer(in_channel, in_channel, 3, device=device)
612
- self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, device=device)
613
-
614
- self.skip = ConvLayer(
615
- in_channel, out_channel, 1, downsample=True, activate=False, bias=False
616
- )
617
-
618
- def forward(self, input):
619
- out = self.conv1(input)
620
- out = self.conv2(out)
621
-
622
- skip = self.skip(input)
623
- out = (out + skip) / math.sqrt(2)
624
-
625
- return out
626
-
627
- class FullGenerator(nn.Module):
628
- def __init__(
629
- self,
630
- size,
631
- style_dim,
632
- n_mlp,
633
- channel_multiplier=2,
634
- blur_kernel=[1, 3, 3, 1],
635
- lr_mlp=0.01,
636
- isconcat=True,
637
- narrow=1,
638
- device='cpu'
639
- ):
640
- super().__init__()
641
- channels = {
642
- 4: int(512 * narrow),
643
- 8: int(512 * narrow),
644
- 16: int(512 * narrow),
645
- 32: int(512 * narrow),
646
- 64: int(256 * channel_multiplier * narrow),
647
- 128: int(128 * channel_multiplier * narrow),
648
- 256: int(64 * channel_multiplier * narrow),
649
- 512: int(32 * channel_multiplier * narrow),
650
- 1024: int(16 * channel_multiplier * narrow)
651
- }
652
-
653
- self.log_size = int(math.log(size, 2))
654
- self.generator = Generator(size, style_dim, n_mlp, channel_multiplier=channel_multiplier, blur_kernel=blur_kernel, lr_mlp=lr_mlp, isconcat=isconcat, narrow=narrow, device=device)
655
-
656
- conv = [ConvLayer(3, channels[size], 1, device=device)]
657
- self.ecd0 = nn.Sequential(*conv)
658
- in_channel = channels[size]
659
-
660
- self.names = ['ecd%d'%i for i in range(self.log_size-1)]
661
- for i in range(self.log_size, 2, -1):
662
- out_channel = channels[2 ** (i - 1)]
663
- #conv = [ResBlock(in_channel, out_channel, blur_kernel)]
664
- conv = [ConvLayer(in_channel, out_channel, 3, downsample=True, device=device)]
665
- setattr(self, self.names[self.log_size-i+1], nn.Sequential(*conv))
666
- in_channel = out_channel
667
- self.final_linear = nn.Sequential(EqualLinear(channels[4] * 4 * 4, style_dim, activation='fused_lrelu', device=device))
668
-
669
- def forward(self,
670
- inputs,
671
- return_latents=False,
672
- inject_index=None,
673
- truncation=1,
674
- truncation_latent=None,
675
- input_is_latent=False,
676
- ):
677
- noise = []
678
- for i in range(self.log_size-1):
679
- ecd = getattr(self, self.names[i])
680
- inputs = ecd(inputs)
681
- noise.append(inputs)
682
-
683
- inputs = inputs.view(inputs.shape[0], -1)
684
- outs = self.final_linear(inputs)
685
- noise = list(itertools.chain.from_iterable(itertools.repeat(x, 2) for x in noise))[::-1]
686
- outs = self.generator([outs], return_latents, inject_index, truncation, truncation_latent, input_is_latent, noise=noise[1:])
687
- return outs
688
-
689
- class Discriminator(nn.Module):
690
- def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], narrow=1, device='cpu'):
691
- super().__init__()
692
-
693
- channels = {
694
- 4: int(512 * narrow),
695
- 8: int(512 * narrow),
696
- 16: int(512 * narrow),
697
- 32: int(512 * narrow),
698
- 64: int(256 * channel_multiplier * narrow),
699
- 128: int(128 * channel_multiplier * narrow),
700
- 256: int(64 * channel_multiplier * narrow),
701
- 512: int(32 * channel_multiplier * narrow),
702
- 1024: int(16 * channel_multiplier * narrow)
703
- }
704
-
705
- convs = [ConvLayer(3, channels[size], 1, device=device)]
706
-
707
- log_size = int(math.log(size, 2))
708
-
709
- in_channel = channels[size]
710
-
711
- for i in range(log_size, 2, -1):
712
- out_channel = channels[2 ** (i - 1)]
713
-
714
- convs.append(ResBlock(in_channel, out_channel, blur_kernel, device=device))
715
-
716
- in_channel = out_channel
717
-
718
- self.convs = nn.Sequential(*convs)
719
-
720
- self.stddev_group = 4
721
- self.stddev_feat = 1
722
-
723
- self.final_conv = ConvLayer(in_channel + 1, channels[4], 3, device=device)
724
- self.final_linear = nn.Sequential(
725
- EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu', device=device),
726
- EqualLinear(channels[4], 1),
727
- )
728
-
729
- def forward(self, input):
730
- out = self.convs(input)
731
-
732
- batch, channel, height, width = out.shape
733
- group = min(batch, self.stddev_group)
734
- stddev = out.view(
735
- group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
736
- )
737
- stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
738
- stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
739
- stddev = stddev.repeat(group, 1, height, width)
740
- out = torch.cat([out, stddev], 1)
741
-
742
- out = self.final_conv(out)
743
-
744
- out = out.view(batch, -1)
745
- out = self.final_linear(out)
746
- return out
 
1
+ '''
2
+ @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3
+ @author: yangxy (yangtao9009@gmail.com)
4
+ '''
5
+ import math
6
+ import random
7
+ import functools
8
+ import operator
9
+ import itertools
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from torch.autograd import Function
15
+
16
+ from videoretalking.third_part.GPEN.face_model.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
17
+
18
+ class PixelNorm(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def forward(self, input):
23
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
24
+
25
+
26
+ def make_kernel(k):
27
+ k = torch.tensor(k, dtype=torch.float32)
28
+
29
+ if k.ndim == 1:
30
+ k = k[None, :] * k[:, None]
31
+
32
+ k /= k.sum()
33
+
34
+ return k
35
+
36
+
37
+ class Upsample(nn.Module):
38
+ def __init__(self, kernel, factor=2, device='cpu'):
39
+ super().__init__()
40
+
41
+ self.factor = factor
42
+ kernel = make_kernel(kernel) * (factor ** 2)
43
+ self.register_buffer('kernel', kernel)
44
+
45
+ p = kernel.shape[0] - factor
46
+
47
+ pad0 = (p + 1) // 2 + factor - 1
48
+ pad1 = p // 2
49
+
50
+ self.pad = (pad0, pad1)
51
+ self.device = device
52
+
53
+ def forward(self, input):
54
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad, device=self.device)
55
+
56
+ return out
57
+
58
+
59
+ class Downsample(nn.Module):
60
+ def __init__(self, kernel, factor=2, device='cpu'):
61
+ super().__init__()
62
+
63
+ self.factor = factor
64
+ kernel = make_kernel(kernel)
65
+ self.register_buffer('kernel', kernel)
66
+
67
+ p = kernel.shape[0] - factor
68
+
69
+ pad0 = (p + 1) // 2
70
+ pad1 = p // 2
71
+
72
+ self.pad = (pad0, pad1)
73
+ self.device = device
74
+
75
+ def forward(self, input):
76
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad, device=self.device)
77
+
78
+ return out
79
+
80
+
81
+ class Blur(nn.Module):
82
+ def __init__(self, kernel, pad, upsample_factor=1, device='cpu'):
83
+ super().__init__()
84
+
85
+ kernel = make_kernel(kernel)
86
+
87
+ if upsample_factor > 1:
88
+ kernel = kernel * (upsample_factor ** 2)
89
+
90
+ self.register_buffer('kernel', kernel)
91
+
92
+ self.pad = pad
93
+ self.device = device
94
+
95
+ def forward(self, input):
96
+ out = upfirdn2d(input, self.kernel, pad=self.pad, device=self.device)
97
+
98
+ return out
99
+
100
+
101
+ class EqualConv2d(nn.Module):
102
+ def __init__(
103
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
104
+ ):
105
+ super().__init__()
106
+
107
+ self.weight = nn.Parameter(
108
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
109
+ )
110
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
111
+
112
+ self.stride = stride
113
+ self.padding = padding
114
+
115
+ if bias:
116
+ self.bias = nn.Parameter(torch.zeros(out_channel))
117
+
118
+ else:
119
+ self.bias = None
120
+
121
+ def forward(self, input):
122
+ out = F.conv2d(
123
+ input,
124
+ self.weight * self.scale,
125
+ bias=self.bias,
126
+ stride=self.stride,
127
+ padding=self.padding,
128
+ )
129
+
130
+ return out
131
+
132
+ def __repr__(self):
133
+ return (
134
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
135
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
136
+ )
137
+
138
+
139
+ class EqualLinear(nn.Module):
140
+ def __init__(
141
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, device='cpu'
142
+ ):
143
+ super().__init__()
144
+
145
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
146
+
147
+ if bias:
148
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
149
+
150
+ else:
151
+ self.bias = None
152
+
153
+ self.activation = activation
154
+ self.device = device
155
+
156
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
157
+ self.lr_mul = lr_mul
158
+
159
+ def forward(self, input):
160
+ if self.activation:
161
+ out = F.linear(input, self.weight * self.scale)
162
+ out = fused_leaky_relu(out, self.bias * self.lr_mul, device=self.device)
163
+
164
+ else:
165
+ out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
166
+
167
+ return out
168
+
169
+ def __repr__(self):
170
+ return (
171
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
172
+ )
173
+
174
+
175
+ class ScaledLeakyReLU(nn.Module):
176
+ def __init__(self, negative_slope=0.2):
177
+ super().__init__()
178
+
179
+ self.negative_slope = negative_slope
180
+
181
+ def forward(self, input):
182
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
183
+
184
+ return out * math.sqrt(2)
185
+
186
+
187
+ class ModulatedConv2d(nn.Module):
188
+ def __init__(
189
+ self,
190
+ in_channel,
191
+ out_channel,
192
+ kernel_size,
193
+ style_dim,
194
+ demodulate=True,
195
+ upsample=False,
196
+ downsample=False,
197
+ blur_kernel=[1, 3, 3, 1],
198
+ device='cpu'
199
+ ):
200
+ super().__init__()
201
+
202
+ self.eps = 1e-8
203
+ self.kernel_size = kernel_size
204
+ self.in_channel = in_channel
205
+ self.out_channel = out_channel
206
+ self.upsample = upsample
207
+ self.downsample = downsample
208
+
209
+ if upsample:
210
+ factor = 2
211
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
212
+ pad0 = (p + 1) // 2 + factor - 1
213
+ pad1 = p // 2 + 1
214
+
215
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor, device=device)
216
+
217
+ if downsample:
218
+ factor = 2
219
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
220
+ pad0 = (p + 1) // 2
221
+ pad1 = p // 2
222
+
223
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), device=device)
224
+
225
+ fan_in = in_channel * kernel_size ** 2
226
+ self.scale = 1 / math.sqrt(fan_in)
227
+ self.padding = kernel_size // 2
228
+
229
+ self.weight = nn.Parameter(
230
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
231
+ )
232
+
233
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
234
+
235
+ self.demodulate = demodulate
236
+
237
+ def __repr__(self):
238
+ return (
239
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
240
+ f'upsample={self.upsample}, downsample={self.downsample})'
241
+ )
242
+
243
+ def forward(self, input, style):
244
+ batch, in_channel, height, width = input.shape
245
+
246
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
247
+ weight = self.scale * self.weight * style
248
+
249
+ if self.demodulate:
250
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
251
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
252
+
253
+ weight = weight.view(
254
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
255
+ )
256
+
257
+ if self.upsample:
258
+ input = input.view(1, batch * in_channel, height, width)
259
+ weight = weight.view(
260
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
261
+ )
262
+ weight = weight.transpose(1, 2).reshape(
263
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
264
+ )
265
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
266
+ _, _, height, width = out.shape
267
+ out = out.view(batch, self.out_channel, height, width)
268
+ out = self.blur(out)
269
+
270
+ elif self.downsample:
271
+ input = self.blur(input)
272
+ _, _, height, width = input.shape
273
+ input = input.view(1, batch * in_channel, height, width)
274
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
275
+ _, _, height, width = out.shape
276
+ out = out.view(batch, self.out_channel, height, width)
277
+
278
+ else:
279
+ input = input.view(1, batch * in_channel, height, width)
280
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
281
+ _, _, height, width = out.shape
282
+ out = out.view(batch, self.out_channel, height, width)
283
+
284
+ return out
285
+
286
+
287
+ class NoiseInjection(nn.Module):
288
+ def __init__(self, isconcat=True):
289
+ super().__init__()
290
+
291
+ self.isconcat = isconcat
292
+ self.weight = nn.Parameter(torch.zeros(1))
293
+
294
+ def forward(self, image, noise=None):
295
+ if noise is None:
296
+ batch, _, height, width = image.shape
297
+ noise = image.new_empty(batch, 1, height, width).normal_()
298
+
299
+ if self.isconcat:
300
+ return torch.cat((image, self.weight * noise), dim=1)
301
+ else:
302
+ return image + self.weight * noise
303
+
304
+
305
+ class ConstantInput(nn.Module):
306
+ def __init__(self, channel, size=4):
307
+ super().__init__()
308
+
309
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
310
+
311
+ def forward(self, input):
312
+ batch = input.shape[0]
313
+ out = self.input.repeat(batch, 1, 1, 1)
314
+
315
+ return out
316
+
317
+
318
+ class StyledConv(nn.Module):
319
+ def __init__(
320
+ self,
321
+ in_channel,
322
+ out_channel,
323
+ kernel_size,
324
+ style_dim,
325
+ upsample=False,
326
+ blur_kernel=[1, 3, 3, 1],
327
+ demodulate=True,
328
+ isconcat=True,
329
+ device='cpu'
330
+ ):
331
+ super().__init__()
332
+
333
+ self.conv = ModulatedConv2d(
334
+ in_channel,
335
+ out_channel,
336
+ kernel_size,
337
+ style_dim,
338
+ upsample=upsample,
339
+ blur_kernel=blur_kernel,
340
+ demodulate=demodulate,
341
+ device=device
342
+ )
343
+
344
+ self.noise = NoiseInjection(isconcat)
345
+ #self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
346
+ #self.activate = ScaledLeakyReLU(0.2)
347
+ feat_multiplier = 2 if isconcat else 1
348
+ self.activate = FusedLeakyReLU(out_channel*feat_multiplier, device=device)
349
+
350
+ def forward(self, input, style, noise=None):
351
+ out = self.conv(input, style)
352
+ out = self.noise(out, noise=noise)
353
+ # out = out + self.bias
354
+ out = self.activate(out)
355
+
356
+ return out
357
+
358
+
359
+ class ToRGB(nn.Module):
360
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], device='cpu'):
361
+ super().__init__()
362
+
363
+ if upsample:
364
+ self.upsample = Upsample(blur_kernel, device=device)
365
+
366
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False, device=device)
367
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
368
+
369
+ def forward(self, input, style, skip=None):
370
+ out = self.conv(input, style)
371
+ out = out + self.bias
372
+
373
+ if skip is not None:
374
+ skip = self.upsample(skip)
375
+
376
+ out = out + skip
377
+
378
+ return out
379
+
380
+ class Generator(nn.Module):
381
+ def __init__(
382
+ self,
383
+ size,
384
+ style_dim,
385
+ n_mlp,
386
+ channel_multiplier=2,
387
+ blur_kernel=[1, 3, 3, 1],
388
+ lr_mlp=0.01,
389
+ isconcat=True,
390
+ narrow=1,
391
+ device='cpu'
392
+ ):
393
+ super().__init__()
394
+
395
+ self.size = size
396
+ self.n_mlp = n_mlp
397
+ self.style_dim = style_dim
398
+ self.feat_multiplier = 2 if isconcat else 1
399
+
400
+ layers = [PixelNorm()]
401
+
402
+ for i in range(n_mlp):
403
+ layers.append(
404
+ EqualLinear(
405
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu', device=device
406
+ )
407
+ )
408
+
409
+ self.style = nn.Sequential(*layers)
410
+
411
+ self.channels = {
412
+ 4: int(512 * narrow),
413
+ 8: int(512 * narrow),
414
+ 16: int(512 * narrow),
415
+ 32: int(512 * narrow),
416
+ 64: int(256 * channel_multiplier * narrow),
417
+ 128: int(128 * channel_multiplier * narrow),
418
+ 256: int(64 * channel_multiplier * narrow),
419
+ 512: int(32 * channel_multiplier * narrow),
420
+ 1024: int(16 * channel_multiplier * narrow)
421
+ }
422
+
423
+ self.input = ConstantInput(self.channels[4])
424
+ self.conv1 = StyledConv(
425
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel, isconcat=isconcat, device=device
426
+ )
427
+ self.to_rgb1 = ToRGB(self.channels[4]*self.feat_multiplier, style_dim, upsample=False, device=device)
428
+
429
+ self.log_size = int(math.log(size, 2))
430
+
431
+ self.convs = nn.ModuleList()
432
+ self.upsamples = nn.ModuleList()
433
+ self.to_rgbs = nn.ModuleList()
434
+
435
+ in_channel = self.channels[4]
436
+
437
+ for i in range(3, self.log_size + 1):
438
+ out_channel = self.channels[2 ** i]
439
+
440
+ self.convs.append(
441
+ StyledConv(
442
+ in_channel*self.feat_multiplier,
443
+ out_channel,
444
+ 3,
445
+ style_dim,
446
+ upsample=True,
447
+ blur_kernel=blur_kernel,
448
+ isconcat=isconcat,
449
+ device=device
450
+ )
451
+ )
452
+
453
+ self.convs.append(
454
+ StyledConv(
455
+ out_channel*self.feat_multiplier, out_channel, 3, style_dim, blur_kernel=blur_kernel, isconcat=isconcat, device=device
456
+ )
457
+ )
458
+
459
+ self.to_rgbs.append(ToRGB(out_channel*self.feat_multiplier, style_dim, device=device))
460
+
461
+ in_channel = out_channel
462
+
463
+ self.n_latent = self.log_size * 2 - 2
464
+
465
+ def make_noise(self):
466
+ device = self.input.input.device
467
+
468
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
469
+
470
+ for i in range(3, self.log_size + 1):
471
+ for _ in range(2):
472
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
473
+
474
+ return noises
475
+
476
+ def mean_latent(self, n_latent):
477
+ latent_in = torch.randn(
478
+ n_latent, self.style_dim, device=self.input.input.device
479
+ )
480
+ latent = self.style(latent_in).mean(0, keepdim=True)
481
+
482
+ return latent
483
+
484
+ def get_latent(self, input):
485
+ return self.style(input)
486
+
487
+ def forward(
488
+ self,
489
+ styles,
490
+ return_latents=False,
491
+ inject_index=None,
492
+ truncation=1,
493
+ truncation_latent=None,
494
+ input_is_latent=False,
495
+ noise=None,
496
+ ):
497
+ if not input_is_latent:
498
+ styles = [self.style(s) for s in styles]
499
+
500
+ if noise is None:
501
+ '''
502
+ noise = [None] * (2 * (self.log_size - 2) + 1)
503
+ '''
504
+ noise = []
505
+ batch = styles[0].shape[0]
506
+ for i in range(self.n_mlp + 1):
507
+ size = 2 ** (i+2)
508
+ noise.append(torch.randn(batch, self.channels[size], size, size, device=styles[0].device))
509
+
510
+ if truncation < 1:
511
+ style_t = []
512
+
513
+ for style in styles:
514
+ style_t.append(
515
+ truncation_latent + truncation * (style - truncation_latent)
516
+ )
517
+
518
+ styles = style_t
519
+
520
+ if len(styles) < 2:
521
+ inject_index = self.n_latent
522
+
523
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
524
+
525
+ else:
526
+ if inject_index is None:
527
+ inject_index = random.randint(1, self.n_latent - 1)
528
+
529
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
530
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
531
+
532
+ latent = torch.cat([latent, latent2], 1)
533
+
534
+ out = self.input(latent)
535
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
536
+
537
+ skip = self.to_rgb1(out, latent[:, 1])
538
+
539
+ i = 1
540
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
541
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
542
+ ):
543
+ out = conv1(out, latent[:, i], noise=noise1)
544
+ out = conv2(out, latent[:, i + 1], noise=noise2)
545
+ skip = to_rgb(out, latent[:, i + 2], skip)
546
+
547
+ i += 2
548
+
549
+ image = skip
550
+
551
+ if return_latents:
552
+ return image, latent
553
+
554
+ else:
555
+ return image, None
556
+
557
+ class ConvLayer(nn.Sequential):
558
+ def __init__(
559
+ self,
560
+ in_channel,
561
+ out_channel,
562
+ kernel_size,
563
+ downsample=False,
564
+ blur_kernel=[1, 3, 3, 1],
565
+ bias=True,
566
+ activate=True,
567
+ device='cpu'
568
+ ):
569
+ layers = []
570
+
571
+ if downsample:
572
+ factor = 2
573
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
574
+ pad0 = (p + 1) // 2
575
+ pad1 = p // 2
576
+
577
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1), device=device))
578
+
579
+ stride = 2
580
+ self.padding = 0
581
+
582
+ else:
583
+ stride = 1
584
+ self.padding = kernel_size // 2
585
+
586
+ layers.append(
587
+ EqualConv2d(
588
+ in_channel,
589
+ out_channel,
590
+ kernel_size,
591
+ padding=self.padding,
592
+ stride=stride,
593
+ bias=bias and not activate,
594
+ )
595
+ )
596
+
597
+ if activate:
598
+ if bias:
599
+ layers.append(FusedLeakyReLU(out_channel, device=device))
600
+
601
+ else:
602
+ layers.append(ScaledLeakyReLU(0.2))
603
+
604
+ super().__init__(*layers)
605
+
606
+
607
+ class ResBlock(nn.Module):
608
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], device='cpu'):
609
+ super().__init__()
610
+
611
+ self.conv1 = ConvLayer(in_channel, in_channel, 3, device=device)
612
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, device=device)
613
+
614
+ self.skip = ConvLayer(
615
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
616
+ )
617
+
618
+ def forward(self, input):
619
+ out = self.conv1(input)
620
+ out = self.conv2(out)
621
+
622
+ skip = self.skip(input)
623
+ out = (out + skip) / math.sqrt(2)
624
+
625
+ return out
626
+
627
+ class FullGenerator(nn.Module):
628
+ def __init__(
629
+ self,
630
+ size,
631
+ style_dim,
632
+ n_mlp,
633
+ channel_multiplier=2,
634
+ blur_kernel=[1, 3, 3, 1],
635
+ lr_mlp=0.01,
636
+ isconcat=True,
637
+ narrow=1,
638
+ device='cpu'
639
+ ):
640
+ super().__init__()
641
+ channels = {
642
+ 4: int(512 * narrow),
643
+ 8: int(512 * narrow),
644
+ 16: int(512 * narrow),
645
+ 32: int(512 * narrow),
646
+ 64: int(256 * channel_multiplier * narrow),
647
+ 128: int(128 * channel_multiplier * narrow),
648
+ 256: int(64 * channel_multiplier * narrow),
649
+ 512: int(32 * channel_multiplier * narrow),
650
+ 1024: int(16 * channel_multiplier * narrow)
651
+ }
652
+
653
+ self.log_size = int(math.log(size, 2))
654
+ self.generator = Generator(size, style_dim, n_mlp, channel_multiplier=channel_multiplier, blur_kernel=blur_kernel, lr_mlp=lr_mlp, isconcat=isconcat, narrow=narrow, device=device)
655
+
656
+ conv = [ConvLayer(3, channels[size], 1, device=device)]
657
+ self.ecd0 = nn.Sequential(*conv)
658
+ in_channel = channels[size]
659
+
660
+ self.names = ['ecd%d'%i for i in range(self.log_size-1)]
661
+ for i in range(self.log_size, 2, -1):
662
+ out_channel = channels[2 ** (i - 1)]
663
+ #conv = [ResBlock(in_channel, out_channel, blur_kernel)]
664
+ conv = [ConvLayer(in_channel, out_channel, 3, downsample=True, device=device)]
665
+ setattr(self, self.names[self.log_size-i+1], nn.Sequential(*conv))
666
+ in_channel = out_channel
667
+ self.final_linear = nn.Sequential(EqualLinear(channels[4] * 4 * 4, style_dim, activation='fused_lrelu', device=device))
668
+
669
+ def forward(self,
670
+ inputs,
671
+ return_latents=False,
672
+ inject_index=None,
673
+ truncation=1,
674
+ truncation_latent=None,
675
+ input_is_latent=False,
676
+ ):
677
+ noise = []
678
+ for i in range(self.log_size-1):
679
+ ecd = getattr(self, self.names[i])
680
+ inputs = ecd(inputs)
681
+ noise.append(inputs)
682
+
683
+ inputs = inputs.view(inputs.shape[0], -1)
684
+ outs = self.final_linear(inputs)
685
+ noise = list(itertools.chain.from_iterable(itertools.repeat(x, 2) for x in noise))[::-1]
686
+ outs = self.generator([outs], return_latents, inject_index, truncation, truncation_latent, input_is_latent, noise=noise[1:])
687
+ return outs
688
+
689
+ class Discriminator(nn.Module):
690
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], narrow=1, device='cpu'):
691
+ super().__init__()
692
+
693
+ channels = {
694
+ 4: int(512 * narrow),
695
+ 8: int(512 * narrow),
696
+ 16: int(512 * narrow),
697
+ 32: int(512 * narrow),
698
+ 64: int(256 * channel_multiplier * narrow),
699
+ 128: int(128 * channel_multiplier * narrow),
700
+ 256: int(64 * channel_multiplier * narrow),
701
+ 512: int(32 * channel_multiplier * narrow),
702
+ 1024: int(16 * channel_multiplier * narrow)
703
+ }
704
+
705
+ convs = [ConvLayer(3, channels[size], 1, device=device)]
706
+
707
+ log_size = int(math.log(size, 2))
708
+
709
+ in_channel = channels[size]
710
+
711
+ for i in range(log_size, 2, -1):
712
+ out_channel = channels[2 ** (i - 1)]
713
+
714
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel, device=device))
715
+
716
+ in_channel = out_channel
717
+
718
+ self.convs = nn.Sequential(*convs)
719
+
720
+ self.stddev_group = 4
721
+ self.stddev_feat = 1
722
+
723
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3, device=device)
724
+ self.final_linear = nn.Sequential(
725
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu', device=device),
726
+ EqualLinear(channels[4], 1),
727
+ )
728
+
729
+ def forward(self, input):
730
+ out = self.convs(input)
731
+
732
+ batch, channel, height, width = out.shape
733
+ group = min(batch, self.stddev_group)
734
+ stddev = out.view(
735
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
736
+ )
737
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
738
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
739
+ stddev = stddev.repeat(group, 1, height, width)
740
+ out = torch.cat([out, stddev], 1)
741
+
742
+ out = self.final_conv(out)
743
+
744
+ out = out.view(batch, -1)
745
+ out = self.final_linear(out)
746
+ return out