maloyan commited on
Commit
6282546
1 Parent(s): e3bfc1e
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from huggingface_hub import hf_hub_download
9
+ from PIL import Image
10
+
11
+ from models.psp import pSp
12
+
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+
15
+ transfroms = transforms.Compose([
16
+ transforms.Resize((256, 256)),
17
+ transforms.ToTensor()]
18
+ )
19
+
20
+ def log_input_image(x, opts):
21
+ if opts.label_nc == 0:
22
+ return tensor2im(x)
23
+ elif opts.label_nc == 1:
24
+ return tensor2sketch(x)
25
+ else:
26
+ return tensor2map(x)
27
+
28
+
29
+ def tensor2im(var):
30
+ var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
31
+ var = ((var + 1) / 2)
32
+ var[var < 0] = 0
33
+ var[var > 1] = 1
34
+ var = var * 255
35
+ return Image.fromarray(var.astype('uint8'))
36
+
37
+ def tensor2map(var):
38
+ mask = np.argmax(var.data.cpu().numpy(), axis=0)
39
+ print(np.unique(mask))
40
+ colors = get_colors()
41
+ mask_image = np.zeros(shape=(mask.shape[0], mask.shape[1], 3))
42
+ for class_idx in np.unique(mask):
43
+ mask_image[mask == class_idx] = colors[class_idx]
44
+ mask_image = mask_image.astype('uint8')
45
+ return Image.fromarray(mask_image)
46
+
47
+ def tensor2sketch(var):
48
+ im = var[0].cpu().detach().numpy()
49
+ im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
50
+ im = (im * 255).astype(np.uint8)
51
+ return Image.fromarray(im)
52
+ def get_colors():
53
+ # currently support up to 19 classes (for the celebs-hq-mask dataset)
54
+ colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255],
55
+ [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204],
56
+ [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
57
+ return colors
58
+
59
+ def sketch_recognition(img):
60
+ from_im = transfroms(Image.fromarray(img))
61
+ with torch.no_grad():
62
+ res = net(from_im.unsqueeze(0).to(device))
63
+ return tensor2im(res[0])
64
+
65
+
66
+ path = hf_hub_download('huggan/TediGAN_sketch', 'psp_celebs_sketch_to_face.pt')
67
+ ckpt = torch.load(path, map_location=device)
68
+
69
+ opts = ckpt['opts']
70
+ opts.update({"checkpoint_path": path})
71
+ opts = Namespace(**opts)
72
+
73
+ net = pSp(opts)
74
+ net.eval()
75
+ net.to(device)
76
+
77
+ iface = gr.Interface(
78
+ fn=sketch_recognition,
79
+ inputs=gr.inputs.Image(
80
+ shape=(256, 256),
81
+ image_mode="L",
82
+ invert_colors=False,
83
+ source="canvas",
84
+ tool="editor",
85
+ type="numpy",
86
+ label=None,
87
+ optional=False
88
+ ),
89
+ outputs="image"
90
+ ).launch()
91
+ iface.launch()
models/__init__.py ADDED
File without changes
models/encoders/__init__.py ADDED
File without changes
models/encoders/helpers.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
4
+
5
+ """
6
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7
+ """
8
+
9
+
10
+ class Flatten(Module):
11
+ def forward(self, input):
12
+ return input.view(input.size(0), -1)
13
+
14
+
15
+ def l2_norm(input, axis=1):
16
+ norm = torch.norm(input, 2, axis, True)
17
+ output = torch.div(input, norm)
18
+ return output
19
+
20
+
21
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
22
+ """ A named tuple describing a ResNet block. """
23
+
24
+
25
+ def get_block(in_channel, depth, num_units, stride=2):
26
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
27
+
28
+
29
+ def get_blocks(num_layers):
30
+ if num_layers == 50:
31
+ blocks = [
32
+ get_block(in_channel=64, depth=64, num_units=3),
33
+ get_block(in_channel=64, depth=128, num_units=4),
34
+ get_block(in_channel=128, depth=256, num_units=14),
35
+ get_block(in_channel=256, depth=512, num_units=3)
36
+ ]
37
+ elif num_layers == 100:
38
+ blocks = [
39
+ get_block(in_channel=64, depth=64, num_units=3),
40
+ get_block(in_channel=64, depth=128, num_units=13),
41
+ get_block(in_channel=128, depth=256, num_units=30),
42
+ get_block(in_channel=256, depth=512, num_units=3)
43
+ ]
44
+ elif num_layers == 152:
45
+ blocks = [
46
+ get_block(in_channel=64, depth=64, num_units=3),
47
+ get_block(in_channel=64, depth=128, num_units=8),
48
+ get_block(in_channel=128, depth=256, num_units=36),
49
+ get_block(in_channel=256, depth=512, num_units=3)
50
+ ]
51
+ else:
52
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
53
+ return blocks
54
+
55
+
56
+ class SEModule(Module):
57
+ def __init__(self, channels, reduction):
58
+ super(SEModule, self).__init__()
59
+ self.avg_pool = AdaptiveAvgPool2d(1)
60
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
61
+ self.relu = ReLU(inplace=True)
62
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
63
+ self.sigmoid = Sigmoid()
64
+
65
+ def forward(self, x):
66
+ module_input = x
67
+ x = self.avg_pool(x)
68
+ x = self.fc1(x)
69
+ x = self.relu(x)
70
+ x = self.fc2(x)
71
+ x = self.sigmoid(x)
72
+ return module_input * x
73
+
74
+
75
+ class bottleneck_IR(Module):
76
+ def __init__(self, in_channel, depth, stride):
77
+ super(bottleneck_IR, self).__init__()
78
+ if in_channel == depth:
79
+ self.shortcut_layer = MaxPool2d(1, stride)
80
+ else:
81
+ self.shortcut_layer = Sequential(
82
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
83
+ BatchNorm2d(depth)
84
+ )
85
+ self.res_layer = Sequential(
86
+ BatchNorm2d(in_channel),
87
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
88
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
89
+ )
90
+
91
+ def forward(self, x):
92
+ shortcut = self.shortcut_layer(x)
93
+ res = self.res_layer(x)
94
+ return res + shortcut
95
+
96
+
97
+ class bottleneck_IR_SE(Module):
98
+ def __init__(self, in_channel, depth, stride):
99
+ super(bottleneck_IR_SE, self).__init__()
100
+ if in_channel == depth:
101
+ self.shortcut_layer = MaxPool2d(1, stride)
102
+ else:
103
+ self.shortcut_layer = Sequential(
104
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
105
+ BatchNorm2d(depth)
106
+ )
107
+ self.res_layer = Sequential(
108
+ BatchNorm2d(in_channel),
109
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
110
+ PReLU(depth),
111
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
112
+ BatchNorm2d(depth),
113
+ SEModule(depth, 16)
114
+ )
115
+
116
+ def forward(self, x):
117
+ shortcut = self.shortcut_layer(x)
118
+ res = self.res_layer(x)
119
+ return res + shortcut
models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module
6
+
7
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE
8
+ from models.stylegan2.model import EqualLinear
9
+
10
+
11
+ class GradualStyleBlock(Module):
12
+ def __init__(self, in_c, out_c, spatial):
13
+ super(GradualStyleBlock, self).__init__()
14
+ self.out_c = out_c
15
+ self.spatial = spatial
16
+ num_pools = int(np.log2(spatial))
17
+ modules = []
18
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
19
+ nn.LeakyReLU()]
20
+ for i in range(num_pools - 1):
21
+ modules += [
22
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
23
+ nn.LeakyReLU()
24
+ ]
25
+ self.convs = nn.Sequential(*modules)
26
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
27
+
28
+ def forward(self, x):
29
+ x = self.convs(x)
30
+ x = x.view(-1, self.out_c)
31
+ x = self.linear(x)
32
+ return x
33
+
34
+
35
+ class GradualStyleEncoder(Module):
36
+ def __init__(self, num_layers, mode='ir', opts=None):
37
+ super(GradualStyleEncoder, self).__init__()
38
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
39
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
40
+ blocks = get_blocks(num_layers)
41
+ if mode == 'ir':
42
+ unit_module = bottleneck_IR
43
+ elif mode == 'ir_se':
44
+ unit_module = bottleneck_IR_SE
45
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
46
+ BatchNorm2d(64),
47
+ PReLU(64))
48
+ modules = []
49
+ for block in blocks:
50
+ for bottleneck in block:
51
+ modules.append(unit_module(bottleneck.in_channel,
52
+ bottleneck.depth,
53
+ bottleneck.stride))
54
+ self.body = Sequential(*modules)
55
+
56
+ self.styles = nn.ModuleList()
57
+ self.style_count = 18
58
+ self.coarse_ind = 3
59
+ self.middle_ind = 7
60
+ for i in range(self.style_count):
61
+ if i < self.coarse_ind:
62
+ style = GradualStyleBlock(512, 512, 16)
63
+ elif i < self.middle_ind:
64
+ style = GradualStyleBlock(512, 512, 32)
65
+ else:
66
+ style = GradualStyleBlock(512, 512, 64)
67
+ self.styles.append(style)
68
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
69
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
70
+
71
+ def _upsample_add(self, x, y):
72
+ '''Upsample and add two feature maps.
73
+ Args:
74
+ x: (Variable) top feature map to be upsampled.
75
+ y: (Variable) lateral feature map.
76
+ Returns:
77
+ (Variable) added feature map.
78
+ Note in PyTorch, when input size is odd, the upsampled feature map
79
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
80
+ maybe not equal to the lateral feature map size.
81
+ e.g.
82
+ original input size: [N,_,15,15] ->
83
+ conv2d feature map size: [N,_,8,8] ->
84
+ upsampled feature map size: [N,_,16,16]
85
+ So we choose bilinear upsample which supports arbitrary output sizes.
86
+ '''
87
+ _, _, H, W = y.size()
88
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
89
+
90
+ def forward(self, x):
91
+ x = self.input_layer(x)
92
+
93
+ latents = []
94
+ modulelist = list(self.body._modules.values())
95
+ for i, l in enumerate(modulelist):
96
+ x = l(x)
97
+ if i == 6:
98
+ c1 = x
99
+ elif i == 20:
100
+ c2 = x
101
+ elif i == 23:
102
+ c3 = x
103
+
104
+ for j in range(self.coarse_ind):
105
+ latents.append(self.styles[j](c3))
106
+
107
+ p2 = self._upsample_add(c3, self.latlayer1(c2))
108
+ for j in range(self.coarse_ind, self.middle_ind):
109
+ latents.append(self.styles[j](p2))
110
+
111
+ p1 = self._upsample_add(p2, self.latlayer2(c1))
112
+ for j in range(self.middle_ind, self.style_count):
113
+ latents.append(self.styles[j](p1))
114
+
115
+ out = torch.stack(latents, dim=1)
116
+ return out
117
+
118
+
119
+ class BackboneEncoderUsingLastLayerIntoW(Module):
120
+ def __init__(self, num_layers, mode='ir', opts=None):
121
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
122
+ print('Using BackboneEncoderUsingLastLayerIntoW')
123
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
124
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
125
+ blocks = get_blocks(num_layers)
126
+ if mode == 'ir':
127
+ unit_module = bottleneck_IR
128
+ elif mode == 'ir_se':
129
+ unit_module = bottleneck_IR_SE
130
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
131
+ BatchNorm2d(64),
132
+ PReLU(64))
133
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
134
+ self.linear = EqualLinear(512, 512, lr_mul=1)
135
+ modules = []
136
+ for block in blocks:
137
+ for bottleneck in block:
138
+ modules.append(unit_module(bottleneck.in_channel,
139
+ bottleneck.depth,
140
+ bottleneck.stride))
141
+ self.body = Sequential(*modules)
142
+
143
+ def forward(self, x):
144
+ x = self.input_layer(x)
145
+ x = self.body(x)
146
+ x = self.output_pool(x)
147
+ x = x.view(-1, 512)
148
+ x = self.linear(x)
149
+ return x
150
+
151
+
152
+ class BackboneEncoderUsingLastLayerIntoWPlus(Module):
153
+ def __init__(self, num_layers, mode='ir', opts=None):
154
+ super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
155
+ print('Using BackboneEncoderUsingLastLayerIntoWPlus')
156
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
157
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
158
+ blocks = get_blocks(num_layers)
159
+ if mode == 'ir':
160
+ unit_module = bottleneck_IR
161
+ elif mode == 'ir_se':
162
+ unit_module = bottleneck_IR_SE
163
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
164
+ BatchNorm2d(64),
165
+ PReLU(64))
166
+ self.output_layer_2 = Sequential(BatchNorm2d(512),
167
+ torch.nn.AdaptiveAvgPool2d((7, 7)),
168
+ Flatten(),
169
+ Linear(512 * 7 * 7, 512))
170
+ self.linear = EqualLinear(512, 512 * 18, lr_mul=1)
171
+ modules = []
172
+ for block in blocks:
173
+ for bottleneck in block:
174
+ modules.append(unit_module(bottleneck.in_channel,
175
+ bottleneck.depth,
176
+ bottleneck.stride))
177
+ self.body = Sequential(*modules)
178
+
179
+ def forward(self, x):
180
+ x = self.input_layer(x)
181
+ x = self.body(x)
182
+ x = self.output_layer_2(x)
183
+ x = self.linear(x)
184
+ x = x.view(-1, 18, 512)
185
+ return x
models/psp.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines the core research contribution
3
+ """
4
+ import torch
5
+ from torch import nn
6
+
7
+ from models.encoders import psp_encoders
8
+ from models.stylegan2.model import Generator
9
+
10
+ #from configs.paths_config import model_paths
11
+
12
+
13
+ def get_keys(d, name):
14
+ if 'state_dict' in d:
15
+ d = d['state_dict']
16
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
17
+ return d_filt
18
+
19
+
20
+ class pSp(nn.Module):
21
+
22
+ def __init__(self, opts):
23
+ super(pSp, self).__init__()
24
+ self.set_opts(opts)
25
+ # Define architecture
26
+ self.encoder = self.set_encoder()
27
+ self.decoder = Generator(1024, 512, 8)
28
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
29
+ # Load weights if needed
30
+ self.load_weights()
31
+
32
+ def set_encoder(self):
33
+ if self.opts.encoder_type == 'GradualStyleEncoder':
34
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
35
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
36
+ encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
37
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
38
+ encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
39
+ else:
40
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
41
+ return encoder
42
+
43
+ def load_weights(self):
44
+ print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
45
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
46
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
47
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
48
+ self.__load_latent_avg(ckpt)
49
+
50
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
51
+ inject_latent=None, return_latents=False, alpha=None):
52
+ if input_code:
53
+ codes = x
54
+ else:
55
+ codes = self.encoder(x)
56
+ # normalize with respect to the center of an average face
57
+ if self.opts.start_from_latent_avg:
58
+ if self.opts.learn_in_w:
59
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
60
+ else:
61
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
62
+
63
+
64
+ if latent_mask is not None:
65
+ for i in latent_mask:
66
+ if inject_latent is not None:
67
+ if alpha is not None:
68
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
69
+ else:
70
+ codes[:, i] = inject_latent[:, i]
71
+ else:
72
+ codes[:, i] = 0
73
+
74
+ input_is_latent = not input_code
75
+ images, result_latent = self.decoder([codes],
76
+ input_is_latent=input_is_latent,
77
+ randomize_noise=randomize_noise,
78
+ return_latents=return_latents)
79
+
80
+ if resize:
81
+ images = self.face_pool(images)
82
+
83
+ if return_latents:
84
+ return images, result_latent
85
+ else:
86
+ return images
87
+
88
+ def set_opts(self, opts):
89
+ self.opts = opts
90
+
91
+ def __load_latent_avg(self, ckpt, repeat=None):
92
+ if 'latent_avg' in ckpt:
93
+ self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
94
+ if repeat is not None:
95
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
96
+ else:
97
+ self.latent_avg = None
models/stylegan2/__init__.py ADDED
File without changes
models/stylegan2/model.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
8
+
9
+
10
+ class PixelNorm(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def forward(self, input):
15
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
16
+
17
+
18
+ def make_kernel(k):
19
+ k = torch.tensor(k, dtype=torch.float32)
20
+
21
+ if k.ndim == 1:
22
+ k = k[None, :] * k[:, None]
23
+
24
+ k /= k.sum()
25
+
26
+ return k
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, kernel, factor=2):
31
+ super().__init__()
32
+
33
+ self.factor = factor
34
+ kernel = make_kernel(kernel) * (factor ** 2)
35
+ self.register_buffer('kernel', kernel)
36
+
37
+ p = kernel.shape[0] - factor
38
+
39
+ pad0 = (p + 1) // 2 + factor - 1
40
+ pad1 = p // 2
41
+
42
+ self.pad = (pad0, pad1)
43
+
44
+ def forward(self, input):
45
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
46
+
47
+ return out
48
+
49
+
50
+ class Downsample(nn.Module):
51
+ def __init__(self, kernel, factor=2):
52
+ super().__init__()
53
+
54
+ self.factor = factor
55
+ kernel = make_kernel(kernel)
56
+ self.register_buffer('kernel', kernel)
57
+
58
+ p = kernel.shape[0] - factor
59
+
60
+ pad0 = (p + 1) // 2
61
+ pad1 = p // 2
62
+
63
+ self.pad = (pad0, pad1)
64
+
65
+ def forward(self, input):
66
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
67
+
68
+ return out
69
+
70
+
71
+ class Blur(nn.Module):
72
+ def __init__(self, kernel, pad, upsample_factor=1):
73
+ super().__init__()
74
+
75
+ kernel = make_kernel(kernel)
76
+
77
+ if upsample_factor > 1:
78
+ kernel = kernel * (upsample_factor ** 2)
79
+
80
+ self.register_buffer('kernel', kernel)
81
+
82
+ self.pad = pad
83
+
84
+ def forward(self, input):
85
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
86
+
87
+ return out
88
+
89
+
90
+ class EqualConv2d(nn.Module):
91
+ def __init__(
92
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
93
+ ):
94
+ super().__init__()
95
+
96
+ self.weight = nn.Parameter(
97
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
98
+ )
99
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
100
+
101
+ self.stride = stride
102
+ self.padding = padding
103
+
104
+ if bias:
105
+ self.bias = nn.Parameter(torch.zeros(out_channel))
106
+
107
+ else:
108
+ self.bias = None
109
+
110
+ def forward(self, input):
111
+ out = F.conv2d(
112
+ input,
113
+ self.weight * self.scale,
114
+ bias=self.bias,
115
+ stride=self.stride,
116
+ padding=self.padding,
117
+ )
118
+
119
+ return out
120
+
121
+ def __repr__(self):
122
+ return (
123
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
124
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
125
+ )
126
+
127
+
128
+ class EqualLinear(nn.Module):
129
+ def __init__(
130
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
131
+ ):
132
+ super().__init__()
133
+
134
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
135
+
136
+ if bias:
137
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
138
+
139
+ else:
140
+ self.bias = None
141
+
142
+ self.activation = activation
143
+
144
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
145
+ self.lr_mul = lr_mul
146
+
147
+ def forward(self, input):
148
+ if self.activation:
149
+ out = F.linear(input, self.weight * self.scale)
150
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
151
+
152
+ else:
153
+ out = F.linear(
154
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
155
+ )
156
+
157
+ return out
158
+
159
+ def __repr__(self):
160
+ return (
161
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
162
+ )
163
+
164
+
165
+ class ScaledLeakyReLU(nn.Module):
166
+ def __init__(self, negative_slope=0.2):
167
+ super().__init__()
168
+
169
+ self.negative_slope = negative_slope
170
+
171
+ def forward(self, input):
172
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
173
+
174
+ return out * math.sqrt(2)
175
+
176
+
177
+ class ModulatedConv2d(nn.Module):
178
+ def __init__(
179
+ self,
180
+ in_channel,
181
+ out_channel,
182
+ kernel_size,
183
+ style_dim,
184
+ demodulate=True,
185
+ upsample=False,
186
+ downsample=False,
187
+ blur_kernel=[1, 3, 3, 1],
188
+ ):
189
+ super().__init__()
190
+
191
+ self.eps = 1e-8
192
+ self.kernel_size = kernel_size
193
+ self.in_channel = in_channel
194
+ self.out_channel = out_channel
195
+ self.upsample = upsample
196
+ self.downsample = downsample
197
+
198
+ if upsample:
199
+ factor = 2
200
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
201
+ pad0 = (p + 1) // 2 + factor - 1
202
+ pad1 = p // 2 + 1
203
+
204
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
205
+
206
+ if downsample:
207
+ factor = 2
208
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
209
+ pad0 = (p + 1) // 2
210
+ pad1 = p // 2
211
+
212
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
213
+
214
+ fan_in = in_channel * kernel_size ** 2
215
+ self.scale = 1 / math.sqrt(fan_in)
216
+ self.padding = kernel_size // 2
217
+
218
+ self.weight = nn.Parameter(
219
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
220
+ )
221
+
222
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
223
+
224
+ self.demodulate = demodulate
225
+
226
+ def __repr__(self):
227
+ return (
228
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
229
+ f'upsample={self.upsample}, downsample={self.downsample})'
230
+ )
231
+
232
+ def forward(self, input, style):
233
+ batch, in_channel, height, width = input.shape
234
+
235
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
236
+ weight = self.scale * self.weight * style
237
+
238
+ if self.demodulate:
239
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
240
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
241
+
242
+ weight = weight.view(
243
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
244
+ )
245
+
246
+ if self.upsample:
247
+ input = input.view(1, batch * in_channel, height, width)
248
+ weight = weight.view(
249
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
250
+ )
251
+ weight = weight.transpose(1, 2).reshape(
252
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
253
+ )
254
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
255
+ _, _, height, width = out.shape
256
+ out = out.view(batch, self.out_channel, height, width)
257
+ out = self.blur(out)
258
+
259
+ elif self.downsample:
260
+ input = self.blur(input)
261
+ _, _, height, width = input.shape
262
+ input = input.view(1, batch * in_channel, height, width)
263
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
264
+ _, _, height, width = out.shape
265
+ out = out.view(batch, self.out_channel, height, width)
266
+
267
+ else:
268
+ input = input.view(1, batch * in_channel, height, width)
269
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
270
+ _, _, height, width = out.shape
271
+ out = out.view(batch, self.out_channel, height, width)
272
+
273
+ return out
274
+
275
+
276
+ class NoiseInjection(nn.Module):
277
+ def __init__(self):
278
+ super().__init__()
279
+
280
+ self.weight = nn.Parameter(torch.zeros(1))
281
+
282
+ def forward(self, image, noise=None):
283
+ if noise is None:
284
+ batch, _, height, width = image.shape
285
+ noise = image.new_empty(batch, 1, height, width).normal_()
286
+
287
+ return image + self.weight * noise
288
+
289
+
290
+ class ConstantInput(nn.Module):
291
+ def __init__(self, channel, size=4):
292
+ super().__init__()
293
+
294
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
295
+
296
+ def forward(self, input):
297
+ batch = input.shape[0]
298
+ out = self.input.repeat(batch, 1, 1, 1)
299
+
300
+ return out
301
+
302
+
303
+ class StyledConv(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channel,
307
+ out_channel,
308
+ kernel_size,
309
+ style_dim,
310
+ upsample=False,
311
+ blur_kernel=[1, 3, 3, 1],
312
+ demodulate=True,
313
+ ):
314
+ super().__init__()
315
+
316
+ self.conv = ModulatedConv2d(
317
+ in_channel,
318
+ out_channel,
319
+ kernel_size,
320
+ style_dim,
321
+ upsample=upsample,
322
+ blur_kernel=blur_kernel,
323
+ demodulate=demodulate,
324
+ )
325
+
326
+ self.noise = NoiseInjection()
327
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
328
+ # self.activate = ScaledLeakyReLU(0.2)
329
+ self.activate = FusedLeakyReLU(out_channel)
330
+
331
+ def forward(self, input, style, noise=None):
332
+ out = self.conv(input, style)
333
+ out = self.noise(out, noise=noise)
334
+ # out = out + self.bias
335
+ out = self.activate(out)
336
+
337
+ return out
338
+
339
+
340
+ class ToRGB(nn.Module):
341
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
342
+ super().__init__()
343
+
344
+ if upsample:
345
+ self.upsample = Upsample(blur_kernel)
346
+
347
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
348
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
349
+
350
+ def forward(self, input, style, skip=None):
351
+ out = self.conv(input, style)
352
+ out = out + self.bias
353
+
354
+ if skip is not None:
355
+ skip = self.upsample(skip)
356
+
357
+ out = out + skip
358
+
359
+ return out
360
+
361
+
362
+ class Generator(nn.Module):
363
+ def __init__(
364
+ self,
365
+ size,
366
+ style_dim,
367
+ n_mlp,
368
+ channel_multiplier=2,
369
+ blur_kernel=[1, 3, 3, 1],
370
+ lr_mlp=0.01,
371
+ ):
372
+ super().__init__()
373
+
374
+ self.size = size
375
+
376
+ self.style_dim = style_dim
377
+
378
+ layers = [PixelNorm()]
379
+
380
+ for i in range(n_mlp):
381
+ layers.append(
382
+ EqualLinear(
383
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
384
+ )
385
+ )
386
+
387
+ self.style = nn.Sequential(*layers)
388
+
389
+ self.channels = {
390
+ 4: 512,
391
+ 8: 512,
392
+ 16: 512,
393
+ 32: 512,
394
+ 64: 256 * channel_multiplier,
395
+ 128: 128 * channel_multiplier,
396
+ 256: 64 * channel_multiplier,
397
+ 512: 32 * channel_multiplier,
398
+ 1024: 16 * channel_multiplier,
399
+ }
400
+
401
+ self.input = ConstantInput(self.channels[4])
402
+ self.conv1 = StyledConv(
403
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
404
+ )
405
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
406
+
407
+ self.log_size = int(math.log(size, 2))
408
+ self.num_layers = (self.log_size - 2) * 2 + 1
409
+
410
+ self.convs = nn.ModuleList()
411
+ self.upsamples = nn.ModuleList()
412
+ self.to_rgbs = nn.ModuleList()
413
+ self.noises = nn.Module()
414
+
415
+ in_channel = self.channels[4]
416
+
417
+ for layer_idx in range(self.num_layers):
418
+ res = (layer_idx + 5) // 2
419
+ shape = [1, 1, 2 ** res, 2 ** res]
420
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
421
+
422
+ for i in range(3, self.log_size + 1):
423
+ out_channel = self.channels[2 ** i]
424
+
425
+ self.convs.append(
426
+ StyledConv(
427
+ in_channel,
428
+ out_channel,
429
+ 3,
430
+ style_dim,
431
+ upsample=True,
432
+ blur_kernel=blur_kernel,
433
+ )
434
+ )
435
+
436
+ self.convs.append(
437
+ StyledConv(
438
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
439
+ )
440
+ )
441
+
442
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
443
+
444
+ in_channel = out_channel
445
+
446
+ self.n_latent = self.log_size * 2 - 2
447
+
448
+ def make_noise(self):
449
+ device = self.input.input.device
450
+
451
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
452
+
453
+ for i in range(3, self.log_size + 1):
454
+ for _ in range(2):
455
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
456
+
457
+ return noises
458
+
459
+ def mean_latent(self, n_latent):
460
+ latent_in = torch.randn(
461
+ n_latent, self.style_dim, device=self.input.input.device
462
+ )
463
+ latent = self.style(latent_in).mean(0, keepdim=True)
464
+
465
+ return latent
466
+
467
+ def get_latent(self, input):
468
+ return self.style(input)
469
+
470
+ def forward(
471
+ self,
472
+ styles,
473
+ return_latents=False,
474
+ return_features=False,
475
+ inject_index=None,
476
+ truncation=1,
477
+ truncation_latent=None,
478
+ input_is_latent=False,
479
+ noise=None,
480
+ randomize_noise=True,
481
+ ):
482
+ if not input_is_latent:
483
+ styles = [self.style(s) for s in styles]
484
+
485
+ if noise is None:
486
+ if randomize_noise:
487
+ noise = [None] * self.num_layers
488
+ else:
489
+ noise = [
490
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
491
+ ]
492
+
493
+ if truncation < 1:
494
+ style_t = []
495
+
496
+ for style in styles:
497
+ style_t.append(
498
+ truncation_latent + truncation * (style - truncation_latent)
499
+ )
500
+
501
+ styles = style_t
502
+
503
+ if len(styles) < 2:
504
+ inject_index = self.n_latent
505
+
506
+ if styles[0].ndim < 3:
507
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
508
+ else:
509
+ latent = styles[0]
510
+
511
+ else:
512
+ if inject_index is None:
513
+ inject_index = random.randint(1, self.n_latent - 1)
514
+
515
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
516
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
517
+
518
+ latent = torch.cat([latent, latent2], 1)
519
+
520
+ out = self.input(latent)
521
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
522
+
523
+ skip = self.to_rgb1(out, latent[:, 1])
524
+
525
+ i = 1
526
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
527
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
528
+ ):
529
+ out = conv1(out, latent[:, i], noise=noise1)
530
+ out = conv2(out, latent[:, i + 1], noise=noise2)
531
+ skip = to_rgb(out, latent[:, i + 2], skip)
532
+
533
+ i += 2
534
+
535
+ image = skip
536
+
537
+ if return_latents:
538
+ return image, latent
539
+ elif return_features:
540
+ return image, out
541
+ else:
542
+ return image, None
543
+
544
+
545
+ class ConvLayer(nn.Sequential):
546
+ def __init__(
547
+ self,
548
+ in_channel,
549
+ out_channel,
550
+ kernel_size,
551
+ downsample=False,
552
+ blur_kernel=[1, 3, 3, 1],
553
+ bias=True,
554
+ activate=True,
555
+ ):
556
+ layers = []
557
+
558
+ if downsample:
559
+ factor = 2
560
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
561
+ pad0 = (p + 1) // 2
562
+ pad1 = p // 2
563
+
564
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
565
+
566
+ stride = 2
567
+ self.padding = 0
568
+
569
+ else:
570
+ stride = 1
571
+ self.padding = kernel_size // 2
572
+
573
+ layers.append(
574
+ EqualConv2d(
575
+ in_channel,
576
+ out_channel,
577
+ kernel_size,
578
+ padding=self.padding,
579
+ stride=stride,
580
+ bias=bias and not activate,
581
+ )
582
+ )
583
+
584
+ if activate:
585
+ if bias:
586
+ layers.append(FusedLeakyReLU(out_channel))
587
+
588
+ else:
589
+ layers.append(ScaledLeakyReLU(0.2))
590
+
591
+ super().__init__(*layers)
592
+
593
+
594
+ class ResBlock(nn.Module):
595
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
596
+ super().__init__()
597
+
598
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
599
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
600
+
601
+ self.skip = ConvLayer(
602
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
603
+ )
604
+
605
+ def forward(self, input):
606
+ out = self.conv1(input)
607
+ out = self.conv2(out)
608
+
609
+ skip = self.skip(input)
610
+ out = (out + skip) / math.sqrt(2)
611
+
612
+ return out
613
+
614
+
615
+ class Discriminator(nn.Module):
616
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
617
+ super().__init__()
618
+
619
+ channels = {
620
+ 4: 512,
621
+ 8: 512,
622
+ 16: 512,
623
+ 32: 512,
624
+ 64: 256 * channel_multiplier,
625
+ 128: 128 * channel_multiplier,
626
+ 256: 64 * channel_multiplier,
627
+ 512: 32 * channel_multiplier,
628
+ 1024: 16 * channel_multiplier,
629
+ }
630
+
631
+ convs = [ConvLayer(3, channels[size], 1)]
632
+
633
+ log_size = int(math.log(size, 2))
634
+
635
+ in_channel = channels[size]
636
+
637
+ for i in range(log_size, 2, -1):
638
+ out_channel = channels[2 ** (i - 1)]
639
+
640
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
641
+
642
+ in_channel = out_channel
643
+
644
+ self.convs = nn.Sequential(*convs)
645
+
646
+ self.stddev_group = 4
647
+ self.stddev_feat = 1
648
+
649
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
650
+ self.final_linear = nn.Sequential(
651
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
652
+ EqualLinear(channels[4], 1),
653
+ )
654
+
655
+ def forward(self, input):
656
+ out = self.convs(input)
657
+
658
+ batch, channel, height, width = out.shape
659
+ group = min(batch, self.stddev_group)
660
+ stddev = out.view(
661
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
662
+ )
663
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
664
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
665
+ stddev = stddev.repeat(group, 1, height, width)
666
+ out = torch.cat([out, stddev], 1)
667
+
668
+ out = self.final_conv(out)
669
+
670
+ out = out.view(batch, -1)
671
+ out = self.final_linear(out)
672
+
673
+ return out
models/stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
models/stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+ module_path = os.path.dirname(__file__)
9
+ fused = load(
10
+ 'fused',
11
+ sources=[
12
+ os.path.join(module_path, 'fused_bias_act.cpp'),
13
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
14
+ ],
15
+ )
16
+
17
+
18
+ class FusedLeakyReLUFunctionBackward(Function):
19
+ @staticmethod
20
+ def forward(ctx, grad_output, out, negative_slope, scale):
21
+ ctx.save_for_backward(out)
22
+ ctx.negative_slope = negative_slope
23
+ ctx.scale = scale
24
+
25
+ empty = grad_output.new_empty(0)
26
+
27
+ grad_input = fused.fused_bias_act(
28
+ grad_output, empty, out, 3, 1, negative_slope, scale
29
+ )
30
+
31
+ dim = [0]
32
+
33
+ if grad_input.ndim > 2:
34
+ dim += list(range(2, grad_input.ndim))
35
+
36
+ grad_bias = grad_input.sum(dim).detach()
37
+
38
+ return grad_input, grad_bias
39
+
40
+ @staticmethod
41
+ def backward(ctx, gradgrad_input, gradgrad_bias):
42
+ out, = ctx.saved_tensors
43
+ gradgrad_out = fused.fused_bias_act(
44
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
45
+ )
46
+
47
+ return gradgrad_out, None, None, None
48
+
49
+
50
+ class FusedLeakyReLUFunction(Function):
51
+ @staticmethod
52
+ def forward(ctx, input, bias, negative_slope, scale):
53
+ empty = input.new_empty(0)
54
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
55
+ ctx.save_for_backward(out)
56
+ ctx.negative_slope = negative_slope
57
+ ctx.scale = scale
58
+
59
+ return out
60
+
61
+ @staticmethod
62
+ def backward(ctx, grad_output):
63
+ out, = ctx.saved_tensors
64
+
65
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
66
+ grad_output, out, ctx.negative_slope, ctx.scale
67
+ )
68
+
69
+ return grad_input, grad_bias, None, None
70
+
71
+
72
+ class FusedLeakyReLU(nn.Module):
73
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
74
+ super().__init__()
75
+
76
+ self.bias = nn.Parameter(torch.zeros(channel))
77
+ self.negative_slope = negative_slope
78
+ self.scale = scale
79
+
80
+ def forward(self, input):
81
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
82
+
83
+
84
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
85
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
models/stylegan2/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
models/stylegan2/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
models/stylegan2/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }
models/stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.utils.cpp_extension import load
6
+
7
+ module_path = os.path.dirname(__file__)
8
+ upfirdn2d_op = load(
9
+ 'upfirdn2d',
10
+ sources=[
11
+ os.path.join(module_path, 'upfirdn2d.cpp'),
12
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
13
+ ],
14
+ )
15
+
16
+
17
+ class UpFirDn2dBackward(Function):
18
+ @staticmethod
19
+ def forward(
20
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
21
+ ):
22
+ up_x, up_y = up
23
+ down_x, down_y = down
24
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
25
+
26
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
27
+
28
+ grad_input = upfirdn2d_op.upfirdn2d(
29
+ grad_output,
30
+ grad_kernel,
31
+ down_x,
32
+ down_y,
33
+ up_x,
34
+ up_y,
35
+ g_pad_x0,
36
+ g_pad_x1,
37
+ g_pad_y0,
38
+ g_pad_y1,
39
+ )
40
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
41
+
42
+ ctx.save_for_backward(kernel)
43
+
44
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
45
+
46
+ ctx.up_x = up_x
47
+ ctx.up_y = up_y
48
+ ctx.down_x = down_x
49
+ ctx.down_y = down_y
50
+ ctx.pad_x0 = pad_x0
51
+ ctx.pad_x1 = pad_x1
52
+ ctx.pad_y0 = pad_y0
53
+ ctx.pad_y1 = pad_y1
54
+ ctx.in_size = in_size
55
+ ctx.out_size = out_size
56
+
57
+ return grad_input
58
+
59
+ @staticmethod
60
+ def backward(ctx, gradgrad_input):
61
+ kernel, = ctx.saved_tensors
62
+
63
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
64
+
65
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
66
+ gradgrad_input,
67
+ kernel,
68
+ ctx.up_x,
69
+ ctx.up_y,
70
+ ctx.down_x,
71
+ ctx.down_y,
72
+ ctx.pad_x0,
73
+ ctx.pad_x1,
74
+ ctx.pad_y0,
75
+ ctx.pad_y1,
76
+ )
77
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
78
+ gradgrad_out = gradgrad_out.view(
79
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
80
+ )
81
+
82
+ return gradgrad_out, None, None, None, None, None, None, None, None
83
+
84
+
85
+ class UpFirDn2d(Function):
86
+ @staticmethod
87
+ def forward(ctx, input, kernel, up, down, pad):
88
+ up_x, up_y = up
89
+ down_x, down_y = down
90
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
91
+
92
+ kernel_h, kernel_w = kernel.shape
93
+ batch, channel, in_h, in_w = input.shape
94
+ ctx.in_size = input.shape
95
+
96
+ input = input.reshape(-1, in_h, in_w, 1)
97
+
98
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
99
+
100
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
101
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
102
+ ctx.out_size = (out_h, out_w)
103
+
104
+ ctx.up = (up_x, up_y)
105
+ ctx.down = (down_x, down_y)
106
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
107
+
108
+ g_pad_x0 = kernel_w - pad_x0 - 1
109
+ g_pad_y0 = kernel_h - pad_y0 - 1
110
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
111
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
112
+
113
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
114
+
115
+ out = upfirdn2d_op.upfirdn2d(
116
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
117
+ )
118
+ # out = out.view(major, out_h, out_w, minor)
119
+ out = out.view(-1, channel, out_h, out_w)
120
+
121
+ return out
122
+
123
+ @staticmethod
124
+ def backward(ctx, grad_output):
125
+ kernel, grad_kernel = ctx.saved_tensors
126
+
127
+ grad_input = UpFirDn2dBackward.apply(
128
+ grad_output,
129
+ kernel,
130
+ grad_kernel,
131
+ ctx.up,
132
+ ctx.down,
133
+ ctx.pad,
134
+ ctx.g_pad,
135
+ ctx.in_size,
136
+ ctx.out_size,
137
+ )
138
+
139
+ return grad_input, None, None, None, None
140
+
141
+
142
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
143
+ out = UpFirDn2d.apply(
144
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
145
+ )
146
+
147
+ return out
148
+
149
+
150
+ def upfirdn2d_native(
151
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
152
+ ):
153
+ _, in_h, in_w, minor = input.shape
154
+ kernel_h, kernel_w = kernel.shape
155
+
156
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
157
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
158
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
159
+
160
+ out = F.pad(
161
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
162
+ )
163
+ out = out[
164
+ :,
165
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
166
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
167
+ :,
168
+ ]
169
+
170
+ out = out.permute(0, 3, 1, 2)
171
+ out = out.reshape(
172
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
173
+ )
174
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
175
+ out = F.conv2d(out, w)
176
+ out = out.reshape(
177
+ -1,
178
+ minor,
179
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
180
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
181
+ )
182
+ out = out.permute(0, 2, 3, 1)
183
+
184
+ return out[:, ::down_y, ::down_x, :]
models/stylegan2/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19
+ int c = a / b;
20
+
21
+ if (c * b > a) {
22
+ c--;
23
+ }
24
+
25
+ return c;
26
+ }
27
+
28
+
29
+ struct UpFirDn2DKernelParams {
30
+ int up_x;
31
+ int up_y;
32
+ int down_x;
33
+ int down_y;
34
+ int pad_x0;
35
+ int pad_x1;
36
+ int pad_y0;
37
+ int pad_y1;
38
+
39
+ int major_dim;
40
+ int in_h;
41
+ int in_w;
42
+ int minor_dim;
43
+ int kernel_h;
44
+ int kernel_w;
45
+ int out_h;
46
+ int out_w;
47
+ int loop_major;
48
+ int loop_x;
49
+ };
50
+
51
+
52
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
53
+ __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56
+
57
+ __shared__ volatile float sk[kernel_h][kernel_w];
58
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
59
+
60
+ int minor_idx = blockIdx.x;
61
+ int tile_out_y = minor_idx / p.minor_dim;
62
+ minor_idx -= tile_out_y * p.minor_dim;
63
+ tile_out_y *= tile_out_h;
64
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65
+ int major_idx_base = blockIdx.z * p.loop_major;
66
+
67
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68
+ return;
69
+ }
70
+
71
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72
+ int ky = tap_idx / kernel_w;
73
+ int kx = tap_idx - ky * kernel_w;
74
+ scalar_t v = 0.0;
75
+
76
+ if (kx < p.kernel_w & ky < p.kernel_h) {
77
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78
+ }
79
+
80
+ sk[ky][kx] = v;
81
+ }
82
+
83
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87
+ int tile_in_x = floor_div(tile_mid_x, up_x);
88
+ int tile_in_y = floor_div(tile_mid_y, up_y);
89
+
90
+ __syncthreads();
91
+
92
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93
+ int rel_in_y = in_idx / tile_in_w;
94
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
95
+ int in_x = rel_in_x + tile_in_x;
96
+ int in_y = rel_in_y + tile_in_y;
97
+
98
+ scalar_t v = 0.0;
99
+
100
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102
+ }
103
+
104
+ sx[rel_in_y][rel_in_x] = v;
105
+ }
106
+
107
+ __syncthreads();
108
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109
+ int rel_out_y = out_idx / tile_out_w;
110
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
111
+ int out_x = rel_out_x + tile_out_x;
112
+ int out_y = rel_out_y + tile_out_y;
113
+
114
+ int mid_x = tile_mid_x + rel_out_x * down_x;
115
+ int mid_y = tile_mid_y + rel_out_y * down_y;
116
+ int in_x = floor_div(mid_x, up_x);
117
+ int in_y = floor_div(mid_y, up_y);
118
+ int rel_in_x = in_x - tile_in_x;
119
+ int rel_in_y = in_y - tile_in_y;
120
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122
+
123
+ scalar_t v = 0.0;
124
+
125
+ #pragma unroll
126
+ for (int y = 0; y < kernel_h / up_y; y++)
127
+ #pragma unroll
128
+ for (int x = 0; x < kernel_w / up_x; x++)
129
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130
+
131
+ if (out_x < p.out_w & out_y < p.out_h) {
132
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+
139
+
140
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141
+ int up_x, int up_y, int down_x, int down_y,
142
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143
+ int curDevice = -1;
144
+ cudaGetDevice(&curDevice);
145
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146
+
147
+ UpFirDn2DKernelParams p;
148
+
149
+ auto x = input.contiguous();
150
+ auto k = kernel.contiguous();
151
+
152
+ p.major_dim = x.size(0);
153
+ p.in_h = x.size(1);
154
+ p.in_w = x.size(2);
155
+ p.minor_dim = x.size(3);
156
+ p.kernel_h = k.size(0);
157
+ p.kernel_w = k.size(1);
158
+ p.up_x = up_x;
159
+ p.up_y = up_y;
160
+ p.down_x = down_x;
161
+ p.down_y = down_y;
162
+ p.pad_x0 = pad_x0;
163
+ p.pad_x1 = pad_x1;
164
+ p.pad_y0 = pad_y0;
165
+ p.pad_y1 = pad_y1;
166
+
167
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169
+
170
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171
+
172
+ int mode = -1;
173
+
174
+ int tile_out_h;
175
+ int tile_out_w;
176
+
177
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178
+ mode = 1;
179
+ tile_out_h = 16;
180
+ tile_out_w = 64;
181
+ }
182
+
183
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184
+ mode = 2;
185
+ tile_out_h = 16;
186
+ tile_out_w = 64;
187
+ }
188
+
189
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190
+ mode = 3;
191
+ tile_out_h = 16;
192
+ tile_out_w = 64;
193
+ }
194
+
195
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196
+ mode = 4;
197
+ tile_out_h = 16;
198
+ tile_out_w = 64;
199
+ }
200
+
201
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202
+ mode = 5;
203
+ tile_out_h = 8;
204
+ tile_out_w = 32;
205
+ }
206
+
207
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208
+ mode = 6;
209
+ tile_out_h = 8;
210
+ tile_out_w = 32;
211
+ }
212
+
213
+ dim3 block_size;
214
+ dim3 grid_size;
215
+
216
+ if (tile_out_h > 0 && tile_out_w) {
217
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
218
+ p.loop_x = 1;
219
+ block_size = dim3(32 * 8, 1, 1);
220
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222
+ (p.major_dim - 1) / p.loop_major + 1);
223
+ }
224
+
225
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226
+ switch (mode) {
227
+ case 1:
228
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
229
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
230
+ );
231
+
232
+ break;
233
+
234
+ case 2:
235
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
236
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
237
+ );
238
+
239
+ break;
240
+
241
+ case 3:
242
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
243
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
244
+ );
245
+
246
+ break;
247
+
248
+ case 4:
249
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
250
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
251
+ );
252
+
253
+ break;
254
+
255
+ case 5:
256
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
257
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
258
+ );
259
+
260
+ break;
261
+
262
+ case 6:
263
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
264
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
265
+ );
266
+
267
+ break;
268
+ }
269
+ });
270
+
271
+ return out;
272
+ }