Spanicin commited on
Commit
88de352
·
verified ·
1 Parent(s): 06b6a8f

Update videoretalking/third_part/GPEN/face_model/op/upfirdn2d.py

Browse files
videoretalking/third_part/GPEN/face_model/op/upfirdn2d.py CHANGED
@@ -1,194 +1,194 @@
1
- import os
2
- import platform
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch.autograd import Function
7
- from torch.utils.cpp_extension import load, _import_module_from_library
8
-
9
- # if running GPEN without cuda, please comment line 10-18
10
- if platform.system() == 'Linux' and torch.cuda.is_available():
11
- module_path = os.path.dirname(__file__)
12
- upfirdn2d_op = load(
13
- 'upfirdn2d',
14
- sources=[
15
- os.path.join(module_path, 'upfirdn2d.cpp'),
16
- os.path.join(module_path, 'upfirdn2d_kernel.cu'),
17
- ],
18
- )
19
-
20
-
21
- #upfirdn2d_op = _import_module_from_library('upfirdn2d', '/tmp/torch_extensions/upfirdn2d', True)
22
-
23
- class UpFirDn2dBackward(Function):
24
- @staticmethod
25
- def forward(
26
- ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
27
- ):
28
-
29
- up_x, up_y = up
30
- down_x, down_y = down
31
- g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
32
-
33
- grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
34
-
35
- grad_input = upfirdn2d_op.upfirdn2d(
36
- grad_output,
37
- grad_kernel,
38
- down_x,
39
- down_y,
40
- up_x,
41
- up_y,
42
- g_pad_x0,
43
- g_pad_x1,
44
- g_pad_y0,
45
- g_pad_y1,
46
- )
47
- grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
48
-
49
- ctx.save_for_backward(kernel)
50
-
51
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
52
-
53
- ctx.up_x = up_x
54
- ctx.up_y = up_y
55
- ctx.down_x = down_x
56
- ctx.down_y = down_y
57
- ctx.pad_x0 = pad_x0
58
- ctx.pad_x1 = pad_x1
59
- ctx.pad_y0 = pad_y0
60
- ctx.pad_y1 = pad_y1
61
- ctx.in_size = in_size
62
- ctx.out_size = out_size
63
-
64
- return grad_input
65
-
66
- @staticmethod
67
- def backward(ctx, gradgrad_input):
68
- kernel, = ctx.saved_tensors
69
-
70
- gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
71
-
72
- gradgrad_out = upfirdn2d_op.upfirdn2d(
73
- gradgrad_input,
74
- kernel,
75
- ctx.up_x,
76
- ctx.up_y,
77
- ctx.down_x,
78
- ctx.down_y,
79
- ctx.pad_x0,
80
- ctx.pad_x1,
81
- ctx.pad_y0,
82
- ctx.pad_y1,
83
- )
84
- # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
85
- gradgrad_out = gradgrad_out.view(
86
- ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
87
- )
88
-
89
- return gradgrad_out, None, None, None, None, None, None, None, None
90
-
91
-
92
- class UpFirDn2d(Function):
93
- @staticmethod
94
- def forward(ctx, input, kernel, up, down, pad):
95
- up_x, up_y = up
96
- down_x, down_y = down
97
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
98
-
99
- kernel_h, kernel_w = kernel.shape
100
- batch, channel, in_h, in_w = input.shape
101
- ctx.in_size = input.shape
102
-
103
- input = input.reshape(-1, in_h, in_w, 1)
104
-
105
- ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
106
-
107
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
108
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
109
- ctx.out_size = (out_h, out_w)
110
-
111
- ctx.up = (up_x, up_y)
112
- ctx.down = (down_x, down_y)
113
- ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
114
-
115
- g_pad_x0 = kernel_w - pad_x0 - 1
116
- g_pad_y0 = kernel_h - pad_y0 - 1
117
- g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
118
- g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
119
-
120
- ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
121
-
122
- out = upfirdn2d_op.upfirdn2d(
123
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
124
- )
125
- # out = out.view(major, out_h, out_w, minor)
126
- out = out.view(-1, channel, out_h, out_w)
127
-
128
- return out
129
-
130
- @staticmethod
131
- def backward(ctx, grad_output):
132
- kernel, grad_kernel = ctx.saved_tensors
133
-
134
- grad_input = UpFirDn2dBackward.apply(
135
- grad_output,
136
- kernel,
137
- grad_kernel,
138
- ctx.up,
139
- ctx.down,
140
- ctx.pad,
141
- ctx.g_pad,
142
- ctx.in_size,
143
- ctx.out_size,
144
- )
145
-
146
- return grad_input, None, None, None, None
147
-
148
-
149
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0), device='cpu'):
150
- if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu':
151
- out = UpFirDn2d.apply(
152
- input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
153
- )
154
- else:
155
- out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
156
-
157
- return out
158
-
159
-
160
- def upfirdn2d_native(
161
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
162
- ):
163
- input = input.permute(0, 2, 3, 1)
164
- _, in_h, in_w, minor = input.shape
165
- kernel_h, kernel_w = kernel.shape
166
- out = input.view(-1, in_h, 1, in_w, 1, minor)
167
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
168
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
169
-
170
- out = F.pad(
171
- out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
172
- )
173
- out = out[
174
- :,
175
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
176
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
177
- :,
178
- ]
179
-
180
- out = out.permute(0, 3, 1, 2)
181
- out = out.reshape(
182
- [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
183
- )
184
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
185
- out = F.conv2d(out, w)
186
- out = out.reshape(
187
- -1,
188
- minor,
189
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
190
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
191
- )
192
- # out = out.permute(0, 2, 3, 1)
193
- return out[:, :, ::down_y, ::down_x]
194
-
 
1
+ import os
2
+ import platform
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load, _import_module_from_library
8
+
9
+ # # if running GPEN without cuda, please comment line 10-18
10
+ # if platform.system() == 'Linux' and torch.cuda.is_available():
11
+ # module_path = os.path.dirname(__file__)
12
+ # upfirdn2d_op = load(
13
+ # 'upfirdn2d',
14
+ # sources=[
15
+ # os.path.join(module_path, 'upfirdn2d.cpp'),
16
+ # os.path.join(module_path, 'upfirdn2d_kernel.cu'),
17
+ # ],
18
+ # )
19
+
20
+
21
+ #upfirdn2d_op = _import_module_from_library('upfirdn2d', '/tmp/torch_extensions/upfirdn2d', True)
22
+
23
+ class UpFirDn2dBackward(Function):
24
+ @staticmethod
25
+ def forward(
26
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
27
+ ):
28
+
29
+ up_x, up_y = up
30
+ down_x, down_y = down
31
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
32
+
33
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
34
+
35
+ grad_input = upfirdn2d_op.upfirdn2d(
36
+ grad_output,
37
+ grad_kernel,
38
+ down_x,
39
+ down_y,
40
+ up_x,
41
+ up_y,
42
+ g_pad_x0,
43
+ g_pad_x1,
44
+ g_pad_y0,
45
+ g_pad_y1,
46
+ )
47
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
48
+
49
+ ctx.save_for_backward(kernel)
50
+
51
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
52
+
53
+ ctx.up_x = up_x
54
+ ctx.up_y = up_y
55
+ ctx.down_x = down_x
56
+ ctx.down_y = down_y
57
+ ctx.pad_x0 = pad_x0
58
+ ctx.pad_x1 = pad_x1
59
+ ctx.pad_y0 = pad_y0
60
+ ctx.pad_y1 = pad_y1
61
+ ctx.in_size = in_size
62
+ ctx.out_size = out_size
63
+
64
+ return grad_input
65
+
66
+ @staticmethod
67
+ def backward(ctx, gradgrad_input):
68
+ kernel, = ctx.saved_tensors
69
+
70
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
71
+
72
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
73
+ gradgrad_input,
74
+ kernel,
75
+ ctx.up_x,
76
+ ctx.up_y,
77
+ ctx.down_x,
78
+ ctx.down_y,
79
+ ctx.pad_x0,
80
+ ctx.pad_x1,
81
+ ctx.pad_y0,
82
+ ctx.pad_y1,
83
+ )
84
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
85
+ gradgrad_out = gradgrad_out.view(
86
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
87
+ )
88
+
89
+ return gradgrad_out, None, None, None, None, None, None, None, None
90
+
91
+
92
+ class UpFirDn2d(Function):
93
+ @staticmethod
94
+ def forward(ctx, input, kernel, up, down, pad):
95
+ up_x, up_y = up
96
+ down_x, down_y = down
97
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
98
+
99
+ kernel_h, kernel_w = kernel.shape
100
+ batch, channel, in_h, in_w = input.shape
101
+ ctx.in_size = input.shape
102
+
103
+ input = input.reshape(-1, in_h, in_w, 1)
104
+
105
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
106
+
107
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
108
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
109
+ ctx.out_size = (out_h, out_w)
110
+
111
+ ctx.up = (up_x, up_y)
112
+ ctx.down = (down_x, down_y)
113
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
114
+
115
+ g_pad_x0 = kernel_w - pad_x0 - 1
116
+ g_pad_y0 = kernel_h - pad_y0 - 1
117
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
118
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
119
+
120
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
121
+
122
+ out = upfirdn2d_op.upfirdn2d(
123
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
124
+ )
125
+ # out = out.view(major, out_h, out_w, minor)
126
+ out = out.view(-1, channel, out_h, out_w)
127
+
128
+ return out
129
+
130
+ @staticmethod
131
+ def backward(ctx, grad_output):
132
+ kernel, grad_kernel = ctx.saved_tensors
133
+
134
+ grad_input = UpFirDn2dBackward.apply(
135
+ grad_output,
136
+ kernel,
137
+ grad_kernel,
138
+ ctx.up,
139
+ ctx.down,
140
+ ctx.pad,
141
+ ctx.g_pad,
142
+ ctx.in_size,
143
+ ctx.out_size,
144
+ )
145
+
146
+ return grad_input, None, None, None, None
147
+
148
+
149
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0), device='cpu'):
150
+ if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu':
151
+ out = UpFirDn2d.apply(
152
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
153
+ )
154
+ else:
155
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
156
+
157
+ return out
158
+
159
+
160
+ def upfirdn2d_native(
161
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
162
+ ):
163
+ input = input.permute(0, 2, 3, 1)
164
+ _, in_h, in_w, minor = input.shape
165
+ kernel_h, kernel_w = kernel.shape
166
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
167
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
168
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
169
+
170
+ out = F.pad(
171
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
172
+ )
173
+ out = out[
174
+ :,
175
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
176
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
177
+ :,
178
+ ]
179
+
180
+ out = out.permute(0, 3, 1, 2)
181
+ out = out.reshape(
182
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
183
+ )
184
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
185
+ out = F.conv2d(out, w)
186
+ out = out.reshape(
187
+ -1,
188
+ minor,
189
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
190
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
191
+ )
192
+ # out = out.permute(0, 2, 3, 1)
193
+ return out[:, :, ::down_y, ::down_x]
194
+