Spaces:
Runtime error
Runtime error
query colors from RGB image
Browse files- apps/infer.py +33 -46
apps/infer.py
CHANGED
@@ -33,6 +33,7 @@ from apps.Normal import Normal
|
|
33 |
from apps.IFGeo import IFGeo
|
34 |
from pytorch3d.ops import SubdivideMeshes
|
35 |
from lib.common.config import cfg
|
|
|
36 |
from lib.common.train_util import init_loss, load_normal_networks, load_networks
|
37 |
from lib.common.BNI import BNI
|
38 |
from lib.common.BNI_utils import save_normal_tensor
|
@@ -93,14 +94,14 @@ if __name__ == "__main__":
|
|
93 |
"vol_res": cfg.vol_res,
|
94 |
"single": args.multi,
|
95 |
}
|
96 |
-
|
97 |
if cfg.bni.use_ifnet:
|
98 |
print(colored("Use IF-Nets (Implicit)+ for completion", "green"))
|
99 |
else:
|
100 |
print(colored("Use SMPL-X (Explicit) for completion", "green"))
|
101 |
|
102 |
dataset = TestDataset(dataset_param, device)
|
103 |
-
|
104 |
print(colored(f"Dataset Size: {len(dataset)}", "green"))
|
105 |
|
106 |
pbar = tqdm(dataset)
|
@@ -130,11 +131,7 @@ if __name__ == "__main__":
|
|
130 |
|
131 |
os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
|
132 |
|
133 |
-
in_tensor = {
|
134 |
-
"smpl_faces": data["smpl_faces"],
|
135 |
-
"image": data["img_icon"].to(device),
|
136 |
-
"mask": data["img_mask"].to(device)
|
137 |
-
}
|
138 |
|
139 |
# The optimizer and variables
|
140 |
optimed_pose = data["body_pose"].requires_grad_(True)
|
@@ -158,7 +155,7 @@ if __name__ == "__main__":
|
|
158 |
N_body, N_pose = optimed_pose.shape[:2]
|
159 |
|
160 |
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
|
161 |
-
|
162 |
if osp.exists(smpl_path):
|
163 |
|
164 |
smpl_verts_lst = []
|
@@ -183,7 +180,7 @@ if __name__ == "__main__":
|
|
183 |
|
184 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
185 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
186 |
-
|
187 |
else:
|
188 |
# smpl optimization
|
189 |
loop_smpl = tqdm(range(args.loop_smpl))
|
@@ -252,16 +249,14 @@ if __name__ == "__main__":
|
|
252 |
|
253 |
# BUG: PyTorch3D silhouette renderer generates dilated mask
|
254 |
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
|
255 |
-
smpl_arr_fake = torch.cat(
|
256 |
-
|
257 |
-
dim=-1)
|
258 |
|
259 |
body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
|
260 |
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
|
261 |
body_overlap_flag = body_overlap < cfg.body_overlap_thres
|
262 |
|
263 |
-
losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] +
|
264 |
-
diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0
|
265 |
|
266 |
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
|
267 |
occluded_idx = torch.where(body_overlap_flag)[0]
|
@@ -308,18 +303,15 @@ if __name__ == "__main__":
|
|
308 |
|
309 |
img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
|
310 |
torchvision.utils.save_image(
|
311 |
-
torch.cat(
|
312 |
-
data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
|
313 |
-
|
314 |
-
],
|
315 |
-
dim=3), img_crop_path)
|
316 |
|
317 |
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
|
318 |
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
|
319 |
|
320 |
img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
|
321 |
-
torchvision.utils.save_image(
|
322 |
-
torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path)
|
323 |
|
324 |
smpl_obj_lst = []
|
325 |
|
@@ -397,12 +389,7 @@ if __name__ == "__main__":
|
|
397 |
)
|
398 |
|
399 |
# BNI process
|
400 |
-
BNI_object = BNI(
|
401 |
-
dir_path=osp.join(args.out_dir, cfg.name, "BNI"),
|
402 |
-
name=data["name"],
|
403 |
-
BNI_dict=BNI_dict,
|
404 |
-
cfg=cfg.bni,
|
405 |
-
device=device)
|
406 |
|
407 |
BNI_object.extract_surface(False)
|
408 |
|
@@ -419,16 +406,11 @@ if __name__ == "__main__":
|
|
419 |
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
|
420 |
|
421 |
# mesh completion via IF-net
|
422 |
-
in_tensor.update(
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
occupancies = VoxelGrid.from_mesh(
|
429 |
-
side_mesh, cfg.vol_res, loc=[
|
430 |
-
0,
|
431 |
-
] * 3, scale=2.0).data.transpose(2, 1, 0)
|
432 |
occupancies = np.flip(occupancies, axis=1)
|
433 |
|
434 |
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device)
|
@@ -446,10 +428,9 @@ if __name__ == "__main__":
|
|
446 |
else:
|
447 |
side_mesh = apply_vertex_mask(
|
448 |
side_mesh,
|
449 |
-
(SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
|
450 |
-
SMPLX_object.eyeball_vertex_mask).eq(0).float(),
|
451 |
)
|
452 |
-
|
453 |
#register side_mesh to BNI surfaces
|
454 |
side_mesh = Meshes(
|
455 |
verts=[torch.tensor(side_mesh.vertices).float()],
|
@@ -458,7 +439,6 @@ if __name__ == "__main__":
|
|
458 |
sm = SubdivideMeshes(side_mesh)
|
459 |
side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
|
460 |
|
461 |
-
|
462 |
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
|
463 |
side_faces = torch.tensor(side_mesh.faces).long().to(device)
|
464 |
|
@@ -469,7 +449,6 @@ if __name__ == "__main__":
|
|
469 |
|
470 |
# export intermediate meshes
|
471 |
BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
|
472 |
-
|
473 |
full_lst = []
|
474 |
|
475 |
if "face" in cfg.bni.use_smpl:
|
@@ -479,8 +458,7 @@ if __name__ == "__main__":
|
|
479 |
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
|
480 |
|
481 |
# remove face neighbor triangles
|
482 |
-
BNI_object.F_B_trimesh = part_removal(
|
483 |
-
BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
|
484 |
side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
|
485 |
face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
|
486 |
full_lst += [face_mesh]
|
@@ -497,8 +475,7 @@ if __name__ == "__main__":
|
|
497 |
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
|
498 |
|
499 |
# remove hand neighbor triangles
|
500 |
-
BNI_object.F_B_trimesh = part_removal(
|
501 |
-
BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
|
502 |
side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
|
503 |
hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
|
504 |
full_lst += [hand_mesh]
|
@@ -528,6 +505,16 @@ if __name__ == "__main__":
|
|
528 |
rotate_recon_lst = dataset.render.get_image(cam_type="four")
|
529 |
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
|
530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
# for video rendering
|
532 |
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
|
533 |
in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long())
|
|
|
33 |
from apps.IFGeo import IFGeo
|
34 |
from pytorch3d.ops import SubdivideMeshes
|
35 |
from lib.common.config import cfg
|
36 |
+
from lib.common.render import query_color
|
37 |
from lib.common.train_util import init_loss, load_normal_networks, load_networks
|
38 |
from lib.common.BNI import BNI
|
39 |
from lib.common.BNI_utils import save_normal_tensor
|
|
|
94 |
"vol_res": cfg.vol_res,
|
95 |
"single": args.multi,
|
96 |
}
|
97 |
+
|
98 |
if cfg.bni.use_ifnet:
|
99 |
print(colored("Use IF-Nets (Implicit)+ for completion", "green"))
|
100 |
else:
|
101 |
print(colored("Use SMPL-X (Explicit) for completion", "green"))
|
102 |
|
103 |
dataset = TestDataset(dataset_param, device)
|
104 |
+
|
105 |
print(colored(f"Dataset Size: {len(dataset)}", "green"))
|
106 |
|
107 |
pbar = tqdm(dataset)
|
|
|
131 |
|
132 |
os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
|
133 |
|
134 |
+
in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["img_icon"].to(device), "mask": data["img_mask"].to(device)}
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# The optimizer and variables
|
137 |
optimed_pose = data["body_pose"].requires_grad_(True)
|
|
|
155 |
N_body, N_pose = optimed_pose.shape[:2]
|
156 |
|
157 |
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
|
158 |
+
|
159 |
if osp.exists(smpl_path):
|
160 |
|
161 |
smpl_verts_lst = []
|
|
|
180 |
|
181 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
182 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
183 |
+
|
184 |
else:
|
185 |
# smpl optimization
|
186 |
loop_smpl = tqdm(range(args.loop_smpl))
|
|
|
249 |
|
250 |
# BUG: PyTorch3D silhouette renderer generates dilated mask
|
251 |
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
|
252 |
+
smpl_arr_fake = torch.cat([in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), in_tensor["T_normal_B"][:, 0].ne(bg_value).float()],
|
253 |
+
dim=-1)
|
|
|
254 |
|
255 |
body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
|
256 |
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
|
257 |
body_overlap_flag = body_overlap < cfg.body_overlap_thres
|
258 |
|
259 |
+
losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] + diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0
|
|
|
260 |
|
261 |
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
|
262 |
occluded_idx = torch.where(body_overlap_flag)[0]
|
|
|
303 |
|
304 |
img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
|
305 |
torchvision.utils.save_image(
|
306 |
+
torch.cat(
|
307 |
+
[data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5, (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5],
|
308 |
+
dim=3), img_crop_path)
|
|
|
|
|
309 |
|
310 |
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
|
311 |
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
|
312 |
|
313 |
img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
|
314 |
+
torchvision.utils.save_image(torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path)
|
|
|
315 |
|
316 |
smpl_obj_lst = []
|
317 |
|
|
|
389 |
)
|
390 |
|
391 |
# BNI process
|
392 |
+
BNI_object = BNI(dir_path=osp.join(args.out_dir, cfg.name, "BNI"), name=data["name"], BNI_dict=BNI_dict, cfg=cfg.bni, device=device)
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
BNI_object.extract_surface(False)
|
395 |
|
|
|
406 |
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
|
407 |
|
408 |
# mesh completion via IF-net
|
409 |
+
in_tensor.update(dataset.depth_to_voxel({"depth_F": BNI_object.F_depth.unsqueeze(0), "depth_B": BNI_object.B_depth.unsqueeze(0)}))
|
410 |
+
|
411 |
+
occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
|
412 |
+
0,
|
413 |
+
] * 3, scale=2.0).data.transpose(2, 1, 0)
|
|
|
|
|
|
|
|
|
|
|
414 |
occupancies = np.flip(occupancies, axis=1)
|
415 |
|
416 |
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device)
|
|
|
428 |
else:
|
429 |
side_mesh = apply_vertex_mask(
|
430 |
side_mesh,
|
431 |
+
(SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask + SMPLX_object.eyeball_vertex_mask).eq(0).float(),
|
|
|
432 |
)
|
433 |
+
|
434 |
#register side_mesh to BNI surfaces
|
435 |
side_mesh = Meshes(
|
436 |
verts=[torch.tensor(side_mesh.vertices).float()],
|
|
|
439 |
sm = SubdivideMeshes(side_mesh)
|
440 |
side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
|
441 |
|
|
|
442 |
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
|
443 |
side_faces = torch.tensor(side_mesh.faces).long().to(device)
|
444 |
|
|
|
449 |
|
450 |
# export intermediate meshes
|
451 |
BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
|
|
|
452 |
full_lst = []
|
453 |
|
454 |
if "face" in cfg.bni.use_smpl:
|
|
|
458 |
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
|
459 |
|
460 |
# remove face neighbor triangles
|
461 |
+
BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
|
|
|
462 |
side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
|
463 |
face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
|
464 |
full_lst += [face_mesh]
|
|
|
475 |
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
|
476 |
|
477 |
# remove hand neighbor triangles
|
478 |
+
BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
|
|
|
479 |
side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
|
480 |
hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
|
481 |
full_lst += [hand_mesh]
|
|
|
505 |
rotate_recon_lst = dataset.render.get_image(cam_type="four")
|
506 |
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
|
507 |
|
508 |
+
# coloring the final mesh
|
509 |
+
final_colors = query_color(
|
510 |
+
torch.tensor(final_mesh.vertices).float(),
|
511 |
+
torch.tensor(final_mesh.faces).long(),
|
512 |
+
in_tensor["image"][idx:idx + 1],
|
513 |
+
device=device,
|
514 |
+
)
|
515 |
+
final_mesh.visual.vertex_colors = final_colors
|
516 |
+
final_mesh.export(final_path)
|
517 |
+
|
518 |
# for video rendering
|
519 |
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
|
520 |
in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long())
|