Spaces:
Runtime error
Runtime error
optimize rot_6d instead of rot_mat
Browse files- apps/infer.py +11 -6
- lib/common/train_util.py +2 -0
- lib/dataset/TestDataset.py +5 -0
- lib/dataset/mesh_util.py +16 -0
apps/infer.py
CHANGED
@@ -26,6 +26,7 @@ from lib.dataset.mesh_util import (
|
|
26 |
unwrap,
|
27 |
remesh,
|
28 |
tensor2variable,
|
|
|
29 |
)
|
30 |
|
31 |
from lib.dataset.TestDataset import TestDataset
|
@@ -165,12 +166,16 @@ def generate_model(in_path, model_type):
|
|
165 |
for _ in loop_smpl:
|
166 |
|
167 |
optimizer_smpl.zero_grad()
|
|
|
|
|
|
|
|
|
168 |
|
169 |
if dataset_param["hps_type"] != "pixie":
|
170 |
smpl_out = dataset.smpl_model(
|
171 |
betas=optimed_betas,
|
172 |
-
body_pose=
|
173 |
-
global_orient=
|
174 |
pose2rot=False,
|
175 |
)
|
176 |
|
@@ -180,8 +185,8 @@ def generate_model(in_path, model_type):
|
|
180 |
smpl_verts, _, _ = dataset.smpl_model(
|
181 |
shape_params=optimed_betas,
|
182 |
expression_params=tensor2variable(data["exp"], device),
|
183 |
-
body_pose=
|
184 |
-
global_pose=
|
185 |
jaw_pose=tensor2variable(data["jaw_pose"], device),
|
186 |
left_hand_pose=tensor2variable(
|
187 |
data["left_hand_pose"], device),
|
@@ -316,8 +321,8 @@ def generate_model(in_path, model_type):
|
|
316 |
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb")
|
317 |
|
318 |
smpl_info = {'betas': optimed_betas,
|
319 |
-
'pose':
|
320 |
-
'orient':
|
321 |
'trans': optimed_trans}
|
322 |
|
323 |
np.save(
|
|
|
26 |
unwrap,
|
27 |
remesh,
|
28 |
tensor2variable,
|
29 |
+
rot6d_to_rotmat
|
30 |
)
|
31 |
|
32 |
from lib.dataset.TestDataset import TestDataset
|
|
|
166 |
for _ in loop_smpl:
|
167 |
|
168 |
optimizer_smpl.zero_grad()
|
169 |
+
|
170 |
+
# 6d_rot to rot_mat
|
171 |
+
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1,6)).unsqueeze(0)
|
172 |
+
optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1,6)).unsqueeze(0)
|
173 |
|
174 |
if dataset_param["hps_type"] != "pixie":
|
175 |
smpl_out = dataset.smpl_model(
|
176 |
betas=optimed_betas,
|
177 |
+
body_pose=optimed_pose_mat,
|
178 |
+
global_orient=optimed_orient_mat,
|
179 |
pose2rot=False,
|
180 |
)
|
181 |
|
|
|
185 |
smpl_verts, _, _ = dataset.smpl_model(
|
186 |
shape_params=optimed_betas,
|
187 |
expression_params=tensor2variable(data["exp"], device),
|
188 |
+
body_pose=optimed_pose_mat,
|
189 |
+
global_pose=optimed_orient_mat,
|
190 |
jaw_pose=tensor2variable(data["jaw_pose"], device),
|
191 |
left_hand_pose=tensor2variable(
|
192 |
data["left_hand_pose"], device),
|
|
|
321 |
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb")
|
322 |
|
323 |
smpl_info = {'betas': optimed_betas,
|
324 |
+
'pose': optimed_pose_mat,
|
325 |
+
'orient': optimed_orient_mat,
|
326 |
'trans': optimed_trans}
|
327 |
|
328 |
np.save(
|
lib/common/train_util.py
CHANGED
@@ -32,6 +32,8 @@ import os
|
|
32 |
from termcolor import colored
|
33 |
|
34 |
|
|
|
|
|
35 |
def reshape_sample_tensor(sample_tensor, num_views):
|
36 |
if num_views == 1:
|
37 |
return sample_tensor
|
|
|
32 |
from termcolor import colored
|
33 |
|
34 |
|
35 |
+
|
36 |
+
|
37 |
def reshape_sample_tensor(sample_tensor, num_views):
|
38 |
if num_views == 1:
|
39 |
return sample_tensor
|
lib/dataset/TestDataset.py
CHANGED
@@ -240,6 +240,11 @@ class TestDataset():
|
|
240 |
# body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
|
241 |
# global_orient - [1, 1, 3, 3]
|
242 |
# smpl_verts - [1, 6890, 3] / [1, 10475, 3]
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
return data_dict
|
245 |
|
|
|
240 |
# body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
|
241 |
# global_orient - [1, 1, 3, 3]
|
242 |
# smpl_verts - [1, 6890, 3] / [1, 10475, 3]
|
243 |
+
|
244 |
+
# from rot_mat to rot_6d for better optimization
|
245 |
+
N_body = data_dict["body_pose"].shape[1]
|
246 |
+
data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body,-1)
|
247 |
+
data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1,-1)
|
248 |
|
249 |
return data_dict
|
250 |
|
lib/dataset/mesh_util.py
CHANGED
@@ -44,6 +44,22 @@ from pytorch3d.loss import (
|
|
44 |
|
45 |
from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
def tensor2variable(tensor, device):
|
49 |
# [1,23,3,3]
|
|
|
44 |
|
45 |
from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
|
46 |
|
47 |
+
def rot6d_to_rotmat(x):
|
48 |
+
"""Convert 6D rotation representation to 3x3 rotation matrix.
|
49 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
50 |
+
Input:
|
51 |
+
(B,6) Batch of 6-D rotation representations
|
52 |
+
Output:
|
53 |
+
(B,3,3) Batch of corresponding rotation matrices
|
54 |
+
"""
|
55 |
+
x = x.view(-1, 3, 2)
|
56 |
+
a1 = x[:, :, 0]
|
57 |
+
a2 = x[:, :, 1]
|
58 |
+
b1 = F.normalize(a1)
|
59 |
+
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
|
60 |
+
b3 = torch.cross(b1, b2)
|
61 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
62 |
+
|
63 |
|
64 |
def tensor2variable(tensor, device):
|
65 |
# [1,23,3,3]
|