pengc02 commited on
Commit
7648567
1 Parent(s): 42fd375

upload_part1

Browse files
Files changed (33) hide show
  1. render_utils/__pycache__/stitch_body_and_head.cpython-310.pyc +0 -0
  2. render_utils/__pycache__/stitch_funcs.cpython-310.pyc +0 -0
  3. render_utils/calc_smplx2faceverse.py +238 -0
  4. render_utils/camera_dir.py +171 -0
  5. render_utils/lib/networks/__init__.py +0 -0
  6. render_utils/lib/networks/__pycache__/__init__.cpython-310.pyc +0 -0
  7. render_utils/lib/networks/__pycache__/__init__.cpython-38.pyc +0 -0
  8. render_utils/lib/networks/__pycache__/faceverse_torch.cpython-310.pyc +0 -0
  9. render_utils/lib/networks/__pycache__/faceverse_torch.cpython-38.pyc +0 -0
  10. render_utils/lib/networks/__pycache__/smpl_torch.cpython-310.pyc +0 -0
  11. render_utils/lib/networks/__pycache__/smpl_torch.cpython-38.pyc +0 -0
  12. render_utils/lib/networks/faceverse_torch.py +292 -0
  13. render_utils/lib/networks/smpl_torch.py +341 -0
  14. render_utils/lib/utils/__init__.py +0 -0
  15. render_utils/lib/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  16. render_utils/lib/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  17. render_utils/lib/utils/__pycache__/gaussian_np_utils.cpython-310.pyc +0 -0
  18. render_utils/lib/utils/__pycache__/gaussian_np_utils.cpython-38.pyc +0 -0
  19. render_utils/lib/utils/__pycache__/geometry.cpython-310.pyc +0 -0
  20. render_utils/lib/utils/__pycache__/geometry.cpython-38.pyc +0 -0
  21. render_utils/lib/utils/__pycache__/graphics_utils.cpython-310.pyc +0 -0
  22. render_utils/lib/utils/__pycache__/graphics_utils.cpython-38.pyc +0 -0
  23. render_utils/lib/utils/__pycache__/rotation_conversions.cpython-310.pyc +0 -0
  24. render_utils/lib/utils/__pycache__/rotation_conversions.cpython-38.pyc +0 -0
  25. render_utils/lib/utils/__pycache__/sh_utils.cpython-310.pyc +0 -0
  26. render_utils/lib/utils/__pycache__/sh_utils.cpython-38.pyc +0 -0
  27. render_utils/lib/utils/gaussian_np_utils.py +162 -0
  28. render_utils/lib/utils/geometry.py +517 -0
  29. render_utils/lib/utils/graphics_utils.py +181 -0
  30. render_utils/lib/utils/rotation_conversions.py +586 -0
  31. render_utils/lib/utils/sh_utils.py +118 -0
  32. render_utils/stitch_body_and_head.py +433 -0
  33. render_utils/stitch_funcs.py +145 -0
