Spaces:
Sleeping
Sleeping
upload_part1
Browse files- render_utils/__pycache__/stitch_body_and_head.cpython-310.pyc +0 -0
- render_utils/__pycache__/stitch_funcs.cpython-310.pyc +0 -0
- render_utils/calc_smplx2faceverse.py +238 -0
- render_utils/camera_dir.py +171 -0
- render_utils/lib/networks/__init__.py +0 -0
- render_utils/lib/networks/__pycache__/__init__.cpython-310.pyc +0 -0
- render_utils/lib/networks/__pycache__/__init__.cpython-38.pyc +0 -0
- render_utils/lib/networks/__pycache__/faceverse_torch.cpython-310.pyc +0 -0
- render_utils/lib/networks/__pycache__/faceverse_torch.cpython-38.pyc +0 -0
- render_utils/lib/networks/__pycache__/smpl_torch.cpython-310.pyc +0 -0
- render_utils/lib/networks/__pycache__/smpl_torch.cpython-38.pyc +0 -0
- render_utils/lib/networks/faceverse_torch.py +292 -0
- render_utils/lib/networks/smpl_torch.py +341 -0
- render_utils/lib/utils/__init__.py +0 -0
- render_utils/lib/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- render_utils/lib/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- render_utils/lib/utils/__pycache__/gaussian_np_utils.cpython-310.pyc +0 -0
- render_utils/lib/utils/__pycache__/gaussian_np_utils.cpython-38.pyc +0 -0
- render_utils/lib/utils/__pycache__/geometry.cpython-310.pyc +0 -0
- render_utils/lib/utils/__pycache__/geometry.cpython-38.pyc +0 -0
- render_utils/lib/utils/__pycache__/graphics_utils.cpython-310.pyc +0 -0
- render_utils/lib/utils/__pycache__/graphics_utils.cpython-38.pyc +0 -0
- render_utils/lib/utils/__pycache__/rotation_conversions.cpython-310.pyc +0 -0
- render_utils/lib/utils/__pycache__/rotation_conversions.cpython-38.pyc +0 -0
- render_utils/lib/utils/__pycache__/sh_utils.cpython-310.pyc +0 -0
- render_utils/lib/utils/__pycache__/sh_utils.cpython-38.pyc +0 -0
- render_utils/lib/utils/gaussian_np_utils.py +162 -0
- render_utils/lib/utils/geometry.py +517 -0
- render_utils/lib/utils/graphics_utils.py +181 -0
- render_utils/lib/utils/rotation_conversions.py +586 -0
- render_utils/lib/utils/sh_utils.py +118 -0
- render_utils/stitch_body_and_head.py +433 -0
- render_utils/stitch_funcs.py +145 -0
render_utils/__pycache__/stitch_body_and_head.cpython-310.pyc
ADDED
Binary file (12.9 kB). View file
|
|
render_utils/__pycache__/stitch_funcs.cpython-310.pyc
ADDED
Binary file (3.73 kB). View file
|
|
render_utils/calc_smplx2faceverse.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
import time, random
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import smplx
|
9 |
+
import pickle
|
10 |
+
import trimesh
|
11 |
+
import os, glob
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
from .lib.networks.faceverse_torch import FaceVerseModel
|
15 |
+
|
16 |
+
|
17 |
+
def estimate_rigid_transformation(src, tgt):
|
18 |
+
src = src.transpose()
|
19 |
+
tgt = tgt.transpose()
|
20 |
+
mu1, mu2 = src.mean(axis=1, keepdims=True), tgt.mean(axis=1, keepdims=True)
|
21 |
+
X1, X2 = src - mu1, tgt - mu2
|
22 |
+
|
23 |
+
K = X1.dot(X2.T)
|
24 |
+
U, s, Vh = np.linalg.svd(K)
|
25 |
+
V = Vh.T
|
26 |
+
Z = np.eye(U.shape[0])
|
27 |
+
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
|
28 |
+
R = V.dot(Z.dot(U.T))
|
29 |
+
t = mu2 - R.dot(mu1)
|
30 |
+
|
31 |
+
orient, _ = cv2.Rodrigues(R)
|
32 |
+
orient = orient.reshape([-1])
|
33 |
+
t = t.reshape([-1])
|
34 |
+
return orient, t
|
35 |
+
|
36 |
+
|
37 |
+
def load_smpl_beta(body_fitting_result_fpath):
|
38 |
+
if os.path.isdir(body_fitting_result_fpath):
|
39 |
+
body_fitting_result = torch.load(
|
40 |
+
os.path.join(body_fitting_result_fpath, 'checkpoints/latest.pt'), map_location='cpu')
|
41 |
+
betas = body_fitting_result['betas']['weight']
|
42 |
+
elif body_fitting_result_fpath.endswith('.pt'):
|
43 |
+
body_fitting_result = torch.load(body_fitting_result_fpath, map_location='cpu')
|
44 |
+
betas = body_fitting_result['beta'].reshape(1, -1)
|
45 |
+
elif body_fitting_result_fpath.endswith('.npz'):
|
46 |
+
body_fitting_result = np.load(body_fitting_result_fpath)
|
47 |
+
betas = body_fitting_result['betas'].reshape(1, -1)
|
48 |
+
betas = torch.from_numpy(betas.astype(np.float32))
|
49 |
+
else:
|
50 |
+
raise ValueError('Unknown body fitting result file format: {}'.format(body_fitting_result_fpath))
|
51 |
+
return betas
|
52 |
+
|
53 |
+
|
54 |
+
def load_face_id_scale_param(face_fitting_result_fpath):
|
55 |
+
if os.path.isfile(face_fitting_result_fpath):
|
56 |
+
face_fitting_result = dict(np.load(face_fitting_result_fpath))
|
57 |
+
id_tensor = face_fitting_result['id_coeff'].astype(np.float32)
|
58 |
+
scale_tensor = face_fitting_result['scale'].astype(np.float32)
|
59 |
+
id_tensor = torch.from_numpy(id_tensor).reshape(1, -1)
|
60 |
+
scale_tensor = torch.from_numpy(scale_tensor).reshape(1, -1)
|
61 |
+
else:
|
62 |
+
param_paths = sorted(glob.glob(os.path.join(face_fitting_result_fpath, '*', 'params.npz')))
|
63 |
+
param = np.load(param_paths[0])
|
64 |
+
id_tensor = torch.from_numpy(param['id_coeff']).reshape(1, -1)
|
65 |
+
scale_tensor = torch.from_numpy(param['scale']).reshape(1, 1)
|
66 |
+
|
67 |
+
return id_tensor, scale_tensor
|
68 |
+
|
69 |
+
|
70 |
+
def calc_smplx2faceverse(body_fitting_result_fpath, face_fitting_result_fpath, output_dir):
|
71 |
+
device = torch.device('cuda')
|
72 |
+
os.makedirs(output_dir, exist_ok=True)
|
73 |
+
|
74 |
+
betas = load_smpl_beta(body_fitting_result_fpath)
|
75 |
+
id_tensor, scale_tensor = load_face_id_scale_param(face_fitting_result_fpath)
|
76 |
+
|
77 |
+
smpl = smplx.SMPLX(model_path='./AnimatableGaussians/smpl_files/smplx', gender='neutral',
|
78 |
+
use_pca=True, num_pca_comps=45, flat_hand_mean=True, batch_size=1)
|
79 |
+
flame = smplx.FLAME(model_path='./AnimatableGaussians/smpl_files/FLAME2019', gender='neutral', batch_size=1)
|
80 |
+
|
81 |
+
pose = np.zeros([63], dtype=np.float32)
|
82 |
+
|
83 |
+
pose = torch.from_numpy(pose).unsqueeze(0)
|
84 |
+
smpl_out = smpl(body_pose=pose, betas=betas)
|
85 |
+
verts = smpl_out.vertices.detach().cpu().squeeze(0).numpy()
|
86 |
+
flame_out = flame()
|
87 |
+
verts_flame = flame_out.vertices.detach().cpu().squeeze(0).numpy()
|
88 |
+
|
89 |
+
smplx2flame_data = np.load('./data/smpl_models/smplx_mano_flame_correspondences/SMPL-X__FLAME_vertex_ids.npy')
|
90 |
+
verts_flame_on_smplx = verts[smplx2flame_data]
|
91 |
+
|
92 |
+
orient, t = estimate_rigid_transformation(verts_flame, verts_flame_on_smplx)
|
93 |
+
R, _ = cv2.Rodrigues(orient)
|
94 |
+
|
95 |
+
rel_transf = np.eye(4)
|
96 |
+
rel_transf[:3, :3] = R
|
97 |
+
rel_transf[:3, 3] = t.reshape(-1)
|
98 |
+
np.save('%s/flame_to_smplx.npy' % (output_dir), rel_transf.astype(np.float32))
|
99 |
+
|
100 |
+
# TODO: DELETE ME
|
101 |
+
with open('./debug/debug_verts_smplx.obj', 'w') as fp:
|
102 |
+
for v in verts:
|
103 |
+
fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
|
104 |
+
with open('./debug/debug_verts_flame_in_smplx.obj', 'w') as fp:
|
105 |
+
for v in np.matmul(verts_flame, R.transpose()) + t.reshape([1, 3]):
|
106 |
+
fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
|
107 |
+
|
108 |
+
# align Faceverse to T-pose FLAME on SMPL-X
|
109 |
+
faceverse_mesh = trimesh.load_mesh('./data/smpl_models/faceverse2flame/faceverse_icp.obj', process=False)
|
110 |
+
verts_faceverse_ref = np.matmul(np.asarray(faceverse_mesh.vertices), R.transpose()) + t.reshape(1, 3)
|
111 |
+
|
112 |
+
# TODO: DELETE ME
|
113 |
+
with open('./debug/debug_verts_faceverse_ref.obj', 'w') as fp:
|
114 |
+
for v in verts_faceverse_ref:
|
115 |
+
fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
|
116 |
+
|
117 |
+
model_dict = np.load('./data/faceverse_models/faceverse_simple_v2.npy', allow_pickle=True).item()
|
118 |
+
faceverse_model = FaceVerseModel(model_dict, batch_size=1)
|
119 |
+
faceverse_model.init_coeff_tensors()
|
120 |
+
coeffs = faceverse_model.get_packed_tensors()
|
121 |
+
fv_out = faceverse_model.forward(coeffs=coeffs)
|
122 |
+
verts_faceverse = fv_out['v'].squeeze(0).detach().cpu().numpy()
|
123 |
+
|
124 |
+
# calculate the relative transformation between FLAME in canonical pose and its position on SMPL-X
|
125 |
+
orient, t = estimate_rigid_transformation(verts_faceverse, verts_faceverse_ref)
|
126 |
+
orient = torch.from_numpy(orient.astype(np.float32)).unsqueeze(0).to(device)
|
127 |
+
t = torch.from_numpy(t.astype(np.float32)).unsqueeze(0).to(device)
|
128 |
+
|
129 |
+
# optimize the Faceverse to fit SMPL-X
|
130 |
+
faceverse_model.init_coeff_tensors(
|
131 |
+
rot_coeff=orient, trans_coeff=t, id_coeff=id_tensor.to(device), scale_coeff=scale_tensor.to(device))
|
132 |
+
nonrigid_optim_params = [
|
133 |
+
faceverse_model.get_exp_tensor(), faceverse_model.get_rot_tensor(), faceverse_model.get_trans_tensor(),
|
134 |
+
# faceverse_model.get_scale_tensor(), faceverse_model.get_id_tensor()
|
135 |
+
]
|
136 |
+
nonrigid_optimizer = torch.optim.Adam(nonrigid_optim_params, lr=1e-1)
|
137 |
+
verts_faceverse_ref = torch.from_numpy(verts_faceverse_ref.astype(np.float32)).to(device).unsqueeze(0)
|
138 |
+
for iter in range(1000):
|
139 |
+
coeffs = faceverse_model.get_packed_tensors()
|
140 |
+
fv_out = faceverse_model.forward(coeffs=coeffs)
|
141 |
+
verts_pred = fv_out['v']
|
142 |
+
loss = torch.mean(torch.square(verts_pred - verts_faceverse_ref))
|
143 |
+
if iter % 10 == 0:
|
144 |
+
print(loss.item())
|
145 |
+
nonrigid_optimizer.zero_grad()
|
146 |
+
loss.backward()
|
147 |
+
nonrigid_optimizer.step()
|
148 |
+
|
149 |
+
np.savez('%s/faceverse_param_to_smplx.npz' % (output_dir), {
|
150 |
+
'id': faceverse_model.get_id_tensor().detach().cpu().numpy(),
|
151 |
+
'exp': faceverse_model.get_exp_tensor().detach().cpu().numpy(),
|
152 |
+
'rot': faceverse_model.get_rot_tensor().detach().cpu().numpy(),
|
153 |
+
'transl': faceverse_model.get_trans_tensor().detach().cpu().numpy(),
|
154 |
+
'scale': faceverse_model.get_scale_tensor().detach().cpu().numpy(),
|
155 |
+
})
|
156 |
+
|
157 |
+
# calculate SMPLX to faceverse space transformation (without scale)
|
158 |
+
orient = faceverse_model.get_rot_tensor().detach().cpu().numpy()
|
159 |
+
transl = faceverse_model.get_trans_tensor().detach().cpu().numpy()
|
160 |
+
rotmat, _ = cv2.Rodrigues(orient)
|
161 |
+
transform_mat = np.eye(4)
|
162 |
+
transform_mat[:3, :3] = rotmat
|
163 |
+
transform_mat[:3, 3] = transl
|
164 |
+
transform_mat = np.linalg.inv(transform_mat)
|
165 |
+
np.save('%s/smplx_to_faceverse_space.npy' % (output_dir), transform_mat.astype(np.float32))
|
166 |
+
|
167 |
+
# calculate SMPLX to faceverse distance
|
168 |
+
dists = []
|
169 |
+
verts_faceverse_ref = verts_faceverse_ref.detach().cpu().numpy()
|
170 |
+
for v in verts:
|
171 |
+
dist = np.linalg.norm(v.reshape(1, 3) - verts_faceverse_ref, axis=-1)
|
172 |
+
dist = np.min(dist)
|
173 |
+
dists.append(dist)
|
174 |
+
dists = np.asarray(dists)
|
175 |
+
np.save('%s/smplx_verts_to_faceverse_dist.npy' % (output_dir), dists.astype(np.float32))
|
176 |
+
|
177 |
+
# sample nodes on facial area
|
178 |
+
dists_ = np.ones_like(dists)
|
179 |
+
smplx_topo_new = np.load('./data/smpl_models/smplx_topo_new.npy')
|
180 |
+
valid_vids = set(smplx_topo_new.reshape([-1]).tolist())
|
181 |
+
dists_[np.asarray(list(valid_vids))] = dists[np.asarray(list(valid_vids))]
|
182 |
+
|
183 |
+
vids_on_face = np.where(dists_ < 0.01)[0]
|
184 |
+
verts_on_face = verts[vids_on_face]
|
185 |
+
geod_dist_mat = np.linalg.norm(np.expand_dims(verts_on_face, axis=0) - np.expand_dims(verts_on_face, axis=1),
|
186 |
+
axis=2)
|
187 |
+
nodes = [0] # nose
|
188 |
+
dist_nodes_to_rest_points = geod_dist_mat[nodes[0]]
|
189 |
+
for i in range(1, 256):
|
190 |
+
idx = np.argmax(dist_nodes_to_rest_points)
|
191 |
+
nodes.append(idx)
|
192 |
+
new_dist = geod_dist_mat[idx]
|
193 |
+
update_flag = new_dist < dist_nodes_to_rest_points
|
194 |
+
dist_nodes_to_rest_points[update_flag] = new_dist[update_flag]
|
195 |
+
|
196 |
+
# with open('./debug/debug_face_nodes.obj', 'w') as fp:
|
197 |
+
# for n in verts_on_face[np.asarray(nodes)]:
|
198 |
+
# fp.write('v %f %f %f\n' % (n[0], n[1], n[2]))
|
199 |
+
vids_on_faces_sampled = vids_on_face[np.asarray(nodes)]
|
200 |
+
vids_on_faces_sampled = np.ascontiguousarray(vids_on_faces_sampled).astype(np.int32)
|
201 |
+
np.save('%s/vids_on_faces_sampled.npy' % (output_dir), vids_on_faces_sampled)
|
202 |
+
|
203 |
+
# test SMPLX-to-faceverse space transformation (without scale)
|
204 |
+
verts_smpl_in_faceverse = np.matmul(verts, transform_mat[:3, :3].transpose()) + \
|
205 |
+
transform_mat[:3, 3].reshape(1, 3)
|
206 |
+
with open('./debug/debug_verts_smpl_in_faceverse.obj', 'w') as fp:
|
207 |
+
for v in verts_smpl_in_faceverse:
|
208 |
+
fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
|
209 |
+
|
210 |
+
# save personalized, canonical faceverse model
|
211 |
+
faceverse_model.init_coeff_tensors(id_coeff=id_tensor.to(device), scale_coeff=scale_tensor.to(device))
|
212 |
+
coeffs = faceverse_model.get_packed_tensors()
|
213 |
+
fv_out = faceverse_model.forward(coeffs=coeffs)
|
214 |
+
verts_faceverse = fv_out['v'].squeeze(0).detach().cpu().numpy()
|
215 |
+
with open('./debug/debug_verts_faceverse.obj', 'w') as fp:
|
216 |
+
for v in verts_faceverse:
|
217 |
+
fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
|
218 |
+
|
219 |
+
|
220 |
+
if __name__ == '__main__':
|
221 |
+
# body_fitting_result_dir = 'D:/UpperBodyAvatar/code/smplfitting_multiview_sequence_smplx/results/Shuangqing/zzr_fullbody_20221130_01_2k/whole.pt'
|
222 |
+
# face_fitting_result_dir = 'D:\\UpperBodyAvatar\\data\\Shuangqing\\zzr_face_20221130_01_2k\\faceverse_params'
|
223 |
+
#
|
224 |
+
# output_dir = './data/faceverse/'
|
225 |
+
# result_suffix = 'shuangqing_zzr'
|
226 |
+
|
227 |
+
# body_fitting_result_dir = 'D:/Product/FullAppRelease/smplfitting_multiview_sequence_smplx/results/body_data_model_20231224/whole.pt'
|
228 |
+
# face_fitting_result_dir = 'D:/Product/data/HuiyingCenter/20231224_model/model_20231224_face_data/faceverse_params'
|
229 |
+
#
|
230 |
+
# output_dir = './data/faceverse/'
|
231 |
+
# result_suffix = 'huiyin_model20231224'
|
232 |
+
|
233 |
+
body_fitting_result_dir = 'D:/Product/FullAppRelease/smplfitting_multiview_sequence_smplx/results/body_data_male20230530_betterhand/whole.pt'
|
234 |
+
face_fitting_result_dir = 'D:/Product/data/HuiyingCenter/20230531_models/male20230530_face_data/faceverse_params'
|
235 |
+
|
236 |
+
output_dir = './data/body_face_stitching/huiyin_male20230530'
|
237 |
+
|
238 |
+
calc_smplx2faceverse(body_fitting_result_dir, face_fitting_result_dir, output_dir)
|
render_utils/camera_dir.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def get_camera_dir(idx):
|
5 |
+
img_scale = self.body['test'].get('img_scale', 1.0)
|
6 |
+
view_setting = self.body['test'].get('view_setting', 'free')
|
7 |
+
if view_setting == 'camera':
|
8 |
+
# training view setting
|
9 |
+
cam_id = self.body['test']['render_view_idx']
|
10 |
+
intr = self.dataset.intr_mats[cam_id].copy()
|
11 |
+
intr[:2] *= img_scale
|
12 |
+
extr = self.dataset.extr_mats[cam_id].copy()
|
13 |
+
img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale)
|
14 |
+
elif view_setting.startswith('free'):
|
15 |
+
# free view setting
|
16 |
+
# frame_num_per_circle = 360
|
17 |
+
# print(self.opt['test'].get('global_orient', False))
|
18 |
+
frame_num_per_circle = 360
|
19 |
+
rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi
|
20 |
+
|
21 |
+
extr = visualize_util.calc_free_mv(object_center,
|
22 |
+
tar_pos = np.array([0, 0, 2.5]),
|
23 |
+
rot_Y = rot_Y,
|
24 |
+
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
|
25 |
+
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
|
26 |
+
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
|
27 |
+
intr[:2] *= img_scale
|
28 |
+
img_h = int(1024 * img_scale)
|
29 |
+
img_w = int(1024 * img_scale)
|
30 |
+
|
31 |
+
extr_list.append(extr)
|
32 |
+
intr_list.append(intr)
|
33 |
+
img_h_list.append(img_h)
|
34 |
+
img_w_list.append(img_w)
|
35 |
+
|
36 |
+
elif view_setting.startswith('degree120'):
|
37 |
+
print('we render 120 degree')
|
38 |
+
# +- 60 degree
|
39 |
+
frame_per_cycle = 480
|
40 |
+
max_degree = 60
|
41 |
+
frame_half_cycle = frame_per_cycle // 2
|
42 |
+
if idx%frame_per_cycle < frame_per_cycle/2:
|
43 |
+
rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
|
44 |
+
# rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi
|
45 |
+
else:
|
46 |
+
rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
|
47 |
+
|
48 |
+
# to radian
|
49 |
+
rot_Y = rot_Y * np.pi / 180
|
50 |
+
if rot_Y<0:
|
51 |
+
rot_Y = rot_Y + 2 * np.pi
|
52 |
+
# print('rot_Y: ', rot_Y)
|
53 |
+
extr = visualize_util.calc_free_mv(object_center,
|
54 |
+
tar_pos = np.array([0, 0, 2.5]),
|
55 |
+
rot_Y = rot_Y,
|
56 |
+
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
|
57 |
+
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
|
58 |
+
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
|
59 |
+
intr[:2] *= img_scale
|
60 |
+
img_h = int(1024 * img_scale)
|
61 |
+
img_w = int(1024 * img_scale)
|
62 |
+
|
63 |
+
extr_list.append(extr)
|
64 |
+
intr_list.append(intr)
|
65 |
+
img_h_list.append(img_h)
|
66 |
+
img_w_list.append(img_w)
|
67 |
+
|
68 |
+
elif view_setting.startswith('degree90'):
|
69 |
+
print('we render 90 degree')
|
70 |
+
# +- 60 degree
|
71 |
+
frame_per_cycle = 360
|
72 |
+
max_degree = 45
|
73 |
+
frame_half_cycle = frame_per_cycle // 2
|
74 |
+
if idx%frame_per_cycle < frame_per_cycle/2:
|
75 |
+
rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
|
76 |
+
# rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi
|
77 |
+
else:
|
78 |
+
rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
|
79 |
+
|
80 |
+
# to radian
|
81 |
+
rot_Y = rot_Y * np.pi / 180
|
82 |
+
if rot_Y<0:
|
83 |
+
rot_Y = rot_Y + 2 * np.pi
|
84 |
+
# print('rot_Y: ', rot_Y)
|
85 |
+
extr = visualize_util.calc_free_mv(object_center,
|
86 |
+
tar_pos = np.array([0, 0, 2.5]),
|
87 |
+
rot_Y = rot_Y,
|
88 |
+
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
|
89 |
+
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
|
90 |
+
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
|
91 |
+
intr[:2] *= img_scale
|
92 |
+
img_h = int(1024 * img_scale)
|
93 |
+
img_w = int(1024 * img_scale)
|
94 |
+
|
95 |
+
extr_list.append(extr)
|
96 |
+
intr_list.append(intr)
|
97 |
+
img_h_list.append(img_h)
|
98 |
+
img_w_list.append(img_w)
|
99 |
+
|
100 |
+
|
101 |
+
elif view_setting.startswith('front'):
|
102 |
+
# front view setting
|
103 |
+
extr = visualize_util.calc_free_mv(object_center,
|
104 |
+
tar_pos = np.array([0, 0, 2.5]),
|
105 |
+
rot_Y = 0.,
|
106 |
+
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
|
107 |
+
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
|
108 |
+
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
|
109 |
+
intr[:2] *= img_scale
|
110 |
+
img_h = int(1024 * img_scale)
|
111 |
+
img_w = int(1024 * img_scale)
|
112 |
+
|
113 |
+
extr_list.append(extr)
|
114 |
+
intr_list.append(intr)
|
115 |
+
img_h_list.append(img_h)
|
116 |
+
img_w_list.append(img_w)
|
117 |
+
|
118 |
+
# print('extr: ', extr)
|
119 |
+
# print('intr: ', intr)
|
120 |
+
# print('img_h: ', img_h)
|
121 |
+
# print('img_w: ', img_w)
|
122 |
+
# exit()
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
elif view_setting.startswith('back'):
|
127 |
+
# back view setting
|
128 |
+
extr = visualize_util.calc_free_mv(object_center,
|
129 |
+
tar_pos = np.array([0, 0, 2.5]),
|
130 |
+
rot_Y = np.pi,
|
131 |
+
rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0.,
|
132 |
+
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
|
133 |
+
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
|
134 |
+
intr[:2] *= img_scale
|
135 |
+
img_h = int(1024 * img_scale)
|
136 |
+
img_w = int(1024 * img_scale)
|
137 |
+
elif view_setting.startswith('moving'):
|
138 |
+
# moving camera setting
|
139 |
+
extr = visualize_util.calc_free_mv(object_center,
|
140 |
+
# tar_pos = np.array([0, 0, 3.0]),
|
141 |
+
# rot_Y = -0.3,
|
142 |
+
tar_pos = np.array([0, 0, 2.5]),
|
143 |
+
rot_Y = 0.,
|
144 |
+
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
|
145 |
+
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
|
146 |
+
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
|
147 |
+
intr[:2] *= img_scale
|
148 |
+
img_h = int(1024 * img_scale)
|
149 |
+
img_w = int(1024 * img_scale)
|
150 |
+
elif view_setting.startswith('cano'):
|
151 |
+
cano_center = self.dataset.cano_bounds.mean(0)
|
152 |
+
extr = np.identity(4, np.float32)
|
153 |
+
extr[:3, 3] = -cano_center
|
154 |
+
rot_x = np.identity(4, np.float32)
|
155 |
+
rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0]
|
156 |
+
extr = rot_x @ extr
|
157 |
+
f_len = 5000
|
158 |
+
extr[2, 3] += f_len / 512
|
159 |
+
intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32)
|
160 |
+
# item = self.dataset.getitem(idx,
|
161 |
+
# training = False,
|
162 |
+
# extr = extr,
|
163 |
+
# intr = intr,
|
164 |
+
# img_w = 1024,
|
165 |
+
# img_h = 1024)
|
166 |
+
img_w, img_h = 1024, 1024
|
167 |
+
# item['live_smpl_v'] = item['cano_smpl_v']
|
168 |
+
# item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1)
|
169 |
+
# item['live_bounds'] = item['cano_bounds']
|
170 |
+
else:
|
171 |
+
raise ValueError('Invalid view setting for animation!')
|
render_utils/lib/networks/__init__.py
ADDED
File without changes
|
render_utils/lib/networks/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (196 Bytes). View file
|
|
render_utils/lib/networks/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (189 Bytes). View file
|
|
render_utils/lib/networks/__pycache__/faceverse_torch.cpython-310.pyc
ADDED
Binary file (8.88 kB). View file
|
|
render_utils/lib/networks/__pycache__/faceverse_torch.cpython-38.pyc
ADDED
Binary file (9.05 kB). View file
|
|
render_utils/lib/networks/__pycache__/smpl_torch.cpython-310.pyc
ADDED
Binary file (12.8 kB). View file
|
|
render_utils/lib/networks/__pycache__/smpl_torch.cpython-38.pyc
ADDED
Binary file (13.4 kB). View file
|
|
render_utils/lib/networks/faceverse_torch.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
from ..utils.rotation_conversions import axis_angle_to_matrix
|
5 |
+
|
6 |
+
|
7 |
+
class FaceVerseModel(nn.Module):
|
8 |
+
def __init__(self, model_dict, batch_size=1, device='cuda:0', **kargs):
|
9 |
+
super(FaceVerseModel, self).__init__()
|
10 |
+
|
11 |
+
self.batch_size = batch_size
|
12 |
+
self.device = torch.device(device)
|
13 |
+
|
14 |
+
self.rotXYZ = torch.eye(3).view(1, 3, 3).repeat(3, 1, 1).view(3, 1, 3, 3).to(self.device)
|
15 |
+
|
16 |
+
self.skinmask = torch.tensor(model_dict['skinmask'], requires_grad=False, device=self.device)
|
17 |
+
|
18 |
+
self.kp_inds = torch.tensor(model_dict['keypoints'].reshape(-1, 1), requires_grad=False).squeeze().long().to(self.device)
|
19 |
+
|
20 |
+
self.meanshape = torch.tensor(model_dict['meanshape'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)
|
21 |
+
self.meantex = torch.tensor(model_dict['meantex'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)
|
22 |
+
|
23 |
+
self.idBase = torch.tensor(model_dict['idBase'], dtype=torch.float32, requires_grad=False, device=self.device)
|
24 |
+
self.expBase = torch.tensor(model_dict['exBase'], dtype=torch.float32, requires_grad=False, device=self.device)
|
25 |
+
self.texBase = torch.tensor(model_dict['texBase'], dtype=torch.float32, requires_grad=False, device=self.device)
|
26 |
+
# self.expBase[:, :8] *= 4 # 控制眼睛
|
27 |
+
# self.expBase[:, 8:] *= 2 # 控制其它
|
28 |
+
|
29 |
+
self.tri = torch.tensor(model_dict['tri'], dtype=torch.int64, requires_grad=False, device=self.device)
|
30 |
+
# self.tri = self.tri[:, [0, 2, 1]] # TODO
|
31 |
+
self.point_buf = torch.tensor(model_dict['point_buf'], dtype=torch.int64, requires_grad=False, device=self.device)
|
32 |
+
|
33 |
+
self.num_vertex = self.meanshape.shape[0] // 3
|
34 |
+
self.id_dims = self.idBase.shape[1]
|
35 |
+
self.tex_dims = self.texBase.shape[1]
|
36 |
+
self.exp_dims = self.expBase.shape[1]
|
37 |
+
self.all_dims = self.id_dims + self.tex_dims + self.exp_dims
|
38 |
+
self.init_coeff_tensors()
|
39 |
+
|
40 |
+
# for tracking by landmarks
|
41 |
+
self.kp_inds_view = torch.cat([self.kp_inds[:, None] * 3, self.kp_inds[:, None] * 3 + 1, self.kp_inds[:, None] * 3 + 2], dim=1).flatten()
|
42 |
+
self.idBase_view = self.idBase[self.kp_inds_view, :].detach().clone()
|
43 |
+
self.expBase_view = self.expBase[self.kp_inds_view, :].detach().clone()
|
44 |
+
self.meanshape_view = self.meanshape[:, self.kp_inds_view].detach().clone()
|
45 |
+
|
46 |
+
# zxc
|
47 |
+
self.identity = torch.eye(3, dtype=torch.float32, device=self.device)
|
48 |
+
|
49 |
+
def init_coeff_tensors(self, id_coeff=None, tex_coeff=None, exp_coeff=None, gamma_coeff=None, trans_coeff=None, rot_coeff=None, scale_coeff=None):
|
50 |
+
if id_coeff is None:
|
51 |
+
self.id_tensor = torch.zeros(
|
52 |
+
(1, self.id_dims), dtype=torch.float32,
|
53 |
+
requires_grad=True, device=self.device)
|
54 |
+
else:
|
55 |
+
assert id_coeff.shape == (1, self.id_dims)
|
56 |
+
# self.id_tensor = torch.tensor(id_coeff, dtype=torch.float32, requires_grad=True, device=self.device)
|
57 |
+
self.id_tensor = id_coeff.clone().detach().requires_grad_(True)
|
58 |
+
|
59 |
+
if tex_coeff is None:
|
60 |
+
self.tex_tensor = torch.zeros(
|
61 |
+
(1, self.tex_dims), dtype=torch.float32,
|
62 |
+
requires_grad=True, device=self.device)
|
63 |
+
else:
|
64 |
+
assert tex_coeff.shape == (1, self.tex_dims)
|
65 |
+
# self.tex_tensor = torch.tensor(tex_coeff, dtype=torch.float32, requires_grad=True, device=self.device)
|
66 |
+
self.tex_tensor = tex_coeff.clone().detach().requires_grad_(True)
|
67 |
+
|
68 |
+
if exp_coeff is None:
|
69 |
+
self.exp_tensor = torch.zeros(
|
70 |
+
(self.batch_size, self.exp_dims), dtype=torch.float32,
|
71 |
+
requires_grad=True, device=self.device)
|
72 |
+
else:
|
73 |
+
assert exp_coeff.shape == (1, self.exp_dims)
|
74 |
+
# self.exp_tensor = torch.tensor(exp_coeff, dtype=torch.float32, requires_grad=True, device=self.device)
|
75 |
+
self.exp_tensor = exp_coeff.clone().detach().requires_grad_(True)
|
76 |
+
|
77 |
+
if gamma_coeff is None:
|
78 |
+
self.gamma_tensor = torch.zeros(
|
79 |
+
(self.batch_size, 27), dtype=torch.float32,
|
80 |
+
requires_grad=True, device=self.device)
|
81 |
+
else:
|
82 |
+
# self.gamma_tensor = torch.tensor(
|
83 |
+
# gamma_coeff, dtype=torch.float32,
|
84 |
+
# requires_grad=True, device=self.device)
|
85 |
+
self.gamma_tensor = gamma_coeff.clone().detach().requires_grad_(True)
|
86 |
+
|
87 |
+
if trans_coeff is None:
|
88 |
+
self.trans_tensor = torch.zeros(
|
89 |
+
(self.batch_size, 3), dtype=torch.float32,
|
90 |
+
requires_grad=True, device=self.device)
|
91 |
+
else:
|
92 |
+
# self.trans_tensor = torch.tensor(
|
93 |
+
# trans_coeff, dtype=torch.float32,
|
94 |
+
# requires_grad=True, device=self.device)
|
95 |
+
self.trans_tensor = trans_coeff.clone().detach().requires_grad_(True)
|
96 |
+
|
97 |
+
if scale_coeff is None:
|
98 |
+
self.scale_tensor = 0.18 * torch.ones((self.batch_size, 1), dtype=torch.float32, device=self.device)
|
99 |
+
# self.scale_tensor = torch.ones((self.batch_size, 1), dtype=torch.float32, device=self.device)
|
100 |
+
self.scale_tensor.requires_grad_(True)
|
101 |
+
# self.scale_tensor = torch.ones(
|
102 |
+
# (self.batch_size, 1), dtype=torch.float32,
|
103 |
+
# requires_grad=True, device=self.device)
|
104 |
+
else:
|
105 |
+
# self.scale_tensor = torch.tensor(
|
106 |
+
# scale_coeff, dtype=torch.float32,
|
107 |
+
# requires_grad=True, device=self.device)
|
108 |
+
self.scale_tensor = scale_coeff.clone().detach().requires_grad_(True)
|
109 |
+
|
110 |
+
if rot_coeff is None:
|
111 |
+
self.rot_tensor = torch.zeros(
|
112 |
+
(self.batch_size, 3), dtype=torch.float32,
|
113 |
+
requires_grad=True, device=self.device)
|
114 |
+
else:
|
115 |
+
# self.rot_tensor = torch.tensor(
|
116 |
+
# rot_coeff, dtype=torch.float32,
|
117 |
+
# requires_grad=True, device=self.device)
|
118 |
+
self.rot_tensor = rot_coeff.clone().detach().requires_grad_(True)
|
119 |
+
|
120 |
+
def get_lms(self, vs):
|
121 |
+
lms = vs[:, self.kp_inds, :]
|
122 |
+
return lms
|
123 |
+
|
124 |
+
def split_coeffs(self, coeffs):
|
125 |
+
id_coeff = coeffs[:, :self.id_dims] # identity(shape) coeff
|
126 |
+
exp_coeff = coeffs[:, self.id_dims:self.id_dims + self.exp_dims] # expression coeff
|
127 |
+
tex_coeff = coeffs[:, self.id_dims + self.exp_dims:self.all_dims] # texture(albedo) coeff
|
128 |
+
angles = coeffs[:, self.all_dims:self.all_dims + 3] # ruler angles(x,y,z) for rotation of dim 3
|
129 |
+
gamma = coeffs[:, self.all_dims + 3:self.all_dims + 30] # lighting coeff for 3 channel SH function of dim 27
|
130 |
+
translation = coeffs[:, self.all_dims + 30:-1] # translation coeff of dim 3
|
131 |
+
scale = coeffs[:, -1:]
|
132 |
+
return id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale
|
133 |
+
|
134 |
+
def merge_coeffs(self, id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale):
|
135 |
+
coeffs = torch.cat([id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale], dim=1)
|
136 |
+
return coeffs
|
137 |
+
|
138 |
+
def get_packed_tensors(self):
|
139 |
+
return self.merge_coeffs(self.id_tensor.repeat(self.batch_size, 1),
|
140 |
+
self.exp_tensor,
|
141 |
+
self.tex_tensor.repeat(self.batch_size, 1),
|
142 |
+
self.rot_tensor, self.gamma_tensor,
|
143 |
+
self.trans_tensor, self.scale_tensor)
|
144 |
+
|
145 |
+
def forward(self, coeffs=None, camT=None):
|
146 |
+
if coeffs is None:
|
147 |
+
id_coeff = self.id_tensor.repeat(self.batch_size, 1)
|
148 |
+
tex_coeff = self.tex_tensor.repeat(self.batch_size, 1)
|
149 |
+
exp_coeff, angles, gamma = self.exp_tensor, self.rot_tensor, self.gamma_tensor
|
150 |
+
translation, scale = self.rot_tensor, self.scale_tensor
|
151 |
+
else:
|
152 |
+
id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale = self.split_coeffs(coeffs)
|
153 |
+
rotation = axis_angle_to_matrix(angles)
|
154 |
+
|
155 |
+
if camT is not None:
|
156 |
+
rotation2 = camT[:3, :3].reshape(1, 3, 3)
|
157 |
+
translation2 = camT[:3, 3:].reshape(1, 3, 1)
|
158 |
+
if torch.allclose(rotation2, self.identity):
|
159 |
+
translation = translation + translation2
|
160 |
+
else:
|
161 |
+
rotation = torch.matmul(rotation2, rotation)
|
162 |
+
translation = torch.matmul(rotation2, translation) + translation2
|
163 |
+
|
164 |
+
vs = self.get_vs(id_coeff, exp_coeff)
|
165 |
+
vs_t = self.rigid_transform(vs, rotation, translation, torch.abs(scale))
|
166 |
+
|
167 |
+
lms_t = self.get_lms(vs_t)
|
168 |
+
|
169 |
+
face_texture = self.get_color(tex_coeff)
|
170 |
+
face_norm = self.compute_norm(vs, self.tri, self.point_buf)
|
171 |
+
face_norm_r = face_norm.bmm(rotation)
|
172 |
+
face_color = self.add_illumination(face_texture, face_norm_r, gamma)
|
173 |
+
|
174 |
+
return {'v': vs_t, 'lm': lms_t, 'face_texture': face_texture, 'face_color': face_color}
|
175 |
+
|
176 |
+
def forward_landmarks(self, coeffs=None, camT=None):
|
177 |
+
if coeffs is None:
|
178 |
+
id_coeff = self.id_tensor.repeat(self.batch_size, 1)
|
179 |
+
tex_coeff = self.tex_tensor.repeat(self.batch_size, 1)
|
180 |
+
exp_coeff, angles, gamma = self.exp_tensor, self.rot_tensor, self.gamma_tensor
|
181 |
+
translation, scale = self.trans_tensor, self.scale_tensor
|
182 |
+
else:
|
183 |
+
id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale = self.split_coeffs(coeffs)
|
184 |
+
rotation = axis_angle_to_matrix(angles)
|
185 |
+
|
186 |
+
if camT is not None:
|
187 |
+
rotation2 = camT[:3, :3].reshape(1, 3, 3)
|
188 |
+
translation2 = camT[:3, 3:].reshape(1, 3, 1)
|
189 |
+
if torch.allclose(rotation2, self.identity):
|
190 |
+
translation = translation + translation2
|
191 |
+
else:
|
192 |
+
rotation = torch.matmul(rotation2, rotation)
|
193 |
+
translation = torch.matmul(rotation2, translation) + translation2
|
194 |
+
lms = self.get_vs_lms(id_coeff, exp_coeff)
|
195 |
+
lms_t = self.rigid_transform(lms, rotation, translation, torch.abs(scale))
|
196 |
+
return lms_t
|
197 |
+
|
198 |
+
def get_vs(self, id_coeff, exp_coeff):
|
199 |
+
face_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + \
|
200 |
+
torch.einsum('ij,aj->ai', self.expBase, exp_coeff) + self.meanshape
|
201 |
+
face_shape = face_shape.view(self.batch_size, -1, 3)
|
202 |
+
return face_shape
|
203 |
+
|
204 |
+
def get_vs_lms(self, id_coeff, exp_coeff):
|
205 |
+
face_shape = torch.einsum('ij,aj->ai', self.idBase_view, id_coeff) + \
|
206 |
+
torch.einsum('ij,aj->ai', self.expBase_view, exp_coeff) + self.meanshape_view
|
207 |
+
face_shape = face_shape.view(self.batch_size, -1, 3)
|
208 |
+
return face_shape
|
209 |
+
|
210 |
+
def get_color(self, tex_coeff):
|
211 |
+
face_texture = torch.einsum('ij,aj->ai', self.texBase, tex_coeff) + self.meantex
|
212 |
+
face_texture = face_texture.view(self.batch_size, -1, 3)
|
213 |
+
return face_texture
|
214 |
+
|
215 |
+
def get_skinmask(self):
|
216 |
+
return self.skinmask
|
217 |
+
|
218 |
+
def compute_norm(self, vs, tri, point_buf):
|
219 |
+
face_id = tri
|
220 |
+
point_id = point_buf
|
221 |
+
v1 = vs[:, face_id[:, 0], :]
|
222 |
+
v2 = vs[:, face_id[:, 1], :]
|
223 |
+
v3 = vs[:, face_id[:, 2], :]
|
224 |
+
e1 = v1 - v2
|
225 |
+
e2 = v2 - v3
|
226 |
+
face_norm = e1.cross(e2)
|
227 |
+
|
228 |
+
v_norm = face_norm[:, point_id, :].sum(2)
|
229 |
+
v_norm = v_norm / (v_norm.norm(dim=2).unsqueeze(2) + 1e-9)
|
230 |
+
|
231 |
+
return v_norm
|
232 |
+
|
233 |
+
def add_illumination(self, face_texture, norm, gamma):
|
234 |
+
gamma = gamma.view(-1, 3, 9).clone()
|
235 |
+
gamma[:, :, 0] += 0.8
|
236 |
+
gamma = gamma.permute(0, 2, 1)
|
237 |
+
|
238 |
+
a0 = np.pi
|
239 |
+
a1 = 2 * np.pi / np.sqrt(3.0)
|
240 |
+
a2 = 2 * np.pi / np.sqrt(8.0)
|
241 |
+
c0 = 1 / np.sqrt(4 * np.pi)
|
242 |
+
c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
|
243 |
+
c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
|
244 |
+
d0 = 0.5 / np.sqrt(3.0)
|
245 |
+
|
246 |
+
norm = norm.view(-1, 3)
|
247 |
+
nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
|
248 |
+
arrH = []
|
249 |
+
|
250 |
+
arrH.append(a0 * c0 * (nx * 0 + 1))
|
251 |
+
arrH.append(-a1 * c1 * ny)
|
252 |
+
arrH.append(a1 * c1 * nz)
|
253 |
+
arrH.append(-a1 * c1 * nx)
|
254 |
+
arrH.append(a2 * c2 * nx * ny)
|
255 |
+
arrH.append(-a2 * c2 * ny * nz)
|
256 |
+
arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
|
257 |
+
arrH.append(-a2 * c2 * nx * nz)
|
258 |
+
arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))
|
259 |
+
|
260 |
+
H = torch.stack(arrH, 1)
|
261 |
+
Y = H.view(self.batch_size, face_texture.shape[1], 9)
|
262 |
+
lighting = Y.bmm(gamma)
|
263 |
+
|
264 |
+
face_color = face_texture * lighting
|
265 |
+
return face_color
|
266 |
+
|
267 |
+
def rigid_transform(self, vs, rot, trans, scale):
|
268 |
+
scale = scale.reshape(-1, 1, 1)
|
269 |
+
vs_r = torch.matmul(vs * scale, rot.permute(0, 2, 1))
|
270 |
+
vs_t = vs_r + trans.reshape(-1, 1, 3)
|
271 |
+
return vs_t
|
272 |
+
|
273 |
+
def get_rot_tensor(self):
|
274 |
+
return self.rot_tensor
|
275 |
+
|
276 |
+
def get_trans_tensor(self):
|
277 |
+
return self.trans_tensor
|
278 |
+
|
279 |
+
def get_exp_tensor(self):
|
280 |
+
return self.exp_tensor
|
281 |
+
|
282 |
+
def get_tex_tensor(self):
|
283 |
+
return self.tex_tensor
|
284 |
+
|
285 |
+
def get_id_tensor(self):
|
286 |
+
return self.id_tensor
|
287 |
+
|
288 |
+
def get_gamma_tensor(self):
|
289 |
+
return self.gamma_tensor
|
290 |
+
|
291 |
+
def get_scale_tensor(self):
|
292 |
+
return self.scale_tensor
|
render_utils/lib/networks/smpl_torch.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pickle
|
4 |
+
import numpy as np
|
5 |
+
import scipy.sparse
|
6 |
+
import os
|
7 |
+
|
8 |
+
import smplx
|
9 |
+
import smplx.lbs
|
10 |
+
|
11 |
+
from ..utils.geometry import rodrigues
|
12 |
+
def _convert_to_sparse(mat):
|
13 |
+
if isinstance(mat, np.ndarray):
|
14 |
+
row_ind, col_ind = np.asarray(mat > 0).nonzero()
|
15 |
+
data = [mat[r, c] for (r, c) in zip(row_ind, col_ind)]
|
16 |
+
mat_sparse = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=mat.shape)
|
17 |
+
return mat_sparse
|
18 |
+
else:
|
19 |
+
return mat
|
20 |
+
|
21 |
+
|
22 |
+
class SmplTorch(nn.Module):
|
23 |
+
def __init__(self, model_file='./data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'):
|
24 |
+
super(SmplTorch, self).__init__()
|
25 |
+
if model_file.endswith('.pkl'):
|
26 |
+
with open(model_file, 'rb') as f:
|
27 |
+
smpl_model = pickle.load(f, encoding='iso-8859-1')
|
28 |
+
# print('SmplTorch: Loading SMPL data from %s' % model_file)
|
29 |
+
elif model_file.endswith('.npz'):
|
30 |
+
smpl_model = np.load(model_file)
|
31 |
+
# print('SmplTorch: Loading SMPL data from %s' % model_file)
|
32 |
+
else:
|
33 |
+
raise ValueError('Unknown file extension: %s' % model_file)
|
34 |
+
|
35 |
+
J_regressor = _convert_to_sparse(smpl_model['J_regressor']).tocoo()
|
36 |
+
row = J_regressor.row
|
37 |
+
col = J_regressor.col
|
38 |
+
data = J_regressor.data
|
39 |
+
i = torch.LongTensor([row, col])
|
40 |
+
v = torch.FloatTensor(data)
|
41 |
+
J_regressor_shape = J_regressor.shape
|
42 |
+
self.v_num = len(smpl_model['v_template'])
|
43 |
+
self.joint_num = smpl_model['weights'].shape[1]
|
44 |
+
self.beta_dim = 10
|
45 |
+
|
46 |
+
self.register_buffer('J_regressor', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense())
|
47 |
+
self.register_buffer('weights', torch.FloatTensor(smpl_model['weights']))
|
48 |
+
self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs']))
|
49 |
+
self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template']))
|
50 |
+
self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs'][:, :, :self.beta_dim])))
|
51 |
+
self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64)))
|
52 |
+
self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64)))
|
53 |
+
id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
|
54 |
+
self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
|
55 |
+
|
56 |
+
self.pose_shape = [self.joint_num, 3]
|
57 |
+
self.beta_shape = [self.beta_dim]
|
58 |
+
self.translation_shape = [3]
|
59 |
+
|
60 |
+
self.pose = torch.zeros(self.pose_shape)
|
61 |
+
self.beta = torch.zeros(self.beta_shape)
|
62 |
+
self.translation = torch.zeros(self.translation_shape)
|
63 |
+
|
64 |
+
self.verts = None
|
65 |
+
|
66 |
+
# # BODY25 joint regressor
|
67 |
+
# with open(os.path.join(os.path.dirname(model_file), 'J_regressor_body25.pkl'), 'rb') as f:
|
68 |
+
# jreg_data = pickle.load(f, encoding='iso-8859-1')
|
69 |
+
# J_regressor25 = jreg_data['J_regressor_body25'].tocoo()
|
70 |
+
# row = J_regressor25.row
|
71 |
+
# col = J_regressor25.col
|
72 |
+
# data = J_regressor25.data
|
73 |
+
# i = torch.LongTensor([row, col])
|
74 |
+
# v = torch.FloatTensor(data)
|
75 |
+
# J_regressor_shape = J_regressor25.shape
|
76 |
+
# self.register_buffer('J_regressor_body25', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense())
|
77 |
+
|
78 |
+
def forward(self, pose, beta, apply_pose_blend_shape=True):
|
79 |
+
device = pose.device
|
80 |
+
batch_size = pose.shape[0]
|
81 |
+
v_template = self.v_template[None, :]
|
82 |
+
shapedirs = self.shapedirs.view(-1, self.beta_dim)[None, :].expand(batch_size, -1, -1)
|
83 |
+
beta = beta[:, :, None]
|
84 |
+
v_shaped = torch.matmul(shapedirs, beta).view(-1, self.v_num, 3) + v_template
|
85 |
+
# batched sparse matmul not supported in pytorch
|
86 |
+
J = []
|
87 |
+
for i in range(batch_size):
|
88 |
+
J.append(torch.matmul(self.J_regressor, v_shaped[i]))
|
89 |
+
J = torch.stack(J, dim=0)
|
90 |
+
# input it rotmat: (bs,24,3,3)
|
91 |
+
if pose.ndimension() == 4:
|
92 |
+
R = pose
|
93 |
+
# input it rotmat: (bs,72)
|
94 |
+
elif pose.ndimension() == 2:
|
95 |
+
pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3)
|
96 |
+
R = rodrigues(pose_cube).view(batch_size, self.joint_num, 3, 3)
|
97 |
+
R = R.view(batch_size, self.joint_num, 3, 3)
|
98 |
+
|
99 |
+
if apply_pose_blend_shape:
|
100 |
+
I_cube = torch.eye(3)[None, None, :].to(device)
|
101 |
+
# I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1)
|
102 |
+
lrotmin = (R[:, 1:, :] - I_cube).view(batch_size, -1)
|
103 |
+
posedirs = self.posedirs.view(-1, 9*(self.joint_num-1))[None, :].expand(batch_size, -1, -1)
|
104 |
+
v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, self.v_num, 3)
|
105 |
+
else:
|
106 |
+
v_posed = v_shaped
|
107 |
+
|
108 |
+
J_ = J.clone()
|
109 |
+
J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
|
110 |
+
G_ = torch.cat([R, J_[:, :, :, None]], dim=-1)
|
111 |
+
pad_row = torch.FloatTensor([0, 0, 0, 1]).to(device).view(1, 1, 1, 4).expand(batch_size, self.joint_num, -1, -1)
|
112 |
+
G_ = torch.cat([G_, pad_row], dim=2)
|
113 |
+
G = [G_[:, 0].clone()]
|
114 |
+
for i in range(1, self.joint_num):
|
115 |
+
G.append(torch.matmul(G[self.parent[i - 1]], G_[:, i, :, :]))
|
116 |
+
G = torch.stack(G, dim=1)
|
117 |
+
|
118 |
+
rest = torch.cat([J, torch.zeros(batch_size, self.joint_num, 1).to(device)], dim=2).view(batch_size, self.joint_num, 4, 1)
|
119 |
+
zeros = torch.zeros(batch_size, self.joint_num, 4, 3).to(device)
|
120 |
+
rest = torch.cat([zeros, rest], dim=-1)
|
121 |
+
rest = torch.matmul(G, rest)
|
122 |
+
G = G - rest
|
123 |
+
T = torch.matmul(self.weights, G.permute(1, 0, 2, 3).contiguous().view(self.joint_num, -1)).view(self.v_num, batch_size, 4, 4).transpose(0, 1)
|
124 |
+
rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1)
|
125 |
+
v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
|
126 |
+
|
127 |
+
return v, {'J': J, 'G': G, 'T': T, 'R': R, 'v_shaped': v_shaped, 'v_posed': v_posed}
|
128 |
+
|
129 |
+
def get_joints(self, vertices):
|
130 |
+
"""
|
131 |
+
This method is used to get the joint locations from the SMPL mesh
|
132 |
+
Input:
|
133 |
+
vertices: size = (B, self.v_num, 3)
|
134 |
+
Output:
|
135 |
+
3D joints: size = (B, 38, 3)
|
136 |
+
"""
|
137 |
+
joints = torch.einsum('bik,ji->bjk', vertices, self.J_regressor)
|
138 |
+
return joints
|
139 |
+
|
140 |
+
def get_root(self, vertices):
|
141 |
+
"""
|
142 |
+
This method is used to get the root locations from the SMPL mesh
|
143 |
+
Input:
|
144 |
+
vertices: size = (B, self.v_num, 3)
|
145 |
+
Output:
|
146 |
+
3D joints: size = (B, 1, 3)
|
147 |
+
"""
|
148 |
+
joints = torch.einsum('bik,ji->bjk', vertices, self.J_regressor)
|
149 |
+
return joints[:, 0:1, :]
|
150 |
+
|
151 |
+
# def get_joints_body25(self, vertices):
|
152 |
+
# joints = torch.einsum('bik,ji->bjk', vertices, self.J_regressor_body25)
|
153 |
+
# return joints
|
154 |
+
|
155 |
+
|
156 |
+
class SmplProjector(nn.Module):
|
157 |
+
def __init__(self, verts_cano=None, face_triangles=None):
|
158 |
+
super(SmplProjector, self).__init__()
|
159 |
+
self.verts_cano = verts_cano
|
160 |
+
self.face_triangles = face_triangles
|
161 |
+
|
162 |
+
face_node_idx = np.loadtxt('./data/smplx/face_coarse_level_ids.txt').astype(np.int64)
|
163 |
+
node_to_face_table = np.loadtxt('./data/smplx/face_coarse_level_to_all_table.txt').astype(np.int64)
|
164 |
+
self.register_buffer('face_node_idx', torch.from_numpy(face_node_idx), persistent=False)
|
165 |
+
self.register_buffer('node_to_face_table', torch.from_numpy(node_to_face_table), persistent=False)
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def batch_select(tensor, index, dim):
|
169 |
+
"""
|
170 |
+
Perform index_select for batched inputs, where the index tensors of different items are different
|
171 |
+
:param tensor: [B, G, ....]
|
172 |
+
:param index: [B, N], which is the key difference from torch.index_select()
|
173 |
+
:param dim: int
|
174 |
+
:return:
|
175 |
+
"""
|
176 |
+
res = []
|
177 |
+
for bi in range(tensor.shape[0]):
|
178 |
+
res.append(
|
179 |
+
torch.index_select(tensor[bi], dim=dim-1, index=index[bi])
|
180 |
+
)
|
181 |
+
res = torch.stack(res, dim=0)
|
182 |
+
return res
|
183 |
+
|
184 |
+
def get_current_beta(self):
|
185 |
+
return self.density_func.get_beta()
|
186 |
+
|
187 |
+
def calculate_nearest_barycentric(self, verts, pts):
|
188 |
+
nearest_face_id, nearest_dist, nearest_point_barycentric = calc_nearest_barycentric_coord(
|
189 |
+
verts, self.face_triangles, pts, self.face_node_idx, self.node_to_face_table)
|
190 |
+
return nearest_face_id, nearest_dist, nearest_point_barycentric
|
191 |
+
|
192 |
+
def interpolate_vert_feat(self, vert_feat, pts, nearest_face_id, nearest_dist, nearest_point_barycentric):
|
193 |
+
fvfa = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 0]) # [B, G, 3]
|
194 |
+
fvfb = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 1]) # [B, G, 3]
|
195 |
+
fvfc = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 2]) # [B, G, 3]
|
196 |
+
fvfs = torch.cat([fvfa.unsqueeze(2), fvfb.unsqueeze(2), fvfc.unsqueeze(2)], dim=2) # [B, G, 3, 3]
|
197 |
+
nearest_fvfs = self.batch_select(fvfs, nearest_face_id, dim=1)
|
198 |
+
nearest_vf = torch.sum(nearest_point_barycentric.unsqueeze(-1)*nearest_fvfs, dim=-2) # [B, N, 3]
|
199 |
+
return nearest_vf
|
200 |
+
|
201 |
+
def calculate_nearest_point(self, verts, vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric):
|
202 |
+
verts_cano = self.verts_cano.unsqueeze(0).expand(verts.shape[0], -1, -1)
|
203 |
+
nearest_v = self.interpolate_vert_feat(verts, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
|
204 |
+
nearest_nml = self.interpolate_vert_feat(vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
|
205 |
+
nearest_vc = self.interpolate_vert_feat(verts_cano, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
|
206 |
+
return nearest_v, nearest_nml, nearest_vc
|
207 |
+
|
208 |
+
def forward(self, verts, vert_nmls, pts, extra_vert_feat=None):
|
209 |
+
nearest_face_id, nearest_dist, nearest_point_barycentric = self.calculate_nearest_barycentric(verts, pts)
|
210 |
+
nearest_v, nearest_nml, nearest_vc = self.calculate_nearest_point(verts, vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
|
211 |
+
nearest_dist = nearest_dist.unsqueeze(-1)
|
212 |
+
nearest_h = torch.sum((pts-nearest_v)*nearest_nml, dim=-1, keepdim=True) # [B, N, 1]
|
213 |
+
sdf = nearest_h.sign() * nearest_dist
|
214 |
+
if extra_vert_feat is None:
|
215 |
+
return sdf, nearest_v, nearest_nml, nearest_vc
|
216 |
+
else:
|
217 |
+
if len(extra_vert_feat.shape) == 2:
|
218 |
+
extra_vert_feat = extra_vert_feat.unsqueeze(0).expand(verts.shape[0], -1, -1)
|
219 |
+
nearest_feat = self.interpolate_vert_feat(extra_vert_feat, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
|
220 |
+
return sdf, nearest_v, nearest_nml, nearest_vc, nearest_feat
|
221 |
+
|
222 |
+
|
223 |
+
class SmplProjectorBF(nn.Module):
|
224 |
+
"""
|
225 |
+
Calculate barycentric projection in a brute-force manner
|
226 |
+
"""
|
227 |
+
def __init__(self, verts_cano=None, face_triangles=None):
|
228 |
+
super(SmplProjectorBF, self).__init__()
|
229 |
+
self.verts_cano = verts_cano
|
230 |
+
self.face_triangles = face_triangles
|
231 |
+
|
232 |
+
@staticmethod
|
233 |
+
def batch_select(tensor, index, dim):
|
234 |
+
"""
|
235 |
+
Perform index_select for batched inputs, where the index tensors of different items are different
|
236 |
+
:param tensor: [B, G, ....]
|
237 |
+
:param index: [B, N], which is the key difference from torch.index_select()
|
238 |
+
:param dim: int
|
239 |
+
:return:
|
240 |
+
"""
|
241 |
+
res = []
|
242 |
+
for bi in range(tensor.shape[0]):
|
243 |
+
res.append(
|
244 |
+
torch.index_select(tensor[bi], dim=dim-1, index=index[bi])
|
245 |
+
)
|
246 |
+
res = torch.stack(res, dim=0)
|
247 |
+
return res
|
248 |
+
|
249 |
+
def get_current_beta(self):
|
250 |
+
return self.density_func.get_beta()
|
251 |
+
|
252 |
+
def calculate_nearest_barycentric(self, verts, pts):
|
253 |
+
nearest_face_id, nearest_dist, nearest_point_barycentric = calc_nearest_barycentric_coord_bf(
|
254 |
+
verts, self.face_triangles, pts)
|
255 |
+
return nearest_face_id, nearest_dist, nearest_point_barycentric
|
256 |
+
|
257 |
+
def interpolate_vert_feat(self, vert_feat, pts, nearest_face_id, nearest_dist, nearest_point_barycentric):
|
258 |
+
fvfa = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 0]) # [B, G, 3]
|
259 |
+
fvfb = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 1]) # [B, G, 3]
|
260 |
+
fvfc = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 2]) # [B, G, 3]
|
261 |
+
fvfs = torch.cat([fvfa.unsqueeze(2), fvfb.unsqueeze(2), fvfc.unsqueeze(2)], dim=2) # [B, G, 3, 3]
|
262 |
+
nearest_fvfs = self.batch_select(fvfs, nearest_face_id, dim=1)
|
263 |
+
nearest_vf = torch.sum(nearest_point_barycentric.unsqueeze(-1)*nearest_fvfs, dim=-2) # [B, N, 3]
|
264 |
+
return nearest_vf
|
265 |
+
|
266 |
+
def calculate_nearest_point(self, verts, vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric):
|
267 |
+
verts_cano = self.verts_cano.unsqueeze(0).expand(verts.shape[0], -1, -1)
|
268 |
+
nearest_v = self.interpolate_vert_feat(verts, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
|
269 |
+
nearest_nml = self.interpolate_vert_feat(vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
|
270 |
+
nearest_vc = self.interpolate_vert_feat(verts_cano, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
|
271 |
+
return nearest_v, nearest_nml, nearest_vc
|
272 |
+
|
273 |
+
def forward(self, verts, vert_nmls, pts):
|
274 |
+
nearest_face_id, nearest_dist, nearest_point_barycentric = self.calculate_nearest_barycentric(verts, pts)
|
275 |
+
nearest_v, nearest_nml, nearest_vc = self.calculate_nearest_point(verts, vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
|
276 |
+
nearest_dist = nearest_dist.unsqueeze(-1)
|
277 |
+
nearest_h = torch.sum((pts-nearest_v)*nearest_nml, dim=-1, keepdim=True) # [B, N, 1]
|
278 |
+
sdf = nearest_h.sign() * nearest_dist
|
279 |
+
return sdf, nearest_v, nearest_nml, nearest_vc
|
280 |
+
|
281 |
+
|
282 |
+
class SmplJointRegressor(nn.Module):
|
283 |
+
def __init__(self, model_file, extra_regressor_file):
|
284 |
+
super().__init__()
|
285 |
+
# '../smpl_models/J_regressor_extra_' + 'smplx' + '.npy'
|
286 |
+
|
287 |
+
self.smpl = smplx.SMPLX(model_path=model_file, gender='neutral', use_pca=False,
|
288 |
+
num_pca_comps=45, flat_hand_mean=True, batch_size=1)
|
289 |
+
self.extra_jregressor = torch.tensor(np.load(extra_regressor_file), dtype=torch.float32)
|
290 |
+
self.smpl2coco = {
|
291 |
+
'body': torch.LongTensor([55, 57, 56, 59, 58, 16, 17, 18, 19, 20, 21, 128, 127, 4, 5, 7, 8]),
|
292 |
+
'foot': torch.LongTensor(np.arange(60, 66, dtype=int)),
|
293 |
+
'face': torch.LongTensor(np.arange(76, 127, dtype=int)),
|
294 |
+
'lhand': torch.LongTensor(
|
295 |
+
[20, 37, 38, 39, 66, 25, 26, 27, 67, 28, 29, 30, 68, 34, 35, 36, 69, 31, 32, 33, 70]),
|
296 |
+
'rhand': torch.LongTensor(
|
297 |
+
[21, 52, 53, 54, 71, 40, 41, 42, 72, 43, 44, 45, 73, 49, 50, 51, 74, 46, 47, 48, 75])
|
298 |
+
}
|
299 |
+
|
300 |
+
def to(self, device):
|
301 |
+
super().to(device)
|
302 |
+
self.smpl = self.smpl.to(device)
|
303 |
+
self.smpl2coco = {k: v.to(device) for k, v in self.smpl2coco.items()}
|
304 |
+
self.extra_jregressor = self.extra_jregressor.to(device)
|
305 |
+
return self
|
306 |
+
|
307 |
+
def get_all_joints(self, verts):
|
308 |
+
lmk_faces_idx = self.smpl.lmk_faces_idx.unsqueeze(dim=0).expand(verts.shape[0], -1).contiguous()
|
309 |
+
lmk_bary_coords = self.smpl.lmk_bary_coords.unsqueeze(dim=0).repeat(verts.shape[0], 1, 1)
|
310 |
+
landmarks = smplx.lbs.vertices2landmarks(verts, self.smpl.faces_tensor, lmk_faces_idx, lmk_bary_coords)
|
311 |
+
|
312 |
+
joints = torch.einsum('bik,ji->bjk', verts, self.smpl.J_regressor)
|
313 |
+
joints = self.smpl.vertex_joint_selector(verts, joints)
|
314 |
+
joints = torch.cat([joints, landmarks], dim=1)
|
315 |
+
if self.smpl.joint_mapper is not None:
|
316 |
+
joints = self.joint_mapper(joints=joints, vertices=verts)
|
317 |
+
|
318 |
+
extra_joints = torch.einsum('bik,ji->bjk', verts, self.extra_jregressor)
|
319 |
+
joints = torch.cat([joints, extra_joints], dim=1)
|
320 |
+
|
321 |
+
return joints
|
322 |
+
|
323 |
+
def forward(self, verts, cam_param=None):
|
324 |
+
joints = self.get_all_joints(verts)
|
325 |
+
|
326 |
+
if cam_param:
|
327 |
+
if len(cam_param['K'].shape) == 3:
|
328 |
+
smpl_j2d = torch.einsum("ljk,lik->lij", cam_param['K'], torch.einsum(
|
329 |
+
"ijk,lk->ilj", cam_param['R'], joints.reshape(-1, 3)) + cam_param['T'].reshape(-1, 1, 3))
|
330 |
+
else:
|
331 |
+
smpl_j2d = torch.einsum("jk,lik->lij", cam_param['K'], torch.einsum(
|
332 |
+
"ijk,lk->ilj", cam_param['R'], joints.reshape(-1, 3)) + cam_param['T'].reshape(-1, 1, 3))
|
333 |
+
valid = smpl_j2d[:, :, 2] > 1e-3
|
334 |
+
smpl_j2d[:, :, :2][valid] = smpl_j2d[:, :, :2][valid] / smpl_j2d[:, :, 2:3][valid]
|
335 |
+
smpl_j2d[:, :, 2] = valid.float()
|
336 |
+
skel_2d = {k: torch.index_select(smpl_j2d, dim=1, index=v) for k, v in self.smpl2coco.items()}
|
337 |
+
else:
|
338 |
+
skel_2d = None
|
339 |
+
|
340 |
+
skel = {k: torch.index_select(joints, dim=1, index=v) for k, v in self.smpl2coco.items()}
|
341 |
+
return skel, skel_2d
|
render_utils/lib/utils/__init__.py
ADDED
File without changes
|
render_utils/lib/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (193 Bytes). View file
|
|
render_utils/lib/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (186 Bytes). View file
|
|
render_utils/lib/utils/__pycache__/gaussian_np_utils.cpython-310.pyc
ADDED
Binary file (6.28 kB). View file
|
|
render_utils/lib/utils/__pycache__/gaussian_np_utils.cpython-38.pyc
ADDED
Binary file (6.54 kB). View file
|
|
render_utils/lib/utils/__pycache__/geometry.cpython-310.pyc
ADDED
Binary file (16.2 kB). View file
|
|
render_utils/lib/utils/__pycache__/geometry.cpython-38.pyc
ADDED
Binary file (16 kB). View file
|
|
render_utils/lib/utils/__pycache__/graphics_utils.cpython-310.pyc
ADDED
Binary file (5.44 kB). View file
|
|
render_utils/lib/utils/__pycache__/graphics_utils.cpython-38.pyc
ADDED
Binary file (5.42 kB). View file
|
|
render_utils/lib/utils/__pycache__/rotation_conversions.cpython-310.pyc
ADDED
Binary file (17.1 kB). View file
|
|
render_utils/lib/utils/__pycache__/rotation_conversions.cpython-38.pyc
ADDED
Binary file (17.1 kB). View file
|
|
render_utils/lib/utils/__pycache__/sh_utils.cpython-310.pyc
ADDED
Binary file (2.58 kB). View file
|
|
render_utils/lib/utils/__pycache__/sh_utils.cpython-38.pyc
ADDED
Binary file (2.63 kB). View file
|
|
render_utils/lib/utils/gaussian_np_utils.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from plyfile import PlyData, PlyElement
|
4 |
+
from typing import NamedTuple
|
5 |
+
import smplx
|
6 |
+
import tqdm
|
7 |
+
import cv2 as cv
|
8 |
+
import os
|
9 |
+
|
10 |
+
from scipy.spatial.transform import Rotation as R
|
11 |
+
|
12 |
+
|
13 |
+
class GaussianAttributes(NamedTuple):
|
14 |
+
xyz: np.ndarray
|
15 |
+
opacities: np.ndarray
|
16 |
+
features_dc: np.ndarray
|
17 |
+
features_extra: np.ndarray
|
18 |
+
scales: np.ndarray
|
19 |
+
rot: np.ndarray
|
20 |
+
|
21 |
+
def load_gaussians_from_ply(path):
|
22 |
+
max_sh_degree = 3
|
23 |
+
plydata = PlyData.read(path)
|
24 |
+
|
25 |
+
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
26 |
+
np.asarray(plydata.elements[0]["y"]),
|
27 |
+
np.asarray(plydata.elements[0]["z"])), axis=1)
|
28 |
+
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
29 |
+
|
30 |
+
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
31 |
+
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
32 |
+
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
|
33 |
+
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
|
34 |
+
|
35 |
+
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
|
36 |
+
extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split('_')[-1]))
|
37 |
+
assert len(extra_f_names) == 3 * (max_sh_degree + 1) ** 2 - 3
|
38 |
+
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
39 |
+
for idx, attr_name in enumerate(extra_f_names):
|
40 |
+
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
41 |
+
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
|
42 |
+
features_extra = features_extra.reshape((features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1))
|
43 |
+
|
44 |
+
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
45 |
+
scale_names = sorted(scale_names, key=lambda x: int(x.split('_')[-1]))
|
46 |
+
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
47 |
+
for idx, attr_name in enumerate(scale_names):
|
48 |
+
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
49 |
+
|
50 |
+
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
|
51 |
+
rot_names = sorted(rot_names, key=lambda x: int(x.split('_')[-1]))
|
52 |
+
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
53 |
+
for idx, attr_name in enumerate(rot_names):
|
54 |
+
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
55 |
+
|
56 |
+
return GaussianAttributes(xyz, opacities, features_dc, features_extra, scales, rots)
|
57 |
+
|
58 |
+
|
59 |
+
def construct_list_of_attributes(_features_dc, _features_rest, _scaling, _rotation):
|
60 |
+
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
61 |
+
# All channels except the 3 DC
|
62 |
+
for i in range(_features_dc.shape[1] * _features_dc.shape[2]):
|
63 |
+
l.append('f_dc_{}'.format(i))
|
64 |
+
for i in range(_features_rest.shape[1] * _features_rest.shape[2]):
|
65 |
+
l.append('f_rest_{}'.format(i))
|
66 |
+
l.append('opacity')
|
67 |
+
for i in range(_scaling.shape[1]):
|
68 |
+
l.append('scale_{}'.format(i))
|
69 |
+
for i in range(_rotation.shape[1]):
|
70 |
+
l.append('rot_{}'.format(i))
|
71 |
+
return l
|
72 |
+
|
73 |
+
|
74 |
+
def select_gaussians(gaussian_attrs, select_mask_or_idx):
|
75 |
+
return GaussianAttributes(
|
76 |
+
xyz=gaussian_attrs.xyz[select_mask_or_idx],
|
77 |
+
opacities=gaussian_attrs.opacities[select_mask_or_idx],
|
78 |
+
features_dc=gaussian_attrs.features_dc[select_mask_or_idx],
|
79 |
+
features_extra=gaussian_attrs.features_extra[select_mask_or_idx],
|
80 |
+
scales=gaussian_attrs.scales[select_mask_or_idx],
|
81 |
+
rot=gaussian_attrs.rot[select_mask_or_idx]
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
def combine_gaussians(gaussian_attrs_list):
|
86 |
+
return GaussianAttributes(
|
87 |
+
xyz=np.concatenate([gau.xyz for gau in gaussian_attrs_list], axis=0),
|
88 |
+
opacities=np.concatenate([gau.opacities for gau in gaussian_attrs_list], axis=0),
|
89 |
+
features_dc=np.concatenate([gau.features_dc for gau in gaussian_attrs_list], axis=0),
|
90 |
+
features_extra=np.concatenate([gau.features_extra for gau in gaussian_attrs_list], axis=0),
|
91 |
+
scales=np.concatenate([gau.scales for gau in gaussian_attrs_list], axis=0),
|
92 |
+
rot=np.concatenate([gau.rot for gau in gaussian_attrs_list], axis=0),
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def apply_transformation_to_gaussians(gaussian_attrs, spatial_transformation, color_transformation=None):
|
97 |
+
xyzs = np.copy(gaussian_attrs.xyz)
|
98 |
+
xyzs = np.matmul(xyzs, spatial_transformation[:3, :3].transpose()) + spatial_transformation[:3, 3].reshape([1, 3])
|
99 |
+
|
100 |
+
gaussian_rotmats = R.from_quat(gaussian_attrs.rot[:, (1, 2, 3, 0)]).as_matrix()
|
101 |
+
new_rots = []
|
102 |
+
for rotmat in gaussian_rotmats:
|
103 |
+
rotmat = np.matmul(spatial_transformation[:3, :3], rotmat)
|
104 |
+
rotq = R.from_matrix(rotmat).as_quat()
|
105 |
+
rotq = np.array([rotq[3], rotq[0], rotq[1], rotq[2]])
|
106 |
+
new_rots.append(rotq)
|
107 |
+
new_rots = np.stack(new_rots, axis=0)
|
108 |
+
if color_transformation is not None:
|
109 |
+
if color_transformation.shape[0] == 3 and color_transformation.shape[1] == 3:
|
110 |
+
new_clrs = np.matmul(gaussian_attrs.features_dc[:, :, 0], color_transformation)[:, :, np.newaxis]
|
111 |
+
elif color_transformation.shape[0] == 4 and color_transformation.shape[1] == 4:
|
112 |
+
clrs = gaussian_attrs.features_dc[:, :, 0]
|
113 |
+
clrs = np.concatenate([clrs, np.ones_like(clrs[:, :1])], axis=1)
|
114 |
+
new_clrs = np.matmul(clrs, color_transformation)
|
115 |
+
new_clrs = new_clrs[:, :3, np.newaxis]
|
116 |
+
else:
|
117 |
+
new_clrs = gaussian_attrs.features_dc
|
118 |
+
|
119 |
+
return GaussianAttributes(
|
120 |
+
xyz=xyzs,
|
121 |
+
opacities=gaussian_attrs.opacities,
|
122 |
+
features_dc=new_clrs,
|
123 |
+
features_extra=gaussian_attrs.features_extra,
|
124 |
+
scales=gaussian_attrs.scales,
|
125 |
+
rot=new_rots,
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
def update_gaussian_attributes(
|
130 |
+
orig_gaussian,
|
131 |
+
new_xyz=None, new_rgb=None, new_rot=None, new_opacity=None, new_scale=None):
|
132 |
+
return GaussianAttributes(
|
133 |
+
xyz=orig_gaussian.xyz if new_xyz is None else new_xyz,
|
134 |
+
opacities=orig_gaussian.opacities if new_opacity is None else new_opacity,
|
135 |
+
features_dc=orig_gaussian.features_dc if new_rgb is None else new_rgb,
|
136 |
+
features_extra=orig_gaussian.features_extra,
|
137 |
+
scales=orig_gaussian.scales if new_scale is None else new_scale,
|
138 |
+
rot=orig_gaussian.rot if new_rot is None else new_rot,
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
def save_gaussians_as_ply(path, gaussian_attrs: GaussianAttributes):
|
143 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
144 |
+
|
145 |
+
xyz = gaussian_attrs.xyz
|
146 |
+
normals = np.zeros_like(xyz)
|
147 |
+
features_dc = gaussian_attrs.features_dc
|
148 |
+
features_rest = gaussian_attrs.features_extra
|
149 |
+
opacities = gaussian_attrs.opacities
|
150 |
+
scale = gaussian_attrs.scales
|
151 |
+
rotation = gaussian_attrs.rot
|
152 |
+
|
153 |
+
dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes(features_dc, features_rest, scale, rotation)]
|
154 |
+
|
155 |
+
elements = np.empty(xyz.shape[0], dtype = dtype_full)
|
156 |
+
attributes = np.concatenate((xyz, normals, features_dc.reshape(features_dc.shape[0], -1),
|
157 |
+
features_rest.reshape(features_rest.shape[0], -1),
|
158 |
+
opacities, scale, rotation), axis=1)
|
159 |
+
elements[:] = list(map(tuple, attributes))
|
160 |
+
el = PlyElement.describe(elements, 'vertex')
|
161 |
+
PlyData([el]).write(path)
|
162 |
+
return
|
render_utils/lib/utils/geometry.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import trimesh
|
7 |
+
import scipy.sparse as sp
|
8 |
+
import collections
|
9 |
+
|
10 |
+
|
11 |
+
def rodrigues(theta):
|
12 |
+
"""Convert axis-angle representation to rotation matrix.
|
13 |
+
Args:
|
14 |
+
theta: size = [B, 3]
|
15 |
+
Returns:
|
16 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
17 |
+
"""
|
18 |
+
l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
|
19 |
+
angle = torch.unsqueeze(l1norm, -1)
|
20 |
+
normalized = torch.div(theta, angle)
|
21 |
+
angle = angle * 0.5
|
22 |
+
v_cos = torch.cos(angle)
|
23 |
+
v_sin = torch.sin(angle)
|
24 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
|
25 |
+
return quat2mat(quat)
|
26 |
+
|
27 |
+
|
28 |
+
def quat2mat(quat):
|
29 |
+
"""Convert quaternion coefficients to rotation matrix.
|
30 |
+
Args:
|
31 |
+
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
32 |
+
Returns:
|
33 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
34 |
+
"""
|
35 |
+
norm_quat = quat
|
36 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
37 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
|
38 |
+
|
39 |
+
B = quat.size(0)
|
40 |
+
|
41 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
42 |
+
wx, wy, wz = w * x, w * y, w * z
|
43 |
+
xy, xz, yz = x * y, x * z, y * z
|
44 |
+
|
45 |
+
rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
|
46 |
+
2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
|
47 |
+
2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
|
48 |
+
return rotMat
|
49 |
+
|
50 |
+
|
51 |
+
def inv_4x4(mats):
|
52 |
+
"""Calculate the inverse of homogeneous transformations
|
53 |
+
:param mats: [B, 4, 4]
|
54 |
+
:return:
|
55 |
+
"""
|
56 |
+
Rs = mats[:, :3, :3]
|
57 |
+
ts = mats[:, :3, 3:]
|
58 |
+
# R_invs = torch.transpose(Rs, 1, 2)
|
59 |
+
R_invs = torch.inverse(Rs)
|
60 |
+
t_invs = -torch.matmul(R_invs, ts)
|
61 |
+
Rt_invs = torch.cat([R_invs, t_invs], dim=-1) # [B, 3, 4]
|
62 |
+
|
63 |
+
device = R_invs.device
|
64 |
+
pad_row = torch.FloatTensor([0, 0, 0, 1]).to(device).view(1, 1, 4).expand(Rs.shape[0], -1, -1)
|
65 |
+
mat_invs = torch.cat([Rt_invs, pad_row], dim=1)
|
66 |
+
return mat_invs
|
67 |
+
|
68 |
+
|
69 |
+
def as_mesh(scene_or_mesh):
|
70 |
+
"""
|
71 |
+
Convert a possible scene to a mesh.
|
72 |
+
|
73 |
+
If conversion occurs, the returned mesh has only vertex and face data.
|
74 |
+
"""
|
75 |
+
if isinstance(scene_or_mesh, trimesh.Scene):
|
76 |
+
if len(scene_or_mesh.geometry) == 0:
|
77 |
+
mesh = None # empty scene
|
78 |
+
else:
|
79 |
+
# we lose texture information here
|
80 |
+
mesh = trimesh.util.concatenate(
|
81 |
+
tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
|
82 |
+
for g in scene_or_mesh.geometry.values()))
|
83 |
+
else:
|
84 |
+
assert(isinstance(scene_or_mesh, trimesh.Trimesh))
|
85 |
+
mesh = scene_or_mesh
|
86 |
+
return mesh
|
87 |
+
|
88 |
+
|
89 |
+
def get_edge_unique(faces):
|
90 |
+
"""
|
91 |
+
Parameters
|
92 |
+
------------
|
93 |
+
faces: n x 3 int array
|
94 |
+
Should be from a watertight mesh without degenerated triangles and intersection
|
95 |
+
"""
|
96 |
+
faces = np.asanyarray(faces)
|
97 |
+
|
98 |
+
# each face has three edges
|
99 |
+
edges = faces[:, [0, 1, 1, 2, 2, 0]].reshape((-1, 2))
|
100 |
+
flags = edges[:, 0] < edges[:, 1]
|
101 |
+
edges = edges[flags]
|
102 |
+
return edges
|
103 |
+
|
104 |
+
|
105 |
+
def get_neighbors(edges):
|
106 |
+
neighbors = collections.defaultdict(set)
|
107 |
+
[(neighbors[edge[0]].add(edge[1]),
|
108 |
+
neighbors[edge[1]].add(edge[0]))
|
109 |
+
for edge in edges]
|
110 |
+
|
111 |
+
max_index = edges.max() + 1
|
112 |
+
array = [list(neighbors[i]) for i in range(max_index)]
|
113 |
+
|
114 |
+
return array
|
115 |
+
|
116 |
+
|
117 |
+
def construct_degree_matrix(vnum, faces):
|
118 |
+
row = col = list(range(vnum))
|
119 |
+
value = [0] * vnum
|
120 |
+
es = get_edge_unique(faces)
|
121 |
+
for e in es:
|
122 |
+
if e[0] < e[1]:
|
123 |
+
value[e[0]] += 1
|
124 |
+
value[e[0]] += 1
|
125 |
+
|
126 |
+
dm = sp.coo_matrix((value, (row, col)), shape=(vnum, vnum), dtype=np.float32)
|
127 |
+
return dm
|
128 |
+
|
129 |
+
|
130 |
+
def construct_neighborhood_matrix(vnum, faces):
|
131 |
+
row = list()
|
132 |
+
col = list()
|
133 |
+
value = list()
|
134 |
+
es = get_edge_unique(faces)
|
135 |
+
for e in es:
|
136 |
+
if e[0] < e[1]:
|
137 |
+
row.append(e[0])
|
138 |
+
col.append(e[1])
|
139 |
+
value.append(1)
|
140 |
+
row.append(e[1])
|
141 |
+
col.append(e[0])
|
142 |
+
value.append(1)
|
143 |
+
|
144 |
+
nm = sp.coo_matrix((value, (row, col)), shape=(vnum, vnum), dtype=np.float32)
|
145 |
+
return nm
|
146 |
+
|
147 |
+
|
148 |
+
def construct_laplacian_matrix(vnum, faces, normalized=False):
|
149 |
+
edges = get_edge_unique(faces)
|
150 |
+
neighbors = get_neighbors(edges)
|
151 |
+
|
152 |
+
col = np.concatenate(neighbors)
|
153 |
+
row = np.concatenate([[i] * len(n)
|
154 |
+
for i, n in enumerate(neighbors)])
|
155 |
+
col = np.concatenate([col, np.arange(0, vnum)])
|
156 |
+
row = np.concatenate([row, np.arange(0, vnum)])
|
157 |
+
|
158 |
+
if normalized:
|
159 |
+
data = [[1.0 / len(n)] * len(n) for n in neighbors]
|
160 |
+
data += [[-1.0] * vnum]
|
161 |
+
else:
|
162 |
+
data = [[1.0] * len(n) for n in neighbors]
|
163 |
+
data += [[-len(n) for n in neighbors]]
|
164 |
+
|
165 |
+
data = np.concatenate(data)
|
166 |
+
# create the sparse matrix
|
167 |
+
matrix = sp.coo_matrix((data, (row, col)), shape=[vnum] * 2)
|
168 |
+
return matrix
|
169 |
+
|
170 |
+
|
171 |
+
def rotationx_4x4(theta):
|
172 |
+
return np.array([
|
173 |
+
[1.0, 0.0, 0.0, 0.0],
|
174 |
+
[0.0, np.cos(theta / 180 * np.pi), np.sin(theta / 180 * np.pi), 0.0],
|
175 |
+
[0.0, -np.sin(theta / 180 * np.pi), np.cos(theta / 180 * np.pi), 0.0],
|
176 |
+
[0.0, 0.0, 0.0, 1.0]
|
177 |
+
])
|
178 |
+
|
179 |
+
|
180 |
+
def rotationy_4x4(theta):
|
181 |
+
return np.array([
|
182 |
+
[np.cos(theta / 180 * np.pi), 0.0, np.sin(theta / 180 * np.pi), 0.0],
|
183 |
+
[0.0, 1.0, 0.0, 0.0],
|
184 |
+
[-np.sin(theta / 180 * np.pi), 0.0, np.cos(theta / 180 * np.pi), 0.0],
|
185 |
+
[0.0, 0.0, 0.0, 1.0]
|
186 |
+
])
|
187 |
+
|
188 |
+
|
189 |
+
def rotationz_4x4(theta):
|
190 |
+
return np.array([
|
191 |
+
[np.cos(theta / 180 * np.pi), np.sin(theta / 180 * np.pi), 0.0, 0.0],
|
192 |
+
[-np.sin(theta / 180 * np.pi), np.cos(theta / 180 * np.pi), 0.0, 0.0],
|
193 |
+
[0.0, 0.0, 1.0, 0.0],
|
194 |
+
[0.0, 0.0, 0.0, 1.0]
|
195 |
+
])
|
196 |
+
|
197 |
+
|
198 |
+
def rotationx_3x3(theta):
|
199 |
+
return np.array([
|
200 |
+
[1.0, 0.0, 0.0],
|
201 |
+
[0.0, np.cos(theta / 180 * np.pi), np.sin(theta / 180 * np.pi)],
|
202 |
+
[0.0, -np.sin(theta / 180 * np.pi), np.cos(theta / 180 * np.pi)],
|
203 |
+
])
|
204 |
+
|
205 |
+
|
206 |
+
def rotationy_3x3(theta):
|
207 |
+
return np.array([
|
208 |
+
[np.cos(theta / 180 * np.pi), 0.0, np.sin(theta / 180 * np.pi)],
|
209 |
+
[0.0, 1.0, 0.0],
|
210 |
+
[-np.sin(theta / 180 * np.pi), 0.0, np.cos(theta / 180 * np.pi)],
|
211 |
+
])
|
212 |
+
|
213 |
+
|
214 |
+
def rotationz_3x3(theta):
|
215 |
+
return np.array([
|
216 |
+
[np.cos(theta / 180 * np.pi), np.sin(theta / 180 * np.pi), 0.0],
|
217 |
+
[-np.sin(theta / 180 * np.pi), np.cos(theta / 180 * np.pi), 0.0],
|
218 |
+
[0.0, 0.0, 1.0],
|
219 |
+
])
|
220 |
+
|
221 |
+
|
222 |
+
def generate_point_grids(vol_res):
|
223 |
+
x_coords = np.array(range(0, vol_res), dtype=np.float32)
|
224 |
+
y_coords = np.array(range(0, vol_res), dtype=np.float32)
|
225 |
+
z_coords = np.array(range(0, vol_res), dtype=np.float32)
|
226 |
+
|
227 |
+
yv, xv, zv = np.meshgrid(x_coords, y_coords, z_coords)
|
228 |
+
xv = np.reshape(xv, (-1, 1))
|
229 |
+
yv = np.reshape(yv, (-1, 1))
|
230 |
+
zv = np.reshape(zv, (-1, 1))
|
231 |
+
pts = np.concatenate([xv, yv, zv], axis=-1)
|
232 |
+
pts = pts.astype(np.float32)
|
233 |
+
return pts
|
234 |
+
|
235 |
+
|
236 |
+
def infer_occupancy_value_grid_octree(test_res, pts, query_fn, init_res=64, ignore_thres=0.05):
|
237 |
+
pts = np.reshape(pts, (test_res, test_res, test_res, 3))
|
238 |
+
|
239 |
+
pts_ov = np.zeros([test_res, test_res, test_res])
|
240 |
+
dirty = np.ones_like(pts_ov, dtype=np.bool)
|
241 |
+
grid_mask = np.zeros_like(pts_ov, dtype=np.bool)
|
242 |
+
|
243 |
+
reso = test_res // init_res
|
244 |
+
while reso > 0:
|
245 |
+
grid_mask[0:test_res:reso, 0:test_res:reso, 0:test_res:reso] = True
|
246 |
+
test_mask = np.logical_and(grid_mask, dirty)
|
247 |
+
|
248 |
+
pts_ = pts[test_mask]
|
249 |
+
pts_ov[test_mask] = np.reshape(query_fn(pts_), pts_ov[test_mask].shape)
|
250 |
+
|
251 |
+
if reso <= 1:
|
252 |
+
break
|
253 |
+
for x in range(0, test_res - reso, reso):
|
254 |
+
for y in range(0, test_res - reso, reso):
|
255 |
+
for z in range(0, test_res - reso, reso):
|
256 |
+
# if center marked, return
|
257 |
+
if not dirty[x + reso // 2, y + reso // 2, z + reso // 2]:
|
258 |
+
continue
|
259 |
+
v0 = pts_ov[x, y, z]
|
260 |
+
v1 = pts_ov[x, y, z + reso]
|
261 |
+
v2 = pts_ov[x, y + reso, z]
|
262 |
+
v3 = pts_ov[x, y + reso, z + reso]
|
263 |
+
v4 = pts_ov[x + reso, y, z]
|
264 |
+
v5 = pts_ov[x + reso, y, z + reso]
|
265 |
+
v6 = pts_ov[x + reso, y + reso, z]
|
266 |
+
v7 = pts_ov[x + reso, y + reso, z + reso]
|
267 |
+
v = np.array([v0, v1, v2, v3, v4, v5, v6, v7])
|
268 |
+
v_min = np.min(v)
|
269 |
+
v_max = np.max(v)
|
270 |
+
# this cell is all the same
|
271 |
+
if (v_max - v_min) < ignore_thres:
|
272 |
+
pts_ov[x:x + reso, y:y + reso, z:z + reso] = (v_max + v_min) / 2
|
273 |
+
dirty[x:x + reso, y:y + reso, z:z + reso] = False
|
274 |
+
reso //= 2
|
275 |
+
return pts_ov
|
276 |
+
|
277 |
+
|
278 |
+
def batch_rod2quat(rot_vecs):
|
279 |
+
batch_size = rot_vecs.shape[0]
|
280 |
+
|
281 |
+
angle = torch.norm(rot_vecs + 1e-16, dim=1, keepdim=True)
|
282 |
+
rot_dir = rot_vecs / angle
|
283 |
+
|
284 |
+
cos = torch.cos(angle / 2)
|
285 |
+
sin = torch.sin(angle / 2)
|
286 |
+
|
287 |
+
# Bx1 arrays
|
288 |
+
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
289 |
+
|
290 |
+
qx = rx * sin
|
291 |
+
qy = ry * sin
|
292 |
+
qz = rz * sin
|
293 |
+
qw = cos-1.0
|
294 |
+
|
295 |
+
return torch.cat([qx,qy,qz,qw], dim=1)
|
296 |
+
|
297 |
+
|
298 |
+
def batch_quat2matrix(rvec):
|
299 |
+
'''
|
300 |
+
args:
|
301 |
+
rvec: (B, N, 4)
|
302 |
+
'''
|
303 |
+
B, N, _ = rvec.size()
|
304 |
+
|
305 |
+
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=2))
|
306 |
+
rvec = rvec / theta[:, :, None]
|
307 |
+
return torch.stack((
|
308 |
+
1. - 2. * rvec[:, :, 1] ** 2 - 2. * rvec[:, :, 2] ** 2,
|
309 |
+
2. * (rvec[:, :, 0] * rvec[:, :, 1] - rvec[:, :, 2] * rvec[:, :, 3]),
|
310 |
+
2. * (rvec[:, :, 0] * rvec[:, :, 2] + rvec[:, :, 1] * rvec[:, :, 3]),
|
311 |
+
|
312 |
+
2. * (rvec[:, :, 0] * rvec[:, :, 1] + rvec[:, :, 2] * rvec[:, :, 3]),
|
313 |
+
1. - 2. * rvec[:, :, 0] ** 2 - 2. * rvec[:, :, 2] ** 2,
|
314 |
+
2. * (rvec[:, :, 1] * rvec[:, :, 2] - rvec[:, :, 0] * rvec[:, :, 3]),
|
315 |
+
|
316 |
+
2. * (rvec[:, :, 0] * rvec[:, :, 2] - rvec[:, :, 1] * rvec[:, :, 3]),
|
317 |
+
2. * (rvec[:, :, 0] * rvec[:, :, 3] + rvec[:, :, 1] * rvec[:, :, 2]),
|
318 |
+
1. - 2. * rvec[:, :, 0] ** 2 - 2. * rvec[:, :, 1] ** 2
|
319 |
+
), dim=2).view(B, N, 3, 3)
|
320 |
+
|
321 |
+
|
322 |
+
def get_posemap(map_type, n_joints, parents, n_traverse=1, normalize=True):
|
323 |
+
pose_map = torch.zeros(n_joints,n_joints-1)
|
324 |
+
if map_type == 'parent':
|
325 |
+
for i in range(n_joints-1):
|
326 |
+
pose_map[i+1,i] = 1.0
|
327 |
+
elif map_type == 'children':
|
328 |
+
for i in range(n_joints-1):
|
329 |
+
parent = parents[i+1]
|
330 |
+
for j in range(n_traverse):
|
331 |
+
pose_map[parent, i] += 1.0
|
332 |
+
if parent == 0:
|
333 |
+
break
|
334 |
+
parent = parents[parent]
|
335 |
+
if normalize:
|
336 |
+
pose_map /= pose_map.sum(0,keepdim=True)+1e-16
|
337 |
+
elif map_type == 'both':
|
338 |
+
for i in range(n_joints-1):
|
339 |
+
pose_map[i+1,i] += 1.0
|
340 |
+
parent = parents[i+1]
|
341 |
+
for j in range(n_traverse):
|
342 |
+
pose_map[parent, i] += 1.0
|
343 |
+
if parent == 0:
|
344 |
+
break
|
345 |
+
parent = parents[parent]
|
346 |
+
if normalize:
|
347 |
+
pose_map /= pose_map.sum(0,keepdim=True)+1e-16
|
348 |
+
else:
|
349 |
+
raise NotImplementedError('unsupported pose map type [%s]' % map_type)
|
350 |
+
pose_map = torch.cat([torch.zeros(n_joints, 1), pose_map], dim=1)
|
351 |
+
return pose_map
|
352 |
+
|
353 |
+
|
354 |
+
def vertices_to_triangles(vertices, faces):
|
355 |
+
"""
|
356 |
+
:param vertices: [batch size, number of vertices, 3]
|
357 |
+
:param faces: [batch size, number of faces, 3)
|
358 |
+
:return: [batch size, number of faces, 3, 3]
|
359 |
+
"""
|
360 |
+
assert (vertices.ndimension() == 3)
|
361 |
+
assert (faces.ndimension() == 3)
|
362 |
+
assert (vertices.shape[0] == faces.shape[0])
|
363 |
+
assert (vertices.shape[2] == 3)
|
364 |
+
assert (faces.shape[2] == 3)
|
365 |
+
|
366 |
+
bs, nv = vertices.shape[:2]
|
367 |
+
bs, nf = faces.shape[:2]
|
368 |
+
device = vertices.device
|
369 |
+
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
|
370 |
+
vertices = vertices.reshape((bs * nv, 3))
|
371 |
+
# pytorch only supports long and byte tensors for indexing
|
372 |
+
return vertices[faces.long()]
|
373 |
+
|
374 |
+
|
375 |
+
def calc_face_normals(vertices, faces):
|
376 |
+
assert len(vertices.shape) == 3
|
377 |
+
assert len(faces.shape) == 2
|
378 |
+
if isinstance(faces, np.ndarray):
|
379 |
+
faces = torch.from_numpy(faces.astype(np.int64)).to(vertices.device)
|
380 |
+
|
381 |
+
batch_size, pt_num = vertices.shape[:2]
|
382 |
+
face_num = faces.shape[0]
|
383 |
+
|
384 |
+
triangles = vertices_to_triangles(vertices, faces.unsqueeze(0).expand(batch_size, -1, -1))
|
385 |
+
triangles = triangles.reshape((batch_size * face_num, 3, 3))
|
386 |
+
v10 = triangles[:, 0] - triangles[:, 1]
|
387 |
+
v12 = triangles[:, 2] - triangles[:, 1]
|
388 |
+
# pytorch normalize divides by max(norm, eps) instead of (norm+eps) in chainer
|
389 |
+
normals = F.normalize(torch.cross(v10, v12), eps=1e-5)
|
390 |
+
normals = normals.reshape((batch_size, face_num, 3))
|
391 |
+
|
392 |
+
return normals
|
393 |
+
|
394 |
+
|
395 |
+
def calc_vert_normals(vertices, faces):
|
396 |
+
"""
|
397 |
+
vertices: [B, N, 3]
|
398 |
+
faces: [F, 3]
|
399 |
+
"""
|
400 |
+
normals = torch.zeros_like(vertices)
|
401 |
+
v0s = torch.index_select(vertices, dim=1, index=faces[:, 0]) # [B, F, 3]
|
402 |
+
v1s = torch.index_select(vertices, dim=1, index=faces[:, 1])
|
403 |
+
v2s = torch.index_select(vertices, dim=1, index=faces[:, 2])
|
404 |
+
normals = torch.index_add(normals, dim=1, index=faces[:, 1], source=torch.cross(v2s-v1s, v0s-v1s, dim=-1))
|
405 |
+
normals = torch.index_add(normals, dim=1, index=faces[:, 2], source=torch.cross(v0s-v2s, v1s-v2s, dim=-1))
|
406 |
+
normals = torch.index_add(normals, dim=1, index=faces[:, 0], source=torch.cross(v1s-v0s, v2s-v0s, dim=-1))
|
407 |
+
normals = F.normalize(normals, dim=-1)
|
408 |
+
return normals
|
409 |
+
|
410 |
+
|
411 |
+
def calc_vert_normals_numpy(vertices, faces):
|
412 |
+
assert len(vertices.shape) == 2
|
413 |
+
assert len(faces.shape) == 2
|
414 |
+
|
415 |
+
nmls = np.zeros_like(vertices)
|
416 |
+
fv0 = vertices[faces[:, 0]]
|
417 |
+
fv1 = vertices[faces[:, 1]]
|
418 |
+
fv2 = vertices[faces[:, 2]]
|
419 |
+
face_nmls = np.cross(fv1-fv0, fv2-fv0, axis=-1)
|
420 |
+
face_nmls = face_nmls / (np.linalg.norm(face_nmls, axis=-1, keepdims=True) + 1e-20)
|
421 |
+
for f, fn in zip(faces, face_nmls):
|
422 |
+
nmls[f] += fn
|
423 |
+
nmls = nmls / (np.linalg.norm(nmls, axis=-1, keepdims=True) + 1e-20)
|
424 |
+
return nmls
|
425 |
+
|
426 |
+
|
427 |
+
def glUV2torchUV(gl_uv):
|
428 |
+
torch_uv = torch.stack([
|
429 |
+
gl_uv[..., 0]*2.0-1.0,
|
430 |
+
gl_uv[..., 1]*-2.0+1.0
|
431 |
+
], dim=-1)
|
432 |
+
return torch_uv
|
433 |
+
|
434 |
+
|
435 |
+
def normalize_vert_bbox(verts, dim=-1, per_axis=False):
|
436 |
+
bbox_min = torch.min(verts, dim=dim, keepdim=True)[0]
|
437 |
+
bbox_max = torch.max(verts, dim=dim, keepdim=True)[0]
|
438 |
+
verts = verts - 0.5 * (bbox_max + bbox_min)
|
439 |
+
if per_axis:
|
440 |
+
verts = 2 * verts / (bbox_max - bbox_min)
|
441 |
+
else:
|
442 |
+
verts = 2 * verts / torch.max(bbox_max-bbox_min, dim=dim, keepdim=True)[0]
|
443 |
+
return verts
|
444 |
+
|
445 |
+
|
446 |
+
def upsample_sdf_volume(sdf, upsample_factor):
|
447 |
+
assert sdf.shape[0] == sdf.shape[1] == sdf.shape[2]
|
448 |
+
coarse_resolution = sdf.shape[0]
|
449 |
+
fine_resolution = coarse_resolution * upsample_factor
|
450 |
+
|
451 |
+
sdf_interp_buffer = np.zeros([2, 2, 2, coarse_resolution, coarse_resolution, coarse_resolution],
|
452 |
+
dtype=np.float32)
|
453 |
+
dx_list = [0, 1, 0, 1, 0, 1, 0, 1]
|
454 |
+
dy_list = [0, 0, 1, 1, 0, 0, 1, 1]
|
455 |
+
dz_list = [0, 0, 0, 0, 1, 1, 1, 1]
|
456 |
+
for dx, dy, dz in zip(dx_list, dy_list, dz_list):
|
457 |
+
sdf_interp_buffer[dx, dy, dz, :, :, :] = np.roll(sdf, (-dx, -dy, -dz), axis=(0, 1, 2))
|
458 |
+
|
459 |
+
sdf_fine = np.zeros([fine_resolution, fine_resolution, fine_resolution], dtype=np.float32)
|
460 |
+
for dx in range(upsample_factor):
|
461 |
+
for dy in range(upsample_factor):
|
462 |
+
for dz in range(upsample_factor):
|
463 |
+
wx = (1.0 - dx / upsample_factor)
|
464 |
+
wy = (1.0 - dy / upsample_factor)
|
465 |
+
wz = (1.0 - dz / upsample_factor)
|
466 |
+
sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
|
467 |
+
wx * wy * wz * sdf_interp_buffer[0, 0, 0]
|
468 |
+
sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
|
469 |
+
(1.0 - wx) * wy * wz * sdf_interp_buffer[1, 0, 0]
|
470 |
+
sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
|
471 |
+
wx * (1.0 - wy) * wz * sdf_interp_buffer[0, 1, 0]
|
472 |
+
sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
|
473 |
+
(1.0 - wx) * (1.0 - wy) * wz * sdf_interp_buffer[1, 1, 0]
|
474 |
+
sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
|
475 |
+
wx * wy * (1.0 - wz) * sdf_interp_buffer[0, 0, 1]
|
476 |
+
sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
|
477 |
+
(1.0 - wx) * wy * (1.0 - wz) * sdf_interp_buffer[1, 0, 1]
|
478 |
+
sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
|
479 |
+
wx * (1.0 - wy) * (1.0 - wz) * sdf_interp_buffer[0, 1, 1]
|
480 |
+
sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
|
481 |
+
(1.0 - wx) * (1.0 - wy) * (1.0 - wz) * sdf_interp_buffer[1, 1, 1]
|
482 |
+
return sdf_fine
|
483 |
+
|
484 |
+
|
485 |
+
def search_nearest_correspondence(src, tgt):
|
486 |
+
tgt_idx = np.zeros(len(src), dtype=np.int32)
|
487 |
+
tgt_dist = np.zeros(len(src), dtype=np.float32)
|
488 |
+
for i in range(len(src)):
|
489 |
+
dist = np.linalg.norm(tgt - src[i:(i+1)], axis=1, keepdims=False)
|
490 |
+
tgt_idx[i] = np.argmin(dist)
|
491 |
+
tgt_dist[i] = dist[tgt_idx[i]]
|
492 |
+
return tgt_idx, tgt_dist
|
493 |
+
|
494 |
+
|
495 |
+
def estimate_rigid_transformation(src, tgt):
|
496 |
+
src = src.transpose()
|
497 |
+
tgt = tgt.transpose()
|
498 |
+
mu1, mu2 = src.mean(axis=1, keepdims=True), tgt.mean(axis=1, keepdims=True)
|
499 |
+
X1, X2 = src - mu1, tgt - mu2
|
500 |
+
|
501 |
+
K = X1.dot(X2.T)
|
502 |
+
U, s, Vh = np.linalg.svd(K)
|
503 |
+
V = Vh.T
|
504 |
+
Z = np.eye(U.shape[0])
|
505 |
+
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
|
506 |
+
R = V.dot(Z.dot(U.T))
|
507 |
+
t = mu2 - R.dot(mu1)
|
508 |
+
|
509 |
+
# orient, _ = cv.Rodrigues(R)
|
510 |
+
# orient = orient.reshape([-1])
|
511 |
+
# t = t.reshape([-1])
|
512 |
+
# return orient, t
|
513 |
+
|
514 |
+
transf = np.eye(4, dtype=np.float32)
|
515 |
+
transf[:3, :3] = R
|
516 |
+
transf[:3, 3] = t.reshape([-1])
|
517 |
+
return transf
|
render_utils/lib/utils/graphics_utils.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
import numpy as np
|
15 |
+
from typing import NamedTuple
|
16 |
+
|
17 |
+
class BasicPointCloud(NamedTuple):
|
18 |
+
points : np.array
|
19 |
+
colors : np.array
|
20 |
+
normals : np.array
|
21 |
+
|
22 |
+
def geom_transform_points(points, transf_matrix):
|
23 |
+
P, _ = points.shape
|
24 |
+
ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
|
25 |
+
points_hom = torch.cat([points, ones], dim=1)
|
26 |
+
points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
|
27 |
+
|
28 |
+
denom = points_out[..., 3:] + 0.0000001
|
29 |
+
return (points_out[..., :3] / denom).squeeze(dim=0)
|
30 |
+
|
31 |
+
def getWorld2View(R, t):
|
32 |
+
Rt = np.zeros((4, 4))
|
33 |
+
Rt[:3, :3] = R.transpose()
|
34 |
+
Rt[:3, 3] = t
|
35 |
+
Rt[3, 3] = 1.0
|
36 |
+
return np.float32(Rt)
|
37 |
+
|
38 |
+
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
|
39 |
+
Rt = np.zeros((4, 4))
|
40 |
+
Rt[:3, :3] = R.transpose()
|
41 |
+
Rt[:3, 3] = t
|
42 |
+
Rt[3, 3] = 1.0
|
43 |
+
|
44 |
+
C2W = np.linalg.inv(Rt)
|
45 |
+
cam_center = C2W[:3, 3]
|
46 |
+
cam_center = (cam_center + translate) * scale
|
47 |
+
C2W[:3, 3] = cam_center
|
48 |
+
Rt = np.linalg.inv(C2W)
|
49 |
+
return np.float32(Rt)
|
50 |
+
|
51 |
+
def getProjectionMatrix(znear, zfar, fovX, fovY, K = None, img_h = None, img_w = None):
|
52 |
+
if K is None:
|
53 |
+
tanHalfFovY = math.tan((fovY / 2))
|
54 |
+
tanHalfFovX = math.tan((fovX / 2))
|
55 |
+
top = tanHalfFovY * znear
|
56 |
+
bottom = -top
|
57 |
+
right = tanHalfFovX * znear
|
58 |
+
left = -right
|
59 |
+
else:
|
60 |
+
near_fx = znear / K[0, 0]
|
61 |
+
near_fy = znear / K[1, 1]
|
62 |
+
|
63 |
+
left = - (img_w - K[0, 2]) * near_fx
|
64 |
+
right = K[0, 2] * near_fx
|
65 |
+
bottom = (K[1, 2] - img_h) * near_fy
|
66 |
+
top = K[1, 2] * near_fy
|
67 |
+
|
68 |
+
P = torch.zeros(4, 4)
|
69 |
+
|
70 |
+
z_sign = 1.0
|
71 |
+
|
72 |
+
P[0, 0] = 2.0 * znear / (right - left)
|
73 |
+
P[1, 1] = 2.0 * znear / (top - bottom)
|
74 |
+
P[0, 2] = (right + left) / (right - left)
|
75 |
+
P[1, 2] = (top + bottom) / (top - bottom)
|
76 |
+
P[3, 2] = z_sign
|
77 |
+
P[2, 2] = z_sign * zfar / (zfar - znear)
|
78 |
+
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
79 |
+
return P
|
80 |
+
|
81 |
+
def fov2focal(fov, pixels):
|
82 |
+
return pixels / (2 * math.tan(fov / 2))
|
83 |
+
|
84 |
+
def focal2fov(focal, pixels):
|
85 |
+
return 2*math.atan(pixels/(2*focal))
|
86 |
+
|
87 |
+
|
88 |
+
def _so3_exp_map(
|
89 |
+
log_rot: torch.Tensor, eps: float = 0.0001
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
A helper function that computes the so3 exponential map and,
|
93 |
+
apart from the rotation matrix, also returns intermediate variables
|
94 |
+
that can be re-used in other functions.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def hat(v: torch.Tensor) -> torch.Tensor:
|
98 |
+
"""
|
99 |
+
Compute the Hat operator [1] of a batch of 3D vectors.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
v: Batch of vectors of shape `(minibatch , 3)`.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
Batch of skew-symmetric matrices of shape
|
106 |
+
`(minibatch, 3 , 3)` where each matrix is of the form:
|
107 |
+
`[ 0 -v_z v_y ]
|
108 |
+
[ v_z 0 -v_x ]
|
109 |
+
[ -v_y v_x 0 ]`
|
110 |
+
|
111 |
+
Raises:
|
112 |
+
ValueError if `v` is of incorrect shape.
|
113 |
+
|
114 |
+
[1] https://en.wikipedia.org/wiki/Hat_operator
|
115 |
+
"""
|
116 |
+
|
117 |
+
N, dim = v.shape
|
118 |
+
if dim != 3:
|
119 |
+
raise ValueError("Input vectors have to be 3-dimensional.")
|
120 |
+
|
121 |
+
h = torch.zeros((N, 3, 3), dtype=v.dtype, device=v.device)
|
122 |
+
|
123 |
+
x, y, z = v.unbind(1)
|
124 |
+
|
125 |
+
h[:, 0, 1] = -z
|
126 |
+
h[:, 0, 2] = y
|
127 |
+
h[:, 1, 0] = z
|
128 |
+
h[:, 1, 2] = -x
|
129 |
+
h[:, 2, 0] = -y
|
130 |
+
h[:, 2, 1] = x
|
131 |
+
|
132 |
+
return h
|
133 |
+
|
134 |
+
_, dim = log_rot.shape
|
135 |
+
if dim != 3:
|
136 |
+
raise ValueError("Input tensor shape has to be Nx3.")
|
137 |
+
|
138 |
+
nrms = (log_rot * log_rot).sum(1)
|
139 |
+
# phis ... rotation angles
|
140 |
+
rot_angles = torch.clamp(nrms, eps).sqrt()
|
141 |
+
rot_angles_inv = 1.0 / rot_angles
|
142 |
+
fac1 = rot_angles_inv * rot_angles.sin()
|
143 |
+
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
|
144 |
+
skews = hat(log_rot)
|
145 |
+
skews_square = torch.bmm(skews, skews)
|
146 |
+
|
147 |
+
R = (
|
148 |
+
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
|
149 |
+
fac1[:, None, None] * skews
|
150 |
+
+ fac2[:, None, None] * skews_square
|
151 |
+
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
|
152 |
+
)
|
153 |
+
|
154 |
+
return R, rot_angles, skews, skews_square
|
155 |
+
|
156 |
+
|
157 |
+
def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
|
158 |
+
"""
|
159 |
+
Convert a batch of logarithmic representations of rotation matrices `log_rot`
|
160 |
+
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
|
161 |
+
|
162 |
+
In the logarithmic representation, each rotation matrix is represented as
|
163 |
+
a 3-dimensional vector (`log_rot`) who's l2-norm and direction correspond
|
164 |
+
to the magnitude of the rotation angle and the axis of rotation respectively.
|
165 |
+
|
166 |
+
The conversion has a singularity around `log(R) = 0`
|
167 |
+
which is handled by clamping controlled with the `eps` argument.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
log_rot: Batch of vectors of shape `(minibatch, 3)`.
|
171 |
+
eps: A float constant handling the conversion singularity.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
175 |
+
|
176 |
+
Raises:
|
177 |
+
ValueError if `log_rot` is of incorrect shape.
|
178 |
+
|
179 |
+
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
|
180 |
+
"""
|
181 |
+
return _so3_exp_map(log_rot, eps=eps)[0]
|
render_utils/lib/utils/rotation_conversions.py
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Optional, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
Device = Union[str, torch.device]
|
13 |
+
|
14 |
+
"""
|
15 |
+
The transformation matrices returned from the functions in this file assume
|
16 |
+
the points on which the transformation will be applied are column vectors.
|
17 |
+
i.e. the R matrix is structured as
|
18 |
+
|
19 |
+
R = [
|
20 |
+
[Rxx, Rxy, Rxz],
|
21 |
+
[Ryx, Ryy, Ryz],
|
22 |
+
[Rzx, Rzy, Rzz],
|
23 |
+
] # (3, 3)
|
24 |
+
|
25 |
+
This matrix can be applied to column vectors by post multiplication
|
26 |
+
by the points e.g.
|
27 |
+
|
28 |
+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
|
29 |
+
transformed_points = R * points
|
30 |
+
|
31 |
+
To apply the same matrix to points which are row vectors, the R matrix
|
32 |
+
can be transposed and pre multiplied by the points:
|
33 |
+
|
34 |
+
e.g.
|
35 |
+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
36 |
+
transformed_points = points * R.transpose(1, 0)
|
37 |
+
"""
|
38 |
+
|
39 |
+
|
40 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
41 |
+
"""
|
42 |
+
Convert rotations given as quaternions to rotation matrices.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
quaternions: quaternions with real part first,
|
46 |
+
as tensor of shape (..., 4).
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
50 |
+
"""
|
51 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
52 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
53 |
+
|
54 |
+
o = torch.stack(
|
55 |
+
(
|
56 |
+
1 - two_s * (j * j + k * k),
|
57 |
+
two_s * (i * j - k * r),
|
58 |
+
two_s * (i * k + j * r),
|
59 |
+
two_s * (i * j + k * r),
|
60 |
+
1 - two_s * (i * i + k * k),
|
61 |
+
two_s * (j * k - i * r),
|
62 |
+
two_s * (i * k - j * r),
|
63 |
+
two_s * (j * k + i * r),
|
64 |
+
1 - two_s * (i * i + j * j),
|
65 |
+
),
|
66 |
+
-1,
|
67 |
+
)
|
68 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
69 |
+
|
70 |
+
|
71 |
+
def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
72 |
+
"""
|
73 |
+
Return a tensor where each element has the absolute value taken from the,
|
74 |
+
corresponding element of a, with sign taken from the corresponding
|
75 |
+
element of b. This is like the standard copysign floating-point operation,
|
76 |
+
but is not careful about negative 0 and NaN.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
a: source tensor.
|
80 |
+
b: tensor whose signs will be used, of the same shape as a.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Tensor of the same shape as a with the signs of b.
|
84 |
+
"""
|
85 |
+
signs_differ = (a < 0) != (b < 0)
|
86 |
+
return torch.where(signs_differ, -a, a)
|
87 |
+
|
88 |
+
|
89 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
90 |
+
"""
|
91 |
+
Returns torch.sqrt(torch.max(0, x))
|
92 |
+
but with a zero subgradient where x is 0.
|
93 |
+
"""
|
94 |
+
ret = torch.zeros_like(x)
|
95 |
+
positive_mask = x > 0
|
96 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
97 |
+
return ret
|
98 |
+
|
99 |
+
|
100 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
101 |
+
"""
|
102 |
+
Convert rotations given as rotation matrices to quaternions.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
109 |
+
"""
|
110 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
111 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
112 |
+
|
113 |
+
batch_dim = matrix.shape[:-2]
|
114 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
115 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
116 |
+
)
|
117 |
+
|
118 |
+
q_abs = _sqrt_positive_part(
|
119 |
+
torch.stack(
|
120 |
+
[
|
121 |
+
1.0 + m00 + m11 + m22,
|
122 |
+
1.0 + m00 - m11 - m22,
|
123 |
+
1.0 - m00 + m11 - m22,
|
124 |
+
1.0 - m00 - m11 + m22,
|
125 |
+
],
|
126 |
+
dim=-1,
|
127 |
+
)
|
128 |
+
)
|
129 |
+
|
130 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
131 |
+
quat_by_rijk = torch.stack(
|
132 |
+
[
|
133 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
134 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
135 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
136 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
137 |
+
],
|
138 |
+
dim=-2,
|
139 |
+
)
|
140 |
+
|
141 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
142 |
+
# the candidate won't be picked.
|
143 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
144 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
145 |
+
|
146 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
147 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
148 |
+
|
149 |
+
return quat_candidates[
|
150 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
|
151 |
+
].reshape(batch_dim + (4,))
|
152 |
+
|
153 |
+
|
154 |
+
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
|
155 |
+
"""
|
156 |
+
Return the rotation matrices for one of the rotations about an axis
|
157 |
+
of which Euler angles describe, for each value of the angle given.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
axis: Axis label "X" or "Y or "Z".
|
161 |
+
angle: any shape tensor of Euler angles in radians
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
165 |
+
"""
|
166 |
+
|
167 |
+
cos = torch.cos(angle)
|
168 |
+
sin = torch.sin(angle)
|
169 |
+
one = torch.ones_like(angle)
|
170 |
+
zero = torch.zeros_like(angle)
|
171 |
+
|
172 |
+
if axis == "X":
|
173 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
174 |
+
elif axis == "Y":
|
175 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
176 |
+
elif axis == "Z":
|
177 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
178 |
+
else:
|
179 |
+
raise ValueError("letter must be either X, Y or Z.")
|
180 |
+
|
181 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
182 |
+
|
183 |
+
|
184 |
+
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
|
185 |
+
"""
|
186 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
190 |
+
convention: Convention string of three uppercase letters from
|
191 |
+
{"X", "Y", and "Z"}.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
195 |
+
"""
|
196 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
197 |
+
raise ValueError("Invalid input euler angles.")
|
198 |
+
if len(convention) != 3:
|
199 |
+
raise ValueError("Convention must have 3 letters.")
|
200 |
+
if convention[1] in (convention[0], convention[2]):
|
201 |
+
raise ValueError(f"Invalid convention {convention}.")
|
202 |
+
for letter in convention:
|
203 |
+
if letter not in ("X", "Y", "Z"):
|
204 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
205 |
+
matrices = [
|
206 |
+
_axis_angle_rotation(c, e)
|
207 |
+
for c, e in zip(convention, torch.unbind(euler_angles, -1))
|
208 |
+
]
|
209 |
+
# return functools.reduce(torch.matmul, matrices)
|
210 |
+
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
|
211 |
+
|
212 |
+
|
213 |
+
def _angle_from_tan(
|
214 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
215 |
+
) -> torch.Tensor:
|
216 |
+
"""
|
217 |
+
Extract the first or third Euler angle from the two members of
|
218 |
+
the matrix which are positive constant times its sine and cosine.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
222 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
223 |
+
convention.
|
224 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
225 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
226 |
+
which means the relevant entries are in the same row of the
|
227 |
+
rotation matrix. If not, they are in the same column.
|
228 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Euler Angles in radians for each matrix in data as a tensor
|
232 |
+
of shape (...).
|
233 |
+
"""
|
234 |
+
|
235 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
236 |
+
if horizontal:
|
237 |
+
i2, i1 = i1, i2
|
238 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
239 |
+
if horizontal == even:
|
240 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
241 |
+
if tait_bryan:
|
242 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
243 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
244 |
+
|
245 |
+
|
246 |
+
def _index_from_letter(letter: str) -> int:
|
247 |
+
if letter == "X":
|
248 |
+
return 0
|
249 |
+
if letter == "Y":
|
250 |
+
return 1
|
251 |
+
if letter == "Z":
|
252 |
+
return 2
|
253 |
+
raise ValueError("letter must be either X, Y or Z.")
|
254 |
+
|
255 |
+
|
256 |
+
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
|
257 |
+
"""
|
258 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
262 |
+
convention: Convention string of three uppercase letters.
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
Euler angles in radians as tensor of shape (..., 3).
|
266 |
+
"""
|
267 |
+
if len(convention) != 3:
|
268 |
+
raise ValueError("Convention must have 3 letters.")
|
269 |
+
if convention[1] in (convention[0], convention[2]):
|
270 |
+
raise ValueError(f"Invalid convention {convention}.")
|
271 |
+
for letter in convention:
|
272 |
+
if letter not in ("X", "Y", "Z"):
|
273 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
274 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
275 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
276 |
+
i0 = _index_from_letter(convention[0])
|
277 |
+
i2 = _index_from_letter(convention[2])
|
278 |
+
tait_bryan = i0 != i2
|
279 |
+
if tait_bryan:
|
280 |
+
central_angle = torch.asin(
|
281 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
285 |
+
|
286 |
+
o = (
|
287 |
+
_angle_from_tan(
|
288 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
289 |
+
),
|
290 |
+
central_angle,
|
291 |
+
_angle_from_tan(
|
292 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
293 |
+
),
|
294 |
+
)
|
295 |
+
return torch.stack(o, -1)
|
296 |
+
|
297 |
+
|
298 |
+
def random_quaternions(
|
299 |
+
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
300 |
+
) -> torch.Tensor:
|
301 |
+
"""
|
302 |
+
Generate random quaternions representing rotations,
|
303 |
+
i.e. versors with nonnegative real part.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
n: Number of quaternions in a batch to return.
|
307 |
+
dtype: Type to return.
|
308 |
+
device: Desired device of returned tensor. Default:
|
309 |
+
uses the current device for the default tensor type.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
Quaternions as tensor of shape (N, 4).
|
313 |
+
"""
|
314 |
+
if isinstance(device, str):
|
315 |
+
device = torch.device(device)
|
316 |
+
o = torch.randn((n, 4), dtype=dtype, device=device)
|
317 |
+
s = (o * o).sum(1)
|
318 |
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
319 |
+
return o
|
320 |
+
|
321 |
+
|
322 |
+
def random_rotations(
|
323 |
+
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
324 |
+
) -> torch.Tensor:
|
325 |
+
"""
|
326 |
+
Generate random rotations as 3x3 rotation matrices.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
n: Number of rotation matrices in a batch to return.
|
330 |
+
dtype: Type to return.
|
331 |
+
device: Device of returned tensor. Default: if None,
|
332 |
+
uses the current device for the default tensor type.
|
333 |
+
|
334 |
+
Returns:
|
335 |
+
Rotation matrices as tensor of shape (n, 3, 3).
|
336 |
+
"""
|
337 |
+
quaternions = random_quaternions(n, dtype=dtype, device=device)
|
338 |
+
return quaternion_to_matrix(quaternions)
|
339 |
+
|
340 |
+
|
341 |
+
def random_rotation(
|
342 |
+
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
343 |
+
) -> torch.Tensor:
|
344 |
+
"""
|
345 |
+
Generate a single random 3x3 rotation matrix.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
dtype: Type to return
|
349 |
+
device: Device of returned tensor. Default: if None,
|
350 |
+
uses the current device for the default tensor type
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
Rotation matrix as tensor of shape (3, 3).
|
354 |
+
"""
|
355 |
+
return random_rotations(1, dtype, device)[0]
|
356 |
+
|
357 |
+
|
358 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
359 |
+
"""
|
360 |
+
Convert a unit quaternion to a standard form: one in which the real
|
361 |
+
part is non negative.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
quaternions: Quaternions with real part first,
|
365 |
+
as tensor of shape (..., 4).
|
366 |
+
|
367 |
+
Returns:
|
368 |
+
Standardized quaternions as tensor of shape (..., 4).
|
369 |
+
"""
|
370 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
371 |
+
|
372 |
+
|
373 |
+
def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
374 |
+
"""
|
375 |
+
Multiply two quaternions.
|
376 |
+
Usual torch rules for broadcasting apply.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
380 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
381 |
+
|
382 |
+
Returns:
|
383 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
384 |
+
"""
|
385 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
386 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
387 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
388 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
389 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
390 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
391 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
392 |
+
|
393 |
+
|
394 |
+
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
395 |
+
"""
|
396 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
397 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
398 |
+
Usual torch rules for broadcasting apply.
|
399 |
+
|
400 |
+
Args:
|
401 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
402 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
406 |
+
"""
|
407 |
+
ab = quaternion_raw_multiply(a, b)
|
408 |
+
return standardize_quaternion(ab)
|
409 |
+
|
410 |
+
|
411 |
+
def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
|
412 |
+
"""
|
413 |
+
Given a quaternion representing rotation, get the quaternion representing
|
414 |
+
its inverse.
|
415 |
+
|
416 |
+
Args:
|
417 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
418 |
+
first, which must be versors (unit quaternions).
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
422 |
+
"""
|
423 |
+
|
424 |
+
scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
|
425 |
+
return quaternion * scaling
|
426 |
+
|
427 |
+
|
428 |
+
def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
|
429 |
+
"""
|
430 |
+
Apply the rotation given by a quaternion to a 3D point.
|
431 |
+
Usual torch rules for broadcasting apply.
|
432 |
+
|
433 |
+
Args:
|
434 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
435 |
+
point: Tensor of 3D points of shape (..., 3).
|
436 |
+
|
437 |
+
Returns:
|
438 |
+
Tensor of rotated points of shape (..., 3).
|
439 |
+
"""
|
440 |
+
if point.size(-1) != 3:
|
441 |
+
raise ValueError(f"Points are not in 3D, {point.shape}.")
|
442 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
443 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
444 |
+
out = quaternion_raw_multiply(
|
445 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
446 |
+
quaternion_invert(quaternion),
|
447 |
+
)
|
448 |
+
return out[..., 1:]
|
449 |
+
|
450 |
+
|
451 |
+
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
452 |
+
"""
|
453 |
+
Convert rotations given as axis/angle to rotation matrices.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
457 |
+
as a tensor of shape (..., 3), where the magnitude is
|
458 |
+
the angle turned anticlockwise in radians around the
|
459 |
+
vector's direction.
|
460 |
+
|
461 |
+
Returns:
|
462 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
463 |
+
"""
|
464 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
465 |
+
|
466 |
+
|
467 |
+
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
|
468 |
+
"""
|
469 |
+
Convert rotations given as rotation matrices to axis/angle.
|
470 |
+
|
471 |
+
Args:
|
472 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
Rotations given as a vector in axis angle form, as a tensor
|
476 |
+
of shape (..., 3), where the magnitude is the angle
|
477 |
+
turned anticlockwise in radians around the vector's
|
478 |
+
direction.
|
479 |
+
"""
|
480 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
481 |
+
|
482 |
+
|
483 |
+
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
484 |
+
"""
|
485 |
+
Convert rotations given as axis/angle to quaternions.
|
486 |
+
|
487 |
+
Args:
|
488 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
489 |
+
as a tensor of shape (..., 3), where the magnitude is
|
490 |
+
the angle turned anticlockwise in radians around the
|
491 |
+
vector's direction.
|
492 |
+
|
493 |
+
Returns:
|
494 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
495 |
+
"""
|
496 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
497 |
+
half_angles = angles * 0.5
|
498 |
+
eps = 1e-6
|
499 |
+
small_angles = angles.abs() < eps
|
500 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
501 |
+
sin_half_angles_over_angles[~small_angles] = (
|
502 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
503 |
+
)
|
504 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
505 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
506 |
+
sin_half_angles_over_angles[small_angles] = (
|
507 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
508 |
+
)
|
509 |
+
quaternions = torch.cat(
|
510 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
511 |
+
)
|
512 |
+
return quaternions
|
513 |
+
|
514 |
+
|
515 |
+
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
516 |
+
"""
|
517 |
+
Convert rotations given as quaternions to axis/angle.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
quaternions: quaternions with real part first,
|
521 |
+
as tensor of shape (..., 4).
|
522 |
+
|
523 |
+
Returns:
|
524 |
+
Rotations given as a vector in axis angle form, as a tensor
|
525 |
+
of shape (..., 3), where the magnitude is the angle
|
526 |
+
turned anticlockwise in radians around the vector's
|
527 |
+
direction.
|
528 |
+
"""
|
529 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
530 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
531 |
+
angles = 2 * half_angles
|
532 |
+
eps = 1e-6
|
533 |
+
small_angles = angles.abs() < eps
|
534 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
535 |
+
sin_half_angles_over_angles[~small_angles] = (
|
536 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
537 |
+
)
|
538 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
539 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
540 |
+
sin_half_angles_over_angles[small_angles] = (
|
541 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
542 |
+
)
|
543 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
544 |
+
|
545 |
+
|
546 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
547 |
+
"""
|
548 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
549 |
+
using Gram--Schmidt orthogonalization per Section B of [1].
|
550 |
+
Args:
|
551 |
+
d6: 6D rotation representation, of size (*, 6)
|
552 |
+
|
553 |
+
Returns:
|
554 |
+
batch of rotation matrices of size (*, 3, 3)
|
555 |
+
|
556 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
557 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
558 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
559 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
560 |
+
"""
|
561 |
+
|
562 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
563 |
+
b1 = F.normalize(a1, dim=-1)
|
564 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
565 |
+
b2 = F.normalize(b2, dim=-1)
|
566 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
567 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
568 |
+
|
569 |
+
|
570 |
+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
571 |
+
"""
|
572 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
573 |
+
by dropping the last row. Note that 6D representation is not unique.
|
574 |
+
Args:
|
575 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
576 |
+
|
577 |
+
Returns:
|
578 |
+
6D rotation representation, of size (*, 6)
|
579 |
+
|
580 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
581 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
582 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
583 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
584 |
+
"""
|
585 |
+
batch_dim = matrix.size()[:-2]
|
586 |
+
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
|
render_utils/lib/utils/sh_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The PlenOctree Authors.
|
2 |
+
# Redistribution and use in source and binary forms, with or without
|
3 |
+
# modification, are permitted provided that the following conditions are met:
|
4 |
+
#
|
5 |
+
# 1. Redistributions of source code must retain the above copyright notice,
|
6 |
+
# this list of conditions and the following disclaimer.
|
7 |
+
#
|
8 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
9 |
+
# this list of conditions and the following disclaimer in the documentation
|
10 |
+
# and/or other materials provided with the distribution.
|
11 |
+
#
|
12 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
13 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
14 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
15 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
16 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
17 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
18 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
19 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
20 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
21 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
22 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
23 |
+
|
24 |
+
import torch
|
25 |
+
|
26 |
+
C0 = 0.28209479177387814
|
27 |
+
C1 = 0.4886025119029199
|
28 |
+
C2 = [
|
29 |
+
1.0925484305920792,
|
30 |
+
-1.0925484305920792,
|
31 |
+
0.31539156525252005,
|
32 |
+
-1.0925484305920792,
|
33 |
+
0.5462742152960396
|
34 |
+
]
|
35 |
+
C3 = [
|
36 |
+
-0.5900435899266435,
|
37 |
+
2.890611442640554,
|
38 |
+
-0.4570457994644658,
|
39 |
+
0.3731763325901154,
|
40 |
+
-0.4570457994644658,
|
41 |
+
1.445305721320277,
|
42 |
+
-0.5900435899266435
|
43 |
+
]
|
44 |
+
C4 = [
|
45 |
+
2.5033429417967046,
|
46 |
+
-1.7701307697799304,
|
47 |
+
0.9461746957575601,
|
48 |
+
-0.6690465435572892,
|
49 |
+
0.10578554691520431,
|
50 |
+
-0.6690465435572892,
|
51 |
+
0.47308734787878004,
|
52 |
+
-1.7701307697799304,
|
53 |
+
0.6258357354491761,
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
def eval_sh(deg, sh, dirs):
|
58 |
+
"""
|
59 |
+
Evaluate spherical harmonics at unit directions
|
60 |
+
using hardcoded SH polynomials.
|
61 |
+
Works with torch/np/jnp.
|
62 |
+
... Can be 0 or more batch dimensions.
|
63 |
+
Args:
|
64 |
+
deg: int SH deg. Currently, 0-3 supported
|
65 |
+
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
|
66 |
+
dirs: jnp.ndarray unit directions [..., 3]
|
67 |
+
Returns:
|
68 |
+
[..., C]
|
69 |
+
"""
|
70 |
+
assert deg <= 4 and deg >= 0
|
71 |
+
coeff = (deg + 1) ** 2
|
72 |
+
assert sh.shape[-1] >= coeff
|
73 |
+
|
74 |
+
result = C0 * sh[..., 0]
|
75 |
+
if deg > 0:
|
76 |
+
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
|
77 |
+
result = (result -
|
78 |
+
C1 * y * sh[..., 1] +
|
79 |
+
C1 * z * sh[..., 2] -
|
80 |
+
C1 * x * sh[..., 3])
|
81 |
+
|
82 |
+
if deg > 1:
|
83 |
+
xx, yy, zz = x * x, y * y, z * z
|
84 |
+
xy, yz, xz = x * y, y * z, x * z
|
85 |
+
result = (result +
|
86 |
+
C2[0] * xy * sh[..., 4] +
|
87 |
+
C2[1] * yz * sh[..., 5] +
|
88 |
+
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
|
89 |
+
C2[3] * xz * sh[..., 7] +
|
90 |
+
C2[4] * (xx - yy) * sh[..., 8])
|
91 |
+
|
92 |
+
if deg > 2:
|
93 |
+
result = (result +
|
94 |
+
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
|
95 |
+
C3[1] * xy * z * sh[..., 10] +
|
96 |
+
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
|
97 |
+
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
|
98 |
+
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
|
99 |
+
C3[5] * z * (xx - yy) * sh[..., 14] +
|
100 |
+
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
|
101 |
+
|
102 |
+
if deg > 3:
|
103 |
+
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
|
104 |
+
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
|
105 |
+
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
|
106 |
+
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
|
107 |
+
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
|
108 |
+
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
|
109 |
+
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
|
110 |
+
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
|
111 |
+
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
|
112 |
+
return result
|
113 |
+
|
114 |
+
def RGB2SH(rgb):
|
115 |
+
return (rgb - 0.5) / C0
|
116 |
+
|
117 |
+
def SH2RGB(sh):
|
118 |
+
return sh * C0 + 0.5
|
render_utils/stitch_body_and_head.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import trimesh
|
5 |
+
from sklearn.neighbors import KDTree
|
6 |
+
import tqdm
|
7 |
+
import cv2 as cv
|
8 |
+
import os, glob
|
9 |
+
|
10 |
+
from .lib.networks.faceverse_torch import FaceVerseModel
|
11 |
+
from .lib.networks.smpl_torch import SmplTorch
|
12 |
+
from .lib.utils.gaussian_np_utils import GaussianAttributes, load_gaussians_from_ply, save_gaussians_as_ply, \
|
13 |
+
apply_transformation_to_gaussians, combine_gaussians, select_gaussians, update_gaussian_attributes
|
14 |
+
from .lib.utils.geometry import search_nearest_correspondence, estimate_rigid_transformation
|
15 |
+
from .lib.utils.sh_utils import SH2RGB
|
16 |
+
|
17 |
+
|
18 |
+
def process_smpl_head():
|
19 |
+
smpl = SmplTorch(model_file='./AnimatableGaussians/smpl_files/smplx/SMPLX_NEUTRAL.npz')
|
20 |
+
smpl_v_template = smpl.v_template.detach().cpu().numpy()
|
21 |
+
smpl_faces = smpl.faces.detach().cpu().numpy()
|
22 |
+
# head_skinning_weights = smpl.weights[:, 15] + smpl.weights[:, 22] + smpl.weights[:, 23] + smpl.weights[:, 24]
|
23 |
+
#
|
24 |
+
# blend_weight = np.clip(head_skinning_weights* 1.2 - 0.2, 0, 1)
|
25 |
+
# head_ids = np.where(blend_weight > 0)[0]
|
26 |
+
#
|
27 |
+
# np.savez('./data/smplx_head_vidx_and_blendweight.npz', blend_weight=blend_weight, head_ids=head_ids)
|
28 |
+
|
29 |
+
if not os.path.exists('./data/smpl_models/smplx_head_3.obj'):
|
30 |
+
trimesh.Trimesh(vertices=smpl_v_template, faces=smpl_faces).export('./data/smpl_models/smplx_head_3.obj')
|
31 |
+
print('Please cut out SMPL head!!!')
|
32 |
+
import pdb; pdb.set_trace()
|
33 |
+
|
34 |
+
smplx_head = trimesh.load('./data/smpl_models/smplx_head_3.obj')
|
35 |
+
smplx_to_head_dist = np.zeros([smpl_v_template.shape[0]])
|
36 |
+
for vi, v in enumerate(smpl_v_template):
|
37 |
+
nndist = np.min(np.linalg.norm(v.reshape([1, 3]) - smplx_head.vertices, axis=1))
|
38 |
+
smplx_to_head_dist[vi] = nndist
|
39 |
+
|
40 |
+
head_ids = np.where(smplx_to_head_dist < 0.001)[0]
|
41 |
+
blend_weight = np.exp(-smplx_to_head_dist*smplx_to_head_dist * 2000)
|
42 |
+
|
43 |
+
np.savez('./data/smpl_models/smplx_head_vidx_and_blendweight.npz', blend_weight=blend_weight, head_ids=head_ids)
|
44 |
+
|
45 |
+
return
|
46 |
+
|
47 |
+
|
48 |
+
def load_body_params(path):
|
49 |
+
param = dict(np.load(path))
|
50 |
+
global_orient = param['global_orient']
|
51 |
+
transl = param['transl']
|
52 |
+
body_pose = param['body_pose']
|
53 |
+
betas = param['betas']
|
54 |
+
return global_orient, transl, body_pose, betas
|
55 |
+
|
56 |
+
|
57 |
+
def load_face_params(path):
|
58 |
+
param = dict(np.load(path))
|
59 |
+
pose = param['pose']
|
60 |
+
scale = param['scale']
|
61 |
+
id_coeff = param['id_coeff']
|
62 |
+
exp_coeff = param['exp_coeff']
|
63 |
+
return pose, scale, id_coeff, exp_coeff
|
64 |
+
|
65 |
+
|
66 |
+
def get_smpl_verts_and_head_transformation(smpl, global_orient, body_pose, transl, betas):
|
67 |
+
pose = torch.cat([
|
68 |
+
torch.from_numpy(global_orient.astype(np.float32)),
|
69 |
+
torch.from_numpy(body_pose.astype(np.float32)),
|
70 |
+
torch.zeros([(3+15+15)*3], dtype=torch.float32)], dim=-1)
|
71 |
+
beta = torch.from_numpy(betas.astype(np.float32))
|
72 |
+
verts, skinning_dict = smpl.forward(pose.reshape(1, -1), beta.reshape(1, -1))
|
73 |
+
verts = verts[0].detach().cpu().numpy()
|
74 |
+
head_joint_transfmat = skinning_dict['G'][0, 15].detach().cpu().numpy()
|
75 |
+
verts += transl.reshape([1, 3])
|
76 |
+
head_joint_transfmat[:3, 3] += transl.reshape([3])
|
77 |
+
return verts, head_joint_transfmat
|
78 |
+
|
79 |
+
|
80 |
+
def crop_facial_area(faceverse_verts, points, dist_thres=0.025):
|
81 |
+
min_x, min_y, min_z = np.min(faceverse_verts, axis=0)
|
82 |
+
max_x, max_y, max_z = np.max(faceverse_verts, axis=0)
|
83 |
+
pad = dist_thres*2
|
84 |
+
in_bbox_mask = (points[:, 0] > min_x - pad) * (points[:, 0] < max_x + pad) * \
|
85 |
+
(points[:, 1] > min_y - pad) * (points[:, 1] < max_y + pad) * \
|
86 |
+
(points[:, 2] > min_z - pad) * (points[:, 2] < max_z + pad)
|
87 |
+
in_bbox_idx = np.where(in_bbox_mask)[0]
|
88 |
+
facial_points = points[in_bbox_mask]
|
89 |
+
nndist = np.ones([len(facial_points)]) * 1e10
|
90 |
+
for i in tqdm.trange(len(facial_points), desc='calculating facial area'):
|
91 |
+
nndist[i] = np.min(np.linalg.norm(faceverse_verts - facial_points[i:(i+1)], axis=1, keepdims=False))
|
92 |
+
close_to_face_mask = nndist < dist_thres
|
93 |
+
facial_points = facial_points[close_to_face_mask]
|
94 |
+
facial_idx = in_bbox_idx[close_to_face_mask]
|
95 |
+
return facial_points, facial_idx
|
96 |
+
|
97 |
+
|
98 |
+
def crop_facial_area2(smpl_verts, smpl_head_vids, points):
|
99 |
+
min_x, min_y, min_z = np.min(smpl_verts[smpl_head_vids], axis=0)
|
100 |
+
max_x, max_y, max_z = np.max(smpl_verts[smpl_head_vids], axis=0)
|
101 |
+
pad = 0.05
|
102 |
+
in_bbox_mask = (points[:, 0] > min_x - pad) * (points[:, 0] < max_x + pad) * \
|
103 |
+
(points[:, 1] > min_y - pad) * (points[:, 1] < max_y + pad) * \
|
104 |
+
(points[:, 2] > min_z - pad) * (points[:, 2] < max_z + pad)
|
105 |
+
in_bbox_idx = np.where(in_bbox_mask)[0]
|
106 |
+
facial_points = points[in_bbox_mask]
|
107 |
+
smpl_head_mask = np.zeros([len(smpl_verts)], dtype=np.bool_)
|
108 |
+
smpl_head_mask[smpl_head_vids] = True
|
109 |
+
close_to_face_mask = np.zeros([len(facial_points)], dtype=np.bool_)
|
110 |
+
for i in tqdm.trange(len(facial_points)):
|
111 |
+
nnid = np.argmin(np.linalg.norm(smpl_verts - facial_points[i:(i+1)], axis=1, keepdims=False))
|
112 |
+
close_to_face_mask[i] = smpl_head_mask[nnid]
|
113 |
+
facial_points = facial_points[close_to_face_mask]
|
114 |
+
facial_idx = in_bbox_idx[close_to_face_mask]
|
115 |
+
return facial_points, facial_idx
|
116 |
+
|
117 |
+
|
118 |
+
def transform_faceverse_to_live_body_space(faceverse_verts, faceverse_to_smplx, head_joint_transfmat):
|
119 |
+
faceverse_verts = np.matmul(faceverse_verts, faceverse_to_smplx[:3, :3].transpose()) + faceverse_to_smplx[:3, 3].reshape(1, 3)
|
120 |
+
faceverse_verts = np.matmul(faceverse_verts, head_joint_transfmat[:3, :3].transpose()) + head_joint_transfmat[:3, 3].reshape(1, 3)
|
121 |
+
return faceverse_verts
|
122 |
+
|
123 |
+
|
124 |
+
def calc_livehead2livebody(head_pose, smplx_to_faceverse, head_joint_transfmat):
|
125 |
+
head_cano2live = np.eye(4, dtype=np.float32)
|
126 |
+
head_cano2live[:3, :3] = cv.Rodrigues(head_pose[:3])[0]
|
127 |
+
head_cano2live[:3, 3] = head_pose[3:]
|
128 |
+
head_live2cano = np.linalg.inv(head_cano2live)
|
129 |
+
|
130 |
+
faceverse_to_smplx = np.linalg.inv(smplx_to_faceverse)
|
131 |
+
|
132 |
+
total_transf = np.eye(4, dtype=np.float32)
|
133 |
+
for t in [head_live2cano, np.diag([1, -1, -1, 1]), faceverse_to_smplx, head_joint_transfmat]:
|
134 |
+
total_transf = np.matmul(t, total_transf)
|
135 |
+
|
136 |
+
return total_transf
|
137 |
+
|
138 |
+
|
139 |
+
def get_face_blend_weight(head_facial_points, smpl_verts, sigma=0.015):
|
140 |
+
# dists = np.load('./data/faceverse/smplx_verts_to_faceverse_dist.npy').astype(np.float32)
|
141 |
+
# face_nerf_blend_weight = np.exp(-dists**2/(2*sigma**2))
|
142 |
+
# face_nerf_blend_weight = np.clip(face_nerf_blend_weight*1.2 - 0.1, 0, 1)
|
143 |
+
|
144 |
+
smpl_blend_weight = dict(np.load('./data/smpl_models/smplx_head_vidx_and_blendweight.npz'))['blend_weight']
|
145 |
+
|
146 |
+
corr_idx_, _ = search_nearest_correspondence(head_facial_points, smpl_verts)
|
147 |
+
corr_bw = smpl_blend_weight[corr_idx_]
|
148 |
+
|
149 |
+
for _ in tqdm.trange(10):
|
150 |
+
corr_bw_ = np.zeros_like(corr_bw)
|
151 |
+
tree = KDTree(head_facial_points, leaf_size=2)
|
152 |
+
for i in range(len(head_facial_points)):
|
153 |
+
_, idx = tree.query(head_facial_points[i:(i+1)], k=4)
|
154 |
+
corr_bw_[i] = np.mean(corr_bw[idx])
|
155 |
+
corr_bw = np.copy(corr_bw_)
|
156 |
+
|
157 |
+
# corr_bw = np.clip(corr_bw*1.2 - 0.15, 0, 1)
|
158 |
+
|
159 |
+
# with open('./debug/debug_head_facial_bw.obj', 'w') as fp:
|
160 |
+
# for p, w in zip(head_facial_points, corr_bw):
|
161 |
+
# fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], w, w, w))
|
162 |
+
# import pdb; pdb.set_trace()
|
163 |
+
return corr_bw
|
164 |
+
|
165 |
+
|
166 |
+
def get_face_blend_weight2(head_facial_points, body_points, body_facial_idx):
|
167 |
+
body_facial_bbox_min = np.min(body_points[body_facial_idx], axis=0)
|
168 |
+
body_facial_bbox_max = np.max(body_points[body_facial_idx], axis=0)
|
169 |
+
body_facial_bbox_min = body_facial_bbox_min - 0.1
|
170 |
+
body_facial_bbox_max = body_facial_bbox_max + 0.1
|
171 |
+
inside_bbox_flag = \
|
172 |
+
np.int32(body_points[:, 0] > body_facial_bbox_min[0]) * \
|
173 |
+
np.int32(body_points[:, 0] < body_facial_bbox_max[0]) * \
|
174 |
+
np.int32(body_points[:, 1] > body_facial_bbox_min[1]) * \
|
175 |
+
np.int32(body_points[:, 1] < body_facial_bbox_max[1]) * \
|
176 |
+
np.int32(body_points[:, 2] > body_facial_bbox_min[2]) * \
|
177 |
+
np.int32(body_points[:, 2] < body_facial_bbox_max[2])
|
178 |
+
point_idx_inside_bbox = np.nonzero(inside_bbox_flag >0)[0]
|
179 |
+
body_blend_weight = np.zeros([len(body_points)], dtype=np.float32)
|
180 |
+
body_blend_weight[body_facial_idx] = 1
|
181 |
+
|
182 |
+
body_points_in_bbox = body_points[point_idx_inside_bbox]
|
183 |
+
body_blend_weight_in_bbox = body_blend_weight[point_idx_inside_bbox]
|
184 |
+
for _ in tqdm.trange(1, desc='Calculating body facial blend weight'):
|
185 |
+
corr_bw_ = np.zeros_like(body_blend_weight_in_bbox)
|
186 |
+
tree = KDTree(body_points_in_bbox, leaf_size=2)
|
187 |
+
for i in tqdm.trange(len(body_points_in_bbox)):
|
188 |
+
ind = tree.query_radius(body_points_in_bbox[i:(i+1)], r=0.035)
|
189 |
+
corr_bw_[i] = np.mean(body_blend_weight_in_bbox[ind[0]])
|
190 |
+
body_blend_weight_in_bbox = np.copy(corr_bw_)
|
191 |
+
body_blend_weight[point_idx_inside_bbox] = body_blend_weight_in_bbox
|
192 |
+
|
193 |
+
with open('./debug/debug_body_facial_bw.obj', 'w') as fp:
|
194 |
+
for p, w in zip(body_points, body_blend_weight):
|
195 |
+
fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], w, w, w))
|
196 |
+
|
197 |
+
tree = KDTree(body_points, leaf_size=2)
|
198 |
+
corr_bw = np.zeros([len(head_facial_points)], dtype=np.float32)
|
199 |
+
for i in range(len(head_facial_points)):
|
200 |
+
_, idx = tree.query(head_facial_points[i:(i+1)], k=4)
|
201 |
+
corr_bw[i] = np.mean(body_blend_weight[idx])
|
202 |
+
# corr_bw = np.clip(corr_bw*1.2 - 0.15, 0, 1)
|
203 |
+
|
204 |
+
corr_bw_bmin, corr_bw_bmax = np.percentile(corr_bw, 5), np.percentile(corr_bw, 95)
|
205 |
+
corr_bw = np.clip((corr_bw-corr_bw_bmin)/(corr_bw_bmax-corr_bw_bmin), 0, 1)
|
206 |
+
with open('./debug/debug_head_facial_bw.obj', 'w') as fp:
|
207 |
+
for p, w in zip(head_facial_points, corr_bw):
|
208 |
+
fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], w, w, w))
|
209 |
+
return corr_bw
|
210 |
+
|
211 |
+
|
212 |
+
def estimate_color_transfer(head_facial_points, body_facial_points, head_facial_color, body_facial_color, head_facial_opacity):
|
213 |
+
head_facial_color = head_facial_color * 0.28209479177387814 + 0.5
|
214 |
+
body_facial_color = body_facial_color * 0.28209479177387814 + 0.5
|
215 |
+
|
216 |
+
corr_idx, _ = search_nearest_correspondence(head_facial_points, body_facial_points)
|
217 |
+
corr_color = body_facial_color[corr_idx]
|
218 |
+
|
219 |
+
opacity = 1/(1+np.exp(-head_facial_opacity))
|
220 |
+
weight = np.float32(opacity > 0.35)
|
221 |
+
head_facial_color = head_facial_color.reshape(len(head_facial_color), 3) * weight.reshape([-1, 1])
|
222 |
+
corr_color = corr_color.reshape(len(corr_color), 3) * weight.reshape([-1, 1])
|
223 |
+
|
224 |
+
head_facial_color = np.concatenate([head_facial_color, np.zeros_like(head_facial_color[:, :1])], axis=1)
|
225 |
+
corr_color = np.concatenate([corr_color, np.zeros_like(corr_color[:, :1])], axis=1)
|
226 |
+
|
227 |
+
transfer = nn.Parameter(torch.eye(4, dtype=torch.float32))
|
228 |
+
head_facial_color_th = torch.from_numpy(head_facial_color).float()
|
229 |
+
corr_color_th = torch.from_numpy(corr_color).float()
|
230 |
+
weight_th = torch.from_numpy(weight).float()
|
231 |
+
optim = torch.optim.Adam([transfer], lr=1e-2)
|
232 |
+
|
233 |
+
for i in range(500):
|
234 |
+
optim.zero_grad()
|
235 |
+
loss = torch.mean(torch.abs(corr_color_th - torch.matmul(head_facial_color_th, transfer.permute(1, 0)))*weight_th)
|
236 |
+
loss = loss + torch.sum(torch.square(transfer - torch.eye(4, dtype=torch.float32))) * 5e-2
|
237 |
+
if i % 25 == 0:
|
238 |
+
print(loss.item())
|
239 |
+
loss.backward()
|
240 |
+
optim.step()
|
241 |
+
transfer = transfer.detach().cpu().numpy()
|
242 |
+
print(transfer)
|
243 |
+
|
244 |
+
# with open('./debug/debug_body_facial_color_updated.obj', 'w') as fp:
|
245 |
+
# for p, c in zip(body_facial_points, body_facial_color):
|
246 |
+
# # c = c * 0.28209479177387814 + 0.5
|
247 |
+
# c = np.clip(c, 0, 1)
|
248 |
+
# fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], c[0], c[1], c[2]))
|
249 |
+
# with open('./debug/debug_head_facial_color_updated.obj', 'w') as fp:
|
250 |
+
# head_facial_color = np.matmul(head_facial_color, transfer)
|
251 |
+
# for p, c, w in zip(head_facial_points, head_facial_color, weight):
|
252 |
+
# if w < 0.1:
|
253 |
+
# continue
|
254 |
+
# # c = c * 0.28209479177387814 + 0.5
|
255 |
+
# c = np.clip(c, 0, 1)
|
256 |
+
# fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], c[0], c[1], c[2]))
|
257 |
+
# import pdb; pdb.set_trace()
|
258 |
+
|
259 |
+
return transfer
|
260 |
+
|
261 |
+
|
262 |
+
def blend_color(head_facial_color, body_facial_color, blend_weight):
|
263 |
+
blend_weight = blend_weight.reshape([len(blend_weight)] + [1]*(len(head_facial_color.shape)-1))
|
264 |
+
result = head_facial_color * blend_weight + body_facial_color * (1-blend_weight)
|
265 |
+
return result
|
266 |
+
|
267 |
+
def save_body_face_stitching_data(
|
268 |
+
result_path, smplx_to_faceverse, residual_transf, body_nonface_mask, head_nonface_mask,
|
269 |
+
head_facial_idx, body_facial_idx, corr_idx, face_color_bw, color_transfer):
|
270 |
+
# os.makedirs('./data/%s' % result_suffix, exist_ok=True)
|
271 |
+
# np.savez('./data/%s/body_face_blending_param.npz' % result_suffix,
|
272 |
+
# smplx_to_faceverse=smplx_to_faceverse.astype(np.float32),
|
273 |
+
# residual_transf=residual_transf.astype(np.float32),
|
274 |
+
# body_nonface_mask=body_nonface_mask.astype(np.int32),
|
275 |
+
# head_facial_idx=head_facial_idx.astype(np.int32),
|
276 |
+
# body_facial_idx=body_facial_idx.astype(np.int32),
|
277 |
+
# head_body_facial_corr_idx=corr_idx.astype(np.int32),
|
278 |
+
# face_color_bw=face_color_bw.astype(np.float32),
|
279 |
+
# color_transfer=color_transfer.astype(np.float32))
|
280 |
+
head_color_bw = np.zeros([len(head_nonface_mask)])
|
281 |
+
head_color_bw[head_facial_idx] = face_color_bw
|
282 |
+
head_corr_idx = np.zeros([len(head_nonface_mask)])
|
283 |
+
head_corr_idx[head_facial_idx] = body_facial_idx[corr_idx]
|
284 |
+
np.savez(result_path,
|
285 |
+
smplx_to_faceverse=smplx_to_faceverse.astype(np.float32),
|
286 |
+
residual_transf=residual_transf.astype(np.float32),
|
287 |
+
body_nonface_mask=body_nonface_mask.astype(np.int32),
|
288 |
+
head_nonface_mask=head_nonface_mask.astype(np.int32),
|
289 |
+
head_facial_idx=head_facial_idx.astype(np.int32),
|
290 |
+
body_facial_idx=body_facial_idx.astype(np.int32),
|
291 |
+
head_body_corr_idx=head_corr_idx.astype(np.int32),
|
292 |
+
head_color_bw=head_color_bw.astype(np.float32),
|
293 |
+
color_transfer=color_transfer.astype(np.float32))
|
294 |
+
return
|
295 |
+
|
296 |
+
|
297 |
+
def manual_refine_facial_cropping(head_facial_points, head_facial_idx, head_facial_colors, body_facial_points, body_facial_idx, body_facial_colors):
|
298 |
+
def _save_points_as_obj(fpath, points, points_color):
|
299 |
+
points_color = np.clip(points_color, 0, 1)
|
300 |
+
with open(fpath, 'w') as fp:
|
301 |
+
for p, c in zip(points, points_color):
|
302 |
+
fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], c[0], c[1], c[2]))
|
303 |
+
return
|
304 |
+
|
305 |
+
_save_points_as_obj('./debug/head_facial_points.obj', head_facial_points, head_facial_colors)
|
306 |
+
_save_points_as_obj('./debug/body_facial_points.obj', body_facial_points, body_facial_colors)
|
307 |
+
# trimesh.Trimesh(vertices=head_facial_points, vertex_colors=head_facial_colors).export('./debug/head_facial_points.obj')
|
308 |
+
# trimesh.Trimesh(vertices=body_facial_points, vertex_colors=body_facial_colors).export('./debug/body_facial_points.obj')
|
309 |
+
if True:
|
310 |
+
print('Saving facial points cropped by algorithms. Please remove unnecessary points manually!')
|
311 |
+
import pdb; pdb.set_trace()
|
312 |
+
|
313 |
+
head_facial_points_ = trimesh.load('./debug/head_facial_points.obj').vertices
|
314 |
+
body_facial_points_ = trimesh.load('./debug/body_facial_points.obj').vertices
|
315 |
+
_, head_nndist = search_nearest_correspondence(head_facial_points, head_facial_points_)
|
316 |
+
_, body_nndist = search_nearest_correspondence(body_facial_points, body_facial_points_)
|
317 |
+
head_flag = head_nndist < 1e-4
|
318 |
+
body_flag = body_nndist < 1e-4
|
319 |
+
return head_facial_points[head_flag], head_facial_idx[head_flag], body_facial_points[body_flag], body_facial_idx[body_flag]
|
320 |
+
|
321 |
+
|
322 |
+
def stitch_body_and_head(ref_body_gaussian_path, ref_head_gaussian_path,
|
323 |
+
ref_body_param_path, ref_head_param_path,
|
324 |
+
smplx2faceverse_path, result_folder):
|
325 |
+
device = torch.device("cuda")
|
326 |
+
|
327 |
+
body_gaussians = load_gaussians_from_ply(ref_body_gaussian_path)
|
328 |
+
head_gaussians = load_gaussians_from_ply(ref_head_gaussian_path)
|
329 |
+
global_orient, transl, body_pose, betas = load_body_params(ref_body_param_path)
|
330 |
+
head_pose, head_scale, id_coeff, exp_coeff = load_face_params(ref_head_param_path)
|
331 |
+
smplx_to_faceverse = np.load(smplx2faceverse_path)
|
332 |
+
faceverse_to_smplx = np.linalg.inv(smplx_to_faceverse)
|
333 |
+
|
334 |
+
smpl = SmplTorch(model_file='./AnimatableGaussians/smpl_files/smplx/SMPLX_NEUTRAL.npz')
|
335 |
+
smpl_verts, head_joint_transfmat = get_smpl_verts_and_head_transformation(
|
336 |
+
smpl, global_orient, body_pose, transl, betas)
|
337 |
+
smpl_head_vids = dict(np.load('./data/smpl_models/smplx_head_vidx_and_blendweight.npz'))['head_ids']
|
338 |
+
smpl_head_verts = smpl_verts[smpl_head_vids]
|
339 |
+
|
340 |
+
model_dict = np.load('./data/faceverse_models/faceverse_simple_v2.npy', allow_pickle=True).item()
|
341 |
+
faceverse_model = FaceVerseModel(model_dict, batch_size=1)
|
342 |
+
faceverse_model.init_coeff_tensors(
|
343 |
+
id_coeff=torch.from_numpy(id_coeff).reshape([1, -1]).to(device),
|
344 |
+
scale_coeff=torch.from_numpy(head_scale).reshape([1, 1]).to(device),
|
345 |
+
)
|
346 |
+
faceverse_verts = faceverse_model.forward()['v'][0].detach().cpu().numpy()
|
347 |
+
faceverse_verts = transform_faceverse_to_live_body_space(faceverse_verts, faceverse_to_smplx, head_joint_transfmat)
|
348 |
+
|
349 |
+
livehead2livebody = calc_livehead2livebody(
|
350 |
+
head_pose, smplx_to_faceverse, head_joint_transfmat)
|
351 |
+
head_gaussians_xyz = np.matmul(head_gaussians.xyz, livehead2livebody[:3, :3].transpose()) \
|
352 |
+
+ livehead2livebody[:3, 3].reshape(1, 3)
|
353 |
+
|
354 |
+
# head_facial_points, head_facial_idx = crop_facial_area(smpl_head_verts, head_gaussians_xyz)
|
355 |
+
# body_facial_points, body_facial_idx = crop_facial_area(smpl_head_verts, body_gaussians.xyz)
|
356 |
+
head_facial_points, head_facial_idx = crop_facial_area2(smpl_verts, smpl_head_vids, head_gaussians_xyz)
|
357 |
+
body_facial_points, body_facial_idx = crop_facial_area2(smpl_verts, smpl_head_vids, body_gaussians.xyz)
|
358 |
+
|
359 |
+
residual_transf = np.eye(4)
|
360 |
+
head_facial_points, head_facial_idx, body_facial_points, body_facial_idx = manual_refine_facial_cropping(
|
361 |
+
head_facial_points, head_facial_idx, SH2RGB(head_gaussians.features_dc[head_facial_idx]),
|
362 |
+
body_facial_points, body_facial_idx, SH2RGB(body_gaussians.features_dc[body_facial_idx]))
|
363 |
+
while True:
|
364 |
+
for _ in tqdm.trange(4, desc='Fitting residual transformation'):
|
365 |
+
corr_idx, _ = search_nearest_correspondence(head_facial_points, body_facial_points)
|
366 |
+
corr = body_facial_points[corr_idx]
|
367 |
+
transf = estimate_rigid_transformation(head_facial_points, corr)
|
368 |
+
residual_transf = np.matmul(transf, residual_transf)
|
369 |
+
head_facial_points = np.matmul(head_facial_points, transf[:3, :3].transpose()) + transf[:3, 3].reshape(1, 3)
|
370 |
+
if_crop_well = input('If the facial cropping is good enough? (y/n): ')
|
371 |
+
if if_crop_well == 'y':
|
372 |
+
break
|
373 |
+
else:
|
374 |
+
head_facial_points, head_facial_idx, body_facial_points, body_facial_idx = manual_refine_facial_cropping(
|
375 |
+
head_facial_points, head_facial_idx, SH2RGB(head_gaussians.features_dc[head_facial_idx]),
|
376 |
+
body_facial_points, body_facial_idx, SH2RGB(body_gaussians.features_dc[body_facial_idx]))
|
377 |
+
|
378 |
+
# head_facial_points, head_facial_idx, body_facial_points, body_facial_idx = manual_refine_facial_cropping(
|
379 |
+
# head_facial_points, head_facial_idx, SH2RGB(head_gaussians.features_dc[head_facial_idx]),
|
380 |
+
# body_facial_points, body_facial_idx, SH2RGB(body_gaussians.features_dc[body_facial_idx]))
|
381 |
+
|
382 |
+
# 更改一下逻辑,改成直到对齐为止。
|
383 |
+
|
384 |
+
|
385 |
+
print(np.matmul(residual_transf, livehead2livebody))
|
386 |
+
residual_transf = np.matmul(np.linalg.inv(livehead2livebody), np.matmul(residual_transf, livehead2livebody))
|
387 |
+
corr_idx, _ = search_nearest_correspondence(head_facial_points, body_facial_points)
|
388 |
+
|
389 |
+
# head_gaussians_xyz = np.matmul(head_gaussians_xyz, residual_transf[:3, :3].transpose()) + residual_transf[:3, 3].reshape(1, 3)
|
390 |
+
# faceverse_verts = np.matmul(faceverse_verts, residual_transf[:3, :3].transpose()) + residual_transf[:3, 3].reshape(1, 3)
|
391 |
+
|
392 |
+
# total_transf = np.matmul(residual_transf, livehead2livebody)
|
393 |
+
total_transf = np.matmul(livehead2livebody, residual_transf)
|
394 |
+
print(total_transf)
|
395 |
+
|
396 |
+
color_transfer = estimate_color_transfer(
|
397 |
+
head_facial_points, body_facial_points,
|
398 |
+
head_gaussians.features_dc[head_facial_idx], body_gaussians.features_dc[body_facial_idx],
|
399 |
+
head_gaussians.opacities[head_facial_idx]
|
400 |
+
)
|
401 |
+
# face_color_bw = get_face_blend_weight(head_facial_points, smpl_verts, sigma=0.015)
|
402 |
+
face_color_bw = get_face_blend_weight2(head_facial_points, body_gaussians.xyz, body_facial_idx)
|
403 |
+
|
404 |
+
body_nonface_mask = np.ones([len(body_gaussians.xyz)], dtype=np.bool_)
|
405 |
+
body_nonface_mask[body_facial_idx] = 0
|
406 |
+
head_nonface_mask = np.ones([len(head_gaussians.xyz)], dtype=np.bool_)
|
407 |
+
head_nonface_mask[head_facial_idx] = 0
|
408 |
+
|
409 |
+
save_body_face_stitching_data(
|
410 |
+
os.path.join(result_folder, 'body_head_blending_param.npz'),
|
411 |
+
smplx_to_faceverse, residual_transf, body_nonface_mask, head_nonface_mask,
|
412 |
+
head_facial_idx, body_facial_idx, corr_idx, face_color_bw, color_transfer)
|
413 |
+
|
414 |
+
body_gaussians = apply_transformation_to_gaussians(body_gaussians, np.eye(4))
|
415 |
+
head_gaussians = apply_transformation_to_gaussians(head_gaussians, total_transf, np.eye(3))
|
416 |
+
body_gaussians_wo_face = select_gaussians(body_gaussians, body_nonface_mask)
|
417 |
+
head_gaussians_face_only = select_gaussians(head_gaussians, head_facial_idx)
|
418 |
+
head_gaussians_face_only_new_color = blend_color(
|
419 |
+
head_gaussians_face_only.features_dc, body_gaussians.features_dc[body_facial_idx][corr_idx], face_color_bw)
|
420 |
+
head_gaussians_face_only_new_xyz = blend_color(
|
421 |
+
head_gaussians_face_only.xyz, body_gaussians.xyz[body_facial_idx][corr_idx], face_color_bw)
|
422 |
+
head_gaussians_face_only_new_opacities = blend_color(
|
423 |
+
head_gaussians_face_only.opacities, body_gaussians.opacities[body_facial_idx][corr_idx], face_color_bw)
|
424 |
+
head_gaussians_face_only_new_scales = blend_color(
|
425 |
+
head_gaussians_face_only.scales, body_gaussians.scales[body_facial_idx][corr_idx], face_color_bw)
|
426 |
+
|
427 |
+
head_gaussians_face_only = update_gaussian_attributes(
|
428 |
+
head_gaussians_face_only, new_rgb=head_gaussians_face_only_new_color,
|
429 |
+
new_xyz=head_gaussians_face_only_new_xyz, new_opacity=head_gaussians_face_only_new_opacities,
|
430 |
+
new_scale=head_gaussians_face_only_new_scales)
|
431 |
+
|
432 |
+
full_gaussians = combine_gaussians([body_gaussians_wo_face, head_gaussians_face_only])
|
433 |
+
save_gaussians_as_ply(os.path.join(result_folder, 'full_gaussians.ply'), full_gaussians)
|
render_utils/stitch_funcs.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import argparse
|
5 |
+
import tqdm
|
6 |
+
import json
|
7 |
+
import cv2 as cv
|
8 |
+
import os, glob
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
from render_utils.lib.utils.graphics_utils import focal2fov, getProjectionMatrix
|
13 |
+
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
14 |
+
|
15 |
+
|
16 |
+
def render3(
|
17 |
+
gaussian_vals: dict,
|
18 |
+
bg_color: torch.Tensor,
|
19 |
+
extr: torch.Tensor,
|
20 |
+
intr: torch.Tensor,
|
21 |
+
img_w: int,
|
22 |
+
img_h: int,
|
23 |
+
scaling_modifier = 1.0,
|
24 |
+
override_color = None,
|
25 |
+
compute_cov3D_python = False
|
26 |
+
):
|
27 |
+
means3D = gaussian_vals['positions']
|
28 |
+
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
29 |
+
screenspace_points = torch.zeros_like(means3D, dtype = means3D.dtype, requires_grad = True, device = "cuda") + 0
|
30 |
+
try:
|
31 |
+
screenspace_points.retain_grad()
|
32 |
+
except:
|
33 |
+
pass
|
34 |
+
means2D = screenspace_points
|
35 |
+
opacity = gaussian_vals['opacity']
|
36 |
+
|
37 |
+
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
38 |
+
# scaling / rotation by the rasterizer.
|
39 |
+
scales = None
|
40 |
+
rotations = None
|
41 |
+
cov3D_precomp = None
|
42 |
+
scales = gaussian_vals['scales']
|
43 |
+
rotations = gaussian_vals['rotations']
|
44 |
+
|
45 |
+
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
46 |
+
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
47 |
+
shs = None
|
48 |
+
# colors_precomp = None
|
49 |
+
# if override_color is None:
|
50 |
+
# shs = gaussian_vals['shs']
|
51 |
+
# else:
|
52 |
+
# colors_precomp = override_color
|
53 |
+
if 'colors' in gaussian_vals:
|
54 |
+
colors_precomp = gaussian_vals['colors']
|
55 |
+
else:
|
56 |
+
colors_precomp = None
|
57 |
+
|
58 |
+
# Set up rasterization configuration
|
59 |
+
FoVx = focal2fov(intr[0, 0].item(), img_w)
|
60 |
+
FoVy = focal2fov(intr[1, 1].item(), img_h)
|
61 |
+
tanfovx = math.tan(FoVx * 0.5)
|
62 |
+
tanfovy = math.tan(FoVy * 0.5)
|
63 |
+
world_view_transform = extr.transpose(1, 0).cuda()
|
64 |
+
projection_matrix = getProjectionMatrix(znear = 0.1, zfar = 100, fovX = FoVx, fovY = FoVy, K = intr, img_w = img_w, img_h = img_h).transpose(0, 1).cuda()
|
65 |
+
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
|
66 |
+
camera_center = torch.linalg.inv(extr)[:3, 3]
|
67 |
+
|
68 |
+
raster_settings = GaussianRasterizationSettings(
|
69 |
+
image_height = img_h,
|
70 |
+
image_width = img_w,
|
71 |
+
tanfovx = tanfovx,
|
72 |
+
tanfovy = tanfovy,
|
73 |
+
bg = bg_color,
|
74 |
+
scale_modifier = scaling_modifier,
|
75 |
+
viewmatrix = world_view_transform,
|
76 |
+
projmatrix = full_proj_transform,
|
77 |
+
sh_degree = gaussian_vals['max_sh_degree'],
|
78 |
+
campos = camera_center,
|
79 |
+
prefiltered = False,
|
80 |
+
debug = False
|
81 |
+
)
|
82 |
+
|
83 |
+
rasterizer = GaussianRasterizer(raster_settings = raster_settings)
|
84 |
+
|
85 |
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
86 |
+
rendered_image, radii = rasterizer(
|
87 |
+
means3D = means3D,
|
88 |
+
means2D = means2D,
|
89 |
+
shs = shs,
|
90 |
+
colors_precomp = colors_precomp,
|
91 |
+
opacities = opacity,
|
92 |
+
scales = scales,
|
93 |
+
rotations = rotations,
|
94 |
+
cov3D_precomp = cov3D_precomp)
|
95 |
+
|
96 |
+
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
97 |
+
# They will be excluded from value updates used in the splitting criteria.
|
98 |
+
return {
|
99 |
+
"render": rendered_image,
|
100 |
+
"viewspace_points": screenspace_points,
|
101 |
+
"visibility_filter": radii > 0,
|
102 |
+
"radii": radii
|
103 |
+
}
|
104 |
+
|
105 |
+
|
106 |
+
def blend_color(head_facial_color, body_facial_color, blend_weight):
|
107 |
+
blend_weight = blend_weight.reshape([len(blend_weight)] + [1]*(len(head_facial_color.shape)-1))
|
108 |
+
result = head_facial_color * blend_weight + body_facial_color * (1-blend_weight)
|
109 |
+
return result
|
110 |
+
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def paste_back_with_linear_interp(pasteback_scale, pasteback_center, src, tgt_size):
|
114 |
+
pasteback_topleft = [pasteback_center[0] - src.shape[1]/2/pasteback_scale,
|
115 |
+
pasteback_center[1] - src.shape[0]/2/pasteback_scale]
|
116 |
+
|
117 |
+
h, w = src.shape[0], src.shape[1]
|
118 |
+
grayscale = False
|
119 |
+
if len(src.shape) == 2:
|
120 |
+
src = src.reshape([h, w, 1])
|
121 |
+
grayscale = True
|
122 |
+
src = torch.from_numpy(src)
|
123 |
+
src = src.permute(2, 0, 1).unsqueeze(0)
|
124 |
+
grid = torch.meshgrid(torch.arange(0, tgt_size[0]), torch.arange(0, tgt_size[1]), indexing='xy')
|
125 |
+
grid = torch.stack(grid, dim = -1).float().to(src.device).unsqueeze(0)
|
126 |
+
grid[..., 0] = (grid[..., 0] - pasteback_topleft[0]) * pasteback_scale
|
127 |
+
grid[..., 1] = (grid[..., 1] - pasteback_topleft[1]) * pasteback_scale
|
128 |
+
|
129 |
+
grid[..., 0] = grid[..., 0] / (src.shape[-1] / 2.0) - 1.0
|
130 |
+
grid[..., 1] = grid[..., 1] / (src.shape[-2] / 2.0) - 1.0
|
131 |
+
out = F.grid_sample(src, grid, align_corners = True)
|
132 |
+
out = out[0].detach().permute(1, 2, 0).cpu().numpy()
|
133 |
+
if grayscale:
|
134 |
+
out = out[:, :, 0]
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
def soften_blending_mask(blending_mask, valid_mask):
|
139 |
+
blending_mask = np.clip(blending_mask*2.0, 0.0, 1.0)
|
140 |
+
blending_mask = cv.erode(blending_mask, np.ones((5, 5))) * valid_mask
|
141 |
+
blending_mask_bk = np.copy(blending_mask)
|
142 |
+
blending_mask = cv.blur(blending_mask*valid_mask, (25, 25))
|
143 |
+
valid_mask = cv.blur(valid_mask, (25, 25))
|
144 |
+
blending_mask = blending_mask / (valid_mask + 1e-6) * blending_mask_bk
|
145 |
+
return blending_mask
|