diff --git a/README.md b/README.md
index 997b626d8f415f03e66d10c10f815290af50b344..70b48651eb55091f6c13dcae46cc44356ee8b939 100644
--- a/README.md
+++ b/README.md
@@ -103,20 +103,23 @@ python -m apps.avatarizer -n {filename}
### Some adjustable parameters in _config/econ.yaml_
-- `use_ifnet: True`
- - True: use IF-Nets+ for mesh completion ( $\text{ECON}_\text{IF}$ - Better quality)
- - False: use SMPL-X for mesh completion ( $\text{ECON}_\text{EX}$ - Faster speed)
+- `use_ifnet: False`
+ - True: use IF-Nets+ for mesh completion ( $\text{ECON}_\text{IF}$ - Better quality, **~2min / img**)
+ - False: use SMPL-X for mesh completion ( $\text{ECON}_\text{EX}$ - Faster speed, **~1.5min / img**)
- `use_smpl: ["hand", "face"]`
- [ ]: don't use either hands or face parts from SMPL-X
- ["hand"]: only use the **visible** hands from SMPL-X
- ["hand", "face"]: use both **visible** hands and face from SMPL-X
- `thickness: 2cm`
- could be increased accordingly in case final reconstruction **xx_full.obj** looks flat
+- `k: 4`
+ - could be reduced accordingly in case the surface of **xx_full.obj** has discontinous artifacts
- `hps_type: PIXIE`
- "pixie": more accurate for face and hands
- "pymafx": more robust for challenging poses
-- `k: 4`
- - could be reduced accordingly in case the surface of **xx_full.obj** has discontinous artifacts
+- `texture_src: image`
+ - "image": direct mapping the aligned pixels to final mesh
+ - "SD": use Stable Diffusion to generate full texture (TODO)
@@ -160,7 +163,6 @@ Here are some great resources we benefit from:
- [BiNI](https://github.com/hoshino042/bilateral_normal_integration) for Bilateral Normal Integration
- [MonoPortDataset](https://github.com/Project-Splinter/MonoPortDataset) for Data Processing, [MonoPort](https://github.com/Project-Splinter/MonoPort) for fast implicit surface query
- [rembg](https://github.com/danielgatis/rembg) for Human Segmentation
-- [pypoisson](https://github.com/mmolero/pypoisson) for poisson reconstruction
- [MediaPipe](https://google.github.io/mediapipe/getting_started/python.html) for full-body landmark estimation
- [PyTorch-NICP](https://github.com/wuhaozhe/pytorch-nicp) for non-rigid registration
- [smplx](https://github.com/vchoutas/smplx), [PyMAF-X](https://www.liuyebin.com/pymaf-x/), [PIXIE](https://github.com/YadiraF/PIXIE) for Human Pose & Shape Estimation
diff --git a/apps/IFGeo.py b/apps/IFGeo.py
index 966462dd93c3aedd84567b392f0ff9f876321bfc..8cb033d8d3fbd597ac80526d3d0c691451975685 100644
--- a/apps/IFGeo.py
+++ b/apps/IFGeo.py
@@ -24,7 +24,6 @@ torch.backends.cudnn.benchmark = True
class IFGeo(pl.LightningModule):
-
def __init__(self, cfg):
super(IFGeo, self).__init__()
@@ -44,14 +43,15 @@ class IFGeo(pl.LightningModule):
from lib.net.IFGeoNet_nobody import IFGeoNet
self.netG = IFGeoNet(cfg)
-
- self.resolutions = (np.logspace(
- start=5,
- stop=np.log2(self.mcube_res),
- base=2,
- num=int(np.log2(self.mcube_res) - 4),
- endpoint=True,
- ) + 1.0)
+ self.resolutions = (
+ np.logspace(
+ start=5,
+ stop=np.log2(self.mcube_res),
+ base=2,
+ num=int(np.log2(self.mcube_res) - 4),
+ endpoint=True,
+ ) + 1.0
+ )
self.resolutions = self.resolutions.astype(np.int16).tolist()
@@ -82,9 +82,9 @@ class IFGeo(pl.LightningModule):
if self.cfg.optim == "Adadelta":
- optimizer_G = torch.optim.Adadelta(optim_params_G,
- lr=self.lr_G,
- weight_decay=weight_decay)
+ optimizer_G = torch.optim.Adadelta(
+ optim_params_G, lr=self.lr_G, weight_decay=weight_decay
+ )
elif self.cfg.optim == "Adam":
@@ -103,20 +103,14 @@ class IFGeo(pl.LightningModule):
raise NotImplementedError
# set scheduler
- scheduler_G = torch.optim.lr_scheduler.MultiStepLR(optimizer_G,
- milestones=self.cfg.schedule,
- gamma=self.cfg.gamma)
+ scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
+ )
return [optimizer_G], [scheduler_G]
def training_step(self, batch, batch_idx):
- # cfg log
- if self.cfg.devices == 1:
- if not self.cfg.fast_dev and self.global_step == 0:
- export_cfg(self.logger, osp.join(self.cfg.results_path, self.cfg.name), self.cfg)
- self.logger.experiment.config.update(convert_to_dict(self.cfg))
-
self.netG.train()
preds_G = self.netG(batch)
@@ -127,12 +121,9 @@ class IFGeo(pl.LightningModule):
"loss": error_G,
}
- self.log_dict(metrics_log,
- prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=False,
- sync_dist=True)
+ self.log_dict(
+ metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True
+ )
return metrics_log
@@ -143,12 +134,14 @@ class IFGeo(pl.LightningModule):
"train/avgloss": batch_mean(outputs, "loss"),
}
- self.log_dict(metrics_log,
- prog_bar=False,
- logger=True,
- on_step=False,
- on_epoch=True,
- rank_zero_only=True)
+ self.log_dict(
+ metrics_log,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True,
+ rank_zero_only=True
+ )
def validation_step(self, batch, batch_idx):
@@ -162,12 +155,9 @@ class IFGeo(pl.LightningModule):
"val/loss": error_G,
}
- self.log_dict(metrics_log,
- prog_bar=True,
- logger=False,
- on_step=True,
- on_epoch=False,
- sync_dist=True)
+ self.log_dict(
+ metrics_log, prog_bar=True, logger=False, on_step=True, on_epoch=False, sync_dist=True
+ )
return metrics_log
@@ -178,9 +168,11 @@ class IFGeo(pl.LightningModule):
"val/avgloss": batch_mean(outputs, "val/loss"),
}
- self.log_dict(metrics_log,
- prog_bar=False,
- logger=True,
- on_step=False,
- on_epoch=True,
- rank_zero_only=True)
+ self.log_dict(
+ metrics_log,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True,
+ rank_zero_only=True
+ )
diff --git a/apps/Normal.py b/apps/Normal.py
index a57df041fe1523a04a9ba9e58e0e88e43023ca1b..235c0aef05914ef040f1495843d339c758ebd9f3 100644
--- a/apps/Normal.py
+++ b/apps/Normal.py
@@ -1,14 +1,12 @@
from lib.net import NormalNet
-from lib.common.train_util import convert_to_dict, export_cfg, batch_mean
+from lib.common.train_util import batch_mean
import torch
import numpy as np
-import os.path as osp
from skimage.transform import resize
import pytorch_lightning as pl
class Normal(pl.LightningModule):
-
def __init__(self, cfg):
super(Normal, self).__init__()
self.cfg = cfg
@@ -44,19 +42,19 @@ class Normal(pl.LightningModule):
optimizer_N_F = torch.optim.Adam(optim_params_N_F, lr=self.lr_F, betas=(0.5, 0.999))
optimizer_N_B = torch.optim.Adam(optim_params_N_B, lr=self.lr_B, betas=(0.5, 0.999))
- scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_F,
- milestones=self.cfg.schedule,
- gamma=self.cfg.gamma)
+ scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
+ )
- scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_B,
- milestones=self.cfg.schedule,
- gamma=self.cfg.gamma)
+ scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
+ )
if 'gan' in self.ALL_losses:
optim_params_N_D = [{"params": self.netG.netD.parameters(), "lr": self.lr_D}]
optimizer_N_D = torch.optim.Adam(optim_params_N_D, lr=self.lr_D, betas=(0.5, 0.999))
- scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_D,
- milestones=self.cfg.schedule,
- gamma=self.cfg.gamma)
+ scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer_N_D, milestones=self.cfg.schedule, gamma=self.cfg.gamma
+ )
self.schedulers = [scheduler_N_F, scheduler_N_B, scheduler_N_D]
optims = [optimizer_N_F, optimizer_N_B, optimizer_N_D]
@@ -77,19 +75,16 @@ class Normal(pl.LightningModule):
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(1, 2, 0),
(height, height),
anti_aliasing=True,
- ))
+ )
+ )
- self.logger.log_image(key=f"Normal/{dataset}/{idx if not self.overfit else 1}",
- images=[(np.concatenate(result_list, axis=1) * 255.0).astype(np.uint8)
- ])
+ self.logger.log_image(
+ key=f"Normal/{dataset}/{idx if not self.overfit else 1}",
+ images=[(np.concatenate(result_list, axis=1) * 255.0).astype(np.uint8)]
+ )
def training_step(self, batch, batch_idx):
- # cfg log
- if not self.cfg.fast_dev and self.global_step == 0 and self.cfg.devices == 1:
- export_cfg(self.logger, osp.join(self.cfg.results_path, self.cfg.name), self.cfg)
- self.logger.experiment.config.update(convert_to_dict(self.cfg))
-
self.netG.train()
# retrieve the data
@@ -125,7 +120,8 @@ class Normal(pl.LightningModule):
opt_B.step()
if batch_idx > 0 and batch_idx % int(
- self.cfg.freq_show_train) == 0 and self.cfg.devices == 1:
+ self.cfg.freq_show_train
+ ) == 0 and self.cfg.devices == 1:
self.netG.eval()
with torch.no_grad():
@@ -142,12 +138,9 @@ class Normal(pl.LightningModule):
for key in error_dict.keys():
metrics_log["train/loss_" + key] = error_dict[key].item()
- self.log_dict(metrics_log,
- prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=False,
- sync_dist=True)
+ self.log_dict(
+ metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True
+ )
return metrics_log
@@ -163,12 +156,14 @@ class Normal(pl.LightningModule):
loss_name = key
metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
- self.log_dict(metrics_log,
- prog_bar=False,
- logger=True,
- on_step=False,
- on_epoch=True,
- rank_zero_only=True)
+ self.log_dict(
+ metrics_log,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True,
+ rank_zero_only=True
+ )
def validation_step(self, batch, batch_idx):
@@ -212,9 +207,11 @@ class Normal(pl.LightningModule):
[stage, loss_name] = key.split("/")
metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
- self.log_dict(metrics_log,
- prog_bar=False,
- logger=True,
- on_step=False,
- on_epoch=True,
- rank_zero_only=True)
+ self.log_dict(
+ metrics_log,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True,
+ rank_zero_only=True
+ )
diff --git a/apps/avatarizer.py b/apps/avatarizer.py
index 12c8fe781a78af9a4f586564f7b2826027b3d5f4..a601b1a60a2ee61c21936bb720f06ec87198d8d8 100644
--- a/apps/avatarizer.py
+++ b/apps/avatarizer.py
@@ -44,7 +44,8 @@ smpl_model = smplx.create(
use_pca=False,
num_betas=200,
num_expression_coeffs=50,
- ext='pkl')
+ ext='pkl'
+)
smpl_out_lst = []
@@ -62,7 +63,9 @@ for pose_type in ["t-pose", "da-pose", "pose"]:
return_full_pose=True,
return_joint_transformation=True,
return_vertex_transformation=True,
- pose_type=pose_type))
+ pose_type=pose_type
+ )
+ )
smpl_verts = smpl_out_lst[2].vertices.detach()[0]
smpl_tree = cKDTree(smpl_verts.cpu().numpy())
@@ -74,7 +77,8 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
econ_verts = torch.tensor(econ_obj.vertices).float()
rot_mat_t = smpl_out_lst[2].vertex_transformation.detach()[0][idx[:, 0]]
homo_coord = torch.ones_like(econ_verts)[..., :1]
- econ_cano_verts = torch.inverse(rot_mat_t) @ torch.cat([econ_verts, homo_coord], dim=1).unsqueeze(-1)
+ econ_cano_verts = torch.inverse(rot_mat_t) @ torch.cat([econ_verts, homo_coord],
+ dim=1).unsqueeze(-1)
econ_cano_verts = econ_cano_verts[:, :3, 0].cpu()
econ_cano = trimesh.Trimesh(econ_cano_verts, econ_obj.faces)
@@ -84,7 +88,9 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
econ_da = trimesh.Trimesh(econ_da_verts[:, :3, 0].cpu(), econ_obj.faces)
# da-pose for SMPL-X
- smpl_da = trimesh.Trimesh(smpl_out_lst[1].vertices.detach()[0], smpl_model.faces, maintain_orders=True, process=False)
+ smpl_da = trimesh.Trimesh(
+ smpl_out_lst[1].vertices.detach()[0], smpl_model.faces, maintain_orders=True, process=False
+ )
smpl_da.export(f"{prefix}_smpl_da.obj")
# remove hands from ECON for next registeration
@@ -97,7 +103,8 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
# remove SMPL-X hand and face
register_mask = ~np.isin(
np.arange(smpl_da.vertices.shape[0]),
- np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid]))
+ np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid])
+ )
register_mask *= ~smplx_container.eyeball_vertex_mask.bool().numpy()
smpl_da_body = smpl_da.copy()
smpl_da_body.update_faces(register_mask[smpl_da.faces].all(axis=1))
@@ -115,8 +122,13 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
# remove over-streched+hand faces from ECON
econ_da_body = econ_da.copy()
edge_before = np.sqrt(
- ((econ_obj.vertices[econ_cano.edges[:, 0]] - econ_obj.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1))
- edge_after = np.sqrt(((econ_da.vertices[econ_cano.edges[:, 0]] - econ_da.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1))
+ ((econ_obj.vertices[econ_cano.edges[:, 0]] -
+ econ_obj.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1)
+ )
+ edge_after = np.sqrt(
+ ((econ_da.vertices[econ_cano.edges[:, 0]] -
+ econ_da.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1)
+ )
edge_diff = edge_after / edge_before.clip(1e-2)
streched_mask = np.unique(econ_cano.edges[edge_diff > 6])
mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid)
@@ -148,8 +160,9 @@ econ_J_regressor = (smpl_model.J_regressor[:, idx] * knn_weights[None]).sum(axis
econ_lbs_weights = (smpl_model.lbs_weights.T[:, idx] * knn_weights[None]).sum(axis=-1).T
num_posedirs = smpl_model.posedirs.shape[0]
-econ_posedirs = (smpl_model.posedirs.view(num_posedirs, -1, 3)[:, idx, :] *
- knn_weights[None, ..., None]).sum(axis=-2).view(num_posedirs, -1).float()
+econ_posedirs = (
+ smpl_model.posedirs.view(num_posedirs, -1, 3)[:, idx, :] * knn_weights[None, ..., None]
+).sum(axis=-2).view(num_posedirs, -1).float()
econ_J_regressor /= econ_J_regressor.sum(axis=1, keepdims=True)
econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
@@ -157,8 +170,9 @@ econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
# re-compute da-pose rot_mat for ECON
rot_mat_da = smpl_out_lst[1].vertex_transformation.detach()[0][idx[:, 0]]
econ_da_verts = torch.tensor(econ_da.vertices).float()
-econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat([econ_da_verts, torch.ones_like(econ_da_verts)[..., :1]],
- dim=1).unsqueeze(-1)
+econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat(
+ [econ_da_verts, torch.ones_like(econ_da_verts)[..., :1]], dim=1
+).unsqueeze(-1)
econ_cano_verts = econ_cano_verts[:, :3, 0].double()
# ----------------------------------------------------
@@ -174,7 +188,8 @@ posed_econ_verts, _ = general_lbs(
posedirs=econ_posedirs,
J_regressor=econ_J_regressor,
parents=smpl_model.parents,
- lbs_weights=econ_lbs_weights)
+ lbs_weights=econ_lbs_weights
+)
econ_pose = trimesh.Trimesh(posed_econ_verts[0].detach(), econ_da.faces)
-econ_pose.export(f"{prefix}_econ_pose.obj")
\ No newline at end of file
+econ_pose.export(f"{prefix}_econ_pose.obj")
diff --git a/apps/infer.py b/apps/infer.py
index 1e354b0e202771934a8a81e16c175489185e047a..fe88f0bbf64f1656fcf496ddce277f56384d1d20 100644
--- a/apps/infer.py
+++ b/apps/infer.py
@@ -34,7 +34,8 @@ from apps.IFGeo import IFGeo
from pytorch3d.ops import SubdivideMeshes
from lib.common.config import cfg
from lib.common.render import query_color
-from lib.common.train_util import init_loss, load_normal_networks, load_networks
+from lib.common.train_util import init_loss, Format
+from lib.common.imutils import blend_rgb_norm
from lib.common.BNI import BNI
from lib.common.BNI_utils import save_normal_tensor
from lib.dataset.TestDataset import TestDataset
@@ -68,20 +69,25 @@ if __name__ == "__main__":
device = torch.device(f"cuda:{args.gpu_device}")
# setting for testing on in-the-wild images
- cfg_show_list = ["test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True, "batch_size", 1]
+ cfg_show_list = [
+ "test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True,
+ "batch_size", 1
+ ]
cfg.merge_from_list(cfg_show_list)
cfg.freeze()
- # load model
- normal_model = Normal(cfg).to(device)
- load_normal_networks(normal_model, cfg.normal_path)
- normal_model.netG.eval()
-
- # load IFGeo model
- ifnet_model = IFGeo(cfg).to(device)
- load_networks(ifnet_model, mlp_path=cfg.ifnet_path)
- ifnet_model.netG.eval()
+ # load normal model
+ normal_net = Normal.load_from_checkpoint(
+ cfg=cfg, checkpoint_path=cfg.normal_path, map_location=device, strict=False
+ )
+ normal_net = normal_net.to(device)
+ normal_net.netG.eval()
+ print(
+ colored(
+ f"Resume Normal Estimator from {Format.start} {cfg.normal_path} {Format.end}", "green"
+ )
+ )
# SMPLX object
SMPLX_object = SMPLX()
@@ -89,16 +95,24 @@ if __name__ == "__main__":
dataset_param = {
"image_dir": args.in_dir,
"seg_dir": args.seg_dir,
- "use_seg": True, # w/ or w/o segmentation
- "hps_type": cfg.bni.hps_type, # pymafx/pixie
+ "use_seg": True, # w/ or w/o segmentation
+ "hps_type": cfg.bni.hps_type, # pymafx/pixie
"vol_res": cfg.vol_res,
"single": args.multi,
}
if cfg.bni.use_ifnet:
- print(colored("Use IF-Nets (Implicit)+ for completion", "green"))
+ # load IFGeo model
+ ifnet = IFGeo.load_from_checkpoint(
+ cfg=cfg, checkpoint_path=cfg.ifnet_path, map_location=device, strict=False
+ )
+ ifnet = ifnet.to(device)
+ ifnet.netG.eval()
+
+ print(colored(f"Resume IF-Net+ from {Format.start} {cfg.ifnet_path} {Format.end}", "green"))
+ print(colored(f"Complete with {Format.start} IF-Nets+ (Implicit) {Format.end}", "green"))
else:
- print(colored("Use SMPL-X (Explicit) for completion", "green"))
+ print(colored(f"Complete with {Format.start} SMPL-X (Explicit) {Format.end}", "green"))
dataset = TestDataset(dataset_param, device)
@@ -125,13 +139,17 @@ if __name__ == "__main__":
# 2. SMPL params (xxx_smpl.npy)
# 3. d-BiNI surfaces (xxx_BNI.obj)
# 4. seperate face/hand mesh (xxx_hand/face.obj)
- # 5. full shape impainted by IF-Nets+, and remeshed shape (xxx_IF_(remesh).obj)
+ # 5. full shape impainted by IF-Nets+ after remeshing (xxx_IF.obj)
# 6. sideded or occluded parts (xxx_side.obj)
# 7. final reconstructed clothed human (xxx_full.obj)
os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
- in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["img_icon"].to(device), "mask": data["img_mask"].to(device)}
+ in_tensor = {
+ "smpl_faces": data["smpl_faces"],
+ "image": data["img_icon"].to(device),
+ "mask": data["img_mask"].to(device)
+ }
# The optimizer and variables
optimed_pose = data["body_pose"].requires_grad_(True)
@@ -139,7 +157,9 @@ if __name__ == "__main__":
optimed_betas = data["betas"].requires_grad_(True)
optimed_orient = data["global_orient"].requires_grad_(True)
- optimizer_smpl = torch.optim.Adam([optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True)
+ optimizer_smpl = torch.optim.Adam(
+ [optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True
+ )
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_smpl,
mode="min",
@@ -156,10 +176,12 @@ if __name__ == "__main__":
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
+ # remove this line if you change the loop_smpl and obtain different SMPL-X fits
if osp.exists(smpl_path):
smpl_verts_lst = []
smpl_faces_lst = []
+
for idx in range(N_body):
smpl_obj = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj"
@@ -173,10 +195,12 @@ if __name__ == "__main__":
batch_smpl_faces = torch.stack(smpl_faces_lst)
# render optimized mesh as normal [-1,1]
- in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(batch_smpl_verts, batch_smpl_faces)
+ in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
+ batch_smpl_verts, batch_smpl_faces
+ )
with torch.no_grad():
- in_tensor["normal_F"], in_tensor["normal_B"] = normal_model.netG(in_tensor)
+ in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
@@ -194,8 +218,10 @@ if __name__ == "__main__":
N_body, N_pose = optimed_pose.shape[:2]
# 6d_rot to rot_mat
- optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, 6)).view(N_body, 1, 3, 3)
- optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, 6)).view(N_body, N_pose, 3, 3)
+ optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1,
+ 6)).view(N_body, 1, 3, 3)
+ optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1,
+ 6)).view(N_body, N_pose, 3, 3)
smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model(
shape_params=optimed_betas,
@@ -208,11 +234,16 @@ if __name__ == "__main__":
)
smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
- smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor([1.0, 1.0, -1.0]).to(device)
+ smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor(
+ [1.0, 1.0, -1.0]
+ ).to(device)
# landmark errors
- smpl_joints_3d = (smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0) * 0.5
- in_tensor["smpl_joint"] = smpl_joints[:, dataset.smpl_data.smpl_joint_ids_24_pixie, :]
+ smpl_joints_3d = (
+ smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0
+ ) * 0.5
+ in_tensor["smpl_joint"] = smpl_joints[:,
+ dataset.smpl_data.smpl_joint_ids_24_pixie, :]
ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device)
ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device)
@@ -227,7 +258,7 @@ if __name__ == "__main__":
T_mask_F, T_mask_B = dataset.render.get_image(type="mask")
with torch.no_grad():
- in_tensor["normal_F"], in_tensor["normal_B"] = normal_model.netG(in_tensor)
+ in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"])
diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"])
@@ -249,25 +280,37 @@ if __name__ == "__main__":
# BUG: PyTorch3D silhouette renderer generates dilated mask
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
- smpl_arr_fake = torch.cat([in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), in_tensor["T_normal_B"][:, 0].ne(bg_value).float()],
- dim=-1)
+ smpl_arr_fake = torch.cat(
+ [
+ in_tensor["T_normal_F"][:, 0].ne(bg_value).float(),
+ in_tensor["T_normal_B"][:, 0].ne(bg_value).float()
+ ],
+ dim=-1
+ )
- body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
+ body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)
+ ).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
body_overlap_flag = body_overlap < cfg.body_overlap_thres
- losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] + diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0
+ losses["normal"]["value"] = (
+ diff_F_smpl * body_overlap_mask[..., :512] +
+ diff_B_smpl * body_overlap_mask[..., 512:]
+ ).mean() / 2.0
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
occluded_idx = torch.where(body_overlap_flag)[0]
ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95
- losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) * ghum_conf).mean(dim=1)
+ losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) *
+ ghum_conf).mean(dim=1)
# Weighted sum of the losses
smpl_loss = 0.0
- pbar_desc = "Body Fitting --- "
+ pbar_desc = "Body Fitting -- "
for k in ["normal", "silhouette", "joint"]:
- per_loop_loss = (losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device)).mean()
+ per_loop_loss = (
+ losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device)
+ ).mean()
pbar_desc += f"{k}: {per_loop_loss:.3f} | "
smpl_loss += per_loop_loss
pbar_desc += f"Total: {smpl_loss:.3f}"
@@ -279,19 +322,25 @@ if __name__ == "__main__":
# save intermediate results / vis_freq and final_step
if (i % args.vis_freq == 0) or (i == args.loop_smpl - 1):
- per_loop_lst.extend([
- in_tensor["image"],
- in_tensor["T_normal_F"],
- in_tensor["normal_F"],
- diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1),
- ])
- per_loop_lst.extend([
- in_tensor["image"],
- in_tensor["T_normal_B"],
- in_tensor["normal_B"],
- diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1),
- ])
- per_data_lst.append(get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl"))
+ per_loop_lst.extend(
+ [
+ in_tensor["image"],
+ in_tensor["T_normal_F"],
+ in_tensor["normal_F"],
+ diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1),
+ ]
+ )
+ per_loop_lst.extend(
+ [
+ in_tensor["image"],
+ in_tensor["T_normal_B"],
+ in_tensor["normal_B"],
+ diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1),
+ ]
+ )
+ per_data_lst.append(
+ get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl")
+ )
smpl_loss.backward()
optimizer_smpl.step()
@@ -304,14 +353,21 @@ if __name__ == "__main__":
img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
torchvision.utils.save_image(
torch.cat(
- [data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5, (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5],
- dim=3), img_crop_path)
+ [
+ data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
+ (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5
+ ],
+ dim=3
+ ), img_crop_path
+ )
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
- torchvision.utils.save_image(torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path)
+ torchvision.utils.save_image(
+ torch.cat([data["img_raw"], rgb_norm_F, rgb_norm_B], dim=-1) / 255., img_overlap_path
+ )
smpl_obj_lst = []
@@ -329,15 +385,28 @@ if __name__ == "__main__":
if not osp.exists(smpl_obj_path):
smpl_obj.export(smpl_obj_path)
smpl_info = {
- "betas": optimed_betas[idx].detach().cpu().unsqueeze(0),
- "body_pose": rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()).cpu().unsqueeze(0),
- "global_orient": rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()).cpu().unsqueeze(0),
- "transl": optimed_trans[idx].detach().cpu(),
- "expression": data["exp"][idx].cpu().unsqueeze(0),
- "jaw_pose": rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0),
- "left_hand_pose": rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]).cpu().unsqueeze(0),
- "right_hand_pose": rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]).cpu().unsqueeze(0),
- "scale": data["scale"][idx].cpu(),
+ "betas":
+ optimed_betas[idx].detach().cpu().unsqueeze(0),
+ "body_pose":
+ rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()
+ ).cpu().unsqueeze(0),
+ "global_orient":
+ rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()
+ ).cpu().unsqueeze(0),
+ "transl":
+ optimed_trans[idx].detach().cpu(),
+ "expression":
+ data["exp"][idx].cpu().unsqueeze(0),
+ "jaw_pose":
+ rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0),
+ "left_hand_pose":
+ rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]
+ ).cpu().unsqueeze(0),
+ "right_hand_pose":
+ rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]
+ ).cpu().unsqueeze(0),
+ "scale":
+ data["scale"][idx].cpu(),
}
np.save(
smpl_obj_path.replace(".obj", ".npy"),
@@ -359,10 +428,13 @@ if __name__ == "__main__":
per_data_lst = []
- batch_smpl_verts = in_tensor["smpl_verts"].detach() * torch.tensor([1.0, -1.0, 1.0], device=device)
+ batch_smpl_verts = in_tensor["smpl_verts"].detach(
+ ) * torch.tensor([1.0, -1.0, 1.0], device=device)
batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]]
- in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(batch_smpl_verts, batch_smpl_faces)
+ in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(
+ batch_smpl_verts, batch_smpl_faces
+ )
per_loop_lst = []
@@ -389,7 +461,13 @@ if __name__ == "__main__":
)
# BNI process
- BNI_object = BNI(dir_path=osp.join(args.out_dir, cfg.name, "BNI"), name=data["name"], BNI_dict=BNI_dict, cfg=cfg.bni, device=device)
+ BNI_object = BNI(
+ dir_path=osp.join(args.out_dir, cfg.name, "BNI"),
+ name=data["name"],
+ BNI_dict=BNI_dict,
+ cfg=cfg.bni,
+ device=device
+ )
BNI_object.extract_surface(False)
@@ -406,29 +484,40 @@ if __name__ == "__main__":
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
# mesh completion via IF-net
- in_tensor.update(dataset.depth_to_voxel({"depth_F": BNI_object.F_depth.unsqueeze(0), "depth_B": BNI_object.B_depth.unsqueeze(0)}))
+ in_tensor.update(
+ dataset.depth_to_voxel(
+ {
+ "depth_F": BNI_object.F_depth.unsqueeze(0),
+ "depth_B": BNI_object.B_depth.unsqueeze(0)
+ }
+ )
+ )
occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
0,
] * 3, scale=2.0).data.transpose(2, 1, 0)
occupancies = np.flip(occupancies, axis=1)
- in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device)
+ in_tensor["body_voxels"] = torch.tensor(occupancies.copy()
+ ).float().unsqueeze(0).to(device)
with torch.no_grad():
- sdf = ifnet_model.reconEngine(netG=ifnet_model.netG, batch=in_tensor)
- verts_IF, faces_IF = ifnet_model.reconEngine.export_mesh(sdf)
+ sdf = ifnet.reconEngine(netG=ifnet.netG, batch=in_tensor)
+ verts_IF, faces_IF = ifnet.reconEngine.export_mesh(sdf)
- if ifnet_model.clean_mesh_flag:
+ if ifnet.clean_mesh_flag:
verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF)
side_mesh = trimesh.Trimesh(verts_IF, faces_IF)
- side_mesh = remesh(side_mesh, side_mesh_path)
+ side_mesh = remesh_laplacian(side_mesh, side_mesh_path)
else:
side_mesh = apply_vertex_mask(
side_mesh,
- (SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask + SMPLX_object.eyeball_vertex_mask).eq(0).float(),
+ (
+ SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
+ SMPLX_object.eyeball_vertex_mask
+ ).eq(0).float(),
)
#register side_mesh to BNI surfaces
@@ -448,7 +537,9 @@ if __name__ == "__main__":
# 3. remove eyeball faces
# export intermediate meshes
- BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
+ BNI_object.F_B_trimesh.export(
+ f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj"
+ )
full_lst = []
if "face" in cfg.bni.use_smpl:
@@ -458,37 +549,63 @@ if __name__ == "__main__":
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
# remove face neighbor triangles
- BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
- side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
+ BNI_object.F_B_trimesh = part_removal(
+ BNI_object.F_B_trimesh,
+ face_mesh,
+ cfg.bni.face_thres,
+ device,
+ smplx_mesh,
+ region="face"
+ )
+ side_mesh = part_removal(
+ side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face"
+ )
face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
full_lst += [face_mesh]
if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]):
- hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0],)
+ hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0], )
if data['hands_visibility'][idx][0]:
- hand_mask.index_fill_(0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0)
+ hand_mask.index_fill_(
+ 0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0
+ )
if data['hands_visibility'][idx][1]:
- hand_mask.index_fill_(0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0)
+ hand_mask.index_fill_(
+ 0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0
+ )
# only hands
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
# remove hand neighbor triangles
- BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
- side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
+ BNI_object.F_B_trimesh = part_removal(
+ BNI_object.F_B_trimesh,
+ hand_mesh,
+ cfg.bni.hand_thres,
+ device,
+ smplx_mesh,
+ region="hand"
+ )
+ side_mesh = part_removal(
+ side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand"
+ )
hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
full_lst += [hand_mesh]
full_lst += [BNI_object.F_B_trimesh]
# initial side_mesh could be SMPLX or IF-net
- side_mesh = part_removal(side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False)
+ side_mesh = part_removal(
+ side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False
+ )
full_lst += [side_mesh]
# # export intermediate meshes
- BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
+ BNI_object.F_B_trimesh.export(
+ f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj"
+ )
side_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj")
if cfg.bni.use_poisson:
@@ -505,15 +622,22 @@ if __name__ == "__main__":
rotate_recon_lst = dataset.render.get_image(cam_type="four")
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
- # coloring the final mesh
- final_colors = query_color(
- torch.tensor(final_mesh.vertices).float(),
- torch.tensor(final_mesh.faces).long(),
- in_tensor["image"][idx:idx + 1],
- device=device,
- )
- final_mesh.visual.vertex_colors = final_colors
- final_mesh.export(final_path)
+ if cfg.bni.texture_src == 'image':
+
+ # coloring the final mesh (front: RGB pixels, back: normal colors)
+ final_colors = query_color(
+ torch.tensor(final_mesh.vertices).float(),
+ torch.tensor(final_mesh.faces).long(),
+ in_tensor["image"][idx:idx + 1],
+ device=device,
+ )
+ final_mesh.visual.vertex_colors = final_colors
+ final_mesh.export(final_path)
+
+ elif cfg.bni.texture_src == 'SD':
+
+ # !TODO: add texture from Stable Diffusion
+ pass
# for video rendering
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
diff --git a/apps/multi_render.py b/apps/multi_render.py
index 933029cae4f98c2bc6400431dc3eb828701158ef..4088440757ce81137aaad7685d9df4b53b1c1383 100644
--- a/apps/multi_render.py
+++ b/apps/multi_render.py
@@ -20,6 +20,4 @@ faces_lst = in_tensor["body_faces"] + in_tensor["BNI_faces"]
# self-rotated video
render.load_meshes(verts_lst, faces_lst)
-render.get_rendered_video_multi(
- in_tensor,
- f"{root}/{args.name}_cloth.mp4")
\ No newline at end of file
+render.get_rendered_video_multi(in_tensor, f"{root}/{args.name}_cloth.mp4")
diff --git a/configs/econ.yaml b/configs/econ.yaml
index 548b683af89750e58ce8536f19f99170aeb5e693..ec35721b8ac9cc25fb31a9bc7836d44f3373aeb4 100644
--- a/configs/econ.yaml
+++ b/configs/econ.yaml
@@ -35,3 +35,4 @@ bni:
face_thres: 6e-2
thickness: 0.02
hps_type: "pixie"
+ texture_src: "SD"
diff --git a/docs/installation.md b/docs/installation.md
index e53298dec5f0e1cc89b2433f63ed23b9d97606dc..df52326db113c77165fdd254dd3e75aeb5f8a2f3 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -9,12 +9,11 @@ cd ECON
## Environment
-- Ubuntu 20 / 18
-- GCC=7 (required by [pypoisson](https://github.com/mmolero/pypoisson/issues/13))
+- Ubuntu 20 / 18, (Windows as well, see [issue#7](https://github.com/YuliangXiu/ECON/issues/7))
- **CUDA=11.4, GPU Memory > 12GB**
- Python = 3.8
- PyTorch >= 1.13.0 (official [Get Started](https://pytorch.org/get-started/locally/))
-- CUPY >= 11.3.0 (offcial [Installation](https://docs.cupy.dev/en/stable/install.html#installing-cupy-from-pypi))
+- Cupy >= 11.3.0 (offcial [Installation](https://docs.cupy.dev/en/stable/install.html#installing-cupy-from-pypi))
- PyTorch3D (official [INSTALL.md](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md), recommend [install-from-local-clone](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-from-a-local-clone))
```bash
diff --git a/lib/common/BNI.py b/lib/common/BNI.py
index df26ffb077682e8525f0bcf727520d181801a24a..1b65777913c2a3573848842f65dd266f213c7f87 100644
--- a/lib/common/BNI.py
+++ b/lib/common/BNI.py
@@ -1,12 +1,12 @@
-from lib.common.BNI_utils import (verts_inverse_transform, depth_inverse_transform,
- double_side_bilateral_normal_integration)
+from lib.common.BNI_utils import (
+ verts_inverse_transform, depth_inverse_transform, double_side_bilateral_normal_integration
+)
import torch
import trimesh
class BNI:
-
def __init__(self, dir_path, name, BNI_dict, cfg, device):
self.scale = 256.0
@@ -64,22 +64,20 @@ class BNI:
F_B_verts = torch.cat((F_verts, B_verts), dim=0)
F_B_faces = torch.cat(
- (bni_result["F_faces"], bni_result["B_faces"] + bni_result["F_faces"].max() + 1), dim=0)
+ (bni_result["F_faces"], bni_result["B_faces"] + bni_result["F_faces"].max() + 1), dim=0
+ )
- self.F_B_trimesh = trimesh.Trimesh(F_B_verts.float(),
- F_B_faces.long(),
- process=False,
- maintain_order=True)
+ self.F_B_trimesh = trimesh.Trimesh(
+ F_B_verts.float(), F_B_faces.long(), process=False, maintain_order=True
+ )
- self.F_trimesh = trimesh.Trimesh(F_verts.float(),
- bni_result["F_faces"].long(),
- process=False,
- maintain_order=True)
+ self.F_trimesh = trimesh.Trimesh(
+ F_verts.float(), bni_result["F_faces"].long(), process=False, maintain_order=True
+ )
- self.B_trimesh = trimesh.Trimesh(B_verts.float(),
- bni_result["B_faces"].long(),
- process=False,
- maintain_order=True)
+ self.B_trimesh = trimesh.Trimesh(
+ B_verts.float(), bni_result["B_faces"].long(), process=False, maintain_order=True
+ )
if __name__ == "__main__":
@@ -93,16 +91,18 @@ if __name__ == "__main__":
bni_dict = np.load(npy_file, allow_pickle=True).item()
default_cfg = {'k': 2, 'lambda1': 1e-4, 'boundary_consist': 1e-6}
-
+
# for k in [1, 2, 4, 10, 100]:
# default_cfg['k'] = k
# for k in [1e-8, 1e-4, 1e-2, 1e-1, 1]:
- # default_cfg['lambda1'] = k
+ # default_cfg['lambda1'] = k
# for k in [1e-4, 1e-2, 0]:
- # default_cfg['boundary_consist'] = k
-
- bni_object = BNI(osp.dirname(npy_file), osp.basename(npy_file), bni_dict, default_cfg,
- torch.device('cuda:0'))
+ # default_cfg['boundary_consist'] = k
+
+ bni_object = BNI(
+ osp.dirname(npy_file), osp.basename(npy_file), bni_dict, default_cfg,
+ torch.device('cuda:0')
+ )
bni_object.extract_surface()
bni_object.F_trimesh.export(osp.join(osp.dirname(npy_file), "F.obj"))
diff --git a/lib/common/BNI_utils.py b/lib/common/BNI_utils.py
index b5f5f873dbdfbaf9b40c268d481a6368824edbec..57deed1d9e7da0379a3017e4f0f44c1bf34314b9 100644
--- a/lib/common/BNI_utils.py
+++ b/lib/common/BNI_utils.py
@@ -53,8 +53,9 @@ def find_contour(mask, method='all'):
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
else:
- contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE,
- cv2.CHAIN_APPROX_SIMPLE)
+ contours, _ = cv2.findContours(
+ mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
+ )
contour_cloth = np.array(find_max_list(contours))[:, 0, :]
@@ -67,16 +68,19 @@ def mean_value_cordinates(inner_pts, contour_pts):
body_edges_c = np.roll(body_edges_a, shift=-1, axis=1)
body_edges_b = np.sqrt(((contour_pts - np.roll(contour_pts, shift=-1, axis=0))**2).sum(axis=1))
- body_edges = np.concatenate([
- body_edges_a[..., None], body_edges_c[..., None],
- np.repeat(body_edges_b[None, :, None], axis=0, repeats=len(inner_pts))
- ],
- axis=-1)
+ body_edges = np.concatenate(
+ [
+ body_edges_a[..., None], body_edges_c[..., None],
+ np.repeat(body_edges_b[None, :, None], axis=0, repeats=len(inner_pts))
+ ],
+ axis=-1
+ )
body_cos = (body_edges[:, :, 0]**2 + body_edges[:, :, 1]**2 -
body_edges[:, :, 2]**2) / (2 * body_edges[:, :, 0] * body_edges[:, :, 1])
body_tan_half = np.sqrt(
- (1. - np.clip(body_cos, a_max=1., a_min=-1.)) / np.clip(1. + body_cos, 1e-6, 2.))
+ (1. - np.clip(body_cos, a_max=1., a_min=-1.)) / np.clip(1. + body_cos, 1e-6, 2.)
+ )
w = (body_tan_half + np.roll(body_tan_half, shift=1, axis=1)) / body_edges_a
w /= w.sum(axis=1, keepdims=True)
@@ -97,16 +101,18 @@ def dispCorres(img_size, contour1, contour2, phi, dir_path):
contour2 = contour2[None, :, None, :].astype(np.int32)
disp = np.zeros((img_size, img_size, 3), dtype=np.uint8)
- cv2.drawContours(disp, contour1, -1, (0, 255, 0), 1) # green
- cv2.drawContours(disp, contour2, -1, (255, 0, 0), 1) # blue
+ cv2.drawContours(disp, contour1, -1, (0, 255, 0), 1) # green
+ cv2.drawContours(disp, contour2, -1, (255, 0, 0), 1) # blue
- for i in range(contour1.shape[1]): # do not show all the points when display
+ for i in range(contour1.shape[1]): # do not show all the points when display
# cv2.circle(disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), 1,
# (255, 0, 0), -1)
corresPoint = contour2[0, phi[i], 0]
# cv2.circle(disp, (corresPoint[0], corresPoint[1]), 1, (0, 255, 0), -1)
- cv2.line(disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]),
- (corresPoint[0], corresPoint[1]), (255, 255, 255), 1)
+ cv2.line(
+ disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), (corresPoint[0], corresPoint[1]),
+ (255, 255, 255), 1
+ )
cv2.imwrite(osp.join(dir_path, "corres.png"), disp)
@@ -162,7 +168,8 @@ def verts_transform(t, depth_scale):
t_copy *= depth_scale * 0.5
t_copy += depth_scale * 0.5
t_copy = t_copy[:, [1, 0, 2]] * torch.Tensor([2.0, 2.0, -2.0]) + torch.Tensor(
- [0.0, 0.0, depth_scale])
+ [0.0, 0.0, depth_scale]
+ )
return t_copy
@@ -328,19 +335,22 @@ def construct_facets_from(mask):
facet_move_top_mask = move_top(mask)
facet_move_left_mask = move_left(mask)
facet_move_top_left_mask = move_top_left(mask)
- facet_top_left_mask = (facet_move_top_mask * facet_move_left_mask * facet_move_top_left_mask *
- mask)
+ facet_top_left_mask = (
+ facet_move_top_mask * facet_move_left_mask * facet_move_top_left_mask * mask
+ )
facet_top_right_mask = move_right(facet_top_left_mask)
facet_bottom_left_mask = move_bottom(facet_top_left_mask)
facet_bottom_right_mask = move_bottom_right(facet_top_left_mask)
- return cp.hstack((
- 4 * cp.ones((cp.sum(facet_top_left_mask).item(), 1)),
- idx[facet_top_left_mask][:, None],
- idx[facet_bottom_left_mask][:, None],
- idx[facet_bottom_right_mask][:, None],
- idx[facet_top_right_mask][:, None],
- )).astype(int)
+ return cp.hstack(
+ (
+ 4 * cp.ones((cp.sum(facet_top_left_mask).item(), 1)),
+ idx[facet_top_left_mask][:, None],
+ idx[facet_bottom_left_mask][:, None],
+ idx[facet_bottom_right_mask][:, None],
+ idx[facet_top_right_mask][:, None],
+ )
+ ).astype(int)
def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1):
@@ -364,8 +374,8 @@ def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1):
u[..., 0] = xx
u[..., 1] = yy
u[..., 2] = 1
- u = u[mask].T # 3 x m
- vertices = (cp.linalg.inv(K) @ u).T * depth_map[mask, cp.newaxis] # m x 3
+ u = u[mask].T # 3 x m
+ vertices = (cp.linalg.inv(K) @ u).T * depth_map[mask, cp.newaxis] # m x 3
return vertices
@@ -374,7 +384,6 @@ def sigmoid(x, k=1):
return 1 / (1 + cp.exp(-k * x))
-
def boundary_excluded_mask(mask):
top_mask = cp.pad(mask, ((1, 0), (0, 0)), "constant", constant_values=0)[:-1, :]
bottom_mask = cp.pad(mask, ((0, 1), (0, 0)), "constant", constant_values=0)[1:, :]
@@ -410,22 +419,24 @@ def create_boundary_matrix(mask):
return B, B_full
-def double_side_bilateral_normal_integration(normal_front,
- normal_back,
- normal_mask,
- depth_front=None,
- depth_back=None,
- depth_mask=None,
- k=2,
- lambda_normal_back=1,
- lambda_depth_front=1e-4,
- lambda_depth_back=1e-2,
- lambda_boundary_consistency=1,
- step_size=1,
- max_iter=150,
- tol=1e-4,
- cg_max_iter=5000,
- cg_tol=1e-3):
+def double_side_bilateral_normal_integration(
+ normal_front,
+ normal_back,
+ normal_mask,
+ depth_front=None,
+ depth_back=None,
+ depth_mask=None,
+ k=2,
+ lambda_normal_back=1,
+ lambda_depth_front=1e-4,
+ lambda_depth_back=1e-2,
+ lambda_boundary_consistency=1,
+ step_size=1,
+ max_iter=150,
+ tol=1e-4,
+ cg_max_iter=5000,
+ cg_tol=1e-3
+):
# To avoid confusion, we list the coordinate systems in this code as follows
#
@@ -467,14 +478,12 @@ def double_side_bilateral_normal_integration(normal_front,
del normal_map_back
# right, left, top, bottom
- A3_f, A4_f, A1_f, A2_f = generate_dx_dy(normal_mask,
- nz_horizontal=nz_front,
- nz_vertical=nz_front,
- step_size=step_size)
- A3_b, A4_b, A1_b, A2_b = generate_dx_dy(normal_mask,
- nz_horizontal=nz_back,
- nz_vertical=nz_back,
- step_size=step_size)
+ A3_f, A4_f, A1_f, A2_f = generate_dx_dy(
+ normal_mask, nz_horizontal=nz_front, nz_vertical=nz_front, step_size=step_size
+ )
+ A3_b, A4_b, A1_b, A2_b = generate_dx_dy(
+ normal_mask, nz_horizontal=nz_back, nz_vertical=nz_back, step_size=step_size
+ )
has_left_mask = cp.logical_and(move_right(normal_mask), normal_mask)
has_right_mask = cp.logical_and(move_left(normal_mask), normal_mask)
@@ -498,29 +507,25 @@ def double_side_bilateral_normal_integration(normal_front,
b_back = cp.concatenate((-nx_back, -nx_back, -ny_back, -ny_back))
# initialization
- W_front = spdiags(0.5 * cp.ones(4 * num_normals),
- 0,
- 4 * num_normals,
- 4 * num_normals,
- format="csr")
- W_back = spdiags(0.5 * cp.ones(4 * num_normals),
- 0,
- 4 * num_normals,
- 4 * num_normals,
- format="csr")
+ W_front = spdiags(
+ 0.5 * cp.ones(4 * num_normals), 0, 4 * num_normals, 4 * num_normals, format="csr"
+ )
+ W_back = spdiags(
+ 0.5 * cp.ones(4 * num_normals), 0, 4 * num_normals, 4 * num_normals, format="csr"
+ )
z_front = cp.zeros(num_normals, float)
z_back = cp.zeros(num_normals, float)
z_combined = cp.concatenate((z_front, z_back))
B, B_full = create_boundary_matrix(normal_mask)
- B_mat = lambda_boundary_consistency * coo_matrix(B_full.get().T @ B_full.get()) #bug
+ B_mat = lambda_boundary_consistency * coo_matrix(B_full.get().T @ B_full.get()) #bug
energy_list = []
if depth_mask is not None:
- depth_mask_flat = depth_mask[normal_mask].astype(bool) # shape: (num_normals,)
- z_prior_front = depth_map_front[normal_mask] # shape: (num_normals,)
+ depth_mask_flat = depth_mask[normal_mask].astype(bool) # shape: (num_normals,)
+ z_prior_front = depth_map_front[normal_mask] # shape: (num_normals,)
z_prior_front[~depth_mask_flat] = 0
z_prior_back = depth_map_back[normal_mask]
z_prior_back[~depth_mask_flat] = 0
@@ -554,40 +559,43 @@ def double_side_bilateral_normal_integration(normal_front,
vstack((csr_matrix((num_normals, num_normals)), A_mat_back))]) + B_mat
b_vec_combined = cp.concatenate((b_vec_front, b_vec_back))
- D = spdiags(1 / cp.clip(A_mat_combined.diagonal(), 1e-5, None), 0, 2 * num_normals,
- 2 * num_normals, "csr") # Jacob preconditioner
+ D = spdiags(
+ 1 / cp.clip(A_mat_combined.diagonal(), 1e-5, None), 0, 2 * num_normals, 2 * num_normals,
+ "csr"
+ ) # Jacob preconditioner
- z_combined, _ = cg(A_mat_combined,
- b_vec_combined,
- M=D,
- x0=z_combined,
- maxiter=cg_max_iter,
- tol=cg_tol)
+ z_combined, _ = cg(
+ A_mat_combined, b_vec_combined, M=D, x0=z_combined, maxiter=cg_max_iter, tol=cg_tol
+ )
z_front = z_combined[:num_normals]
z_back = z_combined[num_normals:]
- wu_f = sigmoid((A2_f.dot(z_front))**2 - (A1_f.dot(z_front))**2, k) # top
- wv_f = sigmoid((A4_f.dot(z_front))**2 - (A3_f.dot(z_front))**2, k) # right
+ wu_f = sigmoid((A2_f.dot(z_front))**2 - (A1_f.dot(z_front))**2, k) # top
+ wv_f = sigmoid((A4_f.dot(z_front))**2 - (A3_f.dot(z_front))**2, k) # right
wu_f[top_boundnary_mask] = 0.5
wu_f[bottom_boundary_mask] = 0.5
wv_f[left_boundary_mask] = 0.5
wv_f[right_boudnary_mask] = 0.5
- W_front = spdiags(cp.concatenate((wu_f, 1 - wu_f, wv_f, 1 - wv_f)),
- 0,
- 4 * num_normals,
- 4 * num_normals,
- format="csr")
-
- wu_b = sigmoid((A2_b.dot(z_back))**2 - (A1_b.dot(z_back))**2, k) # top
- wv_b = sigmoid((A4_b.dot(z_back))**2 - (A3_b.dot(z_back))**2, k) # right
+ W_front = spdiags(
+ cp.concatenate((wu_f, 1 - wu_f, wv_f, 1 - wv_f)),
+ 0,
+ 4 * num_normals,
+ 4 * num_normals,
+ format="csr"
+ )
+
+ wu_b = sigmoid((A2_b.dot(z_back))**2 - (A1_b.dot(z_back))**2, k) # top
+ wv_b = sigmoid((A4_b.dot(z_back))**2 - (A3_b.dot(z_back))**2, k) # right
wu_b[top_boundnary_mask] = 0.5
wu_b[bottom_boundary_mask] = 0.5
wv_b[left_boundary_mask] = 0.5
wv_b[right_boudnary_mask] = 0.5
- W_back = spdiags(cp.concatenate((wu_b, 1 - wu_b, wv_b, 1 - wv_b)),
- 0,
- 4 * num_normals,
- 4 * num_normals,
- format="csr")
+ W_back = spdiags(
+ cp.concatenate((wu_b, 1 - wu_b, wv_b, 1 - wv_b)),
+ 0,
+ 4 * num_normals,
+ 4 * num_normals,
+ format="csr"
+ )
energy_old = energy
energy = (A_front_data @ z_front - b_front).T @ W_front @ (A_front_data @ z_front - b_front) + \
@@ -603,23 +611,26 @@ def double_side_bilateral_normal_integration(normal_front,
if relative_energy < tol:
break
# del A1, A2, A3, A4, nx, ny
-
+
depth_map_front_est = cp.ones_like(normal_mask, float) * cp.nan
depth_map_front_est[normal_mask] = z_front
depth_map_back_est = cp.ones_like(normal_mask, float) * cp.nan
depth_map_back_est[normal_mask] = z_back
-
+
# manually cut the intersection
- normal_mask[depth_map_front_est>=depth_map_back_est] = False
+ normal_mask[depth_map_front_est >= depth_map_back_est] = False
depth_map_front_est[~normal_mask] = cp.nan
depth_map_back_est[~normal_mask] = cp.nan
vertices_front = cp.asnumpy(
- map_depth_map_to_point_clouds(depth_map_front_est, normal_mask, K=None,
- step_size=step_size))
+ map_depth_map_to_point_clouds(
+ depth_map_front_est, normal_mask, K=None, step_size=step_size
+ )
+ )
vertices_back = cp.asnumpy(
- map_depth_map_to_point_clouds(depth_map_back_est, normal_mask, K=None, step_size=step_size))
+ map_depth_map_to_point_clouds(depth_map_back_est, normal_mask, K=None, step_size=step_size)
+ )
facets_back = cp.asnumpy(construct_facets_from(normal_mask))
@@ -656,7 +667,7 @@ def save_normal_tensor(in_tensor, idx, png_path, thickness=0.0):
depth_B_arr = depth2arr(in_tensor["depth_B"][idx])
BNI_dict = {}
-
+
# clothed human
BNI_dict["normal_F"] = normal_F_arr
BNI_dict["normal_B"] = normal_B_arr
diff --git a/lib/common/blender_utils.py b/lib/common/blender_utils.py
index 45b443f2157d712c8be145458bf4f3197b727521..a02260cc722bd9729dfbeb153543ac5f648deacf 100644
--- a/lib/common/blender_utils.py
+++ b/lib/common/blender_utils.py
@@ -3,6 +3,7 @@ import sys, os
from math import radians
import mathutils
import bmesh
+
print(sys.exec_prefix)
from tqdm import tqdm
import numpy as np
@@ -29,7 +30,6 @@ shadows = False
# diffuse_color = (18/255., 139/255., 142/255.,1) #correct
# diffuse_color = (251/255., 60/255., 60/255.,1) #wrong
-
smooth = False
wireframe = False
@@ -47,13 +47,16 @@ compositor_alpha = 0.7
# Helper functions
##################################################
+
def blender_print(*args, **kwargs):
- print (*args, **kwargs, file=sys.stderr)
+ print(*args, **kwargs, file=sys.stderr)
+
def using_app():
''' Returns if script is running through Blender application (GUI or background processing)'''
return (not sys.argv[0].endswith('.py'))
+
def setup_diffuse_transparent_material(target, color, object_transparent, backface_transparent):
''' Sets up diffuse/transparent material with backface culling in cycles'''
@@ -110,8 +113,10 @@ def setup_diffuse_transparent_material(target, color, object_transparent, backfa
links.new(node_mix_backface.outputs[0], node_output.inputs[0])
return
+
##################################################
+
def setup_scene():
global render
global cycles_gpu
@@ -150,12 +155,13 @@ def setup_scene():
if cycles_gpu:
print('Activating GPU acceleration')
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
-
+
if bpy.app.version[0] >= 3:
- cuda_devices = bpy.context.preferences.addons['cycles'].preferences.get_devices_for_type(compute_device_type = 'CUDA')
+ cuda_devices = bpy.context.preferences.addons[
+ 'cycles'].preferences.get_devices_for_type(compute_device_type='CUDA')
else:
- (cuda_devices, opencl_devices) = bpy.context.preferences.addons['cycles'].preferences.get_devices()
-
+ (cuda_devices, opencl_devices
+ ) = bpy.context.preferences.addons['cycles'].preferences.get_devices()
if (len(cuda_devices) < 1):
print('ERROR: CUDA GPU acceleration not available')
@@ -178,7 +184,7 @@ def setup_scene():
if bpy.app.version[0] < 3:
scene.render.tile_x = 64
scene.render.tile_y = 64
-
+
# Disable Blender 3 denoiser to properly measure Cycles render speed
if bpy.app.version[0] >= 3:
scene.cycles.use_denoising = False
@@ -226,7 +232,6 @@ def setup_scene():
bpy.ops.mesh.mark_freestyle_edge(clear=True)
bpy.ops.object.mode_set(mode='OBJECT')
-
# Setup freestyle mode for wireframe overlay rendering
if wireframe:
scene.render.use_freestyle = True
@@ -245,8 +250,10 @@ def setup_scene():
# Output transparent image when no background is used
scene.render.image_settings.color_mode = 'RGBA'
+
##################################################
+
def setup_compositing():
global compositor_image_scale
@@ -275,6 +282,7 @@ def setup_compositing():
links.new(blend_node.outputs[0], tree.nodes['Composite'].inputs[0])
+
def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
'''Render image of given model file'''
global smooth
@@ -288,13 +296,13 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
# Import object into scene
bpy.ops.import_scene.obj(filepath=path)
object = bpy.context.selected_objects[0]
-
+
object.rotation_euler = (radians(90.0), 0.0, radians(yaw))
- z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:,1])
+ z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:, 1])
# z_top = np.max(np.array([vert.co for vert in object.data.vertices])[:,1])
# blender_print(radians(90.0), z_bottom, z_top)
object.location -= mathutils.Vector((0.0, 0.0, z_bottom))
-
+
if quads:
bpy.context.view_layer.objects.active = object
bpy.ops.object.mode_set(mode='EDIT')
@@ -309,11 +317,11 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
bpy.ops.object.mode_set(mode='EDIT')
bpy.ops.mesh.mark_freestyle_edge(clear=False)
bpy.ops.object.mode_set(mode='OBJECT')
-
+
if correct:
- diffuse_color = (18/255., 139/255., 142/255.,1) #correct
+ diffuse_color = (18 / 255., 139 / 255., 142 / 255., 1) #correct
else:
- diffuse_color = (251/255., 60/255., 60/255.,1) #wrong
+ diffuse_color = (251 / 255., 60 / 255., 60 / 255., 1) #wrong
setup_diffuse_transparent_material(object, diffuse_color, object_transparent, mouth_transparent)
@@ -336,10 +344,10 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
bpy.ops.render.render(write_still=True)
# Remove temporary output redirection
-# sys.stdout.flush()
-# os.close(1)
-# os.dup(old)
-# os.close(old)
+ # sys.stdout.flush()
+ # os.close(1)
+ # os.dup(old)
+ # os.close(old)
# Delete last selected object from scene
object.select_set(True)
@@ -351,7 +359,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True):
global quality_preview
if not input_file.endswith('.obj'):
- print('ERROR: Invalid input: ' + input_file )
+ print('ERROR: Invalid input: ' + input_file)
return
print('Processing: ' + input_file)
@@ -361,7 +369,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True):
if quality_preview:
output_file = output_file.replace('.png', '-preview.png')
- angle = 360.0/views
+ angle = 360.0 / views
pbar = tqdm(range(0, views))
for view in pbar:
pbar.set_description(f"{os.path.basename(output_file)} | View:{str(view)}")
@@ -369,8 +377,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True):
output_file_view = f"{output_file}/{view:03d}.png"
if not os.path.exists(os.path.join(output_dir, output_file_view)):
render_file(input_file, input_dir, output_file_view, output_dir, yaw, correct)
-
+
cmd = "ffmpeg -loglevel quiet -r 30 -f lavfi -i color=c=white:s=512x512 -i " + os.path.join(output_dir, output_file, '%3d.png') + \
" -shortest -filter_complex \"[0:v][1:v]overlay=shortest=1,format=yuv420p[out]\" -map \"[out]\" -y " + output_dir+"/"+output_file+".mp4"
os.system(cmd)
-
\ No newline at end of file
diff --git a/lib/common/cloth_extraction.py b/lib/common/cloth_extraction.py
index 7da5f0ec1102f49ff513af27e08597b0bd65bcb7..612a96787e1aa836b097971e7aaf55b284ef178a 100644
--- a/lib/common/cloth_extraction.py
+++ b/lib/common/cloth_extraction.py
@@ -36,11 +36,13 @@ def load_segmentation(path, shape):
xy = np.vstack((x, y)).T
coordinates.append(xy)
- segmentations.append({
- "type": val["category_name"],
- "type_id": val["category_id"],
- "coordinates": coordinates,
- })
+ segmentations.append(
+ {
+ "type": val["category_name"],
+ "type_id": val["category_id"],
+ "coordinates": coordinates,
+ }
+ )
return segmentations
@@ -56,9 +58,8 @@ def smpl_to_recon_labels(recon, smpl, k=1):
Returns a dictionary containing the bodypart and the corresponding indices
"""
smpl_vert_segmentation = json.load(
- open(
- os.path.join(os.path.dirname(__file__),
- "smpl_vert_segmentation.json")))
+ open(os.path.join(os.path.dirname(__file__), "smpl_vert_segmentation.json"))
+ )
n = smpl.vertices.shape[0]
y = np.array([None] * n)
for key, val in smpl_vert_segmentation.items():
@@ -71,8 +72,7 @@ def smpl_to_recon_labels(recon, smpl, k=1):
recon_labels = {}
for key in smpl_vert_segmentation.keys():
- recon_labels[key] = list(
- np.argwhere(y_pred == key).flatten().astype(int))
+ recon_labels[key] = list(np.argwhere(y_pred == key).flatten().astype(int))
return recon_labels
@@ -139,8 +139,7 @@ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
if type == 1 or type == 3 or type == 10:
body_parts_to_remove += ["leftForeArm", "rightForeArm"]
# No sleeves at all or lower body clothes
- elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8
- or type == 9):
+ elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9):
body_parts_to_remove += [
"leftForeArm",
"rightForeArm",
@@ -159,8 +158,8 @@ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
]
verts_to_remove = list(
- itertools.chain.from_iterable(
- [recon_labels[part] for part in body_parts_to_remove]))
+ itertools.chain.from_iterable([recon_labels[part] for part in body_parts_to_remove])
+ )
label_mask = np.zeros(num_verts, dtype=bool)
label_mask[verts_to_remove] = True
diff --git a/lib/common/config.py b/lib/common/config.py
index 04a65599f0e32aaa95e007aea1aa106a5c58a868..33c917cc61fce59cccf1f5e4a34056a0f22fc0e3 100644
--- a/lib/common/config.py
+++ b/lib/common/config.py
@@ -100,6 +100,7 @@ _C.bni.thickness = 0.00
_C.bni.hand_thres = 4e-2
_C.bni.face_thres = 6e-2
_C.bni.hps_type = "pixie"
+_C.bni.texture_src = "image"
# kernel_size, stride, dilation, padding
@@ -170,10 +171,10 @@ _C.dataset.rp_type = "pifu900"
_C.dataset.th_type = "train"
_C.dataset.input_size = 512
_C.dataset.rotation_num = 3
-_C.dataset.num_precomp = 10 # Number of segmentation classifiers
-_C.dataset.num_multiseg = 500 # Number of categories per classifier
-_C.dataset.num_knn = 10 # for loss/error
-_C.dataset.num_knn_dis = 20 # for accuracy
+_C.dataset.num_precomp = 10 # Number of segmentation classifiers
+_C.dataset.num_multiseg = 500 # Number of categories per classifier
+_C.dataset.num_knn = 10 # for loss/error
+_C.dataset.num_knn_dis = 20 # for accuracy
_C.dataset.num_verts_max = 20000
_C.dataset.zray_type = False
_C.dataset.online_smpl = False
@@ -210,8 +211,7 @@ def get_cfg_defaults():
# Alternatively, provide a way to import the defaults as
# a global singleton:
-cfg = _C # users can `from config import cfg`
-
+cfg = _C # users can `from config import cfg`
# cfg = get_cfg_defaults()
# cfg.merge_from_file('./configs/example.yaml')
@@ -244,9 +244,7 @@ def parse_args(args):
def parse_args_extend(args):
if args.resume:
if not os.path.exists(args.log_dir):
- raise ValueError(
- "Experiment are set to resume mode, but log directory does not exist."
- )
+ raise ValueError("Experiment are set to resume mode, but log directory does not exist.")
# load log's cfg
cfg_file = os.path.join(args.log_dir, "cfg.yaml")
diff --git a/lib/common/imutils.py b/lib/common/imutils.py
index cc9e09f888ffc2b268308d0a1802debf798db0cd..f96e666b8a80d139aa240275a33f06d1464a6207 100644
--- a/lib/common/imutils.py
+++ b/lib/common/imutils.py
@@ -3,14 +3,13 @@ import mediapipe as mp
import torch
import numpy as np
import torch.nn.functional as F
-from rembg import remove
-from rembg.session_factory import new_session
from PIL import Image
-from torchvision.models import detection
-
from lib.pymafx.core import constants
-from lib.common.cloth_extraction import load_segmentation
+
+from rembg import remove
+from rembg.session_factory import new_session
from torchvision import transforms
+from kornia.geometry.transform import get_affine_matrix2d, warp_affine
def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
@@ -24,42 +23,40 @@ def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
return transforms.Compose(all_ops)
-def aug_matrix(w1, h1, w2, h2):
- dx = (w2 - w1) / 2.0
- dy = (h2 - h1) / 2.0
-
- matrix_trans = np.array([[1.0, 0, dx], [0, 1.0, dy], [0, 0, 1.0]])
-
- scale = np.min([float(w2) / w1, float(h2) / h1])
+def get_affine_matrix_wh(w1, h1, w2, h2):
- M = get_affine_matrix(center=(w2 / 2.0, h2 / 2.0), translate=(0, 0), scale=scale)
-
- M = np.array(M + [0.0, 0.0, 1.0]).reshape(3, 3)
- M = M.dot(matrix_trans)
+ transl = torch.tensor([(w2 - w1) / 2.0, (h2 - h1) / 2.0]).unsqueeze(0)
+ center = torch.tensor([w1 / 2.0, h1 / 2.0]).unsqueeze(0)
+ scale = torch.min(torch.tensor([w2 / w1, h2 / h1])).repeat(2).unsqueeze(0)
+ M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.]))
return M
-def get_affine_matrix(center, translate, scale):
- cx, cy = center
- tx, ty = translate
-
- M = [1, 0, 0, 0, 1, 0]
- M = [x * scale for x in M]
+def get_affine_matrix_box(boxes, w2, h2):
- # Apply translation and of center translation: RSS * C^-1
- M[2] += M[0] * (-cx) + M[1] * (-cy)
- M[5] += M[3] * (-cx) + M[4] * (-cy)
+ # boxes [left, top, right, bottom]
+ width = boxes[:, 2] - boxes[:, 0] #(N,)
+ height = boxes[:, 3] - boxes[:, 1] #(N,)
+ center = torch.tensor(
+ [(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]
+ ).T #(N,2)
+ scale = torch.min(torch.tensor([w2 / width, h2 / height]),
+ dim=0)[0].unsqueeze(1).repeat(1, 2) * 0.9 #(N,2)
+ transl = torch.tensor([w2 / 2.0 - center[:, 0], h2 / 2.0 - center[:, 1]]).unsqueeze(0) #(N,2)
+ M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.]))
- # Apply center translation: T * C * RSS * C^-1
- M[2] += cx + tx
- M[5] += cy + ty
return M
def load_img(img_file):
img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
+
+ # considering 16-bit image
+ if img.dtype == np.uint16:
+ img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
+
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
@@ -68,11 +65,10 @@ def load_img(img_file):
else:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
- return img
+ return torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float(), img.shape[:2]
def get_keypoints(image):
-
def collect_xyv(x, body=True):
lmk = x.landmark
all_lmks = []
@@ -84,8 +80,8 @@ def get_keypoints(image):
mp_holistic = mp.solutions.holistic
with mp_holistic.Holistic(
- static_image_mode=True,
- model_complexity=2,
+ static_image_mode=True,
+ model_complexity=2,
) as holistic:
results = holistic.process(image)
@@ -93,9 +89,15 @@ def get_keypoints(image):
result = {}
result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps
- result["lhand"] = collect_xyv(results.left_hand_landmarks, False) if results.left_hand_landmarks else fake_kps
- result["rhand"] = collect_xyv(results.right_hand_landmarks, False) if results.right_hand_landmarks else fake_kps
- result["face"] = collect_xyv(results.face_landmarks, False) if results.face_landmarks else fake_kps
+ result["lhand"] = collect_xyv(
+ results.left_hand_landmarks, False
+ ) if results.left_hand_landmarks else fake_kps
+ result["rhand"] = collect_xyv(
+ results.right_hand_landmarks, False
+ ) if results.right_hand_landmarks else fake_kps
+ result["face"] = collect_xyv(
+ results.face_landmarks, False
+ ) if results.face_landmarks else fake_kps
return result
@@ -104,13 +106,21 @@ def get_pymafx(image, landmarks):
# image [3,512,512]
- item = {'img_body': F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0]}
+ item = {
+ 'img_body':
+ F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0]
+ }
for part in ['lhand', 'rhand', 'face']:
kp2d = landmarks[part]
kp2d_valid = kp2d[kp2d[:, 3] > 0.]
if len(kp2d_valid) > 0:
- bbox = [min(kp2d_valid[:, 0]), min(kp2d_valid[:, 1]), max(kp2d_valid[:, 0]), max(kp2d_valid[:, 1])]
+ bbox = [
+ min(kp2d_valid[:, 0]),
+ min(kp2d_valid[:, 1]),
+ max(kp2d_valid[:, 0]),
+ max(kp2d_valid[:, 1])
+ ]
center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
@@ -141,20 +151,6 @@ def get_pymafx(image, landmarks):
return item
-def expand_bbox(bbox, width, height, ratio=0.1):
-
- bbox = np.around(bbox).astype(np.int16)
- bbox_width = bbox[2] - bbox[0]
- bbox_height = bbox[3] - bbox[1]
-
- bbox[1] = max(bbox[1] - bbox_height * ratio, 0)
- bbox[3] = min(bbox[3] + bbox_height * ratio, height)
- bbox[0] = max(bbox[0] - bbox_width * ratio, 0)
- bbox[2] = min(bbox[2] + bbox_width * ratio, width)
-
- return bbox
-
-
def remove_floats(mask):
# 1. find all the contours
@@ -173,51 +169,48 @@ def remove_floats(mask):
return new_mask
-def process_image(img_file, hps_type, single, input_res=512):
+def process_image(img_file, hps_type, single, input_res, detector):
- img_raw = load_img(img_file)
-
- in_height, in_width = img_raw.shape[:2]
- M = aug_matrix(in_width, in_height, input_res * 2, input_res * 2)
-
- # from rectangle to square by padding (input_res*2, input_res*2)
- img_square = cv2.warpAffine(img_raw, M[0:2, :], (input_res * 2, input_res * 2), flags=cv2.INTER_CUBIC)
+ img_raw, (in_height, in_width) = load_img(img_file)
+ tgt_res = input_res * 2
+ M_square = get_affine_matrix_wh(in_width, in_height, tgt_res, tgt_res)
+ img_square = warp_affine(
+ img_raw,
+ M_square[:, :2], (tgt_res, ) * 2,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True
+ )
# detection for bbox
- detector = detection.maskrcnn_resnet50_fpn(weights=detection.MaskRCNN_ResNet50_FPN_V2_Weights)
- detector.eval()
- predictions = detector([torch.from_numpy(img_square).permute(2, 0, 1) / 255.])[0]
+ predictions = detector(img_square / 255.)[0]
if single:
top_score = predictions["scores"][predictions["labels"] == 1].max()
human_ids = torch.where(predictions["scores"] == top_score)[0]
else:
- human_ids = torch.logical_and(predictions["labels"] == 1, predictions["scores"] > 0.9).nonzero().squeeze(1)
+ human_ids = torch.logical_and(predictions["labels"] == 1,
+ predictions["scores"] > 0.9).nonzero().squeeze(1)
boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy()
masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy()
- width = boxes[:, 2] - boxes[:, 0] #(N,)
- height = boxes[:, 3] - boxes[:, 1] #(N,)
- center = np.array([(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]).T #(N,2)
- scale = np.array([width, height]).max(axis=0) / 90.
+ M_crop = get_affine_matrix_box(boxes, input_res, input_res)
img_icon_lst = []
img_crop_lst = []
img_hps_lst = []
img_mask_lst = []
- uncrop_param_lst = []
landmark_lst = []
hands_visibility_lst = []
img_pymafx_lst = []
uncrop_param = {
- "center": center,
- "scale": scale,
"ori_shape": [in_height, in_width],
"box_shape": [input_res, input_res],
- "crop_shape": [input_res * 2, input_res * 2, 3],
- "M": M,
+ "square_shape": [tgt_res, tgt_res],
+ "M_square": M_square,
+ "M_crop": M_crop
}
for idx in range(len(boxes)):
@@ -228,59 +221,74 @@ def process_image(img_file, hps_type, single, input_res=512):
else:
mask_detection = masks[0] * 0.
- img_crop, _ = crop(
- np.concatenate([img_square, (mask_detection < 0.4) * 255], axis=2), center[idx], scale[idx], [input_res, input_res])
-
- # get accurate segmentation mask of focus person
+ img_square_rgba = torch.cat(
+ [img_square.squeeze(0).permute(1, 2, 0),
+ torch.tensor(mask_detection < 0.4) * 255],
+ dim=2
+ )
+
+ img_crop = warp_affine(
+ img_square_rgba.unsqueeze(0).permute(0, 3, 1, 2),
+ M_crop[idx:idx + 1, :2], (input_res, ) * 2,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True
+ ).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
+
+ # get accurate person segmentation mask
img_rembg = remove(img_crop, post_process_mask=True, session=new_session("u2net"))
img_mask = remove_floats(img_rembg[:, :, [3]])
- # required image tensors / arrays
-
- # img_icon (tensor): (-1, 1), [3,512,512]
- # img_hps (tensor): (-2.11, 2.44), [3,224,224]
-
- # img_np (array): (0, 255), [512,512,3]
- # img_rembg (array): (0, 255), [512,512,4]
- # img_mask (array): (0, 1), [512,512,1]
- # img_crop (array): (0, 255), [512,512,4]
-
mean_icon = std_icon = (0.5, 0.5, 0.5)
img_np = (img_rembg[..., :3] * img_mask).astype(np.uint8)
- img_icon = transform_to_tensor(512, mean_icon, std_icon)(Image.fromarray(img_np)) * torch.tensor(img_mask).permute(
- 2, 0, 1)
- img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN, constants.IMG_NORM_STD)(Image.fromarray(img_np))
+ img_icon = transform_to_tensor(512, mean_icon, std_icon)(
+ Image.fromarray(img_np)
+ ) * torch.tensor(img_mask).permute(2, 0, 1)
+ img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN,
+ constants.IMG_NORM_STD)(Image.fromarray(img_np))
landmarks = get_keypoints(img_np)
+ # get hands visibility
+ hands_visibility = [True, True]
+ if landmarks['lhand'][:, -1].mean() == 0.:
+ hands_visibility[0] = False
+ if landmarks['rhand'][:, -1].mean() == 0.:
+ hands_visibility[1] = False
+ hands_visibility_lst.append(hands_visibility)
+
if hps_type == 'pymafx':
img_pymafx_lst.append(
get_pymafx(
- transform_to_tensor(512, constants.IMG_NORM_MEAN, constants.IMG_NORM_STD)(Image.fromarray(img_np)),
- landmarks))
+ transform_to_tensor(512, constants.IMG_NORM_MEAN,
+ constants.IMG_NORM_STD)(Image.fromarray(img_np)), landmarks
+ )
+ )
img_crop_lst.append(torch.tensor(img_crop).permute(2, 0, 1) / 255.0)
img_icon_lst.append(img_icon)
img_hps_lst.append(img_hps)
img_mask_lst.append(torch.tensor(img_mask[..., 0]))
- uncrop_param_lst.append(uncrop_param)
landmark_lst.append(landmarks['body'])
- hands_visibility = [True, True]
- if landmarks['lhand'][:, -1].mean() == 0.:
- hands_visibility[0] = False
- if landmarks['rhand'][:, -1].mean() == 0.:
- hands_visibility[1] = False
- hands_visibility_lst.append(hands_visibility)
+ # required image tensors / arrays
+
+ # img_icon (tensor): (-1, 1), [3,512,512]
+ # img_hps (tensor): (-2.11, 2.44), [3,224,224]
+
+ # img_np (array): (0, 255), [512,512,3]
+ # img_rembg (array): (0, 255), [512,512,4]
+ # img_mask (array): (0, 1), [512,512,1]
+ # img_crop (array): (0, 255), [512,512,4]
return_dict = {
- "img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res]
- "img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res]
- "img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res]
- "img_raw": img_raw, #[H, W, 3]
- "img_mask": torch.stack(img_mask_lst).float(), #[N, res, res]
+ "img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res]
+ "img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res]
+ "img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res]
+ "img_raw": img_raw, #[1, 3, H, W]
+ "img_mask": torch.stack(img_mask_lst).float(), #[N, res, res]
"uncrop_param": uncrop_param,
- "landmark": torch.stack(landmark_lst), #[N, 33, 4]
+ "landmark": torch.stack(landmark_lst), #[N, 33, 4]
"hands_visibility": hands_visibility_lst,
}
@@ -302,250 +310,51 @@ def process_image(img_file, hps_type, single, input_res=512):
return return_dict
-def get_transform(center, scale, res):
- """Generate transformation matrix."""
- h = 100 * scale
- t = np.zeros((3, 3))
- t[0, 0] = float(res[1]) / h
- t[1, 1] = float(res[0]) / h
- t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
- t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
- t[2, 2] = 1
-
- return t
-
-
-def transform(pt, center, scale, res, invert=0):
- """Transform pixel location to different reference."""
- t = get_transform(center, scale, res)
- if invert:
- t = np.linalg.inv(t)
- new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
- new_pt = np.dot(t, new_pt)
- return np.around(new_pt[:2]).astype(np.int16)
-
-
-def crop(img, center, scale, res):
- """Crop image according to the supplied bounding box."""
-
- img_height, img_width = img.shape[:2]
-
- # Upper left point
- ul = np.array(transform([0, 0], center, scale, res, invert=1))
-
- # Bottom right point
- br = np.array(transform(res, center, scale, res, invert=1))
-
- new_shape = [br[1] - ul[1], br[0] - ul[0]]
- if len(img.shape) > 2:
- new_shape += [img.shape[2]]
- new_img = np.zeros(new_shape)
-
- # Range to fill new array
- new_x = max(0, -ul[0]), min(br[0], img_width) - ul[0]
- new_y = max(0, -ul[1]), min(br[1], img_height) - ul[1]
-
- # Range to sample from original image
- old_x = max(0, ul[0]), min(img_width, br[0])
- old_y = max(0, ul[1]), min(img_height, br[1])
-
- new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
- new_img = F.interpolate(
- torch.tensor(new_img).permute(2, 0, 1).unsqueeze(0), res, mode='bilinear').permute(0, 2, 3,
- 1)[0].numpy().astype(np.uint8)
-
- return new_img, (old_x, new_x, old_y, new_y, new_shape)
-
-
-def crop_segmentation(org_coord, res, cropping_parameters):
- old_x, new_x, old_y, new_y, new_shape = cropping_parameters
+def blend_rgb_norm(norms, data):
- new_coord = np.zeros((org_coord.shape))
- new_coord[:, 0] = new_x[0] + (org_coord[:, 0] - old_x[0])
- new_coord[:, 1] = new_y[0] + (org_coord[:, 1] - old_y[0])
-
- new_coord[:, 0] = res[0] * (new_coord[:, 0] / new_shape[1])
- new_coord[:, 1] = res[1] * (new_coord[:, 1] / new_shape[0])
-
- return new_coord
-
-
-def corner_align(ul, br):
-
- if ul[1] - ul[0] != br[1] - br[0]:
- ul[1] = ul[0] + br[1] - br[0]
-
- return ul, br
-
-
-def uncrop(img, center, scale, orig_shape):
- """'Undo' the image cropping/resizing.
- This function is used when evaluating mask/part segmentation.
- """
-
- res = img.shape[:2]
-
- # Upper left point
- ul = np.array(transform([0, 0], center, scale, res, invert=1))
- # Bottom right point
- br = np.array(transform(res, center, scale, res, invert=1))
-
- # quick fix
- ul, br = corner_align(ul, br)
-
- # size of cropped image
- crop_shape = [br[1] - ul[1], br[0] - ul[0]]
- new_img = np.zeros(orig_shape, dtype=np.uint8)
-
- # Range to fill new array
- new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
- new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
+ # norms [N, 3, res, res]
+ masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1)
+ norm_mask = F.interpolate(
+ torch.cat([norms, masks], dim=1).detach(),
+ size=data["uncrop_param"]["box_shape"],
+ mode="bilinear",
+ align_corners=False
+ )
+ final = data["img_raw"].type_as(norm_mask)
- # Range to sample from original image
- old_x = max(0, ul[0]), min(orig_shape[1], br[0])
- old_y = max(0, ul[1]), min(orig_shape[0], br[1])
+ for idx in range(len(norms)):
- img = np.array(Image.fromarray(img.astype(np.uint8)).resize(crop_shape))
+ norm_pred = (norm_mask[idx:idx + 1, :3, :, :] + 1.0) * 255.0 / 2.0
+ mask_pred = norm_mask[idx:idx + 1, 3:4, :, :].repeat(1, 3, 1, 1)
- new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
+ norm_ori = unwrap(norm_pred, data["uncrop_param"], idx)
+ mask_ori = unwrap(mask_pred, data["uncrop_param"], idx)
- return new_img
+ final = final * (1.0 - mask_ori) + norm_ori * mask_ori
+ return final.detach().cpu()
-def rot_aa(aa, rot):
- """Rotate axis angle parameters."""
- # pose parameters
- R = np.array([
- [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
- [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
- [0, 0, 1],
- ])
- # find the rotation of the body in camera frame
- per_rdg, _ = cv2.Rodrigues(aa)
- # apply the global rotation to the global orientation
- resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
- aa = (resrot.T)[0]
- return aa
+def unwrap(image, uncrop_param, idx):
-def flip_img(img):
- """Flip rgb images or masks.
- channels come last, e.g. (256,256,3).
- """
- img = np.fliplr(img)
- return img
+ device = image.device
+ img_square = warp_affine(
+ image,
+ torch.inverse(uncrop_param["M_crop"])[idx:idx + 1, :2].to(device),
+ uncrop_param["square_shape"],
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True
+ )
-def flip_kp(kp, is_smpl=False):
- """Flip keypoints."""
- if len(kp) == 24:
- if is_smpl:
- flipped_parts = constants.SMPL_JOINTS_FLIP_PERM
- else:
- flipped_parts = constants.J24_FLIP_PERM
- elif len(kp) == 49:
- if is_smpl:
- flipped_parts = constants.SMPL_J49_FLIP_PERM
- else:
- flipped_parts = constants.J49_FLIP_PERM
- kp = kp[flipped_parts]
- kp[:, 0] = -kp[:, 0]
- return kp
-
-
-def flip_pose(pose):
- """Flip pose.
- The flipping is based on SMPL parameters.
- """
- flipped_parts = constants.SMPL_POSE_FLIP_PERM
- pose = pose[flipped_parts]
- # we also negate the second and the third dimension of the axis-angle
- pose[1::3] = -pose[1::3]
- pose[2::3] = -pose[2::3]
- return pose
-
-
-def normalize_2d_kp(kp_2d, crop_size=224, inv=False):
- # Normalize keypoints between -1, 1
- if not inv:
- ratio = 1.0 / crop_size
- kp_2d = 2.0 * kp_2d * ratio - 1.0
- else:
- ratio = 1.0 / crop_size
- kp_2d = (kp_2d + 1.0) / (2 * ratio)
-
- return kp_2d
-
-
-def visualize_landmarks(image, joints, color):
-
- img_w, img_h = image.shape[:2]
-
- for joint in joints:
- image = cv2.circle(image, (int(joint[0] * img_w), int(joint[1] * img_h)), 5, color)
-
- return image
-
-
-def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
- """
- param joints: [num_joints, 3]
- param joints_vis: [num_joints, 3]
- return: target, target_weight(1: visible, 0: invisible)
- """
- num_joints = joints.shape[0]
- device = joints.device
- cur_device = torch.device(device.type, device.index)
- if not hasattr(heatmap_size, "__len__"):
- # width height
- heatmap_size = [heatmap_size, heatmap_size]
- assert len(heatmap_size) == 2
- target_weight = np.ones((num_joints, 1), dtype=np.float32)
- if joints_vis is not None:
- target_weight[:, 0] = joints_vis[:, 0]
- target = torch.zeros(
- (num_joints, heatmap_size[1], heatmap_size[0]),
- dtype=torch.float32,
- device=cur_device,
+ img_ori = warp_affine(
+ img_square,
+ torch.inverse(uncrop_param["M_square"])[:, :2].to(device),
+ uncrop_param["ori_shape"],
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True
)
- tmp_size = sigma * 3
-
- for joint_id in range(num_joints):
- mu_x = int(joints[joint_id][0] * heatmap_size[0] + 0.5)
- mu_y = int(joints[joint_id][1] * heatmap_size[1] + 0.5)
- # Check that any part of the gaussian is in-bounds
- ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
- br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
- if (ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] or br[0] < 0 or br[1] < 0):
- # If not, just return the image as is
- target_weight[joint_id] = 0
- continue
-
- # # Generate gaussian
- size = 2 * tmp_size + 1
- # x = np.arange(0, size, 1, np.float32)
- # y = x[:, np.newaxis]
- # x0 = y0 = size // 2
- # # The gaussian is not normalized, we want the center value to equal 1
- # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
- # g = torch.from_numpy(g.astype(np.float32))
-
- x = torch.arange(0, size, dtype=torch.float32, device=cur_device)
- y = x.unsqueeze(-1)
- x0 = y0 = size // 2
- # The gaussian is not normalized, we want the center value to equal 1
- g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
-
- # Usable gaussian range
- g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
- g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
- # Image range
- img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
- img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
-
- v = target_weight[joint_id]
- if v > 0.5:
- target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
-
- return target, target_weight
+ return img_ori
diff --git a/lib/common/libmesh/inside_mesh.py b/lib/common/libmesh/inside_mesh.py
index 110ff87407efc49010328764299b824f647708cb..eaac43c2e6fe103c6a1dd4e182642ff0cc6a024a 100644
--- a/lib/common/libmesh/inside_mesh.py
+++ b/lib/common/libmesh/inside_mesh.py
@@ -5,7 +5,7 @@ from .triangle_hash import TriangleHash as _TriangleHash
def check_mesh_contains(mesh, points, hash_resolution=512):
intersector = MeshIntersector(mesh, hash_resolution)
contains, hole_points = intersector.query(points)
- return contains, hole_points
+ return contains, hole_points
class MeshIntersector:
@@ -25,8 +25,7 @@ class MeshIntersector:
# assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5))
triangles2d = triangles[:, :, :2]
- self._tri_intersector2d = TriangleIntersector2d(
- triangles2d, resolution)
+ self._tri_intersector2d = TriangleIntersector2d(triangles2d, resolution)
def query(self, points):
# Rescale points
@@ -38,8 +37,7 @@ class MeshIntersector:
# cull points outside of the axis aligned bounding box
# this avoids running ray tests unless points are close
- inside_aabb = np.all(
- (0 <= points) & (points <= self.resolution), axis=1)
+ inside_aabb = np.all((0 <= points) & (points <= self.resolution), axis=1)
if not inside_aabb.any():
return contains, hole_points
@@ -48,14 +46,14 @@ class MeshIntersector:
points = points[mask]
# Compute intersection depth and check order
- points_indices, tri_indices = self._tri_intersector2d.query(
- points[:, :2])
+ points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2])
triangles_intersect = self._triangles[tri_indices]
points_intersect = points[points_indices]
depth_intersect, abs_n_2 = self.compute_intersection_depth(
- points_intersect, triangles_intersect)
+ points_intersect, triangles_intersect
+ )
# Count number of intersections in both directions
smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2
@@ -73,7 +71,7 @@ class MeshIntersector:
# print('Warning: contains1 != contains2 for some points.')
contains[mask] = (contains1 & contains2)
hole_points[mask] = np.logical_xor(contains1, contains2)
- return contains, hole_points
+ return contains, hole_points
def compute_intersection_depth(self, points, triangles):
t1 = triangles[:, 0, :]
@@ -150,7 +148,7 @@ class TriangleIntersector2d:
sum_uv = u + v
contains[mask] = (
- (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA)
- & (0 < sum_uv) & (sum_uv < abs_detA)
+ (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) & (0 < sum_uv) &
+ (sum_uv < abs_detA)
)
return contains
diff --git a/lib/common/libmesh/setup.py b/lib/common/libmesh/setup.py
index a565e470dd6bb6a2042c86b47a1524c5f7194d58..38ac162300df4e987134e81306e1a6ad674a5323 100644
--- a/lib/common/libmesh/setup.py
+++ b/lib/common/libmesh/setup.py
@@ -2,7 +2,4 @@ from setuptools import setup
from Cython.Build import cythonize
import numpy
-
-setup(name = 'libmesh',
- ext_modules = cythonize("*.pyx"),
- include_dirs=[numpy.get_include()])
+setup(name='libmesh', ext_modules=cythonize("*.pyx"), include_dirs=[numpy.get_include()])
diff --git a/lib/common/libvoxelize/setup.py b/lib/common/libvoxelize/setup.py
index 7a4056e8914dbc65b4fe99acc4d7e3e9f49a04e6..1a534ece09af40fbabd3221eae2e2f5d7931f80c 100644
--- a/lib/common/libvoxelize/setup.py
+++ b/lib/common/libvoxelize/setup.py
@@ -1,5 +1,4 @@
from setuptools import setup
from Cython.Build import cythonize
-setup(name = 'libvoxelize',
- ext_modules = cythonize("*.pyx"))
+setup(name='libvoxelize', ext_modules=cythonize("*.pyx"))
diff --git a/lib/common/local_affine.py b/lib/common/local_affine.py
index 6cbaa8f0626214c518f95551cfae1ba78a60fc43..3a6ef580ebb306bc8f3a3fbb87200b33df262c68 100644
--- a/lib/common/local_affine.py
+++ b/lib/common/local_affine.py
@@ -16,7 +16,6 @@ from lib.common.train_util import init_loss
# reference: https://github.com/wuhaozhe/pytorch-nicp
class LocalAffine(nn.Module):
-
def __init__(self, num_points, batch_size=1, edges=None):
'''
specify the number of points, the number of points should be constant across the batch
@@ -26,8 +25,14 @@ class LocalAffine(nn.Module):
add additional pooling on top of w matrix
'''
super(LocalAffine, self).__init__()
- self.A = nn.Parameter(torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1))
- self.b = nn.Parameter(torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(batch_size, num_points, 1, 1))
+ self.A = nn.Parameter(
+ torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1)
+ )
+ self.b = nn.Parameter(
+ torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(
+ batch_size, num_points, 1, 1
+ )
+ )
self.edges = edges
self.num_points = num_points
@@ -38,24 +43,23 @@ class LocalAffine(nn.Module):
'''
if self.edges is None:
raise Exception("edges cannot be none when calculate stiff")
- idx1 = self.edges[:, 0]
- idx2 = self.edges[:, 1]
affine_weight = torch.cat((self.A, self.b), dim=3)
- w1 = torch.index_select(affine_weight, dim=1, index=idx1)
- w2 = torch.index_select(affine_weight, dim=1, index=idx2)
+ w1 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 0])
+ w2 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 1])
w_diff = (w1 - w2)**2
w_rigid = (torch.linalg.det(self.A) - 1.0)**2
return w_diff, w_rigid
def forward(self, x):
'''
- x should have shape of B * N * 3
+ x should have shape of B * N * 3 * 1
'''
x = x.unsqueeze(3)
out_x = torch.matmul(self.A, x)
out_x = out_x + self.b
- stiffness, rigid = self.stiffness()
out_x.squeeze_(3)
+ stiffness, rigid = self.stiffness()
+
return out_x, stiffness, rigid
@@ -75,10 +79,16 @@ def register(target_mesh, src_mesh, device):
tgt_mesh = trimesh2meshes(target_mesh).to(device)
src_verts = src_mesh.verts_padded().clone()
- local_affine_model = LocalAffine(src_mesh.verts_padded().shape[1],
- src_mesh.verts_padded().shape[0], src_mesh.edges_packed()).to(device)
+ local_affine_model = LocalAffine(
+ src_mesh.verts_padded().shape[1],
+ src_mesh.verts_padded().shape[0], src_mesh.edges_packed()
+ ).to(device)
- optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}], lr=1e-2, amsgrad=True)
+ optimizer_cloth = torch.optim.Adam(
+ [{
+ 'params': local_affine_model.parameters()
+ }], lr=1e-2, amsgrad=True
+ )
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_cloth,
mode="min",
@@ -90,28 +100,27 @@ def register(target_mesh, src_mesh, device):
losses = init_loss()
- loop_cloth = tqdm(range(200))
+ loop_cloth = tqdm(range(100))
for i in loop_cloth:
optimizer_cloth.zero_grad()
- deformed_verts, stiffness, rigid = local_affine_model(src_verts)
+ deformed_verts, stiffness, rigid = local_affine_model(x=src_verts)
src_mesh = src_mesh.update_padded(deformed_verts)
# losses for laplacian, edge, normal consistency
update_mesh_shape_prior_losses(src_mesh, losses)
losses["cloth"]["value"] = chamfer_distance(
- x=src_mesh.verts_padded(),
- y=tgt_mesh.verts_padded())[0]
-
- losses["stiffness"]["value"] = torch.mean(stiffness)
+ x=src_mesh.verts_padded(), y=tgt_mesh.verts_padded()
+ )[0]
+ losses["stiff"]["value"] = torch.mean(stiffness)
losses["rigid"]["value"] = torch.mean(rigid)
# Weighted sum of the losses
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
- pbar_desc = "Register SMPL-X towards ECON --- "
+ pbar_desc = "Register SMPL-X -> d-BiNI -- "
for k in losses.keys():
if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
@@ -119,7 +128,7 @@ def register(target_mesh, src_mesh, device):
losses[k]["value"] * losses[k]["weight"]
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "
- pbar_desc += f"Total: {cloth_loss:.5f}"
+ pbar_desc += f"TOTAL: {cloth_loss:.3f}"
loop_cloth.set_description(pbar_desc)
# update params
@@ -131,6 +140,7 @@ def register(target_mesh, src_mesh, device):
src_mesh.verts_packed().detach().squeeze(0).cpu(),
src_mesh.faces_packed().detach().squeeze(0).cpu(),
process=False,
- maintains_order=True)
+ maintains_order=True
+ )
return final
diff --git a/lib/common/render.py b/lib/common/render.py
index c4407bb3e0035d21300c87dd198f573d767d21e1..392566099a1034eb9138d905923248465a625aad 100644
--- a/lib/common/render.py
+++ b/lib/common/render.py
@@ -31,7 +31,8 @@ from pytorch3d.renderer import (
)
from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.structures import Meshes
-from lib.dataset.mesh_util import get_visibility, blend_rgb_norm
+from lib.dataset.mesh_util import get_visibility
+from lib.common.imutils import blend_rgb_norm
import lib.common.render_utils as util
import torch
@@ -74,20 +75,23 @@ def query_color(verts, faces, image, device):
(xy, z) = verts.split([2, 1], dim=1)
visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
- uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
+ uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
colors = (
- (torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :, 0].permute(1, 0) +
- 1.0) * 0.5 * 255.0)
+ (
+ torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :,
+ 0].permute(1, 0) + 1.0
+ ) * 0.5 * 255.0
+ )
colors[visibility == 0.0] = (
(Meshes(verts.unsqueeze(0), faces.unsqueeze(0)).verts_normals_padded().squeeze(0) + 1.0) *
- 0.5 * 255.0)[visibility == 0.0]
+ 0.5 * 255.0
+ )[visibility == 0.0]
return colors.detach().cpu()
class cleanShader(torch.nn.Module):
-
def __init__(self, blend_params=None):
super().__init__()
self.blend_params = blend_params if blend_params is not None else BlendParams()
@@ -103,7 +107,6 @@ class cleanShader(torch.nn.Module):
class Render:
-
def __init__(self, size=512, device=torch.device("cuda:0")):
self.device = device
self.size = size
@@ -119,21 +122,30 @@ class Render:
self.cam_pos = {
"frontback":
- torch.tensor([
- (0, self.mesh_y_center, self.dis),
- (0, self.mesh_y_center, -self.dis),
- ]),
+ torch.tensor(
+ [
+ (0, self.mesh_y_center, self.dis),
+ (0, self.mesh_y_center, -self.dis),
+ ]
+ ),
"four":
- torch.tensor([
- (0, self.mesh_y_center, self.dis),
- (self.dis, self.mesh_y_center, 0),
- (0, self.mesh_y_center, -self.dis),
- (-self.dis, self.mesh_y_center, 0),
- ]),
+ torch.tensor(
+ [
+ (0, self.mesh_y_center, self.dis),
+ (self.dis, self.mesh_y_center, 0),
+ (0, self.mesh_y_center, -self.dis),
+ (-self.dis, self.mesh_y_center, 0),
+ ]
+ ),
"around":
- torch.tensor([(100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center,
- 100.0 * math.sin(np.pi / 180 * angle))
- for angle in range(0, 360, self.step)])
+ torch.tensor(
+ [
+ (
+ 100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center,
+ 100.0 * math.sin(np.pi / 180 * angle)
+ ) for angle in range(0, 360, self.step)
+ ]
+ )
}
self.type = "color"
@@ -153,8 +165,8 @@ class Render:
R, T = look_at_view_transform(
eye=self.cam_pos[type][idx],
- at=((0, self.mesh_y_center, 0),),
- up=((0, 1, 0),),
+ at=((0, self.mesh_y_center, 0), ),
+ up=((0, 1, 0), ),
)
cameras = FoVOrthographicCameras(
@@ -167,7 +179,7 @@ class Render:
min_y=-100.0,
max_x=100.0,
min_x=-100.0,
- scale_xyz=(self.scale * np.ones(3),) * len(R),
+ scale_xyz=(self.scale * np.ones(3), ) * len(R),
)
return cameras
@@ -202,15 +214,17 @@ class Render:
cull_backfaces=True,
)
- self.silhouetteRas = MeshRasterizer(cameras=camera,
- raster_settings=self.raster_settings_silhouette)
- self.renderer = MeshRenderer(rasterizer=self.silhouetteRas,
- shader=SoftSilhouetteShader())
+ self.silhouetteRas = MeshRasterizer(
+ cameras=camera, raster_settings=self.raster_settings_silhouette
+ )
+ self.renderer = MeshRenderer(
+ rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
+ )
elif type == "pointcloud":
- self.raster_settings_pcd = PointsRasterizationSettings(image_size=self.size,
- radius=0.006,
- points_per_pixel=10)
+ self.raster_settings_pcd = PointsRasterizationSettings(
+ image_size=self.size, radius=0.006, points_per_pixel=10
+ )
self.pcdRas = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings_pcd)
self.renderer = PointsRenderer(
@@ -230,8 +244,12 @@ class Render:
V_lst = []
F_lst = []
for V, F in zip(verts, faces):
- V_lst.append(torch.tensor(V).float().to(self.device))
- F_lst.append(torch.tensor(F).long().to(self.device))
+ if not torch.is_tensor(V):
+ V_lst.append(torch.tensor(V).float().to(self.device))
+ F_lst.append(torch.tensor(F).long().to(self.device))
+ else:
+ V_lst.append(V.float().to(self.device))
+ F_lst.append(F.long().to(self.device))
self.meshes = Meshes(V_lst, F_lst).to(self.device)
else:
# array or tensor
@@ -248,7 +266,8 @@ class Render:
# texture only support single mesh
if len(self.meshes) == 1:
self.meshes.textures = TexturesVertex(
- verts_features=(self.meshes.verts_normals_padded() + 1.0) * 0.5)
+ verts_features=(self.meshes.verts_normals_padded() + 1.0) * 0.5
+ )
def get_image(self, cam_type="frontback", type="rgb", bg="gray"):
@@ -260,7 +279,8 @@ class Render:
current_mesh = self.meshes[mesh_id]
current_mesh.textures = TexturesVertex(
- verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5)
+ verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
+ )
if type == "depth":
fragments = self.meshRas(current_mesh.extend(len(self.cam_pos[cam_type])))
@@ -276,7 +296,7 @@ class Render:
print(f"unknown {type}")
if cam_type == 'frontback':
- images[1] = torch.flip(images[1], dims=(-1,))
+ images[1] = torch.flip(images[1], dims=(-1, ))
# images [N_render, 3, res, res]
img_lst.append(images.unsqueeze(1))
@@ -287,9 +307,8 @@ class Render:
return list(meshes)
def get_rendered_video_multi(self, data, save_path):
-
- width = data["img_raw"].shape[1]
- height = data["img_raw"].shape[0]
+
+ height, width = data["img_raw"].shape[2:]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video = cv2.VideoWriter(
@@ -302,14 +321,15 @@ class Render:
pbar = tqdm(range(len(self.meshes)))
pbar.set_description(colored(f"Normal Rendering {os.path.basename(save_path)}...", "blue"))
- mesh_renders = [] #[(N_cam, 3, res, res)*N_mesh]
+ mesh_renders = [] #[(N_cam, 3, res, res)*N_mesh]
# render all the normals
for mesh_id in pbar:
current_mesh = self.meshes[mesh_id]
current_mesh.textures = TexturesVertex(
- verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5)
+ verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
+ )
norm_lst = []
@@ -320,21 +340,33 @@ class Render:
self.init_renderer(batch_cams, "mesh", "gray")
norm_lst.append(
- self.renderer(current_mesh.extend(len(batch_cams_idx)))[..., :3].permute(
- 0, 3, 1, 2))
+ self.renderer(current_mesh.extend(len(batch_cams_idx))
+ )[..., :3].permute(0, 3, 1, 2)
+ )
mesh_renders.append(torch.cat(norm_lst).detach().cpu())
# generate video frame by frame
pbar = tqdm(range(len(self.cam_pos["around"])))
pbar.set_description(colored(f"Video Exporting {os.path.basename(save_path)}...", "blue"))
+
for cam_id in pbar:
- img_raw = data["img_raw"].astype(np.uint8)
+ img_raw = data["img_raw"]
num_obj = len(mesh_renders) // 2
- img_smpl = blend_rgb_norm((torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0, data)
- img_cloth = blend_rgb_norm((torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0, data)
+ img_smpl = blend_rgb_norm(
+ (torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0, data
+ )
+ img_cloth = blend_rgb_norm(
+ (torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0, data
+ )
- top_img = cv2.resize(np.concatenate([img_raw, img_smpl], axis=1), (width, height // 2))
- final_img = np.concatenate([top_img, img_cloth], axis=0)
+ top_img = cv2.resize(
+ torch.cat([img_raw, img_smpl],
+ dim=-1).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8),
+ (width, height // 2)
+ )
+ final_img = np.concatenate(
+ [top_img, img_cloth.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)], axis=0
+ )
video.write(final_img[:, :, ::-1])
video.release()
diff --git a/lib/common/render_utils.py b/lib/common/render_utils.py
index 013f625ec7c62a54e1fd5f5bf8579afdb6561023..cb2ca46f420c063c7a1c6a82276d41c42852e451 100644
--- a/lib/common/render_utils.py
+++ b/lib/common/render_utils.py
@@ -25,9 +25,7 @@ from pytorch3d.renderer.mesh import rasterize_meshes
Tensor = NewType("Tensor", torch.Tensor)
-def solid_angles(points: Tensor,
- triangles: Tensor,
- thresh: float = 1e-8) -> Tensor:
+def solid_angles(points: Tensor, triangles: Tensor, thresh: float = 1e-8) -> Tensor:
"""Compute solid angle between the input points and triangles
Follows the method described in:
The Solid Angle of a Plane Triangle
@@ -55,9 +53,7 @@ def solid_angles(points: Tensor,
norms = torch.norm(centered_tris, dim=-1)
# Should be BxQxFx3
- cross_prod = torch.cross(centered_tris[:, :, :, 1],
- centered_tris[:, :, :, 2],
- dim=-1)
+ cross_prod = torch.cross(centered_tris[:, :, :, 1], centered_tris[:, :, :, 2], dim=-1)
# Should be BxQxF
numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
del cross_prod
@@ -67,8 +63,10 @@ def solid_angles(points: Tensor,
dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
del centered_tris
- denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] +
- dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0])
+ denominator = (
+ norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] + dot02 * norms[:, :, :, 1] +
+ dot12 * norms[:, :, :, 0]
+ )
del dot01, dot12, dot02, norms
# Should be BxQ
@@ -80,9 +78,7 @@ def solid_angles(points: Tensor,
return 2 * solid_angle
-def winding_numbers(points: Tensor,
- triangles: Tensor,
- thresh: float = 1e-8) -> Tensor:
+def winding_numbers(points: Tensor, triangles: Tensor, thresh: float = 1e-8) -> Tensor:
"""Uses winding_numbers to compute inside/outside
Robust inside-outside segmentation using generalized winding numbers
Alec Jacobson,
@@ -109,8 +105,7 @@ def winding_numbers(points: Tensor,
"""
# The generalized winding number is the sum of solid angles of the point
# with respect to all triangles.
- return (1 / (4 * math.pi) *
- solid_angles(points, triangles, thresh=thresh).sum(dim=-1))
+ return (1 / (4 * math.pi) * solid_angles(points, triangles, thresh=thresh).sum(dim=-1))
def batch_contains(verts, faces, points):
@@ -124,8 +119,7 @@ def batch_contains(verts, faces, points):
contains = torch.zeros(B, N)
for i in range(B):
- contains[i] = torch.as_tensor(
- trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
+ contains[i] = torch.as_tensor(trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
return 2.0 * (contains - 0.5)
@@ -155,8 +149,7 @@ def face_vertices(vertices, faces):
bs, nv = vertices.shape[:2]
bs, nf = faces.shape[:2]
device = vertices.device
- faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
- nv)[:, None, None]
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
return vertices[faces.long()]
@@ -168,7 +161,6 @@ class Pytorch3dRasterizer(nn.Module):
x,y,z are in image space, normalized
can only render squared image now
"""
-
def __init__(self, image_size=224, blur_radius=0.0, faces_per_pixel=1):
"""
use fixed raster_settings for rendering faces
@@ -189,8 +181,7 @@ class Pytorch3dRasterizer(nn.Module):
def forward(self, vertices, faces, attributes=None):
fixed_vertices = vertices.clone()
fixed_vertices[..., :2] = -fixed_vertices[..., :2]
- meshes_screen = Meshes(verts=fixed_vertices.float(),
- faces=faces.long())
+ meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long())
raster_settings = self.raster_settings
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_screen,
@@ -204,8 +195,9 @@ class Pytorch3dRasterizer(nn.Module):
vismask = (pix_to_face > -1).float()
D = attributes.shape[-1]
attributes = attributes.clone()
- attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
- 3, attributes.shape[-1])
+ attributes = attributes.view(
+ attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1]
+ )
N, H, W, K, _ = bary_coords.shape
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
@@ -213,8 +205,7 @@ class Pytorch3dRasterizer(nn.Module):
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
- pixel_vals[mask] = 0 # Replace masked values in output.
+ pixel_vals[mask] = 0 # Replace masked values in output.
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
- pixel_vals = torch.cat(
- [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
+ pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
return pixel_vals
diff --git a/lib/common/seg3d_lossless.py b/lib/common/seg3d_lossless.py
index 6afdc822369608dc89a4215c7f68f4041523ed30..4f5cba2a1edb3a5df14d17beabb9d296203865c1 100644
--- a/lib/common/seg3d_lossless.py
+++ b/lib/common/seg3d_lossless.py
@@ -31,7 +31,6 @@ logging.getLogger("lightning").setLevel(logging.ERROR)
class Seg3dLossless(nn.Module):
-
def __init__(
self,
query_func,
@@ -53,19 +52,14 @@ class Seg3dLossless(nn.Module):
"""
super().__init__()
self.query_func = query_func
- self.register_buffer(
- "b_min",
- torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
- self.register_buffer(
- "b_max",
- torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
+ self.register_buffer("b_min", torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
+ self.register_buffer("b_max", torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
# ti.init(arch=ti.cuda)
# self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
if type(resolutions[0]) is int:
- resolutions = torch.tensor([(res, res, res)
- for res in resolutions])
+ resolutions = torch.tensor([(res, res, res) for res in resolutions])
else:
resolutions = torch.tensor(resolutions)
self.register_buffer("resolutions", resolutions)
@@ -87,45 +81,36 @@ class Seg3dLossless(nn.Module):
), f"resolution {resolution} need to be odd becuase of align_corner."
# init first resolution
- init_coords = create_grid3D(0,
- resolutions[-1] - 1,
- steps=resolutions[0]) # [N, 3]
- init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1,
- 1) # [bz, N, 3]
+ init_coords = create_grid3D(0, resolutions[-1] - 1, steps=resolutions[0]) # [N, 3]
+ init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1, 1) # [bz, N, 3]
self.register_buffer("init_coords", init_coords)
# some useful tensors
calculated = torch.zeros(
- (self.resolutions[-1][2], self.resolutions[-1][1],
- self.resolutions[-1][0]),
+ (self.resolutions[-1][2], self.resolutions[-1][1], self.resolutions[-1][0]),
dtype=torch.bool,
)
self.register_buffer("calculated", calculated)
- gird8_offsets = (torch.stack(
- torch.meshgrid(
- [
- torch.tensor([-1, 0, 1]),
- torch.tensor([-1, 0, 1]),
- torch.tensor([-1, 0, 1]),
- ],
- indexing="ij",
- )).int().view(3, -1).t()) # [27, 3]
+ gird8_offsets = (
+ torch.stack(
+ torch.meshgrid(
+ [
+ torch.tensor([-1, 0, 1]),
+ torch.tensor([-1, 0, 1]),
+ torch.tensor([-1, 0, 1]),
+ ],
+ indexing="ij",
+ )
+ ).int().view(3, -1).t()
+ ) # [27, 3]
self.register_buffer("gird8_offsets", gird8_offsets)
# smooth convs
- self.smooth_conv3x3 = SmoothConv3D(in_channels=1,
- out_channels=1,
- kernel_size=3)
- self.smooth_conv5x5 = SmoothConv3D(in_channels=1,
- out_channels=1,
- kernel_size=5)
- self.smooth_conv7x7 = SmoothConv3D(in_channels=1,
- out_channels=1,
- kernel_size=7)
- self.smooth_conv9x9 = SmoothConv3D(in_channels=1,
- out_channels=1,
- kernel_size=9)
+ self.smooth_conv3x3 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=3)
+ self.smooth_conv5x5 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=5)
+ self.smooth_conv7x7 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=7)
+ self.smooth_conv9x9 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=9)
@torch.no_grad()
def batch_eval(self, coords, **kwargs):
@@ -144,7 +129,7 @@ class Seg3dLossless(nn.Module):
# query function
occupancys = self.query_func(**kwargs, points=coords2D)
if type(occupancys) is list:
- occupancys = torch.stack(occupancys) # [bz, C, N]
+ occupancys = torch.stack(occupancys) # [bz, C, N]
assert (
len(occupancys.size()) == 3
), "query_func should return a occupancy with shape of [bz, C, N]"
@@ -175,10 +160,9 @@ class Seg3dLossless(nn.Module):
# first step
if torch.equal(resolution, self.resolutions[0]):
- coords = self.init_coords.clone() # torch.long
+ coords = self.init_coords.clone() # torch.long
occupancys = self.batch_eval(coords, **kwargs)
- occupancys = occupancys.view(self.batchsize, self.channels, D,
- H, W)
+ occupancys = occupancys.view(self.batchsize, self.channels, D, H, W)
if (occupancys > 0.5).sum() == 0:
# return F.interpolate(
# occupancys, size=(final_D, final_H, final_W),
@@ -239,23 +223,22 @@ class Seg3dLossless(nn.Module):
with torch.no_grad():
if torch.equal(resolution, self.resolutions[1]):
- is_boundary = (self.smooth_conv9x9(is_boundary.float())
- > 0)[0, 0]
+ is_boundary = (self.smooth_conv9x9(is_boundary.float()) > 0)[0, 0]
elif torch.equal(resolution, self.resolutions[2]):
- is_boundary = (self.smooth_conv7x7(is_boundary.float())
- > 0)[0, 0]
+ is_boundary = (self.smooth_conv7x7(is_boundary.float()) > 0)[0, 0]
else:
- is_boundary = (self.smooth_conv3x3(is_boundary.float())
- > 0)[0, 0]
+ is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0]
coords_accum = coords_accum.long()
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
coords_accum[0, :, 0], ] = False
- point_coords = (is_boundary.permute(
- 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0))
- point_indices = (point_coords[:, :, 2] * H * W +
- point_coords[:, :, 1] * W +
- point_coords[:, :, 0])
+ point_coords = (
+ is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
+ )
+ point_indices = (
+ point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
+ point_coords[:, :, 0]
+ )
R, C, D, H, W = occupancys.shape
@@ -269,13 +252,15 @@ class Seg3dLossless(nn.Module):
# put mask point predictions to the right places on the upsampled grid.
R, C, D, H, W = occupancys.shape
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
- occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
- 2, point_indices, occupancys_topk).view(R, C, D, H, W))
+ occupancys = (
+ occupancys.reshape(R, C,
+ D * H * W).scatter_(2, point_indices,
+ occupancys_topk).view(R, C, D, H, W)
+ )
with torch.no_grad():
voxels = coords / stride
- coords_accum = torch.cat([voxels, coords_accum],
- dim=1).unique(dim=1)
+ coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
return occupancys[0, 0]
@@ -300,18 +285,16 @@ class Seg3dLossless(nn.Module):
# first step
if torch.equal(resolution, self.resolutions[0]):
- coords = self.init_coords.clone() # torch.long
+ coords = self.init_coords.clone() # torch.long
occupancys = self.batch_eval(coords, **kwargs)
- occupancys = occupancys.view(self.batchsize, self.channels, D,
- H, W)
+ occupancys = occupancys.view(self.batchsize, self.channels, D, H, W)
if self.visualize:
self.plot(occupancys, coords, final_D, final_H, final_W)
with torch.no_grad():
coords_accum = coords / stride
- calculated[coords[0, :, 2], coords[0, :, 1],
- coords[0, :, 0]] = True
+ calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
# next steps
else:
@@ -338,35 +321,34 @@ class Seg3dLossless(nn.Module):
with torch.no_grad():
# TODO
- if self.use_shadow and torch.equal(resolution,
- self.resolutions[-1]):
+ if self.use_shadow and torch.equal(resolution, self.resolutions[-1]):
# larger z means smaller depth here
depth_res = resolution[2].item()
- depth_index = torch.linspace(0,
- depth_res - 1,
- steps=depth_res).type_as(
- occupancys.device)
- depth_index_max = (torch.max(
- (occupancys > self.balance_value) *
- (depth_index + 1),
- dim=-1,
- keepdim=True,
- )[0] - 1)
+ depth_index = torch.linspace(0, depth_res - 1,
+ steps=depth_res).type_as(occupancys.device)
+ depth_index_max = (
+ torch.max(
+ (occupancys > self.balance_value) * (depth_index + 1),
+ dim=-1,
+ keepdim=True,
+ )[0] - 1
+ )
shadow = depth_index < depth_index_max
is_boundary[shadow] = False
is_boundary = is_boundary[0, 0]
else:
- is_boundary = (self.smooth_conv3x3(is_boundary.float())
- > 0)[0, 0]
+ is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0]
# is_boundary = is_boundary[0, 0]
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
coords_accum[0, :, 0], ] = False
- point_coords = (is_boundary.permute(
- 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0))
- point_indices = (point_coords[:, :, 2] * H * W +
- point_coords[:, :, 1] * W +
- point_coords[:, :, 0])
+ point_coords = (
+ is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
+ )
+ point_indices = (
+ point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
+ point_coords[:, :, 0]
+ )
R, C, D, H, W = occupancys.shape
# interpolated value
@@ -388,28 +370,28 @@ class Seg3dLossless(nn.Module):
# put mask point predictions to the right places on the upsampled grid.
R, C, D, H, W = occupancys.shape
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
- occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
- 2, point_indices, occupancys_topk).view(R, C, D, H, W))
+ occupancys = (
+ occupancys.reshape(R, C,
+ D * H * W).scatter_(2, point_indices,
+ occupancys_topk).view(R, C, D, H, W)
+ )
with torch.no_grad():
# conflicts
- conflicts = ((occupancys_interp - self.balance_value) *
- (occupancys_topk - self.balance_value) < 0)[0,
- 0]
+ conflicts = (
+ (occupancys_interp - self.balance_value) *
+ (occupancys_topk - self.balance_value) < 0
+ )[0, 0]
if self.visualize:
- self.plot(occupancys, coords, final_D, final_H,
- final_W)
+ self.plot(occupancys, coords, final_D, final_H, final_W)
voxels = coords / stride
- coords_accum = torch.cat([voxels, coords_accum],
- dim=1).unique(dim=1)
- calculated[coords[0, :, 2], coords[0, :, 1],
- coords[0, :, 0]] = True
+ coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
+ calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
while conflicts.sum() > 0:
- if self.use_shadow and torch.equal(resolution,
- self.resolutions[-1]):
+ if self.use_shadow and torch.equal(resolution, self.resolutions[-1]):
break
with torch.no_grad():
@@ -426,25 +408,27 @@ class Seg3dLossless(nn.Module):
)
conflicts_boundary = (
- (conflicts_coords.int() +
- self.gird8_offsets.unsqueeze(1) *
- stride.int()).reshape(-1, 3).long().unique(dim=0))
- conflicts_boundary[:,
- 0] = conflicts_boundary[:, 0].clamp(
- 0,
- calculated.size(2) - 1)
- conflicts_boundary[:,
- 1] = conflicts_boundary[:, 1].clamp(
- 0,
- calculated.size(1) - 1)
- conflicts_boundary[:,
- 2] = conflicts_boundary[:, 2].clamp(
- 0,
- calculated.size(0) - 1)
-
- coords = conflicts_boundary[calculated[
- conflicts_boundary[:, 2], conflicts_boundary[:, 1],
- conflicts_boundary[:, 0], ] == False]
+ (
+ conflicts_coords.int() +
+ self.gird8_offsets.unsqueeze(1) * stride.int()
+ ).reshape(-1, 3).long().unique(dim=0)
+ )
+ conflicts_boundary[:, 0] = conflicts_boundary[:, 0].clamp(
+ 0,
+ calculated.size(2) - 1
+ )
+ conflicts_boundary[:, 1] = conflicts_boundary[:, 1].clamp(
+ 0,
+ calculated.size(1) - 1
+ )
+ conflicts_boundary[:, 2] = conflicts_boundary[:, 2].clamp(
+ 0,
+ calculated.size(0) - 1
+ )
+
+ coords = conflicts_boundary[calculated[conflicts_boundary[:, 2],
+ conflicts_boundary[:, 1],
+ conflicts_boundary[:, 0], ] == False]
if self.debug:
self.plot(
@@ -458,9 +442,10 @@ class Seg3dLossless(nn.Module):
coords = coords.unsqueeze(0)
point_coords = coords / stride
- point_indices = (point_coords[:, :, 2] * H * W +
- point_coords[:, :, 1] * W +
- point_coords[:, :, 0])
+ point_indices = (
+ point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
+ point_coords[:, :, 0]
+ )
R, C, D, H, W = occupancys.shape
# interpolated value
@@ -481,44 +466,37 @@ class Seg3dLossless(nn.Module):
with torch.no_grad():
# conflicts
- conflicts = ((occupancys_interp - self.balance_value) *
- (occupancys_topk - self.balance_value) <
- 0)[0, 0]
+ conflicts = (
+ (occupancys_interp - self.balance_value) *
+ (occupancys_topk - self.balance_value) < 0
+ )[0, 0]
# put mask point predictions to the right places on the upsampled grid.
- point_indices = point_indices.unsqueeze(1).expand(
- -1, C, -1)
- occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
- 2, point_indices, occupancys_topk).view(R, C, D, H, W))
+ point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
+ occupancys = (
+ occupancys.reshape(R, C,
+ D * H * W).scatter_(2, point_indices,
+ occupancys_topk).view(R, C, D, H, W)
+ )
with torch.no_grad():
voxels = coords / stride
- coords_accum = torch.cat([voxels, coords_accum],
- dim=1).unique(dim=1)
- calculated[coords[0, :, 2], coords[0, :, 1],
- coords[0, :, 0]] = True
+ coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
+ calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
if self.visualize:
this_stage_coords = torch.cat(this_stage_coords, dim=1)
- self.plot(occupancys, this_stage_coords, final_D, final_H,
- final_W)
+ self.plot(occupancys, this_stage_coords, final_D, final_H, final_W)
return occupancys[0, 0]
- def plot(self,
- occupancys,
- coords,
- final_D,
- final_H,
- final_W,
- title="",
- **kwargs):
+ def plot(self, occupancys, coords, final_D, final_H, final_W, title="", **kwargs):
final = F.interpolate(
occupancys.float(),
size=(final_D, final_H, final_W),
mode="trilinear",
align_corners=True,
- ) # here true is correct!
+ ) # here true is correct!
x = coords[0, :, 0].to("cpu")
y = coords[0, :, 1].to("cpu")
z = coords[0, :, 2].to("cpu")
@@ -548,20 +526,19 @@ class Seg3dLossless(nn.Module):
sdf_all = sdf.permute(2, 1, 0)
# shadow
- grad_v = (sdf_all > 0.5) * torch.linspace(
- resolution, 1, steps=resolution).to(sdf.device)
- grad_c = torch.ones_like(sdf_all) * torch.linspace(
- 0, resolution - 1, steps=resolution).to(sdf.device)
+ grad_v = (sdf_all > 0.5) * torch.linspace(resolution, 1, steps=resolution).to(sdf.device)
+ grad_c = torch.ones_like(sdf_all) * torch.linspace(0, resolution - 1,
+ steps=resolution).to(sdf.device)
max_v, max_c = grad_v.max(dim=2)
shadow = grad_c > max_c.view(resolution, resolution, 1)
keep = (sdf_all > 0.5) & (~shadow)
- p1 = keep.nonzero(as_tuple=False).t() # [3, N]
- p2 = p1.clone() # z
+ p1 = keep.nonzero(as_tuple=False).t() # [3, N]
+ p2 = p1.clone() # z
p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
- p3 = p1.clone() # y
+ p3 = p1.clone() # y
p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
- p4 = p1.clone() # x
+ p4 = p1.clone() # x
p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
@@ -569,10 +546,10 @@ class Seg3dLossless(nn.Module):
v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
- X = p1[0, :].long() # [N,]
- Y = p1[1, :].long() # [N,]
- Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + p1[2, :].float() * (
- v2 - 0.5) / (v2 - v1) # [N,]
+ X = p1[0, :].long() # [N,]
+ Y = p1[1, :].long() # [N,]
+ Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + p1[2, :].float() * (v2 - 0.5
+ ) / (v2 - v1) # [N,]
Z = Z.clamp(0, resolution)
# normal
@@ -588,8 +565,7 @@ class Seg3dLossless(nn.Module):
@torch.no_grad()
def render_normal(self, resolution, X, Y, Z, norm):
- image = torch.ones((1, 3, resolution, resolution),
- dtype=torch.float32).to(norm.device)
+ image = torch.ones((1, 3, resolution, resolution), dtype=torch.float32).to(norm.device)
color = (norm + 1) / 2.0
color = color.clamp(0, 1)
image[0, :, Y, X] = color.t()
@@ -617,9 +593,9 @@ class Seg3dLossless(nn.Module):
def export_mesh(self, occupancys):
final = occupancys[1:, 1:, 1:].contiguous()
-
+
verts, faces = marching_cubes(final.unsqueeze(0), isolevel=0.5)
verts = verts[0].cpu().float()
- faces = faces[0].cpu().long()[:,[0,2,1]]
-
+ faces = faces[0].cpu().long()[:, [0, 2, 1]]
+
return verts, faces
diff --git a/lib/common/seg3d_utils.py b/lib/common/seg3d_utils.py
index 958fb338f02d814c1d73c37edbcad777c2cff9fe..bee264615a54777bb948414c82f502c678664329 100644
--- a/lib/common/seg3d_utils.py
+++ b/lib/common/seg3d_utils.py
@@ -20,11 +20,7 @@ import torch.nn.functional as F
import matplotlib.pyplot as plt
-def plot_mask2D(mask,
- title="",
- point_coords=None,
- figsize=10,
- point_marker_size=5):
+def plot_mask2D(mask, title="", point_coords=None, figsize=10, point_marker_size=5):
'''
Simple plotting tool to show intermediate mask predictions and points
where PointRend is applied.
@@ -46,26 +42,19 @@ def plot_mask2D(mask,
plt.xlabel(W, fontsize=30)
plt.xticks([], [])
plt.yticks([], [])
- plt.imshow(mask.detach(),
- interpolation="nearest",
- cmap=plt.get_cmap('gray'))
+ plt.imshow(mask.detach(), interpolation="nearest", cmap=plt.get_cmap('gray'))
if point_coords is not None:
- plt.scatter(x=point_coords[0],
- y=point_coords[1],
- color="red",
- s=point_marker_size,
- clip_on=True)
+ plt.scatter(
+ x=point_coords[0], y=point_coords[1], color="red", s=point_marker_size, clip_on=True
+ )
plt.xlim(-0.5, W - 0.5)
plt.ylim(H - 0.5, -0.5)
plt.show()
-def plot_mask3D(mask=None,
- title="",
- point_coords=None,
- figsize=1500,
- point_marker_size=8,
- interactive=True):
+def plot_mask3D(
+ mask=None, title="", point_coords=None, figsize=1500, point_marker_size=8, interactive=True
+):
'''
Simple plotting tool to show intermediate mask predictions and points
where PointRend is applied.
@@ -90,7 +79,8 @@ def plot_mask3D(mask=None,
# marching cube to find surface
verts, faces, normals, values = measure.marching_cubes_lewiner(
- mask, 0.5, gradient_direction='ascent')
+ mask, 0.5, gradient_direction='ascent'
+ )
# create a mesh
mesh = trimesh.Trimesh(verts, faces)
@@ -110,57 +100,49 @@ def plot_mask3D(mask=None,
pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
vis_list.append(pc)
- vp.show(*vis_list,
- bg="white",
- axes=1,
- interactive=interactive,
- azimuth=30,
- elevation=30)
+ vp.show(*vis_list, bg="white", axes=1, interactive=interactive, azimuth=30, elevation=30)
def create_grid3D(min, max, steps):
if type(min) is int:
- min = (min, min, min) # (x, y, z)
+ min = (min, min, min) # (x, y, z)
if type(max) is int:
- max = (max, max, max) # (x, y)
+ max = (max, max, max) # (x, y)
if type(steps) is int:
- steps = (steps, steps, steps) # (x, y, z)
+ steps = (steps, steps, steps) # (x, y, z)
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
- gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX],
- indexing='ij')
- coords = torch.stack([gridW, girdH,
- gridD]) # [2, steps[0], steps[1], steps[2]]
- coords = coords.view(3, -1).t() # [N, 3]
+ gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX], indexing='ij')
+ coords = torch.stack([gridW, girdH, gridD]) # [2, steps[0], steps[1], steps[2]]
+ coords = coords.view(3, -1).t() # [N, 3]
return coords
def create_grid2D(min, max, steps):
if type(min) is int:
- min = (min, min) # (x, y)
+ min = (min, min) # (x, y)
if type(max) is int:
- max = (max, max) # (x, y)
+ max = (max, max) # (x, y)
if type(steps) is int:
- steps = (steps, steps) # (x, y)
+ steps = (steps, steps) # (x, y)
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
girdH, gridW = torch.meshgrid([arrangeY, arrangeX], indexing='ij')
- coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
- coords = coords.view(2, -1).t() # [N, 2]
+ coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
+ coords = coords.view(2, -1).t() # [N, 2]
return coords
class SmoothConv2D(nn.Module):
-
def __init__(self, in_channels, out_channels, kernel_size=3):
super().__init__()
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
self.padding = (kernel_size - 1) // 2
weight = torch.ones(
- (in_channels, out_channels, kernel_size, kernel_size),
- dtype=torch.float32) / (kernel_size**2)
+ (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32
+ ) / (kernel_size**2)
self.register_buffer('weight', weight)
def forward(self, input):
@@ -168,53 +150,49 @@ class SmoothConv2D(nn.Module):
class SmoothConv3D(nn.Module):
-
def __init__(self, in_channels, out_channels, kernel_size=3):
super().__init__()
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
self.padding = (kernel_size - 1) // 2
weight = torch.ones(
- (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
- dtype=torch.float32) / (kernel_size**3)
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32
+ ) / (kernel_size**3)
self.register_buffer('weight', weight)
def forward(self, input):
return F.conv3d(input, self.weight, padding=self.padding)
-def build_smooth_conv3D(in_channels=1,
- out_channels=1,
- kernel_size=3,
- padding=1):
- smooth_conv = torch.nn.Conv3d(in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- padding=padding)
+def build_smooth_conv3D(in_channels=1, out_channels=1, kernel_size=3, padding=1):
+ smooth_conv = torch.nn.Conv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding
+ )
smooth_conv.weight.data = torch.ones(
- (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
- dtype=torch.float32) / (kernel_size**3)
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32
+ ) / (kernel_size**3)
smooth_conv.bias.data = torch.zeros(out_channels)
return smooth_conv
-def build_smooth_conv2D(in_channels=1,
- out_channels=1,
- kernel_size=3,
- padding=1):
- smooth_conv = torch.nn.Conv2d(in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- padding=padding)
+def build_smooth_conv2D(in_channels=1, out_channels=1, kernel_size=3, padding=1):
+ smooth_conv = torch.nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding
+ )
smooth_conv.weight.data = torch.ones(
- (in_channels, out_channels, kernel_size, kernel_size),
- dtype=torch.float32) / (kernel_size**2)
+ (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32
+ ) / (kernel_size**2)
smooth_conv.bias.data = torch.zeros(out_channels)
return smooth_conv
-def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
- **kwargs):
+def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, **kwargs):
"""
Find `num_points` most uncertain points from `uncertainty_map` grid.
Args:
@@ -233,28 +211,21 @@ def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
# d_step = 1.0 / float(D)
num_points = min(D * H * W, num_points)
- point_scores, point_indices = torch.topk(uncertainty_map.view(
- R, D * H * W),
- k=num_points,
- dim=1)
- point_coords = torch.zeros(R,
- num_points,
- 3,
- dtype=torch.float,
- device=uncertainty_map.device)
+ point_scores, point_indices = torch.topk(
+ uncertainty_map.view(R, D * H * W), k=num_points, dim=1
+ )
+ point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device)
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
- point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
- point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
- point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
- print(f"resolution {D} x {H} x {W}", point_scores.min(),
- point_scores.max())
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
+ print(f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
return point_indices, point_coords
-def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
- clip_min):
+def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, clip_min):
"""
Find `num_points` most uncertain points from `uncertainty_map` grid.
Args:
@@ -276,28 +247,21 @@ def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
uncertainty_map = uncertainty_map.view(D * H * W)
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
num_points = min(num_points, indices.size(0))
- point_scores, point_indices = torch.topk(uncertainty_map[indices],
- k=num_points,
- dim=0)
+ point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0)
point_indices = indices[point_indices].unsqueeze(0)
- point_coords = torch.zeros(R,
- num_points,
- 3,
- dtype=torch.float,
- device=uncertainty_map.device)
+ point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device)
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
- point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
- point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
- point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
# print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
return point_indices, point_coords
-def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
- **kwargs):
+def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, **kwargs):
"""
Find `num_points` most uncertain points from `uncertainty_map` grid.
Args:
@@ -315,14 +279,8 @@ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
# w_step = 1.0 / float(W)
num_points = min(H * W, num_points)
- point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W),
- k=num_points,
- dim=1)
- point_coords = torch.zeros(R,
- num_points,
- 2,
- dtype=torch.long,
- device=uncertainty_map.device)
+ point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)
+ point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device)
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
@@ -331,8 +289,7 @@ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
return point_indices, point_coords
-def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
- clip_min):
+def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, clip_min):
"""
Find `num_points` most uncertain points from `uncertainty_map` grid.
Args:
@@ -353,16 +310,10 @@ def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
uncertainty_map = uncertainty_map.view(H * W)
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
num_points = min(num_points, indices.size(0))
- point_scores, point_indices = torch.topk(uncertainty_map[indices],
- k=num_points,
- dim=0)
+ point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0)
point_indices = indices[point_indices].unsqueeze(0)
- point_coords = torch.zeros(R,
- num_points,
- 2,
- dtype=torch.long,
- device=uncertainty_map.device)
+ point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device)
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
@@ -388,7 +339,6 @@ def calculate_uncertainty(logits, classes=None, balance_value=0.5):
if logits.shape[1] == 1:
gt_class_logits = logits
else:
- gt_class_logits = logits[
- torch.arange(logits.shape[0], device=logits.device),
- classes].unsqueeze(1)
+ gt_class_logits = logits[torch.arange(logits.shape[0], device=logits.device),
+ classes].unsqueeze(1)
return -torch.abs(gt_class_logits - balance_value)
diff --git a/lib/common/train_util.py b/lib/common/train_util.py
index 06ff842e30fab3f29b4046a518f48f46bcb8eb76..a39102a5849e25d056b9d96d2df9538790bec6ea 100644
--- a/lib/common/train_util.py
+++ b/lib/common/train_util.py
@@ -14,63 +14,62 @@
#
# Contact: ps-license@tuebingen.mpg.de
-import yaml
-import os.path as osp
import torch
-import numpy as np
from ..dataset.mesh_util import *
from ..net.geometry import orthogonal
-import cv2, PIL
-from tqdm import tqdm
-import os
from termcolor import colored
import pytorch_lightning as pl
+class Format:
+ end = '\033[0m'
+ start = '\033[4m'
+
+
def init_loss():
losses = {
- # Cloth: Normal_recon - Normal_pred
+ # Cloth: chamfer distance
"cloth": {
"weight": 1e3,
"value": 0.0
},
- # Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
- "stiffness": {
+ # Stiffness: [RT]_v1 - [RT]_v2 (v1-edge-v2)
+ "stiff": {
"weight": 1e5,
"value": 0.0
},
- # Cloth: det(R) = 1
+ # Cloth: det(R) = 1
"rigid": {
"weight": 1e5,
"value": 0.0
},
- # Cloth: edge length
+ # Cloth: edge length
"edge": {
"weight": 0,
"value": 0.0
},
- # Cloth: normal consistency
+ # Cloth: normal consistency
"nc": {
"weight": 0,
"value": 0.0
},
- # Cloth: laplacian smoonth
- "laplacian": {
+ # Cloth: laplacian smoonth
+ "lapla": {
"weight": 1e2,
"value": 0.0
},
- # Body: Normal_pred - Normal_smpl
+ # Body: Normal_pred - Normal_smpl
"normal": {
"weight": 1e0,
"value": 0.0
},
- # Body: Silhouette_pred - Silhouette_smpl
+ # Body: Silhouette_pred - Silhouette_smpl
"silhouette": {
"weight": 1e0,
"value": 0.0
},
- # Joint: reprojected joints difference
+ # Joint: reprojected joints difference
"joint": {
"weight": 5e0,
"value": 0.0
@@ -81,7 +80,6 @@ def init_loss():
class SubTrainer(pl.Trainer):
-
def save_checkpoint(self, filepath, weights_only=False):
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
@@ -101,214 +99,6 @@ class SubTrainer(pl.Trainer):
pl.utilities.cloud_io.atomic_save(_checkpoint, filepath)
-def rename(old_dict, old_name, new_name):
- new_dict = {}
- for key, value in zip(old_dict.keys(), old_dict.values()):
- new_key = key if key != old_name else new_name
- new_dict[new_key] = old_dict[key]
- return new_dict
-
-
-def load_normal_networks(model, normal_path):
-
- pretrained_dict = torch.load(
- normal_path,
- map_location=model.device)["state_dict"]
- model_dict = model.state_dict()
-
- # 1. filter out unnecessary keys
- pretrained_dict = {
- k: v
- for k, v in pretrained_dict.items()
- if k in model_dict and v.shape == model_dict[k].shape
- }
-
- # # 2. overwrite entries in the existing state dict
- model_dict.update(pretrained_dict)
- # 3. load the new state dict
- model.load_state_dict(model_dict)
-
- del pretrained_dict
- del model_dict
-
- print(colored(f"Resume Normal weights from {normal_path}", "green"))
-
-
-def load_networks(model, mlp_path, normal_path=None):
-
- model_dict = model.state_dict()
- main_dict = {}
- normal_dict = {}
-
- # MLP part loading
- if os.path.exists(mlp_path) and mlp_path.endswith("ckpt"):
- main_dict = torch.load(
- mlp_path,
- map_location=model.device)["state_dict"]
-
- main_dict = {
- k: v
- for k, v in main_dict.items()
- if k in model_dict and v.shape == model_dict[k].shape and (
- "reconEngine" not in k) and ("normal_filter" not in k) and (
- "voxelization" not in k)
- }
- print(colored(f"Resume MLP weights from {mlp_path}", "green"))
-
- # normal network part loading
- if normal_path is not None and os.path.exists(normal_path) and normal_path.endswith("ckpt"):
- normal_dict = torch.load(
- normal_path,
- map_location=model.device)["state_dict"]
-
- for key in normal_dict.keys():
- normal_dict = rename(normal_dict, key,
- key.replace("netG", "netG.normal_filter"))
-
- normal_dict = {
- k: v
- for k, v in normal_dict.items()
- if k in model_dict and v.shape == model_dict[k].shape
- }
- print(colored(f"Resume normal model from {normal_path}", "green"))
-
- model_dict.update(main_dict)
- model_dict.update(normal_dict)
- model.load_state_dict(model_dict)
-
- # clean unused GPU memory
- del main_dict
- del normal_dict
- del model_dict
- torch.cuda.empty_cache()
-
-
-def reshape_sample_tensor(sample_tensor, num_views):
- if num_views == 1:
- return sample_tensor
- # Need to repeat sample_tensor along the batch dim num_views times
- sample_tensor = sample_tensor.unsqueeze(dim=1)
- sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
- sample_tensor = sample_tensor.view(
- sample_tensor.shape[0] * sample_tensor.shape[1],
- sample_tensor.shape[2],
- sample_tensor.shape[3],
- )
- return sample_tensor
-
-
-def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
- """Sets the learning rate to the initial LR decayed by schedule"""
- if epoch in schedule:
- lr *= gamma
- for param_group in optimizer.param_groups:
- param_group["lr"] = lr
- return lr
-
-
-def compute_acc(pred, gt, thresh=0.5):
- """
- return:
- IOU, precision, and recall
- """
- with torch.no_grad():
- vol_pred = pred > thresh
- vol_gt = gt > thresh
-
- union = vol_pred | vol_gt
- inter = vol_pred & vol_gt
-
- true_pos = inter.sum().float()
-
- union = union.sum().float()
- if union == 0:
- union = 1
- vol_pred = vol_pred.sum().float()
- if vol_pred == 0:
- vol_pred = 1
- vol_gt = vol_gt.sum().float()
- if vol_gt == 0:
- vol_gt = 1
- return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
-
-def calc_error(opt, net, cuda, dataset, num_tests):
- if num_tests > len(dataset):
- num_tests = len(dataset)
- with torch.no_grad():
- erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
- for idx in tqdm(range(num_tests)):
- data = dataset[idx * len(dataset) // num_tests]
- # retrieve the data
- image_tensor = data["img"].to(device=cuda)
- calib_tensor = data["calib"].to(device=cuda)
- sample_tensor = data["samples"].to(device=cuda).unsqueeze(0)
- if opt.num_views > 1:
- sample_tensor = reshape_sample_tensor(sample_tensor,
- opt.num_views)
- label_tensor = data["labels"].to(device=cuda).unsqueeze(0)
-
- res, error = net.forward(image_tensor,
- sample_tensor,
- calib_tensor,
- labels=label_tensor)
-
- IOU, prec, recall = compute_acc(res, label_tensor)
-
- # print(
- # '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
- # .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
- erorr_arr.append(error.item())
- IOU_arr.append(IOU.item())
- prec_arr.append(prec.item())
- recall_arr.append(recall.item())
-
- return (
- np.average(erorr_arr),
- np.average(IOU_arr),
- np.average(prec_arr),
- np.average(recall_arr),
- )
-
-
-def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
- if num_tests > len(dataset):
- num_tests = len(dataset)
- with torch.no_grad():
- error_color_arr = []
-
- for idx in tqdm(range(num_tests)):
- data = dataset[idx * len(dataset) // num_tests]
- # retrieve the data
- image_tensor = data["img"].to(device=cuda)
- calib_tensor = data["calib"].to(device=cuda)
- color_sample_tensor = data["color_samples"].to(
- device=cuda).unsqueeze(0)
-
- if opt.num_views > 1:
- color_sample_tensor = reshape_sample_tensor(
- color_sample_tensor, opt.num_views)
-
- rgb_tensor = data["rgbs"].to(device=cuda).unsqueeze(0)
-
- netG.filter(image_tensor)
- _, errorC = netC.forward(
- image_tensor,
- netG.get_im_feat(),
- color_sample_tensor,
- calib_tensor,
- labels=rgb_tensor,
- )
-
- # print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
- # .format(idx, num_tests, errorG.item(), errorC.item()))
- error_color_arr.append(errorC.item())
-
- return np.average(error_color_arr)
-
-
-# pytorch lightning training related fucntions
-
-
def query_func(opt, netG, features, points, proj_matrix=None):
"""
- points: size of (bz, N, 3)
@@ -317,7 +107,7 @@ def query_func(opt, netG, features, points, proj_matrix=None):
"""
assert len(points) == 1
samples = points.repeat(opt.num_views, 1, 1)
- samples = samples.permute(0, 2, 1) # [bz, 3, N]
+ samples = samples.permute(0, 2, 1) # [bz, 3, N]
# view specific query
if proj_matrix is not None:
@@ -337,85 +127,25 @@ def query_func(opt, netG, features, points, proj_matrix=None):
return preds
+
def query_func_IF(batch, netG, points):
"""
- points: size of (bz, N, 3)
return: size of (bz, 1, N)
"""
-
+
batch["samples_geo"] = points
batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points)
-
+
preds = netG(batch)
return preds.unsqueeze(1)
-def isin(ar1, ar2):
- return (ar1[..., None] == ar2).any(-1)
-
-
-def in1d(ar1, ar2):
- mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool)
- mask[ar2.unique()] = True
- return mask[ar1]
-
def batch_mean(res, key):
- return torch.stack([
- x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key])
- for x in res
- ]).mean()
-
-
-def tf_log_convert(log_dict):
- new_log_dict = log_dict.copy()
- for k, v in log_dict.items():
- new_log_dict[k.replace("_", "/")] = v
- del new_log_dict[k]
-
- return new_log_dict
-
-
-def bar_log_convert(log_dict, name=None, rot=None):
- from decimal import Decimal
-
- new_log_dict = {}
-
- if name is not None:
- new_log_dict["name"] = name[0]
- if rot is not None:
- new_log_dict["rot"] = rot[0]
-
- for k, v in log_dict.items():
- color = "yellow"
- if "loss" in k:
- color = "red"
- k = k.replace("loss", "L")
- elif "acc" in k:
- color = "green"
- k = k.replace("acc", "A")
- elif "iou" in k:
- color = "green"
- k = k.replace("iou", "I")
- elif "prec" in k:
- color = "green"
- k = k.replace("prec", "P")
- elif "recall" in k:
- color = "green"
- k = k.replace("recall", "R")
-
- if "lr" not in k:
- new_log_dict[colored(k.split("_")[1],
- color)] = colored(f"{v:.3f}", color)
- else:
- new_log_dict[colored(k.split("_")[1],
- color)] = colored(f"{Decimal(str(v)):.1E}",
- color)
-
- if "loss" in new_log_dict.keys():
- del new_log_dict["loss"]
-
- return new_log_dict
+ return torch.stack(
+ [x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res]
+ ).mean()
def accumulate(outputs, rot_num, split):
@@ -430,160 +160,10 @@ def accumulate(outputs, rot_num, split):
keyword = f"{dataset}/{metric}"
if keyword not in hparam_log_dict.keys():
hparam_log_dict[keyword] = 0
- for idx in range(split[dataset][0] * rot_num,
- split[dataset][1] * rot_num):
+ for idx in range(split[dataset][0] * rot_num, split[dataset][1] * rot_num):
hparam_log_dict[keyword] += outputs[idx][metric].item()
- hparam_log_dict[keyword] /= (split[dataset][1] -
- split[dataset][0]) * rot_num
+ hparam_log_dict[keyword] /= (split[dataset][1] - split[dataset][0]) * rot_num
print(colored(hparam_log_dict, "green"))
return hparam_log_dict
-
-
-def calc_error_N(outputs, targets):
- """calculate the error of normal (IGR)
-
- Args:
- outputs (torch.tensor): [B, 3, N]
- target (torch.tensor): [B, N, 3]
-
- # manifold loss and grad_loss in IGR paper
- grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
- normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
-
- Returns:
- torch.tensor: error of valid normals on the surface
- """
- # outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3))
- outputs = -outputs.permute(0, 2, 1).reshape(-1, 1)
- targets = targets.reshape(-1, 3)[:, 2:3]
- with_normals = targets.sum(dim=1).abs() > 0.0
-
- # eikonal loss
- grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean()
- # normals loss
- normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean()
-
- return grad_loss * 0.0 + normal_loss
-
-
-def calc_knn_acc(preds, carn_verts, labels, pick_num):
- """calculate knn accuracy
-
- Args:
- preds (torch.tensor): [B, 3, N]
- carn_verts (torch.tensor): [SMPLX_V_num, 3]
- labels (torch.tensor): [B, N_knn, N]
- """
- N_knn_full = labels.shape[1]
- preds = preds.permute(0, 2, 1).reshape(-1, 3)
- labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn]
- labels = labels[:, :pick_num]
-
- dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num]
- knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn]
- cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0]
- bool_col = torch.zeros_like(cat_mat)[:, 0]
- for i in range(pick_num * 2 - 1):
- bool_col += cat_mat[:, i] == cat_mat[:, i + 1]
- acc = (bool_col > 0).sum() / len(bool_col)
-
- return acc
-
-
-def calc_acc_seg(output, target, num_multiseg):
- from pytorch_lightning.metrics import Accuracy
-
- return Accuracy()(output.reshape(-1, num_multiseg).cpu(),
- target.flatten().cpu())
-
-
-def add_watermark(imgs, titles):
-
- # Write some Text
-
- font = cv2.FONT_HERSHEY_SIMPLEX
- bottomLeftCornerOfText = (350, 50)
- bottomRightCornerOfText = (800, 50)
- fontScale = 1
- fontColor = (1.0, 1.0, 1.0)
- lineType = 2
-
- for i in range(len(imgs)):
-
- title = titles[i + 1]
- cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale,
- fontColor, lineType)
-
- if i == 0:
- cv2.putText(
- imgs[i],
- str(titles[i][0]),
- bottomRightCornerOfText,
- font,
- fontScale,
- fontColor,
- lineType,
- )
-
- result = np.concatenate(imgs, axis=0).transpose(2, 0, 1)
-
- return result
-
-
-def make_test_gif(img_dir):
-
- if img_dir is not None and len(os.listdir(img_dir)) > 0:
- for dataset in os.listdir(img_dir):
- for subject in sorted(os.listdir(osp.join(img_dir, dataset))):
- img_lst = []
- im1 = None
- for file in sorted(
- os.listdir(osp.join(img_dir, dataset, subject))):
- if file[-3:] not in ["obj", "gif"]:
- img_path = os.path.join(img_dir, dataset, subject,
- file)
- if im1 == None:
- im1 = PIL.Image.open(img_path)
- else:
- img_lst.append(PIL.Image.open(img_path))
-
- print(os.path.join(img_dir, dataset, subject, "out.gif"))
- im1.save(
- os.path.join(img_dir, dataset, subject, "out.gif"),
- save_all=True,
- append_images=img_lst,
- duration=500,
- loop=0,
- )
-
-
-def export_cfg(logger, dir, cfg):
-
- cfg_export_file = osp.join(dir, f"cfg_{logger.version}.yaml")
-
- if not osp.exists(cfg_export_file):
- os.makedirs(osp.dirname(cfg_export_file), exist_ok=True)
- with open(cfg_export_file, "w+") as file:
- _ = yaml.dump(cfg, file)
-
-
-from yacs.config import CfgNode
-
-_VALID_TYPES = {tuple, list, str, int, float, bool}
-
-
-def convert_to_dict(cfg_node, key_list=[]):
- """ Convert a config node to dictionary """
- if not isinstance(cfg_node, CfgNode):
- if type(cfg_node) not in _VALID_TYPES:
- print(
- "Key {} with value {} is not a valid type; valid types: {}".
- format(".".join(key_list), type(cfg_node), _VALID_TYPES), )
- return cfg_node
- else:
- cfg_dict = dict(cfg_node)
- for k, v in cfg_dict.items():
- cfg_dict[k] = convert_to_dict(v, key_list + [k])
- return cfg_dict
diff --git a/lib/common/voxelize.py b/lib/common/voxelize.py
index 112eb4ad82dc763e132ad7da49bb4eb409b15629..f792189ccc185e9a7b596eae5a9230fe21482aef 100644
--- a/lib/common/voxelize.py
+++ b/lib/common/voxelize.py
@@ -13,6 +13,7 @@ from lib.common.libmesh.inside_mesh import check_mesh_contains
# From Occupancy Networks, Mescheder et. al. CVPR'19
+
def make_3d_grid(bb_min, bb_max, shape):
''' Makes a 3D grid.
@@ -37,7 +38,7 @@ def make_3d_grid(bb_min, bb_max, shape):
class VoxelGrid:
def __init__(self, data, loc=(0., 0., 0.), scale=1):
- assert(data.shape[0] == data.shape[1] == data.shape[2])
+ assert (data.shape[0] == data.shape[1] == data.shape[2])
data = np.asarray(data, dtype=np.bool)
loc = np.asarray(loc)
self.data = data
@@ -53,7 +54,7 @@ class VoxelGrid:
# Default scale, scales the mesh to [-0.45, 0.45]^3
if scale is None:
- scale = (bounds[1] - bounds[0]).max()/0.9
+ scale = (bounds[1] - bounds[0]).max() / 0.9
loc = np.asarray(loc)
scale = float(scale)
@@ -61,7 +62,7 @@ class VoxelGrid:
# Transform mesh
mesh = mesh.copy()
mesh.apply_translation(-loc)
- mesh.apply_scale(1/scale)
+ mesh.apply_scale(1 / scale)
# Apply method
if method == 'ray':
@@ -75,7 +76,7 @@ class VoxelGrid:
def down_sample(self, factor=2):
if not (self.resolution % factor) == 0:
raise ValueError('Resolution must be divisible by factor.')
- new_data = block_reduce(self.data, (factor,) * 3, np.max)
+ new_data = block_reduce(self.data, (factor, ) * 3, np.max)
return VoxelGrid(new_data, self.loc, self.scale)
def to_mesh(self):
@@ -103,9 +104,9 @@ class VoxelGrid:
f2 = f2_r | f2_l
f3 = f3_r | f3_l
- assert(f1.shape == (nx + 1, ny, nz))
- assert(f2.shape == (nx, ny + 1, nz))
- assert(f3.shape == (nx, ny, nz + 1))
+ assert (f1.shape == (nx + 1, ny, nz))
+ assert (f2.shape == (nx, ny + 1, nz))
+ assert (f3.shape == (nx, ny, nz + 1))
# Determine if vertex present
v = np.full(grid_shape, False)
@@ -146,53 +147,76 @@ class VoxelGrid:
f2_r_x, f2_r_y, f2_r_z = np.where(f2_r)
f3_r_x, f3_r_y, f3_r_z = np.where(f3_r)
- faces_1_l = np.stack([
- v_idx[f1_l_x, f1_l_y, f1_l_z],
- v_idx[f1_l_x, f1_l_y, f1_l_z + 1],
- v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1],
- v_idx[f1_l_x, f1_l_y + 1, f1_l_z],
- ], axis=1)
-
- faces_1_r = np.stack([
- v_idx[f1_r_x, f1_r_y, f1_r_z],
- v_idx[f1_r_x, f1_r_y + 1, f1_r_z],
- v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1],
- v_idx[f1_r_x, f1_r_y, f1_r_z + 1],
- ], axis=1)
-
- faces_2_l = np.stack([
- v_idx[f2_l_x, f2_l_y, f2_l_z],
- v_idx[f2_l_x + 1, f2_l_y, f2_l_z],
- v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1],
- v_idx[f2_l_x, f2_l_y, f2_l_z + 1],
- ], axis=1)
-
- faces_2_r = np.stack([
- v_idx[f2_r_x, f2_r_y, f2_r_z],
- v_idx[f2_r_x, f2_r_y, f2_r_z + 1],
- v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1],
- v_idx[f2_r_x + 1, f2_r_y, f2_r_z],
- ], axis=1)
-
- faces_3_l = np.stack([
- v_idx[f3_l_x, f3_l_y, f3_l_z],
- v_idx[f3_l_x, f3_l_y + 1, f3_l_z],
- v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z],
- v_idx[f3_l_x + 1, f3_l_y, f3_l_z],
- ], axis=1)
-
- faces_3_r = np.stack([
- v_idx[f3_r_x, f3_r_y, f3_r_z],
- v_idx[f3_r_x + 1, f3_r_y, f3_r_z],
- v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z],
- v_idx[f3_r_x, f3_r_y + 1, f3_r_z],
- ], axis=1)
-
- faces = np.concatenate([
- faces_1_l, faces_1_r,
- faces_2_l, faces_2_r,
- faces_3_l, faces_3_r,
- ], axis=0)
+ faces_1_l = np.stack(
+ [
+ v_idx[f1_l_x, f1_l_y, f1_l_z],
+ v_idx[f1_l_x, f1_l_y, f1_l_z + 1],
+ v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1],
+ v_idx[f1_l_x, f1_l_y + 1, f1_l_z],
+ ],
+ axis=1
+ )
+
+ faces_1_r = np.stack(
+ [
+ v_idx[f1_r_x, f1_r_y, f1_r_z],
+ v_idx[f1_r_x, f1_r_y + 1, f1_r_z],
+ v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1],
+ v_idx[f1_r_x, f1_r_y, f1_r_z + 1],
+ ],
+ axis=1
+ )
+
+ faces_2_l = np.stack(
+ [
+ v_idx[f2_l_x, f2_l_y, f2_l_z],
+ v_idx[f2_l_x + 1, f2_l_y, f2_l_z],
+ v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1],
+ v_idx[f2_l_x, f2_l_y, f2_l_z + 1],
+ ],
+ axis=1
+ )
+
+ faces_2_r = np.stack(
+ [
+ v_idx[f2_r_x, f2_r_y, f2_r_z],
+ v_idx[f2_r_x, f2_r_y, f2_r_z + 1],
+ v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1],
+ v_idx[f2_r_x + 1, f2_r_y, f2_r_z],
+ ],
+ axis=1
+ )
+
+ faces_3_l = np.stack(
+ [
+ v_idx[f3_l_x, f3_l_y, f3_l_z],
+ v_idx[f3_l_x, f3_l_y + 1, f3_l_z],
+ v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z],
+ v_idx[f3_l_x + 1, f3_l_y, f3_l_z],
+ ],
+ axis=1
+ )
+
+ faces_3_r = np.stack(
+ [
+ v_idx[f3_r_x, f3_r_y, f3_r_z],
+ v_idx[f3_r_x + 1, f3_r_y, f3_r_z],
+ v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z],
+ v_idx[f3_r_x, f3_r_y + 1, f3_r_z],
+ ],
+ axis=1
+ )
+
+ faces = np.concatenate(
+ [
+ faces_1_l,
+ faces_1_r,
+ faces_2_l,
+ faces_2_r,
+ faces_3_l,
+ faces_3_r,
+ ], axis=0
+ )
vertices = self.loc + self.scale * vertices
mesh = trimesh.Trimesh(vertices, faces, process=False)
@@ -200,7 +224,7 @@ class VoxelGrid:
@property
def resolution(self):
- assert(self.data.shape[0] == self.data.shape[1] == self.data.shape[2])
+ assert (self.data.shape[0] == self.data.shape[1] == self.data.shape[2])
return self.data.shape[0]
def contains(self, points):
@@ -211,12 +235,9 @@ class VoxelGrid:
# Discretize points to [0, nx-1]^3
points_i = ((points + 0.5) * nx).astype(np.int32)
# i1, i2, i3 have sizes (batch_size, T)
- i1, i2, i3 = points_i[..., 0], points_i[..., 1], points_i[..., 2]
+ i1, i2, i3 = points_i[..., 0], points_i[..., 1], points_i[..., 2]
# Only use indices inside bounding box
- mask = (
- (i1 >= 0) & (i2 >= 0) & (i3 >= 0)
- & (nx > i1) & (nx > i2) & (nx > i3)
- )
+ mask = ((i1 >= 0) & (i2 >= 0) & (i3 >= 0) & (nx > i1) & (nx > i2) & (nx > i3))
# Prevent out of bounds error
i1 = i1[mask]
i2 = i2[mask]
@@ -254,7 +275,7 @@ def voxelize_surface(mesh, resolution):
vertices = (vertices + 0.5) * resolution
face_loc = vertices[faces]
- occ = np.full((resolution,) * 3, 0, dtype=np.int32)
+ occ = np.full((resolution, ) * 3, 0, dtype=np.int32)
face_loc = face_loc.astype(np.float32)
voxelize_mesh_(occ, face_loc)
@@ -264,9 +285,9 @@ def voxelize_surface(mesh, resolution):
def voxelize_interior(mesh, resolution):
- shape = (resolution,) * 3
- bb_min = (0.5,) * 3
- bb_max = (resolution - 0.5,) * 3
+ shape = (resolution, ) * 3
+ bb_min = (0.5, ) * 3
+ bb_max = (resolution - 0.5, ) * 3
# Create points. Add noise to break symmetry
points = make_3d_grid(bb_min, bb_max, shape=shape).numpy()
points = points + 0.1 * (np.random.rand(*points.shape) - 0.5)
@@ -280,14 +301,9 @@ def check_voxel_occupied(occupancy_grid):
occ = occupancy_grid
occupied = (
- occ[..., :-1, :-1, :-1]
- & occ[..., :-1, :-1, 1:]
- & occ[..., :-1, 1:, :-1]
- & occ[..., :-1, 1:, 1:]
- & occ[..., 1:, :-1, :-1]
- & occ[..., 1:, :-1, 1:]
- & occ[..., 1:, 1:, :-1]
- & occ[..., 1:, 1:, 1:]
+ occ[..., :-1, :-1, :-1] & occ[..., :-1, :-1, 1:] & occ[..., :-1, 1:, :-1] &
+ occ[..., :-1, 1:, 1:] & occ[..., 1:, :-1, :-1] & occ[..., 1:, :-1, 1:] &
+ occ[..., 1:, 1:, :-1] & occ[..., 1:, 1:, 1:]
)
return occupied
@@ -296,14 +312,9 @@ def check_voxel_unoccupied(occupancy_grid):
occ = occupancy_grid
unoccupied = ~(
- occ[..., :-1, :-1, :-1]
- | occ[..., :-1, :-1, 1:]
- | occ[..., :-1, 1:, :-1]
- | occ[..., :-1, 1:, 1:]
- | occ[..., 1:, :-1, :-1]
- | occ[..., 1:, :-1, 1:]
- | occ[..., 1:, 1:, :-1]
- | occ[..., 1:, 1:, 1:]
+ occ[..., :-1, :-1, :-1] | occ[..., :-1, :-1, 1:] | occ[..., :-1, 1:, :-1] |
+ occ[..., :-1, 1:, 1:] | occ[..., 1:, :-1, :-1] | occ[..., 1:, :-1, 1:] |
+ occ[..., 1:, 1:, :-1] | occ[..., 1:, 1:, 1:]
)
return unoccupied
diff --git a/lib/dataset/Evaluator.py b/lib/dataset/Evaluator.py
index 6e3f1c7218d607174a58c6ac9f6406dbef3262d2..b215a9bb2f81b88029d63b7a83d8d76a842559e3 100644
--- a/lib/dataset/Evaluator.py
+++ b/lib/dataset/Evaluator.py
@@ -37,7 +37,6 @@ class _PointFaceDistance(Function):
"""
Torch autograd Function wrapper PointFaceDistance Cuda implementation
"""
-
@staticmethod
def forward(
ctx,
@@ -92,12 +91,15 @@ class _PointFaceDistance(Function):
grad_dists = grad_dists.contiguous()
points, tris, idxs = ctx.saved_tensors
min_triangle_area = ctx.min_triangle_area
- grad_points, grad_tris = _C.point_face_dist_backward(points, tris, idxs, grad_dists, min_triangle_area)
+ grad_points, grad_tris = _C.point_face_dist_backward(
+ points, tris, idxs, grad_dists, min_triangle_area
+ )
return grad_points, None, grad_tris, None, None, None
-def _rand_barycentric_coords(size1, size2, dtype: torch.dtype,
- device: torch.device) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+def _rand_barycentric_coords(
+ size1, size2, dtype: torch.dtype, device: torch.device
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Helper function to generate random barycentric coordinates which are uniformly
distributed over a triangle.
@@ -167,19 +169,21 @@ def sample_points_from_meshes(meshes, num_samples: int = 10000):
faces = meshes.faces_packed()
mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
num_meshes = len(meshes)
- num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.
+ num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.
# Initialize samples tensor with fill value 0 for empty meshes.
samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
# Only compute samples for non empty meshes
with torch.no_grad():
- areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero.
+ areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero.
max_faces = meshes.num_faces_per_mesh().max().item()
- areas_padded = packed_to_padded(areas, mesh_to_face[meshes.valid], max_faces) # (N, F)
+ areas_padded = packed_to_padded(areas, mesh_to_face[meshes.valid], max_faces) # (N, F)
# TODO (gkioxari) Confirm multinomial bug is not present with real data.
- samples_face_idxs = areas_padded.multinomial(num_samples, replacement=True) # (N, num_samples)
+ samples_face_idxs = areas_padded.multinomial(
+ num_samples, replacement=True
+ ) # (N, num_samples)
samples_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
# Randomly generate barycentric coords.
@@ -200,23 +204,25 @@ def point_mesh_distance(meshes, pcls, weighted=True):
raise ValueError("meshes and pointclouds must be equal sized batches")
# packed representation for pointclouds
- points = pcls.points_packed() # (P, 3)
+ points = pcls.points_packed() # (P, 3)
points_first_idx = pcls.cloud_to_packed_first_idx()
max_points = pcls.num_points_per_cloud().max().item()
# packed representation for faces
verts_packed = meshes.verts_packed()
faces_packed = meshes.faces_packed()
- tris = verts_packed[faces_packed] # (T, 3, 3)
+ tris = verts_packed[faces_packed] # (T, 3, 3)
tris_first_idx = meshes.mesh_to_faces_packed_first_idx()
# point to face distance: shape (P,)
- point_to_face, idxs = _PointFaceDistance.apply(points, points_first_idx, tris, tris_first_idx, max_points, 5e-3)
+ point_to_face, idxs = _PointFaceDistance.apply(
+ points, points_first_idx, tris, tris_first_idx, max_points, 5e-3
+ )
if weighted:
# weight each example by the inverse of number of points in the example
- point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
- num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
+ point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
+ num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
weights_p = 1.0 / weights_p.float()
point_to_face = torch.sqrt(point_to_face) * weights_p
@@ -225,7 +231,6 @@ def point_mesh_distance(meshes, pcls, weighted=True):
class Evaluator:
-
def __init__(self, device):
self.render = Render(size=512, device=device)
@@ -253,8 +258,8 @@ class Evaluator:
self.render.meshes = self.tgt_mesh
tgt_normal_imgs = self.render.get_image(cam_type="four", bg="black")
- src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
- tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
+ src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
+ tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True)
tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True)
@@ -274,8 +279,11 @@ class Evaluator:
# error_hf = ((((src_normal_arr - tgt_normal_arr) * sim_mask)**2).sum(dim=0).mean()) * 4.0
normal_img = Image.fromarray(
- (torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute(1, 2, 0).detach().cpu().numpy() * 255.0).astype(
- np.uint8))
+ (
+ torch.cat([src_normal_arr, tgt_normal_arr],
+ dim=1).permute(1, 2, 0).detach().cpu().numpy() * 255.0
+ ).astype(np.uint8)
+ )
normal_img.save(normal_path)
return error
@@ -291,7 +299,9 @@ class Evaluator:
p2s_dist_all, _ = point_mesh_distance(self.src_mesh, tgt_points) * 100.0
p2s_dist = p2s_dist_all.sum()
- chamfer_dist = (point_mesh_distance(self.tgt_mesh, src_points)[0].sum() * 100.0 + p2s_dist) * 0.5
+ chamfer_dist = (
+ point_mesh_distance(self.tgt_mesh, src_points)[0].sum() * 100.0 + p2s_dist
+ ) * 0.5
return chamfer_dist, p2s_dist
diff --git a/lib/dataset/NormalDataset.py b/lib/dataset/NormalDataset.py
index 1e532b3c820885a8ea96ee65439796ad23de9230..3567ac8cd5a83517a93c80c008bbb9b8d23616a7 100644
--- a/lib/dataset/NormalDataset.py
+++ b/lib/dataset/NormalDataset.py
@@ -23,7 +23,6 @@ import torchvision.transforms as transforms
class NormalDataset:
-
def __init__(self, cfg, split="train"):
self.split = split
@@ -44,8 +43,7 @@ class NormalDataset:
if self.split != "train":
self.rotations = range(0, 360, 120)
else:
- self.rotations = np.arange(0, 360, 360 //
- self.opt.rotation_num).astype(np.int)
+ self.rotations = np.arange(0, 360, 360 // self.opt.rotation_num).astype(np.int)
self.datasets_dict = {}
@@ -54,26 +52,29 @@ class NormalDataset:
dataset_dir = osp.join(self.root, dataset)
self.datasets_dict[dataset] = {
- "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"),
- dtype=str),
+ "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str),
"scale": self.scales[dataset_id],
}
self.subject_list = self.get_subject_list(split)
# PIL to tensor
- self.image_to_tensor = transforms.Compose([
- transforms.Resize(self.input_size),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ])
+ self.image_to_tensor = transforms.Compose(
+ [
+ transforms.Resize(self.input_size),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ ]
+ )
# PIL to tensor
- self.mask_to_tensor = transforms.Compose([
- transforms.Resize(self.input_size),
- transforms.ToTensor(),
- transforms.Normalize((0.0, ), (1.0, )),
- ])
+ self.mask_to_tensor = transforms.Compose(
+ [
+ transforms.Resize(self.input_size),
+ transforms.ToTensor(),
+ transforms.Normalize((0.0, ), (1.0, )),
+ ]
+ )
def get_subject_list(self, split):
@@ -88,16 +89,12 @@ class NormalDataset:
subject_list += np.loadtxt(split_txt, dtype=str).tolist()
if self.split != "test":
- subject_list += subject_list[:self.bsize -
- len(subject_list) % self.bsize]
+ subject_list += subject_list[:self.bsize - len(subject_list) % self.bsize]
print(colored(f"total: {len(subject_list)}", "yellow"))
- bug_list = sorted(
- np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
+ bug_list = sorted(np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
- subject_list = [
- subject for subject in subject_list if (subject not in bug_list)
- ]
+ subject_list = [subject for subject in subject_list if (subject not in bug_list)]
# subject_list = ["thuman2/0008"]
return subject_list
@@ -113,48 +110,41 @@ class NormalDataset:
rotation = self.rotations[rid]
subject = self.subject_list[mid].split("/")[1]
dataset = self.subject_list[mid].split("/")[0]
- render_folder = "/".join(
- [dataset + f"_{self.opt.rotation_num}views", subject])
+ render_folder = "/".join([dataset + f"_{self.opt.rotation_num}views", subject])
if not osp.exists(osp.join(self.root, render_folder)):
render_folder = "/".join([dataset + f"_36views", subject])
# setup paths
data_dict = {
- "dataset":
- dataset,
- "subject":
- subject,
- "rotation":
- rotation,
- "scale":
- self.datasets_dict[dataset]["scale"],
- "image_path":
- osp.join(self.root, render_folder, "render",
- f"{rotation:03d}.png"),
+ "dataset": dataset,
+ "subject": subject,
+ "rotation": rotation,
+ "scale": self.datasets_dict[dataset]["scale"],
+ "image_path": osp.join(self.root, render_folder, "render", f"{rotation:03d}.png"),
}
# image/normal/depth loader
for name, channel in zip(self.in_total, self.in_total_dim):
if f"{name}_path" not in data_dict.keys():
- data_dict.update({
- f"{name}_path":
- osp.join(self.root, render_folder, name,
- f"{rotation:03d}.png")
- })
-
- data_dict.update({
- name:
- self.imagepath2tensor(data_dict[f"{name}_path"],
- channel,
- inv=False,
- erasing=False)
- })
-
- path_keys = [
- key for key in data_dict.keys() if "_path" in key or "_dir" in key
- ]
+ data_dict.update(
+ {
+ f"{name}_path":
+ osp.join(self.root, render_folder, name, f"{rotation:03d}.png")
+ }
+ )
+
+ data_dict.update(
+ {
+ name:
+ self.imagepath2tensor(
+ data_dict[f"{name}_path"], channel, inv=False, erasing=False
+ )
+ }
+ )
+
+ path_keys = [key for key in data_dict.keys() if "_path" in key or "_dir" in key]
for key in path_keys:
del data_dict[key]
@@ -172,10 +162,9 @@ class NormalDataset:
# simulate occlusion
if erasing:
- mask = kornia.augmentation.RandomErasing(p=0.2,
- scale=(0.01, 0.2),
- ratio=(0.3, 3.3),
- keepdim=True)(mask)
+ mask = kornia.augmentation.RandomErasing(
+ p=0.2, scale=(0.01, 0.2), ratio=(0.3, 3.3), keepdim=True
+ )(mask)
image = (image * mask)[:channel]
return (image * (0.5 - inv) * 2.0).float()
diff --git a/lib/dataset/NormalModule.py b/lib/dataset/NormalModule.py
index fbf0a4533f2e23db5ec25cd95511894b9f6296f9..ff672b3c42f5951f4ebf6c8446014d1d277ab02c 100644
--- a/lib/dataset/NormalModule.py
+++ b/lib/dataset/NormalModule.py
@@ -22,7 +22,6 @@ import pytorch_lightning as pl
class NormalModule(pl.LightningDataModule):
-
def __init__(self, cfg):
super(NormalModule, self).__init__()
self.cfg = cfg
@@ -40,7 +39,7 @@ class NormalModule(pl.LightningDataModule):
self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
-
+
self.data_size = {
"train": len(self.train_dataset),
"val": len(self.val_dataset),
@@ -69,7 +68,7 @@ class NormalModule(pl.LightningDataModule):
)
return val_data_loader
-
+
def val_dataloader(self):
test_data_loader = DataLoader(
diff --git a/lib/dataset/PointFeat.py b/lib/dataset/PointFeat.py
index 4ade2b319aaa383c3d86cc319f013254d7ccebd9..457b949e5ce712a1eace33b1306fd48613ba8887 100644
--- a/lib/dataset/PointFeat.py
+++ b/lib/dataset/PointFeat.py
@@ -6,7 +6,6 @@ from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection
class PointFeat:
-
def __init__(self, verts, faces):
# verts [B, N_vert, 3]
@@ -23,7 +22,10 @@ class PointFeat:
if verts.shape[1] == 10475:
faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask]
- mouth_faces = (torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1, 1).to(self.device))
+ mouth_faces = (
+ torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1,
+ 1).to(self.device)
+ )
self.faces = torch.cat([faces, mouth_faces], dim=1).long()
self.verts = verts.float()
@@ -35,11 +37,15 @@ class PointFeat:
points = points.float()
residues, pts_ind = point_mesh_distance(self.mesh, Pointclouds(points), weighted=False)
- closest_triangles = torch.gather(self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
+ closest_triangles = torch.gather(
+ self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
+ ).view(-1, 3, 3)
bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles)
feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces)
- closest_normals = torch.gather(feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
+ closest_normals = torch.gather(
+ feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
+ ).view(-1, 3, 3)
shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0))
pts2shoot_normals = points - shoot_verts
diff --git a/lib/dataset/TestDataset.py b/lib/dataset/TestDataset.py
index 49d6187bbfd131ad95cf63d5f59984e26acb1071..e016d3d94b3d9f043ef5a8526d2d1be67be6f4a7 100644
--- a/lib/dataset/TestDataset.py
+++ b/lib/dataset/TestDataset.py
@@ -25,6 +25,7 @@ from lib.pixielib.utils.config import cfg as pixie_cfg
from lib.pixielib.pixie import PIXIE
from lib.pixielib.models.SMPLX import SMPLX as PIXIE_SMPLX
from lib.common.imutils import process_image
+from lib.common.train_util import Format
from lib.net.geometry import rotation_matrix_to_angle_axis, rot6d_to_rotmat
from lib.pymafx.core import path_config
@@ -36,8 +37,9 @@ from lib.dataset.body_model import TetraSMPLModel
from lib.dataset.mesh_util import get_visibility, SMPLX
import torch.nn.functional as F
from torchvision import transforms
+from torchvision.models import detection
+
import os.path as osp
-import os
import torch
import glob
import numpy as np
@@ -48,7 +50,6 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
class TestDataset:
-
def __init__(self, cfg, device):
self.image_dir = cfg["image_dir"]
@@ -65,7 +66,9 @@ class TestDataset:
keep_lst = sorted(glob.glob(f"{self.image_dir}/*"))
img_fmts = ["jpg", "png", "jpeg", "JPG", "bmp"]
- self.subject_list = sorted([item for item in keep_lst if item.split(".")[-1] in img_fmts], reverse=False)
+ self.subject_list = sorted(
+ [item for item in keep_lst if item.split(".")[-1] in img_fmts], reverse=False
+ )
# smpl related
self.smpl_data = SMPLX()
@@ -80,7 +83,16 @@ class TestDataset:
self.smpl_model = PIXIE_SMPLX(pixie_cfg.model).to(self.device)
- print(colored(f"Use {self.hps_type.upper()} to estimate human pose and shape", "green"))
+ self.detector = detection.maskrcnn_resnet50_fpn(
+ weights=detection.MaskRCNN_ResNet50_FPN_V2_Weights
+ )
+ self.detector.eval()
+
+ print(
+ colored(
+ f"SMPL-XÂ estimate with {Format.start} {self.hps_type.upper()} {Format.end}", "green"
+ )
+ )
self.render = Render(size=512, device=self.device)
@@ -90,7 +102,9 @@ class TestDataset:
def compute_vis_cmap(self, smpl_verts, smpl_faces):
(xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=-1)
- smpl_vis = get_visibility(xy, z, torch.as_tensor(smpl_faces).long()[:, :, [0, 2, 1]]).unsqueeze(-1)
+ smpl_vis = get_visibility(xy, z,
+ torch.as_tensor(smpl_faces).long()[:, :,
+ [0, 2, 1]]).unsqueeze(-1)
smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type).unsqueeze(0)
return {
@@ -109,7 +123,8 @@ class TestDataset:
depth_FB[:, ~depth_mask[0]] = 0.
# Important: index_long = depth_value - 1
- index_z = (((depth_FB + 1.) * 0.5 * self.vol_res) - 1).clip(0, self.vol_res - 1).permute(1, 2, 0)
+ index_z = (((depth_FB + 1.) * 0.5 * self.vol_res) - 1).clip(0, self.vol_res -
+ 1).permute(1, 2, 0)
index_z_ceil = torch.ceil(index_z).long()
index_z_floor = torch.floor(index_z).long()
index_z_frac = torch.frac(index_z)
@@ -121,7 +136,7 @@ class TestDataset:
F.one_hot(index_z_floor[..., 1], self.vol_res) * (1.0 - index_z_frac[..., 1])
voxels[index_mask] *= 0
- voxels = torch.flip(voxels, [2]).permute(2, 0, 1).float() #[x-2, y-0, z-1]
+ voxels = torch.flip(voxels, [2]).permute(2, 0, 1).float() #[x-2, y-0, z-1]
return {
"depth_voxels": voxels.flip([
@@ -139,18 +154,25 @@ class TestDataset:
smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0])
verts = (
- np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * scale.item() + trans.detach().cpu().numpy())
+ np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * scale.item() +
+ trans.detach().cpu().numpy()
+ )
faces = (
np.loadtxt(
osp.join(self.smpl_data.tedra_dir, "tetrahedrons_neutral_adult.txt"),
dtype=np.int32,
- ) - 1)
+ ) - 1
+ )
pad_v_num = int(8000 - verts.shape[0])
pad_f_num = int(25100 - faces.shape[0])
- verts = (np.pad(verts, ((0, pad_v_num), (0, 0)), mode="constant", constant_values=0.0).astype(np.float32) * 0.5)
- faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode="constant", constant_values=0.0).astype(np.int32)
+ verts = (
+ np.pad(verts, ((0, pad_v_num),
+ (0, 0)), mode="constant", constant_values=0.0).astype(np.float32) * 0.5
+ )
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode="constant",
+ constant_values=0.0).astype(np.int32)
verts[:, 2] *= -1.0
@@ -168,7 +190,7 @@ class TestDataset:
img_path = self.subject_list[index]
img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
- arr_dict = process_image(img_path, self.hps_type, self.single, 512)
+ arr_dict = process_image(img_path, self.hps_type, self.single, 512, self.detector)
arr_dict.update({"name": img_name})
with torch.no_grad():
@@ -179,7 +201,10 @@ class TestDataset:
preds_dict, _ = self.hps.forward(batch)
arr_dict["smpl_faces"] = (
- torch.as_tensor(self.smpl_data.smplx_faces.astype(np.int64)).unsqueeze(0).long().to(self.device))
+ torch.as_tensor(self.smpl_data.smplx_faces.astype(np.int64)).unsqueeze(0).long().to(
+ self.device
+ )
+ )
arr_dict["type"] = self.smpl_type
if self.hps_type == "pymafx":
@@ -198,13 +223,16 @@ class TestDataset:
elif self.hps_type == "pixie":
arr_dict.update(preds_dict)
arr_dict["global_orient"] = preds_dict["global_pose"]
- arr_dict["betas"] = preds_dict["shape"] #200
+ arr_dict["betas"] = preds_dict["shape"] #200
arr_dict["smpl_verts"] = preds_dict["vertices"]
scale, tranX, tranY = preds_dict["cam"].split(1, dim=1)
# 1.1435, 0.0128, 0.3520
arr_dict["scale"] = scale.unsqueeze(1)
- arr_dict["trans"] = (torch.cat([tranX, tranY, torch.zeros_like(tranX)], dim=1).unsqueeze(1).to(self.device).float())
+ arr_dict["trans"] = (
+ torch.cat([tranX, tranY, torch.zeros_like(tranX)],
+ dim=1).unsqueeze(1).to(self.device).float()
+ )
# data_dict info (key-shape):
# scale, tranX, tranY - tensor.float
@@ -230,4 +258,4 @@ class TestDataset:
# render optimized mesh (normal, T_normal, image [-1,1])
self.render.load_meshes(verts, faces)
- return self.render.get_image(type="depth")
\ No newline at end of file
+ return self.render.get_image(type="depth")
diff --git a/lib/dataset/body_model.py b/lib/dataset/body_model.py
index f41f481ae67e124cbb45c5c8f34179c5c6f49311..cebb105591cab29d833f2965ec609c85fd522881 100644
--- a/lib/dataset/body_model.py
+++ b/lib/dataset/body_model.py
@@ -21,7 +21,6 @@ import os
class SMPLModel:
-
def __init__(self, model_path, age):
"""
SMPL model.
@@ -49,20 +48,16 @@ class SMPLModel:
if age == "kid":
v_template_smil = np.load(
- os.path.join(os.path.dirname(model_path),
- "smpl/smpl_kid_template.npy"))
+ os.path.join(os.path.dirname(model_path), "smpl/smpl_kid_template.npy")
+ )
v_template_smil -= np.mean(v_template_smil, axis=0)
- v_template_diff = np.expand_dims(v_template_smil - self.v_template,
- axis=2)
+ v_template_diff = np.expand_dims(v_template_smil - self.v_template, axis=2)
self.shapedirs = np.concatenate(
- (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
- axis=2)
+ (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), axis=2
+ )
self.beta_shape[0] += 1
- id_to_col = {
- self.kintree_table[1, i]: i
- for i in range(self.kintree_table.shape[1])
- }
+ id_to_col = {self.kintree_table[1, i]: i for i in range(self.kintree_table.shape[1])}
self.parent = {
i: id_to_col[self.kintree_table[0, i]]
for i in range(1, self.kintree_table.shape[1])
@@ -121,33 +116,30 @@ class SMPLModel:
pose_cube = self.pose.reshape((-1, 1, 3))
# rotation matrix for each joint
self.R = self.rodrigues(pose_cube)
- I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
- (self.R.shape[0] - 1, 3, 3))
+ I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), (self.R.shape[0] - 1, 3, 3))
lrotmin = (self.R[1:] - I_cube).ravel()
# how pose affect body shape in zero pose
v_posed = v_shaped + self.posedirs.dot(lrotmin)
# world transformation of each joint
G = np.empty((self.kintree_table.shape[1], 4, 4))
- G[0] = self.with_zeros(
- np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
+ G[0] = self.with_zeros(np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
for i in range(1, self.kintree_table.shape[1]):
G[i] = G[self.parent[i]].dot(
self.with_zeros(
- np.hstack([
- self.R[i],
- ((self.J[i, :] - self.J[self.parent[i], :]).reshape(
- [3, 1])),
- ])))
+ np.hstack(
+ [
+ self.R[i],
+ ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
+ ]
+ )
+ )
+ )
# remove the transformation due to the rest pose
- G = G - self.pack(
- np.matmul(
- G,
- np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
+ G = G - self.pack(np.matmul(G, np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
# transformation of each vertex
T = np.tensordot(self.weights, G, axes=[[1], [0]])
rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
- v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
- 4])[:, :3]
+ v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
self.verts = v + self.trans.reshape([1, 3])
self.G = G
@@ -171,19 +163,20 @@ class SMPLModel:
r_hat = r / theta
cos = np.cos(theta)
z_stick = np.zeros(theta.shape[0])
- m = np.dstack([
- z_stick,
- -r_hat[:, 0, 2],
- r_hat[:, 0, 1],
- r_hat[:, 0, 2],
- z_stick,
- -r_hat[:, 0, 0],
- -r_hat[:, 0, 1],
- r_hat[:, 0, 0],
- z_stick,
- ]).reshape([-1, 3, 3])
- i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
- [theta.shape[0], 3, 3])
+ m = np.dstack(
+ [
+ z_stick,
+ -r_hat[:, 0, 2],
+ r_hat[:, 0, 1],
+ r_hat[:, 0, 2],
+ z_stick,
+ -r_hat[:, 0, 0],
+ -r_hat[:, 0, 1],
+ r_hat[:, 0, 0],
+ z_stick,
+ ]
+ ).reshape([-1, 3, 3])
+ i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3])
A = np.transpose(r_hat, axes=[0, 2, 1])
B = r_hat
dot = np.matmul(A, B)
@@ -238,12 +231,7 @@ class SMPLModel:
class TetraSMPLModel:
-
- def __init__(self,
- model_path,
- model_addition_path,
- age="adult",
- v_template=None):
+ def __init__(self, model_path, model_addition_path, age="adult", v_template=None):
"""
SMPL model.
@@ -276,10 +264,7 @@ class TetraSMPLModel:
self.posedirs_added = params_added["posedirs_added"]
self.tetrahedrons = params_added["tetrahedrons"]
- id_to_col = {
- self.kintree_table[1, i]: i
- for i in range(self.kintree_table.shape[1])
- }
+ id_to_col = {self.kintree_table[1, i]: i for i in range(self.kintree_table.shape[1])}
self.parent = {
i: id_to_col[self.kintree_table[0, i]]
for i in range(1, self.kintree_table.shape[1])
@@ -291,14 +276,13 @@ class TetraSMPLModel:
if age == "kid":
v_template_smil = np.load(
- os.path.join(os.path.dirname(model_path),
- "smpl_kid_template.npy"))
+ os.path.join(os.path.dirname(model_path), "smpl_kid_template.npy")
+ )
v_template_smil -= np.mean(v_template_smil, axis=0)
- v_template_diff = np.expand_dims(v_template_smil - self.v_template,
- axis=2)
+ v_template_diff = np.expand_dims(v_template_smil - self.v_template, axis=2)
self.shapedirs = np.concatenate(
- (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
- axis=2)
+ (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), axis=2
+ )
self.beta_shape[0] += 1
self.pose = np.zeros(self.pose_shape)
@@ -356,50 +340,42 @@ class TetraSMPLModel:
"""
# how beta affect body shape
v_shaped = self.shapedirs.dot(self.beta) + self.v_template
- v_shaped_added = self.shapedirs_added.dot(
- self.beta) + self.v_template_added
+ v_shaped_added = self.shapedirs_added.dot(self.beta) + self.v_template_added
# joints location
self.J = self.J_regressor.dot(v_shaped)
pose_cube = self.pose.reshape((-1, 1, 3))
# rotation matrix for each joint
self.R = self.rodrigues(pose_cube)
- I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
- (self.R.shape[0] - 1, 3, 3))
+ I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), (self.R.shape[0] - 1, 3, 3))
lrotmin = (self.R[1:] - I_cube).ravel()
# how pose affect body shape in zero pose
v_posed = v_shaped + self.posedirs.dot(lrotmin)
v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin)
# world transformation of each joint
G = np.empty((self.kintree_table.shape[1], 4, 4))
- G[0] = self.with_zeros(
- np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
+ G[0] = self.with_zeros(np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
for i in range(1, self.kintree_table.shape[1]):
G[i] = G[self.parent[i]].dot(
self.with_zeros(
- np.hstack([
- self.R[i],
- ((self.J[i, :] - self.J[self.parent[i], :]).reshape(
- [3, 1])),
- ])))
+ np.hstack(
+ [
+ self.R[i],
+ ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
+ ]
+ )
+ )
+ )
# remove the transformation due to the rest pose
- G = G - self.pack(
- np.matmul(
- G,
- np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
+ G = G - self.pack(np.matmul(G, np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
self.G = G
# transformation of each vertex
T = np.tensordot(self.weights, G, axes=[[1], [0]])
rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
- v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
- 4])[:, :3]
+ v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
self.verts = v + self.trans.reshape([1, 3])
T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]])
- rest_shape_added_h = np.hstack(
- (v_posed_added, np.ones([v_posed_added.shape[0], 1])))
- v_added = np.matmul(T_added,
- rest_shape_added_h.reshape([-1, 4,
- 1])).reshape([-1, 4
- ])[:, :3]
+ rest_shape_added_h = np.hstack((v_posed_added, np.ones([v_posed_added.shape[0], 1])))
+ v_added = np.matmul(T_added, rest_shape_added_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
self.verts_added = v_added + self.trans.reshape([1, 3])
def rodrigues(self, r):
@@ -422,19 +398,20 @@ class TetraSMPLModel:
r_hat = r / theta
cos = np.cos(theta)
z_stick = np.zeros(theta.shape[0])
- m = np.dstack([
- z_stick,
- -r_hat[:, 0, 2],
- r_hat[:, 0, 1],
- r_hat[:, 0, 2],
- z_stick,
- -r_hat[:, 0, 0],
- -r_hat[:, 0, 1],
- r_hat[:, 0, 0],
- z_stick,
- ]).reshape([-1, 3, 3])
- i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
- [theta.shape[0], 3, 3])
+ m = np.dstack(
+ [
+ z_stick,
+ -r_hat[:, 0, 2],
+ r_hat[:, 0, 1],
+ r_hat[:, 0, 2],
+ z_stick,
+ -r_hat[:, 0, 0],
+ -r_hat[:, 0, 1],
+ r_hat[:, 0, 0],
+ z_stick,
+ ]
+ ).reshape([-1, 3, 3])
+ i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3])
A = np.transpose(r_hat, axes=[0, 2, 1])
B = r_hat
dot = np.matmul(A, B)
diff --git a/lib/dataset/mesh_util.py b/lib/dataset/mesh_util.py
index 52496a221fc995f891abdebe892acbf2f701aeca..b639763e5ea98fd79698878a47c3089b08f4b86e 100644
--- a/lib/dataset/mesh_util.py
+++ b/lib/dataset/mesh_util.py
@@ -14,32 +14,33 @@
#
# Contact: ps-license@tuebingen.mpg.de
+import os
import numpy as np
-import cv2
-import pymeshlab
import torch
import torchvision
import trimesh
-import os
-from termcolor import colored
+import open3d as o3d
+import tinyobjloader
import os.path as osp
import _pickle as cPickle
+from termcolor import colored
from scipy.spatial import cKDTree
from pytorch3d.structures import Meshes
import torch.nn.functional as F
import lib.smplx as smplx
+from lib.common.render_utils import Pytorch3dRasterizer
from pytorch3d.renderer.mesh import rasterize_meshes
from PIL import Image, ImageFont, ImageDraw
from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency
-import tinyobjloader
-from lib.common.imutils import uncrop
-from lib.common.render_utils import Pytorch3dRasterizer
+class Format:
+ end = '\033[0m'
+ start = '\033[4m'
-class SMPLX:
+class SMPLX:
def __init__(self):
self.current_dir = osp.join(osp.dirname(__file__), "../../data/smpl_related")
@@ -54,10 +55,14 @@ class SMPLX:
self.smplx_eyeball_fid_path = osp.join(self.current_dir, "smpl_data/eyeball_fid.npy")
self.smplx_fill_mouth_fid_path = osp.join(self.current_dir, "smpl_data/fill_mouth_fid.npy")
- self.smplx_flame_vid_path = osp.join(self.current_dir, "smpl_data/FLAME_SMPLX_vertex_ids.npy")
+ self.smplx_flame_vid_path = osp.join(
+ self.current_dir, "smpl_data/FLAME_SMPLX_vertex_ids.npy"
+ )
self.smplx_mano_vid_path = osp.join(self.current_dir, "smpl_data/MANO_SMPLX_vertex_ids.pkl")
self.front_flame_path = osp.join(self.current_dir, "smpl_data/FLAME_face_mask_ids.npy")
- self.smplx_vertex_lmkid_path = osp.join(self.current_dir, "smpl_data/smplx_vertex_lmkid.npy")
+ self.smplx_vertex_lmkid_path = osp.join(
+ self.current_dir, "smpl_data/smplx_vertex_lmkid.npy"
+ )
self.smplx_faces = np.load(self.smplx_faces_path)
self.smplx_verts = np.load(self.smplx_verts_path)
@@ -68,84 +73,51 @@ class SMPLX:
self.smplx_eyeball_fid_mask = np.load(self.smplx_eyeball_fid_path)
self.smplx_mouth_fid = np.load(self.smplx_fill_mouth_fid_path)
self.smplx_mano_vid_dict = np.load(self.smplx_mano_vid_path, allow_pickle=True)
- self.smplx_mano_vid = np.concatenate([self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]])
+ self.smplx_mano_vid = np.concatenate(
+ [self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]]
+ )
self.smplx_flame_vid = np.load(self.smplx_flame_vid_path, allow_pickle=True)
self.smplx_front_flame_vid = self.smplx_flame_vid[np.load(self.front_flame_path)]
# hands
- self.mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_(0, torch.tensor(self.smplx_mano_vid), 1.0)
+ self.mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_(
+ 0, torch.tensor(self.smplx_mano_vid), 1.0
+ )
# face
- self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_(
- 0, torch.tensor(self.smplx_front_flame_vid), 1.0)
- self.eyeball_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_(
- 0, torch.tensor(self.smplx_faces[self.smplx_eyeball_fid_mask].flatten()), 1.0)
+ self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_(
+ 0, torch.tensor(self.smplx_front_flame_vid), 1.0
+ )
+ self.eyeball_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_(
+ 0, torch.tensor(self.smplx_faces[self.smplx_eyeball_fid_mask].flatten()), 1.0
+ )
self.smplx_to_smpl = cPickle.load(open(self.smplx_to_smplx_path, "rb"))
self.model_dir = osp.join(self.current_dir, "models")
self.tedra_dir = osp.join(self.current_dir, "../tedra_data")
- self.ghum_smpl_pairs = torch.tensor([
- (0, 24),
- (2, 26),
- (5, 25),
- (7, 28),
- (8, 27),
- (11, 16),
- (12, 17),
- (13, 18),
- (14, 19),
- (15, 20),
- (16, 21),
- (17, 39),
- (18, 44),
- (19, 36),
- (20, 41),
- (21, 35),
- (22, 40),
- (23, 1),
- (24, 2),
- (25, 4),
- (26, 5),
- (27, 7),
- (28, 8),
- (29, 31),
- (30, 34),
- (31, 29),
- (32, 32),
- ]).long()
+ self.ghum_smpl_pairs = torch.tensor(
+ [
+ (0, 24), (2, 26), (5, 25), (7, 28), (8, 27), (11, 16), (12, 17), (13, 18), (14, 19),
+ (15, 20), (16, 21), (17, 39), (18, 44), (19, 36), (20, 41), (21, 35), (22, 40),
+ (23, 1), (24, 2), (25, 4), (26, 5), (27, 7), (28, 8), (29, 31), (30, 34), (31, 29),
+ (32, 32)
+ ]
+ ).long()
# smpl-smplx correspondence
self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73]
self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [61 + 68, 72 + 68]
- self.smpl_joint_ids_45 = (np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist())
-
- self.extra_joint_ids = (
- np.array([
- 61,
- 72,
- 66,
- 69,
- 58,
- 68,
- 57,
- 56,
- 64,
- 59,
- 67,
- 75,
- 70,
- 65,
- 60,
- 61,
- 63,
- 62,
- 76,
- 71,
- 72,
- 74,
- 73,
- ]) + 68)
+ self.smpl_joint_ids_45 = np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist()
+
+ self.extra_joint_ids = np.array(
+ [
+ 61, 72, 66, 69, 58, 68, 57, 56, 64, 59, 67, 75, 70, 65, 60, 61, 63, 62, 76, 71, 72,
+ 74, 73
+ ]
+ )
+
+ self.extra_joint_ids += 68
self.smpl_joint_ids_45_pixie = (np.arange(22).tolist() + self.extra_joint_ids.tolist())
@@ -222,27 +194,6 @@ def load_fit_body(fitted_path, scale, smpl_type="smplx", smpl_gender="neutral",
return smpl_mesh, smpl_joints
-def create_grid_points_from_xyz_bounds(bound, res):
-
- min_x, max_x, min_y, max_y, min_z, max_z = bound
- x = torch.linspace(min_x, max_x, res)
- y = torch.linspace(min_y, max_y, res)
- z = torch.linspace(min_z, max_z, res)
- X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
-
- return torch.stack([X, Y, Z], dim=-1)
-
-
-def create_grid_points_from_xy_bounds(bound, res):
-
- min_x, max_x, min_y, max_y = bound
- x = torch.linspace(min_x, max_x, res)
- y = torch.linspace(min_y, max_y, res)
- X, Y = torch.meshgrid(x, y, indexing='ij')
-
- return torch.stack([X, Y], dim=-1)
-
-
def apply_face_mask(mesh, face_mask):
mesh.update_faces(face_mask)
@@ -277,7 +228,8 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr
part_extractor = PointFeat(
torch.tensor(part_mesh.vertices).unsqueeze(0).to(device),
- torch.tensor(part_mesh.faces).unsqueeze(0).to(device))
+ torch.tensor(part_mesh.faces).unsqueeze(0).to(device)
+ )
(part_dist, _) = part_extractor.query(torch.tensor(full_mesh.vertices).unsqueeze(0).to(device))
@@ -286,12 +238,20 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr
if region == "hand":
_, idx = smpl_tree.query(full_mesh.vertices, k=1)
full_lmkid = SMPL_container.smplx_vertex_lmkid[idx]
- remove_mask = torch.logical_and(remove_mask, torch.tensor(full_lmkid >= 20).type_as(remove_mask).unsqueeze(0))
+ remove_mask = torch.logical_and(
+ remove_mask,
+ torch.tensor(full_lmkid >= 20).type_as(remove_mask).unsqueeze(0)
+ )
elif region == "face":
_, idx = smpl_tree.query(full_mesh.vertices, k=5)
- face_space_mask = torch.isin(torch.tensor(idx), torch.tensor(SMPL_container.smplx_front_flame_vid))
- remove_mask = torch.logical_and(remove_mask, face_space_mask.any(dim=1).type_as(remove_mask).unsqueeze(0))
+ face_space_mask = torch.isin(
+ torch.tensor(idx), torch.tensor(SMPL_container.smplx_front_flame_vid)
+ )
+ remove_mask = torch.logical_and(
+ remove_mask,
+ face_space_mask.any(dim=1).type_as(remove_mask).unsqueeze(0)
+ )
BNI_part_mask = ~(remove_mask).flatten()[full_mesh.faces].any(dim=1)
full_mesh.update_faces(BNI_part_mask.detach().cpu())
@@ -303,109 +263,6 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr
return full_mesh
-def cross(triangles):
- """
- Returns the cross product of two edges from input triangles
- Parameters
- --------------
- triangles: (n, 3, 3) float
- Vertices of triangles
- Returns
- --------------
- crosses : (n, 3) float
- Cross product of two edge vectors
- """
- vectors = np.diff(triangles, axis=1)
- crosses = np.cross(vectors[:, 0], vectors[:, 1])
- return crosses
-
-
-def tri_area(triangles=None, crosses=None, sum=False):
- """
- Calculates the sum area of input triangles
- Parameters
- ----------
- triangles : (n, 3, 3) float
- Vertices of triangles
- crosses : (n, 3) float or None
- As a speedup don't re- compute cross products
- sum : bool
- Return summed area or individual triangle area
- Returns
- ----------
- area : (n,) float or float
- Individual or summed area depending on `sum` argument
- """
- if crosses is None:
- crosses = cross(triangles)
- area = (np.sum(crosses**2, axis=1)**.5) * .5
- if sum:
- return np.sum(area)
- return area
-
-
-def sample_surface(triangles, count, area=None):
- """
- Sample the surface of a mesh, returning the specified
- number of points
- For individual triangle sampling uses this method:
- http://mathworld.wolfram.com/TrianglePointPicking.html
- Parameters
- ---------
- triangles : (n, 3, 3) float
- Vertices of triangles
- count : int
- Number of points to return
- Returns
- ---------
- samples : (count, 3) float
- Points in space on the surface of mesh
- face_index : (count,) int
- Indices of faces for each sampled point
- """
-
- # len(mesh.faces) float, array of the areas
- # of each face of the mesh
- if area is None:
- area = tri_area(triangles)
-
- # total area (float)
- area_sum = np.sum(area)
- # cumulative area (len(mesh.faces))
- area_cum = np.cumsum(area)
- face_pick = np.random.random(count) * area_sum
- face_index = np.searchsorted(area_cum, face_pick)
-
- # pull triangles into the form of an origin + 2 vectors
- tri_origins = triangles[:, 0]
- tri_vectors = triangles[:, 1:].copy()
- tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))
-
- # pull the vectors for the faces we are going to sample from
- tri_origins = tri_origins[face_index]
- tri_vectors = tri_vectors[face_index]
-
- # randomly generate two 0-1 scalar components to multiply edge vectors by
- random_lengths = np.random.random((len(tri_vectors), 2, 1))
-
- # points will be distributed on a quadrilateral if we use 2 0-1 samples
- # if the two scalar components sum less than 1.0 the point will be
- # inside the triangle, so we find vectors longer than 1.0 and
- # transform them to be inside the triangle
- random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0
- random_lengths[random_test] -= 1.0
- random_lengths = np.abs(random_lengths)
-
- # multiply triangle edge vectors by the random lengths and sum
- sample_vector = (tri_vectors * random_lengths).sum(axis=1)
-
- # finally, offset by the origin to generate
- # (n,3) points in space on the triangle
- samples = torch.tensor(sample_vector + tri_origins).float()
-
- return samples, face_index
-
-
def obj_loader(path, with_uv=True):
# Create reader.
reader = tinyobjloader.ObjReader()
@@ -424,8 +281,8 @@ def obj_loader(path, with_uv=True):
f_vt = tri[:, [2, 5, 8]]
if with_uv:
- face_uvs = vt[f_vt].mean(axis=1) #[m, 2]
- vert_uvs = np.zeros((v.shape[0], 2), dtype=np.float32) #[n, 2]
+ face_uvs = vt[f_vt].mean(axis=1) #[m, 2]
+ vert_uvs = np.zeros((v.shape[0], 2), dtype=np.float32) #[n, 2]
vert_uvs[f_v.reshape(-1)] = vt[f_vt.reshape(-1)]
return v, f_v, vert_uvs, face_uvs
@@ -434,7 +291,6 @@ def obj_loader(path, with_uv=True):
class HoppeMesh:
-
def __init__(self, verts, faces, uvs=None, texture=None):
"""
The HoppeSDF calculates signed distance towards a predefined oriented point cloud
@@ -459,34 +315,20 @@ class HoppeMesh:
- points: [n, 3]
- return: [n, 4] rgba
"""
- triangles = self.verts[faces] #[n, 3, 3]
- barycentric = trimesh.triangles.points_to_barycentric(triangles, points) #[n, 3]
- vert_colors = self.vertex_colors[faces] #[n, 3, 4]
+ triangles = self.verts[faces] #[n, 3, 3]
+ barycentric = trimesh.triangles.points_to_barycentric(triangles, points) #[n, 3]
+ vert_colors = self.vertex_colors[faces] #[n, 3, 4]
point_colors = torch.tensor((barycentric[:, :, None] * vert_colors).sum(axis=1)).float()
return point_colors
def triangles(self):
- return self.verts[self.faces].numpy() #[n, 3, 3]
+ return self.verts[self.faces].numpy() #[n, 3, 3]
def tensor2variable(tensor, device):
return tensor.requires_grad_(True).to(device)
-class GMoF(torch.nn.Module):
-
- def __init__(self, rho=1):
- super(GMoF, self).__init__()
- self.rho = rho
-
- def extra_repr(self):
- return "rho = {}".format(self.rho)
-
- def forward(self, residual):
- dist = torch.div(residual, residual + self.rho**2)
- return self.rho**2 * dist
-
-
def mesh_edge_loss(meshes, target_length: float = 0.0):
"""
Computes mesh edge length regularization loss averaged across all meshes
@@ -508,10 +350,10 @@ def mesh_edge_loss(meshes, target_length: float = 0.0):
return torch.tensor([0.0], dtype=torch.float32, device=meshes.device, requires_grad=True)
N = len(meshes)
- edges_packed = meshes.edges_packed() # (sum(E_n), 3)
- verts_packed = meshes.verts_packed() # (sum(V_n), 3)
- edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
- num_edges_per_mesh = meshes.num_edges_per_mesh() # N
+ edges_packed = meshes.edges_packed() # (sum(E_n), 3)
+ verts_packed = meshes.verts_packed() # (sum(V_n), 3)
+ edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
+ num_edges_per_mesh = meshes.num_edges_per_mesh() # N
# Determine the weight for each edge based on the number of edges in the
# mesh it corresponds to.
@@ -531,99 +373,37 @@ def mesh_edge_loss(meshes, target_length: float = 0.0):
return loss_all
-def remesh(obj, obj_path):
-
- obj.export(obj_path)
- ms = pymeshlab.MeshSet()
- ms.load_new_mesh(obj_path)
- # ms.meshing_decimation_quadric_edge_collapse(targetfacenum=100000)
- ms.meshing_isotropic_explicit_remeshing(targetlen=pymeshlab.Percentage(0.5), adaptive=True)
- ms.apply_coord_laplacian_smoothing()
- ms.save_current_mesh(obj_path[:-4] + "_remesh.obj")
- polished_mesh = trimesh.load_mesh(obj_path[:-4] + "_remesh.obj")
+def remesh_laplacian(mesh, obj_path):
- return polished_mesh
-
-
-def poisson_remesh(obj_path):
-
- ms = pymeshlab.MeshSet()
- ms.load_new_mesh(obj_path)
- ms.meshing_decimation_quadric_edge_collapse(targetfacenum=50000)
- # ms.apply_coord_laplacian_smoothing()
- ms.save_current_mesh(obj_path)
- # ms.save_current_mesh(obj_path.replace(".obj", ".ply"))
- polished_mesh = trimesh.load_mesh(obj_path)
+ mesh = mesh.simplify_quadratic_decimation(50000)
+ mesh = trimesh.smoothing.filter_humphrey(
+ mesh, alpha=0.1, beta=0.5, iterations=10, laplacian_operator=None
+ )
+ mesh.export(obj_path)
- return polished_mesh
+ return mesh
def poisson(mesh, obj_path, depth=10):
- from pypoisson import poisson_reconstruction
- faces, vertices = poisson_reconstruction(mesh.vertices, mesh.vertex_normals, depth=depth)
-
- new_meshes = trimesh.Trimesh(vertices, faces)
- new_mesh_lst = new_meshes.split(only_watertight=False)
- comp_num = [new_mesh.vertices.shape[0] for new_mesh in new_mesh_lst]
- final_mesh = new_mesh_lst[comp_num.index(max(comp_num))]
- final_mesh.export(obj_path)
+ pcd_path = obj_path[:-4] + ".ply"
+ assert (mesh.vertex_normals.shape[1] == 3)
+ mesh.export(pcd_path)
+ pcl = o3d.io.read_point_cloud(pcd_path)
+ with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm:
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
+ pcl, depth=depth, n_threads=-1
+ )
+ print(colored(f"\n Poisson completion to {Format.start} {obj_path} {Format.end}", "yellow"))
- final_mesh = poisson_remesh(obj_path)
+ # only keep the largest component
+ largest_mesh = keep_largest(trimesh.Trimesh(np.array(mesh.vertices), np.array(mesh.triangles)))
+ largest_mesh.export(obj_path)
- return final_mesh
+ # mesh decimation for faster rendering
+ low_res_mesh = largest_mesh.simplify_quadratic_decimation(50000)
-
-def get_mask(tensor, dim):
-
- mask = torch.abs(tensor).sum(dim=dim, keepdims=True) > 0.0
- mask = mask.type_as(tensor)
-
- return mask
-
-
-def blend_rgb_norm(norms, data):
-
- # norms [N, 3, res, res]
-
- masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1)
- norm_mask = F.interpolate(
- torch.cat([norms, masks], dim=1).detach().cpu(),
- size=data["uncrop_param"]["box_shape"],
- mode="bilinear",
- align_corners=False).permute(0, 2, 3, 1).numpy()
- final = data["img_raw"]
-
- for idx in range(len(norms)):
-
- norm_pred = (norm_mask[idx, :, :, :3] + 1.0) * 255.0 / 2.0
- mask_pred = np.repeat(norm_mask[idx, :, :, 3:4], 3, axis=-1)
-
- norm_ori = unwrap(norm_pred, data["uncrop_param"], idx)
- mask_ori = unwrap(mask_pred, data["uncrop_param"], idx)
-
- final = final * (1.0 - mask_ori) + norm_ori * mask_ori
-
- return final.astype(np.uint8)
-
-
-def unwrap(image, uncrop_param, idx):
-
- img_uncrop = uncrop(
- image,
- uncrop_param["center"][idx],
- uncrop_param["scale"][idx],
- uncrop_param["crop_shape"],
- )
-
- img_orig = cv2.warpAffine(
- img_uncrop,
- np.linalg.inv(uncrop_param["M"])[:2, :],
- uncrop_param["ori_shape"][::-1],
- flags=cv2.INTER_CUBIC,
- )
-
- return img_orig
+ return low_res_mesh
# Losses to smooth / regularize the mesh shape
@@ -634,60 +414,7 @@ def update_mesh_shape_prior_losses(mesh, losses):
# mesh normal consistency
losses["nc"]["value"] = mesh_normal_consistency(mesh)
# mesh laplacian smoothing
- losses["laplacian"]["value"] = mesh_laplacian_smoothing(mesh, method="uniform")
-
-
-def rename(old_dict, old_name, new_name):
- new_dict = {}
- for key, value in zip(old_dict.keys(), old_dict.values()):
- new_key = key if key != old_name else new_name
- new_dict[new_key] = old_dict[key]
- return new_dict
-
-
-def load_checkpoint(model, cfg):
-
- model_dict = model.state_dict()
- main_dict = {}
- normal_dict = {}
-
- device = torch.device(f"cuda:{cfg['test_gpus'][0]}")
-
- if os.path.exists(cfg.resume_path) and cfg.resume_path.endswith("ckpt"):
- main_dict = torch.load(cfg.resume_path, map_location=device)["state_dict"]
-
- main_dict = {
- k: v for k, v in main_dict.items() if k in model_dict and v.shape == model_dict[k].shape and
- ("reconEngine" not in k) and ("normal_filter" not in k) and ("voxelization" not in k)
- }
- print(colored(f"Resume MLP weights from {cfg.resume_path}", "green"))
-
- if os.path.exists(cfg.normal_path) and cfg.normal_path.endswith("ckpt"):
- normal_dict = torch.load(cfg.normal_path, map_location=device)["state_dict"]
-
- for key in normal_dict.keys():
- normal_dict = rename(normal_dict, key, key.replace("netG", "netG.normal_filter"))
-
- normal_dict = {k: v for k, v in normal_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
- print(colored(f"Resume normal model from {cfg.normal_path}", "green"))
-
- model_dict.update(main_dict)
- model_dict.update(normal_dict)
- model.load_state_dict(model_dict)
-
- model.netG = model.netG.to(device)
- model.reconEngine = model.reconEngine.to(device)
-
- model.netG.training = False
- model.netG.eval()
-
- del main_dict
- del normal_dict
- del model_dict
-
- torch.cuda.empty_cache()
-
- return model
+ losses["lapla"]["value"] = mesh_laplacian_smoothing(mesh, method="uniform")
def read_smpl_constants(folder):
@@ -706,8 +433,10 @@ def read_smpl_constants(folder):
smpl_vertex_code = np.float32(np.copy(smpl_vtx_std))
"""Load smpl faces & tetrahedrons"""
smpl_faces = np.loadtxt(os.path.join(folder, "faces.txt"), dtype=np.int32) - 1
- smpl_face_code = (smpl_vertex_code[smpl_faces[:, 0]] + smpl_vertex_code[smpl_faces[:, 1]] +
- smpl_vertex_code[smpl_faces[:, 2]]) / 3.0
+ smpl_face_code = (
+ smpl_vertex_code[smpl_faces[:, 0]] + smpl_vertex_code[smpl_faces[:, 1]] +
+ smpl_vertex_code[smpl_faces[:, 2]]
+ ) / 3.0
smpl_tetras = (np.loadtxt(os.path.join(folder, "tetrahedrons.txt"), dtype=np.int32) - 1)
return_dict = {
@@ -720,19 +449,6 @@ def read_smpl_constants(folder):
return return_dict
-def feat_select(feat, select):
-
- # feat [B, featx2, N]
- # select [B, 1, N]
- # return [B, feat, N]
-
- dim = feat.shape[1] // 2
- idx = torch.tile((1 - select), (1, dim, 1)) * dim + torch.arange(0, dim).unsqueeze(0).unsqueeze(2).type_as(select)
- feat_select = torch.gather(feat, 1, idx.long())
-
- return feat_select
-
-
def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel=1):
"""get the visibility of vertices
@@ -771,7 +487,9 @@ def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel
for idx in range(N_body):
Num_faces = len(faces[idx])
- vis_vertices_id = torch.unique(faces[idx][torch.unique(pix_to_face[idx][pix_to_face[idx] != -1]) - Num_faces * idx, :])
+ vis_vertices_id = torch.unique(
+ faces[idx][torch.unique(pix_to_face[idx][pix_to_face[idx] != -1]) - Num_faces * idx, :]
+ )
vis_mask[idx, vis_vertices_id] = 1.0
# print("------------------------\n")
@@ -825,7 +543,7 @@ def orthogonal(points, calibrations, transforms=None):
"""
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
- pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
+ pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
if transforms is not None:
scale = transforms[:2, :2]
shift = transforms[:2, 2:3]
@@ -925,37 +643,14 @@ def compute_normal_batch(vertices, faces):
return vert_norm
-def calculate_mIoU(outputs, labels):
-
- SMOOTH = 1e-6
-
- outputs = outputs.int()
- labels = labels.int()
-
- intersection = ((outputs & labels).float().sum()) # Will be zero if Truth=0 or Prediction=0
- union = (outputs | labels).float().sum() # Will be zzero if both are 0
-
- iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0
-
- thresholded = (torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10) # This is equal to comparing with thresolds
-
- return (thresholded.mean().detach().cpu().numpy()
- ) # Or thresholded.mean() if you are interested in average across the batch
-
-
-def add_alpha(colors, alpha=0.7):
-
- colors_pad = np.pad(colors, ((0, 0), (0, 1)), mode="constant", constant_values=alpha)
-
- return colors_pad
-
-
def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"):
font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf")
font = ImageFont.truetype(font_path, 30)
grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0), nrow=nrow, padding=0)
- grid_img = Image.fromarray(((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 * 255.0).astype(np.uint8))
+ grid_img = Image.fromarray(
+ ((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 * 255.0).astype(np.uint8)
+ )
if False:
# add text
@@ -965,16 +660,20 @@ def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"):
draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font)
if type == "smpl":
- for col_id, col_txt in enumerate([
+ for col_id, col_txt in enumerate(
+ [
"image",
"smpl-norm(render)",
"cloth-norm(pred)",
"diff-norm",
"diff-mask",
- ]):
+ ]
+ ):
draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
elif type == "cloth":
- for col_id, col_txt in enumerate(["image", "cloth-norm(recon)", "cloth-norm(pred)", "diff-norm"]):
+ for col_id, col_txt in enumerate(
+ ["image", "cloth-norm(recon)", "cloth-norm(pred)", "diff-norm"]
+ ):
draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
for col_id, col_txt in enumerate(["0", "90", "180", "270"]):
draw.text(
@@ -996,12 +695,9 @@ def clean_mesh(verts, faces):
device = verts.device
mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(), faces.detach().cpu().numpy())
- mesh_lst = mesh_lst.split(only_watertight=False)
- comp_num = [mesh.vertices.shape[0] for mesh in mesh_lst]
-
- mesh_clean = mesh_lst[comp_num.index(max(comp_num))]
- final_verts = torch.as_tensor(mesh_clean.vertices).float().to(device)
- final_faces = torch.as_tensor(mesh_clean.faces).long().to(device)
+ largest_mesh = keep_largest(mesh_lst)
+ final_verts = torch.as_tensor(largest_mesh.vertices).float().to(device)
+ final_faces = torch.as_tensor(largest_mesh.faces).long().to(device)
return final_verts, final_faces
diff --git a/lib/net/BasePIFuNet.py b/lib/net/BasePIFuNet.py
index 6793a1d771fe6be62d38d9a9ea621002195b8ab9..eb18dbb3245d57c9e030c18094322a58e874db93 100644
--- a/lib/net/BasePIFuNet.py
+++ b/lib/net/BasePIFuNet.py
@@ -21,11 +21,10 @@ from .geometry import index, orthogonal, perspective
class BasePIFuNet(pl.LightningModule):
-
def __init__(
- self,
- projection_mode="orthogonal",
- error_term=nn.MSELoss(),
+ self,
+ projection_mode="orthogonal",
+ error_term=nn.MSELoss(),
):
"""
:param projection_mode:
diff --git a/lib/net/Discriminator.py b/lib/net/Discriminator.py
index 83dc1ac393e2fb9130b3f8904f84e76f83329f98..c60acdde000d414c78af0705ba268af3117c6ec9 100644
--- a/lib/net/Discriminator.py
+++ b/lib/net/Discriminator.py
@@ -9,17 +9,18 @@ from lib.torch_utils.ops.native_ops import FusedLeakyReLU, fused_leaky_relu, upf
class DiscriminatorHead(nn.Module):
-
def __init__(self, in_channel, disc_stddev=False):
super().__init__()
self.disc_stddev = disc_stddev
stddev_dim = 1 if disc_stddev else 0
- self.conv_stddev = ConvLayer2d(in_channel=in_channel + stddev_dim,
- out_channel=in_channel,
- kernel_size=3,
- activate=True)
+ self.conv_stddev = ConvLayer2d(
+ in_channel=in_channel + stddev_dim,
+ out_channel=in_channel,
+ kernel_size=3,
+ activate=True
+ )
self.final_linear = nn.Sequential(
nn.Flatten(),
@@ -32,8 +33,8 @@ class DiscriminatorHead(nn.Module):
inv_perm = torch.argsort(perm)
batch, channel, height, width = x.shape
- x = x[
- perm] # shuffle inputs so that all views in a single trajectory don't get put together
+ x = x[perm
+ ] # shuffle inputs so that all views in a single trajectory don't get put together
group = min(batch, stddev_group)
stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width)
@@ -41,7 +42,7 @@ class DiscriminatorHead(nn.Module):
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
- stddev = stddev[inv_perm] # reorder inputs
+ stddev = stddev[inv_perm] # reorder inputs
x = x[inv_perm]
out = torch.cat([x, stddev], 1)
@@ -56,7 +57,6 @@ class DiscriminatorHead(nn.Module):
class ConvDecoder(nn.Module):
-
def __init__(self, in_channel, out_channel, in_res, out_res):
super().__init__()
@@ -68,20 +68,22 @@ class ConvDecoder(nn.Module):
for i in range(log_size_in, log_size_out):
out_ch = in_ch // 2
self.layers.append(
- ConvLayer2d(in_channel=in_ch,
- out_channel=out_ch,
- kernel_size=3,
- upsample=True,
- bias=True,
- activate=True))
+ ConvLayer2d(
+ in_channel=in_ch,
+ out_channel=out_ch,
+ kernel_size=3,
+ upsample=True,
+ bias=True,
+ activate=True
+ )
+ )
in_ch = out_ch
self.layers.append(
- ConvLayer2d(in_channel=in_ch,
- out_channel=out_channel,
- kernel_size=3,
- bias=True,
- activate=False))
+ ConvLayer2d(
+ in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False
+ )
+ )
self.layers = nn.Sequential(*self.layers)
def forward(self, x):
@@ -89,7 +91,6 @@ class ConvDecoder(nn.Module):
class StyleDiscriminator(nn.Module):
-
def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs):
super().__init__()
@@ -104,7 +105,8 @@ class StyleDiscriminator(nn.Module):
for i in range(log_size_in, log_size_out, -1):
out_channels = int(min(in_channels * 2, ch_max))
self.layers.append(
- ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True))
+ ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True)
+ )
in_channels = out_channels
self.layers = nn.Sequential(*self.layers)
@@ -147,7 +149,6 @@ class Blur(nn.Module):
Upsample factor.
"""
-
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
@@ -177,7 +178,6 @@ class Upsample(nn.Module):
Upsampling factor.
"""
-
def __init__(self, kernel=[1, 3, 3, 1], factor=2):
super().__init__()
@@ -208,7 +208,6 @@ class Downsample(nn.Module):
Downsampling factor.
"""
-
def __init__(self, kernel=[1, 3, 3, 1], factor=2):
super().__init__()
@@ -250,7 +249,6 @@ class EqualLinear(nn.Module):
Apply leakyReLU activation.
"""
-
def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False):
super().__init__()
@@ -300,7 +298,6 @@ class EqualConv2d(nn.Module):
Use bias term.
"""
-
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
super().__init__()
@@ -316,16 +313,20 @@ class EqualConv2d(nn.Module):
self.bias = None
def forward(self, input):
- out = F.conv2d(input,
- self.weight * self.scale,
- bias=self.bias,
- stride=self.stride,
- padding=self.padding)
+ out = F.conv2d(
+ input,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding
+ )
return out
def __repr__(self):
- return (f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
- f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})")
+ return (
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
+ )
class EqualConvTranspose2d(nn.Module):
@@ -353,15 +354,16 @@ class EqualConvTranspose2d(nn.Module):
Use bias term.
"""
-
- def __init__(self,
- in_channel,
- out_channel,
- kernel_size,
- stride=1,
- padding=0,
- output_padding=0,
- bias=True):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ bias=True
+ ):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size))
@@ -388,12 +390,13 @@ class EqualConvTranspose2d(nn.Module):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},'
- f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})')
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},'
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
+ )
class ConvLayer2d(nn.Sequential):
-
def __init__(
self,
in_channel,
@@ -415,12 +418,15 @@ class ConvLayer2d(nn.Sequential):
pad1 = p // 2 + 1
layers.append(
- EqualConvTranspose2d(in_channel,
- out_channel,
- kernel_size,
- padding=0,
- stride=2,
- bias=bias and not activate))
+ EqualConvTranspose2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=0,
+ stride=2,
+ bias=bias and not activate
+ )
+ )
layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor))
if downsample:
@@ -431,23 +437,29 @@ class ConvLayer2d(nn.Sequential):
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
layers.append(
- EqualConv2d(in_channel,
- out_channel,
- kernel_size,
- padding=0,
- stride=2,
- bias=bias and not activate))
+ EqualConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=0,
+ stride=2,
+ bias=bias and not activate
+ )
+ )
if (not downsample) and (not upsample):
padding = kernel_size // 2
layers.append(
- EqualConv2d(in_channel,
- out_channel,
- kernel_size,
- padding=padding,
- stride=1,
- bias=bias and not activate))
+ EqualConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=padding,
+ stride=1,
+ bias=bias and not activate
+ )
+ )
if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias))
@@ -472,7 +484,6 @@ class ConvResBlock2d(nn.Module):
Apply downsampling via strided convolution in the second conv.
"""
-
def __init__(self, in_channel, out_channel, upsample=False, downsample=False):
super().__init__()
diff --git a/lib/net/FBNet.py b/lib/net/FBNet.py
index 0122d2fb47aa7316075d29f10cd8fe8012e7862a..f4797667d4d800019967d7ee2ed944ec8b8550fc 100644
--- a/lib/net/FBNet.py
+++ b/lib/net/FBNet.py
@@ -51,17 +51,17 @@ def get_norm_layer(norm_type="instance"):
def define_G(
- input_nc,
- output_nc,
- ngf,
- netG,
- n_downsample_global=3,
- n_blocks_global=9,
- n_local_enhancers=1,
- n_blocks_local=3,
- norm="instance",
- gpu_ids=[],
- last_op=nn.Tanh(),
+ input_nc,
+ output_nc,
+ ngf,
+ netG,
+ n_downsample_global=3,
+ n_blocks_global=9,
+ n_local_enhancers=1,
+ n_blocks_local=3,
+ norm="instance",
+ gpu_ids=[],
+ last_op=nn.Tanh(),
):
norm_layer = get_norm_layer(norm_type=norm)
if netG == "global":
@@ -97,17 +97,20 @@ def define_G(
return netG
-def define_D(input_nc,
- ndf,
- n_layers_D,
- norm='instance',
- use_sigmoid=False,
- num_D=1,
- getIntermFeat=False,
- gpu_ids=[]):
+def define_D(
+ input_nc,
+ ndf,
+ n_layers_D,
+ norm='instance',
+ use_sigmoid=False,
+ num_D=1,
+ getIntermFeat=False,
+ gpu_ids=[]
+):
norm_layer = get_norm_layer(norm_type=norm)
- netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D,
- getIntermFeat)
+ netD = MultiscaleDiscriminator(
+ input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat
+ )
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
netD.cuda(gpu_ids[0])
@@ -129,7 +132,6 @@ def print_network(net):
# Generator
##############################################################################
class LocalEnhancer(pl.LightningModule):
-
def __init__(
self,
input_nc,
@@ -155,8 +157,9 @@ class LocalEnhancer(pl.LightningModule):
n_blocks_global,
norm_layer,
).model
- model_global = [model_global[i] for i in range(len(model_global) - 3)
- ] # get rid of final convolution layers
+ model_global = [
+ model_global[i] for i in range(len(model_global) - 3)
+ ] # get rid of final convolution layers
self.model = nn.Sequential(*model_global)
###### local enhancer layers #####
@@ -224,17 +227,16 @@ class LocalEnhancer(pl.LightningModule):
class GlobalGenerator(pl.LightningModule):
-
def __init__(
- self,
- input_nc,
- output_nc,
- ngf=64,
- n_downsampling=3,
- n_blocks=9,
- norm_layer=nn.BatchNorm2d,
- padding_type="reflect",
- last_op=nn.Tanh(),
+ self,
+ input_nc,
+ output_nc,
+ ngf=64,
+ n_downsampling=3,
+ n_blocks=9,
+ norm_layer=nn.BatchNorm2d,
+ padding_type="reflect",
+ last_op=nn.Tanh(),
):
assert n_blocks >= 0
super(GlobalGenerator, self).__init__()
@@ -296,42 +298,49 @@ class GlobalGenerator(pl.LightningModule):
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
-
- def __init__(self,
- input_nc,
- ndf=64,
- n_layers=3,
- norm_layer=nn.BatchNorm2d,
- use_sigmoid=False,
- getIntermFeat=False):
+ def __init__(
+ self,
+ input_nc,
+ ndf=64,
+ n_layers=3,
+ norm_layer=nn.BatchNorm2d,
+ use_sigmoid=False,
+ getIntermFeat=False
+ ):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
- sequence = [[
- nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
- nn.LeakyReLU(0.2, True)
- ]]
+ sequence = [
+ [
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
+ nn.LeakyReLU(0.2, True)
+ ]
+ ]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
- sequence += [[
- nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
- norm_layer(nf),
- nn.LeakyReLU(0.2, True)
- ]]
+ sequence += [
+ [
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
+ norm_layer(nf),
+ nn.LeakyReLU(0.2, True)
+ ]
+ ]
nf_prev = nf
nf = min(nf * 2, 512)
- sequence += [[
- nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
- norm_layer(nf),
- nn.LeakyReLU(0.2, True)
- ]]
+ sequence += [
+ [
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
+ norm_layer(nf),
+ nn.LeakyReLU(0.2, True)
+ ]
+ ]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
@@ -359,27 +368,30 @@ class NLayerDiscriminator(nn.Module):
class MultiscaleDiscriminator(pl.LightningModule):
-
- def __init__(self,
- input_nc,
- ndf=64,
- n_layers=3,
- norm_layer=nn.BatchNorm2d,
- use_sigmoid=False,
- num_D=3,
- getIntermFeat=False):
+ def __init__(
+ self,
+ input_nc,
+ ndf=64,
+ n_layers=3,
+ norm_layer=nn.BatchNorm2d,
+ use_sigmoid=False,
+ num_D=3,
+ getIntermFeat=False
+ ):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
- netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid,
- getIntermFeat)
+ netD = NLayerDiscriminator(
+ input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat
+ )
if getIntermFeat:
for j in range(n_layers + 2):
- setattr(self, 'scale' + str(i) + '_layer' + str(j),
- getattr(netD, 'model' + str(j)))
+ setattr(
+ self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))
+ )
else:
setattr(self, 'layer' + str(i), netD.model)
@@ -414,11 +426,11 @@ class MultiscaleDiscriminator(pl.LightningModule):
# Define a resnet block
class ResnetBlock(pl.LightningModule):
-
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
super(ResnetBlock, self).__init__()
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation,
- use_dropout)
+ self.conv_block = self.build_conv_block(
+ dim, padding_type, norm_layer, activation, use_dropout
+ )
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
conv_block = []
@@ -459,7 +471,6 @@ class ResnetBlock(pl.LightningModule):
class Encoder(pl.LightningModule):
-
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
super(Encoder, self).__init__()
self.output_nc = output_nc
@@ -510,18 +521,17 @@ class Encoder(pl.LightningModule):
inst_list = np.unique(inst.cpu().numpy().astype(int))
for i in inst_list:
for b in range(input.size()[0]):
- indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4
+ indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4
for j in range(self.output_nc):
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
- indices[:, 3],]
+ indices[:, 3], ]
mean_feat = torch.mean(output_ins).expand_as(output_ins)
outputs_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
- indices[:, 3],] = mean_feat
+ indices[:, 3], ] = mean_feat
return outputs_mean
class Vgg19(nn.Module):
-
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
@@ -555,7 +565,6 @@ class Vgg19(nn.Module):
class VGG19FeatLayer(nn.Module):
-
def __init__(self):
super(VGG19FeatLayer, self).__init__()
self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.eval()
@@ -593,7 +602,6 @@ class VGG19FeatLayer(nn.Module):
class VGGLoss(pl.LightningModule):
-
def __init__(self):
super(VGGLoss, self).__init__()
self.vgg = Vgg19().eval()
@@ -609,11 +617,7 @@ class VGGLoss(pl.LightningModule):
class GANLoss(pl.LightningModule):
-
- def __init__(self,
- use_lsgan=True,
- target_real_label=1.0,
- target_fake_label=0.0):
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
@@ -628,16 +632,18 @@ class GANLoss(pl.LightningModule):
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
- create_label = ((self.real_label_var is None) or
- (self.real_label_var.numel() != input.numel()))
+ create_label = (
+ (self.real_label_var is None) or (self.real_label_var.numel() != input.numel())
+ )
if create_label:
real_tensor = self.tensor(input.size()).fill_(self.real_label)
self.real_label_var = real_tensor
self.real_label_var.requires_grad = False
target_tensor = self.real_label_var
else:
- create_label = ((self.fake_label_var is None) or
- (self.fake_label_var.numel() != input.numel()))
+ create_label = (
+ (self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())
+ )
if create_label:
fake_tensor = self.tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = fake_tensor
@@ -659,7 +665,6 @@ class GANLoss(pl.LightningModule):
class IDMRFLoss(pl.LightningModule):
-
def __init__(self, featlayer=VGG19FeatLayer):
super(IDMRFLoss, self).__init__()
self.featlayer = featlayer()
@@ -678,7 +683,8 @@ class IDMRFLoss(pl.LightningModule):
patch_size = 1
patch_stride = 1
patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(
- 3, patch_size, patch_stride)
+ 3, patch_size, patch_stride
+ )
self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5)
dims = self.patches_OIHW.size()
self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5])
@@ -743,7 +749,8 @@ class IDMRFLoss(pl.LightningModule):
self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer])
for layer in self.feat_content_layers
]
- self.content_loss = functools.reduce(lambda x, y: x + y,
- content_loss_list) * self.lambda_content
+ self.content_loss = functools.reduce(
+ lambda x, y: x + y, content_loss_list
+ ) * self.lambda_content
return self.style_loss + self.content_loss
diff --git a/lib/net/GANLoss.py b/lib/net/GANLoss.py
index 9be907f5c3f74a3a05fd9a52913325ce54b09a9f..5d6711479980e89a3fc067b5ef579bb382eb29df 100644
--- a/lib/net/GANLoss.py
+++ b/lib/net/GANLoss.py
@@ -32,13 +32,12 @@ def logistic_loss(fake_pred, real_pred, mode):
def r1_loss(real_pred, real_img):
- (grad_real,) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)
+ (grad_real, ) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
class GANLoss(nn.Module):
-
def __init__(
self,
opt,
@@ -64,7 +63,7 @@ class GANLoss(nn.Module):
logits_fake = self.discriminator(disc_in_fake)
disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d')
-
+
log = {
"disc_loss": disc_loss.detach(),
"logits_real": logits_real.mean().detach(),
diff --git a/lib/net/IFGeoNet.py b/lib/net/IFGeoNet.py
index 195953d0ed91aa7663040dadad3a757bf1086699..a72be083da26093093fb2da46dade7ace3df1bae 100644
--- a/lib/net/IFGeoNet.py
+++ b/lib/net/IFGeoNet.py
@@ -8,20 +8,17 @@ from lib.dataset.mesh_util import read_smpl_constants, SMPLX
class SelfAttention(torch.nn.Module):
-
def __init__(self, in_channels, out_channels):
super().__init__()
- self.conv = nn.Conv3d(in_channels,
- out_channels,
- 3,
- padding=1,
- padding_mode='replicate')
- self.attention = nn.Conv3d(in_channels,
- out_channels,
- kernel_size=3,
- padding=1,
- padding_mode='replicate',
- bias=False)
+ self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate')
+ self.attention = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ padding_mode='replicate',
+ bias=False
+ )
with torch.no_grad():
self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
@@ -32,38 +29,45 @@ class SelfAttention(torch.nn.Module):
class IFGeoNet(nn.Module):
-
def __init__(self, cfg, hidden_dim=256):
super(IFGeoNet, self).__init__()
- self.conv_in_partial = nn.Conv3d(1, 16, 3, padding=1,
- padding_mode='replicate') # out: 256 ->m.p. 128
+ self.conv_in_partial = nn.Conv3d(
+ 1, 16, 3, padding=1, padding_mode='replicate'
+ ) # out: 256 ->m.p. 128
- self.conv_in_smpl = nn.Conv3d(1, 4, 3, padding=1,
- padding_mode='replicate') # out: 256 ->m.p. 128
+ self.conv_in_smpl = nn.Conv3d(
+ 1, 4, 3, padding=1, padding_mode='replicate'
+ ) # out: 256 ->m.p. 128
self.SA = SelfAttention(4, 4)
- self.conv_0_fusion = nn.Conv3d(16 + 4, 32, 3, padding=1,
- padding_mode='replicate') # out: 128
- self.conv_0_1_fusion = nn.Conv3d(32, 32, 3, padding=1,
- padding_mode='replicate') # out: 128 ->m.p. 64
-
- self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
- self.conv_0_1 = nn.Conv3d(32, 32, 3, padding=1,
- padding_mode='replicate') # out: 128 ->m.p. 64
-
- self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
- self.conv_1_1 = nn.Conv3d(64, 64, 3, padding=1,
- padding_mode='replicate') # out: 64 -> mp 32
-
- self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
- self.conv_2_1 = nn.Conv3d(128, 128, 3, padding=1,
- padding_mode='replicate') # out: 32 -> mp 16
- self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
- self.conv_3_1 = nn.Conv3d(128, 128, 3, padding=1,
- padding_mode='replicate') # out: 16 -> mp 8
- self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
- self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
+ self.conv_0_fusion = nn.Conv3d(
+ 16 + 4, 32, 3, padding=1, padding_mode='replicate'
+ ) # out: 128
+ self.conv_0_1_fusion = nn.Conv3d(
+ 32, 32, 3, padding=1, padding_mode='replicate'
+ ) # out: 128 ->m.p. 64
+
+ self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
+ self.conv_0_1 = nn.Conv3d(
+ 32, 32, 3, padding=1, padding_mode='replicate'
+ ) # out: 128 ->m.p. 64
+
+ self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
+ self.conv_1_1 = nn.Conv3d(
+ 64, 64, 3, padding=1, padding_mode='replicate'
+ ) # out: 64 -> mp 32
+
+ self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
+ self.conv_2_1 = nn.Conv3d(
+ 128, 128, 3, padding=1, padding_mode='replicate'
+ ) # out: 32 -> mp 16
+ self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
+ self.conv_3_1 = nn.Conv3d(
+ 128, 128, 3, padding=1, padding_mode='replicate'
+ ) # out: 16 -> mp 8
+ self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
+ self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
@@ -97,21 +101,21 @@ class IFGeoNet(nn.Module):
smooth_kernel_size=7,
batch_size=cfg.batch_size,
)
-
+
self.l1_loss = nn.SmoothL1Loss()
def forward(self, batch):
-
+
if "body_voxels" in batch.keys():
x_smpl = batch["body_voxels"]
else:
with torch.no_grad():
self.voxelization.update_param(batch["voxel_faces"])
- x_smpl = self.voxelization(batch["voxel_verts"])[:, 0] #[B, 128, 128, 128]
-
+ x_smpl = self.voxelization(batch["voxel_verts"])[:, 0] #[B, 128, 128, 128]
+
p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
- batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
- x = batch["depth_voxels"] #[B, 128, 128, 128]
+ batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
+ x = batch["depth_voxels"] #[B, 128, 128, 128]
x = x.unsqueeze(1)
x_smpl = x_smpl.unsqueeze(1)
@@ -119,63 +123,67 @@ class IFGeoNet(nn.Module):
p = p.unsqueeze(1).unsqueeze(1)
# partial inputs feature extraction
- feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners = True)
+ feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True)
net_partial = self.actvn(self.conv_in_partial(x))
net_partial = self.partial_conv_in_bn(net_partial)
- net_partial = self.maxpool(net_partial) # out 64
+ net_partial = self.maxpool(net_partial) # out 64
# smpl inputs feature extraction
# feature_0_smpl = F.grid_sample(x_smpl, p, padding_mode='border', align_corners = True)
net_smpl = self.actvn(self.conv_in_smpl(x_smpl))
net_smpl = self.smpl_conv_in_bn(net_smpl)
- net_smpl = self.maxpool(net_smpl) # out 64
+ net_smpl = self.maxpool(net_smpl) # out 64
net_smpl = self.SA(net_smpl)
-
+
# Feature fusion
net = self.actvn(self.conv_0_fusion(torch.concat([net_partial, net_smpl], dim=1)))
net = self.actvn(self.conv_0_1_fusion(net))
net = self.conv0_1_bn_fusion(net)
- feature_1_fused = F.grid_sample(net, p, padding_mode='border', align_corners = True)
+ feature_1_fused = F.grid_sample(net, p, padding_mode='border', align_corners=True)
# net = self.maxpool(net) # out 64
net = self.actvn(self.conv_0(net))
net = self.actvn(self.conv_0_1(net))
net = self.conv0_1_bn(net)
- feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners = True)
- net = self.maxpool(net) # out 32
+ feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
+ net = self.maxpool(net) # out 32
net = self.actvn(self.conv_1(net))
net = self.actvn(self.conv_1_1(net))
net = self.conv1_1_bn(net)
- feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners = True)
- net = self.maxpool(net) # out 16
+ feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
+ net = self.maxpool(net) # out 16
net = self.actvn(self.conv_2(net))
net = self.actvn(self.conv_2_1(net))
net = self.conv2_1_bn(net)
- feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners = True)
- net = self.maxpool(net) # out 8
+ feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
+ net = self.maxpool(net) # out 8
net = self.actvn(self.conv_3(net))
net = self.actvn(self.conv_3_1(net))
net = self.conv3_1_bn(net)
- feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners = True)
- net = self.maxpool(net) # out 4
+ feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
+ net = self.maxpool(net) # out 4
net = self.actvn(self.conv_4(net))
net = self.actvn(self.conv_4_1(net))
net = self.conv4_1_bn(net)
- feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners = True) # out 2
+ feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2
# here every channel corresponse to one feature.
- features = torch.cat((feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4,
- feature_5, feature_6),
- dim=1) # (B, features, 1,7,sample_num)
+ features = torch.cat(
+ (
+ feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
+ feature_6
+ ),
+ dim=1
+ ) # (B, features, 1,7,sample_num)
shape = features.shape
features = torch.reshape(
- features,
- (shape[0], shape[1] * shape[3], shape[4])) # (B, featues_per_sample, samples_num)
+ features, (shape[0], shape[1] * shape[3], shape[4])
+ ) # (B, featues_per_sample, samples_num)
# (B, featue_size, samples_num)
features = torch.cat((features, p_features), dim=1)
@@ -183,7 +191,7 @@ class IFGeoNet(nn.Module):
net = self.actvn(self.fc_1(net))
net = self.actvn(self.fc_2(net))
net = self.fc_out(net).squeeze(1)
-
+
return net
def compute_loss(self, prds, tgts):
diff --git a/lib/net/IFGeoNet_nobody.py b/lib/net/IFGeoNet_nobody.py
index bf83b5c09557294a025f975241068f3cf03d19b6..ceda5dedfcf09167f670a66a91b152f6181631cc 100644
--- a/lib/net/IFGeoNet_nobody.py
+++ b/lib/net/IFGeoNet_nobody.py
@@ -8,16 +8,17 @@ from lib.dataset.mesh_util import read_smpl_constants, SMPLX
class SelfAttention(torch.nn.Module):
-
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate')
- self.attention = nn.Conv3d(in_channels,
- out_channels,
- kernel_size=3,
- padding=1,
- padding_mode='replicate',
- bias=False)
+ self.attention = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ padding_mode='replicate',
+ bias=False
+ )
with torch.no_grad():
self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
@@ -28,34 +29,39 @@ class SelfAttention(torch.nn.Module):
class IFGeoNet(nn.Module):
-
def __init__(self, cfg, hidden_dim=256):
super(IFGeoNet, self).__init__()
- self.conv_in_partial = nn.Conv3d(1, 16, 3, padding=1,
- padding_mode='replicate') # out: 256 ->m.p. 128
+ self.conv_in_partial = nn.Conv3d(
+ 1, 16, 3, padding=1, padding_mode='replicate'
+ ) # out: 256 ->m.p. 128
self.SA = SelfAttention(4, 4)
- self.conv_0_fusion = nn.Conv3d(16, 32, 3, padding=1, padding_mode='replicate') # out: 128
- self.conv_0_1_fusion = nn.Conv3d(32, 32, 3, padding=1,
- padding_mode='replicate') # out: 128 ->m.p. 64
-
- self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
- self.conv_0_1 = nn.Conv3d(32, 32, 3, padding=1,
- padding_mode='replicate') # out: 128 ->m.p. 64
-
- self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
- self.conv_1_1 = nn.Conv3d(64, 64, 3, padding=1,
- padding_mode='replicate') # out: 64 -> mp 32
-
- self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
- self.conv_2_1 = nn.Conv3d(128, 128, 3, padding=1,
- padding_mode='replicate') # out: 32 -> mp 16
- self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
- self.conv_3_1 = nn.Conv3d(128, 128, 3, padding=1,
- padding_mode='replicate') # out: 16 -> mp 8
- self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
- self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
+ self.conv_0_fusion = nn.Conv3d(16, 32, 3, padding=1, padding_mode='replicate') # out: 128
+ self.conv_0_1_fusion = nn.Conv3d(
+ 32, 32, 3, padding=1, padding_mode='replicate'
+ ) # out: 128 ->m.p. 64
+
+ self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
+ self.conv_0_1 = nn.Conv3d(
+ 32, 32, 3, padding=1, padding_mode='replicate'
+ ) # out: 128 ->m.p. 64
+
+ self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
+ self.conv_1_1 = nn.Conv3d(
+ 64, 64, 3, padding=1, padding_mode='replicate'
+ ) # out: 64 -> mp 32
+
+ self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
+ self.conv_2_1 = nn.Conv3d(
+ 128, 128, 3, padding=1, padding_mode='replicate'
+ ) # out: 32 -> mp 16
+ self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
+ self.conv_3_1 = nn.Conv3d(
+ 128, 128, 3, padding=1, padding_mode='replicate'
+ ) # out: 16 -> mp 8
+ self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
+ self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
@@ -95,8 +101,8 @@ class IFGeoNet(nn.Module):
def forward(self, batch):
p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
- batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
- x = batch["depth_voxels"] #[B, 128, 128, 128]
+ batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
+ x = batch["depth_voxels"] #[B, 128, 128, 128]
x = x.unsqueeze(1)
p_features = p.transpose(1, -1)
@@ -106,7 +112,7 @@ class IFGeoNet(nn.Module):
feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True)
net_partial = self.actvn(self.conv_in_partial(x))
net_partial = self.partial_conv_in_bn(net_partial)
- net_partial = self.maxpool(net_partial) # out 64
+ net_partial = self.maxpool(net_partial) # out 64
# Feature fusion
net = self.actvn(self.conv_0_fusion(net_partial))
@@ -119,40 +125,44 @@ class IFGeoNet(nn.Module):
net = self.actvn(self.conv_0_1(net))
net = self.conv0_1_bn(net)
feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
- net = self.maxpool(net) # out 32
+ net = self.maxpool(net) # out 32
net = self.actvn(self.conv_1(net))
net = self.actvn(self.conv_1_1(net))
net = self.conv1_1_bn(net)
feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
- net = self.maxpool(net) # out 16
+ net = self.maxpool(net) # out 16
net = self.actvn(self.conv_2(net))
net = self.actvn(self.conv_2_1(net))
net = self.conv2_1_bn(net)
feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
- net = self.maxpool(net) # out 8
+ net = self.maxpool(net) # out 8
net = self.actvn(self.conv_3(net))
net = self.actvn(self.conv_3_1(net))
net = self.conv3_1_bn(net)
feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
- net = self.maxpool(net) # out 4
+ net = self.maxpool(net) # out 4
net = self.actvn(self.conv_4(net))
net = self.actvn(self.conv_4_1(net))
net = self.conv4_1_bn(net)
- feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2
+ feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2
# here every channel corresponse to one feature.
- features = torch.cat((feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4,
- feature_5, feature_6),
- dim=1) # (B, features, 1,7,sample_num)
+ features = torch.cat(
+ (
+ feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
+ feature_6
+ ),
+ dim=1
+ ) # (B, features, 1,7,sample_num)
shape = features.shape
features = torch.reshape(
- features,
- (shape[0], shape[1] * shape[3], shape[4])) # (B, featues_per_sample, samples_num)
+ features, (shape[0], shape[1] * shape[3], shape[4])
+ ) # (B, featues_per_sample, samples_num)
# (B, featue_size, samples_num)
features = torch.cat((features, p_features), dim=1)
@@ -167,4 +177,4 @@ class IFGeoNet(nn.Module):
loss = self.l1_loss(prds, tgts)
- return loss
\ No newline at end of file
+ return loss
diff --git a/lib/net/NormalNet.py b/lib/net/NormalNet.py
index 71620076ace11b3542763f9b5285ec58a86bb949..a065840ed859137e72ba1e37a40c636da7c32e6f 100644
--- a/lib/net/NormalNet.py
+++ b/lib/net/NormalNet.py
@@ -35,7 +35,6 @@ class NormalNet(BasePIFuNet):
4. Classification.
5. During training, error is calculated on all stacks.
"""
-
def __init__(self, cfg):
super(NormalNet, self).__init__()
@@ -65,9 +64,11 @@ class NormalNet(BasePIFuNet):
item[0] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"
]
self.in_nmlF_dim = sum(
- [item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"])
+ [item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"]
+ )
self.in_nmlB_dim = sum(
- [item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"])
+ [item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"]
+ )
self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
@@ -134,18 +135,20 @@ class NormalNet(BasePIFuNet):
if 'mrf' in self.F_losses:
mrf_F_loss = self.mrf_loss(
F.interpolate(prd_F, scale_factor=scale_factor, mode='bicubic', align_corners=True),
- F.interpolate(tgt_F, scale_factor=scale_factor, mode='bicubic', align_corners=True))
+ F.interpolate(tgt_F, scale_factor=scale_factor, mode='bicubic', align_corners=True)
+ )
total_loss["netF"] += self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
total_loss["mrf_F"] = self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
if 'mrf' in self.B_losses:
mrf_B_loss = self.mrf_loss(
F.interpolate(prd_B, scale_factor=scale_factor, mode='bicubic', align_corners=True),
- F.interpolate(tgt_B, scale_factor=scale_factor, mode='bicubic', align_corners=True))
+ F.interpolate(tgt_B, scale_factor=scale_factor, mode='bicubic', align_corners=True)
+ )
total_loss["netB"] += self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
total_loss["mrf_B"] = self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
if 'gan' in self.ALL_losses:
-
+
total_loss["netD"] = 0.0
pred_fake = self.netD.forward(prd_B)
@@ -154,8 +157,8 @@ class NormalNet(BasePIFuNet):
loss_D_real = self.gan_loss(pred_real, True)
loss_G_fake = self.gan_loss(pred_fake, True)
- total_loss["netD"] += 0.5 * (
- loss_D_fake + loss_D_real) * self.B_losses_ratio[self.B_losses.index('gan')]
+ total_loss["netD"] += 0.5 * (loss_D_fake + loss_D_real
+ ) * self.B_losses_ratio[self.B_losses.index('gan')]
total_loss["D_fake"] = loss_D_fake * self.B_losses_ratio[self.B_losses.index('gan')]
total_loss["D_real"] = loss_D_real * self.B_losses_ratio[self.B_losses.index('gan')]
@@ -167,8 +170,8 @@ class NormalNet(BasePIFuNet):
for i in range(2):
for j in range(len(pred_fake[i]) - 1):
loss_G_GAN_Feat += self.l1_loss(pred_fake[i][j], pred_real[i][j].detach())
- total_loss["netB"] += loss_G_GAN_Feat * self.B_losses_ratio[self.B_losses.index(
- 'gan_feat')]
+ total_loss["netB"] += loss_G_GAN_Feat * self.B_losses_ratio[
+ self.B_losses.index('gan_feat')]
total_loss["G_GAN_Feat"] = loss_G_GAN_Feat * self.B_losses_ratio[
self.B_losses.index('gan_feat')]
diff --git a/lib/net/geometry.py b/lib/net/geometry.py
index af6bf154723addf0565820468f32d8a2efa980a1..6d7d82d2cb6b760596d1bbf70804e542999f802e 100644
--- a/lib/net/geometry.py
+++ b/lib/net/geometry.py
@@ -19,12 +19,12 @@ import numpy as np
import numbers
from torch.nn import functional as F
from einops.einops import rearrange
-
"""
Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
"""
+
def quaternion_to_rotation_matrix(quat):
"""Convert quaternion coefficients to rotation matrix.
Args:
@@ -42,11 +42,13 @@ def quaternion_to_rotation_matrix(quat):
wx, wy, wz = w * x, w * y, w * z
xy, xz, yz = x * y, x * z, y * z
- rotMat = torch.stack([
- w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
- 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
- ],
- dim=1).view(B, 3, 3)
+ rotMat = torch.stack(
+ [
+ w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
+ 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
+ ],
+ dim=1
+ ).view(B, 3, 3)
return rotMat
@@ -56,7 +58,7 @@ def index(feat, uv):
:param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1]
:return: [B, C, N] image features at the uv coordinates
"""
- uv = uv.transpose(1, 2) # [B, N, 2]
+ uv = uv.transpose(1, 2) # [B, N, 2]
(B, N, _) = uv.shape
C = feat.shape[1]
@@ -64,14 +66,14 @@ def index(feat, uv):
if uv.shape[-1] == 3:
# uv = uv[:,:,[2,1,0]]
# uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...]
- uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3]
+ uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3]
else:
- uv = uv.unsqueeze(2) # [B, N, 1, 2]
+ uv = uv.unsqueeze(2) # [B, N, 1, 2]
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
# for old versions, simply remove the aligned_corners argument.
- samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
- return samples.view(B, C, N) # [B, C, N]
+ samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
+ return samples.view(B, C, N) # [B, C, N]
def orthogonal(points, calibrations, transforms=None):
@@ -84,7 +86,7 @@ def orthogonal(points, calibrations, transforms=None):
"""
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
- pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
+ pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
if transforms is not None:
scale = transforms[:2, :2]
shift = transforms[:2, 2:3]
@@ -102,7 +104,7 @@ def perspective(points, calibrations, transforms=None):
"""
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
- homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
+ homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
xy = homo[:, :2, :] / homo[:, 2:3, :]
if transforms is not None:
scale = transforms[:2, :2]
@@ -187,7 +189,8 @@ def rotation_matrix_to_angle_axis(rotation_matrix):
if rotation_matrix.shape[1:] == (3, 3):
rot_mat = rotation_matrix.reshape(-1, 3, 3)
hom = torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device).reshape(
- 1, 3, 1).expand(rot_mat.shape[0], -1, -1)
+ 1, 3, 1
+ ).expand(rot_mat.shape[0], -1, -1)
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
@@ -222,8 +225,9 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
if not quaternion.shape[-1] == 4:
- raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format(
- quaternion.shape))
+ raise ValueError(
+ "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
+ )
# unpack input and compute conversion
q1: torch.Tensor = quaternion[..., 1]
q2: torch.Tensor = quaternion[..., 2]
@@ -276,11 +280,13 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
if len(rotation_matrix.shape) > 3:
- raise ValueError("Input size must be a three dimensional tensor. Got {}".format(
- rotation_matrix.shape))
+ raise ValueError(
+ "Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape)
+ )
if not rotation_matrix.shape[-2:] == (3, 4):
- raise ValueError("Input size must be a N x 3 x 4 tensor. Got {}".format(
- rotation_matrix.shape))
+ raise ValueError(
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(rotation_matrix.shape)
+ )
rmat_t = torch.transpose(rotation_matrix, 1, 2)
@@ -347,8 +353,10 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
- q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa
- + t3_rep * mask_c3) # noqa
+ q /= torch.sqrt(
+ t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa
+ + t3_rep * mask_c3
+ ) # noqa
q *= 0.5
return q
@@ -389,6 +397,7 @@ def rot6d_to_rotmat(x):
mat = torch.stack((b1, b2, b3), dim=-1)
return mat
+
def rotmat_to_rot6d(x):
"""Convert 3x3 rotation matrix to 6D rotation representation.
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
@@ -402,6 +411,7 @@ def rotmat_to_rot6d(x):
x = x.reshape(batch_size, 6)
return x
+
def rotmat_to_angle(x):
"""Convert rotation to one-D angle.
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
@@ -440,12 +450,9 @@ def projection(pred_joints, pred_camera, retain_z=False):
return pred_keypoints_2d
-def perspective_projection(points,
- rotation,
- translation,
- focal_length,
- camera_center,
- retain_z=False):
+def perspective_projection(
+ points, rotation, translation, focal_length, camera_center, retain_z=False
+):
"""
This function computes the perspective projection of a set of points.
Input:
@@ -501,10 +508,12 @@ def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_si
weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
# least squares
- Q = np.array([
- F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
- O - np.reshape(joints_2d, -1)
- ]).T
+ Q = np.array(
+ [
+ F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
+ O - np.reshape(joints_2d, -1)
+ ]
+ ).T
c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
# weighted least squares
@@ -558,15 +567,12 @@ def estimate_translation(S, joints_2d, focal_length=5000., img_size=224., use_al
S_i = S[i]
joints_i = joints_2d[i]
conf_i = joints_conf[i]
- trans[i] = estimate_translation_np(S_i,
- joints_i,
- conf_i,
- focal_length=focal_length[i],
- img_size=img_size[i])
+ trans[i] = estimate_translation_np(
+ S_i, joints_i, conf_i, focal_length=focal_length[i], img_size=img_size[i]
+ )
return torch.from_numpy(trans).to(device)
-
def Rot_y(angle, category="torch", prepend_dim=True, device=None):
"""Rotate around y-axis by angle
Args:
@@ -574,11 +580,13 @@ def Rot_y(angle, category="torch", prepend_dim=True, device=None):
prepend_dim: prepend an extra dimension
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
"""
- m = np.array([
- [np.cos(angle), 0.0, np.sin(angle)],
- [0.0, 1.0, 0.0],
- [-np.sin(angle), 0.0, np.cos(angle)],
- ])
+ m = np.array(
+ [
+ [np.cos(angle), 0.0, np.sin(angle)],
+ [0.0, 1.0, 0.0],
+ [-np.sin(angle), 0.0, np.cos(angle)],
+ ]
+ )
if category == "torch":
if prepend_dim:
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -600,11 +608,13 @@ def Rot_x(angle, category="torch", prepend_dim=True, device=None):
prepend_dim: prepend an extra dimension
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
"""
- m = np.array([
- [1.0, 0.0, 0.0],
- [0.0, np.cos(angle), -np.sin(angle)],
- [0.0, np.sin(angle), np.cos(angle)],
- ])
+ m = np.array(
+ [
+ [1.0, 0.0, 0.0],
+ [0.0, np.cos(angle), -np.sin(angle)],
+ [0.0, np.sin(angle), np.cos(angle)],
+ ]
+ )
if category == "torch":
if prepend_dim:
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -626,11 +636,13 @@ def Rot_z(angle, category="torch", prepend_dim=True, device=None):
prepend_dim: prepend an extra dimension
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
"""
- m = np.array([
- [np.cos(angle), -np.sin(angle), 0.0],
- [np.sin(angle), np.cos(angle), 0.0],
- [0.0, 0.0, 1.0],
- ])
+ m = np.array(
+ [
+ [np.cos(angle), -np.sin(angle), 0.0],
+ [np.sin(angle), np.cos(angle), 0.0],
+ [0.0, 0.0, 1.0],
+ ]
+ )
if category == "torch":
if prepend_dim:
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -672,7 +684,7 @@ def compute_twist_rotation(rotation_matrix, twist_axis):
twist_rotation = quaternion_to_rotation_matrix(twist_quaternion)
twist_aa = quaternion_to_angle_axis(twist_quaternion)
- twist_angle = torch.sum(twist_aa, dim=1, keepdim=True) / torch.sum(
- twist_axis, dim=1, keepdim=True)
+ twist_angle = torch.sum(twist_aa, dim=1,
+ keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True)
- return twist_rotation, twist_angle
\ No newline at end of file
+ return twist_rotation, twist_angle
diff --git a/lib/net/net_util.py b/lib/net/net_util.py
index 200a87e5e09094069379f989082d8099d97b75f8..d89fcff5670909cd41c2e917e87b3bdb25870d8a 100644
--- a/lib/net/net_util.py
+++ b/lib/net/net_util.py
@@ -71,11 +71,10 @@ def init_weights(net, init_type="normal", init_gain=0.02):
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
-
- def init_func(m): # define the initialization function
+ def init_func(m): # define the initialization function
classname = m.__class__.__name__
- if hasattr(m, "weight") and (classname.find("Conv") != -1 or
- classname.find("Linear") != -1):
+ if hasattr(m,
+ "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
if init_type == "normal":
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == "xavier":
@@ -85,17 +84,19 @@ def init_weights(net, init_type="normal", init_gain=0.02):
elif init_type == "orthogonal":
init.orthogonal_(m.weight.data, gain=init_gain)
else:
- raise NotImplementedError("initialization method [%s] is not implemented" %
- init_type)
+ raise NotImplementedError(
+ "initialization method [%s] is not implemented" % init_type
+ )
if hasattr(m, "bias") and m.bias is not None:
init.constant_(m.bias.data, 0.0)
- elif (classname.find("BatchNorm2d") !=
- -1): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+ elif (
+ classname.find("BatchNorm2d") != -1
+ ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
# print('initialize network with %s' % init_type)
- net.apply(init_func) # apply the initialization function
+ net.apply(init_func) # apply the initialization function
def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]):
@@ -110,7 +111,7 @@ def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]):
"""
if len(gpu_ids) > 0:
assert torch.cuda.is_available()
- net = torch.nn.DataParallel(net) # multi-GPUs
+ net = torch.nn.DataParallel(net) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net
@@ -127,13 +128,9 @@ def imageSpaceRotation(xy, rot):
return (disp * xy).sum(dim=1)
-def cal_gradient_penalty(netD,
- real_data,
- fake_data,
- device,
- type="mixed",
- constant=1.0,
- lambda_gp=10.0):
+def cal_gradient_penalty(
+ netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0
+):
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
Arguments:
@@ -155,9 +152,11 @@ def cal_gradient_penalty(netD,
interpolatesv = fake_data
elif type == "mixed":
alpha = torch.rand(real_data.shape[0], 1)
- alpha = (alpha.expand(real_data.shape[0],
- real_data.nelement() //
- real_data.shape[0]).contiguous().view(*real_data.shape))
+ alpha = (
+ alpha.expand(real_data.shape[0],
+ real_data.nelement() //
+ real_data.shape[0]).contiguous().view(*real_data.shape)
+ )
alpha = alpha.to(device)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
@@ -172,9 +171,9 @@ def cal_gradient_penalty(netD,
retain_graph=True,
only_inputs=True,
)
- gradients = gradients[0].view(real_data.size(0), -1) # flat the data
- gradient_penalty = ((
- (gradients + 1e-16).norm(2, dim=1) - constant)**2).mean() * lambda_gp # added eps
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant)**
+ 2).mean() * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
@@ -201,13 +200,11 @@ def get_norm_layer(norm_type="instance"):
class Flatten(nn.Module):
-
def forward(self, input):
return input.view(input.size(0), -1)
class ConvBlock(nn.Module):
-
def __init__(self, in_planes, out_planes, opt):
super(ConvBlock, self).__init__()
[k, s, d, p] = opt.conv3x3
@@ -258,5 +255,3 @@ class ConvBlock(nn.Module):
out3 += residual
return out3
-
-
diff --git a/lib/net/voxelize.py b/lib/net/voxelize.py
index ba341ce905455e14fa21456acfe899ba19e6f783..394b40e6eeeb158bb691c1e518b6b1f7a889b8d8 100644
--- a/lib/net/voxelize.py
+++ b/lib/net/voxelize.py
@@ -13,7 +13,6 @@ class VoxelizationFunction(Function):
Definition of differentiable voxelization function
Currently implemented only for cuda Tensors
"""
-
@staticmethod
def forward(
ctx,
@@ -48,12 +47,15 @@ class VoxelizationFunction(Function):
smpl_face_code = smpl_face_code.contiguous()
smpl_tetrahedrons = smpl_tetrahedrons.contiguous()
- occ_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, ctx.volume_res,
- ctx.volume_res).fill_(0.0)
- semantic_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, ctx.volume_res,
- ctx.volume_res, 3).fill_(0.0)
- weight_sum_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, ctx.volume_res,
- ctx.volume_res).fill_(1e-3)
+ occ_volume = torch.cuda.FloatTensor(
+ ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res
+ ).fill_(0.0)
+ semantic_volume = torch.cuda.FloatTensor(
+ ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res, 3
+ ).fill_(0.0)
+ weight_sum_volume = torch.cuda.FloatTensor(
+ ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res
+ ).fill_(1e-3)
# occ_volume [B, volume_res, volume_res, volume_res]
# semantic_volume [B, volume_res, volume_res, volume_res, 3]
@@ -80,7 +82,6 @@ class Voxelization(nn.Module):
"""
Wrapper around the autograd function VoxelizationFunction
"""
-
def __init__(
self,
smpl_vertex_code,
@@ -151,21 +152,25 @@ class Voxelization(nn.Module):
self.sigma,
self.smooth_kernel_size,
)
- return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw)
+ return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw)
def vertices_to_faces(self, vertices):
assert vertices.ndimension() == 3
bs, nv = vertices.shape[:2]
- face = (self.smpl_face_indices_batch +
- (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None])
+ face = (
+ self.smpl_face_indices_batch +
+ (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]
+ )
vertices_ = vertices.reshape((bs * nv, 3))
return vertices_[face.long()]
def vertices_to_tetrahedrons(self, vertices):
assert vertices.ndimension() == 3
bs, nv = vertices.shape[:2]
- tets = (self.smpl_tetraderon_indices_batch +
- (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None])
+ tets = (
+ self.smpl_tetraderon_indices_batch +
+ (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]
+ )
vertices_ = vertices.reshape((bs * nv, 3))
return vertices_[tets.long()]
@@ -174,8 +179,9 @@ class Voxelization(nn.Module):
assert face_verts.shape[2] == 3
assert face_verts.shape[3] == 3
bs, nf = face_verts.shape[:2]
- face_centers = (face_verts[:, :, 0, :] + face_verts[:, :, 1, :] +
- face_verts[:, :, 2, :]) / 3.0
+ face_centers = (
+ face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + face_verts[:, :, 2, :]
+ ) / 3.0
face_centers = face_centers.reshape((bs, nf, 3))
return face_centers
diff --git a/lib/pixielib/models/FLAME.py b/lib/pixielib/models/FLAME.py
index 6c0b78d5fd8efa12651ac5cb689b4e5c4d790636..b62b1069b6083685e8ff1511e57c48ccf79bc927 100755
--- a/lib/pixielib/models/FLAME.py
+++ b/lib/pixielib/models/FLAME.py
@@ -27,7 +27,6 @@ class FLAMETex(nn.Module):
FLAME texture converted from BFM:
https://github.com/TimoBolkart/BFM_to_FLAME
"""
-
def __init__(self, config):
super(FLAMETex, self).__init__()
if config.tex_type == "BFM":
@@ -54,8 +53,7 @@ class FLAMETex(nn.Module):
n_tex = config.n_tex
num_components = texture_basis.shape[1]
texture_mean = torch.from_numpy(texture_mean).float()[None, ...]
- texture_basis = torch.from_numpy(
- texture_basis[:, :n_tex]).float()[None, ...]
+ texture_basis = torch.from_numpy(texture_basis[:, :n_tex]).float()[None, ...]
self.register_buffer("texture_mean", texture_mean)
self.register_buffer("texture_basis", texture_basis)
@@ -64,10 +62,8 @@ class FLAMETex(nn.Module):
texcode: [batchsize, n_tex]
texture: [bz, 3, 256, 256], range: 0-1
"""
- texture = self.texture_mean + (self.texture_basis *
- texcode[:, None, :]).sum(-1)
- texture = texture.reshape(texcode.shape[0], 512, 512,
- 3).permute(0, 3, 1, 2)
+ texture = self.texture_mean + (self.texture_basis * texcode[:, None, :]).sum(-1)
+ texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2)
texture = F.interpolate(texture, [256, 256])
texture = texture[:, [2, 1, 0], :, :]
return texture
@@ -78,13 +74,13 @@ def texture_flame2smplx(cached_data, flame_texture, smplx_texture):
TODO: pytorch version ==> grid sample
"""
if smplx_texture.shape[0] != smplx_texture.shape[1]:
- print("SMPL-X texture not squared (%d != %d)" %
- (smplx_texture[0], smplx_texture[1]))
+ print("SMPL-X texture not squared (%d != %d)" % (smplx_texture[0], smplx_texture[1]))
return
if smplx_texture.shape[0] != cached_data["target_resolution"]:
print(
- "SMPL-X texture size does not match cached image resolution (%d != %d)"
- % (smplx_texture.shape[0], cached_data["target_resolution"]))
+ "SMPL-X texture size does not match cached image resolution (%d != %d)" %
+ (smplx_texture.shape[0], cached_data["target_resolution"])
+ )
return
x_coords = cached_data["x_coords"]
y_coords = cached_data["y_coords"]
@@ -98,11 +94,13 @@ def texture_flame2smplx(cached_data, flame_texture, smplx_texture):
flame_texture.shape[0],
).astype(int)
source_tex_coords[:, 1] = np.clip(
- flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0,
- flame_texture.shape[1]).astype(int)
+ flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0, flame_texture.shape[1]
+ ).astype(int)
smplx_texture[y_coords[target_pixel_ids].astype(int),
- x_coords[target_pixel_ids].astype(int), :, ] = flame_texture[
- source_tex_coords[:, 0], source_tex_coords[:, 1]]
+ x_coords[target_pixel_ids].astype(int), :, ] = flame_texture[source_tex_coords[:,
+ 0],
+ source_tex_coords[:,
+ 1]]
return smplx_texture
diff --git a/lib/pixielib/models/SMPLX.py b/lib/pixielib/models/SMPLX.py
index 0e940d2d42bd07484943f13519df97e6cf3e0fa8..beb672facebc7fa9e61eee7f8e7f3f185ac6cdad 100644
--- a/lib/pixielib/models/SMPLX.py
+++ b/lib/pixielib/models/SMPLX.py
@@ -209,452 +209,468 @@ extra_names = [
SMPLX_names += extra_names
part_indices = {}
-part_indices["body"] = np.array([
- 0,
- 1,
- 2,
- 3,
- 4,
- 5,
- 6,
- 7,
- 8,
- 9,
- 10,
- 11,
- 12,
- 13,
- 14,
- 15,
- 16,
- 17,
- 18,
- 19,
- 20,
- 21,
- 22,
- 23,
- 24,
- 123,
- 124,
- 125,
- 126,
- 127,
- 132,
- 134,
- 135,
- 136,
- 137,
- 138,
- 143,
-])
-part_indices["torso"] = np.array([
- 0,
- 1,
- 2,
- 3,
- 6,
- 9,
- 12,
- 13,
- 14,
- 15,
- 16,
- 17,
- 18,
- 19,
- 22,
- 23,
- 24,
- 55,
- 56,
- 57,
- 58,
- 59,
- 76,
- 77,
- 78,
- 79,
- 80,
- 81,
- 82,
- 83,
- 84,
- 85,
- 86,
- 87,
- 88,
- 89,
- 90,
- 91,
- 92,
- 93,
- 94,
- 95,
- 96,
- 97,
- 98,
- 99,
- 100,
- 101,
- 102,
- 103,
- 104,
- 105,
- 106,
- 107,
- 108,
- 109,
- 110,
- 111,
- 112,
- 113,
- 114,
- 115,
- 116,
- 117,
- 118,
- 119,
- 120,
- 121,
- 122,
- 123,
- 124,
- 125,
- 126,
- 127,
- 128,
- 129,
- 130,
- 131,
- 132,
- 133,
- 134,
- 135,
- 136,
- 137,
- 138,
- 139,
- 140,
- 141,
- 142,
- 143,
- 144,
-])
-part_indices["head"] = np.array([
- 12,
- 15,
- 22,
- 23,
- 24,
- 55,
- 56,
- 57,
- 58,
- 59,
- 60,
- 61,
- 62,
- 63,
- 64,
- 65,
- 66,
- 67,
- 68,
- 69,
- 70,
- 71,
- 72,
- 73,
- 74,
- 75,
- 76,
- 77,
- 78,
- 79,
- 80,
- 81,
- 82,
- 83,
- 84,
- 85,
- 86,
- 87,
- 88,
- 89,
- 90,
- 91,
- 92,
- 93,
- 94,
- 95,
- 96,
- 97,
- 98,
- 99,
- 100,
- 101,
- 102,
- 103,
- 104,
- 105,
- 106,
- 107,
- 108,
- 109,
- 110,
- 111,
- 112,
- 113,
- 114,
- 115,
- 116,
- 117,
- 118,
- 119,
- 120,
- 121,
- 122,
- 123,
- 125,
- 126,
- 134,
- 136,
- 137,
-])
-part_indices["face"] = np.array([
- 55,
- 56,
- 57,
- 58,
- 59,
- 60,
- 61,
- 62,
- 63,
- 64,
- 65,
- 66,
- 67,
- 68,
- 69,
- 70,
- 71,
- 72,
- 73,
- 74,
- 75,
- 76,
- 77,
- 78,
- 79,
- 80,
- 81,
- 82,
- 83,
- 84,
- 85,
- 86,
- 87,
- 88,
- 89,
- 90,
- 91,
- 92,
- 93,
- 94,
- 95,
- 96,
- 97,
- 98,
- 99,
- 100,
- 101,
- 102,
- 103,
- 104,
- 105,
- 106,
- 107,
- 108,
- 109,
- 110,
- 111,
- 112,
- 113,
- 114,
- 115,
- 116,
- 117,
- 118,
- 119,
- 120,
- 121,
- 122,
-])
-part_indices["upper"] = np.array([
- 12,
- 13,
- 14,
- 55,
- 56,
- 57,
- 58,
- 59,
- 60,
- 61,
- 62,
- 63,
- 64,
- 65,
- 66,
- 67,
- 68,
- 69,
- 70,
- 71,
- 72,
- 73,
- 74,
- 75,
- 76,
- 77,
- 78,
- 79,
- 80,
- 81,
- 82,
- 83,
- 84,
- 85,
- 86,
- 87,
- 88,
- 89,
- 90,
- 91,
- 92,
- 93,
- 94,
- 95,
- 96,
- 97,
- 98,
- 99,
- 100,
- 101,
- 102,
- 103,
- 104,
- 105,
- 106,
- 107,
- 108,
- 109,
- 110,
- 111,
- 112,
- 113,
- 114,
- 115,
- 116,
- 117,
- 118,
- 119,
- 120,
- 121,
- 122,
-])
-part_indices["hand"] = np.array([
- 20,
- 21,
- 25,
- 26,
- 27,
- 28,
- 29,
- 30,
- 31,
- 32,
- 33,
- 34,
- 35,
- 36,
- 37,
- 38,
- 39,
- 40,
- 41,
- 42,
- 43,
- 44,
- 45,
- 46,
- 47,
- 48,
- 49,
- 50,
- 51,
- 52,
- 53,
- 54,
- 128,
- 129,
- 130,
- 131,
- 133,
- 139,
- 140,
- 141,
- 142,
- 144,
-])
-part_indices["left_hand"] = np.array([
- 20,
- 25,
- 26,
- 27,
- 28,
- 29,
- 30,
- 31,
- 32,
- 33,
- 34,
- 35,
- 36,
- 37,
- 38,
- 39,
- 128,
- 129,
- 130,
- 131,
- 133,
-])
-part_indices["right_hand"] = np.array([
- 21,
- 40,
- 41,
- 42,
- 43,
- 44,
- 45,
- 46,
- 47,
- 48,
- 49,
- 50,
- 51,
- 52,
- 53,
- 54,
- 139,
- 140,
- 141,
- 142,
- 144,
-])
+part_indices["body"] = np.array(
+ [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 123,
+ 124,
+ 125,
+ 126,
+ 127,
+ 132,
+ 134,
+ 135,
+ 136,
+ 137,
+ 138,
+ 143,
+ ]
+)
+part_indices["torso"] = np.array(
+ [
+ 0,
+ 1,
+ 2,
+ 3,
+ 6,
+ 9,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 22,
+ 23,
+ 24,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 81,
+ 82,
+ 83,
+ 84,
+ 85,
+ 86,
+ 87,
+ 88,
+ 89,
+ 90,
+ 91,
+ 92,
+ 93,
+ 94,
+ 95,
+ 96,
+ 97,
+ 98,
+ 99,
+ 100,
+ 101,
+ 102,
+ 103,
+ 104,
+ 105,
+ 106,
+ 107,
+ 108,
+ 109,
+ 110,
+ 111,
+ 112,
+ 113,
+ 114,
+ 115,
+ 116,
+ 117,
+ 118,
+ 119,
+ 120,
+ 121,
+ 122,
+ 123,
+ 124,
+ 125,
+ 126,
+ 127,
+ 128,
+ 129,
+ 130,
+ 131,
+ 132,
+ 133,
+ 134,
+ 135,
+ 136,
+ 137,
+ 138,
+ 139,
+ 140,
+ 141,
+ 142,
+ 143,
+ 144,
+ ]
+)
+part_indices["head"] = np.array(
+ [
+ 12,
+ 15,
+ 22,
+ 23,
+ 24,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ 60,
+ 61,
+ 62,
+ 63,
+ 64,
+ 65,
+ 66,
+ 67,
+ 68,
+ 69,
+ 70,
+ 71,
+ 72,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 81,
+ 82,
+ 83,
+ 84,
+ 85,
+ 86,
+ 87,
+ 88,
+ 89,
+ 90,
+ 91,
+ 92,
+ 93,
+ 94,
+ 95,
+ 96,
+ 97,
+ 98,
+ 99,
+ 100,
+ 101,
+ 102,
+ 103,
+ 104,
+ 105,
+ 106,
+ 107,
+ 108,
+ 109,
+ 110,
+ 111,
+ 112,
+ 113,
+ 114,
+ 115,
+ 116,
+ 117,
+ 118,
+ 119,
+ 120,
+ 121,
+ 122,
+ 123,
+ 125,
+ 126,
+ 134,
+ 136,
+ 137,
+ ]
+)
+part_indices["face"] = np.array(
+ [
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ 60,
+ 61,
+ 62,
+ 63,
+ 64,
+ 65,
+ 66,
+ 67,
+ 68,
+ 69,
+ 70,
+ 71,
+ 72,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 81,
+ 82,
+ 83,
+ 84,
+ 85,
+ 86,
+ 87,
+ 88,
+ 89,
+ 90,
+ 91,
+ 92,
+ 93,
+ 94,
+ 95,
+ 96,
+ 97,
+ 98,
+ 99,
+ 100,
+ 101,
+ 102,
+ 103,
+ 104,
+ 105,
+ 106,
+ 107,
+ 108,
+ 109,
+ 110,
+ 111,
+ 112,
+ 113,
+ 114,
+ 115,
+ 116,
+ 117,
+ 118,
+ 119,
+ 120,
+ 121,
+ 122,
+ ]
+)
+part_indices["upper"] = np.array(
+ [
+ 12,
+ 13,
+ 14,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ 60,
+ 61,
+ 62,
+ 63,
+ 64,
+ 65,
+ 66,
+ 67,
+ 68,
+ 69,
+ 70,
+ 71,
+ 72,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 81,
+ 82,
+ 83,
+ 84,
+ 85,
+ 86,
+ 87,
+ 88,
+ 89,
+ 90,
+ 91,
+ 92,
+ 93,
+ 94,
+ 95,
+ 96,
+ 97,
+ 98,
+ 99,
+ 100,
+ 101,
+ 102,
+ 103,
+ 104,
+ 105,
+ 106,
+ 107,
+ 108,
+ 109,
+ 110,
+ 111,
+ 112,
+ 113,
+ 114,
+ 115,
+ 116,
+ 117,
+ 118,
+ 119,
+ 120,
+ 121,
+ 122,
+ ]
+)
+part_indices["hand"] = np.array(
+ [
+ 20,
+ 21,
+ 25,
+ 26,
+ 27,
+ 28,
+ 29,
+ 30,
+ 31,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 128,
+ 129,
+ 130,
+ 131,
+ 133,
+ 139,
+ 140,
+ 141,
+ 142,
+ 144,
+ ]
+)
+part_indices["left_hand"] = np.array(
+ [
+ 20,
+ 25,
+ 26,
+ 27,
+ 28,
+ 29,
+ 30,
+ 31,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39,
+ 128,
+ 129,
+ 130,
+ 131,
+ 133,
+ ]
+)
+part_indices["right_hand"] = np.array(
+ [
+ 21,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 139,
+ 140,
+ 141,
+ 142,
+ 144,
+ ]
+)
# kinematic tree
head_kin_chain = [15, 12, 9, 6, 3, 0]
@@ -691,13 +707,12 @@ class SMPLX(nn.Module):
Given smplx parameters, this class generates a differentiable SMPLX function
which outputs a mesh and 3D joints
"""
-
def __init__(self, config):
super(SMPLX, self).__init__()
# print("creating the SMPLX Decoder")
ss = np.load(config.smplx_model_path, allow_pickle=True)
smplx_model = Struct(**ss)
-
+
self.dtype = torch.float32
self.register_buffer(
"faces_tensor",
@@ -705,8 +720,8 @@ class SMPLX(nn.Module):
)
# The vertices of the template model
self.register_buffer(
- "v_template",
- to_tensor(to_np(smplx_model.v_template), dtype=self.dtype))
+ "v_template", to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)
+ )
# The shape components and expression
# expression space is the same as FLAME
shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype)
@@ -721,21 +736,18 @@ class SMPLX(nn.Module):
# The pose components
num_pose_basis = smplx_model.posedirs.shape[-1]
posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T
- self.register_buffer("posedirs",
- to_tensor(to_np(posedirs), dtype=self.dtype))
+ self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype))
self.register_buffer(
- "J_regressor",
- to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype))
+ "J_regressor", to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)
+ )
parents = to_tensor(to_np(smplx_model.kintree_table[0])).long()
parents[0] = -1
self.register_buffer("parents", parents)
- self.register_buffer(
- "lbs_weights",
- to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
+ self.register_buffer("lbs_weights", to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
# for face keypoints
self.register_buffer(
- "lmk_faces_idx",
- torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long))
+ "lmk_faces_idx", torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)
+ )
self.register_buffer(
"lmk_bary_coords",
torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype),
@@ -746,24 +758,20 @@ class SMPLX(nn.Module):
)
self.register_buffer(
"dynamic_lmk_bary_coords",
- torch.tensor(smplx_model.dynamic_lmk_bary_coords,
- dtype=self.dtype),
+ torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype),
)
# pelvis to head, to calculate head yaw angle, then find the dynamic landmarks
- self.register_buffer("head_kin_chain",
- torch.tensor(head_kin_chain, dtype=torch.long))
+ self.register_buffer("head_kin_chain", torch.tensor(head_kin_chain, dtype=torch.long))
# -- initialize parameters
# shape and expression
self.register_buffer(
"shape_params",
- nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype),
- requires_grad=False),
+ nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False),
)
self.register_buffer(
"expression_params",
- nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype),
- requires_grad=False),
+ nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False),
)
# pose: represented as rotation matrx [number of joints, 3, 3]
self.register_buffer(
@@ -824,8 +832,7 @@ class SMPLX(nn.Module):
)
if config.extra_joint_path:
- self.extra_joint_selector = JointsFromVerticesSelector(
- fname=config.extra_joint_path)
+ self.extra_joint_selector = JointsFromVerticesSelector(fname=config.extra_joint_path)
self.use_joint_regressor = True
self.keypoint_names = SMPLX_names
if self.use_joint_regressor:
@@ -843,7 +850,8 @@ class SMPLX(nn.Module):
self.register_buffer("target_idxs", torch.from_numpy(target))
self.register_buffer(
"extra_joint_regressor",
- torch.from_numpy(j14_regressor).to(torch.float32))
+ torch.from_numpy(j14_regressor).to(torch.float32)
+ )
self.part_indices = part_indices
def forward(
@@ -880,23 +888,17 @@ class SMPLX(nn.Module):
if expression_params is None:
expression_params = self.expression_params.expand(batch_size, -1)
if global_pose is None:
- global_pose = self.global_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
+ global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
if body_pose is None:
- body_pose = self.body_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
+ body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
if jaw_pose is None:
- jaw_pose = self.jaw_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
+ jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
if eye_pose is None:
- eye_pose = self.eye_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
+ eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
if left_hand_pose is None:
- left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
+ left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
if right_hand_pose is None:
- right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
+ right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
shape_components = torch.cat([shape_params, expression_params], dim=1)
full_pose = torch.cat(
@@ -910,8 +912,7 @@ class SMPLX(nn.Module):
],
dim=1,
)
- template_vertices = self.v_template.unsqueeze(0).expand(
- batch_size, -1, -1)
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
# smplx
vertices, joints = lbs(
shape_components,
@@ -926,10 +927,8 @@ class SMPLX(nn.Module):
pose2rot=False,
)
# face dynamic landmarks
- lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(
- batch_size, -1)
- lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(
- batch_size, -1, -1)
+ lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
+ lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
vertices,
full_pose,
@@ -939,14 +938,12 @@ class SMPLX(nn.Module):
)
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1)
- landmarks = vertices2landmarks(vertices, self.faces_tensor,
- lmk_faces_idx, lmk_bary_coords)
+ landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
final_joint_set = [joints, landmarks]
if hasattr(self, "extra_joint_selector"):
# Add any extra joints that might be needed
- extra_joints = self.extra_joint_selector(vertices,
- self.faces_tensor)
+ extra_joints = self.extra_joint_selector(vertices, self.faces_tensor)
final_joint_set.append(extra_joints)
# Create the final joint set
joints = torch.cat(final_joint_set, dim=1)
@@ -978,16 +975,15 @@ class SMPLX(nn.Module):
# -> Left elbow -> Left wrist
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
else:
- raise NotImplementedError(
- f"pose_abs2rel does not support: {abs_joint}")
+ raise NotImplementedError(f"pose_abs2rel does not support: {abs_joint}")
batch_size = global_pose.shape[0]
dtype = global_pose.dtype
device = global_pose.device
full_pose = torch.cat([global_pose, body_pose], dim=1)
- rel_rot_mat = (torch.eye(3, device=device,
- dtype=dtype).unsqueeze_(dim=0).repeat(
- batch_size, 1, 1))
+ rel_rot_mat = (
+ torch.eye(3, device=device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1)
+ )
for idx in kin_chain[1:]:
rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat)
@@ -1027,11 +1023,8 @@ class SMPLX(nn.Module):
# -> Left elbow -> Left wrist
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
else:
- raise NotImplementedError(
- f"pose_rel2abs does not support: {abs_joint}")
- rel_rot_mat = torch.eye(3,
- device=full_pose.device,
- dtype=full_pose.dtype).unsqueeze_(dim=0)
+ raise NotImplementedError(f"pose_rel2abs does not support: {abs_joint}")
+ rel_rot_mat = torch.eye(3, device=full_pose.device, dtype=full_pose.dtype).unsqueeze_(dim=0)
for idx in kin_chain:
rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat)
abs_pose = rel_rot_mat[:, None, :, :]
diff --git a/lib/pixielib/models/encoders.py b/lib/pixielib/models/encoders.py
index 6b0d0e17cf1ca9dc7c87aeba5d8dc3df97f04011..0783c9265ab442a259fd693a55039026cc7608db 100755
--- a/lib/pixielib/models/encoders.py
+++ b/lib/pixielib/models/encoders.py
@@ -5,14 +5,13 @@ import torch.nn.functional as F
class ResnetEncoder(nn.Module):
-
def __init__(self, append_layers=None):
super(ResnetEncoder, self).__init__()
from . import resnet
# feature_size = 2048
self.feature_dim = 2048
- self.encoder = resnet.load_ResNet50Model() # out: 2048
+ self.encoder = resnet.load_ResNet50Model() # out: 2048
# regressor
self.append_layers = append_layers
@@ -25,7 +24,6 @@ class ResnetEncoder(nn.Module):
class MLP(nn.Module):
-
def __init__(self, channels=[2048, 1024, 1], last_op=None):
super(MLP, self).__init__()
layers = []
@@ -45,13 +43,12 @@ class MLP(nn.Module):
class HRNEncoder(nn.Module):
-
def __init__(self, append_layers=None):
super(HRNEncoder, self).__init__()
from . import hrnet
self.feature_dim = 2048
- self.encoder = hrnet.load_HRNet(pretrained=True) # out: 2048
+ self.encoder = hrnet.load_HRNet(pretrained=True) # out: 2048
# regressor
self.append_layers = append_layers
diff --git a/lib/pixielib/models/hrnet.py b/lib/pixielib/models/hrnet.py
index 158c3cc31189d488877f6d2884fab7dc65bc8815..c1fd871abf8ae79dd87f96e30d14d726c913db05 100644
--- a/lib/pixielib/models/hrnet.py
+++ b/lib/pixielib/models/hrnet.py
@@ -15,38 +15,42 @@ def load_HRNet(pretrained=False):
hr_net_cfg_dict = {
"use_old_impl": False,
"pretrained_layers": ["*"],
- "stage1": {
- "num_modules": 1,
- "num_branches": 1,
- "num_blocks": [4],
- "num_channels": [64],
- "block": "BOTTLENECK",
- "fuse_method": "SUM",
- },
- "stage2": {
- "num_modules": 1,
- "num_branches": 2,
- "num_blocks": [4, 4],
- "num_channels": [48, 96],
- "block": "BASIC",
- "fuse_method": "SUM",
- },
- "stage3": {
- "num_modules": 4,
- "num_branches": 3,
- "num_blocks": [4, 4, 4],
- "num_channels": [48, 96, 192],
- "block": "BASIC",
- "fuse_method": "SUM",
- },
- "stage4": {
- "num_modules": 3,
- "num_branches": 4,
- "num_blocks": [4, 4, 4, 4],
- "num_channels": [48, 96, 192, 384],
- "block": "BASIC",
- "fuse_method": "SUM",
- },
+ "stage1":
+ {
+ "num_modules": 1,
+ "num_branches": 1,
+ "num_blocks": [4],
+ "num_channels": [64],
+ "block": "BOTTLENECK",
+ "fuse_method": "SUM",
+ },
+ "stage2":
+ {
+ "num_modules": 1,
+ "num_branches": 2,
+ "num_blocks": [4, 4],
+ "num_channels": [48, 96],
+ "block": "BASIC",
+ "fuse_method": "SUM",
+ },
+ "stage3":
+ {
+ "num_modules": 4,
+ "num_branches": 3,
+ "num_blocks": [4, 4, 4],
+ "num_channels": [48, 96, 192],
+ "block": "BASIC",
+ "fuse_method": "SUM",
+ },
+ "stage4":
+ {
+ "num_modules": 3,
+ "num_branches": 4,
+ "num_blocks": [4, 4, 4, 4],
+ "num_channels": [48, 96, 192, 384],
+ "block": "BASIC",
+ "fuse_method": "SUM",
+ },
}
hr_net_cfg = hr_net_cfg_dict
model = HighResolutionNet(hr_net_cfg)
@@ -55,7 +59,6 @@ def load_HRNet(pretrained=False):
class HighResolutionModule(nn.Module):
-
def __init__(
self,
num_branches,
@@ -67,8 +70,7 @@ class HighResolutionModule(nn.Module):
multi_scale_output=True,
):
super(HighResolutionModule, self).__init__()
- self._check_branches(num_branches, blocks, num_blocks, num_inchannels,
- num_channels)
+ self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
@@ -76,37 +78,33 @@ class HighResolutionModule(nn.Module):
self.multi_scale_output = multi_scale_output
- self.branches = self._make_branches(num_branches, blocks, num_blocks,
- num_channels)
+ self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(True)
- def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels,
- num_channels):
+ def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
if num_branches != len(num_blocks):
- error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
- num_branches, len(num_blocks))
+ error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks))
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
- num_branches, len(num_channels))
+ num_branches, len(num_channels)
+ )
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
- num_branches, len(num_inchannels))
+ num_branches, len(num_inchannels)
+ )
raise ValueError(error_msg)
- def _make_one_branch(self,
- branch_index,
- block,
- num_blocks,
- num_channels,
- stride=1):
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
downsample = None
- if (stride != 1 or self.num_inchannels[branch_index] !=
- num_channels[branch_index] * block.expansion):
+ if (
+ stride != 1 or
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion
+ ):
downsample = nn.Sequential(
nn.Conv2d(
self.num_inchannels[branch_index],
@@ -115,8 +113,7 @@ class HighResolutionModule(nn.Module):
stride=stride,
bias=False,
),
- nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
- momentum=BN_MOMENTUM),
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
@@ -126,13 +123,11 @@ class HighResolutionModule(nn.Module):
num_channels[branch_index],
stride,
downsample,
- ))
- self.num_inchannels[
- branch_index] = num_channels[branch_index] * block.expansion
+ )
+ )
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
- layers.append(
- block(self.num_inchannels[branch_index],
- num_channels[branch_index]))
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
return nn.Sequential(*layers)
@@ -140,8 +135,7 @@ class HighResolutionModule(nn.Module):
branches = []
for i in range(num_branches):
- branches.append(
- self._make_one_branch(i, block, num_blocks, num_channels))
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
@@ -167,9 +161,9 @@ class HighResolutionModule(nn.Module):
bias=False,
),
nn.BatchNorm2d(num_inchannels[i]),
- nn.Upsample(scale_factor=2**(j - i),
- mode="nearest"),
- ))
+ nn.Upsample(scale_factor=2**(j - i), mode="nearest"),
+ )
+ )
elif j == i:
fuse_layer.append(None)
else:
@@ -188,7 +182,8 @@ class HighResolutionModule(nn.Module):
bias=False,
),
nn.BatchNorm2d(num_outchannels_conv3x3),
- ))
+ )
+ )
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(
@@ -203,7 +198,8 @@ class HighResolutionModule(nn.Module):
),
nn.BatchNorm2d(num_outchannels_conv3x3),
nn.ReLU(True),
- ))
+ )
+ )
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
@@ -237,7 +233,6 @@ blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck}
class HighResolutionNet(nn.Module):
-
def __init__(self, cfg, **kwargs):
self.inplanes = 64
super(HighResolutionNet, self).__init__()
@@ -245,19 +240,9 @@ class HighResolutionNet(nn.Module):
self.use_old_impl = use_old_impl
# stem net
- self.conv1 = nn.Conv2d(3,
- 64,
- kernel_size=3,
- stride=2,
- padding=1,
- bias=False)
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
- self.conv2 = nn.Conv2d(64,
- 64,
- kernel_size=3,
- stride=2,
- padding=1,
- bias=False)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
@@ -271,41 +256,29 @@ class HighResolutionNet(nn.Module):
self.stage2_cfg = cfg.get("stage2", {})
num_channels = self.stage2_cfg.get("num_channels", (32, 64))
block = blocks_dict[self.stage2_cfg.get("block")]
- num_channels = [
- num_channels[i] * block.expansion for i in range(len(num_channels))
- ]
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
stage2_num_channels = num_channels
- self.transition1 = self._make_transition_layer([stage1_out_channel],
- num_channels)
- self.stage2, pre_stage_channels = self._make_stage(
- self.stage2_cfg, num_channels)
+ self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
self.stage3_cfg = cfg.get("stage3")
num_channels = self.stage3_cfg["num_channels"]
block = blocks_dict[self.stage3_cfg["block"]]
- num_channels = [
- num_channels[i] * block.expansion for i in range(len(num_channels))
- ]
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
stage3_num_channels = num_channels
- self.transition2 = self._make_transition_layer(pre_stage_channels,
- num_channels)
- self.stage3, pre_stage_channels = self._make_stage(
- self.stage3_cfg, num_channels)
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
self.stage4_cfg = cfg.get("stage4")
num_channels = self.stage4_cfg["num_channels"]
block = blocks_dict[self.stage4_cfg["block"]]
- num_channels = [
- num_channels[i] * block.expansion for i in range(len(num_channels))
- ]
- self.transition3 = self._make_transition_layer(pre_stage_channels,
- num_channels)
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
stage_4_out_channels = num_channels
self.stage4, pre_stage_channels = self._make_stage(
- self.stage4_cfg,
- num_channels,
- multi_scale_output=not self.use_old_impl)
+ self.stage4_cfg, num_channels, multi_scale_output=not self.use_old_impl
+ )
stage4_num_channels = num_channels
self.output_channels_dim = pre_stage_channels
@@ -316,35 +289,34 @@ class HighResolutionNet(nn.Module):
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
if use_old_impl:
- in_dims = (2**2 * stage2_num_channels[-1] +
- 2**1 * stage3_num_channels[-1] +
- stage_4_out_channels[-1])
+ in_dims = (
+ 2**2 * stage2_num_channels[-1] + 2**1 * stage3_num_channels[-1] +
+ stage_4_out_channels[-1]
+ )
else:
# TODO: Replace with parameters
in_dims = 4 * 384
self.subsample_4 = self._make_subsample_layer(
- in_channels=stage4_num_channels[0], num_layers=3)
+ in_channels=stage4_num_channels[0], num_layers=3
+ )
self.subsample_3 = self._make_subsample_layer(
- in_channels=stage2_num_channels[-1], num_layers=2)
+ in_channels=stage2_num_channels[-1], num_layers=2
+ )
self.subsample_2 = self._make_subsample_layer(
- in_channels=stage3_num_channels[-1], num_layers=1)
- self.conv_layers = self._make_conv_layer(in_channels=in_dims,
- num_layers=5)
+ in_channels=stage3_num_channels[-1], num_layers=1
+ )
+ self.conv_layers = self._make_conv_layer(in_channels=in_dims, num_layers=5)
def get_output_dim(self):
- base_output = {
- f"layer{idx + 1}": val
- for idx, val in enumerate(self.output_channels_dim)
- }
+ base_output = {f"layer{idx + 1}": val for idx, val in enumerate(self.output_channels_dim)}
output = base_output.copy()
for key in base_output:
output[f"{key}_avg_pooling"] = output[key]
output["concat"] = 2048
return output
- def _make_transition_layer(self, num_channels_pre_layer,
- num_channels_cur_layer):
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
@@ -364,26 +336,24 @@ class HighResolutionNet(nn.Module):
),
nn.BatchNorm2d(num_channels_cur_layer[i]),
nn.ReLU(inplace=True),
- ))
+ )
+ )
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre_layer[-1]
- outchannels = (num_channels_cur_layer[i] if j == i -
- num_branches_pre else inchannels)
+ outchannels = (
+ num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
+ )
conv3x3s.append(
nn.Sequential(
- nn.Conv2d(inchannels,
- outchannels,
- 3,
- 2,
- 1,
- bias=False),
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
nn.BatchNorm2d(outchannels),
nn.ReLU(inplace=True),
- ))
+ )
+ )
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
@@ -410,24 +380,13 @@ class HighResolutionNet(nn.Module):
return nn.Sequential(*layers)
- def _make_conv_layer(self,
- in_channels=2048,
- num_layers=3,
- num_filters=2048,
- stride=1):
+ def _make_conv_layer(self, in_channels=2048, num_layers=3, num_filters=2048, stride=1):
layers = []
for i in range(num_layers):
- downsample = nn.Conv2d(in_channels,
- num_filters,
- stride=1,
- kernel_size=1,
- bias=False)
- layers.append(
- Bottleneck(in_channels,
- num_filters // 4,
- downsample=downsample))
+ downsample = nn.Conv2d(in_channels, num_filters, stride=1, kernel_size=1, bias=False)
+ layers.append(Bottleneck(in_channels, num_filters // 4, downsample=downsample))
in_channels = num_filters
return nn.Sequential(*layers)
@@ -444,18 +403,15 @@ class HighResolutionNet(nn.Module):
kernel_size=3,
stride=stride,
padding=1,
- ))
+ )
+ )
in_channels = 2 * in_channels
layers.append(nn.BatchNorm2d(in_channels, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
- def _make_stage(self,
- layer_config,
- num_inchannels,
- multi_scale_output=True,
- log=False):
+ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True, log=False):
num_modules = layer_config["num_modules"]
num_branches = layer_config["num_branches"]
num_blocks = layer_config["num_blocks"]
@@ -480,7 +436,8 @@ class HighResolutionNet(nn.Module):
num_channels,
fuse_method,
reset_multi_scale_output,
- ))
+ )
+ )
modules[-1].log = log
num_inchannels = modules[-1].get_num_inchannels()
@@ -580,15 +537,14 @@ class HighResolutionNet(nn.Module):
def load_weights(self, pretrained=""):
pretrained = osp.expandvars(pretrained)
if osp.isfile(pretrained):
- pretrained_state_dict = torch.load(
- pretrained, map_location=torch.device("cpu"))
+ pretrained_state_dict = torch.load(pretrained, map_location=torch.device("cpu"))
need_init_state_dict = {}
for name, m in pretrained_state_dict.items():
- if (name.split(".")[0] in self.pretrained_layers
- or self.pretrained_layers[0] == "*"):
+ if (
+ name.split(".")[0] in self.pretrained_layers or self.pretrained_layers[0] == "*"
+ ):
need_init_state_dict[name] = m
- missing, unexpected = self.load_state_dict(need_init_state_dict,
- strict=False)
+ missing, unexpected = self.load_state_dict(need_init_state_dict, strict=False)
elif pretrained:
raise ValueError("{} is not exist!".format(pretrained))
diff --git a/lib/pixielib/models/lbs.py b/lib/pixielib/models/lbs.py
index 2b5f9a648408f3a83670b9bce94f7a7a08de37ae..a2252a9a81c7e9ca3633a02cc08f3fafd5bd22cc 100755
--- a/lib/pixielib/models/lbs.py
+++ b/lib/pixielib/models/lbs.py
@@ -30,8 +30,7 @@ def rot_mat_to_euler(rot_mats):
# Calculates rotation matrix to euler angles
# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
- sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
- rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
+ sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
return torch.atan2(-rot_mats[:, 2, 0], sy)
@@ -86,15 +85,13 @@ def find_dynamic_lmk_idx_and_bcoords(
# aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
rot_mats = torch.index_select(pose, 1, head_kin_chain)
- rel_rot_mat = torch.eye(3, device=vertices.device,
- dtype=dtype).unsqueeze_(dim=0)
+ rel_rot_mat = torch.eye(3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0)
for idx in range(len(head_kin_chain)):
# rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
rel_rot_mat = torch.matmul(rot_mats[:, idx], rel_rot_mat)
- y_rot_angle = torch.round(
- torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
- max=39)).to(dtype=torch.long)
+ y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
+ max=39)).to(dtype=torch.long)
# print(y_rot_angle[0])
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
@@ -102,8 +99,7 @@ def find_dynamic_lmk_idx_and_bcoords(
y_rot_angle = neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle
# print(y_rot_angle[0])
- dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0,
- y_rot_angle)
+ dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, y_rot_angle)
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle)
return dyn_lmk_faces_idx, dyn_lmk_b_coords
@@ -135,11 +131,11 @@ def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
batch_size, num_verts = vertices.shape[:2]
device = vertices.device
- lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
- batch_size, -1, 3)
+ lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(batch_size, -1, 3)
- lmk_faces += (torch.arange(batch_size, dtype=torch.long,
- device=device).view(-1, 1, 1) * num_verts)
+ lmk_faces += (
+ torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
+ )
lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3)
@@ -211,13 +207,11 @@ def lbs(
# N x J x 3 x 3
ident = torch.eye(3, dtype=dtype, device=device)
if pose2rot:
- rot_mats = batch_rodrigues(pose.view(-1, 3),
- dtype=dtype).view([batch_size, -1, 3, 3])
+ rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
# (N x P) x (P, V * 3) -> N x V x 3
- pose_offsets = torch.matmul(pose_feature,
- posedirs).view(batch_size, -1, 3)
+ pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
else:
pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
rot_mats = pose.view(batch_size, -1, 3, 3)
@@ -234,12 +228,9 @@ def lbs(
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
# (N x V x (J + 1)) x (N x (J + 1) x 16)
num_joints = J_regressor.shape[0]
- T = torch.matmul(W, A.view(batch_size, num_joints,
- 16)).view(batch_size, -1, 4, 4)
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
- homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
- dtype=dtype,
- device=device)
+ homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], dtype=dtype, device=device)
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
@@ -318,8 +309,7 @@ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
- K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros],
- dim=1).view((batch_size, 3, 3))
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3))
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
@@ -335,9 +325,7 @@ def transform_mat(R, t):
- T: Bx4x4 Transformation matrix
"""
# No padding left or right, only add an extra row
- return torch.cat([F.pad(R, [0, 0, 0, 1]),
- F.pad(t, [0, 0, 0, 1], value=1)],
- dim=2)
+ return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
@@ -370,15 +358,13 @@ def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
rel_joints[:, 1:] -= joints[:, parents[1:]]
transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
- rel_joints.reshape(-1, 3, 1)).reshape(
- -1, joints.shape[1], 4, 4)
+ rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
transform_chain = [transforms_mat[:, 0]]
for i in range(1, parents.shape[0]):
# Subtract the joint location at the rest pose
# No need for rotation, since it's identity when at rest
- curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:,
- i])
+ curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i])
transform_chain.append(curr_res)
transforms = torch.stack(transform_chain, dim=1)
@@ -392,21 +378,22 @@ def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
joints_homogen = F.pad(joints, [0, 0, 0, 1])
rel_transforms = transforms - F.pad(
- torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
+ torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]
+ )
return posed_joints, rel_transforms
class JointsFromVerticesSelector(nn.Module):
-
def __init__(self, fname):
"""Selects extra joints from vertices"""
super(JointsFromVerticesSelector, self).__init__()
err_msg = ("Either pass a filename or triangle face ids, names and"
" barycentrics")
- assert fname is not None or (face_ids is not None and bcs is not None
- and names is not None), err_msg
+ assert fname is not None or (
+ face_ids is not None and bcs is not None and names is not None
+ ), err_msg
if fname is not None:
fname = os.path.expanduser(os.path.expandvars(fname))
with open(fname, "r") as f:
@@ -422,13 +409,11 @@ class JointsFromVerticesSelector(nn.Module):
assert len(bcs) == len(
face_ids
), "The number of barycentric coordinates must be equal to the faces"
- assert len(names) == len(
- face_ids), "The number of names must be equal to the number of "
+ assert len(names) == len(face_ids), "The number of names must be equal to the number of "
self.names = names
self.register_buffer("bcs", torch.tensor(bcs, dtype=torch.float32))
- self.register_buffer("face_ids",
- torch.tensor(face_ids, dtype=torch.long))
+ self.register_buffer("face_ids", torch.tensor(face_ids, dtype=torch.long))
def extra_joint_names(self):
"""Returns the names of the extra joints"""
@@ -439,8 +424,7 @@ class JointsFromVerticesSelector(nn.Module):
return []
vertex_ids = faces[self.face_ids].reshape(-1)
# Should be BxNx3x3
- triangles = torch.index_select(vertices, 1, vertex_ids).reshape(
- -1, len(self.bcs), 3, 3)
+ triangles = torch.index_select(vertices, 1, vertex_ids).reshape(-1, len(self.bcs), 3, 3)
return (triangles * self.bcs[None, :, :, None]).sum(dim=2)
@@ -463,7 +447,6 @@ def to_np(array, dtype=np.float32):
class Struct(object):
-
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
diff --git a/lib/pixielib/models/moderators.py b/lib/pixielib/models/moderators.py
index 8a14c472530787be97045a4e620e28cae051df65..3ab139ac2ad3e0cbd99c8e40dbf6136a37e53cb5 100644
--- a/lib/pixielib/models/moderators.py
+++ b/lib/pixielib/models/moderators.py
@@ -12,11 +12,7 @@ import torch.nn.functional as F
class TempSoftmaxFusion(nn.Module):
-
- def __init__(self,
- channels=[2048 * 2, 1024, 1],
- detach_inputs=False,
- detach_feature=False):
+ def __init__(self, channels=[2048 * 2, 1024, 1], detach_inputs=False, detach_feature=False):
super(TempSoftmaxFusion, self).__init__()
self.detach_inputs = detach_inputs
self.detach_feature = detach_feature
@@ -63,11 +59,7 @@ class TempSoftmaxFusion(nn.Module):
class GumbelSoftmaxFusion(nn.Module):
-
- def __init__(self,
- channels=[2048 * 2, 1024, 1],
- detach_inputs=False,
- detach_feature=False):
+ def __init__(self, channels=[2048 * 2, 1024, 1], detach_inputs=False, detach_feature=False):
super(GumbelSoftmaxFusion, self).__init__()
self.detach_inputs = detach_inputs
self.detach_feature = detach_feature
diff --git a/lib/pixielib/models/resnet.py b/lib/pixielib/models/resnet.py
index 72bbf174a9b0ff9a9d75010e70e8059326fb72e3..162bc655bff1bd3ca2058334de2e15660de8f5f5 100755
--- a/lib/pixielib/models/resnet.py
+++ b/lib/pixielib/models/resnet.py
@@ -22,16 +22,10 @@ from torchvision import models
class ResNet(nn.Module):
-
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
- self.conv1 = nn.Conv2d(3,
- 64,
- kernel_size=7,
- stride=2,
- padding=3,
- bias=False)
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@@ -98,12 +92,7 @@ class Bottleneck(nn.Module):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
- self.conv2 = nn.Conv2d(planes,
- planes,
- kernel_size=3,
- stride=stride,
- padding=1,
- bias=False)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
@@ -136,12 +125,7 @@ class Bottleneck(nn.Module):
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
- return nn.Conv2d(in_planes,
- out_planes,
- kernel_size=3,
- stride=stride,
- padding=1,
- bias=False)
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
@@ -196,8 +180,7 @@ def load_ResNet50Model():
model = ResNet(Bottleneck, [3, 4, 6, 3])
copy_parameter_from_resnet(
model,
- torchvision.models.resnet50(
- weights=models.ResNet50_Weights.DEFAULT).state_dict(),
+ torchvision.models.resnet50(weights=models.ResNet50_Weights.DEFAULT).state_dict(),
)
return model
@@ -206,8 +189,7 @@ def load_ResNet101Model():
model = ResNet(Bottleneck, [3, 4, 23, 3])
copy_parameter_from_resnet(
model,
- torchvision.models.resnet101(
- weights=models.ResNet101_Weights.DEFAULT).state_dict(),
+ torchvision.models.resnet101(weights=models.ResNet101_Weights.DEFAULT).state_dict(),
)
return model
@@ -216,8 +198,7 @@ def load_ResNet152Model():
model = ResNet(Bottleneck, [3, 8, 36, 3])
copy_parameter_from_resnet(
model,
- torchvision.models.resnet152(
- weights=models.ResNet152_Weights.DEFAULT).state_dict(),
+ torchvision.models.resnet152(weights=models.ResNet152_Weights.DEFAULT).state_dict(),
)
return model
@@ -229,7 +210,6 @@ def load_ResNet152Model():
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
-
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
@@ -247,11 +227,9 @@ class DoubleConv(nn.Module):
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
-
def __init__(self, in_channels, out_channels):
super().__init__()
- self.maxpool_conv = nn.Sequential(
- nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
+ self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
def forward(self, x):
return self.maxpool_conv(x)
@@ -259,20 +237,16 @@ class Down(nn.Module):
class Up(nn.Module):
"""Upscaling then double conv"""
-
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
- self.up = nn.Upsample(scale_factor=2,
- mode="bilinear",
- align_corners=True)
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
else:
- self.up = nn.ConvTranspose2d(in_channels // 2,
- in_channels // 2,
- kernel_size=2,
- stride=2)
+ self.up = nn.ConvTranspose2d(
+ in_channels // 2, in_channels // 2, kernel_size=2, stride=2
+ )
self.conv = DoubleConv(in_channels, out_channels)
@@ -282,9 +256,7 @@ class Up(nn.Module):
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
- x1 = F.pad(
- x1,
- [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
@@ -293,7 +265,6 @@ class Up(nn.Module):
class OutConv(nn.Module):
-
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
@@ -303,7 +274,6 @@ class OutConv(nn.Module):
class UNet(nn.Module):
-
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
diff --git a/lib/pixielib/pixie.py b/lib/pixielib/pixie.py
index 575ec41079d5867ea4e3cce8968eb6b5e0bb4e95..545bc46f92b73aff4037da7ff3c6ebeba2b4c361 100644
--- a/lib/pixielib/pixie.py
+++ b/lib/pixielib/pixie.py
@@ -33,7 +33,6 @@ from .utils.config import cfg
class PIXIE(object):
-
def __init__(self, config=None, device="cuda:0"):
if config is None:
self.cfg = cfg
@@ -45,10 +44,7 @@ class PIXIE(object):
self.param_list_dict = {}
for lst in self.cfg.params.keys():
param_list = cfg.params.get(lst)
- self.param_list_dict[lst] = {
- i: cfg.model.get("n_" + i)
- for i in param_list
- }
+ self.param_list_dict[lst] = {i: cfg.model.get("n_" + i) for i in param_list}
# Build the models
self._create_model()
@@ -97,24 +93,19 @@ class PIXIE(object):
self.Regressor = {}
for key in self.cfg.network.regressor.keys():
n_output = sum(self.param_list_dict[f"{key}_list"].values())
- channels = ([2048] + self.cfg.network.regressor.get(key).channels +
- [n_output])
+ channels = ([2048] + self.cfg.network.regressor.get(key).channels + [n_output])
if self.cfg.network.regressor.get(key).type == "mlp":
self.Regressor[key] = MLP(channels=channels).to(self.device)
- self.model_dict[f"Regressor_{key}"] = self.Regressor[
- key].state_dict()
+ self.model_dict[f"Regressor_{key}"] = self.Regressor[key].state_dict()
# Build the extractors
# to extract separate head/left hand/right hand feature from body feature
self.Extractor = {}
for key in self.cfg.network.extractor.keys():
- channels = [
- 2048
- ] + self.cfg.network.extractor.get(key).channels + [2048]
+ channels = [2048] + self.cfg.network.extractor.get(key).channels + [2048]
if self.cfg.network.extractor.get(key).type == "mlp":
self.Extractor[key] = MLP(channels=channels).to(self.device)
- self.model_dict[f"Extractor_{key}"] = self.Extractor[
- key].state_dict()
+ self.model_dict[f"Extractor_{key}"] = self.Extractor[key].state_dict()
# Build the moderators
self.Moderator = {}
@@ -122,15 +113,13 @@ class PIXIE(object):
share_part = key.split("_")[0]
detach_inputs = self.cfg.network.moderator.get(key).detach_inputs
detach_feature = self.cfg.network.moderator.get(key).detach_feature
- channels = [2048 * 2
- ] + self.cfg.network.moderator.get(key).channels + [2]
+ channels = [2048 * 2] + self.cfg.network.moderator.get(key).channels + [2]
self.Moderator[key] = TempSoftmaxFusion(
detach_inputs=detach_inputs,
detach_feature=detach_feature,
channels=channels,
).to(self.device)
- self.model_dict[f"Moderator_{key}"] = self.Moderator[
- key].state_dict()
+ self.model_dict[f"Moderator_{key}"] = self.Moderator[key].state_dict()
# Build the SMPL-X body model, which we also use to represent faces and
# hands, using the relevant parts only
@@ -147,9 +136,7 @@ class PIXIE(object):
print(f"pixie trained model path: {model_path} does not exist!")
exit()
# eval mode
- for module in [
- self.Encoder, self.Regressor, self.Moderator, self.Extractor
- ]:
+ for module in [self.Encoder, self.Regressor, self.Moderator, self.Extractor]:
for net in module.values():
net.eval()
@@ -185,14 +172,14 @@ class PIXIE(object):
# crop
cropper_key = "hand" if "hand" in part_key else part_key
points_scale = image.shape[-2:]
- cropped_image, tform = self.Cropper[cropper_key].crop(
- image, points_for_crop, points_scale)
+ cropped_image, tform = self.Cropper[cropper_key].crop(image, points_for_crop, points_scale)
# transform points(must be normalized to [-1.1]) accordingly
cropped_points_dict = {}
for points_key in points_dict.keys():
points = points_dict[points_key]
cropped_points = self.Cropper[cropper_key].transform_points(
- points, tform, points_scale, normalize=True)
+ points, tform, points_scale, normalize=True
+ )
cropped_points_dict[points_key] = cropped_points
return cropped_image, cropped_points_dict
@@ -244,8 +231,7 @@ class PIXIE(object):
# then predict share parameters
feature[key][f"{key}_share"] = feature[key][key]
share_dict = self.decompose_code(
- self.Regressor[f"{part}_share"](
- feature[key][f"{part}_share"]),
+ self.Regressor[f"{part}_share"](feature[key][f"{part}_share"]),
self.param_list_dict[f"{part}_share_list"],
)
# compose parameters
@@ -257,13 +243,16 @@ class PIXIE(object):
f_body = feature["body"]["body"]
# extract part feature
for part_name in ["head", "left_hand", "right_hand"]:
- feature["body"][f"{part_name}_share"] = self.Extractor[
- f"{part_name}_share"](f_body)
+ feature["body"][f"{part_name}_share"] = self.Extractor[f"{part_name}_share"](
+ f_body
+ )
# -- check if part crops are given, if not, crop parts by coarse body estimation
- if ("head_image" not in data[key].keys()
- or "left_hand_image" not in data[key].keys()
- or "right_hand_image" not in data[key].keys()):
+ if (
+ "head_image" not in data[key].keys() or
+ "left_hand_image" not in data[key].keys() or
+ "right_hand_image" not in data[key].keys()
+ ):
# - run without fusion to get coarse estimation, for cropping parts
# body only
body_dict = self.decompose_code(
@@ -272,29 +261,26 @@ class PIXIE(object):
)
# head share
head_share_dict = self.decompose_code(
- self.Regressor["head" + "_share"](
- feature[key]["head" + "_share"]),
+ self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
self.param_list_dict["head" + "_share_list"],
)
# right hand share
right_hand_share_dict = self.decompose_code(
- self.Regressor["hand" + "_share"](
- feature[key]["right_hand" + "_share"]),
+ self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
self.param_list_dict["hand" + "_share_list"],
)
# left hand share
left_hand_share_dict = self.decompose_code(
- self.Regressor["hand" + "_share"](
- feature[key]["left_hand" + "_share"]),
+ self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
self.param_list_dict["hand" + "_share_list"],
)
# change the dict name from right to left
- left_hand_share_dict[
- "left_hand_pose"] = left_hand_share_dict.pop(
- "right_hand_pose")
- left_hand_share_dict[
- "left_wrist_pose"] = left_hand_share_dict.pop(
- "right_wrist_pose")
+ left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop(
+ "right_hand_pose"
+ )
+ left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
+ "right_wrist_pose"
+ )
param_dict[key] = {
**body_dict,
**head_share_dict,
@@ -304,21 +290,18 @@ class PIXIE(object):
if body_only:
param_dict["moderator_weight"] = None
return param_dict
- prediction_body_only = self.decode(param_dict[key],
- param_type="body")
+ prediction_body_only = self.decode(param_dict[key], param_type="body")
# crop
for part_name in ["head", "left_hand", "right_hand"]:
part = part_name.split("_")[-1]
points_dict = {
- "smplx_kpt":
- prediction_body_only["smplx_kpt"],
- "trans_verts":
- prediction_body_only["transformed_vertices"],
+ "smplx_kpt": prediction_body_only["smplx_kpt"],
+ "trans_verts": prediction_body_only["transformed_vertices"],
}
- image_hd = torchvision.transforms.Resize(1024)(
- data["body"]["image"])
+ image_hd = torchvision.transforms.Resize(1024)(data["body"]["image"])
cropped_image, cropped_joints_dict = self.part_from_body(
- image_hd, part_name, points_dict)
+ image_hd, part_name, points_dict
+ )
data[key][part_name + "_image"] = cropped_image
# -- encode features from part crops, then fuse feature using the weight from moderator
@@ -338,16 +321,12 @@ class PIXIE(object):
self.Regressor[f"{part}_share"](f_part),
self.param_list_dict[f"{part}_share_list"],
)
- param_dict["body_" + part_name] = {
- **part_dict,
- **part_share_dict
- }
+ param_dict["body_" + part_name] = {**part_dict, **part_share_dict}
# moderator to assign weight, then integrate features
- f_body_out, f_part_out, f_weight = self.Moderator[
- f"{part}_share"](feature["body"][f"{part_name}_share"],
- f_part,
- work=True)
+ f_body_out, f_part_out, f_weight = self.Moderator[f"{part}_share"](
+ feature["body"][f"{part_name}_share"], f_part, work=True
+ )
if copy_and_paste:
# copy and paste strategy always trusts the results from part
feature["body"][f"{part_name}_share"] = f_part
@@ -355,8 +334,9 @@ class PIXIE(object):
# for hand, if part weight > 0.7 (very confident, then fully trust part)
part_w = f_weight[:, [1]]
part_w[part_w > 0.7] = 1.0
- f_body_out = (feature["body"][f"{part_name}_share"] *
- (1.0 - part_w) + f_part * part_w)
+ f_body_out = (
+ feature["body"][f"{part_name}_share"] * (1.0 - part_w) + f_part * part_w
+ )
feature["body"][f"{part_name}_share"] = f_body_out
else:
feature["body"][f"{part_name}_share"] = f_body_out
@@ -367,29 +347,24 @@ class PIXIE(object):
# -- predict parameters from fused body feature
# head share
head_share_dict = self.decompose_code(
- self.Regressor["head" + "_share"](feature[key]["head" +
- "_share"]),
+ self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
self.param_list_dict["head" + "_share_list"],
)
# right hand share
right_hand_share_dict = self.decompose_code(
- self.Regressor["hand" + "_share"](
- feature[key]["right_hand" + "_share"]),
+ self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
self.param_list_dict["hand" + "_share_list"],
)
# left hand share
left_hand_share_dict = self.decompose_code(
- self.Regressor["hand" + "_share"](
- feature[key]["left_hand" + "_share"]),
+ self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
self.param_list_dict["hand" + "_share_list"],
)
# change the dict name from right to left
- left_hand_share_dict[
- "left_hand_pose"] = left_hand_share_dict.pop(
- "right_hand_pose")
- left_hand_share_dict[
- "left_wrist_pose"] = left_hand_share_dict.pop(
- "right_wrist_pose")
+ left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop("right_hand_pose")
+ left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
+ "right_wrist_pose"
+ )
param_dict["body"] = {
**body_dict,
**head_share_dict,
@@ -403,10 +378,10 @@ class PIXIE(object):
if keep_local:
# for local change that will not affect whole body and produce unnatral pose, trust part
param_dict[key]["exp"] = param_dict["body_head"]["exp"]
- param_dict[key]["right_hand_pose"] = param_dict[
- "body_right_hand"]["right_hand_pose"]
- param_dict[key]["left_hand_pose"] = param_dict[
- "body_left_hand"]["right_hand_pose"]
+ param_dict[key]["right_hand_pose"] = param_dict["body_right_hand"][
+ "right_hand_pose"]
+ param_dict[key]["left_hand_pose"] = param_dict["body_left_hand"][
+ "right_hand_pose"]
return param_dict
@@ -426,75 +401,70 @@ class PIXIE(object):
if "pose" in key and "jaw" not in key:
param_dict[key] = converter.batch_cont2matrix(param_dict[key])
if param_type == "body" or param_type == "head":
- param_dict["jaw_pose"] = converter.batch_euler2matrix(
- param_dict["jaw_pose"])[:, None, :, :]
+ param_dict["jaw_pose"] = converter.batch_euler2matrix(param_dict["jaw_pose"]
+ )[:, None, :, :]
# complement params if it's not in given param dict
if param_type == "head":
batch_size = param_dict["shape"].shape[0]
param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
param_dict["global_pose"] = param_dict["head_pose"]
- param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(
- 0).expand(
- batch_size, -1, -1,
- -1)[:, :self.param_list_dict["body_list"]["partbody_pose"]]
+ param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
+ batch_size, -1, -1, -1
+ )[:, :self.param_list_dict["body_list"]["partbody_pose"]]
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
- param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(
- 0).expand(batch_size, -1, -1, -1)
- param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(
- 0).expand(batch_size, -1, -1, -1)
- param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(
- 0).expand(batch_size, -1, -1, -1)
- param_dict[
- "right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(
- 0).expand(batch_size, -1, -1, -1)
+ batch_size, -1, -1, -1
+ )
+ param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
+ batch_size, -1, -1, -1
+ )
+ param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
+ batch_size, -1, -1, -1
+ )
+ param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
+ batch_size, -1, -1, -1
+ )
+ param_dict["right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(0).expand(
+ batch_size, -1, -1, -1
+ )
elif param_type == "hand":
batch_size = param_dict["right_hand_pose"].shape[0]
- param_dict["abs_right_wrist_pose"] = param_dict[
- "right_wrist_pose"].clone()
+ param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
dtype = param_dict["right_hand_pose"].dtype
device = param_dict["right_hand_pose"].device
- x_180_pose = (torch.eye(3, dtype=dtype,
- device=device).unsqueeze(0).repeat(
- 1, 1, 1))
+ x_180_pose = (torch.eye(3, dtype=dtype, device=device).unsqueeze(0).repeat(1, 1, 1))
x_180_pose[0, 2, 2] = -1.0
x_180_pose[0, 1, 1] = -1.0
- param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
- param_dict["shape"] = self.smplx.shape_params.expand(
- batch_size, -1)
- param_dict["exp"] = self.smplx.expression_params.expand(
- batch_size, -1)
+ param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
+ param_dict["shape"] = self.smplx.shape_params.expand(batch_size, -1)
+ param_dict["exp"] = self.smplx.expression_params.expand(batch_size, -1)
param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
+ batch_size, -1, -1, -1
+ )
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
- param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(
- batch_size, -1, -1, -1)
- param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(
- 0).expand(
- batch_size, -1, -1,
- -1)[:, :self.param_list_dict["body_list"]["partbody_pose"]]
- param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(
- 0).expand(batch_size, -1, -1, -1)
- param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(
- 0).expand(batch_size, -1, -1, -1)
+ batch_size, -1, -1, -1
+ )
+ param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
+ param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
+ batch_size, -1, -1, -1
+ )[:, :self.param_list_dict["body_list"]["partbody_pose"]]
+ param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
+ batch_size, -1, -1, -1
+ )
+ param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
+ batch_size, -1, -1, -1
+ )
elif param_type == "body":
# the predcition from the head and hand share regressor is always absolute pose
batch_size = param_dict["shape"].shape[0]
param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
- param_dict["abs_right_wrist_pose"] = param_dict[
- "right_wrist_pose"].clone()
- param_dict["abs_left_wrist_pose"] = param_dict[
- "left_wrist_pose"].clone()
+ param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
+ param_dict["abs_left_wrist_pose"] = param_dict["left_wrist_pose"].clone()
# the body-hand share regressor is working for right hand
# so we assume body network get the flipped feature for the left hand. then get the parameters
# then we need to flip it back to left, which matches the input left hand
- param_dict["left_wrist_pose"] = util.flip_pose(
- param_dict["left_wrist_pose"])
- param_dict["left_hand_pose"] = util.flip_pose(
- param_dict["left_hand_pose"])
+ param_dict["left_wrist_pose"] = util.flip_pose(param_dict["left_wrist_pose"])
+ param_dict["left_hand_pose"] = util.flip_pose(param_dict["left_hand_pose"])
else:
exit()
@@ -508,8 +478,7 @@ class PIXIE(object):
Returns:
predictions: smplx predictions
"""
- if "jaw_pose" in param_dict.keys() and len(
- param_dict["jaw_pose"].shape) == 2:
+ if "jaw_pose" in param_dict.keys() and len(param_dict["jaw_pose"].shape) == 2:
self.convert_pose(param_dict, param_type)
elif param_dict["right_wrist_pose"].shape[-1] == 6:
self.convert_pose(param_dict, param_type)
@@ -532,9 +501,8 @@ class PIXIE(object):
# change absolute head&hand pose to relative pose according to rest body pose
if param_type == "head" or param_type == "body":
param_dict["body_pose"] = self.smplx.pose_abs2rel(
- param_dict["global_pose"],
- param_dict["body_pose"],
- abs_joint="head")
+ param_dict["global_pose"], param_dict["body_pose"], abs_joint="head"
+ )
if param_type == "hand" or param_type == "body":
param_dict["body_pose"] = self.smplx.pose_abs2rel(
param_dict["global_pose"],
@@ -550,7 +518,7 @@ class PIXIE(object):
if self.cfg.model.check_pose:
# check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose)
# xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left)
- for pose_ind in [14]: # head [15-1, 20-1, 21-1]:
+ for pose_ind in [14]: # head [15-1, 20-1, 21-1]:
curr_pose = param_dict["body_pose"][:, pose_ind]
euler_pose = converter._compute_euler_from_matrix(curr_pose)
for i, max_angle in enumerate([20, 70, 10]):
@@ -560,9 +528,7 @@ class PIXIE(object):
min=-max_angle * np.pi / 180,
max=max_angle * np.pi / 180,
)] = 0.0
- param_dict[
- "body_pose"][:, pose_ind] = converter.batch_euler2matrix(
- euler_pose)
+ param_dict["body_pose"][:, pose_ind] = converter.batch_euler2matrix(euler_pose)
# SMPLX
verts, landmarks, joints = self.smplx(
@@ -594,8 +560,8 @@ class PIXIE(object):
# change the order of face keypoints, to be the same as "standard" 68 keypoints
prediction["face_kpt"] = torch.cat(
- [prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]],
- dim=1)
+ [prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]], dim=1
+ )
prediction.update(param_dict)
diff --git a/lib/pixielib/utils/array_cropper.py b/lib/pixielib/utils/array_cropper.py
index 661146ec42d58c207d00182f162a88f1c594e3df..fbee84b6a6f0f3dcad7fcd6b33bf03faf56be625 100644
--- a/lib/pixielib/utils/array_cropper.py
+++ b/lib/pixielib/utils/array_cropper.py
@@ -23,15 +23,14 @@ def points2bbox(points, points_scale=None):
bottom = np.max(points[:, 1])
size = max(right - left, bottom - top)
# + old_size*0.1])
- center = np.array(
- [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
+ center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
return center, size
# translate center
def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.0):
trans_scale = (np.random.rand(2) * 2 - 1) * trans_scale
- center = center + trans_scale * bbox_size # 0.5
+ center = center + trans_scale * bbox_size # 0.5
scale = np.random.rand() * (scale[1] - scale[0]) + scale[0]
size = int(bbox_size * scale)
return center, size
@@ -48,27 +47,25 @@ def crop_array(image, center, bboxsize, crop_size):
tform: 3x3 affine matrix
"""
# points: top-left, top-right, bottom-right
- src_pts = np.array([
- [center[0] - bboxsize / 2, center[1] - bboxsize / 2],
- [center[0] + bboxsize / 2, center[1] - bboxsize / 2],
- [center[0] + bboxsize / 2, center[1] + bboxsize / 2],
- ])
- DST_PTS = np.array([[0, 0], [crop_size - 1, 0],
- [crop_size - 1, crop_size - 1]])
+ src_pts = np.array(
+ [
+ [center[0] - bboxsize / 2, center[1] - bboxsize / 2],
+ [center[0] + bboxsize / 2, center[1] - bboxsize / 2],
+ [center[0] + bboxsize / 2, center[1] + bboxsize / 2],
+ ]
+ )
+ DST_PTS = np.array([[0, 0], [crop_size - 1, 0], [crop_size - 1, crop_size - 1]])
# estimate transformation between points
tform = estimate_transform("similarity", src_pts, DST_PTS)
# warp images
- cropped_image = warp(image,
- tform.inverse,
- output_shape=(crop_size, crop_size))
+ cropped_image = warp(image, tform.inverse, output_shape=(crop_size, crop_size))
return cropped_image, tform.params.T
class Cropper(object):
-
def __init__(self, crop_size, scale=[1, 1], trans_scale=0.0):
self.crop_size = crop_size
self.scale = scale
@@ -78,11 +75,9 @@ class Cropper(object):
# points to bbox
center, bbox_size = points2bbox(points, points_scale)
# argument bbox.
- center, bbox_size = augment_bbox(center,
- bbox_size,
- scale=self.scale,
- trans_scale=self.trans_scale)
+ center, bbox_size = augment_bbox(
+ center, bbox_size, scale=self.scale, trans_scale=self.trans_scale
+ )
# crop
- cropped_image, tform = crop_array(image, center, bbox_size,
- self.crop_size)
+ cropped_image, tform = crop_array(image, center, bbox_size, self.crop_size)
return cropped_image, tform
diff --git a/lib/pixielib/utils/config.py b/lib/pixielib/utils/config.py
index 04d8ed809489dac385a1764aae0de0565dfe9d6d..115a38e9c52b7cf025defa4a3d37d9490fc71833 100644
--- a/lib/pixielib/utils/config.py
+++ b/lib/pixielib/utils/config.py
@@ -8,59 +8,59 @@ import os
cfg = CN()
-abs_pixie_dir = os.path.abspath(
- os.path.join(os.path.dirname(__file__), "..", "..", ".."))
+abs_pixie_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
cfg.pixie_dir = abs_pixie_dir
cfg.device = "cuda"
cfg.device_id = "0"
-cfg.pretrained_modelpath = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "pixie_model.tar")
+cfg.pretrained_modelpath = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", "pixie_model.tar")
# smplx parameter settings
cfg.params = CN()
-cfg.params.body_list = [
- "body_cam", "global_pose", "partbody_pose", "neck_pose"
-]
+cfg.params.body_list = ["body_cam", "global_pose", "partbody_pose", "neck_pose"]
cfg.params.head_list = ["head_cam", "tex", "light"]
cfg.params.head_share_list = ["shape", "exp", "head_pose", "jaw_pose"]
cfg.params.hand_list = ["hand_cam"]
cfg.params.hand_share_list = [
"right_wrist_pose",
"right_hand_pose",
-] # only for right hand
+] # only for right hand
# ---------------------------------------------------------------------------- #
# Options for Body model
# ---------------------------------------------------------------------------- #
cfg.model = CN()
-cfg.model.topology_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "SMPL_X_template_FLAME_uv.obj")
-cfg.model.topology_smplxtex_path = os.path.join(cfg.pixie_dir,
- "data/HPS/pixie_data",
- "smplx_tex.obj")
-cfg.model.topology_smplx_hand_path = os.path.join(cfg.pixie_dir,
- "data/HPS/pixie_data",
- "smplx_hand.obj")
-cfg.model.smplx_model_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "SMPLX_NEUTRAL_2020.npz")
-cfg.model.face_mask_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "uv_face_mask.png")
-cfg.model.face_eye_mask_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "uv_face_eye_mask.png")
-cfg.model.tex_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "FLAME_albedo_from_BFM.npz")
-cfg.model.extra_joint_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "smplx_extra_joints.yaml")
-cfg.model.j14_regressor_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "SMPLX_to_J14.pkl")
-cfg.model.flame2smplx_cached_path = os.path.join(cfg.pixie_dir,
- "data/HPS/pixie_data",
- "flame2smplx_tex_1024.npy")
-cfg.model.smplx_tex_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "smplx_tex.png")
-cfg.model.mano_ids_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "MANO_SMPLX_vertex_ids.pkl")
-cfg.model.flame_ids_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data",
- "SMPL-X__FLAME_vertex_ids.npy")
+cfg.model.topology_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "SMPL_X_template_FLAME_uv.obj"
+)
+cfg.model.topology_smplxtex_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "smplx_tex.obj"
+)
+cfg.model.topology_smplx_hand_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "smplx_hand.obj"
+)
+cfg.model.smplx_model_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "SMPLX_NEUTRAL_2020.npz"
+)
+cfg.model.face_mask_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", "uv_face_mask.png")
+cfg.model.face_eye_mask_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "uv_face_eye_mask.png"
+)
+cfg.model.tex_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", "FLAME_albedo_from_BFM.npz")
+cfg.model.extra_joint_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "smplx_extra_joints.yaml"
+)
+cfg.model.j14_regressor_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "SMPLX_to_J14.pkl"
+)
+cfg.model.flame2smplx_cached_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "flame2smplx_tex_1024.npy"
+)
+cfg.model.smplx_tex_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", "smplx_tex.png")
+cfg.model.mano_ids_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "MANO_SMPLX_vertex_ids.pkl"
+)
+cfg.model.flame_ids_path = os.path.join(
+ cfg.pixie_dir, "data/HPS/pixie_data", "SMPL-X__FLAME_vertex_ids.npy"
+)
cfg.model.uv_size = 256
cfg.model.n_shape = 200
cfg.model.n_tex = 50
@@ -68,16 +68,16 @@ cfg.model.n_exp = 50
cfg.model.n_body_cam = 3
cfg.model.n_head_cam = 3
cfg.model.n_hand_cam = 3
-cfg.model.tex_type = "BFM" # BFM, FLAME, albedoMM
-cfg.model.uvtex_type = "SMPLX" # FLAME or SMPLX
-cfg.model.use_tex = False # whether to use flame texture model
+cfg.model.tex_type = "BFM" # BFM, FLAME, albedoMM
+cfg.model.uvtex_type = "SMPLX" # FLAME or SMPLX
+cfg.model.use_tex = False # whether to use flame texture model
cfg.model.flame_tex_path = ""
# pose
cfg.model.n_global_pose = 3 * 2
cfg.model.n_head_pose = 3 * 2
cfg.model.n_neck_pose = 3 * 2
-cfg.model.n_jaw_pose = 3 # euler angle
+cfg.model.n_jaw_pose = 3 # euler angle
cfg.model.n_body_pose = 21 * 3 * 2
cfg.model.n_partbody_pose = (21 - 4) * 3 * 2
cfg.model.n_left_hand_pose = 15 * 3 * 2
diff --git a/lib/pixielib/utils/renderer.py b/lib/pixielib/utils/renderer.py
index d45e9ae0adefdc5dceab89e9bb6e71cca58a3630..eb2dc795e01b3e5c78a4ce848777d6cbc5558401 100755
--- a/lib/pixielib/utils/renderer.py
+++ b/lib/pixielib/utils/renderer.py
@@ -36,7 +36,7 @@ def set_rasterizer(type="pytorch3d"):
f"{curr_dir}/rasterizer/standard_rasterize_cuda_kernel.cu",
],
extra_cuda_cflags=["-std=c++14", "-ccbin=$$(which gcc-7)"],
- ) # cuda10.2 is not compatible with gcc9. Specify gcc 7
+ ) # cuda10.2 is not compatible with gcc9. Specify gcc 7
from standard_rasterize_cuda import standard_rasterize
# If JIT does not work, try manually installation first
@@ -51,7 +51,6 @@ class StandardRasterizer(nn.Module):
can render non-squared image
not differentiable
"""
-
def __init__(self, height, width=None):
"""
use fixed raster_settings for rendering faces
@@ -80,15 +79,15 @@ class StandardRasterizer(nn.Module):
vertices[..., 2] = vertices[..., 2] * w / 2
f_vs = util.face_vertices(vertices, faces)
- standard_rasterize(f_vs, depth_buffer, triangle_buffer, baryw_buffer,
- h, w)
+ standard_rasterize(f_vs, depth_buffer, triangle_buffer, baryw_buffer, h, w)
pix_to_face = triangle_buffer[:, :, :, None].long()
bary_coords = baryw_buffer[:, :, :, None, :]
vismask = (pix_to_face > -1).float()
D = attributes.shape[-1]
attributes = attributes.clone()
- attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
- 3, attributes.shape[-1])
+ attributes = attributes.view(
+ attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1]
+ )
N, H, W, K, _ = bary_coords.shape
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
@@ -96,10 +95,9 @@ class StandardRasterizer(nn.Module):
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
- pixel_vals[mask] = 0 # Replace masked values in output.
+ pixel_vals[mask] = 0 # Replace masked values in output.
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
- pixel_vals = torch.cat(
- [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
+ pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
return pixel_vals
@@ -110,7 +108,6 @@ class Pytorch3dRasterizer(nn.Module):
x,y,z are in image space, normalized
can only render squared image now
"""
-
def __init__(self, image_size=224):
"""
use fixed raster_settings for rendering faces
@@ -130,8 +127,7 @@ class Pytorch3dRasterizer(nn.Module):
def forward(self, vertices, faces, attributes=None, h=None, w=None):
fixed_vertices = vertices.clone()
fixed_vertices[..., :2] = -fixed_vertices[..., :2]
- meshes_screen = Meshes(verts=fixed_vertices.float(),
- faces=faces.long())
+ meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long())
raster_settings = self.raster_settings
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_screen,
@@ -145,8 +141,9 @@ class Pytorch3dRasterizer(nn.Module):
vismask = (pix_to_face > -1).float()
D = attributes.shape[-1]
attributes = attributes.clone()
- attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
- 3, attributes.shape[-1])
+ attributes = attributes.view(
+ attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1]
+ )
N, H, W, K, _ = bary_coords.shape
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
@@ -154,20 +151,14 @@ class Pytorch3dRasterizer(nn.Module):
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
- pixel_vals[mask] = 0 # Replace masked values in output.
+ pixel_vals[mask] = 0 # Replace masked values in output.
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
- pixel_vals = torch.cat(
- [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
+ pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
return pixel_vals
class SRenderY(nn.Module):
-
- def __init__(self,
- image_size,
- obj_filename,
- uv_size=256,
- rasterizer_type="standard"):
+ def __init__(self, image_size, obj_filename, uv_size=256, rasterizer_type="standard"):
super(SRenderY, self).__init__()
self.image_size = image_size
self.uv_size = uv_size
@@ -176,8 +167,8 @@ class SRenderY(nn.Module):
self.rasterizer = Pytorch3dRasterizer(image_size)
self.uv_rasterizer = Pytorch3dRasterizer(uv_size)
verts, faces, aux = load_obj(obj_filename)
- uvcoords = aux.verts_uvs[None, ...] # (N, V, 2)
- uvfaces = faces.textures_idx[None, ...] # (N, F, 3)
+ uvcoords = aux.verts_uvs[None, ...] # (N, V, 2)
+ uvfaces = faces.textures_idx[None, ...] # (N, F, 3)
faces = faces.verts_idx[None, ...]
elif rasterizer_type == "standard":
self.rasterizer = StandardRasterizer(image_size)
@@ -192,15 +183,12 @@ class SRenderY(nn.Module):
# faces
dense_triangles = util.generate_triangles(uv_size, uv_size)
- self.register_buffer(
- "dense_faces",
- torch.from_numpy(dense_triangles).long()[None, :, :])
+ self.register_buffer("dense_faces", torch.from_numpy(dense_triangles).long()[None, :, :])
self.register_buffer("faces", faces)
self.register_buffer("raw_uvcoords", uvcoords)
# uv coords
- uvcoords = torch.cat([uvcoords, uvcoords[:, :, 0:1] * 0.0 + 1.0],
- -1) # [bz, ntv, 3]
+ uvcoords = torch.cat([uvcoords, uvcoords[:, :, 0:1] * 0.0 + 1.0], -1) # [bz, ntv, 3]
uvcoords = uvcoords * 2 - 1
uvcoords[..., 1] = -uvcoords[..., 1]
face_uvcoords = util.face_vertices(uvcoords, uvfaces)
@@ -209,26 +197,29 @@ class SRenderY(nn.Module):
self.register_buffer("face_uvcoords", face_uvcoords)
# shape colors, for rendering shape overlay
- colors = (torch.tensor([180, 180, 180])[None, None, :].repeat(
- 1,
- faces.max() + 1, 1).float() / 255.0)
+ colors = (
+ torch.tensor([180, 180, 180])[None, None, :].repeat(1,
+ faces.max() + 1, 1).float() / 255.0
+ )
face_colors = util.face_vertices(colors, faces)
self.register_buffer("vertex_colors", colors)
self.register_buffer("face_colors", face_colors)
# SH factors for lighting
pi = np.pi
- constant_factor = torch.tensor([
- 1 / np.sqrt(4 * pi),
- ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
- ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
- ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
- (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
- (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
- (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
- (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))),
- (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))),
- ]).float()
+ constant_factor = torch.tensor(
+ [
+ 1 / np.sqrt(4 * pi),
+ ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
+ ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
+ ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
+ (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
+ (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
+ (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
+ (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))),
+ (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))),
+ ]
+ ).float()
self.register_buffer("constant_factor", constant_factor)
def forward(
@@ -256,23 +247,24 @@ class SRenderY(nn.Module):
batch_size = vertices.shape[0]
# normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100)
transformed_vertices = transformed_vertices.clone()
- transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] -
- transformed_vertices[:, :, 2].min())
- transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] /
- transformed_vertices[:, :, 2].max())
+ transformed_vertices[:, :, 2] = (
+ transformed_vertices[:, :, 2] - transformed_vertices[:, :, 2].min()
+ )
+ transformed_vertices[:, :, 2] = (
+ transformed_vertices[:, :, 2] / transformed_vertices[:, :, 2].max()
+ )
transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10
# attributes
- face_vertices = util.face_vertices(
- vertices, self.faces.expand(batch_size, -1, -1))
- normals = util.vertex_normals(vertices,
- self.faces.expand(batch_size, -1, -1))
- face_normals = util.face_vertices(
- normals, self.faces.expand(batch_size, -1, -1))
+ face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1))
+ normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1))
+ face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1))
transformed_normals = util.vertex_normals(
- transformed_vertices, self.faces.expand(batch_size, -1, -1))
+ transformed_vertices, self.faces.expand(batch_size, -1, -1)
+ )
transformed_face_normals = util.face_vertices(
- transformed_normals, self.faces.expand(batch_size, -1, -1))
+ transformed_normals, self.faces.expand(batch_size, -1, -1)
+ )
attributes = torch.cat(
[
self.face_uvcoords.expand(batch_size, -1, -1, -1),
@@ -314,38 +306,32 @@ class SRenderY(nn.Module):
if light_type == "point":
vertice_images = rendering[:, 6:9, :, :].detach()
shading = self.add_pointlight(
- vertice_images.permute(0, 2, 3,
- 1).reshape([batch_size, -1, 3]),
- normal_images.permute(0, 2, 3,
- 1).reshape([batch_size, -1, 3]),
+ vertice_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
+ normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
lights,
)
- shading_images = shading.reshape([
- batch_size, albedo_images.shape[2],
- albedo_images.shape[3], 3
- ]).permute(0, 3, 1, 2)
+ shading_images = shading.reshape(
+ [batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]
+ ).permute(0, 3, 1, 2)
else:
shading = self.add_directionlight(
- normal_images.permute(0, 2, 3,
- 1).reshape([batch_size, -1, 3]),
+ normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
lights,
)
- shading_images = shading.reshape([
- batch_size, albedo_images.shape[2],
- albedo_images.shape[3], 3
- ]).permute(0, 3, 1, 2)
+ shading_images = shading.reshape(
+ [batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]
+ ).permute(0, 3, 1, 2)
images = albedo_images * shading_images
else:
images = albedo_images
shading_images = images.detach() * 0.0
if background is None:
- images = images * alpha_images + torch.ones_like(images).to(
- vertices.device) * (1 - alpha_images)
+ images = images * alpha_images + torch.ones_like(images).to(vertices.device
+ ) * (1 - alpha_images)
else:
# background = F.interpolate(background, [self.image_size, self.image_size])
- images = images * alpha_images + background.contiguous() * (
- 1 - alpha_images)
+ images = images * alpha_images + background.contiguous() * (1 - alpha_images)
outputs = {
"images": images,
@@ -379,11 +365,10 @@ class SRenderY(nn.Module):
3 * (N[:, 2]**2) - 1,
],
1,
- ) # [bz, 9, h, w]
+ ) # [bz, 9, h, w]
sh = sh * self.constant_factor[None, :, None, None]
# [bz, 9, 3, h, w]
- shading = torch.sum(
- sh_coeff[:, :, :, None, None] * sh[:, :, None, :, :], 1)
+ shading = torch.sum(sh_coeff[:, :, :, None, None] * sh[:, :, None, :, :], 1)
return shading
def add_pointlight(self, vertices, normals, lights):
@@ -395,14 +380,12 @@ class SRenderY(nn.Module):
"""
light_positions = lights[:, :, :3]
light_intensities = lights[:, :, 3:]
- directions_to_lights = F.normalize(light_positions[:, :, None, :] -
- vertices[:, None, :, :],
- dim=3)
+ directions_to_lights = F.normalize(
+ light_positions[:, :, None, :] - vertices[:, None, :, :], dim=3
+ )
# normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.)
- normals_dot_lights = (normals[:, None, :, :] *
- directions_to_lights).sum(dim=3)
- shading = normals_dot_lights[:, :, :,
- None] * light_intensities[:, :, None, :]
+ normals_dot_lights = (normals[:, None, :, :] * directions_to_lights).sum(dim=3)
+ shading = normals_dot_lights[:, :, :, None] * light_intensities[:, :, None, :]
return shading.mean(1)
def add_directionlight(self, normals, lights):
@@ -415,16 +398,14 @@ class SRenderY(nn.Module):
light_direction = lights[:, :, :3]
light_intensities = lights[:, :, 3:]
directions_to_lights = F.normalize(
- light_direction[:, :, None, :].expand(-1, -1, normals.shape[1],
- -1),
- dim=3)
+ light_direction[:, :, None, :].expand(-1, -1, normals.shape[1], -1), dim=3
+ )
# normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.)
# normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3)
normals_dot_lights = torch.clamp(
- (normals[:, None, :, :] * directions_to_lights).sum(dim=3), 0.0,
- 1.0)
- shading = normals_dot_lights[:, :, :,
- None] * light_intensities[:, :, None, :]
+ (normals[:, None, :, :] * directions_to_lights).sum(dim=3), 0.0, 1.0
+ )
+ shading = normals_dot_lights[:, :, :, None] * light_intensities[:, :, None, :]
return shading.mean(1)
def render_shape(
@@ -445,36 +426,38 @@ class SRenderY(nn.Module):
"""
batch_size = vertices.shape[0]
if lights is None:
- light_positions = (torch.tensor([
- [-5, 5, -5],
- [5, 5, -5],
- [-5, -5, -5],
- [5, -5, -5],
- [0, 0, -5],
- ])[None, :, :].expand(batch_size, -1, -1).float())
+ light_positions = (
+ torch.tensor([
+ [-5, 5, -5],
+ [5, 5, -5],
+ [-5, -5, -5],
+ [5, -5, -5],
+ [0, 0, -5],
+ ])[None, :, :].expand(batch_size, -1, -1).float()
+ )
light_intensities = torch.ones_like(light_positions).float() * 1.7
- lights = torch.cat((light_positions, light_intensities),
- 2).to(vertices.device)
+ lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device)
# normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100)
transformed_vertices = transformed_vertices.clone()
- transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] -
- transformed_vertices[:, :, 2].min())
- transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] /
- transformed_vertices[:, :, 2].max())
+ transformed_vertices[:, :, 2] = (
+ transformed_vertices[:, :, 2] - transformed_vertices[:, :, 2].min()
+ )
+ transformed_vertices[:, :, 2] = (
+ transformed_vertices[:, :, 2] / transformed_vertices[:, :, 2].max()
+ )
transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10
# Attributes
- face_vertices = util.face_vertices(
- vertices, self.faces.expand(batch_size, -1, -1))
- normals = util.vertex_normals(vertices,
- self.faces.expand(batch_size, -1, -1))
- face_normals = util.face_vertices(
- normals, self.faces.expand(batch_size, -1, -1))
+ face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1))
+ normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1))
+ face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1))
transformed_normals = util.vertex_normals(
- transformed_vertices, self.faces.expand(batch_size, -1, -1))
+ transformed_vertices, self.faces.expand(batch_size, -1, -1)
+ )
transformed_face_normals = util.face_vertices(
- transformed_normals, self.faces.expand(batch_size, -1, -1))
+ transformed_normals, self.faces.expand(batch_size, -1, -1)
+ )
if colors is None:
colors = self.face_colors.expand(batch_size, -1, -1, -1)
attributes = torch.cat(
@@ -513,22 +496,22 @@ class SRenderY(nn.Module):
if uv_detail_normals is not None:
uvcoords_images = rendering[:, 12:15, :, :]
grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2]
- detail_normal_images = F.grid_sample(uv_detail_normals,
- grid,
- align_corners=False)
+ detail_normal_images = F.grid_sample(uv_detail_normals, grid, align_corners=False)
normal_images = detail_normal_images
shading = self.add_directionlight(
- normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
- lights)
- shading_images = (shading.reshape(
- [batch_size, albedo_images.shape[2], albedo_images.shape[3],
- 3]).permute(0, 3, 1, 2).contiguous())
+ normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), lights
+ )
+ shading_images = (
+ shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3],
+ 3]).permute(0, 3, 1, 2).contiguous()
+ )
shaded_images = albedo_images * shading_images
if background is None:
- shape_images = shaded_images * alpha_images + torch.ones_like(
- shaded_images).to(vertices.device) * (1 - alpha_images)
+ shape_images = shaded_images * alpha_images + torch.ones_like(shaded_images).to(
+ vertices.device
+ ) * (1 - alpha_images)
else:
# background = F.interpolate(background, [self.image_size, self.image_size])
shape_images = shaded_images * alpha_images + background.contiguous(
@@ -548,18 +531,18 @@ class SRenderY(nn.Module):
transformed_vertices = transformed_vertices.clone()
batch_size = transformed_vertices.shape[0]
- transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] -
- transformed_vertices[:, :, 2].min())
+ transformed_vertices[:, :, 2] = (
+ transformed_vertices[:, :, 2] - transformed_vertices[:, :, 2].min()
+ )
z = -transformed_vertices[:, :, 2:].repeat(1, 1, 3)
z = z - z.min()
z = z / z.max()
# Attributes
- attributes = util.face_vertices(z,
- self.faces.expand(batch_size, -1, -1))
+ attributes = util.face_vertices(z, self.faces.expand(batch_size, -1, -1))
# rasterize
- rendering = self.rasterizer(transformed_vertices,
- self.faces.expand(batch_size, -1, -1),
- attributes)
+ rendering = self.rasterizer(
+ transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes
+ )
####
alpha_images = rendering[:, -1, :, :][:, None, :, :].detach()
@@ -574,14 +557,15 @@ class SRenderY(nn.Module):
transformed_vertices = transformed_vertices.clone()
batch_size = colors.shape[0]
# normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100)
- transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] -
- transformed_vertices[:, :, 2].min())
- transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] /
- transformed_vertices[:, :, 2].max())
+ transformed_vertices[:, :, 2] = (
+ transformed_vertices[:, :, 2] - transformed_vertices[:, :, 2].min()
+ )
+ transformed_vertices[:, :, 2] = (
+ transformed_vertices[:, :, 2] / transformed_vertices[:, :, 2].max()
+ )
transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10
# Attributes
- attributes = util.face_vertices(colors,
- self.faces.expand(batch_size, -1, -1))
+ attributes = util.face_vertices(colors, self.faces.expand(batch_size, -1, -1))
# rasterize
rendering = self.rasterizer(
transformed_vertices,
@@ -602,8 +586,7 @@ class SRenderY(nn.Module):
uv_vertices: [bz, 3, h, w]
"""
batch_size = vertices.shape[0]
- face_vertices = util.face_vertices(
- vertices, self.faces.expand(batch_size, -1, -1))
+ face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1))
uv_vertices = self.uv_rasterizer(
self.uvcoords.expand(batch_size, -1, -1),
self.uvfaces.expand(batch_size, -1, -1),
diff --git a/lib/pixielib/utils/rotation_converter.py b/lib/pixielib/utils/rotation_converter.py
index 257e4eb2c12c242657d0275825717c48c56b5948..f8057cab4e0f84d035a0b8f964823bd61e91dae4 100644
--- a/lib/pixielib/utils/rotation_converter.py
+++ b/lib/pixielib/utils/rotation_converter.py
@@ -27,8 +27,7 @@ def rad2deg(tensor):
>>> output = tgm.rad2deg(input)
"""
if not torch.is_tensor(tensor):
- raise TypeError("Input type is not a torch.Tensor. Got {}".format(
- type(tensor)))
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(tensor)))
return 180.0 * tensor / pi.to(tensor.device).type(tensor.dtype)
@@ -50,8 +49,7 @@ def deg2rad(tensor):
>>> output = tgm.deg2rad(input)
"""
if not torch.is_tensor(tensor):
- raise TypeError("Input type is not a torch.Tensor. Got {}".format(
- type(tensor)))
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(tensor)))
return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.0
@@ -102,13 +100,12 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
"""
if not torch.is_tensor(rotation_matrix):
- raise TypeError("Input type is not a torch.Tensor. Got {}".format(
- type(rotation_matrix)))
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
if len(rotation_matrix.shape) > 3:
raise ValueError(
- "Input size must be a three dimensional tensor. Got {}".format(
- rotation_matrix.shape))
+ "Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape)
+ )
# if not rotation_matrix.shape[-2:] == (3, 4):
# raise ValueError(
# "Input size must be a N x 3 x 4 tensor. Got {}".format(
@@ -179,9 +176,10 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
- q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 +
- t2_rep * mask_c2 # noqa
- + t3_rep * mask_c3) # noqa
+ q /= torch.sqrt(
+ t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa
+ + t3_rep * mask_c3
+ ) # noqa
q *= 0.5
return q
@@ -206,13 +204,12 @@ def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor:
>>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3
"""
if not torch.is_tensor(angle_axis):
- raise TypeError("Input type is not a torch.Tensor. Got {}".format(
- type(angle_axis)))
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(angle_axis)))
if not angle_axis.shape[-1] == 3:
raise ValueError(
- "Input must be a tensor of shape Nx3 or 3. Got {}".format(
- angle_axis.shape))
+ "Input must be a tensor of shape Nx3 or 3. Got {}".format(angle_axis.shape)
+ )
# unpack input and compute conversion
a0: torch.Tensor = angle_axis[..., 0:1]
a1: torch.Tensor = angle_axis[..., 1:2]
@@ -249,9 +246,7 @@ def quaternion_to_rotation_matrix(quat):
"""
norm_quat = quat
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
- w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:,
- 2], norm_quat[:,
- 3]
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
B = quat.size(0)
@@ -296,13 +291,12 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor):
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
"""
if not torch.is_tensor(quaternion):
- raise TypeError("Input type is not a torch.Tensor. Got {}".format(
- type(quaternion)))
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
if not quaternion.shape[-1] == 4:
raise ValueError(
- "Input must be a tensor of shape Nx4 or 4. Got {}".format(
- quaternion.shape))
+ "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
+ )
# unpack input and compute conversion
q1: torch.Tensor = quaternion[..., 1]
q2: torch.Tensor = quaternion[..., 2]
@@ -318,12 +312,10 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor):
)
k_pos: torch.Tensor = two_theta / sin_theta
- k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to(
- quaternion.device)
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to(quaternion.device)
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
- angle_axis: torch.Tensor = torch.zeros_like(quaternion).to(
- quaternion.device)[..., :3]
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion).to(quaternion.device)[..., :3]
angle_axis[..., 0] += q1 * k
angle_axis[..., 1] += q2 * k
angle_axis[..., 2] += q3 * k
@@ -408,10 +400,10 @@ def _compute_euler_from_matrix(dcm, seq="xyz", extrinsic=False):
# 5b
safe_mask = torch.logical_and(safe1, safe2)
- angles[safe_mask, 0] = torch.atan2(dcm_transformed[safe_mask, 0, 2],
- -dcm_transformed[safe_mask, 1, 2])
- angles[safe_mask, 2] = torch.atan2(dcm_transformed[safe_mask, 2, 0],
- dcm_transformed[safe_mask, 2, 1])
+ angles[safe_mask,
+ 0] = torch.atan2(dcm_transformed[safe_mask, 0, 2], -dcm_transformed[safe_mask, 1, 2])
+ angles[safe_mask,
+ 2] = torch.atan2(dcm_transformed[safe_mask, 2, 0], dcm_transformed[safe_mask, 2, 1])
if extrinsic:
# For extrinsic, set first angle to zero so that after reversal we
# ensure that third angle is zero
@@ -448,8 +440,7 @@ def _compute_euler_from_matrix(dcm, seq="xyz", extrinsic=False):
adjust_mask = torch.logical_or(angles[:, 1] < 0, angles[:, 1] > np.pi)
else:
# lambda = + or - pi/2, so we can ensure angle2 -> [-pi/2, pi/2]
- adjust_mask = torch.logical_or(angles[:, 1] < -np.pi / 2,
- angles[:, 1] > np.pi / 2)
+ adjust_mask = torch.logical_or(angles[:, 1] < -np.pi / 2, angles[:, 1] > np.pi / 2)
# Dont adjust gimbal locked angle sequences
adjust_mask = torch.logical_and(adjust_mask, safe_mask)
@@ -463,8 +454,10 @@ def _compute_euler_from_matrix(dcm, seq="xyz", extrinsic=False):
# Step 8
if not torch.all(safe_mask):
- print("Gimbal lock detected. Setting third angle to zero since"
- "it is not possible to uniquely determine all angles.")
+ print(
+ "Gimbal lock detected. Setting third angle to zero since"
+ "it is not possible to uniquely determine all angles."
+ )
# Reverse role of extrinsic and intrinsic rotations, but let third angle be
# zero for gimbal locked cases
@@ -497,8 +490,7 @@ def batch_matrix2euler(rot_mats):
# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
# only y biw
# TODO: add x, z
- sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
- rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
+ sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
return torch.atan2(-rot_mats[:, 2, 0], sy)
@@ -550,8 +542,7 @@ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
- K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros],
- dim=1).view((batch_size, 3, 3))
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3))
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
@@ -571,9 +562,7 @@ def batch_cont2matrix(module_input):
# Normalize the first vector
b1 = F.normalize(reshaped_input[:, :, 0].clone(), dim=1)
- dot_prod = torch.sum(b1 * reshaped_input[:, :, 1].clone(),
- dim=1,
- keepdim=True)
+ dot_prod = torch.sum(b1 * reshaped_input[:, :, 1].clone(), dim=1, keepdim=True)
# Compute the second vector by finding the orthogonal complement to it
b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=1)
# Finish building the basis by taking the cross product
diff --git a/lib/pixielib/utils/tensor_cropper.py b/lib/pixielib/utils/tensor_cropper.py
index 21520901c79d4b690ab3747df6a7c820bf5de951..c486f7709ad9216080102ee275f7165d276eb0ce 100644
--- a/lib/pixielib/utils/tensor_cropper.py
+++ b/lib/pixielib/utils/tensor_cropper.py
@@ -34,21 +34,14 @@ def points2bbox(points, points_scale=None):
def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.0):
batch_size = center.shape[0]
- trans_scale = (torch.rand([batch_size, 2], device=center.device) * 2.0 -
- 1.0) * trans_scale
- center = center + trans_scale * bbox_size # 0.5
- scale = (torch.rand([batch_size, 1], device=center.device) *
- (scale[1] - scale[0]) + scale[0])
+ trans_scale = (torch.rand([batch_size, 2], device=center.device) * 2.0 - 1.0) * trans_scale
+ center = center + trans_scale * bbox_size # 0.5
+ scale = (torch.rand([batch_size, 1], device=center.device) * (scale[1] - scale[0]) + scale[0])
size = bbox_size * scale
return center, size
-def crop_tensor(image,
- center,
- bbox_size,
- crop_size,
- interpolation="bilinear",
- align_corners=False):
+def crop_tensor(image, center, bbox_size, crop_size, interpolation="bilinear", align_corners=False):
"""for batch image
Args:
image (torch.Tensor): the reference tensor of shape BXHxWXC.
@@ -66,11 +59,12 @@ def crop_tensor(image,
device = image.device
batch_size = image.shape[0]
# points: top-left, top-right, bottom-right, bottom-left
- src_pts = (torch.zeros([4, 2], dtype=dtype,
- device=device).unsqueeze(0).expand(
- batch_size, -1, -1).contiguous())
+ src_pts = (
+ torch.zeros([4, 2], dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1,
+ -1).contiguous()
+ )
- src_pts[:, 0, :] = center - bbox_size * 0.5 # / (self.crop_size - 1)
+ src_pts[:, 0, :] = center - bbox_size * 0.5 # / (self.crop_size - 1)
src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5
src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5
src_pts[:, 2, :] = center + bbox_size * 0.5
@@ -107,7 +101,6 @@ def crop_tensor(image,
class Cropper(object):
-
def __init__(self, crop_size, scale=[1, 1], trans_scale=0.0):
self.crop_size = crop_size
self.scale = scale
@@ -116,21 +109,14 @@ class Cropper(object):
def crop(self, image, points, points_scale=None):
# points to bbox
center, bbox_size = points2bbox(points.clone(), points_scale)
- # argument bbox. TODO: add rotation?
- center, bbox_size = augment_bbox(center,
- bbox_size,
- scale=self.scale,
- trans_scale=self.trans_scale)
+ center, bbox_size = augment_bbox(
+ center, bbox_size, scale=self.scale, trans_scale=self.trans_scale
+ )
# crop
- cropped_image, tform = crop_tensor(image, center, bbox_size,
- self.crop_size)
+ cropped_image, tform = crop_tensor(image, center, bbox_size, self.crop_size)
return cropped_image, tform
- def transform_points(self,
- points,
- tform,
- points_scale=None,
- normalize=True):
+ def transform_points(self, points, tform, points_scale=None, normalize=True):
points_2d = points[:, :, :2]
#'input points must use original range'
@@ -153,11 +139,9 @@ class Cropper(object):
),
tform,
)
- trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]],
- dim=-1)
+ trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], dim=-1)
if normalize:
- trans_points[:, :, :
- 2] = trans_points[:, :, :2] / self.crop_size * 2 - 1
+ trans_points[:, :, :2] = trans_points[:, :, :2] / self.crop_size * 2 - 1
return trans_points
@@ -174,14 +158,11 @@ def transform_points(points, tform, points_scale=None):
torch.cat(
[
points_2d,
- torch.ones([batch_size, n_points, 1],
- device=points.device,
- dtype=points.dtype),
+ torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype),
],
dim=-1,
),
tform,
)
- trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]],
- dim=-1)
+ trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], dim=-1)
return trans_points
diff --git a/lib/pixielib/utils/util.py b/lib/pixielib/utils/util.py
index 1e8ec90836ce198186b0842d3eadee1b02318a2d..566eda3a6e6ddf7f236bf4e20bf7220b39981ce3 100755
--- a/lib/pixielib/utils/util.py
+++ b/lib/pixielib/utils/util.py
@@ -46,8 +46,7 @@ def face_vertices(vertices, faces):
bs, nv = vertices.shape[:2]
bs, nf = faces.shape[:2]
device = vertices.device
- faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
- nv)[:, None, None]
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
vertices = vertices.reshape((bs * nv, 3))
# pytorch only supports long and byte tensors for indexing
return vertices[faces.long()]
@@ -71,9 +70,8 @@ def vertex_normals(vertices, faces):
normals = torch.zeros(bs * nv, 3).to(device)
faces = (
- faces +
- (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
- ) # expanded faces
+ faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
+ ) # expanded faces
vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()]
faces = faces.reshape(-1, 3)
@@ -145,12 +143,10 @@ def flip_pose(pose_vector, pose_format="rot-mat"):
# -------------------------------------- image processing
# ref: https://torchgeometry.readthedocs.io/en/latest/_modules/kornia/filters
def gaussian(window_size, sigma):
-
def gauss_fcn(x):
return -((x - window_size // 2)**2) / float(2 * sigma**2)
- gauss = torch.stack(
- [torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)])
+ gauss = torch.stack([torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)])
return gauss / gauss.sum()
@@ -175,10 +171,11 @@ def get_gaussian_kernel(kernel_size: int, sigma: float):
>>> kornia.image.get_gaussian_kernel(5, 1.5)
tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201])
"""
- if not isinstance(kernel_size,
- int) or kernel_size % 2 == 0 or kernel_size <= 0:
- raise TypeError("kernel_size must be an odd positive integer. "
- "Got {}".format(kernel_size))
+ if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
+ raise TypeError(
+ "kernel_size must be an odd positive integer. "
+ "Got {}".format(kernel_size)
+ )
window_1d = gaussian(kernel_size, sigma)
return window_1d
@@ -211,18 +208,14 @@ def get_gaussian_kernel2d(kernel_size, sigma):
[0.0370, 0.0720, 0.0899, 0.0720, 0.0370]])
"""
if not isinstance(kernel_size, tuple) or len(kernel_size) != 2:
- raise TypeError(
- "kernel_size must be a tuple of length two. Got {}".format(
- kernel_size))
+ raise TypeError("kernel_size must be a tuple of length two. Got {}".format(kernel_size))
if not isinstance(sigma, tuple) or len(sigma) != 2:
- raise TypeError(
- "sigma must be a tuple of length two. Got {}".format(sigma))
+ raise TypeError("sigma must be a tuple of length two. Got {}".format(sigma))
ksize_x, ksize_y = kernel_size
sigma_x, sigma_y = sigma
kernel_x = get_gaussian_kernel(ksize_x, sigma_x)
kernel_y = get_gaussian_kernel(ksize_y, sigma_y)
- kernel_2d = torch.matmul(kernel_x.unsqueeze(-1),
- kernel_y.unsqueeze(-1).t())
+ kernel_2d = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
return kernel_2d
@@ -283,10 +276,8 @@ def get_laplacian_kernel2d(kernel_size: int):
[ 1., 1., 1., 1., 1.]])
"""
- if not isinstance(kernel_size,
- int) or kernel_size % 2 == 0 or kernel_size <= 0:
- raise TypeError("ksize must be an odd positive integer. Got {}".format(
- kernel_size))
+ if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
+ raise TypeError("ksize must be an odd positive integer. Got {}".format(kernel_size))
kernel = torch.ones((kernel_size, kernel_size))
mid = kernel_size // 2
@@ -309,7 +300,6 @@ def laplacian(x):
def copy_state_dict(cur_state_dict, pre_state_dict, prefix="", load_name=None):
-
def _get_params(key):
key = prefix + key
if key in pre_state_dict:
@@ -353,7 +343,7 @@ def remove_module(state_dict):
# create new OrderedDict that does not contain `module.`
new_state_dict = OrderedDict()
for k, v in state_dict.items():
- name = k[7:] # remove `module.`
+ name = k[7:] # remove `module.`
new_state_dict[name] = v
return new_state_dict
@@ -433,24 +423,24 @@ def write_obj(
# write vertices
if colors is None:
for i in range(vertices.shape[0]):
- f.write("v {} {} {}\n".format(vertices[i, 0], vertices[i, 1],
- vertices[i, 2]))
+ f.write("v {} {} {}\n".format(vertices[i, 0], vertices[i, 1], vertices[i, 2]))
else:
for i in range(vertices.shape[0]):
- f.write("v {} {} {} {} {} {}\n".format(
- vertices[i, 0],
- vertices[i, 1],
- vertices[i, 2],
- colors[i, 0],
- colors[i, 1],
- colors[i, 2],
- ))
+ f.write(
+ "v {} {} {} {} {} {}\n".format(
+ vertices[i, 0],
+ vertices[i, 1],
+ vertices[i, 2],
+ colors[i, 0],
+ colors[i, 1],
+ colors[i, 2],
+ )
+ )
# write uv coords
if texture is None:
for i in range(faces.shape[0]):
- f.write("f {} {} {}\n".format(faces[i, 0], faces[i, 1],
- faces[i, 2]))
+ f.write("f {} {} {}\n".format(faces[i, 0], faces[i, 1], faces[i, 2]))
else:
for i in range(uvcoords.shape[0]):
f.write("vt {} {}\n".format(uvcoords[i, 0], uvcoords[i, 1]))
@@ -458,37 +448,37 @@ def write_obj(
# write f: ver ind/ uv ind
uvfaces = uvfaces + 1
for i in range(faces.shape[0]):
- f.write("f {}/{} {}/{} {}/{}\n".format(
- faces[i, 0],
- uvfaces[i, 0],
- faces[i, 1],
- uvfaces[i, 1],
- faces[i, 2],
- uvfaces[i, 2],
- ))
+ f.write(
+ "f {}/{} {}/{} {}/{}\n".format(
+ faces[i, 0],
+ uvfaces[i, 0],
+ faces[i, 1],
+ uvfaces[i, 1],
+ faces[i, 2],
+ uvfaces[i, 2],
+ )
+ )
# write mtl
with open(mtl_name, "w") as f:
f.write("newmtl %s\n" % material_name)
- s = "map_Kd {}\n".format(
- os.path.basename(texture_name)) # map to image
+ s = "map_Kd {}\n".format(os.path.basename(texture_name)) # map to image
f.write(s)
if normal_map is not None:
if torch.is_tensor(normal_map):
- normal_map = normal_map.detach().cpu().numpy().squeeze(
- )
+ normal_map = normal_map.detach().cpu().numpy().squeeze()
normal_map = np.transpose(normal_map, (1, 2, 0))
name, _ = os.path.splitext(obj_name)
normal_name = f"{name}_normals.png"
f.write(f"disp {normal_name}")
- out_normal_map = normal_map / (np.linalg.norm(
- normal_map, axis=-1, keepdims=True) + 1e-9)
+ out_normal_map = normal_map / (
+ np.linalg.norm(normal_map, axis=-1, keepdims=True) + 1e-9
+ )
out_normal_map = (out_normal_map + 1) * 0.5
- cv2.imwrite(normal_name, (out_normal_map * 255).astype(
- np.uint8)[:, :, ::-1])
+ cv2.imwrite(normal_name, (out_normal_map * 255).astype(np.uint8)[:, :, ::-1])
cv2.imwrite(texture_name, texture)
@@ -523,20 +513,20 @@ def load_obj(obj_filename):
for line in lines:
tokens = line.strip().split()
- if line.startswith("v "): # Line is a vertex.
+ if line.startswith("v "): # Line is a vertex.
vert = [float(x) for x in tokens[1:4]]
if len(vert) != 3:
msg = "Vertex %s does not have 3 values. Line: %s"
raise ValueError(msg % (str(vert), str(line)))
verts.append(vert)
- elif line.startswith("vt "): # Line is a texture.
+ elif line.startswith("vt "): # Line is a texture.
tx = [float(x) for x in tokens[1:3]]
if len(tx) != 2:
raise ValueError(
- "Texture %s does not have 2 values. Line: %s" %
- (str(tx), str(line)))
+ "Texture %s does not have 2 values. Line: %s" % (str(tx), str(line))
+ )
uvcoords.append(tx)
- elif line.startswith("f "): # Line is a face.
+ elif line.startswith("f "): # Line is a face.
# Update face properties info.
face = tokens[1:]
face_list = [f.split("/") for f in face]
@@ -558,12 +548,7 @@ def load_obj(obj_filename):
# ---------------------------------- visualization
-def draw_rectangle(img,
- bbox,
- bbox_color=(255, 255, 255),
- thickness=3,
- is_opaque=False,
- alpha=0.5):
+def draw_rectangle(img, bbox, bbox_color=(255, 255, 255), thickness=3, is_opaque=False, alpha=0.5):
"""Draws the rectangle around the object
borrowed from: https://bbox-visualizer.readthedocs.io/en/latest/_modules/bbox_visualizer/bbox_visualizer.html
Parameters
@@ -589,13 +574,11 @@ def draw_rectangle(img,
output = img.copy()
if not is_opaque:
- cv2.rectangle(output, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
- bbox_color, thickness)
+ cv2.rectangle(output, (bbox[0], bbox[1]), (bbox[2], bbox[3]), bbox_color, thickness)
else:
overlay = img.copy()
- cv2.rectangle(overlay, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
- bbox_color, -1)
+ cv2.rectangle(overlay, (bbox[0], bbox[1]), (bbox[2], bbox[3]), bbox_color, -1)
# cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output)
return output
@@ -607,9 +590,9 @@ def plot_bbox(image, bbox):
image: the input image
bbox: [left, top, right, bottom]
"""
- image = cv2.rectangle(image.copy(), (bbox[1], bbox[0]), (bbox[3], bbox[2]),
- [0, 255, 0],
- thickness=3)
+ image = cv2.rectangle(
+ image.copy(), (bbox[1], bbox[0]), (bbox[3], bbox[2]), [0, 255, 0], thickness=3
+ )
# image = draw_rectangle(image, bbox, bbox_color=[0,255,0])
return image
@@ -644,8 +627,7 @@ def plot_kpts(image, kpts, color="r"):
if i in end_list:
continue
ed = kpts[i + 1, :2]
- image = cv2.line(image, (st[0], st[1]), (ed[0], ed[1]),
- (255, 255, 255), 1)
+ image = cv2.line(image, (st[0], st[1]), (ed[0], ed[1]), (255, 255, 255), 1)
return image
@@ -674,11 +656,7 @@ def plot_verts(image, kpts, color="r"):
return image
-def tensor_vis_landmarks(images,
- landmarks,
- gt_landmarks=None,
- color="g",
- isScale=True):
+def tensor_vis_landmarks(images, landmarks, gt_landmarks=None, color="g", isScale=True):
# visualize landmarks
vis_landmarks = []
images = images.cpu().numpy()
@@ -690,8 +668,7 @@ def tensor_vis_landmarks(images,
image = image.transpose(1, 2, 0)[:, :, [2, 1, 0]].copy()
image = image * 255
if isScale:
- predicted_landmark = (predicted_landmarks[i] * image.shape[0] / 2 +
- image.shape[0] / 2)
+ predicted_landmark = (predicted_landmarks[i] * image.shape[0] / 2 + image.shape[0] / 2)
else:
predicted_landmark = predicted_landmarks[i]
if predicted_landmark.shape[0] == 68:
@@ -699,8 +676,7 @@ def tensor_vis_landmarks(images,
if gt_landmarks is not None:
image_landmarks = plot_verts(
image_landmarks,
- gt_landmarks_np[i] * image.shape[0] / 2 +
- image.shape[0] / 2,
+ gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2,
"r",
)
else:
@@ -708,14 +684,13 @@ def tensor_vis_landmarks(images,
if gt_landmarks is not None:
image_landmarks = plot_verts(
image_landmarks,
- gt_landmarks_np[i] * image.shape[0] / 2 +
- image.shape[0] / 2,
+ gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2,
"r",
)
vis_landmarks.append(image_landmarks)
vis_landmarks = np.stack(vis_landmarks)
- vis_landmarks = (torch.from_numpy(
- vis_landmarks[:, :, :, [2, 1, 0]].transpose(0, 3, 1, 2)) / 255.0
- ) # , dtype=torch.float32)
+ vis_landmarks = (
+ torch.from_numpy(vis_landmarks[:, :, :, [2, 1, 0]].transpose(0, 3, 1, 2)) / 255.0
+ ) # , dtype=torch.float32)
return vis_landmarks
diff --git a/lib/pymafx/core/cfgs.py b/lib/pymafx/core/cfgs.py
index 580643a7cb2ad7caaec29223b372669b17992926..c970c6c0caafe7a4c2f3abbb311adcd0cef42b94 100644
--- a/lib/pymafx/core/cfgs.py
+++ b/lib/pymafx/core/cfgs.py
@@ -67,6 +67,7 @@ def get_cfg_defaults():
# return cfg.clone()
return cfg
+
def update_cfg(cfg_file):
# cfg = get_cfg_defaults()
cfg.merge_from_file(cfg_file)
@@ -86,6 +87,7 @@ def parse_args(args):
return cfg
+
def parse_args_extend(args):
if args.resume:
if not os.path.exists(args.log_dir):
diff --git a/lib/pymafx/core/constants.py b/lib/pymafx/core/constants.py
index 47fd25e31d1ebfb11cd19a2fa8d8c4e61c79cc5d..5354a289f892a764a16221b469fc49794ff54127 100644
--- a/lib/pymafx/core/constants.py
+++ b/lib/pymafx/core/constants.py
@@ -43,23 +43,23 @@ SPIN_JOINT_NAMES = [
# 24 Ground Truth joints (superset of joints from different datasets)
'Right Ankle',
'Right Knee',
- 'Right Hip', # 2
+ 'Right Hip', # 2
'Left Hip',
- 'Left Knee', # 4
+ 'Left Knee', # 4
'Left Ankle',
- 'Right Wrist', # 6
+ 'Right Wrist', # 6
'Right Elbow',
- 'Right Shoulder', # 8
+ 'Right Shoulder', # 8
'Left Shoulder',
- 'Left Elbow', # 10
+ 'Left Elbow', # 10
'Left Wrist',
- 'Neck (LSP)', # 12
+ 'Neck (LSP)', # 12
'Top of Head (LSP)',
- 'Pelvis (MPII)', # 14
+ 'Pelvis (MPII)', # 14
'Thorax (MPII)',
- 'Spine (H36M)', # 16
+ 'Spine (H36M)', # 16
'Jaw (H36M)',
- 'Head (H36M)', # 18
+ 'Head (H36M)', # 18
'Nose',
'Left Eye',
'Right Eye',
@@ -278,8 +278,8 @@ FACIAL_LANDMARKS = [
'left_mouth_3',
'left_mouth_2',
'left_mouth_1',
- 'left_mouth_5', # 59 in OpenPose output
- 'left_mouth_4', # 58 in OpenPose output
+ 'left_mouth_5', # 59 in OpenPose output
+ 'left_mouth_4', # 58 in OpenPose output
'mouth_bottom',
'right_mouth_4',
'right_mouth_5',
diff --git a/lib/pymafx/models/attention.py b/lib/pymafx/models/attention.py
index 21bf6d10d907546e5462bd896e8d8bb819e41b24..b0f7d3c5c63ba1471ff15ee1a3cf0d8c94a17699 100644
--- a/lib/pymafx/models/attention.py
+++ b/lib/pymafx/models/attention.py
@@ -16,6 +16,7 @@ from .transformers.bert.modeling_bert import BertPreTrainedModel, BertEmbeddings
# import src.modeling.data.config as cfg
# from src.modeling._gcnn import GraphConvolution, GraphResBlock
from .transformers.bert.modeling_utils import prune_linear_layer
+
LayerNormClass = torch.nn.LayerNorm
BertLayerNorm = torch.nn.LayerNorm
from .transformers.bert import BertConfig
@@ -27,7 +28,8 @@ class BertSelfAttention(nn.Module):
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
- "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
@@ -45,8 +47,7 @@ class BertSelfAttention(nn.Module):
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
- def forward(self, hidden_states, attention_mask, head_mask=None,
- history_state=None):
+ def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None):
if history_state is not None:
raise
x_states = torch.cat([history_state, hidden_states], dim=1)
@@ -85,12 +86,13 @@ class BertSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
context_layer = context_layer.view(*new_context_layer_shape)
- outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, )
return outputs
+
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
@@ -114,12 +116,10 @@ class BertAttention(nn.Module):
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
- def forward(self, input_tensor, attention_mask, head_mask=None,
- history_state=None):
- self_outputs = self.self(input_tensor, attention_mask, head_mask,
- history_state)
+ def forward(self, input_tensor, attention_mask, head_mask=None, history_state=None):
+ self_outputs = self.self(input_tensor, attention_mask, head_mask, history_state)
attention_output = self.output(self_outputs[0], input_tensor)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them
return outputs
@@ -131,10 +131,8 @@ class AttLayer(nn.Module):
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
- def MHA(self, hidden_states, attention_mask, head_mask=None,
- history_state=None):
- attention_outputs = self.attention(hidden_states, attention_mask,
- head_mask, history_state)
+ def MHA(self, hidden_states, attention_mask, head_mask=None, history_state=None):
+ attention_outputs = self.attention(hidden_states, attention_mask, head_mask, history_state)
attention_output = attention_outputs[0]
# print('attention_output', hidden_states.shape, attention_output.shape)
@@ -143,12 +141,11 @@ class AttLayer(nn.Module):
# print('intermediate_output', intermediate_output.shape)
layer_output = self.output(intermediate_output, attention_output)
# print('layer_output', layer_output.shape)
- outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
+ outputs = (layer_output, ) + attention_outputs[1:] # add attentions if we output them
return outputs
- def forward(self, hidden_states, attention_mask, head_mask=None,
- history_state=None):
- return self.MHA(hidden_states, attention_mask, head_mask,history_state)
+ def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None):
+ return self.MHA(hidden_states, attention_mask, head_mask, history_state)
class AttEncoder(nn.Module):
@@ -158,34 +155,32 @@ class AttEncoder(nn.Module):
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([AttLayer(config) for _ in range(config.num_hidden_layers)])
- def forward(self, hidden_states, attention_mask, head_mask=None,
- encoder_history_states=None):
+ def forward(self, hidden_states, attention_mask, head_mask=None, encoder_history_states=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
+ all_hidden_states = all_hidden_states + (hidden_states, )
history_state = None if encoder_history_states is None else encoder_history_states[i]
- layer_outputs = layer_module(
- hidden_states, attention_mask, head_mask[i],
- history_state)
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], history_state)
hidden_states = layer_outputs[0]
if self.output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
+ all_attentions = all_attentions + (layer_outputs[1], )
# Add last layer
if self.output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
+ all_hidden_states = all_hidden_states + (hidden_states, )
- outputs = (hidden_states,)
+ outputs = (hidden_states, )
if self.output_hidden_states:
- outputs = outputs + (all_hidden_states,)
+ outputs = outputs + (all_hidden_states, )
if self.output_attentions:
- outputs = outputs + (all_attentions,)
+ outputs = outputs + (all_attentions, )
+
+ return outputs # outputs, (hidden states), (attentions)
- return outputs # outputs, (hidden states), (attentions)
class EncoderBlock(BertPreTrainedModel):
def __init__(self, config):
@@ -195,7 +190,7 @@ class EncoderBlock(BertPreTrainedModel):
self.encoder = AttEncoder(config)
# self.pooler = BertPooler(config)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- self.img_dim = config.img_feature_dim
+ self.img_dim = config.img_feature_dim
try:
self.use_img_layernorm = config.use_img_layernorm
@@ -217,26 +212,32 @@ class EncoderBlock(BertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
- def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None,
- position_ids=None, head_mask=None):
+ def forward(
+ self,
+ img_feats,
+ input_ids=None,
+ token_type_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None
+ ):
batch_size = len(img_feats)
seq_length = len(img_feats[0])
- input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(img_feats.device)
+ input_ids = torch.zeros([batch_size, seq_length], dtype=torch.long).to(img_feats.device)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# print('-------------------')
# print('position_ids', seq_length, position_ids.shape)
- # 494 torch.Size([2, 494])
+ # 494 torch.Size([2, 494])
position_embeddings = self.position_embeddings(position_ids)
# print('position_embeddings', position_embeddings.shape, self.config.max_position_embeddings, self.config.hidden_size)
- # torch.Size([2, 494, 1024]) 512 1024
+ # torch.Size([2, 494, 1024]) 512 1024
# torch.Size([2, 494, 256]) 512 256
-
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
else:
@@ -255,7 +256,9 @@ class EncoderBlock(BertPreTrainedModel):
raise NotImplementedError
# extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
- extended_attention_mask = extended_attention_mask.to(dtype=img_feats.dtype) # fp16 compatibility
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=img_feats.dtype
+ ) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
@@ -264,15 +267,19 @@ class EncoderBlock(BertPreTrainedModel):
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
- head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
- head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
+ -1
+ ) # We can specify head_mask for each layer
+ head_mask = head_mask.to(
+ dtype=next(self.parameters()).dtype
+ ) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
# Project input token features to have spcified hidden size
- # print('img_feats', img_feats.shape) # torch.Size([2, 494, 2051])
+ # print('img_feats', img_feats.shape) # torch.Size([2, 494, 2051])
img_embedding_output = self.img_embedding(img_feats)
- # print('img_embedding_output', img_embedding_output.shape) # torch.Size([2, 494, 1024])
+ # print('img_embedding_output', img_embedding_output.shape) # torch.Size([2, 494, 1024])
# We empirically observe that adding an additional learnable position embedding leads to more stable training
embeddings = position_embeddings + img_embedding_output
@@ -282,21 +289,27 @@ class EncoderBlock(BertPreTrainedModel):
# embeddings = self.dropout(embeddings)
# print('extended_attention_mask', extended_attention_mask.shape) # torch.Size([2, 1, 1, 494])
- encoder_outputs = self.encoder(embeddings,
- extended_attention_mask, head_mask=head_mask)
+ encoder_outputs = self.encoder(embeddings, extended_attention_mask, head_mask=head_mask)
sequence_output = encoder_outputs[0]
- outputs = (sequence_output,)
+ outputs = (sequence_output, )
if self.config.output_hidden_states:
all_hidden_states = encoder_outputs[1]
- outputs = outputs + (all_hidden_states,)
+ outputs = outputs + (all_hidden_states, )
if self.config.output_attentions:
all_attentions = encoder_outputs[-1]
- outputs = outputs + (all_attentions,)
+ outputs = outputs + (all_attentions, )
return outputs
-def get_att_block(img_feature_dim=2048, output_feat_dim=512, hidden_feat_dim=1024, num_attention_heads=4, num_hidden_layers=1):
+
+def get_att_block(
+ img_feature_dim=2048,
+ output_feat_dim=512,
+ hidden_feat_dim=1024,
+ num_attention_heads=4,
+ num_hidden_layers=1
+):
config_class = BertConfig
config = config_class.from_pretrained('lib/pymafx/models/transformers/bert/bert-base-uncased/')
@@ -316,7 +329,7 @@ def get_att_block(img_feature_dim=2048, output_feat_dim=512, hidden_feat_dim=102
# init a transformer encoder and append it to a list
assert config.hidden_size % config.num_attention_heads == 0
- att_model = EncoderBlock(config=config)
+ att_model = EncoderBlock(config=config)
return att_model
@@ -333,16 +346,31 @@ class Graphormer(BertPreTrainedModel):
self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim)
self.apply(self.init_weights)
- def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
- next_sentence_label=None, position_ids=None, head_mask=None):
+ def forward(
+ self,
+ img_feats,
+ input_ids=None,
+ token_type_ids=None,
+ attention_mask=None,
+ masked_lm_labels=None,
+ next_sentence_label=None,
+ position_ids=None,
+ head_mask=None
+ ):
'''
# self.bert has three outputs
# predictions[0]: output tokens
# predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states"
# predictions[2]: attentions, if enable "self.config.output_attentions"
'''
- predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
- attention_mask=attention_mask, head_mask=head_mask)
+ predictions = self.bert(
+ img_feats=img_feats,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask
+ )
# We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification.
pred_score = self.cls_head(predictions[0])
@@ -354,5 +382,3 @@ class Graphormer(BertPreTrainedModel):
return pred_score, predictions[1], predictions[-1]
else:
return pred_score
-
-
\ No newline at end of file
diff --git a/lib/pymafx/models/hmr.py b/lib/pymafx/models/hmr.py
index f91f4a8311b940afca6155d2f31487c6e77fa5ad..da5459d355d3a3f00c53638a376ab3143b23c01e 100755
--- a/lib/pymafx/models/hmr.py
+++ b/lib/pymafx/models/hmr.py
@@ -8,10 +8,12 @@ import math
from lib.net.geometry import rot6d_to_rotmat
import logging
+
logger = logging.getLogger(__name__)
BN_MOMENTUM = 0.1
+
class Bottleneck(nn.Module):
""" Redefinition of Bottleneck residual block
Adapted from the official PyTorch implementation
@@ -22,8 +24,7 @@ class Bottleneck(nn.Module):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
- padding=1, bias=False)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
@@ -57,18 +58,16 @@ class Bottleneck(nn.Module):
class ResNet_Backbone(nn.Module):
""" Feature Extrator with ResNet backbone
"""
-
def __init__(self, model='res50', pretrained=True):
if model == 'res50':
block, layers = Bottleneck, [3, 4, 6, 3]
else:
- pass # TODO
+ pass # TODO
self.inplanes = 64
super().__init__()
npose = 24 * 6
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
- bias=False)
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@@ -87,8 +86,13 @@ class ResNet_Backbone(nn.Module):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
- nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False
+ ),
nn.BatchNorm2d(planes * block.expansion),
)
@@ -105,7 +109,7 @@ class ResNet_Backbone(nn.Module):
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
-
+
def _get_deconv_cfg(deconv_kernel, index):
if deconv_kernel == 4:
padding = 1
@@ -132,7 +136,9 @@ class ResNet_Backbone(nn.Module):
stride=2,
padding=padding,
output_padding=output_padding,
- bias=self.deconv_with_bias))
+ bias=self.deconv_with_bias
+ )
+ )
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
self.inplanes = planes
@@ -164,13 +170,11 @@ class ResNet_Backbone(nn.Module):
class HMR(nn.Module):
""" SMPL Iterative Regressor with ResNet50 backbone
"""
-
def __init__(self, block, layers, smpl_mean_params):
self.inplanes = 64
super().__init__()
npose = 24 * 6
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
- bias=False)
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@@ -206,13 +210,17 @@ class HMR(nn.Module):
self.register_buffer('init_shape', init_shape)
self.register_buffer('init_cam', init_cam)
-
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
- nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False
+ ),
nn.BatchNorm2d(planes * block.expansion),
)
@@ -224,7 +232,6 @@ class HMR(nn.Module):
return nn.Sequential(*layers)
-
def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3):
batch_size = x.shape[0]
@@ -253,7 +260,7 @@ class HMR(nn.Module):
pred_shape = init_shape
pred_cam = init_cam
for i in range(n_iter):
- xc = torch.cat([xf, pred_pose, pred_shape, pred_cam],1)
+ xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
@@ -266,13 +273,14 @@ class HMR(nn.Module):
return pred_rotmat, pred_shape, pred_cam
+
def hmr(smpl_mean_params, pretrained=True, **kwargs):
""" Constructs an HMR model with ResNet50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
- model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs)
+ model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs)
if pretrained:
resnet_imagenet = resnet.resnet50(pretrained=True)
- model.load_state_dict(resnet_imagenet.state_dict(),strict=False)
- return model
\ No newline at end of file
+ model.load_state_dict(resnet_imagenet.state_dict(), strict=False)
+ return model
diff --git a/lib/pymafx/models/hr_module.py b/lib/pymafx/models/hr_module.py
index 285cd2c56728e439fdcd1a8bccbafca3f5549ef3..7396f1ea59860235db8fdd24434114381c4a7083 100644
--- a/lib/pymafx/models/hr_module.py
+++ b/lib/pymafx/models/hr_module.py
@@ -7,16 +7,25 @@ import torch.nn.functional as F
from .res_module import BasicBlock, Bottleneck
import logging
+
logger = logging.getLogger(__name__)
BN_MOMENTUM = 0.1
+
class HighResolutionModule(nn.Module):
- def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
- num_channels, fuse_method, multi_scale_output=True):
+ def __init__(
+ self,
+ num_branches,
+ blocks,
+ num_blocks,
+ num_inchannels,
+ num_channels,
+ fuse_method,
+ multi_scale_output=True
+ ):
super().__init__()
- self._check_branches(
- num_branches, blocks, num_blocks, num_inchannels, num_channels)
+ self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
@@ -24,33 +33,31 @@ class HighResolutionModule(nn.Module):
self.multi_scale_output = multi_scale_output
- self.branches = self._make_branches(
- num_branches, blocks, num_blocks, num_channels)
+ self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(True)
- def _check_branches(self, num_branches, blocks, num_blocks,
- num_inchannels, num_channels):
+ def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
if num_branches != len(num_blocks):
- error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
- num_branches, len(num_blocks))
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
- num_branches, len(num_channels))
+ num_branches, len(num_channels)
+ )
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
- num_branches, len(num_inchannels))
+ num_branches, len(num_inchannels)
+ )
logger.error(error_msg)
raise ValueError(error_msg)
- def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
- stride=1):
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
downsample = None
if stride != 1 or \
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
@@ -58,32 +65,23 @@ class HighResolutionModule(nn.Module):
nn.Conv2d(
self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
- kernel_size=1, stride=stride, bias=False
- ),
- nn.BatchNorm2d(
- num_channels[branch_index] * block.expansion,
- momentum=BN_MOMENTUM
+ kernel_size=1,
+ stride=stride,
+ bias=False
),
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(
block(
- self.num_inchannels[branch_index],
- num_channels[branch_index],
- stride,
- downsample
+ self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample
)
)
self.num_inchannels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
- layers.append(
- block(
- self.num_inchannels[branch_index],
- num_channels[branch_index]
- )
- )
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
return nn.Sequential(*layers)
@@ -91,9 +89,7 @@ class HighResolutionModule(nn.Module):
branches = []
for i in range(num_branches):
- branches.append(
- self._make_one_branch(i, block, num_blocks, num_channels)
- )
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
@@ -110,20 +106,16 @@ class HighResolutionModule(nn.Module):
if j > i:
fuse_layer.append(
nn.Sequential(
- nn.Conv2d(
- num_inchannels[j],
- num_inchannels[i],
- 1, 1, 0, bias=False
- ),
+ nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
nn.BatchNorm2d(num_inchannels[i]),
- nn.Upsample(scale_factor=2**(j-i), mode='nearest')
+ nn.Upsample(scale_factor=2**(j - i), mode='nearest')
)
)
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
- for k in range(i-j):
+ for k in range(i - j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(
@@ -131,9 +123,11 @@ class HighResolutionModule(nn.Module):
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
- 3, 2, 1, bias=False
- ),
- nn.BatchNorm2d(num_outchannels_conv3x3)
+ 3,
+ 2,
+ 1,
+ bias=False
+ ), nn.BatchNorm2d(num_outchannels_conv3x3)
)
)
else:
@@ -143,10 +137,11 @@ class HighResolutionModule(nn.Module):
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
- 3, 2, 1, bias=False
- ),
- nn.BatchNorm2d(num_outchannels_conv3x3),
- nn.ReLU(True)
+ 3,
+ 2,
+ 1,
+ bias=False
+ ), nn.BatchNorm2d(num_outchannels_conv3x3), nn.ReLU(True)
)
)
fuse_layer.append(nn.Sequential(*conv3x3s))
@@ -178,25 +173,19 @@ class HighResolutionModule(nn.Module):
return x_fuse
-blocks_dict = {
- 'BASIC': BasicBlock,
- 'BOTTLENECK': Bottleneck
-}
+blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
class PoseHighResolutionNet(nn.Module):
-
def __init__(self, cfg, pretrained=True, global_mode=False):
self.inplanes = 64
extra = cfg.HR_MODEL.EXTRA
super().__init__()
# stem net
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
- bias=False)
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
- self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
- bias=False)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(Bottleneck, self.inplanes, 64, 4)
@@ -204,34 +193,25 @@ class PoseHighResolutionNet(nn.Module):
self.stage2_cfg = cfg['HR_MODEL']['EXTRA']['STAGE2']
num_channels = self.stage2_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage2_cfg['BLOCK']]
- num_channels = [
- num_channels[i] * block.expansion for i in range(len(num_channels))
- ]
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition1 = self._make_transition_layer([256], num_channels)
- self.stage2, pre_stage_channels = self._make_stage(
- self.stage2_cfg, num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
self.stage3_cfg = cfg['HR_MODEL']['EXTRA']['STAGE3']
num_channels = self.stage3_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage3_cfg['BLOCK']]
- num_channels = [
- num_channels[i] * block.expansion for i in range(len(num_channels))
- ]
- self.transition2 = self._make_transition_layer(
- pre_stage_channels, num_channels)
- self.stage3, pre_stage_channels = self._make_stage(
- self.stage3_cfg, num_channels)
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
self.stage4_cfg = cfg['HR_MODEL']['EXTRA']['STAGE4']
num_channels = self.stage4_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage4_cfg['BLOCK']]
- num_channels = [
- num_channels[i] * block.expansion for i in range(len(num_channels))
- ]
- self.transition3 = self._make_transition_layer(
- pre_stage_channels, num_channels)
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
self.stage4, pre_stage_channels = self._make_stage(
- self.stage4_cfg, num_channels, multi_scale_output=True)
+ self.stage4_cfg, num_channels, multi_scale_output=True
+ )
# Classification Head
self.global_mode = global_mode
@@ -249,11 +229,7 @@ class PoseHighResolutionNet(nn.Module):
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
incre_modules = []
for i, channels in enumerate(pre_stage_channels):
- incre_module = self._make_layer(head_block,
- channels,
- head_channels[i],
- 1,
- stride=1)
+ incre_module = self._make_layer(head_block, channels, head_channels[i], 1, stride=1)
incre_modules.append(incre_module)
incre_modules = nn.ModuleList(incre_modules)
@@ -264,13 +240,13 @@ class PoseHighResolutionNet(nn.Module):
out_channels = head_channels[i + 1] * head_block.expansion
downsamp_module = nn.Sequential(
- nn.Conv2d(in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=3,
- stride=2,
- padding=1),
- nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
- nn.ReLU(inplace=True)
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1
+ ), nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), nn.ReLU(inplace=True)
)
downsamp_modules.append(downsamp_module)
@@ -283,15 +259,12 @@ class PoseHighResolutionNet(nn.Module):
kernel_size=1,
stride=1,
padding=0
- ),
- nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
- nn.ReLU(inplace=True)
+ ), nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), nn.ReLU(inplace=True)
)
return incre_modules, downsamp_modules, final_layer
- def _make_transition_layer(
- self, num_channels_pre_layer, num_channels_cur_layer):
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
@@ -304,27 +277,25 @@ class PoseHighResolutionNet(nn.Module):
nn.Conv2d(
num_channels_pre_layer[i],
num_channels_cur_layer[i],
- 3, 1, 1, bias=False
- ),
- nn.BatchNorm2d(num_channels_cur_layer[i]),
- nn.ReLU(inplace=True)
+ 3,
+ 1,
+ 1,
+ bias=False
+ ), nn.BatchNorm2d(num_channels_cur_layer[i]), nn.ReLU(inplace=True)
)
)
else:
transition_layers.append(None)
else:
conv3x3s = []
- for j in range(i+1-num_branches_pre):
+ for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] \
if j == i-num_branches_pre else inchannels
conv3x3s.append(
nn.Sequential(
- nn.Conv2d(
- inchannels, outchannels, 3, 2, 1, bias=False
- ),
- nn.BatchNorm2d(outchannels),
- nn.ReLU(inplace=True)
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
+ nn.BatchNorm2d(outchannels), nn.ReLU(inplace=True)
)
)
transition_layers.append(nn.Sequential(*conv3x3s))
@@ -336,8 +307,7 @@ class PoseHighResolutionNet(nn.Module):
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
- inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False
+ inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
@@ -350,8 +320,7 @@ class PoseHighResolutionNet(nn.Module):
return nn.Sequential(*layers)
- def _make_stage(self, layer_config, num_inchannels,
- multi_scale_output=True):
+ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
@@ -369,12 +338,7 @@ class PoseHighResolutionNet(nn.Module):
modules.append(
HighResolutionModule(
- num_branches,
- block,
- num_blocks,
- num_inchannels,
- num_channels,
- fuse_method,
+ num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method,
reset_multi_scale_output
)
)
diff --git a/lib/pymafx/models/maf_extractor.py b/lib/pymafx/models/maf_extractor.py
index 1d1af8dac7a4bf9c157f9396ffe32f7811ec50ad..34237bc55663dcbcbd67beb4c5d0b6e693aae266 100644
--- a/lib/pymafx/models/maf_extractor.py
+++ b/lib/pymafx/models/maf_extractor.py
@@ -10,6 +10,7 @@ from lib.pymafx.core import path_config
from lib.pymafx.utils.geometry import projection
import logging
+
logger = logging.getLogger(__name__)
from .transformers.net_utils import PosEnSine
@@ -19,7 +20,9 @@ from lib.pymafx.utils.imutils import j2d_processing
class TransformerDecoderUnit(nn.Module):
- def __init__(self, feat_dim, attri_dim=0, n_head=8, pos_en_flag=True, attn_type='softmax', P=None):
+ def __init__(
+ self, feat_dim, attri_dim=0, n_head=8, pos_en_flag=True, attn_type='softmax', P=None
+ ):
super(TransformerDecoderUnit, self).__init__()
self.feat_dim = feat_dim
self.attn_type = attn_type
@@ -32,7 +35,9 @@ class TransformerDecoderUnit(nn.Module):
self.pos_en = PosEnSine(pe_dim)
else:
pe_dim = 0
- self.attn = OurMultiheadAttention(feat_dim+attri_dim+pe_dim*3, feat_dim+pe_dim*3, feat_dim, n_head) # cross-attention
+ self.attn = OurMultiheadAttention(
+ feat_dim + attri_dim + pe_dim * 3, feat_dim + pe_dim * 3, feat_dim, n_head
+ ) # cross-attention
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
@@ -50,7 +55,7 @@ class TransformerDecoderUnit(nn.Module):
# else:
# q_pos_embed = 0
# k_pos_embed = 0
-
+
# cross-multi-head attention
out = self.attn(q=q, k=k, v=v, attn_type=self.attn_type, P=self.P)[0]
@@ -65,25 +70,28 @@ class TransformerDecoderUnit(nn.Module):
class Mesh_Sampler(nn.Module):
''' Mesh Up/Down-sampling
'''
-
def __init__(self, type='smpl', level=2, device=torch.device('cuda'), option=None):
super().__init__()
# downsample SMPL mesh and assign part labels
if type == 'smpl':
# from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz
- smpl_mesh_graph = np.load(path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1')
+ smpl_mesh_graph = np.load(
+ path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1'
+ )
A = smpl_mesh_graph['A']
U = smpl_mesh_graph['U']
- D = smpl_mesh_graph['D'] # shape: (2,)
+ D = smpl_mesh_graph['D'] # shape: (2,)
elif type == 'mano':
# from https://github.com/microsoft/MeshGraphormer/blob/main/src/modeling/data/mano_downsampling.npz
- mano_mesh_graph = np.load(path_config.MANO_DOWNSAMPLING, allow_pickle=True, encoding='latin1')
+ mano_mesh_graph = np.load(
+ path_config.MANO_DOWNSAMPLING, allow_pickle=True, encoding='latin1'
+ )
A = mano_mesh_graph['A']
U = mano_mesh_graph['U']
- D = mano_mesh_graph['D'] # shape: (2,)
+ D = mano_mesh_graph['D'] # shape: (2,)
# downsampling
ptD = []
@@ -92,14 +100,14 @@ class Mesh_Sampler(nn.Module):
i = torch.LongTensor(np.array([d.row, d.col]))
v = torch.FloatTensor(d.data)
ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
-
+
# downsampling mapping from 6890 points to 431 points
# ptD[0].to_dense() - Size: [1723, 6890] , [195, 778]
# ptD[1].to_dense() - Size: [431, 1723] , [49, 195]
if level == 2:
- Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431
+ Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431
elif level == 1:
- Dmap = ptD[0].to_dense() #
+ Dmap = ptD[0].to_dense() #
self.register_buffer('Dmap', Dmap)
# upsampling
@@ -109,21 +117,21 @@ class Mesh_Sampler(nn.Module):
i = torch.LongTensor(np.array([d.row, d.col]))
v = torch.FloatTensor(d.data)
ptU.append(torch.sparse.FloatTensor(i, v, d.shape))
-
+
# upsampling mapping from 431 points to 6890 points
# ptU[0].to_dense() - Size: [6890, 1723]
# ptU[1].to_dense() - Size: [1723, 431]
if level == 2:
- Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890
+ Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890
elif level == 1:
- Umap = ptU[0].to_dense() #
+ Umap = ptU[0].to_dense() #
self.register_buffer('Umap', Umap)
def downsample(self, x):
- return torch.matmul(self.Dmap.unsqueeze(0), x) # [B, 431, 3]
-
+ return torch.matmul(self.Dmap.unsqueeze(0), x) # [B, 431, 3]
+
def upsample(self, x):
- return torch.matmul(self.Umap.unsqueeze(0), x) # [B, 6890, 3]
+ return torch.matmul(self.Umap.unsqueeze(0), x) # [B, 6890, 3]
def forward(self, x, mode='downsample'):
if mode == 'downsample':
@@ -137,8 +145,9 @@ class MAF_Extractor(nn.Module):
As discussed in the paper, we extract mesh-aligned features based on 2D projection of the mesh vertices.
The features extrated from spatial feature maps will go through a MLP for dimension reduction.
'''
-
- def __init__(self, filter_channels, device=torch.device('cuda'), iwp_cam_mode=True, option=None):
+ def __init__(
+ self, filter_channels, device=torch.device('cuda'), iwp_cam_mode=True, option=None
+ ):
super().__init__()
self.device = device
@@ -151,25 +160,22 @@ class MAF_Extractor(nn.Module):
for l in range(0, len(filter_channels) - 1):
if 0 != l:
self.filters.append(
- nn.Conv1d(
- filter_channels[l] + filter_channels[0],
- filter_channels[l + 1],
- 1))
+ nn.Conv1d(filter_channels[l] + filter_channels[0], filter_channels[l + 1], 1)
+ )
else:
- self.filters.append(nn.Conv1d(
- filter_channels[l],
- filter_channels[l + 1],
- 1))
+ self.filters.append(nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1))
self.add_module("conv%d" % l, self.filters[l])
# downsample SMPL mesh and assign part labels
# from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz
- smpl_mesh_graph = np.load(path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1')
+ smpl_mesh_graph = np.load(
+ path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1'
+ )
A = smpl_mesh_graph['A']
U = smpl_mesh_graph['U']
- D = smpl_mesh_graph['D'] # shape: (2,)
+ D = smpl_mesh_graph['D'] # shape: (2,)
# downsampling
ptD = []
@@ -178,11 +184,11 @@ class MAF_Extractor(nn.Module):
i = torch.LongTensor(np.array([d.row, d.col]))
v = torch.FloatTensor(d.data)
ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
-
+
# downsampling mapping from 6890 points to 431 points
# ptD[0].to_dense() - Size: [1723, 6890]
# ptD[1].to_dense() - Size: [431. 1723]
- Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431
+ Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431
self.register_buffer('Dmap', Dmap)
# upsampling
@@ -192,14 +198,13 @@ class MAF_Extractor(nn.Module):
i = torch.LongTensor(np.array([d.row, d.col]))
v = torch.FloatTensor(d.data)
ptU.append(torch.sparse.FloatTensor(i, v, d.shape))
-
+
# upsampling mapping from 431 points to 6890 points
# ptU[0].to_dense() - Size: [6890, 1723]
# ptU[1].to_dense() - Size: [1723, 431]
- Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890
+ Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890
self.register_buffer('Umap', Umap)
-
def reduce_dim(self, feature):
'''
Dimension reduction by multi-layer perceptrons
@@ -209,19 +214,13 @@ class MAF_Extractor(nn.Module):
y = feature
tmpy = feature
for i, f in enumerate(self.filters):
- y = self._modules['conv' + str(i)](
- y if i == 0
- else torch.cat([y, tmpy], 1)
- )
+ y = self._modules['conv' + str(i)](y if i == 0 else torch.cat([y, tmpy], 1))
if i != len(self.filters) - 1:
y = F.leaky_relu(y)
if self.num_views > 1 and i == len(self.filters) // 2:
- y = y.view(
- -1, self.num_views, y.shape[1], y.shape[2]
- ).mean(dim=1)
- tmpy = feature.view(
- -1, self.num_views, feature.shape[1], feature.shape[2]
- ).mean(dim=1)
+ y = y.view(-1, self.num_views, y.shape[1], y.shape[2]).mean(dim=1)
+ tmpy = feature.view(-1, self.num_views, feature.shape[1],
+ feature.shape[2]).mean(dim=1)
y = self.last_op(y)
@@ -242,7 +241,9 @@ class MAF_Extractor(nn.Module):
# im_feat = self.im_feat
batch_size = im_feat.shape[0]
- point_feat = torch.nn.functional.grid_sample(im_feat, points.unsqueeze(2), align_corners=False)[..., 0]
+ point_feat = torch.nn.functional.grid_sample(
+ im_feat, points.unsqueeze(2), align_corners=False
+ )[..., 0]
if reduce_dim:
mesh_align_feat = self.reduce_dim(point_feat)
@@ -266,6 +267,6 @@ class MAF_Extractor(nn.Module):
# Normalize keypoints to [-1,1]
p_proj_2d = p_proj_2d / (224. / 2.)
else:
- p_proj_2d = j2d_processing(p_proj_2d, cam['kps_transf'])
+ p_proj_2d = j2d_processing(p_proj_2d, cam['kps_transf'])
mesh_align_feat = self.sampling(p_proj_2d, im_feat, add_att=add_att, reduce_dim=reduce_dim)
return mesh_align_feat
diff --git a/lib/pymafx/models/pose_resnet.py b/lib/pymafx/models/pose_resnet.py
index e9a2f6716c002b2fd9645d1877081b4177730049..d97b6609cf02fd2a94d2951f82f71de2be2356c0 100644
--- a/lib/pymafx/models/pose_resnet.py
+++ b/lib/pymafx/models/pose_resnet.py
@@ -14,17 +14,13 @@ import logging
import torch
import torch.nn as nn
-
BN_MOMENTUM = 0.1
logger = logging.getLogger(__name__)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
- return nn.Conv2d(
- in_planes, out_planes, kernel_size=3, stride=stride,
- padding=1, bias=False
- )
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
@@ -66,13 +62,10 @@ class Bottleneck(nn.Module):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
- padding=1, bias=False)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
- bias=False)
- self.bn3 = nn.BatchNorm2d(planes * self.expansion,
- momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
@@ -101,7 +94,6 @@ class Bottleneck(nn.Module):
class PoseResNet(nn.Module):
-
def __init__(self, block, layers, cfg, global_mode, **kwargs):
self.inplanes = 64
extra = cfg.POSE_RES_MODEL.EXTRA
@@ -109,8 +101,7 @@ class PoseResNet(nn.Module):
self.deconv_with_bias = extra.DECONV_WITH_BIAS
super(PoseResNet, self).__init__()
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
- bias=False)
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@@ -144,8 +135,13 @@ class PoseResNet(nn.Module):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
- nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False
+ ),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
@@ -190,7 +186,9 @@ class PoseResNet(nn.Module):
stride=2,
padding=padding,
output_padding=output_padding,
- bias=self.deconv_with_bias))
+ bias=self.deconv_with_bias
+ )
+ )
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
self.inplanes = planes
@@ -218,7 +216,9 @@ class PoseResNet(nn.Module):
else:
g_feat = None
if self.extra.NUM_DECONV_LAYERS == 3:
- deconv_blocks = [self.deconv_layers[0:3], self.deconv_layers[3:6], self.deconv_layers[6:9]]
+ deconv_blocks = [
+ self.deconv_layers[0:3], self.deconv_layers[3:6], self.deconv_layers[6:9]
+ ]
s_feat_list = []
s_feat = x
@@ -284,6 +284,7 @@ resnet_spec = {
152: (Bottleneck, [3, 8, 36, 3])
}
+
def get_resnet_encoder(cfg, init_weight=True, global_mode=False, **kwargs):
num_layers = cfg.POSE_RES_MODEL.EXTRA.NUM_LAYERS
diff --git a/lib/pymafx/models/pymaf_net.py b/lib/pymafx/models/pymaf_net.py
index 5b6f3587e5c470236a0647550b64309058545d1b..ca57e4b1c8ce971d76ce53d02827f441016a19ab 100644
--- a/lib/pymafx/models/pymaf_net.py
+++ b/lib/pymafx/models/pymaf_net.py
@@ -23,15 +23,16 @@ BN_MOMENTUM = 0.1
class Regressor(nn.Module):
-
- def __init__(self,
- feat_dim,
- smpl_mean_params,
- use_cam_feats=False,
- feat_dim_hand=0,
- feat_dim_face=0,
- bhf_names=['body'],
- smpl_models={}):
+ def __init__(
+ self,
+ feat_dim,
+ smpl_mean_params,
+ use_cam_feats=False,
+ feat_dim_hand=0,
+ feat_dim_face=0,
+ bhf_names=['body'],
+ smpl_models={}
+ ):
super().__init__()
npose = 24 * 6
@@ -96,8 +97,9 @@ class Regressor(nn.Module):
rh_cam_dim = 3
rh_orient_dim = 6
rh_shape_dim = 10
- self.fc3_hand = nn.Linear(1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim,
- 1024)
+ self.fc3_hand = nn.Linear(
+ 1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024
+ )
self.drop3_hand = nn.Dropout()
self.decshape_rhand = nn.Linear(1024, 10)
@@ -122,8 +124,9 @@ class Regressor(nn.Module):
rh_cam_dim = 3
rh_orient_dim = 6
rh_shape_dim = 10
- self.fc3_face = nn.Linear(1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim,
- 1024)
+ self.fc3_face = nn.Linear(
+ 1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024
+ )
self.drop3_face = nn.Dropout()
self.decshape_face = nn.Linear(1024, 10)
@@ -167,10 +170,14 @@ class Regressor(nn.Module):
if not self.smpl_mode:
lhand_mean_rot6d = rotmat_to_rot6d(
batch_rodrigues(self.smpl.model.model_neutral.left_hand_mean.view(-1, 3)).view(
- [-1, 3, 3]))
+ [-1, 3, 3]
+ )
+ )
rhand_mean_rot6d = rotmat_to_rot6d(
batch_rodrigues(self.smpl.model.model_neutral.right_hand_mean.view(-1, 3)).view(
- [-1, 3, 3]))
+ [-1, 3, 3]
+ )
+ )
init_lhand = lhand_mean_rot6d.reshape(-1).unsqueeze(0)
init_rhand = rhand_mean_rot6d.reshape(-1).unsqueeze(0)
# init_hand = torch.cat([init_lhand, init_rhand]).unsqueeze(0)
@@ -185,14 +192,16 @@ class Regressor(nn.Module):
self.register_buffer('init_face', init_face)
self.register_buffer('init_exp', init_exp)
- def forward(self,
- x=None,
- n_iter=1,
- J_regressor=None,
- rw_cam={},
- init_mode=False,
- global_iter=-1,
- **kwargs):
+ def forward(
+ self,
+ x=None,
+ n_iter=1,
+ J_regressor=None,
+ rw_cam={},
+ init_mode=False,
+ global_iter=-1,
+ **kwargs
+ ):
if x is not None:
batch_size = x.shape[0]
else:
@@ -215,8 +224,9 @@ class Regressor(nn.Module):
if self.full_body_mode or self.body_hand_mode:
if cfg.MODEL.PyMAF.OPT_WRIST:
- pred_rotmat_body = rot6d_to_rotmat(pred_pose.reshape(
- batch_size, -1, 6)) # .view(batch_size, 24, 3, 3)
+ pred_rotmat_body = rot6d_to_rotmat(
+ pred_pose.reshape(batch_size, -1, 6)
+ ) # .view(batch_size, 24, 3, 3)
if cfg.MODEL.PyMAF.PRED_VIS_H:
pred_vis_hands = None
@@ -291,7 +301,8 @@ class Regressor(nn.Module):
vfov = rw_cam['vfov'][:, None]
crop_ratio = rw_cam['crop_ratio'][:, None]
crop_center = rw_cam['bbox_center'] / torch.cat(
- [rw_cam['img_w'][:, None], rw_cam['img_h'][:, None]], 1)
+ [rw_cam['img_w'][:, None], rw_cam['img_h'][:, None]], 1
+ )
xc = torch.cat([xc, vfov, crop_ratio, crop_center], 1)
xc = self.fc1(xc)
@@ -315,8 +326,8 @@ class Regressor(nn.Module):
xc_lhand = torch.cat([xc_lhand, pred_lhand], 1)
xc_rhand = torch.cat([xc_rhand, pred_rhand], 1)
elif self.full_body_mode:
- xc_lhand, xc_rhand, xc_face = kwargs['xc_lhand'], kwargs[
- 'xc_rhand'], kwargs['xc_face']
+ xc_lhand, xc_rhand, xc_face = kwargs['xc_lhand'], kwargs['xc_rhand'
+ ], kwargs['xc_face']
xc_lhand = torch.cat([xc_lhand, pred_lhand], 1)
xc_rhand = torch.cat([xc_rhand, pred_rhand], 1)
xc_face = torch.cat([xc_face, pred_face, pred_exp], 1)
@@ -328,7 +339,8 @@ class Regressor(nn.Module):
if cfg.MODEL.PyMAF.OPT_WRIST:
xc_lhand = torch.cat(
- [xc_lhand, pred_shape_lh, pred_orient_lh, pred_cam_lh], 1)
+ [xc_lhand, pred_shape_lh, pred_orient_lh, pred_cam_lh], 1
+ )
xc_lhand = self.drop3_hand(self.fc3_hand(xc_lhand))
pred_shape_lh = self.decshape_rhand(xc_lhand) + pred_shape_lh
@@ -342,7 +354,8 @@ class Regressor(nn.Module):
if cfg.MODEL.MESH_MODEL == 'mano' or cfg.MODEL.PyMAF.OPT_WRIST:
xc_rhand = torch.cat(
- [xc_rhand, pred_shape_rh, pred_orient_rh, pred_cam_rh], 1)
+ [xc_rhand, pred_shape_rh, pred_orient_rh, pred_cam_rh], 1
+ )
xc_rhand = self.drop3_hand(self.fc3_hand(xc_rhand))
pred_shape_rh = self.decshape_rhand(xc_rhand) + pred_shape_rh
@@ -351,7 +364,8 @@ class Regressor(nn.Module):
if cfg.MODEL.MESH_MODEL == 'mano':
pred_cam = torch.cat(
- [pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:] / 10.], dim=1)
+ [pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:] / 10.], dim=1
+ )
if 'face' in self.part_names:
xc_face = self.drop1_face(self.fc1_face(xc_face))
@@ -361,7 +375,8 @@ class Regressor(nn.Module):
if cfg.MODEL.MESH_MODEL == 'flame':
xc_face = torch.cat(
- [xc_face, pred_shape_fa, pred_orient_fa, pred_cam_fa], 1)
+ [xc_face, pred_shape_fa, pred_orient_fa, pred_cam_fa], 1
+ )
xc_face = self.drop3_face(self.fc3_face(xc_face))
pred_shape_fa = self.decshape_face(xc_face) + pred_shape_fa
@@ -370,7 +385,8 @@ class Regressor(nn.Module):
if cfg.MODEL.MESH_MODEL == 'flame':
pred_cam = torch.cat(
- [pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:] / 10.], dim=1)
+ [pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:] / 10.], dim=1
+ )
if self.full_body_mode or self.body_hand_mode:
if cfg.MODEL.PyMAF.PRED_VIS_H:
@@ -385,22 +401,26 @@ class Regressor(nn.Module):
if cfg.MODEL.PyMAF.OPT_WRIST:
- pred_rotmat_body = rot6d_to_rotmat(pred_pose.reshape(
- batch_size, -1, 6)) # .view(batch_size, 24, 3, 3)
+ pred_rotmat_body = rot6d_to_rotmat(
+ pred_pose.reshape(batch_size, -1, 6)
+ ) # .view(batch_size, 24, 3, 3)
pred_lwrist = pred_rotmat_body[:, 20]
pred_rwrist = pred_rotmat_body[:, 21]
pred_gl_body, body_joints = self.body_model.get_global_rotation(
global_orient=pred_rotmat_body[:, 0:1],
- body_pose=pred_rotmat_body[:, 1:])
+ body_pose=pred_rotmat_body[:, 1:]
+ )
pred_gl_lelbow = pred_gl_body[:, 18]
pred_gl_relbow = pred_gl_body[:, 19]
target_gl_lwrist = rot6d_to_rotmat(
- pred_orient_lh.reshape(batch_size, -1, 6))
+ pred_orient_lh.reshape(batch_size, -1, 6)
+ )
target_gl_lwrist *= self.flip_vector.to(target_gl_lwrist.device)
target_gl_rwrist = rot6d_to_rotmat(
- pred_orient_rh.reshape(batch_size, -1, 6))
+ pred_orient_rh.reshape(batch_size, -1, 6)
+ )
opt_lwrist = torch.bmm(pred_gl_lelbow.transpose(1, 2), target_gl_lwrist)
opt_rwrist = torch.bmm(pred_gl_relbow.transpose(1, 2), target_gl_rwrist)
@@ -408,34 +428,40 @@ class Regressor(nn.Module):
if cfg.MODEL.PyMAF.ADAPT_INTEGR:
# if cfg.MODEL.PyMAF.ADAPT_INTEGR and global_iter == (cfg.MODEL.PyMAF.N_ITER - 1):
tpose_joints = self.smpl.get_tpose(betas=pred_shape)
- lelbow_twist_axis = nn.functional.normalize(tpose_joints[:, 20] -
- tpose_joints[:, 18],
- dim=1)
- relbow_twist_axis = nn.functional.normalize(tpose_joints[:, 21] -
- tpose_joints[:, 19],
- dim=1)
+ lelbow_twist_axis = nn.functional.normalize(
+ tpose_joints[:, 20] - tpose_joints[:, 18], dim=1
+ )
+ relbow_twist_axis = nn.functional.normalize(
+ tpose_joints[:, 21] - tpose_joints[:, 19], dim=1
+ )
lelbow_twist, lelbow_twist_angle = compute_twist_rotation(
- opt_lwrist, lelbow_twist_axis)
+ opt_lwrist, lelbow_twist_axis
+ )
relbow_twist, relbow_twist_angle = compute_twist_rotation(
- opt_rwrist, relbow_twist_axis)
+ opt_rwrist, relbow_twist_axis
+ )
min_angle = -0.4 * float(np.pi)
max_angle = 0.4 * float(np.pi)
- lelbow_twist_angle[lelbow_twist_angle == torch.clamp(
- lelbow_twist_angle, min_angle, max_angle)] = 0
- relbow_twist_angle[relbow_twist_angle == torch.clamp(
- relbow_twist_angle, min_angle, max_angle)] = 0
+ lelbow_twist_angle[lelbow_twist_angle == torch.
+ clamp(lelbow_twist_angle, min_angle, max_angle)
+ ] = 0
+ relbow_twist_angle[relbow_twist_angle == torch.
+ clamp(relbow_twist_angle, min_angle, max_angle)
+ ] = 0
lelbow_twist_angle[lelbow_twist_angle > max_angle] -= max_angle
lelbow_twist_angle[lelbow_twist_angle < min_angle] -= min_angle
relbow_twist_angle[relbow_twist_angle > max_angle] -= max_angle
relbow_twist_angle[relbow_twist_angle < min_angle] -= min_angle
- lelbow_twist = batch_rodrigues(lelbow_twist_axis *
- lelbow_twist_angle)
- relbow_twist = batch_rodrigues(relbow_twist_axis *
- relbow_twist_angle)
+ lelbow_twist = batch_rodrigues(
+ lelbow_twist_axis * lelbow_twist_angle
+ )
+ relbow_twist = batch_rodrigues(
+ relbow_twist_axis * relbow_twist_angle
+ )
opt_lwrist = torch.bmm(lelbow_twist.transpose(1, 2), opt_lwrist)
opt_rwrist = torch.bmm(relbow_twist.transpose(1, 2), opt_rwrist)
@@ -446,7 +472,8 @@ class Regressor(nn.Module):
opt_relbow = torch.bmm(pred_rotmat_body[:, 19], relbow_twist)
if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == (
- cfg.MODEL.PyMAF.N_ITER - 1):
+ cfg.MODEL.PyMAF.N_ITER - 1
+ ):
opt_lwrist_filtered = [
opt_lwrist[_i]
if pred_vis_lhand[_i] else pred_rotmat_body[_i, 20]
@@ -473,16 +500,19 @@ class Regressor(nn.Module):
opt_lelbow = torch.stack(opt_lelbow_filtered)
opt_relbow = torch.stack(opt_relbow_filtered)
- pred_rotmat_body = torch.cat([
- pred_rotmat_body[:, :18],
- opt_lelbow.unsqueeze(1),
- opt_relbow.unsqueeze(1),
- opt_lwrist.unsqueeze(1),
- opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
- ], 1)
+ pred_rotmat_body = torch.cat(
+ [
+ pred_rotmat_body[:, :18],
+ opt_lelbow.unsqueeze(1),
+ opt_relbow.unsqueeze(1),
+ opt_lwrist.unsqueeze(1),
+ opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
+ ], 1
+ )
else:
if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == (
- cfg.MODEL.PyMAF.N_ITER - 1):
+ cfg.MODEL.PyMAF.N_ITER - 1
+ ):
opt_lwrist_filtered = [
opt_lwrist[_i]
if pred_vis_lhand[_i] else pred_rotmat_body[_i, 20]
@@ -497,32 +527,36 @@ class Regressor(nn.Module):
opt_lwrist = torch.stack(opt_lwrist_filtered)
opt_rwrist = torch.stack(opt_rwrist_filtered)
- pred_rotmat_body = torch.cat([
- pred_rotmat_body[:, :20],
- opt_lwrist.unsqueeze(1),
- opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
- ], 1)
+ pred_rotmat_body = torch.cat(
+ [
+ pred_rotmat_body[:, :20],
+ opt_lwrist.unsqueeze(1),
+ opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
+ ], 1
+ )
if self.hand_only_mode:
pred_rotmat_rh = rot6d_to_rotmat(
- torch.cat([pred_orient_rh, pred_rhand],
- dim=1).reshape(batch_size, -1, 6)) # .view(batch_size, 16, 3, 3)
+ torch.cat([pred_orient_rh, pred_rhand], dim=1).reshape(batch_size, -1, 6)
+ ) # .view(batch_size, 16, 3, 3)
assert pred_rotmat_rh.shape[1] == 1 + 15
elif self.face_only_mode:
pred_rotmat_fa = rot6d_to_rotmat(
- torch.cat([pred_orient_fa, pred_face],
- dim=1).reshape(batch_size, -1, 6)) # .view(batch_size, 16, 3, 3)
+ torch.cat([pred_orient_fa, pred_face], dim=1).reshape(batch_size, -1, 6)
+ ) # .view(batch_size, 16, 3, 3)
assert pred_rotmat_fa.shape[1] == 1 + 3
elif self.full_body_mode or self.body_hand_mode:
if cfg.MODEL.PyMAF.OPT_WRIST:
pred_rotmat = pred_rotmat_body
else:
- pred_rotmat = rot6d_to_rotmat(pred_pose.reshape(batch_size, -1,
- 6)) # .view(batch_size, 24, 3, 3)
+ pred_rotmat = rot6d_to_rotmat(
+ pred_pose.reshape(batch_size, -1, 6)
+ ) # .view(batch_size, 24, 3, 3)
assert pred_rotmat.shape[1] == 24
else:
- pred_rotmat = rot6d_to_rotmat(pred_pose.reshape(batch_size, -1,
- 6)) # .view(batch_size, 24, 3, 3)
+ pred_rotmat = rot6d_to_rotmat(
+ pred_pose.reshape(batch_size, -1, 6)
+ ) # .view(batch_size, 24, 3, 3)
assert pred_rotmat.shape[1] == 24
# if self.full_body_mode:
@@ -547,8 +581,8 @@ class Regressor(nn.Module):
assert pred_hfrotmat.shape[1] == (15 * 2 + 3)
# flip left hand pose
- pred_lhand_rotmat = pred_hfrotmat[:, :15] * self.flip_vector.to(
- pred_hfrotmat.device).unsqueeze(0)
+ pred_lhand_rotmat = pred_hfrotmat[:, :15] * self.flip_vector.to(pred_hfrotmat.device
+ ).unsqueeze(0)
pred_rhand_rotmat = pred_hfrotmat[:, 15:30]
pred_face_rotmat = pred_hfrotmat[:, 30:]
@@ -596,17 +630,20 @@ class Regressor(nn.Module):
elif self.face_only_mode:
pred_joints_full = pred_output.face_joints
elif self.smplx_mode:
- pred_joints_full = torch.cat([
- pred_joints, pred_output.lhand_joints, pred_output.rhand_joints,
- pred_output.face_joints, pred_output.lfoot_joints, pred_output.rfoot_joints
- ],
- dim=1)
+ pred_joints_full = torch.cat(
+ [
+ pred_joints, pred_output.lhand_joints, pred_output.rhand_joints,
+ pred_output.face_joints, pred_output.lfoot_joints, pred_output.rfoot_joints
+ ],
+ dim=1
+ )
else:
pred_joints_full = pred_joints
- pred_keypoints_2d = projection(pred_joints_full, {
- **rw_cam, 'cam_sxy': pred_cam
- },
- iwp_mode=cfg.MODEL.USE_IWP_CAM)
+ pred_keypoints_2d = projection(
+ pred_joints_full, {
+ **rw_cam, 'cam_sxy': pred_cam
+ }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+ )
if cfg.MODEL.USE_IWP_CAM:
# Normalize keypoints to [-1,1]
pred_keypoints_2d = pred_keypoints_2d / (224. / 2.)
@@ -624,126 +661,137 @@ class Regressor(nn.Module):
else:
kp_3d = pred_joints
pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72)
- output.update({
- 'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
- 'verts': pred_vertices,
- 'kp_2d': pred_keypoints_2d[:, :len_b_kp],
- 'kp_3d': kp_3d,
- 'pred_joints': pred_joints,
- 'smpl_kp_3d': pred_output.smpl_joints,
- 'rotmat': pred_rotmat,
- 'pred_cam': pred_cam,
- 'pred_shape': pred_shape,
- 'pred_pose': pred_pose,
- })
+ output.update(
+ {
+ 'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
+ 'verts': pred_vertices,
+ 'kp_2d': pred_keypoints_2d[:, :len_b_kp],
+ 'kp_3d': kp_3d,
+ 'pred_joints': pred_joints,
+ 'smpl_kp_3d': pred_output.smpl_joints,
+ 'rotmat': pred_rotmat,
+ 'pred_cam': pred_cam,
+ 'pred_shape': pred_shape,
+ 'pred_pose': pred_pose,
+ }
+ )
# if self.full_body_mode:
if self.smplx_mode:
# assert pred_keypoints_2d.shape[1] == 144
len_h_kp = len(constants.HAND_NAMES)
len_f_kp = len(constants.FACIAL_LANDMARKS)
len_feet_kp = 2 * len(constants.FOOT_NAMES)
- output.update({
- 'smplx_verts':
- pred_output.smplx_vertices if cfg.MODEL.EVAL_MODE else None,
- 'pred_lhand':
- pred_lhand,
- 'pred_rhand':
- pred_rhand,
- 'pred_face':
- pred_face,
- 'pred_exp':
- pred_exp,
- 'verts_lh':
- pred_output.lhand_vertices,
- 'verts_rh':
- pred_output.rhand_vertices,
- # 'pred_arm_rotmat': pred_arm_rotmat,
- # 'pred_hfrotmat': pred_hfrotmat,
- 'pred_lhand_rotmat':
- pred_lhand_rotmat,
- 'pred_rhand_rotmat':
- pred_rhand_rotmat,
- 'pred_face_rotmat':
- pred_face_rotmat,
- 'pred_lhand_kp3d':
- pred_output.lhand_joints,
- 'pred_rhand_kp3d':
- pred_output.rhand_joints,
- 'pred_face_kp3d':
- pred_output.face_joints,
- 'pred_lhand_kp2d':
- pred_keypoints_2d[:, len_b_kp:len_b_kp + len_h_kp],
- 'pred_rhand_kp2d':
- pred_keypoints_2d[:, len_b_kp + len_h_kp:len_b_kp + len_h_kp * 2],
- 'pred_face_kp2d':
- pred_keypoints_2d[:, len_b_kp + len_h_kp * 2:len_b_kp + len_h_kp * 2 +
- len_f_kp],
- 'pred_feet_kp2d':
- pred_keypoints_2d[:, len_b_kp + len_h_kp * 2 + len_f_kp:len_b_kp +
- len_h_kp * 2 + len_f_kp + len_feet_kp],
- })
+ output.update(
+ {
+ 'smplx_verts':
+ pred_output.smplx_vertices if cfg.MODEL.EVAL_MODE else None,
+ 'pred_lhand':
+ pred_lhand,
+ 'pred_rhand':
+ pred_rhand,
+ 'pred_face':
+ pred_face,
+ 'pred_exp':
+ pred_exp,
+ 'verts_lh':
+ pred_output.lhand_vertices,
+ 'verts_rh':
+ pred_output.rhand_vertices,
+ # 'pred_arm_rotmat': pred_arm_rotmat,
+ # 'pred_hfrotmat': pred_hfrotmat,
+ 'pred_lhand_rotmat':
+ pred_lhand_rotmat,
+ 'pred_rhand_rotmat':
+ pred_rhand_rotmat,
+ 'pred_face_rotmat':
+ pred_face_rotmat,
+ 'pred_lhand_kp3d':
+ pred_output.lhand_joints,
+ 'pred_rhand_kp3d':
+ pred_output.rhand_joints,
+ 'pred_face_kp3d':
+ pred_output.face_joints,
+ 'pred_lhand_kp2d':
+ pred_keypoints_2d[:, len_b_kp:len_b_kp + len_h_kp],
+ 'pred_rhand_kp2d':
+ pred_keypoints_2d[:, len_b_kp + len_h_kp:len_b_kp + len_h_kp * 2],
+ 'pred_face_kp2d':
+ pred_keypoints_2d[:, len_b_kp + len_h_kp * 2:len_b_kp + len_h_kp * 2 +
+ len_f_kp],
+ 'pred_feet_kp2d':
+ pred_keypoints_2d[:, len_b_kp + len_h_kp * 2 + len_f_kp:len_b_kp +
+ len_h_kp * 2 + len_f_kp + len_feet_kp],
+ }
+ )
if cfg.MODEL.PyMAF.OPT_WRIST:
- output.update({
- 'pred_orient_lh': pred_orient_lh,
- 'pred_shape_lh': pred_shape_lh,
- 'pred_orient_rh': pred_orient_rh,
- 'pred_shape_rh': pred_shape_rh,
- 'pred_cam_fa': pred_cam_fa,
- 'pred_cam_lh': pred_cam_lh,
- 'pred_cam_rh': pred_cam_rh,
- })
+ output.update(
+ {
+ 'pred_orient_lh': pred_orient_lh,
+ 'pred_shape_lh': pred_shape_lh,
+ 'pred_orient_rh': pred_orient_rh,
+ 'pred_shape_rh': pred_shape_rh,
+ 'pred_cam_fa': pred_cam_fa,
+ 'pred_cam_lh': pred_cam_lh,
+ 'pred_cam_rh': pred_cam_rh,
+ }
+ )
if cfg.MODEL.PyMAF.PRED_VIS_H:
output.update({'pred_vis_hands': pred_vis_hands})
elif self.hand_only_mode:
# hand mesh out
assert pred_keypoints_2d.shape[1] == 21
- output.update({
- 'theta': pred_cam,
- 'pred_cam': pred_cam,
- 'pred_rhand': pred_rhand,
- 'pred_rhand_rotmat': pred_rotmat_rh[:, 1:],
- 'pred_orient_rh': pred_orient_rh,
- 'pred_orient_rh_rotmat': pred_rotmat_rh[:, 0],
- 'verts_rh': pred_output.rhand_vertices,
- 'pred_cam_rh': pred_cam_rh,
- 'pred_shape_rh': pred_shape_rh,
- 'pred_rhand_kp3d': pred_output.rhand_joints,
- 'pred_rhand_kp2d': pred_keypoints_2d,
- })
+ output.update(
+ {
+ 'theta': pred_cam,
+ 'pred_cam': pred_cam,
+ 'pred_rhand': pred_rhand,
+ 'pred_rhand_rotmat': pred_rotmat_rh[:, 1:],
+ 'pred_orient_rh': pred_orient_rh,
+ 'pred_orient_rh_rotmat': pred_rotmat_rh[:, 0],
+ 'verts_rh': pred_output.rhand_vertices,
+ 'pred_cam_rh': pred_cam_rh,
+ 'pred_shape_rh': pred_shape_rh,
+ 'pred_rhand_kp3d': pred_output.rhand_joints,
+ 'pred_rhand_kp2d': pred_keypoints_2d,
+ }
+ )
elif self.face_only_mode:
# face mesh out
assert pred_keypoints_2d.shape[1] == 68
- output.update({
- 'theta': pred_cam,
- 'pred_cam': pred_cam,
- 'pred_face': pred_face,
- 'pred_exp': pred_exp,
- 'pred_face_rotmat': pred_rotmat_fa[:, 1:],
- 'pred_orient_fa': pred_orient_fa,
- 'pred_orient_fa_rotmat': pred_rotmat_fa[:, 0],
- 'verts_fa': pred_output.flame_vertices,
- 'pred_cam_fa': pred_cam_fa,
- 'pred_shape_fa': pred_shape_fa,
- 'pred_face_kp3d': pred_output.face_joints,
- 'pred_face_kp2d': pred_keypoints_2d,
- })
+ output.update(
+ {
+ 'theta': pred_cam,
+ 'pred_cam': pred_cam,
+ 'pred_face': pred_face,
+ 'pred_exp': pred_exp,
+ 'pred_face_rotmat': pred_rotmat_fa[:, 1:],
+ 'pred_orient_fa': pred_orient_fa,
+ 'pred_orient_fa_rotmat': pred_rotmat_fa[:, 0],
+ 'verts_fa': pred_output.flame_vertices,
+ 'pred_cam_fa': pred_cam_fa,
+ 'pred_shape_fa': pred_shape_fa,
+ 'pred_face_kp3d': pred_output.face_joints,
+ 'pred_face_kp2d': pred_keypoints_2d,
+ }
+ )
return output
-def get_attention_modules(module_keys,
- img_feature_dim_list,
- hidden_feat_dim,
- n_iter,
- num_attention_heads=1):
+def get_attention_modules(
+ module_keys, img_feature_dim_list, hidden_feat_dim, n_iter, num_attention_heads=1
+):
align_attention = nn.ModuleDict()
for k in module_keys:
align_attention[k] = nn.ModuleList()
for i in range(n_iter):
align_attention[k].append(
- get_att_block(img_feature_dim=img_feature_dim_list[k][i],
- hidden_feat_dim=hidden_feat_dim,
- num_attention_heads=num_attention_heads))
+ get_att_block(
+ img_feature_dim=img_feature_dim_list[k][i],
+ hidden_feat_dim=hidden_feat_dim,
+ num_attention_heads=num_attention_heads
+ )
+ )
return align_attention
@@ -764,11 +812,9 @@ class PyMAF(nn.Module):
PyMAF: 3D Human Pose and Shape Regression with Pyramidal Mesh Alignment Feedback Loop, in ICCV, 2021
PyMAF-X: Towards Well-aligned Full-body Model Regression from Monocular Images, arXiv:2207.06400, 2022
"""
-
- def __init__(self,
- smpl_mean_params=SMPL_MEAN_PARAMS,
- pretrained=True,
- device=torch.device('cuda')):
+ def __init__(
+ self, smpl_mean_params=SMPL_MEAN_PARAMS, pretrained=True, device=torch.device('cuda')
+ ):
super().__init__()
self.device = device
@@ -829,8 +875,9 @@ class PyMAF(nn.Module):
self.smpl_family['face'] = SMPL_Family(model_type='flame')
self.smpl_family['body'] = SMPL_Family(model_type='smplx')
else:
- self.smpl_family['body'] = SMPL_Family(model_type=cfg.MODEL.MESH_MODEL,
- all_gender=cfg.MODEL.ALL_GENDER)
+ self.smpl_family['body'] = SMPL_Family(
+ model_type=cfg.MODEL.MESH_MODEL, all_gender=cfg.MODEL.ALL_GENDER
+ )
self.init_mesh_output = None
self.batch_size = 1
@@ -845,14 +892,14 @@ class PyMAF(nn.Module):
if 'body' in bhf_names:
# if self.smplx_mode or 'hr' in cfg.MODEL.PyMAF.BACKBONE:
if cfg.MODEL.PyMAF.BACKBONE == 'res50':
- body_encoder = get_resnet_encoder(cfg,
- init_weight=(not cfg.MODEL.EVAL_MODE),
- global_mode=self.global_mode)
+ body_encoder = get_resnet_encoder(
+ cfg, init_weight=(not cfg.MODEL.EVAL_MODE), global_mode=self.global_mode
+ )
body_sfeat_dim = list(cfg.POSE_RES_MODEL.EXTRA.NUM_DECONV_FILTERS)
elif cfg.MODEL.PyMAF.BACKBONE == 'hr48':
- body_encoder = get_hrnet_encoder(cfg,
- init_weight=(not cfg.MODEL.EVAL_MODE),
- global_mode=self.global_mode)
+ body_encoder = get_hrnet_encoder(
+ cfg, init_weight=(not cfg.MODEL.EVAL_MODE), global_mode=self.global_mode
+ )
body_sfeat_dim = list(cfg.HR_MODEL.EXTRA.STAGE4.NUM_CHANNELS)
body_sfeat_dim.reverse()
body_sfeat_dim = body_sfeat_dim[1:]
@@ -885,7 +932,8 @@ class PyMAF(nn.Module):
self.encoders[hf] = get_resnet_encoder(
cfg,
init_weight=(not cfg.MODEL.EVAL_MODE),
- global_mode=self.global_mode)
+ global_mode=self.global_mode
+ )
self.part_module_names[hf].update({f'encoders.{hf}': self.encoders[hf]})
hf_sfeat_dim = list(cfg.POSE_RES_MODEL.EXTRA.NUM_DECONV_FILTERS)
else:
@@ -895,15 +943,19 @@ class PyMAF(nn.Module):
assert cfg.MODEL.PyMAF.MAF_ON
self.dp_head_hf = nn.ModuleDict()
if 'hand' in bhf_names:
- self.dp_head_hf['hand'] = IUV_predict_layer(feat_dim=hf_sfeat_dim[-1],
- mode='pncc')
+ self.dp_head_hf['hand'] = IUV_predict_layer(
+ feat_dim=hf_sfeat_dim[-1], mode='pncc'
+ )
self.part_module_names['hand'].update(
- {'dp_head_hf.hand': self.dp_head_hf['hand']})
+ {'dp_head_hf.hand': self.dp_head_hf['hand']}
+ )
if 'face' in bhf_names:
- self.dp_head_hf['face'] = IUV_predict_layer(feat_dim=hf_sfeat_dim[-1],
- mode='pncc')
+ self.dp_head_hf['face'] = IUV_predict_layer(
+ feat_dim=hf_sfeat_dim[-1], mode='pncc'
+ )
self.part_module_names['face'].update(
- {'dp_head_hf.face': self.dp_head_hf['face']})
+ {'dp_head_hf.face': self.dp_head_hf['face']}
+ )
smpl2limb_vert_faces = get_partial_smpl()
@@ -914,7 +966,8 @@ class PyMAF(nn.Module):
grid_size = 21
xv, yv = torch.meshgrid(
[torch.linspace(-1, 1, grid_size),
- torch.linspace(-1, 1, grid_size)])
+ torch.linspace(-1, 1, grid_size)]
+ )
grid_points = torch.stack([xv.reshape(-1), yv.reshape(-1)]).unsqueeze(0)
self.register_buffer('grid_points', grid_points)
grid_feat_dim = grid_size * grid_size * cfg.MODEL.PyMAF.MLP_DIM[-1]
@@ -943,7 +996,8 @@ class PyMAF(nn.Module):
if 'face' in self.bhf_names:
bhf_ma_feat_dim.update(
- {'face': len(constants.FACIAL_LANDMARKS) * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]})
+ {'face': len(constants.FACIAL_LANDMARKS) * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]}
+ )
if self.fuse_grid_align:
bhf_att_feat_dim.update({'face': 1024})
@@ -959,25 +1013,31 @@ class PyMAF(nn.Module):
if 'face' in bhf_names:
hfimg_feat_dim_list['face'] = hf_sfeat_dim[-n_iter_att:]
- self.align_attention = get_attention_modules(bhf_names,
- hfimg_feat_dim_list,
- hidden_feat_dim,
- n_iter=n_iter_att,
- num_attention_heads=num_att_heads)
+ self.align_attention = get_attention_modules(
+ bhf_names,
+ hfimg_feat_dim_list,
+ hidden_feat_dim,
+ n_iter=n_iter_att,
+ num_attention_heads=num_att_heads
+ )
for part in bhf_names:
self.part_module_names[part].update(
- {f'align_attention.{part}': self.align_attention[part]})
+ {f'align_attention.{part}': self.align_attention[part]}
+ )
if self.fuse_grid_align:
- self.att_feat_reduce = get_fusion_modules(bhf_names,
- bhf_ma_feat_dim,
- grid_feat_dim,
- n_iter=n_iter_att,
- out_feat_len=bhf_att_feat_dim)
+ self.att_feat_reduce = get_fusion_modules(
+ bhf_names,
+ bhf_ma_feat_dim,
+ grid_feat_dim,
+ n_iter=n_iter_att,
+ out_feat_len=bhf_att_feat_dim
+ )
for part in bhf_names:
self.part_module_names[part].update(
- {f'att_feat_reduce.{part}': self.att_feat_reduce[part]})
+ {f'att_feat_reduce.{part}': self.att_feat_reduce[part]}
+ )
# build regressor for parameter prediction
self.regressor = nn.ModuleList()
@@ -1002,10 +1062,13 @@ class PyMAF(nn.Module):
if self.smpl_mode:
self.regressor.append(
- Regressor(feat_dim=ref_infeat_dim,
- smpl_mean_params=smpl_mean_params,
- use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT,
- smpl_models=self.smpl_family))
+ Regressor(
+ feat_dim=ref_infeat_dim,
+ smpl_mean_params=smpl_mean_params,
+ use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT,
+ smpl_models=self.smpl_family
+ )
+ )
else:
if cfg.MODEL.PyMAF.MAF_ON:
if 'hand' in self.bhf_names or 'face' in self.bhf_names:
@@ -1032,28 +1095,35 @@ class PyMAF(nn.Module):
feat_dim_face = global_feat_dim
self.regressor.append(
- Regressor(feat_dim=ref_infeat_dim,
- smpl_mean_params=smpl_mean_params,
- use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT,
- feat_dim_hand=feat_dim_hand,
- feat_dim_face=feat_dim_face,
- bhf_names=bhf_names,
- smpl_models=self.smpl_family))
+ Regressor(
+ feat_dim=ref_infeat_dim,
+ smpl_mean_params=smpl_mean_params,
+ use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT,
+ feat_dim_hand=feat_dim_hand,
+ feat_dim_face=feat_dim_face,
+ bhf_names=bhf_names,
+ smpl_models=self.smpl_family
+ )
+ )
# assign sub-regressor to each part
for dec_name, dec_module in self.regressor[-1].named_children():
if 'hand' in dec_name:
self.part_module_names['hand'].update(
- {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module})
+ {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
+ )
elif 'face' in dec_name or 'head' in dec_name or 'exp' in dec_name:
self.part_module_names['face'].update(
- {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module})
+ {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
+ )
elif 'res' in dec_name or 'vis' in dec_name:
self.part_module_names['link'].update(
- {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module})
+ {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
+ )
elif 'body' in self.part_module_names:
self.part_module_names['body'].update(
- {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module})
+ {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
+ )
# mesh-aligned feature extractor
self.maf_extractor = nn.ModuleDict()
@@ -1070,12 +1140,17 @@ class PyMAF(nn.Module):
if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT and i >= self.att_starts:
self.maf_extractor[part].append(
- MAF_Extractor(filter_channels=filter_channels_default[att_feat_dim_idx:],
- iwp_cam_mode=cfg.MODEL.USE_IWP_CAM))
+ MAF_Extractor(
+ filter_channels=filter_channels_default[att_feat_dim_idx:],
+ iwp_cam_mode=cfg.MODEL.USE_IWP_CAM
+ )
+ )
else:
self.maf_extractor[part].append(
- MAF_Extractor(filter_channels=filter_channels,
- iwp_cam_mode=cfg.MODEL.USE_IWP_CAM))
+ MAF_Extractor(
+ filter_channels=filter_channels, iwp_cam_mode=cfg.MODEL.USE_IWP_CAM
+ )
+ )
self.part_module_names[part].update({f'maf_extractor.{part}': self.maf_extractor[part]})
# check all modules have been added to part_module_names
@@ -1099,10 +1174,9 @@ class PyMAF(nn.Module):
""" initialize the mesh model with default poses and shapes
"""
if self.init_mesh_output is None or self.batch_size != batch_size:
- self.init_mesh_output = self.regressor[0](torch.zeros(batch_size),
- J_regressor=J_regressor,
- rw_cam=rw_cam,
- init_mode=True)
+ self.init_mesh_output = self.regressor[0](
+ torch.zeros(batch_size), J_regressor=J_regressor, rw_cam=rw_cam, init_mode=True
+ )
self.batch_size = batch_size
return self.init_mesh_output
@@ -1110,11 +1184,13 @@ class PyMAF(nn.Module):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
- nn.Conv2d(self.inplanes,
- planes * block.expansion,
- kernel_size=1,
- stride=stride,
- bias=False),
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False
+ ),
nn.BatchNorm2d(planes * block.expansion),
)
@@ -1156,13 +1232,16 @@ class PyMAF(nn.Module):
planes = num_filters[i]
layers.append(
- nn.ConvTranspose2d(in_channels=self.inplanes,
- out_channels=planes,
- kernel_size=kernel,
- stride=2,
- padding=padding,
- output_padding=output_padding,
- bias=self.deconv_with_bias))
+ nn.ConvTranspose2d(
+ in_channels=self.inplanes,
+ out_channels=planes,
+ kernel_size=kernel,
+ stride=2,
+ padding=padding,
+ output_padding=output_padding,
+ bias=self.deconv_with_bias
+ )
+ )
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
self.inplanes = planes
@@ -1196,7 +1275,7 @@ class PyMAF(nn.Module):
vis_feat_list: the list containing features for visualization
'''
- # batch keys: ['img_body', 'orig_height', 'orig_width', 'person_id', 'img_lhand',
+ # batch keys: ['img_body', 'orig_height', 'orig_width', 'person_id', 'img_lhand',
# 'lhand_theta_inv', 'img_rhand', 'rhand_theta_inv', 'img_face', 'face_theta_inv']
# extract spatial features or global features
@@ -1234,7 +1313,8 @@ class PyMAF(nn.Module):
img_rhand = batch['img_rhand']
batch_size = img_rhand.shape[0]
limb_feat_dict['rhand'], limb_gfeat_dict['rhand'] = self.encoders['hand'](
- img_rhand)
+ img_rhand
+ )
if cfg.MODEL.PyMAF.MAF_ON:
for k in limb_feat_dict.keys():
@@ -1292,10 +1372,11 @@ class PyMAF(nn.Module):
if self.hand_only_mode:
pred_cam = mesh_output['pred_cam'].detach()
pred_rhand_v = self.mano_sampler(mesh_output['verts_rh'])
- pred_rhand_proj = projection(pred_rhand_v, {
- **rw_cam, 'cam_sxy': pred_cam
- },
- iwp_mode=cfg.MODEL.USE_IWP_CAM)
+ pred_rhand_proj = projection(
+ pred_rhand_v, {
+ **rw_cam, 'cam_sxy': pred_cam
+ }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+ )
if cfg.MODEL.USE_IWP_CAM:
pred_rhand_proj = pred_rhand_proj / (224. / 2.)
else:
@@ -1310,10 +1391,11 @@ class PyMAF(nn.Module):
elif self.face_only_mode:
pred_cam = mesh_output['pred_cam'].detach()
pred_face_v = mesh_output['pred_face_kp3d']
- pred_face_proj = projection(pred_face_v, {
- **rw_cam, 'cam_sxy': pred_cam
- },
- iwp_mode=cfg.MODEL.USE_IWP_CAM)
+ pred_face_proj = projection(
+ pred_face_v, {
+ **rw_cam, 'cam_sxy': pred_cam
+ }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+ )
if cfg.MODEL.USE_IWP_CAM:
pred_face_proj = pred_face_proj / (224. / 2.)
else:
@@ -1326,10 +1408,11 @@ class PyMAF(nn.Module):
pred_lhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2lhand])
pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand])
pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1)
- pred_hand_proj = projection(pred_hand_v, {
- **rw_cam, 'cam_sxy': pred_cam
- },
- iwp_mode=cfg.MODEL.USE_IWP_CAM)
+ pred_hand_proj = projection(
+ pred_hand_v, {
+ **rw_cam, 'cam_sxy': pred_cam
+ }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+ )
if cfg.MODEL.USE_IWP_CAM:
pred_hand_proj = pred_hand_proj / (224. / 2.)
else:
@@ -1343,20 +1426,23 @@ class PyMAF(nn.Module):
}
proj_hf_pts = {
'lhand':
- torch.cat([proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]],
- dim=1),
+ torch.cat(
+ [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1
+ ),
'rhand':
- torch.cat([proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]],
- dim=1),
+ torch.cat(
+ [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1
+ ),
}
elif self.full_body_mode:
pred_lhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2lhand])
pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand])
pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1)
- pred_hand_proj = projection(pred_hand_v, {
- **rw_cam, 'cam_sxy': pred_cam
- },
- iwp_mode=cfg.MODEL.USE_IWP_CAM)
+ pred_hand_proj = projection(
+ pred_hand_v, {
+ **rw_cam, 'cam_sxy': pred_cam
+ }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+ )
if cfg.MODEL.USE_IWP_CAM:
pred_hand_proj = pred_hand_proj / (224. / 2.)
else:
@@ -1372,11 +1458,13 @@ class PyMAF(nn.Module):
}
proj_hf_pts = {
'lhand':
- torch.cat([proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]],
- dim=1),
+ torch.cat(
+ [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1
+ ),
'rhand':
- torch.cat([proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]],
- dim=1),
+ torch.cat(
+ [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1
+ ),
'face':
torch.cat([proj_hf_center['face'], mesh_output['pred_face_kp2d']], dim=1)
}
@@ -1402,7 +1490,8 @@ class PyMAF(nn.Module):
if limb_rf_i == 0 or cfg.MODEL.PyMAF.GRID_FEAT:
limb_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling(
- grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim)
+ grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim
+ )
else:
if self.hand_only_mode or self.face_only_mode:
proj_hf_pts_crop = proj_hf_pts[part_name][:, :, :2]
@@ -1422,8 +1511,8 @@ class PyMAF(nn.Module):
theta_i_inv = batch[f'{part_name}_theta_inv']
proj_hf_pts_crop = torch.bmm(
theta_i_inv,
- homo_vector(proj_hf_pts[part_name][:, :, :2]).permute(
- 0, 2, 1)).permute(0, 2, 1)
+ homo_vector(proj_hf_pts[part_name][:, :, :2]).permute(0, 2, 1)
+ ).permute(0, 2, 1)
if part_name == 'lhand':
flip_x = torch.tensor([-1, 1])[None,
@@ -1445,15 +1534,17 @@ class PyMAF(nn.Module):
limb_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling(
proj_hf_pts_crop_ctd.detach(),
im_feat=limb_feat_i,
- reduce_dim=limb_reduce_dim)
+ reduce_dim=limb_reduce_dim
+ )
if self.fuse_grid_align and limb_rf_i >= self.att_starts:
limb_grid_feature_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling(
- grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim)
+ grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim
+ )
limb_grid_ref_feat_ctd = torch.cat(
- [limb_grid_feature_ctd, limb_ref_feat_ctd],
- dim=-1).permute(0, 2, 1)
+ [limb_grid_feature_ctd, limb_ref_feat_ctd], dim=-1
+ ).permute(0, 2, 1)
if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT:
att_ref_feat_ctd = self.align_attention[hf_key][
@@ -1462,7 +1553,8 @@ class PyMAF(nn.Module):
att_ref_feat_ctd = limb_grid_ref_feat_ctd
att_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].reduce_dim(
- att_ref_feat_ctd.permute(0, 2, 1)).view(batch_size, -1)
+ att_ref_feat_ctd.permute(0, 2, 1)
+ ).view(batch_size, -1)
limb_ref_feat_ctd = self.att_feat_reduce[hf_key][
limb_rf_i - self.att_starts](att_ref_feat_ctd)
@@ -1479,11 +1571,13 @@ class PyMAF(nn.Module):
reduce_dim = (not self.fuse_grid_align) or (rf_i < self.att_starts)
if rf_i == 0 or cfg.MODEL.PyMAF.GRID_FEAT:
ref_feature = self.maf_extractor['body'][rf_i].sampling(
- grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim)
+ grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim
+ )
else:
# TODO: use a more sparse SMPL implementation (with 431 vertices) for acceleration
pred_smpl_verts_ds = self.mesh_sampler.downsample(
- pred_smpl_verts) # [B, 431, 3]
+ pred_smpl_verts
+ ) # [B, 431, 3]
ref_feature = self.maf_extractor['body'][rf_i](
pred_smpl_verts_ds,
im_feat=s_feat_i,
@@ -1491,25 +1585,28 @@ class PyMAF(nn.Module):
**rw_cam, 'cam_sxy': pred_cam
},
add_att=True,
- reduce_dim=reduce_dim) # [B, 431 * n_feat]
+ reduce_dim=reduce_dim
+ ) # [B, 431 * n_feat]
if self.fuse_grid_align and rf_i >= self.att_starts:
if rf_i > 0 and not cfg.MODEL.PyMAF.GRID_FEAT:
grid_feature = self.maf_extractor['body'][rf_i].sampling(
- grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim)
+ grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim
+ )
grid_ref_feat = torch.cat([grid_feature, ref_feature], dim=-1)
else:
grid_ref_feat = ref_feature
grid_ref_feat = grid_ref_feat.permute(0, 2, 1)
if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT:
- att_ref_feat = self.align_attention['body'][rf_i - self.att_starts](
- grid_ref_feat)[0]
+ att_ref_feat = self.align_attention['body'][
+ rf_i - self.att_starts](grid_ref_feat)[0]
elif cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC:
att_ref_feat = grid_ref_feat
att_ref_feat = self.maf_extractor['body'][rf_i].reduce_dim(
- att_ref_feat.permute(0, 2, 1))
+ att_ref_feat.permute(0, 2, 1)
+ )
att_ref_feat = att_ref_feat.view(batch_size, -1)
ref_feature = self.att_feat_reduce['body'][rf_i -
@@ -1560,12 +1657,14 @@ class PyMAF(nn.Module):
current_states['init_cam_rh'] = mesh_output['pred_cam_rh'].detach()
# update mesh parameters
- mesh_output = self.regressor[rf_i](ref_feature,
- n_iter=1,
- J_regressor=J_regressor,
- rw_cam=rw_cam,
- global_iter=rf_i,
- **current_states)
+ mesh_output = self.regressor[rf_i](
+ ref_feature,
+ n_iter=1,
+ J_regressor=J_regressor,
+ rw_cam=rw_cam,
+ global_iter=rf_i,
+ **current_states
+ )
out_dict['mesh_out'].append(mesh_output)
diff --git a/lib/pymafx/models/res_module.py b/lib/pymafx/models/res_module.py
index 98d7721d8562110472e4de730028a5ff6da6c0e7..94de7ecaa2ba3ead51c5f960e0ae08b806d9cd80 100644
--- a/lib/pymafx/models/res_module.py
+++ b/lib/pymafx/models/res_module.py
@@ -12,17 +12,24 @@ from collections import OrderedDict
from lib.pymafx.core.cfgs import cfg
# from .transformers.tokenlearner import TokenLearner
-
import logging
-logger = logging.getLogger(__name__)
+logger = logging.getLogger(__name__)
BN_MOMENTUM = 0.1
+
def conv3x3(in_planes, out_planes, stride=1, bias=False, groups=1):
"""3x3 convolution with padding"""
- return nn.Conv2d(in_planes * groups, out_planes * groups, kernel_size=3, stride=stride,
- padding=1, bias=bias, groups=groups)
+ return nn.Conv2d(
+ in_planes * groups,
+ out_planes * groups,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=bias,
+ groups=groups
+ )
class BasicBlock(nn.Module):
@@ -62,15 +69,28 @@ class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1):
super().__init__()
- self.conv1 = nn.Conv2d(inplanes * groups, planes * groups, kernel_size=1, bias=False, groups=groups)
+ self.conv1 = nn.Conv2d(
+ inplanes * groups, planes * groups, kernel_size=1, bias=False, groups=groups
+ )
self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
- self.conv2 = nn.Conv2d(planes * groups, planes * groups, kernel_size=3, stride=stride,
- padding=1, bias=False, groups=groups)
+ self.conv2 = nn.Conv2d(
+ planes * groups,
+ planes * groups,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False,
+ groups=groups
+ )
self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
- self.conv3 = nn.Conv2d(planes * groups, planes * self.expansion * groups, kernel_size=1,
- bias=False, groups=groups)
- self.bn3 = nn.BatchNorm2d(planes * self.expansion * groups,
- momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(
+ planes * groups,
+ planes * self.expansion * groups,
+ kernel_size=1,
+ bias=False,
+ groups=groups
+ )
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion * groups, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
@@ -98,11 +118,13 @@ class Bottleneck(nn.Module):
return out
-resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]),
- 34: (BasicBlock, [3, 4, 6, 3]),
- 50: (Bottleneck, [3, 4, 6, 3]),
- 101: (Bottleneck, [3, 4, 23, 3]),
- 152: (Bottleneck, [3, 8, 36, 3])}
+resnet_spec = {
+ 18: (BasicBlock, [2, 2, 2, 2]),
+ 34: (BasicBlock, [3, 4, 6, 3]),
+ 50: (Bottleneck, [3, 4, 6, 3]),
+ 101: (Bottleneck, [3, 4, 23, 3]),
+ 152: (Bottleneck, [3, 8, 36, 3])
+}
class IUV_predict_layer(nn.Module):
@@ -162,12 +184,12 @@ class IUV_predict_layer(nn.Module):
)
elif mode in ['pncc']:
self.predict_pncc = nn.Conv2d(
- in_channels=feat_dim,
- out_channels=3,
- kernel_size=final_cov_k,
- stride=1,
- padding=1 if final_cov_k == 3 else 0
- )
+ in_channels=feat_dim,
+ out_channels=3,
+ kernel_size=final_cov_k,
+ stride=1,
+ padding=1 if final_cov_k == 3 else 0
+ )
self.inplanes = feat_dim
@@ -175,8 +197,13 @@ class IUV_predict_layer(nn.Module):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
- nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False
+ ),
nn.BatchNorm2d(planes * block.expansion),
)
@@ -197,7 +224,6 @@ class IUV_predict_layer(nn.Module):
return_dict['predict_uv_index'] = predict_uv_index
return_dict['predict_ann_index'] = predict_ann_index
-
if self.mode == 'iuv':
predict_u = self.predict_u(x)
@@ -209,7 +235,7 @@ class IUV_predict_layer(nn.Module):
return_dict['predict_v'] = None
# return_dict['predict_u'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device)
# return_dict['predict_v'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device)
-
+
if self.mode == 'pncc':
predict_pncc = self.predict_pncc(x)
return_dict['predict_pncc'] = predict_pncc
@@ -252,10 +278,11 @@ class Kps_predict_layer(nn.Module):
stride=1,
padding=1 if final_cov_k == 3 else 0
)
- self.predict_kps = nn.Sequential(add_module,
- # nn.BatchNorm2d(feat_dim, momentum=BN_MOMENTUM),
- # conv,
- )
+ self.predict_kps = nn.Sequential(
+ add_module,
+ # nn.BatchNorm2d(feat_dim, momentum=BN_MOMENTUM),
+ # conv,
+ )
else:
self.predict_kps = nn.Conv2d(
in_channels=feat_dim,
@@ -277,8 +304,16 @@ class Kps_predict_layer(nn.Module):
class SmplResNet(nn.Module):
-
- def __init__(self, resnet_nums, in_channels=3, num_classes=229, last_stride=2, n_extra_feat=0, truncate=0, **kwargs):
+ def __init__(
+ self,
+ resnet_nums,
+ in_channels=3,
+ num_classes=229,
+ last_stride=2,
+ n_extra_feat=0,
+ truncate=0,
+ **kwargs
+ ):
super().__init__()
self.inplanes = 64
@@ -287,15 +322,16 @@ class SmplResNet(nn.Module):
# self.deconv_with_bias = extra.DECONV_WITH_BIAS
block, layers = resnet_spec[resnet_nums]
- self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,
- bias=False)
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) if truncate < 2 else None
- self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) if truncate < 1 else None
+ self.layer4 = self._make_layer(
+ block, 512, layers[3], stride=last_stride
+ ) if truncate < 1 else None
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
@@ -306,16 +342,26 @@ class SmplResNet(nn.Module):
self.n_extra_feat = n_extra_feat
if n_extra_feat > 0:
- self.trans_conv = nn.Sequential(nn.Conv2d(n_extra_feat + 512*block.expansion, 512*block.expansion, kernel_size=1, bias=False),
- nn.BatchNorm2d(512*block.expansion, momentum=BN_MOMENTUM),
- nn.ReLU(True))
+ self.trans_conv = nn.Sequential(
+ nn.Conv2d(
+ n_extra_feat + 512 * block.expansion,
+ 512 * block.expansion,
+ kernel_size=1,
+ bias=False
+ ), nn.BatchNorm2d(512 * block.expansion, momentum=BN_MOMENTUM), nn.ReLU(True)
+ )
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
- nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False
+ ),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
@@ -378,8 +424,7 @@ class SmplResNet(nn.Module):
else:
state_dict[key] = state_dict_old[key]
else:
- raise RuntimeError(
- 'No state_dict found in checkpoint file {}'.format(pretrained))
+ raise RuntimeError('No state_dict found in checkpoint file {}'.format(pretrained))
self.load_state_dict(state_dict, strict=False)
else:
logger.error('=> imagenet pretrained model dose not exist')
@@ -388,7 +433,6 @@ class SmplResNet(nn.Module):
class LimbResLayers(nn.Module):
-
def __init__(self, resnet_nums, inplanes, outplanes=None, groups=1, **kwargs):
super().__init__()
@@ -407,8 +451,14 @@ class LimbResLayers(nn.Module):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
- nn.Conv2d(self.inplanes * groups, planes * block.expansion * groups,
- kernel_size=1, stride=stride, bias=False, groups=groups),
+ nn.Conv2d(
+ self.inplanes * groups,
+ planes * block.expansion * groups,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ groups=groups
+ ),
nn.BatchNorm2d(planes * block.expansion * groups, momentum=BN_MOMENTUM),
)
diff --git a/lib/pymafx/models/smpl.py b/lib/pymafx/models/smpl.py
index 2ac5405473b792eb5e89ebe30328461c619354e3..0a69eaf24e518545542cdd1eb55c819a549e8d55 100644
--- a/lib/pymafx/models/smpl.py
+++ b/lib/pymafx/models/smpl.py
@@ -19,6 +19,7 @@ from lib.pymafx.core import path_config, constants
SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS
SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR
+
@dataclass
class ModelOutput(SMPLXOutput):
smpl_joints: Optional[torch.Tensor] = None
@@ -33,16 +34,31 @@ class ModelOutput(SMPLXOutput):
lfoot_joints: Optional[torch.Tensor] = None
rfoot_joints: Optional[torch.Tensor] = None
+
class SMPL(_SMPL):
""" Extension of the official SMPL implementation to support more joints """
- def __init__(self, create_betas=False, create_global_orient=False, create_body_pose=False, create_transl=False, *args, **kwargs):
- super().__init__(create_betas=create_betas,
- create_global_orient=create_global_orient,
- create_body_pose=create_body_pose,
- create_transl=create_transl, *args, **kwargs)
+ def __init__(
+ self,
+ create_betas=False,
+ create_global_orient=False,
+ create_body_pose=False,
+ create_transl=False,
+ *args,
+ **kwargs
+ ):
+ super().__init__(
+ create_betas=create_betas,
+ create_global_orient=create_global_orient,
+ create_body_pose=create_body_pose,
+ create_transl=create_transl,
+ *args,
+ **kwargs
+ )
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA)
- self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
+ self.register_buffer(
+ 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)
+ )
self.joint_map = torch.tensor(joints, dtype=torch.long)
# self.ModelOutput = namedtuple('ModelOutput_', ModelOutput._fields + ('smpl_joints', 'joints_J19',))
# self.ModelOutput.__new__.__defaults__ = (None,) * len(self.ModelOutput._fields)
@@ -58,17 +74,19 @@ class SMPL(_SMPL):
vertices = smpl_output.vertices
joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
smpl_joints = smpl_output.joints[:, :24]
- joints = joints[:, self.joint_map, :] # [B, 49, 3]
+ joints = joints[:, self.joint_map, :] # [B, 49, 3]
joints_J24 = joints[:, -24:, :]
joints_J19 = joints_J24[:, constants.J24_TO_J19, :]
- output = ModelOutput(vertices=vertices,
- global_orient=smpl_output.global_orient,
- body_pose=smpl_output.body_pose,
- joints=joints,
- joints_J19=joints_J19,
- smpl_joints=smpl_joints,
- betas=smpl_output.betas,
- full_pose=smpl_output.full_pose)
+ output = ModelOutput(
+ vertices=vertices,
+ global_orient=smpl_output.global_orient,
+ body_pose=smpl_output.body_pose,
+ joints=joints,
+ joints_J19=joints_J19,
+ smpl_joints=smpl_joints,
+ betas=smpl_output.betas,
+ full_pose=smpl_output.full_pose
+ )
return output
def get_global_rotation(
@@ -107,18 +125,20 @@ class SMPL(_SMPL):
batch_size = max(batch_size, len(var))
if global_orient is None:
- global_orient = torch.eye(3, device=device, dtype=dtype).view(
- 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ global_orient = torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
+ -1).contiguous()
if body_pose is None:
- body_pose = torch.eye(3, device=device, dtype=dtype).view(
- 1, 1, 3, 3).expand(
- batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous()
+ body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(
+ batch_size, self.NUM_BODY_JOINTS, -1, -1
+ ).contiguous()
# Concatenate all pose vectors
full_pose = torch.cat(
[global_orient.reshape(-1, 1, 3, 3),
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)],
- dim=1)
+ dim=1
+ )
rot_mats = full_pose.view(batch_size, -1, 3, 3)
@@ -132,16 +152,15 @@ class SMPL(_SMPL):
rel_joints = joints.clone()
rel_joints[:, 1:] -= joints[:, self.parents[1:]]
- transforms_mat = transform_mat(
- rot_mats.reshape(-1, 3, 3),
- rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
+ transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
+ rel_joints.reshape(-1, 3,
+ 1)).reshape(-1, joints.shape[1], 4, 4)
transform_chain = [transforms_mat[:, 0]]
for i in range(1, self.parents.shape[0]):
# Subtract the joint location at the rest pose
# No need for rotation, since it's identity when at rest
- curr_res = torch.matmul(transform_chain[self.parents[i]],
- transforms_mat[:, i])
+ curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i])
transform_chain.append(curr_res)
transforms = torch.stack(transform_chain, dim=1)
@@ -230,60 +249,72 @@ class SMPLX(SMPLXLayer):
batch_size = max(batch_size, len(var))
if global_orient is None:
- global_orient = torch.eye(3, device=device, dtype=dtype).view(
- 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ global_orient = torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
+ -1).contiguous()
if body_pose is None:
- body_pose = torch.eye(3, device=device, dtype=dtype).view(
- 1, 1, 3, 3).expand(
- batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous()
+ body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(
+ batch_size, self.NUM_BODY_JOINTS, -1, -1
+ ).contiguous()
if left_hand_pose is None:
- left_hand_pose = torch.eye(3, device=device, dtype=dtype).view(
- 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
+ left_hand_pose = torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1,
+ -1).contiguous()
if right_hand_pose is None:
- right_hand_pose = torch.eye(3, device=device, dtype=dtype).view(
- 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
+ right_hand_pose = torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3,
+ 3).expand(batch_size, 15, -1,
+ -1).contiguous()
if jaw_pose is None:
- jaw_pose = torch.eye(3, device=device, dtype=dtype).view(
- 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ jaw_pose = torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
+ -1).contiguous()
if leye_pose is None:
- leye_pose = torch.eye(3, device=device, dtype=dtype).view(
- 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ leye_pose = torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
+ -1).contiguous()
if reye_pose is None:
- reye_pose = torch.eye(3, device=device, dtype=dtype).view(
- 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ reye_pose = torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
+ -1).contiguous()
# Concatenate all pose vectors
full_pose = torch.cat(
- [global_orient.reshape(-1, 1, 3, 3),
- body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
- jaw_pose.reshape(-1, 1, 3, 3),
- leye_pose.reshape(-1, 1, 3, 3),
- reye_pose.reshape(-1, 1, 3, 3),
- left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
- right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)],
- dim=1)
-
+ [
+ global_orient.reshape(-1, 1, 3, 3),
+ body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
+ jaw_pose.reshape(-1, 1, 3, 3),
+ leye_pose.reshape(-1, 1, 3, 3),
+ reye_pose.reshape(-1, 1, 3, 3),
+ left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
+ right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)
+ ],
+ dim=1
+ )
+
rot_mats = full_pose.view(batch_size, -1, 3, 3)
# Get the joints
# NxJx3 array
- joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0).expand(batch_size, -1, -1))
+ joints = vertices2joints(
+ self.J_regressor,
+ self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+ )
joints = torch.unsqueeze(joints, dim=-1)
rel_joints = joints.clone()
rel_joints[:, 1:] -= joints[:, self.parents[1:]]
- transforms_mat = transform_mat(
- rot_mats.reshape(-1, 3, 3),
- rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
+ transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
+ rel_joints.reshape(-1, 3,
+ 1)).reshape(-1, joints.shape[1], 4, 4)
transform_chain = [transforms_mat[:, 0]]
for i in range(1, self.parents.shape[0]):
# Subtract the joint location at the rest pose
# No need for rotation, since it's identity when at rest
- curr_res = torch.matmul(transform_chain[self.parents[i]],
- transforms_mat[:, i])
+ curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i])
transform_chain.append(curr_res)
transforms = torch.stack(transform_chain, dim=1)
@@ -298,7 +329,6 @@ class SMPLX(SMPLXLayer):
class SMPLX_ALL(nn.Module):
""" Extension of the official SMPLX implementation to support more joints """
-
def __init__(self, batch_size=1, use_face_contour=True, all_gender=False, **kwargs):
super().__init__()
numBetas = 10
@@ -309,45 +339,72 @@ class SMPLX_ALL(nn.Module):
self.genders = ['neutral']
for gender in self.genders:
assert gender in ['male', 'female', 'neutral']
- self.model_dict = nn.ModuleDict({gender: SMPLX(path_config.SMPL_MODEL_DIR,
- gender=gender,
- ext='npz',
- num_betas=numBetas,
- use_pca=False, batch_size=batch_size, use_face_contour=use_face_contour, num_pca_comps=45, **kwargs)
- for gender in self.genders})
+ self.model_dict = nn.ModuleDict(
+ {
+ gender: SMPLX(
+ path_config.SMPL_MODEL_DIR,
+ gender=gender,
+ ext='npz',
+ num_betas=numBetas,
+ use_pca=False,
+ batch_size=batch_size,
+ use_face_contour=use_face_contour,
+ num_pca_comps=45,
+ **kwargs
+ )
+ for gender in self.genders
+ }
+ )
self.model_neutral = self.model_dict['neutral']
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA)
- self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
+ self.register_buffer(
+ 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)
+ )
self.joint_map = torch.tensor(joints, dtype=torch.long)
# smplx_to_smpl.pkl, file source: https://smpl-x.is.tue.mpg.de
- smplx_to_smpl = pickle.load(open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb'))
- self.register_buffer('smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32))
+ smplx_to_smpl = pickle.load(
+ open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb')
+ )
+ self.register_buffer(
+ 'smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32)
+ )
smpl2limb_vert_faces = get_partial_smpl('smpl')
self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long()
self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long()
# left and right hand joint mapping
- smplx2lhand_joints = [constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES]
- smplx2rhand_joints = [constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES]
+ smplx2lhand_joints = [
+ constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES
+ ]
+ smplx2rhand_joints = [
+ constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES
+ ]
self.smplx2lh_joint_map = torch.tensor(smplx2lhand_joints, dtype=torch.long)
self.smplx2rh_joint_map = torch.tensor(smplx2rhand_joints, dtype=torch.long)
# left and right foot joint mapping
- smplx2lfoot_joints = [constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES]
- smplx2rfoot_joints = [constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES]
+ smplx2lfoot_joints = [
+ constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES
+ ]
+ smplx2rfoot_joints = [
+ constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES
+ ]
self.smplx2lf_joint_map = torch.tensor(smplx2lfoot_joints, dtype=torch.long)
self.smplx2rf_joint_map = torch.tensor(smplx2rfoot_joints, dtype=torch.long)
for g in self.genders:
- J_template = torch.einsum('ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template])
- J_dirs = torch.einsum('ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs])
+ J_template = torch.einsum(
+ 'ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template]
+ )
+ J_dirs = torch.einsum(
+ 'ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs]
+ )
self.register_buffer(f'{g}_J_template', J_template)
self.register_buffer(f'{g}_J_dirs', J_dirs)
-
def forward(self, *args, **kwargs):
batch_size = kwargs['body_pose'].shape[0]
kwargs['get_skin'] = True
@@ -357,7 +414,10 @@ class SMPLX_ALL(nn.Module):
kwargs['gender'] = 2 * torch.ones(batch_size).to(kwargs['body_pose'].device)
# pose for 55 joints: 1, 21, 15, 15, 1, 1, 1
- pose_keys = ['global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose', 'leye_pose', 'reye_pose']
+ pose_keys = [
+ 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose',
+ 'leye_pose', 'reye_pose'
+ ]
param_keys = ['betas'] + pose_keys
if kwargs['pose2rot']:
for key in pose_keys:
@@ -366,7 +426,9 @@ class SMPLX_ALL(nn.Module):
# kwargs[key] += self.model_neutral.left_hand_mean
# elif key == 'right_hand_pose':
# kwargs[key] += self.model_neutral.right_hand_mean
- kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([batch_size, -1, 3, 3])
+ kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view(
+ [batch_size, -1, 3, 3]
+ )
if kwargs['body_pose'].shape[1] == 23:
# remove hand pose in the body_pose
kwargs['body_pose'] = kwargs['body_pose'][:, :21]
@@ -406,26 +468,27 @@ class SMPLX_ALL(nn.Module):
smplx_j45 = smplx_joints[:, constants.SMPLX2SMPL_J45]
joints = torch.cat([smplx_j45, extra_joints], dim=1)
smpl_joints = smplx_j45[:, :24]
- joints = joints[:, self.joint_map, :] # [B, 49, 3]
+ joints = joints[:, self.joint_map, :] # [B, 49, 3]
joints_J24 = joints[:, -24:, :]
joints_J19 = joints_J24[:, constants.J24_TO_J19, :]
- output = ModelOutput(vertices=smpl_vertices,
- smplx_vertices=smplx_vertices,
- lhand_vertices=lhand_vertices,
- rhand_vertices=rhand_vertices,
- # global_orient=smplx_output.global_orient,
- # body_pose=smplx_output.body_pose,
- joints=joints,
- joints_J19=joints_J19,
- smpl_joints=smpl_joints,
- # betas=smplx_output.betas,
- # full_pose=smplx_output.full_pose,
- lhand_joints=lhand_joints,
- rhand_joints=rhand_joints,
- lfoot_joints=lfoot_joints,
- rfoot_joints=rfoot_joints,
- face_joints=face_joints,
- )
+ output = ModelOutput(
+ vertices=smpl_vertices,
+ smplx_vertices=smplx_vertices,
+ lhand_vertices=lhand_vertices,
+ rhand_vertices=rhand_vertices,
+ # global_orient=smplx_output.global_orient,
+ # body_pose=smplx_output.body_pose,
+ joints=joints,
+ joints_J19=joints_J19,
+ smpl_joints=smpl_joints,
+ # betas=smplx_output.betas,
+ # full_pose=smplx_output.full_pose,
+ lhand_joints=lhand_joints,
+ rhand_joints=rhand_joints,
+ lfoot_joints=lfoot_joints,
+ rfoot_joints=rfoot_joints,
+ face_joints=face_joints,
+ )
return output
# def make_hand_regressor(self):
@@ -467,7 +530,7 @@ class SMPLX_ALL(nn.Module):
kwargs['gender'] = 2 * torch.ones(batch_size).to(device)
else:
kwargs['gender'] = gender
-
+
param_keys = ['betas']
gender_idx_list = []
@@ -480,7 +543,9 @@ class SMPLX_ALL(nn.Module):
gender_kwargs = {}
gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs})
- J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes(gender_kwargs['betas'], getattr(self, f'{g}_J_dirs'))
+ J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes(
+ gender_kwargs['betas'], getattr(self, f'{g}_J_dirs')
+ )
smplx_joints.append(J)
@@ -491,9 +556,10 @@ class SMPLX_ALL(nn.Module):
return smplx_joints
+
class MANO(MANOLayer):
""" Extension of the official MANO implementation to support more joints """
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
@@ -504,7 +570,9 @@ class MANO(MANOLayer):
if kwargs['pose2rot']:
for key in pose_keys:
if key in kwargs:
- kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([batch_size, -1, 3, 3])
+ kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view(
+ [batch_size, -1, 3, 3]
+ )
kwargs['hand_pose'] = kwargs.pop('right_hand_pose')
mano_output = super().forward(*args, **kwargs)
th_verts = mano_output.vertices
@@ -515,15 +583,18 @@ class MANO(MANOLayer):
tips = th_verts[:, [745, 317, 445, 556, 673]]
th_jtr = torch.cat([th_jtr, tips], 1)
# Reorder joints to match visualization utilities
- th_jtr = th_jtr[:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
- output = ModelOutput(rhand_vertices=th_verts,
- rhand_joints=th_jtr,
- )
+ th_jtr = th_jtr[:,
+ [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
+ output = ModelOutput(
+ rhand_vertices=th_verts,
+ rhand_joints=th_jtr,
+ )
return output
+
class FLAME(FLAMELayer):
""" Extension of the official FLAME implementation to support more joints """
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
@@ -534,30 +605,33 @@ class FLAME(FLAMELayer):
if kwargs['pose2rot']:
for key in pose_keys:
if key in kwargs:
- kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([batch_size, -1, 3, 3])
+ kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view(
+ [batch_size, -1, 3, 3]
+ )
flame_output = super().forward(*args, **kwargs)
- output = ModelOutput(flame_vertices=flame_output.vertices,
- face_joints=flame_output.joints[:, 5:],
- )
+ output = ModelOutput(
+ flame_vertices=flame_output.vertices,
+ face_joints=flame_output.joints[:, 5:],
+ )
return output
+
class SMPL_Family():
def __init__(self, model_type='smpl', *args, **kwargs):
if model_type == 'smpl':
- self.model = SMPL(
- model_path=SMPL_MODEL_DIR,
- *args, **kwargs
- )
+ self.model = SMPL(model_path=SMPL_MODEL_DIR, *args, **kwargs)
elif model_type == 'smplx':
self.model = SMPLX_ALL(*args, **kwargs)
elif model_type == 'mano':
- self.model = MANO(model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs)
+ self.model = MANO(
+ model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs
+ )
elif model_type == 'flame':
self.model = FLAME(model_path=SMPL_MODEL_DIR, use_face_contour=True, *args, **kwargs)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
-
+
def get_tpose(self, *args, **kwargs):
return self.model.get_tpose(*args, **kwargs)
@@ -570,14 +644,17 @@ class SMPL_Family():
# else:
# self.model.cuda(device)
+
def get_smpl_faces():
smpl = SMPL(model_path=SMPL_MODEL_DIR, batch_size=1)
return smpl.faces
+
def get_smplx_faces():
smplx = SMPLX(SMPL_MODEL_DIR, batch_size=1)
return smplx.faces
+
def get_mano_faces(hand_type='right'):
assert hand_type in ['right', 'left']
is_rhand = True if hand_type == 'right' else False
@@ -585,11 +662,13 @@ def get_mano_faces(hand_type='right'):
return mano.faces
+
def get_flame_faces():
flame = FLAME(SMPL_MODEL_DIR, batch_size=1)
return flame.faces
+
def get_model_faces(type='smpl'):
if type == 'smpl':
return get_smpl_faces()
@@ -600,6 +679,7 @@ def get_model_faces(type='smpl'):
elif type == 'flame':
return get_flame_faces()
+
def get_model_tpose(type='smpl'):
if type == 'smpl':
return get_smpl_tpose()
@@ -610,43 +690,64 @@ def get_model_tpose(type='smpl'):
elif type == 'flame':
return get_flame_tpose()
+
def get_smpl_tpose():
- smpl = SMPL(create_betas=True, create_global_orient=True, create_body_pose=True, model_path=SMPL_MODEL_DIR, batch_size=1)
+ smpl = SMPL(
+ create_betas=True,
+ create_global_orient=True,
+ create_body_pose=True,
+ model_path=SMPL_MODEL_DIR,
+ batch_size=1
+ )
vertices = smpl().vertices[0]
return vertices.detach()
+
def get_smpl_tpose_joint():
- smpl = SMPL(create_betas=True, create_global_orient=True, create_body_pose=True, model_path=SMPL_MODEL_DIR, batch_size=1)
+ smpl = SMPL(
+ create_betas=True,
+ create_global_orient=True,
+ create_body_pose=True,
+ model_path=SMPL_MODEL_DIR,
+ batch_size=1
+ )
tpose_joint = smpl().smpl_joints[0]
return tpose_joint.detach()
+
def get_smplx_tpose():
smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1)
vertices = smplx().vertices[0]
return vertices
+
def get_smplx_tpose_joint():
smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1)
tpose_joint = smplx().joints[0]
return tpose_joint
+
def get_mano_tpose():
mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=True)
- vertices = mano(global_orient=torch.zeros(1, 3),
- right_hand_pose=torch.zeros(1, 15*3)).rhand_vertices[0]
+ vertices = mano(global_orient=torch.zeros(1, 3),
+ right_hand_pose=torch.zeros(1, 15 * 3)).rhand_vertices[0]
return vertices
+
def get_flame_tpose():
flame = FLAME(SMPL_MODEL_DIR, batch_size=1)
vertices = flame(global_orient=torch.zeros(1, 3)).flame_vertices[0]
return vertices
+
def get_part_joints(smpl_joints):
batch_size = smpl_joints.shape[0]
# part_joints = torch.zeros().to(smpl_joints.device)
- one_seg_pairs = [(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17)]
+ one_seg_pairs = [
+ (0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17)
+ ]
two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)]
one_seg_pairs.extend(two_seg_pairs)
@@ -660,12 +761,13 @@ def get_part_joints(smpl_joints):
part_joints.append(new_joint)
for j_p in single_joints:
- part_joints.append(smpl_joints[:, j_p:j_p+1])
+ part_joints.append(smpl_joints[:, j_p:j_p + 1])
part_joints = torch.cat(part_joints, dim=1)
return part_joints
+
def get_partial_smpl(body_model='smpl', device=torch.device('cuda')):
body_model_faces = get_model_faces(body_model)
@@ -680,9 +782,13 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')):
part_vert_faces[part] = {'vids': part_vids['vids'], 'faces': part_vids['faces']}
else:
if part in ['lhand', 'rhand']:
- with open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb') as json_file:
+ with open(
+ os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb'
+ ) as json_file:
smplx_mano_id = pickle.load(json_file)
- with open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb') as json_file:
+ with open(
+ os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb'
+ ) as json_file:
smplx_smpl_id = pickle.load(json_file)
smplx_tpose = get_smplx_tpose()
@@ -701,13 +807,17 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')):
smpl2mano_id.append(int(v_closest))
smpl2mano_vids = np.array(smpl2mano_id).astype(np.long)
- mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left').astype(np.long)
+ mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left'
+ ).astype(np.long)
np.savez(part_vid_fname, vids=smpl2mano_vids, faces=mano_faces)
part_vert_faces[part] = {'vids': smpl2mano_vids, 'faces': mano_faces}
elif part in ['face', 'arm', 'forearm', 'larm', 'rarm']:
- with open(os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), 'rb') as json_file:
+ with open(
+ os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)),
+ 'rb'
+ ) as json_file:
smplx_part_id = json.load(json_file)
# main_body_part = list(smplx_part_id.keys())
@@ -716,12 +826,30 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')):
if part == 'face':
selected_body_part = ['head']
elif part == 'arm':
- selected_body_part = ['rightHand', 'leftArm', 'leftShoulder', 'rightShoulder', 'rightArm', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',]
+ selected_body_part = [
+ 'rightHand',
+ 'leftArm',
+ 'leftShoulder',
+ 'rightShoulder',
+ 'rightArm',
+ 'leftHandIndex1',
+ 'rightHandIndex1',
+ 'leftForeArm',
+ 'rightForeArm',
+ 'leftHand',
+ ]
# selected_body_part = ['rightHand', 'leftArm', 'rightArm', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',]
elif part == 'forearm':
- selected_body_part = ['rightHand', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',]
+ selected_body_part = [
+ 'rightHand',
+ 'leftHandIndex1',
+ 'rightHandIndex1',
+ 'leftForeArm',
+ 'rightForeArm',
+ 'leftHand',
+ ]
elif part == 'arm_eval':
- selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm']
+ selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm']
elif part == 'larm':
# selected_body_part = ['leftArm', 'leftForeArm']
selected_body_part = ['leftForeArm']
@@ -749,7 +877,7 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')):
np.savez(part_vid_fname, vids=smpl2head_vids, faces=head_faces)
part_vert_faces[part] = {'vids': smpl2head_vids, 'faces': head_faces}
-
+
elif part in ['lwrist', 'rwrist']:
if body_model == 'smplx':
@@ -765,11 +893,11 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')):
wrist_vids = []
for vid, vt in enumerate(body_model_verts):
- v_j_dist = torch.sum((vt - wrist_joint) ** 2)
+ v_j_dist = torch.sum((vt - wrist_joint)**2)
if v_j_dist < dist:
wrist_vids.append(vid)
-
+
wrist_vids = np.array(wrist_vids)
part_body_fid = []
diff --git a/lib/pymafx/models/transformers/bert/__init__.py b/lib/pymafx/models/transformers/bert/__init__.py
index 4e18c1b84248ce3d4425cad34e5bb0f08fe32fec..0432a1e92856c438e5fd2f550dc5029a78fa354c 100644
--- a/lib/pymafx/models/transformers/bert/__init__.py
+++ b/lib/pymafx/models/transformers/bert/__init__.py
@@ -1,8 +1,9 @@
__version__ = "1.0.0"
-from .modeling_bert import (BertConfig, BertModel,
- load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
- BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
+from .modeling_bert import (
+ BertConfig, BertModel, load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
+)
from .modeling_graphormer import Graphormer
@@ -10,7 +11,9 @@ from .modeling_graphormer import Graphormer
# from .e2e_hand_network import Graphormer_Hand_Network
-from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
- PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
+from .modeling_utils import (
+ WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_layer,
+ Conv1D
+)
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
diff --git a/lib/pymafx/models/transformers/bert/e2e_body_network.py b/lib/pymafx/models/transformers/bert/e2e_body_network.py
index 22faf2f7c3a3a58047d5553c179d332827f6daa6..9d1c75e276aa18fa1e8f2d865cbef7a275f71b8c 100644
--- a/lib/pymafx/models/transformers/bert/e2e_body_network.py
+++ b/lib/pymafx/models/transformers/bert/e2e_body_network.py
@@ -7,6 +7,7 @@ Licensed under the MIT license.
import torch
import src.modeling.data.config as cfg
+
class Graphormer_Body_Network(torch.nn.Module):
'''
End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
@@ -24,25 +25,27 @@ class Graphormer_Body_Network(torch.nn.Module):
self.cam_param_fc3 = torch.nn.Linear(250, 3)
self.grid_feat_dim = torch.nn.Linear(1024, 2051)
-
def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False):
batch_size = images.size(0)
# Generate T-pose template mesh
- template_pose = torch.zeros((1,72))
- template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
+ template_pose = torch.zeros((1, 72))
+ template_pose[:, 0] = 3.1416 # Rectify "upside down" reference mesh in global coord
template_pose = template_pose.cuda(self.config.device)
- template_betas = torch.zeros((1,10)).cuda(self.config.device)
+ template_betas = torch.zeros((1, 10)).cuda(self.config.device)
template_vertices = smpl(template_pose, template_betas)
# template mesh simplification
template_vertices_sub = mesh_sampler.downsample(template_vertices)
template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2)
- print('template_vertices', template_vertices.shape, template_vertices_sub.shape, template_vertices_sub2.shape)
+ print(
+ 'template_vertices', template_vertices.shape, template_vertices_sub.shape,
+ template_vertices_sub2.shape
+ )
- # template mesh-to-joint regression
+ # template mesh-to-joint regression
template_3d_joints = smpl.get_h36m_joints(template_vertices)
- template_pelvis = template_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
- template_3d_joints = template_3d_joints[:,cfg.H36M_J17_TO_J14,:]
+ template_pelvis = template_3d_joints[:, cfg.H36M_J17_NAME.index('Pelvis'), :]
+ template_3d_joints = template_3d_joints[:, cfg.H36M_J17_TO_J14, :]
num_joints = template_3d_joints.shape[1]
# normalize
@@ -50,7 +53,7 @@ class Graphormer_Body_Network(torch.nn.Module):
template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :]
# concatinate template joints and template vertices, and then duplicate to batch size
- ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2],dim=1)
+ ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2], dim=1)
ref_vertices = ref_vertices.expand(batch_size, -1, -1)
print('ref_vertices', ref_vertices.shape)
@@ -62,7 +65,7 @@ class Graphormer_Body_Network(torch.nn.Module):
print('image_feat', image_feat.shape)
# process grid features
grid_feat = torch.flatten(grid_feat, start_dim=2)
- grid_feat = grid_feat.transpose(1,2)
+ grid_feat = grid_feat.transpose(1, 2)
print('grid_feat bf', grid_feat.shape)
grid_feat = self.grid_feat_dim(grid_feat)
print('grid_feat', grid_feat.shape)
@@ -70,42 +73,43 @@ class Graphormer_Body_Network(torch.nn.Module):
features = torch.cat([ref_vertices, image_feat], dim=2)
print('features', features.shape, ref_vertices.shape, image_feat.shape)
# prepare input tokens including joint/vertex queries and grid features
- features = torch.cat([features, grid_feat],dim=1)
+ features = torch.cat([features, grid_feat], dim=1)
print('features', features.shape)
- if is_train==True:
+ if is_train == True:
# apply mask vertex/joint modeling
# meta_masks is a tensor of all the masks, randomly generated in dataloader
# we pre-define a [MASK] token, which is a floating-value vector with 0.01s
- special_token = torch.ones_like(features[:,:-49,:]).cuda()*0.01
+ special_token = torch.ones_like(features[:, :-49, :]).cuda() * 0.01
print('special_token', special_token.shape, meta_masks.shape)
print('meta_masks', torch.unique(meta_masks))
- features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
+ features[:, :-49, :
+ ] = features[:, :-49, :] * meta_masks + special_token * (1 - meta_masks)
# forward pass
- if self.config.output_attentions==True:
+ if self.config.output_attentions == True:
features, hidden_states, att = self.trans_encoder(features)
else:
features = self.trans_encoder(features)
- pred_3d_joints = features[:,:num_joints,:]
- pred_vertices_sub2 = features[:,num_joints:-49,:]
+ pred_3d_joints = features[:, :num_joints, :]
+ pred_vertices_sub2 = features[:, num_joints:-49, :]
# learn camera parameters
x = self.cam_param_fc(pred_vertices_sub2)
- x = x.transpose(1,2)
+ x = x.transpose(1, 2)
x = self.cam_param_fc2(x)
x = self.cam_param_fc3(x)
- cam_param = x.transpose(1,2)
+ cam_param = x.transpose(1, 2)
cam_param = cam_param.squeeze()
- temp_transpose = pred_vertices_sub2.transpose(1,2)
+ temp_transpose = pred_vertices_sub2.transpose(1, 2)
pred_vertices_sub = self.upsampling(temp_transpose)
pred_vertices_full = self.upsampling2(pred_vertices_sub)
- pred_vertices_sub = pred_vertices_sub.transpose(1,2)
- pred_vertices_full = pred_vertices_full.transpose(1,2)
+ pred_vertices_sub = pred_vertices_sub.transpose(1, 2)
+ pred_vertices_full = pred_vertices_full.transpose(1, 2)
- if self.config.output_attentions==True:
+ if self.config.output_attentions == True:
return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att
else:
- return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full
\ No newline at end of file
+ return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full
diff --git a/lib/pymafx/models/transformers/bert/e2e_hand_network.py b/lib/pymafx/models/transformers/bert/e2e_hand_network.py
index 7030d0f6f1ec7e6741d4c5f67b2e78eaad708a5f..410968c4abc63e1ae8281b2e0297c8eef4e7bbcf 100644
--- a/lib/pymafx/models/transformers/bert/e2e_hand_network.py
+++ b/lib/pymafx/models/transformers/bert/e2e_hand_network.py
@@ -7,6 +7,7 @@ Licensed under the MIT license.
import torch
import src.modeling.data.config as cfg
+
class Graphormer_Hand_Network(torch.nn.Module):
'''
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
@@ -18,31 +19,31 @@ class Graphormer_Hand_Network(torch.nn.Module):
self.trans_encoder = trans_encoder
self.upsampling = torch.nn.Linear(195, 778)
self.cam_param_fc = torch.nn.Linear(3, 1)
- self.cam_param_fc2 = torch.nn.Linear(195+21, 150)
+ self.cam_param_fc2 = torch.nn.Linear(195 + 21, 150)
self.cam_param_fc3 = torch.nn.Linear(150, 3)
self.grid_feat_dim = torch.nn.Linear(1024, 2051)
def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False):
batch_size = images.size(0)
# Generate T-pose template mesh
- template_pose = torch.zeros((1,48))
+ template_pose = torch.zeros((1, 48))
template_pose = template_pose.cuda()
- template_betas = torch.zeros((1,10)).cuda()
+ template_betas = torch.zeros((1, 10)).cuda()
template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas)
- template_vertices = template_vertices/1000.0
- template_3d_joints = template_3d_joints/1000.0
+ template_vertices = template_vertices / 1000.0
+ template_3d_joints = template_3d_joints / 1000.0
template_vertices_sub = mesh_sampler.downsample(template_vertices)
# normalize
- template_root = template_3d_joints[:,cfg.J_NAME.index('Wrist'),:]
+ template_root = template_3d_joints[:, cfg.J_NAME.index('Wrist'), :]
template_3d_joints = template_3d_joints - template_root[:, None, :]
template_vertices = template_vertices - template_root[:, None, :]
template_vertices_sub = template_vertices_sub - template_root[:, None, :]
num_joints = template_3d_joints.shape[1]
# concatinate template joints and template vertices, and then duplicate to batch size
- ref_vertices = torch.cat([template_3d_joints, template_vertices_sub],dim=1)
+ ref_vertices = torch.cat([template_3d_joints, template_vertices_sub], dim=1)
ref_vertices = ref_vertices.expand(batch_size, -1, -1)
# extract grid features and global image features using a CNN backbone
@@ -51,42 +52,43 @@ class Graphormer_Hand_Network(torch.nn.Module):
image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1)
# process grid features
grid_feat = torch.flatten(grid_feat, start_dim=2)
- grid_feat = grid_feat.transpose(1,2)
+ grid_feat = grid_feat.transpose(1, 2)
grid_feat = self.grid_feat_dim(grid_feat)
# concatinate image feat and template mesh to form the joint/vertex queries
features = torch.cat([ref_vertices, image_feat], dim=2)
# prepare input tokens including joint/vertex queries and grid features
- features = torch.cat([features, grid_feat],dim=1)
+ features = torch.cat([features, grid_feat], dim=1)
- if is_train==True:
+ if is_train == True:
# apply mask vertex/joint modeling
# meta_masks is a tensor of all the masks, randomly generated in dataloader
- # we pre-define a [MASK] token, which is a floating-value vector with 0.01s
- special_token = torch.ones_like(features[:,:-49,:]).cuda()*0.01
- features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
+ # we pre-define a [MASK] token, which is a floating-value vector with 0.01s
+ special_token = torch.ones_like(features[:, :-49, :]).cuda() * 0.01
+ features[:, :-49, :
+ ] = features[:, :-49, :] * meta_masks + special_token * (1 - meta_masks)
# forward pass
- if self.config.output_attentions==True:
+ if self.config.output_attentions == True:
features, hidden_states, att = self.trans_encoder(features)
else:
features = self.trans_encoder(features)
- pred_3d_joints = features[:,:num_joints,:]
- pred_vertices_sub = features[:,num_joints:-49,:]
+ pred_3d_joints = features[:, :num_joints, :]
+ pred_vertices_sub = features[:, num_joints:-49, :]
# learn camera parameters
- x = self.cam_param_fc(features[:,:-49,:])
- x = x.transpose(1,2)
+ x = self.cam_param_fc(features[:, :-49, :])
+ x = x.transpose(1, 2)
x = self.cam_param_fc2(x)
x = self.cam_param_fc3(x)
- cam_param = x.transpose(1,2)
+ cam_param = x.transpose(1, 2)
cam_param = cam_param.squeeze()
- temp_transpose = pred_vertices_sub.transpose(1,2)
+ temp_transpose = pred_vertices_sub.transpose(1, 2)
pred_vertices = self.upsampling(temp_transpose)
- pred_vertices = pred_vertices.transpose(1,2)
+ pred_vertices = pred_vertices.transpose(1, 2)
- if self.config.output_attentions==True:
+ if self.config.output_attentions == True:
return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att
else:
- return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices
\ No newline at end of file
+ return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices
diff --git a/lib/pymafx/models/transformers/bert/file_utils.py b/lib/pymafx/models/transformers/bert/file_utils.py
index fd655cec0ed897d5abaea8289a6395aaa672d767..ee58bed427f90be254caee9a0733d81ae92c8711 100644
--- a/lib/pymafx/models/transformers/bert/file_utils.py
+++ b/lib/pymafx/models/transformers/bert/file_utils.py
@@ -26,8 +26,8 @@ try:
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
- os.getenv('TORCH_HOME', os.path.join(
- os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
+ os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))
+ )
default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers')
try:
@@ -38,12 +38,12 @@ except ImportError:
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(
- os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
+ os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)
+ )
except (AttributeError, ImportError):
- PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
- default_cache_path)
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)
-logger = logging.getLogger(__name__) # pylint: disable=invalid-name
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def url_to_filename(url, etag=None):
@@ -138,7 +138,6 @@ def s3_request(func):
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
-
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
@@ -175,7 +174,7 @@ def http_get(url, temp_file):
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
- if chunk: # filter out keep-alive new chunks
+ if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
@@ -251,7 +250,7 @@ def get_from_cache(url, cache_dir=None):
with open(meta_path, 'w') as meta_file:
output_string = json.dumps(meta)
if sys.version_info[0] == 2 and isinstance(output_string, str):
- output_string = unicode(output_string, 'utf-8') # The beauty of python 2
+ output_string = unicode(output_string, 'utf-8') # The beauty of python 2
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
diff --git a/lib/pymafx/models/transformers/bert/modeling_bert.py b/lib/pymafx/models/transformers/bert/modeling_bert.py
index 738ebe6ec5f0697d0ec526fab4973489d01afd8e..c4a7f27f1bc0e69d87ac3747b8d8acfafb03b4b8 100644
--- a/lib/pymafx/models/transformers/bert/modeling_bert.py
+++ b/lib/pymafx/models/transformers/bert/modeling_bert.py
@@ -28,41 +28,69 @@ import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
-from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel,
- prune_linear_layer, add_start_docstrings)
+from .modeling_utils import (
+ WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer,
+ add_start_docstrings
+)
logger = logging.getLogger(__name__)
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
- 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
- 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
- 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
- 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
- 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
- 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
- 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
- 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
- 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
- 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
- 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
- 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
- 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
+ 'bert-base-uncased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
+ 'bert-large-uncased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
+ 'bert-base-cased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
+ 'bert-large-cased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
+ 'bert-base-multilingual-uncased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
+ 'bert-base-multilingual-cased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
+ 'bert-base-chinese':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
+ 'bert-base-german-cased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking-finetuned-squad':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking-finetuned-squad':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-base-cased-finetuned-mrpc':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
}
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
- 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
- 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
- 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
- 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
- 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
- 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
- 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
- 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
- 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
- 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
- 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
- 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
+ 'bert-base-uncased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
+ 'bert-large-uncased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
+ 'bert-base-cased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
+ 'bert-large-cased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
+ 'bert-base-multilingual-uncased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
+ 'bert-base-multilingual-cased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
+ 'bert-base-chinese':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
+ 'bert-base-german-cased':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
+ 'bert-large-uncased-whole-word-masking':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
+ 'bert-large-cased-whole-word-masking':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
+ 'bert-large-uncased-whole-word-masking-finetuned-squad':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
+ 'bert-large-cased-whole-word-masking-finetuned-squad':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
+ 'bert-base-cased-finetuned-mrpc':
+ "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
}
@@ -74,8 +102,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
import numpy as np
import tensorflow as tf
except ImportError:
- logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
- "https://www.tensorflow.org/install/ for installation instructions.")
+ logger.error(
+ "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
@@ -180,23 +210,26 @@ class BertConfig(PretrainedConfig):
"""
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
- def __init__(self,
- vocab_size_or_config_json_file=30522,
- hidden_size=768,
- num_hidden_layers=12,
- num_attention_heads=12,
- intermediate_size=3072,
- hidden_act="gelu",
- hidden_dropout_prob=0.1,
- attention_probs_dropout_prob=0.1,
- max_position_embeddings=512,
- type_vocab_size=2,
- initializer_range=0.02,
- layer_norm_eps=1e-12,
- **kwargs):
+ def __init__(
+ self,
+ vocab_size_or_config_json_file=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ **kwargs
+ ):
super(BertConfig, self).__init__(**kwargs)
- if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
- and isinstance(vocab_size_or_config_json_file, unicode)):
+ if isinstance(
+ vocab_size_or_config_json_file, str
+ ) or (sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
@@ -215,9 +248,10 @@ class BertConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
else:
- raise ValueError("First argument must be either a vocabulary size (int)"
- "or the path to a pretrained model config file (str)")
-
+ raise ValueError(
+ "First argument must be either a vocabulary size (int)"
+ "or the path to a pretrained model config file (str)"
+ )
# try:
@@ -240,6 +274,7 @@ class BertLayerNorm(nn.Module):
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
+
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
@@ -278,7 +313,8 @@ class BertSelfAttention(nn.Module):
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
- "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
@@ -325,10 +361,10 @@ class BertSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
context_layer = context_layer.view(*new_context_layer_shape)
- outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, )
return outputs
@@ -372,7 +408,7 @@ class BertAttention(nn.Module):
def forward(self, input_tensor, attention_mask, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask)
attention_output = self.output(self_outputs[0], input_tensor)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them
return outputs
@@ -380,7 +416,8 @@ class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
+ if isinstance(config.hidden_act, str
+ ) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
@@ -417,7 +454,7 @@ class BertLayer(nn.Module):
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
- outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
+ outputs = (layer_output, ) + attention_outputs[1:] # add attentions if we output them
return outputs
@@ -433,24 +470,24 @@ class BertEncoder(nn.Module):
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
+ all_hidden_states = all_hidden_states + (hidden_states, )
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
hidden_states = layer_outputs[0]
if self.output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
+ all_attentions = all_attentions + (layer_outputs[1], )
# Add last layer
if self.output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
+ all_hidden_states = all_hidden_states + (hidden_states, )
- outputs = (hidden_states,)
+ outputs = (hidden_states, )
if self.output_hidden_states:
- outputs = outputs + (all_hidden_states,)
+ outputs = outputs + (all_hidden_states, )
if self.output_attentions:
- outputs = outputs + (all_attentions,)
- return outputs # outputs, (hidden states), (attentions)
+ outputs = outputs + (all_attentions, )
+ return outputs # outputs, (hidden states), (attentions)
class BertPooler(nn.Module):
@@ -472,7 +509,8 @@ class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
+ if isinstance(config.hidden_act, str
+ ) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
@@ -492,9 +530,7 @@ class BertLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
- self.decoder = nn.Linear(config.hidden_size,
- config.vocab_size,
- bias=False)
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
@@ -620,8 +656,11 @@ BERT_INPUTS_DOCSTRING = r"""
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
-@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
- BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+
+@add_start_docstrings(
+ "The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
+)
class BertModel(BertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
@@ -675,7 +714,14 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
+ def forward(
+ self,
+ input_ids,
+ token_type_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None
+ ):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
@@ -693,7 +739,9 @@ class BertModel(BertPreTrainedModel):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
- extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=next(self.parameters()).dtype
+ ) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
@@ -706,25 +754,36 @@ class BertModel(BertPreTrainedModel):
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
- head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
- head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
+ -1
+ ) # We can specify head_mask for each layer
+ head_mask = head_mask.to(
+ dtype=next(self.parameters()).dtype
+ ) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
- embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
- encoder_outputs = self.encoder(embedding_output,
- extended_attention_mask,
- head_mask=head_mask)
+ embedding_output = self.embeddings(
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
+ )
+ encoder_outputs = self.encoder(
+ embedding_output, extended_attention_mask, head_mask=head_mask
+ )
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
- outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
- return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
+ outputs = (
+ sequence_output,
+ pooled_output,
+ ) + encoder_outputs[1:] # add hidden_states and attentions if they are here
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
-@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
+@add_start_docstrings(
+ """Bert Model with two heads on top as done during the pre-training:
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
- BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
+)
class BertForPreTraining(BertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
@@ -777,31 +836,54 @@ class BertForPreTraining(BertPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
- self._tie_or_clone_weights(self.cls.predictions.decoder,
- self.bert.embeddings.word_embeddings)
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
- next_sentence_label=None, position_ids=None, head_mask=None):
- outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
- attention_mask=attention_mask, head_mask=head_mask)
+ self._tie_or_clone_weights(
+ self.cls.predictions.decoder, self.bert.embeddings.word_embeddings
+ )
+
+ def forward(
+ self,
+ input_ids,
+ token_type_ids=None,
+ attention_mask=None,
+ masked_lm_labels=None,
+ next_sentence_label=None,
+ position_ids=None,
+ head_mask=None
+ ):
+ outputs = self.bert(
+ input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask
+ )
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
- outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
+ outputs = (
+ prediction_scores,
+ seq_relationship_score,
+ ) + outputs[2:] # add hidden states and attention if they are here
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
- next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)
+ )
+ next_sentence_loss = loss_fct(
+ seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)
+ )
total_loss = masked_lm_loss + next_sentence_loss
- outputs = (total_loss,) + outputs
+ outputs = (total_loss, ) + outputs
- return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
-@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
- BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+@add_start_docstrings(
+ """Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING,
+ BERT_INPUTS_DOCSTRING
+)
class BertForMaskedLM(BertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
@@ -847,28 +929,46 @@ class BertForMaskedLM(BertPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
- self._tie_or_clone_weights(self.cls.predictions.decoder,
- self.bert.embeddings.word_embeddings)
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
- position_ids=None, head_mask=None):
- outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
- attention_mask=attention_mask, head_mask=head_mask)
+ self._tie_or_clone_weights(
+ self.cls.predictions.decoder, self.bert.embeddings.word_embeddings
+ )
+
+ def forward(
+ self,
+ input_ids,
+ token_type_ids=None,
+ attention_mask=None,
+ masked_lm_labels=None,
+ position_ids=None,
+ head_mask=None
+ ):
+ outputs = self.bert(
+ input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask
+ )
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
- outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention is they are here
+ outputs = (prediction_scores,
+ ) + outputs[2:] # Add hidden states and attention is they are here
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
- outputs = (masked_lm_loss,) + outputs
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)
+ )
+ outputs = (masked_lm_loss, ) + outputs
- return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
-@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
- BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+@add_start_docstrings(
+ """Bert Model with a `next sentence prediction (classification)` head on top. """,
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
+)
class BertForNextSentencePrediction(BertPreTrainedModel):
r"""
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
@@ -909,26 +1009,42 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.apply(self.init_weights)
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
- position_ids=None, head_mask=None):
- outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
- attention_mask=attention_mask, head_mask=head_mask)
+ def forward(
+ self,
+ input_ids,
+ token_type_ids=None,
+ attention_mask=None,
+ next_sentence_label=None,
+ position_ids=None,
+ head_mask=None
+ ):
+ outputs = self.bert(
+ input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask
+ )
pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output)
- outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
+ outputs = (seq_relationship_score,
+ ) + outputs[2:] # add hidden states and attention if they are here
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
- next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
- outputs = (next_sentence_loss,) + outputs
+ next_sentence_loss = loss_fct(
+ seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)
+ )
+ outputs = (next_sentence_loss, ) + outputs
- return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
-@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
- the pooled output) e.g. for GLUE tasks. """,
- BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+@add_start_docstrings(
+ """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
+ the pooled output) e.g. for GLUE tasks. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
+)
class BertForSequenceClassification(BertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
@@ -972,16 +1088,28 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.apply(self.init_weights)
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
- position_ids=None, head_mask=None):
- outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
- attention_mask=attention_mask, head_mask=head_mask)
+ def forward(
+ self,
+ input_ids,
+ token_type_ids=None,
+ attention_mask=None,
+ labels=None,
+ position_ids=None,
+ head_mask=None
+ ):
+ outputs = self.bert(
+ input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask
+ )
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
- outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
+ outputs = (logits, ) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
if self.num_labels == 1:
@@ -991,14 +1119,15 @@ class BertForSequenceClassification(BertPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- outputs = (loss,) + outputs
+ outputs = (loss, ) + outputs
- return outputs # (loss), logits, (hidden_states), (attentions)
+ return outputs # (loss), logits, (hidden_states), (attentions)
-@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
- the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
- BERT_START_DOCSTRING)
+@add_start_docstrings(
+ """Bert Model with a multiple choice classification head on top (a linear layer on top of
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, BERT_START_DOCSTRING
+)
class BertForMultipleChoice(BertPreTrainedModel):
r"""
Inputs:
@@ -1078,35 +1207,56 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.apply(self.init_weights)
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
- position_ids=None, head_mask=None):
+ def forward(
+ self,
+ input_ids,
+ token_type_ids=None,
+ attention_mask=None,
+ labels=None,
+ position_ids=None,
+ head_mask=None
+ ):
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
- flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
- flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
- flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
- outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
- attention_mask=flat_attention_mask, head_mask=head_mask)
+ flat_position_ids = position_ids.view(
+ -1, position_ids.size(-1)
+ ) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(
+ -1, token_type_ids.size(-1)
+ ) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(
+ -1, attention_mask.size(-1)
+ ) if attention_mask is not None else None
+ outputs = self.bert(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ head_mask=head_mask
+ )
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
- outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
+ outputs = (reshaped_logits,
+ ) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
- outputs = (loss,) + outputs
+ outputs = (loss, ) + outputs
- return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
-@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
+@add_start_docstrings(
+ """Bert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
- BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
+)
class BertForTokenClassification(BertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
@@ -1148,16 +1298,28 @@ class BertForTokenClassification(BertPreTrainedModel):
self.apply(self.init_weights)
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
- position_ids=None, head_mask=None):
- outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
- attention_mask=attention_mask, head_mask=head_mask)
+ def forward(
+ self,
+ input_ids,
+ token_type_ids=None,
+ attention_mask=None,
+ labels=None,
+ position_ids=None,
+ head_mask=None
+ ):
+ outputs = self.bert(
+ input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask
+ )
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
- outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
+ outputs = (logits, ) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@@ -1168,14 +1330,16 @@ class BertForTokenClassification(BertPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- outputs = (loss,) + outputs
+ outputs = (loss, ) + outputs
- return outputs # (loss), scores, (hidden_states), (attentions)
+ return outputs # (loss), scores, (hidden_states), (attentions)
-@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
+@add_start_docstrings(
+ """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
- BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
+)
class BertForQuestionAnswering(BertPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
@@ -1224,10 +1388,23 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.apply(self.init_weights)
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
- end_positions=None, position_ids=None, head_mask=None):
- outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
- attention_mask=attention_mask, head_mask=head_mask)
+ def forward(
+ self,
+ input_ids,
+ token_type_ids=None,
+ attention_mask=None,
+ start_positions=None,
+ end_positions=None,
+ position_ids=None,
+ head_mask=None
+ ):
+ outputs = self.bert(
+ input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask
+ )
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
@@ -1235,7 +1412,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
- outputs = (start_logits, end_logits,) + outputs[2:]
+ outputs = (
+ start_logits,
+ end_logits,
+ ) + outputs[2:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@@ -1251,6 +1431,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
- outputs = (total_loss,) + outputs
+ outputs = (total_loss, ) + outputs
- return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
diff --git a/lib/pymafx/models/transformers/bert/modeling_graphormer.py b/lib/pymafx/models/transformers/bert/modeling_graphormer.py
index 91f2b869511ea6228bef40bb0d30b45c3194ce95..e318af8a45d34148e0db68f42181f692afbf8754 100644
--- a/lib/pymafx/models/transformers/bert/modeling_graphormer.py
+++ b/lib/pymafx/models/transformers/bert/modeling_graphormer.py
@@ -16,6 +16,7 @@ from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, Bert
# import src.modeling.data.config as cfg
# from src.modeling._gcnn import GraphConvolution, GraphResBlock
from .modeling_utils import prune_linear_layer
+
LayerNormClass = torch.nn.LayerNorm
BertLayerNorm = torch.nn.LayerNorm
@@ -26,7 +27,8 @@ class BertSelfAttention(nn.Module):
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
- "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
@@ -44,8 +46,7 @@ class BertSelfAttention(nn.Module):
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
- def forward(self, hidden_states, attention_mask, head_mask=None,
- history_state=None):
+ def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None):
if history_state is not None:
raise
x_states = torch.cat([history_state, hidden_states], dim=1)
@@ -57,7 +58,10 @@ class BertSelfAttention(nn.Module):
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
- print('mixed_query_layer', mixed_query_layer.shape, mixed_key_layer.shape, mixed_value_layer.shape)
+ print(
+ 'mixed_query_layer', mixed_query_layer.shape, mixed_key_layer.shape,
+ mixed_value_layer.shape
+ )
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
@@ -84,12 +88,13 @@ class BertSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
context_layer = context_layer.view(*new_context_layer_shape)
- outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, )
return outputs
+
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
@@ -113,12 +118,10 @@ class BertAttention(nn.Module):
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
- def forward(self, input_tensor, attention_mask, head_mask=None,
- history_state=None):
- self_outputs = self.self(input_tensor, attention_mask, head_mask,
- history_state)
+ def forward(self, input_tensor, attention_mask, head_mask=None, history_state=None):
+ self_outputs = self.self(input_tensor, attention_mask, head_mask, history_state)
attention_output = self.output(self_outputs[0], input_tensor)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them
return outputs
@@ -130,45 +133,46 @@ class GraphormerLayer(nn.Module):
self.mesh_type = config.mesh_type
if self.has_graph_conv == True:
- if self.mesh_type=='hand':
- self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
- elif self.mesh_type=='body':
- self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
-
+ if self.mesh_type == 'hand':
+ self.graph_conv = GraphResBlock(
+ config.hidden_size, config.hidden_size, mesh_type=self.mesh_type
+ )
+ elif self.mesh_type == 'body':
+ self.graph_conv = GraphResBlock(
+ config.hidden_size, config.hidden_size, mesh_type=self.mesh_type
+ )
+
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
- def MHA_GCN(self, hidden_states, attention_mask, head_mask=None,
- history_state=None):
- attention_outputs = self.attention(hidden_states, attention_mask,
- head_mask, history_state)
+ def MHA_GCN(self, hidden_states, attention_mask, head_mask=None, history_state=None):
+ attention_outputs = self.attention(hidden_states, attention_mask, head_mask, history_state)
attention_output = attention_outputs[0]
- if self.has_graph_conv==True:
+ if self.has_graph_conv == True:
if self.mesh_type == 'body':
- joints = attention_output[:,0:14,:]
- vertices = attention_output[:,14:-49,:]
- img_tokens = attention_output[:,-49:,:]
+ joints = attention_output[:, 0:14, :]
+ vertices = attention_output[:, 14:-49, :]
+ img_tokens = attention_output[:, -49:, :]
elif self.mesh_type == 'hand':
- joints = attention_output[:,0:21,:]
- vertices = attention_output[:,21:-49,:]
- img_tokens = attention_output[:,-49:,:]
+ joints = attention_output[:, 0:21, :]
+ vertices = attention_output[:, 21:-49, :]
+ img_tokens = attention_output[:, -49:, :]
vertices = self.graph_conv(vertices)
- joints_vertices = torch.cat([joints,vertices,img_tokens],dim=1)
+ joints_vertices = torch.cat([joints, vertices, img_tokens], dim=1)
else:
joints_vertices = attention_output
intermediate_output = self.intermediate(joints_vertices)
layer_output = self.output(intermediate_output, joints_vertices)
print('layer_output', layer_output.shape)
- outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
+ outputs = (layer_output, ) + attention_outputs[1:] # add attentions if we output them
return outputs
- def forward(self, hidden_states, attention_mask, head_mask=None,
- history_state=None):
- return self.MHA_GCN(hidden_states, attention_mask, head_mask,history_state)
+ def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None):
+ return self.MHA_GCN(hidden_states, attention_mask, head_mask, history_state)
class GraphormerEncoder(nn.Module):
@@ -176,36 +180,36 @@ class GraphormerEncoder(nn.Module):
super(GraphormerEncoder, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
- self.layer = nn.ModuleList([GraphormerLayer(config) for _ in range(config.num_hidden_layers)])
+ self.layer = nn.ModuleList(
+ [GraphormerLayer(config) for _ in range(config.num_hidden_layers)]
+ )
- def forward(self, hidden_states, attention_mask, head_mask=None,
- encoder_history_states=None):
+ def forward(self, hidden_states, attention_mask, head_mask=None, encoder_history_states=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
+ all_hidden_states = all_hidden_states + (hidden_states, )
history_state = None if encoder_history_states is None else encoder_history_states[i]
- layer_outputs = layer_module(
- hidden_states, attention_mask, head_mask[i],
- history_state)
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], history_state)
hidden_states = layer_outputs[0]
if self.output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
+ all_attentions = all_attentions + (layer_outputs[1], )
# Add last layer
if self.output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
+ all_hidden_states = all_hidden_states + (hidden_states, )
- outputs = (hidden_states,)
+ outputs = (hidden_states, )
if self.output_hidden_states:
- outputs = outputs + (all_hidden_states,)
+ outputs = outputs + (all_hidden_states, )
if self.output_attentions:
- outputs = outputs + (all_attentions,)
+ outputs = outputs + (all_attentions, )
+
+ return outputs # outputs, (hidden states), (attentions)
- return outputs # outputs, (hidden states), (attentions)
class EncoderBlock(BertPreTrainedModel):
def __init__(self, config):
@@ -215,7 +219,7 @@ class EncoderBlock(BertPreTrainedModel):
self.encoder = GraphormerEncoder(config)
# self.pooler = BertPooler(config)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- self.img_dim = config.img_feature_dim
+ self.img_dim = config.img_feature_dim
try:
self.use_img_layernorm = config.use_img_layernorm
@@ -237,12 +241,19 @@ class EncoderBlock(BertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
- def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None,
- position_ids=None, head_mask=None):
+ def forward(
+ self,
+ img_feats,
+ input_ids=None,
+ token_type_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None
+ ):
batch_size = len(img_feats)
seq_length = len(img_feats[0])
- input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).cuda()
+ input_ids = torch.zeros([batch_size, seq_length], dtype=torch.long).cuda()
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
@@ -251,7 +262,10 @@ class EncoderBlock(BertPreTrainedModel):
print('position_ids', seq_length, position_ids.shape)
position_embeddings = self.position_embeddings(position_ids)
- print('position_embeddings', position_embeddings.shape, self.config.max_position_embeddings, self.config.hidden_size)
+ print(
+ 'position_embeddings', position_embeddings.shape, self.config.max_position_embeddings,
+ self.config.hidden_size
+ )
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
@@ -270,7 +284,9 @@ class EncoderBlock(BertPreTrainedModel):
else:
raise NotImplementedError
- extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=next(self.parameters()).dtype
+ ) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
@@ -279,8 +295,12 @@ class EncoderBlock(BertPreTrainedModel):
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
- head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
- head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
+ -1
+ ) # We can specify head_mask for each layer
+ head_mask = head_mask.to(
+ dtype=next(self.parameters()).dtype
+ ) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
@@ -297,20 +317,20 @@ class EncoderBlock(BertPreTrainedModel):
embeddings = self.dropout(embeddings)
print('extended_attention_mask', extended_attention_mask.shape)
- encoder_outputs = self.encoder(embeddings,
- extended_attention_mask, head_mask=head_mask)
+ encoder_outputs = self.encoder(embeddings, extended_attention_mask, head_mask=head_mask)
sequence_output = encoder_outputs[0]
- outputs = (sequence_output,)
+ outputs = (sequence_output, )
if self.config.output_hidden_states:
all_hidden_states = encoder_outputs[1]
- outputs = outputs + (all_hidden_states,)
+ outputs = outputs + (all_hidden_states, )
if self.config.output_attentions:
all_attentions = encoder_outputs[-1]
- outputs = outputs + (all_attentions,)
+ outputs = outputs + (all_attentions, )
return outputs
+
class Graphormer(BertPreTrainedModel):
'''
The archtecture of a transformer encoder block we used in Graphormer
@@ -323,16 +343,31 @@ class Graphormer(BertPreTrainedModel):
self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim)
self.apply(self.init_weights)
- def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
- next_sentence_label=None, position_ids=None, head_mask=None):
+ def forward(
+ self,
+ img_feats,
+ input_ids=None,
+ token_type_ids=None,
+ attention_mask=None,
+ masked_lm_labels=None,
+ next_sentence_label=None,
+ position_ids=None,
+ head_mask=None
+ ):
'''
# self.bert has three outputs
# predictions[0]: output tokens
# predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states"
# predictions[2]: attentions, if enable "self.config.output_attentions"
'''
- predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
- attention_mask=attention_mask, head_mask=head_mask)
+ predictions = self.bert(
+ img_feats=img_feats,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask
+ )
# We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification.
pred_score = self.cls_head(predictions[0])
@@ -344,5 +379,3 @@ class Graphormer(BertPreTrainedModel):
return pred_score, predictions[1], predictions[-1]
else:
return pred_score
-
-
\ No newline at end of file
diff --git a/lib/pymafx/models/transformers/bert/modeling_utils.py b/lib/pymafx/models/transformers/bert/modeling_utils.py
index 458852810218b7b85f25cb564da8c96886500b7c..40a0915822c8e736de8ac2466c075e6cc5ef7e83 100644
--- a/lib/pymafx/models/transformers/bert/modeling_utils.py
+++ b/lib/pymafx/models/transformers/bert/modeling_utils.py
@@ -15,8 +15,7 @@
# limitations under the License.
"""PyTorch BERT model."""
-from __future__ import (absolute_import, division, print_function,
- unicode_literals)
+from __future__ import (absolute_import, division, print_function, unicode_literals)
import copy
import json
@@ -38,7 +37,6 @@ CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME = 'model.ckpt'
-
try:
from torch.nn import Identity
except ImportError:
@@ -54,16 +52,19 @@ except ImportError:
if not six.PY2:
+
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = ''.join(docstr) + fn.__doc__
return fn
+
return docstring_decorator
else:
# Not possible to update class docstrings on python2
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
return fn
+
return docstring_decorator
@@ -84,7 +85,9 @@ class PretrainedConfig(object):
""" Save a configuration object to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
- assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
+ assert os.path.isdir(
+ save_directory
+ ), "Saving path should be a directory where the model and configuration can be saved"
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
@@ -145,23 +148,28 @@ class PretrainedConfig(object):
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error(
- "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
- config_file))
+ "Couldn't reach server at '{}' to download pretrained model configuration file."
+ .format(config_file)
+ )
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
- ', '.join(cls.pretrained_config_archive_map.keys()),
- config_file))
+ ', '.join(cls.pretrained_config_archive_map.keys()), config_file
+ )
+ )
return None
if resolved_config_file == config_file:
pass
# logger.info("loading configuration file {}".format(config_file))
else:
- logger.info("loading configuration file {} from cache at {}".format(
- config_file, resolved_config_file))
+ logger.info(
+ "loading configuration file {} from cache at {}".format(
+ config_file, resolved_config_file
+ )
+ )
# Load config
config = cls.from_json_file(resolved_config_file)
@@ -235,7 +243,8 @@ class PreTrainedModel(nn.Module):
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
- ))
+ )
+ )
# Save config in model
self.config = config
@@ -269,7 +278,8 @@ class PreTrainedModel(nn.Module):
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
- new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
+ new_embeddings.weight.data[:num_tokens_to_copy, :
+ ] = old_embeddings.weight.data[:num_tokens_to_copy, :]
return new_embeddings
@@ -295,7 +305,7 @@ class PreTrainedModel(nn.Module):
Return: ``torch.nn.Embeddings``
Pointer to the input tokens Embedding Module of the model
"""
- base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
@@ -315,14 +325,16 @@ class PreTrainedModel(nn.Module):
Args:
heads_to_prune: dict of {layer_num (int): list of heads to prune in this layer (list of int)}
"""
- base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
""" Save a model with its configuration file to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
- assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
+ assert os.path.isdir(
+ save_directory
+ ), "Saving path should be a directory where the model and configuration can be saved"
# Only save the model it-self if we are using distributed training
model_to_save = self.module if hasattr(self, 'module') else self
@@ -402,8 +414,10 @@ class PreTrainedModel(nn.Module):
# Load config
if config is None:
config, model_kwargs = cls.config_class.from_pretrained(
- pretrained_model_name_or_path, *model_args,
- cache_dir=cache_dir, return_unused_kwargs=True,
+ pretrained_model_name_or_path,
+ *model_args,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
**kwargs
)
else:
@@ -415,7 +429,9 @@ class PreTrainedModel(nn.Module):
elif os.path.isdir(pretrained_model_name_or_path):
if from_tf:
# Directly load from a TensorFlow checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
+ archive_file = os.path.join(
+ pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index"
+ )
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
@@ -430,22 +446,27 @@ class PreTrainedModel(nn.Module):
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
- "Couldn't reach server at '{}' to download pretrained weights.".format(
- archive_file))
+ "Couldn't reach server at '{}' to download pretrained weights.".
+ format(archive_file)
+ )
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
- ', '.join(cls.pretrained_model_archive_map.keys()),
- archive_file))
+ ', '.join(cls.pretrained_model_archive_map.keys()), archive_file
+ )
+ )
return None
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
- logger.info("loading weights file {} from cache at {}".format(
- archive_file, resolved_archive_file))
+ logger.info(
+ "loading weights file {} from cache at {}".format(
+ archive_file, resolved_archive_file
+ )
+ )
# Instantiate model.
model = cls(config, *model_args, **model_kwargs)
@@ -454,7 +475,9 @@ class PreTrainedModel(nn.Module):
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
- return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
+ return cls.load_tf_weights(
+ model, config, resolved_archive_file[:-6]
+ ) # Remove the '.index'
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
@@ -484,7 +507,8 @@ class PreTrainedModel(nn.Module):
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
- state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
+ )
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
@@ -492,30 +516,46 @@ class PreTrainedModel(nn.Module):
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ''
model_to_load = model
- if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
+ if not hasattr(model, cls.base_model_prefix) and any(
+ s.startswith(cls.base_model_prefix) for s in state_dict.keys()
+ ):
start_prefix = cls.base_model_prefix + '.'
- if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
+ if hasattr(model, cls.base_model_prefix
+ ) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
if len(missing_keys) > 0:
- logger.info("Weights of {} not initialized from pretrained model: {}".format(
- model.__class__.__name__, missing_keys))
+ logger.info(
+ "Weights of {} not initialized from pretrained model: {}".format(
+ model.__class__.__name__, missing_keys
+ )
+ )
if len(unexpected_keys) > 0:
- logger.info("Weights from pretrained model not used in {}: {}".format(
- model.__class__.__name__, unexpected_keys))
+ logger.info(
+ "Weights from pretrained model not used in {}: {}".format(
+ model.__class__.__name__, unexpected_keys
+ )
+ )
if len(error_msgs) > 0:
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
- model.__class__.__name__, "\n\t".join(error_msgs)))
+ raise RuntimeError(
+ 'Error(s) in loading state_dict for {}:\n\t{}'.format(
+ model.__class__.__name__, "\n\t".join(error_msgs)
+ )
+ )
if hasattr(model, 'tie_weights'):
- model.tie_weights() # make sure word embedding weights are still tied
+ model.tie_weights() # make sure word embedding weights are still tied
# Set model in evaluation mode to desactivate DropOut modules by default
model.eval()
if output_loading_info:
- loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "error_msgs": error_msgs
+ }
return model, loading_info
return model
@@ -534,7 +574,7 @@ class Conv1D(nn.Module):
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
- size_out = x.size()[:-1] + (self.nf,)
+ size_out = x.size()[:-1] + (self.nf, )
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
@@ -586,9 +626,10 @@ class PoolerEndLogits(nn.Module):
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
- start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
- start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
- start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
+ start_positions = start_positions[:, None,
+ None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
+ start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
x = self.activation(x)
@@ -629,14 +670,16 @@ class PoolerAnswerClass(nn.Module):
hsz = hidden_states.shape[-1]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None:
- start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
- start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
+ start_positions = start_positions[:, None,
+ None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
+ start_states = hidden_states.gather(-2,
+ start_positions).squeeze(-2) # shape (bsz, hsz)
if cls_index is not None:
- cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
- cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
+ cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
else:
- cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
+ cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
x = self.activation(x)
@@ -694,8 +737,15 @@ class SQuADHead(nn.Module):
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
- def forward(self, hidden_states, start_positions=None, end_positions=None,
- cls_index=None, is_impossible=None, p_mask=None):
+ def forward(
+ self,
+ hidden_states,
+ start_positions=None,
+ end_positions=None,
+ cls_index=None,
+ is_impossible=None,
+ p_mask=None
+ ):
outputs = ()
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
@@ -707,7 +757,9 @@ class SQuADHead(nn.Module):
x.squeeze_(-1)
# during training, compute the end logits based on the ground truth of the start position
- end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
+ end_logits = self.end_logits(
+ hidden_states, start_positions=start_positions, p_mask=p_mask
+ )
loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
@@ -716,38 +768,58 @@ class SQuADHead(nn.Module):
if cls_index is not None and is_impossible is not None:
# Predict answerability from the representation of CLS and START
- cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
+ cls_logits = self.answer_class(
+ hidden_states, start_positions=start_positions, cls_index=cls_index
+ )
loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
- outputs = (total_loss,) + outputs
+ outputs = (total_loss, ) + outputs
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
- start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
-
- start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
- start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
- start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
- start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
-
- hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
+ start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
+
+ start_top_log_probs, start_top_index = torch.topk(
+ start_log_probs, self.start_n_top, dim=-1
+ ) # shape (bsz, start_n_top)
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(
+ -1, -1, hsz
+ ) # shape (bsz, start_n_top, hsz)
+ start_states = torch.gather(
+ hidden_states, -2, start_top_index_exp
+ ) # shape (bsz, start_n_top, hsz)
+ start_states = start_states.unsqueeze(1).expand(
+ -1, slen, -1, -1
+ ) # shape (bsz, slen, start_n_top, hsz)
+
+ hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
+ start_states
+ ) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
- end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
- end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
+ end_logits = self.end_logits(
+ hidden_states_expanded, start_states=start_states, p_mask=p_mask
+ )
+ end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
- end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
+ end_top_log_probs, end_top_index = torch.topk(
+ end_log_probs, self.end_n_top, dim=1
+ ) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
- cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
+ cls_logits = self.answer_class(
+ hidden_states, start_states=start_states, cls_index=cls_index
+ )
- outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
+ outputs = (
+ start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
+ ) + outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) (total_loss,)
@@ -781,7 +853,9 @@ class SequenceSummary(nn.Module):
self.summary = Identity()
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
- if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
+ if hasattr(
+ config, 'summary_proj_to_labels'
+ ) and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
@@ -814,12 +888,17 @@ class SequenceSummary(nn.Module):
output = hidden_states.mean(dim=1)
elif self.summary_type == 'token_ids':
if token_ids is None:
- token_ids = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
+ token_ids = torch.full_like(
+ hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long
+ )
else:
token_ids = token_ids.unsqueeze(-1).unsqueeze(-1)
- token_ids = token_ids.expand((-1,) * (token_ids.dim()-1) + (hidden_states.size(-1),))
+ token_ids = token_ids.expand(
+ (-1, ) * (token_ids.dim() - 1) + (hidden_states.size(-1), )
+ )
# shape of token_ids: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
- output = hidden_states.gather(-2, token_ids).squeeze(-2) # shape (bsz, XX, hidden_size)
+ output = hidden_states.gather(-2,
+ token_ids).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == 'attn':
raise NotImplementedError
@@ -845,7 +924,8 @@ def prune_linear_layer(layer, index, dim=0):
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
- new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias
+ is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
diff --git a/lib/pymafx/models/transformers/net_utils.py b/lib/pymafx/models/transformers/net_utils.py
index 3e29bb1f1f0b9428a0eeb0eeb4eb432db190a5e8..52782911e276705ec0dd908ce9676430c0a58d72 100644
--- a/lib/pymafx/models/transformers/net_utils.py
+++ b/lib/pymafx/models/transformers/net_utils.py
@@ -7,22 +7,24 @@ import torch.nn.functional as F
class single_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(single_conv, self).__init__()
- self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True),)
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
+ nn.BatchNorm2d(out_ch),
+ nn.ReLU(inplace=True),
+ )
def forward(self, x):
return self.conv(x)
+
class double_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
- self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True))
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), nn.BatchNorm2d(out_ch),
+ nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
+ nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)
+ )
def forward(self, x):
return self.conv(x)
@@ -31,12 +33,11 @@ class double_conv(nn.Module):
class double_conv_down(nn.Module):
def __init__(self, in_ch, out_ch):
super(double_conv_down, self).__init__()
- self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True))
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1), nn.BatchNorm2d(out_ch),
+ nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
+ nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)
+ )
def forward(self, x):
return self.conv(x)
@@ -45,13 +46,12 @@ class double_conv_down(nn.Module):
class double_conv_up(nn.Module):
def __init__(self, in_ch, out_ch):
super(double_conv_up, self).__init__()
- self.conv = nn.Sequential(nn.UpsamplingNearest2d(scale_factor=2),
- nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True))
+ self.conv = nn.Sequential(
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), nn.BatchNorm2d(out_ch),
+ nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
+ nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)
+ )
def forward(self, x):
return self.conv(x)
@@ -87,31 +87,35 @@ class PosEnSine(nn.Module):
x_embed = x_embed / (torch.max(x_embed) + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+ dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_z = z_embed[:, :, :, None] / dim_t
- pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos_z = torch.stack((pos_z[:, :, :, 0::2].sin(), pos_z[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
+ dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
+ dim=4).flatten(3)
+ pos_z = torch.stack((pos_z[:, :, :, 0::2].sin(), pos_z[:, :, :, 1::2].cos()),
+ dim=4).flatten(3)
pos = torch.cat((pos_x, pos_y, pos_z), dim=3).permute(0, 3, 1, 2)
# if pt_coord is None:
pos = pos.repeat(b, 1, 1, 1)
return pos
+
def softmax_attention(q, k, v):
# b x n x d x h x w
h, w = q.shape[-2], q.shape[-1]
- q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d
- k = k.flatten(-2) # b x n x d x hw
+ q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d
+ k = k.flatten(-2) # b x n x d x hw
v = v.flatten(-2).transpose(-2, -1)
print('softmax', q.shape, k.shape, v.shape)
- N = k.shape[-1] # ?????? maybe change to k.shape[-2]????
- attn = torch.matmul(q / N ** 0.5, k)
+ N = k.shape[-1] # ?????? maybe change to k.shape[-2]????
+ attn = torch.matmul(q / N**0.5, k)
attn = F.softmax(attn, dim=-1)
output = torch.matmul(attn, v)
@@ -125,8 +129,8 @@ def dotproduct_attention(q, k, v):
# b x n x d x h x w
h, w = q.shape[-2], q.shape[-1]
- q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d
- k = k.flatten(-2) # b x n x d x hw
+ q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d
+ k = k.flatten(-2) # b x n x d x hw
v = v.flatten(-2).transpose(-2, -1)
N = k.shape[-1]
@@ -140,7 +144,7 @@ def dotproduct_attention(q, k, v):
return output, attn
-def long_range_attention(q, k, v, P_h, P_w): # fixed patch size
+def long_range_attention(q, k, v, P_h, P_w): # fixed patch size
B, N, C, qH, qW = q.size()
_, _, _, kH, kW = k.size()
@@ -151,17 +155,17 @@ def long_range_attention(q, k, v, P_h, P_w): # fixed patch size
k = k.reshape(B, N, C, kQ_h, P_h, kQ_w, P_w)
v = v.reshape(B, N, -1, kQ_h, P_h, kQ_w, P_w)
- q = q.permute(0, 1, 4, 6, 2, 3, 5) # [b, n, Ph, Pw, d, Qh, Qw]
+ q = q.permute(0, 1, 4, 6, 2, 3, 5) # [b, n, Ph, Pw, d, Qh, Qw]
k = k.permute(0, 1, 4, 6, 2, 3, 5)
v = v.permute(0, 1, 4, 6, 2, 3, 5)
- output, attn = softmax_attention(q, k, v) # attn: [b, n, Ph, Pw, qQh*qQw, kQ_h*kQ_w]
+ output, attn = softmax_attention(q, k, v) # attn: [b, n, Ph, Pw, qQh*qQw, kQ_h*kQ_w]
output = output.permute(0, 1, 4, 5, 2, 6, 3)
output = output.reshape(B, N, -1, qH, qW)
return output, attn
-def short_range_attention(q, k, v, Q_h, Q_w): # fixed patch number
+def short_range_attention(q, k, v, Q_h, Q_w): # fixed patch number
B, N, C, qH, qW = q.size()
_, _, _, kH, kW = k.size()
@@ -172,11 +176,11 @@ def short_range_attention(q, k, v, Q_h, Q_w): # fixed patch number
k = k.reshape(B, N, C, Q_h, kP_h, Q_w, kP_w)
v = v.reshape(B, N, -1, Q_h, kP_h, Q_w, kP_w)
- q = q.permute(0, 1, 3, 5, 2, 4, 6) # [b, n, Qh, Qw, d, Ph, Pw]
+ q = q.permute(0, 1, 3, 5, 2, 4, 6) # [b, n, Qh, Qw, d, Ph, Pw]
k = k.permute(0, 1, 3, 5, 2, 4, 6)
v = v.permute(0, 1, 3, 5, 2, 4, 6)
- output, attn = softmax_attention(q, k, v) # attn: [b, n, Qh, Qw, qPh*qPw, kPh*kPw]
+ output, attn = softmax_attention(q, k, v) # attn: [b, n, Qh, Qw, qPh*qPw, kPh*kPw]
output = output.permute(0, 1, 4, 2, 5, 3, 6)
output = output.reshape(B, N, -1, qH, qW)
return output, attn
@@ -188,7 +192,7 @@ def space_to_depth(x, block_size):
if len(x.shape) >= 5:
x = x.view(-1, c, h, w)
unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
- return unfolded_x.view(*x_shape[0:-3], c * block_size ** 2, h // block_size, w // block_size)
+ return unfolded_x.view(*x_shape[0:-3], c * block_size**2, h // block_size, w // block_size)
def depth_to_space(x, block_size):
@@ -196,17 +200,17 @@ def depth_to_space(x, block_size):
c, h, w = x_shape[-3:]
x = x.view(-1, c, h, w)
y = torch.nn.functional.pixel_shuffle(x, block_size)
- return y.view(*x_shape[0:-3], -1, h*block_size, w*block_size)
+ return y.view(*x_shape[0:-3], -1, h * block_size, w * block_size)
def patch_attention(q, k, v, P):
# q: [b, nhead, c, h, w]
- q_patch = space_to_depth(q, P) # [b, nhead, cP^2, h/P, w/P]
+ q_patch = space_to_depth(q, P) # [b, nhead, cP^2, h/P, w/P]
k_patch = space_to_depth(k, P)
v_patch = space_to_depth(v, P)
# output: [b, nhead, cP^2, h/P, w/P]
# attn: [b, nhead, h/P*w/P, h/P*w/P]
- output, attn = softmax_attention(q_patch, k_patch, v_patch)
- output = depth_to_space(output, P) # output: [b, nhead, c, h, w]
+ output, attn = softmax_attention(q_patch, k_patch, v_patch)
+ output = depth_to_space(output, P) # output: [b, nhead, c, h, w]
return output, attn
diff --git a/lib/pymafx/models/transformers/texformer.py b/lib/pymafx/models/transformers/texformer.py
index 4554ce9629f6804e6df82c097f0774905fe332bf..4266b24ed6839f91ce5ca819cc3750143387d48f 100644
--- a/lib/pymafx/models/transformers/texformer.py
+++ b/lib/pymafx/models/transformers/texformer.py
@@ -2,6 +2,7 @@ import torch.nn as nn
from .net_utils import single_conv, double_conv, double_conv_down, double_conv_up, PosEnSine
from .transformer_basics import OurMultiheadAttention
+
class TransformerDecoderUnit(nn.Module):
def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None):
super(TransformerDecoderUnit, self).__init__()
@@ -11,8 +12,8 @@ class TransformerDecoderUnit(nn.Module):
self.P = P
self.pos_en = PosEnSine(self.feat_dim // 2)
- self.attn = OurMultiheadAttention(feat_dim, n_head) # cross-attention
-
+ self.attn = OurMultiheadAttention(feat_dim, n_head) # cross-attention
+
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.activation = nn.ReLU(inplace=True)
@@ -28,7 +29,9 @@ class TransformerDecoderUnit(nn.Module):
k_pos_embed = 0
# cross-multi-head attention
- out = self.attn(q=q+q_pos_embed, k=k+k_pos_embed, v=v, attn_type=self.attn_type, P=self.P)[0]
+ out = self.attn(
+ q=q + q_pos_embed, k=k + k_pos_embed, v=v, attn_type=self.attn_type, P=self.P
+ )[0]
# feed forward
out2 = self.linear2(self.activation(self.linear1(out)))
@@ -52,17 +55,17 @@ class Unet(nn.Module):
def forward(self, x):
feat0 = self.conv_in(x) # H
- feat1 = self.conv1(feat0) # H/2
+ feat1 = self.conv1(feat0) # H/2
feat2 = self.conv2(feat1) # H/4
feat3 = self.conv3(feat2) # H/4
- feat3 = feat3 + feat2 # H/4
+ feat3 = feat3 + feat2 # H/4
feat4 = self.conv4(feat3) # H/2
feat4 = feat4 + feat1 # H/2
- feat5 = self.conv5(feat4) # H
- feat5 = feat5 + feat0 # H
+ feat5 = self.conv5(feat4) # H
+ feat5 = feat5 + feat0 # H
feat6 = self.conv6(feat5)
- return feat0, feat1, feat2, feat3, feat4, feat6
+ return feat0, feat1, feat2, feat3, feat4, feat6
class Texformer(nn.Module):
@@ -77,18 +80,20 @@ class Texformer(nn.Module):
if not self.mask_fusion:
v_ch = out_ch
else:
- v_ch = 2+3
+ v_ch = 2 + 3
self.unet_q = Unet(tgt_ch, self.feat_dim, self.feat_dim)
self.unet_k = Unet(src_ch, self.feat_dim, self.feat_dim)
self.unet_v = Unet(v_ch, self.feat_dim, self.feat_dim)
- self.trans_dec = nn.ModuleList([None,
- None,
- None,
- TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'),
- TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'),
- TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct')])
+ self.trans_dec = nn.ModuleList(
+ [
+ None, None, None,
+ TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'),
+ TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'),
+ TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct')
+ ]
+ )
self.conv0 = double_conv(self.feat_dim, self.feat_dim)
self.conv1 = double_conv_down(self.feat_dim, self.feat_dim)
@@ -96,13 +101,17 @@ class Texformer(nn.Module):
self.conv3 = double_conv(self.feat_dim, self.feat_dim)
self.conv4 = double_conv_up(self.feat_dim, self.feat_dim)
self.conv5 = double_conv_up(self.feat_dim, self.feat_dim)
-
+
if not self.mask_fusion:
- self.conv6 = nn.Sequential(single_conv(self.feat_dim, self.feat_dim),
- nn.Conv2d(self.feat_dim, out_ch, 3, 1, 1))
+ self.conv6 = nn.Sequential(
+ single_conv(self.feat_dim, self.feat_dim),
+ nn.Conv2d(self.feat_dim, out_ch, 3, 1, 1)
+ )
else:
- self.conv6 = nn.Sequential(single_conv(self.feat_dim, self.feat_dim),
- nn.Conv2d(self.feat_dim, 2+3+1, 3, 1, 1)) # mask*flow-sampling + (1-mask)*rgb
+ self.conv6 = nn.Sequential(
+ single_conv(self.feat_dim, self.feat_dim),
+ nn.Conv2d(self.feat_dim, 2 + 3 + 1, 3, 1, 1)
+ ) # mask*flow-sampling + (1-mask)*rgb
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
@@ -120,16 +129,16 @@ class Texformer(nn.Module):
outputs.append(self.trans_dec[i](q_feat[i], k_feat[i], v_feat[i]))
print('outputs', outputs[-1].shape)
- f0 = self.conv0(outputs[2]) # H
- f1 = self.conv1(f0) # H/2
+ f0 = self.conv0(outputs[2]) # H
+ f1 = self.conv1(f0) # H/2
f1 = f1 + outputs[1]
- f2 = self.conv2(f1) # H/4
+ f2 = self.conv2(f1) # H/4
f2 = f2 + outputs[0]
- f3 = self.conv3(f2) # H/4
- f3 = f3 + outputs[0] + f2
- f4 = self.conv4(f3) # H/2
+ f3 = self.conv3(f2) # H/4
+ f3 = f3 + outputs[0] + f2
+ f4 = self.conv4(f3) # H/2
f4 = f4 + outputs[1] + f1
- f5 = self.conv5(f4) # H
+ f5 = self.conv5(f4) # H
f5 = f5 + outputs[2] + f0
if not self.mask_fusion:
out = self.tanh(self.conv6(f5))
@@ -137,4 +146,3 @@ class Texformer(nn.Module):
out_ = self.conv6(f5)
out = [self.tanh(out_[:, :2]), self.tanh(out_[:, 2:5]), self.sigmoid(out_[:, 5:])]
return out
-
diff --git a/lib/pymafx/models/transformers/tokenlearner.py b/lib/pymafx/models/transformers/tokenlearner.py
index 5127fa57e7350daac11ed4e0fde34748eddbbd1f..441b361a721f685f481e764c19b624b593124c1b 100644
--- a/lib/pymafx/models/transformers/tokenlearner.py
+++ b/lib/pymafx/models/transformers/tokenlearner.py
@@ -2,44 +2,45 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
+
class SpatialAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Sequential(
- nn.Conv2d(2, 1, kernel_size=(1,1), stride=1),
- nn.BatchNorm2d(1),
- nn.ReLU()
+ nn.Conv2d(2, 1, kernel_size=(1, 1), stride=1), nn.BatchNorm2d(1), nn.ReLU()
)
-
+
self.sgap = nn.AvgPool2d(2)
def forward(self, x):
B, H, W, C = x.shape
x = x.reshape(B, C, H, W)
-
+
mx = torch.max(x, 1)[0].unsqueeze(1)
avg = torch.mean(x, 1).unsqueeze(1)
combined = torch.cat([mx, avg], dim=1)
fmap = self.conv(combined)
weight_map = torch.sigmoid(fmap)
out = (x * weight_map).mean(dim=(-2, -1))
-
+
return out, x * weight_map
+
class TokenLearner(nn.Module):
def __init__(self, S) -> None:
super().__init__()
self.S = S
self.tokenizers = nn.ModuleList([SpatialAttention() for _ in range(S)])
-
+
def forward(self, x):
B, _, _, C = x.shape
Z = torch.Tensor(B, self.S, C).to(x)
for i in range(self.S):
- Ai, _ = self.tokenizers[i](x) # [B, C]
+ Ai, _ = self.tokenizers[i](x) # [B, C]
Z[:, i, :] = Ai
return Z
+
class TokenFuser(nn.Module):
def __init__(self, H, W, C, S) -> None:
super().__init__()
@@ -47,18 +48,18 @@ class TokenFuser(nn.Module):
self.Bi = nn.Linear(C, S)
self.spatial_attn = SpatialAttention()
self.S = S
-
+
def forward(self, y, x):
B, S, C = y.shape
B, H, W, C = x.shape
-
+
Y = self.projection(y.reshape(B, C, S)).reshape(B, S, C)
- Bw = torch.sigmoid(self.Bi(x)).reshape(B, H*W, S) # [B, HW, S]
+ Bw = torch.sigmoid(self.Bi(x)).reshape(B, H * W, S) # [B, HW, S]
BwY = torch.matmul(Bw, Y)
-
+
_, xj = self.spatial_attn(x)
- xj = xj.reshape(B, H*W, C)
+ xj = xj.reshape(B, H * W, C)
out = (BwY + xj).reshape(B, H, W, C)
-
- return out
\ No newline at end of file
+
+ return out
diff --git a/lib/pymafx/models/transformers/transformer_basics.py b/lib/pymafx/models/transformers/transformer_basics.py
index 05c26c1639b5e5a0e4f68d782a08041024d302ff..144ccd76b7e2f73189634ab551691c4262781b9d 100644
--- a/lib/pymafx/models/transformers/transformer_basics.py
+++ b/lib/pymafx/models/transformers/transformer_basics.py
@@ -35,7 +35,7 @@ class OurMultiheadAttention(nn.Module):
# -------------- Attention -----------------
if attn_type == 'softmax':
- q, attn = softmax_attention(q, k, v) # b x n x dk x h x w --> b x n x dv x h x w
+ q, attn = softmax_attention(q, k, v) # b x n x dk x h x w --> b x n x dv x h x w
elif attn_type == 'dotproduct':
q, attn = dotproduct_attention(q, k, v)
elif attn_type == 'patch':
@@ -50,7 +50,7 @@ class OurMultiheadAttention(nn.Module):
# Concatenate all the heads together: b x (n*dv) x h x w
q = q.reshape(q.shape[0], -1, q.shape[3], q.shape[4])
- q = self.fc(q) # b x d x h x w
+ q = self.fc(q) # b x d x h x w
return q, attn
@@ -65,22 +65,24 @@ class TransformerEncoderUnit(nn.Module):
self.pos_en = PosEnSine(self.feat_dim // 2)
self.attn = OurMultiheadAttention(feat_dim, n_head)
-
+
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.activation = nn.ReLU(inplace=True)
self.norm1 = nn.BatchNorm2d(self.feat_dim)
- self.norm2 = nn.BatchNorm2d(self.feat_dim)
+ self.norm2 = nn.BatchNorm2d(self.feat_dim)
def forward(self, src):
if self.pos_en_flag:
pos_embed = self.pos_en(src)
else:
pos_embed = 0
-
+
# multi-head attention
- src2 = self.attn(q=src+pos_embed, k=src+pos_embed, v=src, attn_type=self.attn_type, P=self.P)[0]
+ src2 = self.attn(
+ q=src + pos_embed, k=src + pos_embed, v=src, attn_type=self.attn_type, P=self.P
+ )[0]
src = src + src2
src = self.norm1(src)
@@ -102,26 +104,40 @@ class TransformerEncoderUnitSparse(nn.Module):
self.pos_en = PosEnSine(self.feat_dim // 2)
self.attn1 = OurMultiheadAttention(feat_dim, n_head) # long range
self.attn2 = OurMultiheadAttention(feat_dim, n_head) # short range
-
+
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.activation = nn.ReLU(inplace=True)
self.norm1 = nn.BatchNorm2d(self.feat_dim)
- self.norm2 = nn.BatchNorm2d(self.feat_dim)
+ self.norm2 = nn.BatchNorm2d(self.feat_dim)
def forward(self, src):
if self.pos_en_flag:
pos_embed = self.pos_en(src)
else:
pos_embed = 0
-
+
# multi-head long-range attention
- src2 = self.attn1(q=src+pos_embed, k=src+pos_embed, v=src, attn_type='sparse_long', ah=self.ahw[0], aw=self.ahw[1])[0]
+ src2 = self.attn1(
+ q=src + pos_embed,
+ k=src + pos_embed,
+ v=src,
+ attn_type='sparse_long',
+ ah=self.ahw[0],
+ aw=self.ahw[1]
+ )[0]
src = src + src2 # ? this might be ok to remove
-
+
# multi-head short-range attention
- src2 = self.attn2(q=src+pos_embed, k=src+pos_embed, v=src, attn_type='sparse_short', ah=self.ahw[2], aw=self.ahw[3])[0]
+ src2 = self.attn2(
+ q=src + pos_embed,
+ k=src + pos_embed,
+ v=src,
+ attn_type='sparse_short',
+ ah=self.ahw[2],
+ aw=self.ahw[3]
+ )[0]
src = src + src2
src = self.norm1(src)
@@ -142,16 +158,16 @@ class TransformerDecoderUnit(nn.Module):
self.P = P
self.pos_en = PosEnSine(self.feat_dim // 2)
- self.attn1 = OurMultiheadAttention(feat_dim, n_head) # self-attention
- self.attn2 = OurMultiheadAttention(feat_dim, n_head) # cross-attention
-
+ self.attn1 = OurMultiheadAttention(feat_dim, n_head) # self-attention
+ self.attn2 = OurMultiheadAttention(feat_dim, n_head) # cross-attention
+
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.activation = nn.ReLU(inplace=True)
self.norm1 = nn.BatchNorm2d(self.feat_dim)
- self.norm2 = nn.BatchNorm2d(self.feat_dim)
- self.norm3 = nn.BatchNorm2d(self.feat_dim)
+ self.norm2 = nn.BatchNorm2d(self.feat_dim)
+ self.norm3 = nn.BatchNorm2d(self.feat_dim)
def forward(self, tgt, src):
if self.pos_en_flag:
@@ -160,14 +176,18 @@ class TransformerDecoderUnit(nn.Module):
else:
src_pos_embed = 0
tgt_pos_embed = 0
-
+
# self-multi-head attention
- tgt2 = self.attn1(q=tgt+tgt_pos_embed, k=tgt+tgt_pos_embed, v=tgt, attn_type=self.attn_type, P=self.P)[0]
+ tgt2 = self.attn1(
+ q=tgt + tgt_pos_embed, k=tgt + tgt_pos_embed, v=tgt, attn_type=self.attn_type, P=self.P
+ )[0]
tgt = tgt + tgt2
tgt = self.norm1(tgt)
# cross-multi-head attention
- tgt2 = self.attn2(q=tgt+tgt_pos_embed, k=src+src_pos_embed, v=src, attn_type=self.attn_type, P=self.P)[0]
+ tgt2 = self.attn2(
+ q=tgt + tgt_pos_embed, k=src + src_pos_embed, v=src, attn_type=self.attn_type, P=self.P
+ )[0]
tgt = tgt + tgt2
tgt = self.norm2(tgt)
@@ -183,23 +203,25 @@ class TransformerDecoderUnitSparse(nn.Module):
def __init__(self, feat_dim, n_head=8, pos_en_flag=True, ahw=None):
super(TransformerDecoderUnitSparse, self).__init__()
self.feat_dim = feat_dim
- self.ahw = ahw # [Ph_tgt, Pw_tgt, Qh_tgt, Qw_tgt, Ph_src, Pw_src, Qh_tgt, Qw_tgt]
+ self.ahw = ahw # [Ph_tgt, Pw_tgt, Qh_tgt, Qw_tgt, Ph_src, Pw_src, Qh_tgt, Qw_tgt]
self.pos_en_flag = pos_en_flag
self.pos_en = PosEnSine(self.feat_dim // 2)
- self.attn1_1 = OurMultiheadAttention(feat_dim, n_head) # self-attention: long
- self.attn1_2 = OurMultiheadAttention(feat_dim, n_head) # self-attention: short
+ self.attn1_1 = OurMultiheadAttention(feat_dim, n_head) # self-attention: long
+ self.attn1_2 = OurMultiheadAttention(feat_dim, n_head) # self-attention: short
- self.attn2_1 = OurMultiheadAttention(feat_dim, n_head) # cross-attention: self-attention-long + cross-attention-short
+ self.attn2_1 = OurMultiheadAttention(
+ feat_dim, n_head
+ ) # cross-attention: self-attention-long + cross-attention-short
self.attn2_2 = OurMultiheadAttention(feat_dim, n_head)
-
+
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.activation = nn.ReLU(inplace=True)
self.norm1 = nn.BatchNorm2d(self.feat_dim)
- self.norm2 = nn.BatchNorm2d(self.feat_dim)
- self.norm3 = nn.BatchNorm2d(self.feat_dim)
+ self.norm2 = nn.BatchNorm2d(self.feat_dim)
+ self.norm3 = nn.BatchNorm2d(self.feat_dim)
def forward(self, tgt, src):
if self.pos_en_flag:
@@ -208,20 +230,48 @@ class TransformerDecoderUnitSparse(nn.Module):
else:
src_pos_embed = 0
tgt_pos_embed = 0
-
+
# self-multi-head attention: sparse long
- tgt2 = self.attn1_1(q=tgt+tgt_pos_embed, k=tgt+tgt_pos_embed, v=tgt, attn_type='sparse_long', ah=self.ahw[0], aw=self.ahw[1])[0]
+ tgt2 = self.attn1_1(
+ q=tgt + tgt_pos_embed,
+ k=tgt + tgt_pos_embed,
+ v=tgt,
+ attn_type='sparse_long',
+ ah=self.ahw[0],
+ aw=self.ahw[1]
+ )[0]
tgt = tgt + tgt2
# self-multi-head attention: sparse short
- tgt2 = self.attn1_2(q=tgt+tgt_pos_embed, k=tgt+tgt_pos_embed, v=tgt, attn_type='sparse_short', ah=self.ahw[2], aw=self.ahw[3])[0]
+ tgt2 = self.attn1_2(
+ q=tgt + tgt_pos_embed,
+ k=tgt + tgt_pos_embed,
+ v=tgt,
+ attn_type='sparse_short',
+ ah=self.ahw[2],
+ aw=self.ahw[3]
+ )[0]
tgt = tgt + tgt2
tgt = self.norm1(tgt)
# self-multi-head attention: sparse long
- src2 = self.attn2_1(q=src+src_pos_embed, k=src+src_pos_embed, v=src, attn_type='sparse_long', ah=self.ahw[4], aw=self.ahw[5])[0]
+ src2 = self.attn2_1(
+ q=src + src_pos_embed,
+ k=src + src_pos_embed,
+ v=src,
+ attn_type='sparse_long',
+ ah=self.ahw[4],
+ aw=self.ahw[5]
+ )[0]
src = src + src2
# cross-multi-head attention: sparse short
- tgt2 = self.attn2_2(q=tgt+tgt_pos_embed, k=src+src_pos_embed, v=src, attn_type='sparse_short', ah=self.ahw[6], aw=self.ahw[7])[0]
+ tgt2 = self.attn2_2(
+ q=tgt + tgt_pos_embed,
+ k=src + src_pos_embed,
+ v=src,
+ attn_type='sparse_short',
+ ah=self.ahw[6],
+ aw=self.ahw[7]
+ )[0]
tgt = tgt + tgt2
tgt = self.norm2(tgt)
@@ -231,4 +281,3 @@ class TransformerDecoderUnitSparse(nn.Module):
tgt = self.norm3(tgt)
return tgt
-
diff --git a/lib/pymafx/utils/binvox_rw.py b/lib/pymafx/utils/binvox_rw.py
index c9c11d6992827ca2132a87599f2042867f77a455..947c3258691da908954f765bde07e0978cfb9f97 100644
--- a/lib/pymafx/utils/binvox_rw.py
+++ b/lib/pymafx/utils/binvox_rw.py
@@ -16,7 +16,6 @@
#
# Modified by Christopher B. Choy
# for python 3 support
-
"""
Binvox to Numpy and back.
@@ -65,6 +64,7 @@ True
import numpy as np
+
class Voxels(object):
""" Holds a binvox model.
data is either a three-dimensional numpy boolean array (dense representation)
@@ -86,7 +86,6 @@ class Voxels(object):
z = scale*z_n + translate[2]
"""
-
def __init__(self, data, dims, translate, scale, axis_order):
self.data = data
self.dims = dims
@@ -104,6 +103,7 @@ class Voxels(object):
def write(self, fp):
write(self, fp)
+
def read_header(fp):
""" Read binvox header. Mostly meant for internal use.
"""
@@ -116,6 +116,7 @@ def read_header(fp):
line = fp.readline()
return dims, translate, scale
+
def read_as_3d_array(fp, fix_coords=True):
""" Read binary binvox format as array.
@@ -189,8 +190,8 @@ def read_as_coord_array(fp, fix_coords=True):
# according to docs,
# index = x * wxh + z * width + y; // wxh = width * height = d * d
- x = nz_voxels / (dims[0]*dims[1])
- zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y
+ x = nz_voxels / (dims[0] * dims[1])
+ zwpy = nz_voxels % (dims[0] * dims[1]) # z*w + y
z = zwpy / dims[0]
y = zwpy % dims[0]
if fix_coords:
@@ -203,34 +204,38 @@ def read_as_coord_array(fp, fix_coords=True):
#return Voxels(data, dims, translate, scale, axis_order)
return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)
+
def dense_to_sparse(voxel_data, dtype=np.int):
""" From dense representation to sparse (coordinate) representation.
No coordinate reordering.
"""
- if voxel_data.ndim!=3:
+ if voxel_data.ndim != 3:
raise ValueError('voxel_data is wrong shape; should be 3D array.')
return np.asarray(np.nonzero(voxel_data), dtype)
+
def sparse_to_dense(voxel_data, dims, dtype=np.bool):
- if voxel_data.ndim!=2 or voxel_data.shape[0]!=3:
+ if voxel_data.ndim != 2 or voxel_data.shape[0] != 3:
raise ValueError('voxel_data is wrong shape; should be 3xN array.')
if np.isscalar(dims):
- dims = [dims]*3
+ dims = [dims] * 3
dims = np.atleast_2d(dims).T
# truncate to integers
xyz = voxel_data.astype(np.int)
# discard voxels that fall outside dims
valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)
- xyz = xyz[:,valid_ix]
+ xyz = xyz[:, valid_ix]
out = np.zeros(dims.flatten(), dtype=dtype)
out[tuple(xyz)] = True
return out
+
#def get_linear_index(x, y, z, dims):
- #""" Assuming xzy order. (y increasing fastest.
- #TODO ensure this is right when dims are not all same
- #"""
- #return x*(dims[1]*dims[2]) + z*dims[1] + y
+#""" Assuming xzy order. (y increasing fastest.
+#TODO ensure this is right when dims are not all same
+#"""
+#return x*(dims[1]*dims[2]) + z*dims[1] + y
+
def write(voxel_model, fp):
""" Write binary binvox format.
@@ -241,33 +246,33 @@ def write(voxel_model, fp):
Doesn't check if the model is 'sane'.
"""
- if voxel_model.data.ndim==2:
+ if voxel_model.data.ndim == 2:
# TODO avoid conversion to dense
dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims)
else:
dense_voxel_data = voxel_model.data
fp.write('#binvox 1\n')
- fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n')
- fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n')
- fp.write('scale '+str(voxel_model.scale)+'\n')
+ fp.write('dim ' + ' '.join(map(str, voxel_model.dims)) + '\n')
+ fp.write('translate ' + ' '.join(map(str, voxel_model.translate)) + '\n')
+ fp.write('scale ' + str(voxel_model.scale) + '\n')
fp.write('data\n')
if not voxel_model.axis_order in ('xzy', 'xyz'):
raise ValueError('Unsupported voxel model axis order')
- if voxel_model.axis_order=='xzy':
+ if voxel_model.axis_order == 'xzy':
voxels_flat = dense_voxel_data.flatten()
- elif voxel_model.axis_order=='xyz':
+ elif voxel_model.axis_order == 'xyz':
voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()
# keep a sort of state machine for writing run length encoding
state = voxels_flat[0]
ctr = 0
for c in voxels_flat:
- if c==state:
+ if c == state:
ctr += 1
# if ctr hits max, dump
- if ctr==255:
+ if ctr == 255:
fp.write(chr(state))
fp.write(chr(ctr))
ctr = 0
@@ -282,6 +287,7 @@ def write(voxel_model, fp):
fp.write(chr(state))
fp.write(chr(ctr))
+
if __name__ == '__main__':
import doctest
doctest.testmod()
diff --git a/lib/pymafx/utils/blob.py b/lib/pymafx/utils/blob.py
index 0d989f13c139abe4905280579c083b93b92e68d8..00123338e18a3fa74a6c3cb730cac9fb41b59ac5 100644
--- a/lib/pymafx/utils/blob.py
+++ b/lib/pymafx/utils/blob.py
@@ -45,9 +45,7 @@ def get_image_blob(im, target_scale, target_max_size):
im_scale (float): image scale (target size) / (original size)
im_info (ndarray)
"""
- processed_im, im_scale = prep_im_for_blob(
- im, cfg.PIXEL_MEANS, [target_scale], target_max_size
- )
+ processed_im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, [target_scale], target_max_size)
blob = im_list_to_blob(processed_im)
# NOTE: this height and width may be larger than actual scaled input image
# due to the FPN.COARSEST_STRIDE related padding in im_list_to_blob. We are
@@ -76,8 +74,7 @@ def im_list_to_blob(ims):
max_shape = get_max_shape([im.shape[:2] for im in ims])
num_images = len(ims)
- blob = np.zeros(
- (num_images, max_shape[0], max_shape[1], 3), dtype=np.float32)
+ blob = np.zeros((num_images, max_shape[0], max_shape[1], 3), dtype=np.float32)
for i in range(num_images):
im = ims[i]
blob[i, 0:im.shape[0], 0:im.shape[1], :] = im
@@ -119,8 +116,9 @@ def prep_im_for_blob(im, pixel_means, target_sizes, max_size):
im_scales = []
for target_size in target_sizes:
im_scale = get_target_scale(im_size_min, im_size_max, target_size, max_size)
- im_resized = cv2.resize(im, None, None, fx=im_scale, fy=im_scale,
- interpolation=cv2.INTER_LINEAR)
+ im_resized = cv2.resize(
+ im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR
+ )
ims.append(im_resized)
im_scales.append(im_scale)
return ims, im_scales
diff --git a/lib/pymafx/utils/cam_params.py b/lib/pymafx/utils/cam_params.py
index 7cf877bbf0de1f952d1be3efe200439712b289b5..1f6c1a8d89b2c80d72c90c841d02425df77aa4a5 100644
--- a/lib/pymafx/utils/cam_params.py
+++ b/lib/pymafx/utils/cam_params.py
@@ -22,6 +22,7 @@ import joblib
from .geometry import batch_euler2matrix
+
def f_pix2vfov(f_pix, img_h):
if torch.is_tensor(f_pix):
@@ -31,6 +32,7 @@ def f_pix2vfov(f_pix, img_h):
return fov
+
def vfov2f_pix(fov, img_h):
if torch.is_tensor(fov):
@@ -40,6 +42,7 @@ def vfov2f_pix(fov, img_h):
return f_pix
+
def read_cam_params(cam_params, orig_shape=None):
# These are predicted camera parameters
# cam_param_folder = CAM_PARAM_FOLDERS[dataset_name][cam_param_type]
@@ -69,6 +72,7 @@ def read_cam_params(cam_params, orig_shape=None):
return cam_rotmat, cam_int, cam_vfov, cam_pitch, cam_roll, cam_focal_length
+
def homo_vector(vector):
"""
vector: B x N x C
diff --git a/lib/pymafx/utils/collections.py b/lib/pymafx/utils/collections.py
index 465c9df196d762430d2318fc30e85ccd107b8b84..edd20a8c89d5d2221dc9d35948eda12c6304ba29 100644
--- a/lib/pymafx/utils/collections.py
+++ b/lib/pymafx/utils/collections.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
-
"""A simple attribute dictionary used for representing configuration options."""
from __future__ import absolute_import
@@ -45,8 +44,7 @@ class AttrDict(dict):
self[name] = value
else:
raise AttributeError(
- 'Attempted to set "{}" to "{}", but AttrDict is immutable'.
- format(name, value)
+ 'Attempted to set "{}" to "{}", but AttrDict is immutable'.format(name, value)
)
def immutable(self, is_immutable):
diff --git a/lib/pymafx/utils/colormap.py b/lib/pymafx/utils/colormap.py
index bc6869f289a9c47519ca69bdddba3dd4fa82ea27..44ef28c050021a6f03d088e9437de0c4adeb5ee5 100644
--- a/lib/pymafx/utils/colormap.py
+++ b/lib/pymafx/utils/colormap.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
-
"""An awesome colormap for really neat visualizations."""
from __future__ import absolute_import
@@ -26,85 +25,26 @@ import numpy as np
def colormap(rgb=False):
color_list = np.array(
[
- 0.000, 0.447, 0.741,
- 0.850, 0.325, 0.098,
- 0.929, 0.694, 0.125,
- 0.494, 0.184, 0.556,
- 0.466, 0.674, 0.188,
- 0.301, 0.745, 0.933,
- 0.635, 0.078, 0.184,
- 0.300, 0.300, 0.300,
- 0.600, 0.600, 0.600,
- 1.000, 0.000, 0.000,
- 1.000, 0.500, 0.000,
- 0.749, 0.749, 0.000,
- 0.000, 1.000, 0.000,
- 0.000, 0.000, 1.000,
- 0.667, 0.000, 1.000,
- 0.333, 0.333, 0.000,
- 0.333, 0.667, 0.000,
- 0.333, 1.000, 0.000,
- 0.667, 0.333, 0.000,
- 0.667, 0.667, 0.000,
- 0.667, 1.000, 0.000,
- 1.000, 0.333, 0.000,
- 1.000, 0.667, 0.000,
- 1.000, 1.000, 0.000,
- 0.000, 0.333, 0.500,
- 0.000, 0.667, 0.500,
- 0.000, 1.000, 0.500,
- 0.333, 0.000, 0.500,
- 0.333, 0.333, 0.500,
- 0.333, 0.667, 0.500,
- 0.333, 1.000, 0.500,
- 0.667, 0.000, 0.500,
- 0.667, 0.333, 0.500,
- 0.667, 0.667, 0.500,
- 0.667, 1.000, 0.500,
- 1.000, 0.000, 0.500,
- 1.000, 0.333, 0.500,
- 1.000, 0.667, 0.500,
- 1.000, 1.000, 0.500,
- 0.000, 0.333, 1.000,
- 0.000, 0.667, 1.000,
- 0.000, 1.000, 1.000,
- 0.333, 0.000, 1.000,
- 0.333, 0.333, 1.000,
- 0.333, 0.667, 1.000,
- 0.333, 1.000, 1.000,
- 0.667, 0.000, 1.000,
- 0.667, 0.333, 1.000,
- 0.667, 0.667, 1.000,
- 0.667, 1.000, 1.000,
- 1.000, 0.000, 1.000,
- 1.000, 0.333, 1.000,
- 1.000, 0.667, 1.000,
- 0.167, 0.000, 0.000,
- 0.333, 0.000, 0.000,
- 0.500, 0.000, 0.000,
- 0.667, 0.000, 0.000,
- 0.833, 0.000, 0.000,
- 1.000, 0.000, 0.000,
- 0.000, 0.167, 0.000,
- 0.000, 0.333, 0.000,
- 0.000, 0.500, 0.000,
- 0.000, 0.667, 0.000,
- 0.000, 0.833, 0.000,
- 0.000, 1.000, 0.000,
- 0.000, 0.000, 0.167,
- 0.000, 0.000, 0.333,
- 0.000, 0.000, 0.500,
- 0.000, 0.000, 0.667,
- 0.000, 0.000, 0.833,
- 0.000, 0.000, 1.000,
- 0.000, 0.000, 0.000,
- 0.143, 0.143, 0.143,
- 0.286, 0.286, 0.286,
- 0.429, 0.429, 0.429,
- 0.571, 0.571, 0.571,
- 0.714, 0.714, 0.714,
- 0.857, 0.857, 0.857,
- 1.000, 1.000, 1.000
+ 0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, 0.184, 0.556,
+ 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078, 0.184, 0.300, 0.300, 0.300,
+ 0.600, 0.600, 0.600, 1.000, 0.000, 0.000, 1.000, 0.500, 0.000, 0.749, 0.749, 0.000,
+ 0.000, 1.000, 0.000, 0.000, 0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000,
+ 0.333, 0.667, 0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000,
+ 0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000, 1.000, 0.000,
+ 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000, 0.500, 0.333, 0.000, 0.500,
+ 0.333, 0.333, 0.500, 0.333, 0.667, 0.500, 0.333, 1.000, 0.500, 0.667, 0.000, 0.500,
+ 0.667, 0.333, 0.500, 0.667, 0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500,
+ 1.000, 0.333, 0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000,
+ 0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333, 0.333, 1.000,
+ 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, 1.000, 0.667, 0.333, 1.000,
+ 0.667, 0.667, 1.000, 0.667, 1.000, 1.000, 1.000, 0.000, 1.000, 1.000, 0.333, 1.000,
+ 1.000, 0.667, 1.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000,
+ 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000,
+ 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000,
+ 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500,
+ 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000,
+ 0.143, 0.143, 0.143, 0.286, 0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571,
+ 0.714, 0.714, 0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000
]
).astype(np.float32)
color_list = color_list.reshape((-1, 3)) * 255
diff --git a/lib/pymafx/utils/common.py b/lib/pymafx/utils/common.py
index ee29cc1c3ed765261c265e15b3959ebecb439857..f3330ea18c4783ccacb21657808b8b8ce2301f86 100755
--- a/lib/pymafx/utils/common.py
+++ b/lib/pymafx/utils/common.py
@@ -4,7 +4,6 @@ import logging
from copy import deepcopy
from .utils.libkdtree import KDTree
-
logger_py = logging.getLogger(__name__)
@@ -37,6 +36,7 @@ def compute_iou(occ1, occ2):
return iou
+
def rgb2gray(rgb):
''' rgb of size B x h x w x 3
'''
@@ -46,8 +46,9 @@ def rgb2gray(rgb):
return gray
-def sample_patch_points(batch_size, n_points, patch_size=1,
- image_resolution=(128, 128), continuous=True):
+def sample_patch_points(
+ batch_size, n_points, patch_size=1, image_resolution=(128, 128), continuous=True
+):
''' Returns sampled points in the range [-1, 1].
Args:
@@ -60,21 +61,21 @@ def sample_patch_points(batch_size, n_points, patch_size=1,
continuous (bool): whether to sample continuously or only on pixel
locations
'''
- assert(patch_size > 0)
+ assert (patch_size > 0)
# Calculate step size for [-1, 1] that is equivalent to a pixel in
# original resolution
h_step = 1. / image_resolution[0]
w_step = 1. / image_resolution[1]
# Get number of patches
- patch_size_squared = patch_size ** 2
+ patch_size_squared = patch_size**2
n_patches = int(n_points / patch_size_squared)
if continuous:
- p = torch.rand(batch_size, n_patches, 2) # [0, 1]
+ p = torch.rand(batch_size, n_patches, 2) # [0, 1]
else:
- px = torch.randint(0, image_resolution[1], size=(
- batch_size, n_patches, 1)).float() / (image_resolution[1] - 1)
- py = torch.randint(0, image_resolution[0], size=(
- batch_size, n_patches, 1)).float() / (image_resolution[0] - 1)
+ px = torch.randint(0, image_resolution[1],
+ size=(batch_size, n_patches, 1)).float() / (image_resolution[1] - 1)
+ py = torch.randint(0, image_resolution[0],
+ size=(batch_size, n_patches, 1)).float() / (image_resolution[0] - 1)
p = torch.cat([px, py], dim=-1)
# Scale p to [0, (1 - (patch_size - 1) * step) ]
p[:, :, 0] *= 1 - (patch_size - 1) * w_step
@@ -83,9 +84,8 @@ def sample_patch_points(batch_size, n_points, patch_size=1,
# Add points
patch_arange = torch.arange(patch_size)
x_offset, y_offset = torch.meshgrid(patch_arange, patch_arange)
- patch_offsets = torch.stack(
- [x_offset.reshape(-1), y_offset.reshape(-1)],
- dim=1).view(1, 1, -1, 2).repeat(batch_size, n_patches, 1, 1).float()
+ patch_offsets = torch.stack([x_offset.reshape(-1), y_offset.reshape(-1)],
+ dim=1).view(1, 1, -1, 2).repeat(batch_size, n_patches, 1, 1).float()
patch_offsets[:, :, :, 0] *= w_step
patch_offsets[:, :, :, 1] *= h_step
@@ -99,13 +99,12 @@ def sample_patch_points(batch_size, n_points, patch_size=1,
p = p.view(batch_size, -1, 2)
amax, amin = p.max(), p.min()
- assert(amax <= 1. and amin >= -1.)
+ assert (amax <= 1. and amin >= -1.)
return p
-def get_proposal_points_in_unit_cube(ray0, ray_direction, padding=0.1,
- eps=1e-6, n_steps=40):
+def get_proposal_points_in_unit_cube(ray0, ray_direction, padding=0.1, eps=1e-6, n_steps=40):
''' Returns n_steps equally spaced points inside the unit cube on the rays
cast from ray0 with direction ray_direction.
@@ -138,8 +137,7 @@ def get_proposal_points_in_unit_cube(ray0, ray_direction, padding=0.1,
return d_proposal, mask_inside_cube
-def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1,
- eps=1e-6, scale=2.0):
+def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1, eps=1e-6, scale=2.0):
''' Checks if rays ray0 + d * ray_direction intersect with unit cube with
padding padding.
@@ -160,7 +158,7 @@ def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1,
# d = - /
# Get points on plane p_e
- p_distance = (scale * 0.5) + padding/2
+ p_distance = (scale * 0.5) + padding / 2
p_e = torch.ones(batch_size, n_pts, 6).to(device) * p_distance
p_e[:, :, 3:] *= -1.
@@ -185,35 +183,32 @@ def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1,
mask_inside_cube = p_mask_inside_cube.sum(-1) == 2
# Get interval values for p's which are valid
- p_intervals = p_intersect[mask_inside_cube][p_mask_inside_cube[
- mask_inside_cube]].view(-1, 2, 3)
+ p_intervals = p_intersect[mask_inside_cube][p_mask_inside_cube[mask_inside_cube]].view(-1, 2, 3)
p_intervals_batch = torch.zeros(batch_size, n_pts, 2, 3).to(device)
p_intervals_batch[mask_inside_cube] = p_intervals
# Calculate ray lengths for the interval points
d_intervals_batch = torch.zeros(batch_size, n_pts, 2).to(device)
norm_ray = torch.norm(ray_direction[mask_inside_cube], dim=-1)
- d_intervals_batch[mask_inside_cube] = torch.stack([
- torch.norm(p_intervals[:, 0] -
- ray0[mask_inside_cube], dim=-1) / norm_ray,
- torch.norm(p_intervals[:, 1] -
- ray0[mask_inside_cube], dim=-1) / norm_ray,
- ], dim=-1)
+ d_intervals_batch[mask_inside_cube] = torch.stack(
+ [
+ torch.norm(p_intervals[:, 0] - ray0[mask_inside_cube], dim=-1) / norm_ray,
+ torch.norm(p_intervals[:, 1] - ray0[mask_inside_cube], dim=-1) / norm_ray,
+ ],
+ dim=-1
+ )
# Sort the ray lengths
d_intervals_batch, indices_sort = d_intervals_batch.sort()
- p_intervals_batch = p_intervals_batch[
- torch.arange(batch_size).view(-1, 1, 1),
- torch.arange(n_pts).view(1, -1, 1),
- indices_sort
- ]
+ p_intervals_batch = p_intervals_batch[torch.arange(batch_size).view(-1, 1, 1),
+ torch.arange(n_pts).view(1, -1, 1), indices_sort]
return p_intervals_batch, d_intervals_batch, mask_inside_cube
def intersect_camera_rays_with_unit_cube(
- pixels, camera_mat, world_mat, scale_mat, padding=0.1, eps=1e-6,
- use_ray_length_as_depth=True):
+ pixels, camera_mat, world_mat, scale_mat, padding=0.1, eps=1e-6, use_ray_length_as_depth=True
+):
''' Returns the intersection points of ray cast from camera origin to
pixel points p on the image plane.
@@ -231,24 +226,22 @@ def intersect_camera_rays_with_unit_cube(
'''
batch_size, n_points, _ = pixels.shape
- pixel_world = image_points_to_world(
- pixels, camera_mat, world_mat, scale_mat)
- camera_world = origin_to_world(
- n_points, camera_mat, world_mat, scale_mat)
+ pixel_world = image_points_to_world(pixels, camera_mat, world_mat, scale_mat)
+ camera_world = origin_to_world(n_points, camera_mat, world_mat, scale_mat)
ray_vector = (pixel_world - camera_world)
p_cube, d_cube, mask_cube = check_ray_intersection_with_unit_cube(
- camera_world, ray_vector, padding=padding, eps=eps)
+ camera_world, ray_vector, padding=padding, eps=eps
+ )
if not use_ray_length_as_depth:
- p_cam = transform_to_camera_space(p_cube.view(
- batch_size, -1, 3), camera_mat, world_mat, scale_mat).view(
- batch_size, n_points, -1, 3)
+ p_cam = transform_to_camera_space(
+ p_cube.view(batch_size, -1, 3), camera_mat, world_mat, scale_mat
+ ).view(batch_size, n_points, -1, 3)
d_cube = p_cam[:, :, :, -1]
return p_cube, d_cube, mask_cube
-def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),
- subsample_to=None):
+def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.), subsample_to=None):
''' Arranges pixels for given resolution in range image_range.
The function returns the unscaled pixel locations as integers and the
@@ -266,9 +259,8 @@ def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),
# Arrange pixel location in scale resolution
pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h))
- pixel_locations = torch.stack(
- [pixel_locations[0], pixel_locations[1]],
- dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)
+ pixel_locations = torch.stack([pixel_locations[0], pixel_locations[1]],
+ dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)
pixel_scaled = pixel_locations.clone().float()
# Shift and scale points to match image_range
@@ -278,10 +270,8 @@ def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),
pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc
# Subsample points if subsample_to is not None and > 0
- if (subsample_to is not None and subsample_to > 0 and
- subsample_to < n_points):
- idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,),
- replace=False)
+ if (subsample_to is not None and subsample_to > 0 and subsample_to < n_points):
+ idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to, ), replace=False)
pixel_scaled = pixel_scaled[:, idx]
pixel_locations = pixel_locations[:, idx]
@@ -342,15 +332,13 @@ def transform_pointcloud(pointcloud, transform):
transform (tensor): transformation of size 4 x 4
'''
- assert(transform.shape == (4, 4) and pointcloud.shape[-1] == 3)
+ assert (transform.shape == (4, 4) and pointcloud.shape[-1] == 3)
pcl, is_numpy = to_pytorch(pointcloud, True)
transform = to_pytorch(transform)
# Transform point cloud to homogen coordinate system
- pcl_hom = torch.cat([
- pcl, torch.ones(pcl.shape[0], 1)
- ], dim=-1).transpose(1, 0)
+ pcl_hom = torch.cat([pcl, torch.ones(pcl.shape[0], 1)], dim=-1).transpose(1, 0)
# Apply transformation to point cloud
pcl_hom_transformed = transform @ pcl_hom
@@ -371,13 +359,11 @@ def transform_points_batch(p, transform):
transform (tensor): transformation of size B x 4 x 4
'''
device = p.device
- assert(transform.shape[1:] == (4, 4) and p.shape[-1]
- == 3 and p.shape[0] == transform.shape[0])
+ assert (transform.shape[1:] == (4, 4) and p.shape[-1] == 3 and p.shape[0] == transform.shape[0])
# Transform points to homogen coordinates
- pcl_hom = torch.cat([
- p, torch.ones(p.shape[0], p.shape[1], 1).to(device)
- ], dim=-1).transpose(2, 1)
+ pcl_hom = torch.cat([p, torch.ones(p.shape[0], p.shape[1], 1).to(device)],
+ dim=-1).transpose(2, 1)
# Apply transformation
pcl_hom_transformed = transform @ pcl_hom
@@ -387,8 +373,9 @@ def transform_points_batch(p, transform):
return pcl_out
-def get_tensor_values(tensor, p, grid_sample=True, mode='nearest',
- with_mask=False, squeeze_channel_dim=False):
+def get_tensor_values(
+ tensor, p, grid_sample=True, mode='nearest', with_mask=False, squeeze_channel_dim=False
+):
'''
Returns values from tensor at given location p.
@@ -415,8 +402,7 @@ def get_tensor_values(tensor, p, grid_sample=True, mode='nearest',
p[:, :, 0] = (p[:, :, 0] + 1) * (w) / 2
p[:, :, 1] = (p[:, :, 1] + 1) * (h) / 2
p = p.long()
- values = tensor[torch.arange(batch_size).unsqueeze(-1), :, p[:, :, 1],
- p[:, :, 0]]
+ values = tensor[torch.arange(batch_size).unsqueeze(-1), :, p[:, :, 1], p[:, :, 0]]
if with_mask:
mask = get_mask(values)
@@ -436,8 +422,7 @@ def get_tensor_values(tensor, p, grid_sample=True, mode='nearest',
return values
-def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat,
- invert=True):
+def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat, invert=True):
''' Transforms pixel positions p with given depth value d to world coordinates.
Args:
@@ -448,7 +433,7 @@ def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat,
scale_mat (tensor): scale matrix
invert (bool): whether to invert matrices (default: true)
'''
- assert(pixels.shape[-1] == 2)
+ assert (pixels.shape[-1] == 2)
# Convert to pytorch
pixels, is_numpy = to_pytorch(pixels, True)
@@ -493,8 +478,8 @@ def transform_to_camera_space(p_world, camera_mat, world_mat, scale_mat):
device = p_world.device
# Transform world points to homogen coordinates
- p_world = torch.cat([p_world, torch.ones(
- batch_size, n_p, 1).to(device)], dim=-1).permute(0, 2, 1)
+ p_world = torch.cat([p_world, torch.ones(batch_size, n_p, 1).to(device)],
+ dim=-1).permute(0, 2, 1)
# Apply matrices to transform p_world to camera space
p_cam = camera_mat @ world_mat @ scale_mat @ p_world
@@ -536,8 +521,7 @@ def origin_to_world(n_points, camera_mat, world_mat, scale_mat, invert=True):
return p_world
-def image_points_to_world(image_points, camera_mat, world_mat, scale_mat,
- invert=True):
+def image_points_to_world(image_points, camera_mat, world_mat, scale_mat, invert=True):
''' Transforms points on image plane to world coordinates.
In contrast to transform_to_world, no depth value is needed as points on
@@ -551,12 +535,13 @@ def image_points_to_world(image_points, camera_mat, world_mat, scale_mat,
invert (bool): whether to invert matrices (default: true)
'''
batch_size, n_pts, dim = image_points.shape
- assert(dim == 2)
+ assert (dim == 2)
device = image_points.device
d_image = torch.ones(batch_size, n_pts, 1).to(device)
- return transform_to_world(image_points, d_image, camera_mat, world_mat,
- scale_mat, invert=invert)
+ return transform_to_world(
+ image_points, d_image, camera_mat, world_mat, scale_mat, invert=invert
+ )
def check_weights(params):
@@ -602,7 +587,7 @@ def get_logits_from_prob(probs, eps=1e-4):
probs (tensor): probability tensor
eps (float): epsilon value for numerical stability
'''
- probs = np.clip(probs, a_min=eps, a_max=1-eps)
+ probs = np.clip(probs, a_min=eps, a_max=1 - eps)
logits = np.log(probs / (1 - probs))
return logits
@@ -629,7 +614,7 @@ def chamfer_distance_naive(points1, points2):
points1 (numpy array): first point set
points2 (numpy array): second point set
'''
- assert(points1.size() == points2.size())
+ assert (points1.size() == points2.size())
batch_size, T, _ = points1.size()
points1 = points1.view(batch_size, T, 1, 3)
@@ -748,10 +733,16 @@ def make_3d_grid(bb_min, bb_max, shape):
return p
-def get_occupancy_loss_points(pixels, camera_mat, world_mat, scale_mat,
- depth_image=None, use_cube_intersection=True,
- occupancy_random_normal=False,
- depth_range=[0, 2.4]):
+def get_occupancy_loss_points(
+ pixels,
+ camera_mat,
+ world_mat,
+ scale_mat,
+ depth_image=None,
+ use_cube_intersection=True,
+ occupancy_random_normal=False,
+ depth_range=[0, 2.4]
+):
''' Returns 3D points for occupancy loss.
Args:
@@ -794,16 +785,19 @@ def get_occupancy_loss_points(pixels, camera_mat, world_mat, scale_mat,
if depth_image is not None:
depth_gt, mask_gt_depth = get_tensor_values(
- depth_image, pixels, squeeze_channel_dim=True, with_mask=True)
+ depth_image, pixels, squeeze_channel_dim=True, with_mask=True
+ )
d_occupancy[mask_gt_depth] = depth_gt[mask_gt_depth]
- p_occupancy = transform_to_world(pixels, d_occupancy.unsqueeze(-1),
- camera_mat, world_mat, scale_mat)
+ p_occupancy = transform_to_world(
+ pixels, d_occupancy.unsqueeze(-1), camera_mat, world_mat, scale_mat
+ )
return p_occupancy
-def get_freespace_loss_points(pixels, camera_mat, world_mat, scale_mat,
- use_cube_intersection=True, depth_range=[0, 2.4]):
+def get_freespace_loss_points(
+ pixels, camera_mat, world_mat, scale_mat, use_cube_intersection=True, depth_range=[0, 2.4]
+):
''' Returns 3D points for freespace loss.
Args:
@@ -832,7 +826,8 @@ def get_freespace_loss_points(pixels, camera_mat, world_mat, scale_mat,
device) * (d_cube[:, 1] - d_cube[:, 0])
p_freespace = transform_to_world(
- pixels, d_freespace.unsqueeze(-1), camera_mat, world_mat, scale_mat)
+ pixels, d_freespace.unsqueeze(-1), camera_mat, world_mat, scale_mat
+ )
return p_freespace
@@ -844,7 +839,6 @@ def normalize_tensor(tensor, min_norm=1e-5, feat_dim=-1):
min_norm (float): minimum norm for numerical stability
feat_dim (int): feature dimension in tensor (default: -1)
'''
- norm_tensor = torch.clamp(torch.norm(tensor, dim=feat_dim, keepdim=True),
- min=min_norm)
+ norm_tensor = torch.clamp(torch.norm(tensor, dim=feat_dim, keepdim=True), min=min_norm)
normed_tensor = tensor / norm_tensor
return normed_tensor
diff --git a/lib/pymafx/utils/data_loader.py b/lib/pymafx/utils/data_loader.py
index 2c34d300c43f15d7de24460f19f1a6da7d483d60..cc92ad223836e9de322bc80bbab887bb9ec3f17b 100644
--- a/lib/pymafx/utils/data_loader.py
+++ b/lib/pymafx/utils/data_loader.py
@@ -3,47 +3,57 @@ import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler
-class RandomSampler(Sampler):
+class RandomSampler(Sampler):
def __init__(self, data_source, checkpoint):
self.data_source = data_source
if checkpoint is not None and checkpoint['dataset_perm'] is not None:
self.dataset_perm = checkpoint['dataset_perm']
- self.perm = self.dataset_perm[checkpoint['batch_size']*checkpoint['batch_idx']:]
+ self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:]
else:
self.dataset_perm = torch.randperm(len(self.data_source)).tolist()
- self.perm = torch.randperm(len(self.data_source)).tolist()
+ self.perm = torch.randperm(len(self.data_source)).tolist()
def __iter__(self):
return iter(self.perm)
-
+
def __len__(self):
return len(self.perm)
-class SequentialSampler(Sampler):
+class SequentialSampler(Sampler):
def __init__(self, data_source, checkpoint):
self.data_source = data_source
if checkpoint is not None and checkpoint['dataset_perm'] is not None:
self.dataset_perm = checkpoint['dataset_perm']
- self.perm = self.dataset_perm[checkpoint['batch_size']*checkpoint['batch_idx']:]
+ self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:]
else:
self.dataset_perm = list(range(len(self.data_source)))
self.perm = self.dataset_perm
def __iter__(self):
return iter(self.perm)
-
+
def __len__(self):
return len(self.perm)
+
class CheckpointDataLoader(DataLoader):
"""
Extends torch.utils.data.DataLoader to handle resuming training from an arbitrary point within an epoch.
"""
- def __init__(self, dataset, checkpoint=None, batch_size=1,
- shuffle=False, num_workers=0, pin_memory=False, drop_last=True,
- timeout=0, worker_init_fn=None):
+ def __init__(
+ self,
+ dataset,
+ checkpoint=None,
+ batch_size=1,
+ shuffle=False,
+ num_workers=0,
+ pin_memory=False,
+ drop_last=True,
+ timeout=0,
+ worker_init_fn=None
+ ):
if shuffle:
sampler = RandomSampler(dataset, checkpoint)
@@ -54,5 +64,14 @@ class CheckpointDataLoader(DataLoader):
else:
self.checkpoint_batch_idx = 0
- super(CheckpointDataLoader, self).__init__(dataset, sampler=sampler, shuffle=False, batch_size=batch_size, num_workers=num_workers,
- drop_last=drop_last, pin_memory=pin_memory, timeout=timeout, worker_init_fn=None)
+ super(CheckpointDataLoader, self).__init__(
+ dataset,
+ sampler=sampler,
+ shuffle=False,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ timeout=timeout,
+ worker_init_fn=None
+ )
diff --git a/lib/pymafx/utils/demo_utils.py b/lib/pymafx/utils/demo_utils.py
index 40ec3d576e9f93d81ac789f74c5b54feb711e0ae..b1ad8da91c7a7f6f67d4770c9866a02a78aa5275 100644
--- a/lib/pymafx/utils/demo_utils.py
+++ b/lib/pymafx/utils/demo_utils.py
@@ -46,8 +46,8 @@ def preprocess_video(video, joints2d, bboxes, frames, scale=1.0, crop_size=224):
if joints2d is not None:
bboxes, time_pt1, time_pt2 = get_all_bbox_params(joints2d, vis_thresh=0.3)
- bboxes[:,2:] = 150. / bboxes[:,2:]
- bboxes = np.stack([bboxes[:,0], bboxes[:,1], bboxes[:,2], bboxes[:,2]]).T
+ bboxes[:, 2:] = 150. / bboxes[:, 2:]
+ bboxes = np.stack([bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 2]]).T
video = video[time_pt1:time_pt2]
joints2d = joints2d[time_pt1:time_pt2]
@@ -66,11 +66,8 @@ def preprocess_video(video, joints2d, bboxes, frames, scale=1.0, crop_size=224):
j2d = joints2d[idx] if joints2d is not None else None
norm_img, raw_img, kp_2d = get_single_image_crop_demo(
- img,
- bbox,
- kp_2d=j2d,
- scale=scale,
- crop_size=crop_size)
+ img, bbox, kp_2d=j2d, scale=scale, crop_size=crop_size
+ )
if joints2d is not None:
joints2d[idx] = kp_2d
@@ -88,16 +85,16 @@ def download_youtube_clip(url, download_folder):
def smplify_runner(
- pred_rotmat,
- pred_betas,
- pred_cam,
- j2d,
- device,
- batch_size,
- lr=1.0,
- opt_steps=1,
- use_lbfgs=True,
- pose2aa=True
+ pred_rotmat,
+ pred_betas,
+ pred_cam,
+ j2d,
+ device,
+ batch_size,
+ lr=1.0,
+ opt_steps=1,
+ use_lbfgs=True,
+ pose2aa=True
):
smplify = TemporalSMPLify(
step_size=lr,
@@ -106,7 +103,7 @@ def smplify_runner(
focal_length=5000.,
use_lbfgs=use_lbfgs,
device=device,
- # max_iter=10,
+ # max_iter=10,
)
# Convert predicted rotation matrices to axis-angle
if pose2aa:
@@ -115,18 +112,16 @@ def smplify_runner(
pred_pose = pred_rotmat
# Calculate camera parameters for smplify
- pred_cam_t = torch.stack([
- pred_cam[:, 1], pred_cam[:, 2],
- 2 * 5000 / (224 * pred_cam[:, 0] + 1e-9)
- ], dim=-1)
+ pred_cam_t = torch.stack(
+ [pred_cam[:, 1], pred_cam[:, 2], 2 * 5000 / (224 * pred_cam[:, 0] + 1e-9)], dim=-1
+ )
gt_keypoints_2d_orig = j2d
# Before running compute reprojection error of the network
opt_joint_loss = smplify.get_fitting_loss(
- pred_pose.detach(), pred_betas.detach(),
- pred_cam_t.detach(),
- 0.5 * 224 * torch.ones(batch_size, 2, device=device),
- gt_keypoints_2d_orig).mean(dim=-1)
+ pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(),
+ 0.5 * 224 * torch.ones(batch_size, 2, device=device), gt_keypoints_2d_orig
+ ).mean(dim=-1)
best_prediction_id = torch.argmin(opt_joint_loss).item()
pred_betas = pred_betas[best_prediction_id].unsqueeze(0)
@@ -140,7 +135,8 @@ def smplify_runner(
# new_opt_pose, new_opt_betas, \
# new_opt_cam_t, \
output, new_opt_joint_loss = smplify(
- pred_pose.detach(), pred_betas.detach(),
+ pred_pose.detach(),
+ pred_betas.detach(),
pred_cam_t.detach(),
0.5 * 224 * torch.ones(batch_size, 2, device=device),
gt_keypoints_2d_orig,
@@ -152,29 +148,34 @@ def smplify_runner(
update = (new_opt_joint_loss < opt_joint_loss)
new_opt_vertices = output['verts']
- new_opt_cam_t = output['theta'][:,:3]
- new_opt_pose = output['theta'][:,3:75]
- new_opt_betas = output['theta'][:,75:]
+ new_opt_cam_t = output['theta'][:, :3]
+ new_opt_pose = output['theta'][:, 3:75]
+ new_opt_betas = output['theta'][:, 75:]
new_opt_joints3d = output['kp_3d']
return_val = [
- update, new_opt_vertices.cpu(), new_opt_cam_t.cpu(),
- new_opt_pose.cpu(), new_opt_betas.cpu(), new_opt_joints3d.cpu(),
- new_opt_joint_loss, opt_joint_loss,
+ update,
+ new_opt_vertices.cpu(),
+ new_opt_cam_t.cpu(),
+ new_opt_pose.cpu(),
+ new_opt_betas.cpu(),
+ new_opt_joints3d.cpu(),
+ new_opt_joint_loss,
+ opt_joint_loss,
]
return return_val
def trim_videos(filename, start_time, end_time, output_filename):
- command = ['ffmpeg',
- '-i', '"%s"' % filename,
- '-ss', str(start_time),
- '-t', str(end_time - start_time),
- '-c:v', 'libx264', '-c:a', 'copy',
- '-threads', '1',
- '-loglevel', 'panic',
- '"%s"' % output_filename]
+ command = [
+ 'ffmpeg', '-i',
+ '"%s"' % filename, '-ss',
+ str(start_time), '-t',
+ str(end_time - start_time), '-c:v', 'libx264', '-c:a', 'copy', '-threads', '1', '-loglevel',
+ 'panic',
+ '"%s"' % output_filename
+ ]
# command = ' '.join(command)
subprocess.call(command)
@@ -187,11 +188,7 @@ def video_to_images(vid_file, img_folder=None, return_info=False):
print(img_folder)
os.makedirs(img_folder, exist_ok=True)
- command = ['ffmpeg',
- '-i', vid_file,
- '-f', 'image2',
- '-v', 'error',
- f'{img_folder}/%06d.png']
+ command = ['ffmpeg', '-i', vid_file, '-f', 'image2', '-v', 'error', f'{img_folder}/%06d.png']
print(f'Running \"{" ".join(command)}\"')
try:
@@ -236,8 +233,24 @@ def images_to_video(img_folder, output_vid_file):
os.makedirs(img_folder, exist_ok=True)
command = [
- 'ffmpeg', '-y', '-threads', '16', '-i', f'{img_folder}/%06d.png', '-profile:v', 'baseline',
- '-level', '3.0', '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-an', '-v', 'error', output_vid_file,
+ 'ffmpeg',
+ '-y',
+ '-threads',
+ '16',
+ '-i',
+ f'{img_folder}/%06d.png',
+ '-profile:v',
+ 'baseline',
+ '-level',
+ '3.0',
+ '-c:v',
+ 'libx264',
+ '-pix_fmt',
+ 'yuv420p',
+ '-an',
+ '-v',
+ 'error',
+ output_vid_file,
]
print(f'Running \"{" ".join(command)}\"')
@@ -257,12 +270,12 @@ def convert_crop_cam_to_orig_img(cam, bbox, img_width, img_height):
:param img_height (int): original image height
:return:
'''
- cx, cy, h = bbox[:,0], bbox[:,1], bbox[:,2]
+ cx, cy, h = bbox[:, 0], bbox[:, 1], bbox[:, 2]
hw, hh = img_width / 2., img_height / 2.
- sx = cam[:,0] * (1. / (img_width / h))
- sy = cam[:,0] * (1. / (img_height / h))
- tx = ((cx - hw) / hw / sx) + cam[:,1]
- ty = ((cy - hh) / hh / sy) + cam[:,2]
+ sx = cam[:, 0] * (1. / (img_width / h))
+ sy = cam[:, 0] * (1. / (img_height / h))
+ tx = ((cx - hw) / hw / sx) + cam[:, 1]
+ ty = ((cy - hh) / hh / sy) + cam[:, 2]
orig_cam = np.stack([sx, sy, tx, ty]).T
return orig_cam
@@ -272,19 +285,24 @@ def prepare_rendering_results(results_dict, nframes):
for person_id, person_data in results_dict.items():
for idx, frame_id in enumerate(person_data['frame_ids']):
frame_results[frame_id][person_id] = {
- 'verts': person_data['verts'][idx],
- 'smplx_verts': person_data['smplx_verts'][idx] if 'smplx_verts' in person_data else None,
- 'cam': person_data['orig_cam'][idx],
- 'cam_t': person_data['orig_cam_t'][idx] if 'orig_cam_t' in person_data else None,
- # 'cam': person_data['pred_cam'][idx],
+ 'verts':
+ person_data['verts'][idx],
+ 'smplx_verts':
+ person_data['smplx_verts'][idx] if 'smplx_verts' in person_data else None,
+ 'cam':
+ person_data['orig_cam'][idx],
+ 'cam_t':
+ person_data['orig_cam_t'][idx] if 'orig_cam_t' in person_data else None,
+ # 'cam': person_data['pred_cam'][idx],
}
# naive depth ordering based on the scale of the weak perspective camera
for frame_id, frame_data in enumerate(frame_results):
# sort based on y-scale of the cam in original image coords
- sort_idx = np.argsort([v['cam'][1] for k,v in frame_data.items()])
+ sort_idx = np.argsort([v['cam'][1] for k, v in frame_data.items()])
frame_results[frame_id] = OrderedDict(
- {list(frame_data.keys())[i]:frame_data[list(frame_data.keys())[i]] for i in sort_idx}
+ {list(frame_data.keys())[i]: frame_data[list(frame_data.keys())[i]]
+ for i in sort_idx}
)
return frame_results
diff --git a/lib/pymafx/utils/densepose_methods.py b/lib/pymafx/utils/densepose_methods.py
index 3d12827674b12784a83da625b5e5fc50c1481e28..93fdf66a6651dcfe05f6e95c55379eaa00c52cb0 100644
--- a/lib/pymafx/utils/densepose_methods.py
+++ b/lib/pymafx/utils/densepose_methods.py
@@ -23,8 +23,9 @@ class DensePoseMethods:
self.All_vertices = ALP_UV['All_vertices'][0]
## Info to compute symmetries.
self.SemanticMaskSymmetries = [0, 1, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 14]
- self.Index_Symmetry_List = [1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24,
- 23];
+ self.Index_Symmetry_List = [
+ 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23
+ ]
UV_symmetry_filename = os.path.join('./data/UV_data', 'UV_symmetry_transforms.mat')
self.UV_symmetry_transformations = loadmat(UV_symmetry_filename)
@@ -65,17 +66,17 @@ class DensePoseMethods:
vCrossW = np.cross(v, w)
vCrossU = np.cross(v, u)
if (np.dot(vCrossW, vCrossU) < 0):
- return False;
+ return False
#
uCrossW = np.cross(u, w)
uCrossV = np.cross(u, v)
#
if (np.dot(uCrossW, uCrossV) < 0):
- return False;
+ return False
#
- denom = np.sqrt((uCrossV ** 2).sum())
- r = np.sqrt((vCrossW ** 2).sum()) / denom
- t = np.sqrt((uCrossW ** 2).sum()) / denom
+ denom = np.sqrt((uCrossV**2).sum())
+ r = np.sqrt((vCrossW**2).sum()) / denom
+ t = np.sqrt((uCrossW**2).sum()) / denom
#
return ((r <= 1) & (t <= 1) & (r + t <= 1))
@@ -90,9 +91,9 @@ class DensePoseMethods:
uCrossW = np.cross(u, w)
uCrossV = np.cross(u, v)
#
- denom = np.sqrt((uCrossV ** 2).sum())
- r = np.sqrt((vCrossW ** 2).sum()) / denom
- t = np.sqrt((uCrossW ** 2).sum()) / denom
+ denom = np.sqrt((uCrossV**2).sum())
+ r = np.sqrt((vCrossW**2).sum()) / denom
+ t = np.sqrt((uCrossW**2).sum()) / denom
#
return (1 - (r + t), r, t)
@@ -101,12 +102,24 @@ class DensePoseMethods:
FaceIndicesNow = np.where(self.FaceIndices == I_point)
FacesNow = self.FacesDensePose[FaceIndicesNow]
#
- P_0 = np.vstack((self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0],
- np.zeros(self.U_norm[FacesNow][:, 0].shape))).transpose()
- P_1 = np.vstack((self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1],
- np.zeros(self.U_norm[FacesNow][:, 1].shape))).transpose()
- P_2 = np.vstack((self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2],
- np.zeros(self.U_norm[FacesNow][:, 2].shape))).transpose()
+ P_0 = np.vstack(
+ (
+ self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0],
+ np.zeros(self.U_norm[FacesNow][:, 0].shape)
+ )
+ ).transpose()
+ P_1 = np.vstack(
+ (
+ self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1],
+ np.zeros(self.U_norm[FacesNow][:, 1].shape)
+ )
+ ).transpose()
+ P_2 = np.vstack(
+ (
+ self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2],
+ np.zeros(self.U_norm[FacesNow][:, 2].shape)
+ )
+ ).transpose()
#
for i, [P0, P1, P2] in enumerate(zip(P_0, P_1, P_2)):
@@ -116,9 +129,12 @@ class DensePoseMethods:
#
# If the found UV is not inside any faces, select the vertex that is closest!
#
- D1 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], P_0[:, 0:2]).squeeze()
- D2 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], P_1[:, 0:2]).squeeze()
- D3 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], P_2[:, 0:2]).squeeze()
+ D1 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :],
+ P_0[:, 0:2]).squeeze()
+ D2 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :],
+ P_1[:, 0:2]).squeeze()
+ D3 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :],
+ P_2[:, 0:2]).squeeze()
#
minD1 = D1.min()
minD2 = D2.min()
diff --git a/lib/pymafx/utils/geometry.py b/lib/pymafx/utils/geometry.py
index 804b08cf8dc3acc0f69b55ae672f3000b6e878b5..608288fc4d73a4918ab95938a7bf5dbe98ce606f 100644
--- a/lib/pymafx/utils/geometry.py
+++ b/lib/pymafx/utils/geometry.py
@@ -43,11 +43,13 @@ def quat_to_rotmat(quat):
wx, wy, wz = w * x, w * y, w * z
xy, xz, yz = x * y, x * z, y * z
- rotMat = torch.stack([
- w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
- 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
- ],
- dim=1).view(B, 3, 3)
+ rotMat = torch.stack(
+ [
+ w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
+ 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
+ ],
+ dim=1
+ ).view(B, 3, 3)
return rotMat
@@ -74,7 +76,8 @@ def rotation_matrix_to_angle_axis(rotation_matrix):
if rotation_matrix.shape[1:] == (3, 3):
rot_mat = rotation_matrix.reshape(-1, 3, 3)
hom = torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device).reshape(
- 1, 3, 1).expand(rot_mat.shape[0], -1, -1)
+ 1, 3, 1
+ ).expand(rot_mat.shape[0], -1, -1)
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
@@ -109,8 +112,9 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
if not quaternion.shape[-1] == 4:
- raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format(
- quaternion.shape))
+ raise ValueError(
+ "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
+ )
# unpack input and compute conversion
q1: torch.Tensor = quaternion[..., 1]
q2: torch.Tensor = quaternion[..., 2]
@@ -119,8 +123,9 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
cos_theta: torch.Tensor = quaternion[..., 0]
- two_theta: torch.Tensor = 2.0 * torch.where(cos_theta < 0.0, torch.atan2(
- -sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta))
+ two_theta: torch.Tensor = 2.0 * torch.where(
+ cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta)
+ )
k_pos: torch.Tensor = two_theta / sin_theta
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
@@ -155,8 +160,9 @@ def quaternion_to_angle(quaternion: torch.Tensor) -> torch.Tensor:
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
if not quaternion.shape[-1] == 4:
- raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format(
- quaternion.shape))
+ raise ValueError(
+ "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
+ )
# unpack input and compute conversion
q1: torch.Tensor = quaternion[..., 1]
q2: torch.Tensor = quaternion[..., 2]
@@ -165,8 +171,9 @@ def quaternion_to_angle(quaternion: torch.Tensor) -> torch.Tensor:
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
cos_theta: torch.Tensor = quaternion[..., 0]
- theta: torch.Tensor = 2.0 * torch.where(cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta),
- torch.atan2(sin_theta, cos_theta))
+ theta: torch.Tensor = 2.0 * torch.where(
+ cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta)
+ )
# theta: torch.Tensor = 2.0 * torch.atan2(sin_theta, cos_theta)
@@ -202,8 +209,9 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
if len(rotation_matrix.shape) > 3:
- raise ValueError("Input size must be a three dimensional tensor. Got {}".format(
- rotation_matrix.shape))
+ raise ValueError(
+ "Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape)
+ )
# if not rotation_matrix.shape[-2:] == (3, 4):
# raise ValueError(
# "Input size must be a N x 3 x 4 tensor. Got {}".format(
@@ -217,31 +225,39 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
- q0 = torch.stack([
- rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
- rmat_t[:, 2, 0] + rmat_t[:, 0, 2]
- ], -1)
+ q0 = torch.stack(
+ [
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2]
+ ], -1
+ )
t0_rep = t0.repeat(4, 1).t()
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
- q1 = torch.stack([
- rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1,
- rmat_t[:, 1, 2] + rmat_t[:, 2, 1]
- ], -1)
+ q1 = torch.stack(
+ [
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1,
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1]
+ ], -1
+ )
t1_rep = t1.repeat(4, 1).t()
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
- q2 = torch.stack([
- rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
- rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2
- ], -1)
+ q2 = torch.stack(
+ [
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2
+ ], -1
+ )
t2_rep = t2.repeat(4, 1).t()
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
- q3 = torch.stack([
- t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
- rmat_t[:, 0, 1] - rmat_t[:, 1, 0]
- ], -1)
+ q3 = torch.stack(
+ [
+ t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0]
+ ], -1
+ )
t3_rep = t3.repeat(4, 1).t()
mask_c0 = mask_d2 * mask_d0_d1
@@ -254,8 +270,10 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
- q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
- t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
+ q /= torch.sqrt(
+ t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
+ t2_rep * mask_c2 + t3_rep * mask_c3
+ ) # noqa
q *= 0.5
return q
@@ -303,11 +321,13 @@ def quaternion_to_rotation_matrix(quat):
wx, wy, wz = w * x, w * y, w * z
xy, xz, yz = x * y, x * z, y * z
- rotMat = torch.stack([
- w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
- 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
- ],
- dim=1).view(B, 3, 3)
+ rotMat = torch.stack(
+ [
+ w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
+ 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
+ ],
+ dim=1
+ ).view(B, 3, 3)
return rotMat
@@ -386,16 +406,18 @@ def projection(pred_joints, pred_camera, retain_z=False, iwp_mode=True):
if iwp_mode:
cam_sxy = pred_camera['cam_sxy']
pred_cam_t = torch.stack(
- [cam_sxy[:, 1], cam_sxy[:, 2], 2 * 5000. / (224. * cam_sxy[:, 0] + 1e-9)], dim=-1)
+ [cam_sxy[:, 1], cam_sxy[:, 2], 2 * 5000. / (224. * cam_sxy[:, 0] + 1e-9)], dim=-1
+ )
camera_center = torch.zeros(batch_size, 2)
- pred_keypoints_2d = perspective_projection(pred_joints,
- rotation=torch.eye(3).unsqueeze(0).expand(
- batch_size, -1, -1).to(pred_joints.device),
- translation=pred_cam_t,
- focal_length=5000.,
- camera_center=camera_center,
- retain_z=retain_z)
+ pred_keypoints_2d = perspective_projection(
+ pred_joints,
+ rotation=torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(pred_joints.device),
+ translation=pred_cam_t,
+ focal_length=5000.,
+ camera_center=camera_center,
+ retain_z=retain_z
+ )
# # Normalize keypoints to [-1,1]
# pred_keypoints_2d = pred_keypoints_2d / (224. / 2.)
else:
@@ -427,13 +449,15 @@ def projection(pred_joints, pred_camera, retain_z=False, iwp_mode=True):
return pred_keypoints_2d
-def perspective_projection(points,
- rotation,
- translation,
- focal_length=None,
- camera_center=None,
- cam_intrinsics=None,
- retain_z=False):
+def perspective_projection(
+ points,
+ rotation,
+ translation,
+ focal_length=None,
+ camera_center=None,
+ cam_intrinsics=None,
+ retain_z=False
+):
"""
This function computes the perspective projection of a set of points.
Input:
@@ -513,10 +537,12 @@ def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_si
weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
# least squares
- Q = np.array([
- F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
- O - np.reshape(joints_2d, -1)
- ]).T
+ Q = np.array(
+ [
+ F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
+ O - np.reshape(joints_2d, -1)
+ ]
+ ).T
c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
# weighted least squares
@@ -570,11 +596,9 @@ def estimate_translation(S, joints_2d, focal_length=5000., img_size=224., use_al
S_i = S[i]
joints_i = joints_2d[i]
conf_i = joints_conf[i]
- trans[i] = estimate_translation_np(S_i,
- joints_i,
- conf_i,
- focal_length=focal_length[i],
- img_size=img_size[i])
+ trans[i] = estimate_translation_np(
+ S_i, joints_i, conf_i, focal_length=focal_length[i], img_size=img_size[i]
+ )
return torch.from_numpy(trans).to(device)
@@ -585,8 +609,10 @@ def Rot_y(angle, category='torch', prepend_dim=True, device=None):
prepend_dim: prepend an extra dimension
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
'''
- m = np.array([[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.],
- [-np.sin(angle), 0., np.cos(angle)]])
+ m = np.array(
+ [[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0.,
+ np.cos(angle)]]
+ )
if category == 'torch':
if prepend_dim:
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -608,8 +634,10 @@ def Rot_x(angle, category='torch', prepend_dim=True, device=None):
prepend_dim: prepend an extra dimension
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
'''
- m = np.array([[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)],
- [0., np.sin(angle), np.cos(angle)]])
+ m = np.array(
+ [[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle),
+ np.cos(angle)]]
+ )
if category == 'torch':
if prepend_dim:
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -631,8 +659,9 @@ def Rot_z(angle, category='torch', prepend_dim=True, device=None):
prepend_dim: prepend an extra dimension
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
'''
- m = np.array([[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle),
- np.cos(angle), 0.], [0., 0., 1.]])
+ m = np.array(
+ [[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]]
+ )
if category == 'torch':
if prepend_dim:
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -674,7 +703,7 @@ def compute_twist_rotation(rotation_matrix, twist_axis):
twist_rotation = quaternion_to_rotation_matrix(twist_quaternion)
twist_aa = quaternion_to_angle_axis(twist_quaternion)
- twist_angle = torch.sum(twist_aa, dim=1, keepdim=True) / torch.sum(
- twist_axis, dim=1, keepdim=True)
+ twist_angle = torch.sum(twist_aa, dim=1,
+ keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True)
return twist_rotation, twist_angle
diff --git a/lib/pymafx/utils/imutils.py b/lib/pymafx/utils/imutils.py
index ad75ff1d206ef344a98abcc2dc32fa4c7fcb6012..b3522fee118cf47c5101bfd8e16991e5c30f58ad 100644
--- a/lib/pymafx/utils/imutils.py
+++ b/lib/pymafx/utils/imutils.py
@@ -9,6 +9,7 @@ from PIL import Image
from lib.pymafx.core import constants
+
def get_transform(center, scale, res, rot=0):
"""Generate transformation matrix."""
h = 200 * scale
@@ -19,29 +20,31 @@ def get_transform(center, scale, res, rot=0):
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if not rot == 0:
- t = np.dot(get_rot_transf(res, rot),t)
+ t = np.dot(get_rot_transf(res, rot), t)
return t
+
def get_rot_transf(res, rot):
"""Generate rotation transformation matrix."""
if rot == 0:
return np.identity(3)
- rot = -rot # To match direction of rotation from cropping
- rot_mat = np.zeros((3,3))
+ rot = -rot # To match direction of rotation from cropping
+ rot_mat = np.zeros((3, 3))
rot_rad = rot * np.pi / 180
- sn,cs = np.sin(rot_rad), np.cos(rot_rad)
- rot_mat[0,:2] = [cs, -sn]
- rot_mat[1,:2] = [sn, cs]
- rot_mat[2,2] = 1
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0, :2] = [cs, -sn]
+ rot_mat[1, :2] = [sn, cs]
+ rot_mat[2, 2] = 1
# Need to rotate around center
t_mat = np.eye(3)
- t_mat[0,2] = -res[1]/2
- t_mat[1,2] = -res[0]/2
+ t_mat[0, 2] = -res[1] / 2
+ t_mat[1, 2] = -res[0] / 2
t_inv = t_mat.copy()
- t_inv[:2,2] *= -1
- rot_transf = np.dot(t_inv,np.dot(rot_mat,t_mat))
+ t_inv[:2, 2] *= -1
+ rot_transf = np.dot(t_inv, np.dot(rot_mat, t_mat))
return rot_transf
+
def transform(pt, center, scale, res, invert=0, rot=0):
"""Transform pixel location to different reference."""
t = get_transform(center, scale, res, rot=rot)
@@ -51,6 +54,7 @@ def transform(pt, center, scale, res, invert=0, rot=0):
new_pt = np.dot(t, new_pt)
return new_pt[:2].astype(int) + 1
+
def transform_pts(coords, center, scale, res, invert=0, rot=0):
"""Transform coordinates (N x 2) to different reference."""
new_coords = coords.copy()
@@ -58,14 +62,14 @@ def transform_pts(coords, center, scale, res, invert=0, rot=0):
new_coords[p, 0:2] = transform(coords[p, 0:2], center, scale, res, invert, rot)
return new_coords
+
def crop(img, center, scale, res, rot=0):
"""Crop image according to the supplied bounding box."""
# Upper left point
- ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
+ ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
# Bottom right point
- br = np.array(transform([res[0]+1,
- res[1]+1], center, scale, res, invert=1))-1
-
+ br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
+
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
if not rot == 0:
@@ -84,8 +88,7 @@ def crop(img, center, scale, res, rot=0):
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
- new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
- old_x[0]:old_x[1]]
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
if not rot == 0:
# Remove padding
@@ -95,15 +98,16 @@ def crop(img, center, scale, res, rot=0):
new_img_resized = np.array(Image.fromarray(new_img.astype(np.uint8)).resize(res))
return new_img_resized, new_img, new_shape
+
def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
"""'Undo' the image cropping/resizing.
This function is used when evaluating mask/part segmentation.
"""
res = img.shape[:2]
# Upper left point
- ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
+ ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
# Bottom right point
- br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
+ br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
# size of cropped image
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
@@ -121,19 +125,24 @@ def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
return new_img
+
def rot_aa(aa, rot):
"""Rotate axis angle parameters."""
# pose parameters
- R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
- [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
- [0, 0, 1]])
+ R = np.array(
+ [
+ [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], [0, 0, 1]
+ ]
+ )
# find the rotation of the body in camera frame
per_rdg, _ = cv2.Rodrigues(aa)
# apply the global rotation to the global orientation
- resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
+ resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
aa = (resrot.T)[0]
return aa
+
def flip_img(img):
"""Flip rgb images or masks.
channels come last, e.g. (256,256,3).
@@ -141,6 +150,7 @@ def flip_img(img):
img = np.fliplr(img)
return img
+
def flip_kp(kp, is_smpl=False, type='body'):
"""Flip keypoints."""
assert type in ['body', 'hand', 'face', 'feet']
@@ -164,11 +174,12 @@ def flip_kp(kp, is_smpl=False, type='body'):
flipped_parts = constants.FACE_FLIP_PERM
elif type == 'feet':
flipped_parts = constants.FEEF_FLIP_PERM
-
+
kp = kp[flipped_parts]
- kp[:,0] = - kp[:,0]
+ kp[:, 0] = -kp[:, 0]
return kp
+
def flip_pose(pose):
"""Flip pose.
The flipping is based on SMPL parameters.
@@ -180,6 +191,7 @@ def flip_pose(pose):
pose[2::3] = -pose[2::3]
return pose
+
def flip_aa(pose):
"""Flip aa.
"""
@@ -194,6 +206,7 @@ def flip_aa(pose):
raise NotImplementedError
return pose
+
def normalize_2d_kp(kp_2d, crop_size=224, inv=False):
# Normalize keypoints between -1, 1
if not inv:
@@ -201,10 +214,11 @@ def normalize_2d_kp(kp_2d, crop_size=224, inv=False):
kp_2d = 2.0 * kp_2d * ratio - 1.0
else:
ratio = 1.0 / crop_size
- kp_2d = (kp_2d + 1.0)/(2*ratio)
+ kp_2d = (kp_2d + 1.0) / (2 * ratio)
return kp_2d
+
def j2d_processing(kp, transf):
"""Process gt 2D keypoints and apply transforms."""
# nparts = kp.shape[1]
@@ -212,9 +226,10 @@ def j2d_processing(kp, transf):
kp_pad = torch.cat([kp, torch.ones((bs, npart, 1)).to(kp)], dim=-1)
kp_new = torch.bmm(transf, kp_pad.transpose(1, 2))
kp_new = kp_new.transpose(1, 2)
- kp_new[:, :, :-1] = 2.*kp_new[:, :, :-1] / constants.IMG_RES - 1.
+ kp_new[:, :, :-1] = 2. * kp_new[:, :, :-1] / constants.IMG_RES - 1.
return kp_new[:, :, :2]
+
def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
'''
param joints: [num_joints, 3]
@@ -231,11 +246,9 @@ def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
target_weight = np.ones((num_joints, 1), dtype=np.float32)
if joints_vis is not None:
target_weight[:, 0] = joints_vis[:, 0]
- target = torch.zeros((num_joints,
- heatmap_size[1],
- heatmap_size[0]),
- dtype=torch.float32,
- device=cur_device)
+ target = torch.zeros(
+ (num_joints, heatmap_size[1], heatmap_size[0]), dtype=torch.float32, device=cur_device
+ )
tmp_size = sigma * 3
@@ -264,7 +277,7 @@ def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
y = x.unsqueeze(-1)
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
- g = torch.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
+ g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
diff --git a/lib/pymafx/utils/io.py b/lib/pymafx/utils/io.py
index 3edb5227c3c58c060646770b8757b1bc61687a6b..0926624ddeb1eccf2e9c6393595acfd34a62e84d 100644
--- a/lib/pymafx/utils/io.py
+++ b/lib/pymafx/utils/io.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
-
"""IO utilities."""
from __future__ import absolute_import
@@ -28,7 +27,7 @@ import re
import sys
try:
from urllib.request import urlopen
-except ImportError: #python2
+except ImportError: #python2
from urllib2 import urlopen
logger = logging.getLogger(__name__)
@@ -59,8 +58,8 @@ def cache_url(url_or_file, cache_dir):
# 'bucket: {}').format(_DETECTRON_S3_BASE_URL)
#
# cache_file_path = url.replace(_DETECTRON_S3_BASE_URL, cache_dir)
- Len_filename = len(url.split('/')[-1])
- BASE_URL = url[0:-Len_filename-1]
+ Len_filename = len(url.split('/')[-1])
+ BASE_URL = url[0:-Len_filename - 1]
#
cache_file_path = url.replace(BASE_URL, cache_dir)
if os.path.exists(cache_file_path):
@@ -102,18 +101,13 @@ def _progress_bar(count, total):
percents = round(100.0 * count / float(total), 1)
bar = '=' * filled_len + '-' * (bar_len - filled_len)
- sys.stdout.write(
- ' [{}] {}% of {:.1f}MB file \r'.
- format(bar, percents, total / 1024 / 1024)
- )
+ sys.stdout.write(' [{}] {}% of {:.1f}MB file \r'.format(bar, percents, total / 1024 / 1024))
sys.stdout.flush()
if count >= total:
sys.stdout.write('\n')
-def download_url(
- url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar
-):
+def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
"""Download url and write it to dst_file_path.
Credit:
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
diff --git a/lib/pymafx/utils/iuvmap.py b/lib/pymafx/utils/iuvmap.py
index 7da02e6cdc1e656493f7566080f3108ca38cac87..7f7c25398e04e30b2b244d44badc83415d583852 100644
--- a/lib/pymafx/utils/iuvmap.py
+++ b/lib/pymafx/utils/iuvmap.py
@@ -9,11 +9,13 @@ def iuvmap_clean(U_uv, V_uv, Index_UV, AnnIndex=None):
recon_Index_UV = []
for i in range(Index_UV.size(1)):
if i == 0:
- recon_Index_UV_i = torch.min(F.threshold(Index_UV_max + 1, 0.5, 0),
- -F.threshold(-Index_UV_max - 1, -1.5, 0))
+ recon_Index_UV_i = torch.min(
+ F.threshold(Index_UV_max + 1, 0.5, 0), -F.threshold(-Index_UV_max - 1, -1.5, 0)
+ )
else:
- recon_Index_UV_i = torch.min(F.threshold(Index_UV_max, i - 0.5, 0),
- -F.threshold(-Index_UV_max, -i - 0.5, 0)) / float(i)
+ recon_Index_UV_i = torch.min(
+ F.threshold(Index_UV_max, i - 0.5, 0), -F.threshold(-Index_UV_max, -i - 0.5, 0)
+ ) / float(i)
recon_Index_UV.append(recon_Index_UV_i)
recon_Index_UV = torch.stack(recon_Index_UV, dim=1)
@@ -24,11 +26,13 @@ def iuvmap_clean(U_uv, V_uv, Index_UV, AnnIndex=None):
recon_Ann_Index = []
for i in range(AnnIndex.size(1)):
if i == 0:
- recon_Ann_Index_i = torch.min(F.threshold(AnnIndex_max + 1, 0.5, 0),
- -F.threshold(-AnnIndex_max - 1, -1.5, 0))
+ recon_Ann_Index_i = torch.min(
+ F.threshold(AnnIndex_max + 1, 0.5, 0), -F.threshold(-AnnIndex_max - 1, -1.5, 0)
+ )
else:
- recon_Ann_Index_i = torch.min(F.threshold(AnnIndex_max, i - 0.5, 0),
- -F.threshold(-AnnIndex_max, -i - 0.5, 0)) / float(i)
+ recon_Ann_Index_i = torch.min(
+ F.threshold(AnnIndex_max, i - 0.5, 0), -F.threshold(-AnnIndex_max, -i - 0.5, 0)
+ ) / float(i)
recon_Ann_Index.append(recon_Ann_Index_i)
recon_Ann_Index = torch.stack(recon_Ann_Index, dim=1)
@@ -66,8 +70,10 @@ def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=N
for part_id in range(0, K):
CurrentU = U_uv[batch_id, part_id]
CurrentV = V_uv[batch_id, part_id]
- output[1, Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id]
- output[2, Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id]
+ output[1,
+ Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id]
+ output[2,
+ Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id]
if uv_rois is None:
outputs.append(output.unsqueeze(0))
@@ -88,12 +94,16 @@ def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=N
new_size = [heatmap_size, max(int(heatmap_size * aspect_ratio), 1)]
output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest')
paddingleft = int(0.5 * (heatmap_size - new_size[1]))
- output = F.pad(output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0))
+ output = F.pad(
+ output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0)
+ )
else:
new_size = [max(int(heatmap_size / aspect_ratio), 1), heatmap_size]
output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest')
paddingtop = int(0.5 * (heatmap_size - new_size[0]))
- output = F.pad(output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop))
+ output = F.pad(
+ output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop)
+ )
outputs.append(output)
@@ -105,8 +115,10 @@ def iuv_img2map(uvimages, uv_rois=None, new_size=None, n_part=24):
batch_size = uvimages.size(0)
uvimg_size = uvimages.size(-1)
- Index2mask = [[0], [1, 2], [3], [4], [5], [6], [7, 9], [8, 10], [11, 13], [12, 14], [15, 17], [16, 18], [19, 21],
- [20, 22], [23, 24]]
+ Index2mask = [
+ [0], [1, 2], [3], [4], [5], [6], [7, 9], [8, 10], [11, 13], [12, 14], [15, 17], [16, 18],
+ [19, 21], [20, 22], [23, 24]
+ ]
part_ind = torch.round(uvimages[:, 0, :, :] * n_part)
part_u = uvimages[:, 1, :, :]
@@ -117,12 +129,15 @@ def iuv_img2map(uvimages, uv_rois=None, new_size=None, n_part=24):
recon_Index_UV = []
recon_Ann_Index = []
- for i in range(n_part+1):
+ for i in range(n_part + 1):
if i == 0:
- recon_Index_UV_i = torch.min(F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0))
+ recon_Index_UV_i = torch.min(
+ F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0)
+ )
else:
- recon_Index_UV_i = torch.min(F.threshold(part_ind, i - 0.5, 0),
- -F.threshold(-part_ind, -i - 0.5, 0)) / float(i)
+ recon_Index_UV_i = torch.min(
+ F.threshold(part_ind, i - 0.5, 0), -F.threshold(-part_ind, -i - 0.5, 0)
+ ) / float(i)
recon_U_i = recon_Index_UV_i * part_u
recon_V_i = recon_Index_UV_i * part_v
@@ -192,8 +207,12 @@ def iuv_img2map(uvimages, uv_rois=None, new_size=None, n_part=24):
recon_U_roi_i = F.interpolate(recon_U_roi_i.unsqueeze(0), size=(M, M), mode='nearest')
recon_V_roi_i = F.interpolate(recon_V_roi_i.unsqueeze(0), size=(M, M), mode='nearest')
- recon_Index_UV_roi_i = F.interpolate(recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest')
- recon_Ann_Index_roi_i = F.interpolate(recon_Ann_Index_roi_i.unsqueeze(0), size=(M, M), mode='nearest')
+ recon_Index_UV_roi_i = F.interpolate(
+ recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest'
+ )
+ recon_Ann_Index_roi_i = F.interpolate(
+ recon_Ann_Index_roi_i.unsqueeze(0), size=(M, M), mode='nearest'
+ )
recon_U_roi.append(recon_U_roi_i)
recon_V_roi.append(recon_V_roi_i)
@@ -217,12 +236,15 @@ def seg_img2map(segimages, uv_rois=None, new_size=None, n_part=24):
recon_Index_UV = []
- for i in range(n_part+1):
+ for i in range(n_part + 1):
if i == 0:
- recon_Index_UV_i = torch.min(F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0))
+ recon_Index_UV_i = torch.min(
+ F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0)
+ )
else:
- recon_Index_UV_i = torch.min(F.threshold(part_ind, i - 0.5, 0),
- -F.threshold(-part_ind, -i - 0.5, 0)) / float(i)
+ recon_Index_UV_i = torch.min(
+ F.threshold(part_ind, i - 0.5, 0), -F.threshold(-part_ind, -i - 0.5, 0)
+ ) / float(i)
recon_Index_UV.append(recon_Index_UV_i)
@@ -262,7 +284,9 @@ def seg_img2map(segimages, uv_rois=None, new_size=None, n_part=24):
recon_Index_UV_roi_i = recon_Index_UV[i, :, h_margin:h_margin + h_size, :]
- recon_Index_UV_roi_i = F.interpolate(recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest')
+ recon_Index_UV_roi_i = F.interpolate(
+ recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest'
+ )
recon_Index_UV_roi.append(recon_Index_UV_roi_i)
diff --git a/lib/pymafx/utils/keypoints.py b/lib/pymafx/utils/keypoints.py
index b505616e14436bcfecdaf3b65a18255ce98b86b0..2ab223c2bef79518adc523da1606cfc331ef8251 100644
--- a/lib/pymafx/utils/keypoints.py
+++ b/lib/pymafx/utils/keypoints.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
-
"""Keypoint utilities (somewhat specific to COCO keypoints)."""
from __future__ import absolute_import
@@ -35,23 +34,9 @@ def get_keypoints():
# Keypoints are not available in the COCO json for the test split, so we
# provide them here.
keypoints = [
- 'nose',
- 'left_eye',
- 'right_eye',
- 'left_ear',
- 'right_ear',
- 'left_shoulder',
- 'right_shoulder',
- 'left_elbow',
- 'right_elbow',
- 'left_wrist',
- 'right_wrist',
- 'left_hip',
- 'right_hip',
- 'left_knee',
- 'right_knee',
- 'left_ankle',
- 'right_ankle'
+ 'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear', 'left_shoulder', 'right_shoulder',
+ 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
+ 'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
]
keypoint_flip_map = {
'left_eye': 'right_eye',
@@ -126,8 +111,7 @@ def heatmaps_to_keypoints(maps, rois):
# NCHW to NHWC for use with OpenCV
maps = np.transpose(maps, [0, 2, 3, 1])
min_size = cfg.KRCNN.INFERENCE_MIN_SIZE
- xy_preds = np.zeros(
- (len(rois), 4, cfg.KRCNN.NUM_KEYPOINTS), dtype=np.float32)
+ xy_preds = np.zeros((len(rois), 4, cfg.KRCNN.NUM_KEYPOINTS), dtype=np.float32)
for i in range(len(rois)):
if min_size > 0:
roi_map_width = int(np.maximum(widths_ceil[i], min_size))
@@ -138,8 +122,8 @@ def heatmaps_to_keypoints(maps, rois):
width_correction = widths[i] / roi_map_width
height_correction = heights[i] / roi_map_height
roi_map = cv2.resize(
- maps[i], (roi_map_width, roi_map_height),
- interpolation=cv2.INTER_CUBIC)
+ maps[i], (roi_map_width, roi_map_height), interpolation=cv2.INTER_CUBIC
+ )
# Bring back to CHW
roi_map = np.transpose(roi_map, [2, 0, 1])
roi_map_probs = scores_to_probs(roi_map.copy())
@@ -148,8 +132,7 @@ def heatmaps_to_keypoints(maps, rois):
pos = roi_map[k, :, :].argmax()
x_int = pos % w
y_int = (pos - x_int) // w
- assert (roi_map_probs[k, y_int, x_int] ==
- roi_map_probs[k, :, :].max())
+ assert (roi_map_probs[k, y_int, x_int] == roi_map_probs[k, :, :].max())
x = (x_int + 0.5) * width_correction
y = (y_int + 0.5) * height_correction
xy_preds[i, 0, k] = x + offset_x[i]
@@ -201,8 +184,8 @@ def keypoints_to_heatmap_labels(keypoints, rois):
valid_loc = np.logical_and(
np.logical_and(x >= 0, y >= 0),
- np.logical_and(
- x < cfg.KRCNN.HEATMAP_SIZE, y < cfg.KRCNN.HEATMAP_SIZE))
+ np.logical_and(x < cfg.KRCNN.HEATMAP_SIZE, y < cfg.KRCNN.HEATMAP_SIZE)
+ )
valid = np.logical_and(valid_loc, vis)
valid = valid.astype(np.int32)
@@ -234,9 +217,7 @@ def nms_oks(kp_predictions, rois, thresh):
while order.size > 0:
i = order[0]
keep.append(i)
- ovr = compute_oks(
- kp_predictions[i], rois[i], kp_predictions[order[1:]],
- rois[order[1:]])
+ ovr = compute_oks(kp_predictions[i], rois[i], kp_predictions[order[1:]], rois[order[1:]])
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
@@ -251,9 +232,9 @@ def compute_oks(src_keypoints, src_roi, dst_keypoints, dst_roi):
dst_roi: Nx4
"""
- sigmas = np.array([
- .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87,
- .87, .89, .89]) / 10.0
+ sigmas = np.array(
+ [.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]
+ ) / 10.0
vars = (sigmas * 2)**2
# area
@@ -313,9 +294,15 @@ def generate_3d_integral_preds_tensor(heatmaps, num_joints, x_dim, y_dim, z_dim)
accu_z = heatmaps.sum(dim=3)
accu_z = accu_z.sum(dim=3)
- accu_x = accu_x * torch.cuda.comm.broadcast(torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index])[0]
- accu_y = accu_y * torch.cuda.comm.broadcast(torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index])[0]
- accu_z = accu_z * torch.cuda.comm.broadcast(torch.arange(z_dim, dtype=torch.float32), devices=[accu_z.device.index])[0]
+ accu_x = accu_x * torch.cuda.comm.broadcast(
+ torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index]
+ )[0]
+ accu_y = accu_y * torch.cuda.comm.broadcast(
+ torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index]
+ )[0]
+ accu_z = accu_z * torch.cuda.comm.broadcast(
+ torch.arange(z_dim, dtype=torch.float32), devices=[accu_z.device.index]
+ )[0]
accu_x = accu_x.sum(dim=2, keepdim=True)
accu_y = accu_y.sum(dim=2, keepdim=True)
@@ -326,8 +313,12 @@ def generate_3d_integral_preds_tensor(heatmaps, num_joints, x_dim, y_dim, z_dim)
accu_x = heatmaps.sum(dim=2)
accu_y = heatmaps.sum(dim=3)
- accu_x = accu_x * torch.cuda.comm.broadcast(torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index])[0]
- accu_y = accu_y * torch.cuda.comm.broadcast(torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index])[0]
+ accu_x = accu_x * torch.cuda.comm.broadcast(
+ torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index]
+ )[0]
+ accu_y = accu_y * torch.cuda.comm.broadcast(
+ torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index]
+ )[0]
accu_x = accu_x.sum(dim=2, keepdim=True)
accu_y = accu_y.sum(dim=2, keepdim=True)
@@ -347,14 +338,18 @@ def softmax_integral_tensor(preds, num_joints, hm_width, hm_height, hm_depth=Non
# integrate heatmap into joint location
if output_3d:
- x, y, z = generate_3d_integral_preds_tensor(preds, num_joints, hm_width, hm_height, hm_depth)
+ x, y, z = generate_3d_integral_preds_tensor(
+ preds, num_joints, hm_width, hm_height, hm_depth
+ )
# x = x / float(hm_width) - 0.5
# y = y / float(hm_height) - 0.5
# z = z / float(hm_depth) - 0.5
preds = torch.cat((x, y, z), dim=2)
# preds = preds.reshape((preds.shape[0], num_joints * 3))
else:
- x, y, _ = generate_3d_integral_preds_tensor(preds, num_joints, hm_width, hm_height, z_dim=None)
+ x, y, _ = generate_3d_integral_preds_tensor(
+ preds, num_joints, hm_width, hm_height, z_dim=None
+ )
# x = x / float(hm_width) - 0.5
# y = y / float(hm_height) - 0.5
preds = torch.cat((x, y), dim=2)
diff --git a/lib/pymafx/utils/mesh_generation.py b/lib/pymafx/utils/mesh_generation.py
index 94943f2b2fb6dc33427418f406d65a1b96efafdd..2876209e7678d2906a84850208f6c288103d07c5 100644
--- a/lib/pymafx/utils/mesh_generation.py
+++ b/lib/pymafx/utils/mesh_generation.py
@@ -33,13 +33,21 @@ class Generator3D(object):
size for refinement process (we added this functionality in this
work)
'''
-
- def __init__(self, model, points_batch_size=100000,
- threshold=0.5, refinement_step=0, device=None,
- resolution0=16, upsampling_steps=3,
- with_normals=False, padding=0.1,
- simplify_nfaces=None, with_color=False,
- refine_max_faces=10000):
+ def __init__(
+ self,
+ model,
+ points_batch_size=100000,
+ threshold=0.5,
+ refinement_step=0,
+ device=None,
+ resolution0=16,
+ upsampling_steps=3,
+ with_normals=False,
+ padding=0.1,
+ simplify_nfaces=None,
+ with_color=False,
+ refine_max_faces=10000
+ ):
self.model = model.to(device)
self.points_batch_size = points_batch_size
self.refinement_step = refinement_step
@@ -68,8 +76,7 @@ class Generator3D(object):
kwargs = {}
c = self.model.encode_inputs(inputs)
- mesh = self.generate_from_latent(c, stats_dict=stats_dict,
- data=data, **kwargs)
+ mesh = self.generate_from_latent(c, stats_dict=stats_dict, data=data, **kwargs)
return mesh, stats_dict
@@ -95,8 +102,7 @@ class Generator3D(object):
return meshes
- def generate_pointcloud(self, mesh, data=None, n_points=2000000,
- scale_back=True):
+ def generate_pointcloud(self, mesh, data=None, n_points=2000000, scale_back=True):
''' Generates a point cloud from the mesh.
Args:
@@ -117,8 +123,7 @@ class Generator3D(object):
pcl_out = trimesh.Trimesh(vertices=pcl, process=False)
return pcl_out
- def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None,
- **kwargs):
+ def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, **kwargs):
''' Generates mesh from latent.
Args:
@@ -135,14 +140,11 @@ class Generator3D(object):
# Shortcut
if self.upsampling_steps == 0:
nx = self.resolution0
- pointsf = box_size * make_3d_grid(
- (-0.5,)*3, (0.5,)*3, (nx,)*3
- )
+ pointsf = box_size * make_3d_grid((-0.5, ) * 3, (0.5, ) * 3, (nx, ) * 3)
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy()
value_grid = values.reshape(nx, nx, nx)
else:
- mesh_extractor = MISE(
- self.resolution0, self.upsampling_steps, threshold)
+ mesh_extractor = MISE(self.resolution0, self.upsampling_steps, threshold)
points = mesh_extractor.query()
@@ -153,8 +155,7 @@ class Generator3D(object):
pointsf = 2 * pointsf / mesh_extractor.resolution
pointsf = box_size * (pointsf - 1.0)
# Evaluate model and update
- values = self.eval_points(
- pointsf, c, pl, **kwargs).cpu().numpy()
+ values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy()
values = values.astype(np.float64)
mesh_extractor.update(points, values)
@@ -203,17 +204,15 @@ class Generator3D(object):
threshold = np.log(self.threshold) - np.log(1. - self.threshold)
# Make sure that mesh is watertight
t0 = time.time()
- occ_hat_padded = np.pad(
- occ_hat, 1, 'constant', constant_values=-1e6)
- vertices, triangles = libmcubes.marching_cubes(
- occ_hat_padded, threshold)
+ occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6)
+ vertices, triangles = libmcubes.marching_cubes(occ_hat_padded, threshold)
stats_dict['time (marching cubes)'] = time.time() - t0
# Strange behaviour in libmcubes: vertices are shifted by 0.5
vertices -= 0.5
# Undo padding
vertices -= 1
# Normalize to bounding box
- vertices /= np.array([n_x-1, n_y-1, n_z-1])
+ vertices /= np.array([n_x - 1, n_y - 1, n_z - 1])
vertices *= 2
vertices = box_size * (vertices - 1)
@@ -228,10 +227,13 @@ class Generator3D(object):
else:
normals = None
# Create mesh
- mesh = trimesh.Trimesh(vertices, triangles,
- vertex_normals=normals,
- # vertex_colors=vertex_colors,
- process=False)
+ mesh = trimesh.Trimesh(
+ vertices,
+ triangles,
+ vertex_normals=normals,
+ # vertex_colors=vertex_colors,
+ process=False
+ )
# Directly return if mesh is empty
if vertices.shape[0] == 0:
@@ -255,9 +257,12 @@ class Generator3D(object):
vertex_colors = self.estimate_colors(np.array(mesh.vertices), c)
stats_dict['time (color)'] = time.time() - t0
mesh = trimesh.Trimesh(
- vertices=mesh.vertices, faces=mesh.faces,
+ vertices=mesh.vertices,
+ faces=mesh.faces,
vertex_normals=mesh.vertex_normals,
- vertex_colors=vertex_colors, process=False)
+ vertex_colors=vertex_colors,
+ process=False
+ )
return mesh
@@ -275,16 +280,15 @@ class Generator3D(object):
for vi in vertices_split:
vi = vi.to(device)
with torch.no_grad():
- ci = self.model.decode_color(
- vi.unsqueeze(0), c).squeeze(0).cpu()
+ ci = self.model.decode_color(vi.unsqueeze(0), c).squeeze(0).cpu()
colors.append(ci)
colors = np.concatenate(colors, axis=0)
colors = np.clip(colors, 0, 1)
colors = (colors * 255).astype(np.uint8)
- colors = np.concatenate([
- colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)],
- axis=1)
+ colors = np.concatenate(
+ [colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1
+ )
return colors
def estimate_normals(self, vertices, c=None):
@@ -328,7 +332,7 @@ class Generator3D(object):
# Some shorthands
n_x, n_y, n_z = occ_hat.shape
- assert(n_x == n_y == n_z)
+ assert (n_x == n_y == n_z)
# threshold = np.log(self.threshold) - np.log(1. - self.threshold)
threshold = self.threshold
@@ -348,8 +352,7 @@ class Generator3D(object):
# Dataset
ds_faces = TensorDataset(faces)
- dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces,
- shuffle=True)
+ dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, shuffle=True)
# We updated the refinement algorithm to subsample faces; this is
# usefull when using a high extraction resolution / when working on
@@ -372,13 +375,16 @@ class Generator3D(object):
face_normal = face_normal / \
(face_normal.norm(dim=1, keepdim=True) + 1e-10)
- face_value = torch.cat([
- torch.sigmoid(self.model.decode(p_split, c).logits)
- for p_split in torch.split(
- face_point.unsqueeze(0), 20000, dim=1)], dim=1)
+ face_value = torch.cat(
+ [
+ torch.sigmoid(self.model.decode(p_split, c).logits)
+ for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1)
+ ],
+ dim=1
+ )
- normal_target = -autograd.grad(
- [face_value.sum()], [face_point], create_graph=True)[0]
+ normal_target = -autograd.grad([face_value.sum()], [face_point],
+ create_graph=True)[0]
normal_target = \
normal_target / \
diff --git a/lib/pymafx/utils/part_utils.py b/lib/pymafx/utils/part_utils.py
index 88bcf06d713ee0cbf8d673c16e1acaed9313ccab..12f0de443fa11e90674761816a644cf82a48a786 100644
--- a/lib/pymafx/utils/part_utils.py
+++ b/lib/pymafx/utils/part_utils.py
@@ -5,6 +5,7 @@ from core import path_config
from models import SMPL
+
class PartRenderer():
"""Renderer used to render segmentation masks and part segmentations.
Internally it uses the Neural 3D Mesh Renderer
@@ -14,40 +15,57 @@ class PartRenderer():
self.focal_length = focal_length
self.render_res = render_res
# We use Neural 3D mesh renderer for rendering masks and part segmentations
- self.neural_renderer = nr.Renderer(dist_coeffs=None, orig_size=self.render_res,
- image_size=render_res,
- light_intensity_ambient=1,
- light_intensity_directional=0,
- anti_aliasing=False)
- self.faces = torch.from_numpy(SMPL(path_config.SMPL_MODEL_DIR).faces.astype(np.int32)).cuda()
+ self.neural_renderer = nr.Renderer(
+ dist_coeffs=None,
+ orig_size=self.render_res,
+ image_size=render_res,
+ light_intensity_ambient=1,
+ light_intensity_directional=0,
+ anti_aliasing=False
+ )
+ self.faces = torch.from_numpy(SMPL(path_config.SMPL_MODEL_DIR).faces.astype(np.int32)
+ ).cuda()
textures = np.load(path_config.VERTEX_TEXTURE_FILE)
self.textures = torch.from_numpy(textures).cuda().float()
self.cube_parts = torch.cuda.FloatTensor(np.load(path_config.CUBE_PARTS_FILE))
def get_parts(self, parts, mask):
"""Process renderer part image to get body part indices."""
- bn,c,h,w = parts.shape
- mask = mask.view(-1,1)
- parts_index = torch.floor(100*parts.permute(0,2,3,1).contiguous().view(-1,3)).long()
- parts = self.cube_parts[parts_index[:,0], parts_index[:,1], parts_index[:,2], None]
+ bn, c, h, w = parts.shape
+ mask = mask.view(-1, 1)
+ parts_index = torch.floor(100 * parts.permute(0, 2, 3, 1).contiguous().view(-1, 3)).long()
+ parts = self.cube_parts[parts_index[:, 0], parts_index[:, 1], parts_index[:, 2], None]
parts *= mask
- parts = parts.view(bn,h,w).long()
+ parts = parts.view(bn, h, w).long()
return parts
def __call__(self, vertices, camera):
"""Wrapper function for rendering process."""
# Estimate camera parameters given a fixed focal length
- cam_t = torch.stack([camera[:,1], camera[:,2], 2*self.focal_length/(self.render_res * camera[:,0] +1e-9)],dim=-1)
+ cam_t = torch.stack(
+ [
+ camera[:, 1], camera[:, 2], 2 * self.focal_length /
+ (self.render_res * camera[:, 0] + 1e-9)
+ ],
+ dim=-1
+ )
batch_size = vertices.shape[0]
K = torch.eye(3, device=vertices.device)
- K[0,0] = self.focal_length
- K[1,1] = self.focal_length
- K[2,2] = 1
- K[0,2] = self.render_res / 2.
- K[1,2] = self.render_res / 2.
+ K[0, 0] = self.focal_length
+ K[1, 1] = self.focal_length
+ K[2, 2] = 1
+ K[0, 2] = self.render_res / 2.
+ K[1, 2] = self.render_res / 2.
K = K[None, :, :].expand(batch_size, -1, -1)
R = torch.eye(3, device=vertices.device)[None, :, :].expand(batch_size, -1, -1)
faces = self.faces[None, :, :].expand(batch_size, -1, -1)
- parts, _, mask = self.neural_renderer(vertices, faces, textures=self.textures.expand(batch_size, -1, -1, -1, -1, -1), K=K, R=R, t=cam_t.unsqueeze(1))
+ parts, _, mask = self.neural_renderer(
+ vertices,
+ faces,
+ textures=self.textures.expand(batch_size, -1, -1, -1, -1, -1),
+ K=K,
+ R=R,
+ t=cam_t.unsqueeze(1)
+ )
parts = self.get_parts(parts, mask)
- return mask, parts
\ No newline at end of file
+ return mask, parts
diff --git a/lib/pymafx/utils/pose_tracker.py b/lib/pymafx/utils/pose_tracker.py
index 5028bb5706b9f1e9e3ccf656734650434b8c82ab..92c383cdb3dba6053a0595b9f03305c02e9fc277 100644
--- a/lib/pymafx/utils/pose_tracker.py
+++ b/lib/pymafx/utils/pose_tracker.py
@@ -23,10 +23,10 @@ import os.path as osp
def run_openpose(
- video_file,
- output_folder,
- staf_folder,
- vis=False,
+ video_file,
+ output_folder,
+ staf_folder,
+ vis=False,
):
pwd = os.getcwd()
@@ -35,13 +35,10 @@ def run_openpose(
render = 1 if vis else 0
display = 2 if vis else 0
cmd = [
- 'build/examples/openpose/openpose.bin',
- '--model_pose', 'BODY_21A',
- '--tracking', '1',
- '--render_pose', str(render),
- '--video', video_file,
- '--write_json', output_folder,
- '--display', str(display)
+ 'build/examples/openpose/openpose.bin', '--model_pose', 'BODY_21A', '--tracking', '1',
+ '--render_pose',
+ str(render), '--video', video_file, '--write_json', output_folder, '--display',
+ str(display)
]
print('Executing', ' '.join(cmd))
@@ -59,7 +56,7 @@ def read_posetrack_keypoints(output_folder):
# print(idx, data)
for person in data['people']:
person_id = person['person_id'][0]
- joints2d = person['pose_keypoints_2d']
+ joints2d = person['pose_keypoints_2d']
if person_id in people.keys():
people[person_id]['joints2d'].append(joints2d)
people[person_id]['frames'].append(idx)
@@ -72,7 +69,9 @@ def read_posetrack_keypoints(output_folder):
people[person_id]['frames'].append(idx)
for k in people.keys():
- people[k]['joints2d'] = np.array(people[k]['joints2d']).reshape((len(people[k]['joints2d']), -1, 3))
+ people[k]['joints2d'] = np.array(people[k]['joints2d']).reshape(
+ (len(people[k]['joints2d']), -1, 3)
+ )
people[k]['frames'] = np.array(people[k]['frames'])
return people
@@ -80,20 +79,14 @@ def read_posetrack_keypoints(output_folder):
def run_posetracker(video_file, staf_folder, posetrack_output_folder='/tmp', display=False):
posetrack_output_folder = os.path.join(
- posetrack_output_folder,
- f'{os.path.basename(video_file)}_posetrack'
+ posetrack_output_folder, f'{os.path.basename(video_file)}_posetrack'
)
# run posetrack on video
- run_openpose(
- video_file,
- posetrack_output_folder,
- vis=display,
- staf_folder=staf_folder
- )
+ run_openpose(video_file, posetrack_output_folder, vis=display, staf_folder=staf_folder)
people_dict = read_posetrack_keypoints(posetrack_output_folder)
shutil.rmtree(posetrack_output_folder)
- return people_dict
\ No newline at end of file
+ return people_dict
diff --git a/lib/pymafx/utils/pose_utils.py b/lib/pymafx/utils/pose_utils.py
index f74bfd6668cb6214e4414cab095c00aee26e7314..55eb1d771376da71c864a715d1dd6b5d66e9894e 100644
--- a/lib/pymafx/utils/pose_utils.py
+++ b/lib/pymafx/utils/pose_utils.py
@@ -7,6 +7,7 @@ from __future__ import print_function
import numpy as np
import torch
+
def compute_similarity_transform(S1, S2):
"""
Computes a similarity transform (sR, t) that takes
@@ -19,7 +20,7 @@ def compute_similarity_transform(S1, S2):
S1 = S1.T
S2 = S2.T
transposed = True
- assert(S2.shape[1] == S1.shape[1])
+ assert (S2.shape[1] == S1.shape[1])
# 1. Remove mean.
mu1 = S1.mean(axis=1, keepdims=True)
@@ -47,16 +48,17 @@ def compute_similarity_transform(S1, S2):
scale = np.trace(R.dot(K)) / var1
# 6. Recover translation.
- t = mu2 - scale*(R.dot(mu1))
+ t = mu2 - scale * (R.dot(mu1))
# 7. Error:
- S1_hat = scale*R.dot(S1) + t
+ S1_hat = scale * R.dot(S1) + t
if transposed:
S1_hat = S1_hat.T
return S1_hat
+
def compute_similarity_transform_batch(S1, S2):
"""Batched version of compute_similarity_transform."""
S1_hat = np.zeros_like(S1)
@@ -64,10 +66,11 @@ def compute_similarity_transform_batch(S1, S2):
S1_hat[i] = compute_similarity_transform(S1[i], S2[i])
return S1_hat
+
def reconstruction_error(S1, S2, reduction='mean'):
"""Do Procrustes alignment and compute reconstruction error."""
S1_hat = compute_similarity_transform_batch(S1, S2)
- re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
+ re = np.sqrt(((S1_hat - S2)**2).sum(axis=-1)).mean(axis=-1)
if reduction == 'mean':
re = re.mean()
elif reduction == 'sum':
@@ -113,6 +116,7 @@ def axis_angle_add(theta, roll_axis, alpha):
return c_n
+
def axis_angle_add_np(theta, roll_axis, alpha):
"""Composition of two axis-angle rotations (NumPy version)
Args:
@@ -145,4 +149,4 @@ def axis_angle_add_np(theta, roll_axis, alpha):
c_sin = np.sin(c_angle * 0.5)
c_n = (c_angle / c_sin) * c_sin_n
- return c_n
\ No newline at end of file
+ return c_n
diff --git a/lib/pymafx/utils/renderer.py b/lib/pymafx/utils/renderer.py
index 032deb76ef2690cdb046e67ff5d1680741dfab3a..9fb19568680b839f93c00a5288c94a5a52025242 100644
--- a/lib/pymafx/utils/renderer.py
+++ b/lib/pymafx/utils/renderer.py
@@ -34,33 +34,20 @@ from pytorch3d.structures.meshes import Meshes
# from pytorch3d.renderer.mesh.renderer import MeshRendererWithFragments
from pytorch3d.renderer import (
- look_at_view_transform,
- FoVPerspectiveCameras,
- PerspectiveCameras,
- AmbientLights,
- PointLights,
- RasterizationSettings,
- BlendParams,
- MeshRenderer,
- MeshRasterizer,
- SoftPhongShader,
- SoftSilhouetteShader,
- HardPhongShader,
- HardGouraudShader,
- HardFlatShader,
- TexturesVertex
+ look_at_view_transform, FoVPerspectiveCameras, PerspectiveCameras, AmbientLights, PointLights,
+ RasterizationSettings, BlendParams, MeshRenderer, MeshRasterizer, SoftPhongShader,
+ SoftSilhouetteShader, HardPhongShader, HardGouraudShader, HardFlatShader, TexturesVertex
)
import logging
+
logger = logging.getLogger(__name__)
+
class WeakPerspectiveCamera(pyrender.Camera):
- def __init__(self,
- scale,
- translation,
- znear=pyrender.camera.DEFAULT_Z_NEAR,
- zfar=None,
- name=None):
+ def __init__(
+ self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None
+ ):
super(WeakPerspectiveCamera, self).__init__(
znear=znear,
zfar=zfar,
@@ -80,21 +67,22 @@ class WeakPerspectiveCamera(pyrender.Camera):
class PyRenderer:
- def __init__(self, resolution=(224,224), orig_img=False, wireframe=False, scale_ratio=1., vis_ratio=1.):
+ def __init__(
+ self, resolution=(224, 224), orig_img=False, wireframe=False, scale_ratio=1., vis_ratio=1.
+ ):
self.resolution = (resolution[0] * scale_ratio, resolution[1] * scale_ratio)
# self.scale_ratio = scale_ratio
- self.faces = {'smplx': get_model_faces('smplx'),
- 'smpl': get_model_faces('smpl'),
- # 'mano': get_model_faces('mano'),
- # 'flame': get_model_faces('flame'),
- }
+ self.faces = {
+ 'smplx': get_model_faces('smplx'),
+ 'smpl': get_model_faces('smpl'),
+ # 'mano': get_model_faces('mano'),
+ # 'flame': get_model_faces('flame'),
+ }
self.orig_img = orig_img
self.wireframe = wireframe
self.renderer = pyrender.OffscreenRenderer(
- viewport_width=self.resolution[0],
- viewport_height=self.resolution[1],
- point_size=1.0
+ viewport_width=self.resolution[0], viewport_height=self.resolution[1], point_size=1.0
)
self.vis_ratio = vis_ratio
@@ -104,7 +92,7 @@ class PyRenderer:
light = pyrender.PointLight(color=np.array([1.0, 1.0, 1.0]) * 0.2, intensity=1)
- yrot = np.radians(120) # angle of lights
+ yrot = np.radians(120) # angle of lights
light_pose = np.eye(4)
light_pose[:3, 3] = [0, -1, 1]
@@ -116,8 +104,9 @@ class PyRenderer:
light_pose[:3, 3] = [1, 1, 2]
self.scene.add(light, pose=light_pose)
- spot_l = pyrender.SpotLight(color=np.ones(3), intensity=15.0,
- innerConeAngle=np.pi/3, outerConeAngle=np.pi/2)
+ spot_l = pyrender.SpotLight(
+ color=np.ones(3), intensity=15.0, innerConeAngle=np.pi / 3, outerConeAngle=np.pi / 2
+ )
light_pose[:3, 3] = [1, 2, 2]
self.scene.add(spot_l, pose=light_pose)
@@ -135,17 +124,34 @@ class PyRenderer:
'red': np.array([0.5, 0.2, 0.2]),
'pink': np.array([0.7, 0.5, 0.5]),
'neutral': np.array([0.7, 0.7, 0.6]),
- # 'purple': np.array([0.5, 0.5, 0.7]),
+ # 'purple': np.array([0.5, 0.5, 0.7]),
'purple': np.array([0.55, 0.4, 0.9]),
'green': np.array([0.5, 0.55, 0.3]),
'sky': np.array([0.3, 0.5, 0.55]),
'white': np.array([1.0, 0.98, 0.94]),
}
- def __call__(self, verts, faces=None, img=np.zeros((224, 224, 3)), cam=np.array([1, 0, 0]),
- focal_length=[5000, 5000], camera_rotation=np.eye(3), crop_info=None,
- angle=None, axis=None, mesh_filename=None, color_type=None, color=[1.0, 1.0, 0.9], iwp_mode=True, crop_img=True, mesh_type='smpl', scale_ratio=1., rgba_mode=False):
-
+ def __call__(
+ self,
+ verts,
+ faces=None,
+ img=np.zeros((224, 224, 3)),
+ cam=np.array([1, 0, 0]),
+ focal_length=[5000, 5000],
+ camera_rotation=np.eye(3),
+ crop_info=None,
+ angle=None,
+ axis=None,
+ mesh_filename=None,
+ color_type=None,
+ color=[1.0, 1.0, 0.9],
+ iwp_mode=True,
+ crop_img=True,
+ mesh_type='smpl',
+ scale_ratio=1.,
+ rgba_mode=False
+ ):
+
if faces is None:
faces = self.faces[mesh_type]
mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
@@ -166,24 +172,28 @@ class PyRenderer:
if len(cam) == 4:
sx, sy, tx, ty = cam
# sy = sx
- camera_translation = np.array([tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)])
+ camera_translation = np.array(
+ [tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]
+ )
elif len(cam) == 3:
sx, tx, ty = cam
sy = sx
- camera_translation = np.array([- tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)])
+ camera_translation = np.array(
+ [-tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]
+ )
render_res = resolution
self.renderer.viewport_width = render_res[1]
self.renderer.viewport_height = render_res[0]
else:
if crop_info['opt_cam_t'] is None:
camera_translation = convert_to_full_img_cam(
- pare_cam=cam[None],
- bbox_height=crop_info['bbox_scale'] * 200.,
- bbox_center=crop_info['bbox_center'],
- img_w=crop_info['img_w'],
- img_h=crop_info['img_h'],
- focal_length=focal_length[0],
- )
+ pare_cam=cam[None],
+ bbox_height=crop_info['bbox_scale'] * 200.,
+ bbox_center=crop_info['bbox_center'],
+ img_w=crop_info['img_w'],
+ img_h=crop_info['img_h'],
+ focal_length=focal_length[0],
+ )
else:
camera_translation = crop_info['opt_cam_t']
if torch.is_tensor(camera_translation):
@@ -197,8 +207,9 @@ class PyRenderer:
self.renderer.viewport_width = render_res[1]
self.renderer.viewport_height = render_res[0]
camera_rotation = camera_rotation.T
- camera = pyrender.IntrinsicsCamera(fx=focal_length[0], fy=focal_length[1],
- cx=render_res[1]/2., cy=render_res[0]/2.)
+ camera = pyrender.IntrinsicsCamera(
+ fx=focal_length[0], fy=focal_length[1], cx=render_res[1] / 2., cy=render_res[0] / 2.
+ )
if color_type != None:
color = self.colors_dict[color_type]
@@ -237,9 +248,14 @@ class PyRenderer:
for item in image_list:
if scale_ratio != 1:
orig_size = item.shape[:2]
- item = resize(item, (orig_size[0] * scale_ratio, orig_size[1] * scale_ratio), anti_aliasing=True)
+ item = resize(
+ item, (orig_size[0] * scale_ratio, orig_size[1] * scale_ratio),
+ anti_aliasing=True
+ )
item = (item * 255).astype(np.uint8)
- output_img = rgb[:, :, :-1] * valid_mask * self.vis_ratio + (1 - valid_mask * self.vis_ratio) * item
+ output_img = rgb[:, :, :-1] * valid_mask * self.vis_ratio + (
+ 1 - valid_mask * self.vis_ratio
+ ) * item
# output_img[valid_mask < 0.5] = item[valid_mask < 0.5]
# if scale_ratio != 1:
# output_img = resize(output_img, (orig_size[0], orig_size[1]), anti_aliasing=True)
@@ -253,7 +269,7 @@ class PyRenderer:
return_img.append(item)
if type(img) is not list:
- # if scale_ratio == 1:
+ # if scale_ratio == 1:
return_img = return_img[0]
self.scene.remove_node(mesh_node)
@@ -267,9 +283,12 @@ class OpenDRenderer:
self.resolution = (resolution[0] * ratio, resolution[1] * ratio)
self.ratio = ratio
self.focal_length = 5000.
- self.K = np.array([[self.focal_length, 0., self.resolution[1] / 2.],
- [0., self.focal_length, self.resolution[0] / 2.],
- [0., 0., 1.]])
+ self.K = np.array(
+ [
+ [self.focal_length, 0., self.resolution[1] / 2.],
+ [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]
+ ]
+ )
self.colors_dict = {
'red': np.array([0.5, 0.2, 0.2]),
'pink': np.array([0.7, 0.5, 0.5]),
@@ -281,16 +300,29 @@ class OpenDRenderer:
}
self.renderer = ColoredRenderer()
self.faces = get_smpl_faces()
-
+
def reset_res(self, resolution):
self.resolution = (resolution[0] * self.ratio, resolution[1] * self.ratio)
- self.K = np.array([[self.focal_length, 0., self.resolution[1] / 2.],
- [0., self.focal_length, self.resolution[0] / 2.],
- [0., 0., 1.]])
+ self.K = np.array(
+ [
+ [self.focal_length, 0., self.resolution[1] / 2.],
+ [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]
+ ]
+ )
- def __call__(self, verts, faces=None, color=None, color_type='white', R=None, mesh_filename=None,
- img=np.zeros((224, 224, 3)), cam=np.array([1, 0, 0]),
- rgba=False, addlight=True):
+ def __call__(
+ self,
+ verts,
+ faces=None,
+ color=None,
+ color_type='white',
+ R=None,
+ mesh_filename=None,
+ img=np.zeros((224, 224, 3)),
+ cam=np.array([1, 0, 0]),
+ rgba=False,
+ addlight=True
+ ):
'''Render mesh using OpenDR
verts: shape - (V, 3)
faces: shape - (F, 3)
@@ -307,18 +339,18 @@ class OpenDRenderer:
f = np.array([K[0, 0], K[1, 1]])
c = np.array([K[0, 2], K[1, 2]])
-
+
if faces is None:
faces = self.faces
if len(cam) == 4:
t = np.array([cam[2], cam[3], 2 * K[0, 0] / (w * cam[0] + 1e-9)])
elif len(cam) == 3:
t = np.array([cam[1], cam[2], 2 * K[0, 0] / (w * cam[0] + 1e-9)])
-
+
rn.camera = ProjectPoints(rt=np.array([0, 0, 0]), t=t, f=f, c=c, k=np.zeros(5))
rn.frustum = {'near': 1., 'far': 1000., 'width': w, 'height': h}
- albedo = np.ones_like(verts)*.9
+ albedo = np.ones_like(verts) * .9
if color is not None:
color0 = np.array(color)
@@ -343,7 +375,7 @@ class OpenDRenderer:
rn.set(v=verts, f=faces, vc=color, bgcolor=np.zeros(3))
if addlight:
- yrot = np.radians(120) # angle of lights
+ yrot = np.radians(120) # angle of lights
# # 1. 1. 0.7
rn.vc = LambertianPointLight(
f=rn.f,
@@ -351,7 +383,8 @@ class OpenDRenderer:
num_verts=len(rn.v),
light_pos=rotateY(np.array([-200, -100, -100]), yrot),
vc=albedo,
- light_color=color0)
+ light_color=color0
+ )
# Construct Left Light
rn.vc += LambertianPointLight(
@@ -360,7 +393,8 @@ class OpenDRenderer:
num_verts=len(rn.v),
light_pos=rotateY(np.array([800, 10, 300]), yrot),
vc=albedo,
- light_color=color1)
+ light_color=color1
+ )
# Construct Right Light
rn.vc += LambertianPointLight(
@@ -369,7 +403,8 @@ class OpenDRenderer:
num_verts=len(rn.v),
light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
vc=albedo,
- light_color=color2)
+ light_color=color2
+ )
rendered_image = rn.r
visibility_image = rn.visibility_image
@@ -379,12 +414,16 @@ class OpenDRenderer:
return_img = []
for item in image_list:
if self.ratio != 1:
- img_resized = resize(item, (item.shape[0] * self.ratio, item.shape[1] * self.ratio), anti_aliasing=True)
+ img_resized = resize(
+ item, (item.shape[0] * self.ratio, item.shape[1] * self.ratio),
+ anti_aliasing=True
+ )
else:
img_resized = item / 255.
try:
- img_resized[visibility_image != (2**32 - 1)] = rendered_image[visibility_image != (2**32 - 1)]
+ img_resized[visibility_image != (2**32 - 1)
+ ] = rendered_image[visibility_image != (2**32 - 1)]
except:
logger.warning('Can not render mesh.')
@@ -407,34 +446,40 @@ class OpenDRenderer:
# https://github.com/classner/up/blob/master/up_tools/camera.py
def rotateY(points, angle):
"""Rotate all points in a 2D array around the y axis."""
- ry = np.array([
- [np.cos(angle), 0., np.sin(angle)],
- [0., 1., 0. ],
- [-np.sin(angle), 0., np.cos(angle)]
- ])
+ ry = np.array(
+ [[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0.,
+ np.cos(angle)]]
+ )
return np.dot(points, ry)
-def rotateX( points, angle ):
+
+def rotateX(points, angle):
"""Rotate all points in a 2D array around the x axis."""
- rx = np.array([
- [1., 0., 0. ],
- [0., np.cos(angle), -np.sin(angle)],
- [0., np.sin(angle), np.cos(angle) ]
- ])
+ rx = np.array(
+ [[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle),
+ np.cos(angle)]]
+ )
return np.dot(points, rx)
-def rotateZ( points, angle ):
+
+def rotateZ(points, angle):
"""Rotate all points in a 2D array around the z axis."""
- rz = np.array([
- [np.cos(angle), -np.sin(angle), 0. ],
- [np.sin(angle), np.cos(angle), 0. ],
- [0., 0., 1. ]
- ])
+ rz = np.array(
+ [[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]]
+ )
return np.dot(points, rz)
class IUV_Renderer(object):
- def __init__(self, focal_length=5000., orig_size=224, output_size=56, mode='iuv', device=torch.device('cuda'), mesh_type='smpl'):
+ def __init__(
+ self,
+ focal_length=5000.,
+ orig_size=224,
+ output_size=56,
+ mode='iuv',
+ device=torch.device('cuda'),
+ mesh_type='smpl'
+ ):
self.focal_length = focal_length
self.orig_size = orig_size
@@ -449,7 +494,9 @@ class IUV_Renderer(object):
faces = DP.FacesDensePose
faces = faces[None, :, :]
- self.faces = torch.from_numpy(faces.astype(np.int32)) # [1, 13774, 3], torch.int32
+ self.faces = torch.from_numpy(
+ faces.astype(np.int32)
+ ) # [1, 13774, 3], torch.int32
num_part = float(np.max(DP.FaceIndices))
self.num_part = num_part
@@ -468,13 +515,22 @@ class IUV_Renderer(object):
np.save(dp_vert_pid_fname, np.array(dp_vert_pid))
textures_vts = np.array(
- [(dp_vert_pid[i] / num_part, DP.U_norm[i], DP.V_norm[i]) for i in
- range(len(vert_mapping))])
- self.textures_vts = torch.from_numpy(textures_vts[None].astype(np.float32)) # (1, 7829, 3)
+ [
+ (dp_vert_pid[i] / num_part, DP.U_norm[i], DP.V_norm[i])
+ for i in range(len(vert_mapping))
+ ]
+ )
+ self.textures_vts = torch.from_numpy(
+ textures_vts[None].astype(np.float32)
+ ) # (1, 7829, 3)
elif mode == 'pncc':
self.vert_mapping = None
- self.faces = torch.from_numpy(get_model_faces(mesh_type)[None].astype(np.int32)) # mano: torch.Size([1, 1538, 3])
- textures_vts = get_model_tpose(mesh_type).unsqueeze(0) # mano: torch.Size([1, 778, 3])
+ self.faces = torch.from_numpy(
+ get_model_faces(mesh_type)[None].astype(np.int32)
+ ) # mano: torch.Size([1, 1538, 3])
+ textures_vts = get_model_tpose(mesh_type).unsqueeze(
+ 0
+ ) # mano: torch.Size([1, 778, 3])
texture_min = torch.min(textures_vts) - 0.001
texture_range = torch.max(textures_vts) - texture_min + 0.001
@@ -485,7 +541,11 @@ class IUV_Renderer(object):
self.faces = torch.from_numpy(get_smpl_faces().astype(np.int32)[None])
- with open(os.path.join(path_config.SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), 'rb') as json_file:
+ with open(
+ os.path.join(
+ path_config.SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)
+ ), 'rb'
+ ) as json_file:
smpl_part_id = json.load(json_file)
v_id = []
@@ -509,9 +569,12 @@ class IUV_Renderer(object):
# range(n_verts)])
self.textures_vts = torch.from_numpy(textures_vts[None].astype(np.float32))
- K = np.array([[self.focal_length, 0., self.orig_size / 2.],
- [0., self.focal_length, self.orig_size / 2.],
- [0., 0., 1.]])
+ K = np.array(
+ [
+ [self.focal_length, 0., self.orig_size / 2.],
+ [0., self.focal_length, self.orig_size / 2.], [0., 0., 1.]
+ ]
+ )
R = np.array([[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]])
@@ -540,26 +603,27 @@ class IUV_Renderer(object):
raster_settings = RasterizationSettings(
image_size=output_size,
- blur_radius=0,
+ blur_radius=0,
faces_per_pixel=1,
)
self.renderer = MeshRenderer(
- rasterizer=MeshRasterizer(
- raster_settings=raster_settings
- ),
- shader=HardFlatShader(
- device=self.device,
- lights=lights,
- blend_params=BlendParams(background_color=[0, 0, 0], sigma=0.0, gamma=0.0)
- )
+ rasterizer=MeshRasterizer(raster_settings=raster_settings),
+ shader=HardFlatShader(
+ device=self.device,
+ lights=lights,
+ blend_params=BlendParams(background_color=[0, 0, 0], sigma=0.0, gamma=0.0)
)
+ )
def camera_matrix(self, cam):
batch_size = cam.size(0)
K = self.K.repeat(batch_size, 1, 1)
R = self.R.repeat(batch_size, 1, 1)
- t = torch.stack([-cam[:, 1], -cam[:, 2], 2 * self.focal_length/(self.orig_size * cam[:, 0] + 1e-9)], dim=-1)
+ t = torch.stack(
+ [-cam[:, 1], -cam[:, 2], 2 * self.focal_length / (self.orig_size * cam[:, 0] + 1e-9)],
+ dim=-1
+ )
if cam.is_cuda:
# device_id = cam.get_device()
@@ -580,9 +644,18 @@ class IUV_Renderer(object):
vertices = verts[:, self.vert_mapping, :]
mesh = Meshes(vertices, self.faces.to(verts.device).expand(batch_size, -1, -1))
- mesh.textures = TexturesVertex(verts_features=self.textures_vts.to(verts.device).expand(batch_size, -1, -1))
+ mesh.textures = TexturesVertex(
+ verts_features=self.textures_vts.to(verts.device).expand(batch_size, -1, -1)
+ )
- cameras = PerspectiveCameras(device=verts.device, R=R, T=t, K=K, in_ndc=False, image_size=[(self.orig_size, self.orig_size)])
+ cameras = PerspectiveCameras(
+ device=verts.device,
+ R=R,
+ T=t,
+ K=K,
+ in_ndc=False,
+ image_size=[(self.orig_size, self.orig_size)]
+ )
iuv_image = self.renderer(mesh, cameras=cameras)
iuv_image = iuv_image[..., :3].permute(0, 3, 1, 2)
diff --git a/lib/pymafx/utils/sample_mesh.py b/lib/pymafx/utils/sample_mesh.py
index 9d8833cf0642394ba6ff9ba86bad11c278ebdd01..2599bee12d2577b6826ea8bfad8c937f2bcc2db2 100644
--- a/lib/pymafx/utils/sample_mesh.py
+++ b/lib/pymafx/utils/sample_mesh.py
@@ -3,7 +3,17 @@ import trimesh
import numpy as np
from .utils.libmesh import check_mesh_contains
-def get_occ_gt(in_path=None, vertices=None, faces=None, pts_num=1000, points_sigma=0.01, with_dp=False, points=None, extra_points=None):
+
+def get_occ_gt(
+ in_path=None,
+ vertices=None,
+ faces=None,
+ pts_num=1000,
+ points_sigma=0.01,
+ with_dp=False,
+ points=None,
+ extra_points=None
+):
if in_path is not None:
mesh = trimesh.load(in_path, process=False)
print(type(mesh.vertices), mesh.vertices.shape, mesh.faces.shape)
@@ -27,7 +37,7 @@ def get_occ_gt(in_path=None, vertices=None, faces=None, pts_num=1000, points_sig
points_surface, index_surface = mesh.sample(n_points_surface, return_index=True)
points_surface += points_sigma * np.random.randn(n_points_surface, 3)
points = np.concatenate([points_uniform, points_surface], axis=0)
-
+
if extra_points is not None:
extra_points += points_sigma * np.random.randn(len(extra_points), 3)
points = np.concatenate([points, extra_points], axis=0)
diff --git a/lib/pymafx/utils/saver.py b/lib/pymafx/utils/saver.py
index 417db9bc59684579a4f0d9778b8c5fd251a2d8f1..6a6bd3a184cc658dbc666ad2dcf3bc15d8cc427b 100644
--- a/lib/pymafx/utils/saver.py
+++ b/lib/pymafx/utils/saver.py
@@ -3,8 +3,10 @@ import os
import torch
import datetime
import logging
+
logger = logging.getLogger(__name__)
+
class CheckpointSaver():
"""Class that handles saving and loading checkpoints during training."""
def __init__(self, save_dir, save_steps=1000, overwrite=False):
@@ -22,26 +24,41 @@ class CheckpointSaver():
return False if self.latest_checkpoint is None else True
else:
return os.path.isfile(checkpoint_file)
-
- def save_checkpoint(self, models, optimizers, epoch, batch_idx, batch_size,
- total_step_count, is_best=False, save_by_step=False, interval=5, with_optimizer=True):
+
+ def save_checkpoint(
+ self,
+ models,
+ optimizers,
+ epoch,
+ batch_idx,
+ batch_size,
+ total_step_count,
+ is_best=False,
+ save_by_step=False,
+ interval=5,
+ with_optimizer=True
+ ):
"""Save checkpoint."""
timestamp = datetime.datetime.now()
if self.overwrite:
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt'))
elif save_by_step:
- checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count)))
+ checkpoint_filename = os.path.abspath(
+ os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count))
+ )
else:
if epoch % interval == 0:
- checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt'))
+ checkpoint_filename = os.path.abspath(
+ os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt')
+ )
else:
checkpoint_filename = None
-
+
checkpoint = {}
for model in models:
model_dict = models[model].state_dict()
for k in list(model_dict.keys()):
- if '.smpl.' in k:
+ if '.smpl.' in k:
del model_dict[k]
checkpoint[model] = model_dict
if with_optimizer:
@@ -56,7 +73,7 @@ class CheckpointSaver():
if checkpoint_filename is not None:
torch.save(checkpoint, checkpoint_filename)
print('Saving checkpoint file [' + checkpoint_filename + ']')
- if is_best: # save the best
+ if is_best: # save the best
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt'))
torch.save(checkpoint, checkpoint_filename)
print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)
@@ -64,7 +81,6 @@ class CheckpointSaver():
torch.save(checkpoint, checkpoint_filename)
print('Saved checkpoint file [' + checkpoint_filename + ']')
-
def load_checkpoint(self, models, optimizers, checkpoint_file=None):
"""Load a checkpoint."""
if checkpoint_file is None:
@@ -74,8 +90,10 @@ class CheckpointSaver():
for model in models:
if model in checkpoint:
model_dict = models[model].state_dict()
- pretrained_dict = {k: v for k, v in checkpoint[model].items()
- if k in model_dict.keys()}
+ pretrained_dict = {
+ k: v
+ for k, v in checkpoint[model].items() if k in model_dict.keys()
+ }
model_dict.update(pretrained_dict)
models[model].load_state_dict(model_dict)
@@ -83,20 +101,23 @@ class CheckpointSaver():
for optimizer in optimizers:
if optimizer in checkpoint:
optimizers[optimizer].load_state_dict(checkpoint[optimizer])
- return {'epoch': checkpoint['epoch'],
- 'batch_idx': checkpoint['batch_idx'],
- 'batch_size': checkpoint['batch_size'],
- 'total_step_count': checkpoint['total_step_count']}
+ return {
+ 'epoch': checkpoint['epoch'],
+ 'batch_idx': checkpoint['batch_idx'],
+ 'batch_size': checkpoint['batch_size'],
+ 'total_step_count': checkpoint['total_step_count']
+ }
def get_latest_checkpoint(self):
"""Get filename of latest checkpoint if it exists."""
- checkpoint_list = []
+ checkpoint_list = []
for dirpath, dirnames, filenames in os.walk(self.save_dir):
for filename in filenames:
if filename.endswith('.pt'):
checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename)))
# sort
import re
+
def atof(text):
try:
retval = float(text)
@@ -111,8 +132,8 @@ class CheckpointSaver():
(See Toothy's implementation in the comments)
float regex comes from https://stackoverflow.com/a/12643073/190597
'''
- return [ atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text) ]
-
+ return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)]
+
checkpoint_list.sort(key=natural_keys)
- self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1]
+ self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1]
return
diff --git a/lib/pymafx/utils/segms.py b/lib/pymafx/utils/segms.py
index 651dd0072e93660bdd0faf565c3358353ef0664b..44c617529d67323a8664c3e00872e5db091b8be6 100644
--- a/lib/pymafx/utils/segms.py
+++ b/lib/pymafx/utils/segms.py
@@ -32,249 +32,237 @@ import pycocotools.mask as mask_util
def GetDensePoseMask(Polys):
- MaskGen = np.zeros([256, 256])
- for i in range(1, 15):
- if (Polys[i - 1]):
- current_mask = mask_util.decode(Polys[i - 1])
- MaskGen[current_mask > 0] = i
- return MaskGen
+ MaskGen = np.zeros([256, 256])
+ for i in range(1, 15):
+ if (Polys[i - 1]):
+ current_mask = mask_util.decode(Polys[i - 1])
+ MaskGen[current_mask > 0] = i
+ return MaskGen
def flip_segms(segms, height, width):
- """Left/right flip each mask in a list of masks."""
-
- def _flip_poly(poly, width):
- flipped_poly = np.array(poly)
- flipped_poly[0::2] = width - np.array(poly[0::2]) - 1
- return flipped_poly.tolist()
-
- def _flip_rle(rle, height, width):
- if 'counts' in rle and type(rle['counts']) == list:
- # Magic RLE format handling painfully discovered by looking at the
- # COCO API showAnns function.
- rle = mask_util.frPyObjects([rle], height, width)
- mask = mask_util.decode(rle)
- mask = mask[:, ::-1, :]
- rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
- return rle
-
- flipped_segms = []
- for segm in segms:
- if type(segm) == list:
- # Polygon format
- flipped_segms.append([_flip_poly(poly, width) for poly in segm])
- else:
- # RLE format
- assert type(segm) == dict
- flipped_segms.append(_flip_rle(segm, height, width))
- return flipped_segms
+ """Left/right flip each mask in a list of masks."""
+ def _flip_poly(poly, width):
+ flipped_poly = np.array(poly)
+ flipped_poly[0::2] = width - np.array(poly[0::2]) - 1
+ return flipped_poly.tolist()
+
+ def _flip_rle(rle, height, width):
+ if 'counts' in rle and type(rle['counts']) == list:
+ # Magic RLE format handling painfully discovered by looking at the
+ # COCO API showAnns function.
+ rle = mask_util.frPyObjects([rle], height, width)
+ mask = mask_util.decode(rle)
+ mask = mask[:, ::-1, :]
+ rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
+ return rle
+
+ flipped_segms = []
+ for segm in segms:
+ if type(segm) == list:
+ # Polygon format
+ flipped_segms.append([_flip_poly(poly, width) for poly in segm])
+ else:
+ # RLE format
+ assert type(segm) == dict
+ flipped_segms.append(_flip_rle(segm, height, width))
+ return flipped_segms
def polys_to_mask(polygons, height, width):
- """Convert from the COCO polygon segmentation format to a binary mask
+ """Convert from the COCO polygon segmentation format to a binary mask
encoded as a 2D array of data type numpy.float32. The polygon segmentation
is understood to be enclosed inside a height x width image. The resulting
mask is therefore of shape (height, width).
"""
- rle = mask_util.frPyObjects(polygons, height, width)
- mask = np.array(mask_util.decode(rle), dtype=np.float32)
- # Flatten in case polygons was a list
- mask = np.sum(mask, axis=2)
- mask = np.array(mask > 0, dtype=np.float32)
- return mask
+ rle = mask_util.frPyObjects(polygons, height, width)
+ mask = np.array(mask_util.decode(rle), dtype=np.float32)
+ # Flatten in case polygons was a list
+ mask = np.sum(mask, axis=2)
+ mask = np.array(mask > 0, dtype=np.float32)
+ return mask
def mask_to_bbox(mask):
- """Compute the tight bounding box of a binary mask."""
- xs = np.where(np.sum(mask, axis=0) > 0)[0]
- ys = np.where(np.sum(mask, axis=1) > 0)[0]
+ """Compute the tight bounding box of a binary mask."""
+ xs = np.where(np.sum(mask, axis=0) > 0)[0]
+ ys = np.where(np.sum(mask, axis=1) > 0)[0]
- if len(xs) == 0 or len(ys) == 0:
- return None
+ if len(xs) == 0 or len(ys) == 0:
+ return None
- x0 = xs[0]
- x1 = xs[-1]
- y0 = ys[0]
- y1 = ys[-1]
- return np.array((x0, y0, x1, y1), dtype=np.float32)
+ x0 = xs[0]
+ x1 = xs[-1]
+ y0 = ys[0]
+ y1 = ys[-1]
+ return np.array((x0, y0, x1, y1), dtype=np.float32)
def polys_to_mask_wrt_box(polygons, box, M):
- """Convert from the COCO polygon segmentation format to a binary mask
+ """Convert from the COCO polygon segmentation format to a binary mask
encoded as a 2D array of data type numpy.float32. The polygon segmentation
is understood to be enclosed in the given box and rasterized to an M x M
mask. The resulting mask is therefore of shape (M, M).
"""
- w = box[2] - box[0]
- h = box[3] - box[1]
+ w = box[2] - box[0]
+ h = box[3] - box[1]
- w = np.maximum(w, 1)
- h = np.maximum(h, 1)
+ w = np.maximum(w, 1)
+ h = np.maximum(h, 1)
- polygons_norm = []
- for poly in polygons:
- p = np.array(poly, dtype=np.float32)
- p[0::2] = (p[0::2] - box[0]) * M / w
- p[1::2] = (p[1::2] - box[1]) * M / h
- polygons_norm.append(p)
+ polygons_norm = []
+ for poly in polygons:
+ p = np.array(poly, dtype=np.float32)
+ p[0::2] = (p[0::2] - box[0]) * M / w
+ p[1::2] = (p[1::2] - box[1]) * M / h
+ polygons_norm.append(p)
- rle = mask_util.frPyObjects(polygons_norm, M, M)
- mask = np.array(mask_util.decode(rle), dtype=np.float32)
- # Flatten in case polygons was a list
- mask = np.sum(mask, axis=2)
- mask = np.array(mask > 0, dtype=np.float32)
- return mask
+ rle = mask_util.frPyObjects(polygons_norm, M, M)
+ mask = np.array(mask_util.decode(rle), dtype=np.float32)
+ # Flatten in case polygons was a list
+ mask = np.sum(mask, axis=2)
+ mask = np.array(mask > 0, dtype=np.float32)
+ return mask
def polys_to_boxes(polys):
- """Convert a list of polygons into an array of tight bounding boxes."""
- boxes_from_polys = np.zeros((len(polys), 4), dtype=np.float32)
- for i in range(len(polys)):
- poly = polys[i]
- x0 = min(min(p[::2]) for p in poly)
- x1 = max(max(p[::2]) for p in poly)
- y0 = min(min(p[1::2]) for p in poly)
- y1 = max(max(p[1::2]) for p in poly)
- boxes_from_polys[i, :] = [x0, y0, x1, y1]
-
- return boxes_from_polys
-
-
-def rle_mask_voting(top_masks,
- all_masks,
- all_dets,
- iou_thresh,
- binarize_thresh,
- method='AVG'):
- """Returns new masks (in correspondence with `top_masks`) by combining
+ """Convert a list of polygons into an array of tight bounding boxes."""
+ boxes_from_polys = np.zeros((len(polys), 4), dtype=np.float32)
+ for i in range(len(polys)):
+ poly = polys[i]
+ x0 = min(min(p[::2]) for p in poly)
+ x1 = max(max(p[::2]) for p in poly)
+ y0 = min(min(p[1::2]) for p in poly)
+ y1 = max(max(p[1::2]) for p in poly)
+ boxes_from_polys[i, :] = [x0, y0, x1, y1]
+
+ return boxes_from_polys
+
+
+def rle_mask_voting(top_masks, all_masks, all_dets, iou_thresh, binarize_thresh, method='AVG'):
+ """Returns new masks (in correspondence with `top_masks`) by combining
multiple overlapping masks coming from the pool of `all_masks`. Two methods
for combining masks are supported: 'AVG' uses a weighted average of
overlapping mask pixels; 'UNION' takes the union of all mask pixels.
"""
- if len(top_masks) == 0:
- return
-
- all_not_crowd = [False] * len(all_masks)
- top_to_all_overlaps = mask_util.iou(top_masks, all_masks, all_not_crowd)
- decoded_all_masks = [
- np.array(mask_util.decode(rle), dtype=np.float32) for rle in all_masks
- ]
- decoded_top_masks = [
- np.array(mask_util.decode(rle), dtype=np.float32) for rle in top_masks
- ]
- all_boxes = all_dets[:, :4].astype(np.int32)
- all_scores = all_dets[:, 4]
-
- # Fill box support with weights
- mask_shape = decoded_all_masks[0].shape
- mask_weights = np.zeros((len(all_masks), mask_shape[0], mask_shape[1]))
- for k in range(len(all_masks)):
- ref_box = all_boxes[k]
- x_0 = max(ref_box[0], 0)
- x_1 = min(ref_box[2] + 1, mask_shape[1])
- y_0 = max(ref_box[1], 0)
- y_1 = min(ref_box[3] + 1, mask_shape[0])
- mask_weights[k, y_0:y_1, x_0:x_1] = all_scores[k]
- mask_weights = np.maximum(mask_weights, 1e-5)
-
- top_segms_out = []
- for k in range(len(top_masks)):
- # Corner case of empty mask
- if decoded_top_masks[k].sum() == 0:
- top_segms_out.append(top_masks[k])
- continue
-
- inds_to_vote = np.where(top_to_all_overlaps[k] >= iou_thresh)[0]
- # Only matches itself
- if len(inds_to_vote) == 1:
- top_segms_out.append(top_masks[k])
- continue
-
- masks_to_vote = [decoded_all_masks[i] for i in inds_to_vote]
- if method == 'AVG':
- ws = mask_weights[inds_to_vote]
- soft_mask = np.average(masks_to_vote, axis=0, weights=ws)
- mask = np.array(soft_mask > binarize_thresh, dtype=np.uint8)
- elif method == 'UNION':
- # Any pixel that's on joins the mask
- soft_mask = np.sum(masks_to_vote, axis=0)
- mask = np.array(soft_mask > 1e-5, dtype=np.uint8)
- else:
- raise NotImplementedError('Method {} is unknown'.format(method))
- rle = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0]
- top_segms_out.append(rle)
-
- return top_segms_out
+ if len(top_masks) == 0:
+ return
+
+ all_not_crowd = [False] * len(all_masks)
+ top_to_all_overlaps = mask_util.iou(top_masks, all_masks, all_not_crowd)
+ decoded_all_masks = [np.array(mask_util.decode(rle), dtype=np.float32) for rle in all_masks]
+ decoded_top_masks = [np.array(mask_util.decode(rle), dtype=np.float32) for rle in top_masks]
+ all_boxes = all_dets[:, :4].astype(np.int32)
+ all_scores = all_dets[:, 4]
+
+ # Fill box support with weights
+ mask_shape = decoded_all_masks[0].shape
+ mask_weights = np.zeros((len(all_masks), mask_shape[0], mask_shape[1]))
+ for k in range(len(all_masks)):
+ ref_box = all_boxes[k]
+ x_0 = max(ref_box[0], 0)
+ x_1 = min(ref_box[2] + 1, mask_shape[1])
+ y_0 = max(ref_box[1], 0)
+ y_1 = min(ref_box[3] + 1, mask_shape[0])
+ mask_weights[k, y_0:y_1, x_0:x_1] = all_scores[k]
+ mask_weights = np.maximum(mask_weights, 1e-5)
+
+ top_segms_out = []
+ for k in range(len(top_masks)):
+ # Corner case of empty mask
+ if decoded_top_masks[k].sum() == 0:
+ top_segms_out.append(top_masks[k])
+ continue
+
+ inds_to_vote = np.where(top_to_all_overlaps[k] >= iou_thresh)[0]
+ # Only matches itself
+ if len(inds_to_vote) == 1:
+ top_segms_out.append(top_masks[k])
+ continue
+
+ masks_to_vote = [decoded_all_masks[i] for i in inds_to_vote]
+ if method == 'AVG':
+ ws = mask_weights[inds_to_vote]
+ soft_mask = np.average(masks_to_vote, axis=0, weights=ws)
+ mask = np.array(soft_mask > binarize_thresh, dtype=np.uint8)
+ elif method == 'UNION':
+ # Any pixel that's on joins the mask
+ soft_mask = np.sum(masks_to_vote, axis=0)
+ mask = np.array(soft_mask > 1e-5, dtype=np.uint8)
+ else:
+ raise NotImplementedError('Method {} is unknown'.format(method))
+ rle = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0]
+ top_segms_out.append(rle)
+
+ return top_segms_out
def rle_mask_nms(masks, dets, thresh, mode='IOU'):
- """Performs greedy non-maximum suppression based on an overlap measurement
+ """Performs greedy non-maximum suppression based on an overlap measurement
between masks. The type of measurement is determined by `mode` and can be
either 'IOU' (standard intersection over union) or 'IOMA' (intersection over
mininum area).
"""
- if len(masks) == 0:
- return []
- if len(masks) == 1:
- return [0]
-
- if mode == 'IOU':
- # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(union(m1, m2))
- all_not_crowds = [False] * len(masks)
- ious = mask_util.iou(masks, masks, all_not_crowds)
- elif mode == 'IOMA':
- # Computes ious[m1, m2] = area(intersect(m1, m2)) / min(area(m1), area(m2))
- all_crowds = [True] * len(masks)
- # ious[m1, m2] = area(intersect(m1, m2)) / area(m2)
- ious = mask_util.iou(masks, masks, all_crowds)
- # ... = max(area(intersect(m1, m2)) / area(m2),
- # area(intersect(m2, m1)) / area(m1))
- ious = np.maximum(ious, ious.transpose())
- elif mode == 'CONTAINMENT':
- # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(m2)
- # Which measures how much m2 is contained inside m1
- all_crowds = [True] * len(masks)
- ious = mask_util.iou(masks, masks, all_crowds)
- else:
- raise NotImplementedError('Mode {} is unknown'.format(mode))
-
- scores = dets[:, 4]
- order = np.argsort(-scores)
-
- keep = []
- while order.size > 0:
- i = order[0]
- keep.append(i)
- ovr = ious[i, order[1:]]
- inds_to_keep = np.where(ovr <= thresh)[0]
- order = order[inds_to_keep + 1]
-
- return keep
+ if len(masks) == 0:
+ return []
+ if len(masks) == 1:
+ return [0]
+
+ if mode == 'IOU':
+ # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(union(m1, m2))
+ all_not_crowds = [False] * len(masks)
+ ious = mask_util.iou(masks, masks, all_not_crowds)
+ elif mode == 'IOMA':
+ # Computes ious[m1, m2] = area(intersect(m1, m2)) / min(area(m1), area(m2))
+ all_crowds = [True] * len(masks)
+ # ious[m1, m2] = area(intersect(m1, m2)) / area(m2)
+ ious = mask_util.iou(masks, masks, all_crowds)
+ # ... = max(area(intersect(m1, m2)) / area(m2),
+ # area(intersect(m2, m1)) / area(m1))
+ ious = np.maximum(ious, ious.transpose())
+ elif mode == 'CONTAINMENT':
+ # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(m2)
+ # Which measures how much m2 is contained inside m1
+ all_crowds = [True] * len(masks)
+ ious = mask_util.iou(masks, masks, all_crowds)
+ else:
+ raise NotImplementedError('Mode {} is unknown'.format(mode))
+
+ scores = dets[:, 4]
+ order = np.argsort(-scores)
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ ovr = ious[i, order[1:]]
+ inds_to_keep = np.where(ovr <= thresh)[0]
+ order = order[inds_to_keep + 1]
+
+ return keep
def rle_masks_to_boxes(masks):
- """Computes the bounding box of each mask in a list of RLE encoded masks."""
- if len(masks) == 0:
- return []
-
- decoded_masks = [
- np.array(mask_util.decode(rle), dtype=np.float32) for rle in masks
- ]
-
- def get_bounds(flat_mask):
- inds = np.where(flat_mask > 0)[0]
- return inds.min(), inds.max()
-
- boxes = np.zeros((len(decoded_masks), 4))
- keep = [True] * len(decoded_masks)
- for i, mask in enumerate(decoded_masks):
- if mask.sum() == 0:
- keep[i] = False
- continue
- flat_mask = mask.sum(axis=0)
- x0, x1 = get_bounds(flat_mask)
- flat_mask = mask.sum(axis=1)
- y0, y1 = get_bounds(flat_mask)
- boxes[i, :] = (x0, y0, x1, y1)
-
- return boxes, np.where(keep)[0]
+ """Computes the bounding box of each mask in a list of RLE encoded masks."""
+ if len(masks) == 0:
+ return []
+
+ decoded_masks = [np.array(mask_util.decode(rle), dtype=np.float32) for rle in masks]
+
+ def get_bounds(flat_mask):
+ inds = np.where(flat_mask > 0)[0]
+ return inds.min(), inds.max()
+
+ boxes = np.zeros((len(decoded_masks), 4))
+ keep = [True] * len(decoded_masks)
+ for i, mask in enumerate(decoded_masks):
+ if mask.sum() == 0:
+ keep[i] = False
+ continue
+ flat_mask = mask.sum(axis=0)
+ x0, x1 = get_bounds(flat_mask)
+ flat_mask = mask.sum(axis=1)
+ y0, y1 = get_bounds(flat_mask)
+ boxes[i, :] = (x0, y0, x1, y1)
+
+ return boxes, np.where(keep)[0]
diff --git a/lib/pymafx/utils/smooth_bbox.py b/lib/pymafx/utils/smooth_bbox.py
index 1d31f74dbfad1cfc5eb4e32da31490106d16d510..4393320e7f50128d6838d99c76b5d0f8f45f6efc 100644
--- a/lib/pymafx/utils/smooth_bbox.py
+++ b/lib/pymafx/utils/smooth_bbox.py
@@ -94,8 +94,11 @@ def get_all_bbox_params(kps, vis_thresh=2):
previous = bbox_params[-1]
# This will be 3x(n+2)
interpolated = np.array(
- [np.linspace(prev, curr, num_to_interpolate + 2)
- for prev, curr in zip(previous, bbox_param)])
+ [
+ np.linspace(prev, curr, num_to_interpolate + 2)
+ for prev, curr in zip(previous, bbox_param)
+ ]
+ )
bbox_params = np.vstack((bbox_params, interpolated.T[1:-1]))
num_to_interpolate = 0
bbox_params = np.vstack((bbox_params, bbox_param))
@@ -116,6 +119,5 @@ def smooth_bbox_params(bbox_params, kernel_size=11, sigma=8):
Returns:
Smoothed bounding box parameters (Nx3).
"""
- smoothed = np.array([signal.medfilt(param, kernel_size)
- for param in bbox_params.T]).T
+ smoothed = np.array([signal.medfilt(param, kernel_size) for param in bbox_params.T]).T
return np.array([gaussian_filter1d(traj, sigma) for traj in smoothed.T]).T
diff --git a/lib/pymafx/utils/transforms.py b/lib/pymafx/utils/transforms.py
index a283e5122bff1b9eb61b1a9092ead15200a355f8..25534674631d40b8b263b242d05339443b169dcb 100644
--- a/lib/pymafx/utils/transforms.py
+++ b/lib/pymafx/utils/transforms.py
@@ -43,7 +43,7 @@ def fliplr_joints(joints, joints_vis, width, matched_parts):
joints_vis[pair[0], :], joints_vis[pair[1], :] = \
joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
- return joints*joints_vis, joints_vis
+ return joints * joints_vis, joints_vis
def transform_preds(coords, center, scale, output_size):
@@ -55,8 +55,7 @@ def transform_preds(coords, center, scale, output_size):
def get_affine_transform(
- center, scale, rot, output_size,
- shift=np.array([0, 0], dtype=np.float32), inv=0
+ center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0
):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
# print(scale)
@@ -114,8 +113,7 @@ def crop(img, center, scale, output_size, rot=0):
trans = get_affine_transform(center, scale, rot, output_size)
dst_img = cv2.warpAffine(
- img, trans, (int(output_size[0]), int(output_size[1])),
- flags=cv2.INTER_LINEAR
+ img, trans, (int(output_size[0]), int(output_size[1])), flags=cv2.INTER_LINEAR
)
return dst_img
diff --git a/lib/pymafx/utils/uv_vis.py b/lib/pymafx/utils/uv_vis.py
index 8dd0d3fd75cfa06c9cd2b6fabac8ab47b8eae833..86fdd33ddee774c2bbe02478b2d74f53f8522256 100644
--- a/lib/pymafx/utils/uv_vis.py
+++ b/lib/pymafx/utils/uv_vis.py
@@ -5,6 +5,7 @@ import torch.nn.functional as F
from skimage.transform import resize
# Use a non-interactive backend
import matplotlib
+
matplotlib.use('Agg')
from .renderer import OpenDRenderer, PyRenderer
@@ -37,8 +38,10 @@ def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=N
for part_id in range(1, K):
CurrentU = U_uv[batch_id, part_id]
CurrentV = V_uv[batch_id, part_id]
- output[1, Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id]
- output[2, Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id]
+ output[1,
+ Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id]
+ output[2,
+ Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id]
if uv_rois is None:
outputs.append(output.unsqueeze(0))
@@ -53,19 +56,34 @@ def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=N
new_size = [heatmap_size, max(int(heatmap_size * aspect_ratio), 1)]
output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest')
paddingleft = int(0.5 * (heatmap_size - new_size[1]))
- output = F.pad(output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0))
+ output = F.pad(
+ output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0)
+ )
else:
new_size = [max(int(heatmap_size / aspect_ratio), 1), heatmap_size]
output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest')
paddingtop = int(0.5 * (heatmap_size - new_size[0]))
- output = F.pad(output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop))
+ output = F.pad(
+ output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop)
+ )
outputs.append(output)
return torch.cat(outputs, dim=0)
-def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, image_name, save_path, opt, ratio=1):
+def vis_smpl_iuv(
+ image,
+ cam_pred,
+ vert_pred,
+ face,
+ pred_uv,
+ vert_errors_batch,
+ image_name,
+ save_path,
+ opt,
+ ratio=1
+):
# save_path = os.path.join('./notebooks/output/demo_results-wild', ids[f_id][0])
if not os.path.exists(save_path):
@@ -82,9 +100,9 @@ def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, i
for draw_i in range(len(cam_pred)):
err_val = '{:06d}_'.format(int(10 * vert_errors_batch[draw_i]))
draw_name = err_val + image_name[draw_i]
- K = np.array([[focal_length, 0., orig_size / 2.],
- [0., focal_length, orig_size / 2.],
- [0., 0., 1.]])
+ K = np.array(
+ [[focal_length, 0., orig_size / 2.], [0., focal_length, orig_size / 2.], [0., 0., 1.]]
+ )
# img_orig, img_resized, img_smpl, render_smpl_rgba = dr_render(
# image[draw_i],
@@ -100,13 +118,14 @@ def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, i
mesh_filename = None
img_orig = np.moveaxis(image[draw_i], 0, -1)
- img_smpl, img_resized = dr_render(vert_pred[draw_i],
- img=img_orig,
- cam=cam_pred[draw_i],
- iwp_mode=True,
- scale_ratio=4.,
- mesh_filename=mesh_filename,
- )
+ img_smpl, img_resized = dr_render(
+ vert_pred[draw_i],
+ img=img_orig,
+ cam=cam_pred[draw_i],
+ iwp_mode=True,
+ scale_ratio=4.,
+ mesh_filename=mesh_filename,
+ )
ones_img = np.ones(img_smpl.shape[:2]) * 255
ones_img = ones_img[:, :, None]
@@ -117,7 +136,9 @@ def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, i
render_img = np.concatenate((img_resized_rgba, img_smpl_rgba), axis=1)
render_img[render_img < 0] = 0
render_img[render_img > 255] = 255
- matplotlib.image.imsave(os.path.join(save_path, draw_name[:-4] + '.png'), render_img.astype(np.uint8))
+ matplotlib.image.imsave(
+ os.path.join(save_path, draw_name[:-4] + '.png'), render_img.astype(np.uint8)
+ )
if pred_uv is not None:
# estimated global IUV
@@ -126,4 +147,6 @@ def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, i
global_iuv = resize(global_iuv, img_resized.shape[:2])
global_iuv[global_iuv > 1] = 1
global_iuv[global_iuv < 0] = 0
- matplotlib.image.imsave(os.path.join(save_path, 'pred_uv_' + draw_name[:-4] + '.png'), global_iuv)
\ No newline at end of file
+ matplotlib.image.imsave(
+ os.path.join(save_path, 'pred_uv_' + draw_name[:-4] + '.png'), global_iuv
+ )
diff --git a/lib/pymafx/utils/vis.py b/lib/pymafx/utils/vis.py
index 873ee694b266e7ba875d548b13f63f65513d2ff7..5273707c05f66275150e7cb2d86f44dcf4c92223 100644
--- a/lib/pymafx/utils/vis.py
+++ b/lib/pymafx/utils/vis.py
@@ -17,7 +17,6 @@
# limitations under the License.
##############################################################################
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -36,14 +35,14 @@ from .imutils import normalize_2d_kp
# Use a non-interactive backend
import matplotlib
+
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from mpl_toolkits.mplot3d import Axes3D
from skimage.transform import resize
-plt.rcParams['pdf.fonttype'] = 42 # For editing in Adobe Illustrator
-
+plt.rcParams['pdf.fonttype'] = 42 # For editing in Adobe Illustrator
_GRAY = (218, 227, 218)
_GREEN = (18, 127, 15)
@@ -52,24 +51,23 @@ _WHITE = (255, 255, 255)
def get_colors():
colors = {
- 'pink': np.array([197, 27, 125]), # L lower leg
- 'light_pink': np.array([233, 163, 201]), # L upper leg
- 'light_green': np.array([161, 215, 106]), # L lower arm
- 'green': np.array([77, 146, 33]), # L upper arm
- 'red': np.array([215, 48, 39]), # head
- 'light_red': np.array([252, 146, 114]), # head
- 'light_orange': np.array([252, 141, 89]), # chest
- 'purple': np.array([118, 42, 131]), # R lower leg
- 'light_purple': np.array([175, 141, 195]), # R upper
- 'light_blue': np.array([145, 191, 219]), # R lower arm
- 'blue': np.array([69, 117, 180]), # R upper arm
- 'gray': np.array([130, 130, 130]), #
- 'white': np.array([255, 255, 255]), #
+ 'pink': np.array([197, 27, 125]), # L lower leg
+ 'light_pink': np.array([233, 163, 201]), # L upper leg
+ 'light_green': np.array([161, 215, 106]), # L lower arm
+ 'green': np.array([77, 146, 33]), # L upper arm
+ 'red': np.array([215, 48, 39]), # head
+ 'light_red': np.array([252, 146, 114]), # head
+ 'light_orange': np.array([252, 141, 89]), # chest
+ 'purple': np.array([118, 42, 131]), # R lower leg
+ 'light_purple': np.array([175, 141, 195]), # R upper
+ 'light_blue': np.array([145, 191, 219]), # R lower arm
+ 'blue': np.array([69, 117, 180]), # R upper arm
+ 'gray': np.array([130, 130, 130]), #
+ 'white': np.array([255, 255, 255]), #
}
return colors
-
def kp_connections(keypoints):
kp_lines = [
[keypoints.index('left_eye'), keypoints.index('right_eye')],
@@ -77,15 +75,21 @@ def kp_connections(keypoints):
[keypoints.index('right_eye'), keypoints.index('nose')],
[keypoints.index('right_eye'), keypoints.index('right_ear')],
[keypoints.index('left_eye'), keypoints.index('left_ear')],
- [keypoints.index('right_shoulder'), keypoints.index('right_elbow')],
- [keypoints.index('right_elbow'), keypoints.index('right_wrist')],
- [keypoints.index('left_shoulder'), keypoints.index('left_elbow')],
- [keypoints.index('left_elbow'), keypoints.index('left_wrist')],
+ [keypoints.index('right_shoulder'),
+ keypoints.index('right_elbow')],
+ [keypoints.index('right_elbow'),
+ keypoints.index('right_wrist')],
+ [keypoints.index('left_shoulder'),
+ keypoints.index('left_elbow')],
+ [keypoints.index('left_elbow'),
+ keypoints.index('left_wrist')],
[keypoints.index('right_hip'), keypoints.index('right_knee')],
- [keypoints.index('right_knee'), keypoints.index('right_ankle')],
+ [keypoints.index('right_knee'),
+ keypoints.index('right_ankle')],
[keypoints.index('left_hip'), keypoints.index('left_knee')],
[keypoints.index('left_knee'), keypoints.index('left_ankle')],
- [keypoints.index('right_shoulder'), keypoints.index('left_shoulder')],
+ [keypoints.index('right_shoulder'),
+ keypoints.index('left_shoulder')],
[keypoints.index('right_hip'), keypoints.index('left_hip')],
]
return kp_lines
@@ -130,16 +134,27 @@ def get_class_string(class_index, score, dataset):
def vis_one_image(
- im, im_name, output_dir, boxes, segms=None, keypoints=None, body_uv=None, thresh=0.9,
- kp_thresh=2, dpi=200, box_alpha=0.0, dataset=None, show_class=False,
- ext='pdf'):
+ im,
+ im_name,
+ output_dir,
+ boxes,
+ segms=None,
+ keypoints=None,
+ body_uv=None,
+ thresh=0.9,
+ kp_thresh=2,
+ dpi=200,
+ box_alpha=0.0,
+ dataset=None,
+ show_class=False,
+ ext='pdf'
+):
"""Visual debugging of detections."""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if isinstance(boxes, list):
- boxes, segms, keypoints, classes = convert_from_cls_format(
- boxes, segms, keypoints)
+ boxes, segms, keypoints, classes = convert_from_cls_format(boxes, segms, keypoints)
if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
return
@@ -176,21 +191,27 @@ def vis_one_image(
print(dataset.classes[classes[i]], score)
# show box (off by default, box_alpha=0.0)
ax.add_patch(
- plt.Rectangle((bbox[0], bbox[1]),
- bbox[2] - bbox[0],
- bbox[3] - bbox[1],
- fill=False, edgecolor='g',
- linewidth=0.5, alpha=box_alpha))
+ plt.Rectangle(
+ (bbox[0], bbox[1]),
+ bbox[2] - bbox[0],
+ bbox[3] - bbox[1],
+ fill=False,
+ edgecolor='g',
+ linewidth=0.5,
+ alpha=box_alpha
+ )
+ )
if show_class:
ax.text(
- bbox[0], bbox[1] - 2,
+ bbox[0],
+ bbox[1] - 2,
get_class_string(classes[i], score, dataset),
fontsize=3,
family='serif',
- bbox=dict(
- facecolor='g', alpha=0.4, pad=0, edgecolor='none'),
- color='white')
+ bbox=dict(facecolor='g', alpha=0.4, pad=0, edgecolor='none'),
+ color='white'
+ )
# show mask
if segms is not None and len(segms) > i:
@@ -205,15 +226,17 @@ def vis_one_image(
img[:, :, c] = color_mask[c]
e = masks[:, :, i]
- _, contour, hier = cv2.findContours(
- e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
+ _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
for c in contour:
polygon = Polygon(
c.reshape((-1, 2)),
- fill=True, facecolor=color_mask,
- edgecolor='w', linewidth=1.2,
- alpha=0.5)
+ fill=True,
+ facecolor=color_mask,
+ edgecolor='w',
+ linewidth=1.2,
+ alpha=0.5
+ )
ax.add_patch(polygon)
# show keypoints
@@ -229,41 +252,39 @@ def vis_one_image(
line = ax.plot(x, y)
plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7)
if kps[2, i1] > kp_thresh:
- ax.plot(
- kps[0, i1], kps[1, i1], '.', color=colors[l],
- markersize=3.0, alpha=0.7)
+ ax.plot(kps[0, i1], kps[1, i1], '.', color=colors[l], markersize=3.0, alpha=0.7)
if kps[2, i2] > kp_thresh:
- ax.plot(
- kps[0, i2], kps[1, i2], '.', color=colors[l],
- markersize=3.0, alpha=0.7)
+ ax.plot(kps[0, i2], kps[1, i2], '.', color=colors[l], markersize=3.0, alpha=0.7)
# add mid shoulder / mid hip for better visualization
mid_shoulder = (
kps[:2, dataset_keypoints.index('right_shoulder')] +
- kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
+ kps[:2, dataset_keypoints.index('left_shoulder')]
+ ) / 2.0
sc_mid_shoulder = np.minimum(
kps[2, dataset_keypoints.index('right_shoulder')],
- kps[2, dataset_keypoints.index('left_shoulder')])
+ kps[2, dataset_keypoints.index('left_shoulder')]
+ )
mid_hip = (
kps[:2, dataset_keypoints.index('right_hip')] +
- kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
+ kps[:2, dataset_keypoints.index('left_hip')]
+ ) / 2.0
sc_mid_hip = np.minimum(
kps[2, dataset_keypoints.index('right_hip')],
- kps[2, dataset_keypoints.index('left_hip')])
- if (sc_mid_shoulder > kp_thresh and
- kps[2, dataset_keypoints.index('nose')] > kp_thresh):
+ kps[2, dataset_keypoints.index('left_hip')]
+ )
+ if (
+ sc_mid_shoulder > kp_thresh and kps[2, dataset_keypoints.index('nose')] > kp_thresh
+ ):
x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]]
y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]]
line = ax.plot(x, y)
- plt.setp(
- line, color=colors[len(kp_lines)], linewidth=1.0, alpha=0.7)
+ plt.setp(line, color=colors[len(kp_lines)], linewidth=1.0, alpha=0.7)
if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
x = [mid_shoulder[0], mid_hip[0]]
y = [mid_shoulder[1], mid_hip[1]]
line = ax.plot(x, y)
- plt.setp(
- line, color=colors[len(kp_lines) + 1], linewidth=1.0,
- alpha=0.7)
+ plt.setp(line, color=colors[len(kp_lines) + 1], linewidth=1.0, alpha=0.7)
# DensePose Visualization Starts!!
## Get full IUV image out
@@ -283,14 +304,19 @@ def vis_one_image(
####
output = IUV_fields[ind]
####
- All_Coords_Old = All_Coords[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2], :]
- All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2, 0])[All_Coords_Old == 0]
- All_Coords[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2], :] = All_Coords_Old
+ All_Coords_Old = All_Coords[entry[1]:entry[1] + output.shape[1],
+ entry[0]:entry[0] + output.shape[2], :]
+ All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2,
+ 0])[All_Coords_Old == 0]
+ All_Coords[entry[1]:entry[1] + output.shape[1],
+ entry[0]:entry[0] + output.shape[2], :] = All_Coords_Old
###
CurrentMask = (output[0, :, :] > 0).astype(np.float32)
- All_inds_old = All_inds[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2]]
+ All_inds_old = All_inds[entry[1]:entry[1] + output.shape[1],
+ entry[0]:entry[0] + output.shape[2]]
All_inds_old[All_inds_old == 0] = CurrentMask[All_inds_old == 0] * i
- All_inds[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2]] = All_inds_old
+ All_inds[entry[1]:entry[1] + output.shape[1],
+ entry[0]:entry[0] + output.shape[2]] = All_inds_old
#
All_Coords[:, :, 1:3] = 255. * All_Coords[:, :, 1:3]
All_Coords[All_Coords > 255] = 255.
@@ -323,7 +349,7 @@ def vis_one_image(
entry = boxes[ind, :]
if entry[4] > 0.75:
entry = entry[0:4].astype(int)
- center_roi = [(entry[2]+entry[0]) / 2., (entry[3]+entry[1]) / 2.]
+ center_roi = [(entry[2] + entry[0]) / 2., (entry[3] + entry[1]) / 2.]
####
output, center_out = smpl_fields[ind]
####
@@ -345,7 +371,8 @@ def vis_one_image(
# All_Coords_Old = All_Coords[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2],
# :]
- All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2, 0])[All_Coords_Old == 0]
+ All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2,
+ 0])[All_Coords_Old == 0]
All_Coords[y1_img:y2_img, x1_img:x2_img, :] = All_Coords_Old
###
# CurrentMask = (output[0, :, :] > 0).astype(np.float32)
@@ -376,8 +403,16 @@ def vis_one_image(
plt.close('all')
-def vis_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis,
- file_name=None, nrow=8, padding=0, pad_value=1, add_text=True):
+def vis_batch_image_with_joints(
+ batch_image,
+ batch_joints,
+ batch_joints_vis,
+ file_name=None,
+ nrow=8,
+ padding=0,
+ pad_value=1,
+ add_text=True
+):
'''
batch_image: [batch_size, channel, height, width]
batch_joints: [batch_size, num_joints, 3],
@@ -417,8 +452,10 @@ def vis_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis,
else:
cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 0, [0, 255, 0], -1)
if add_text:
- cv2.putText(ndarr, str(count), (int(joint[0]), int(joint[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
- ( 0, 255, 0), 1)
+ cv2.putText(
+ ndarr, str(count), (int(joint[0]), int(joint[1])),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1
+ )
except Exception as e:
print(e)
k = k + 1
@@ -436,6 +473,7 @@ def vis_img_3Djoint(batch_img, joints, pairs=None, joint_group=None):
n_sample = max_show
color = ['#00B0F0', '#00B050', '#DC6464', '#207070', '#BC4484']
+
# color = ['g', 'b', 'r']
def m_l_r(idx):
@@ -452,7 +490,7 @@ def vis_img_3Djoint(batch_img, joints, pairs=None, joint_group=None):
# ax_img = plt.subplot(n_sample, 2, i * 2 + 1)
ax_img = plt.subplot(2, n_sample, i + 1)
img_np = batch_img[i].cpu().numpy()
- img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C
+ img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C
ax_img.imshow(img_np)
ax_img.set_axis_off()
ax_pred = plt.subplot(2, n_sample, n_sample + i + 1, projection='3d')
@@ -464,14 +502,29 @@ def vis_img_3Djoint(batch_img, joints, pairs=None, joint_group=None):
if plot_kps.shape[1] > 2:
if joint_group is None:
ax_pred.scatter(plot_kps[:, 2], plot_kps[:, 0], plot_kps[:, 1], s=10, marker='.')
- ax_pred.scatter(plot_kps[0, 2], plot_kps[0, 0], plot_kps[0, 1], s=10, c='g', marker='.')
+ ax_pred.scatter(
+ plot_kps[0, 2], plot_kps[0, 0], plot_kps[0, 1], s=10, c='g', marker='.'
+ )
else:
for j in range(len(joint_group)):
- ax_pred.scatter(plot_kps[joint_group[j], 2], plot_kps[joint_group[j], 0], plot_kps[joint_group[j], 1], s=30, c=color[j], marker='s')
+ ax_pred.scatter(
+ plot_kps[joint_group[j], 2],
+ plot_kps[joint_group[j], 0],
+ plot_kps[joint_group[j], 1],
+ s=30,
+ c=color[j],
+ marker='s'
+ )
if pairs is not None:
for p in pairs:
- ax_pred.plot(plot_kps[p, 2], plot_kps[p, 0], plot_kps[p, 1], c=color[m_l_r(p[1])], linewidth=2)
+ ax_pred.plot(
+ plot_kps[p, 2],
+ plot_kps[p, 0],
+ plot_kps[p, 1],
+ c=color[m_l_r(p[1])],
+ linewidth=2
+ )
# ax_pred.set_axis_off()
@@ -483,7 +536,6 @@ def vis_img_3Djoint(batch_img, joints, pairs=None, joint_group=None):
ax_pred.zaxis.set_ticks([])
-
def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None):
n_sample = joints.shape[0]
max_show = 2
@@ -494,6 +546,7 @@ def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None):
n_sample = max_show
color = ['#00B0F0', '#00B050', '#DC6464', '#207070', '#BC4484']
+
# color = ['g', 'b', 'r']
def m_l_r(idx):
@@ -510,7 +563,7 @@ def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None):
# ax_img = plt.subplot(n_sample, 2, i * 2 + 1)
ax_img = plt.subplot(2, n_sample, i + 1)
img_np = batch_img[i].cpu().numpy()
- img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C
+ img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C
ax_img.imshow(img_np)
ax_img.set_axis_off()
ax_pred = plt.subplot(2, n_sample, n_sample + i + 1)
@@ -526,11 +579,23 @@ def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None):
# ax_pred.scatter(plot_kps[0, 0], plot_kps[0, 1], s=10, c='g', marker='.')
else:
for j in range(len(joint_group)):
- ax_pred.scatter(plot_kps[joint_group[j], 0], plot_kps[joint_group[j], 1], s=100, c=color[j], marker='o')
+ ax_pred.scatter(
+ plot_kps[joint_group[j], 0],
+ plot_kps[joint_group[j], 1],
+ s=100,
+ c=color[j],
+ marker='o'
+ )
if pairs is not None:
for p in pairs:
- ax_pred.plot(plot_kps[p, 0], plot_kps[p, 1], c=color[m_l_r(p[1])], linestyle=':', linewidth=3)
+ ax_pred.plot(
+ plot_kps[p, 0],
+ plot_kps[p, 1],
+ c=color[m_l_r(p[1])],
+ linestyle=':',
+ linewidth=3
+ )
ax_pred.set_axis_off()
@@ -542,34 +607,35 @@ def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None):
ax_pred.yaxis.set_ticks([])
# ax_pred.zaxis.set_ticks([])
+
def draw_skeleton(image, kp_2d, dataset='common', unnormalize=True, thickness=2):
if unnormalize:
- kp_2d[:,:2] = normalize_2d_kp(kp_2d[:,:2], 224, inv=True)
+ kp_2d[:, :2] = normalize_2d_kp(kp_2d[:, :2], 224, inv=True)
- kp_2d[:,2] = kp_2d[:,2] > 0.3
+ kp_2d[:, 2] = kp_2d[:, 2] > 0.3
kp_2d = np.array(kp_2d, dtype=int)
rcolor = get_colors()['red'].tolist()
pcolor = get_colors()['green'].tolist()
lcolor = get_colors()['blue'].tolist()
- common_lr = [0,0,1,1,0,0,0,0,1,0,0,1,1,1,0]
- for idx,pt in enumerate(kp_2d):
- if pt[2] > 0: # if visible
+ common_lr = [0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0]
+ for idx, pt in enumerate(kp_2d):
+ if pt[2] > 0: # if visible
if idx % 2 == 0:
color = rcolor
else:
color = pcolor
cv2.circle(image, (pt[0], pt[1]), 4, color, -1)
# cv2.putText(image, f'{idx}', (pt[0]+1, pt[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 0))
-
+
if dataset == 'common' and len(kp_2d) != 15:
return image
skeleton = eval(f'kp_utils.get_{dataset}_skeleton')()
- for i,(j1,j2) in enumerate(skeleton):
- if kp_2d[j1, 2] > 0 and kp_2d[j2, 2] > 0: # if visible
+ for i, (j1, j2) in enumerate(skeleton):
+ if kp_2d[j1, 2] > 0 and kp_2d[j2, 2] > 0: # if visible
if dataset == 'common':
color = rcolor if common_lr[i] == 0 else lcolor
else:
@@ -579,6 +645,7 @@ def draw_skeleton(image, kp_2d, dataset='common', unnormalize=True, thickness=2)
return image
+
# https://stackoverflow.com/questions/13685386/matplotlib-equal-unit-length-with-equal-aspect-ratio-z-axis-is-not-equal-to
def set_axes_equal(ax):
'''Make axes of 3D plot have equal scale so that spheres appear as spheres,
@@ -602,8 +669,8 @@ def set_axes_equal(ax):
# The plot bounding box is a sphere in the sense of the infinity
# norm, hence I call half the max range the plot radius.
- plot_radius = 0.5*max([x_range, y_range, z_range])
+ plot_radius = 0.5 * max([x_range, y_range, z_range])
ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
- ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
\ No newline at end of file
+ ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
diff --git a/lib/smplx/body_models.py b/lib/smplx/body_models.py
index bad1f542e62a4c2765e56893fd78156122409992..b98adb635a3f9296c102e2bb6ca93bcdb14ab57d 100644
--- a/lib/smplx/body_models.py
+++ b/lib/smplx/body_models.py
@@ -61,7 +61,7 @@ ModelOutput = namedtuple(
"jaw_pose",
],
)
-ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields)
+ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields)
class SMPL(nn.Module):
@@ -234,7 +234,9 @@ class SMPL(nn.Module):
default_body_pose = body_pose.clone().detach()
else:
default_body_pose = torch.tensor(body_pose, dtype=dtype)
- self.register_parameter("body_pose", nn.Parameter(default_body_pose, requires_grad=True))
+ self.register_parameter(
+ "body_pose", nn.Parameter(default_body_pose, requires_grad=True)
+ )
if create_transl:
if transl is None:
@@ -403,7 +405,6 @@ class SMPL(nn.Module):
class SMPLLayer(SMPL):
-
def __init__(self, *args, **kwargs) -> None:
# Just create a SMPL module without any member variables
super(SMPLLayer, self).__init__(
@@ -465,11 +466,16 @@ class SMPLLayer(SMPL):
device, dtype = self.shapedirs.device, self.shapedirs.dtype
if global_orient is None:
global_orient = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
if body_pose is None:
body_pose = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1,
- -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3,
+ 3).expand(batch_size, self.NUM_BODY_JOINTS, -1,
+ -1).contiguous()
+ )
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if transl is None:
@@ -630,7 +636,9 @@ class SMPLH(SMPL):
self.np_left_hand_components = left_hand_components
self.np_right_hand_components = right_hand_components
if self.use_pca:
- self.register_buffer("left_hand_components", torch.tensor(left_hand_components, dtype=dtype))
+ self.register_buffer(
+ "left_hand_components", torch.tensor(left_hand_components, dtype=dtype)
+ )
self.register_buffer(
"right_hand_components",
torch.tensor(right_hand_components, dtype=dtype),
@@ -733,7 +741,9 @@ class SMPLH(SMPL):
if self.use_pca:
left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components])
- right_hand_pose = torch.einsum("bi,ij->bj", [right_hand_pose, self.right_hand_components])
+ right_hand_pose = torch.einsum(
+ "bi,ij->bj", [right_hand_pose, self.right_hand_components]
+ )
full_pose = torch.cat([global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1)
@@ -775,7 +785,6 @@ class SMPLH(SMPL):
class SMPLHLayer(SMPLH):
-
def __init__(self, *args, **kwargs) -> None:
"""SMPL+H as a layer model constructor"""
super(SMPLHLayer, self).__init__(
@@ -857,15 +866,24 @@ class SMPLHLayer(SMPLH):
device, dtype = self.shapedirs.device, self.shapedirs.dtype
if global_orient is None:
global_orient = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
if body_pose is None:
- body_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous())
+ body_pose = (
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous()
+ )
if left_hand_pose is None:
left_hand_pose = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
+ )
if right_hand_pose is None:
right_hand_pose = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
+ )
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if transl is None:
@@ -926,7 +944,7 @@ class SMPLX(SMPLH):
which includes joints for the neck, jaw, eyeballs and fingers.
"""
- NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS # 21
+ NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS # 21
NUM_HAND_JOINTS = 15
NUM_FACE_JOINTS = 3
NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS
@@ -1092,7 +1110,9 @@ class SMPLX(SMPLH):
if create_expression:
if expression is None:
- default_expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype)
+ default_expression = torch.zeros(
+ [batch_size, self.num_expression_coeffs], dtype=dtype
+ )
else:
default_expression = torch.tensor(expression, dtype=dtype)
expression_param = nn.Parameter(default_expression, requires_grad=True)
@@ -1226,7 +1246,9 @@ class SMPLX(SMPLH):
if self.use_pca:
left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components])
- right_hand_pose = torch.einsum("bi,ij->bj", [right_hand_pose, self.right_hand_components])
+ right_hand_pose = torch.einsum(
+ "bi,ij->bj", [right_hand_pose, self.right_hand_components]
+ )
full_pose = torch.cat(
[
@@ -1315,7 +1337,9 @@ class SMPLX(SMPLH):
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
- lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1)
+ lmk_bary_coords = torch.cat(
+ [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1
+ )
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
@@ -1350,7 +1374,6 @@ class SMPLX(SMPLH):
class SMPLXLayer(SMPLX):
-
def __init__(self, *args, **kwargs) -> None:
# Just create a SMPLX module without any member variables
super(SMPLXLayer, self).__init__(
@@ -1454,25 +1477,45 @@ class SMPLXLayer(SMPLX):
if global_orient is None:
global_orient = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
if body_pose is None:
body_pose = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1,
- -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3,
+ 3).expand(batch_size, self.NUM_BODY_JOINTS, -1,
+ -1).contiguous()
+ )
if left_hand_pose is None:
left_hand_pose = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
+ )
if right_hand_pose is None:
right_hand_pose = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
+ )
if jaw_pose is None:
- jaw_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ jaw_pose = (
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
if leye_pose is None:
- leye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ leye_pose = (
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
if reye_pose is None:
- reye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ reye_pose = (
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
if expression is None:
- expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype, device=device)
+ expression = torch.zeros(
+ [batch_size, self.num_expression_coeffs], dtype=dtype, device=device
+ )
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if transl is None:
@@ -1521,7 +1564,9 @@ class SMPLXLayer(SMPLX):
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
- lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1)
+ lmk_bary_coords = torch.cat(
+ [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1
+ )
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
@@ -1646,7 +1691,9 @@ class MANO(SMPL):
)
# add only MANO tips to the extra joints
- self.vertex_joint_selector.extra_joints_idxs = to_tensor(list(VERTEX_IDS["mano"].values()), dtype=torch.long)
+ self.vertex_joint_selector.extra_joints_idxs = to_tensor(
+ list(VERTEX_IDS["mano"].values()), dtype=torch.long
+ )
self.use_pca = use_pca
self.num_pca_comps = num_pca_comps
@@ -1765,7 +1812,6 @@ class MANO(SMPL):
class MANOLayer(MANO):
-
def __init__(self, *args, **kwargs) -> None:
"""MANO as a layer model constructor"""
super(MANOLayer, self).__init__(
@@ -1795,11 +1841,16 @@ class MANOLayer(MANO):
if global_orient is None:
batch_size = 1
global_orient = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
else:
batch_size = global_orient.shape[0]
if hand_pose is None:
- hand_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous())
+ hand_pose = (
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
+ )
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if transl is None:
@@ -1993,7 +2044,9 @@ class FLAME(SMPL):
if create_expression:
if expression is None:
- default_expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype)
+ default_expression = torch.zeros(
+ [batch_size, self.num_expression_coeffs], dtype=dtype
+ )
else:
default_expression = torch.tensor(expression, dtype=dtype)
expression_param = nn.Parameter(default_expression, requires_grad=True)
@@ -2012,7 +2065,8 @@ class FLAME(SMPL):
self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype))
if self.use_face_contour:
face_contour_path = os.path.join(model_path, "flame_dynamic_embedding.npy")
- contour_embeddings = np.load(face_contour_path, allow_pickle=True, encoding="latin1")[()]
+ contour_embeddings = np.load(face_contour_path, allow_pickle=True,
+ encoding="latin1")[()]
dynamic_lmk_faces_idx = np.array(contour_embeddings["lmk_face_idx"], dtype=np.int64)
dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long)
@@ -2148,7 +2202,9 @@ class FLAME(SMPL):
)
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
- lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1)
+ lmk_bary_coords = torch.cat(
+ [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1
+ )
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
@@ -2179,7 +2235,6 @@ class FLAME(SMPL):
class FLAMELayer(FLAME):
-
def __init__(self, *args, **kwargs) -> None:
""" FLAME as a layer model constructor """
super(FLAMELayer, self).__init__(
@@ -2248,21 +2303,37 @@ class FLAMELayer(FLAME):
if global_orient is None:
batch_size = 1
global_orient = (
- torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
else:
batch_size = global_orient.shape[0]
if neck_pose is None:
- neck_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous())
+ neck_pose = (
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous()
+ )
if jaw_pose is None:
- jaw_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ jaw_pose = (
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
if leye_pose is None:
- leye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ leye_pose = (
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
if reye_pose is None:
- reye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous())
+ reye_pose = (
+ torch.eye(3, device=device,
+ dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
+ )
if betas is None:
betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
if expression is None:
- expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype, device=device)
+ expression = torch.zeros(
+ [batch_size, self.num_expression_coeffs], dtype=dtype, device=device
+ )
if transl is None:
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
@@ -2296,7 +2367,9 @@ class FLAMELayer(FLAME):
)
dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
- lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1)
+ lmk_bary_coords = torch.cat(
+ [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1
+ )
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
@@ -2391,7 +2464,9 @@ def build_layer(model_path: str,
raise ValueError(f"Unknown model type {model_type}, exiting!")
-def create(model_path: str, model_type: str = "smpl", **kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]:
+def create(model_path: str,
+ model_type: str = "smpl",
+ **kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]:
"""Method for creating a model from a path and a model type
Parameters
diff --git a/lib/smplx/joint_names.py b/lib/smplx/joint_names.py
index eadde0139deb4a62345d53cc1a8eb151bf83c1b5..a4f7cb0a3d2f9712de47e32e23ae36f918be6abc 100644
--- a/lib/smplx/joint_names.py
+++ b/lib/smplx/joint_names.py
@@ -129,8 +129,8 @@ JOINT_NAMES = [
"left_mouth_3",
"left_mouth_2",
"left_mouth_1",
- "left_mouth_5", # 59 in OpenPose output
- "left_mouth_4", # 58 in OpenPose output
+ "left_mouth_5", # 59 in OpenPose output
+ "left_mouth_4", # 58 in OpenPose output
"mouth_bottom",
"right_mouth_4",
"right_mouth_5",
diff --git a/lib/smplx/lbs.py b/lib/smplx/lbs.py
index c74f480fd146db9f70e92a20baec2543a7c30ca2..ac64f4b41be569331d632bfeb50fef9c50dc3d71 100644
--- a/lib/smplx/lbs.py
+++ b/lib/smplx/lbs.py
@@ -79,11 +79,15 @@ def find_dynamic_lmk_idx_and_bcoords(
else:
rot_mats = torch.index_select(pose.view(batch_size, -1, 3, 3), 1, neck_kin_chain)
- rel_rot_mat = (torch.eye(3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1))
+ rel_rot_mat = (
+ torch.eye(3, device=vertices.device,
+ dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1)
+ )
for idx in range(len(neck_kin_chain)):
rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
- y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, max=39)).to(dtype=torch.long)
+ y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
+ max=39)).to(dtype=torch.long)
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
@@ -95,7 +99,9 @@ def find_dynamic_lmk_idx_and_bcoords(
return dyn_lmk_faces_idx, dyn_lmk_b_coords
-def vertices2landmarks(vertices: Tensor, faces: Tensor, lmk_faces_idx: Tensor, lmk_bary_coords: Tensor) -> Tensor:
+def vertices2landmarks(
+ vertices: Tensor, faces: Tensor, lmk_faces_idx: Tensor, lmk_bary_coords: Tensor
+) -> Tensor:
"""Calculates landmarks by barycentric interpolation
Parameters
@@ -123,7 +129,9 @@ def vertices2landmarks(vertices: Tensor, faces: Tensor, lmk_faces_idx: Tensor, l
lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(batch_size, -1, 3)
- lmk_faces += (torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts)
+ lmk_faces += (
+ torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
+ )
lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3)
@@ -205,7 +213,8 @@ def lbs(
pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
rot_mats = pose.view(batch_size, -1, 3, 3)
- pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(batch_size, -1, 3)
+ pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
+ posedirs).view(batch_size, -1, 3)
v_posed = pose_offsets + v_shaped
# 4. Get the global joint location
@@ -292,7 +301,8 @@ def general_lbs(
else:
rot_mats = pose.view(batch_size, -1, 3, 3)
pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
- pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(batch_size, -1, 3)
+ pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
+ posedirs).view(batch_size, -1, 3)
v_posed = pose_offsets + v_template
@@ -407,7 +417,9 @@ def transform_mat(R: Tensor, t: Tensor) -> Tensor:
return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
-def batch_rigid_transform(rot_mats: Tensor, joints: Tensor, parents: Tensor, dtype=torch.float32) -> Tensor:
+def batch_rigid_transform(
+ rot_mats: Tensor, joints: Tensor, parents: Tensor, dtype=torch.float32
+) -> Tensor:
"""
Applies a batch of rigid transformations to the joints
@@ -436,7 +448,8 @@ def batch_rigid_transform(rot_mats: Tensor, joints: Tensor, parents: Tensor, dty
rel_joints = joints.clone()
rel_joints[:, 1:] -= joints[:, parents[1:]]
- transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
+ transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
+ rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
transform_chain = [transforms_mat[:, 0]]
for i in range(1, parents.shape[0]):
@@ -452,6 +465,8 @@ def batch_rigid_transform(rot_mats: Tensor, joints: Tensor, parents: Tensor, dty
joints_homogen = F.pad(joints, [0, 0, 0, 1])
- rel_transforms = transforms - F.pad(torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
+ rel_transforms = transforms - F.pad(
+ torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]
+ )
return posed_joints, rel_transforms
diff --git a/lib/smplx/utils.py b/lib/smplx/utils.py
index 6deb46a514bf97ed74ab5314b93b1148b1f376f8..d43a25217573f4c327adbf0411a76d1081632a69 100644
--- a/lib/smplx/utils.py
+++ b/lib/smplx/utils.py
@@ -105,7 +105,6 @@ def to_tensor(array: Union[Array, Tensor], dtype=torch.float32) -> Tensor:
class Struct(object):
-
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
@@ -121,6 +120,5 @@ def rot_mat_to_euler(rot_mats):
# Calculates rotation matrix to euler angles
# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
- sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
- rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
+ sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
return torch.atan2(-rot_mats[:, 2, 0], sy)
diff --git a/lib/smplx/vertex_ids.py b/lib/smplx/vertex_ids.py
index b45c0b84d5462f40adf776e6b615e8fb07f7be26..31ed146ed4b3529bfbe0c92450bd3b02559f338b 100644
--- a/lib/smplx/vertex_ids.py
+++ b/lib/smplx/vertex_ids.py
@@ -21,52 +21,54 @@ from __future__ import division
# Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to
# MSCOCO and OpenPose joints
vertex_ids = {
- "smplh": {
- "nose": 332,
- "reye": 6260,
- "leye": 2800,
- "rear": 4071,
- "lear": 583,
- "rthumb": 6191,
- "rindex": 5782,
- "rmiddle": 5905,
- "rring": 6016,
- "rpinky": 6133,
- "lthumb": 2746,
- "lindex": 2319,
- "lmiddle": 2445,
- "lring": 2556,
- "lpinky": 2673,
- "LBigToe": 3216,
- "LSmallToe": 3226,
- "LHeel": 3387,
- "RBigToe": 6617,
- "RSmallToe": 6624,
- "RHeel": 6787,
- },
- "smplx": {
- "nose": 9120,
- "reye": 9929,
- "leye": 9448,
- "rear": 616,
- "lear": 6,
- "rthumb": 8079,
- "rindex": 7669,
- "rmiddle": 7794,
- "rring": 7905,
- "rpinky": 8022,
- "lthumb": 5361,
- "lindex": 4933,
- "lmiddle": 5058,
- "lring": 5169,
- "lpinky": 5286,
- "LBigToe": 5770,
- "LSmallToe": 5780,
- "LHeel": 8846,
- "RBigToe": 8463,
- "RSmallToe": 8474,
- "RHeel": 8635,
- },
+ "smplh":
+ {
+ "nose": 332,
+ "reye": 6260,
+ "leye": 2800,
+ "rear": 4071,
+ "lear": 583,
+ "rthumb": 6191,
+ "rindex": 5782,
+ "rmiddle": 5905,
+ "rring": 6016,
+ "rpinky": 6133,
+ "lthumb": 2746,
+ "lindex": 2319,
+ "lmiddle": 2445,
+ "lring": 2556,
+ "lpinky": 2673,
+ "LBigToe": 3216,
+ "LSmallToe": 3226,
+ "LHeel": 3387,
+ "RBigToe": 6617,
+ "RSmallToe": 6624,
+ "RHeel": 6787,
+ },
+ "smplx":
+ {
+ "nose": 9120,
+ "reye": 9929,
+ "leye": 9448,
+ "rear": 616,
+ "lear": 6,
+ "rthumb": 8079,
+ "rindex": 7669,
+ "rmiddle": 7794,
+ "rring": 7905,
+ "rpinky": 8022,
+ "lthumb": 5361,
+ "lindex": 4933,
+ "lmiddle": 5058,
+ "lring": 5169,
+ "lpinky": 5286,
+ "LBigToe": 5770,
+ "LSmallToe": 5780,
+ "LHeel": 8846,
+ "RBigToe": 8463,
+ "RSmallToe": 8474,
+ "RHeel": 8635,
+ },
"mano": {
"thumb": 744,
"index": 320,
diff --git a/lib/smplx/vertex_joint_selector.py b/lib/smplx/vertex_joint_selector.py
index facf2afe433fde7f63a9978caa0258a7a38a30f3..1680e07acb03402a54fc0621ab36ec1d4de2c78e 100644
--- a/lib/smplx/vertex_joint_selector.py
+++ b/lib/smplx/vertex_joint_selector.py
@@ -27,12 +27,7 @@ from .utils import to_tensor
class VertexJointSelector(nn.Module):
-
- def __init__(self,
- vertex_ids=None,
- use_hands=True,
- use_feet_keypoints=True,
- **kwargs):
+ def __init__(self, vertex_ids=None, use_hands=True, use_feet_keypoints=True, **kwargs):
super(VertexJointSelector, self).__init__()
extra_joints_idxs = []
@@ -63,8 +58,7 @@ class VertexJointSelector(nn.Module):
dtype=np.int32,
)
- extra_joints_idxs = np.concatenate(
- [extra_joints_idxs, feet_keyp_idxs])
+ extra_joints_idxs = np.concatenate([extra_joints_idxs, feet_keyp_idxs])
if use_hands:
self.tip_names = ["thumb", "index", "middle", "ring", "pinky"]
@@ -76,8 +70,7 @@ class VertexJointSelector(nn.Module):
extra_joints_idxs = np.concatenate([extra_joints_idxs, tips_idxs])
- self.register_buffer("extra_joints_idxs",
- to_tensor(extra_joints_idxs, dtype=torch.long))
+ self.register_buffer("extra_joints_idxs", to_tensor(extra_joints_idxs, dtype=torch.long))
def forward(self, vertices, joints):
extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs)
diff --git a/lib/torch_utils/custom_ops.py b/lib/torch_utils/custom_ops.py
index 4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e..2170f4732aba52f614b7cec09ac62465275ad90b 100644
--- a/lib/torch_utils/custom_ops.py
+++ b/lib/torch_utils/custom_ops.py
@@ -20,11 +20,12 @@ from torch.utils.file_baton import FileBaton
#----------------------------------------------------------------------------
# Global options.
-verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
#----------------------------------------------------------------------------
# Internal helper funcs.
+
def _find_compiler_bindir():
patterns = [
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
@@ -38,11 +39,13 @@ def _find_compiler_bindir():
return matches[-1]
return None
+
#----------------------------------------------------------------------------
# Main entry point for compiling and loading C++/CUDA plugins.
_cached_plugins = dict()
+
def get_plugin(module_name, sources, **build_kwargs):
assert verbosity in ['none', 'brief', 'full']
@@ -56,12 +59,14 @@ def get_plugin(module_name, sources, **build_kwargs):
elif verbosity == 'brief':
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
- try: # pylint: disable=too-many-nested-blocks
+ try: # pylint: disable=too-many-nested-blocks
# Make sure we can find the necessary compiler binaries.
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
compiler_bindir = _find_compiler_bindir()
if compiler_bindir is None:
- raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ raise RuntimeError(
+ f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".'
+ )
os.environ['PATH'] += ';' + compiler_bindir
# Compile and load.
@@ -79,7 +84,9 @@ def get_plugin(module_name, sources, **build_kwargs):
# actually cares about this.)
source_dirs_set = set(os.path.dirname(source) for source in sources)
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
- all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
+ all_source_files = sorted(
+ list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())
+ )
# Compute a combined hash digest for all source files in the same
# custom op directory (usually .cu, .cpp, .py and .h files).
@@ -87,7 +94,9 @@ def get_plugin(module_name, sources, **build_kwargs):
for src in all_source_files:
with open(src, 'rb') as f:
hash_md5.update(f.read())
- build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ build_dir = torch.utils.cpp_extension._get_build_directory(
+ module_name, verbose=verbose_build
+ ) # pylint: disable=protected-access
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
if not os.path.isdir(digest_build_dir):
@@ -96,7 +105,9 @@ def get_plugin(module_name, sources, **build_kwargs):
if baton.try_acquire():
try:
for src in all_source_files:
- shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
+ shutil.copyfile(
+ src, os.path.join(digest_build_dir, os.path.basename(src))
+ )
finally:
baton.release()
else:
@@ -104,10 +115,17 @@ def get_plugin(module_name, sources, **build_kwargs):
# wait until done and continue.
baton.wait()
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
- torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
- verbose=verbose_build, sources=digest_sources, **build_kwargs)
+ torch.utils.cpp_extension.load(
+ name=module_name,
+ build_directory=build_dir,
+ verbose=verbose_build,
+ sources=digest_sources,
+ **build_kwargs
+ )
else:
- torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+ torch.utils.cpp_extension.load(
+ name=module_name, verbose=verbose_build, sources=sources, **build_kwargs
+ )
module = importlib.import_module(module_name)
except:
@@ -123,4 +141,5 @@ def get_plugin(module_name, sources, **build_kwargs):
_cached_plugins[module_name] = module
return module
+
#----------------------------------------------------------------------------
diff --git a/lib/torch_utils/misc.py b/lib/torch_utils/misc.py
index 7829f4d9f168557ce8a9a6dec289aa964234cb8c..61c266a84d83e9a486df52e725af1c51488951e4 100644
--- a/lib/torch_utils/misc.py
+++ b/lib/torch_utils/misc.py
@@ -19,6 +19,7 @@ import dnnlib
_constant_cache = dict()
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
value = np.asarray(value)
if shape is not None:
@@ -40,13 +41,15 @@ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
_constant_cache[key] = tensor
return tensor
+
#----------------------------------------------------------------------------
# Replace NaN/Inf with specified numerical values.
try:
- nan_to_num = torch.nan_to_num # 1.8.0a0
+ nan_to_num = torch.nan_to_num # 1.8.0a0
except AttributeError:
- def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
assert isinstance(input, torch.Tensor)
if posinf is None:
posinf = torch.finfo(input.dtype).max
@@ -55,57 +58,73 @@ except AttributeError:
assert nan == 0
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
#----------------------------------------------------------------------------
# Symbolic assert.
try:
- symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
except AttributeError:
- symbolic_assert = torch.Assert # 1.7.0
+ symbolic_assert = torch.Assert # 1.7.0
#----------------------------------------------------------------------------
# Context manager to suppress known warnings in torch.jit.trace().
+
class suppress_tracer_warnings(warnings.catch_warnings):
def __enter__(self):
super().__enter__()
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
return self
+
#----------------------------------------------------------------------------
# Assert that the shape of a tensor matches the given list of integers.
# None indicates that the size of a dimension is allowed to vary.
# Performs symbolic assertion when used in torch.jit.trace().
+
def assert_shape(tensor, ref_shape):
if tensor.ndim != len(ref_shape):
- raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ raise AssertionError(
+ f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}'
+ )
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
if ref_size is None:
pass
elif isinstance(ref_size, torch.Tensor):
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
- symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(
+ torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}'
+ )
elif isinstance(size, torch.Tensor):
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
- symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(
+ torch.equal(size, torch.as_tensor(ref_size)),
+ f'Wrong size for dimension {idx}: expected {ref_size}'
+ )
elif size != ref_size:
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
#----------------------------------------------------------------------------
# Function decorator that calls torch.autograd.profiler.record_function().
+
def profiled_function(fn):
def decorator(*args, **kwargs):
with torch.autograd.profiler.record_function(fn.__name__):
return fn(*args, **kwargs)
+
decorator.__name__ = fn.__name__
return decorator
+
#----------------------------------------------------------------------------
# Sampler for torch.utils.data.DataLoader that loops over the dataset
# indefinitely, shuffling items as it goes.
+
class InfiniteSampler(torch.utils.data.Sampler):
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
assert len(dataset) > 0
@@ -139,17 +158,21 @@ class InfiniteSampler(torch.utils.data.Sampler):
order[i], order[j] = order[j], order[i]
idx += 1
+
#----------------------------------------------------------------------------
# Utilities for operating with torch.nn.Module parameters and buffers.
+
def params_and_buffers(module):
assert isinstance(module, torch.nn.Module)
return list(module.parameters()) + list(module.buffers())
+
def named_params_and_buffers(module):
assert isinstance(module, torch.nn.Module)
return list(module.named_parameters()) + list(module.named_buffers())
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
assert isinstance(src_module, torch.nn.Module)
assert isinstance(dst_module, torch.nn.Module)
@@ -159,10 +182,12 @@ def copy_params_and_buffers(src_module, dst_module, require_all=False):
if name in src_tensors:
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
#----------------------------------------------------------------------------
# Context manager for easily enabling/disabling DistributedDataParallel
# synchronization.
+
@contextlib.contextmanager
def ddp_sync(module, sync):
assert isinstance(module, torch.nn.Module)
@@ -172,9 +197,11 @@ def ddp_sync(module, sync):
with module.no_sync():
yield
+
#----------------------------------------------------------------------------
# Check DistributedDataParallel consistency across processes.
+
def check_ddp_consistency(module, ignore_regex=None):
assert isinstance(module, torch.nn.Module)
for name, tensor in named_params_and_buffers(module):
@@ -186,9 +213,11 @@ def check_ddp_consistency(module, ignore_regex=None):
torch.distributed.broadcast(tensor=other, src=0)
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
+
#----------------------------------------------------------------------------
# Print summary table of module hierarchy.
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
assert isinstance(module, torch.nn.Module)
assert not isinstance(module, torch.jit.ScriptModule)
@@ -197,14 +226,17 @@ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
# Register hooks.
entries = []
nesting = [0]
+
def pre_hook(_mod, _inputs):
nesting[0] += 1
+
def post_hook(mod, _inputs, outputs):
nesting[0] -= 1
if nesting[0] <= max_nesting:
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
@@ -223,7 +255,10 @@ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
# Filter out redundant entries.
if skip_redundant:
- entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+ entries = [
+ e for e in entries
+ if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)
+ ]
# Construct table.
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
@@ -237,13 +272,15 @@ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
buffer_size = sum(t.numel() for t in e.unique_buffers)
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
- rows += [[
- name + (':0' if len(e.outputs) >= 2 else ''),
- str(param_size) if param_size else '-',
- str(buffer_size) if buffer_size else '-',
- (output_shapes + ['-'])[0],
- (output_dtypes + ['-'])[0],
- ]]
+ rows += [
+ [
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]
+ ]
for idx in range(1, len(e.outputs)):
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
param_total += param_size
@@ -259,4 +296,5 @@ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
print()
return outputs
+
#----------------------------------------------------------------------------
diff --git a/lib/torch_utils/ops/bias_act.py b/lib/torch_utils/ops/bias_act.py
index 4bcb409a89ccf6c6f6ecfca5962683df2d280b1f..d8cfdb65d25ed077827862bc70e860c450fe929a 100644
--- a/lib/torch_utils/ops/bias_act.py
+++ b/lib/torch_utils/ops/bias_act.py
@@ -5,7 +5,6 @@
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
"""Custom PyTorch ops for efficient bias and activation."""
import os
@@ -21,15 +20,82 @@ from .. import misc
#----------------------------------------------------------------------------
activation_funcs = {
- 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
- 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
- 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
- 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
- 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
- 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
- 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
- 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
- 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+ 'linear':
+ dnnlib.EasyDict(
+ func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False
+ ),
+ 'relu':
+ dnnlib.EasyDict(
+ func=lambda x, **_: torch.nn.functional.relu(x),
+ def_alpha=0,
+ def_gain=np.sqrt(2),
+ cuda_idx=2,
+ ref='y',
+ has_2nd_grad=False
+ ),
+ 'lrelu':
+ dnnlib.EasyDict(
+ func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
+ def_alpha=0.2,
+ def_gain=np.sqrt(2),
+ cuda_idx=3,
+ ref='y',
+ has_2nd_grad=False
+ ),
+ 'tanh':
+ dnnlib.EasyDict(
+ func=lambda x, **_: torch.tanh(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=4,
+ ref='y',
+ has_2nd_grad=True
+ ),
+ 'sigmoid':
+ dnnlib.EasyDict(
+ func=lambda x, **_: torch.sigmoid(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=5,
+ ref='y',
+ has_2nd_grad=True
+ ),
+ 'elu':
+ dnnlib.EasyDict(
+ func=lambda x, **_: torch.nn.functional.elu(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=6,
+ ref='y',
+ has_2nd_grad=True
+ ),
+ 'selu':
+ dnnlib.EasyDict(
+ func=lambda x, **_: torch.nn.functional.selu(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=7,
+ ref='y',
+ has_2nd_grad=True
+ ),
+ 'softplus':
+ dnnlib.EasyDict(
+ func=lambda x, **_: torch.nn.functional.softplus(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=8,
+ ref='y',
+ has_2nd_grad=True
+ ),
+ 'swish':
+ dnnlib.EasyDict(
+ func=lambda x, **_: torch.sigmoid(x) * x,
+ def_alpha=0,
+ def_gain=np.sqrt(2),
+ cuda_idx=9,
+ ref='x',
+ has_2nd_grad=True
+ ),
}
#----------------------------------------------------------------------------
@@ -38,6 +104,7 @@ _inited = False
_plugin = None
_null_tensor = torch.empty([0])
+
def _init():
global _inited, _plugin
if not _inited:
@@ -45,13 +112,20 @@ def _init():
sources = ['bias_act.cpp', 'bias_act.cu']
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
try:
- _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ _plugin = custom_ops.get_plugin(
+ 'bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']
+ )
except:
- warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ warnings.warn(
+ 'Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n'
+ + traceback.format_exc()
+ )
return _plugin is not None
+
#----------------------------------------------------------------------------
+
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
r"""Fused bias and activation function.
@@ -88,8 +162,10 @@ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None,
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
#----------------------------------------------------------------------------
+
@misc.profiled_function
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
@@ -119,13 +195,15 @@ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=N
# Clamp.
if clamp >= 0:
- x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
return x
+
#----------------------------------------------------------------------------
_bias_act_cuda_cache = dict()
+
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
"""Fast CUDA implementation of `bias_act()` using custom ops.
"""
@@ -144,21 +222,26 @@ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
# Forward op.
class BiasActCuda(torch.autograd.Function):
@staticmethod
- def forward(ctx, x, b): # pylint: disable=arguments-differ
- ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(
+ )[1] == 1 else torch.contiguous_format
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor
y = x
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
- y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ y = _plugin.bias_act(
+ x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha,
+ gain, clamp
+ )
ctx.save_for_backward(
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
- y if 'y' in spec.ref else _null_tensor)
+ y if 'y' in spec.ref else _null_tensor
+ )
return y
@staticmethod
- def backward(ctx, dy): # pylint: disable=arguments-differ
+ def backward(ctx, dy): # pylint: disable=arguments-differ
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
@@ -177,16 +260,17 @@ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
# Backward op.
class BiasActCudaGrad(torch.autograd.Function):
@staticmethod
- def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
- ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
- dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
- ctx.save_for_backward(
- dy if spec.has_2nd_grad else _null_tensor,
- x, b, y)
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(
+ )[1] == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(
+ dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp
+ )
+ ctx.save_for_backward(dy if spec.has_2nd_grad else _null_tensor, x, b, y)
return dx
@staticmethod
- def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
@@ -209,4 +293,5 @@ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda
+
#----------------------------------------------------------------------------
diff --git a/lib/torch_utils/ops/conv2d_gradfix.py b/lib/torch_utils/ops/conv2d_gradfix.py
index e95e10d0b1d0315a63a76446fd4c5c293c8bbc6d..29c3d8f5a8a1e2816e225af3157fc1bb99a4fd33 100644
--- a/lib/torch_utils/ops/conv2d_gradfix.py
+++ b/lib/torch_utils/ops/conv2d_gradfix.py
@@ -5,7 +5,6 @@
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
"""Custom replacement for `torch.nn.functional.conv2d` that supports
arbitrarily high order gradients with zero performance penalty."""
@@ -19,8 +18,9 @@ import torch
#----------------------------------------------------------------------------
-enabled = False # Enable the custom op by setting this to true.
-weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+
@contextlib.contextmanager
def no_weight_gradients():
@@ -30,20 +30,60 @@ def no_weight_gradients():
yield
weight_gradients_disabled = old
+
#----------------------------------------------------------------------------
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if _should_use_custom_op(input):
- return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
- return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
-
-def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ return _conv2d_gradfix(
+ transpose=False,
+ weight_shape=weight.shape,
+ stride=stride,
+ padding=padding,
+ output_padding=0,
+ dilation=dilation,
+ groups=groups
+ ).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(
+ input=input,
+ weight=weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups
+ )
+
+
+def conv_transpose2d(
+ input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1
+):
if _should_use_custom_op(input):
- return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
- return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+ return _conv2d_gradfix(
+ transpose=True,
+ weight_shape=weight.shape,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation
+ ).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(
+ input=input,
+ weight=weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation
+ )
+
#----------------------------------------------------------------------------
+
def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if (not enabled) or (not torch.backends.cudnn.enabled):
@@ -52,19 +92,24 @@ def _should_use_custom_op(input):
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
return True
- warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
+ warnings.warn(
+ f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().'
+ )
return False
+
def _tuple_of_ints(xs, ndim):
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim
assert len(xs) == ndim
assert all(isinstance(x, int) for x in xs)
return xs
+
#----------------------------------------------------------------------------
_conv2d_gradfix_cache = dict()
+
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
# Parse arguments.
ndim = 2
@@ -87,20 +132,18 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di
assert all(dilation[i] >= 0 for i in range(ndim))
if not transpose:
assert all(output_padding[i] == 0 for i in range(ndim))
- else: # transpose
+ else: # transpose
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
# Helpers.
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+
def calc_output_padding(input_shape, output_shape):
if transpose:
return [0, 0]
return [
- input_shape[i + 2]
- - (output_shape[i + 2] - 1) * stride[i]
- - (1 - 2 * padding[i])
- - dilation[i] * (weight_shape[i + 2] - 1)
- for i in range(ndim)
+ input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] - (1 - 2 * padding[i]) -
+ dilation[i] * (weight_shape[i + 2] - 1) for i in range(ndim)
]
# Forward & backward.
@@ -109,9 +152,17 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di
def forward(ctx, input, weight, bias):
assert weight.shape == weight_shape
if not transpose:
- output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
- else: # transpose
- output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ output = torch.nn.functional.conv2d(
+ input=input, weight=weight, bias=bias, **common_kwargs
+ )
+ else: # transpose
+ output = torch.nn.functional.conv_transpose2d(
+ input=input,
+ weight=weight,
+ bias=bias,
+ output_padding=output_padding,
+ **common_kwargs
+ )
ctx.save_for_backward(input, weight)
return output
@@ -124,7 +175,12 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di
if ctx.needs_input_grad[0]:
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
- grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
+ grad_input = _conv2d_gradfix(
+ transpose=(not transpose),
+ weight_shape=weight_shape,
+ output_padding=p,
+ **common_kwargs
+ ).apply(grad_output, weight, None)
assert grad_input.shape == input.shape
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
@@ -140,9 +196,17 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di
class Conv2dGradWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input):
- op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
- flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
- grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
+ op = torch._C._jit_get_operation(
+ 'aten::cudnn_convolution_backward_weight'
+ if not transpose else 'aten::cudnn_convolution_transpose_backward_weight'
+ )
+ flags = [
+ torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
+ torch.backends.cudnn.allow_tf32
+ ]
+ grad_weight = op(
+ weight_shape, grad_output, input, padding, stride, dilation, groups, *flags
+ )
assert grad_weight.shape == weight_shape
ctx.save_for_backward(grad_output, input)
return grad_weight
@@ -159,7 +223,12 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di
if ctx.needs_input_grad[1]:
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
- grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
+ grad2_input = _conv2d_gradfix(
+ transpose=(not transpose),
+ weight_shape=weight_shape,
+ output_padding=p,
+ **common_kwargs
+ ).apply(grad_output, grad2_grad_weight, None)
assert grad2_input.shape == input.shape
return grad2_grad_output, grad2_input
@@ -167,4 +236,5 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di
_conv2d_gradfix_cache[key] = Conv2d
return Conv2d
+
#----------------------------------------------------------------------------
diff --git a/lib/torch_utils/ops/conv2d_resample.py b/lib/torch_utils/ops/conv2d_resample.py
index cd4750744c83354bab78704d4ef51ad1070fcc4a..9f347c59165d1aceafee936b36281610b5a64e1b 100644
--- a/lib/torch_utils/ops/conv2d_resample.py
+++ b/lib/torch_utils/ops/conv2d_resample.py
@@ -5,7 +5,6 @@
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
"""2D convolution with optional up/downsampling."""
import torch
@@ -18,21 +17,24 @@ from .upfirdn2d import _get_filter_size
#----------------------------------------------------------------------------
+
def _get_weight_shape(w):
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
shape = [int(sz) for sz in w.shape]
misc.assert_shape(w, shape)
return shape
+
#----------------------------------------------------------------------------
+
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
"""
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
# Flip weight if requested.
- if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
w = w.flip([2, 3])
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
@@ -53,10 +55,14 @@ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_w
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
return op(x, w, stride=stride, padding=padding, groups=groups)
+
#----------------------------------------------------------------------------
+
@misc.profiled_function
-def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+def conv2d_resample(
+ x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
+):
r"""2D convolution with optional up/downsampling.
Padding is performed only once at the beginning, not between the operations.
@@ -83,7 +89,9 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
# Validate arguments.
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
- assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert f is None or (
+ isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32
+ )
assert isinstance(up, int) and (up >= 1)
assert isinstance(down, int) and (down >= 1)
assert isinstance(groups, int) and (groups >= 1)
@@ -105,19 +113,23 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
if kw == 1 and kh == 1 and (down > 1 and up == 1):
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = upfirdn2d.upfirdn2d(
+ x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
+ )
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
return x
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
if kw == 1 and kh == 1 and (up > 1 and down == 1):
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = upfirdn2d.upfirdn2d(
+ x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter
+ )
return x
# Fast path: downsampling only => use strided convolution.
if down > 1 and up == 1:
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
return x
@@ -135,8 +147,22 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
py1 -= kh - up
pxt = max(min(-px0, -px1), 0)
pyt = max(min(-py0, -py1), 0)
- x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(
+ x=x,
+ w=w,
+ stride=up,
+ padding=[pyt, pxt],
+ groups=groups,
+ transpose=True,
+ flip_weight=(not flip_weight)
+ )
+ x = upfirdn2d.upfirdn2d(
+ x=x,
+ f=f,
+ padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
+ gain=up**2,
+ flip_filter=flip_filter
+ )
if down > 1:
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
return x
@@ -144,13 +170,23 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
if up == 1 and down == 1:
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
- return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+ return _conv2d_wrapper(
+ x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
+ )
# Fallback: Generic reference implementation.
- x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = upfirdn2d.upfirdn2d(
+ x=x,
+ f=(f if up > 1 else None),
+ up=up,
+ padding=[px0, px1, py0, py1],
+ gain=up**2,
+ flip_filter=flip_filter
+ )
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
if down > 1:
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
return x
+
#----------------------------------------------------------------------------
diff --git a/lib/torch_utils/ops/fma.py b/lib/torch_utils/ops/fma.py
index 2eeac58a626c49231e04122b93e321ada954c5d3..5c030932fb439b4dcc7b08ad55d0fa2aa9d8f82f 100644
--- a/lib/torch_utils/ops/fma.py
+++ b/lib/torch_utils/ops/fma.py
@@ -5,28 +5,30 @@
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
import torch
#----------------------------------------------------------------------------
-def fma(a, b, c): # => a * b + c
+
+def fma(a, b, c): # => a * b + c
return _FusedMultiplyAdd.apply(a, b, c)
+
#----------------------------------------------------------------------------
-class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
+
+class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
@staticmethod
- def forward(ctx, a, b, c): # pylint: disable=arguments-differ
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
out = torch.addcmul(c, a, b)
ctx.save_for_backward(a, b)
ctx.c_shape = c.shape
return out
@staticmethod
- def backward(ctx, dout): # pylint: disable=arguments-differ
+ def backward(ctx, dout): # pylint: disable=arguments-differ
a, b = ctx.saved_tensors
c_shape = ctx.c_shape
da = None
@@ -44,17 +46,23 @@ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
return da, db, dc
+
#----------------------------------------------------------------------------
+
def _unbroadcast(x, shape):
extra_dims = x.ndim - len(shape)
assert extra_dims >= 0
- dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
+ dim = [
+ i
+ for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)
+ ]
if len(dim):
x = x.sum(dim=dim, keepdim=True)
if extra_dims:
- x = x.reshape(-1, *x.shape[extra_dims+1:])
+ x = x.reshape(-1, *x.shape[extra_dims + 1:])
assert x.shape == shape
return x
+
#----------------------------------------------------------------------------
diff --git a/lib/torch_utils/ops/fused_act.py b/lib/torch_utils/ops/fused_act.py
index 394a8c57229e47243ad645bc8be54674871650f6..c38a2aa0c94f033f7ebcd01eddf8da126fd7add8 100644
--- a/lib/torch_utils/ops/fused_act.py
+++ b/lib/torch_utils/ops/fused_act.py
@@ -36,8 +36,10 @@ class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
- (out,) = ctx.saved_tensors
- gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale)
+ (out, ) = ctx.saved_tensors
+ gradgrad_out = fused.fused_bias_act(
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
+ )
return gradgrad_out, None, None, None, None
@@ -65,7 +67,7 @@ class FusedLeakyReLUFunction(Function):
@staticmethod
def backward(ctx, grad_output):
- (out,) = ctx.saved_tensors
+ (out, ) = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
@@ -78,7 +80,7 @@ class FusedLeakyReLUFunction(Function):
class FusedLeakyReLU(nn.Module):
- def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5):
super().__init__()
if bias:
@@ -93,11 +95,13 @@ class FusedLeakyReLU(nn.Module):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
-def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
+def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
if input.device.type == "cpu":
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
- return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale
+ return F.leaky_relu(
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
+ ) * scale
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
diff --git a/lib/torch_utils/ops/grid_sample_gradfix.py b/lib/torch_utils/ops/grid_sample_gradfix.py
index ca6b3413ea72a734703c34382c023b84523601fd..850feacd5a6300b85493cd7f713bffab1af70536 100644
--- a/lib/torch_utils/ops/grid_sample_gradfix.py
+++ b/lib/torch_utils/ops/grid_sample_gradfix.py
@@ -5,7 +5,6 @@
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
"""Custom replacement for `torch.nn.functional.grid_sample` that
supports arbitrarily high order gradients between the input and output.
Only works on 2D images and assumes
@@ -20,33 +19,44 @@ import torch
#----------------------------------------------------------------------------
-enabled = False # Enable the custom op by setting this to true.
+enabled = False # Enable the custom op by setting this to true.
#----------------------------------------------------------------------------
+
def grid_sample(input, grid):
if _should_use_custom_op():
return _GridSample2dForward.apply(input, grid)
- return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ return torch.nn.functional.grid_sample(
+ input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False
+ )
+
#----------------------------------------------------------------------------
+
def _should_use_custom_op():
if not enabled:
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
return True
- warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
+ warnings.warn(
+ f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().'
+ )
return False
+
#----------------------------------------------------------------------------
+
class _GridSample2dForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid):
assert input.ndim == 4
assert grid.ndim == 4
- output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ output = torch.nn.functional.grid_sample(
+ input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False
+ )
ctx.save_for_backward(input, grid)
return output
@@ -56,8 +66,10 @@ class _GridSample2dForward(torch.autograd.Function):
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
return grad_input, grad_grid
+
#----------------------------------------------------------------------------
+
class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
@@ -68,7 +80,7 @@ class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def backward(ctx, grad2_grad_input, grad2_grad_grid):
- _ = grad2_grad_grid # unused
+ _ = grad2_grad_grid # unused
grid, = ctx.saved_tensors
grad2_grad_output = None
grad2_input = None
@@ -80,4 +92,5 @@ class _GridSample2dBackward(torch.autograd.Function):
assert not ctx.needs_input_grad[2]
return grad2_grad_output, grad2_input, grad2_grid
+
#----------------------------------------------------------------------------
diff --git a/lib/torch_utils/ops/native_ops.py b/lib/torch_utils/ops/native_ops.py
index 09cc5c3245113c690ae7f4891f512351cfdd5187..a21a1368c69aee0e802fa710d34a59ec63523fb6 100644
--- a/lib/torch_utils/ops/native_ops.py
+++ b/lib/torch_utils/ops/native_ops.py
@@ -4,7 +4,7 @@ from torch.nn import functional as F
class FusedLeakyReLU(nn.Module):
- def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5):
super().__init__()
if bias:
@@ -20,13 +20,15 @@ class FusedLeakyReLU(nn.Module):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
-def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
+def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
if input.dtype == torch.float16:
bias = bias.half()
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
- return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale
+ return F.leaky_relu(
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
+ ) * scale
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
@@ -48,12 +50,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
- out = out[
- :,
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
- :,
- ]
+ out = out[:,
+ max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
diff --git a/lib/torch_utils/ops/upfirdn2d.py b/lib/torch_utils/ops/upfirdn2d.py
index ceeac2b9834e33b7c601c28bf27f32aa91c69256..86f6fb36eb83711db42aef6b05c003eceaeeaa69 100644
--- a/lib/torch_utils/ops/upfirdn2d.py
+++ b/lib/torch_utils/ops/upfirdn2d.py
@@ -5,7 +5,6 @@
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
"""Custom PyTorch ops for efficient resampling of 2D images."""
import os
@@ -23,17 +22,24 @@ from . import conv2d_gradfix
_inited = False
_plugin = None
+
def _init():
global _inited, _plugin
if not _inited:
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
try:
- _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ _plugin = custom_ops.get_plugin(
+ 'upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']
+ )
except:
- warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ warnings.warn(
+ 'Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n'
+ + traceback.format_exc()
+ )
return _plugin is not None
+
def _parse_scaling(scaling):
if isinstance(scaling, int):
scaling = [scaling, scaling]
@@ -43,6 +49,7 @@ def _parse_scaling(scaling):
assert sx >= 1 and sy >= 1
return sx, sy
+
def _parse_padding(padding):
if isinstance(padding, int):
padding = [padding, padding]
@@ -54,6 +61,7 @@ def _parse_padding(padding):
padx0, padx1, pady0, pady1 = padding
return padx0, padx1, pady0, pady1
+
def _get_filter_size(f):
if f is None:
return 1, 1
@@ -67,9 +75,13 @@ def _get_filter_size(f):
assert fw >= 1 and fh >= 1
return fw, fh
+
#----------------------------------------------------------------------------
-def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
+
+def setup_filter(
+ f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None
+):
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
Args:
@@ -111,12 +123,14 @@ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=Fals
f /= f.sum()
if flip_filter:
f = f.flip(list(range(f.ndim)))
- f = f * (gain ** (f.ndim / 2))
+ f = f * (gain**(f.ndim / 2))
f = f.to(device=device)
return f
+
#----------------------------------------------------------------------------
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
r"""Pad, upsample, filter, and downsample a batch of 2D images.
@@ -160,11 +174,17 @@ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cu
assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda']
if impl == 'cuda' and x.device.type == 'cuda' and _init():
- return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
- return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
+ return _upfirdn2d_cuda(
+ up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
+ ).apply(x, f)
+ return _upfirdn2d_ref(
+ x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
+ )
+
#----------------------------------------------------------------------------
+
@misc.profiled_function
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
@@ -187,10 +207,12 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
# Pad or crop.
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
- x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
+ x = x[:, :,
+ max(-pady0, 0):x.shape[2] - max(-pady1, 0),
+ max(-padx0, 0):x.shape[3] - max(-padx1, 0)]
# Setup filter.
- f = f * (gain ** (f.ndim / 2))
+ f = f * (gain**(f.ndim / 2))
f = f.to(x.dtype)
if not flip_filter:
f = f.flip(list(range(f.ndim)))
@@ -207,10 +229,12 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
x = x[:, :, ::downy, ::downx]
return x
+
#----------------------------------------------------------------------------
_upfirdn2d_cuda_cache = dict()
+
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
"""
@@ -227,23 +251,31 @@ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
# Forward op.
class Upfirdn2dCuda(torch.autograd.Function):
@staticmethod
- def forward(ctx, x, f): # pylint: disable=arguments-differ
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
assert isinstance(x, torch.Tensor) and x.ndim == 4
if f is None:
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
y = x
if f.ndim == 2:
- y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+ y = _plugin.upfirdn2d(
+ y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain
+ )
else:
- y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
- y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
+ y = _plugin.upfirdn2d(
+ y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter,
+ np.sqrt(gain)
+ )
+ y = _plugin.upfirdn2d(
+ y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter,
+ np.sqrt(gain)
+ )
ctx.save_for_backward(f)
ctx.x_shape = x.shape
return y
@staticmethod
- def backward(ctx, dy): # pylint: disable=arguments-differ
+ def backward(ctx, dy): # pylint: disable=arguments-differ
f, = ctx.saved_tensors
_, _, ih, iw = ctx.x_shape
_, _, oh, ow = dy.shape
@@ -258,7 +290,9 @@ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
df = None
if ctx.needs_input_grad[0]:
- dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
+ dx = _upfirdn2d_cuda(
+ up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain
+ ).apply(dy, f)
assert not ctx.needs_input_grad[1]
return dx, df
@@ -267,8 +301,10 @@ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
return Upfirdn2dCuda
+
#----------------------------------------------------------------------------
+
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
r"""Filter a batch of 2D images using the given 2D FIR filter.
@@ -303,8 +339,10 @@ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
]
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
#----------------------------------------------------------------------------
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
r"""Upsample a batch of 2D images using the given 2D FIR filter.
@@ -340,10 +378,14 @@ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
pady0 + (fh + upy - 1) // 2,
pady1 + (fh - upy) // 2,
]
- return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
+ return upfirdn2d(
+ x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl
+ )
+
#----------------------------------------------------------------------------
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
r"""Downsample a batch of 2D images using the given 2D FIR filter.
@@ -381,4 +423,5 @@ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'
]
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
#----------------------------------------------------------------------------
diff --git a/lib/torch_utils/persistence.py b/lib/torch_utils/persistence.py
index 0186cfd97bca0fcb397a7b73643520c1d1105a02..c3263dc0690ac12d5d2e74a6d9d8d2af2fed0f5b 100644
--- a/lib/torch_utils/persistence.py
+++ b/lib/torch_utils/persistence.py
@@ -5,7 +5,6 @@
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
"""Facilities for pickling Python code alongside other data.
The pickled code is automatically imported into a separate Python module
@@ -24,14 +23,15 @@ import dnnlib
#----------------------------------------------------------------------------
-_version = 6 # internal version number
-_decorators = set() # {decorator_class, ...}
-_import_hooks = [] # [hook_function, ...]
+_version = 6 # internal version number
+_decorators = set() # {decorator_class, ...}
+_import_hooks = [] # [hook_function, ...]
_module_to_src_dict = dict() # {module: src, ...}
_src_to_module_dict = dict() # {src: module, ...}
#----------------------------------------------------------------------------
+
def persistent_class(orig_class):
r"""Class decorator that extends a given class to save its source code
when pickled.
@@ -119,18 +119,26 @@ def persistent_class(orig_class):
fields = list(super().__reduce__())
fields += [None] * max(3 - len(fields), 0)
if fields[0] is not _reconstruct_persistent_obj:
- meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
- fields[0] = _reconstruct_persistent_obj # reconstruct func
- fields[1] = (meta,) # reconstruct args
- fields[2] = None # state dict
+ meta = dict(
+ type='class',
+ version=_version,
+ module_src=self._orig_module_src,
+ class_name=self._orig_class_name,
+ state=fields[2]
+ )
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
+ fields[1] = (meta, ) # reconstruct args
+ fields[2] = None # state dict
return tuple(fields)
Decorator.__name__ = orig_class.__name__
_decorators.add(Decorator)
return Decorator
+
#----------------------------------------------------------------------------
+
def is_persistent(obj):
r"""Test whether the given object or class is persistent, i.e.,
whether it will save its source code when pickled.
@@ -140,10 +148,12 @@ def is_persistent(obj):
return True
except TypeError:
pass
- return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
+
#----------------------------------------------------------------------------
+
def import_hook(hook):
r"""Register an import hook that is called whenever a persistent object
is being unpickled. A typical use case is to patch the pickled source
@@ -174,8 +184,10 @@ def import_hook(hook):
assert callable(hook)
_import_hooks.append(hook)
+
#----------------------------------------------------------------------------
+
def _reconstruct_persistent_obj(meta):
r"""Hook that is called internally by the `pickle` module to unpickle
a persistent object.
@@ -196,13 +208,15 @@ def _reconstruct_persistent_obj(meta):
setstate = getattr(obj, '__setstate__', None)
if callable(setstate):
- setstate(meta.state) # pylint: disable=not-callable
+ setstate(meta.state) # pylint: disable=not-callable
else:
obj.__dict__.update(meta.state)
return obj
+
#----------------------------------------------------------------------------
+
def _module_to_src(module):
r"""Query the source code of a given Python module.
"""
@@ -213,6 +227,7 @@ def _module_to_src(module):
_src_to_module_dict[src] = module
return src
+
def _src_to_module(src):
r"""Get or create a Python module for the given source code.
"""
@@ -223,11 +238,13 @@ def _src_to_module(src):
sys.modules[module_name] = module
_module_to_src_dict[module] = src
_src_to_module_dict[src] = module
- exec(src, module.__dict__) # pylint: disable=exec-used
+ exec(src, module.__dict__) # pylint: disable=exec-used
return module
+
#----------------------------------------------------------------------------
+
def _check_pickleable(obj):
r"""Check that the given object is pickleable, raising an exception if
it is not. This function is expected to be considerably more efficient
@@ -239,13 +256,15 @@ def _check_pickleable(obj):
if isinstance(obj, dict):
return [[recurse(x), recurse(y)] for x, y in obj.items()]
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
- return None # Python primitive types are pickleable.
+ return None # Python primitive types are pickleable.
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
- return None # NumPy arrays and PyTorch tensors are pickleable.
+ return None # NumPy arrays and PyTorch tensors are pickleable.
if is_persistent(obj):
- return None # Persistent objects are pickleable, by virtue of the constructor check.
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
return obj
+
with io.BytesIO() as f:
pickle.dump(recurse(obj), f)
+
#----------------------------------------------------------------------------
diff --git a/lib/torch_utils/training_stats.py b/lib/torch_utils/training_stats.py
index d2c265f5c8ab235156a4bb12de2df69d00074de5..11658fdbf55450f5f0d4679e247ff65a4b37151e 100644
--- a/lib/torch_utils/training_stats.py
+++ b/lib/torch_utils/training_stats.py
@@ -5,7 +5,6 @@
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
"""Facilities for reporting and collecting training statistics across
multiple processes and devices. The interface is designed to minimize
synchronization overhead as well as the amount of boilerplate in user
@@ -20,17 +19,19 @@ from . import misc
#----------------------------------------------------------------------------
-_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
-_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
-_counter_dtype = torch.float64 # Data type to use for the internal counters.
-_rank = 0 # Rank of the current process.
-_sync_device = None # Device to use for multiprocess communication. None = single-process.
-_sync_called = False # Has _sync() been called yet?
-_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
-_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
+_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
+_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
+_counter_dtype = torch.float64 # Data type to use for the internal counters.
+_rank = 0 # Rank of the current process.
+_sync_device = None # Device to use for multiprocess communication. None = single-process.
+_sync_called = False # Has _sync() been called yet?
+_counters = dict(
+) # Running counters on each device, updated by report(): name => device => torch.Tensor
+_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
#----------------------------------------------------------------------------
+
def init_multiprocessing(rank, sync_device):
r"""Initializes `torch_utils.training_stats` for collecting statistics
across multiple processes.
@@ -50,8 +51,10 @@ def init_multiprocessing(rank, sync_device):
_rank = rank
_sync_device = sync_device
+
#----------------------------------------------------------------------------
+
@misc.profiled_function
def report(name, value):
r"""Broadcasts the given set of scalars to all interested instances of
@@ -98,8 +101,10 @@ def report(name, value):
_counters[name][device].add_(moments)
return value
+
#----------------------------------------------------------------------------
+
def report0(name, value):
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
but ignores any scalars provided by the other processes.
@@ -108,8 +113,10 @@ def report0(name, value):
report(name, value if _rank == 0 else [])
return value
+
#----------------------------------------------------------------------------
+
class Collector:
r"""Collects the scalars broadcasted by `report()` and `report0()` and
computes their long-term averages (mean and standard deviation) over
@@ -220,7 +227,9 @@ class Collector:
"""
stats = dnnlib.EasyDict()
for name in self.names():
- stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
+ stats[name] = dnnlib.EasyDict(
+ num=self.num(name), mean=self.mean(name), std=self.std(name)
+ )
return stats
def __getitem__(self, name):
@@ -229,8 +238,10 @@ class Collector:
"""
return self.mean(name)
+
#----------------------------------------------------------------------------
+
def _sync(names):
r"""Synchronize the global cumulative counters across devices and
processes. Called internally by `Collector.update()`.
@@ -265,4 +276,5 @@ def _sync(names):
# Return name-value pairs.
return [(name, _cumulative[name]) for name in names]
+
#----------------------------------------------------------------------------
diff --git a/requirements.txt b/requirements.txt
index 4faea63e38f40eea0c1f3595a1485ba524c3f73f..f29715018cec6e3a8cea1b896e27bbb0eb05e496 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,19 +3,18 @@ scikit-image
trimesh
rtree
pytorch_lightning
-kornia
+kornia>0.4.0
chumpy
opencv-python
opencv_contrib_python
scikit-learn
protobuf
-pymeshlab
dataclasses
mediapipe
einops
boto3
+open3d
tinyobjloader==2.0.0rc7
git+https://github.com/facebookresearch/pytorch3d.git
git+https://github.com/YuliangXiu/neural_voxelization_layer.git
git+https://github.com/YuliangXiu/rembg.git
-git+https://github.com/mmolero/pypoisson.git