Spaces:
Runtime error
Runtime error
zejunyang
commited on
Commit
·
9667e74
1
Parent(s):
9464d6e
update
Browse files- .gitattributes +1 -0
- NTED/NTED_module.py +101 -0
- NTED/base_function.py +434 -0
- NTED/base_module.py +115 -0
- NTED/config.py +202 -0
- NTED/demo_dataset.py +182 -0
- NTED/edge_attention_layer.py +116 -0
- NTED/extraction_distribution_model.py +62 -0
- NTED/fashion_512.yaml +129 -0
- NTED/nted_checkpoint.pt +3 -0
- NTED/op/__init__.py +2 -0
- NTED/op/conv2d_gradfix.py +227 -0
- NTED/op/fused_act.py +127 -0
- NTED/op/fused_bias_act.cpp +32 -0
- NTED/op/fused_bias_act_kernel.cu +105 -0
- NTED/op/upfirdn2d.cpp +31 -0
- NTED/op/upfirdn2d.py +209 -0
- NTED/op/upfirdn2d_kernel.cu +369 -0
- app.py +20 -8
- example/exp1.png +0 -0
- example/exp2.png +0 -0
- example/exp3.png +0 -0
- example/exp4.png +0 -0
- example/exp5.png +0 -0
- example/exp6.png +0 -0
- example/ref_img.png +3 -0
- lite_openpose/body_bbox_detector.py +179 -0
- lite_openpose/checkpoint_iter_370000.pth +3 -0
- lite_openpose/modules/__init__.py +0 -0
- lite_openpose/modules/conv.py +32 -0
- lite_openpose/modules/get_parameters.py +23 -0
- lite_openpose/modules/keypoints.py +201 -0
- lite_openpose/modules/load_state.py +32 -0
- lite_openpose/modules/loss.py +5 -0
- lite_openpose/modules/one_euro_filter.py +51 -0
- lite_openpose/modules/pose.py +118 -0
- lite_openpose/pose2d_models/__init__.py +0 -0
- lite_openpose/pose2d_models/with_mobilenet.py +123 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*/ref_img.png filter=lfs diff=lfs merge=lfs -text
|
NTED/NTED_module.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
|
6 |
+
import mediapipe as mp
|
7 |
+
from lite_openpose.body_bbox_detector import BodyPoseEstimator
|
8 |
+
from NTED.extraction_distribution_model import Generator
|
9 |
+
from NTED.demo_dataset import DemoDataset
|
10 |
+
from NTED.base_function import accumulate
|
11 |
+
from NTED.config import Config
|
12 |
+
|
13 |
+
|
14 |
+
def set_random_seed(seed):
|
15 |
+
r"""Set random seeds for everything.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
seed (int): Random seed.
|
19 |
+
by_rank (bool):
|
20 |
+
"""
|
21 |
+
random.seed(seed)
|
22 |
+
np.random.seed(seed)
|
23 |
+
torch.manual_seed(seed)
|
24 |
+
torch.cuda.manual_seed(seed)
|
25 |
+
torch.cuda.manual_seed_all(seed)
|
26 |
+
|
27 |
+
class NTED():
|
28 |
+
def __init__(self):
|
29 |
+
super(NTED, self).__init__()
|
30 |
+
|
31 |
+
self.openpose_module = BodyPoseEstimator('cpu')
|
32 |
+
set_random_seed(0)
|
33 |
+
self.opt = Config('NTED/fashion_512.yaml', is_train=False)
|
34 |
+
|
35 |
+
net_G = Generator(**self.opt.gen.param).to('cpu')
|
36 |
+
net_G_ema = Generator(**self.opt.gen.param).to('cpu')
|
37 |
+
net_G_ema.eval()
|
38 |
+
accumulate(net_G_ema, net_G, 0)
|
39 |
+
|
40 |
+
checkpoint = torch.load('NTED/nted_checkpoint.pt', map_location=lambda storage, loc: storage)
|
41 |
+
net_G_ema.load_state_dict(checkpoint['net_G_ema'])
|
42 |
+
self.net_G = net_G_ema.eval()
|
43 |
+
|
44 |
+
self.data_loader = DemoDataset()
|
45 |
+
|
46 |
+
mp_hands = mp.solutions.hands
|
47 |
+
self.hands = mp_hands.Hands(static_image_mode=True, max_num_hands=2, min_detection_confidence=0.1)
|
48 |
+
|
49 |
+
self.ref_img = cv2.imread('example/ref_img.png')
|
50 |
+
self.ref_img = cv2.resize(self.ref_img, (352, 512))
|
51 |
+
|
52 |
+
def hand_pose_est(self, img):
|
53 |
+
results = self.hands.process(cv2.cvtColor(cv2.flip(img, 1), cv2.COLOR_BGR2RGB))
|
54 |
+
image_height, image_width, _ = img.shape
|
55 |
+
pose_data = []
|
56 |
+
|
57 |
+
if results.multi_hand_landmarks is not None:
|
58 |
+
for hand_landmarks in results.multi_hand_landmarks:
|
59 |
+
for joint_idx in range(21):
|
60 |
+
pose_data.append([image_width - hand_landmarks.landmark[joint_idx].x * image_width, hand_landmarks.landmark[joint_idx].y * image_height])
|
61 |
+
if len(results.multi_hand_landmarks) == 2:
|
62 |
+
if results.multi_handedness[0].classification[0].label == 'Right':
|
63 |
+
# 交换一下,先左手再右手
|
64 |
+
tmp = pose_data[:21].copy()
|
65 |
+
pose_data[:21] = pose_data[21:]
|
66 |
+
pose_data[21:] = tmp
|
67 |
+
elif len(results.multi_hand_landmarks) == 1:
|
68 |
+
miss_hand = [[-1, -1] for _ in range(21)]
|
69 |
+
if results.multi_handedness[0].classification[0].label == 'Left':
|
70 |
+
pose_data += miss_hand
|
71 |
+
else:
|
72 |
+
pose_data = miss_hand + pose_data
|
73 |
+
else:
|
74 |
+
for _ in range(42):
|
75 |
+
pose_data.append([-1, -1])
|
76 |
+
pose_data = np.array(pose_data, dtype=np.int32)
|
77 |
+
|
78 |
+
return pose_data
|
79 |
+
|
80 |
+
|
81 |
+
def inference(self, img):
|
82 |
+
|
83 |
+
img = cv2.resize(img, (352, 512))
|
84 |
+
|
85 |
+
body_pose, bbox = self.openpose_module.detect_body_pose(img.copy())
|
86 |
+
|
87 |
+
hand_pose = self.hand_pose_est(img.copy())
|
88 |
+
|
89 |
+
data = self.data_loader.load_item(self.ref_img, body_pose[0], hand_pose)
|
90 |
+
|
91 |
+
output = self.net_G(
|
92 |
+
data['reference_image'],
|
93 |
+
data['target_skeleton'],
|
94 |
+
)
|
95 |
+
fake_image = output['fake_image'][0]
|
96 |
+
|
97 |
+
fake_image = self.data_loader.tensor2im(fake_image)
|
98 |
+
|
99 |
+
fake_image = cv2.resize(fake_image, (288, 480))
|
100 |
+
|
101 |
+
return data['skeleton_img'], fake_image
|
NTED/base_function.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from NTED.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
|
9 |
+
|
10 |
+
class ExtractionOperation(nn.Module):
|
11 |
+
def __init__(self, in_channel, num_label, match_kernel):
|
12 |
+
super(ExtractionOperation, self).__init__()
|
13 |
+
self.value_conv = EqualConv2d(in_channel, in_channel, match_kernel, 1, match_kernel//2, bias=True)
|
14 |
+
self.semantic_extraction_filter = EqualConv2d(in_channel, num_label, match_kernel, 1, match_kernel//2, bias=False)
|
15 |
+
|
16 |
+
self.softmax = nn.Softmax(dim=-1)
|
17 |
+
self.num_label = num_label
|
18 |
+
|
19 |
+
def forward(self, value, recoder):
|
20 |
+
key = value
|
21 |
+
b,c,h,w = value.shape
|
22 |
+
key = self.semantic_extraction_filter(self.feature_norm(key))
|
23 |
+
extraction_softmax = self.softmax(key.view(b, -1, h*w)) #bkm
|
24 |
+
values_flatten = self.value_conv(value).view(b, -1, h*w)
|
25 |
+
neural_textures = torch.einsum('bkm,bvm->bvk', extraction_softmax, values_flatten)
|
26 |
+
recoder['extraction_softmax'].insert(0, extraction_softmax)
|
27 |
+
recoder['neural_textures'].insert(0, neural_textures)
|
28 |
+
return neural_textures, extraction_softmax
|
29 |
+
|
30 |
+
|
31 |
+
def feature_norm(self, input_tensor):
|
32 |
+
input_tensor = input_tensor - input_tensor.mean(dim=1, keepdim=True)
|
33 |
+
norm = torch.norm(input_tensor, 2, 1, keepdim=True) + sys.float_info.epsilon
|
34 |
+
out = torch.div(input_tensor, norm)
|
35 |
+
return out
|
36 |
+
|
37 |
+
class DistributionOperation(nn.Module):
|
38 |
+
def __init__(self, num_label, input_dim, match_kernel=3):
|
39 |
+
super(DistributionOperation, self).__init__()
|
40 |
+
self.semantic_distribution_filter = EqualConv2d(input_dim, num_label,
|
41 |
+
kernel_size=match_kernel,
|
42 |
+
stride=1,
|
43 |
+
padding=match_kernel//2)
|
44 |
+
self.num_label = num_label
|
45 |
+
|
46 |
+
def forward(self, query, extracted_feature, recoder):
|
47 |
+
b,c,h,w = query.shape
|
48 |
+
|
49 |
+
query = self.semantic_distribution_filter(query)
|
50 |
+
query_flatten = query.view(b, self.num_label, -1)
|
51 |
+
query_softmax = F.softmax(query_flatten, 1)
|
52 |
+
values_q = torch.einsum('bkm,bkv->bvm', query_softmax, extracted_feature.permute(0,2,1))
|
53 |
+
attn_out = values_q.view(b,-1,h,w)
|
54 |
+
recoder['semantic_distribution'].append(query)
|
55 |
+
return attn_out
|
56 |
+
|
57 |
+
class EncoderLayer(nn.Sequential):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
in_channel,
|
61 |
+
out_channel,
|
62 |
+
kernel_size,
|
63 |
+
downsample=False,
|
64 |
+
blur_kernel=[1, 3, 3, 1],
|
65 |
+
bias=True,
|
66 |
+
activate=True,
|
67 |
+
use_extraction=False,
|
68 |
+
num_label=None,
|
69 |
+
match_kernel=None,
|
70 |
+
num_extractions=2
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
if downsample:
|
75 |
+
factor = 2
|
76 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
77 |
+
pad0 = (p + 1) // 2
|
78 |
+
pad1 = p // 2
|
79 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
80 |
+
|
81 |
+
stride = 2
|
82 |
+
padding = 0
|
83 |
+
|
84 |
+
else:
|
85 |
+
self.blur = None
|
86 |
+
stride = 1
|
87 |
+
padding = kernel_size // 2
|
88 |
+
|
89 |
+
|
90 |
+
self.conv = EqualConv2d(
|
91 |
+
in_channel,
|
92 |
+
out_channel,
|
93 |
+
kernel_size,
|
94 |
+
padding=padding,
|
95 |
+
stride=stride,
|
96 |
+
bias=bias and not activate,
|
97 |
+
)
|
98 |
+
|
99 |
+
self.activate = FusedLeakyReLU(out_channel, bias=bias) if activate else None
|
100 |
+
self.use_extraction = use_extraction
|
101 |
+
if self.use_extraction:
|
102 |
+
self.extraction_operations = nn.ModuleList()
|
103 |
+
for _ in range(num_extractions):
|
104 |
+
self.extraction_operations.append(
|
105 |
+
ExtractionOperation(
|
106 |
+
out_channel,
|
107 |
+
num_label,
|
108 |
+
match_kernel
|
109 |
+
)
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, input, recoder=None):
|
113 |
+
out = self.blur(input) if self.blur is not None else input
|
114 |
+
out = self.conv(out)
|
115 |
+
out = self.activate(out) if self.activate is not None else out
|
116 |
+
if self.use_extraction:
|
117 |
+
for extraction_operation in self.extraction_operations:
|
118 |
+
extraction_operation(out, recoder)
|
119 |
+
return out
|
120 |
+
|
121 |
+
|
122 |
+
class DecoderLayer(nn.Module):
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
in_channel,
|
126 |
+
out_channel,
|
127 |
+
kernel_size,
|
128 |
+
upsample=False,
|
129 |
+
blur_kernel=[1, 3, 3, 1],
|
130 |
+
bias=True,
|
131 |
+
activate=True,
|
132 |
+
use_distribution=True,
|
133 |
+
num_label=16,
|
134 |
+
match_kernel=3,
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
if upsample:
|
138 |
+
factor = 2
|
139 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
140 |
+
pad0 = (p + 1) // 2 + factor - 1
|
141 |
+
pad1 = p // 2 + 1
|
142 |
+
|
143 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
144 |
+
self.conv = EqualTransposeConv2d(
|
145 |
+
in_channel,
|
146 |
+
out_channel,
|
147 |
+
kernel_size,
|
148 |
+
stride=2,
|
149 |
+
padding=0,
|
150 |
+
bias=bias and not activate,
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
self.conv = EqualConv2d(
|
154 |
+
in_channel,
|
155 |
+
out_channel,
|
156 |
+
kernel_size,
|
157 |
+
stride=1,
|
158 |
+
padding=kernel_size//2,
|
159 |
+
bias=bias and not activate,
|
160 |
+
)
|
161 |
+
self.blur = None
|
162 |
+
|
163 |
+
self.distribution_operation = DistributionOperation(
|
164 |
+
num_label,
|
165 |
+
out_channel,
|
166 |
+
match_kernel=match_kernel
|
167 |
+
) if use_distribution else None
|
168 |
+
self.activate = FusedLeakyReLU(out_channel, bias=bias) if activate else None
|
169 |
+
self.use_distribution = use_distribution
|
170 |
+
|
171 |
+
def forward(self, input, neural_texture=None, recoder=None):
|
172 |
+
out = self.conv(input)
|
173 |
+
out = self.blur(out) if self.blur is not None else out
|
174 |
+
if self.use_distribution and neural_texture is not None:
|
175 |
+
out_attn = self.distribution_operation(out, neural_texture, recoder)
|
176 |
+
out = (out + out_attn) / math.sqrt(2)
|
177 |
+
|
178 |
+
out = self.activate(out.contiguous()) if self.activate is not None else out
|
179 |
+
|
180 |
+
return out
|
181 |
+
|
182 |
+
class EqualConv2d(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
self.weight = nn.Parameter(
|
189 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
190 |
+
)
|
191 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
192 |
+
|
193 |
+
self.stride = stride
|
194 |
+
self.padding = padding
|
195 |
+
|
196 |
+
if bias:
|
197 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
198 |
+
|
199 |
+
else:
|
200 |
+
self.bias = None
|
201 |
+
|
202 |
+
def forward(self, input):
|
203 |
+
out = conv2d_gradfix.conv2d(
|
204 |
+
input,
|
205 |
+
self.weight * self.scale,
|
206 |
+
bias=self.bias,
|
207 |
+
stride=self.stride,
|
208 |
+
padding=self.padding,
|
209 |
+
)
|
210 |
+
|
211 |
+
return out
|
212 |
+
|
213 |
+
def __repr__(self):
|
214 |
+
return (
|
215 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
216 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
class EqualTransposeConv2d(nn.Module):
|
221 |
+
def __init__(
|
222 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
223 |
+
):
|
224 |
+
super().__init__()
|
225 |
+
|
226 |
+
self.weight = nn.Parameter(
|
227 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
228 |
+
)
|
229 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
230 |
+
|
231 |
+
self.stride = stride
|
232 |
+
self.padding = padding
|
233 |
+
|
234 |
+
if bias:
|
235 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
236 |
+
|
237 |
+
else:
|
238 |
+
self.bias = None
|
239 |
+
|
240 |
+
def forward(self, input):
|
241 |
+
weight = self.weight.transpose(0,1)
|
242 |
+
out = conv2d_gradfix.conv_transpose2d(
|
243 |
+
input,
|
244 |
+
weight * self.scale,
|
245 |
+
bias=self.bias,
|
246 |
+
stride=self.stride,
|
247 |
+
padding=self.padding,
|
248 |
+
)
|
249 |
+
|
250 |
+
return out
|
251 |
+
|
252 |
+
def __repr__(self):
|
253 |
+
return (
|
254 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
255 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
256 |
+
)
|
257 |
+
|
258 |
+
class ToRGB(nn.Module):
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
in_channel,
|
262 |
+
upsample=True,
|
263 |
+
blur_kernel=[1, 3, 3, 1]
|
264 |
+
):
|
265 |
+
super().__init__()
|
266 |
+
|
267 |
+
if upsample:
|
268 |
+
self.upsample = Upsample(blur_kernel)
|
269 |
+
self.conv = EqualConv2d(in_channel, 3, 3, stride=1, padding=1)
|
270 |
+
|
271 |
+
def forward(self, input, skip=None):
|
272 |
+
out = self.conv(input)
|
273 |
+
if skip is not None:
|
274 |
+
skip = self.upsample(skip)
|
275 |
+
out = out + skip
|
276 |
+
return out
|
277 |
+
|
278 |
+
|
279 |
+
class EqualLinear(nn.Module):
|
280 |
+
def __init__(
|
281 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
286 |
+
|
287 |
+
if bias:
|
288 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
289 |
+
|
290 |
+
else:
|
291 |
+
self.bias = None
|
292 |
+
|
293 |
+
self.activation = activation
|
294 |
+
|
295 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
296 |
+
self.lr_mul = lr_mul
|
297 |
+
|
298 |
+
def forward(self, input):
|
299 |
+
if self.activation:
|
300 |
+
out = F.linear(input, self.weight * self.scale)
|
301 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
302 |
+
|
303 |
+
else:
|
304 |
+
out = F.linear(
|
305 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
306 |
+
)
|
307 |
+
|
308 |
+
return out
|
309 |
+
|
310 |
+
def __repr__(self):
|
311 |
+
return (
|
312 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
313 |
+
)
|
314 |
+
|
315 |
+
class Upsample(nn.Module):
|
316 |
+
def __init__(self, kernel, factor=2):
|
317 |
+
super().__init__()
|
318 |
+
|
319 |
+
self.factor = factor
|
320 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
321 |
+
self.register_buffer("kernel", kernel)
|
322 |
+
|
323 |
+
p = kernel.shape[0] - factor
|
324 |
+
|
325 |
+
pad0 = (p + 1) // 2 + factor - 1
|
326 |
+
pad1 = p // 2
|
327 |
+
|
328 |
+
self.pad = (pad0, pad1)
|
329 |
+
|
330 |
+
def forward(self, input):
|
331 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
332 |
+
|
333 |
+
return out
|
334 |
+
|
335 |
+
class ResBlock(nn.Module):
|
336 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
337 |
+
super().__init__()
|
338 |
+
|
339 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
340 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
341 |
+
|
342 |
+
self.skip = ConvLayer(
|
343 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
344 |
+
)
|
345 |
+
|
346 |
+
def forward(self, input):
|
347 |
+
out = self.conv1(input)
|
348 |
+
out = self.conv2(out)
|
349 |
+
|
350 |
+
skip = self.skip(input)
|
351 |
+
out = (out + skip) / math.sqrt(2)
|
352 |
+
|
353 |
+
return out
|
354 |
+
|
355 |
+
class ConvLayer(nn.Sequential):
|
356 |
+
def __init__(
|
357 |
+
self,
|
358 |
+
in_channel,
|
359 |
+
out_channel,
|
360 |
+
kernel_size,
|
361 |
+
downsample=False,
|
362 |
+
blur_kernel=[1, 3, 3, 1],
|
363 |
+
bias=True,
|
364 |
+
activate=True,
|
365 |
+
):
|
366 |
+
layers = []
|
367 |
+
|
368 |
+
if downsample:
|
369 |
+
factor = 2
|
370 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
371 |
+
pad0 = (p + 1) // 2
|
372 |
+
pad1 = p // 2
|
373 |
+
|
374 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
375 |
+
|
376 |
+
stride = 2
|
377 |
+
self.padding = 0
|
378 |
+
|
379 |
+
else:
|
380 |
+
stride = 1
|
381 |
+
self.padding = kernel_size // 2
|
382 |
+
|
383 |
+
layers.append(
|
384 |
+
EqualConv2d(
|
385 |
+
in_channel,
|
386 |
+
out_channel,
|
387 |
+
kernel_size,
|
388 |
+
padding=self.padding,
|
389 |
+
stride=stride,
|
390 |
+
bias=bias and not activate,
|
391 |
+
)
|
392 |
+
)
|
393 |
+
|
394 |
+
if activate:
|
395 |
+
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
396 |
+
|
397 |
+
super().__init__(*layers)
|
398 |
+
|
399 |
+
|
400 |
+
class Blur(nn.Module):
|
401 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
402 |
+
super().__init__()
|
403 |
+
|
404 |
+
kernel = make_kernel(kernel)
|
405 |
+
|
406 |
+
if upsample_factor > 1:
|
407 |
+
kernel = kernel * (upsample_factor ** 2)
|
408 |
+
|
409 |
+
self.register_buffer("kernel", kernel)
|
410 |
+
|
411 |
+
self.pad = pad
|
412 |
+
|
413 |
+
def forward(self, input):
|
414 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
415 |
+
|
416 |
+
return out
|
417 |
+
|
418 |
+
|
419 |
+
def make_kernel(k):
|
420 |
+
k = torch.tensor(k, dtype=torch.float32)
|
421 |
+
|
422 |
+
if k.ndim == 1:
|
423 |
+
k = k[None, :] * k[:, None]
|
424 |
+
|
425 |
+
k /= k.sum()
|
426 |
+
|
427 |
+
return k
|
428 |
+
|
429 |
+
def accumulate(model1, model2, decay=0.999):
|
430 |
+
par1 = dict(model1.named_parameters())
|
431 |
+
par2 = dict(model2.named_parameters())
|
432 |
+
|
433 |
+
for k in par1.keys():
|
434 |
+
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
|
NTED/base_module.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import functools
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from NTED.base_function import EncoderLayer, DecoderLayer, ToRGB
|
9 |
+
from NTED.edge_attention_layer import Edge_Attn
|
10 |
+
|
11 |
+
class Encoder(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
size,
|
15 |
+
input_dim,
|
16 |
+
channels,
|
17 |
+
num_labels=None,
|
18 |
+
match_kernels=None,
|
19 |
+
blur_kernel=[1, 3, 3, 1],
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.first = EncoderLayer(input_dim, channels[size], 1)
|
23 |
+
self.convs = nn.ModuleList()
|
24 |
+
|
25 |
+
log_size = int(math.log(size, 2))
|
26 |
+
self.log_size = log_size
|
27 |
+
|
28 |
+
in_channel = channels[size]
|
29 |
+
for i in range(log_size-1, 3, -1):
|
30 |
+
out_channel = channels[2 ** i]
|
31 |
+
num_label = num_labels[2 ** i] if num_labels is not None else None
|
32 |
+
match_kernel = match_kernels[2 ** i] if match_kernels is not None else None
|
33 |
+
use_extraction = num_label and match_kernel
|
34 |
+
conv = EncoderLayer(
|
35 |
+
in_channel,
|
36 |
+
out_channel,
|
37 |
+
kernel_size=3,
|
38 |
+
downsample=True,
|
39 |
+
blur_kernel=blur_kernel,
|
40 |
+
use_extraction=use_extraction,
|
41 |
+
num_label=num_label,
|
42 |
+
match_kernel=match_kernel
|
43 |
+
)
|
44 |
+
|
45 |
+
self.convs.append(conv)
|
46 |
+
in_channel = out_channel
|
47 |
+
|
48 |
+
def forward(self, input, recoder=None):
|
49 |
+
out = self.first(input)
|
50 |
+
for idx, layer in enumerate(self.convs):
|
51 |
+
out = layer(out, recoder)
|
52 |
+
return out
|
53 |
+
|
54 |
+
class Decoder(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
size,
|
58 |
+
channels,
|
59 |
+
num_labels,
|
60 |
+
match_kernels,
|
61 |
+
blur_kernel=[1, 3, 3, 1],
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
|
66 |
+
self.convs = nn.ModuleList()
|
67 |
+
# input at resolution 16*16
|
68 |
+
in_channel = channels[16]
|
69 |
+
self.log_size = int(math.log(size, 2))
|
70 |
+
|
71 |
+
for i in range(4, self.log_size + 1):
|
72 |
+
out_channel = channels[2 ** i]
|
73 |
+
num_label, match_kernel = num_labels[2 ** i], match_kernels[2 ** i]
|
74 |
+
use_distribution = num_label and match_kernel
|
75 |
+
upsample = (i != 4)
|
76 |
+
|
77 |
+
base_layer = functools.partial(
|
78 |
+
DecoderLayer,
|
79 |
+
out_channel=out_channel,
|
80 |
+
kernel_size=3,
|
81 |
+
blur_kernel=blur_kernel,
|
82 |
+
use_distribution=use_distribution,
|
83 |
+
num_label=num_label,
|
84 |
+
match_kernel=match_kernel
|
85 |
+
)
|
86 |
+
|
87 |
+
up = nn.Module()
|
88 |
+
up.conv0 = base_layer(in_channel=in_channel, upsample=upsample)
|
89 |
+
up.conv1 = base_layer(in_channel=out_channel, upsample=False)
|
90 |
+
up.to_rgb = ToRGB(out_channel, upsample=upsample)
|
91 |
+
self.convs.append(up)
|
92 |
+
in_channel = out_channel
|
93 |
+
|
94 |
+
self.num_labels, self.match_kernels = num_labels, match_kernels
|
95 |
+
|
96 |
+
self.edge_attn_block = Edge_Attn(in_channels=3)
|
97 |
+
|
98 |
+
def forward(self, input, neural_textures, recoder):
|
99 |
+
counter = 0
|
100 |
+
out, skip = input, None
|
101 |
+
for i, up in enumerate(self.convs):
|
102 |
+
if self.num_labels[2**(i+4)] and self.match_kernels[2**(i+4)]:
|
103 |
+
neural_texture_conv0 = neural_textures[counter]
|
104 |
+
neural_texture_conv1 = neural_textures[counter+1]
|
105 |
+
counter += 2
|
106 |
+
else:
|
107 |
+
neural_texture_conv0, neural_texture_conv1 = None, None
|
108 |
+
out = up.conv0(out, neural_texture=neural_texture_conv0, recoder=recoder)
|
109 |
+
out = up.conv1(out, neural_texture=neural_texture_conv1, recoder=recoder)
|
110 |
+
|
111 |
+
skip = up.to_rgb(out, skip)
|
112 |
+
image = self.edge_attn_block(skip)
|
113 |
+
# image = skip
|
114 |
+
return image
|
115 |
+
|
NTED/config.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
class AttrDict(dict):
|
9 |
+
"""Dict as attribute trick."""
|
10 |
+
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
13 |
+
self.__dict__ = self
|
14 |
+
for key, value in self.__dict__.items():
|
15 |
+
if isinstance(value, dict):
|
16 |
+
self.__dict__[key] = AttrDict(value)
|
17 |
+
elif isinstance(value, (list, tuple)):
|
18 |
+
if isinstance(value[0], dict):
|
19 |
+
self.__dict__[key] = [AttrDict(item) for item in value]
|
20 |
+
else:
|
21 |
+
self.__dict__[key] = value
|
22 |
+
|
23 |
+
def yaml(self):
|
24 |
+
"""Convert object to yaml dict and return."""
|
25 |
+
yaml_dict = {}
|
26 |
+
for key, value in self.__dict__.items():
|
27 |
+
if isinstance(value, AttrDict):
|
28 |
+
yaml_dict[key] = value.yaml()
|
29 |
+
elif isinstance(value, list):
|
30 |
+
if isinstance(value[0], AttrDict):
|
31 |
+
new_l = []
|
32 |
+
for item in value:
|
33 |
+
new_l.append(item.yaml())
|
34 |
+
yaml_dict[key] = new_l
|
35 |
+
else:
|
36 |
+
yaml_dict[key] = value
|
37 |
+
else:
|
38 |
+
yaml_dict[key] = value
|
39 |
+
return yaml_dict
|
40 |
+
|
41 |
+
def __repr__(self):
|
42 |
+
"""Print all variables."""
|
43 |
+
ret_str = []
|
44 |
+
for key, value in self.__dict__.items():
|
45 |
+
if isinstance(value, AttrDict):
|
46 |
+
ret_str.append('{}:'.format(key))
|
47 |
+
child_ret_str = value.__repr__().split('\n')
|
48 |
+
for item in child_ret_str:
|
49 |
+
ret_str.append(' ' + item)
|
50 |
+
elif isinstance(value, list):
|
51 |
+
if isinstance(value[0], AttrDict):
|
52 |
+
ret_str.append('{}:'.format(key))
|
53 |
+
for item in value:
|
54 |
+
# Treat as AttrDict above.
|
55 |
+
child_ret_str = item.__repr__().split('\n')
|
56 |
+
for item in child_ret_str:
|
57 |
+
ret_str.append(' ' + item)
|
58 |
+
else:
|
59 |
+
ret_str.append('{}: {}'.format(key, value))
|
60 |
+
else:
|
61 |
+
ret_str.append('{}: {}'.format(key, value))
|
62 |
+
return '\n'.join(ret_str)
|
63 |
+
|
64 |
+
|
65 |
+
class Config(AttrDict):
|
66 |
+
r"""Configuration class. This should include every human specifiable
|
67 |
+
hyperparameter values for your training."""
|
68 |
+
|
69 |
+
def __init__(self, filename=None, verbose=False, is_train=True):
|
70 |
+
super(Config, self).__init__()
|
71 |
+
# Set default parameters.
|
72 |
+
# Logging.
|
73 |
+
|
74 |
+
large_number = 1000000000
|
75 |
+
self.snapshot_save_iter = large_number
|
76 |
+
self.snapshot_save_epoch = large_number
|
77 |
+
self.snapshot_save_start_iter = 0
|
78 |
+
self.snapshot_save_start_epoch = 0
|
79 |
+
self.image_save_iter = large_number
|
80 |
+
self.eval_epoch = large_number
|
81 |
+
self.start_eval_epoch = large_number
|
82 |
+
self.eval_epoch = large_number
|
83 |
+
self.max_epoch = large_number
|
84 |
+
self.max_iter = large_number
|
85 |
+
self.logging_iter = 100
|
86 |
+
self.image_to_tensorboard=False
|
87 |
+
self.which_iter = None
|
88 |
+
self.resume = True
|
89 |
+
|
90 |
+
|
91 |
+
self.checkpoints_dir = 'NTED'
|
92 |
+
self.name = 'nted_checkpoint.pt'
|
93 |
+
self.phase = 'train' if is_train else 'test'
|
94 |
+
|
95 |
+
# Networks.
|
96 |
+
self.gen = AttrDict(type='generators.dummy')
|
97 |
+
self.dis = AttrDict(type='discriminators.dummy')
|
98 |
+
|
99 |
+
# Optimizers.
|
100 |
+
self.gen_optimizer = AttrDict(type='adam',
|
101 |
+
lr=0.0001,
|
102 |
+
adam_beta1=0.0,
|
103 |
+
adam_beta2=0.999,
|
104 |
+
eps=1e-8,
|
105 |
+
lr_policy=AttrDict(iteration_mode=False,
|
106 |
+
type='step',
|
107 |
+
step_size=large_number,
|
108 |
+
gamma=1))
|
109 |
+
self.dis_optimizer = AttrDict(type='adam',
|
110 |
+
lr=0.0001,
|
111 |
+
adam_beta1=0.0,
|
112 |
+
adam_beta2=0.999,
|
113 |
+
eps=1e-8,
|
114 |
+
lr_policy=AttrDict(iteration_mode=False,
|
115 |
+
type='step',
|
116 |
+
step_size=large_number,
|
117 |
+
gamma=1))
|
118 |
+
# Data.
|
119 |
+
self.data = AttrDict(name='dummy',
|
120 |
+
type='datasets.images',
|
121 |
+
num_workers=0)
|
122 |
+
self.test_data = AttrDict(name='dummy',
|
123 |
+
type='datasets.images',
|
124 |
+
num_workers=0,
|
125 |
+
test=AttrDict(is_lmdb=False,
|
126 |
+
roots='',
|
127 |
+
batch_size=1))
|
128 |
+
self.trainer = AttrDict(
|
129 |
+
image_to_tensorboard=False,
|
130 |
+
hparam_to_tensorboard=False)
|
131 |
+
|
132 |
+
# Cudnn.
|
133 |
+
self.cudnn = AttrDict(deterministic=False,
|
134 |
+
benchmark=True)
|
135 |
+
|
136 |
+
# Others.
|
137 |
+
self.pretrained_weight = ''
|
138 |
+
self.inference_args = AttrDict()
|
139 |
+
|
140 |
+
|
141 |
+
# Update with given configurations.
|
142 |
+
assert os.path.exists(filename), 'File {} not exist.'.format(filename)
|
143 |
+
loader = yaml.SafeLoader
|
144 |
+
loader.add_implicit_resolver(
|
145 |
+
u'tag:yaml.org,2002:float',
|
146 |
+
re.compile(u'''^(?:
|
147 |
+
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
148 |
+
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
149 |
+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
150 |
+
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
151 |
+
|[-+]?\\.(?:inf|Inf|INF)
|
152 |
+
|\\.(?:nan|NaN|NAN))$''', re.X),
|
153 |
+
list(u'-+0123456789.'))
|
154 |
+
try:
|
155 |
+
with open(filename, 'r') as f:
|
156 |
+
cfg_dict = yaml.load(f, Loader=loader)
|
157 |
+
except EnvironmentError:
|
158 |
+
print('Please check the file with name of "%s"', filename)
|
159 |
+
recursive_update(self, cfg_dict)
|
160 |
+
|
161 |
+
# Put common opts in both gen and dis.
|
162 |
+
if 'common' in cfg_dict:
|
163 |
+
self.common = AttrDict(**cfg_dict['common'])
|
164 |
+
self.gen.common = self.common
|
165 |
+
self.dis.common = self.common
|
166 |
+
|
167 |
+
|
168 |
+
if verbose:
|
169 |
+
print(' config '.center(80, '-'))
|
170 |
+
print(self.__repr__())
|
171 |
+
print(''.center(80, '-'))
|
172 |
+
|
173 |
+
|
174 |
+
def rsetattr(obj, attr, val):
|
175 |
+
"""Recursively find object and set value"""
|
176 |
+
pre, _, post = attr.rpartition('.')
|
177 |
+
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
|
178 |
+
|
179 |
+
|
180 |
+
def rgetattr(obj, attr, *args):
|
181 |
+
"""Recursively find object and return value"""
|
182 |
+
|
183 |
+
def _getattr(obj, attr):
|
184 |
+
r"""Get attribute."""
|
185 |
+
return getattr(obj, attr, *args)
|
186 |
+
|
187 |
+
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
188 |
+
|
189 |
+
|
190 |
+
def recursive_update(d, u):
|
191 |
+
"""Recursively update AttrDict d with AttrDict u"""
|
192 |
+
for key, value in u.items():
|
193 |
+
if isinstance(value, collections.abc.Mapping):
|
194 |
+
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
|
195 |
+
elif isinstance(value, (list, tuple)):
|
196 |
+
if isinstance(value[0], dict):
|
197 |
+
d.__dict__[key] = [AttrDict(item) for item in value]
|
198 |
+
else:
|
199 |
+
d.__dict__[key] = value
|
200 |
+
else:
|
201 |
+
d.__dict__[key] = value
|
202 |
+
return d
|
NTED/demo_dataset.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchvision.transforms.functional as F
|
10 |
+
|
11 |
+
class DemoDataset(object):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
self.LIMBSEQ = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
15 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
16 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
17 |
+
|
18 |
+
self.COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
19 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
20 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
21 |
+
|
22 |
+
self.LIMBSEQ_hands = [[0, 1], [1, 2], [2, 3], [3, 4], \
|
23 |
+
[0, 5], [5, 6], [6, 7], [7, 8], \
|
24 |
+
[0, 9], [9, 10], [10, 11], [11, 12], \
|
25 |
+
[0, 13], [13, 14], [14, 15], [15, 16], \
|
26 |
+
[0, 17], [17, 18], [18, 19], [19, 20], \
|
27 |
+
[21, 22], [22, 23], [23, 24], [24, 25], \
|
28 |
+
[21, 26], [26, 27], [27, 28], [28, 29], \
|
29 |
+
[21, 30], [30, 31], [31, 32], [32, 33], \
|
30 |
+
[21, 34], [34, 35], [35, 36], [36, 37], \
|
31 |
+
[21, 38], [38, 39], [39, 40], [40, 41]]
|
32 |
+
|
33 |
+
self.COLORS_hands = [[85, 0, 0], [170, 0, 0], [85, 85, 0], [85, 170, 0], [170, 85, 0], [170, 170, 0], [85, 85, 85], \
|
34 |
+
[85, 85, 170], [85, 170, 85], [85, 170, 170], [0, 85, 0], [0, 170, 0], [0, 85, 85], [0, 85, 170], \
|
35 |
+
[0, 170, 85], [0, 170, 170], [50, 0, 0], [135, 0, 0], [50, 50, 0], [50, 135, 0], [135, 50, 0], \
|
36 |
+
[135, 135, 0], [50, 50, 50], [50, 50, 135], [50, 135, 50], [50, 135, 135], [0, 50, 0], [0, 135, 0], \
|
37 |
+
[0, 50, 50], [0, 50, 135], [0, 135, 50], [0, 135, 135], [100, 0, 0], [200, 0, 0], [100, 100, 0], \
|
38 |
+
[100, 200, 0], [200, 100, 0], [200, 200, 0], [100, 100, 100], [100, 100, 200], [100, 200, 100], [100, 200, 200]
|
39 |
+
]
|
40 |
+
|
41 |
+
self.img_size = tuple([512, 352])
|
42 |
+
|
43 |
+
def load_item(self, img, pose, handpose=None):
|
44 |
+
|
45 |
+
reference_img = self.get_image_tensor(img)[None,:]
|
46 |
+
label, ske = self.get_label_tensor(pose, handpose)
|
47 |
+
label = label[None,:]
|
48 |
+
|
49 |
+
return {'reference_image':reference_img, 'target_skeleton':label, 'skeleton_img': ske}
|
50 |
+
|
51 |
+
def get_image_tensor(self, bgr_img):
|
52 |
+
img = Image.fromarray(cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB))
|
53 |
+
img = F.resize(img, self.img_size)
|
54 |
+
img = F.to_tensor(img)
|
55 |
+
img = F.normalize(img, (0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
|
56 |
+
return img
|
57 |
+
|
58 |
+
def get_label_tensor(self, pose, hand_pose=None):
|
59 |
+
canvas = np.zeros((self.img_size[0], self.img_size[1], 3)).astype(np.uint8)
|
60 |
+
keypoint = np.array(pose)
|
61 |
+
if hand_pose is not None:
|
62 |
+
keypoint_hands = np.array(hand_pose)
|
63 |
+
else:
|
64 |
+
keypoint_hands = None
|
65 |
+
|
66 |
+
# keypoint = self.trans_keypoins(keypoint)
|
67 |
+
|
68 |
+
stickwidth = 4
|
69 |
+
for i in range(18):
|
70 |
+
x, y = keypoint[i, 0:2]
|
71 |
+
if x == -1 or y == -1:
|
72 |
+
continue
|
73 |
+
cv2.circle(canvas, (int(x), int(y)), 4, self.COLORS[i], thickness=-1)
|
74 |
+
if keypoint_hands is not None:
|
75 |
+
for i in range(42):
|
76 |
+
x, y = keypoint_hands[i, 0:2]
|
77 |
+
if x == -1 or y == -1:
|
78 |
+
continue
|
79 |
+
cv2.circle(canvas, (int(x), int(y)), 4, self.COLORS_hands[i], thickness=-1)
|
80 |
+
|
81 |
+
joints = []
|
82 |
+
for i in range(17):
|
83 |
+
Y = keypoint[np.array(self.LIMBSEQ[i])-1, 0]
|
84 |
+
X = keypoint[np.array(self.LIMBSEQ[i])-1, 1]
|
85 |
+
cur_canvas = canvas.copy()
|
86 |
+
if -1 in Y or -1 in X:
|
87 |
+
joints.append(np.zeros_like(cur_canvas[:, :, 0]))
|
88 |
+
continue
|
89 |
+
mX = np.mean(X)
|
90 |
+
mY = np.mean(Y)
|
91 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
92 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
93 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
94 |
+
cv2.fillConvexPoly(cur_canvas, polygon, self.COLORS[i])
|
95 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
96 |
+
|
97 |
+
joint = np.zeros_like(cur_canvas[:, :, 0])
|
98 |
+
cv2.fillConvexPoly(joint, polygon, 255)
|
99 |
+
joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
|
100 |
+
joints.append(joint)
|
101 |
+
if keypoint_hands is not None:
|
102 |
+
for i in range(40):
|
103 |
+
Y = keypoint_hands[np.array(self.LIMBSEQ_hands[i]), 0]
|
104 |
+
X = keypoint_hands[np.array(self.LIMBSEQ_hands[i]), 1]
|
105 |
+
cur_canvas = canvas.copy()
|
106 |
+
if -1 in Y or -1 in X:
|
107 |
+
if (i+1) % 4 == 0:
|
108 |
+
joints.append(np.zeros_like(cur_canvas[:, :, 0]))
|
109 |
+
continue
|
110 |
+
mX = np.mean(X)
|
111 |
+
mY = np.mean(Y)
|
112 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
113 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
114 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), int(stickwidth/2)), int(angle), 0, 360, 1)
|
115 |
+
cv2.fillConvexPoly(cur_canvas, polygon, self.COLORS_hands[i])
|
116 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
117 |
+
|
118 |
+
# 一根手指一个通道
|
119 |
+
if i % 4 == 0:
|
120 |
+
joint = np.zeros_like(cur_canvas[:, :, 0])
|
121 |
+
cv2.fillConvexPoly(joint, polygon, 255)
|
122 |
+
joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
|
123 |
+
if (i+1) % 4 == 0:
|
124 |
+
joints.append(joint)
|
125 |
+
|
126 |
+
pose = F.to_tensor(Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
|
127 |
+
|
128 |
+
tensors_dist = 0
|
129 |
+
e = 1
|
130 |
+
for i in range(len(joints)):
|
131 |
+
im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
|
132 |
+
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
|
133 |
+
tensor_dist = F.to_tensor(Image.fromarray(im_dist))
|
134 |
+
tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
|
135 |
+
e += 1
|
136 |
+
|
137 |
+
label_tensor = torch.cat((pose, tensors_dist), dim=0)
|
138 |
+
|
139 |
+
return label_tensor, canvas
|
140 |
+
|
141 |
+
def tensor2im(self, image_tensor, imtype=np.uint8, normalize=True,
|
142 |
+
three_channel_output=True):
|
143 |
+
r"""Convert tensor to image.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
image_tensor (torch.tensor or list of torch.tensor): If tensor then
|
147 |
+
(NxCxHxW) or (NxTxCxHxW) or (CxHxW).
|
148 |
+
imtype (np.dtype): Type of output image.
|
149 |
+
normalize (bool): Is the input image normalized or not?
|
150 |
+
three_channel_output (bool): Should single channel images be made 3
|
151 |
+
channel in output?
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
(numpy.ndarray, list if case 1, 2 above).
|
155 |
+
"""
|
156 |
+
if image_tensor is None:
|
157 |
+
return None
|
158 |
+
if isinstance(image_tensor, list):
|
159 |
+
return [self.tensor2im(x, imtype, normalize) for x in image_tensor]
|
160 |
+
if image_tensor.dim() == 5 or image_tensor.dim() == 4:
|
161 |
+
return [self.tensor2im(image_tensor[idx], imtype, normalize)
|
162 |
+
for idx in range(image_tensor.size(0))]
|
163 |
+
|
164 |
+
if image_tensor.dim() == 3:
|
165 |
+
image_numpy = image_tensor.detach().float().numpy()
|
166 |
+
if normalize:
|
167 |
+
image_numpy = (np.transpose(
|
168 |
+
image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
169 |
+
else:
|
170 |
+
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
|
171 |
+
image_numpy = np.clip(image_numpy, 0, 255)
|
172 |
+
if image_numpy.shape[2] == 1 and three_channel_output:
|
173 |
+
image_numpy = np.repeat(image_numpy, 3, axis=2)
|
174 |
+
elif image_numpy.shape[2] > 3:
|
175 |
+
image_numpy = image_numpy[:, :, :3]
|
176 |
+
return image_numpy.astype(imtype)
|
177 |
+
|
178 |
+
def trans_keypoins(self, keypoints):
|
179 |
+
missing_keypoint_index = keypoints == -1
|
180 |
+
|
181 |
+
keypoints[missing_keypoint_index] = -1
|
182 |
+
return keypoints
|
NTED/edge_attention_layer.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Date: 2023-03-14
|
2 |
+
# Creater: zejunyang
|
3 |
+
# Function: 边缘注意力层。
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from NTED.base_function import Blur
|
10 |
+
|
11 |
+
|
12 |
+
class ResBlock(nn.Module):
|
13 |
+
def __init__(self, in_nc, out_nc, scale='down'): # , norm_layer=nn.BatchNorm2d
|
14 |
+
super(ResBlock, self).__init__()
|
15 |
+
use_bias = True
|
16 |
+
assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'"
|
17 |
+
|
18 |
+
if scale == 'same':
|
19 |
+
# self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True)
|
20 |
+
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=True)
|
21 |
+
if scale == 'up':
|
22 |
+
self.scale = nn.Sequential(
|
23 |
+
nn.Upsample(scale_factor=2, mode='bilinear'),
|
24 |
+
nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True)
|
25 |
+
)
|
26 |
+
if scale == 'down':
|
27 |
+
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias)
|
28 |
+
|
29 |
+
self.block = nn.Sequential(
|
30 |
+
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
31 |
+
# norm_layer(out_nc),
|
32 |
+
nn.ReLU(inplace=True),
|
33 |
+
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
34 |
+
# norm_layer(out_nc)
|
35 |
+
)
|
36 |
+
self.relu = nn.ReLU(inplace=True)
|
37 |
+
# self.padding = nn.ReplicationPad2d(padding=(0, 1, 0, 0))
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
residual = self.scale(x)
|
41 |
+
return self.relu(residual + self.block(residual))
|
42 |
+
|
43 |
+
|
44 |
+
class Edge_Attn(nn.Module):
|
45 |
+
def __init__(self, in_channels=3):
|
46 |
+
super(Edge_Attn, self).__init__()
|
47 |
+
self.in_channels = in_channels
|
48 |
+
|
49 |
+
blur_kernel=[1, 3, 3, 3, 1]
|
50 |
+
self.blur = Blur(blur_kernel, pad=(2, 2), upsample_factor=1)
|
51 |
+
|
52 |
+
# self.conv = nn.Conv2d(self.in_channels, self.in_channels, 3, padding=1, bias=False)
|
53 |
+
self.res_block = ResBlock(self.in_channels, self.in_channels, scale='same')
|
54 |
+
self.sigmoid = nn.Sigmoid()
|
55 |
+
|
56 |
+
def gradient(self, x):
|
57 |
+
h_x = x.size()[2]
|
58 |
+
w_x = x.size()[3]
|
59 |
+
stride = 3
|
60 |
+
r = F.pad(x, (0, stride, 0, 0), mode='replicate')[:, :, :, stride:]
|
61 |
+
l = F.pad(x, (stride, 0, 0, 0), mode='replicate')[:, :, :, :w_x]
|
62 |
+
t = F.pad(x, (0, 0, stride, 0), mode='replicate')[:, :, :h_x, :]
|
63 |
+
b = F.pad(x, (0, 0, 0, stride), mode='replicate')[:, :, stride:, :]
|
64 |
+
xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5)
|
65 |
+
xgrad = self.blur(xgrad)
|
66 |
+
return xgrad
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
# feature_edge = self.gradient(x).detach()
|
70 |
+
# attn = self.conv(feature_edge)
|
71 |
+
|
72 |
+
for b in range(x.shape[0]):
|
73 |
+
for c in range(x.shape[1]):
|
74 |
+
if c == 0:
|
75 |
+
channel_edge = self.gradient(x[b:b+1, c:c+1])
|
76 |
+
else:
|
77 |
+
channel_edge = torch.concat([channel_edge, self.gradient(x[b:b+1, c:c+1])], dim=1)
|
78 |
+
if b == 0:
|
79 |
+
feature_edge = channel_edge
|
80 |
+
else:
|
81 |
+
feature_edge = torch.concat([feature_edge, channel_edge], dim=0)
|
82 |
+
feature_edge = feature_edge.detach()
|
83 |
+
feature_edge = x * feature_edge
|
84 |
+
attn = self.res_block(feature_edge)
|
85 |
+
attn = self.sigmoid(attn)
|
86 |
+
|
87 |
+
# out = x * attn
|
88 |
+
|
89 |
+
out = x * attn + x
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
from PIL import Image
|
97 |
+
import numpy as np
|
98 |
+
import cv2
|
99 |
+
|
100 |
+
edg_atten = Edge_Attn()
|
101 |
+
|
102 |
+
im = Image.open('/apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset/fake_images/001400.png')
|
103 |
+
npim = np.array(im,dtype=np.float32)
|
104 |
+
npim = cv2.cvtColor(npim, cv2.COLOR_RGB2GRAY)
|
105 |
+
|
106 |
+
# npim = npim[:, :, 2]
|
107 |
+
tim = torch.from_numpy(npim).unsqueeze_(0).unsqueeze_(0)
|
108 |
+
edge = edg_atten.gradient(tim)
|
109 |
+
npgrad = edge.squeeze(0).squeeze(0).data.clamp(0,255).numpy()
|
110 |
+
Image.fromarray(npgrad.astype('uint8')).save('tmp.png')
|
111 |
+
|
112 |
+
# tim = torch.from_numpy(npim).unsqueeze_(0)
|
113 |
+
# edge = edg_atten.gradient_1order(tim)
|
114 |
+
# npgrad = edge.squeeze(0).data.clamp(0,255).numpy()[:, :, 0]
|
115 |
+
# Image.fromarray(npgrad.astype('uint8')).save('tmp.png')
|
116 |
+
|
NTED/extraction_distribution_model.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
from torch import nn
|
3 |
+
from NTED.base_module import Encoder, Decoder
|
4 |
+
|
5 |
+
from torch.cuda.amp import autocast as autocast
|
6 |
+
|
7 |
+
class Generator(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
size,
|
11 |
+
semantic_dim,
|
12 |
+
channels,
|
13 |
+
num_labels,
|
14 |
+
match_kernels,
|
15 |
+
blur_kernel=[1, 3, 3, 1],
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.size = size
|
19 |
+
self.reference_encoder = Encoder(
|
20 |
+
size, 3, channels, num_labels, match_kernels, blur_kernel
|
21 |
+
)
|
22 |
+
|
23 |
+
self.skeleton_encoder = Encoder(
|
24 |
+
size, semantic_dim, channels,
|
25 |
+
)
|
26 |
+
|
27 |
+
self.target_image_renderer = Decoder(
|
28 |
+
size, channels, num_labels, match_kernels, blur_kernel
|
29 |
+
)
|
30 |
+
|
31 |
+
def _cal_temp(self, module):
|
32 |
+
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
33 |
+
|
34 |
+
def forward(
|
35 |
+
self,
|
36 |
+
source_image,
|
37 |
+
skeleton,
|
38 |
+
amp_flag=False,
|
39 |
+
):
|
40 |
+
if amp_flag:
|
41 |
+
with autocast():
|
42 |
+
output_dict={}
|
43 |
+
recoder = collections.defaultdict(list)
|
44 |
+
skeleton_feature = self.skeleton_encoder(skeleton)
|
45 |
+
_ = self.reference_encoder(source_image, recoder)
|
46 |
+
neural_textures = recoder["neural_textures"]
|
47 |
+
output_dict['fake_image'] = self.target_image_renderer(
|
48 |
+
skeleton_feature, neural_textures, recoder
|
49 |
+
)
|
50 |
+
output_dict['info'] = recoder
|
51 |
+
return output_dict
|
52 |
+
else:
|
53 |
+
output_dict={}
|
54 |
+
recoder = collections.defaultdict(list)
|
55 |
+
skeleton_feature = self.skeleton_encoder(skeleton)
|
56 |
+
_ = self.reference_encoder(source_image, recoder)
|
57 |
+
neural_textures = recoder["neural_textures"]
|
58 |
+
output_dict['fake_image'] = self.target_image_renderer(
|
59 |
+
skeleton_feature, neural_textures, recoder
|
60 |
+
)
|
61 |
+
output_dict['info'] = recoder
|
62 |
+
return output_dict
|
NTED/fashion_512.yaml
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
distributed: True
|
2 |
+
image_to_tensorboard: True
|
3 |
+
snapshot_save_iter: 50000
|
4 |
+
snapshot_save_epoch: 20
|
5 |
+
snapshot_save_start_iter: 20000
|
6 |
+
snapshot_save_start_epoch: 100
|
7 |
+
image_save_iter: 1000
|
8 |
+
max_epoch: 400
|
9 |
+
logging_iter: 100
|
10 |
+
amp: False
|
11 |
+
|
12 |
+
gen_optimizer:
|
13 |
+
type: adam
|
14 |
+
lr: 0.002
|
15 |
+
adam_beta1: 0.
|
16 |
+
adam_beta2: 0.99
|
17 |
+
lr_policy:
|
18 |
+
iteration_mode: False
|
19 |
+
type: step
|
20 |
+
step_size: 1000000
|
21 |
+
gamma: 1
|
22 |
+
|
23 |
+
dis_optimizer:
|
24 |
+
type: adam
|
25 |
+
lr: 0.001882
|
26 |
+
adam_beta1: 0.
|
27 |
+
adam_beta2: 0.9905
|
28 |
+
lr_policy:
|
29 |
+
iteration_mode: False
|
30 |
+
type: step
|
31 |
+
step_size: 1000000
|
32 |
+
gamma: 1
|
33 |
+
|
34 |
+
|
35 |
+
trainer:
|
36 |
+
type: NTED.extraction_distribution_trainer::Trainer
|
37 |
+
gan_mode: style_gan2
|
38 |
+
gan_start_iteration: 1000 # 0
|
39 |
+
face_crop_method: util.face_crop::crop_face_from_output
|
40 |
+
hand_crop_method: util.face_crop::crop_hands_from_output
|
41 |
+
d_reg_every: 16
|
42 |
+
r1: 10
|
43 |
+
loss_weight:
|
44 |
+
weight_perceptual: 1
|
45 |
+
weight_gan: 1.5
|
46 |
+
weight_attn_rec: 15
|
47 |
+
weight_face: 1
|
48 |
+
weight_hand: 1
|
49 |
+
weight_l1: 1
|
50 |
+
weight_l1_hand: 0.8
|
51 |
+
weight_edge: 100
|
52 |
+
attn_weights:
|
53 |
+
8: 1
|
54 |
+
16: 1
|
55 |
+
32: 1
|
56 |
+
64: 1
|
57 |
+
128: 1
|
58 |
+
256: 1
|
59 |
+
vgg_param:
|
60 |
+
network: vgg19
|
61 |
+
layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
|
62 |
+
num_scales: 3
|
63 |
+
use_style_loss: True
|
64 |
+
style_to_perceptual: 1000
|
65 |
+
vgg_hand_param:
|
66 |
+
network: vgg19
|
67 |
+
layers: ['relu_1_1', 'relu_2_1', 'relu_3_1','relu_3_3', 'relu_4_1', 'relu_4_3', 'relu_5_1']
|
68 |
+
|
69 |
+
gen:
|
70 |
+
type: NTED.extraction_distribution_model::Generator
|
71 |
+
param:
|
72 |
+
size: 512
|
73 |
+
semantic_dim: 30
|
74 |
+
channels:
|
75 |
+
16: 512
|
76 |
+
32: 512
|
77 |
+
64: 512
|
78 |
+
128: 256
|
79 |
+
256: 128
|
80 |
+
512: 64
|
81 |
+
1024: 32
|
82 |
+
num_labels:
|
83 |
+
16: 16
|
84 |
+
32: 32
|
85 |
+
64: 32
|
86 |
+
128: 64
|
87 |
+
256: 64
|
88 |
+
512: False
|
89 |
+
match_kernels:
|
90 |
+
16: 1
|
91 |
+
32: 3
|
92 |
+
64: 3
|
93 |
+
128: 3
|
94 |
+
256: 3
|
95 |
+
512: False
|
96 |
+
|
97 |
+
dis:
|
98 |
+
type: generators.discriminator::Discriminator
|
99 |
+
param:
|
100 |
+
size: 512
|
101 |
+
channels:
|
102 |
+
4: 512
|
103 |
+
8: 512
|
104 |
+
16: 512
|
105 |
+
32: 512
|
106 |
+
64: 512
|
107 |
+
128: 256
|
108 |
+
256: 128
|
109 |
+
512: 64
|
110 |
+
is_square_image: False
|
111 |
+
|
112 |
+
|
113 |
+
data:
|
114 |
+
type: data.fashion_data::Dataset
|
115 |
+
preprocess_mode: resize_and_crop # resize_and_crop
|
116 |
+
path: /apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset_2d
|
117 |
+
num_workers: 16
|
118 |
+
sub_path: 512-352
|
119 |
+
resolution: 512
|
120 |
+
scale_param: 0.1
|
121 |
+
train:
|
122 |
+
batch_size: 4 # real_batch_size: 2 * 2 (source-->target & target --> source) * 4 (GPUs) = 16
|
123 |
+
distributed: True
|
124 |
+
val:
|
125 |
+
batch_size: 4
|
126 |
+
distributed: True
|
127 |
+
hand_keypoint: True
|
128 |
+
|
129 |
+
|
NTED/nted_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:359d3d3bac365afe04aa8b906f1dc8891f0dd87ff1dfe5e60059b4fb9bb96af8
|
3 |
+
size 284375285
|
NTED/op/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
+
from .upfirdn2d import upfirdn2d
|
NTED/op/conv2d_gradfix.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import autograd
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
enabled = True
|
9 |
+
weight_gradients_disabled = False
|
10 |
+
|
11 |
+
|
12 |
+
@contextlib.contextmanager
|
13 |
+
def no_weight_gradients():
|
14 |
+
global weight_gradients_disabled
|
15 |
+
|
16 |
+
old = weight_gradients_disabled
|
17 |
+
weight_gradients_disabled = True
|
18 |
+
yield
|
19 |
+
weight_gradients_disabled = old
|
20 |
+
|
21 |
+
|
22 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
23 |
+
if could_use_op(input):
|
24 |
+
return conv2d_gradfix(
|
25 |
+
transpose=False,
|
26 |
+
weight_shape=weight.shape,
|
27 |
+
stride=stride,
|
28 |
+
padding=padding,
|
29 |
+
output_padding=0,
|
30 |
+
dilation=dilation,
|
31 |
+
groups=groups,
|
32 |
+
).apply(input, weight, bias)
|
33 |
+
|
34 |
+
return F.conv2d(
|
35 |
+
input=input,
|
36 |
+
weight=weight,
|
37 |
+
bias=bias,
|
38 |
+
stride=stride,
|
39 |
+
padding=padding,
|
40 |
+
dilation=dilation,
|
41 |
+
groups=groups,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def conv_transpose2d(
|
46 |
+
input,
|
47 |
+
weight,
|
48 |
+
bias=None,
|
49 |
+
stride=1,
|
50 |
+
padding=0,
|
51 |
+
output_padding=0,
|
52 |
+
groups=1,
|
53 |
+
dilation=1,
|
54 |
+
):
|
55 |
+
if could_use_op(input):
|
56 |
+
return conv2d_gradfix(
|
57 |
+
transpose=True,
|
58 |
+
weight_shape=weight.shape,
|
59 |
+
stride=stride,
|
60 |
+
padding=padding,
|
61 |
+
output_padding=output_padding,
|
62 |
+
groups=groups,
|
63 |
+
dilation=dilation,
|
64 |
+
).apply(input, weight, bias)
|
65 |
+
|
66 |
+
return F.conv_transpose2d(
|
67 |
+
input=input,
|
68 |
+
weight=weight,
|
69 |
+
bias=bias,
|
70 |
+
stride=stride,
|
71 |
+
padding=padding,
|
72 |
+
output_padding=output_padding,
|
73 |
+
dilation=dilation,
|
74 |
+
groups=groups,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def could_use_op(input):
|
79 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
80 |
+
return False
|
81 |
+
|
82 |
+
if input.device.type != "cuda":
|
83 |
+
return False
|
84 |
+
|
85 |
+
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
|
86 |
+
return True
|
87 |
+
|
88 |
+
warnings.warn(
|
89 |
+
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
|
90 |
+
)
|
91 |
+
|
92 |
+
return False
|
93 |
+
|
94 |
+
|
95 |
+
def ensure_tuple(xs, ndim):
|
96 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
97 |
+
|
98 |
+
return xs
|
99 |
+
|
100 |
+
|
101 |
+
conv2d_gradfix_cache = dict()
|
102 |
+
|
103 |
+
|
104 |
+
def conv2d_gradfix(
|
105 |
+
transpose, weight_shape, stride, padding, output_padding, dilation, groups
|
106 |
+
):
|
107 |
+
ndim = 2
|
108 |
+
weight_shape = tuple(weight_shape)
|
109 |
+
stride = ensure_tuple(stride, ndim)
|
110 |
+
padding = ensure_tuple(padding, ndim)
|
111 |
+
output_padding = ensure_tuple(output_padding, ndim)
|
112 |
+
dilation = ensure_tuple(dilation, ndim)
|
113 |
+
|
114 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
115 |
+
if key in conv2d_gradfix_cache:
|
116 |
+
return conv2d_gradfix_cache[key]
|
117 |
+
|
118 |
+
common_kwargs = dict(
|
119 |
+
stride=stride, padding=padding, dilation=dilation, groups=groups
|
120 |
+
)
|
121 |
+
|
122 |
+
def calc_output_padding(input_shape, output_shape):
|
123 |
+
if transpose:
|
124 |
+
return [0, 0]
|
125 |
+
|
126 |
+
return [
|
127 |
+
input_shape[i + 2]
|
128 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
129 |
+
- (1 - 2 * padding[i])
|
130 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
131 |
+
for i in range(ndim)
|
132 |
+
]
|
133 |
+
|
134 |
+
class Conv2d(autograd.Function):
|
135 |
+
@staticmethod
|
136 |
+
def forward(ctx, input, weight, bias):
|
137 |
+
if not transpose:
|
138 |
+
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
139 |
+
|
140 |
+
else:
|
141 |
+
out = F.conv_transpose2d(
|
142 |
+
input=input,
|
143 |
+
weight=weight,
|
144 |
+
bias=bias,
|
145 |
+
output_padding=output_padding,
|
146 |
+
**common_kwargs,
|
147 |
+
)
|
148 |
+
|
149 |
+
ctx.save_for_backward(input, weight)
|
150 |
+
|
151 |
+
return out
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def backward(ctx, grad_output):
|
155 |
+
input, weight = ctx.saved_tensors
|
156 |
+
grad_input, grad_weight, grad_bias = None, None, None
|
157 |
+
|
158 |
+
if ctx.needs_input_grad[0]:
|
159 |
+
p = calc_output_padding(
|
160 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
161 |
+
)
|
162 |
+
grad_input = conv2d_gradfix(
|
163 |
+
transpose=(not transpose),
|
164 |
+
weight_shape=weight_shape,
|
165 |
+
output_padding=p,
|
166 |
+
**common_kwargs,
|
167 |
+
).apply(grad_output, weight, None)
|
168 |
+
|
169 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
170 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
171 |
+
|
172 |
+
if ctx.needs_input_grad[2]:
|
173 |
+
grad_bias = grad_output.sum((0, 2, 3))
|
174 |
+
|
175 |
+
return grad_input, grad_weight, grad_bias
|
176 |
+
|
177 |
+
class Conv2dGradWeight(autograd.Function):
|
178 |
+
@staticmethod
|
179 |
+
def forward(ctx, grad_output, input):
|
180 |
+
op = torch._C._jit_get_operation(
|
181 |
+
"aten::cudnn_convolution_backward_weight"
|
182 |
+
if not transpose
|
183 |
+
else "aten::cudnn_convolution_transpose_backward_weight"
|
184 |
+
)
|
185 |
+
flags = [
|
186 |
+
torch.backends.cudnn.benchmark,
|
187 |
+
torch.backends.cudnn.deterministic,
|
188 |
+
torch.backends.cudnn.allow_tf32,
|
189 |
+
]
|
190 |
+
grad_weight = op(
|
191 |
+
weight_shape,
|
192 |
+
grad_output,
|
193 |
+
input,
|
194 |
+
padding,
|
195 |
+
stride,
|
196 |
+
dilation,
|
197 |
+
groups,
|
198 |
+
*flags,
|
199 |
+
)
|
200 |
+
ctx.save_for_backward(grad_output, input)
|
201 |
+
|
202 |
+
return grad_weight
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def backward(ctx, grad_grad_weight):
|
206 |
+
grad_output, input = ctx.saved_tensors
|
207 |
+
grad_grad_output, grad_grad_input = None, None
|
208 |
+
|
209 |
+
if ctx.needs_input_grad[0]:
|
210 |
+
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
211 |
+
|
212 |
+
if ctx.needs_input_grad[1]:
|
213 |
+
p = calc_output_padding(
|
214 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
215 |
+
)
|
216 |
+
grad_grad_input = conv2d_gradfix(
|
217 |
+
transpose=(not transpose),
|
218 |
+
weight_shape=weight_shape,
|
219 |
+
output_padding=p,
|
220 |
+
**common_kwargs,
|
221 |
+
).apply(grad_output, grad_grad_weight, None)
|
222 |
+
|
223 |
+
return grad_grad_output, grad_grad_input
|
224 |
+
|
225 |
+
conv2d_gradfix_cache[key] = Conv2d
|
226 |
+
|
227 |
+
return Conv2d
|
NTED/op/fused_act.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
|
10 |
+
module_path = os.path.dirname(__file__)
|
11 |
+
fused = load(
|
12 |
+
"fused",
|
13 |
+
sources=[
|
14 |
+
os.path.join(module_path, "fused_bias_act.cpp"),
|
15 |
+
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
16 |
+
],
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, grad_output, out, bias, negative_slope, scale):
|
23 |
+
ctx.save_for_backward(out)
|
24 |
+
ctx.negative_slope = negative_slope
|
25 |
+
ctx.scale = scale
|
26 |
+
|
27 |
+
empty = grad_output.new_empty(0)
|
28 |
+
|
29 |
+
grad_input = fused.fused_bias_act(
|
30 |
+
grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
|
31 |
+
)
|
32 |
+
|
33 |
+
dim = [0]
|
34 |
+
|
35 |
+
if grad_input.ndim > 2:
|
36 |
+
dim += list(range(2, grad_input.ndim))
|
37 |
+
|
38 |
+
if bias:
|
39 |
+
grad_bias = grad_input.sum(dim).detach()
|
40 |
+
|
41 |
+
else:
|
42 |
+
grad_bias = empty
|
43 |
+
|
44 |
+
return grad_input, grad_bias
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
48 |
+
out, = ctx.saved_tensors
|
49 |
+
gradgrad_out = fused.fused_bias_act(
|
50 |
+
gradgrad_input.contiguous(),
|
51 |
+
gradgrad_bias.to(gradgrad_input.dtype),
|
52 |
+
out,
|
53 |
+
3,
|
54 |
+
1,
|
55 |
+
ctx.negative_slope,
|
56 |
+
ctx.scale,
|
57 |
+
)
|
58 |
+
|
59 |
+
return gradgrad_out, None, None, None, None
|
60 |
+
|
61 |
+
|
62 |
+
class FusedLeakyReLUFunction(Function):
|
63 |
+
@staticmethod
|
64 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
65 |
+
empty = input.new_empty(0)
|
66 |
+
|
67 |
+
ctx.bias = bias is not None
|
68 |
+
|
69 |
+
if bias is None:
|
70 |
+
bias = empty
|
71 |
+
|
72 |
+
out = fused.fused_bias_act(input, bias.to(input.dtype), empty, 3, 0, negative_slope, scale)
|
73 |
+
ctx.save_for_backward(out)
|
74 |
+
ctx.negative_slope = negative_slope
|
75 |
+
ctx.scale = scale
|
76 |
+
|
77 |
+
return out
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def backward(ctx, grad_output):
|
81 |
+
out, = ctx.saved_tensors
|
82 |
+
|
83 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
84 |
+
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
|
85 |
+
)
|
86 |
+
|
87 |
+
if not ctx.bias:
|
88 |
+
grad_bias = None
|
89 |
+
|
90 |
+
return grad_input, grad_bias, None, None
|
91 |
+
|
92 |
+
|
93 |
+
class FusedLeakyReLU(nn.Module):
|
94 |
+
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
if bias:
|
98 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
99 |
+
|
100 |
+
else:
|
101 |
+
self.bias = None
|
102 |
+
|
103 |
+
self.negative_slope = negative_slope
|
104 |
+
self.scale = scale
|
105 |
+
|
106 |
+
def forward(self, input):
|
107 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
108 |
+
|
109 |
+
|
110 |
+
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
111 |
+
if input.device.type == "cpu":
|
112 |
+
if bias is not None:
|
113 |
+
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
114 |
+
return (
|
115 |
+
F.leaky_relu(
|
116 |
+
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
117 |
+
)
|
118 |
+
* scale
|
119 |
+
)
|
120 |
+
|
121 |
+
else:
|
122 |
+
return F.leaky_relu(input, negative_slope=0.2) * scale
|
123 |
+
|
124 |
+
else:
|
125 |
+
return FusedLeakyReLUFunction.apply(
|
126 |
+
input.contiguous(), bias, negative_slope, scale
|
127 |
+
)
|
NTED/op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#include <ATen/ATen.h>
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
6 |
+
const torch::Tensor &bias,
|
7 |
+
const torch::Tensor &refer, int act, int grad,
|
8 |
+
float alpha, float scale);
|
9 |
+
|
10 |
+
#define CHECK_CUDA(x) \
|
11 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
12 |
+
#define CHECK_CONTIGUOUS(x) \
|
13 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
14 |
+
#define CHECK_INPUT(x) \
|
15 |
+
CHECK_CUDA(x); \
|
16 |
+
CHECK_CONTIGUOUS(x)
|
17 |
+
|
18 |
+
torch::Tensor fused_bias_act(const torch::Tensor &input,
|
19 |
+
const torch::Tensor &bias,
|
20 |
+
const torch::Tensor &refer, int act, int grad,
|
21 |
+
float alpha, float scale) {
|
22 |
+
CHECK_INPUT(input);
|
23 |
+
CHECK_INPUT(bias);
|
24 |
+
|
25 |
+
at::DeviceGuard guard(input.device());
|
26 |
+
|
27 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
28 |
+
}
|
29 |
+
|
30 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
31 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
32 |
+
}
|
NTED/op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
|
15 |
+
#include <cuda.h>
|
16 |
+
#include <cuda_runtime.h>
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void
|
20 |
+
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
|
21 |
+
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
|
22 |
+
scalar_t scale, int loop_x, int size_x, int step_b,
|
23 |
+
int size_b, int use_bias, int use_ref) {
|
24 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
25 |
+
|
26 |
+
scalar_t zero = 0.0;
|
27 |
+
|
28 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
|
29 |
+
loop_idx++, xi += blockDim.x) {
|
30 |
+
scalar_t x = p_x[xi];
|
31 |
+
|
32 |
+
if (use_bias) {
|
33 |
+
x += p_b[(xi / step_b) % size_b];
|
34 |
+
}
|
35 |
+
|
36 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
37 |
+
|
38 |
+
scalar_t y;
|
39 |
+
|
40 |
+
switch (act * 10 + grad) {
|
41 |
+
default:
|
42 |
+
case 10:
|
43 |
+
y = x;
|
44 |
+
break;
|
45 |
+
case 11:
|
46 |
+
y = x;
|
47 |
+
break;
|
48 |
+
case 12:
|
49 |
+
y = 0.0;
|
50 |
+
break;
|
51 |
+
|
52 |
+
case 30:
|
53 |
+
y = (x > 0.0) ? x : x * alpha;
|
54 |
+
break;
|
55 |
+
case 31:
|
56 |
+
y = (ref > 0.0) ? x : x * alpha;
|
57 |
+
break;
|
58 |
+
case 32:
|
59 |
+
y = 0.0;
|
60 |
+
break;
|
61 |
+
}
|
62 |
+
|
63 |
+
out[xi] = y * scale;
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
68 |
+
const torch::Tensor &bias,
|
69 |
+
const torch::Tensor &refer, int act, int grad,
|
70 |
+
float alpha, float scale) {
|
71 |
+
int curDevice = -1;
|
72 |
+
cudaGetDevice(&curDevice);
|
73 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
74 |
+
|
75 |
+
auto x = input.contiguous();
|
76 |
+
auto b = bias.contiguous();
|
77 |
+
auto ref = refer.contiguous();
|
78 |
+
|
79 |
+
int use_bias = b.numel() ? 1 : 0;
|
80 |
+
int use_ref = ref.numel() ? 1 : 0;
|
81 |
+
|
82 |
+
int size_x = x.numel();
|
83 |
+
int size_b = b.numel();
|
84 |
+
int step_b = 1;
|
85 |
+
|
86 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
87 |
+
step_b *= x.size(i);
|
88 |
+
}
|
89 |
+
|
90 |
+
int loop_x = 4;
|
91 |
+
int block_size = 4 * 32;
|
92 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
93 |
+
|
94 |
+
auto y = torch::empty_like(x);
|
95 |
+
|
96 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
97 |
+
x.scalar_type(), "fused_bias_act_kernel", [&] {
|
98 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
99 |
+
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
100 |
+
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
|
101 |
+
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
|
102 |
+
});
|
103 |
+
|
104 |
+
return y;
|
105 |
+
}
|
NTED/op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
5 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
6 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
7 |
+
int pad_y0, int pad_y1);
|
8 |
+
|
9 |
+
#define CHECK_CUDA(x) \
|
10 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
11 |
+
#define CHECK_CONTIGUOUS(x) \
|
12 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
13 |
+
#define CHECK_INPUT(x) \
|
14 |
+
CHECK_CUDA(x); \
|
15 |
+
CHECK_CONTIGUOUS(x)
|
16 |
+
|
17 |
+
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
|
18 |
+
int up_x, int up_y, int down_x, int down_y, int pad_x0,
|
19 |
+
int pad_x1, int pad_y0, int pad_y1) {
|
20 |
+
CHECK_INPUT(input);
|
21 |
+
CHECK_INPUT(kernel);
|
22 |
+
|
23 |
+
at::DeviceGuard guard(input.device());
|
24 |
+
|
25 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
26 |
+
pad_y0, pad_y1);
|
27 |
+
}
|
28 |
+
|
29 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
30 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
31 |
+
}
|
NTED/op/upfirdn2d.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import abc
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
|
10 |
+
module_path = os.path.dirname(__file__)
|
11 |
+
upfirdn2d_op = load(
|
12 |
+
"upfirdn2d",
|
13 |
+
sources=[
|
14 |
+
os.path.join(module_path, "upfirdn2d.cpp"),
|
15 |
+
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
16 |
+
],
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class UpFirDn2dBackward(Function):
|
21 |
+
@staticmethod
|
22 |
+
def forward(
|
23 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
24 |
+
):
|
25 |
+
|
26 |
+
up_x, up_y = up
|
27 |
+
down_x, down_y = down
|
28 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
29 |
+
|
30 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
31 |
+
|
32 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
33 |
+
grad_output,
|
34 |
+
grad_kernel.to(grad_output.dtype),
|
35 |
+
down_x,
|
36 |
+
down_y,
|
37 |
+
up_x,
|
38 |
+
up_y,
|
39 |
+
g_pad_x0,
|
40 |
+
g_pad_x1,
|
41 |
+
g_pad_y0,
|
42 |
+
g_pad_y1,
|
43 |
+
)
|
44 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
45 |
+
|
46 |
+
ctx.save_for_backward(kernel)
|
47 |
+
|
48 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
49 |
+
|
50 |
+
ctx.up_x = up_x
|
51 |
+
ctx.up_y = up_y
|
52 |
+
ctx.down_x = down_x
|
53 |
+
ctx.down_y = down_y
|
54 |
+
ctx.pad_x0 = pad_x0
|
55 |
+
ctx.pad_x1 = pad_x1
|
56 |
+
ctx.pad_y0 = pad_y0
|
57 |
+
ctx.pad_y1 = pad_y1
|
58 |
+
ctx.in_size = in_size
|
59 |
+
ctx.out_size = out_size
|
60 |
+
|
61 |
+
return grad_input
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def backward(ctx, gradgrad_input):
|
65 |
+
kernel, = ctx.saved_tensors
|
66 |
+
|
67 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
68 |
+
|
69 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
70 |
+
gradgrad_input,
|
71 |
+
kernel.to(gradgrad_input.dtype),
|
72 |
+
ctx.up_x,
|
73 |
+
ctx.up_y,
|
74 |
+
ctx.down_x,
|
75 |
+
ctx.down_y,
|
76 |
+
ctx.pad_x0,
|
77 |
+
ctx.pad_x1,
|
78 |
+
ctx.pad_y0,
|
79 |
+
ctx.pad_y1,
|
80 |
+
)
|
81 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
82 |
+
gradgrad_out = gradgrad_out.view(
|
83 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
84 |
+
)
|
85 |
+
|
86 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
87 |
+
|
88 |
+
|
89 |
+
class UpFirDn2d(Function):
|
90 |
+
@staticmethod
|
91 |
+
def forward(ctx, input, kernel, up, down, pad):
|
92 |
+
up_x, up_y = up
|
93 |
+
down_x, down_y = down
|
94 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
95 |
+
|
96 |
+
kernel_h, kernel_w = kernel.shape
|
97 |
+
batch, channel, in_h, in_w = input.shape
|
98 |
+
ctx.in_size = input.shape
|
99 |
+
|
100 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
101 |
+
|
102 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
103 |
+
|
104 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
105 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
106 |
+
ctx.out_size = (out_h, out_w)
|
107 |
+
|
108 |
+
ctx.up = (up_x, up_y)
|
109 |
+
ctx.down = (down_x, down_y)
|
110 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
111 |
+
|
112 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
113 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
114 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
115 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
116 |
+
|
117 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
118 |
+
|
119 |
+
out = upfirdn2d_op.upfirdn2d(
|
120 |
+
input, kernel.to(input.dtype), up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
121 |
+
)
|
122 |
+
# out = out.view(major, out_h, out_w, minor)
|
123 |
+
out = out.view(-1, channel, out_h, out_w)
|
124 |
+
|
125 |
+
return out
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def backward(ctx, grad_output):
|
129 |
+
kernel, grad_kernel = ctx.saved_tensors
|
130 |
+
|
131 |
+
grad_input = None
|
132 |
+
|
133 |
+
if ctx.needs_input_grad[0]:
|
134 |
+
grad_input = UpFirDn2dBackward.apply(
|
135 |
+
grad_output,
|
136 |
+
kernel,
|
137 |
+
grad_kernel,
|
138 |
+
ctx.up,
|
139 |
+
ctx.down,
|
140 |
+
ctx.pad,
|
141 |
+
ctx.g_pad,
|
142 |
+
ctx.in_size,
|
143 |
+
ctx.out_size,
|
144 |
+
)
|
145 |
+
|
146 |
+
return grad_input, None, None, None, None
|
147 |
+
|
148 |
+
|
149 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
150 |
+
if not isinstance(up, abc.Iterable):
|
151 |
+
up = (up, up)
|
152 |
+
|
153 |
+
if not isinstance(down, abc.Iterable):
|
154 |
+
down = (down, down)
|
155 |
+
|
156 |
+
if len(pad) == 2:
|
157 |
+
pad = (pad[0], pad[1], pad[0], pad[1])
|
158 |
+
|
159 |
+
if input.device.type == "cpu":
|
160 |
+
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
|
161 |
+
|
162 |
+
else:
|
163 |
+
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
164 |
+
|
165 |
+
return out
|
166 |
+
|
167 |
+
|
168 |
+
def upfirdn2d_native(
|
169 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
170 |
+
):
|
171 |
+
_, channel, in_h, in_w = input.shape
|
172 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
173 |
+
|
174 |
+
_, in_h, in_w, minor = input.shape
|
175 |
+
kernel_h, kernel_w = kernel.shape
|
176 |
+
|
177 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
178 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
179 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
180 |
+
|
181 |
+
out = F.pad(
|
182 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
183 |
+
)
|
184 |
+
out = out[
|
185 |
+
:,
|
186 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
187 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
188 |
+
:,
|
189 |
+
]
|
190 |
+
|
191 |
+
out = out.permute(0, 3, 1, 2)
|
192 |
+
out = out.reshape(
|
193 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
194 |
+
)
|
195 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
196 |
+
out = F.conv2d(out, w)
|
197 |
+
out = out.reshape(
|
198 |
+
-1,
|
199 |
+
minor,
|
200 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
201 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
202 |
+
)
|
203 |
+
out = out.permute(0, 2, 3, 1)
|
204 |
+
out = out[:, ::down_y, ::down_x, :]
|
205 |
+
|
206 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
207 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
208 |
+
|
209 |
+
return out.view(-1, channel, out_h, out_w)
|
NTED/op/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
18 |
+
int c = a / b;
|
19 |
+
|
20 |
+
if (c * b > a) {
|
21 |
+
c--;
|
22 |
+
}
|
23 |
+
|
24 |
+
return c;
|
25 |
+
}
|
26 |
+
|
27 |
+
struct UpFirDn2DKernelParams {
|
28 |
+
int up_x;
|
29 |
+
int up_y;
|
30 |
+
int down_x;
|
31 |
+
int down_y;
|
32 |
+
int pad_x0;
|
33 |
+
int pad_x1;
|
34 |
+
int pad_y0;
|
35 |
+
int pad_y1;
|
36 |
+
|
37 |
+
int major_dim;
|
38 |
+
int in_h;
|
39 |
+
int in_w;
|
40 |
+
int minor_dim;
|
41 |
+
int kernel_h;
|
42 |
+
int kernel_w;
|
43 |
+
int out_h;
|
44 |
+
int out_w;
|
45 |
+
int loop_major;
|
46 |
+
int loop_x;
|
47 |
+
};
|
48 |
+
|
49 |
+
template <typename scalar_t>
|
50 |
+
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
51 |
+
const scalar_t *kernel,
|
52 |
+
const UpFirDn2DKernelParams p) {
|
53 |
+
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
54 |
+
int out_y = minor_idx / p.minor_dim;
|
55 |
+
minor_idx -= out_y * p.minor_dim;
|
56 |
+
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
57 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
58 |
+
|
59 |
+
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
60 |
+
major_idx_base >= p.major_dim) {
|
61 |
+
return;
|
62 |
+
}
|
63 |
+
|
64 |
+
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
65 |
+
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
66 |
+
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
67 |
+
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
68 |
+
|
69 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
70 |
+
loop_major < p.loop_major && major_idx < p.major_dim;
|
71 |
+
loop_major++, major_idx++) {
|
72 |
+
for (int loop_x = 0, out_x = out_x_base;
|
73 |
+
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
74 |
+
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
75 |
+
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
76 |
+
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
77 |
+
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
78 |
+
|
79 |
+
const scalar_t *x_p =
|
80 |
+
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
81 |
+
minor_idx];
|
82 |
+
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
83 |
+
int x_px = p.minor_dim;
|
84 |
+
int k_px = -p.up_x;
|
85 |
+
int x_py = p.in_w * p.minor_dim;
|
86 |
+
int k_py = -p.up_y * p.kernel_w;
|
87 |
+
|
88 |
+
scalar_t v = 0.0f;
|
89 |
+
|
90 |
+
for (int y = 0; y < h; y++) {
|
91 |
+
for (int x = 0; x < w; x++) {
|
92 |
+
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
93 |
+
x_p += x_px;
|
94 |
+
k_p += k_px;
|
95 |
+
}
|
96 |
+
|
97 |
+
x_p += x_py - w * x_px;
|
98 |
+
k_p += k_py - w * k_px;
|
99 |
+
}
|
100 |
+
|
101 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
102 |
+
minor_idx] = v;
|
103 |
+
}
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
108 |
+
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
109 |
+
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
110 |
+
const scalar_t *kernel,
|
111 |
+
const UpFirDn2DKernelParams p) {
|
112 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
113 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
114 |
+
|
115 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
116 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
117 |
+
|
118 |
+
int minor_idx = blockIdx.x;
|
119 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
120 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
121 |
+
tile_out_y *= tile_out_h;
|
122 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
123 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
124 |
+
|
125 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
126 |
+
major_idx_base >= p.major_dim) {
|
127 |
+
return;
|
128 |
+
}
|
129 |
+
|
130 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
131 |
+
tap_idx += blockDim.x) {
|
132 |
+
int ky = tap_idx / kernel_w;
|
133 |
+
int kx = tap_idx - ky * kernel_w;
|
134 |
+
scalar_t v = 0.0;
|
135 |
+
|
136 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
137 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
138 |
+
}
|
139 |
+
|
140 |
+
sk[ky][kx] = v;
|
141 |
+
}
|
142 |
+
|
143 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
144 |
+
loop_major < p.loop_major & major_idx < p.major_dim;
|
145 |
+
loop_major++, major_idx++) {
|
146 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
147 |
+
loop_x < p.loop_x & tile_out_x < p.out_w;
|
148 |
+
loop_x++, tile_out_x += tile_out_w) {
|
149 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
150 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
151 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
152 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
153 |
+
|
154 |
+
__syncthreads();
|
155 |
+
|
156 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
157 |
+
in_idx += blockDim.x) {
|
158 |
+
int rel_in_y = in_idx / tile_in_w;
|
159 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
160 |
+
int in_x = rel_in_x + tile_in_x;
|
161 |
+
int in_y = rel_in_y + tile_in_y;
|
162 |
+
|
163 |
+
scalar_t v = 0.0;
|
164 |
+
|
165 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
166 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
167 |
+
p.minor_dim +
|
168 |
+
minor_idx];
|
169 |
+
}
|
170 |
+
|
171 |
+
sx[rel_in_y][rel_in_x] = v;
|
172 |
+
}
|
173 |
+
|
174 |
+
__syncthreads();
|
175 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
176 |
+
out_idx += blockDim.x) {
|
177 |
+
int rel_out_y = out_idx / tile_out_w;
|
178 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
179 |
+
int out_x = rel_out_x + tile_out_x;
|
180 |
+
int out_y = rel_out_y + tile_out_y;
|
181 |
+
|
182 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
183 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
184 |
+
int in_x = floor_div(mid_x, up_x);
|
185 |
+
int in_y = floor_div(mid_y, up_y);
|
186 |
+
int rel_in_x = in_x - tile_in_x;
|
187 |
+
int rel_in_y = in_y - tile_in_y;
|
188 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
189 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
190 |
+
|
191 |
+
scalar_t v = 0.0;
|
192 |
+
|
193 |
+
#pragma unroll
|
194 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
195 |
+
#pragma unroll
|
196 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
197 |
+
v += sx[rel_in_y + y][rel_in_x + x] *
|
198 |
+
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
199 |
+
|
200 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
201 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
202 |
+
minor_idx] = v;
|
203 |
+
}
|
204 |
+
}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
210 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
211 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
212 |
+
int pad_y0, int pad_y1) {
|
213 |
+
int curDevice = -1;
|
214 |
+
cudaGetDevice(&curDevice);
|
215 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
216 |
+
|
217 |
+
UpFirDn2DKernelParams p;
|
218 |
+
|
219 |
+
auto x = input.contiguous();
|
220 |
+
auto k = kernel.contiguous();
|
221 |
+
|
222 |
+
p.major_dim = x.size(0);
|
223 |
+
p.in_h = x.size(1);
|
224 |
+
p.in_w = x.size(2);
|
225 |
+
p.minor_dim = x.size(3);
|
226 |
+
p.kernel_h = k.size(0);
|
227 |
+
p.kernel_w = k.size(1);
|
228 |
+
p.up_x = up_x;
|
229 |
+
p.up_y = up_y;
|
230 |
+
p.down_x = down_x;
|
231 |
+
p.down_y = down_y;
|
232 |
+
p.pad_x0 = pad_x0;
|
233 |
+
p.pad_x1 = pad_x1;
|
234 |
+
p.pad_y0 = pad_y0;
|
235 |
+
p.pad_y1 = pad_y1;
|
236 |
+
|
237 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
238 |
+
p.down_y;
|
239 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
240 |
+
p.down_x;
|
241 |
+
|
242 |
+
auto out =
|
243 |
+
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
244 |
+
|
245 |
+
int mode = -1;
|
246 |
+
|
247 |
+
int tile_out_h = -1;
|
248 |
+
int tile_out_w = -1;
|
249 |
+
|
250 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
251 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
252 |
+
mode = 1;
|
253 |
+
tile_out_h = 16;
|
254 |
+
tile_out_w = 64;
|
255 |
+
}
|
256 |
+
|
257 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
258 |
+
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
259 |
+
mode = 2;
|
260 |
+
tile_out_h = 16;
|
261 |
+
tile_out_w = 64;
|
262 |
+
}
|
263 |
+
|
264 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
265 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
266 |
+
mode = 3;
|
267 |
+
tile_out_h = 16;
|
268 |
+
tile_out_w = 64;
|
269 |
+
}
|
270 |
+
|
271 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
272 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
273 |
+
mode = 4;
|
274 |
+
tile_out_h = 16;
|
275 |
+
tile_out_w = 64;
|
276 |
+
}
|
277 |
+
|
278 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
279 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
280 |
+
mode = 5;
|
281 |
+
tile_out_h = 8;
|
282 |
+
tile_out_w = 32;
|
283 |
+
}
|
284 |
+
|
285 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
286 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
287 |
+
mode = 6;
|
288 |
+
tile_out_h = 8;
|
289 |
+
tile_out_w = 32;
|
290 |
+
}
|
291 |
+
|
292 |
+
dim3 block_size;
|
293 |
+
dim3 grid_size;
|
294 |
+
|
295 |
+
if (tile_out_h > 0 && tile_out_w > 0) {
|
296 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
297 |
+
p.loop_x = 1;
|
298 |
+
block_size = dim3(32 * 8, 1, 1);
|
299 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
300 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
301 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
302 |
+
} else {
|
303 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
304 |
+
p.loop_x = 4;
|
305 |
+
block_size = dim3(4, 32, 1);
|
306 |
+
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
307 |
+
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
308 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
309 |
+
}
|
310 |
+
|
311 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
312 |
+
switch (mode) {
|
313 |
+
case 1:
|
314 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
315 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
316 |
+
x.data_ptr<scalar_t>(),
|
317 |
+
k.data_ptr<scalar_t>(), p);
|
318 |
+
|
319 |
+
break;
|
320 |
+
|
321 |
+
case 2:
|
322 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
323 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
324 |
+
x.data_ptr<scalar_t>(),
|
325 |
+
k.data_ptr<scalar_t>(), p);
|
326 |
+
|
327 |
+
break;
|
328 |
+
|
329 |
+
case 3:
|
330 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
331 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
332 |
+
x.data_ptr<scalar_t>(),
|
333 |
+
k.data_ptr<scalar_t>(), p);
|
334 |
+
|
335 |
+
break;
|
336 |
+
|
337 |
+
case 4:
|
338 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
339 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
340 |
+
x.data_ptr<scalar_t>(),
|
341 |
+
k.data_ptr<scalar_t>(), p);
|
342 |
+
|
343 |
+
break;
|
344 |
+
|
345 |
+
case 5:
|
346 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
347 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
348 |
+
x.data_ptr<scalar_t>(),
|
349 |
+
k.data_ptr<scalar_t>(), p);
|
350 |
+
|
351 |
+
break;
|
352 |
+
|
353 |
+
case 6:
|
354 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
355 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
356 |
+
x.data_ptr<scalar_t>(),
|
357 |
+
k.data_ptr<scalar_t>(), p);
|
358 |
+
|
359 |
+
break;
|
360 |
+
|
361 |
+
default:
|
362 |
+
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
363 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
364 |
+
k.data_ptr<scalar_t>(), p);
|
365 |
+
}
|
366 |
+
});
|
367 |
+
|
368 |
+
return out;
|
369 |
+
}
|
app.py
CHANGED
@@ -1,17 +1,29 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
return "恭喜,您今年" + 年龄预测器_输入您的年龄 + "岁了!"
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
'''
|
11 |
TODO
|
12 |
-
先把openpose light整合进来测试一下
|
13 |
-
|
14 |
测试视频展示功能
|
15 |
-
|
16 |
-
|
17 |
'''
|
|
|
1 |
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from NTED.NTED_module import NTED
|
6 |
|
7 |
+
NTED_Module = NTED()
|
|
|
8 |
|
9 |
+
def pose_transfer(上传人体姿态图):
|
10 |
+
img = 上传人体姿态图
|
11 |
+
fake_img = NTED_Module.inference(img)
|
12 |
+
|
13 |
+
return fake_img
|
14 |
|
15 |
+
with gr.Column():
|
16 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
17 |
+
|
18 |
+
gr.Interface(fn=pose_transfer,
|
19 |
+
inputs=["image"],
|
20 |
+
outputs=[result_gallery],
|
21 |
+
title="谷小雨姿态驱动图像",
|
22 |
+
examples=[["example/exp1.png"], ["example/exp2.png"], ["example/exp3.png"],\
|
23 |
+
["example/exp4.png"], ["example/exp5.png"], ["example/exp6.png"]],
|
24 |
+
).launch(server_name='0.0.0.0')
|
25 |
|
26 |
'''
|
27 |
TODO
|
|
|
|
|
28 |
测试视频展示功能
|
|
|
|
|
29 |
'''
|
example/exp1.png
ADDED
example/exp2.png
ADDED
example/exp3.png
ADDED
example/exp4.png
ADDED
example/exp5.png
ADDED
example/exp6.png
ADDED
example/ref_img.png
ADDED
Git LFS Details
|
lite_openpose/body_bbox_detector.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import sys
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
# from PIL import Image
|
13 |
+
|
14 |
+
# Code from https://github.com/Daniil-Osokin/lightweight-human-pose-estimation.pytorch/blob/master/demo.py
|
15 |
+
|
16 |
+
# 2D body pose estimator
|
17 |
+
sys.path.append('/apdcephfs/share_1474453/zejunzhang/workspace/HR-VITON/dataset_process_utils/lite_openpose')
|
18 |
+
from pose2d_models.with_mobilenet import PoseEstimationWithMobileNet
|
19 |
+
from modules.load_state import load_state
|
20 |
+
from modules.pose import Pose, track_poses
|
21 |
+
from modules.keypoints import extract_keypoints, group_keypoints
|
22 |
+
|
23 |
+
|
24 |
+
def normalize(img, img_mean, img_scale):
|
25 |
+
img = np.array(img, dtype=np.float32)
|
26 |
+
img = (img - img_mean) * img_scale
|
27 |
+
return img
|
28 |
+
|
29 |
+
|
30 |
+
def pad_width(img, stride, pad_value, min_dims):
|
31 |
+
h, w, _ = img.shape
|
32 |
+
h = min(min_dims[0], h)
|
33 |
+
min_dims[0] = math.ceil(min_dims[0] / float(stride)) * stride
|
34 |
+
min_dims[1] = max(min_dims[1], w)
|
35 |
+
min_dims[1] = math.ceil(min_dims[1] / float(stride)) * stride
|
36 |
+
pad = []
|
37 |
+
pad.append(int(math.floor((min_dims[0] - h) / 2.0)))
|
38 |
+
pad.append(int(math.floor((min_dims[1] - w) / 2.0)))
|
39 |
+
pad.append(int(min_dims[0] - h - pad[0]))
|
40 |
+
pad.append(int(min_dims[1] - w - pad[1]))
|
41 |
+
padded_img = cv2.copyMakeBorder(img, pad[0], pad[2], pad[1], pad[3],
|
42 |
+
cv2.BORDER_CONSTANT, value=pad_value)
|
43 |
+
return padded_img, pad
|
44 |
+
|
45 |
+
|
46 |
+
class BodyPoseEstimator(object):
|
47 |
+
"""
|
48 |
+
Hand Detector for third-view input.
|
49 |
+
It combines a body pose estimator (https://github.com/jhugestar/lightweight-human-pose-estimation.pytorch.git)
|
50 |
+
"""
|
51 |
+
def __init__(self, device='cpu'):
|
52 |
+
# print("Loading Body Pose Estimator")
|
53 |
+
self.device=device
|
54 |
+
self.__load_body_estimator()
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
def __load_body_estimator(self):
|
59 |
+
net = PoseEstimationWithMobileNet()
|
60 |
+
pose2d_checkpoint = "lite_openpose/checkpoint_iter_370000.pth"
|
61 |
+
checkpoint = torch.load(pose2d_checkpoint, map_location='cpu')
|
62 |
+
load_state(net, checkpoint)
|
63 |
+
net = net.eval()
|
64 |
+
net = net.to(self.device)
|
65 |
+
self.model = net
|
66 |
+
|
67 |
+
|
68 |
+
#Code from https://github.com/Daniil-Osokin/lightweight-human-pose-estimation.pytorch/demo.py
|
69 |
+
def __infer_fast(self, img, input_height_size, stride, upsample_ratio,
|
70 |
+
cpu=False, pad_value=(0, 0, 0), img_mean=(128, 128, 128), img_scale=1/256):
|
71 |
+
height, width, _ = img.shape
|
72 |
+
scale = input_height_size / height
|
73 |
+
|
74 |
+
scaled_img = cv2.resize(img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
75 |
+
scaled_img = normalize(scaled_img, img_mean, img_scale)
|
76 |
+
min_dims = [input_height_size, max(scaled_img.shape[1], input_height_size)]
|
77 |
+
padded_img, pad = pad_width(scaled_img, stride, pad_value, min_dims)
|
78 |
+
|
79 |
+
tensor_img = torch.from_numpy(padded_img).permute(2, 0, 1).unsqueeze(0).float()
|
80 |
+
if not cpu:
|
81 |
+
tensor_img = tensor_img.to(self.device)
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
stages_output = self.model(tensor_img)
|
85 |
+
|
86 |
+
stage2_heatmaps = stages_output[-2]
|
87 |
+
heatmaps = np.transpose(stage2_heatmaps.squeeze().cpu().data.numpy(), (1, 2, 0))
|
88 |
+
heatmaps = cv2.resize(heatmaps, (0, 0), fx=upsample_ratio, fy=upsample_ratio, interpolation=cv2.INTER_CUBIC)
|
89 |
+
|
90 |
+
stage2_pafs = stages_output[-1]
|
91 |
+
pafs = np.transpose(stage2_pafs.squeeze().cpu().data.numpy(), (1, 2, 0))
|
92 |
+
pafs = cv2.resize(pafs, (0, 0), fx=upsample_ratio, fy=upsample_ratio, interpolation=cv2.INTER_CUBIC)
|
93 |
+
|
94 |
+
return heatmaps, pafs, scale, pad
|
95 |
+
|
96 |
+
def detect_body_pose(self, img):
|
97 |
+
"""
|
98 |
+
Output:
|
99 |
+
current_bbox: BBOX_XYWH
|
100 |
+
"""
|
101 |
+
stride = 8
|
102 |
+
upsample_ratio = 4
|
103 |
+
orig_img = img.copy()
|
104 |
+
|
105 |
+
# forward
|
106 |
+
heatmaps, pafs, scale, pad = self.__infer_fast(img,
|
107 |
+
input_height_size=256, stride=stride, upsample_ratio=upsample_ratio)
|
108 |
+
|
109 |
+
total_keypoints_num = 0
|
110 |
+
all_keypoints_by_type = []
|
111 |
+
num_keypoints = Pose.num_kpts
|
112 |
+
for kpt_idx in range(num_keypoints): # 19th for bg
|
113 |
+
total_keypoints_num += extract_keypoints(heatmaps[:, :, kpt_idx], all_keypoints_by_type, total_keypoints_num)
|
114 |
+
|
115 |
+
pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, pafs, demo=True)
|
116 |
+
for kpt_id in range(all_keypoints.shape[0]):
|
117 |
+
all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * stride / upsample_ratio - pad[1]) / scale
|
118 |
+
all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * stride / upsample_ratio - pad[0]) / scale
|
119 |
+
|
120 |
+
'''
|
121 |
+
# print(len(pose_entries))
|
122 |
+
if len(pose_entries)>1:
|
123 |
+
pose_entries = pose_entries[:1]
|
124 |
+
print("We only support one person currently")
|
125 |
+
# assert len(pose_entries) == 1, "We only support one person currently"
|
126 |
+
'''
|
127 |
+
|
128 |
+
current_poses, current_bbox = list(), list()
|
129 |
+
for n in range(len(pose_entries)):
|
130 |
+
if len(pose_entries[n]) == 0:
|
131 |
+
continue
|
132 |
+
pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1
|
133 |
+
for kpt_id in range(num_keypoints):
|
134 |
+
if pose_entries[n][kpt_id] != -1.0: # keypoint was found
|
135 |
+
pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0])
|
136 |
+
pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1])
|
137 |
+
pose = Pose(pose_keypoints, pose_entries[n][18])
|
138 |
+
current_poses.append(pose.keypoints)
|
139 |
+
current_bbox.append(np.array(pose.bbox))
|
140 |
+
|
141 |
+
# enlarge the bbox
|
142 |
+
for i, bbox in enumerate(current_bbox):
|
143 |
+
x, y, w, h = bbox
|
144 |
+
margin = 0.2
|
145 |
+
x_margin = int(w * margin)
|
146 |
+
y_margin = int(h * margin)
|
147 |
+
x0 = max(x-x_margin, 0)
|
148 |
+
y0 = max(y-y_margin, 0)
|
149 |
+
x1 = min(x+w+x_margin, orig_img.shape[1])
|
150 |
+
y1 = min(y+h+y_margin, orig_img.shape[0])
|
151 |
+
current_bbox[i] = np.array((x0, y0, x1, y1)).astype(np.int32) # ltrb
|
152 |
+
|
153 |
+
# 只拿一个人
|
154 |
+
body_point_list = []
|
155 |
+
if len(current_poses) > 0:
|
156 |
+
for item in current_poses[0]:
|
157 |
+
if item[0] == item[1] == -1:
|
158 |
+
body_point_list += [0.0, 0.0, 0.0]
|
159 |
+
else:
|
160 |
+
body_point_list += [float(item[0]), float(item[1]), 1.0]
|
161 |
+
else:
|
162 |
+
for i in range(18):
|
163 |
+
body_point_list += [0.0, 0.0, 0.0]
|
164 |
+
|
165 |
+
pose_dict = dict()
|
166 |
+
pose_dict["people"] = []
|
167 |
+
pose_dict["people"].append({
|
168 |
+
"person_id": [-1],
|
169 |
+
"pose_keypoints_2d": body_point_list,
|
170 |
+
"hand_left_keypoints_2d": [],
|
171 |
+
"hand_right_keypoints_2d": [],
|
172 |
+
"face_keypoints_2d": [],
|
173 |
+
"pose_keypoints_3d": [],
|
174 |
+
"face_keypoints_3d": [],
|
175 |
+
"hand_left_keypoints_3d": [],
|
176 |
+
"hand_right_keypoints_3d": [],
|
177 |
+
})
|
178 |
+
|
179 |
+
return current_poses, current_bbox
|
lite_openpose/checkpoint_iter_370000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:517c86f769c6636583083f1467e3d212a0006c27109edb3aeffc19a79622d411
|
3 |
+
size 87959810
|
lite_openpose/modules/__init__.py
ADDED
File without changes
|
lite_openpose/modules/conv.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
def conv(in_channels, out_channels, kernel_size=3, padding=1, bn=True, dilation=1, stride=1, relu=True, bias=True):
|
5 |
+
modules = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)]
|
6 |
+
if bn:
|
7 |
+
modules.append(nn.BatchNorm2d(out_channels))
|
8 |
+
if relu:
|
9 |
+
modules.append(nn.ReLU(inplace=True))
|
10 |
+
return nn.Sequential(*modules)
|
11 |
+
|
12 |
+
|
13 |
+
def conv_dw(in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1):
|
14 |
+
return nn.Sequential(
|
15 |
+
nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation=dilation, groups=in_channels, bias=False),
|
16 |
+
nn.BatchNorm2d(in_channels),
|
17 |
+
nn.ReLU(inplace=True),
|
18 |
+
|
19 |
+
nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
|
20 |
+
nn.BatchNorm2d(out_channels),
|
21 |
+
nn.ReLU(inplace=True),
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def conv_dw_no_bn(in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1):
|
26 |
+
return nn.Sequential(
|
27 |
+
nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation=dilation, groups=in_channels, bias=False),
|
28 |
+
nn.ELU(inplace=True),
|
29 |
+
|
30 |
+
nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
|
31 |
+
nn.ELU(inplace=True),
|
32 |
+
)
|
lite_openpose/modules/get_parameters.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_parameters(model, predicate):
|
5 |
+
for module in model.modules():
|
6 |
+
for param_name, param in module.named_parameters():
|
7 |
+
if predicate(module, param_name):
|
8 |
+
yield param
|
9 |
+
|
10 |
+
|
11 |
+
def get_parameters_conv(model, name):
|
12 |
+
return get_parameters(model, lambda m, p: isinstance(m, nn.Conv2d) and m.groups == 1 and p == name)
|
13 |
+
|
14 |
+
|
15 |
+
def get_parameters_conv_depthwise(model, name):
|
16 |
+
return get_parameters(model, lambda m, p: isinstance(m, nn.Conv2d)
|
17 |
+
and m.groups == m.in_channels
|
18 |
+
and m.in_channels == m.out_channels
|
19 |
+
and p == name)
|
20 |
+
|
21 |
+
|
22 |
+
def get_parameters_bn(model, name):
|
23 |
+
return get_parameters(model, lambda m, p: isinstance(m, nn.BatchNorm2d) and p == name)
|
lite_openpose/modules/keypoints.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
from operator import itemgetter
|
4 |
+
|
5 |
+
BODY_PARTS_KPT_IDS = [[1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [1, 11],
|
6 |
+
[11, 12], [12, 13], [1, 0], [0, 14], [14, 16], [0, 15], [15, 17], [2, 16], [5, 17]]
|
7 |
+
BODY_PARTS_PAF_IDS = ([12, 13], [20, 21], [14, 15], [16, 17], [22, 23], [24, 25], [0, 1], [2, 3], [4, 5],
|
8 |
+
[6, 7], [8, 9], [10, 11], [28, 29], [30, 31], [34, 35], [32, 33], [36, 37], [18, 19], [26, 27])
|
9 |
+
|
10 |
+
|
11 |
+
def linspace2d(start, stop, n=10):
|
12 |
+
points = 1 / (n - 1) * (stop - start)
|
13 |
+
return points[:, None] * np.arange(n) + start[:, None]
|
14 |
+
|
15 |
+
|
16 |
+
def extract_keypoints(heatmap, all_keypoints, total_keypoint_num):
|
17 |
+
heatmap[heatmap < 0.1] = 0
|
18 |
+
heatmap_with_borders = np.pad(heatmap, [(2, 2), (2, 2)], mode='constant')
|
19 |
+
heatmap_center = heatmap_with_borders[1:heatmap_with_borders.shape[0]-1, 1:heatmap_with_borders.shape[1]-1]
|
20 |
+
heatmap_left = heatmap_with_borders[1:heatmap_with_borders.shape[0]-1, 2:heatmap_with_borders.shape[1]]
|
21 |
+
heatmap_right = heatmap_with_borders[1:heatmap_with_borders.shape[0]-1, 0:heatmap_with_borders.shape[1]-2]
|
22 |
+
heatmap_up = heatmap_with_borders[2:heatmap_with_borders.shape[0], 1:heatmap_with_borders.shape[1]-1]
|
23 |
+
heatmap_down = heatmap_with_borders[0:heatmap_with_borders.shape[0]-2, 1:heatmap_with_borders.shape[1]-1]
|
24 |
+
|
25 |
+
heatmap_peaks = (heatmap_center > heatmap_left) &\
|
26 |
+
(heatmap_center > heatmap_right) &\
|
27 |
+
(heatmap_center > heatmap_up) &\
|
28 |
+
(heatmap_center > heatmap_down)
|
29 |
+
heatmap_peaks = heatmap_peaks[1:heatmap_center.shape[0]-1, 1:heatmap_center.shape[1]-1]
|
30 |
+
keypoints = list(zip(np.nonzero(heatmap_peaks)[1], np.nonzero(heatmap_peaks)[0])) # (w, h)
|
31 |
+
keypoints = sorted(keypoints, key=itemgetter(0))
|
32 |
+
|
33 |
+
suppressed = np.zeros(len(keypoints), np.uint8)
|
34 |
+
keypoints_with_score_and_id = []
|
35 |
+
keypoint_num = 0
|
36 |
+
for i in range(len(keypoints)):
|
37 |
+
if suppressed[i]:
|
38 |
+
continue
|
39 |
+
for j in range(i+1, len(keypoints)):
|
40 |
+
if math.sqrt((keypoints[i][0] - keypoints[j][0]) ** 2 +
|
41 |
+
(keypoints[i][1] - keypoints[j][1]) ** 2) < 6:
|
42 |
+
suppressed[j] = 1
|
43 |
+
keypoint_with_score_and_id = (keypoints[i][0], keypoints[i][1], heatmap[keypoints[i][1], keypoints[i][0]],
|
44 |
+
total_keypoint_num + keypoint_num)
|
45 |
+
keypoints_with_score_and_id.append(keypoint_with_score_and_id)
|
46 |
+
keypoint_num += 1
|
47 |
+
all_keypoints.append(keypoints_with_score_and_id)
|
48 |
+
return keypoint_num
|
49 |
+
|
50 |
+
|
51 |
+
def group_keypoints(all_keypoints_by_type, pafs, pose_entry_size=20, min_paf_score=0.05, demo=False):
|
52 |
+
pose_entries = []
|
53 |
+
all_keypoints = np.array([item for sublist in all_keypoints_by_type for item in sublist])
|
54 |
+
for part_id in range(len(BODY_PARTS_PAF_IDS)):
|
55 |
+
part_pafs = pafs[:, :, BODY_PARTS_PAF_IDS[part_id]]
|
56 |
+
kpts_a = all_keypoints_by_type[BODY_PARTS_KPT_IDS[part_id][0]]
|
57 |
+
kpts_b = all_keypoints_by_type[BODY_PARTS_KPT_IDS[part_id][1]]
|
58 |
+
num_kpts_a = len(kpts_a)
|
59 |
+
num_kpts_b = len(kpts_b)
|
60 |
+
kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
|
61 |
+
kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
|
62 |
+
|
63 |
+
if num_kpts_a == 0 and num_kpts_b == 0: # no keypoints for such body part
|
64 |
+
continue
|
65 |
+
elif num_kpts_a == 0: # body part has just 'b' keypoints
|
66 |
+
for i in range(num_kpts_b):
|
67 |
+
num = 0
|
68 |
+
for j in range(len(pose_entries)): # check if already in some pose, was added by another body part
|
69 |
+
if pose_entries[j][kpt_b_id] == kpts_b[i][3]:
|
70 |
+
num += 1
|
71 |
+
continue
|
72 |
+
if num == 0:
|
73 |
+
pose_entry = np.ones(pose_entry_size) * -1
|
74 |
+
pose_entry[kpt_b_id] = kpts_b[i][3] # keypoint idx
|
75 |
+
pose_entry[-1] = 1 # num keypoints in pose
|
76 |
+
pose_entry[-2] = kpts_b[i][2] # pose score
|
77 |
+
pose_entries.append(pose_entry)
|
78 |
+
continue
|
79 |
+
elif num_kpts_b == 0: # body part has just 'a' keypoints
|
80 |
+
for i in range(num_kpts_a):
|
81 |
+
num = 0
|
82 |
+
for j in range(len(pose_entries)):
|
83 |
+
if pose_entries[j][kpt_a_id] == kpts_a[i][3]:
|
84 |
+
num += 1
|
85 |
+
continue
|
86 |
+
if num == 0:
|
87 |
+
pose_entry = np.ones(pose_entry_size) * -1
|
88 |
+
pose_entry[kpt_a_id] = kpts_a[i][3]
|
89 |
+
pose_entry[-1] = 1
|
90 |
+
pose_entry[-2] = kpts_a[i][2]
|
91 |
+
pose_entries.append(pose_entry)
|
92 |
+
continue
|
93 |
+
|
94 |
+
connections = []
|
95 |
+
for i in range(num_kpts_a):
|
96 |
+
kpt_a = np.array(kpts_a[i][0:2])
|
97 |
+
for j in range(num_kpts_b):
|
98 |
+
kpt_b = np.array(kpts_b[j][0:2])
|
99 |
+
mid_point = [(), ()]
|
100 |
+
mid_point[0] = (int(round((kpt_a[0] + kpt_b[0]) * 0.5)),
|
101 |
+
int(round((kpt_a[1] + kpt_b[1]) * 0.5)))
|
102 |
+
mid_point[1] = mid_point[0]
|
103 |
+
|
104 |
+
vec = [kpt_b[0] - kpt_a[0], kpt_b[1] - kpt_a[1]]
|
105 |
+
vec_norm = math.sqrt(vec[0] ** 2 + vec[1] ** 2)
|
106 |
+
if vec_norm == 0:
|
107 |
+
continue
|
108 |
+
vec[0] /= vec_norm
|
109 |
+
vec[1] /= vec_norm
|
110 |
+
cur_point_score = (vec[0] * part_pafs[mid_point[0][1], mid_point[0][0], 0] +
|
111 |
+
vec[1] * part_pafs[mid_point[1][1], mid_point[1][0], 1])
|
112 |
+
|
113 |
+
height_n = pafs.shape[0] // 2
|
114 |
+
success_ratio = 0
|
115 |
+
point_num = 10 # number of points to integration over paf
|
116 |
+
if cur_point_score > -100:
|
117 |
+
passed_point_score = 0
|
118 |
+
passed_point_num = 0
|
119 |
+
x, y = linspace2d(kpt_a, kpt_b)
|
120 |
+
for point_idx in range(point_num):
|
121 |
+
if not demo:
|
122 |
+
px = int(round(x[point_idx]))
|
123 |
+
py = int(round(y[point_idx]))
|
124 |
+
else:
|
125 |
+
px = int(x[point_idx])
|
126 |
+
py = int(y[point_idx])
|
127 |
+
paf = part_pafs[py, px, 0:2]
|
128 |
+
cur_point_score = vec[0] * paf[0] + vec[1] * paf[1]
|
129 |
+
if cur_point_score > min_paf_score:
|
130 |
+
passed_point_score += cur_point_score
|
131 |
+
passed_point_num += 1
|
132 |
+
success_ratio = passed_point_num / point_num
|
133 |
+
ratio = 0
|
134 |
+
if passed_point_num > 0:
|
135 |
+
ratio = passed_point_score / passed_point_num
|
136 |
+
ratio += min(height_n / vec_norm - 1, 0)
|
137 |
+
if ratio > 0 and success_ratio > 0.8:
|
138 |
+
score_all = ratio + kpts_a[i][2] + kpts_b[j][2]
|
139 |
+
connections.append([i, j, ratio, score_all])
|
140 |
+
if len(connections) > 0:
|
141 |
+
connections = sorted(connections, key=itemgetter(2), reverse=True)
|
142 |
+
|
143 |
+
num_connections = min(num_kpts_a, num_kpts_b)
|
144 |
+
has_kpt_a = np.zeros(num_kpts_a, dtype=np.int32)
|
145 |
+
has_kpt_b = np.zeros(num_kpts_b, dtype=np.int32)
|
146 |
+
filtered_connections = []
|
147 |
+
for row in range(len(connections)):
|
148 |
+
if len(filtered_connections) == num_connections:
|
149 |
+
break
|
150 |
+
i, j, cur_point_score = connections[row][0:3]
|
151 |
+
if not has_kpt_a[i] and not has_kpt_b[j]:
|
152 |
+
filtered_connections.append([kpts_a[i][3], kpts_b[j][3], cur_point_score])
|
153 |
+
has_kpt_a[i] = 1
|
154 |
+
has_kpt_b[j] = 1
|
155 |
+
connections = filtered_connections
|
156 |
+
if len(connections) == 0:
|
157 |
+
continue
|
158 |
+
|
159 |
+
if part_id == 0:
|
160 |
+
pose_entries = [np.ones(pose_entry_size) * -1 for _ in range(len(connections))]
|
161 |
+
for i in range(len(connections)):
|
162 |
+
pose_entries[i][BODY_PARTS_KPT_IDS[0][0]] = connections[i][0]
|
163 |
+
pose_entries[i][BODY_PARTS_KPT_IDS[0][1]] = connections[i][1]
|
164 |
+
pose_entries[i][-1] = 2
|
165 |
+
pose_entries[i][-2] = np.sum(all_keypoints[connections[i][0:2], 2]) + connections[i][2]
|
166 |
+
elif part_id == 17 or part_id == 18:
|
167 |
+
kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
|
168 |
+
kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
|
169 |
+
for i in range(len(connections)):
|
170 |
+
for j in range(len(pose_entries)):
|
171 |
+
if pose_entries[j][kpt_a_id] == connections[i][0] and pose_entries[j][kpt_b_id] == -1:
|
172 |
+
pose_entries[j][kpt_b_id] = connections[i][1]
|
173 |
+
elif pose_entries[j][kpt_b_id] == connections[i][1] and pose_entries[j][kpt_a_id] == -1:
|
174 |
+
pose_entries[j][kpt_a_id] = connections[i][0]
|
175 |
+
continue
|
176 |
+
else:
|
177 |
+
kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
|
178 |
+
kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
|
179 |
+
for i in range(len(connections)):
|
180 |
+
num = 0
|
181 |
+
for j in range(len(pose_entries)):
|
182 |
+
if pose_entries[j][kpt_a_id] == connections[i][0]:
|
183 |
+
pose_entries[j][kpt_b_id] = connections[i][1]
|
184 |
+
num += 1
|
185 |
+
pose_entries[j][-1] += 1
|
186 |
+
pose_entries[j][-2] += all_keypoints[connections[i][1], 2] + connections[i][2]
|
187 |
+
if num == 0:
|
188 |
+
pose_entry = np.ones(pose_entry_size) * -1
|
189 |
+
pose_entry[kpt_a_id] = connections[i][0]
|
190 |
+
pose_entry[kpt_b_id] = connections[i][1]
|
191 |
+
pose_entry[-1] = 2
|
192 |
+
pose_entry[-2] = np.sum(all_keypoints[connections[i][0:2], 2]) + connections[i][2]
|
193 |
+
pose_entries.append(pose_entry)
|
194 |
+
|
195 |
+
filtered_entries = []
|
196 |
+
for i in range(len(pose_entries)):
|
197 |
+
if pose_entries[i][-1] < 3 or (pose_entries[i][-2] / pose_entries[i][-1] < 0.2):
|
198 |
+
continue
|
199 |
+
filtered_entries.append(pose_entries[i])
|
200 |
+
pose_entries = np.asarray(filtered_entries)
|
201 |
+
return pose_entries, all_keypoints
|
lite_openpose/modules/load_state.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
|
3 |
+
|
4 |
+
def load_state(net, checkpoint):
|
5 |
+
source_state = checkpoint['state_dict']
|
6 |
+
target_state = net.state_dict()
|
7 |
+
new_target_state = collections.OrderedDict()
|
8 |
+
for target_key, target_value in target_state.items():
|
9 |
+
if target_key in source_state and source_state[target_key].size() == target_state[target_key].size():
|
10 |
+
new_target_state[target_key] = source_state[target_key]
|
11 |
+
else:
|
12 |
+
new_target_state[target_key] = target_state[target_key]
|
13 |
+
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
|
14 |
+
|
15 |
+
net.load_state_dict(new_target_state)
|
16 |
+
|
17 |
+
|
18 |
+
def load_from_mobilenet(net, checkpoint):
|
19 |
+
source_state = checkpoint['state_dict']
|
20 |
+
target_state = net.state_dict()
|
21 |
+
new_target_state = collections.OrderedDict()
|
22 |
+
for target_key, target_value in target_state.items():
|
23 |
+
k = target_key
|
24 |
+
if k.find('model') != -1:
|
25 |
+
k = k.replace('model', 'module.model')
|
26 |
+
if k in source_state and source_state[k].size() == target_state[target_key].size():
|
27 |
+
new_target_state[target_key] = source_state[k]
|
28 |
+
else:
|
29 |
+
new_target_state[target_key] = target_state[target_key]
|
30 |
+
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
|
31 |
+
|
32 |
+
net.load_state_dict(new_target_state)
|
lite_openpose/modules/loss.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def l2_loss(input, target, mask, batch_size):
|
2 |
+
loss = (input - target) * mask
|
3 |
+
loss = (loss * loss) / 2 / batch_size
|
4 |
+
|
5 |
+
return loss.sum()
|
lite_openpose/modules/one_euro_filter.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
|
4 |
+
def get_alpha(rate=30, cutoff=1):
|
5 |
+
tau = 1 / (2 * math.pi * cutoff)
|
6 |
+
te = 1 / rate
|
7 |
+
return 1 / (1 + tau / te)
|
8 |
+
|
9 |
+
|
10 |
+
class LowPassFilter:
|
11 |
+
def __init__(self):
|
12 |
+
self.x_previous = None
|
13 |
+
|
14 |
+
def __call__(self, x, alpha=0.5):
|
15 |
+
if self.x_previous is None:
|
16 |
+
self.x_previous = x
|
17 |
+
return x
|
18 |
+
x_filtered = alpha * x + (1 - alpha) * self.x_previous
|
19 |
+
self.x_previous = x_filtered
|
20 |
+
return x_filtered
|
21 |
+
|
22 |
+
|
23 |
+
class OneEuroFilter:
|
24 |
+
def __init__(self, freq=15, mincutoff=1, beta=0.05, dcutoff=1):
|
25 |
+
self.freq = freq
|
26 |
+
self.mincutoff = mincutoff
|
27 |
+
self.beta = beta
|
28 |
+
self.dcutoff = dcutoff
|
29 |
+
self.filter_x = LowPassFilter()
|
30 |
+
self.filter_dx = LowPassFilter()
|
31 |
+
self.x_previous = None
|
32 |
+
self.dx = None
|
33 |
+
|
34 |
+
def __call__(self, x):
|
35 |
+
if self.dx is None:
|
36 |
+
self.dx = 0
|
37 |
+
else:
|
38 |
+
self.dx = (x - self.x_previous) * self.freq
|
39 |
+
dx_smoothed = self.filter_dx(self.dx, get_alpha(self.freq, self.dcutoff))
|
40 |
+
cutoff = self.mincutoff + self.beta * abs(dx_smoothed)
|
41 |
+
x_filtered = self.filter_x(x, get_alpha(self.freq, cutoff))
|
42 |
+
self.x_previous = x
|
43 |
+
return x_filtered
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
filter = OneEuroFilter(freq=15, beta=0.1)
|
48 |
+
for val in range(10):
|
49 |
+
x = val + (-1)**(val % 2)
|
50 |
+
x_filtered = filter(x)
|
51 |
+
print(x_filtered, x)
|
lite_openpose/modules/pose.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from modules.keypoints import BODY_PARTS_KPT_IDS, BODY_PARTS_PAF_IDS
|
5 |
+
from modules.one_euro_filter import OneEuroFilter
|
6 |
+
|
7 |
+
|
8 |
+
class Pose:
|
9 |
+
num_kpts = 18
|
10 |
+
kpt_names = ['nose', 'neck',
|
11 |
+
'r_sho', 'r_elb', 'r_wri', 'l_sho', 'l_elb', 'l_wri',
|
12 |
+
'r_hip', 'r_knee', 'r_ank', 'l_hip', 'l_knee', 'l_ank',
|
13 |
+
'r_eye', 'l_eye',
|
14 |
+
'r_ear', 'l_ear']
|
15 |
+
sigmas = np.array([.26, .79, .79, .72, .62, .79, .72, .62, 1.07, .87, .89, 1.07, .87, .89, .25, .25, .35, .35],
|
16 |
+
dtype=np.float32) / 10.0
|
17 |
+
vars = (sigmas * 2) ** 2
|
18 |
+
last_id = -1
|
19 |
+
color = [0, 224, 255]
|
20 |
+
|
21 |
+
def __init__(self, keypoints, confidence):
|
22 |
+
super().__init__()
|
23 |
+
self.keypoints = keypoints
|
24 |
+
self.confidence = confidence
|
25 |
+
self.bbox = Pose.get_bbox(self.keypoints)
|
26 |
+
self.id = None
|
27 |
+
self.filters = [[OneEuroFilter(), OneEuroFilter()] for _ in range(Pose.num_kpts)]
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def get_bbox(keypoints):
|
31 |
+
found_keypoints = np.zeros((np.count_nonzero(keypoints[:, 0] != -1), 2), dtype=np.int32)
|
32 |
+
found_kpt_id = 0
|
33 |
+
for kpt_id in range(Pose.num_kpts):
|
34 |
+
if keypoints[kpt_id, 0] == -1:
|
35 |
+
continue
|
36 |
+
found_keypoints[found_kpt_id] = keypoints[kpt_id]
|
37 |
+
found_kpt_id += 1
|
38 |
+
bbox = cv2.boundingRect(found_keypoints)
|
39 |
+
return bbox
|
40 |
+
|
41 |
+
def update_id(self, id=None):
|
42 |
+
self.id = id
|
43 |
+
if self.id is None:
|
44 |
+
self.id = Pose.last_id + 1
|
45 |
+
Pose.last_id += 1
|
46 |
+
|
47 |
+
def draw(self, img):
|
48 |
+
assert self.keypoints.shape == (Pose.num_kpts, 2)
|
49 |
+
|
50 |
+
for part_id in range(len(BODY_PARTS_PAF_IDS) - 2):
|
51 |
+
kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
|
52 |
+
global_kpt_a_id = self.keypoints[kpt_a_id, 0]
|
53 |
+
if global_kpt_a_id != -1:
|
54 |
+
x_a, y_a = self.keypoints[kpt_a_id]
|
55 |
+
cv2.circle(img, (int(x_a), int(y_a)), 3, Pose.color, -1)
|
56 |
+
kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
|
57 |
+
global_kpt_b_id = self.keypoints[kpt_b_id, 0]
|
58 |
+
if global_kpt_b_id != -1:
|
59 |
+
x_b, y_b = self.keypoints[kpt_b_id]
|
60 |
+
cv2.circle(img, (int(x_b), int(y_b)), 3, Pose.color, -1)
|
61 |
+
if global_kpt_a_id != -1 and global_kpt_b_id != -1:
|
62 |
+
cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), Pose.color, 2)
|
63 |
+
|
64 |
+
|
65 |
+
def get_similarity(a, b, threshold=0.5):
|
66 |
+
num_similar_kpt = 0
|
67 |
+
for kpt_id in range(Pose.num_kpts):
|
68 |
+
if a.keypoints[kpt_id, 0] != -1 and b.keypoints[kpt_id, 0] != -1:
|
69 |
+
distance = np.sum((a.keypoints[kpt_id] - b.keypoints[kpt_id]) ** 2)
|
70 |
+
area = max(a.bbox[2] * a.bbox[3], b.bbox[2] * b.bbox[3])
|
71 |
+
similarity = np.exp(-distance / (2 * (area + np.spacing(1)) * Pose.vars[kpt_id]))
|
72 |
+
if similarity > threshold:
|
73 |
+
num_similar_kpt += 1
|
74 |
+
return num_similar_kpt
|
75 |
+
|
76 |
+
|
77 |
+
def track_poses(previous_poses, current_poses, threshold=3, smooth=False):
|
78 |
+
"""Propagate poses ids from previous frame results. Id is propagated,
|
79 |
+
if there are at least `threshold` similar keypoints between pose from previous frame and current.
|
80 |
+
If correspondence between pose on previous and current frame was established, pose keypoints are smoothed.
|
81 |
+
|
82 |
+
:param previous_poses: poses from previous frame with ids
|
83 |
+
:param current_poses: poses from current frame to assign ids
|
84 |
+
:param threshold: minimal number of similar keypoints between poses
|
85 |
+
:param smooth: smooth pose keypoints between frames
|
86 |
+
:return: None
|
87 |
+
"""
|
88 |
+
current_poses = sorted(current_poses, key=lambda pose: pose.confidence, reverse=True) # match confident poses first
|
89 |
+
mask = np.ones(len(previous_poses), dtype=np.int32)
|
90 |
+
for current_pose in current_poses:
|
91 |
+
best_matched_id = None
|
92 |
+
best_matched_pose_id = None
|
93 |
+
best_matched_iou = 0
|
94 |
+
for id, previous_pose in enumerate(previous_poses):
|
95 |
+
if not mask[id]:
|
96 |
+
continue
|
97 |
+
iou = get_similarity(current_pose, previous_pose)
|
98 |
+
if iou > best_matched_iou:
|
99 |
+
best_matched_iou = iou
|
100 |
+
best_matched_pose_id = previous_pose.id
|
101 |
+
best_matched_id = id
|
102 |
+
if best_matched_iou >= threshold:
|
103 |
+
mask[best_matched_id] = 0
|
104 |
+
else: # pose not similar to any previous
|
105 |
+
best_matched_pose_id = None
|
106 |
+
current_pose.update_id(best_matched_pose_id)
|
107 |
+
|
108 |
+
if smooth:
|
109 |
+
for kpt_id in range(Pose.num_kpts):
|
110 |
+
if current_pose.keypoints[kpt_id, 0] == -1:
|
111 |
+
continue
|
112 |
+
# reuse filter if previous pose has valid filter
|
113 |
+
if (best_matched_pose_id is not None
|
114 |
+
and previous_poses[best_matched_id].keypoints[kpt_id, 0] != -1):
|
115 |
+
current_pose.filters[kpt_id] = previous_poses[best_matched_id].filters[kpt_id]
|
116 |
+
current_pose.keypoints[kpt_id, 0] = current_pose.filters[kpt_id][0](current_pose.keypoints[kpt_id, 0])
|
117 |
+
current_pose.keypoints[kpt_id, 1] = current_pose.filters[kpt_id][1](current_pose.keypoints[kpt_id, 1])
|
118 |
+
current_pose.bbox = Pose.get_bbox(current_pose.keypoints)
|
lite_openpose/pose2d_models/__init__.py
ADDED
File without changes
|
lite_openpose/pose2d_models/with_mobilenet.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from modules.conv import conv, conv_dw, conv_dw_no_bn
|
5 |
+
|
6 |
+
|
7 |
+
class Cpm(nn.Module):
|
8 |
+
def __init__(self, in_channels, out_channels):
|
9 |
+
super().__init__()
|
10 |
+
self.align = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
|
11 |
+
self.trunk = nn.Sequential(
|
12 |
+
conv_dw_no_bn(out_channels, out_channels),
|
13 |
+
conv_dw_no_bn(out_channels, out_channels),
|
14 |
+
conv_dw_no_bn(out_channels, out_channels)
|
15 |
+
)
|
16 |
+
self.conv = conv(out_channels, out_channels, bn=False)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x = self.align(x)
|
20 |
+
x = self.conv(x + self.trunk(x))
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
class InitialStage(nn.Module):
|
25 |
+
def __init__(self, num_channels, num_heatmaps, num_pafs):
|
26 |
+
super().__init__()
|
27 |
+
self.trunk = nn.Sequential(
|
28 |
+
conv(num_channels, num_channels, bn=False),
|
29 |
+
conv(num_channels, num_channels, bn=False),
|
30 |
+
conv(num_channels, num_channels, bn=False)
|
31 |
+
)
|
32 |
+
self.heatmaps = nn.Sequential(
|
33 |
+
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
|
34 |
+
conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
|
35 |
+
)
|
36 |
+
self.pafs = nn.Sequential(
|
37 |
+
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
|
38 |
+
conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
trunk_features = self.trunk(x)
|
43 |
+
heatmaps = self.heatmaps(trunk_features)
|
44 |
+
pafs = self.pafs(trunk_features)
|
45 |
+
return [heatmaps, pafs]
|
46 |
+
|
47 |
+
|
48 |
+
class RefinementStageBlock(nn.Module):
|
49 |
+
def __init__(self, in_channels, out_channels):
|
50 |
+
super().__init__()
|
51 |
+
self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
|
52 |
+
self.trunk = nn.Sequential(
|
53 |
+
conv(out_channels, out_channels),
|
54 |
+
conv(out_channels, out_channels, dilation=2, padding=2)
|
55 |
+
)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
initial_features = self.initial(x)
|
59 |
+
trunk_features = self.trunk(initial_features)
|
60 |
+
return initial_features + trunk_features
|
61 |
+
|
62 |
+
|
63 |
+
class RefinementStage(nn.Module):
|
64 |
+
def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs):
|
65 |
+
super().__init__()
|
66 |
+
self.trunk = nn.Sequential(
|
67 |
+
RefinementStageBlock(in_channels, out_channels),
|
68 |
+
RefinementStageBlock(out_channels, out_channels),
|
69 |
+
RefinementStageBlock(out_channels, out_channels),
|
70 |
+
RefinementStageBlock(out_channels, out_channels),
|
71 |
+
RefinementStageBlock(out_channels, out_channels)
|
72 |
+
)
|
73 |
+
self.heatmaps = nn.Sequential(
|
74 |
+
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
|
75 |
+
conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
|
76 |
+
)
|
77 |
+
self.pafs = nn.Sequential(
|
78 |
+
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
|
79 |
+
conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
|
80 |
+
)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
trunk_features = self.trunk(x)
|
84 |
+
heatmaps = self.heatmaps(trunk_features)
|
85 |
+
pafs = self.pafs(trunk_features)
|
86 |
+
return [heatmaps, pafs]
|
87 |
+
|
88 |
+
|
89 |
+
class PoseEstimationWithMobileNet(nn.Module):
|
90 |
+
def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
|
91 |
+
super().__init__()
|
92 |
+
self.model = nn.Sequential(
|
93 |
+
conv( 3, 32, stride=2, bias=False),
|
94 |
+
conv_dw( 32, 64),
|
95 |
+
conv_dw( 64, 128, stride=2),
|
96 |
+
conv_dw(128, 128),
|
97 |
+
conv_dw(128, 256, stride=2),
|
98 |
+
conv_dw(256, 256),
|
99 |
+
conv_dw(256, 512), # conv4_2
|
100 |
+
conv_dw(512, 512, dilation=2, padding=2),
|
101 |
+
conv_dw(512, 512),
|
102 |
+
conv_dw(512, 512),
|
103 |
+
conv_dw(512, 512),
|
104 |
+
conv_dw(512, 512) # conv5_5
|
105 |
+
)
|
106 |
+
self.cpm = Cpm(512, num_channels)
|
107 |
+
|
108 |
+
self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs)
|
109 |
+
self.refinement_stages = nn.ModuleList()
|
110 |
+
for idx in range(num_refinement_stages):
|
111 |
+
self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels,
|
112 |
+
num_heatmaps, num_pafs))
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
backbone_features = self.model(x)
|
116 |
+
backbone_features = self.cpm(backbone_features)
|
117 |
+
|
118 |
+
stages_output = self.initial_stage(backbone_features)
|
119 |
+
for refinement_stage in self.refinement_stages:
|
120 |
+
stages_output.extend(
|
121 |
+
refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1)))
|
122 |
+
|
123 |
+
return stages_output
|