render_utils/__pycache__/stitch_body_and_head.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
render_utils/__pycache__/stitch_funcs.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
render_utils/calc_smplx2faceverse.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import time, random
4
+
5
+ import torch
6
+ import numpy as np
7
+ import cv2
8
+ import smplx
9
+ import pickle
10
+ import trimesh
11
+ import os, glob
12
+ import argparse
13
+
14
+ from .lib.networks.faceverse_torch import FaceVerseModel
15
+
16
+
17
+ def estimate_rigid_transformation(src, tgt):
18
+ src = src.transpose()
19
+ tgt = tgt.transpose()
20
+ mu1, mu2 = src.mean(axis=1, keepdims=True), tgt.mean(axis=1, keepdims=True)
21
+ X1, X2 = src - mu1, tgt - mu2
22
+
23
+ K = X1.dot(X2.T)
24
+ U, s, Vh = np.linalg.svd(K)
25
+ V = Vh.T
26
+ Z = np.eye(U.shape[0])
27
+ Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
28
+ R = V.dot(Z.dot(U.T))
29
+ t = mu2 - R.dot(mu1)
30
+
31
+ orient, _ = cv2.Rodrigues(R)
32
+ orient = orient.reshape([-1])
33
+ t = t.reshape([-1])
34
+ return orient, t
35
+
36
+
37
+ def load_smpl_beta(body_fitting_result_fpath):
38
+ if os.path.isdir(body_fitting_result_fpath):
39
+ body_fitting_result = torch.load(
40
+ os.path.join(body_fitting_result_fpath, 'checkpoints/latest.pt'), map_location='cpu')
41
+ betas = body_fitting_result['betas']['weight']
42
+ elif body_fitting_result_fpath.endswith('.pt'):
43
+ body_fitting_result = torch.load(body_fitting_result_fpath, map_location='cpu')
44
+ betas = body_fitting_result['beta'].reshape(1, -1)
45
+ elif body_fitting_result_fpath.endswith('.npz'):
46
+ body_fitting_result = np.load(body_fitting_result_fpath)
47
+ betas = body_fitting_result['betas'].reshape(1, -1)
48
+ betas = torch.from_numpy(betas.astype(np.float32))
49
+ else:
50
+ raise ValueError('Unknown body fitting result file format: {}'.format(body_fitting_result_fpath))
51
+ return betas
52
+
53
+
54
+ def load_face_id_scale_param(face_fitting_result_fpath):
55
+ if os.path.isfile(face_fitting_result_fpath):
56
+ face_fitting_result = dict(np.load(face_fitting_result_fpath))
57
+ id_tensor = face_fitting_result['id_coeff'].astype(np.float32)
58
+ scale_tensor = face_fitting_result['scale'].astype(np.float32)
59
+ id_tensor = torch.from_numpy(id_tensor).reshape(1, -1)
60
+ scale_tensor = torch.from_numpy(scale_tensor).reshape(1, -1)
61
+ else:
62
+ param_paths = sorted(glob.glob(os.path.join(face_fitting_result_fpath, '*', 'params.npz')))
63
+ param = np.load(param_paths[0])
64
+ id_tensor = torch.from_numpy(param['id_coeff']).reshape(1, -1)
65
+ scale_tensor = torch.from_numpy(param['scale']).reshape(1, 1)
66
+
67
+ return id_tensor, scale_tensor
68
+
69
+
70
+ def calc_smplx2faceverse(body_fitting_result_fpath, face_fitting_result_fpath, output_dir):
71
+ device = torch.device('cuda')
72
+ os.makedirs(output_dir, exist_ok=True)
73
+
74
+ betas = load_smpl_beta(body_fitting_result_fpath)
75
+ id_tensor, scale_tensor = load_face_id_scale_param(face_fitting_result_fpath)
76
+
77
+ smpl = smplx.SMPLX(model_path='./AnimatableGaussians/smpl_files/smplx', gender='neutral',
78
+ use_pca=True, num_pca_comps=45, flat_hand_mean=True, batch_size=1)
79
+ flame = smplx.FLAME(model_path='./AnimatableGaussians/smpl_files/FLAME2019', gender='neutral', batch_size=1)
80
+
81
+ pose = np.zeros([63], dtype=np.float32)
82
+
83
+ pose = torch.from_numpy(pose).unsqueeze(0)
84
+ smpl_out = smpl(body_pose=pose, betas=betas)
85
+ verts = smpl_out.vertices.detach().cpu().squeeze(0).numpy()
86
+ flame_out = flame()
87
+ verts_flame = flame_out.vertices.detach().cpu().squeeze(0).numpy()
88
+
89
+ smplx2flame_data = np.load('./data/smpl_models/smplx_mano_flame_correspondences/SMPL-X__FLAME_vertex_ids.npy')
90
+ verts_flame_on_smplx = verts[smplx2flame_data]
91
+
92
+ orient, t = estimate_rigid_transformation(verts_flame, verts_flame_on_smplx)
93
+ R, _ = cv2.Rodrigues(orient)
94
+
95
+ rel_transf = np.eye(4)
96
+ rel_transf[:3, :3] = R
97
+ rel_transf[:3, 3] = t.reshape(-1)
98
+ np.save('%s/flame_to_smplx.npy' % (output_dir), rel_transf.astype(np.float32))
99
+
100
+ # TODO: DELETE ME
101
+ with open('./debug/debug_verts_smplx.obj', 'w') as fp:
102
+ for v in verts:
103
+ fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
104
+ with open('./debug/debug_verts_flame_in_smplx.obj', 'w') as fp:
105
+ for v in np.matmul(verts_flame, R.transpose()) + t.reshape([1, 3]):
106
+ fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
107
+
108
+ # align Faceverse to T-pose FLAME on SMPL-X
109
+ faceverse_mesh = trimesh.load_mesh('./data/smpl_models/faceverse2flame/faceverse_icp.obj', process=False)
110
+ verts_faceverse_ref = np.matmul(np.asarray(faceverse_mesh.vertices), R.transpose()) + t.reshape(1, 3)
111
+
112
+ # TODO: DELETE ME
113
+ with open('./debug/debug_verts_faceverse_ref.obj', 'w') as fp:
114
+ for v in verts_faceverse_ref:
115
+ fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
116
+
117
+ model_dict = np.load('./data/faceverse_models/faceverse_simple_v2.npy', allow_pickle=True).item()
118
+ faceverse_model = FaceVerseModel(model_dict, batch_size=1)
119
+ faceverse_model.init_coeff_tensors()
120
+ coeffs = faceverse_model.get_packed_tensors()
121
+ fv_out = faceverse_model.forward(coeffs=coeffs)
122
+ verts_faceverse = fv_out['v'].squeeze(0).detach().cpu().numpy()
123
+
124
+ # calculate the relative transformation between FLAME in canonical pose and its position on SMPL-X
125
+ orient, t = estimate_rigid_transformation(verts_faceverse, verts_faceverse_ref)
126
+ orient = torch.from_numpy(orient.astype(np.float32)).unsqueeze(0).to(device)
127
+ t = torch.from_numpy(t.astype(np.float32)).unsqueeze(0).to(device)
128
+
129
+ # optimize the Faceverse to fit SMPL-X
130
+ faceverse_model.init_coeff_tensors(
131
+ rot_coeff=orient, trans_coeff=t, id_coeff=id_tensor.to(device), scale_coeff=scale_tensor.to(device))
132
+ nonrigid_optim_params = [
133
+ faceverse_model.get_exp_tensor(), faceverse_model.get_rot_tensor(), faceverse_model.get_trans_tensor(),
134
+ # faceverse_model.get_scale_tensor(), faceverse_model.get_id_tensor()
135
+ ]
136
+ nonrigid_optimizer = torch.optim.Adam(nonrigid_optim_params, lr=1e-1)
137
+ verts_faceverse_ref = torch.from_numpy(verts_faceverse_ref.astype(np.float32)).to(device).unsqueeze(0)
138
+ for iter in range(1000):
139
+ coeffs = faceverse_model.get_packed_tensors()
140
+ fv_out = faceverse_model.forward(coeffs=coeffs)
141
+ verts_pred = fv_out['v']
142
+ loss = torch.mean(torch.square(verts_pred - verts_faceverse_ref))
143
+ if iter % 10 == 0:
144
+ print(loss.item())
145
+ nonrigid_optimizer.zero_grad()
146
+ loss.backward()
147
+ nonrigid_optimizer.step()
148
+
149
+ np.savez('%s/faceverse_param_to_smplx.npz' % (output_dir), {
150
+ 'id': faceverse_model.get_id_tensor().detach().cpu().numpy(),
151
+ 'exp': faceverse_model.get_exp_tensor().detach().cpu().numpy(),
152
+ 'rot': faceverse_model.get_rot_tensor().detach().cpu().numpy(),
153
+ 'transl': faceverse_model.get_trans_tensor().detach().cpu().numpy(),
154
+ 'scale': faceverse_model.get_scale_tensor().detach().cpu().numpy(),
155
+ })
156
+
157
+ # calculate SMPLX to faceverse space transformation (without scale)
158
+ orient = faceverse_model.get_rot_tensor().detach().cpu().numpy()
159
+ transl = faceverse_model.get_trans_tensor().detach().cpu().numpy()
160
+ rotmat, _ = cv2.Rodrigues(orient)
161
+ transform_mat = np.eye(4)
162
+ transform_mat[:3, :3] = rotmat
163
+ transform_mat[:3, 3] = transl
164
+ transform_mat = np.linalg.inv(transform_mat)
165
+ np.save('%s/smplx_to_faceverse_space.npy' % (output_dir), transform_mat.astype(np.float32))
166
+
167
+ # calculate SMPLX to faceverse distance
168
+ dists = []
169
+ verts_faceverse_ref = verts_faceverse_ref.detach().cpu().numpy()
170
+ for v in verts:
171
+ dist = np.linalg.norm(v.reshape(1, 3) - verts_faceverse_ref, axis=-1)
172
+ dist = np.min(dist)
173
+ dists.append(dist)
174
+ dists = np.asarray(dists)
175
+ np.save('%s/smplx_verts_to_faceverse_dist.npy' % (output_dir), dists.astype(np.float32))
176
+
177
+ # sample nodes on facial area
178
+ dists_ = np.ones_like(dists)
179
+ smplx_topo_new = np.load('./data/smpl_models/smplx_topo_new.npy')
180
+ valid_vids = set(smplx_topo_new.reshape([-1]).tolist())
181
+ dists_[np.asarray(list(valid_vids))] = dists[np.asarray(list(valid_vids))]
182
+
183
+ vids_on_face = np.where(dists_ < 0.01)[0]
184
+ verts_on_face = verts[vids_on_face]
185
+ geod_dist_mat = np.linalg.norm(np.expand_dims(verts_on_face, axis=0) - np.expand_dims(verts_on_face, axis=1),
186
+ axis=2)
187
+ nodes = [0] # nose
188
+ dist_nodes_to_rest_points = geod_dist_mat[nodes[0]]
189
+ for i in range(1, 256):
190
+ idx = np.argmax(dist_nodes_to_rest_points)
191
+ nodes.append(idx)
192
+ new_dist = geod_dist_mat[idx]
193
+ update_flag = new_dist < dist_nodes_to_rest_points
194
+ dist_nodes_to_rest_points[update_flag] = new_dist[update_flag]
195
+
196
+ # with open('./debug/debug_face_nodes.obj', 'w') as fp:
197
+ # for n in verts_on_face[np.asarray(nodes)]:
198
+ # fp.write('v %f %f %f\n' % (n[0], n[1], n[2]))
199
+ vids_on_faces_sampled = vids_on_face[np.asarray(nodes)]
200
+ vids_on_faces_sampled = np.ascontiguousarray(vids_on_faces_sampled).astype(np.int32)
201
+ np.save('%s/vids_on_faces_sampled.npy' % (output_dir), vids_on_faces_sampled)
202
+
203
+ # test SMPLX-to-faceverse space transformation (without scale)
204
+ verts_smpl_in_faceverse = np.matmul(verts, transform_mat[:3, :3].transpose()) + \
205
+ transform_mat[:3, 3].reshape(1, 3)
206
+ with open('./debug/debug_verts_smpl_in_faceverse.obj', 'w') as fp:
207
+ for v in verts_smpl_in_faceverse:
208
+ fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
209
+
210
+ # save personalized, canonical faceverse model
211
+ faceverse_model.init_coeff_tensors(id_coeff=id_tensor.to(device), scale_coeff=scale_tensor.to(device))
212
+ coeffs = faceverse_model.get_packed_tensors()
213
+ fv_out = faceverse_model.forward(coeffs=coeffs)
214
+ verts_faceverse = fv_out['v'].squeeze(0).detach().cpu().numpy()
215
+ with open('./debug/debug_verts_faceverse.obj', 'w') as fp:
216
+ for v in verts_faceverse:
217
+ fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
218
+
219
+
220
+ if __name__ == '__main__':
221
+ # body_fitting_result_dir = 'D:/UpperBodyAvatar/code/smplfitting_multiview_sequence_smplx/results/Shuangqing/zzr_fullbody_20221130_01_2k/whole.pt'
222
+ # face_fitting_result_dir = 'D:\\UpperBodyAvatar\\data\\Shuangqing\\zzr_face_20221130_01_2k\\faceverse_params'
223
+ #
224
+ # output_dir = './data/faceverse/'
225
+ # result_suffix = 'shuangqing_zzr'
226
+
227
+ # body_fitting_result_dir = 'D:/Product/FullAppRelease/smplfitting_multiview_sequence_smplx/results/body_data_model_20231224/whole.pt'
228
+ # face_fitting_result_dir = 'D:/Product/data/HuiyingCenter/20231224_model/model_20231224_face_data/faceverse_params'
229
+ #
230
+ # output_dir = './data/faceverse/'
231
+ # result_suffix = 'huiyin_model20231224'
232
+
233
+ body_fitting_result_dir = 'D:/Product/FullAppRelease/smplfitting_multiview_sequence_smplx/results/body_data_male20230530_betterhand/whole.pt'
234
+ face_fitting_result_dir = 'D:/Product/data/HuiyingCenter/20230531_models/male20230530_face_data/faceverse_params'
235
+
236
+ output_dir = './data/body_face_stitching/huiyin_male20230530'
237
+
238
+ calc_smplx2faceverse(body_fitting_result_dir, face_fitting_result_dir, output_dir)
render_utils/camera_dir.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def get_camera_dir(idx):
5
+ img_scale = self.body['test'].get('img_scale', 1.0)
6
+ view_setting = self.body['test'].get('view_setting', 'free')
7
+ if view_setting == 'camera':
8
+ # training view setting
9
+ cam_id = self.body['test']['render_view_idx']
10
+ intr = self.dataset.intr_mats[cam_id].copy()
11
+ intr[:2] *= img_scale
12
+ extr = self.dataset.extr_mats[cam_id].copy()
13
+ img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale)
14
+ elif view_setting.startswith('free'):
15
+ # free view setting
16
+ # frame_num_per_circle = 360
17
+ # print(self.opt['test'].get('global_orient', False))
18
+ frame_num_per_circle = 360
19
+ rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi
20
+
21
+ extr = visualize_util.calc_free_mv(object_center,
22
+ tar_pos = np.array([0, 0, 2.5]),
23
+ rot_Y = rot_Y,
24
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
25
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
26
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
27
+ intr[:2] *= img_scale
28
+ img_h = int(1024 * img_scale)
29
+ img_w = int(1024 * img_scale)
30
+
31
+ extr_list.append(extr)
32
+ intr_list.append(intr)
33
+ img_h_list.append(img_h)
34
+ img_w_list.append(img_w)
35
+
36
+ elif view_setting.startswith('degree120'):
37
+ print('we render 120 degree')
38
+ # +- 60 degree
39
+ frame_per_cycle = 480
40
+ max_degree = 60
41
+ frame_half_cycle = frame_per_cycle // 2
42
+ if idx%frame_per_cycle < frame_per_cycle/2:
43
+ rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
44
+ # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi
45
+ else:
46
+ rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
47
+
48
+ # to radian
49
+ rot_Y = rot_Y * np.pi / 180
50
+ if rot_Y<0:
51
+ rot_Y = rot_Y + 2 * np.pi
52
+ # print('rot_Y: ', rot_Y)
53
+ extr = visualize_util.calc_free_mv(object_center,
54
+ tar_pos = np.array([0, 0, 2.5]),
55
+ rot_Y = rot_Y,
56
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
57
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
58
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
59
+ intr[:2] *= img_scale
60
+ img_h = int(1024 * img_scale)
61
+ img_w = int(1024 * img_scale)
62
+
63
+ extr_list.append(extr)
64
+ intr_list.append(intr)
65
+ img_h_list.append(img_h)
66
+ img_w_list.append(img_w)
67
+
68
+ elif view_setting.startswith('degree90'):
69
+ print('we render 90 degree')
70
+ # +- 60 degree
71
+ frame_per_cycle = 360
72
+ max_degree = 45
73
+ frame_half_cycle = frame_per_cycle // 2
74
+ if idx%frame_per_cycle < frame_per_cycle/2:
75
+ rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
76
+ # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi
77
+ else:
78
+ rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
79
+
80
+ # to radian
81
+ rot_Y = rot_Y * np.pi / 180
82
+ if rot_Y<0:
83
+ rot_Y = rot_Y + 2 * np.pi
84
+ # print('rot_Y: ', rot_Y)
85
+ extr = visualize_util.calc_free_mv(object_center,
86
+ tar_pos = np.array([0, 0, 2.5]),
87
+ rot_Y = rot_Y,
88
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
89
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
90
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
91
+ intr[:2] *= img_scale
92
+ img_h = int(1024 * img_scale)
93
+ img_w = int(1024 * img_scale)
94
+
95
+ extr_list.append(extr)
96
+ intr_list.append(intr)
97
+ img_h_list.append(img_h)
98
+ img_w_list.append(img_w)
99
+
100
+
101
+ elif view_setting.startswith('front'):
102
+ # front view setting
103
+ extr = visualize_util.calc_free_mv(object_center,
104
+ tar_pos = np.array([0, 0, 2.5]),
105
+ rot_Y = 0.,
106
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
107
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
108
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
109
+ intr[:2] *= img_scale
110
+ img_h = int(1024 * img_scale)
111
+ img_w = int(1024 * img_scale)
112
+
113
+ extr_list.append(extr)
114
+ intr_list.append(intr)
115
+ img_h_list.append(img_h)
116
+ img_w_list.append(img_w)
117
+
118
+ # print('extr: ', extr)
119
+ # print('intr: ', intr)
120
+ # print('img_h: ', img_h)
121
+ # print('img_w: ', img_w)
122
+ # exit()
123
+
124
+
125
+
126
+ elif view_setting.startswith('back'):
127
+ # back view setting
128
+ extr = visualize_util.calc_free_mv(object_center,
129
+ tar_pos = np.array([0, 0, 2.5]),
130
+ rot_Y = np.pi,
131
+ rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0.,
132
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
133
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
134
+ intr[:2] *= img_scale
135
+ img_h = int(1024 * img_scale)
136
+ img_w = int(1024 * img_scale)
137
+ elif view_setting.startswith('moving'):
138
+ # moving camera setting
139
+ extr = visualize_util.calc_free_mv(object_center,
140
+ # tar_pos = np.array([0, 0, 3.0]),
141
+ # rot_Y = -0.3,
142
+ tar_pos = np.array([0, 0, 2.5]),
143
+ rot_Y = 0.,
144
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
145
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
146
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
147
+ intr[:2] *= img_scale
148
+ img_h = int(1024 * img_scale)
149
+ img_w = int(1024 * img_scale)
150
+ elif view_setting.startswith('cano'):
151
+ cano_center = self.dataset.cano_bounds.mean(0)
152
+ extr = np.identity(4, np.float32)
153
+ extr[:3, 3] = -cano_center
154
+ rot_x = np.identity(4, np.float32)
155
+ rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0]
156
+ extr = rot_x @ extr
157
+ f_len = 5000
158
+ extr[2, 3] += f_len / 512
159
+ intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32)
160
+ # item = self.dataset.getitem(idx,
161
+ # training = False,
162
+ # extr = extr,
163
+ # intr = intr,
164
+ # img_w = 1024,
165
+ # img_h = 1024)
166
+ img_w, img_h = 1024, 1024
167
+ # item['live_smpl_v'] = item['cano_smpl_v']
168
+ # item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1)
169
+ # item['live_bounds'] = item['cano_bounds']
170
+ else:
171
+ raise ValueError('Invalid view setting for animation!')
render_utils/lib/networks/__init__.py ADDED
File without changes
render_utils/lib/networks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (196 Bytes). View file
 
render_utils/lib/networks/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (189 Bytes). View file
 
render_utils/lib/networks/__pycache__/faceverse_torch.cpython-310.pyc ADDED
Binary file (8.88 kB). View file
 
render_utils/lib/networks/__pycache__/faceverse_torch.cpython-38.pyc ADDED
Binary file (9.05 kB). View file
 
render_utils/lib/networks/__pycache__/smpl_torch.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
render_utils/lib/networks/__pycache__/smpl_torch.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
render_utils/lib/networks/faceverse_torch.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from ..utils.rotation_conversions import axis_angle_to_matrix
5
+
6
+
7
+ class FaceVerseModel(nn.Module):
8
+ def __init__(self, model_dict, batch_size=1, device='cuda:0', **kargs):
9
+ super(FaceVerseModel, self).__init__()
10
+
11
+ self.batch_size = batch_size
12
+ self.device = torch.device(device)
13
+
14
+ self.rotXYZ = torch.eye(3).view(1, 3, 3).repeat(3, 1, 1).view(3, 1, 3, 3).to(self.device)
15
+
16
+ self.skinmask = torch.tensor(model_dict['skinmask'], requires_grad=False, device=self.device)
17
+
18
+ self.kp_inds = torch.tensor(model_dict['keypoints'].reshape(-1, 1), requires_grad=False).squeeze().long().to(self.device)
19
+
20
+ self.meanshape = torch.tensor(model_dict['meanshape'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)
21
+ self.meantex = torch.tensor(model_dict['meantex'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)
22
+
23
+ self.idBase = torch.tensor(model_dict['idBase'], dtype=torch.float32, requires_grad=False, device=self.device)
24
+ self.expBase = torch.tensor(model_dict['exBase'], dtype=torch.float32, requires_grad=False, device=self.device)
25
+ self.texBase = torch.tensor(model_dict['texBase'], dtype=torch.float32, requires_grad=False, device=self.device)
26
+ # self.expBase[:, :8] *= 4 # 控制眼睛
27
+ # self.expBase[:, 8:] *= 2 # 控制其它
28
+
29
+ self.tri = torch.tensor(model_dict['tri'], dtype=torch.int64, requires_grad=False, device=self.device)
30
+ # self.tri = self.tri[:, [0, 2, 1]] # TODO
31
+ self.point_buf = torch.tensor(model_dict['point_buf'], dtype=torch.int64, requires_grad=False, device=self.device)
32
+
33
+ self.num_vertex = self.meanshape.shape[0] // 3
34
+ self.id_dims = self.idBase.shape[1]
35
+ self.tex_dims = self.texBase.shape[1]
36
+ self.exp_dims = self.expBase.shape[1]
37
+ self.all_dims = self.id_dims + self.tex_dims + self.exp_dims
38
+ self.init_coeff_tensors()
39
+
40
+ # for tracking by landmarks
41
+ self.kp_inds_view = torch.cat([self.kp_inds[:, None] * 3, self.kp_inds[:, None] * 3 + 1, self.kp_inds[:, None] * 3 + 2], dim=1).flatten()
42
+ self.idBase_view = self.idBase[self.kp_inds_view, :].detach().clone()
43
+ self.expBase_view = self.expBase[self.kp_inds_view, :].detach().clone()
44
+ self.meanshape_view = self.meanshape[:, self.kp_inds_view].detach().clone()
45
+
46
+ # zxc
47
+ self.identity = torch.eye(3, dtype=torch.float32, device=self.device)
48
+
49
+ def init_coeff_tensors(self, id_coeff=None, tex_coeff=None, exp_coeff=None, gamma_coeff=None, trans_coeff=None, rot_coeff=None, scale_coeff=None):
50
+ if id_coeff is None:
51
+ self.id_tensor = torch.zeros(
52
+ (1, self.id_dims), dtype=torch.float32,
53
+ requires_grad=True, device=self.device)
54
+ else:
55
+ assert id_coeff.shape == (1, self.id_dims)
56
+ # self.id_tensor = torch.tensor(id_coeff, dtype=torch.float32, requires_grad=True, device=self.device)
57
+ self.id_tensor = id_coeff.clone().detach().requires_grad_(True)
58
+
59
+ if tex_coeff is None:
60
+ self.tex_tensor = torch.zeros(
61
+ (1, self.tex_dims), dtype=torch.float32,
62
+ requires_grad=True, device=self.device)
63
+ else:
64
+ assert tex_coeff.shape == (1, self.tex_dims)
65
+ # self.tex_tensor = torch.tensor(tex_coeff, dtype=torch.float32, requires_grad=True, device=self.device)
66
+ self.tex_tensor = tex_coeff.clone().detach().requires_grad_(True)
67
+
68
+ if exp_coeff is None:
69
+ self.exp_tensor = torch.zeros(
70
+ (self.batch_size, self.exp_dims), dtype=torch.float32,
71
+ requires_grad=True, device=self.device)
72
+ else:
73
+ assert exp_coeff.shape == (1, self.exp_dims)
74
+ # self.exp_tensor = torch.tensor(exp_coeff, dtype=torch.float32, requires_grad=True, device=self.device)
75
+ self.exp_tensor = exp_coeff.clone().detach().requires_grad_(True)
76
+
77
+ if gamma_coeff is None:
78
+ self.gamma_tensor = torch.zeros(
79
+ (self.batch_size, 27), dtype=torch.float32,
80
+ requires_grad=True, device=self.device)
81
+ else:
82
+ # self.gamma_tensor = torch.tensor(
83
+ # gamma_coeff, dtype=torch.float32,
84
+ # requires_grad=True, device=self.device)
85
+ self.gamma_tensor = gamma_coeff.clone().detach().requires_grad_(True)
86
+
87
+ if trans_coeff is None:
88
+ self.trans_tensor = torch.zeros(
89
+ (self.batch_size, 3), dtype=torch.float32,
90
+ requires_grad=True, device=self.device)
91
+ else:
92
+ # self.trans_tensor = torch.tensor(
93
+ # trans_coeff, dtype=torch.float32,
94
+ # requires_grad=True, device=self.device)
95
+ self.trans_tensor = trans_coeff.clone().detach().requires_grad_(True)
96
+
97
+ if scale_coeff is None:
98
+ self.scale_tensor = 0.18 * torch.ones((self.batch_size, 1), dtype=torch.float32, device=self.device)
99
+ # self.scale_tensor = torch.ones((self.batch_size, 1), dtype=torch.float32, device=self.device)
100
+ self.scale_tensor.requires_grad_(True)
101
+ # self.scale_tensor = torch.ones(
102
+ # (self.batch_size, 1), dtype=torch.float32,
103
+ # requires_grad=True, device=self.device)
104
+ else:
105
+ # self.scale_tensor = torch.tensor(
106
+ # scale_coeff, dtype=torch.float32,
107
+ # requires_grad=True, device=self.device)
108
+ self.scale_tensor = scale_coeff.clone().detach().requires_grad_(True)
109
+
110
+ if rot_coeff is None:
111
+ self.rot_tensor = torch.zeros(
112
+ (self.batch_size, 3), dtype=torch.float32,
113
+ requires_grad=True, device=self.device)
114
+ else:
115
+ # self.rot_tensor = torch.tensor(
116
+ # rot_coeff, dtype=torch.float32,
117
+ # requires_grad=True, device=self.device)
118
+ self.rot_tensor = rot_coeff.clone().detach().requires_grad_(True)
119
+
120
+ def get_lms(self, vs):
121
+ lms = vs[:, self.kp_inds, :]
122
+ return lms
123
+
124
+ def split_coeffs(self, coeffs):
125
+ id_coeff = coeffs[:, :self.id_dims] # identity(shape) coeff
126
+ exp_coeff = coeffs[:, self.id_dims:self.id_dims + self.exp_dims] # expression coeff
127
+ tex_coeff = coeffs[:, self.id_dims + self.exp_dims:self.all_dims] # texture(albedo) coeff
128
+ angles = coeffs[:, self.all_dims:self.all_dims + 3] # ruler angles(x,y,z) for rotation of dim 3
129
+ gamma = coeffs[:, self.all_dims + 3:self.all_dims + 30] # lighting coeff for 3 channel SH function of dim 27
130
+ translation = coeffs[:, self.all_dims + 30:-1] # translation coeff of dim 3
131
+ scale = coeffs[:, -1:]
132
+ return id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale
133
+
134
+ def merge_coeffs(self, id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale):
135
+ coeffs = torch.cat([id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale], dim=1)
136
+ return coeffs
137
+
138
+ def get_packed_tensors(self):
139
+ return self.merge_coeffs(self.id_tensor.repeat(self.batch_size, 1),
140
+ self.exp_tensor,
141
+ self.tex_tensor.repeat(self.batch_size, 1),
142
+ self.rot_tensor, self.gamma_tensor,
143
+ self.trans_tensor, self.scale_tensor)
144
+
145
+ def forward(self, coeffs=None, camT=None):
146
+ if coeffs is None:
147
+ id_coeff = self.id_tensor.repeat(self.batch_size, 1)
148
+ tex_coeff = self.tex_tensor.repeat(self.batch_size, 1)
149
+ exp_coeff, angles, gamma = self.exp_tensor, self.rot_tensor, self.gamma_tensor
150
+ translation, scale = self.rot_tensor, self.scale_tensor
151
+ else:
152
+ id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale = self.split_coeffs(coeffs)
153
+ rotation = axis_angle_to_matrix(angles)
154
+
155
+ if camT is not None:
156
+ rotation2 = camT[:3, :3].reshape(1, 3, 3)
157
+ translation2 = camT[:3, 3:].reshape(1, 3, 1)
158
+ if torch.allclose(rotation2, self.identity):
159
+ translation = translation + translation2
160
+ else:
161
+ rotation = torch.matmul(rotation2, rotation)
162
+ translation = torch.matmul(rotation2, translation) + translation2
163
+
164
+ vs = self.get_vs(id_coeff, exp_coeff)
165
+ vs_t = self.rigid_transform(vs, rotation, translation, torch.abs(scale))
166
+
167
+ lms_t = self.get_lms(vs_t)
168
+
169
+ face_texture = self.get_color(tex_coeff)
170
+ face_norm = self.compute_norm(vs, self.tri, self.point_buf)
171
+ face_norm_r = face_norm.bmm(rotation)
172
+ face_color = self.add_illumination(face_texture, face_norm_r, gamma)
173
+
174
+ return {'v': vs_t, 'lm': lms_t, 'face_texture': face_texture, 'face_color': face_color}
175
+
176
+ def forward_landmarks(self, coeffs=None, camT=None):
177
+ if coeffs is None:
178
+ id_coeff = self.id_tensor.repeat(self.batch_size, 1)
179
+ tex_coeff = self.tex_tensor.repeat(self.batch_size, 1)
180
+ exp_coeff, angles, gamma = self.exp_tensor, self.rot_tensor, self.gamma_tensor
181
+ translation, scale = self.trans_tensor, self.scale_tensor
182
+ else:
183
+ id_coeff, exp_coeff, tex_coeff, angles, gamma, translation, scale = self.split_coeffs(coeffs)
184
+ rotation = axis_angle_to_matrix(angles)
185
+
186
+ if camT is not None:
187
+ rotation2 = camT[:3, :3].reshape(1, 3, 3)
188
+ translation2 = camT[:3, 3:].reshape(1, 3, 1)
189
+ if torch.allclose(rotation2, self.identity):
190
+ translation = translation + translation2
191
+ else:
192
+ rotation = torch.matmul(rotation2, rotation)
193
+ translation = torch.matmul(rotation2, translation) + translation2
194
+ lms = self.get_vs_lms(id_coeff, exp_coeff)
195
+ lms_t = self.rigid_transform(lms, rotation, translation, torch.abs(scale))
196
+ return lms_t
197
+
198
+ def get_vs(self, id_coeff, exp_coeff):
199
+ face_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + \
200
+ torch.einsum('ij,aj->ai', self.expBase, exp_coeff) + self.meanshape
201
+ face_shape = face_shape.view(self.batch_size, -1, 3)
202
+ return face_shape
203
+
204
+ def get_vs_lms(self, id_coeff, exp_coeff):
205
+ face_shape = torch.einsum('ij,aj->ai', self.idBase_view, id_coeff) + \
206
+ torch.einsum('ij,aj->ai', self.expBase_view, exp_coeff) + self.meanshape_view
207
+ face_shape = face_shape.view(self.batch_size, -1, 3)
208
+ return face_shape
209
+
210
+ def get_color(self, tex_coeff):
211
+ face_texture = torch.einsum('ij,aj->ai', self.texBase, tex_coeff) + self.meantex
212
+ face_texture = face_texture.view(self.batch_size, -1, 3)
213
+ return face_texture
214
+
215
+ def get_skinmask(self):
216
+ return self.skinmask
217
+
218
+ def compute_norm(self, vs, tri, point_buf):
219
+ face_id = tri
220
+ point_id = point_buf
221
+ v1 = vs[:, face_id[:, 0], :]
222
+ v2 = vs[:, face_id[:, 1], :]
223
+ v3 = vs[:, face_id[:, 2], :]
224
+ e1 = v1 - v2
225
+ e2 = v2 - v3
226
+ face_norm = e1.cross(e2)
227
+
228
+ v_norm = face_norm[:, point_id, :].sum(2)
229
+ v_norm = v_norm / (v_norm.norm(dim=2).unsqueeze(2) + 1e-9)
230
+
231
+ return v_norm
232
+
233
+ def add_illumination(self, face_texture, norm, gamma):
234
+ gamma = gamma.view(-1, 3, 9).clone()
235
+ gamma[:, :, 0] += 0.8
236
+ gamma = gamma.permute(0, 2, 1)
237
+
238
+ a0 = np.pi
239
+ a1 = 2 * np.pi / np.sqrt(3.0)
240
+ a2 = 2 * np.pi / np.sqrt(8.0)
241
+ c0 = 1 / np.sqrt(4 * np.pi)
242
+ c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
243
+ c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
244
+ d0 = 0.5 / np.sqrt(3.0)
245
+
246
+ norm = norm.view(-1, 3)
247
+ nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
248
+ arrH = []
249
+
250
+ arrH.append(a0 * c0 * (nx * 0 + 1))
251
+ arrH.append(-a1 * c1 * ny)
252
+ arrH.append(a1 * c1 * nz)
253
+ arrH.append(-a1 * c1 * nx)
254
+ arrH.append(a2 * c2 * nx * ny)
255
+ arrH.append(-a2 * c2 * ny * nz)
256
+ arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
257
+ arrH.append(-a2 * c2 * nx * nz)
258
+ arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))
259
+
260
+ H = torch.stack(arrH, 1)
261
+ Y = H.view(self.batch_size, face_texture.shape[1], 9)
262
+ lighting = Y.bmm(gamma)
263
+
264
+ face_color = face_texture * lighting
265
+ return face_color
266
+
267
+ def rigid_transform(self, vs, rot, trans, scale):
268
+ scale = scale.reshape(-1, 1, 1)
269
+ vs_r = torch.matmul(vs * scale, rot.permute(0, 2, 1))
270
+ vs_t = vs_r + trans.reshape(-1, 1, 3)
271
+ return vs_t
272
+
273
+ def get_rot_tensor(self):
274
+ return self.rot_tensor
275
+
276
+ def get_trans_tensor(self):
277
+ return self.trans_tensor
278
+
279
+ def get_exp_tensor(self):
280
+ return self.exp_tensor
281
+
282
+ def get_tex_tensor(self):
283
+ return self.tex_tensor
284
+
285
+ def get_id_tensor(self):
286
+ return self.id_tensor
287
+
288
+ def get_gamma_tensor(self):
289
+ return self.gamma_tensor
290
+
291
+ def get_scale_tensor(self):
292
+ return self.scale_tensor
render_utils/lib/networks/smpl_torch.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pickle
4
+ import numpy as np
5
+ import scipy.sparse
6
+ import os
7
+
8
+ import smplx
9
+ import smplx.lbs
10
+
11
+ from ..utils.geometry import rodrigues
12
+ def _convert_to_sparse(mat):
13
+ if isinstance(mat, np.ndarray):
14
+ row_ind, col_ind = np.asarray(mat > 0).nonzero()
15
+ data = [mat[r, c] for (r, c) in zip(row_ind, col_ind)]
16
+ mat_sparse = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=mat.shape)
17
+ return mat_sparse
18
+ else:
19
+ return mat
20
+
21
+
22
+ class SmplTorch(nn.Module):
23
+ def __init__(self, model_file='./data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'):
24
+ super(SmplTorch, self).__init__()
25
+ if model_file.endswith('.pkl'):
26
+ with open(model_file, 'rb') as f:
27
+ smpl_model = pickle.load(f, encoding='iso-8859-1')
28
+ # print('SmplTorch: Loading SMPL data from %s' % model_file)
29
+ elif model_file.endswith('.npz'):
30
+ smpl_model = np.load(model_file)
31
+ # print('SmplTorch: Loading SMPL data from %s' % model_file)
32
+ else:
33
+ raise ValueError('Unknown file extension: %s' % model_file)
34
+
35
+ J_regressor = _convert_to_sparse(smpl_model['J_regressor']).tocoo()
36
+ row = J_regressor.row
37
+ col = J_regressor.col
38
+ data = J_regressor.data
39
+ i = torch.LongTensor([row, col])
40
+ v = torch.FloatTensor(data)
41
+ J_regressor_shape = J_regressor.shape
42
+ self.v_num = len(smpl_model['v_template'])
43
+ self.joint_num = smpl_model['weights'].shape[1]
44
+ self.beta_dim = 10
45
+
46
+ self.register_buffer('J_regressor', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense())
47
+ self.register_buffer('weights', torch.FloatTensor(smpl_model['weights']))
48
+ self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs']))
49
+ self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template']))
50
+ self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs'][:, :, :self.beta_dim])))
51
+ self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64)))
52
+ self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64)))
53
+ id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
54
+ self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
55
+
56
+ self.pose_shape = [self.joint_num, 3]
57
+ self.beta_shape = [self.beta_dim]
58
+ self.translation_shape = [3]
59
+
60
+ self.pose = torch.zeros(self.pose_shape)
61
+ self.beta = torch.zeros(self.beta_shape)
62
+ self.translation = torch.zeros(self.translation_shape)
63
+
64
+ self.verts = None
65
+
66
+ # # BODY25 joint regressor
67
+ # with open(os.path.join(os.path.dirname(model_file), 'J_regressor_body25.pkl'), 'rb') as f:
68
+ # jreg_data = pickle.load(f, encoding='iso-8859-1')
69
+ # J_regressor25 = jreg_data['J_regressor_body25'].tocoo()
70
+ # row = J_regressor25.row
71
+ # col = J_regressor25.col
72
+ # data = J_regressor25.data
73
+ # i = torch.LongTensor([row, col])
74
+ # v = torch.FloatTensor(data)
75
+ # J_regressor_shape = J_regressor25.shape
76
+ # self.register_buffer('J_regressor_body25', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense())
77
+
78
+ def forward(self, pose, beta, apply_pose_blend_shape=True):
79
+ device = pose.device
80
+ batch_size = pose.shape[0]
81
+ v_template = self.v_template[None, :]
82
+ shapedirs = self.shapedirs.view(-1, self.beta_dim)[None, :].expand(batch_size, -1, -1)
83
+ beta = beta[:, :, None]
84
+ v_shaped = torch.matmul(shapedirs, beta).view(-1, self.v_num, 3) + v_template
85
+ # batched sparse matmul not supported in pytorch
86
+ J = []
87
+ for i in range(batch_size):
88
+ J.append(torch.matmul(self.J_regressor, v_shaped[i]))
89
+ J = torch.stack(J, dim=0)
90
+ # input it rotmat: (bs,24,3,3)
91
+ if pose.ndimension() == 4:
92
+ R = pose
93
+ # input it rotmat: (bs,72)
94
+ elif pose.ndimension() == 2:
95
+ pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3)
96
+ R = rodrigues(pose_cube).view(batch_size, self.joint_num, 3, 3)
97
+ R = R.view(batch_size, self.joint_num, 3, 3)
98
+
99
+ if apply_pose_blend_shape:
100
+ I_cube = torch.eye(3)[None, None, :].to(device)
101
+ # I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1)
102
+ lrotmin = (R[:, 1:, :] - I_cube).view(batch_size, -1)
103
+ posedirs = self.posedirs.view(-1, 9*(self.joint_num-1))[None, :].expand(batch_size, -1, -1)
104
+ v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, self.v_num, 3)
105
+ else:
106
+ v_posed = v_shaped
107
+
108
+ J_ = J.clone()
109
+ J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
110
+ G_ = torch.cat([R, J_[:, :, :, None]], dim=-1)
111
+ pad_row = torch.FloatTensor([0, 0, 0, 1]).to(device).view(1, 1, 1, 4).expand(batch_size, self.joint_num, -1, -1)
112
+ G_ = torch.cat([G_, pad_row], dim=2)
113
+ G = [G_[:, 0].clone()]
114
+ for i in range(1, self.joint_num):
115
+ G.append(torch.matmul(G[self.parent[i - 1]], G_[:, i, :, :]))
116
+ G = torch.stack(G, dim=1)
117
+
118
+ rest = torch.cat([J, torch.zeros(batch_size, self.joint_num, 1).to(device)], dim=2).view(batch_size, self.joint_num, 4, 1)
119
+ zeros = torch.zeros(batch_size, self.joint_num, 4, 3).to(device)
120
+ rest = torch.cat([zeros, rest], dim=-1)
121
+ rest = torch.matmul(G, rest)
122
+ G = G - rest
123
+ T = torch.matmul(self.weights, G.permute(1, 0, 2, 3).contiguous().view(self.joint_num, -1)).view(self.v_num, batch_size, 4, 4).transpose(0, 1)
124
+ rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1)
125
+ v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
126
+
127
+ return v, {'J': J, 'G': G, 'T': T, 'R': R, 'v_shaped': v_shaped, 'v_posed': v_posed}
128
+
129
+ def get_joints(self, vertices):
130
+ """
131
+ This method is used to get the joint locations from the SMPL mesh
132
+ Input:
133
+ vertices: size = (B, self.v_num, 3)
134
+ Output:
135
+ 3D joints: size = (B, 38, 3)
136
+ """
137
+ joints = torch.einsum('bik,ji->bjk', vertices, self.J_regressor)
138
+ return joints
139
+
140
+ def get_root(self, vertices):
141
+ """
142
+ This method is used to get the root locations from the SMPL mesh
143
+ Input:
144
+ vertices: size = (B, self.v_num, 3)
145
+ Output:
146
+ 3D joints: size = (B, 1, 3)
147
+ """
148
+ joints = torch.einsum('bik,ji->bjk', vertices, self.J_regressor)
149
+ return joints[:, 0:1, :]
150
+
151
+ # def get_joints_body25(self, vertices):
152
+ # joints = torch.einsum('bik,ji->bjk', vertices, self.J_regressor_body25)
153
+ # return joints
154
+
155
+
156
+ class SmplProjector(nn.Module):
157
+ def __init__(self, verts_cano=None, face_triangles=None):
158
+ super(SmplProjector, self).__init__()
159
+ self.verts_cano = verts_cano
160
+ self.face_triangles = face_triangles
161
+
162
+ face_node_idx = np.loadtxt('./data/smplx/face_coarse_level_ids.txt').astype(np.int64)
163
+ node_to_face_table = np.loadtxt('./data/smplx/face_coarse_level_to_all_table.txt').astype(np.int64)
164
+ self.register_buffer('face_node_idx', torch.from_numpy(face_node_idx), persistent=False)
165
+ self.register_buffer('node_to_face_table', torch.from_numpy(node_to_face_table), persistent=False)
166
+
167
+ @staticmethod
168
+ def batch_select(tensor, index, dim):
169
+ """
170
+ Perform index_select for batched inputs, where the index tensors of different items are different
171
+ :param tensor: [B, G, ....]
172
+ :param index: [B, N], which is the key difference from torch.index_select()
173
+ :param dim: int
174
+ :return:
175
+ """
176
+ res = []
177
+ for bi in range(tensor.shape[0]):
178
+ res.append(
179
+ torch.index_select(tensor[bi], dim=dim-1, index=index[bi])
180
+ )
181
+ res = torch.stack(res, dim=0)
182
+ return res
183
+
184
+ def get_current_beta(self):
185
+ return self.density_func.get_beta()
186
+
187
+ def calculate_nearest_barycentric(self, verts, pts):
188
+ nearest_face_id, nearest_dist, nearest_point_barycentric = calc_nearest_barycentric_coord(
189
+ verts, self.face_triangles, pts, self.face_node_idx, self.node_to_face_table)
190
+ return nearest_face_id, nearest_dist, nearest_point_barycentric
191
+
192
+ def interpolate_vert_feat(self, vert_feat, pts, nearest_face_id, nearest_dist, nearest_point_barycentric):
193
+ fvfa = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 0]) # [B, G, 3]
194
+ fvfb = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 1]) # [B, G, 3]
195
+ fvfc = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 2]) # [B, G, 3]
196
+ fvfs = torch.cat([fvfa.unsqueeze(2), fvfb.unsqueeze(2), fvfc.unsqueeze(2)], dim=2) # [B, G, 3, 3]
197
+ nearest_fvfs = self.batch_select(fvfs, nearest_face_id, dim=1)
198
+ nearest_vf = torch.sum(nearest_point_barycentric.unsqueeze(-1)*nearest_fvfs, dim=-2) # [B, N, 3]
199
+ return nearest_vf
200
+
201
+ def calculate_nearest_point(self, verts, vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric):
202
+ verts_cano = self.verts_cano.unsqueeze(0).expand(verts.shape[0], -1, -1)
203
+ nearest_v = self.interpolate_vert_feat(verts, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
204
+ nearest_nml = self.interpolate_vert_feat(vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
205
+ nearest_vc = self.interpolate_vert_feat(verts_cano, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
206
+ return nearest_v, nearest_nml, nearest_vc
207
+
208
+ def forward(self, verts, vert_nmls, pts, extra_vert_feat=None):
209
+ nearest_face_id, nearest_dist, nearest_point_barycentric = self.calculate_nearest_barycentric(verts, pts)
210
+ nearest_v, nearest_nml, nearest_vc = self.calculate_nearest_point(verts, vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
211
+ nearest_dist = nearest_dist.unsqueeze(-1)
212
+ nearest_h = torch.sum((pts-nearest_v)*nearest_nml, dim=-1, keepdim=True) # [B, N, 1]
213
+ sdf = nearest_h.sign() * nearest_dist
214
+ if extra_vert_feat is None:
215
+ return sdf, nearest_v, nearest_nml, nearest_vc
216
+ else:
217
+ if len(extra_vert_feat.shape) == 2:
218
+ extra_vert_feat = extra_vert_feat.unsqueeze(0).expand(verts.shape[0], -1, -1)
219
+ nearest_feat = self.interpolate_vert_feat(extra_vert_feat, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
220
+ return sdf, nearest_v, nearest_nml, nearest_vc, nearest_feat
221
+
222
+
223
+ class SmplProjectorBF(nn.Module):
224
+ """
225
+ Calculate barycentric projection in a brute-force manner
226
+ """
227
+ def __init__(self, verts_cano=None, face_triangles=None):
228
+ super(SmplProjectorBF, self).__init__()
229
+ self.verts_cano = verts_cano
230
+ self.face_triangles = face_triangles
231
+
232
+ @staticmethod
233
+ def batch_select(tensor, index, dim):
234
+ """
235
+ Perform index_select for batched inputs, where the index tensors of different items are different
236
+ :param tensor: [B, G, ....]
237
+ :param index: [B, N], which is the key difference from torch.index_select()
238
+ :param dim: int
239
+ :return:
240
+ """
241
+ res = []
242
+ for bi in range(tensor.shape[0]):
243
+ res.append(
244
+ torch.index_select(tensor[bi], dim=dim-1, index=index[bi])
245
+ )
246
+ res = torch.stack(res, dim=0)
247
+ return res
248
+
249
+ def get_current_beta(self):
250
+ return self.density_func.get_beta()
251
+
252
+ def calculate_nearest_barycentric(self, verts, pts):
253
+ nearest_face_id, nearest_dist, nearest_point_barycentric = calc_nearest_barycentric_coord_bf(
254
+ verts, self.face_triangles, pts)
255
+ return nearest_face_id, nearest_dist, nearest_point_barycentric
256
+
257
+ def interpolate_vert_feat(self, vert_feat, pts, nearest_face_id, nearest_dist, nearest_point_barycentric):
258
+ fvfa = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 0]) # [B, G, 3]
259
+ fvfb = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 1]) # [B, G, 3]
260
+ fvfc = torch.index_select(vert_feat, dim=1, index=self.face_triangles[:, 2]) # [B, G, 3]
261
+ fvfs = torch.cat([fvfa.unsqueeze(2), fvfb.unsqueeze(2), fvfc.unsqueeze(2)], dim=2) # [B, G, 3, 3]
262
+ nearest_fvfs = self.batch_select(fvfs, nearest_face_id, dim=1)
263
+ nearest_vf = torch.sum(nearest_point_barycentric.unsqueeze(-1)*nearest_fvfs, dim=-2) # [B, N, 3]
264
+ return nearest_vf
265
+
266
+ def calculate_nearest_point(self, verts, vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric):
267
+ verts_cano = self.verts_cano.unsqueeze(0).expand(verts.shape[0], -1, -1)
268
+ nearest_v = self.interpolate_vert_feat(verts, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
269
+ nearest_nml = self.interpolate_vert_feat(vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
270
+ nearest_vc = self.interpolate_vert_feat(verts_cano, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
271
+ return nearest_v, nearest_nml, nearest_vc
272
+
273
+ def forward(self, verts, vert_nmls, pts):
274
+ nearest_face_id, nearest_dist, nearest_point_barycentric = self.calculate_nearest_barycentric(verts, pts)
275
+ nearest_v, nearest_nml, nearest_vc = self.calculate_nearest_point(verts, vert_nmls, pts, nearest_face_id, nearest_dist, nearest_point_barycentric)
276
+ nearest_dist = nearest_dist.unsqueeze(-1)
277
+ nearest_h = torch.sum((pts-nearest_v)*nearest_nml, dim=-1, keepdim=True) # [B, N, 1]
278
+ sdf = nearest_h.sign() * nearest_dist
279
+ return sdf, nearest_v, nearest_nml, nearest_vc
280
+
281
+
282
+ class SmplJointRegressor(nn.Module):
283
+ def __init__(self, model_file, extra_regressor_file):
284
+ super().__init__()
285
+ # '../smpl_models/J_regressor_extra_' + 'smplx' + '.npy'
286
+
287
+ self.smpl = smplx.SMPLX(model_path=model_file, gender='neutral', use_pca=False,
288
+ num_pca_comps=45, flat_hand_mean=True, batch_size=1)
289
+ self.extra_jregressor = torch.tensor(np.load(extra_regressor_file), dtype=torch.float32)
290
+ self.smpl2coco = {
291
+ 'body': torch.LongTensor([55, 57, 56, 59, 58, 16, 17, 18, 19, 20, 21, 128, 127, 4, 5, 7, 8]),
292
+ 'foot': torch.LongTensor(np.arange(60, 66, dtype=int)),
293
+ 'face': torch.LongTensor(np.arange(76, 127, dtype=int)),
294
+ 'lhand': torch.LongTensor(
295
+ [20, 37, 38, 39, 66, 25, 26, 27, 67, 28, 29, 30, 68, 34, 35, 36, 69, 31, 32, 33, 70]),
296
+ 'rhand': torch.LongTensor(
297
+ [21, 52, 53, 54, 71, 40, 41, 42, 72, 43, 44, 45, 73, 49, 50, 51, 74, 46, 47, 48, 75])
298
+ }
299
+
300
+ def to(self, device):
301
+ super().to(device)
302
+ self.smpl = self.smpl.to(device)
303
+ self.smpl2coco = {k: v.to(device) for k, v in self.smpl2coco.items()}
304
+ self.extra_jregressor = self.extra_jregressor.to(device)
305
+ return self
306
+
307
+ def get_all_joints(self, verts):
308
+ lmk_faces_idx = self.smpl.lmk_faces_idx.unsqueeze(dim=0).expand(verts.shape[0], -1).contiguous()
309
+ lmk_bary_coords = self.smpl.lmk_bary_coords.unsqueeze(dim=0).repeat(verts.shape[0], 1, 1)
310
+ landmarks = smplx.lbs.vertices2landmarks(verts, self.smpl.faces_tensor, lmk_faces_idx, lmk_bary_coords)
311
+
312
+ joints = torch.einsum('bik,ji->bjk', verts, self.smpl.J_regressor)
313
+ joints = self.smpl.vertex_joint_selector(verts, joints)
314
+ joints = torch.cat([joints, landmarks], dim=1)
315
+ if self.smpl.joint_mapper is not None:
316
+ joints = self.joint_mapper(joints=joints, vertices=verts)
317
+
318
+ extra_joints = torch.einsum('bik,ji->bjk', verts, self.extra_jregressor)
319
+ joints = torch.cat([joints, extra_joints], dim=1)
320
+
321
+ return joints
322
+
323
+ def forward(self, verts, cam_param=None):
324
+ joints = self.get_all_joints(verts)
325
+
326
+ if cam_param:
327
+ if len(cam_param['K'].shape) == 3:
328
+ smpl_j2d = torch.einsum("ljk,lik->lij", cam_param['K'], torch.einsum(
329
+ "ijk,lk->ilj", cam_param['R'], joints.reshape(-1, 3)) + cam_param['T'].reshape(-1, 1, 3))
330
+ else:
331
+ smpl_j2d = torch.einsum("jk,lik->lij", cam_param['K'], torch.einsum(
332
+ "ijk,lk->ilj", cam_param['R'], joints.reshape(-1, 3)) + cam_param['T'].reshape(-1, 1, 3))
333
+ valid = smpl_j2d[:, :, 2] > 1e-3
334
+ smpl_j2d[:, :, :2][valid] = smpl_j2d[:, :, :2][valid] / smpl_j2d[:, :, 2:3][valid]
335
+ smpl_j2d[:, :, 2] = valid.float()
336
+ skel_2d = {k: torch.index_select(smpl_j2d, dim=1, index=v) for k, v in self.smpl2coco.items()}
337
+ else:
338
+ skel_2d = None
339
+
340
+ skel = {k: torch.index_select(joints, dim=1, index=v) for k, v in self.smpl2coco.items()}
341
+ return skel, skel_2d
render_utils/lib/utils/__init__.py ADDED
File without changes
render_utils/lib/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (193 Bytes). View file
 
render_utils/lib/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (186 Bytes). View file
 
render_utils/lib/utils/__pycache__/gaussian_np_utils.cpython-310.pyc ADDED
Binary file (6.28 kB). View file
 
render_utils/lib/utils/__pycache__/gaussian_np_utils.cpython-38.pyc ADDED
Binary file (6.54 kB). View file
 
render_utils/lib/utils/__pycache__/geometry.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
render_utils/lib/utils/__pycache__/geometry.cpython-38.pyc ADDED
Binary file (16 kB). View file
 
render_utils/lib/utils/__pycache__/graphics_utils.cpython-310.pyc ADDED
Binary file (5.44 kB). View file
 
render_utils/lib/utils/__pycache__/graphics_utils.cpython-38.pyc ADDED
Binary file (5.42 kB). View file
 
render_utils/lib/utils/__pycache__/rotation_conversions.cpython-310.pyc ADDED
Binary file (17.1 kB). View file
 
render_utils/lib/utils/__pycache__/rotation_conversions.cpython-38.pyc ADDED
Binary file (17.1 kB). View file
 
render_utils/lib/utils/__pycache__/sh_utils.cpython-310.pyc ADDED
Binary file (2.58 kB). View file
 
render_utils/lib/utils/__pycache__/sh_utils.cpython-38.pyc ADDED
Binary file (2.63 kB). View file
 
render_utils/lib/utils/gaussian_np_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from plyfile import PlyData, PlyElement
4
+ from typing import NamedTuple
5
+ import smplx
6
+ import tqdm
7
+ import cv2 as cv
8
+ import os
9
+
10
+ from scipy.spatial.transform import Rotation as R
11
+
12
+
13
+ class GaussianAttributes(NamedTuple):
14
+ xyz: np.ndarray
15
+ opacities: np.ndarray
16
+ features_dc: np.ndarray
17
+ features_extra: np.ndarray
18
+ scales: np.ndarray
19
+ rot: np.ndarray
20
+
21
+ def load_gaussians_from_ply(path):
22
+ max_sh_degree = 3
23
+ plydata = PlyData.read(path)
24
+
25
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
26
+ np.asarray(plydata.elements[0]["y"]),
27
+ np.asarray(plydata.elements[0]["z"])), axis=1)
28
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
29
+
30
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
31
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
32
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
33
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
34
+
35
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
36
+ extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split('_')[-1]))
37
+ assert len(extra_f_names) == 3 * (max_sh_degree + 1) ** 2 - 3
38
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
39
+ for idx, attr_name in enumerate(extra_f_names):
40
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
41
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
42
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1))
43
+
44
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
45
+ scale_names = sorted(scale_names, key=lambda x: int(x.split('_')[-1]))
46
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
47
+ for idx, attr_name in enumerate(scale_names):
48
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
49
+
50
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
51
+ rot_names = sorted(rot_names, key=lambda x: int(x.split('_')[-1]))
52
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
53
+ for idx, attr_name in enumerate(rot_names):
54
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
55
+
56
+ return GaussianAttributes(xyz, opacities, features_dc, features_extra, scales, rots)
57
+
58
+
59
+ def construct_list_of_attributes(_features_dc, _features_rest, _scaling, _rotation):
60
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
61
+ # All channels except the 3 DC
62
+ for i in range(_features_dc.shape[1] * _features_dc.shape[2]):
63
+ l.append('f_dc_{}'.format(i))
64
+ for i in range(_features_rest.shape[1] * _features_rest.shape[2]):
65
+ l.append('f_rest_{}'.format(i))
66
+ l.append('opacity')
67
+ for i in range(_scaling.shape[1]):
68
+ l.append('scale_{}'.format(i))
69
+ for i in range(_rotation.shape[1]):
70
+ l.append('rot_{}'.format(i))
71
+ return l
72
+
73
+
74
+ def select_gaussians(gaussian_attrs, select_mask_or_idx):
75
+ return GaussianAttributes(
76
+ xyz=gaussian_attrs.xyz[select_mask_or_idx],
77
+ opacities=gaussian_attrs.opacities[select_mask_or_idx],
78
+ features_dc=gaussian_attrs.features_dc[select_mask_or_idx],
79
+ features_extra=gaussian_attrs.features_extra[select_mask_or_idx],
80
+ scales=gaussian_attrs.scales[select_mask_or_idx],
81
+ rot=gaussian_attrs.rot[select_mask_or_idx]
82
+ )
83
+
84
+
85
+ def combine_gaussians(gaussian_attrs_list):
86
+ return GaussianAttributes(
87
+ xyz=np.concatenate([gau.xyz for gau in gaussian_attrs_list], axis=0),
88
+ opacities=np.concatenate([gau.opacities for gau in gaussian_attrs_list], axis=0),
89
+ features_dc=np.concatenate([gau.features_dc for gau in gaussian_attrs_list], axis=0),
90
+ features_extra=np.concatenate([gau.features_extra for gau in gaussian_attrs_list], axis=0),
91
+ scales=np.concatenate([gau.scales for gau in gaussian_attrs_list], axis=0),
92
+ rot=np.concatenate([gau.rot for gau in gaussian_attrs_list], axis=0),
93
+ )
94
+
95
+
96
+ def apply_transformation_to_gaussians(gaussian_attrs, spatial_transformation, color_transformation=None):
97
+ xyzs = np.copy(gaussian_attrs.xyz)
98
+ xyzs = np.matmul(xyzs, spatial_transformation[:3, :3].transpose()) + spatial_transformation[:3, 3].reshape([1, 3])
99
+
100
+ gaussian_rotmats = R.from_quat(gaussian_attrs.rot[:, (1, 2, 3, 0)]).as_matrix()
101
+ new_rots = []
102
+ for rotmat in gaussian_rotmats:
103
+ rotmat = np.matmul(spatial_transformation[:3, :3], rotmat)
104
+ rotq = R.from_matrix(rotmat).as_quat()
105
+ rotq = np.array([rotq[3], rotq[0], rotq[1], rotq[2]])
106
+ new_rots.append(rotq)
107
+ new_rots = np.stack(new_rots, axis=0)
108
+ if color_transformation is not None:
109
+ if color_transformation.shape[0] == 3 and color_transformation.shape[1] == 3:
110
+ new_clrs = np.matmul(gaussian_attrs.features_dc[:, :, 0], color_transformation)[:, :, np.newaxis]
111
+ elif color_transformation.shape[0] == 4 and color_transformation.shape[1] == 4:
112
+ clrs = gaussian_attrs.features_dc[:, :, 0]
113
+ clrs = np.concatenate([clrs, np.ones_like(clrs[:, :1])], axis=1)
114
+ new_clrs = np.matmul(clrs, color_transformation)
115
+ new_clrs = new_clrs[:, :3, np.newaxis]
116
+ else:
117
+ new_clrs = gaussian_attrs.features_dc
118
+
119
+ return GaussianAttributes(
120
+ xyz=xyzs,
121
+ opacities=gaussian_attrs.opacities,
122
+ features_dc=new_clrs,
123
+ features_extra=gaussian_attrs.features_extra,
124
+ scales=gaussian_attrs.scales,
125
+ rot=new_rots,
126
+ )
127
+
128
+
129
+ def update_gaussian_attributes(
130
+ orig_gaussian,
131
+ new_xyz=None, new_rgb=None, new_rot=None, new_opacity=None, new_scale=None):
132
+ return GaussianAttributes(
133
+ xyz=orig_gaussian.xyz if new_xyz is None else new_xyz,
134
+ opacities=orig_gaussian.opacities if new_opacity is None else new_opacity,
135
+ features_dc=orig_gaussian.features_dc if new_rgb is None else new_rgb,
136
+ features_extra=orig_gaussian.features_extra,
137
+ scales=orig_gaussian.scales if new_scale is None else new_scale,
138
+ rot=orig_gaussian.rot if new_rot is None else new_rot,
139
+ )
140
+
141
+
142
+ def save_gaussians_as_ply(path, gaussian_attrs: GaussianAttributes):
143
+ os.makedirs(os.path.dirname(path), exist_ok=True)
144
+
145
+ xyz = gaussian_attrs.xyz
146
+ normals = np.zeros_like(xyz)
147
+ features_dc = gaussian_attrs.features_dc
148
+ features_rest = gaussian_attrs.features_extra
149
+ opacities = gaussian_attrs.opacities
150
+ scale = gaussian_attrs.scales
151
+ rotation = gaussian_attrs.rot
152
+
153
+ dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes(features_dc, features_rest, scale, rotation)]
154
+
155
+ elements = np.empty(xyz.shape[0], dtype = dtype_full)
156
+ attributes = np.concatenate((xyz, normals, features_dc.reshape(features_dc.shape[0], -1),
157
+ features_rest.reshape(features_rest.shape[0], -1),
158
+ opacities, scale, rotation), axis=1)
159
+ elements[:] = list(map(tuple, attributes))
160
+ el = PlyElement.describe(elements, 'vertex')
161
+ PlyData([el]).write(path)
162
+ return
render_utils/lib/utils/geometry.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import trimesh
7
+ import scipy.sparse as sp
8
+ import collections
9
+
10
+
11
+ def rodrigues(theta):
12
+ """Convert axis-angle representation to rotation matrix.
13
+ Args:
14
+ theta: size = [B, 3]
15
+ Returns:
16
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
17
+ """
18
+ l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
19
+ angle = torch.unsqueeze(l1norm, -1)
20
+ normalized = torch.div(theta, angle)
21
+ angle = angle * 0.5
22
+ v_cos = torch.cos(angle)
23
+ v_sin = torch.sin(angle)
24
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
25
+ return quat2mat(quat)
26
+
27
+
28
+ def quat2mat(quat):
29
+ """Convert quaternion coefficients to rotation matrix.
30
+ Args:
31
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
32
+ Returns:
33
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
34
+ """
35
+ norm_quat = quat
36
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
37
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
38
+
39
+ B = quat.size(0)
40
+
41
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
42
+ wx, wy, wz = w * x, w * y, w * z
43
+ xy, xz, yz = x * y, x * z, y * z
44
+
45
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
46
+ 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
47
+ 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
48
+ return rotMat
49
+
50
+
51
+ def inv_4x4(mats):
52
+ """Calculate the inverse of homogeneous transformations
53
+ :param mats: [B, 4, 4]
54
+ :return:
55
+ """
56
+ Rs = mats[:, :3, :3]
57
+ ts = mats[:, :3, 3:]
58
+ # R_invs = torch.transpose(Rs, 1, 2)
59
+ R_invs = torch.inverse(Rs)
60
+ t_invs = -torch.matmul(R_invs, ts)
61
+ Rt_invs = torch.cat([R_invs, t_invs], dim=-1) # [B, 3, 4]
62
+
63
+ device = R_invs.device
64
+ pad_row = torch.FloatTensor([0, 0, 0, 1]).to(device).view(1, 1, 4).expand(Rs.shape[0], -1, -1)
65
+ mat_invs = torch.cat([Rt_invs, pad_row], dim=1)
66
+ return mat_invs
67
+
68
+
69
+ def as_mesh(scene_or_mesh):
70
+ """
71
+ Convert a possible scene to a mesh.
72
+
73
+ If conversion occurs, the returned mesh has only vertex and face data.
74
+ """
75
+ if isinstance(scene_or_mesh, trimesh.Scene):
76
+ if len(scene_or_mesh.geometry) == 0:
77
+ mesh = None # empty scene
78
+ else:
79
+ # we lose texture information here
80
+ mesh = trimesh.util.concatenate(
81
+ tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
82
+ for g in scene_or_mesh.geometry.values()))
83
+ else:
84
+ assert(isinstance(scene_or_mesh, trimesh.Trimesh))
85
+ mesh = scene_or_mesh
86
+ return mesh
87
+
88
+
89
+ def get_edge_unique(faces):
90
+ """
91
+ Parameters
92
+ ------------
93
+ faces: n x 3 int array
94
+ Should be from a watertight mesh without degenerated triangles and intersection
95
+ """
96
+ faces = np.asanyarray(faces)
97
+
98
+ # each face has three edges
99
+ edges = faces[:, [0, 1, 1, 2, 2, 0]].reshape((-1, 2))
100
+ flags = edges[:, 0] < edges[:, 1]
101
+ edges = edges[flags]
102
+ return edges
103
+
104
+
105
+ def get_neighbors(edges):
106
+ neighbors = collections.defaultdict(set)
107
+ [(neighbors[edge[0]].add(edge[1]),
108
+ neighbors[edge[1]].add(edge[0]))
109
+ for edge in edges]
110
+
111
+ max_index = edges.max() + 1
112
+ array = [list(neighbors[i]) for i in range(max_index)]
113
+
114
+ return array
115
+
116
+
117
+ def construct_degree_matrix(vnum, faces):
118
+ row = col = list(range(vnum))
119
+ value = [0] * vnum
120
+ es = get_edge_unique(faces)
121
+ for e in es:
122
+ if e[0] < e[1]:
123
+ value[e[0]] += 1
124
+ value[e[0]] += 1
125
+
126
+ dm = sp.coo_matrix((value, (row, col)), shape=(vnum, vnum), dtype=np.float32)
127
+ return dm
128
+
129
+
130
+ def construct_neighborhood_matrix(vnum, faces):
131
+ row = list()
132
+ col = list()
133
+ value = list()
134
+ es = get_edge_unique(faces)
135
+ for e in es:
136
+ if e[0] < e[1]:
137
+ row.append(e[0])
138
+ col.append(e[1])
139
+ value.append(1)
140
+ row.append(e[1])
141
+ col.append(e[0])
142
+ value.append(1)
143
+
144
+ nm = sp.coo_matrix((value, (row, col)), shape=(vnum, vnum), dtype=np.float32)
145
+ return nm
146
+
147
+
148
+ def construct_laplacian_matrix(vnum, faces, normalized=False):
149
+ edges = get_edge_unique(faces)
150
+ neighbors = get_neighbors(edges)
151
+
152
+ col = np.concatenate(neighbors)
153
+ row = np.concatenate([[i] * len(n)
154
+ for i, n in enumerate(neighbors)])
155
+ col = np.concatenate([col, np.arange(0, vnum)])
156
+ row = np.concatenate([row, np.arange(0, vnum)])
157
+
158
+ if normalized:
159
+ data = [[1.0 / len(n)] * len(n) for n in neighbors]
160
+ data += [[-1.0] * vnum]
161
+ else:
162
+ data = [[1.0] * len(n) for n in neighbors]
163
+ data += [[-len(n) for n in neighbors]]
164
+
165
+ data = np.concatenate(data)
166
+ # create the sparse matrix
167
+ matrix = sp.coo_matrix((data, (row, col)), shape=[vnum] * 2)
168
+ return matrix
169
+
170
+
171
+ def rotationx_4x4(theta):
172
+ return np.array([
173
+ [1.0, 0.0, 0.0, 0.0],
174
+ [0.0, np.cos(theta / 180 * np.pi), np.sin(theta / 180 * np.pi), 0.0],
175
+ [0.0, -np.sin(theta / 180 * np.pi), np.cos(theta / 180 * np.pi), 0.0],
176
+ [0.0, 0.0, 0.0, 1.0]
177
+ ])
178
+
179
+
180
+ def rotationy_4x4(theta):
181
+ return np.array([
182
+ [np.cos(theta / 180 * np.pi), 0.0, np.sin(theta / 180 * np.pi), 0.0],
183
+ [0.0, 1.0, 0.0, 0.0],
184
+ [-np.sin(theta / 180 * np.pi), 0.0, np.cos(theta / 180 * np.pi), 0.0],
185
+ [0.0, 0.0, 0.0, 1.0]
186
+ ])
187
+
188
+
189
+ def rotationz_4x4(theta):
190
+ return np.array([
191
+ [np.cos(theta / 180 * np.pi), np.sin(theta / 180 * np.pi), 0.0, 0.0],
192
+ [-np.sin(theta / 180 * np.pi), np.cos(theta / 180 * np.pi), 0.0, 0.0],
193
+ [0.0, 0.0, 1.0, 0.0],
194
+ [0.0, 0.0, 0.0, 1.0]
195
+ ])
196
+
197
+
198
+ def rotationx_3x3(theta):
199
+ return np.array([
200
+ [1.0, 0.0, 0.0],
201
+ [0.0, np.cos(theta / 180 * np.pi), np.sin(theta / 180 * np.pi)],
202
+ [0.0, -np.sin(theta / 180 * np.pi), np.cos(theta / 180 * np.pi)],
203
+ ])
204
+
205
+
206
+ def rotationy_3x3(theta):
207
+ return np.array([
208
+ [np.cos(theta / 180 * np.pi), 0.0, np.sin(theta / 180 * np.pi)],
209
+ [0.0, 1.0, 0.0],
210
+ [-np.sin(theta / 180 * np.pi), 0.0, np.cos(theta / 180 * np.pi)],
211
+ ])
212
+
213
+
214
+ def rotationz_3x3(theta):
215
+ return np.array([
216
+ [np.cos(theta / 180 * np.pi), np.sin(theta / 180 * np.pi), 0.0],
217
+ [-np.sin(theta / 180 * np.pi), np.cos(theta / 180 * np.pi), 0.0],
218
+ [0.0, 0.0, 1.0],
219
+ ])
220
+
221
+
222
+ def generate_point_grids(vol_res):
223
+ x_coords = np.array(range(0, vol_res), dtype=np.float32)
224
+ y_coords = np.array(range(0, vol_res), dtype=np.float32)
225
+ z_coords = np.array(range(0, vol_res), dtype=np.float32)
226
+
227
+ yv, xv, zv = np.meshgrid(x_coords, y_coords, z_coords)
228
+ xv = np.reshape(xv, (-1, 1))
229
+ yv = np.reshape(yv, (-1, 1))
230
+ zv = np.reshape(zv, (-1, 1))
231
+ pts = np.concatenate([xv, yv, zv], axis=-1)
232
+ pts = pts.astype(np.float32)
233
+ return pts
234
+
235
+
236
+ def infer_occupancy_value_grid_octree(test_res, pts, query_fn, init_res=64, ignore_thres=0.05):
237
+ pts = np.reshape(pts, (test_res, test_res, test_res, 3))
238
+
239
+ pts_ov = np.zeros([test_res, test_res, test_res])
240
+ dirty = np.ones_like(pts_ov, dtype=np.bool)
241
+ grid_mask = np.zeros_like(pts_ov, dtype=np.bool)
242
+
243
+ reso = test_res // init_res
244
+ while reso > 0:
245
+ grid_mask[0:test_res:reso, 0:test_res:reso, 0:test_res:reso] = True
246
+ test_mask = np.logical_and(grid_mask, dirty)
247
+
248
+ pts_ = pts[test_mask]
249
+ pts_ov[test_mask] = np.reshape(query_fn(pts_), pts_ov[test_mask].shape)
250
+
251
+ if reso <= 1:
252
+ break
253
+ for x in range(0, test_res - reso, reso):
254
+ for y in range(0, test_res - reso, reso):
255
+ for z in range(0, test_res - reso, reso):
256
+ # if center marked, return
257
+ if not dirty[x + reso // 2, y + reso // 2, z + reso // 2]:
258
+ continue
259
+ v0 = pts_ov[x, y, z]
260
+ v1 = pts_ov[x, y, z + reso]
261
+ v2 = pts_ov[x, y + reso, z]
262
+ v3 = pts_ov[x, y + reso, z + reso]
263
+ v4 = pts_ov[x + reso, y, z]
264
+ v5 = pts_ov[x + reso, y, z + reso]
265
+ v6 = pts_ov[x + reso, y + reso, z]
266
+ v7 = pts_ov[x + reso, y + reso, z + reso]
267
+ v = np.array([v0, v1, v2, v3, v4, v5, v6, v7])
268
+ v_min = np.min(v)
269
+ v_max = np.max(v)
270
+ # this cell is all the same
271
+ if (v_max - v_min) < ignore_thres:
272
+ pts_ov[x:x + reso, y:y + reso, z:z + reso] = (v_max + v_min) / 2
273
+ dirty[x:x + reso, y:y + reso, z:z + reso] = False
274
+ reso //= 2
275
+ return pts_ov
276
+
277
+
278
+ def batch_rod2quat(rot_vecs):
279
+ batch_size = rot_vecs.shape[0]
280
+
281
+ angle = torch.norm(rot_vecs + 1e-16, dim=1, keepdim=True)
282
+ rot_dir = rot_vecs / angle
283
+
284
+ cos = torch.cos(angle / 2)
285
+ sin = torch.sin(angle / 2)
286
+
287
+ # Bx1 arrays
288
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
289
+
290
+ qx = rx * sin
291
+ qy = ry * sin
292
+ qz = rz * sin
293
+ qw = cos-1.0
294
+
295
+ return torch.cat([qx,qy,qz,qw], dim=1)
296
+
297
+
298
+ def batch_quat2matrix(rvec):
299
+ '''
300
+ args:
301
+ rvec: (B, N, 4)
302
+ '''
303
+ B, N, _ = rvec.size()
304
+
305
+ theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=2))
306
+ rvec = rvec / theta[:, :, None]
307
+ return torch.stack((
308
+ 1. - 2. * rvec[:, :, 1] ** 2 - 2. * rvec[:, :, 2] ** 2,
309
+ 2. * (rvec[:, :, 0] * rvec[:, :, 1] - rvec[:, :, 2] * rvec[:, :, 3]),
310
+ 2. * (rvec[:, :, 0] * rvec[:, :, 2] + rvec[:, :, 1] * rvec[:, :, 3]),
311
+
312
+ 2. * (rvec[:, :, 0] * rvec[:, :, 1] + rvec[:, :, 2] * rvec[:, :, 3]),
313
+ 1. - 2. * rvec[:, :, 0] ** 2 - 2. * rvec[:, :, 2] ** 2,
314
+ 2. * (rvec[:, :, 1] * rvec[:, :, 2] - rvec[:, :, 0] * rvec[:, :, 3]),
315
+
316
+ 2. * (rvec[:, :, 0] * rvec[:, :, 2] - rvec[:, :, 1] * rvec[:, :, 3]),
317
+ 2. * (rvec[:, :, 0] * rvec[:, :, 3] + rvec[:, :, 1] * rvec[:, :, 2]),
318
+ 1. - 2. * rvec[:, :, 0] ** 2 - 2. * rvec[:, :, 1] ** 2
319
+ ), dim=2).view(B, N, 3, 3)
320
+
321
+
322
+ def get_posemap(map_type, n_joints, parents, n_traverse=1, normalize=True):
323
+ pose_map = torch.zeros(n_joints,n_joints-1)
324
+ if map_type == 'parent':
325
+ for i in range(n_joints-1):
326
+ pose_map[i+1,i] = 1.0
327
+ elif map_type == 'children':
328
+ for i in range(n_joints-1):
329
+ parent = parents[i+1]
330
+ for j in range(n_traverse):
331
+ pose_map[parent, i] += 1.0
332
+ if parent == 0:
333
+ break
334
+ parent = parents[parent]
335
+ if normalize:
336
+ pose_map /= pose_map.sum(0,keepdim=True)+1e-16
337
+ elif map_type == 'both':
338
+ for i in range(n_joints-1):
339
+ pose_map[i+1,i] += 1.0
340
+ parent = parents[i+1]
341
+ for j in range(n_traverse):
342
+ pose_map[parent, i] += 1.0
343
+ if parent == 0:
344
+ break
345
+ parent = parents[parent]
346
+ if normalize:
347
+ pose_map /= pose_map.sum(0,keepdim=True)+1e-16
348
+ else:
349
+ raise NotImplementedError('unsupported pose map type [%s]' % map_type)
350
+ pose_map = torch.cat([torch.zeros(n_joints, 1), pose_map], dim=1)
351
+ return pose_map
352
+
353
+
354
+ def vertices_to_triangles(vertices, faces):
355
+ """
356
+ :param vertices: [batch size, number of vertices, 3]
357
+ :param faces: [batch size, number of faces, 3)
358
+ :return: [batch size, number of faces, 3, 3]
359
+ """
360
+ assert (vertices.ndimension() == 3)
361
+ assert (faces.ndimension() == 3)
362
+ assert (vertices.shape[0] == faces.shape[0])
363
+ assert (vertices.shape[2] == 3)
364
+ assert (faces.shape[2] == 3)
365
+
366
+ bs, nv = vertices.shape[:2]
367
+ bs, nf = faces.shape[:2]
368
+ device = vertices.device
369
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
370
+ vertices = vertices.reshape((bs * nv, 3))
371
+ # pytorch only supports long and byte tensors for indexing
372
+ return vertices[faces.long()]
373
+
374
+
375
+ def calc_face_normals(vertices, faces):
376
+ assert len(vertices.shape) == 3
377
+ assert len(faces.shape) == 2
378
+ if isinstance(faces, np.ndarray):
379
+ faces = torch.from_numpy(faces.astype(np.int64)).to(vertices.device)
380
+
381
+ batch_size, pt_num = vertices.shape[:2]
382
+ face_num = faces.shape[0]
383
+
384
+ triangles = vertices_to_triangles(vertices, faces.unsqueeze(0).expand(batch_size, -1, -1))
385
+ triangles = triangles.reshape((batch_size * face_num, 3, 3))
386
+ v10 = triangles[:, 0] - triangles[:, 1]
387
+ v12 = triangles[:, 2] - triangles[:, 1]
388
+ # pytorch normalize divides by max(norm, eps) instead of (norm+eps) in chainer
389
+ normals = F.normalize(torch.cross(v10, v12), eps=1e-5)
390
+ normals = normals.reshape((batch_size, face_num, 3))
391
+
392
+ return normals
393
+
394
+
395
+ def calc_vert_normals(vertices, faces):
396
+ """
397
+ vertices: [B, N, 3]
398
+ faces: [F, 3]
399
+ """
400
+ normals = torch.zeros_like(vertices)
401
+ v0s = torch.index_select(vertices, dim=1, index=faces[:, 0]) # [B, F, 3]
402
+ v1s = torch.index_select(vertices, dim=1, index=faces[:, 1])
403
+ v2s = torch.index_select(vertices, dim=1, index=faces[:, 2])
404
+ normals = torch.index_add(normals, dim=1, index=faces[:, 1], source=torch.cross(v2s-v1s, v0s-v1s, dim=-1))
405
+ normals = torch.index_add(normals, dim=1, index=faces[:, 2], source=torch.cross(v0s-v2s, v1s-v2s, dim=-1))
406
+ normals = torch.index_add(normals, dim=1, index=faces[:, 0], source=torch.cross(v1s-v0s, v2s-v0s, dim=-1))
407
+ normals = F.normalize(normals, dim=-1)
408
+ return normals
409
+
410
+
411
+ def calc_vert_normals_numpy(vertices, faces):
412
+ assert len(vertices.shape) == 2
413
+ assert len(faces.shape) == 2
414
+
415
+ nmls = np.zeros_like(vertices)
416
+ fv0 = vertices[faces[:, 0]]
417
+ fv1 = vertices[faces[:, 1]]
418
+ fv2 = vertices[faces[:, 2]]
419
+ face_nmls = np.cross(fv1-fv0, fv2-fv0, axis=-1)
420
+ face_nmls = face_nmls / (np.linalg.norm(face_nmls, axis=-1, keepdims=True) + 1e-20)
421
+ for f, fn in zip(faces, face_nmls):
422
+ nmls[f] += fn
423
+ nmls = nmls / (np.linalg.norm(nmls, axis=-1, keepdims=True) + 1e-20)
424
+ return nmls
425
+
426
+
427
+ def glUV2torchUV(gl_uv):
428
+ torch_uv = torch.stack([
429
+ gl_uv[..., 0]*2.0-1.0,
430
+ gl_uv[..., 1]*-2.0+1.0
431
+ ], dim=-1)
432
+ return torch_uv
433
+
434
+
435
+ def normalize_vert_bbox(verts, dim=-1, per_axis=False):
436
+ bbox_min = torch.min(verts, dim=dim, keepdim=True)[0]
437
+ bbox_max = torch.max(verts, dim=dim, keepdim=True)[0]
438
+ verts = verts - 0.5 * (bbox_max + bbox_min)
439
+ if per_axis:
440
+ verts = 2 * verts / (bbox_max - bbox_min)
441
+ else:
442
+ verts = 2 * verts / torch.max(bbox_max-bbox_min, dim=dim, keepdim=True)[0]
443
+ return verts
444
+
445
+
446
+ def upsample_sdf_volume(sdf, upsample_factor):
447
+ assert sdf.shape[0] == sdf.shape[1] == sdf.shape[2]
448
+ coarse_resolution = sdf.shape[0]
449
+ fine_resolution = coarse_resolution * upsample_factor
450
+
451
+ sdf_interp_buffer = np.zeros([2, 2, 2, coarse_resolution, coarse_resolution, coarse_resolution],
452
+ dtype=np.float32)
453
+ dx_list = [0, 1, 0, 1, 0, 1, 0, 1]
454
+ dy_list = [0, 0, 1, 1, 0, 0, 1, 1]
455
+ dz_list = [0, 0, 0, 0, 1, 1, 1, 1]
456
+ for dx, dy, dz in zip(dx_list, dy_list, dz_list):
457
+ sdf_interp_buffer[dx, dy, dz, :, :, :] = np.roll(sdf, (-dx, -dy, -dz), axis=(0, 1, 2))
458
+
459
+ sdf_fine = np.zeros([fine_resolution, fine_resolution, fine_resolution], dtype=np.float32)
460
+ for dx in range(upsample_factor):
461
+ for dy in range(upsample_factor):
462
+ for dz in range(upsample_factor):
463
+ wx = (1.0 - dx / upsample_factor)
464
+ wy = (1.0 - dy / upsample_factor)
465
+ wz = (1.0 - dz / upsample_factor)
466
+ sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
467
+ wx * wy * wz * sdf_interp_buffer[0, 0, 0]
468
+ sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
469
+ (1.0 - wx) * wy * wz * sdf_interp_buffer[1, 0, 0]
470
+ sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
471
+ wx * (1.0 - wy) * wz * sdf_interp_buffer[0, 1, 0]
472
+ sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
473
+ (1.0 - wx) * (1.0 - wy) * wz * sdf_interp_buffer[1, 1, 0]
474
+ sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
475
+ wx * wy * (1.0 - wz) * sdf_interp_buffer[0, 0, 1]
476
+ sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
477
+ (1.0 - wx) * wy * (1.0 - wz) * sdf_interp_buffer[1, 0, 1]
478
+ sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
479
+ wx * (1.0 - wy) * (1.0 - wz) * sdf_interp_buffer[0, 1, 1]
480
+ sdf_fine[dx::upsample_factor, dy::upsample_factor, dz::upsample_factor] += \
481
+ (1.0 - wx) * (1.0 - wy) * (1.0 - wz) * sdf_interp_buffer[1, 1, 1]
482
+ return sdf_fine
483
+
484
+
485
+ def search_nearest_correspondence(src, tgt):
486
+ tgt_idx = np.zeros(len(src), dtype=np.int32)
487
+ tgt_dist = np.zeros(len(src), dtype=np.float32)
488
+ for i in range(len(src)):
489
+ dist = np.linalg.norm(tgt - src[i:(i+1)], axis=1, keepdims=False)
490
+ tgt_idx[i] = np.argmin(dist)
491
+ tgt_dist[i] = dist[tgt_idx[i]]
492
+ return tgt_idx, tgt_dist
493
+
494
+
495
+ def estimate_rigid_transformation(src, tgt):
496
+ src = src.transpose()
497
+ tgt = tgt.transpose()
498
+ mu1, mu2 = src.mean(axis=1, keepdims=True), tgt.mean(axis=1, keepdims=True)
499
+ X1, X2 = src - mu1, tgt - mu2
500
+
501
+ K = X1.dot(X2.T)
502
+ U, s, Vh = np.linalg.svd(K)
503
+ V = Vh.T
504
+ Z = np.eye(U.shape[0])
505
+ Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
506
+ R = V.dot(Z.dot(U.T))
507
+ t = mu2 - R.dot(mu1)
508
+
509
+ # orient, _ = cv.Rodrigues(R)
510
+ # orient = orient.reshape([-1])
511
+ # t = t.reshape([-1])
512
+ # return orient, t
513
+
514
+ transf = np.eye(4, dtype=np.float32)
515
+ transf[:3, :3] = R
516
+ transf[:3, 3] = t.reshape([-1])
517
+ return transf
render_utils/lib/utils/graphics_utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ import numpy as np
15
+ from typing import NamedTuple
16
+
17
+ class BasicPointCloud(NamedTuple):
18
+ points : np.array
19
+ colors : np.array
20
+ normals : np.array
21
+
22
+ def geom_transform_points(points, transf_matrix):
23
+ P, _ = points.shape
24
+ ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
25
+ points_hom = torch.cat([points, ones], dim=1)
26
+ points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
27
+
28
+ denom = points_out[..., 3:] + 0.0000001
29
+ return (points_out[..., :3] / denom).squeeze(dim=0)
30
+
31
+ def getWorld2View(R, t):
32
+ Rt = np.zeros((4, 4))
33
+ Rt[:3, :3] = R.transpose()
34
+ Rt[:3, 3] = t
35
+ Rt[3, 3] = 1.0
36
+ return np.float32(Rt)
37
+
38
+ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
39
+ Rt = np.zeros((4, 4))
40
+ Rt[:3, :3] = R.transpose()
41
+ Rt[:3, 3] = t
42
+ Rt[3, 3] = 1.0
43
+
44
+ C2W = np.linalg.inv(Rt)
45
+ cam_center = C2W[:3, 3]
46
+ cam_center = (cam_center + translate) * scale
47
+ C2W[:3, 3] = cam_center
48
+ Rt = np.linalg.inv(C2W)
49
+ return np.float32(Rt)
50
+
51
+ def getProjectionMatrix(znear, zfar, fovX, fovY, K = None, img_h = None, img_w = None):
52
+ if K is None:
53
+ tanHalfFovY = math.tan((fovY / 2))
54
+ tanHalfFovX = math.tan((fovX / 2))
55
+ top = tanHalfFovY * znear
56
+ bottom = -top
57
+ right = tanHalfFovX * znear
58
+ left = -right
59
+ else:
60
+ near_fx = znear / K[0, 0]
61
+ near_fy = znear / K[1, 1]
62
+
63
+ left = - (img_w - K[0, 2]) * near_fx
64
+ right = K[0, 2] * near_fx
65
+ bottom = (K[1, 2] - img_h) * near_fy
66
+ top = K[1, 2] * near_fy
67
+
68
+ P = torch.zeros(4, 4)
69
+
70
+ z_sign = 1.0
71
+
72
+ P[0, 0] = 2.0 * znear / (right - left)
73
+ P[1, 1] = 2.0 * znear / (top - bottom)
74
+ P[0, 2] = (right + left) / (right - left)
75
+ P[1, 2] = (top + bottom) / (top - bottom)
76
+ P[3, 2] = z_sign
77
+ P[2, 2] = z_sign * zfar / (zfar - znear)
78
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
79
+ return P
80
+
81
+ def fov2focal(fov, pixels):
82
+ return pixels / (2 * math.tan(fov / 2))
83
+
84
+ def focal2fov(focal, pixels):
85
+ return 2*math.atan(pixels/(2*focal))
86
+
87
+
88
+ def _so3_exp_map(
89
+ log_rot: torch.Tensor, eps: float = 0.0001
90
+ ):
91
+ """
92
+ A helper function that computes the so3 exponential map and,
93
+ apart from the rotation matrix, also returns intermediate variables
94
+ that can be re-used in other functions.
95
+ """
96
+
97
+ def hat(v: torch.Tensor) -> torch.Tensor:
98
+ """
99
+ Compute the Hat operator [1] of a batch of 3D vectors.
100
+
101
+ Args:
102
+ v: Batch of vectors of shape `(minibatch , 3)`.
103
+
104
+ Returns:
105
+ Batch of skew-symmetric matrices of shape
106
+ `(minibatch, 3 , 3)` where each matrix is of the form:
107
+ `[ 0 -v_z v_y ]
108
+ [ v_z 0 -v_x ]
109
+ [ -v_y v_x 0 ]`
110
+
111
+ Raises:
112
+ ValueError if `v` is of incorrect shape.
113
+
114
+ [1] https://en.wikipedia.org/wiki/Hat_operator
115
+ """
116
+
117
+ N, dim = v.shape
118
+ if dim != 3:
119
+ raise ValueError("Input vectors have to be 3-dimensional.")
120
+
121
+ h = torch.zeros((N, 3, 3), dtype=v.dtype, device=v.device)
122
+
123
+ x, y, z = v.unbind(1)
124
+
125
+ h[:, 0, 1] = -z
126
+ h[:, 0, 2] = y
127
+ h[:, 1, 0] = z
128
+ h[:, 1, 2] = -x
129
+ h[:, 2, 0] = -y
130
+ h[:, 2, 1] = x
131
+
132
+ return h
133
+
134
+ _, dim = log_rot.shape
135
+ if dim != 3:
136
+ raise ValueError("Input tensor shape has to be Nx3.")
137
+
138
+ nrms = (log_rot * log_rot).sum(1)
139
+ # phis ... rotation angles
140
+ rot_angles = torch.clamp(nrms, eps).sqrt()
141
+ rot_angles_inv = 1.0 / rot_angles
142
+ fac1 = rot_angles_inv * rot_angles.sin()
143
+ fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
144
+ skews = hat(log_rot)
145
+ skews_square = torch.bmm(skews, skews)
146
+
147
+ R = (
148
+ # pyre-fixme[16]: `float` has no attribute `__getitem__`.
149
+ fac1[:, None, None] * skews
150
+ + fac2[:, None, None] * skews_square
151
+ + torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
152
+ )
153
+
154
+ return R, rot_angles, skews, skews_square
155
+
156
+
157
+ def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
158
+ """
159
+ Convert a batch of logarithmic representations of rotation matrices `log_rot`
160
+ to a batch of 3x3 rotation matrices using Rodrigues formula [1].
161
+
162
+ In the logarithmic representation, each rotation matrix is represented as
163
+ a 3-dimensional vector (`log_rot`) who's l2-norm and direction correspond
164
+ to the magnitude of the rotation angle and the axis of rotation respectively.
165
+
166
+ The conversion has a singularity around `log(R) = 0`
167
+ which is handled by clamping controlled with the `eps` argument.
168
+
169
+ Args:
170
+ log_rot: Batch of vectors of shape `(minibatch, 3)`.
171
+ eps: A float constant handling the conversion singularity.
172
+
173
+ Returns:
174
+ Batch of rotation matrices of shape `(minibatch, 3, 3)`.
175
+
176
+ Raises:
177
+ ValueError if `log_rot` is of incorrect shape.
178
+
179
+ [1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
180
+ """
181
+ return _so3_exp_map(log_rot, eps=eps)[0]
render_utils/lib/utils/rotation_conversions.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ Device = Union[str, torch.device]
13
+
14
+ """
15
+ The transformation matrices returned from the functions in this file assume
16
+ the points on which the transformation will be applied are column vectors.
17
+ i.e. the R matrix is structured as
18
+
19
+ R = [
20
+ [Rxx, Rxy, Rxz],
21
+ [Ryx, Ryy, Ryz],
22
+ [Rzx, Rzy, Rzz],
23
+ ] # (3, 3)
24
+
25
+ This matrix can be applied to column vectors by post multiplication
26
+ by the points e.g.
27
+
28
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
29
+ transformed_points = R * points
30
+
31
+ To apply the same matrix to points which are row vectors, the R matrix
32
+ can be transposed and pre multiplied by the points:
33
+
34
+ e.g.
35
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
36
+ transformed_points = points * R.transpose(1, 0)
37
+ """
38
+
39
+
40
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
41
+ """
42
+ Convert rotations given as quaternions to rotation matrices.
43
+
44
+ Args:
45
+ quaternions: quaternions with real part first,
46
+ as tensor of shape (..., 4).
47
+
48
+ Returns:
49
+ Rotation matrices as tensor of shape (..., 3, 3).
50
+ """
51
+ r, i, j, k = torch.unbind(quaternions, -1)
52
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
53
+
54
+ o = torch.stack(
55
+ (
56
+ 1 - two_s * (j * j + k * k),
57
+ two_s * (i * j - k * r),
58
+ two_s * (i * k + j * r),
59
+ two_s * (i * j + k * r),
60
+ 1 - two_s * (i * i + k * k),
61
+ two_s * (j * k - i * r),
62
+ two_s * (i * k - j * r),
63
+ two_s * (j * k + i * r),
64
+ 1 - two_s * (i * i + j * j),
65
+ ),
66
+ -1,
67
+ )
68
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
69
+
70
+
71
+ def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
72
+ """
73
+ Return a tensor where each element has the absolute value taken from the,
74
+ corresponding element of a, with sign taken from the corresponding
75
+ element of b. This is like the standard copysign floating-point operation,
76
+ but is not careful about negative 0 and NaN.
77
+
78
+ Args:
79
+ a: source tensor.
80
+ b: tensor whose signs will be used, of the same shape as a.
81
+
82
+ Returns:
83
+ Tensor of the same shape as a with the signs of b.
84
+ """
85
+ signs_differ = (a < 0) != (b < 0)
86
+ return torch.where(signs_differ, -a, a)
87
+
88
+
89
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Returns torch.sqrt(torch.max(0, x))
92
+ but with a zero subgradient where x is 0.
93
+ """
94
+ ret = torch.zeros_like(x)
95
+ positive_mask = x > 0
96
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
97
+ return ret
98
+
99
+
100
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Convert rotations given as rotation matrices to quaternions.
103
+
104
+ Args:
105
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
106
+
107
+ Returns:
108
+ quaternions with real part first, as tensor of shape (..., 4).
109
+ """
110
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
111
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
112
+
113
+ batch_dim = matrix.shape[:-2]
114
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
115
+ matrix.reshape(batch_dim + (9,)), dim=-1
116
+ )
117
+
118
+ q_abs = _sqrt_positive_part(
119
+ torch.stack(
120
+ [
121
+ 1.0 + m00 + m11 + m22,
122
+ 1.0 + m00 - m11 - m22,
123
+ 1.0 - m00 + m11 - m22,
124
+ 1.0 - m00 - m11 + m22,
125
+ ],
126
+ dim=-1,
127
+ )
128
+ )
129
+
130
+ # we produce the desired quaternion multiplied by each of r, i, j, k
131
+ quat_by_rijk = torch.stack(
132
+ [
133
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
134
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
135
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
136
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
137
+ ],
138
+ dim=-2,
139
+ )
140
+
141
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
142
+ # the candidate won't be picked.
143
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
144
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
145
+
146
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
147
+ # forall i; we pick the best-conditioned one (with the largest denominator)
148
+
149
+ return quat_candidates[
150
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
151
+ ].reshape(batch_dim + (4,))
152
+
153
+
154
+ def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Return the rotation matrices for one of the rotations about an axis
157
+ of which Euler angles describe, for each value of the angle given.
158
+
159
+ Args:
160
+ axis: Axis label "X" or "Y or "Z".
161
+ angle: any shape tensor of Euler angles in radians
162
+
163
+ Returns:
164
+ Rotation matrices as tensor of shape (..., 3, 3).
165
+ """
166
+
167
+ cos = torch.cos(angle)
168
+ sin = torch.sin(angle)
169
+ one = torch.ones_like(angle)
170
+ zero = torch.zeros_like(angle)
171
+
172
+ if axis == "X":
173
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
174
+ elif axis == "Y":
175
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
176
+ elif axis == "Z":
177
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
178
+ else:
179
+ raise ValueError("letter must be either X, Y or Z.")
180
+
181
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
182
+
183
+
184
+ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
185
+ """
186
+ Convert rotations given as Euler angles in radians to rotation matrices.
187
+
188
+ Args:
189
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
190
+ convention: Convention string of three uppercase letters from
191
+ {"X", "Y", and "Z"}.
192
+
193
+ Returns:
194
+ Rotation matrices as tensor of shape (..., 3, 3).
195
+ """
196
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
197
+ raise ValueError("Invalid input euler angles.")
198
+ if len(convention) != 3:
199
+ raise ValueError("Convention must have 3 letters.")
200
+ if convention[1] in (convention[0], convention[2]):
201
+ raise ValueError(f"Invalid convention {convention}.")
202
+ for letter in convention:
203
+ if letter not in ("X", "Y", "Z"):
204
+ raise ValueError(f"Invalid letter {letter} in convention string.")
205
+ matrices = [
206
+ _axis_angle_rotation(c, e)
207
+ for c, e in zip(convention, torch.unbind(euler_angles, -1))
208
+ ]
209
+ # return functools.reduce(torch.matmul, matrices)
210
+ return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
211
+
212
+
213
+ def _angle_from_tan(
214
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
215
+ ) -> torch.Tensor:
216
+ """
217
+ Extract the first or third Euler angle from the two members of
218
+ the matrix which are positive constant times its sine and cosine.
219
+
220
+ Args:
221
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
222
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
223
+ convention.
224
+ data: Rotation matrices as tensor of shape (..., 3, 3).
225
+ horizontal: Whether we are looking for the angle for the third axis,
226
+ which means the relevant entries are in the same row of the
227
+ rotation matrix. If not, they are in the same column.
228
+ tait_bryan: Whether the first and third axes in the convention differ.
229
+
230
+ Returns:
231
+ Euler Angles in radians for each matrix in data as a tensor
232
+ of shape (...).
233
+ """
234
+
235
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
236
+ if horizontal:
237
+ i2, i1 = i1, i2
238
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
239
+ if horizontal == even:
240
+ return torch.atan2(data[..., i1], data[..., i2])
241
+ if tait_bryan:
242
+ return torch.atan2(-data[..., i2], data[..., i1])
243
+ return torch.atan2(data[..., i2], -data[..., i1])
244
+
245
+
246
+ def _index_from_letter(letter: str) -> int:
247
+ if letter == "X":
248
+ return 0
249
+ if letter == "Y":
250
+ return 1
251
+ if letter == "Z":
252
+ return 2
253
+ raise ValueError("letter must be either X, Y or Z.")
254
+
255
+
256
+ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
257
+ """
258
+ Convert rotations given as rotation matrices to Euler angles in radians.
259
+
260
+ Args:
261
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
262
+ convention: Convention string of three uppercase letters.
263
+
264
+ Returns:
265
+ Euler angles in radians as tensor of shape (..., 3).
266
+ """
267
+ if len(convention) != 3:
268
+ raise ValueError("Convention must have 3 letters.")
269
+ if convention[1] in (convention[0], convention[2]):
270
+ raise ValueError(f"Invalid convention {convention}.")
271
+ for letter in convention:
272
+ if letter not in ("X", "Y", "Z"):
273
+ raise ValueError(f"Invalid letter {letter} in convention string.")
274
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
275
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
276
+ i0 = _index_from_letter(convention[0])
277
+ i2 = _index_from_letter(convention[2])
278
+ tait_bryan = i0 != i2
279
+ if tait_bryan:
280
+ central_angle = torch.asin(
281
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
282
+ )
283
+ else:
284
+ central_angle = torch.acos(matrix[..., i0, i0])
285
+
286
+ o = (
287
+ _angle_from_tan(
288
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
289
+ ),
290
+ central_angle,
291
+ _angle_from_tan(
292
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
293
+ ),
294
+ )
295
+ return torch.stack(o, -1)
296
+
297
+
298
+ def random_quaternions(
299
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
300
+ ) -> torch.Tensor:
301
+ """
302
+ Generate random quaternions representing rotations,
303
+ i.e. versors with nonnegative real part.
304
+
305
+ Args:
306
+ n: Number of quaternions in a batch to return.
307
+ dtype: Type to return.
308
+ device: Desired device of returned tensor. Default:
309
+ uses the current device for the default tensor type.
310
+
311
+ Returns:
312
+ Quaternions as tensor of shape (N, 4).
313
+ """
314
+ if isinstance(device, str):
315
+ device = torch.device(device)
316
+ o = torch.randn((n, 4), dtype=dtype, device=device)
317
+ s = (o * o).sum(1)
318
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
319
+ return o
320
+
321
+
322
+ def random_rotations(
323
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
324
+ ) -> torch.Tensor:
325
+ """
326
+ Generate random rotations as 3x3 rotation matrices.
327
+
328
+ Args:
329
+ n: Number of rotation matrices in a batch to return.
330
+ dtype: Type to return.
331
+ device: Device of returned tensor. Default: if None,
332
+ uses the current device for the default tensor type.
333
+
334
+ Returns:
335
+ Rotation matrices as tensor of shape (n, 3, 3).
336
+ """
337
+ quaternions = random_quaternions(n, dtype=dtype, device=device)
338
+ return quaternion_to_matrix(quaternions)
339
+
340
+
341
+ def random_rotation(
342
+ dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
343
+ ) -> torch.Tensor:
344
+ """
345
+ Generate a single random 3x3 rotation matrix.
346
+
347
+ Args:
348
+ dtype: Type to return
349
+ device: Device of returned tensor. Default: if None,
350
+ uses the current device for the default tensor type
351
+
352
+ Returns:
353
+ Rotation matrix as tensor of shape (3, 3).
354
+ """
355
+ return random_rotations(1, dtype, device)[0]
356
+
357
+
358
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
359
+ """
360
+ Convert a unit quaternion to a standard form: one in which the real
361
+ part is non negative.
362
+
363
+ Args:
364
+ quaternions: Quaternions with real part first,
365
+ as tensor of shape (..., 4).
366
+
367
+ Returns:
368
+ Standardized quaternions as tensor of shape (..., 4).
369
+ """
370
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
371
+
372
+
373
+ def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
374
+ """
375
+ Multiply two quaternions.
376
+ Usual torch rules for broadcasting apply.
377
+
378
+ Args:
379
+ a: Quaternions as tensor of shape (..., 4), real part first.
380
+ b: Quaternions as tensor of shape (..., 4), real part first.
381
+
382
+ Returns:
383
+ The product of a and b, a tensor of quaternions shape (..., 4).
384
+ """
385
+ aw, ax, ay, az = torch.unbind(a, -1)
386
+ bw, bx, by, bz = torch.unbind(b, -1)
387
+ ow = aw * bw - ax * bx - ay * by - az * bz
388
+ ox = aw * bx + ax * bw + ay * bz - az * by
389
+ oy = aw * by - ax * bz + ay * bw + az * bx
390
+ oz = aw * bz + ax * by - ay * bx + az * bw
391
+ return torch.stack((ow, ox, oy, oz), -1)
392
+
393
+
394
+ def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
395
+ """
396
+ Multiply two quaternions representing rotations, returning the quaternion
397
+ representing their composition, i.e. the versor with nonnegative real part.
398
+ Usual torch rules for broadcasting apply.
399
+
400
+ Args:
401
+ a: Quaternions as tensor of shape (..., 4), real part first.
402
+ b: Quaternions as tensor of shape (..., 4), real part first.
403
+
404
+ Returns:
405
+ The product of a and b, a tensor of quaternions of shape (..., 4).
406
+ """
407
+ ab = quaternion_raw_multiply(a, b)
408
+ return standardize_quaternion(ab)
409
+
410
+
411
+ def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
412
+ """
413
+ Given a quaternion representing rotation, get the quaternion representing
414
+ its inverse.
415
+
416
+ Args:
417
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
418
+ first, which must be versors (unit quaternions).
419
+
420
+ Returns:
421
+ The inverse, a tensor of quaternions of shape (..., 4).
422
+ """
423
+
424
+ scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
425
+ return quaternion * scaling
426
+
427
+
428
+ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
429
+ """
430
+ Apply the rotation given by a quaternion to a 3D point.
431
+ Usual torch rules for broadcasting apply.
432
+
433
+ Args:
434
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
435
+ point: Tensor of 3D points of shape (..., 3).
436
+
437
+ Returns:
438
+ Tensor of rotated points of shape (..., 3).
439
+ """
440
+ if point.size(-1) != 3:
441
+ raise ValueError(f"Points are not in 3D, {point.shape}.")
442
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
443
+ point_as_quaternion = torch.cat((real_parts, point), -1)
444
+ out = quaternion_raw_multiply(
445
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
446
+ quaternion_invert(quaternion),
447
+ )
448
+ return out[..., 1:]
449
+
450
+
451
+ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
452
+ """
453
+ Convert rotations given as axis/angle to rotation matrices.
454
+
455
+ Args:
456
+ axis_angle: Rotations given as a vector in axis angle form,
457
+ as a tensor of shape (..., 3), where the magnitude is
458
+ the angle turned anticlockwise in radians around the
459
+ vector's direction.
460
+
461
+ Returns:
462
+ Rotation matrices as tensor of shape (..., 3, 3).
463
+ """
464
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
465
+
466
+
467
+ def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
468
+ """
469
+ Convert rotations given as rotation matrices to axis/angle.
470
+
471
+ Args:
472
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
473
+
474
+ Returns:
475
+ Rotations given as a vector in axis angle form, as a tensor
476
+ of shape (..., 3), where the magnitude is the angle
477
+ turned anticlockwise in radians around the vector's
478
+ direction.
479
+ """
480
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
481
+
482
+
483
+ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
484
+ """
485
+ Convert rotations given as axis/angle to quaternions.
486
+
487
+ Args:
488
+ axis_angle: Rotations given as a vector in axis angle form,
489
+ as a tensor of shape (..., 3), where the magnitude is
490
+ the angle turned anticlockwise in radians around the
491
+ vector's direction.
492
+
493
+ Returns:
494
+ quaternions with real part first, as tensor of shape (..., 4).
495
+ """
496
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
497
+ half_angles = angles * 0.5
498
+ eps = 1e-6
499
+ small_angles = angles.abs() < eps
500
+ sin_half_angles_over_angles = torch.empty_like(angles)
501
+ sin_half_angles_over_angles[~small_angles] = (
502
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
503
+ )
504
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
505
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
506
+ sin_half_angles_over_angles[small_angles] = (
507
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
508
+ )
509
+ quaternions = torch.cat(
510
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
511
+ )
512
+ return quaternions
513
+
514
+
515
+ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
516
+ """
517
+ Convert rotations given as quaternions to axis/angle.
518
+
519
+ Args:
520
+ quaternions: quaternions with real part first,
521
+ as tensor of shape (..., 4).
522
+
523
+ Returns:
524
+ Rotations given as a vector in axis angle form, as a tensor
525
+ of shape (..., 3), where the magnitude is the angle
526
+ turned anticlockwise in radians around the vector's
527
+ direction.
528
+ """
529
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
530
+ half_angles = torch.atan2(norms, quaternions[..., :1])
531
+ angles = 2 * half_angles
532
+ eps = 1e-6
533
+ small_angles = angles.abs() < eps
534
+ sin_half_angles_over_angles = torch.empty_like(angles)
535
+ sin_half_angles_over_angles[~small_angles] = (
536
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
537
+ )
538
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
539
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
540
+ sin_half_angles_over_angles[small_angles] = (
541
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
542
+ )
543
+ return quaternions[..., 1:] / sin_half_angles_over_angles
544
+
545
+
546
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
547
+ """
548
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
549
+ using Gram--Schmidt orthogonalization per Section B of [1].
550
+ Args:
551
+ d6: 6D rotation representation, of size (*, 6)
552
+
553
+ Returns:
554
+ batch of rotation matrices of size (*, 3, 3)
555
+
556
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
557
+ On the Continuity of Rotation Representations in Neural Networks.
558
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
559
+ Retrieved from http://arxiv.org/abs/1812.07035
560
+ """
561
+
562
+ a1, a2 = d6[..., :3], d6[..., 3:]
563
+ b1 = F.normalize(a1, dim=-1)
564
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
565
+ b2 = F.normalize(b2, dim=-1)
566
+ b3 = torch.cross(b1, b2, dim=-1)
567
+ return torch.stack((b1, b2, b3), dim=-2)
568
+
569
+
570
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
571
+ """
572
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
573
+ by dropping the last row. Note that 6D representation is not unique.
574
+ Args:
575
+ matrix: batch of rotation matrices of size (*, 3, 3)
576
+
577
+ Returns:
578
+ 6D rotation representation, of size (*, 6)
579
+
580
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
581
+ On the Continuity of Rotation Representations in Neural Networks.
582
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
583
+ Retrieved from http://arxiv.org/abs/1812.07035
584
+ """
585
+ batch_dim = matrix.size()[:-2]
586
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
render_utils/lib/utils/sh_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The PlenOctree Authors.
2
+ # Redistribution and use in source and binary forms, with or without
3
+ # modification, are permitted provided that the following conditions are met:
4
+ #
5
+ # 1. Redistributions of source code must retain the above copyright notice,
6
+ # this list of conditions and the following disclaimer.
7
+ #
8
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
9
+ # this list of conditions and the following disclaimer in the documentation
10
+ # and/or other materials provided with the distribution.
11
+ #
12
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22
+ # POSSIBILITY OF SUCH DAMAGE.
23
+
24
+ import torch
25
+
26
+ C0 = 0.28209479177387814
27
+ C1 = 0.4886025119029199
28
+ C2 = [
29
+ 1.0925484305920792,
30
+ -1.0925484305920792,
31
+ 0.31539156525252005,
32
+ -1.0925484305920792,
33
+ 0.5462742152960396
34
+ ]
35
+ C3 = [
36
+ -0.5900435899266435,
37
+ 2.890611442640554,
38
+ -0.4570457994644658,
39
+ 0.3731763325901154,
40
+ -0.4570457994644658,
41
+ 1.445305721320277,
42
+ -0.5900435899266435
43
+ ]
44
+ C4 = [
45
+ 2.5033429417967046,
46
+ -1.7701307697799304,
47
+ 0.9461746957575601,
48
+ -0.6690465435572892,
49
+ 0.10578554691520431,
50
+ -0.6690465435572892,
51
+ 0.47308734787878004,
52
+ -1.7701307697799304,
53
+ 0.6258357354491761,
54
+ ]
55
+
56
+
57
+ def eval_sh(deg, sh, dirs):
58
+ """
59
+ Evaluate spherical harmonics at unit directions
60
+ using hardcoded SH polynomials.
61
+ Works with torch/np/jnp.
62
+ ... Can be 0 or more batch dimensions.
63
+ Args:
64
+ deg: int SH deg. Currently, 0-3 supported
65
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66
+ dirs: jnp.ndarray unit directions [..., 3]
67
+ Returns:
68
+ [..., C]
69
+ """
70
+ assert deg <= 4 and deg >= 0
71
+ coeff = (deg + 1) ** 2
72
+ assert sh.shape[-1] >= coeff
73
+
74
+ result = C0 * sh[..., 0]
75
+ if deg > 0:
76
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77
+ result = (result -
78
+ C1 * y * sh[..., 1] +
79
+ C1 * z * sh[..., 2] -
80
+ C1 * x * sh[..., 3])
81
+
82
+ if deg > 1:
83
+ xx, yy, zz = x * x, y * y, z * z
84
+ xy, yz, xz = x * y, y * z, x * z
85
+ result = (result +
86
+ C2[0] * xy * sh[..., 4] +
87
+ C2[1] * yz * sh[..., 5] +
88
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89
+ C2[3] * xz * sh[..., 7] +
90
+ C2[4] * (xx - yy) * sh[..., 8])
91
+
92
+ if deg > 2:
93
+ result = (result +
94
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95
+ C3[1] * xy * z * sh[..., 10] +
96
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99
+ C3[5] * z * (xx - yy) * sh[..., 14] +
100
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101
+
102
+ if deg > 3:
103
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112
+ return result
113
+
114
+ def RGB2SH(rgb):
115
+ return (rgb - 0.5) / C0
116
+
117
+ def SH2RGB(sh):
118
+ return sh * C0 + 0.5
render_utils/stitch_body_and_head.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import trimesh
5
+ from sklearn.neighbors import KDTree
6
+ import tqdm
7
+ import cv2 as cv
8
+ import os, glob
9
+
10
+ from .lib.networks.faceverse_torch import FaceVerseModel
11
+ from .lib.networks.smpl_torch import SmplTorch
12
+ from .lib.utils.gaussian_np_utils import GaussianAttributes, load_gaussians_from_ply, save_gaussians_as_ply, \
13
+ apply_transformation_to_gaussians, combine_gaussians, select_gaussians, update_gaussian_attributes
14
+ from .lib.utils.geometry import search_nearest_correspondence, estimate_rigid_transformation
15
+ from .lib.utils.sh_utils import SH2RGB
16
+
17
+
18
+ def process_smpl_head():
19
+ smpl = SmplTorch(model_file='./AnimatableGaussians/smpl_files/smplx/SMPLX_NEUTRAL.npz')
20
+ smpl_v_template = smpl.v_template.detach().cpu().numpy()
21
+ smpl_faces = smpl.faces.detach().cpu().numpy()
22
+ # head_skinning_weights = smpl.weights[:, 15] + smpl.weights[:, 22] + smpl.weights[:, 23] + smpl.weights[:, 24]
23
+ #
24
+ # blend_weight = np.clip(head_skinning_weights* 1.2 - 0.2, 0, 1)
25
+ # head_ids = np.where(blend_weight > 0)[0]
26
+ #
27
+ # np.savez('./data/smplx_head_vidx_and_blendweight.npz', blend_weight=blend_weight, head_ids=head_ids)
28
+
29
+ if not os.path.exists('./data/smpl_models/smplx_head_3.obj'):
30
+ trimesh.Trimesh(vertices=smpl_v_template, faces=smpl_faces).export('./data/smpl_models/smplx_head_3.obj')
31
+ print('Please cut out SMPL head!!!')
32
+ import pdb; pdb.set_trace()
33
+
34
+ smplx_head = trimesh.load('./data/smpl_models/smplx_head_3.obj')
35
+ smplx_to_head_dist = np.zeros([smpl_v_template.shape[0]])
36
+ for vi, v in enumerate(smpl_v_template):
37
+ nndist = np.min(np.linalg.norm(v.reshape([1, 3]) - smplx_head.vertices, axis=1))
38
+ smplx_to_head_dist[vi] = nndist
39
+
40
+ head_ids = np.where(smplx_to_head_dist < 0.001)[0]
41
+ blend_weight = np.exp(-smplx_to_head_dist*smplx_to_head_dist * 2000)
42
+
43
+ np.savez('./data/smpl_models/smplx_head_vidx_and_blendweight.npz', blend_weight=blend_weight, head_ids=head_ids)
44
+
45
+ return
46
+
47
+
48
+ def load_body_params(path):
49
+ param = dict(np.load(path))
50
+ global_orient = param['global_orient']
51
+ transl = param['transl']
52
+ body_pose = param['body_pose']
53
+ betas = param['betas']
54
+ return global_orient, transl, body_pose, betas
55
+
56
+
57
+ def load_face_params(path):
58
+ param = dict(np.load(path))
59
+ pose = param['pose']
60
+ scale = param['scale']
61
+ id_coeff = param['id_coeff']
62
+ exp_coeff = param['exp_coeff']
63
+ return pose, scale, id_coeff, exp_coeff
64
+
65
+
66
+ def get_smpl_verts_and_head_transformation(smpl, global_orient, body_pose, transl, betas):
67
+ pose = torch.cat([
68
+ torch.from_numpy(global_orient.astype(np.float32)),
69
+ torch.from_numpy(body_pose.astype(np.float32)),
70
+ torch.zeros([(3+15+15)*3], dtype=torch.float32)], dim=-1)
71
+ beta = torch.from_numpy(betas.astype(np.float32))
72
+ verts, skinning_dict = smpl.forward(pose.reshape(1, -1), beta.reshape(1, -1))
73
+ verts = verts[0].detach().cpu().numpy()
74
+ head_joint_transfmat = skinning_dict['G'][0, 15].detach().cpu().numpy()
75
+ verts += transl.reshape([1, 3])
76
+ head_joint_transfmat[:3, 3] += transl.reshape([3])
77
+ return verts, head_joint_transfmat
78
+
79
+
80
+ def crop_facial_area(faceverse_verts, points, dist_thres=0.025):
81
+ min_x, min_y, min_z = np.min(faceverse_verts, axis=0)
82
+ max_x, max_y, max_z = np.max(faceverse_verts, axis=0)
83
+ pad = dist_thres*2
84
+ in_bbox_mask = (points[:, 0] > min_x - pad) * (points[:, 0] < max_x + pad) * \
85
+ (points[:, 1] > min_y - pad) * (points[:, 1] < max_y + pad) * \
86
+ (points[:, 2] > min_z - pad) * (points[:, 2] < max_z + pad)
87
+ in_bbox_idx = np.where(in_bbox_mask)[0]
88
+ facial_points = points[in_bbox_mask]
89
+ nndist = np.ones([len(facial_points)]) * 1e10
90
+ for i in tqdm.trange(len(facial_points), desc='calculating facial area'):
91
+ nndist[i] = np.min(np.linalg.norm(faceverse_verts - facial_points[i:(i+1)], axis=1, keepdims=False))
92
+ close_to_face_mask = nndist < dist_thres
93
+ facial_points = facial_points[close_to_face_mask]
94
+ facial_idx = in_bbox_idx[close_to_face_mask]
95
+ return facial_points, facial_idx
96
+
97
+
98
+ def crop_facial_area2(smpl_verts, smpl_head_vids, points):
99
+ min_x, min_y, min_z = np.min(smpl_verts[smpl_head_vids], axis=0)
100
+ max_x, max_y, max_z = np.max(smpl_verts[smpl_head_vids], axis=0)
101
+ pad = 0.05
102
+ in_bbox_mask = (points[:, 0] > min_x - pad) * (points[:, 0] < max_x + pad) * \
103
+ (points[:, 1] > min_y - pad) * (points[:, 1] < max_y + pad) * \
104
+ (points[:, 2] > min_z - pad) * (points[:, 2] < max_z + pad)
105
+ in_bbox_idx = np.where(in_bbox_mask)[0]
106
+ facial_points = points[in_bbox_mask]
107
+ smpl_head_mask = np.zeros([len(smpl_verts)], dtype=np.bool_)
108
+ smpl_head_mask[smpl_head_vids] = True
109
+ close_to_face_mask = np.zeros([len(facial_points)], dtype=np.bool_)
110
+ for i in tqdm.trange(len(facial_points)):
111
+ nnid = np.argmin(np.linalg.norm(smpl_verts - facial_points[i:(i+1)], axis=1, keepdims=False))
112
+ close_to_face_mask[i] = smpl_head_mask[nnid]
113
+ facial_points = facial_points[close_to_face_mask]
114
+ facial_idx = in_bbox_idx[close_to_face_mask]
115
+ return facial_points, facial_idx
116
+
117
+
118
+ def transform_faceverse_to_live_body_space(faceverse_verts, faceverse_to_smplx, head_joint_transfmat):
119
+ faceverse_verts = np.matmul(faceverse_verts, faceverse_to_smplx[:3, :3].transpose()) + faceverse_to_smplx[:3, 3].reshape(1, 3)
120
+ faceverse_verts = np.matmul(faceverse_verts, head_joint_transfmat[:3, :3].transpose()) + head_joint_transfmat[:3, 3].reshape(1, 3)
121
+ return faceverse_verts
122
+
123
+
124
+ def calc_livehead2livebody(head_pose, smplx_to_faceverse, head_joint_transfmat):
125
+ head_cano2live = np.eye(4, dtype=np.float32)
126
+ head_cano2live[:3, :3] = cv.Rodrigues(head_pose[:3])[0]
127
+ head_cano2live[:3, 3] = head_pose[3:]
128
+ head_live2cano = np.linalg.inv(head_cano2live)
129
+
130
+ faceverse_to_smplx = np.linalg.inv(smplx_to_faceverse)
131
+
132
+ total_transf = np.eye(4, dtype=np.float32)
133
+ for t in [head_live2cano, np.diag([1, -1, -1, 1]), faceverse_to_smplx, head_joint_transfmat]:
134
+ total_transf = np.matmul(t, total_transf)
135
+
136
+ return total_transf
137
+
138
+
139
+ def get_face_blend_weight(head_facial_points, smpl_verts, sigma=0.015):
140
+ # dists = np.load('./data/faceverse/smplx_verts_to_faceverse_dist.npy').astype(np.float32)
141
+ # face_nerf_blend_weight = np.exp(-dists**2/(2*sigma**2))
142
+ # face_nerf_blend_weight = np.clip(face_nerf_blend_weight*1.2 - 0.1, 0, 1)
143
+
144
+ smpl_blend_weight = dict(np.load('./data/smpl_models/smplx_head_vidx_and_blendweight.npz'))['blend_weight']
145
+
146
+ corr_idx_, _ = search_nearest_correspondence(head_facial_points, smpl_verts)
147
+ corr_bw = smpl_blend_weight[corr_idx_]
148
+
149
+ for _ in tqdm.trange(10):
150
+ corr_bw_ = np.zeros_like(corr_bw)
151
+ tree = KDTree(head_facial_points, leaf_size=2)
152
+ for i in range(len(head_facial_points)):
153
+ _, idx = tree.query(head_facial_points[i:(i+1)], k=4)
154
+ corr_bw_[i] = np.mean(corr_bw[idx])
155
+ corr_bw = np.copy(corr_bw_)
156
+
157
+ # corr_bw = np.clip(corr_bw*1.2 - 0.15, 0, 1)
158
+
159
+ # with open('./debug/debug_head_facial_bw.obj', 'w') as fp:
160
+ # for p, w in zip(head_facial_points, corr_bw):
161
+ # fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], w, w, w))
162
+ # import pdb; pdb.set_trace()
163
+ return corr_bw
164
+
165
+
166
+ def get_face_blend_weight2(head_facial_points, body_points, body_facial_idx):
167
+ body_facial_bbox_min = np.min(body_points[body_facial_idx], axis=0)
168
+ body_facial_bbox_max = np.max(body_points[body_facial_idx], axis=0)
169
+ body_facial_bbox_min = body_facial_bbox_min - 0.1
170
+ body_facial_bbox_max = body_facial_bbox_max + 0.1
171
+ inside_bbox_flag = \
172
+ np.int32(body_points[:, 0] > body_facial_bbox_min[0]) * \
173
+ np.int32(body_points[:, 0] < body_facial_bbox_max[0]) * \
174
+ np.int32(body_points[:, 1] > body_facial_bbox_min[1]) * \
175
+ np.int32(body_points[:, 1] < body_facial_bbox_max[1]) * \
176
+ np.int32(body_points[:, 2] > body_facial_bbox_min[2]) * \
177
+ np.int32(body_points[:, 2] < body_facial_bbox_max[2])
178
+ point_idx_inside_bbox = np.nonzero(inside_bbox_flag >0)[0]
179
+ body_blend_weight = np.zeros([len(body_points)], dtype=np.float32)
180
+ body_blend_weight[body_facial_idx] = 1
181
+
182
+ body_points_in_bbox = body_points[point_idx_inside_bbox]
183
+ body_blend_weight_in_bbox = body_blend_weight[point_idx_inside_bbox]
184
+ for _ in tqdm.trange(1, desc='Calculating body facial blend weight'):
185
+ corr_bw_ = np.zeros_like(body_blend_weight_in_bbox)
186
+ tree = KDTree(body_points_in_bbox, leaf_size=2)
187
+ for i in tqdm.trange(len(body_points_in_bbox)):
188
+ ind = tree.query_radius(body_points_in_bbox[i:(i+1)], r=0.035)
189
+ corr_bw_[i] = np.mean(body_blend_weight_in_bbox[ind[0]])
190
+ body_blend_weight_in_bbox = np.copy(corr_bw_)
191
+ body_blend_weight[point_idx_inside_bbox] = body_blend_weight_in_bbox
192
+
193
+ with open('./debug/debug_body_facial_bw.obj', 'w') as fp:
194
+ for p, w in zip(body_points, body_blend_weight):
195
+ fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], w, w, w))
196
+
197
+ tree = KDTree(body_points, leaf_size=2)
198
+ corr_bw = np.zeros([len(head_facial_points)], dtype=np.float32)
199
+ for i in range(len(head_facial_points)):
200
+ _, idx = tree.query(head_facial_points[i:(i+1)], k=4)
201
+ corr_bw[i] = np.mean(body_blend_weight[idx])
202
+ # corr_bw = np.clip(corr_bw*1.2 - 0.15, 0, 1)
203
+
204
+ corr_bw_bmin, corr_bw_bmax = np.percentile(corr_bw, 5), np.percentile(corr_bw, 95)
205
+ corr_bw = np.clip((corr_bw-corr_bw_bmin)/(corr_bw_bmax-corr_bw_bmin), 0, 1)
206
+ with open('./debug/debug_head_facial_bw.obj', 'w') as fp:
207
+ for p, w in zip(head_facial_points, corr_bw):
208
+ fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], w, w, w))
209
+ return corr_bw
210
+
211
+
212
+ def estimate_color_transfer(head_facial_points, body_facial_points, head_facial_color, body_facial_color, head_facial_opacity):
213
+ head_facial_color = head_facial_color * 0.28209479177387814 + 0.5
214
+ body_facial_color = body_facial_color * 0.28209479177387814 + 0.5
215
+
216
+ corr_idx, _ = search_nearest_correspondence(head_facial_points, body_facial_points)
217
+ corr_color = body_facial_color[corr_idx]
218
+
219
+ opacity = 1/(1+np.exp(-head_facial_opacity))
220
+ weight = np.float32(opacity > 0.35)
221
+ head_facial_color = head_facial_color.reshape(len(head_facial_color), 3) * weight.reshape([-1, 1])
222
+ corr_color = corr_color.reshape(len(corr_color), 3) * weight.reshape([-1, 1])
223
+
224
+ head_facial_color = np.concatenate([head_facial_color, np.zeros_like(head_facial_color[:, :1])], axis=1)
225
+ corr_color = np.concatenate([corr_color, np.zeros_like(corr_color[:, :1])], axis=1)
226
+
227
+ transfer = nn.Parameter(torch.eye(4, dtype=torch.float32))
228
+ head_facial_color_th = torch.from_numpy(head_facial_color).float()
229
+ corr_color_th = torch.from_numpy(corr_color).float()
230
+ weight_th = torch.from_numpy(weight).float()
231
+ optim = torch.optim.Adam([transfer], lr=1e-2)
232
+
233
+ for i in range(500):
234
+ optim.zero_grad()
235
+ loss = torch.mean(torch.abs(corr_color_th - torch.matmul(head_facial_color_th, transfer.permute(1, 0)))*weight_th)
236
+ loss = loss + torch.sum(torch.square(transfer - torch.eye(4, dtype=torch.float32))) * 5e-2
237
+ if i % 25 == 0:
238
+ print(loss.item())
239
+ loss.backward()
240
+ optim.step()
241
+ transfer = transfer.detach().cpu().numpy()
242
+ print(transfer)
243
+
244
+ # with open('./debug/debug_body_facial_color_updated.obj', 'w') as fp:
245
+ # for p, c in zip(body_facial_points, body_facial_color):
246
+ # # c = c * 0.28209479177387814 + 0.5
247
+ # c = np.clip(c, 0, 1)
248
+ # fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], c[0], c[1], c[2]))
249
+ # with open('./debug/debug_head_facial_color_updated.obj', 'w') as fp:
250
+ # head_facial_color = np.matmul(head_facial_color, transfer)
251
+ # for p, c, w in zip(head_facial_points, head_facial_color, weight):
252
+ # if w < 0.1:
253
+ # continue
254
+ # # c = c * 0.28209479177387814 + 0.5
255
+ # c = np.clip(c, 0, 1)
256
+ # fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], c[0], c[1], c[2]))
257
+ # import pdb; pdb.set_trace()
258
+
259
+ return transfer
260
+
261
+
262
+ def blend_color(head_facial_color, body_facial_color, blend_weight):
263
+ blend_weight = blend_weight.reshape([len(blend_weight)] + [1]*(len(head_facial_color.shape)-1))
264
+ result = head_facial_color * blend_weight + body_facial_color * (1-blend_weight)
265
+ return result
266
+
267
+ def save_body_face_stitching_data(
268
+ result_path, smplx_to_faceverse, residual_transf, body_nonface_mask, head_nonface_mask,
269
+ head_facial_idx, body_facial_idx, corr_idx, face_color_bw, color_transfer):
270
+ # os.makedirs('./data/%s' % result_suffix, exist_ok=True)
271
+ # np.savez('./data/%s/body_face_blending_param.npz' % result_suffix,
272
+ # smplx_to_faceverse=smplx_to_faceverse.astype(np.float32),
273
+ # residual_transf=residual_transf.astype(np.float32),
274
+ # body_nonface_mask=body_nonface_mask.astype(np.int32),
275
+ # head_facial_idx=head_facial_idx.astype(np.int32),
276
+ # body_facial_idx=body_facial_idx.astype(np.int32),
277
+ # head_body_facial_corr_idx=corr_idx.astype(np.int32),
278
+ # face_color_bw=face_color_bw.astype(np.float32),
279
+ # color_transfer=color_transfer.astype(np.float32))
280
+ head_color_bw = np.zeros([len(head_nonface_mask)])
281
+ head_color_bw[head_facial_idx] = face_color_bw
282
+ head_corr_idx = np.zeros([len(head_nonface_mask)])
283
+ head_corr_idx[head_facial_idx] = body_facial_idx[corr_idx]
284
+ np.savez(result_path,
285
+ smplx_to_faceverse=smplx_to_faceverse.astype(np.float32),
286
+ residual_transf=residual_transf.astype(np.float32),
287
+ body_nonface_mask=body_nonface_mask.astype(np.int32),
288
+ head_nonface_mask=head_nonface_mask.astype(np.int32),
289
+ head_facial_idx=head_facial_idx.astype(np.int32),
290
+ body_facial_idx=body_facial_idx.astype(np.int32),
291
+ head_body_corr_idx=head_corr_idx.astype(np.int32),
292
+ head_color_bw=head_color_bw.astype(np.float32),
293
+ color_transfer=color_transfer.astype(np.float32))
294
+ return
295
+
296
+
297
+ def manual_refine_facial_cropping(head_facial_points, head_facial_idx, head_facial_colors, body_facial_points, body_facial_idx, body_facial_colors):
298
+ def _save_points_as_obj(fpath, points, points_color):
299
+ points_color = np.clip(points_color, 0, 1)
300
+ with open(fpath, 'w') as fp:
301
+ for p, c in zip(points, points_color):
302
+ fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], c[0], c[1], c[2]))
303
+ return
304
+
305
+ _save_points_as_obj('./debug/head_facial_points.obj', head_facial_points, head_facial_colors)
306
+ _save_points_as_obj('./debug/body_facial_points.obj', body_facial_points, body_facial_colors)
307
+ # trimesh.Trimesh(vertices=head_facial_points, vertex_colors=head_facial_colors).export('./debug/head_facial_points.obj')
308
+ # trimesh.Trimesh(vertices=body_facial_points, vertex_colors=body_facial_colors).export('./debug/body_facial_points.obj')
309
+ if True:
310
+ print('Saving facial points cropped by algorithms. Please remove unnecessary points manually!')
311
+ import pdb; pdb.set_trace()
312
+
313
+ head_facial_points_ = trimesh.load('./debug/head_facial_points.obj').vertices
314
+ body_facial_points_ = trimesh.load('./debug/body_facial_points.obj').vertices
315
+ _, head_nndist = search_nearest_correspondence(head_facial_points, head_facial_points_)
316
+ _, body_nndist = search_nearest_correspondence(body_facial_points, body_facial_points_)
317
+ head_flag = head_nndist < 1e-4
318
+ body_flag = body_nndist < 1e-4
319
+ return head_facial_points[head_flag], head_facial_idx[head_flag], body_facial_points[body_flag], body_facial_idx[body_flag]
320
+
321
+
322
+ def stitch_body_and_head(ref_body_gaussian_path, ref_head_gaussian_path,
323
+ ref_body_param_path, ref_head_param_path,
324
+ smplx2faceverse_path, result_folder):
325
+ device = torch.device("cuda")
326
+
327
+ body_gaussians = load_gaussians_from_ply(ref_body_gaussian_path)
328
+ head_gaussians = load_gaussians_from_ply(ref_head_gaussian_path)
329
+ global_orient, transl, body_pose, betas = load_body_params(ref_body_param_path)
330
+ head_pose, head_scale, id_coeff, exp_coeff = load_face_params(ref_head_param_path)
331
+ smplx_to_faceverse = np.load(smplx2faceverse_path)
332
+ faceverse_to_smplx = np.linalg.inv(smplx_to_faceverse)
333
+
334
+ smpl = SmplTorch(model_file='./AnimatableGaussians/smpl_files/smplx/SMPLX_NEUTRAL.npz')
335
+ smpl_verts, head_joint_transfmat = get_smpl_verts_and_head_transformation(
336
+ smpl, global_orient, body_pose, transl, betas)
337
+ smpl_head_vids = dict(np.load('./data/smpl_models/smplx_head_vidx_and_blendweight.npz'))['head_ids']
338
+ smpl_head_verts = smpl_verts[smpl_head_vids]
339
+
340
+ model_dict = np.load('./data/faceverse_models/faceverse_simple_v2.npy', allow_pickle=True).item()
341
+ faceverse_model = FaceVerseModel(model_dict, batch_size=1)
342
+ faceverse_model.init_coeff_tensors(
343
+ id_coeff=torch.from_numpy(id_coeff).reshape([1, -1]).to(device),
344
+ scale_coeff=torch.from_numpy(head_scale).reshape([1, 1]).to(device),
345
+ )
346
+ faceverse_verts = faceverse_model.forward()['v'][0].detach().cpu().numpy()
347
+ faceverse_verts = transform_faceverse_to_live_body_space(faceverse_verts, faceverse_to_smplx, head_joint_transfmat)
348
+
349
+ livehead2livebody = calc_livehead2livebody(
350
+ head_pose, smplx_to_faceverse, head_joint_transfmat)
351
+ head_gaussians_xyz = np.matmul(head_gaussians.xyz, livehead2livebody[:3, :3].transpose()) \
352
+ + livehead2livebody[:3, 3].reshape(1, 3)
353
+
354
+ # head_facial_points, head_facial_idx = crop_facial_area(smpl_head_verts, head_gaussians_xyz)
355
+ # body_facial_points, body_facial_idx = crop_facial_area(smpl_head_verts, body_gaussians.xyz)
356
+ head_facial_points, head_facial_idx = crop_facial_area2(smpl_verts, smpl_head_vids, head_gaussians_xyz)
357
+ body_facial_points, body_facial_idx = crop_facial_area2(smpl_verts, smpl_head_vids, body_gaussians.xyz)
358
+
359
+ residual_transf = np.eye(4)
360
+ head_facial_points, head_facial_idx, body_facial_points, body_facial_idx = manual_refine_facial_cropping(
361
+ head_facial_points, head_facial_idx, SH2RGB(head_gaussians.features_dc[head_facial_idx]),
362
+ body_facial_points, body_facial_idx, SH2RGB(body_gaussians.features_dc[body_facial_idx]))
363
+ while True:
364
+ for _ in tqdm.trange(4, desc='Fitting residual transformation'):
365
+ corr_idx, _ = search_nearest_correspondence(head_facial_points, body_facial_points)
366
+ corr = body_facial_points[corr_idx]
367
+ transf = estimate_rigid_transformation(head_facial_points, corr)
368
+ residual_transf = np.matmul(transf, residual_transf)
369
+ head_facial_points = np.matmul(head_facial_points, transf[:3, :3].transpose()) + transf[:3, 3].reshape(1, 3)
370
+ if_crop_well = input('If the facial cropping is good enough? (y/n): ')
371
+ if if_crop_well == 'y':
372
+ break
373
+ else:
374
+ head_facial_points, head_facial_idx, body_facial_points, body_facial_idx = manual_refine_facial_cropping(
375
+ head_facial_points, head_facial_idx, SH2RGB(head_gaussians.features_dc[head_facial_idx]),
376
+ body_facial_points, body_facial_idx, SH2RGB(body_gaussians.features_dc[body_facial_idx]))
377
+
378
+ # head_facial_points, head_facial_idx, body_facial_points, body_facial_idx = manual_refine_facial_cropping(
379
+ # head_facial_points, head_facial_idx, SH2RGB(head_gaussians.features_dc[head_facial_idx]),
380
+ # body_facial_points, body_facial_idx, SH2RGB(body_gaussians.features_dc[body_facial_idx]))
381
+
382
+ # 更改一下逻辑,改成直到对齐为止。
383
+
384
+
385
+ print(np.matmul(residual_transf, livehead2livebody))
386
+ residual_transf = np.matmul(np.linalg.inv(livehead2livebody), np.matmul(residual_transf, livehead2livebody))
387
+ corr_idx, _ = search_nearest_correspondence(head_facial_points, body_facial_points)
388
+
389
+ # head_gaussians_xyz = np.matmul(head_gaussians_xyz, residual_transf[:3, :3].transpose()) + residual_transf[:3, 3].reshape(1, 3)
390
+ # faceverse_verts = np.matmul(faceverse_verts, residual_transf[:3, :3].transpose()) + residual_transf[:3, 3].reshape(1, 3)
391
+
392
+ # total_transf = np.matmul(residual_transf, livehead2livebody)
393
+ total_transf = np.matmul(livehead2livebody, residual_transf)
394
+ print(total_transf)
395
+
396
+ color_transfer = estimate_color_transfer(
397
+ head_facial_points, body_facial_points,
398
+ head_gaussians.features_dc[head_facial_idx], body_gaussians.features_dc[body_facial_idx],
399
+ head_gaussians.opacities[head_facial_idx]
400
+ )
401
+ # face_color_bw = get_face_blend_weight(head_facial_points, smpl_verts, sigma=0.015)
402
+ face_color_bw = get_face_blend_weight2(head_facial_points, body_gaussians.xyz, body_facial_idx)
403
+
404
+ body_nonface_mask = np.ones([len(body_gaussians.xyz)], dtype=np.bool_)
405
+ body_nonface_mask[body_facial_idx] = 0
406
+ head_nonface_mask = np.ones([len(head_gaussians.xyz)], dtype=np.bool_)
407
+ head_nonface_mask[head_facial_idx] = 0
408
+
409
+ save_body_face_stitching_data(
410
+ os.path.join(result_folder, 'body_head_blending_param.npz'),
411
+ smplx_to_faceverse, residual_transf, body_nonface_mask, head_nonface_mask,
412
+ head_facial_idx, body_facial_idx, corr_idx, face_color_bw, color_transfer)
413
+
414
+ body_gaussians = apply_transformation_to_gaussians(body_gaussians, np.eye(4))
415
+ head_gaussians = apply_transformation_to_gaussians(head_gaussians, total_transf, np.eye(3))
416
+ body_gaussians_wo_face = select_gaussians(body_gaussians, body_nonface_mask)
417
+ head_gaussians_face_only = select_gaussians(head_gaussians, head_facial_idx)
418
+ head_gaussians_face_only_new_color = blend_color(
419
+ head_gaussians_face_only.features_dc, body_gaussians.features_dc[body_facial_idx][corr_idx], face_color_bw)
420
+ head_gaussians_face_only_new_xyz = blend_color(
421
+ head_gaussians_face_only.xyz, body_gaussians.xyz[body_facial_idx][corr_idx], face_color_bw)
422
+ head_gaussians_face_only_new_opacities = blend_color(
423
+ head_gaussians_face_only.opacities, body_gaussians.opacities[body_facial_idx][corr_idx], face_color_bw)
424
+ head_gaussians_face_only_new_scales = blend_color(
425
+ head_gaussians_face_only.scales, body_gaussians.scales[body_facial_idx][corr_idx], face_color_bw)
426
+
427
+ head_gaussians_face_only = update_gaussian_attributes(
428
+ head_gaussians_face_only, new_rgb=head_gaussians_face_only_new_color,
429
+ new_xyz=head_gaussians_face_only_new_xyz, new_opacity=head_gaussians_face_only_new_opacities,
430
+ new_scale=head_gaussians_face_only_new_scales)
431
+
432
+ full_gaussians = combine_gaussians([body_gaussians_wo_face, head_gaussians_face_only])
433
+ save_gaussians_as_ply(os.path.join(result_folder, 'full_gaussians.ply'), full_gaussians)
render_utils/stitch_funcs.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import argparse
5
+ import tqdm
6
+ import json
7
+ import cv2 as cv
8
+ import os, glob
9
+ import math
10
+
11
+
12
+ from render_utils.lib.utils.graphics_utils import focal2fov, getProjectionMatrix
13
+ from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
14
+
15
+
16
+ def render3(
17
+ gaussian_vals: dict,
18
+ bg_color: torch.Tensor,
19
+ extr: torch.Tensor,
20
+ intr: torch.Tensor,
21
+ img_w: int,
22
+ img_h: int,
23
+ scaling_modifier = 1.0,
24
+ override_color = None,
25
+ compute_cov3D_python = False
26
+ ):
27
+ means3D = gaussian_vals['positions']
28
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
29
+ screenspace_points = torch.zeros_like(means3D, dtype = means3D.dtype, requires_grad = True, device = "cuda") + 0
30
+ try:
31
+ screenspace_points.retain_grad()
32
+ except:
33
+ pass
34
+ means2D = screenspace_points
35
+ opacity = gaussian_vals['opacity']
36
+
37
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
38
+ # scaling / rotation by the rasterizer.
39
+ scales = None
40
+ rotations = None
41
+ cov3D_precomp = None
42
+ scales = gaussian_vals['scales']
43
+ rotations = gaussian_vals['rotations']
44
+
45
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
46
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
47
+ shs = None
48
+ # colors_precomp = None
49
+ # if override_color is None:
50
+ # shs = gaussian_vals['shs']
51
+ # else:
52
+ # colors_precomp = override_color
53
+ if 'colors' in gaussian_vals:
54
+ colors_precomp = gaussian_vals['colors']
55
+ else:
56
+ colors_precomp = None
57
+
58
+ # Set up rasterization configuration
59
+ FoVx = focal2fov(intr[0, 0].item(), img_w)
60
+ FoVy = focal2fov(intr[1, 1].item(), img_h)
61
+ tanfovx = math.tan(FoVx * 0.5)
62
+ tanfovy = math.tan(FoVy * 0.5)
63
+ world_view_transform = extr.transpose(1, 0).cuda()
64
+ projection_matrix = getProjectionMatrix(znear = 0.1, zfar = 100, fovX = FoVx, fovY = FoVy, K = intr, img_w = img_w, img_h = img_h).transpose(0, 1).cuda()
65
+ full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
66
+ camera_center = torch.linalg.inv(extr)[:3, 3]
67
+
68
+ raster_settings = GaussianRasterizationSettings(
69
+ image_height = img_h,
70
+ image_width = img_w,
71
+ tanfovx = tanfovx,
72
+ tanfovy = tanfovy,
73
+ bg = bg_color,
74
+ scale_modifier = scaling_modifier,
75
+ viewmatrix = world_view_transform,
76
+ projmatrix = full_proj_transform,
77
+ sh_degree = gaussian_vals['max_sh_degree'],
78
+ campos = camera_center,
79
+ prefiltered = False,
80
+ debug = False
81
+ )
82
+
83
+ rasterizer = GaussianRasterizer(raster_settings = raster_settings)
84
+
85
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
86
+ rendered_image, radii = rasterizer(
87
+ means3D = means3D,
88
+ means2D = means2D,
89
+ shs = shs,
90
+ colors_precomp = colors_precomp,
91
+ opacities = opacity,
92
+ scales = scales,
93
+ rotations = rotations,
94
+ cov3D_precomp = cov3D_precomp)
95
+
96
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
97
+ # They will be excluded from value updates used in the splitting criteria.
98
+ return {
99
+ "render": rendered_image,
100
+ "viewspace_points": screenspace_points,
101
+ "visibility_filter": radii > 0,
102
+ "radii": radii
103
+ }
104
+
105
+
106
+ def blend_color(head_facial_color, body_facial_color, blend_weight):
107
+ blend_weight = blend_weight.reshape([len(blend_weight)] + [1]*(len(head_facial_color.shape)-1))
108
+ result = head_facial_color * blend_weight + body_facial_color * (1-blend_weight)
109
+ return result
110
+
111
+
112
+ @torch.no_grad()
113
+ def paste_back_with_linear_interp(pasteback_scale, pasteback_center, src, tgt_size):
114
+ pasteback_topleft = [pasteback_center[0] - src.shape[1]/2/pasteback_scale,
115
+ pasteback_center[1] - src.shape[0]/2/pasteback_scale]
116
+
117
+ h, w = src.shape[0], src.shape[1]
118
+ grayscale = False
119
+ if len(src.shape) == 2:
120
+ src = src.reshape([h, w, 1])
121
+ grayscale = True
122
+ src = torch.from_numpy(src)
123
+ src = src.permute(2, 0, 1).unsqueeze(0)
124
+ grid = torch.meshgrid(torch.arange(0, tgt_size[0]), torch.arange(0, tgt_size[1]), indexing='xy')
125
+ grid = torch.stack(grid, dim = -1).float().to(src.device).unsqueeze(0)
126
+ grid[..., 0] = (grid[..., 0] - pasteback_topleft[0]) * pasteback_scale
127
+ grid[..., 1] = (grid[..., 1] - pasteback_topleft[1]) * pasteback_scale
128
+
129
+ grid[..., 0] = grid[..., 0] / (src.shape[-1] / 2.0) - 1.0
130
+ grid[..., 1] = grid[..., 1] / (src.shape[-2] / 2.0) - 1.0
131
+ out = F.grid_sample(src, grid, align_corners = True)
132
+ out = out[0].detach().permute(1, 2, 0).cpu().numpy()
133
+ if grayscale:
134
+ out = out[:, :, 0]
135
+ return out
136
+
137
+
138
+ def soften_blending_mask(blending_mask, valid_mask):
139
+ blending_mask = np.clip(blending_mask*2.0, 0.0, 1.0)
140
+ blending_mask = cv.erode(blending_mask, np.ones((5, 5))) * valid_mask
141
+ blending_mask_bk = np.copy(blending_mask)
142
+ blending_mask = cv.blur(blending_mask*valid_mask, (25, 25))
143
+ valid_mask = cv.blur(valid_mask, (25, 25))
144
+ blending_mask = blending_mask / (valid_mask + 1e-6) * blending_mask_bk
145
+ return blending_mask