diff --git a/.gitignore b/.gitignore
index bd973b50e5620cc6ca0a87c6ba5fe9790f643f64..f660fd6c046a92a58b99b25b3a84e0634467daf3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,4 +16,5 @@ build
 dist
 *egg-info
 *.so
-run.sh
\ No newline at end of file
+run.sh
+*.log
\ No newline at end of file
diff --git a/README.md b/README.md
index 8066a33a5ad4f0c8d395639b723127751b02cf0b..83f9681e81c28a2678137d2ee916284a38c2e69b 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@
     <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
     <a href="https://cupy.dev/"><img alt="cupy" src="https://img.shields.io/badge/-Cupy-46C02B?logo=numpy&logoColor=white"></a>
     <a href="https://twitter.com/yuliangxiu"><img alt='Twitter' src="https://img.shields.io/twitter/follow/yuliangxiu?label=%40yuliangxiu"></a>
+    <a href="https://discord.gg/Vqa7KBGRyk"><img alt="discord invitation link" src="https://dcbadge.vercel.app/api/server/Vqa7KBGRyk?style=flat"></a>
     <br></br>
     <a href='https://colab.research.google.com/drive/1YRgwoRCZIrSB2e7auEWFyG10Xzjbrbno?usp=sharing'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Google Colab'></a>
     <a href='https://github.com/YuliangXiu/ECON/blob/master/docs/installation-docker.md'><img src='https://img.shields.io/badge/Docker-9cf.svg?logo=Docker' alt='Docker'></a>
@@ -35,7 +36,6 @@
     </a>
     <a href='https://xiuyuliang.cn/econ/'>
       <img src='https://img.shields.io/badge/ECON-Page-orange?style=for-the-badge&logo=Google%20chrome&logoColor=white&labelColor=D35400' alt='Project Page'></a>
-    <a href="https://discord.gg/Vqa7KBGRyk"><img src="https://img.shields.io/discord/940240966844035082?color=7289DA&labelColor=4a64bd&logo=discord&logoColor=white&style=for-the-badge"></a>
     <a href="https://youtu.be/j5hw4tsWpoY"><img alt="youtube views" title="Subscribe to my YouTube channel" src="https://img.shields.io/youtube/views/j5hw4tsWpoY?logo=youtube&labelColor=ce4630&style=for-the-badge"/></a>
   </p>
 </p>
diff --git a/apps/IFGeo.py b/apps/IFGeo.py
index 8cb033d8d3fbd597ac80526d3d0c691451975685..45e892a655fd894d59ee925f32a62d378370bd49 100644
--- a/apps/IFGeo.py
+++ b/apps/IFGeo.py
@@ -14,11 +14,12 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from lib.common.seg3d_lossless import Seg3dLossless
-from lib.common.train_util import *
-import torch
 import numpy as np
 import pytorch_lightning as pl
+import torch
+
+from lib.common.seg3d_lossless import Seg3dLossless
+from lib.common.train_util import *
 
 torch.backends.cudnn.benchmark = True
 
diff --git a/apps/Normal.py b/apps/Normal.py
index 235c0aef05914ef040f1495843d339c758ebd9f3..94176f9996753680c3ae8608a485093deae14822 100644
--- a/apps/Normal.py
+++ b/apps/Normal.py
@@ -1,9 +1,10 @@
-from lib.net import NormalNet
-from lib.common.train_util import batch_mean
-import torch
 import numpy as np
-from skimage.transform import resize
 import pytorch_lightning as pl
+import torch
+from skimage.transform import resize
+
+from lib.common.train_util import batch_mean
+from lib.net import NormalNet
 
 
 class Normal(pl.LightningModule):
diff --git a/apps/avatarizer.py b/apps/avatarizer.py
index 10c35fa3d31d844685c1e17d3c6e8714426891c0..cc48a600f177f7892b68bb9e3a0b8e70ad733d56 100644
--- a/apps/avatarizer.py
+++ b/apps/avatarizer.py
@@ -1,17 +1,25 @@
-import numpy as np
-import trimesh
-import torch
 import argparse
+import os
 import os.path as osp
-import lib.smplx as smplx
+
+import numpy as np
+import torch
+import trimesh
 from pytorch3d.ops import SubdivideMeshes
 from pytorch3d.structures import Meshes
-
-from lib.smplx.lbs import general_lbs
-from lib.dataset.mesh_util import keep_largest, poisson
 from scipy.spatial import cKDTree
-from lib.dataset.mesh_util import SMPLX
+
+import lib.smplx as smplx
 from lib.common.local_affine import register
+from lib.dataset.mesh_util import (
+    SMPLX,
+    export_obj,
+    keep_largest,
+    o3d_ransac,
+    poisson,
+    remesh_laplacian,
+)
+from lib.smplx.lbs import general_lbs
 
 # loading cfg file
 parser = argparse.ArgumentParser()
@@ -22,12 +30,18 @@ args = parser.parse_args()
 smplx_container = SMPLX()
 device = torch.device(f"cuda:{args.gpu}")
 
+# loading SMPL-X and econ objs inferred with ECON
 prefix = f"./results/econ/obj/{args.name}"
 smpl_path = f"{prefix}_smpl_00.npy"
-econ_path = f"{prefix}_0_full.obj"
-
 smplx_param = np.load(smpl_path, allow_pickle=True).item()
+
+# export econ obj with pre-computed normals
+econ_path = f"{prefix}_0_full.obj"
 econ_obj = trimesh.load(econ_path)
+assert (econ_obj.vertex_normals.shape[1] == 3)
+econ_obj.export(f"{prefix}_econ_raw.ply")
+
+# align econ with SMPL-X
 econ_obj.vertices *= np.array([1.0, -1.0, -1.0])
 econ_obj.vertices /= smplx_param["scale"].cpu().numpy()
 econ_obj.vertices -= smplx_param["transl"].cpu().numpy()
@@ -49,6 +63,7 @@ smpl_model = smplx.create(
 
 smpl_out_lst = []
 
+# obtain the pose params of T-pose, DA-pose, and the original pose
 for pose_type in ["t-pose", "da-pose", "pose"]:
     smpl_out_lst.append(
         smpl_model(
@@ -67,6 +82,12 @@ for pose_type in ["t-pose", "da-pose", "pose"]:
         )
     )
 
+# -------------------------- align econ and SMPL-X in DA-pose space ------------------------- #
+# 1. find the vertex-correspondence between SMPL-X and econ
+# 2. ECON + SMPL-X: posed space --> T-pose space --> DA-pose space
+# 3. ECON (w/o hands & over-streched faces) + SMPL-X (w/ hands & registered inpainting parts)
+# ------------------------------------------------------------------------------------------- #
+
 smpl_verts = smpl_out_lst[2].vertices.detach()[0]
 smpl_tree = cKDTree(smpl_verts.cpu().numpy())
 dist, idx = smpl_tree.query(econ_obj.vertices, k=5)
@@ -143,14 +164,25 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
     smpl_da_body.remove_unreferenced_vertices()
 
     smpl_hand = smpl_da.copy()
-    smpl_hand.update_faces(smplx_container.smplx_mano_vertex_mask.numpy()[smpl_hand.faces].all(axis=1))
+    smpl_hand.update_faces(
+        smplx_container.smplx_mano_vertex_mask.numpy()[smpl_hand.faces].all(axis=1)
+    )
     smpl_hand.remove_unreferenced_vertices()
     econ_da = sum([smpl_hand, smpl_da_body, econ_da_body])
-    econ_da = poisson(econ_da, f"{prefix}_econ_da.obj", depth=10, decimation=False)
+    econ_da = poisson(econ_da, f"{prefix}_econ_da.obj", depth=10, face_count=50000)
+    econ_da = remesh_laplacian(econ_da, f"{prefix}_econ_da.obj")
 else:
     econ_da = trimesh.load(f"{prefix}_econ_da.obj")
     smpl_da = trimesh.load(f"{prefix}_smpl_da.obj", maintain_orders=True, process=False)
 
+# ---------------------- SMPL-X compatible ECON ---------------------- #
+# 1. Find the new vertex-correspondence between NEW ECON and SMPL-X
+# 2. Build the new J_regressor, lbs_weights, posedirs
+# 3. canonicalize the NEW ECON
+# ------------------------------------------------------------------- #
+
+print("Start building the SMPL-X compatible ECON model...")
+
 smpl_tree = cKDTree(smpl_da.vertices)
 dist, idx = smpl_tree.query(econ_da.vertices, k=5)
 knn_weights = np.exp(-dist**2)
@@ -167,19 +199,137 @@ econ_posedirs = (
 econ_J_regressor /= econ_J_regressor.sum(dim=1, keepdims=True).clip(min=1e-10)
 econ_lbs_weights /= econ_lbs_weights.sum(dim=1, keepdims=True)
 
-# re-compute da-pose rot_mat for ECON
 rot_mat_da = smpl_out_lst[1].vertex_transformation.detach()[0][idx[:, 0]]
 econ_da_verts = torch.tensor(econ_da.vertices).float()
-econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat(
-    [econ_da_verts, torch.ones_like(econ_da_verts)[..., :1]], dim=1
-).unsqueeze(-1)
+econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat([
+    econ_da_verts, torch.ones_like(econ_da_verts)[..., :1]
+],
+                                                        dim=1).unsqueeze(-1)
 econ_cano_verts = econ_cano_verts[:, :3, 0].double()
 
 # ----------------------------------------------------
-# use any SMPL-X pose to animate ECON reconstruction
+# use original pose to animate ECON reconstruction
 # ----------------------------------------------------
 
 new_pose = smpl_out_lst[2].full_pose
+# new_pose[:, :3] = 0.
+
+posed_econ_verts, _ = general_lbs(
+    pose=new_pose,
+    v_template=econ_cano_verts.unsqueeze(0),
+    posedirs=econ_posedirs,
+    J_regressor=econ_J_regressor,
+    parents=smpl_model.parents,
+    lbs_weights=econ_lbs_weights
+)
+
+aligned_econ_verts = posed_econ_verts[0].detach().cpu().numpy()
+aligned_econ_verts += smplx_param["transl"].cpu().numpy()
+aligned_econ_verts *= smplx_param["scale"].cpu().numpy() * np.array([1.0, -1.0, -1.0])
+econ_pose = trimesh.Trimesh(aligned_econ_verts, econ_da.faces)
+assert (econ_pose.vertex_normals.shape[1] == 3)
+econ_pose.export(f"{prefix}_econ_pose.ply")
+
+# -------------------------------------------------------------------------
+# Align posed ECON with original ECON, for pixel-aligned texture extraction
+# -------------------------------------------------------------------------
+
+print("Start ICP registration between posed & original ECON...")
+import open3d as o3d
+
+source = o3d.io.read_point_cloud(f"{prefix}_econ_pose.ply")
+target = o3d.io.read_point_cloud(f"{prefix}_econ_raw.ply")
+trans_init = o3d_ransac(source, target)
+icp_criteria = o3d.pipelines.registration.ICPConvergenceCriteria(
+    relative_fitness=0.000001, relative_rmse=0.000001, max_iteration=100
+)
+
+reg_p2l = o3d.pipelines.registration.registration_icp(
+    source,
+    target,
+    0.1,
+    trans_init,
+    o3d.pipelines.registration.TransformationEstimationPointToPlane(),
+    criteria=icp_criteria
+)
+econ_pose.apply_transform(reg_p2l.transformation)
+
+cache_path = f"{prefix.replace('obj','cache')}"
+os.makedirs(cache_path, exist_ok=True)
+
+# -----------------------------------------------------------------
+# create UV texture (.obj .mtl .png) from posed ECON reconstruction
+# -----------------------------------------------------------------
+
+print("Start Color mapping...")
+from PIL import Image
+from torchvision import transforms
+
+from lib.common.render import query_color
+from lib.common.render_utils import Pytorch3dRasterizer
+
+if not osp.exists(f"{prefix}_econ_icp_rgb.ply"):
+    masked_image = f"./results/econ/png/{args.name}_cloth.png"
+    tensor_image = transforms.ToTensor()(Image.open(masked_image))[:, :, :512]
+    final_colors = query_color(
+        torch.tensor(econ_pose.vertices).float(),
+        torch.tensor(econ_pose.faces).long(),
+        ((tensor_image - 0.5) * 2.0).unsqueeze(0).to(device),
+        device=device,
+        paint_normal=False,
+    )
+    final_colors[final_colors == tensor_image[:, 0, 0] * 255.0] = 0.0
+    final_colors = final_colors.detach().cpu().numpy()
+    econ_pose.visual.vertex_colors = final_colors
+    econ_pose.export(f"{prefix}_econ_icp_rgb.ply")
+else:
+    mesh = trimesh.load(f"{prefix}_econ_icp_rgb.ply")
+    final_colors = mesh.visual.vertex_colors[:, :3]
+
+print("Start UV texture generation...")
+
+# Generate UV coords
+v_np = econ_pose.vertices
+f_np = econ_pose.faces
+
+vt_cache = osp.join(cache_path, "vt.pt")
+ft_cache = osp.join(cache_path, "ft.pt")
+
+if osp.exists(vt_cache) and osp.exists(ft_cache):
+    vt = torch.load(vt_cache).to(device)
+    ft = torch.load(ft_cache).to(device)
+else:
+    import xatlas
+    atlas = xatlas.Atlas()
+    atlas.add_mesh(v_np, f_np)
+    chart_options = xatlas.ChartOptions()
+    chart_options.max_iterations = 4
+    atlas.generate(chart_options=chart_options)
+    vmapping, ft_np, vt_np = atlas[0]
+
+    vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
+    ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
+    torch.save(vt.cpu(), vt_cache)
+    torch.save(ft.cpu(), ft_cache)
+
+# UV texture rendering
+uv_rasterizer = Pytorch3dRasterizer(image_size=512, device=device)
+texture_npy = uv_rasterizer.get_texture(
+    torch.cat([(vt - 0.5) * 2.0, torch.ones_like(vt[:, :1])], dim=1),
+    ft,
+    torch.tensor(v_np).unsqueeze(0).float(),
+    torch.tensor(f_np).unsqueeze(0).long(),
+    torch.tensor(final_colors).unsqueeze(0).float() / 255.0,
+)
+
+Image.fromarray((texture_npy * 255.0).astype(np.uint8)).save(f"{cache_path}/texture.png")
+
+# UV mask for TEXTure (https://readpaper.com/paper/4720151447010820097)
+texture_npy[texture_npy.sum(axis=2) == 0.0] = 1.0
+Image.fromarray((texture_npy * 255.0).astype(np.uint8)).save(f"{cache_path}/mask.png")
+
+# generate da-pose vertices
+new_pose = smpl_out_lst[1].full_pose
 new_pose[:, :3] = 0.
 
 posed_econ_verts, _ = general_lbs(
@@ -191,5 +341,8 @@ posed_econ_verts, _ = general_lbs(
     lbs_weights=econ_lbs_weights
 )
 
-econ_pose = trimesh.Trimesh(posed_econ_verts[0].detach(), econ_da.faces)
-econ_pose.export(f"{prefix}_econ_pose.obj")
+# export mtl file
+mtl_string = f"newmtl mat0 \nKa 1.000000 1.000000 1.000000 \nKd 1.000000 1.000000 1.000000 \nKs 0.000000 0.000000 0.000000 \nTr 1.000000 \nillum 1 \nNs 0.000000\nmap_Kd texture.png"
+with open(f"{cache_path}/material.mtl", 'w') as file:
+    file.write(mtl_string)
+export_obj(posed_econ_verts[0].detach().cpu().numpy(), f_np, vt, ft, f"{cache_path}/mesh.obj")
diff --git a/apps/benchmark.py b/apps/benchmark.py
index 3c95b0e209439af0c221742a3f32168e1fb5c280..504e2537e7c38794390345a0b0dca3b3667540db 100644
--- a/apps/benchmark.py
+++ b/apps/benchmark.py
@@ -14,28 +14,29 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import warnings
 import logging
+import warnings
 
 warnings.filterwarnings("ignore")
 logging.getLogger("lightning").setLevel(logging.ERROR)
 logging.getLogger("trimesh").setLevel(logging.ERROR)
 
-import torch
 import argparse
 import os
 
+import torch
 from termcolor import colored
 from tqdm.auto import tqdm
-from apps.Normal import Normal
+
 from apps.IFGeo import IFGeo
-from lib.common.config import cfg
+from apps.Normal import Normal
 from lib.common.BNI import BNI
 from lib.common.BNI_utils import save_normal_tensor
+from lib.common.config import cfg
+from lib.common.voxelize import VoxelGrid
 from lib.dataset.EvalDataset import EvalDataset
 from lib.dataset.Evaluator import Evaluator
 from lib.dataset.mesh_util import *
-from lib.common.voxelize import VoxelGrid
 
 torch.backends.cudnn.benchmark = True
 speed_analysis = False
@@ -62,8 +63,14 @@ if __name__ == "__main__":
     device = torch.device("cuda:0")
 
     cfg_test_list = [
-        "dataset.rotation_num", 3, "bni.use_smpl", ["hand"], "bni.use_ifnet", args.ifnet,
-        "bni.cut_intersection", True,
+        "dataset.rotation_num",
+        3,
+        "bni.use_smpl",
+        ["hand"],
+        "bni.use_ifnet",
+        args.ifnet,
+        "bni.cut_intersection",
+        True,
     ]
 
     # # if w/ RenderPeople+CAPE
@@ -176,12 +183,10 @@ if __name__ == "__main__":
 
                 # mesh completion via IF-net
                 in_tensor.update(
-                    dataset.depth_to_voxel(
-                        {
-                            "depth_F": BNI_object.F_depth.unsqueeze(0).to(device),
-                            "depth_B": BNI_object.B_depth.unsqueeze(0).to(device)
-                        }
-                    )
+                    dataset.depth_to_voxel({
+                        "depth_F": BNI_object.F_depth.unsqueeze(0).to(device), "depth_B":
+                        BNI_object.B_depth.unsqueeze(0).to(device)
+                    })
                 )
 
                 occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
diff --git a/apps/infer.py b/apps/infer.py
index 97b80c437270711cccf0df2ff14779ecd1fc0ef7..3a9d85c6d92f6e92cbb4acd5f2f905163f079c4f 100644
--- a/apps/infer.py
+++ b/apps/infer.py
@@ -14,35 +14,37 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import warnings
 import logging
+import warnings
 
 warnings.filterwarnings("ignore")
 logging.getLogger("lightning").setLevel(logging.ERROR)
 logging.getLogger("trimesh").setLevel(logging.ERROR)
 
-import torch, torchvision
-import trimesh
-import numpy as np
 import argparse
 import os
 
+import numpy as np
+import torch
+import torchvision
+import trimesh
+from pytorch3d.ops import SubdivideMeshes
 from termcolor import colored
 from tqdm.auto import tqdm
-from apps.Normal import Normal
+
 from apps.IFGeo import IFGeo
-from pytorch3d.ops import SubdivideMeshes
-from lib.common.config import cfg
-from lib.common.render import query_color
-from lib.common.train_util import init_loss, Format
-from lib.common.imutils import blend_rgb_norm
+from apps.Normal import Normal
 from lib.common.BNI import BNI
 from lib.common.BNI_utils import save_normal_tensor
-from lib.dataset.TestDataset import TestDataset
+from lib.common.config import cfg
+from lib.common.imutils import blend_rgb_norm
 from lib.common.local_affine import register
-from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis
-from lib.dataset.mesh_util import *
+from lib.common.render import query_color
+from lib.common.train_util import Format, init_loss
 from lib.common.voxelize import VoxelGrid
+from lib.dataset.mesh_util import *
+from lib.dataset.TestDataset import TestDataset
+from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis
 
 torch.backends.cudnn.benchmark = True
 
@@ -146,9 +148,8 @@ if __name__ == "__main__":
         os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
 
         in_tensor = {
-            "smpl_faces": data["smpl_faces"],
-            "image": data["img_icon"].to(device),
-            "mask": data["img_mask"].to(device)
+            "smpl_faces": data["smpl_faces"], "image": data["img_icon"].to(device), "mask":
+            data["img_mask"].to(device)
         }
 
         # The optimizer and variables
@@ -157,9 +158,11 @@ if __name__ == "__main__":
         optimed_betas = data["betas"].requires_grad_(True)
         optimed_orient = data["global_orient"].requires_grad_(True)
 
-        optimizer_smpl = torch.optim.Adam(
-            [optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True
-        )
+        optimizer_smpl = torch.optim.Adam([
+            optimed_pose, optimed_trans, optimed_betas, optimed_orient
+        ],
+                                          lr=1e-2,
+                                          amsgrad=True)
         scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
             optimizer_smpl,
             mode="min",
@@ -234,9 +237,9 @@ if __name__ == "__main__":
                 )
 
                 smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
-                smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor(
-                    [1.0, 1.0, -1.0]
-                ).to(device)
+                smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor([
+                    1.0, 1.0, -1.0
+                ]).to(device)
 
                 # landmark errors
                 smpl_joints_3d = (
@@ -280,13 +283,11 @@ if __name__ == "__main__":
 
                 # BUG: PyTorch3D silhouette renderer generates dilated mask
                 bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
-                smpl_arr_fake = torch.cat(
-                    [
-                        in_tensor["T_normal_F"][:, 0].ne(bg_value).float(),
-                        in_tensor["T_normal_B"][:, 0].ne(bg_value).float()
-                    ],
-                    dim=-1
-                )
+                smpl_arr_fake = torch.cat([
+                    in_tensor["T_normal_F"][:, 0].ne(bg_value).float(),
+                    in_tensor["T_normal_B"][:, 0].ne(bg_value).float()
+                ],
+                                          dim=-1)
 
                 body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)
                                ).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
@@ -322,22 +323,18 @@ if __name__ == "__main__":
                 # save intermediate results
                 if (i == args.loop_smpl - 1) and (not args.novis):
 
-                    per_loop_lst.extend(
-                        [
-                            in_tensor["image"],
-                            in_tensor["T_normal_F"],
-                            in_tensor["normal_F"],
-                            diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1),
-                        ]
-                    )
-                    per_loop_lst.extend(
-                        [
-                            in_tensor["image"],
-                            in_tensor["T_normal_B"],
-                            in_tensor["normal_B"],
-                            diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1),
-                        ]
-                    )
+                    per_loop_lst.extend([
+                        in_tensor["image"],
+                        in_tensor["T_normal_F"],
+                        in_tensor["normal_F"],
+                        diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1),
+                    ])
+                    per_loop_lst.extend([
+                        in_tensor["image"],
+                        in_tensor["T_normal_B"],
+                        in_tensor["normal_B"],
+                        diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1),
+                    ])
                     per_data_lst.append(
                         get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl")
                     )
@@ -357,13 +354,11 @@ if __name__ == "__main__":
         if not args.novis:
             img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
             torchvision.utils.save_image(
-                torch.cat(
-                    [
-                        data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
-                        (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5
-                    ],
-                    dim=3
-                ), img_crop_path
+                torch.cat([
+                    data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
+                    (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5
+                ],
+                          dim=3), img_crop_path
             )
 
             rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
@@ -392,27 +387,25 @@ if __name__ == "__main__":
                 smpl_obj.export(smpl_obj_path)
                 smpl_info = {
                     "betas":
-                        optimed_betas[idx].detach().cpu().unsqueeze(0),
+                    optimed_betas[idx].detach().cpu().unsqueeze(0),
                     "body_pose":
-                        rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()
-                                                     ).cpu().unsqueeze(0),
+                    rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()
+                                                 ).cpu().unsqueeze(0),
                     "global_orient":
-                        rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()
-                                                     ).cpu().unsqueeze(0),
+                    rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()
+                                                 ).cpu().unsqueeze(0),
                     "transl":
-                        optimed_trans[idx].detach().cpu(),
+                    optimed_trans[idx].detach().cpu(),
                     "expression":
-                        data["exp"][idx].cpu().unsqueeze(0),
+                    data["exp"][idx].cpu().unsqueeze(0),
                     "jaw_pose":
-                        rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0),
+                    rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0),
                     "left_hand_pose":
-                        rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]
-                                                     ).cpu().unsqueeze(0),
+                    rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]).cpu().unsqueeze(0),
                     "right_hand_pose":
-                        rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]
-                                                     ).cpu().unsqueeze(0),
+                    rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]).cpu().unsqueeze(0),
                     "scale":
-                        data["scale"][idx].cpu(),
+                    data["scale"][idx].cpu(),
                 }
                 np.save(
                     smpl_obj_path.replace(".obj", ".npy"),
@@ -434,8 +427,8 @@ if __name__ == "__main__":
 
         per_data_lst = []
 
-        batch_smpl_verts = in_tensor["smpl_verts"].detach(
-        ) * torch.tensor([1.0, -1.0, 1.0], device=device)
+        batch_smpl_verts = in_tensor["smpl_verts"].detach() * torch.tensor([1.0, -1.0, 1.0],
+                                                                           device=device)
         batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]]
 
         in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(
@@ -491,12 +484,10 @@ if __name__ == "__main__":
 
                 # mesh completion via IF-net
                 in_tensor.update(
-                    dataset.depth_to_voxel(
-                        {
-                            "depth_F": BNI_object.F_depth.unsqueeze(0),
-                            "depth_B": BNI_object.B_depth.unsqueeze(0)
-                        }
-                    )
+                    dataset.depth_to_voxel({
+                        "depth_F": BNI_object.F_depth.unsqueeze(0), "depth_B":
+                        BNI_object.B_depth.unsqueeze(0)
+                    })
                 )
 
                 occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
diff --git a/apps/multi_render.py b/apps/multi_render.py
index 4088440757ce81137aaad7685d9df4b53b1c1383..10bdf4e52b3b69e756906d402124d2359c9403f0 100644
--- a/apps/multi_render.py
+++ b/apps/multi_render.py
@@ -1,7 +1,9 @@
-from lib.common.render import Render
-import torch
 import argparse
 
+import torch
+
+from lib.common.render import Render
+
 root = "./results/econ/vid"
 
 # loading cfg file
diff --git a/configs/econ.yaml b/configs/econ.yaml
index f96784da9a111dbc8e1139d3f6e9af641401e9e1..5283393be72b3dc50ed7f88021c4c76515d19a1d 100644
--- a/configs/econ.yaml
+++ b/configs/econ.yaml
@@ -28,7 +28,7 @@ bni:
   lambda1: 1e-4
   boundary_consist: 1e-6
   poisson_depth: 10
-  use_smpl: ["hand", "face"]
+  use_smpl: ["hand"]
   use_ifnet: False
   use_poisson: True
   hand_thres: 8e-2
diff --git a/docs/tricks.md b/docs/tricks.md
index 8fa32714b3f77561705a9c83052a6367ed7cae58..17ba7e98f502d385fc075ae756656a639b0502a5 100644
--- a/docs/tricks.md
+++ b/docs/tricks.md
@@ -2,7 +2,7 @@
 
 ### If the reconstructed geometry is not satisfying, play with the adjustable parameters in _config/econ.yaml_
 
-- `use_smpl: ["hand", "face"]`
+- `use_smpl: ["hand"]`
   - [ ]: don't use either hands or face parts from SMPL-X
   - ["hand"]: only use the **visible** hands from SMPL-X
   - ["hand", "face"]: use both **visible** hands and face from SMPL-X
diff --git a/lib/common/BNI.py b/lib/common/BNI.py
index 2e4365e012c3c862b5f009bb879c4dda45490db4..1ae381db6786bbf3c07c1b423467e39457fe651a 100644
--- a/lib/common/BNI.py
+++ b/lib/common/BNI.py
@@ -1,10 +1,12 @@
-from lib.common.BNI_utils import (
-    verts_inverse_transform, depth_inverse_transform, double_side_bilateral_normal_integration
-)
-
 import torch
 import trimesh
 
+from lib.common.BNI_utils import (
+    depth_inverse_transform,
+    double_side_bilateral_normal_integration,
+    verts_inverse_transform,
+)
+
 
 class BNI:
     def __init__(self, dir_path, name, BNI_dict, cfg, device):
@@ -84,8 +86,9 @@ class BNI:
 
 if __name__ == "__main__":
 
-    import numpy as np
     import os.path as osp
+
+    import numpy as np
     from tqdm import tqdm
 
     root = "/home/yxiu/Code/ECON/results/examples/BNI"
diff --git a/lib/common/BNI_utils.py b/lib/common/BNI_utils.py
index 6fa64c44c9049d0a7a37fb3e6d252549e180a931..235910fb51c263c51410afc247795f224456dd78 100644
--- a/lib/common/BNI_utils.py
+++ b/lib/common/BNI_utils.py
@@ -1,13 +1,23 @@
-import torch
-import trimesh
-import cv2, os
-from PIL import Image
+import os
 import os.path as osp
+
 import cupy as cp
+import cv2
 import numpy as np
-from cupyx.scipy.sparse import csr_matrix, vstack, hstack, spdiags, diags, coo_matrix
+import torch
+import trimesh
+from cupyx.scipy.sparse import (
+    coo_matrix,
+    csr_matrix,
+    diags,
+    hstack,
+    spdiags,
+    vstack,
+)
 from cupyx.scipy.sparse.linalg import cg
+from PIL import Image
 from tqdm.auto import tqdm
+
 from lib.dataset.mesh_util import clean_floats
 
 
@@ -68,13 +78,11 @@ def mean_value_cordinates(inner_pts, contour_pts):
     body_edges_c = np.roll(body_edges_a, shift=-1, axis=1)
     body_edges_b = np.sqrt(((contour_pts - np.roll(contour_pts, shift=-1, axis=0))**2).sum(axis=1))
 
-    body_edges = np.concatenate(
-        [
-            body_edges_a[..., None], body_edges_c[..., None],
-            np.repeat(body_edges_b[None, :, None], axis=0, repeats=len(inner_pts))
-        ],
-        axis=-1
-    )
+    body_edges = np.concatenate([
+        body_edges_a[..., None], body_edges_c[..., None],
+        np.repeat(body_edges_b[None, :, None], axis=0, repeats=len(inner_pts))
+    ],
+                                axis=-1)
 
     body_cos = (body_edges[:, :, 0]**2 + body_edges[:, :, 1]**2 -
                 body_edges[:, :, 2]**2) / (2 * body_edges[:, :, 0] * body_edges[:, :, 1])
@@ -167,9 +175,9 @@ def verts_transform(t, depth_scale):
     t_copy = t.clone()
     t_copy *= depth_scale * 0.5
     t_copy += depth_scale * 0.5
-    t_copy = t_copy[:, [1, 0, 2]] * torch.Tensor([2.0, 2.0, -2.0]) + torch.Tensor(
-        [0.0, 0.0, depth_scale]
-    )
+    t_copy = t_copy[:, [1, 0, 2]] * torch.Tensor([2.0, 2.0, -2.0]) + torch.Tensor([
+        0.0, 0.0, depth_scale
+    ])
 
     return t_copy
 
@@ -342,15 +350,13 @@ def construct_facets_from(mask):
     facet_bottom_left_mask = move_bottom(facet_top_left_mask)
     facet_bottom_right_mask = move_bottom_right(facet_top_left_mask)
 
-    return cp.hstack(
-        (
-            4 * cp.ones((cp.sum(facet_top_left_mask).item(), 1)),
-            idx[facet_top_left_mask][:, None],
-            idx[facet_bottom_left_mask][:, None],
-            idx[facet_bottom_right_mask][:, None],
-            idx[facet_top_right_mask][:, None],
-        )
-    ).astype(int)
+    return cp.hstack((
+        4 * cp.ones((cp.sum(facet_top_left_mask).item(), 1)),
+        idx[facet_top_left_mask][:, None],
+        idx[facet_bottom_left_mask][:, None],
+        idx[facet_bottom_right_mask][:, None],
+        idx[facet_top_right_mask][:, None],
+    )).astype(int)
 
 
 def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1):
@@ -614,7 +620,7 @@ def double_side_bilateral_normal_integration(
 
         energy_list.append(energy)
         relative_energy = cp.abs(energy - energy_old) / energy_old
-        
+
         # print(f"step {i + 1}/{max_iter} energy: {energy:.3e}"
         #       f" relative energy: {relative_energy:.3e}")
 
@@ -640,13 +646,11 @@ def double_side_bilateral_normal_integration(
             B_verts = verts_inverse_transform(torch.as_tensor(vertices_back).float(), 256.0)
 
             F_B_verts = torch.cat((F_verts, B_verts), dim=0)
-            F_B_faces = torch.cat(
-                (
-                    torch.as_tensor(faces_front_).long(),
-                    torch.as_tensor(faces_back_).long() + faces_front_.max() + 1
-                ),
-                dim=0
-            )
+            F_B_faces = torch.cat((
+                torch.as_tensor(faces_front_).long(),
+                torch.as_tensor(faces_back_).long() + faces_front_.max() + 1
+            ),
+                                  dim=0)
 
             front_surf = trimesh.Trimesh(F_verts, faces_front_)
             back_surf = trimesh.Trimesh(B_verts, faces_back_)
@@ -690,12 +694,12 @@ def double_side_bilateral_normal_integration(
     back_mesh = clean_floats(trimesh.Trimesh(vertices_back, faces_back))
 
     result = {
-        "F_verts": torch.as_tensor(front_mesh.vertices).float(),
-        "F_faces": torch.as_tensor(front_mesh.faces).long(),
-        "B_verts": torch.as_tensor(back_mesh.vertices).float(),
-        "B_faces": torch.as_tensor(back_mesh.faces).long(),
-        "F_depth": torch.as_tensor(depth_map_front_est).float(),
-        "B_depth": torch.as_tensor(depth_map_back_est).float()
+        "F_verts": torch.as_tensor(front_mesh.vertices).float(), "F_faces": torch.as_tensor(
+            front_mesh.faces
+        ).long(), "B_verts": torch.as_tensor(back_mesh.vertices).float(), "B_faces":
+        torch.as_tensor(back_mesh.faces).long(), "F_depth":
+        torch.as_tensor(depth_map_front_est).float(), "B_depth":
+        torch.as_tensor(depth_map_back_est).float()
     }
 
     return result
diff --git a/lib/common/blender_utils.py b/lib/common/blender_utils.py
deleted file mode 100644
index a02260cc722bd9729dfbeb153543ac5f648deacf..0000000000000000000000000000000000000000
--- a/lib/common/blender_utils.py
+++ /dev/null
@@ -1,383 +0,0 @@
-import bpy
-import sys, os
-from math import radians
-import mathutils
-import bmesh
-
-print(sys.exec_prefix)
-from tqdm import tqdm
-import numpy as np
-
-##################################################
-# Globals
-##################################################
-
-views = 120
-
-render = 'eevee'
-cycles_gpu = False
-
-quality_preview = False
-samples_preview = 16
-samples_final = 256
-
-resolution_x = 512
-resolution_y = 512
-
-shadows = False
-
-# diffuse_color = (57.0/255.0, 108.0/255.0, 189.0/255.0, 1.0)
-# diffuse_color = (18/255., 139/255., 142/255.,1)     #correct
-# diffuse_color = (251/255., 60/255., 60/255.,1)    #wrong
-
-smooth = False
-
-wireframe = False
-line_thickness = 0.1
-quads = False
-
-object_transparent = False
-mouth_transparent = False
-
-compositor_background_image = False
-compositor_image_scale = 1.0
-compositor_alpha = 0.7
-
-##################################################
-# Helper functions
-##################################################
-
-
-def blender_print(*args, **kwargs):
-    print(*args, **kwargs, file=sys.stderr)
-
-
-def using_app():
-    ''' Returns if script is running through Blender application (GUI or background processing)'''
-    return (not sys.argv[0].endswith('.py'))
-
-
-def setup_diffuse_transparent_material(target, color, object_transparent, backface_transparent):
-    ''' Sets up diffuse/transparent material with backface culling in cycles'''
-
-    mat = target.active_material
-    if mat is None:
-        # Create material
-        mat = bpy.data.materials.new(name='Material')
-        target.data.materials.append(mat)
-
-    mat.use_nodes = True
-    nodes = mat.node_tree.nodes
-    for node in nodes:
-        nodes.remove(node)
-
-    node_geometry = nodes.new('ShaderNodeNewGeometry')
-
-    node_diffuse = nodes.new('ShaderNodeBsdfDiffuse')
-    node_diffuse.inputs[0].default_value = color
-
-    node_transparent = nodes.new('ShaderNodeBsdfTransparent')
-    node_transparent.inputs[0].default_value = (1.0, 1.0, 1.0, 1.0)
-
-    node_emission = nodes.new('ShaderNodeEmission')
-    node_emission.inputs[0].default_value = (0.0, 0.0, 0.0, 1.0)
-
-    node_mix = nodes.new(type='ShaderNodeMixShader')
-    if object_transparent:
-        node_mix.inputs[0].default_value = 1.0
-    else:
-        node_mix.inputs[0].default_value = 0.0
-
-    node_mix_mouth = nodes.new(type='ShaderNodeMixShader')
-    if object_transparent or backface_transparent:
-        node_mix_mouth.inputs[0].default_value = 1.0
-    else:
-        node_mix_mouth.inputs[0].default_value = 0.0
-
-    node_mix_backface = nodes.new(type='ShaderNodeMixShader')
-
-    node_output = nodes.new(type='ShaderNodeOutputMaterial')
-
-    links = mat.node_tree.links
-
-    links.new(node_geometry.outputs[6], node_mix_backface.inputs[0])
-
-    links.new(node_diffuse.outputs[0], node_mix.inputs[1])
-    links.new(node_transparent.outputs[0], node_mix.inputs[2])
-    links.new(node_mix.outputs[0], node_mix_backface.inputs[1])
-
-    links.new(node_emission.outputs[0], node_mix_mouth.inputs[1])
-    links.new(node_transparent.outputs[0], node_mix_mouth.inputs[2])
-    links.new(node_mix_mouth.outputs[0], node_mix_backface.inputs[2])
-
-    links.new(node_mix_backface.outputs[0], node_output.inputs[0])
-    return
-
-
-##################################################
-
-
-def setup_scene():
-    global render
-    global cycles_gpu
-    global quality_preview
-    global resolution_x
-    global resolution_y
-    global shadows
-    global wireframe
-    global line_thickness
-    global compositor_background_image
-
-    # Remove default cube
-    if 'Cube' in bpy.data.objects:
-        bpy.data.objects['Cube'].select_set(True)
-        bpy.ops.object.delete()
-
-    scene = bpy.data.scenes['Scene']
-
-    # Setup render engine
-    if render == 'cycles':
-        scene.render.engine = 'CYCLES'
-    else:
-        scene.render.engine = 'BLENDER_EEVEE'
-
-    scene.render.resolution_x = resolution_x
-    scene.render.resolution_y = resolution_y
-    scene.render.resolution_percentage = 100
-    scene.render.film_transparent = True
-    if quality_preview:
-        scene.cycles.samples = samples_preview
-    else:
-        scene.cycles.samples = samples_final
-
-    # Setup Cycles CUDA GPU acceleration if requested
-    if render == 'cycles':
-        if cycles_gpu:
-            print('Activating GPU acceleration')
-            bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
-
-            if bpy.app.version[0] >= 3:
-                cuda_devices = bpy.context.preferences.addons[
-                    'cycles'].preferences.get_devices_for_type(compute_device_type='CUDA')
-            else:
-                (cuda_devices, opencl_devices
-                ) = bpy.context.preferences.addons['cycles'].preferences.get_devices()
-
-            if (len(cuda_devices) < 1):
-                print('ERROR: CUDA GPU acceleration not available')
-                sys.exit(1)
-
-            for cuda_device in cuda_devices:
-                if cuda_device.type == 'CUDA':
-                    cuda_device.use = True
-                    print('Using CUDA device: ' + str(cuda_device.name))
-                else:
-                    cuda_device.use = False
-                    print('Igoring CUDA device: ' + str(cuda_device.name))
-
-            scene.cycles.device = 'GPU'
-            if bpy.app.version[0] < 3:
-                scene.render.tile_x = 256
-                scene.render.tile_y = 256
-        else:
-            scene.cycles.device = 'CPU'
-            if bpy.app.version[0] < 3:
-                scene.render.tile_x = 64
-                scene.render.tile_y = 64
-
-    # Disable Blender 3 denoiser to properly measure Cycles render speed
-    if bpy.app.version[0] >= 3:
-        scene.cycles.use_denoising = False
-
-    # Setup camera
-    camera = bpy.data.objects['Camera']
-    camera.location = (0.0, -3, 1.8)
-    camera.rotation_euler = (radians(74), 0.0, 0)
-    bpy.data.cameras['Camera'].lens = 55
-
-    # Setup light
-
-    # Setup lights
-    light = bpy.data.objects['Light']
-    light.location = (-2, -3.0, 0.0)
-    light.rotation_euler = (radians(90.0), 0.0, 0.0)
-    bpy.data.lights['Light'].type = 'POINT'
-    bpy.data.lights['Light'].energy = 2
-    light.data.cycles.cast_shadow = False
-
-    if 'Sun' not in bpy.data.objects:
-        bpy.ops.object.light_add(type='SUN')
-        light_sun = bpy.context.active_object
-        light_sun.location = (0.0, -3, 0.0)
-        light_sun.rotation_euler = (radians(45.0), 0.0, radians(30))
-        bpy.data.lights['Sun'].energy = 2
-        light_sun.data.cycles.cast_shadow = shadows
-    else:
-        light_sun = bpy.data.objects['Sun']
-
-    if shadows:
-        # Setup shadow catcher
-        bpy.ops.mesh.primitive_plane_add()
-        plane = bpy.context.active_object
-        plane.scale = (5.0, 5.0, 1)
-
-        plane.cycles.is_shadow_catcher = True
-
-        # Exclude plane from diffuse cycles contribution to avoid bright pixel noise in body rendering
-        # plane.cycles_visibility.diffuse = False
-
-        if wireframe:
-            # Unmark freestyle edges
-            bpy.ops.object.mode_set(mode='EDIT')
-            bpy.ops.mesh.mark_freestyle_edge(clear=True)
-            bpy.ops.object.mode_set(mode='OBJECT')
-
-    # Setup freestyle mode for wireframe overlay rendering
-    if wireframe:
-        scene.render.use_freestyle = True
-        scene.render.line_thickness = line_thickness
-        bpy.context.view_layer.freestyle_settings.linesets[0].select_edge_mark = True
-
-        # Disable border edges so that we don't see contour of shadow catcher plane
-        bpy.context.view_layer.freestyle_settings.linesets[0].select_border = False
-    else:
-        scene.render.use_freestyle = False
-
-    if compositor_background_image:
-        # Setup compositing when using background image
-        setup_compositing()
-    else:
-        # Output transparent image when no background is used
-        scene.render.image_settings.color_mode = 'RGBA'
-
-
-##################################################
-
-
-def setup_compositing():
-
-    global compositor_image_scale
-    global compositor_alpha
-
-    # Node editor compositing setup
-    bpy.context.scene.use_nodes = True
-    tree = bpy.context.scene.node_tree
-
-    # Create input image node
-    image_node = tree.nodes.new(type='CompositorNodeImage')
-
-    scale_node = tree.nodes.new(type='CompositorNodeScale')
-    scale_node.inputs[1].default_value = compositor_image_scale
-    scale_node.inputs[2].default_value = compositor_image_scale
-
-    blend_node = tree.nodes.new(type='CompositorNodeAlphaOver')
-    blend_node.inputs[0].default_value = compositor_alpha
-
-    # Link nodes
-    links = tree.links
-    links.new(image_node.outputs[0], scale_node.inputs[0])
-
-    links.new(scale_node.outputs[0], blend_node.inputs[1])
-    links.new(tree.nodes['Render Layers'].outputs[0], blend_node.inputs[2])
-
-    links.new(blend_node.outputs[0], tree.nodes['Composite'].inputs[0])
-
-
-def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
-    '''Render image of given model file'''
-    global smooth
-    global object_transparent
-    global mouth_transparent
-    global compositor_background_image
-    global quads
-
-    path = input_dir + input_file
-
-    # Import object into scene
-    bpy.ops.import_scene.obj(filepath=path)
-    object = bpy.context.selected_objects[0]
-
-    object.rotation_euler = (radians(90.0), 0.0, radians(yaw))
-    z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:, 1])
-    # z_top = np.max(np.array([vert.co for vert in object.data.vertices])[:,1])
-    # blender_print(radians(90.0), z_bottom, z_top)
-    object.location -= mathutils.Vector((0.0, 0.0, z_bottom))
-
-    if quads:
-        bpy.context.view_layer.objects.active = object
-        bpy.ops.object.mode_set(mode='EDIT')
-        bpy.ops.mesh.tris_convert_to_quads()
-        bpy.ops.object.mode_set(mode='OBJECT')
-
-    if smooth:
-        bpy.ops.object.shade_smooth()
-
-    # Mark freestyle edges
-    bpy.context.view_layer.objects.active = object
-    bpy.ops.object.mode_set(mode='EDIT')
-    bpy.ops.mesh.mark_freestyle_edge(clear=False)
-    bpy.ops.object.mode_set(mode='OBJECT')
-
-    if correct:
-        diffuse_color = (18 / 255., 139 / 255., 142 / 255., 1)    #correct
-    else:
-        diffuse_color = (251 / 255., 60 / 255., 60 / 255., 1)    #wrong
-
-    setup_diffuse_transparent_material(object, diffuse_color, object_transparent, mouth_transparent)
-
-    if compositor_background_image:
-        # Set background image
-        image_path = input_dir + input_file.replace('.obj', '_original.png')
-        bpy.context.scene.node_tree.nodes['Image'].image = bpy.data.images.load(image_path)
-
-    # Render
-    bpy.context.scene.render.filepath = os.path.join(output_dir, output_file)
-
-    # Silence console output of bpy.ops.render.render by redirecting stdout to file
-    # Note: Does not actually write the output to file (Windows 7)
-    sys.stdout.flush()
-    old = os.dup(1)
-    os.close(1)
-    os.open('blender_render.log', os.O_WRONLY | os.O_CREAT)
-
-    # Render
-    bpy.ops.render.render(write_still=True)
-
-    # Remove temporary output redirection
-    #    sys.stdout.flush()
-    #    os.close(1)
-    #    os.dup(old)
-    #    os.close(old)
-
-    # Delete last selected object from scene
-    object.select_set(True)
-    bpy.ops.object.delete()
-
-
-def process_file(input_file, input_dir, output_file, output_dir, correct=True):
-    global views
-    global quality_preview
-
-    if not input_file.endswith('.obj'):
-        print('ERROR: Invalid input: ' + input_file)
-        return
-
-    print('Processing: ' + input_file)
-    if output_file == '':
-        output_file = input_file[:-4]
-
-    if quality_preview:
-        output_file = output_file.replace('.png', '-preview.png')
-
-    angle = 360.0 / views
-    pbar = tqdm(range(0, views))
-    for view in pbar:
-        pbar.set_description(f"{os.path.basename(output_file)} | View:{str(view)}")
-        yaw = view * angle
-        output_file_view = f"{output_file}/{view:03d}.png"
-        if not os.path.exists(os.path.join(output_dir, output_file_view)):
-            render_file(input_file, input_dir, output_file_view, output_dir, yaw, correct)
-
-    cmd = "ffmpeg -loglevel quiet -r 30 -f lavfi -i color=c=white:s=512x512 -i " + os.path.join(output_dir, output_file, '%3d.png') + \
-        " -shortest -filter_complex \"[0:v][1:v]overlay=shortest=1,format=yuv420p[out]\" -map \"[out]\" -y " + output_dir+"/"+output_file+".mp4"
-    os.system(cmd)
diff --git a/lib/common/cloth_extraction.py b/lib/common/cloth_extraction.py
index 612a96787e1aa836b097971e7aaf55b284ef178a..3e6f67234e25eaa2124c3d20692d138b49d20f8d 100644
--- a/lib/common/cloth_extraction.py
+++ b/lib/common/cloth_extraction.py
@@ -1,10 +1,11 @@
-import numpy as np
+import itertools
 import json
 import os
-import itertools
+from collections import Counter
+
+import numpy as np
 import trimesh
 from matplotlib.path import Path
-from collections import Counter
 from sklearn.neighbors import KNeighborsClassifier
 
 
@@ -36,13 +37,11 @@ def load_segmentation(path, shape):
                 xy = np.vstack((x, y)).T
                 coordinates.append(xy)
 
-            segmentations.append(
-                {
-                    "type": val["category_name"],
-                    "type_id": val["category_id"],
-                    "coordinates": coordinates,
-                }
-            )
+            segmentations.append({
+                "type": val["category_name"],
+                "type_id": val["category_id"],
+                "coordinates": coordinates,
+            })
 
         return segmentations
 
diff --git a/lib/common/config.py b/lib/common/config.py
index 3b42fc73cea714259e2a350ee3bd7dc14dbf9375..6d946c115f2b482d598b06d405608f3004b75e78 100644
--- a/lib/common/config.py
+++ b/lib/common/config.py
@@ -14,9 +14,10 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from yacs.config import CfgNode as CN
 import os
 
+from yacs.config import CfgNode as CN
+
 _C = CN(new_allowed=True)
 
 # needed by trainer
diff --git a/lib/common/imutils.py b/lib/common/imutils.py
index 39f61cff3d38d9b50ed245b1a87655ceb2c0eae3..287ab057ee20ac6d28ae8f7a83a77f681af040db 100644
--- a/lib/common/imutils.py
+++ b/lib/common/imutils.py
@@ -1,17 +1,18 @@
 import os
-os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
+
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
 import cv2
 import mediapipe as mp
-import torch
 import numpy as np
+import torch
 import torch.nn.functional as F
+from kornia.geometry.transform import get_affine_matrix2d, warp_affine
 from PIL import Image
-from lib.pymafx.core import constants
-
 from rembg import remove
 from rembg.session_factory import new_session
 from torchvision import transforms
-from kornia.geometry.transform import get_affine_matrix2d, warp_affine
+
+from lib.pymafx.core import constants
 
 
 def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
@@ -40,13 +41,14 @@ def get_affine_matrix_box(boxes, w2, h2):
     # boxes [left, top, right, bottom]
     width = boxes[:, 2] - boxes[:, 0]    #(N,)
     height = boxes[:, 3] - boxes[:, 1]    #(N,)
-    center = torch.tensor(
-        [(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]
-    ).T    #(N,2)
+    center = torch.tensor([(boxes[:, 0] + boxes[:, 2]) / 2.0,
+                           (boxes[:, 1] + boxes[:, 3]) / 2.0]).T    #(N,2)
     scale = torch.min(torch.tensor([w2 / width, h2 / height]),
                       dim=0)[0].unsqueeze(1).repeat(1, 2) * 0.9    #(N,2)
-    transl = torch.cat([w2 / 2.0 - center[:, 0:1], h2 / 2.0 - center[:, 1:2]], dim=1)   #(N,2)
-    M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.,]*transl.shape[0]))
+    transl = torch.cat([w2 / 2.0 - center[:, 0:1], h2 / 2.0 - center[:, 1:2]], dim=1)    #(N,2)
+    M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([
+        0.,
+    ] * transl.shape[0]))
 
     return M
 
@@ -54,12 +56,12 @@ def get_affine_matrix_box(boxes, w2, h2):
 def load_img(img_file):
 
     if img_file.endswith("exr"):
-        img = cv2.imread(img_file, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)  
-    else :
+        img = cv2.imread(img_file, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+    else:
         img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
 
     # considering non 8-bit image
-    if img.dtype != np.uint8 :
+    if img.dtype != np.uint8:
         img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
 
     if len(img.shape) == 2:
@@ -112,8 +114,8 @@ def get_pymafx(image, landmarks):
     # image [3,512,512]
 
     item = {
-        'img_body':
-            F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0]
+        'img_body': F.interpolate(image.unsqueeze(0), size=224, mode='bicubic',
+                                  align_corners=True)[0]
     }
 
     for part in ['lhand', 'rhand', 'face']:
@@ -211,11 +213,8 @@ def process_image(img_file, hps_type, single, input_res, detector):
     img_pymafx_lst = []
 
     uncrop_param = {
-        "ori_shape": [in_height, in_width],
-        "box_shape": [input_res, input_res],
-        "square_shape": [tgt_res, tgt_res],
-        "M_square": M_square,
-        "M_crop": M_crop
+        "ori_shape": [in_height, in_width], "box_shape": [input_res, input_res], "square_shape":
+        [tgt_res, tgt_res], "M_square": M_square, "M_crop": M_crop
     }
 
     for idx in range(len(boxes)):
@@ -226,11 +225,11 @@ def process_image(img_file, hps_type, single, input_res, detector):
         else:
             mask_detection = masks[0] * 0.
 
-        img_square_rgba = torch.cat(
-            [img_square.squeeze(0).permute(1, 2, 0),
-             torch.tensor(mask_detection < 0.4) * 255],
-            dim=2
-        )
+        img_square_rgba = torch.cat([
+            img_square.squeeze(0).permute(1, 2, 0),
+            torch.tensor(mask_detection < 0.4) * 255
+        ],
+                                    dim=2)
 
         img_crop = warp_affine(
             img_square_rgba.unsqueeze(0).permute(0, 3, 1, 2),
diff --git a/lib/common/libmesh/inside_mesh.py b/lib/common/libmesh/inside_mesh.py
index eaac43c2e6fe103c6a1dd4e182642ff0cc6a024a..24fa682633ce0a751054bdf0bdde3fa1f595d0b7 100644
--- a/lib/common/libmesh/inside_mesh.py
+++ b/lib/common/libmesh/inside_mesh.py
@@ -1,4 +1,5 @@
 import numpy as np
+
 from .triangle_hash import TriangleHash as _TriangleHash
 
 
@@ -147,8 +148,6 @@ class TriangleIntersector2d:
         v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA
 
         sum_uv = u + v
-        contains[mask] = (
-            (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) & (0 < sum_uv) &
-            (sum_uv < abs_detA)
-        )
+        contains[mask] = ((0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) & (0 < sum_uv) &
+                          (sum_uv < abs_detA))
         return contains
diff --git a/lib/common/libmesh/setup.py b/lib/common/libmesh/setup.py
index 38ac162300df4e987134e81306e1a6ad674a5323..65a56afc07363ce47320a69da57883e1906f35ab 100644
--- a/lib/common/libmesh/setup.py
+++ b/lib/common/libmesh/setup.py
@@ -1,5 +1,5 @@
+import numpy
 from setuptools import setup
 from Cython.Build import cythonize
-import numpy
 
 setup(name='libmesh', ext_modules=cythonize("*.pyx"), include_dirs=[numpy.get_include()])
diff --git a/lib/common/libmesh/triangle_hash.cpp b/lib/common/libmesh/triangle_hash.cpp
index 0fc7f44a1db876944eeb72eca57c56b038e80dd6..5f46e5e4f372d1c4dd629743e41063cc6e528adf 100644
--- a/lib/common/libmesh/triangle_hash.cpp
+++ b/lib/common/libmesh/triangle_hash.cpp
@@ -720,12 +720,12 @@ static CYTHON_INLINE float __PYX_NAN() {
 
     /* NumPy API declarations from "numpy/__init__.pxd" */
     
+#include <math.h>
 #include "ios"
 #include "new"
 #include "stdexcept"
 #include "typeinfo"
 #include <vector>
-#include <math.h>
 #include "pythread.h"
 #include <stdlib.h>
 #include "pystate.h"
@@ -1330,8 +1330,8 @@ typedef npy_clongdouble __pyx_t_5numpy_clongdouble_t;
  */
 typedef npy_cdouble __pyx_t_5numpy_complex_t;
 
-/* "triangle_hash.pyx":9
- * from libc.math cimport floor, ceil
+/* "triangle_hash.pyx":11
+ * 
  * 
  * cdef class TriangleHash:             # <<<<<<<<<<<<<<
  *     cdef vector[vector[int]] spatial_hash
@@ -1423,8 +1423,8 @@ struct __pyx_memoryviewslice_obj {
 
 
 
-/* "triangle_hash.pyx":9
- * from libc.math cimport floor, ceil
+/* "triangle_hash.pyx":11
+ * 
  * 
  * cdef class TriangleHash:             # <<<<<<<<<<<<<<
  *     cdef vector[vector[int]] spatial_hash
@@ -2279,6 +2279,10 @@ static PyObject *__pyx_memoryview_assign_item_from_object(struct __pyx_memoryvie
 static PyObject *__pyx_memoryviewslice_convert_item_to_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp); /* proto*/
 static PyObject *__pyx_memoryviewslice_assign_item_from_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value); /* proto*/
 
+/* Module declarations from 'cython.view' */
+
+/* Module declarations from 'cython' */
+
 /* Module declarations from 'cpython.buffer' */
 
 /* Module declarations from 'libc.string' */
@@ -2317,14 +2321,10 @@ static PyTypeObject *__pyx_ptype_5numpy_flexible = 0;
 static PyTypeObject *__pyx_ptype_5numpy_character = 0;
 static PyTypeObject *__pyx_ptype_5numpy_ufunc = 0;
 
-/* Module declarations from 'cython.view' */
-
-/* Module declarations from 'cython' */
+/* Module declarations from 'libc.math' */
 
 /* Module declarations from 'libcpp.vector' */
 
-/* Module declarations from 'libc.math' */
-
 /* Module declarations from 'triangle_hash' */
 static PyTypeObject *__pyx_ptype_13triangle_hash_TriangleHash = 0;
 static PyTypeObject *__pyx_array_type = 0;
@@ -2667,7 +2667,7 @@ static PyObject *__pyx_tuple__28;
 static PyObject *__pyx_codeobj__29;
 /* Late includes */
 
-/* "triangle_hash.pyx":13
+/* "triangle_hash.pyx":15
  *     cdef int resolution
  * 
  *     def __cinit__(self, double[:, :, :] triangles, int resolution):             # <<<<<<<<<<<<<<
@@ -2709,11 +2709,11 @@ static int __pyx_pw_13triangle_hash_12TriangleHash_1__cinit__(PyObject *__pyx_v_
         case  1:
         if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_resolution)) != 0)) kw_args--;
         else {
-          __Pyx_RaiseArgtupleInvalid("__cinit__", 1, 2, 2, 1); __PYX_ERR(0, 13, __pyx_L3_error)
+          __Pyx_RaiseArgtupleInvalid("__cinit__", 1, 2, 2, 1); __PYX_ERR(0, 15, __pyx_L3_error)
         }
       }
       if (unlikely(kw_args > 0)) {
-        if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__cinit__") < 0)) __PYX_ERR(0, 13, __pyx_L3_error)
+        if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__cinit__") < 0)) __PYX_ERR(0, 15, __pyx_L3_error)
       }
     } else if (PyTuple_GET_SIZE(__pyx_args) != 2) {
       goto __pyx_L5_argtuple_error;
@@ -2721,12 +2721,12 @@ static int __pyx_pw_13triangle_hash_12TriangleHash_1__cinit__(PyObject *__pyx_v_
       values[0] = PyTuple_GET_ITEM(__pyx_args, 0);
       values[1] = PyTuple_GET_ITEM(__pyx_args, 1);
     }
-    __pyx_v_triangles = __Pyx_PyObject_to_MemoryviewSlice_dsdsds_double(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_triangles.memview)) __PYX_ERR(0, 13, __pyx_L3_error)
-    __pyx_v_resolution = __Pyx_PyInt_As_int(values[1]); if (unlikely((__pyx_v_resolution == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 13, __pyx_L3_error)
+    __pyx_v_triangles = __Pyx_PyObject_to_MemoryviewSlice_dsdsds_double(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_triangles.memview)) __PYX_ERR(0, 15, __pyx_L3_error)
+    __pyx_v_resolution = __Pyx_PyInt_As_int(values[1]); if (unlikely((__pyx_v_resolution == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 15, __pyx_L3_error)
   }
   goto __pyx_L4_argument_unpacking_done;
   __pyx_L5_argtuple_error:;
-  __Pyx_RaiseArgtupleInvalid("__cinit__", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 13, __pyx_L3_error)
+  __Pyx_RaiseArgtupleInvalid("__cinit__", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 15, __pyx_L3_error)
   __pyx_L3_error:;
   __Pyx_AddTraceback("triangle_hash.TriangleHash.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename);
   __Pyx_RefNannyFinishContext();
@@ -2747,7 +2747,7 @@ static int __pyx_pf_13triangle_hash_12TriangleHash___cinit__(struct __pyx_obj_13
   int __pyx_clineno = 0;
   __Pyx_RefNannySetupContext("__cinit__", 0);
 
-  /* "triangle_hash.pyx":14
+  /* "triangle_hash.pyx":16
  * 
  *     def __cinit__(self, double[:, :, :] triangles, int resolution):
  *         self.spatial_hash.resize(resolution * resolution)             # <<<<<<<<<<<<<<
@@ -2758,10 +2758,10 @@ static int __pyx_pf_13triangle_hash_12TriangleHash___cinit__(struct __pyx_obj_13
     __pyx_v_self->spatial_hash.resize((__pyx_v_resolution * __pyx_v_resolution));
   } catch(...) {
     __Pyx_CppExn2PyErr();
-    __PYX_ERR(0, 14, __pyx_L1_error)
+    __PYX_ERR(0, 16, __pyx_L1_error)
   }
 
-  /* "triangle_hash.pyx":15
+  /* "triangle_hash.pyx":17
  *     def __cinit__(self, double[:, :, :] triangles, int resolution):
  *         self.spatial_hash.resize(resolution * resolution)
  *         self.resolution = resolution             # <<<<<<<<<<<<<<
@@ -2770,7 +2770,7 @@ static int __pyx_pf_13triangle_hash_12TriangleHash___cinit__(struct __pyx_obj_13
  */
   __pyx_v_self->resolution = __pyx_v_resolution;
 
-  /* "triangle_hash.pyx":16
+  /* "triangle_hash.pyx":18
  *         self.spatial_hash.resize(resolution * resolution)
  *         self.resolution = resolution
  *         self._build_hash(triangles)             # <<<<<<<<<<<<<<
@@ -2779,7 +2779,7 @@ static int __pyx_pf_13triangle_hash_12TriangleHash___cinit__(struct __pyx_obj_13
  */
   (void)(((struct __pyx_vtabstruct_13triangle_hash_TriangleHash *)__pyx_v_self->__pyx_vtab)->_build_hash(__pyx_v_self, __pyx_v_triangles));
 
-  /* "triangle_hash.pyx":13
+  /* "triangle_hash.pyx":15
  *     cdef int resolution
  * 
  *     def __cinit__(self, double[:, :, :] triangles, int resolution):             # <<<<<<<<<<<<<<
@@ -2799,7 +2799,7 @@ static int __pyx_pf_13triangle_hash_12TriangleHash___cinit__(struct __pyx_obj_13
   return __pyx_r;
 }
 
-/* "triangle_hash.pyx":20
+/* "triangle_hash.pyx":22
  *     @cython.boundscheck(False)  # Deactivate bounds checking
  *     @cython.wraparound(False)   # Deactivate negative indexing.
  *     cdef int _build_hash(self, double[:, :, :] triangles):             # <<<<<<<<<<<<<<
@@ -2839,7 +2839,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
   int __pyx_clineno = 0;
   __Pyx_RefNannySetupContext("_build_hash", 0);
 
-  /* "triangle_hash.pyx":21
+  /* "triangle_hash.pyx":23
  *     @cython.wraparound(False)   # Deactivate negative indexing.
  *     cdef int _build_hash(self, double[:, :, :] triangles):
  *         assert(triangles.shape[1] == 3)             # <<<<<<<<<<<<<<
@@ -2850,12 +2850,12 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
   if (unlikely(!Py_OptimizeFlag)) {
     if (unlikely(!(((__pyx_v_triangles.shape[1]) == 3) != 0))) {
       PyErr_SetNone(PyExc_AssertionError);
-      __PYX_ERR(0, 21, __pyx_L1_error)
+      __PYX_ERR(0, 23, __pyx_L1_error)
     }
   }
   #endif
 
-  /* "triangle_hash.pyx":22
+  /* "triangle_hash.pyx":24
  *     cdef int _build_hash(self, double[:, :, :] triangles):
  *         assert(triangles.shape[1] == 3)
  *         assert(triangles.shape[2] == 2)             # <<<<<<<<<<<<<<
@@ -2866,12 +2866,12 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
   if (unlikely(!Py_OptimizeFlag)) {
     if (unlikely(!(((__pyx_v_triangles.shape[2]) == 2) != 0))) {
       PyErr_SetNone(PyExc_AssertionError);
-      __PYX_ERR(0, 22, __pyx_L1_error)
+      __PYX_ERR(0, 24, __pyx_L1_error)
     }
   }
   #endif
 
-  /* "triangle_hash.pyx":24
+  /* "triangle_hash.pyx":26
  *         assert(triangles.shape[2] == 2)
  * 
  *         cdef int n_tri = triangles.shape[0]             # <<<<<<<<<<<<<<
@@ -2880,7 +2880,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
  */
   __pyx_v_n_tri = (__pyx_v_triangles.shape[0]);
 
-  /* "triangle_hash.pyx":31
+  /* "triangle_hash.pyx":33
  *         cdef int spatial_idx
  * 
  *         for i_tri in range(n_tri):             # <<<<<<<<<<<<<<
@@ -2892,7 +2892,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
   for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) {
     __pyx_v_i_tri = __pyx_t_3;
 
-    /* "triangle_hash.pyx":33
+    /* "triangle_hash.pyx":35
  *         for i_tri in range(n_tri):
  *             # Compute bounding box
  *             for j in range(2):             # <<<<<<<<<<<<<<
@@ -2902,7 +2902,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
     for (__pyx_t_4 = 0; __pyx_t_4 < 2; __pyx_t_4+=1) {
       __pyx_v_j = __pyx_t_4;
 
-      /* "triangle_hash.pyx":35
+      /* "triangle_hash.pyx":37
  *             for j in range(2):
  *                 bbox_min[j] = <int> min(
  *                     triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]             # <<<<<<<<<<<<<<
@@ -2933,7 +2933,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
         __pyx_t_11 = __pyx_t_10;
       }
 
-      /* "triangle_hash.pyx":34
+      /* "triangle_hash.pyx":36
  *             # Compute bounding box
  *             for j in range(2):
  *                 bbox_min[j] = <int> min(             # <<<<<<<<<<<<<<
@@ -2942,7 +2942,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
  */
       (__pyx_v_bbox_min[__pyx_v_j]) = ((int)__pyx_t_11);
 
-      /* "triangle_hash.pyx":38
+      /* "triangle_hash.pyx":40
  *                 )
  *                 bbox_max[j] = <int> max(
  *                     triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]             # <<<<<<<<<<<<<<
@@ -2973,7 +2973,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
         __pyx_t_10 = __pyx_t_9;
       }
 
-      /* "triangle_hash.pyx":37
+      /* "triangle_hash.pyx":39
  *                     triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
  *                 )
  *                 bbox_max[j] = <int> max(             # <<<<<<<<<<<<<<
@@ -2982,7 +2982,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
  */
       (__pyx_v_bbox_max[__pyx_v_j]) = ((int)__pyx_t_10);
 
-      /* "triangle_hash.pyx":40
+      /* "triangle_hash.pyx":42
  *                     triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
  *                 )
  *                 bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1)             # <<<<<<<<<<<<<<
@@ -3005,7 +3005,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
       }
       (__pyx_v_bbox_min[__pyx_v_j]) = __pyx_t_15;
 
-      /* "triangle_hash.pyx":41
+      /* "triangle_hash.pyx":43
  *                 )
  *                 bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1)
  *                 bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1)             # <<<<<<<<<<<<<<
@@ -3029,7 +3029,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
       (__pyx_v_bbox_max[__pyx_v_j]) = __pyx_t_13;
     }
 
-    /* "triangle_hash.pyx":44
+    /* "triangle_hash.pyx":46
  * 
  *             # Find all voxels where bounding box intersects
  *             for x in range(bbox_min[0], bbox_max[0] + 1):             # <<<<<<<<<<<<<<
@@ -3041,7 +3041,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
     for (__pyx_t_4 = (__pyx_v_bbox_min[0]); __pyx_t_4 < __pyx_t_15; __pyx_t_4+=1) {
       __pyx_v_x = __pyx_t_4;
 
-      /* "triangle_hash.pyx":45
+      /* "triangle_hash.pyx":47
  *             # Find all voxels where bounding box intersects
  *             for x in range(bbox_min[0], bbox_max[0] + 1):
  *                 for y in range(bbox_min[1], bbox_max[1] + 1):             # <<<<<<<<<<<<<<
@@ -3053,7 +3053,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
       for (__pyx_t_14 = (__pyx_v_bbox_min[1]); __pyx_t_14 < __pyx_t_16; __pyx_t_14+=1) {
         __pyx_v_y = __pyx_t_14;
 
-        /* "triangle_hash.pyx":46
+        /* "triangle_hash.pyx":48
  *             for x in range(bbox_min[0], bbox_max[0] + 1):
  *                 for y in range(bbox_min[1], bbox_max[1] + 1):
  *                     spatial_idx = self.resolution * x + y             # <<<<<<<<<<<<<<
@@ -3062,7 +3062,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
  */
         __pyx_v_spatial_idx = ((__pyx_v_self->resolution * __pyx_v_x) + __pyx_v_y);
 
-        /* "triangle_hash.pyx":47
+        /* "triangle_hash.pyx":49
  *                 for y in range(bbox_min[1], bbox_max[1] + 1):
  *                     spatial_idx = self.resolution * x + y
  *                     self.spatial_hash[spatial_idx].push_back(i_tri)             # <<<<<<<<<<<<<<
@@ -3073,13 +3073,13 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
           (__pyx_v_self->spatial_hash[__pyx_v_spatial_idx]).push_back(__pyx_v_i_tri);
         } catch(...) {
           __Pyx_CppExn2PyErr();
-          __PYX_ERR(0, 47, __pyx_L1_error)
+          __PYX_ERR(0, 49, __pyx_L1_error)
         }
       }
     }
   }
 
-  /* "triangle_hash.pyx":20
+  /* "triangle_hash.pyx":22
  *     @cython.boundscheck(False)  # Deactivate bounds checking
  *     @cython.wraparound(False)   # Deactivate negative indexing.
  *     cdef int _build_hash(self, double[:, :, :] triangles):             # <<<<<<<<<<<<<<
@@ -3098,7 +3098,7 @@ static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_1
   return __pyx_r;
 }
 
-/* "triangle_hash.pyx":51
+/* "triangle_hash.pyx":53
  *     @cython.boundscheck(False)  # Deactivate bounds checking
  *     @cython.wraparound(False)   # Deactivate negative indexing.
  *     cpdef query(self, double[:, :] points):             # <<<<<<<<<<<<<<
@@ -3155,12 +3155,12 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
     if (unlikely(!__Pyx_object_dict_version_matches(((PyObject *)__pyx_v_self), __pyx_tp_dict_version, __pyx_obj_dict_version))) {
       PY_UINT64_T __pyx_type_dict_guard = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
       #endif
-      __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_query); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 51, __pyx_L1_error)
+      __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_query); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 53, __pyx_L1_error)
       __Pyx_GOTREF(__pyx_t_1);
       if (!PyCFunction_Check(__pyx_t_1) || (PyCFunction_GET_FUNCTION(__pyx_t_1) != (PyCFunction)(void*)__pyx_pw_13triangle_hash_12TriangleHash_3query)) {
         __Pyx_XDECREF(__pyx_r);
-        if (unlikely(!__pyx_v_points.memview)) { __Pyx_RaiseUnboundLocalError("points"); __PYX_ERR(0, 51, __pyx_L1_error) }
-        __pyx_t_3 = __pyx_memoryview_fromslice(__pyx_v_points, 2, (PyObject *(*)(char *)) __pyx_memview_get_double, (int (*)(char *, PyObject *)) __pyx_memview_set_double, 0);; if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 51, __pyx_L1_error)
+        if (unlikely(!__pyx_v_points.memview)) { __Pyx_RaiseUnboundLocalError("points"); __PYX_ERR(0, 53, __pyx_L1_error) }
+        __pyx_t_3 = __pyx_memoryview_fromslice(__pyx_v_points, 2, (PyObject *(*)(char *)) __pyx_memview_get_double, (int (*)(char *, PyObject *)) __pyx_memview_set_double, 0);; if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 53, __pyx_L1_error)
         __Pyx_GOTREF(__pyx_t_3);
         __Pyx_INCREF(__pyx_t_1);
         __pyx_t_4 = __pyx_t_1; __pyx_t_5 = NULL;
@@ -3176,7 +3176,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
         __pyx_t_2 = (__pyx_t_5) ? __Pyx_PyObject_Call2Args(__pyx_t_4, __pyx_t_5, __pyx_t_3) : __Pyx_PyObject_CallOneArg(__pyx_t_4, __pyx_t_3);
         __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0;
         __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
-        if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 51, __pyx_L1_error)
+        if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 53, __pyx_L1_error)
         __Pyx_GOTREF(__pyx_t_2);
         __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
         __pyx_r = __pyx_t_2;
@@ -3197,7 +3197,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
     #endif
   }
 
-  /* "triangle_hash.pyx":52
+  /* "triangle_hash.pyx":54
  *     @cython.wraparound(False)   # Deactivate negative indexing.
  *     cpdef query(self, double[:, :] points):
  *         assert(points.shape[1] == 2)             # <<<<<<<<<<<<<<
@@ -3208,12 +3208,12 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
   if (unlikely(!Py_OptimizeFlag)) {
     if (unlikely(!(((__pyx_v_points.shape[1]) == 2) != 0))) {
       PyErr_SetNone(PyExc_AssertionError);
-      __PYX_ERR(0, 52, __pyx_L1_error)
+      __PYX_ERR(0, 54, __pyx_L1_error)
     }
   }
   #endif
 
-  /* "triangle_hash.pyx":53
+  /* "triangle_hash.pyx":55
  *     cpdef query(self, double[:, :] points):
  *         assert(points.shape[1] == 2)
  *         cdef int n_points = points.shape[0]             # <<<<<<<<<<<<<<
@@ -3222,7 +3222,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
  */
   __pyx_v_n_points = (__pyx_v_points.shape[0]);
 
-  /* "triangle_hash.pyx":63
+  /* "triangle_hash.pyx":65
  *         cdef int spatial_idx
  * 
  *         for i_point in range(n_points):             # <<<<<<<<<<<<<<
@@ -3234,7 +3234,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
   for (__pyx_t_8 = 0; __pyx_t_8 < __pyx_t_7; __pyx_t_8+=1) {
     __pyx_v_i_point = __pyx_t_8;
 
-    /* "triangle_hash.pyx":64
+    /* "triangle_hash.pyx":66
  * 
  *         for i_point in range(n_points):
  *             x = int(points[i_point, 0])             # <<<<<<<<<<<<<<
@@ -3245,7 +3245,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
     __pyx_t_10 = 0;
     __pyx_v_x = ((int)(*((double *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_points.data + __pyx_t_9 * __pyx_v_points.strides[0]) ) + __pyx_t_10 * __pyx_v_points.strides[1]) ))));
 
-    /* "triangle_hash.pyx":65
+    /* "triangle_hash.pyx":67
  *         for i_point in range(n_points):
  *             x = int(points[i_point, 0])
  *             y = int(points[i_point, 1])             # <<<<<<<<<<<<<<
@@ -3256,7 +3256,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
     __pyx_t_9 = 1;
     __pyx_v_y = ((int)(*((double *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_points.data + __pyx_t_10 * __pyx_v_points.strides[0]) ) + __pyx_t_9 * __pyx_v_points.strides[1]) ))));
 
-    /* "triangle_hash.pyx":66
+    /* "triangle_hash.pyx":68
  *             x = int(points[i_point, 0])
  *             y = int(points[i_point, 1])
  *             if not (0 <= x < self.resolution and 0 <= y < self.resolution):             # <<<<<<<<<<<<<<
@@ -3283,7 +3283,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
     __pyx_t_12 = ((!__pyx_t_11) != 0);
     if (__pyx_t_12) {
 
-      /* "triangle_hash.pyx":67
+      /* "triangle_hash.pyx":69
  *             y = int(points[i_point, 1])
  *             if not (0 <= x < self.resolution and 0 <= y < self.resolution):
  *                 continue             # <<<<<<<<<<<<<<
@@ -3292,7 +3292,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
  */
       goto __pyx_L3_continue;
 
-      /* "triangle_hash.pyx":66
+      /* "triangle_hash.pyx":68
  *             x = int(points[i_point, 0])
  *             y = int(points[i_point, 1])
  *             if not (0 <= x < self.resolution and 0 <= y < self.resolution):             # <<<<<<<<<<<<<<
@@ -3301,7 +3301,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
  */
     }
 
-    /* "triangle_hash.pyx":69
+    /* "triangle_hash.pyx":71
  *                 continue
  * 
  *             spatial_idx = self.resolution * x +  y             # <<<<<<<<<<<<<<
@@ -3310,7 +3310,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
  */
     __pyx_v_spatial_idx = ((__pyx_v_self->resolution * __pyx_v_x) + __pyx_v_y);
 
-    /* "triangle_hash.pyx":70
+    /* "triangle_hash.pyx":72
  * 
  *             spatial_idx = self.resolution * x +  y
  *             for i_tri in self.spatial_hash[spatial_idx]:             # <<<<<<<<<<<<<<
@@ -3325,7 +3325,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
       ++__pyx_t_14;
       __pyx_v_i_tri = __pyx_t_16;
 
-      /* "triangle_hash.pyx":71
+      /* "triangle_hash.pyx":73
  *             spatial_idx = self.resolution * x +  y
  *             for i_tri in self.spatial_hash[spatial_idx]:
  *                 points_indices.push_back(i_point)             # <<<<<<<<<<<<<<
@@ -3336,10 +3336,10 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
         __pyx_v_points_indices.push_back(__pyx_v_i_point);
       } catch(...) {
         __Pyx_CppExn2PyErr();
-        __PYX_ERR(0, 71, __pyx_L1_error)
+        __PYX_ERR(0, 73, __pyx_L1_error)
       }
 
-      /* "triangle_hash.pyx":72
+      /* "triangle_hash.pyx":74
  *             for i_tri in self.spatial_hash[spatial_idx]:
  *                 points_indices.push_back(i_point)
  *                 tri_indices.push_back(i_tri)             # <<<<<<<<<<<<<<
@@ -3350,10 +3350,10 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
         __pyx_v_tri_indices.push_back(__pyx_v_i_tri);
       } catch(...) {
         __Pyx_CppExn2PyErr();
-        __PYX_ERR(0, 72, __pyx_L1_error)
+        __PYX_ERR(0, 74, __pyx_L1_error)
       }
 
-      /* "triangle_hash.pyx":70
+      /* "triangle_hash.pyx":72
  * 
  *             spatial_idx = self.resolution * x +  y
  *             for i_tri in self.spatial_hash[spatial_idx]:             # <<<<<<<<<<<<<<
@@ -3364,35 +3364,35 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
     __pyx_L3_continue:;
   }
 
-  /* "triangle_hash.pyx":74
+  /* "triangle_hash.pyx":76
  *                 tri_indices.push_back(i_tri)
  * 
  *         points_indices_np = np.zeros(points_indices.size(), dtype=np.int32)             # <<<<<<<<<<<<<<
  *         tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32)
  * 
  */
-  __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 74, __pyx_L1_error)
+  __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 76, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
-  __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_zeros); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 74, __pyx_L1_error)
+  __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_zeros); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 76, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_2);
   __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
-  __pyx_t_1 = __Pyx_PyInt_FromSize_t(__pyx_v_points_indices.size()); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 74, __pyx_L1_error)
+  __pyx_t_1 = __Pyx_PyInt_FromSize_t(__pyx_v_points_indices.size()); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 76, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
-  __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 74, __pyx_L1_error)
+  __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 76, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_4);
   __Pyx_GIVEREF(__pyx_t_1);
   PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_1);
   __pyx_t_1 = 0;
-  __pyx_t_1 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 74, __pyx_L1_error)
+  __pyx_t_1 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 76, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
-  __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_np); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 74, __pyx_L1_error)
+  __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_np); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 76, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_3);
-  __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_3, __pyx_n_s_int32); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 74, __pyx_L1_error)
+  __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_3, __pyx_n_s_int32); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 76, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_5);
   __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
-  if (PyDict_SetItem(__pyx_t_1, __pyx_n_s_dtype, __pyx_t_5) < 0) __PYX_ERR(0, 74, __pyx_L1_error)
+  if (PyDict_SetItem(__pyx_t_1, __pyx_n_s_dtype, __pyx_t_5) < 0) __PYX_ERR(0, 76, __pyx_L1_error)
   __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
-  __pyx_t_5 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_4, __pyx_t_1); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 74, __pyx_L1_error)
+  __pyx_t_5 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_4, __pyx_t_1); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 76, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_5);
   __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
   __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
@@ -3400,35 +3400,35 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
   __pyx_v_points_indices_np = __pyx_t_5;
   __pyx_t_5 = 0;
 
-  /* "triangle_hash.pyx":75
+  /* "triangle_hash.pyx":77
  * 
  *         points_indices_np = np.zeros(points_indices.size(), dtype=np.int32)
  *         tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32)             # <<<<<<<<<<<<<<
  * 
  *         cdef int[:] points_indices_view = points_indices_np
  */
-  __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 75, __pyx_L1_error)
+  __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 77, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_5);
-  __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_zeros); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 75, __pyx_L1_error)
+  __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_zeros); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 77, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
   __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
-  __pyx_t_5 = __Pyx_PyInt_FromSize_t(__pyx_v_tri_indices.size()); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 75, __pyx_L1_error)
+  __pyx_t_5 = __Pyx_PyInt_FromSize_t(__pyx_v_tri_indices.size()); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 77, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_5);
-  __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 75, __pyx_L1_error)
+  __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 77, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_4);
   __Pyx_GIVEREF(__pyx_t_5);
   PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_5);
   __pyx_t_5 = 0;
-  __pyx_t_5 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 75, __pyx_L1_error)
+  __pyx_t_5 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 77, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_5);
-  __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_np); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 75, __pyx_L1_error)
+  __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_np); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 77, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_2);
-  __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_int32); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 75, __pyx_L1_error)
+  __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_int32); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 77, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_3);
   __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
-  if (PyDict_SetItem(__pyx_t_5, __pyx_n_s_dtype, __pyx_t_3) < 0) __PYX_ERR(0, 75, __pyx_L1_error)
+  if (PyDict_SetItem(__pyx_t_5, __pyx_n_s_dtype, __pyx_t_3) < 0) __PYX_ERR(0, 77, __pyx_L1_error)
   __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
-  __pyx_t_3 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_4, __pyx_t_5); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 75, __pyx_L1_error)
+  __pyx_t_3 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_4, __pyx_t_5); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 77, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_3);
   __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
   __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
@@ -3436,31 +3436,31 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
   __pyx_v_tri_indices_np = __pyx_t_3;
   __pyx_t_3 = 0;
 
-  /* "triangle_hash.pyx":77
+  /* "triangle_hash.pyx":79
  *         tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32)
  * 
  *         cdef int[:] points_indices_view = points_indices_np             # <<<<<<<<<<<<<<
  *         cdef int[:] tri_indices_view = tri_indices_np
  * 
  */
-  __pyx_t_17 = __Pyx_PyObject_to_MemoryviewSlice_ds_int(__pyx_v_points_indices_np, PyBUF_WRITABLE); if (unlikely(!__pyx_t_17.memview)) __PYX_ERR(0, 77, __pyx_L1_error)
+  __pyx_t_17 = __Pyx_PyObject_to_MemoryviewSlice_ds_int(__pyx_v_points_indices_np, PyBUF_WRITABLE); if (unlikely(!__pyx_t_17.memview)) __PYX_ERR(0, 79, __pyx_L1_error)
   __pyx_v_points_indices_view = __pyx_t_17;
   __pyx_t_17.memview = NULL;
   __pyx_t_17.data = NULL;
 
-  /* "triangle_hash.pyx":78
+  /* "triangle_hash.pyx":80
  * 
  *         cdef int[:] points_indices_view = points_indices_np
  *         cdef int[:] tri_indices_view = tri_indices_np             # <<<<<<<<<<<<<<
  * 
  *         for k in range(points_indices.size()):
  */
-  __pyx_t_17 = __Pyx_PyObject_to_MemoryviewSlice_ds_int(__pyx_v_tri_indices_np, PyBUF_WRITABLE); if (unlikely(!__pyx_t_17.memview)) __PYX_ERR(0, 78, __pyx_L1_error)
+  __pyx_t_17 = __Pyx_PyObject_to_MemoryviewSlice_ds_int(__pyx_v_tri_indices_np, PyBUF_WRITABLE); if (unlikely(!__pyx_t_17.memview)) __PYX_ERR(0, 80, __pyx_L1_error)
   __pyx_v_tri_indices_view = __pyx_t_17;
   __pyx_t_17.memview = NULL;
   __pyx_t_17.data = NULL;
 
-  /* "triangle_hash.pyx":80
+  /* "triangle_hash.pyx":82
  *         cdef int[:] tri_indices_view = tri_indices_np
  * 
  *         for k in range(points_indices.size()):             # <<<<<<<<<<<<<<
@@ -3472,7 +3472,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
   for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_19; __pyx_t_6+=1) {
     __pyx_v_k = __pyx_t_6;
 
-    /* "triangle_hash.pyx":81
+    /* "triangle_hash.pyx":83
  * 
  *         for k in range(points_indices.size()):
  *             points_indices_view[k] = points_indices[k]             # <<<<<<<<<<<<<<
@@ -3483,7 +3483,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
     *((int *) ( /* dim=0 */ (__pyx_v_points_indices_view.data + __pyx_t_9 * __pyx_v_points_indices_view.strides[0]) )) = (__pyx_v_points_indices[__pyx_v_k]);
   }
 
-  /* "triangle_hash.pyx":83
+  /* "triangle_hash.pyx":85
  *             points_indices_view[k] = points_indices[k]
  * 
  *         for k in range(tri_indices.size()):             # <<<<<<<<<<<<<<
@@ -3495,7 +3495,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
   for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_19; __pyx_t_6+=1) {
     __pyx_v_k = __pyx_t_6;
 
-    /* "triangle_hash.pyx":84
+    /* "triangle_hash.pyx":86
  * 
  *         for k in range(tri_indices.size()):
  *             tri_indices_view[k] = tri_indices[k]             # <<<<<<<<<<<<<<
@@ -3506,13 +3506,13 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
     *((int *) ( /* dim=0 */ (__pyx_v_tri_indices_view.data + __pyx_t_9 * __pyx_v_tri_indices_view.strides[0]) )) = (__pyx_v_tri_indices[__pyx_v_k]);
   }
 
-  /* "triangle_hash.pyx":86
+  /* "triangle_hash.pyx":88
  *             tri_indices_view[k] = tri_indices[k]
  * 
  *         return points_indices_np, tri_indices_np             # <<<<<<<<<<<<<<
  */
   __Pyx_XDECREF(__pyx_r);
-  __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 86, __pyx_L1_error)
+  __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 88, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_3);
   __Pyx_INCREF(__pyx_v_points_indices_np);
   __Pyx_GIVEREF(__pyx_v_points_indices_np);
@@ -3524,7 +3524,7 @@ static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_1
   __pyx_t_3 = 0;
   goto __pyx_L0;
 
-  /* "triangle_hash.pyx":51
+  /* "triangle_hash.pyx":53
  *     @cython.boundscheck(False)  # Deactivate bounds checking
  *     @cython.wraparound(False)   # Deactivate negative indexing.
  *     cpdef query(self, double[:, :] points):             # <<<<<<<<<<<<<<
@@ -3563,7 +3563,7 @@ static PyObject *__pyx_pw_13triangle_hash_12TriangleHash_3query(PyObject *__pyx_
   __Pyx_RefNannyDeclarations
   __Pyx_RefNannySetupContext("query (wrapper)", 0);
   assert(__pyx_arg_points); {
-    __pyx_v_points = __Pyx_PyObject_to_MemoryviewSlice_dsds_double(__pyx_arg_points, PyBUF_WRITABLE); if (unlikely(!__pyx_v_points.memview)) __PYX_ERR(0, 51, __pyx_L3_error)
+    __pyx_v_points = __Pyx_PyObject_to_MemoryviewSlice_dsds_double(__pyx_arg_points, PyBUF_WRITABLE); if (unlikely(!__pyx_v_points.memview)) __PYX_ERR(0, 53, __pyx_L3_error)
   }
   goto __pyx_L4_argument_unpacking_done;
   __pyx_L3_error:;
@@ -3587,8 +3587,8 @@ static PyObject *__pyx_pf_13triangle_hash_12TriangleHash_2query(struct __pyx_obj
   int __pyx_clineno = 0;
   __Pyx_RefNannySetupContext("query", 0);
   __Pyx_XDECREF(__pyx_r);
-  if (unlikely(!__pyx_v_points.memview)) { __Pyx_RaiseUnboundLocalError("points"); __PYX_ERR(0, 51, __pyx_L1_error) }
-  __pyx_t_1 = __pyx_f_13triangle_hash_12TriangleHash_query(__pyx_v_self, __pyx_v_points, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 51, __pyx_L1_error)
+  if (unlikely(!__pyx_v_points.memview)) { __Pyx_RaiseUnboundLocalError("points"); __PYX_ERR(0, 53, __pyx_L1_error) }
+  __pyx_t_1 = __pyx_f_13triangle_hash_12TriangleHash_query(__pyx_v_self, __pyx_v_points, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 53, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
   __pyx_r = __pyx_t_1;
   __pyx_t_1 = 0;
@@ -18738,7 +18738,7 @@ static __Pyx_StringTabEntry __pyx_string_tab[] = {
   {0, 0, 0, 0, 0, 0, 0}
 };
 static CYTHON_SMALL_CODE int __Pyx_InitCachedBuiltins(void) {
-  __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 31, __pyx_L1_error)
+  __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 33, __pyx_L1_error)
   __pyx_builtin_TypeError = __Pyx_GetBuiltinName(__pyx_n_s_TypeError); if (!__pyx_builtin_TypeError) __PYX_ERR(1, 2, __pyx_L1_error)
   __pyx_builtin_ImportError = __Pyx_GetBuiltinName(__pyx_n_s_ImportError); if (!__pyx_builtin_ImportError) __PYX_ERR(2, 945, __pyx_L1_error)
   __pyx_builtin_ValueError = __Pyx_GetBuiltinName(__pyx_n_s_ValueError); if (!__pyx_builtin_ValueError) __PYX_ERR(1, 133, __pyx_L1_error)
@@ -19118,16 +19118,16 @@ static int __Pyx_modinit_type_init_code(void) {
   __pyx_vtabptr_13triangle_hash_TriangleHash = &__pyx_vtable_13triangle_hash_TriangleHash;
   __pyx_vtable_13triangle_hash_TriangleHash._build_hash = (int (*)(struct __pyx_obj_13triangle_hash_TriangleHash *, __Pyx_memviewslice))__pyx_f_13triangle_hash_12TriangleHash__build_hash;
   __pyx_vtable_13triangle_hash_TriangleHash.query = (PyObject *(*)(struct __pyx_obj_13triangle_hash_TriangleHash *, __Pyx_memviewslice, int __pyx_skip_dispatch))__pyx_f_13triangle_hash_12TriangleHash_query;
-  if (PyType_Ready(&__pyx_type_13triangle_hash_TriangleHash) < 0) __PYX_ERR(0, 9, __pyx_L1_error)
+  if (PyType_Ready(&__pyx_type_13triangle_hash_TriangleHash) < 0) __PYX_ERR(0, 11, __pyx_L1_error)
   #if PY_VERSION_HEX < 0x030800B1
   __pyx_type_13triangle_hash_TriangleHash.tp_print = 0;
   #endif
   if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_type_13triangle_hash_TriangleHash.tp_dictoffset && __pyx_type_13triangle_hash_TriangleHash.tp_getattro == PyObject_GenericGetAttr)) {
     __pyx_type_13triangle_hash_TriangleHash.tp_getattro = __Pyx_PyObject_GenericGetAttr;
   }
-  if (__Pyx_SetVtable(__pyx_type_13triangle_hash_TriangleHash.tp_dict, __pyx_vtabptr_13triangle_hash_TriangleHash) < 0) __PYX_ERR(0, 9, __pyx_L1_error)
-  if (PyObject_SetAttr(__pyx_m, __pyx_n_s_TriangleHash, (PyObject *)&__pyx_type_13triangle_hash_TriangleHash) < 0) __PYX_ERR(0, 9, __pyx_L1_error)
-  if (__Pyx_setup_reduce((PyObject*)&__pyx_type_13triangle_hash_TriangleHash) < 0) __PYX_ERR(0, 9, __pyx_L1_error)
+  if (__Pyx_SetVtable(__pyx_type_13triangle_hash_TriangleHash.tp_dict, __pyx_vtabptr_13triangle_hash_TriangleHash) < 0) __PYX_ERR(0, 11, __pyx_L1_error)
+  if (PyObject_SetAttr(__pyx_m, __pyx_n_s_TriangleHash, (PyObject *)&__pyx_type_13triangle_hash_TriangleHash) < 0) __PYX_ERR(0, 11, __pyx_L1_error)
+  if (__Pyx_setup_reduce((PyObject*)&__pyx_type_13triangle_hash_TriangleHash) < 0) __PYX_ERR(0, 11, __pyx_L1_error)
   __pyx_ptype_13triangle_hash_TriangleHash = &__pyx_type_13triangle_hash_TriangleHash;
   __pyx_vtabptr_array = &__pyx_vtable_array;
   __pyx_vtable_array.get_memview = (PyObject *(*)(struct __pyx_array_obj *))__pyx_array_get_memview;
@@ -19468,7 +19468,7 @@ if (!__Pyx_RefNanny) {
  * 
  * # distutils: language=c++
  * import numpy as np             # <<<<<<<<<<<<<<
- * cimport numpy as np
+ * 
  * cimport cython
  */
   __pyx_t_1 = __Pyx_Import(__pyx_n_s_numpy, 0, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 3, __pyx_L1_error)
@@ -19480,7 +19480,7 @@ if (!__Pyx_RefNanny) {
  * 
  * # distutils: language=c++             # <<<<<<<<<<<<<<
  * import numpy as np
- * cimport numpy as np
+ * 
  */
   __pyx_t_1 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 2, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
diff --git a/lib/common/libmesh/triangle_hash.pyx b/lib/common/libmesh/triangle_hash.pyx
index 9e3ad59bd7bfc35a3e92e0b40e0349bfb3759a4b..6d6da0a763044307899100e3fa0f849ba2e0444b 100644
--- a/lib/common/libmesh/triangle_hash.pyx
+++ b/lib/common/libmesh/triangle_hash.pyx
@@ -1,10 +1,12 @@
 
 # distutils: language=c++
 import numpy as np
-cimport numpy as np
+
 cimport cython
+cimport numpy as np
+from libc.math cimport ceil, floor
 from libcpp.vector cimport vector
-from libc.math cimport floor, ceil
+
 
 cdef class TriangleHash:
     cdef vector[vector[int]] spatial_hash
diff --git a/lib/common/libvoxelize/voxelize.c b/lib/common/libvoxelize/voxelize.c
index 7320c8246e2cec1892352b4e4208765868af9651..11d0d3b5c00684dd4cd6b5268462adb83d7c89ad 100644
--- a/lib/common/libvoxelize/voxelize.c
+++ b/lib/common/libvoxelize/voxelize.c
@@ -2115,7 +2115,7 @@ static PyObject *__pyx_tuple__24;
 static PyObject *__pyx_codeobj__25;
 /* Late includes */
 
-/* "voxelize.pyx":12
+/* "voxelize.pyx":13
  * @cython.boundscheck(False)  # Deactivate bounds checking
  * @cython.wraparound(False)   # Deactivate negative indexing.
  * cpdef int voxelize_mesh_(bint[:, :, :] occ, float[:, :, ::1] faces):             # <<<<<<<<<<<<<<
@@ -2138,7 +2138,7 @@ static int __pyx_f_8voxelize_voxelize_mesh_(__Pyx_memviewslice __pyx_v_occ, __Py
   int __pyx_clineno = 0;
   __Pyx_RefNannySetupContext("voxelize_mesh_", 0);
 
-  /* "voxelize.pyx":13
+  /* "voxelize.pyx":14
  * @cython.wraparound(False)   # Deactivate negative indexing.
  * cpdef int voxelize_mesh_(bint[:, :, :] occ, float[:, :, ::1] faces):
  *     assert(faces.shape[1] == 3)             # <<<<<<<<<<<<<<
@@ -2149,12 +2149,12 @@ static int __pyx_f_8voxelize_voxelize_mesh_(__Pyx_memviewslice __pyx_v_occ, __Py
   if (unlikely(!Py_OptimizeFlag)) {
     if (unlikely(!(((__pyx_v_faces.shape[1]) == 3) != 0))) {
       PyErr_SetNone(PyExc_AssertionError);
-      __PYX_ERR(0, 13, __pyx_L1_error)
+      __PYX_ERR(0, 14, __pyx_L1_error)
     }
   }
   #endif
 
-  /* "voxelize.pyx":14
+  /* "voxelize.pyx":15
  * cpdef int voxelize_mesh_(bint[:, :, :] occ, float[:, :, ::1] faces):
  *     assert(faces.shape[1] == 3)
  *     assert(faces.shape[2] == 3)             # <<<<<<<<<<<<<<
@@ -2165,12 +2165,12 @@ static int __pyx_f_8voxelize_voxelize_mesh_(__Pyx_memviewslice __pyx_v_occ, __Py
   if (unlikely(!Py_OptimizeFlag)) {
     if (unlikely(!(((__pyx_v_faces.shape[2]) == 3) != 0))) {
       PyErr_SetNone(PyExc_AssertionError);
-      __PYX_ERR(0, 14, __pyx_L1_error)
+      __PYX_ERR(0, 15, __pyx_L1_error)
     }
   }
   #endif
 
-  /* "voxelize.pyx":16
+  /* "voxelize.pyx":17
  *     assert(faces.shape[2] == 3)
  * 
  *     n_faces = faces.shape[0]             # <<<<<<<<<<<<<<
@@ -2179,7 +2179,7 @@ static int __pyx_f_8voxelize_voxelize_mesh_(__Pyx_memviewslice __pyx_v_occ, __Py
  */
   __pyx_v_n_faces = (__pyx_v_faces.shape[0]);
 
-  /* "voxelize.pyx":18
+  /* "voxelize.pyx":19
  *     n_faces = faces.shape[0]
  *     cdef int i
  *     for i in range(n_faces):             # <<<<<<<<<<<<<<
@@ -2191,7 +2191,7 @@ static int __pyx_f_8voxelize_voxelize_mesh_(__Pyx_memviewslice __pyx_v_occ, __Py
   for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) {
     __pyx_v_i = __pyx_t_3;
 
-    /* "voxelize.pyx":19
+    /* "voxelize.pyx":20
  *     cdef int i
  *     for i in range(n_faces):
  *         voxelize_triangle_(occ, faces[i])             # <<<<<<<<<<<<<<
@@ -2221,7 +2221,7 @@ __pyx_t_4.strides[1] = __pyx_v_faces.strides[2];
     __pyx_t_4.data = NULL;
   }
 
-  /* "voxelize.pyx":12
+  /* "voxelize.pyx":13
  * @cython.boundscheck(False)  # Deactivate bounds checking
  * @cython.wraparound(False)   # Deactivate negative indexing.
  * cpdef int voxelize_mesh_(bint[:, :, :] occ, float[:, :, ::1] faces):             # <<<<<<<<<<<<<<
@@ -2275,11 +2275,11 @@ static PyObject *__pyx_pw_8voxelize_1voxelize_mesh_(PyObject *__pyx_self, PyObje
         case  1:
         if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_faces)) != 0)) kw_args--;
         else {
-          __Pyx_RaiseArgtupleInvalid("voxelize_mesh_", 1, 2, 2, 1); __PYX_ERR(0, 12, __pyx_L3_error)
+          __Pyx_RaiseArgtupleInvalid("voxelize_mesh_", 1, 2, 2, 1); __PYX_ERR(0, 13, __pyx_L3_error)
         }
       }
       if (unlikely(kw_args > 0)) {
-        if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "voxelize_mesh_") < 0)) __PYX_ERR(0, 12, __pyx_L3_error)
+        if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "voxelize_mesh_") < 0)) __PYX_ERR(0, 13, __pyx_L3_error)
       }
     } else if (PyTuple_GET_SIZE(__pyx_args) != 2) {
       goto __pyx_L5_argtuple_error;
@@ -2287,12 +2287,12 @@ static PyObject *__pyx_pw_8voxelize_1voxelize_mesh_(PyObject *__pyx_self, PyObje
       values[0] = PyTuple_GET_ITEM(__pyx_args, 0);
       values[1] = PyTuple_GET_ITEM(__pyx_args, 1);
     }
-    __pyx_v_occ = __Pyx_PyObject_to_MemoryviewSlice_dsdsds_int(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_occ.memview)) __PYX_ERR(0, 12, __pyx_L3_error)
-    __pyx_v_faces = __Pyx_PyObject_to_MemoryviewSlice_d_d_dc_float(values[1], PyBUF_WRITABLE); if (unlikely(!__pyx_v_faces.memview)) __PYX_ERR(0, 12, __pyx_L3_error)
+    __pyx_v_occ = __Pyx_PyObject_to_MemoryviewSlice_dsdsds_int(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_occ.memview)) __PYX_ERR(0, 13, __pyx_L3_error)
+    __pyx_v_faces = __Pyx_PyObject_to_MemoryviewSlice_d_d_dc_float(values[1], PyBUF_WRITABLE); if (unlikely(!__pyx_v_faces.memview)) __PYX_ERR(0, 13, __pyx_L3_error)
   }
   goto __pyx_L4_argument_unpacking_done;
   __pyx_L5_argtuple_error:;
-  __Pyx_RaiseArgtupleInvalid("voxelize_mesh_", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 12, __pyx_L3_error)
+  __Pyx_RaiseArgtupleInvalid("voxelize_mesh_", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 13, __pyx_L3_error)
   __pyx_L3_error:;
   __Pyx_AddTraceback("voxelize.voxelize_mesh_", __pyx_clineno, __pyx_lineno, __pyx_filename);
   __Pyx_RefNannyFinishContext();
@@ -2314,9 +2314,9 @@ static PyObject *__pyx_pf_8voxelize_voxelize_mesh_(CYTHON_UNUSED PyObject *__pyx
   int __pyx_clineno = 0;
   __Pyx_RefNannySetupContext("voxelize_mesh_", 0);
   __Pyx_XDECREF(__pyx_r);
-  if (unlikely(!__pyx_v_occ.memview)) { __Pyx_RaiseUnboundLocalError("occ"); __PYX_ERR(0, 12, __pyx_L1_error) }
-  if (unlikely(!__pyx_v_faces.memview)) { __Pyx_RaiseUnboundLocalError("faces"); __PYX_ERR(0, 12, __pyx_L1_error) }
-  __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_f_8voxelize_voxelize_mesh_(__pyx_v_occ, __pyx_v_faces, 0)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 12, __pyx_L1_error)
+  if (unlikely(!__pyx_v_occ.memview)) { __Pyx_RaiseUnboundLocalError("occ"); __PYX_ERR(0, 13, __pyx_L1_error) }
+  if (unlikely(!__pyx_v_faces.memview)) { __Pyx_RaiseUnboundLocalError("faces"); __PYX_ERR(0, 13, __pyx_L1_error) }
+  __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_f_8voxelize_voxelize_mesh_(__pyx_v_occ, __pyx_v_faces, 0)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 13, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
   __pyx_r = __pyx_t_1;
   __pyx_t_1 = 0;
@@ -2335,7 +2335,7 @@ static PyObject *__pyx_pf_8voxelize_voxelize_mesh_(CYTHON_UNUSED PyObject *__pyx
   return __pyx_r;
 }
 
-/* "voxelize.pyx":24
+/* "voxelize.pyx":25
  * @cython.boundscheck(False)  # Deactivate bounds checking
  * @cython.wraparound(False)   # Deactivate negative indexing.
  * cpdef int voxelize_triangle_(bint[:, :, :] occupancies, float[:, ::1] triverts):             # <<<<<<<<<<<<<<
@@ -2382,7 +2382,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
   Py_ssize_t __pyx_t_25;
   __Pyx_RefNannySetupContext("voxelize_triangle_", 0);
 
-  /* "voxelize.pyx":32
+  /* "voxelize.pyx":33
  *     cdef bint intersection
  * 
  *     boxhalfsize[:] = (0.5, 0.5, 0.5)             # <<<<<<<<<<<<<<
@@ -2397,7 +2397,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
   (__pyx_t_1[1]) = __pyx_t_3;
   (__pyx_t_1[2]) = __pyx_t_4;
 
-  /* "voxelize.pyx":34
+  /* "voxelize.pyx":35
  *     boxhalfsize[:] = (0.5, 0.5, 0.5)
  * 
  *     for i in range(3):             # <<<<<<<<<<<<<<
@@ -2407,7 +2407,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
   for (__pyx_t_5 = 0; __pyx_t_5 < 3; __pyx_t_5+=1) {
     __pyx_v_i = __pyx_t_5;
 
-    /* "voxelize.pyx":36
+    /* "voxelize.pyx":37
  *     for i in range(3):
  *         bbox_min[i] = <int> (
  *             min(triverts[0, i], triverts[1, i], triverts[2, i])             # <<<<<<<<<<<<<<
@@ -2435,7 +2435,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
       __pyx_t_11 = __pyx_t_10;
     }
 
-    /* "voxelize.pyx":35
+    /* "voxelize.pyx":36
  * 
  *     for i in range(3):
  *         bbox_min[i] = <int> (             # <<<<<<<<<<<<<<
@@ -2444,7 +2444,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
  */
     (__pyx_v_bbox_min[__pyx_v_i]) = ((int)__pyx_t_11);
 
-    /* "voxelize.pyx":38
+    /* "voxelize.pyx":39
  *             min(triverts[0, i], triverts[1, i], triverts[2, i])
  *         )
  *         bbox_min[i] = min(max(bbox_min[i], 0), occupancies.shape[i] - 1)             # <<<<<<<<<<<<<<
@@ -2468,7 +2468,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
     (__pyx_v_bbox_min[__pyx_v_i]) = __pyx_t_16;
   }
 
-  /* "voxelize.pyx":40
+  /* "voxelize.pyx":41
  *         bbox_min[i] = min(max(bbox_min[i], 0), occupancies.shape[i] - 1)
  * 
  *     for i in range(3):             # <<<<<<<<<<<<<<
@@ -2478,7 +2478,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
   for (__pyx_t_5 = 0; __pyx_t_5 < 3; __pyx_t_5+=1) {
     __pyx_v_i = __pyx_t_5;
 
-    /* "voxelize.pyx":42
+    /* "voxelize.pyx":43
  *     for i in range(3):
  *         bbox_max[i] = <int> (
  *             max(triverts[0, i], triverts[1, i], triverts[2, i])             # <<<<<<<<<<<<<<
@@ -2506,7 +2506,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
       __pyx_t_10 = __pyx_t_9;
     }
 
-    /* "voxelize.pyx":41
+    /* "voxelize.pyx":42
  * 
  *     for i in range(3):
  *         bbox_max[i] = <int> (             # <<<<<<<<<<<<<<
@@ -2515,7 +2515,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
  */
     (__pyx_v_bbox_max[__pyx_v_i]) = ((int)__pyx_t_10);
 
-    /* "voxelize.pyx":44
+    /* "voxelize.pyx":45
  *             max(triverts[0, i], triverts[1, i], triverts[2, i])
  *         )
  *         bbox_max[i] = min(max(bbox_max[i], 0), occupancies.shape[i] - 1)             # <<<<<<<<<<<<<<
@@ -2539,7 +2539,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
     (__pyx_v_bbox_max[__pyx_v_i]) = __pyx_t_12;
   }
 
-  /* "voxelize.pyx":46
+  /* "voxelize.pyx":47
  *         bbox_max[i] = min(max(bbox_max[i], 0), occupancies.shape[i] - 1)
  * 
  *     for i in range(bbox_min[0], bbox_max[0] + 1):             # <<<<<<<<<<<<<<
@@ -2551,7 +2551,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
   for (__pyx_t_5 = (__pyx_v_bbox_min[0]); __pyx_t_5 < __pyx_t_15; __pyx_t_5+=1) {
     __pyx_v_i = __pyx_t_5;
 
-    /* "voxelize.pyx":47
+    /* "voxelize.pyx":48
  * 
  *     for i in range(bbox_min[0], bbox_max[0] + 1):
  *         for j in range(bbox_min[1], bbox_max[1] + 1):             # <<<<<<<<<<<<<<
@@ -2563,7 +2563,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
     for (__pyx_t_14 = (__pyx_v_bbox_min[1]); __pyx_t_14 < __pyx_t_18; __pyx_t_14+=1) {
       __pyx_v_j = __pyx_t_14;
 
-      /* "voxelize.pyx":48
+      /* "voxelize.pyx":49
  *     for i in range(bbox_min[0], bbox_max[0] + 1):
  *         for j in range(bbox_min[1], bbox_max[1] + 1):
  *             for k in range(bbox_min[2], bbox_max[2] + 1):             # <<<<<<<<<<<<<<
@@ -2575,7 +2575,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
       for (__pyx_t_21 = (__pyx_v_bbox_min[2]); __pyx_t_21 < __pyx_t_20; __pyx_t_21+=1) {
         __pyx_v_k = __pyx_t_21;
 
-        /* "voxelize.pyx":49
+        /* "voxelize.pyx":50
  *         for j in range(bbox_min[1], bbox_max[1] + 1):
  *             for k in range(bbox_min[2], bbox_max[2] + 1):
  *                 boxcenter[:] = (i + 0.5, j + 0.5, k + 0.5)             # <<<<<<<<<<<<<<
@@ -2590,7 +2590,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
         (__pyx_t_1[1]) = __pyx_t_3;
         (__pyx_t_1[2]) = __pyx_t_2;
 
-        /* "voxelize.pyx":51
+        /* "voxelize.pyx":52
  *                 boxcenter[:] = (i + 0.5, j + 0.5, k + 0.5)
  *                 intersection = triBoxOverlap(&boxcenter[0], &boxhalfsize[0],
  *                                              &triverts[0, 0], &triverts[1, 0], &triverts[2, 0])             # <<<<<<<<<<<<<<
@@ -2604,7 +2604,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
         __pyx_t_24 = 2;
         __pyx_t_25 = 0;
 
-        /* "voxelize.pyx":50
+        /* "voxelize.pyx":51
  *             for k in range(bbox_min[2], bbox_max[2] + 1):
  *                 boxcenter[:] = (i + 0.5, j + 0.5, k + 0.5)
  *                 intersection = triBoxOverlap(&boxcenter[0], &boxhalfsize[0],             # <<<<<<<<<<<<<<
@@ -2613,7 +2613,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
  */
         __pyx_v_intersection = triBoxOverlap((&(__pyx_v_boxcenter[0])), (&(__pyx_v_boxhalfsize[0])), (&(*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_triverts.data + __pyx_t_6 * __pyx_v_triverts.strides[0]) )) + __pyx_t_7)) )))), (&(*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_triverts.data + __pyx_t_22 * __pyx_v_triverts.strides[0]) )) + __pyx_t_23)) )))), (&(*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_triverts.data + __pyx_t_24 * __pyx_v_triverts.strides[0]) )) + __pyx_t_25)) )))));
 
-        /* "voxelize.pyx":52
+        /* "voxelize.pyx":53
  *                 intersection = triBoxOverlap(&boxcenter[0], &boxhalfsize[0],
  *                                              &triverts[0, 0], &triverts[1, 0], &triverts[2, 0])
  *                 occupancies[i, j, k] |= intersection             # <<<<<<<<<<<<<<
@@ -2628,7 +2628,7 @@ static int __pyx_f_8voxelize_voxelize_triangle_(__Pyx_memviewslice __pyx_v_occup
     }
   }
 
-  /* "voxelize.pyx":24
+  /* "voxelize.pyx":25
  * @cython.boundscheck(False)  # Deactivate bounds checking
  * @cython.wraparound(False)   # Deactivate negative indexing.
  * cpdef int voxelize_triangle_(bint[:, :, :] occupancies, float[:, ::1] triverts):             # <<<<<<<<<<<<<<
@@ -2676,11 +2676,11 @@ static PyObject *__pyx_pw_8voxelize_3voxelize_triangle_(PyObject *__pyx_self, Py
         case  1:
         if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_triverts)) != 0)) kw_args--;
         else {
-          __Pyx_RaiseArgtupleInvalid("voxelize_triangle_", 1, 2, 2, 1); __PYX_ERR(0, 24, __pyx_L3_error)
+          __Pyx_RaiseArgtupleInvalid("voxelize_triangle_", 1, 2, 2, 1); __PYX_ERR(0, 25, __pyx_L3_error)
         }
       }
       if (unlikely(kw_args > 0)) {
-        if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "voxelize_triangle_") < 0)) __PYX_ERR(0, 24, __pyx_L3_error)
+        if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "voxelize_triangle_") < 0)) __PYX_ERR(0, 25, __pyx_L3_error)
       }
     } else if (PyTuple_GET_SIZE(__pyx_args) != 2) {
       goto __pyx_L5_argtuple_error;
@@ -2688,12 +2688,12 @@ static PyObject *__pyx_pw_8voxelize_3voxelize_triangle_(PyObject *__pyx_self, Py
       values[0] = PyTuple_GET_ITEM(__pyx_args, 0);
       values[1] = PyTuple_GET_ITEM(__pyx_args, 1);
     }
-    __pyx_v_occupancies = __Pyx_PyObject_to_MemoryviewSlice_dsdsds_int(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_occupancies.memview)) __PYX_ERR(0, 24, __pyx_L3_error)
-    __pyx_v_triverts = __Pyx_PyObject_to_MemoryviewSlice_d_dc_float(values[1], PyBUF_WRITABLE); if (unlikely(!__pyx_v_triverts.memview)) __PYX_ERR(0, 24, __pyx_L3_error)
+    __pyx_v_occupancies = __Pyx_PyObject_to_MemoryviewSlice_dsdsds_int(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_occupancies.memview)) __PYX_ERR(0, 25, __pyx_L3_error)
+    __pyx_v_triverts = __Pyx_PyObject_to_MemoryviewSlice_d_dc_float(values[1], PyBUF_WRITABLE); if (unlikely(!__pyx_v_triverts.memview)) __PYX_ERR(0, 25, __pyx_L3_error)
   }
   goto __pyx_L4_argument_unpacking_done;
   __pyx_L5_argtuple_error:;
-  __Pyx_RaiseArgtupleInvalid("voxelize_triangle_", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 24, __pyx_L3_error)
+  __Pyx_RaiseArgtupleInvalid("voxelize_triangle_", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 25, __pyx_L3_error)
   __pyx_L3_error:;
   __Pyx_AddTraceback("voxelize.voxelize_triangle_", __pyx_clineno, __pyx_lineno, __pyx_filename);
   __Pyx_RefNannyFinishContext();
@@ -2715,9 +2715,9 @@ static PyObject *__pyx_pf_8voxelize_2voxelize_triangle_(CYTHON_UNUSED PyObject *
   int __pyx_clineno = 0;
   __Pyx_RefNannySetupContext("voxelize_triangle_", 0);
   __Pyx_XDECREF(__pyx_r);
-  if (unlikely(!__pyx_v_occupancies.memview)) { __Pyx_RaiseUnboundLocalError("occupancies"); __PYX_ERR(0, 24, __pyx_L1_error) }
-  if (unlikely(!__pyx_v_triverts.memview)) { __Pyx_RaiseUnboundLocalError("triverts"); __PYX_ERR(0, 24, __pyx_L1_error) }
-  __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_f_8voxelize_voxelize_triangle_(__pyx_v_occupancies, __pyx_v_triverts, 0)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)
+  if (unlikely(!__pyx_v_occupancies.memview)) { __Pyx_RaiseUnboundLocalError("occupancies"); __PYX_ERR(0, 25, __pyx_L1_error) }
+  if (unlikely(!__pyx_v_triverts.memview)) { __Pyx_RaiseUnboundLocalError("triverts"); __PYX_ERR(0, 25, __pyx_L1_error) }
+  __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_f_8voxelize_voxelize_triangle_(__pyx_v_occupancies, __pyx_v_triverts, 0)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 25, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
   __pyx_r = __pyx_t_1;
   __pyx_t_1 = 0;
@@ -2736,7 +2736,7 @@ static PyObject *__pyx_pf_8voxelize_2voxelize_triangle_(CYTHON_UNUSED PyObject *
   return __pyx_r;
 }
 
-/* "voxelize.pyx":57
+/* "voxelize.pyx":58
  * @cython.boundscheck(False)  # Deactivate bounds checking
  * @cython.wraparound(False)   # Deactivate negative indexing.
  * cdef int test_triangle_aabb(float[::1] boxcenter, float[::1] boxhalfsize, float[:, ::1] triverts):             # <<<<<<<<<<<<<<
@@ -2762,7 +2762,7 @@ static int __pyx_f_8voxelize_test_triangle_aabb(__Pyx_memviewslice __pyx_v_boxce
   int __pyx_clineno = 0;
   __Pyx_RefNannySetupContext("test_triangle_aabb", 0);
 
-  /* "voxelize.pyx":58
+  /* "voxelize.pyx":59
  * @cython.wraparound(False)   # Deactivate negative indexing.
  * cdef int test_triangle_aabb(float[::1] boxcenter, float[::1] boxhalfsize, float[:, ::1] triverts):
  *     assert(boxcenter.shape[0] == 3)             # <<<<<<<<<<<<<<
@@ -2773,12 +2773,12 @@ static int __pyx_f_8voxelize_test_triangle_aabb(__Pyx_memviewslice __pyx_v_boxce
   if (unlikely(!Py_OptimizeFlag)) {
     if (unlikely(!(((__pyx_v_boxcenter.shape[0]) == 3) != 0))) {
       PyErr_SetNone(PyExc_AssertionError);
-      __PYX_ERR(0, 58, __pyx_L1_error)
+      __PYX_ERR(0, 59, __pyx_L1_error)
     }
   }
   #endif
 
-  /* "voxelize.pyx":59
+  /* "voxelize.pyx":60
  * cdef int test_triangle_aabb(float[::1] boxcenter, float[::1] boxhalfsize, float[:, ::1] triverts):
  *     assert(boxcenter.shape[0] == 3)
  *     assert(boxhalfsize.shape[0] == 3)             # <<<<<<<<<<<<<<
@@ -2789,12 +2789,12 @@ static int __pyx_f_8voxelize_test_triangle_aabb(__Pyx_memviewslice __pyx_v_boxce
   if (unlikely(!Py_OptimizeFlag)) {
     if (unlikely(!(((__pyx_v_boxhalfsize.shape[0]) == 3) != 0))) {
       PyErr_SetNone(PyExc_AssertionError);
-      __PYX_ERR(0, 59, __pyx_L1_error)
+      __PYX_ERR(0, 60, __pyx_L1_error)
     }
   }
   #endif
 
-  /* "voxelize.pyx":60
+  /* "voxelize.pyx":61
  *     assert(boxcenter.shape[0] == 3)
  *     assert(boxhalfsize.shape[0] == 3)
  *     assert(triverts.shape[0] == triverts.shape[1] == 3)             # <<<<<<<<<<<<<<
@@ -2809,12 +2809,12 @@ static int __pyx_f_8voxelize_test_triangle_aabb(__Pyx_memviewslice __pyx_v_boxce
     }
     if (unlikely(!(__pyx_t_1 != 0))) {
       PyErr_SetNone(PyExc_AssertionError);
-      __PYX_ERR(0, 60, __pyx_L1_error)
+      __PYX_ERR(0, 61, __pyx_L1_error)
     }
   }
   #endif
 
-  /* "voxelize.pyx":64
+  /* "voxelize.pyx":65
  *     # print(triverts)
  *     # Call functions
  *     cdef int result = triBoxOverlap(&boxcenter[0], &boxhalfsize[0],             # <<<<<<<<<<<<<<
@@ -2824,7 +2824,7 @@ static int __pyx_f_8voxelize_test_triangle_aabb(__Pyx_memviewslice __pyx_v_boxce
   __pyx_t_2 = 0;
   __pyx_t_3 = 0;
 
-  /* "voxelize.pyx":65
+  /* "voxelize.pyx":66
  *     # Call functions
  *     cdef int result = triBoxOverlap(&boxcenter[0], &boxhalfsize[0],
  *                                     &triverts[0, 0], &triverts[1, 0], &triverts[2, 0])             # <<<<<<<<<<<<<<
@@ -2837,7 +2837,7 @@ static int __pyx_f_8voxelize_test_triangle_aabb(__Pyx_memviewslice __pyx_v_boxce
   __pyx_t_8 = 2;
   __pyx_t_9 = 0;
 
-  /* "voxelize.pyx":64
+  /* "voxelize.pyx":65
  *     # print(triverts)
  *     # Call functions
  *     cdef int result = triBoxOverlap(&boxcenter[0], &boxhalfsize[0],             # <<<<<<<<<<<<<<
@@ -2846,7 +2846,7 @@ static int __pyx_f_8voxelize_test_triangle_aabb(__Pyx_memviewslice __pyx_v_boxce
  */
   __pyx_v_result = triBoxOverlap((&(*((float *) ( /* dim=0 */ ((char *) (((float *) __pyx_v_boxcenter.data) + __pyx_t_2)) )))), (&(*((float *) ( /* dim=0 */ ((char *) (((float *) __pyx_v_boxhalfsize.data) + __pyx_t_3)) )))), (&(*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_triverts.data + __pyx_t_4 * __pyx_v_triverts.strides[0]) )) + __pyx_t_5)) )))), (&(*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_triverts.data + __pyx_t_6 * __pyx_v_triverts.strides[0]) )) + __pyx_t_7)) )))), (&(*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_triverts.data + __pyx_t_8 * __pyx_v_triverts.strides[0]) )) + __pyx_t_9)) )))));
 
-  /* "voxelize.pyx":66
+  /* "voxelize.pyx":67
  *     cdef int result = triBoxOverlap(&boxcenter[0], &boxhalfsize[0],
  *                                     &triverts[0, 0], &triverts[1, 0], &triverts[2, 0])
  *     return result             # <<<<<<<<<<<<<<
@@ -2854,7 +2854,7 @@ static int __pyx_f_8voxelize_test_triangle_aabb(__Pyx_memviewslice __pyx_v_boxce
   __pyx_r = __pyx_v_result;
   goto __pyx_L0;
 
-  /* "voxelize.pyx":57
+  /* "voxelize.pyx":58
  * @cython.boundscheck(False)  # Deactivate bounds checking
  * @cython.wraparound(False)   # Deactivate negative indexing.
  * cdef int test_triangle_aabb(float[::1] boxcenter, float[::1] boxhalfsize, float[:, ::1] triverts):             # <<<<<<<<<<<<<<
@@ -16757,7 +16757,7 @@ static __Pyx_StringTabEntry __pyx_string_tab[] = {
   {0, 0, 0, 0, 0, 0, 0}
 };
 static CYTHON_SMALL_CODE int __Pyx_InitCachedBuiltins(void) {
-  __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 18, __pyx_L1_error)
+  __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 19, __pyx_L1_error)
   __pyx_builtin_ValueError = __Pyx_GetBuiltinName(__pyx_n_s_ValueError); if (!__pyx_builtin_ValueError) __PYX_ERR(1, 133, __pyx_L1_error)
   __pyx_builtin_MemoryError = __Pyx_GetBuiltinName(__pyx_n_s_MemoryError); if (!__pyx_builtin_MemoryError) __PYX_ERR(1, 148, __pyx_L1_error)
   __pyx_builtin_enumerate = __Pyx_GetBuiltinName(__pyx_n_s_enumerate); if (!__pyx_builtin_enumerate) __PYX_ERR(1, 151, __pyx_L1_error)
@@ -17377,8 +17377,8 @@ if (!__Pyx_RefNanny) {
 
   /* "voxelize.pyx":1
  * cimport cython             # <<<<<<<<<<<<<<
- * from libc.math cimport floor, ceil
  * from cython.view cimport array as cvarray
+ * from libc.math cimport ceil, floor
  */
   __pyx_t_1 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
diff --git a/lib/common/libvoxelize/voxelize.pyx b/lib/common/libvoxelize/voxelize.pyx
index 1ba840295c88f2ae11e2c575b3aece9dc07a930a..8e6a4829ba31b11bfab21ac1ae40bff27a74be65 100644
--- a/lib/common/libvoxelize/voxelize.pyx
+++ b/lib/common/libvoxelize/voxelize.pyx
@@ -1,6 +1,7 @@
 cimport cython
-from libc.math cimport floor, ceil
 from cython.view cimport array as cvarray
+from libc.math cimport ceil, floor
+
 
 cdef extern from "tribox2.h":
     int triBoxOverlap(float boxcenter[3], float boxhalfsize[3],
diff --git a/lib/common/local_affine.py b/lib/common/local_affine.py
index a8c7225ba1e0428161425669a19da3a94d0866df..ca23bd61e7c90de4a8ac19a4554c46417c5be87d 100644
--- a/lib/common/local_affine.py
+++ b/lib/common/local_affine.py
@@ -5,13 +5,14 @@
 # file that should have been included as part of this package.
 
 import torch
-import trimesh
 import torch.nn as nn
-from tqdm import tqdm
-from pytorch3d.structures import Meshes
+import trimesh
 from pytorch3d.loss import chamfer_distance
-from lib.dataset.mesh_util import update_mesh_shape_prior_losses
+from pytorch3d.structures import Meshes
+from tqdm import tqdm
+
 from lib.common.train_util import init_loss
+from lib.dataset.mesh_util import update_mesh_shape_prior_losses
 
 
 # reference: https://github.com/wuhaozhe/pytorch-nicp
@@ -84,11 +85,9 @@ def register(target_mesh, src_mesh, device, verbose=True):
         src_mesh.verts_padded().shape[0], src_mesh.edges_packed()
     ).to(device)
 
-    optimizer_cloth = torch.optim.Adam(
-        [{
-            'params': local_affine_model.parameters()
-        }], lr=1e-2, amsgrad=True
-    )
+    optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}],
+                                       lr=1e-2,
+                                       amsgrad=True)
     scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
         optimizer_cloth,
         mode="min",
@@ -104,7 +103,7 @@ def register(target_mesh, src_mesh, device, verbose=True):
         loop_cloth = tqdm(range(100))
     else:
         loop_cloth = range(100)
-    
+
     for i in loop_cloth:
 
         optimizer_cloth.zero_grad()
diff --git a/lib/common/render.py b/lib/common/render.py
index 769f6899a72a11bfe51cbd204e8e886b61aa4194..5f89eceeacba0d7ebac9602a800db48a0d406f4e 100644
--- a/lib/common/render.py
+++ b/lib/common/render.py
@@ -14,35 +14,36 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
+import math
+import os
+
+import cv2
+import numpy as np
+import torch
+from PIL import ImageColor
 from pytorch3d.renderer import (
+    AlphaCompositor,
     BlendParams,
-    blending,
-    look_at_view_transform,
     FoVOrthographicCameras,
-    RasterizationSettings,
+    MeshRasterizer,
+    MeshRenderer,
     PointsRasterizationSettings,
-    PointsRenderer,
-    AlphaCompositor,
     PointsRasterizer,
-    MeshRenderer,
-    MeshRasterizer,
+    PointsRenderer,
+    RasterizationSettings,
     SoftSilhouetteShader,
     TexturesVertex,
+    blending,
+    look_at_view_transform,
 )
 from pytorch3d.renderer.mesh import TexturesVertex
 from pytorch3d.structures import Meshes
-from lib.dataset.mesh_util import get_visibility
-from lib.common.imutils import blend_rgb_norm
+from termcolor import colored
+from tqdm import tqdm
 
 import lib.common.render_utils as util
-import torch
-import numpy as np
-from PIL import ImageColor
-from tqdm import tqdm
-import os
-import cv2
-import math
-from termcolor import colored
+from lib.common.imutils import blend_rgb_norm
+from lib.dataset.mesh_util import get_visibility
 
 
 def image2vid(images, vid_path):
@@ -58,7 +59,7 @@ def image2vid(images, vid_path):
     video.release()
 
 
-def query_color(verts, faces, image, device):
+def query_color(verts, faces, image, device, paint_normal=True):
     """query colors from points and image
 
     Args:
@@ -77,16 +78,16 @@ def query_color(verts, faces, image, device):
     visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
     uv = xy.unsqueeze(0).unsqueeze(2)    # [B, N, 2]
     uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
-    colors = (
-        (
-            torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :,
-                                                                           0].permute(1, 0) + 1.0
-        ) * 0.5 * 255.0
-    )
-    colors[visibility == 0.0] = (
-        (Meshes(verts.unsqueeze(0), faces.unsqueeze(0)).verts_normals_padded().squeeze(0) + 1.0) *
-        0.5 * 255.0
-    )[visibility == 0.0]
+    colors = ((
+        torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :, 0].permute(1, 0) +
+        1.0
+    ) * 0.5 * 255.0)
+    if paint_normal:
+        colors[visibility == 0.0] = ((
+            Meshes(verts.unsqueeze(0), faces.unsqueeze(0)).verts_normals_padded().squeeze(0) + 1.0
+        ) * 0.5 * 255.0)[visibility == 0.0]
+    else:
+        colors[visibility == 0.0] = torch.tensor([0.0, 0.0, 0.0]).to(device)
 
     return colors.detach().cpu()
 
@@ -121,31 +122,25 @@ class Render:
         self.step = 3
 
         self.cam_pos = {
-            "frontback":
-                torch.tensor(
-                    [
-                        (0, self.mesh_y_center, self.dis),
-                        (0, self.mesh_y_center, -self.dis),
-                    ]
-                ),
-            "four":
-                torch.tensor(
-                    [
-                        (0, self.mesh_y_center, self.dis),
-                        (self.dis, self.mesh_y_center, 0),
-                        (0, self.mesh_y_center, -self.dis),
-                        (-self.dis, self.mesh_y_center, 0),
-                    ]
-                ),
-            "around":
-                torch.tensor(
-                    [
-                        (
-                            100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center,
-                            100.0 * math.sin(np.pi / 180 * angle)
-                        ) for angle in range(0, 360, self.step)
-                    ]
-                )
+            "front":
+            torch.tensor([
+                (0, self.mesh_y_center, self.dis),
+                (0, self.mesh_y_center, -self.dis),
+            ]), "frontback":
+            torch.tensor([
+                (0, self.mesh_y_center, self.dis),
+                (0, self.mesh_y_center, -self.dis),
+            ]), "four":
+            torch.tensor([
+                (0, self.mesh_y_center, self.dis),
+                (self.dis, self.mesh_y_center, 0),
+                (0, self.mesh_y_center, -self.dis),
+                (-self.dis, self.mesh_y_center, 0),
+            ]), "around":
+            torch.tensor([(
+                100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center,
+                100.0 * math.sin(np.pi / 180 * angle)
+            ) for angle in range(0, 360, self.step)])
         }
 
         self.type = "color"
@@ -315,7 +310,7 @@ class Render:
             save_path,
             fourcc,
             self.fps,
-            (width*3, int(height)),
+            (width * 3, int(height)),
         )
 
         pbar = tqdm(range(len(self.meshes)))
@@ -352,15 +347,13 @@ class Render:
         for cam_id in pbar:
             img_raw = data["img_raw"]
             num_obj = len(mesh_renders) // 2
-            img_smpl = blend_rgb_norm(
-                (torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0, data
-            )
-            img_cloth = blend_rgb_norm(
-                (torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0, data
-            )
-            final_img = torch.cat(
-                [img_raw, img_smpl, img_cloth], dim=-1).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
-            
+            img_smpl = blend_rgb_norm((torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0,
+                                      data)
+            img_cloth = blend_rgb_norm((torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0,
+                                       data)
+            final_img = torch.cat([img_raw, img_smpl, img_cloth],
+                                  dim=-1).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
+
             video.write(final_img[:, :, ::-1])
 
         video.release()
diff --git a/lib/common/render_utils.py b/lib/common/render_utils.py
index cb2ca46f420c063c7a1c6a82276d41c42852e451..389eab740fc341b9395e1e008b4fa91e9ef3cf83 100644
--- a/lib/common/render_utils.py
+++ b/lib/common/render_utils.py
@@ -14,13 +14,15 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import torch
-from torch import nn
-import trimesh
 import math
 from typing import NewType
-from pytorch3d.structures import Meshes
+
+import numpy as np
+import torch
+import trimesh
 from pytorch3d.renderer.mesh import rasterize_meshes
+from pytorch3d.structures import Meshes
+from torch import nn
 
 Tensor = NewType("Tensor", torch.Tensor)
 
@@ -125,8 +127,6 @@ def batch_contains(verts, faces, points):
 
 
 def dict2obj(d):
-    # if isinstance(d, list):
-    #     d = [dict2obj(x) for x in d]
     if not isinstance(d, dict):
         return d
 
@@ -161,7 +161,9 @@ class Pytorch3dRasterizer(nn.Module):
         x,y,z are in image space, normalized
         can only render squared image now
     """
-    def __init__(self, image_size=224, blur_radius=0.0, faces_per_pixel=1):
+    def __init__(
+        self, image_size=224, blur_radius=0.0, faces_per_pixel=1, device=torch.device("cuda:0")
+    ):
         """
         use fixed raster_settings for rendering faces
         """
@@ -177,6 +179,7 @@ class Pytorch3dRasterizer(nn.Module):
         }
         raster_settings = dict2obj(raster_settings)
         self.raster_settings = raster_settings
+        self.device = device
 
     def forward(self, vertices, faces, attributes=None):
         fixed_vertices = vertices.clone()
@@ -209,3 +212,15 @@ class Pytorch3dRasterizer(nn.Module):
         pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
         pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
         return pixel_vals
+
+    def get_texture(self, uvcoords, uvfaces, verts, faces, verts_color):
+
+        batch_size = verts.shape[0]
+        uv_verts_color = face_vertices(verts_color, faces.expand(batch_size, -1,
+                                                                 -1)).to(self.device)
+        uv_map = self.forward(
+            uvcoords.expand(batch_size, -1, -1), uvfaces.expand(batch_size, -1, -1), uv_verts_color
+        )[:, :3]
+        uv_map_npy = np.flip(uv_map.squeeze(0).permute(1, 2, 0).cpu().numpy(), 0)
+
+        return uv_map_npy
diff --git a/lib/common/seg3d_lossless.py b/lib/common/seg3d_lossless.py
index 4f5cba2a1edb3a5df14d17beabb9d296203865c1..bf086b505a98e0fb8a639394e982cddb048ab8b0 100644
--- a/lib/common/seg3d_lossless.py
+++ b/lib/common/seg3d_lossless.py
@@ -14,19 +14,16 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from .seg3d_utils import (
-    create_grid3D,
-    plot_mask3D,
-    SmoothConv3D,
-)
+import logging
 
+import numpy as np
 import torch
 import torch.nn as nn
-import numpy as np
 import torch.nn.functional as F
-import logging
 from pytorch3d.ops.marching_cubes import marching_cubes
 
+from .seg3d_utils import SmoothConv3D, create_grid3D, plot_mask3D
+
 logging.getLogger("lightning").setLevel(logging.ERROR)
 
 
@@ -378,10 +375,8 @@ class Seg3dLossless(nn.Module):
 
                 with torch.no_grad():
                     # conflicts
-                    conflicts = (
-                        (occupancys_interp - self.balance_value) *
-                        (occupancys_topk - self.balance_value) < 0
-                    )[0, 0]
+                    conflicts = ((occupancys_interp - self.balance_value) *
+                                 (occupancys_topk - self.balance_value) < 0)[0, 0]
 
                     if self.visualize:
                         self.plot(occupancys, coords, final_D, final_H, final_W)
@@ -407,12 +402,9 @@ class Seg3dLossless(nn.Module):
                                 title="conflicts",
                             )
 
-                        conflicts_boundary = (
-                            (
-                                conflicts_coords.int() +
-                                self.gird8_offsets.unsqueeze(1) * stride.int()
-                            ).reshape(-1, 3).long().unique(dim=0)
-                        )
+                        conflicts_boundary = ((
+                            conflicts_coords.int() + self.gird8_offsets.unsqueeze(1) * stride.int()
+                        ).reshape(-1, 3).long().unique(dim=0))
                         conflicts_boundary[:, 0] = conflicts_boundary[:, 0].clamp(
                             0,
                             calculated.size(2) - 1
@@ -466,10 +458,8 @@ class Seg3dLossless(nn.Module):
 
                     with torch.no_grad():
                         # conflicts
-                        conflicts = (
-                            (occupancys_interp - self.balance_value) *
-                            (occupancys_topk - self.balance_value) < 0
-                        )[0, 0]
+                        conflicts = ((occupancys_interp - self.balance_value) *
+                                     (occupancys_topk - self.balance_value) < 0)[0, 0]
 
                     # put mask point predictions to the right places on the upsampled grid.
                     point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
diff --git a/lib/common/seg3d_utils.py b/lib/common/seg3d_utils.py
index bee264615a54777bb948414c82f502c678664329..8f2c547d2e02b40c9e149071818116160df0be3a 100644
--- a/lib/common/seg3d_utils.py
+++ b/lib/common/seg3d_utils.py
@@ -14,10 +14,10 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
+import matplotlib.pyplot as plt
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-import matplotlib.pyplot as plt
 
 
 def plot_mask2D(mask, title="", point_coords=None, figsize=10, point_marker_size=5):
@@ -140,9 +140,8 @@ class SmoothConv2D(nn.Module):
         assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
         self.padding = (kernel_size - 1) // 2
 
-        weight = torch.ones(
-            (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32
-        ) / (kernel_size**2)
+        weight = torch.ones((in_channels, out_channels, kernel_size, kernel_size),
+                            dtype=torch.float32) / (kernel_size**2)
         self.register_buffer('weight', weight)
 
     def forward(self, input):
@@ -155,9 +154,8 @@ class SmoothConv3D(nn.Module):
         assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
         self.padding = (kernel_size - 1) // 2
 
-        weight = torch.ones(
-            (in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32
-        ) / (kernel_size**3)
+        weight = torch.ones((in_channels, out_channels, kernel_size, kernel_size, kernel_size),
+                            dtype=torch.float32) / (kernel_size**3)
         self.register_buffer('weight', weight)
 
     def forward(self, input):
@@ -185,9 +183,8 @@ def build_smooth_conv2D(in_channels=1, out_channels=1, kernel_size=3, padding=1)
         kernel_size=kernel_size,
         padding=padding
     )
-    smooth_conv.weight.data = torch.ones(
-        (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32
-    ) / (kernel_size**2)
+    smooth_conv.weight.data = torch.ones((in_channels, out_channels, kernel_size, kernel_size),
+                                         dtype=torch.float32) / (kernel_size**2)
     smooth_conv.bias.data = torch.zeros(out_channels)
     return smooth_conv
 
diff --git a/lib/common/train_util.py b/lib/common/train_util.py
index a39102a5849e25d056b9d96d2df9538790bec6ea..324547c05fd7281381262f16d6240d1b9f2240da 100644
--- a/lib/common/train_util.py
+++ b/lib/common/train_util.py
@@ -14,11 +14,12 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
+import pytorch_lightning as pl
 import torch
+from termcolor import colored
+
 from ..dataset.mesh_util import *
 from ..net.geometry import orthogonal
-from termcolor import colored
-import pytorch_lightning as pl
 
 
 class Format:
@@ -30,50 +31,23 @@ def init_loss():
 
     losses = {
     # Cloth: chamfer distance
-        "cloth": {
-            "weight": 1e3,
-            "value": 0.0
-        },
+        "cloth": {"weight": 1e3, "value": 0.0},
     # Stiffness: [RT]_v1 - [RT]_v2 (v1-edge-v2)
-        "stiff": {
-            "weight": 1e5,
-            "value": 0.0
-        },
+        "stiff": {"weight": 1e5, "value": 0.0},
     # Cloth: det(R) = 1
-        "rigid": {
-            "weight": 1e5,
-            "value": 0.0
-        },
+        "rigid": {"weight": 1e5, "value": 0.0},
     # Cloth: edge length
-        "edge": {
-            "weight": 0,
-            "value": 0.0
-        },
+        "edge": {"weight": 0, "value": 0.0},
     # Cloth: normal consistency
-        "nc": {
-            "weight": 0,
-            "value": 0.0
-        },
+        "nc": {"weight": 0, "value": 0.0},
     # Cloth: laplacian smoonth
-        "lapla": {
-            "weight": 1e2,
-            "value": 0.0
-        },
+        "lapla": {"weight": 1e2, "value": 0.0},
     # Body: Normal_pred - Normal_smpl
-        "normal": {
-            "weight": 1e0,
-            "value": 0.0
-        },
+        "normal": {"weight": 1e0, "value": 0.0},
     # Body: Silhouette_pred - Silhouette_smpl
-        "silhouette": {
-            "weight": 1e0,
-            "value": 0.0
-        },
+        "silhouette": {"weight": 1e0, "value": 0.0},
     # Joint: reprojected joints difference
-        "joint": {
-            "weight": 5e0,
-            "value": 0.0
-        },
+        "joint": {"weight": 5e0, "value": 0.0},
     }
 
     return losses
@@ -143,9 +117,9 @@ def query_func_IF(batch, netG, points):
 
 
 def batch_mean(res, key):
-    return torch.stack(
-        [x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res]
-    ).mean()
+    return torch.stack([
+        x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res
+    ]).mean()
 
 
 def accumulate(outputs, rot_num, split):
diff --git a/lib/common/voxelize.py b/lib/common/voxelize.py
index f792189ccc185e9a7b596eae5a9230fe21482aef..44cbf2b4f43aeed217a56b543feddf6110336773 100644
--- a/lib/common/voxelize.py
+++ b/lib/common/voxelize.py
@@ -1,15 +1,14 @@
-import trimesh
-import numpy as np
 import os
 import traceback
 
-import torch
 import numpy as np
+import torch
 import trimesh
 from scipy import ndimage
 from skimage.measure import block_reduce
-from lib.common.libvoxelize.voxelize import voxelize_mesh_
+
 from lib.common.libmesh.inside_mesh import check_mesh_contains
+from lib.common.libvoxelize.voxelize import voxelize_mesh_
 
 # From Occupancy Networks, Mescheder et. al. CVPR'19
 
@@ -147,76 +146,63 @@ class VoxelGrid:
         f2_r_x, f2_r_y, f2_r_z = np.where(f2_r)
         f3_r_x, f3_r_y, f3_r_z = np.where(f3_r)
 
-        faces_1_l = np.stack(
-            [
-                v_idx[f1_l_x, f1_l_y, f1_l_z],
-                v_idx[f1_l_x, f1_l_y, f1_l_z + 1],
-                v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1],
-                v_idx[f1_l_x, f1_l_y + 1, f1_l_z],
-            ],
-            axis=1
-        )
-
-        faces_1_r = np.stack(
-            [
-                v_idx[f1_r_x, f1_r_y, f1_r_z],
-                v_idx[f1_r_x, f1_r_y + 1, f1_r_z],
-                v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1],
-                v_idx[f1_r_x, f1_r_y, f1_r_z + 1],
-            ],
-            axis=1
-        )
-
-        faces_2_l = np.stack(
-            [
-                v_idx[f2_l_x, f2_l_y, f2_l_z],
-                v_idx[f2_l_x + 1, f2_l_y, f2_l_z],
-                v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1],
-                v_idx[f2_l_x, f2_l_y, f2_l_z + 1],
-            ],
-            axis=1
-        )
-
-        faces_2_r = np.stack(
-            [
-                v_idx[f2_r_x, f2_r_y, f2_r_z],
-                v_idx[f2_r_x, f2_r_y, f2_r_z + 1],
-                v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1],
-                v_idx[f2_r_x + 1, f2_r_y, f2_r_z],
-            ],
-            axis=1
-        )
-
-        faces_3_l = np.stack(
-            [
-                v_idx[f3_l_x, f3_l_y, f3_l_z],
-                v_idx[f3_l_x, f3_l_y + 1, f3_l_z],
-                v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z],
-                v_idx[f3_l_x + 1, f3_l_y, f3_l_z],
-            ],
-            axis=1
-        )
-
-        faces_3_r = np.stack(
-            [
-                v_idx[f3_r_x, f3_r_y, f3_r_z],
-                v_idx[f3_r_x + 1, f3_r_y, f3_r_z],
-                v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z],
-                v_idx[f3_r_x, f3_r_y + 1, f3_r_z],
-            ],
-            axis=1
-        )
-
-        faces = np.concatenate(
-            [
-                faces_1_l,
-                faces_1_r,
-                faces_2_l,
-                faces_2_r,
-                faces_3_l,
-                faces_3_r,
-            ], axis=0
-        )
+        faces_1_l = np.stack([
+            v_idx[f1_l_x, f1_l_y, f1_l_z],
+            v_idx[f1_l_x, f1_l_y, f1_l_z + 1],
+            v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1],
+            v_idx[f1_l_x, f1_l_y + 1, f1_l_z],
+        ],
+                             axis=1)
+
+        faces_1_r = np.stack([
+            v_idx[f1_r_x, f1_r_y, f1_r_z],
+            v_idx[f1_r_x, f1_r_y + 1, f1_r_z],
+            v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1],
+            v_idx[f1_r_x, f1_r_y, f1_r_z + 1],
+        ],
+                             axis=1)
+
+        faces_2_l = np.stack([
+            v_idx[f2_l_x, f2_l_y, f2_l_z],
+            v_idx[f2_l_x + 1, f2_l_y, f2_l_z],
+            v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1],
+            v_idx[f2_l_x, f2_l_y, f2_l_z + 1],
+        ],
+                             axis=1)
+
+        faces_2_r = np.stack([
+            v_idx[f2_r_x, f2_r_y, f2_r_z],
+            v_idx[f2_r_x, f2_r_y, f2_r_z + 1],
+            v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1],
+            v_idx[f2_r_x + 1, f2_r_y, f2_r_z],
+        ],
+                             axis=1)
+
+        faces_3_l = np.stack([
+            v_idx[f3_l_x, f3_l_y, f3_l_z],
+            v_idx[f3_l_x, f3_l_y + 1, f3_l_z],
+            v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z],
+            v_idx[f3_l_x + 1, f3_l_y, f3_l_z],
+        ],
+                             axis=1)
+
+        faces_3_r = np.stack([
+            v_idx[f3_r_x, f3_r_y, f3_r_z],
+            v_idx[f3_r_x + 1, f3_r_y, f3_r_z],
+            v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z],
+            v_idx[f3_r_x, f3_r_y + 1, f3_r_z],
+        ],
+                             axis=1)
+
+        faces = np.concatenate([
+            faces_1_l,
+            faces_1_r,
+            faces_2_l,
+            faces_2_r,
+            faces_3_l,
+            faces_3_r,
+        ],
+                               axis=0)
 
         vertices = self.loc + self.scale * vertices
         mesh = trimesh.Trimesh(vertices, faces, process=False)
diff --git a/lib/dataset/EvalDataset.py b/lib/dataset/EvalDataset.py
index d7aca595c2f3027ee3ba3fdfa9a99ea742afe2b7..9ef6426285b103b04db4f0c827c6755056bcc01b 100644
--- a/lib/dataset/EvalDataset.py
+++ b/lib/dataset/EvalDataset.py
@@ -14,22 +14,24 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import torch.nn.functional as F
-from lib.common.render import Render
-from lib.dataset.mesh_util import (SMPLX, projection, rescale_smpl, HoppeMesh)
-import os.path as osp
-import numpy as np
-from PIL import Image
 import os
+import os.path as osp
+
 import cv2
-import trimesh
+import numpy as np
 import torch
+import torch.nn.functional as F
 import torchvision.transforms as transforms
+import trimesh
+from PIL import Image
+
+from lib.common.render import Render
+from lib.dataset.mesh_util import SMPLX, HoppeMesh, projection, rescale_smpl
 
 cape_gender = {
     "male":
-        ['00032', '00096', '00122', '00127', '00145', '00215', '02474', '03284', '03375', '03394'],
-    "female": ['00134', '00159', '03223', '03331', '03383']
+    ['00032', '00096', '00122', '00127', '00145', '00215', '02474', '03284', '03375',
+     '03394'], "female": ['00134', '00159', '03223', '03331', '03383']
 }
 
 
@@ -74,30 +76,27 @@ class EvalDataset:
                 "scale": self.scales[dataset_id],
             }
 
-            self.datasets_dict[dataset].update(
-                {"subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str)}
-            )
+            self.datasets_dict[dataset].update({
+                "subjects":
+                np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str)
+            })
 
         self.subject_list = self.get_subject_list()
         self.smplx = SMPLX()
 
         # PIL to tensor
-        self.image_to_tensor = transforms.Compose(
-            [
-                transforms.Resize(self.input_size),
-                transforms.ToTensor(),
-                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
-            ]
-        )
+        self.image_to_tensor = transforms.Compose([
+            transforms.Resize(self.input_size),
+            transforms.ToTensor(),
+            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+        ])
 
         # PIL to tensor
-        self.mask_to_tensor = transforms.Compose(
-            [
-                transforms.Resize(self.input_size),
-                transforms.ToTensor(),
-                transforms.Normalize((0.0, ), (1.0, )),
-            ]
-        )
+        self.mask_to_tensor = transforms.Compose([
+            transforms.Resize(self.input_size),
+            transforms.ToTensor(),
+            transforms.Normalize((0.0, ), (1.0, )),
+        ])
 
         self.device = device
         self.render = Render(size=512, device=self.device)
@@ -154,27 +153,23 @@ class EvalDataset:
         }
 
         if dataset == "cape":
-            data_dict.update(
-                {
-                    "mesh_path":
-                        osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}.obj"),
-                    "smpl_path":
-                        osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.obj"),
-                }
-            )
+            data_dict.update({
+                "mesh_path":
+                osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}.obj"),
+                "smpl_path":
+                osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.obj"),
+            })
         else:
 
-            data_dict.update(
-                {
-                    "mesh_path":
-                        osp.join(
-                            self.datasets_dict[dataset]["mesh_dir"],
-                            f"{subject}.obj",
-                        ),
-                    "smplx_path":
-                        osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}.obj"),
-                }
-            )
+            data_dict.update({
+                "mesh_path":
+                osp.join(
+                    self.datasets_dict[dataset]["mesh_dir"],
+                    f"{subject}.obj",
+                ),
+                "smplx_path":
+                osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}.obj"),
+            })
 
         # load training data
         data_dict.update(self.load_calib(data_dict))
@@ -183,18 +178,17 @@ class EvalDataset:
         for name, channel in zip(self.in_total, self.in_total_dim):
 
             if f"{name}_path" not in data_dict.keys():
-                data_dict.update(
-                    {
-                        f"{name}_path":
-                            osp.join(self.root, render_folder, name, f"{rotation:03d}.png")
-                    }
-                )
+                data_dict.update({
+                    f"{name}_path":
+                    osp.join(self.root, render_folder, name, f"{rotation:03d}.png")
+                })
 
             # tensor update
             if os.path.exists(data_dict[f"{name}_path"]):
-                data_dict.update(
-                    {name: self.imagepath2tensor(data_dict[f"{name}_path"], channel, inv=False)}
-                )
+                data_dict.update({
+                    name:
+                    self.imagepath2tensor(data_dict[f"{name}_path"], channel, inv=False)
+                })
 
         data_dict.update(self.load_mesh(data_dict))
         data_dict.update(self.load_smpl(data_dict))
diff --git a/lib/dataset/Evaluator.py b/lib/dataset/Evaluator.py
index 066e183617fbe79532d7a7f291227d8dfaef9c7c..3d840033735f04447fc296888e2775a5f8686b5e 100644
--- a/lib/dataset/Evaluator.py
+++ b/lib/dataset/Evaluator.py
@@ -14,20 +14,21 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from lib.dataset.mesh_util import projection
-from lib.common.render import Render
+from typing import Tuple
+
 import numpy as np
 import torch
-from torchvision.utils import make_grid
+from PIL import Image
 from pytorch3d import _C
+from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
+from pytorch3d.ops.packed_to_padded import packed_to_padded
+from pytorch3d.structures import Pointclouds
 from torch.autograd import Function
 from torch.autograd.function import once_differentiable
-from pytorch3d.structures import Pointclouds
-from PIL import Image
+from torchvision.utils import make_grid
 
-from typing import Tuple
-from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
-from pytorch3d.ops.packed_to_padded import packed_to_padded
+from lib.common.render import Render
+from lib.dataset.mesh_util import projection
 
 _DEFAULT_MIN_TRIANGLE_AREA: float = 5e-3
 
@@ -278,12 +279,10 @@ class Evaluator:
 
         # error_hf = ((((src_normal_arr - tgt_normal_arr) * sim_mask)**2).sum(dim=0).mean()) * 4.0
 
-        normal_img = Image.fromarray(
-            (
-                torch.cat([src_normal_arr, tgt_normal_arr],
-                          dim=1).permute(1, 2, 0).detach().cpu().numpy() * 255.0
-            ).astype(np.uint8)
-        )
+        normal_img = Image.fromarray((
+            torch.cat([src_normal_arr, tgt_normal_arr],
+                      dim=1).permute(1, 2, 0).detach().cpu().numpy() * 255.0
+        ).astype(np.uint8))
         normal_img.save(normal_path)
 
         return error
diff --git a/lib/dataset/NormalDataset.py b/lib/dataset/NormalDataset.py
index 3567ac8cd5a83517a93c80c008bbb9b8d23616a7..59132030d61c27f4cf6883025a440613d3eecbce 100644
--- a/lib/dataset/NormalDataset.py
+++ b/lib/dataset/NormalDataset.py
@@ -14,12 +14,13 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import kornia
 import os.path as osp
+
+import kornia
 import numpy as np
+import torchvision.transforms as transforms
 from PIL import Image
 from termcolor import colored
-import torchvision.transforms as transforms
 
 
 class NormalDataset:
@@ -59,22 +60,18 @@ class NormalDataset:
         self.subject_list = self.get_subject_list(split)
 
         # PIL to tensor
-        self.image_to_tensor = transforms.Compose(
-            [
-                transforms.Resize(self.input_size),
-                transforms.ToTensor(),
-                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
-            ]
-        )
+        self.image_to_tensor = transforms.Compose([
+            transforms.Resize(self.input_size),
+            transforms.ToTensor(),
+            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+        ])
 
         # PIL to tensor
-        self.mask_to_tensor = transforms.Compose(
-            [
-                transforms.Resize(self.input_size),
-                transforms.ToTensor(),
-                transforms.Normalize((0.0, ), (1.0, )),
-            ]
-        )
+        self.mask_to_tensor = transforms.Compose([
+            transforms.Resize(self.input_size),
+            transforms.ToTensor(),
+            transforms.Normalize((0.0, ), (1.0, )),
+        ])
 
     def get_subject_list(self, split):
 
@@ -128,21 +125,15 @@ class NormalDataset:
         for name, channel in zip(self.in_total, self.in_total_dim):
 
             if f"{name}_path" not in data_dict.keys():
-                data_dict.update(
-                    {
-                        f"{name}_path":
-                            osp.join(self.root, render_folder, name, f"{rotation:03d}.png")
-                    }
-                )
-
-            data_dict.update(
-                {
-                    name:
-                        self.imagepath2tensor(
-                            data_dict[f"{name}_path"], channel, inv=False, erasing=False
-                        )
-                }
-            )
+                data_dict.update({
+                    f"{name}_path":
+                    osp.join(self.root, render_folder, name, f"{rotation:03d}.png")
+                })
+
+            data_dict.update({
+                name:
+                self.imagepath2tensor(data_dict[f"{name}_path"], channel, inv=False, erasing=False)
+            })
 
         path_keys = [key for key in data_dict.keys() if "_path" in key or "_dir" in key]
 
diff --git a/lib/dataset/NormalModule.py b/lib/dataset/NormalModule.py
index ff672b3c42f5951f4ebf6c8446014d1d277ab02c..16dd02ec26789d40715b24b67f371da45aff2f8f 100644
--- a/lib/dataset/NormalModule.py
+++ b/lib/dataset/NormalModule.py
@@ -14,11 +14,11 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from torch.utils.data import DataLoader
-from lib.dataset.NormalDataset import NormalDataset
-
 # pytorch lightning related libs
 import pytorch_lightning as pl
+from torch.utils.data import DataLoader
+
+from lib.dataset.NormalDataset import NormalDataset
 
 
 class NormalModule(pl.LightningDataModule):
diff --git a/lib/dataset/PointFeat.py b/lib/dataset/PointFeat.py
index 457b949e5ce712a1eace33b1306fd48613ba8887..f6b5a708cca885f99205662fe6b42e13bf432da7 100644
--- a/lib/dataset/PointFeat.py
+++ b/lib/dataset/PointFeat.py
@@ -1,5 +1,6 @@
-from pytorch3d.structures import Meshes, Pointclouds
 import torch
+from pytorch3d.structures import Meshes, Pointclouds
+
 from lib.common.render_utils import face_vertices
 from lib.dataset.Evaluator import point_mesh_distance
 from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection
diff --git a/lib/dataset/TestDataset.py b/lib/dataset/TestDataset.py
index e99627ec8eeb6a369e10caf4d19c198910fc4a2e..b6eef20dd4a33654fa10b2f68bc5eb82715fa7db 100644
--- a/lib/dataset/TestDataset.py
+++ b/lib/dataset/TestDataset.py
@@ -14,37 +14,34 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import warnings
 import logging
+import warnings
 
 warnings.filterwarnings("ignore")
 logging.getLogger("lightning").setLevel(logging.ERROR)
 logging.getLogger("trimesh").setLevel(logging.ERROR)
 
-from lib.pixielib.utils.config import cfg as pixie_cfg
-from lib.pixielib.pixie import PIXIE
-from lib.pixielib.models.SMPLX import SMPLX as PIXIE_SMPLX
-from lib.common.imutils import process_image
-from lib.common.train_util import Format
-from lib.net.geometry import rotation_matrix_to_angle_axis, rot6d_to_rotmat
-
-from lib.pymafx.core import path_config
-from lib.pymafx.models import pymaf_net
+import glob
+import os.path as osp
 
-from lib.common.config import cfg
-from lib.common.render import Render
-from lib.dataset.body_model import TetraSMPLModel
-from lib.dataset.mesh_util import get_visibility, SMPLX
+import numpy as np
+import torch
 import torch.nn.functional as F
+from PIL import ImageFile
+from termcolor import colored
 from torchvision import transforms
 from torchvision.models import detection
 
-import os.path as osp
-import torch
-import glob
-import numpy as np
-from termcolor import colored
-from PIL import ImageFile
+from lib.common.config import cfg
+from lib.common.imutils import process_image
+from lib.common.render import Render
+from lib.common.train_util import Format
+from lib.dataset.mesh_util import SMPLX, get_visibility
+from lib.pixielib.models.SMPLX import SMPLX as PIXIE_SMPLX
+from lib.pixielib.pixie import PIXIE
+from lib.pixielib.utils.config import cfg as pixie_cfg
+from lib.pymafx.core import path_config
+from lib.pymafx.models import pymaf_net
 
 ImageFile.LOAD_TRUNCATED_IMAGES = True
 
@@ -66,9 +63,8 @@ class TestDataset:
         keep_lst = sorted(glob.glob(f"{self.image_dir}/*"))
         img_fmts = ["jpg", "png", "jpeg", "JPG", "bmp", "exr"]
 
-        self.subject_list = sorted(
-            [item for item in keep_lst if item.split(".")[-1] in img_fmts], reverse=False
-        )
+        self.subject_list = sorted([item for item in keep_lst if item.split(".")[-1] in img_fmts],
+                                   reverse=False)
 
         # smpl related
         self.smpl_data = SMPLX()
diff --git a/lib/dataset/body_model.py b/lib/dataset/body_model.py
index cebb105591cab29d833f2965ec609c85fd522881..dff2e07ef1277930974dc066da1a2685bcd0cdf3 100644
--- a/lib/dataset/body_model.py
+++ b/lib/dataset/body_model.py
@@ -14,10 +14,11 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import numpy as np
+import os
 import pickle
+
+import numpy as np
 import torch
-import os
 
 
 class SMPLModel:
@@ -126,12 +127,10 @@ class SMPLModel:
         for i in range(1, self.kintree_table.shape[1]):
             G[i] = G[self.parent[i]].dot(
                 self.with_zeros(
-                    np.hstack(
-                        [
-                            self.R[i],
-                            ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
-                        ]
-                    )
+                    np.hstack([
+                        self.R[i],
+                        ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
+                    ])
                 )
             )
         # remove the transformation due to the rest pose
@@ -163,19 +162,17 @@ class SMPLModel:
         r_hat = r / theta
         cos = np.cos(theta)
         z_stick = np.zeros(theta.shape[0])
-        m = np.dstack(
-            [
-                z_stick,
-                -r_hat[:, 0, 2],
-                r_hat[:, 0, 1],
-                r_hat[:, 0, 2],
-                z_stick,
-                -r_hat[:, 0, 0],
-                -r_hat[:, 0, 1],
-                r_hat[:, 0, 0],
-                z_stick,
-            ]
-        ).reshape([-1, 3, 3])
+        m = np.dstack([
+            z_stick,
+            -r_hat[:, 0, 2],
+            r_hat[:, 0, 1],
+            r_hat[:, 0, 2],
+            z_stick,
+            -r_hat[:, 0, 0],
+            -r_hat[:, 0, 1],
+            r_hat[:, 0, 0],
+            z_stick,
+        ]).reshape([-1, 3, 3])
         i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3])
         A = np.transpose(r_hat, axes=[0, 2, 1])
         B = r_hat
@@ -357,12 +354,10 @@ class TetraSMPLModel:
         for i in range(1, self.kintree_table.shape[1]):
             G[i] = G[self.parent[i]].dot(
                 self.with_zeros(
-                    np.hstack(
-                        [
-                            self.R[i],
-                            ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
-                        ]
-                    )
+                    np.hstack([
+                        self.R[i],
+                        ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
+                    ])
                 )
             )
         # remove the transformation due to the rest pose
@@ -398,19 +393,17 @@ class TetraSMPLModel:
         r_hat = r / theta
         cos = np.cos(theta)
         z_stick = np.zeros(theta.shape[0])
-        m = np.dstack(
-            [
-                z_stick,
-                -r_hat[:, 0, 2],
-                r_hat[:, 0, 1],
-                r_hat[:, 0, 2],
-                z_stick,
-                -r_hat[:, 0, 0],
-                -r_hat[:, 0, 1],
-                r_hat[:, 0, 0],
-                z_stick,
-            ]
-        ).reshape([-1, 3, 3])
+        m = np.dstack([
+            z_stick,
+            -r_hat[:, 0, 2],
+            r_hat[:, 0, 1],
+            r_hat[:, 0, 2],
+            z_stick,
+            -r_hat[:, 0, 0],
+            -r_hat[:, 0, 1],
+            r_hat[:, 0, 0],
+            z_stick,
+        ]).reshape([-1, 3, 3])
         i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3])
         A = np.transpose(r_hat, axes=[0, 2, 1])
         B = r_hat
diff --git a/lib/dataset/mesh_util.py b/lib/dataset/mesh_util.py
index 7dc46d51c3abd9a4e22c7bbba96c78ef86e03dbf..6f257765db55681016424c446db86e1f157e3866 100644
--- a/lib/dataset/mesh_util.py
+++ b/lib/dataset/mesh_util.py
@@ -14,25 +14,25 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
+import json
 import os
+import os.path as osp
+
+import _pickle as cPickle
 import numpy as np
+import open3d as o3d
 import torch
+import torch.nn.functional as F
 import torchvision
 import trimesh
-import json
-import open3d as o3d
-import os.path as osp
-import _pickle as cPickle
-from termcolor import colored
+from PIL import Image, ImageDraw, ImageFont
+from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency
+from pytorch3d.renderer.mesh import rasterize_meshes
+from pytorch3d.structures import Meshes
 from scipy.spatial import cKDTree
 
-from pytorch3d.structures import Meshes
-import torch.nn.functional as F
 import lib.smplx as smplx
-from lib.common.render_utils import Pytorch3dRasterizer
-from pytorch3d.renderer.mesh import rasterize_meshes
-from PIL import Image, ImageFont, ImageDraw
-from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency
+from lib.common.render_utils import Pytorch3dRasterizer, face_vertices
 
 
 class Format:
@@ -74,19 +74,17 @@ class SMPLX:
         self.smplx_vertex_lmkid = np.load(self.smplx_vertex_lmkid_path)
 
         self.smpl_vert_seg = json.load(open(self.smpl_vert_seg_path))
-        self.smpl_mano_vid = np.concatenate(
-            [
-                self.smpl_vert_seg["rightHand"], self.smpl_vert_seg["rightHandIndex1"],
-                self.smpl_vert_seg["leftHand"], self.smpl_vert_seg["leftHandIndex1"]
-            ]
-        )
+        self.smpl_mano_vid = np.concatenate([
+            self.smpl_vert_seg["rightHand"], self.smpl_vert_seg["rightHandIndex1"],
+            self.smpl_vert_seg["leftHand"], self.smpl_vert_seg["leftHandIndex1"]
+        ])
 
         self.smplx_eyeball_fid_mask = np.load(self.smplx_eyeball_fid_path)
         self.smplx_mouth_fid = np.load(self.smplx_fill_mouth_fid_path)
         self.smplx_mano_vid_dict = np.load(self.smplx_mano_vid_path, allow_pickle=True)
-        self.smplx_mano_vid = np.concatenate(
-            [self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]]
-        )
+        self.smplx_mano_vid = np.concatenate([
+            self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]
+        ])
         self.smplx_flame_vid = np.load(self.smplx_flame_vid_path, allow_pickle=True)
         self.smplx_front_flame_vid = self.smplx_flame_vid[np.load(self.front_flame_path)]
 
@@ -110,26 +108,22 @@ class SMPLX:
 
         self.model_dir = osp.join(self.current_dir, "models")
 
-        self.ghum_smpl_pairs = torch.tensor(
-            [
-                (0, 24), (2, 26), (5, 25), (7, 28), (8, 27), (11, 16), (12, 17), (13, 18), (14, 19),
-                (15, 20), (16, 21), (17, 39), (18, 44), (19, 36), (20, 41), (21, 35), (22, 40),
-                (23, 1), (24, 2), (25, 4), (26, 5), (27, 7), (28, 8), (29, 31), (30, 34), (31, 29),
-                (32, 32)
-            ]
-        ).long()
+        self.ghum_smpl_pairs = torch.tensor([(0, 24), (2, 26), (5, 25), (7, 28), (8, 27), (11, 16),
+                                             (12, 17), (13, 18), (14, 19), (15, 20), (16, 21),
+                                             (17, 39), (18, 44), (19, 36), (20, 41), (21, 35),
+                                             (22, 40), (23, 1), (24, 2), (25, 4), (26, 5), (27, 7),
+                                             (28, 8), (29, 31), (30, 34), (31, 29),
+                                             (32, 32)]).long()
 
         # smpl-smplx correspondence
         self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73]
         self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [61 + 68, 72 + 68]
         self.smpl_joint_ids_45 = np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist()
 
-        self.extra_joint_ids = np.array(
-            [
-                61, 72, 66, 69, 58, 68, 57, 56, 64, 59, 67, 75, 70, 65, 60, 61, 63, 62, 76, 71, 72,
-                74, 73
-            ]
-        )
+        self.extra_joint_ids = np.array([
+            61, 72, 66, 69, 58, 68, 57, 56, 64, 59, 67, 75, 70, 65, 60, 61, 63, 62, 76, 71, 72, 74,
+            73
+        ])
 
         self.extra_joint_ids += 68
 
@@ -369,9 +363,9 @@ def mesh_edge_loss(meshes, target_length: float = 0.0):
     return loss_all
 
 
-def remesh_laplacian(mesh, obj_path):
+def remesh_laplacian(mesh, obj_path, face_count=50000):
 
-    mesh = mesh.simplify_quadratic_decimation(50000)
+    mesh = mesh.simplify_quadratic_decimation(face_count)
     mesh = trimesh.smoothing.filter_humphrey(
         mesh, alpha=0.1, beta=0.5, iterations=10, laplacian_operator=None
     )
@@ -380,7 +374,7 @@ def remesh_laplacian(mesh, obj_path):
     return mesh
 
 
-def poisson(mesh, obj_path, depth=10, decimation=True):
+def poisson(mesh, obj_path, depth=10, face_count=50000):
 
     pcd_path = obj_path[:-4] + "_soups.ply"
     assert (mesh.vertex_normals.shape[1] == 3)
@@ -395,12 +389,9 @@ def poisson(mesh, obj_path, depth=10, decimation=True):
     largest_mesh = keep_largest(trimesh.Trimesh(np.array(mesh.vertices), np.array(mesh.triangles)))
     largest_mesh.export(obj_path)
 
-    if decimation:
-        # mesh decimation for faster rendering
-        low_res_mesh = largest_mesh.simplify_quadratic_decimation(50000)
-        return low_res_mesh
-    else:
-        return largest_mesh
+    # mesh decimation for faster rendering
+    low_res_mesh = largest_mesh.simplify_quadratic_decimation(face_count)
+    return low_res_mesh
 
 
 # Losses to smooth / regularize the mesh shape
@@ -437,10 +428,9 @@ def read_smpl_constants(folder):
     smpl_tetras = (np.loadtxt(os.path.join(folder, "tetrahedrons.txt"), dtype=np.int32) - 1)
 
     return_dict = {
-        "smpl_vertex_code": torch.tensor(smpl_vertex_code),
-        "smpl_face_code": torch.tensor(smpl_face_code),
-        "smpl_faces": torch.tensor(smpl_faces),
-        "smpl_tetras": torch.tensor(smpl_tetras)
+        "smpl_vertex_code": torch.tensor(smpl_vertex_code), "smpl_face_code":
+        torch.tensor(smpl_face_code), "smpl_faces": torch.tensor(smpl_faces), "smpl_tetras":
+        torch.tensor(smpl_tetras)
     }
 
     return return_dict
@@ -598,22 +588,6 @@ def compute_normal(vertices, faces):
     return vert_norms, face_norms
 
 
-def face_vertices(vertices, faces):
-    """
-    :param vertices: [batch size, number of vertices, 3]
-    :param faces: [batch size, number of faces, 3]
-    :return: [batch size, number of faces, 3, 3]
-    """
-
-    bs, nv = vertices.shape[:2]
-    bs, nf = faces.shape[:2]
-    device = vertices.device
-    faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
-    vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
-
-    return vertices[faces.long()]
-
-
 def compute_normal_batch(vertices, faces):
 
     if faces.shape[0] != vertices.shape[0]:
@@ -657,20 +631,18 @@ def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"):
             draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font)
 
         if type == "smpl":
-            for col_id, col_txt in enumerate(
-                [
-                    "image",
-                    "smpl-norm(render)",
-                    "cloth-norm(pred)",
-                    "diff-norm",
-                    "diff-mask",
-                ]
-            ):
+            for col_id, col_txt in enumerate([
+                "image",
+                "smpl-norm(render)",
+                "cloth-norm(pred)",
+                "diff-norm",
+                "diff-mask",
+            ]):
                 draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
         elif type == "cloth":
-            for col_id, col_txt in enumerate(
-                ["image", "cloth-norm(recon)", "cloth-norm(pred)", "diff-norm"]
-            ):
+            for col_id, col_txt in enumerate([
+                "image", "cloth-norm(recon)", "cloth-norm(pred)", "diff-norm"
+            ]):
                 draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
             for col_id, col_txt in enumerate(["0", "90", "180", "270"]):
                 draw.text(
@@ -751,3 +723,61 @@ def get_joint_mesh(joints, radius=2.0):
         else:
             combined = sum([combined, ball_new])
     return combined
+
+
+def preprocess_point_cloud(pcd, voxel_size):
+    pcd_down = pcd
+    pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
+        pcd_down, o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 5.0, max_nn=100)
+    )
+    return (pcd_down, pcd_fpfh)
+
+
+def o3d_ransac(src, dst):
+
+    voxel_size = 0.01
+    distance_threshold = 1.5 * voxel_size
+
+    o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)
+
+    # print('Downsampling inputs')
+    src_down, src_fpfh = preprocess_point_cloud(src, voxel_size)
+    dst_down, dst_fpfh = preprocess_point_cloud(dst, voxel_size)
+
+    # print('Running RANSAC')
+    result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
+        src_down,
+        dst_down,
+        src_fpfh,
+        dst_fpfh,
+        mutual_filter=False,
+        max_correspondence_distance=distance_threshold,
+        estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
+        ransac_n=3,
+        checkers=[
+            o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
+            o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)
+        ],
+        criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(1000000, 0.999)
+    )
+
+    return result.transformation
+
+
+def export_obj(v_np, f_np, vt, ft, path):
+
+    # write mtl info into obj
+    new_line = f"mtllib material.mtl \n"
+    vt_lines = "\nusemtl mat0 \n"
+    v_lines = ""
+    f_lines = ""
+
+    for _v in v_np:
+        v_lines += f"v {_v[0]} {_v[1]} {_v[2]}\n"
+    for fid, _f in enumerate(f_np):
+        f_lines += f"f {_f[0]+1}/{ft[fid][0]+1} {_f[1]+1}/{ft[fid][1]+1} {_f[2]+1}/{ft[fid][2]+1}\n"
+    for _vt in vt:
+        vt_lines += f"vt {_vt[0]} {_vt[1]}\n"
+    new_file_data = new_line + v_lines + vt_lines + f_lines
+    with open(path, 'w') as file:
+        file.write(new_file_data)
diff --git a/lib/net/BasePIFuNet.py b/lib/net/BasePIFuNet.py
index eb18dbb3245d57c9e030c18094322a58e874db93..3b5a77aec804f0fcbf72f6320b3ef30cdeb83ea1 100644
--- a/lib/net/BasePIFuNet.py
+++ b/lib/net/BasePIFuNet.py
@@ -14,8 +14,8 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import torch.nn as nn
 import pytorch_lightning as pl
+import torch.nn as nn
 
 from .geometry import index, orthogonal, perspective
 
diff --git a/lib/net/Discriminator.py b/lib/net/Discriminator.py
index c60acdde000d414c78af0705ba268af3117c6ec9..b47ef9fd05ef645950be61111d417638a57ae3c6 100644
--- a/lib/net/Discriminator.py
+++ b/lib/net/Discriminator.py
@@ -1,11 +1,16 @@
 """ The code is based on https://github.com/apple/ml-gsn/ with adaption. """
 
 import math
+
 import torch
 import torch.nn as nn
-import math
 import torch.nn.functional as F
-from lib.torch_utils.ops.native_ops import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
+
+from lib.torch_utils.ops.native_ops import (
+    FusedLeakyReLU,
+    fused_leaky_relu,
+    upfirdn2d,
+)
 
 
 class DiscriminatorHead(nn.Module):
diff --git a/lib/net/FBNet.py b/lib/net/FBNet.py
index f4797667d4d800019967d7ee2ed944ec8b8550fc..5e04d5a04551c186379847637cc0a8d4b813b3da 100644
--- a/lib/net/FBNet.py
+++ b/lib/net/FBNet.py
@@ -19,13 +19,14 @@ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
 WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING 
 OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 """
-import torch
-import torch.nn as nn
 import functools
+
 import numpy as np
 import pytorch_lightning as pl
-from torchvision import models
+import torch
+import torch.nn as nn
 import torch.nn.functional as F
+from torchvision import models
 
 
 ###############################################################################
@@ -313,34 +314,28 @@ class NLayerDiscriminator(nn.Module):
 
         kw = 4
         padw = int(np.ceil((kw - 1.0) / 2))
-        sequence = [
-            [
-                nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
-                nn.LeakyReLU(0.2, True)
-            ]
-        ]
+        sequence = [[
+            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
+            nn.LeakyReLU(0.2, True)
+        ]]
 
         nf = ndf
         for n in range(1, n_layers):
             nf_prev = nf
             nf = min(nf * 2, 512)
-            sequence += [
-                [
-                    nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
-                    norm_layer(nf),
-                    nn.LeakyReLU(0.2, True)
-                ]
-            ]
+            sequence += [[
+                nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
+                norm_layer(nf),
+                nn.LeakyReLU(0.2, True)
+            ]]
 
         nf_prev = nf
         nf = min(nf * 2, 512)
-        sequence += [
-            [
-                nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
-                norm_layer(nf),
-                nn.LeakyReLU(0.2, True)
-            ]
-        ]
+        sequence += [[
+            nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
+            norm_layer(nf),
+            nn.LeakyReLU(0.2, True)
+        ]]
 
         sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
 
@@ -632,18 +627,16 @@ class GANLoss(pl.LightningModule):
     def get_target_tensor(self, input, target_is_real):
         target_tensor = None
         if target_is_real:
-            create_label = (
-                (self.real_label_var is None) or (self.real_label_var.numel() != input.numel())
-            )
+            create_label = ((self.real_label_var is None) or
+                            (self.real_label_var.numel() != input.numel()))
             if create_label:
                 real_tensor = self.tensor(input.size()).fill_(self.real_label)
                 self.real_label_var = real_tensor
                 self.real_label_var.requires_grad = False
             target_tensor = self.real_label_var
         else:
-            create_label = (
-                (self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())
-            )
+            create_label = ((self.fake_label_var is None) or
+                            (self.fake_label_var.numel() != input.numel()))
             if create_label:
                 fake_tensor = self.tensor(input.size()).fill_(self.fake_label)
                 self.fake_label_var = fake_tensor
diff --git a/lib/net/GANLoss.py b/lib/net/GANLoss.py
index 5d6711479980e89a3fc067b5ef579bb382eb29df..79c778a403eb249687ac4850cfb2fb84e5e1dcfb 100644
--- a/lib/net/GANLoss.py
+++ b/lib/net/GANLoss.py
@@ -2,8 +2,9 @@
 
 import torch
 import torch.nn as nn
-from torch import autograd
 import torch.nn.functional as F
+from torch import autograd
+
 from lib.net.Discriminator import StyleDiscriminator
 
 
diff --git a/lib/net/IFGeoNet.py b/lib/net/IFGeoNet.py
index e7fd92a9fe766642100beb7083203401adffa20e..06d3deb4b406f10d21fb91767d824e7569fe5928 100644
--- a/lib/net/IFGeoNet.py
+++ b/lib/net/IFGeoNet.py
@@ -1,7 +1,9 @@
 from pickle import TRUE
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+
 from lib.net.geometry import orthogonal
 
 
@@ -151,13 +153,11 @@ class IFGeoNet(nn.Module):
 
         # here every channel corresponse to one feature.
 
-        features = torch.cat(
-            (
-                feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
-                feature_6
-            ),
-            dim=1
-        )    # (B, features, 1,7,sample_num)
+        features = torch.cat((
+            feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
+            feature_6
+        ),
+                             dim=1)    # (B, features, 1,7,sample_num)
         shape = features.shape
         features = torch.reshape(
             features, (shape[0], shape[1] * shape[3], shape[4])
diff --git a/lib/net/IFGeoNet_nobody.py b/lib/net/IFGeoNet_nobody.py
index 56de86268dcfbbf4a0226c206dcbb992f906db98..1daedb9ec1bccdf1a76c1a8938a902e11dbed9dc 100644
--- a/lib/net/IFGeoNet_nobody.py
+++ b/lib/net/IFGeoNet_nobody.py
@@ -1,7 +1,9 @@
 from pickle import TRUE
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+
 from lib.net.geometry import orthogonal
 
 
@@ -136,13 +138,11 @@ class IFGeoNet(nn.Module):
 
         # here every channel corresponse to one feature.
 
-        features = torch.cat(
-            (
-                feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
-                feature_6
-            ),
-            dim=1
-        )    # (B, features, 1,7,sample_num)
+        features = torch.cat((
+            feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
+            feature_6
+        ),
+                             dim=1)    # (B, features, 1,7,sample_num)
         shape = features.shape
         features = torch.reshape(
             features, (shape[0], shape[1] * shape[3], shape[4])
diff --git a/lib/net/NormalNet.py b/lib/net/NormalNet.py
index a065840ed859137e72ba1e37a40c636da7c32e6f..d4f0fe8e9a9ab935ffc8a49ae2d22d23bc44d4cc 100644
--- a/lib/net/NormalNet.py
+++ b/lib/net/NormalNet.py
@@ -14,14 +14,14 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from lib.net.FBNet import define_G, define_D, VGGLoss, GANLoss, IDMRFLoss
-from lib.net.net_util import init_net
-from lib.net.BasePIFuNet import BasePIFuNet
-
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
+from lib.net.BasePIFuNet import BasePIFuNet
+from lib.net.FBNet import GANLoss, IDMRFLoss, VGGLoss, define_D, define_G
+from lib.net.net_util import init_net
+
 
 class NormalNet(BasePIFuNet):
     """
@@ -63,12 +63,12 @@ class NormalNet(BasePIFuNet):
         self.in_nmlB = [
             item[0] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"
         ]
-        self.in_nmlF_dim = sum(
-            [item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"]
-        )
-        self.in_nmlB_dim = sum(
-            [item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"]
-        )
+        self.in_nmlF_dim = sum([
+            item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"
+        ])
+        self.in_nmlB_dim = sum([
+            item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"
+        ])
 
         self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
         self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
diff --git a/lib/net/geometry.py b/lib/net/geometry.py
index 6d7d82d2cb6b760596d1bbf70804e542999f802e..4c98530b316a3164ad82d742ac55da9e2fc8c212 100644
--- a/lib/net/geometry.py
+++ b/lib/net/geometry.py
@@ -14,11 +14,13 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import torch
-import numpy as np
 import numbers
-from torch.nn import functional as F
+
+import numpy as np
+import torch
 from einops.einops import rearrange
+from torch.nn import functional as F
+
 """
 Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
 Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
@@ -42,13 +44,11 @@ def quaternion_to_rotation_matrix(quat):
     wx, wy, wz = w * x, w * y, w * z
     xy, xz, yz = x * y, x * z, y * z
 
-    rotMat = torch.stack(
-        [
-            w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
-            2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
-        ],
-        dim=1
-    ).view(B, 3, 3)
+    rotMat = torch.stack([
+        w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
+        2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
+    ],
+                         dim=1).view(B, 3, 3)
     return rotMat
 
 
@@ -508,12 +508,10 @@ def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_si
     weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
 
     # least squares
-    Q = np.array(
-        [
-            F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
-            O - np.reshape(joints_2d, -1)
-        ]
-    ).T
+    Q = np.array([
+        F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
+        O - np.reshape(joints_2d, -1)
+    ]).T
     c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
 
     # weighted least squares
@@ -580,13 +578,11 @@ def Rot_y(angle, category="torch", prepend_dim=True, device=None):
             prepend_dim: prepend an extra dimension
     Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
     """
-    m = np.array(
-        [
-            [np.cos(angle), 0.0, np.sin(angle)],
-            [0.0, 1.0, 0.0],
-            [-np.sin(angle), 0.0, np.cos(angle)],
-        ]
-    )
+    m = np.array([
+        [np.cos(angle), 0.0, np.sin(angle)],
+        [0.0, 1.0, 0.0],
+        [-np.sin(angle), 0.0, np.cos(angle)],
+    ])
     if category == "torch":
         if prepend_dim:
             return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -608,13 +604,11 @@ def Rot_x(angle, category="torch", prepend_dim=True, device=None):
             prepend_dim: prepend an extra dimension
     Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
     """
-    m = np.array(
-        [
-            [1.0, 0.0, 0.0],
-            [0.0, np.cos(angle), -np.sin(angle)],
-            [0.0, np.sin(angle), np.cos(angle)],
-        ]
-    )
+    m = np.array([
+        [1.0, 0.0, 0.0],
+        [0.0, np.cos(angle), -np.sin(angle)],
+        [0.0, np.sin(angle), np.cos(angle)],
+    ])
     if category == "torch":
         if prepend_dim:
             return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -636,13 +630,11 @@ def Rot_z(angle, category="torch", prepend_dim=True, device=None):
             prepend_dim: prepend an extra dimension
     Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
     """
-    m = np.array(
-        [
-            [np.cos(angle), -np.sin(angle), 0.0],
-            [np.sin(angle), np.cos(angle), 0.0],
-            [0.0, 0.0, 1.0],
-        ]
-    )
+    m = np.array([
+        [np.cos(angle), -np.sin(angle), 0.0],
+        [np.sin(angle), np.cos(angle), 0.0],
+        [0.0, 0.0, 1.0],
+    ])
     if category == "torch":
         if prepend_dim:
             return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
diff --git a/lib/net/net_util.py b/lib/net/net_util.py
index d89fcff5670909cd41c2e917e87b3bdb25870d8a..fa3c9491596688de0425b4471318ff5c23a9a909 100644
--- a/lib/net/net_util.py
+++ b/lib/net/net_util.py
@@ -14,12 +14,13 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
+import functools
+
 import torch
-from torch.nn import init
 import torch.nn as nn
 import torch.nn.functional as F
-import functools
 from torch.autograd import grad
+from torch.nn import init
 
 
 def gradient(inputs, outputs):
diff --git a/lib/net/voxelize.py b/lib/net/voxelize.py
index 394b40e6eeeb158bb691c1e518b6b1f7a889b8d8..8525ef6cf389c40d8e4a82e1a442c24915a7acd8 100644
--- a/lib/net/voxelize.py
+++ b/lib/net/voxelize.py
@@ -1,11 +1,11 @@
 from __future__ import division, print_function
+
+import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-import numpy as np
-from torch.autograd import Function
-
 import voxelize_cuda
+from torch.autograd import Function
 
 
 class VoxelizationFunction(Function):
diff --git a/lib/pixielib/models/FLAME.py b/lib/pixielib/models/FLAME.py
index b62b1069b6083685e8ff1511e57c48ccf79bc927..fb9ca09c4890f17206905546f4373ab186a5e6d1 100755
--- a/lib/pixielib/models/FLAME.py
+++ b/lib/pixielib/models/FLAME.py
@@ -13,10 +13,11 @@
 # For comments or questions, please email us at pixie@tue.mpg.de
 # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
 
+import pickle
+
+import numpy as np
 import torch
 import torch.nn as nn
-import numpy as np
-import pickle
 import torch.nn.functional as F
 
 
diff --git a/lib/pixielib/models/SMPLX.py b/lib/pixielib/models/SMPLX.py
index beb672facebc7fa9e61eee7f8e7f3f185ac6cdad..9f07f5740100133c94ba9e5f2f9767ba7ea4b42c 100644
--- a/lib/pixielib/models/SMPLX.py
+++ b/lib/pixielib/models/SMPLX.py
@@ -3,19 +3,20 @@ original from https://github.com/vchoutas/smplx
 modified by Vassilis and Yao
 """
 
+import pickle
+
+import numpy as np
 import torch
 import torch.nn as nn
-import numpy as np
-import pickle
 
 from .lbs import (
+    JointsFromVerticesSelector,
     Struct,
-    to_tensor,
-    to_np,
+    find_dynamic_lmk_idx_and_bcoords,
     lbs,
+    to_np,
+    to_tensor,
     vertices2landmarks,
-    JointsFromVerticesSelector,
-    find_dynamic_lmk_idx_and_bcoords,
 )
 
 # SMPLX
@@ -209,468 +210,452 @@ extra_names = [
 SMPLX_names += extra_names
 
 part_indices = {}
-part_indices["body"] = np.array(
-    [
-        0,
-        1,
-        2,
-        3,
-        4,
-        5,
-        6,
-        7,
-        8,
-        9,
-        10,
-        11,
-        12,
-        13,
-        14,
-        15,
-        16,
-        17,
-        18,
-        19,
-        20,
-        21,
-        22,
-        23,
-        24,
-        123,
-        124,
-        125,
-        126,
-        127,
-        132,
-        134,
-        135,
-        136,
-        137,
-        138,
-        143,
-    ]
-)
-part_indices["torso"] = np.array(
-    [
-        0,
-        1,
-        2,
-        3,
-        6,
-        9,
-        12,
-        13,
-        14,
-        15,
-        16,
-        17,
-        18,
-        19,
-        22,
-        23,
-        24,
-        55,
-        56,
-        57,
-        58,
-        59,
-        76,
-        77,
-        78,
-        79,
-        80,
-        81,
-        82,
-        83,
-        84,
-        85,
-        86,
-        87,
-        88,
-        89,
-        90,
-        91,
-        92,
-        93,
-        94,
-        95,
-        96,
-        97,
-        98,
-        99,
-        100,
-        101,
-        102,
-        103,
-        104,
-        105,
-        106,
-        107,
-        108,
-        109,
-        110,
-        111,
-        112,
-        113,
-        114,
-        115,
-        116,
-        117,
-        118,
-        119,
-        120,
-        121,
-        122,
-        123,
-        124,
-        125,
-        126,
-        127,
-        128,
-        129,
-        130,
-        131,
-        132,
-        133,
-        134,
-        135,
-        136,
-        137,
-        138,
-        139,
-        140,
-        141,
-        142,
-        143,
-        144,
-    ]
-)
-part_indices["head"] = np.array(
-    [
-        12,
-        15,
-        22,
-        23,
-        24,
-        55,
-        56,
-        57,
-        58,
-        59,
-        60,
-        61,
-        62,
-        63,
-        64,
-        65,
-        66,
-        67,
-        68,
-        69,
-        70,
-        71,
-        72,
-        73,
-        74,
-        75,
-        76,
-        77,
-        78,
-        79,
-        80,
-        81,
-        82,
-        83,
-        84,
-        85,
-        86,
-        87,
-        88,
-        89,
-        90,
-        91,
-        92,
-        93,
-        94,
-        95,
-        96,
-        97,
-        98,
-        99,
-        100,
-        101,
-        102,
-        103,
-        104,
-        105,
-        106,
-        107,
-        108,
-        109,
-        110,
-        111,
-        112,
-        113,
-        114,
-        115,
-        116,
-        117,
-        118,
-        119,
-        120,
-        121,
-        122,
-        123,
-        125,
-        126,
-        134,
-        136,
-        137,
-    ]
-)
-part_indices["face"] = np.array(
-    [
-        55,
-        56,
-        57,
-        58,
-        59,
-        60,
-        61,
-        62,
-        63,
-        64,
-        65,
-        66,
-        67,
-        68,
-        69,
-        70,
-        71,
-        72,
-        73,
-        74,
-        75,
-        76,
-        77,
-        78,
-        79,
-        80,
-        81,
-        82,
-        83,
-        84,
-        85,
-        86,
-        87,
-        88,
-        89,
-        90,
-        91,
-        92,
-        93,
-        94,
-        95,
-        96,
-        97,
-        98,
-        99,
-        100,
-        101,
-        102,
-        103,
-        104,
-        105,
-        106,
-        107,
-        108,
-        109,
-        110,
-        111,
-        112,
-        113,
-        114,
-        115,
-        116,
-        117,
-        118,
-        119,
-        120,
-        121,
-        122,
-    ]
-)
-part_indices["upper"] = np.array(
-    [
-        12,
-        13,
-        14,
-        55,
-        56,
-        57,
-        58,
-        59,
-        60,
-        61,
-        62,
-        63,
-        64,
-        65,
-        66,
-        67,
-        68,
-        69,
-        70,
-        71,
-        72,
-        73,
-        74,
-        75,
-        76,
-        77,
-        78,
-        79,
-        80,
-        81,
-        82,
-        83,
-        84,
-        85,
-        86,
-        87,
-        88,
-        89,
-        90,
-        91,
-        92,
-        93,
-        94,
-        95,
-        96,
-        97,
-        98,
-        99,
-        100,
-        101,
-        102,
-        103,
-        104,
-        105,
-        106,
-        107,
-        108,
-        109,
-        110,
-        111,
-        112,
-        113,
-        114,
-        115,
-        116,
-        117,
-        118,
-        119,
-        120,
-        121,
-        122,
-    ]
-)
-part_indices["hand"] = np.array(
-    [
-        20,
-        21,
-        25,
-        26,
-        27,
-        28,
-        29,
-        30,
-        31,
-        32,
-        33,
-        34,
-        35,
-        36,
-        37,
-        38,
-        39,
-        40,
-        41,
-        42,
-        43,
-        44,
-        45,
-        46,
-        47,
-        48,
-        49,
-        50,
-        51,
-        52,
-        53,
-        54,
-        128,
-        129,
-        130,
-        131,
-        133,
-        139,
-        140,
-        141,
-        142,
-        144,
-    ]
-)
-part_indices["left_hand"] = np.array(
-    [
-        20,
-        25,
-        26,
-        27,
-        28,
-        29,
-        30,
-        31,
-        32,
-        33,
-        34,
-        35,
-        36,
-        37,
-        38,
-        39,
-        128,
-        129,
-        130,
-        131,
-        133,
-    ]
-)
-part_indices["right_hand"] = np.array(
-    [
-        21,
-        40,
-        41,
-        42,
-        43,
-        44,
-        45,
-        46,
-        47,
-        48,
-        49,
-        50,
-        51,
-        52,
-        53,
-        54,
-        139,
-        140,
-        141,
-        142,
-        144,
-    ]
-)
+part_indices["body"] = np.array([
+    0,
+    1,
+    2,
+    3,
+    4,
+    5,
+    6,
+    7,
+    8,
+    9,
+    10,
+    11,
+    12,
+    13,
+    14,
+    15,
+    16,
+    17,
+    18,
+    19,
+    20,
+    21,
+    22,
+    23,
+    24,
+    123,
+    124,
+    125,
+    126,
+    127,
+    132,
+    134,
+    135,
+    136,
+    137,
+    138,
+    143,
+])
+part_indices["torso"] = np.array([
+    0,
+    1,
+    2,
+    3,
+    6,
+    9,
+    12,
+    13,
+    14,
+    15,
+    16,
+    17,
+    18,
+    19,
+    22,
+    23,
+    24,
+    55,
+    56,
+    57,
+    58,
+    59,
+    76,
+    77,
+    78,
+    79,
+    80,
+    81,
+    82,
+    83,
+    84,
+    85,
+    86,
+    87,
+    88,
+    89,
+    90,
+    91,
+    92,
+    93,
+    94,
+    95,
+    96,
+    97,
+    98,
+    99,
+    100,
+    101,
+    102,
+    103,
+    104,
+    105,
+    106,
+    107,
+    108,
+    109,
+    110,
+    111,
+    112,
+    113,
+    114,
+    115,
+    116,
+    117,
+    118,
+    119,
+    120,
+    121,
+    122,
+    123,
+    124,
+    125,
+    126,
+    127,
+    128,
+    129,
+    130,
+    131,
+    132,
+    133,
+    134,
+    135,
+    136,
+    137,
+    138,
+    139,
+    140,
+    141,
+    142,
+    143,
+    144,
+])
+part_indices["head"] = np.array([
+    12,
+    15,
+    22,
+    23,
+    24,
+    55,
+    56,
+    57,
+    58,
+    59,
+    60,
+    61,
+    62,
+    63,
+    64,
+    65,
+    66,
+    67,
+    68,
+    69,
+    70,
+    71,
+    72,
+    73,
+    74,
+    75,
+    76,
+    77,
+    78,
+    79,
+    80,
+    81,
+    82,
+    83,
+    84,
+    85,
+    86,
+    87,
+    88,
+    89,
+    90,
+    91,
+    92,
+    93,
+    94,
+    95,
+    96,
+    97,
+    98,
+    99,
+    100,
+    101,
+    102,
+    103,
+    104,
+    105,
+    106,
+    107,
+    108,
+    109,
+    110,
+    111,
+    112,
+    113,
+    114,
+    115,
+    116,
+    117,
+    118,
+    119,
+    120,
+    121,
+    122,
+    123,
+    125,
+    126,
+    134,
+    136,
+    137,
+])
+part_indices["face"] = np.array([
+    55,
+    56,
+    57,
+    58,
+    59,
+    60,
+    61,
+    62,
+    63,
+    64,
+    65,
+    66,
+    67,
+    68,
+    69,
+    70,
+    71,
+    72,
+    73,
+    74,
+    75,
+    76,
+    77,
+    78,
+    79,
+    80,
+    81,
+    82,
+    83,
+    84,
+    85,
+    86,
+    87,
+    88,
+    89,
+    90,
+    91,
+    92,
+    93,
+    94,
+    95,
+    96,
+    97,
+    98,
+    99,
+    100,
+    101,
+    102,
+    103,
+    104,
+    105,
+    106,
+    107,
+    108,
+    109,
+    110,
+    111,
+    112,
+    113,
+    114,
+    115,
+    116,
+    117,
+    118,
+    119,
+    120,
+    121,
+    122,
+])
+part_indices["upper"] = np.array([
+    12,
+    13,
+    14,
+    55,
+    56,
+    57,
+    58,
+    59,
+    60,
+    61,
+    62,
+    63,
+    64,
+    65,
+    66,
+    67,
+    68,
+    69,
+    70,
+    71,
+    72,
+    73,
+    74,
+    75,
+    76,
+    77,
+    78,
+    79,
+    80,
+    81,
+    82,
+    83,
+    84,
+    85,
+    86,
+    87,
+    88,
+    89,
+    90,
+    91,
+    92,
+    93,
+    94,
+    95,
+    96,
+    97,
+    98,
+    99,
+    100,
+    101,
+    102,
+    103,
+    104,
+    105,
+    106,
+    107,
+    108,
+    109,
+    110,
+    111,
+    112,
+    113,
+    114,
+    115,
+    116,
+    117,
+    118,
+    119,
+    120,
+    121,
+    122,
+])
+part_indices["hand"] = np.array([
+    20,
+    21,
+    25,
+    26,
+    27,
+    28,
+    29,
+    30,
+    31,
+    32,
+    33,
+    34,
+    35,
+    36,
+    37,
+    38,
+    39,
+    40,
+    41,
+    42,
+    43,
+    44,
+    45,
+    46,
+    47,
+    48,
+    49,
+    50,
+    51,
+    52,
+    53,
+    54,
+    128,
+    129,
+    130,
+    131,
+    133,
+    139,
+    140,
+    141,
+    142,
+    144,
+])
+part_indices["left_hand"] = np.array([
+    20,
+    25,
+    26,
+    27,
+    28,
+    29,
+    30,
+    31,
+    32,
+    33,
+    34,
+    35,
+    36,
+    37,
+    38,
+    39,
+    128,
+    129,
+    130,
+    131,
+    133,
+])
+part_indices["right_hand"] = np.array([
+    21,
+    40,
+    41,
+    42,
+    43,
+    44,
+    45,
+    46,
+    47,
+    48,
+    49,
+    50,
+    51,
+    52,
+    53,
+    54,
+    139,
+    140,
+    141,
+    142,
+    144,
+])
 # kinematic tree
 head_kin_chain = [15, 12, 9, 6, 3, 0]
 
diff --git a/lib/pixielib/models/encoders.py b/lib/pixielib/models/encoders.py
index 0783c9265ab442a259fd693a55039026cc7608db..44f979a2063fe62e3de451bebb267a8852e85955 100755
--- a/lib/pixielib/models/encoders.py
+++ b/lib/pixielib/models/encoders.py
@@ -1,6 +1,6 @@
 import numpy as np
-import torch.nn as nn
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
 
 
diff --git a/lib/pixielib/models/hrnet.py b/lib/pixielib/models/hrnet.py
index c1fd871abf8ae79dd87f96e30d14d726c913db05..665b96efa29fb273b2e28773e5ea35391d99b90e 100644
--- a/lib/pixielib/models/hrnet.py
+++ b/lib/pixielib/models/hrnet.py
@@ -3,10 +3,10 @@ borrowed from https://github.com/vchoutas/expose/blob/master/expose/models/backb
 """
 
 import os.path as osp
+
 import torch
 import torch.nn as nn
-
-from torchvision.models.resnet import Bottleneck, BasicBlock
+from torchvision.models.resnet import BasicBlock, Bottleneck
 
 BN_MOMENTUM = 0.1
 
@@ -15,42 +15,38 @@ def load_HRNet(pretrained=False):
     hr_net_cfg_dict = {
         "use_old_impl": False,
         "pretrained_layers": ["*"],
-        "stage1":
-            {
-                "num_modules": 1,
-                "num_branches": 1,
-                "num_blocks": [4],
-                "num_channels": [64],
-                "block": "BOTTLENECK",
-                "fuse_method": "SUM",
-            },
-        "stage2":
-            {
-                "num_modules": 1,
-                "num_branches": 2,
-                "num_blocks": [4, 4],
-                "num_channels": [48, 96],
-                "block": "BASIC",
-                "fuse_method": "SUM",
-            },
-        "stage3":
-            {
-                "num_modules": 4,
-                "num_branches": 3,
-                "num_blocks": [4, 4, 4],
-                "num_channels": [48, 96, 192],
-                "block": "BASIC",
-                "fuse_method": "SUM",
-            },
-        "stage4":
-            {
-                "num_modules": 3,
-                "num_branches": 4,
-                "num_blocks": [4, 4, 4, 4],
-                "num_channels": [48, 96, 192, 384],
-                "block": "BASIC",
-                "fuse_method": "SUM",
-            },
+        "stage1": {
+            "num_modules": 1,
+            "num_branches": 1,
+            "num_blocks": [4],
+            "num_channels": [64],
+            "block": "BOTTLENECK",
+            "fuse_method": "SUM",
+        },
+        "stage2": {
+            "num_modules": 1,
+            "num_branches": 2,
+            "num_blocks": [4, 4],
+            "num_channels": [48, 96],
+            "block": "BASIC",
+            "fuse_method": "SUM",
+        },
+        "stage3": {
+            "num_modules": 4,
+            "num_branches": 3,
+            "num_blocks": [4, 4, 4],
+            "num_channels": [48, 96, 192],
+            "block": "BASIC",
+            "fuse_method": "SUM",
+        },
+        "stage4": {
+            "num_modules": 3,
+            "num_branches": 4,
+            "num_blocks": [4, 4, 4, 4],
+            "num_channels": [48, 96, 192, 384],
+            "block": "BASIC",
+            "fuse_method": "SUM",
+        },
     }
     hr_net_cfg = hr_net_cfg_dict
     model = HighResolutionNet(hr_net_cfg)
diff --git a/lib/pixielib/models/lbs.py b/lib/pixielib/models/lbs.py
index a2252a9a81c7e9ca3633a02cc08f3fafd5bd22cc..7b490bd9bc79a0e252aec2df99bead814edf4195 100755
--- a/lib/pixielib/models/lbs.py
+++ b/lib/pixielib/models/lbs.py
@@ -14,15 +14,14 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from __future__ import absolute_import
-from __future__ import print_function
-from __future__ import division
+from __future__ import absolute_import, division, print_function
 
-import numpy as np
 import os
-import yaml
+
+import numpy as np
 import torch
 import torch.nn.functional as F
+import yaml
 from torch import nn
 
 
diff --git a/lib/pixielib/models/moderators.py b/lib/pixielib/models/moderators.py
index 3ab139ac2ad3e0cbd99c8e40dbf6136a37e53cb5..205777192a5601c4f37c75a22981fcda8e0416e0 100644
--- a/lib/pixielib/models/moderators.py
+++ b/lib/pixielib/models/moderators.py
@@ -3,8 +3,8 @@
 # output: fused feature, weight
 """
 import numpy as np
-import torch.nn as nn
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
 
 # MLP + temperature softmax
diff --git a/lib/pixielib/models/resnet.py b/lib/pixielib/models/resnet.py
index 162bc655bff1bd3ca2058334de2e15660de8f5f5..9732daf972dd6d7adebb96927bbc80242a3233dd 100755
--- a/lib/pixielib/models/resnet.py
+++ b/lib/pixielib/models/resnet.py
@@ -11,13 +11,14 @@ Loads different resnet models
     mark:   copied from pytorch source code
 """
 
+import math
+
+import numpy as np
+import torch
 import torch.nn as nn
 import torch.nn.functional as F
-import torch
-from torch.nn.parameter import Parameter
-import numpy as np
-import math
 import torchvision
+from torch.nn.parameter import Parameter
 from torchvision import models
 
 
diff --git a/lib/pixielib/pixie.py b/lib/pixielib/pixie.py
index 545bc46f92b73aff4037da7ff3c6ebeba2b4c361..6cadc83bd8b67f32d5819e0d5218de194e63692c 100644
--- a/lib/pixielib/pixie.py
+++ b/lib/pixielib/pixie.py
@@ -14,21 +14,20 @@
 # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
 
 import os
-import torch
-import torchvision
-import torch.nn.functional as F
-import torch.nn as nn
 
+import cv2
 import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
 from skimage.io import imread
-import cv2
 
-from .models.encoders import ResnetEncoder, MLP, HRNEncoder
+from .models.encoders import MLP, HRNEncoder, ResnetEncoder
 from .models.moderators import TempSoftmaxFusion
 from .models.SMPLX import SMPLX
-from .utils import util
 from .utils import rotation_converter as converter
-from .utils import tensor_cropper
+from .utils import tensor_cropper, util
 from .utils.config import cfg
 
 
@@ -55,9 +54,7 @@ class PIXIE(object):
 
         # encode + decode
         param_dict = self.encode(
-            {"body": {
-                "image": data
-            }},
+            {"body": {"image": data}},
             threthold=True,
             keep_local=True,
             copy_and_paste=False,
@@ -559,9 +556,10 @@ class PIXIE(object):
         }
 
         # change the order of face keypoints, to be the same as "standard" 68 keypoints
-        prediction["face_kpt"] = torch.cat(
-            [prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]], dim=1
-        )
+        prediction["face_kpt"] = torch.cat([
+            prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]
+        ],
+                                           dim=1)
 
         prediction.update(param_dict)
 
diff --git a/lib/pixielib/utils/array_cropper.py b/lib/pixielib/utils/array_cropper.py
index fbee84b6a6f0f3dcad7fcd6b33bf03faf56be625..d18e15b504fd7d894dcb3ed72830374a9837e42b 100644
--- a/lib/pixielib/utils/array_cropper.py
+++ b/lib/pixielib/utils/array_cropper.py
@@ -8,7 +8,7 @@ only support crop to squared images
 """
 
 import numpy as np
-from skimage.transform import estimate_transform, warp, resize, rescale
+from skimage.transform import estimate_transform, rescale, resize, warp
 
 
 def points2bbox(points, points_scale=None):
@@ -47,13 +47,11 @@ def crop_array(image, center, bboxsize, crop_size):
         tform: 3x3 affine matrix
     """
     # points: top-left, top-right, bottom-right
-    src_pts = np.array(
-        [
-            [center[0] - bboxsize / 2, center[1] - bboxsize / 2],
-            [center[0] + bboxsize / 2, center[1] - bboxsize / 2],
-            [center[0] + bboxsize / 2, center[1] + bboxsize / 2],
-        ]
-    )
+    src_pts = np.array([
+        [center[0] - bboxsize / 2, center[1] - bboxsize / 2],
+        [center[0] + bboxsize / 2, center[1] - bboxsize / 2],
+        [center[0] + bboxsize / 2, center[1] + bboxsize / 2],
+    ])
     DST_PTS = np.array([[0, 0], [crop_size - 1, 0], [crop_size - 1, crop_size - 1]])
 
     # estimate transformation between points
diff --git a/lib/pixielib/utils/config.py b/lib/pixielib/utils/config.py
index 115a38e9c52b7cf025defa4a3d37d9490fc71833..cd8c87dc43a3adb4ae7c3f3a69595b0fb35d2d07 100644
--- a/lib/pixielib/utils/config.py
+++ b/lib/pixielib/utils/config.py
@@ -1,11 +1,12 @@
 """
 Default config for PIXIE
 """
-from yacs.config import CfgNode as CN
 import argparse
-import yaml
 import os
 
+import yaml
+from yacs.config import CfgNode as CN
+
 cfg = CN()
 
 abs_pixie_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
diff --git a/lib/pixielib/utils/renderer.py b/lib/pixielib/utils/renderer.py
index eb2dc795e01b3e5c78a4ce848777d6cbc5558401..efcc560d7856959f70005858432171776b1bffc7 100755
--- a/lib/pixielib/utils/renderer.py
+++ b/lib/pixielib/utils/renderer.py
@@ -3,12 +3,12 @@ Author: Yao Feng
 Copyright (c) 2020, Yao Feng
 All rights reserved.
 """
+import imageio
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from skimage.io import imread
-import imageio
 
 from . import util
 
@@ -16,17 +16,18 @@ from . import util
 def set_rasterizer(type="pytorch3d"):
     if type == "pytorch3d":
         global Meshes, load_obj, rasterize_meshes
-        from pytorch3d.structures import Meshes
         from pytorch3d.io import load_obj
         from pytorch3d.renderer.mesh import rasterize_meshes
+        from pytorch3d.structures import Meshes
     elif type == "standard":
         global standard_rasterize, load_obj
         import os
-        from .util import load_obj
 
         # Use JIT Compiling Extensions
         # ref: https://pytorch.org/tutorials/advanced/cpp_extension.html
-        from torch.utils.cpp_extension import load, CUDA_HOME
+        from torch.utils.cpp_extension import CUDA_HOME, load
+
+        from .util import load_obj
 
         curr_dir = os.path.dirname(__file__)
         standard_rasterize_cuda = load(
@@ -207,19 +208,17 @@ class SRenderY(nn.Module):
 
         # SH factors for lighting
         pi = np.pi
-        constant_factor = torch.tensor(
-            [
-                1 / np.sqrt(4 * pi),
-                ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
-                ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
-                ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
-                (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
-                (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
-                (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
-                (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))),
-                (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))),
-            ]
-        ).float()
+        constant_factor = torch.tensor([
+            1 / np.sqrt(4 * pi),
+            ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
+            ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
+            ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
+            (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
+            (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
+            (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
+            (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))),
+            (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))),
+        ]).float()
         self.register_buffer("constant_factor", constant_factor)
 
     def forward(
@@ -310,17 +309,17 @@ class SRenderY(nn.Module):
                         normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
                         lights,
                     )
-                    shading_images = shading.reshape(
-                        [batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]
-                    ).permute(0, 3, 1, 2)
+                    shading_images = shading.reshape([
+                        batch_size, albedo_images.shape[2], albedo_images.shape[3], 3
+                    ]).permute(0, 3, 1, 2)
                 else:
                     shading = self.add_directionlight(
                         normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
                         lights,
                     )
-                    shading_images = shading.reshape(
-                        [batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]
-                    ).permute(0, 3, 1, 2)
+                    shading_images = shading.reshape([
+                        batch_size, albedo_images.shape[2], albedo_images.shape[3], 3
+                    ]).permute(0, 3, 1, 2)
             images = albedo_images * shading_images
         else:
             images = albedo_images
@@ -402,9 +401,8 @@ class SRenderY(nn.Module):
         )
         # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.)
         # normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3)
-        normals_dot_lights = torch.clamp(
-            (normals[:, None, :, :] * directions_to_lights).sum(dim=3), 0.0, 1.0
-        )
+        normals_dot_lights = torch.clamp((normals[:, None, :, :] * directions_to_lights).sum(dim=3),
+                                         0.0, 1.0)
         shading = normals_dot_lights[:, :, :, None] * light_intensities[:, :, None, :]
         return shading.mean(1)
 
diff --git a/lib/pixielib/utils/rotation_converter.py b/lib/pixielib/utils/rotation_converter.py
index f8057cab4e0f84d035a0b8f964823bd61e91dae4..e66386d6427db73e50f791a6fd479f085230da51 100644
--- a/lib/pixielib/utils/rotation_converter.py
+++ b/lib/pixielib/utils/rotation_converter.py
@@ -1,6 +1,7 @@
+import numpy as np
 import torch
 import torch.nn.functional as F
-import numpy as np
+
 """ Rotation Converter
     This function is borrowed from https://github.com/kornia/kornia
 
diff --git a/lib/pixielib/utils/tensor_cropper.py b/lib/pixielib/utils/tensor_cropper.py
index c486f7709ad9216080102ee275f7165d276eb0ce..6863ff044a71d054b460f78557f6d09d11f20a30 100644
--- a/lib/pixielib/utils/tensor_cropper.py
+++ b/lib/pixielib/utils/tensor_cropper.py
@@ -8,9 +8,9 @@ only support crop to squared images
 """
 import torch
 from kornia.geometry.transform.imgwarp import (
-    warp_perspective,
     get_perspective_transform,
     warp_affine,
+    warp_perspective,
 )
 
 
diff --git a/lib/pixielib/utils/util.py b/lib/pixielib/utils/util.py
index 566eda3a6e6ddf7f236bf4e20bf7220b39981ce3..5affbd8314fe175cd673ad07a64aa6964e581cbf 100755
--- a/lib/pixielib/utils/util.py
+++ b/lib/pixielib/utils/util.py
@@ -1,10 +1,11 @@
+import os
+import pickle
+from collections import OrderedDict
+
+import cv2
 import numpy as np
 import torch
 import torch.nn.functional as F
-from collections import OrderedDict
-import os
-import cv2
-import pickle
 
 # ---------------------------- process/generate vertices, normals, faces
 
diff --git a/lib/pymafx/core/cfgs.py b/lib/pymafx/core/cfgs.py
index c970c6c0caafe7a4c2f3abbb311adcd0cef42b94..17abd247de8d335131d8facc866d95e485ea9a7a 100644
--- a/lib/pymafx/core/cfgs.py
+++ b/lib/pymafx/core/cfgs.py
@@ -14,12 +14,13 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import os
+import argparse
 import json
+import os
 import random
 import string
-import argparse
 from datetime import datetime
+
 from yacs.config import CfgNode as CN
 
 # Configuration variables
diff --git a/lib/pymafx/core/constants.py b/lib/pymafx/core/constants.py
index 5354a289f892a764a16221b469fc49794ff54127..24077a0c7b89215315b39dcbdf9335193ee6ce50 100644
--- a/lib/pymafx/core/constants.py
+++ b/lib/pymafx/core/constants.py
@@ -79,55 +79,16 @@ JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
 
 # Map joints to SMPL joints
 JOINT_MAP = {
-    'OP Nose': 24,
-    'OP Neck': 12,
-    'OP RShoulder': 17,
-    'OP RElbow': 19,
-    'OP RWrist': 21,
-    'OP LShoulder': 16,
-    'OP LElbow': 18,
-    'OP LWrist': 20,
-    'OP MidHip': 0,
-    'OP RHip': 2,
-    'OP RKnee': 5,
-    'OP RAnkle': 8,
-    'OP LHip': 1,
-    'OP LKnee': 4,
-    'OP LAnkle': 7,
-    'OP REye': 25,
-    'OP LEye': 26,
-    'OP REar': 27,
-    'OP LEar': 28,
-    'OP LBigToe': 29,
-    'OP LSmallToe': 30,
-    'OP LHeel': 31,
-    'OP RBigToe': 32,
-    'OP RSmallToe': 33,
-    'OP RHeel': 34,
-    'Right Ankle': 8,
-    'Right Knee': 5,
-    'Right Hip': 45,
-    'Left Hip': 46,
-    'Left Knee': 4,
-    'Left Ankle': 7,
-    'Right Wrist': 21,
-    'Right Elbow': 19,
-    'Right Shoulder': 17,
-    'Left Shoulder': 16,
-    'Left Elbow': 18,
-    'Left Wrist': 20,
-    'Neck (LSP)': 47,
-    'Top of Head (LSP)': 48,
-    'Pelvis (MPII)': 49,
-    'Thorax (MPII)': 50,
-    'Spine (H36M)': 51,
-    'Jaw (H36M)': 52,
-    'Head (H36M)': 53,
-    'Nose': 24,
-    'Left Eye': 26,
-    'Right Eye': 25,
-    'Left Ear': 28,
-    'Right Ear': 27
+    'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, 'OP RElbow': 19, 'OP RWrist': 21,
+    'OP LShoulder': 16, 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, 'OP RHip': 2, 'OP RKnee':
+    5, 'OP RAnkle': 8, 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, 'OP REye': 25, 'OP LEye': 26,
+    'OP REar': 27, 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, 'OP LHeel': 31,
+    'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, 'Right Ankle': 8, 'Right Knee': 5,
+    'Right Hip': 45, 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, 'Right Wrist': 21,
+    'Right Elbow': 19, 'Right Shoulder': 17, 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist':
+    20, 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
+    'Spine (H36M)': 51, 'Jaw (H36M)': 52, 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
+    'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
 }
 
 # Joint selectors
@@ -163,30 +124,11 @@ SMPL_J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15,
 SMPLX2SMPL_J45 = [i for i in range(22)] + [30, 45] + [i for i in range(55, 55 + 21)]
 
 SMPL_PART_ID = {
-    'rightHand': 1,
-    'rightUpLeg': 2,
-    'leftArm': 3,
-    'leftLeg': 4,
-    'leftToeBase': 5,
-    'leftFoot': 6,
-    'spine1': 7,
-    'spine2': 8,
-    'leftShoulder': 9,
-    'rightShoulder': 10,
-    'rightFoot': 11,
-    'head': 12,
-    'rightArm': 13,
-    'leftHandIndex1': 14,
-    'rightLeg': 15,
-    'rightHandIndex1': 16,
-    'leftForeArm': 17,
-    'rightForeArm': 18,
-    'neck': 19,
-    'rightToeBase': 20,
-    'spine': 21,
-    'leftUpLeg': 22,
-    'leftHand': 23,
-    'hips': 24
+    'rightHand': 1, 'rightUpLeg': 2, 'leftArm': 3, 'leftLeg': 4, 'leftToeBase': 5, 'leftFoot': 6,
+    'spine1': 7, 'spine2': 8, 'leftShoulder': 9, 'rightShoulder': 10, 'rightFoot': 11, 'head': 12,
+    'rightArm': 13, 'leftHandIndex1': 14, 'rightLeg': 15, 'rightHandIndex1': 16, 'leftForeArm': 17,
+    'rightForeArm': 18, 'neck': 19, 'rightToeBase': 20, 'spine': 21, 'leftUpLeg': 22, 'leftHand':
+    23, 'hips': 24
 }
 
 # MANO_NAMES = [
diff --git a/lib/pymafx/models/attention.py b/lib/pymafx/models/attention.py
index b0f7d3c5c63ba1471ff15ee1a3cf0d8c94a17699..b87008b0e68541970943eb3e04027223e85c68e8 100644
--- a/lib/pymafx/models/attention.py
+++ b/lib/pymafx/models/attention.py
@@ -4,15 +4,29 @@ Licensed under the MIT license.
 
 """
 
-from __future__ import absolute_import, division, print_function, unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
+import code
 import logging
 import math
 import os
-import code
+
 import torch
 from torch import nn
-from .transformers.bert.modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput
+
+from .transformers.bert.modeling_bert import (
+    BertEmbeddings,
+    BertIntermediate,
+    BertOutput,
+    BertPooler,
+    BertPreTrainedModel,
+    BertSelfOutput,
+)
 # import src.modeling.data.config as cfg
 # from src.modeling._gcnn import GraphConvolution, GraphResBlock
 from .transformers.bert.modeling_utils import prune_linear_layer
diff --git a/lib/pymafx/models/hmr.py b/lib/pymafx/models/hmr.py
index da5459d355d3a3f00c53638a376ab3143b23c01e..e9ba5759d7a59cb2c5b9ce0964aaf899c27a1e8a 100755
--- a/lib/pymafx/models/hmr.py
+++ b/lib/pymafx/models/hmr.py
@@ -1,13 +1,14 @@
 # This script is borrowed from https://github.com/nkolot/SPIN/blob/master/models/hmr.py
 
+import logging
+import math
+
+import numpy as np
 import torch
 import torch.nn as nn
 import torchvision.models.resnet as resnet
-import numpy as np
-import math
-from lib.net.geometry import rot6d_to_rotmat
 
-import logging
+from lib.net.geometry import rot6d_to_rotmat
 
 logger = logging.getLogger(__name__)
 
diff --git a/lib/pymafx/models/hr_module.py b/lib/pymafx/models/hr_module.py
index 7396f1ea59860235db8fdd24434114381c4a7083..ad6243a463a45733a0c518e34c0dbcb115d39bcc 100644
--- a/lib/pymafx/models/hr_module.py
+++ b/lib/pymafx/models/hr_module.py
@@ -1,13 +1,14 @@
+import logging
 import os
+
 import torch
-import torch.nn as nn
 import torch._utils
+import torch.nn as nn
 import torch.nn.functional as F
+
 # from core.cfgs import cfg
 from .res_module import BasicBlock, Bottleneck
 
-import logging
-
 logger = logging.getLogger(__name__)
 
 BN_MOMENTUM = 0.1
diff --git a/lib/pymafx/models/maf_extractor.py b/lib/pymafx/models/maf_extractor.py
index 34237bc55663dcbcbd67beb4c5d0b6e693aae266..ffe4e73427e30848798df2f57e835a8b10ae2934 100644
--- a/lib/pymafx/models/maf_extractor.py
+++ b/lib/pymafx/models/maf_extractor.py
@@ -1,23 +1,23 @@
 # This script is borrowed and extended from https://github.com/shunsukesaito/PIFu/blob/master/lib/model/SurfaceClassifier.py
 
-import torch
-import scipy
+import logging
+
 import numpy as np
+import scipy
+import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 from lib.pymafx.core import path_config
 from lib.pymafx.utils.geometry import projection
 
-import logging
-
 logger = logging.getLogger(__name__)
 
+from lib.pymafx.utils.imutils import j2d_processing
+
 from .transformers.net_utils import PosEnSine
 from .transformers.transformer_basics import OurMultiheadAttention
 
-from lib.pymafx.utils.imutils import j2d_processing
-
 
 class TransformerDecoderUnit(nn.Module):
     def __init__(
diff --git a/lib/pymafx/models/pose_resnet.py b/lib/pymafx/models/pose_resnet.py
index d97b6609cf02fd2a94d2951f82f71de2be2356c0..16b22e815f715d2ae8e5f217431055ee2ba57ddf 100644
--- a/lib/pymafx/models/pose_resnet.py
+++ b/lib/pymafx/models/pose_resnet.py
@@ -4,12 +4,10 @@
 # Written by Bin Xiao (Bin.Xiao@microsoft.com)
 # ------------------------------------------------------------------------------
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import, division, print_function
 
-import os
 import logging
+import os
 
 import torch
 import torch.nn as nn
@@ -277,11 +275,8 @@ class PoseResNet(nn.Module):
 
 
 resnet_spec = {
-    18: (BasicBlock, [2, 2, 2, 2]),
-    34: (BasicBlock, [3, 4, 6, 3]),
-    50: (Bottleneck, [3, 4, 6, 3]),
-    101: (Bottleneck, [3, 4, 23, 3]),
-    152: (Bottleneck, [3, 8, 36, 3])
+    18: (BasicBlock, [2, 2, 2, 2]), 34: (BasicBlock, [3, 4, 6, 3]), 50: (Bottleneck, [3, 4, 6, 3]),
+    101: (Bottleneck, [3, 4, 23, 3]), 152: (Bottleneck, [3, 8, 36, 3])
 }
 
 
diff --git a/lib/pymafx/models/pymaf_net.py b/lib/pymafx/models/pymaf_net.py
index ca57e4b1c8ce971d76ce53d02827f441016a19ab..1f46c10f4bb0384277a4951135ce6edaabe5ac7d 100644
--- a/lib/pymafx/models/pymaf_net.py
+++ b/lib/pymafx/models/pymaf_net.py
@@ -1,21 +1,34 @@
+import logging
+
+import numpy as np
 import torch
 import torch.nn as nn
-import numpy as np
-from lib.pymafx.core import constants
 
 from lib.common.config import cfg
-from lib.pymafx.utils.geometry import rot6d_to_rotmat, rotmat_to_rot6d, projection, rotation_matrix_to_angle_axis, compute_twist_rotation
-from .maf_extractor import MAF_Extractor, Mesh_Sampler
-from .smpl import SMPL, SMPL_MODEL_DIR, SMPL_MEAN_PARAMS, get_partial_smpl, SMPL_Family
+from lib.pymafx.core import constants
+from lib.pymafx.utils.cam_params import homo_vector
+from lib.pymafx.utils.geometry import (
+    compute_twist_rotation,
+    projection,
+    rot6d_to_rotmat,
+    rotation_matrix_to_angle_axis,
+    rotmat_to_rot6d,
+)
+from lib.pymafx.utils.imutils import j2d_processing
 from lib.smplx.lbs import batch_rodrigues
-from .res_module import IUV_predict_layer
+
+from .attention import get_att_block
 from .hr_module import get_hrnet_encoder
+from .maf_extractor import MAF_Extractor, Mesh_Sampler
 from .pose_resnet import get_resnet_encoder
-from lib.pymafx.utils.imutils import j2d_processing
-from lib.pymafx.utils.cam_params import homo_vector
-from .attention import get_att_block
-
-import logging
+from .res_module import IUV_predict_layer
+from .smpl import (
+    SMPL,
+    SMPL_MEAN_PARAMS,
+    SMPL_MODEL_DIR,
+    SMPL_Family,
+    get_partial_smpl,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -169,14 +182,14 @@ class Regressor(nn.Module):
 
         if not self.smpl_mode:
             lhand_mean_rot6d = rotmat_to_rot6d(
-                batch_rodrigues(self.smpl.model.model_neutral.left_hand_mean.view(-1, 3)).view(
-                    [-1, 3, 3]
-                )
+                batch_rodrigues(self.smpl.model.model_neutral.left_hand_mean.view(-1, 3)).view([
+                    -1, 3, 3
+                ])
             )
             rhand_mean_rot6d = rotmat_to_rot6d(
-                batch_rodrigues(self.smpl.model.model_neutral.right_hand_mean.view(-1, 3)).view(
-                    [-1, 3, 3]
-                )
+                batch_rodrigues(self.smpl.model.model_neutral.right_hand_mean.view(-1, 3)).view([
+                    -1, 3, 3
+                ])
             )
             init_lhand = lhand_mean_rot6d.reshape(-1).unsqueeze(0)
             init_rhand = rhand_mean_rot6d.reshape(-1).unsqueeze(0)
@@ -300,9 +313,9 @@ class Regressor(nn.Module):
                         else:
                             vfov = rw_cam['vfov'][:, None]
                             crop_ratio = rw_cam['crop_ratio'][:, None]
-                            crop_center = rw_cam['bbox_center'] / torch.cat(
-                                [rw_cam['img_w'][:, None], rw_cam['img_h'][:, None]], 1
-                            )
+                            crop_center = rw_cam['bbox_center'] / torch.cat([
+                                rw_cam['img_w'][:, None], rw_cam['img_h'][:, None]
+                            ], 1)
                         xc = torch.cat([xc, vfov, crop_ratio, crop_center], 1)
 
                     xc = self.fc1(xc)
@@ -338,9 +351,9 @@ class Regressor(nn.Module):
                         pred_lhand = self.decrhand(xc_lhand) + pred_lhand
 
                         if cfg.MODEL.PyMAF.OPT_WRIST:
-                            xc_lhand = torch.cat(
-                                [xc_lhand, pred_shape_lh, pred_orient_lh, pred_cam_lh], 1
-                            )
+                            xc_lhand = torch.cat([
+                                xc_lhand, pred_shape_lh, pred_orient_lh, pred_cam_lh
+                            ], 1)
                             xc_lhand = self.drop3_hand(self.fc3_hand(xc_lhand))
 
                             pred_shape_lh = self.decshape_rhand(xc_lhand) + pred_shape_lh
@@ -353,9 +366,9 @@ class Regressor(nn.Module):
                         pred_rhand = self.decrhand(xc_rhand) + pred_rhand
 
                         if cfg.MODEL.MESH_MODEL == 'mano' or cfg.MODEL.PyMAF.OPT_WRIST:
-                            xc_rhand = torch.cat(
-                                [xc_rhand, pred_shape_rh, pred_orient_rh, pred_cam_rh], 1
-                            )
+                            xc_rhand = torch.cat([
+                                xc_rhand, pred_shape_rh, pred_orient_rh, pred_cam_rh
+                            ], 1)
                             xc_rhand = self.drop3_hand(self.fc3_hand(xc_rhand))
 
                             pred_shape_rh = self.decshape_rhand(xc_rhand) + pred_shape_rh
@@ -363,9 +376,10 @@ class Regressor(nn.Module):
                             pred_cam_rh = self.deccam_rhand(xc_rhand) + pred_cam_rh
 
                             if cfg.MODEL.MESH_MODEL == 'mano':
-                                pred_cam = torch.cat(
-                                    [pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:] / 10.], dim=1
-                                )
+                                pred_cam = torch.cat([
+                                    pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:] / 10.
+                                ],
+                                                     dim=1)
 
                     if 'face' in self.part_names:
                         xc_face = self.drop1_face(self.fc1_face(xc_face))
@@ -374,9 +388,9 @@ class Regressor(nn.Module):
                         pred_exp = self.decexp(xc_face) + pred_exp
 
                         if cfg.MODEL.MESH_MODEL == 'flame':
-                            xc_face = torch.cat(
-                                [xc_face, pred_shape_fa, pred_orient_fa, pred_cam_fa], 1
-                            )
+                            xc_face = torch.cat([
+                                xc_face, pred_shape_fa, pred_orient_fa, pred_cam_fa
+                            ], 1)
                             xc_face = self.drop3_face(self.fc3_face(xc_face))
 
                             pred_shape_fa = self.decshape_face(xc_face) + pred_shape_fa
@@ -384,9 +398,10 @@ class Regressor(nn.Module):
                             pred_cam_fa = self.deccam_face(xc_face) + pred_cam_fa
 
                             if cfg.MODEL.MESH_MODEL == 'flame':
-                                pred_cam = torch.cat(
-                                    [pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:] / 10.], dim=1
-                                )
+                                pred_cam = torch.cat([
+                                    pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:] / 10.
+                                ],
+                                                     dim=1)
 
                     if self.full_body_mode or self.body_hand_mode:
                         if cfg.MODEL.PyMAF.PRED_VIS_H:
@@ -500,15 +515,13 @@ class Regressor(nn.Module):
                                     opt_lelbow = torch.stack(opt_lelbow_filtered)
                                     opt_relbow = torch.stack(opt_relbow_filtered)
 
-                                pred_rotmat_body = torch.cat(
-                                    [
-                                        pred_rotmat_body[:, :18],
-                                        opt_lelbow.unsqueeze(1),
-                                        opt_relbow.unsqueeze(1),
-                                        opt_lwrist.unsqueeze(1),
-                                        opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
-                                    ], 1
-                                )
+                                pred_rotmat_body = torch.cat([
+                                    pred_rotmat_body[:, :18],
+                                    opt_lelbow.unsqueeze(1),
+                                    opt_relbow.unsqueeze(1),
+                                    opt_lwrist.unsqueeze(1),
+                                    opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
+                                ], 1)
                             else:
                                 if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == (
                                     cfg.MODEL.PyMAF.N_ITER - 1
@@ -527,13 +540,11 @@ class Regressor(nn.Module):
                                     opt_lwrist = torch.stack(opt_lwrist_filtered)
                                     opt_rwrist = torch.stack(opt_rwrist_filtered)
 
-                                pred_rotmat_body = torch.cat(
-                                    [
-                                        pred_rotmat_body[:, :20],
-                                        opt_lwrist.unsqueeze(1),
-                                        opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
-                                    ], 1
-                                )
+                                pred_rotmat_body = torch.cat([
+                                    pred_rotmat_body[:, :20],
+                                    opt_lwrist.unsqueeze(1),
+                                    opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
+                                ], 1)
 
         if self.hand_only_mode:
             pred_rotmat_rh = rot6d_to_rotmat(
@@ -630,19 +641,15 @@ class Regressor(nn.Module):
         elif self.face_only_mode:
             pred_joints_full = pred_output.face_joints
         elif self.smplx_mode:
-            pred_joints_full = torch.cat(
-                [
-                    pred_joints, pred_output.lhand_joints, pred_output.rhand_joints,
-                    pred_output.face_joints, pred_output.lfoot_joints, pred_output.rfoot_joints
-                ],
-                dim=1
-            )
+            pred_joints_full = torch.cat([
+                pred_joints, pred_output.lhand_joints, pred_output.rhand_joints,
+                pred_output.face_joints, pred_output.lfoot_joints, pred_output.rfoot_joints
+            ],
+                                         dim=1)
         else:
             pred_joints_full = pred_joints
         pred_keypoints_2d = projection(
-            pred_joints_full, {
-                **rw_cam, 'cam_sxy': pred_cam
-            }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+            pred_joints_full, {**rw_cam, 'cam_sxy': pred_cam}, iwp_mode=cfg.MODEL.USE_IWP_CAM
         )
         if cfg.MODEL.USE_IWP_CAM:
             # Normalize keypoints to [-1,1]
@@ -661,119 +668,109 @@ class Regressor(nn.Module):
             else:
                 kp_3d = pred_joints
             pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72)
-            output.update(
-                {
-                    'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
-                    'verts': pred_vertices,
-                    'kp_2d': pred_keypoints_2d[:, :len_b_kp],
-                    'kp_3d': kp_3d,
-                    'pred_joints': pred_joints,
-                    'smpl_kp_3d': pred_output.smpl_joints,
-                    'rotmat': pred_rotmat,
-                    'pred_cam': pred_cam,
-                    'pred_shape': pred_shape,
-                    'pred_pose': pred_pose,
-                }
-            )
+            output.update({
+                'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
+                'verts': pred_vertices,
+                'kp_2d': pred_keypoints_2d[:, :len_b_kp],
+                'kp_3d': kp_3d,
+                'pred_joints': pred_joints,
+                'smpl_kp_3d': pred_output.smpl_joints,
+                'rotmat': pred_rotmat,
+                'pred_cam': pred_cam,
+                'pred_shape': pred_shape,
+                'pred_pose': pred_pose,
+            })
             # if self.full_body_mode:
             if self.smplx_mode:
                 # assert pred_keypoints_2d.shape[1] == 144
                 len_h_kp = len(constants.HAND_NAMES)
                 len_f_kp = len(constants.FACIAL_LANDMARKS)
                 len_feet_kp = 2 * len(constants.FOOT_NAMES)
-                output.update(
-                    {
-                        'smplx_verts':
-                            pred_output.smplx_vertices if cfg.MODEL.EVAL_MODE else None,
-                        'pred_lhand':
-                            pred_lhand,
-                        'pred_rhand':
-                            pred_rhand,
-                        'pred_face':
-                            pred_face,
-                        'pred_exp':
-                            pred_exp,
-                        'verts_lh':
-                            pred_output.lhand_vertices,
-                        'verts_rh':
-                            pred_output.rhand_vertices,
+                output.update({
+                    'smplx_verts':
+                    pred_output.smplx_vertices if cfg.MODEL.EVAL_MODE else None,
+                    'pred_lhand':
+                    pred_lhand,
+                    'pred_rhand':
+                    pred_rhand,
+                    'pred_face':
+                    pred_face,
+                    'pred_exp':
+                    pred_exp,
+                    'verts_lh':
+                    pred_output.lhand_vertices,
+                    'verts_rh':
+                    pred_output.rhand_vertices,
                 # 'pred_arm_rotmat': pred_arm_rotmat,
                 # 'pred_hfrotmat': pred_hfrotmat,
-                        'pred_lhand_rotmat':
-                            pred_lhand_rotmat,
-                        'pred_rhand_rotmat':
-                            pred_rhand_rotmat,
-                        'pred_face_rotmat':
-                            pred_face_rotmat,
-                        'pred_lhand_kp3d':
-                            pred_output.lhand_joints,
-                        'pred_rhand_kp3d':
-                            pred_output.rhand_joints,
-                        'pred_face_kp3d':
-                            pred_output.face_joints,
-                        'pred_lhand_kp2d':
-                            pred_keypoints_2d[:, len_b_kp:len_b_kp + len_h_kp],
-                        'pred_rhand_kp2d':
-                            pred_keypoints_2d[:, len_b_kp + len_h_kp:len_b_kp + len_h_kp * 2],
-                        'pred_face_kp2d':
-                            pred_keypoints_2d[:, len_b_kp + len_h_kp * 2:len_b_kp + len_h_kp * 2 +
-                                              len_f_kp],
-                        'pred_feet_kp2d':
-                            pred_keypoints_2d[:, len_b_kp + len_h_kp * 2 + len_f_kp:len_b_kp +
-                                              len_h_kp * 2 + len_f_kp + len_feet_kp],
-                    }
-                )
+                    'pred_lhand_rotmat':
+                    pred_lhand_rotmat,
+                    'pred_rhand_rotmat':
+                    pred_rhand_rotmat,
+                    'pred_face_rotmat':
+                    pred_face_rotmat,
+                    'pred_lhand_kp3d':
+                    pred_output.lhand_joints,
+                    'pred_rhand_kp3d':
+                    pred_output.rhand_joints,
+                    'pred_face_kp3d':
+                    pred_output.face_joints,
+                    'pred_lhand_kp2d':
+                    pred_keypoints_2d[:, len_b_kp:len_b_kp + len_h_kp],
+                    'pred_rhand_kp2d':
+                    pred_keypoints_2d[:, len_b_kp + len_h_kp:len_b_kp + len_h_kp * 2],
+                    'pred_face_kp2d':
+                    pred_keypoints_2d[:,
+                                      len_b_kp + len_h_kp * 2:len_b_kp + len_h_kp * 2 + len_f_kp],
+                    'pred_feet_kp2d':
+                    pred_keypoints_2d[:, len_b_kp + len_h_kp * 2 + len_f_kp:len_b_kp +
+                                      len_h_kp * 2 + len_f_kp + len_feet_kp],
+                })
                 if cfg.MODEL.PyMAF.OPT_WRIST:
-                    output.update(
-                        {
-                            'pred_orient_lh': pred_orient_lh,
-                            'pred_shape_lh': pred_shape_lh,
-                            'pred_orient_rh': pred_orient_rh,
-                            'pred_shape_rh': pred_shape_rh,
-                            'pred_cam_fa': pred_cam_fa,
-                            'pred_cam_lh': pred_cam_lh,
-                            'pred_cam_rh': pred_cam_rh,
-                        }
-                    )
+                    output.update({
+                        'pred_orient_lh': pred_orient_lh,
+                        'pred_shape_lh': pred_shape_lh,
+                        'pred_orient_rh': pred_orient_rh,
+                        'pred_shape_rh': pred_shape_rh,
+                        'pred_cam_fa': pred_cam_fa,
+                        'pred_cam_lh': pred_cam_lh,
+                        'pred_cam_rh': pred_cam_rh,
+                    })
                 if cfg.MODEL.PyMAF.PRED_VIS_H:
                     output.update({'pred_vis_hands': pred_vis_hands})
         elif self.hand_only_mode:
             # hand mesh out
             assert pred_keypoints_2d.shape[1] == 21
-            output.update(
-                {
-                    'theta': pred_cam,
-                    'pred_cam': pred_cam,
-                    'pred_rhand': pred_rhand,
-                    'pred_rhand_rotmat': pred_rotmat_rh[:, 1:],
-                    'pred_orient_rh': pred_orient_rh,
-                    'pred_orient_rh_rotmat': pred_rotmat_rh[:, 0],
-                    'verts_rh': pred_output.rhand_vertices,
-                    'pred_cam_rh': pred_cam_rh,
-                    'pred_shape_rh': pred_shape_rh,
-                    'pred_rhand_kp3d': pred_output.rhand_joints,
-                    'pred_rhand_kp2d': pred_keypoints_2d,
-                }
-            )
+            output.update({
+                'theta': pred_cam,
+                'pred_cam': pred_cam,
+                'pred_rhand': pred_rhand,
+                'pred_rhand_rotmat': pred_rotmat_rh[:, 1:],
+                'pred_orient_rh': pred_orient_rh,
+                'pred_orient_rh_rotmat': pred_rotmat_rh[:, 0],
+                'verts_rh': pred_output.rhand_vertices,
+                'pred_cam_rh': pred_cam_rh,
+                'pred_shape_rh': pred_shape_rh,
+                'pred_rhand_kp3d': pred_output.rhand_joints,
+                'pred_rhand_kp2d': pred_keypoints_2d,
+            })
         elif self.face_only_mode:
             # face mesh out
             assert pred_keypoints_2d.shape[1] == 68
-            output.update(
-                {
-                    'theta': pred_cam,
-                    'pred_cam': pred_cam,
-                    'pred_face': pred_face,
-                    'pred_exp': pred_exp,
-                    'pred_face_rotmat': pred_rotmat_fa[:, 1:],
-                    'pred_orient_fa': pred_orient_fa,
-                    'pred_orient_fa_rotmat': pred_rotmat_fa[:, 0],
-                    'verts_fa': pred_output.flame_vertices,
-                    'pred_cam_fa': pred_cam_fa,
-                    'pred_shape_fa': pred_shape_fa,
-                    'pred_face_kp3d': pred_output.face_joints,
-                    'pred_face_kp2d': pred_keypoints_2d,
-                }
-            )
+            output.update({
+                'theta': pred_cam,
+                'pred_cam': pred_cam,
+                'pred_face': pred_face,
+                'pred_exp': pred_exp,
+                'pred_face_rotmat': pred_rotmat_fa[:, 1:],
+                'pred_orient_fa': pred_orient_fa,
+                'pred_orient_fa_rotmat': pred_rotmat_fa[:, 0],
+                'verts_fa': pred_output.flame_vertices,
+                'pred_cam_fa': pred_cam_fa,
+                'pred_shape_fa': pred_shape_fa,
+                'pred_face_kp3d': pred_output.face_joints,
+                'pred_face_kp2d': pred_keypoints_2d,
+            })
         return output
 
 
@@ -946,16 +943,16 @@ class PyMAF(nn.Module):
                     self.dp_head_hf['hand'] = IUV_predict_layer(
                         feat_dim=hf_sfeat_dim[-1], mode='pncc'
                     )
-                    self.part_module_names['hand'].update(
-                        {'dp_head_hf.hand': self.dp_head_hf['hand']}
-                    )
+                    self.part_module_names['hand'].update({
+                        'dp_head_hf.hand': self.dp_head_hf['hand']
+                    })
                 if 'face' in bhf_names:
                     self.dp_head_hf['face'] = IUV_predict_layer(
                         feat_dim=hf_sfeat_dim[-1], mode='pncc'
                     )
-                    self.part_module_names['face'].update(
-                        {'dp_head_hf.face': self.dp_head_hf['face']}
-                    )
+                    self.part_module_names['face'].update({
+                        'dp_head_hf.face': self.dp_head_hf['face']
+                    })
 
             smpl2limb_vert_faces = get_partial_smpl()
 
@@ -964,10 +961,10 @@ class PyMAF(nn.Module):
 
         # grid points for grid feature extraction
         grid_size = 21
-        xv, yv = torch.meshgrid(
-            [torch.linspace(-1, 1, grid_size),
-             torch.linspace(-1, 1, grid_size)]
-        )
+        xv, yv = torch.meshgrid([
+            torch.linspace(-1, 1, grid_size),
+            torch.linspace(-1, 1, grid_size)
+        ])
         grid_points = torch.stack([xv.reshape(-1), yv.reshape(-1)]).unsqueeze(0)
         self.register_buffer('grid_points', grid_points)
         grid_feat_dim = grid_size * grid_size * cfg.MODEL.PyMAF.MLP_DIM[-1]
@@ -995,9 +992,10 @@ class PyMAF(nn.Module):
                 bhf_att_feat_dim.update({'hand': 1024})
 
         if 'face' in self.bhf_names:
-            bhf_ma_feat_dim.update(
-                {'face': len(constants.FACIAL_LANDMARKS) * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]}
-            )
+            bhf_ma_feat_dim.update({
+                'face':
+                len(constants.FACIAL_LANDMARKS) * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]
+            })
             if self.fuse_grid_align:
                 bhf_att_feat_dim.update({'face': 1024})
 
@@ -1022,9 +1020,10 @@ class PyMAF(nn.Module):
             )
 
             for part in bhf_names:
-                self.part_module_names[part].update(
-                    {f'align_attention.{part}': self.align_attention[part]}
-                )
+                self.part_module_names[part].update({
+                    f'align_attention.{part}':
+                    self.align_attention[part]
+                })
 
         if self.fuse_grid_align:
             self.att_feat_reduce = get_fusion_modules(
@@ -1035,9 +1034,10 @@ class PyMAF(nn.Module):
                 out_feat_len=bhf_att_feat_dim
             )
             for part in bhf_names:
-                self.part_module_names[part].update(
-                    {f'att_feat_reduce.{part}': self.att_feat_reduce[part]}
-                )
+                self.part_module_names[part].update({
+                    f'att_feat_reduce.{part}':
+                    self.att_feat_reduce[part]
+                })
 
         # build regressor for parameter prediction
         self.regressor = nn.ModuleList()
@@ -1109,21 +1109,25 @@ class PyMAF(nn.Module):
             # assign sub-regressor to each part
             for dec_name, dec_module in self.regressor[-1].named_children():
                 if 'hand' in dec_name:
-                    self.part_module_names['hand'].update(
-                        {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
-                    )
+                    self.part_module_names['hand'].update({
+                        'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name):
+                        dec_module
+                    })
                 elif 'face' in dec_name or 'head' in dec_name or 'exp' in dec_name:
-                    self.part_module_names['face'].update(
-                        {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
-                    )
+                    self.part_module_names['face'].update({
+                        'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name):
+                        dec_module
+                    })
                 elif 'res' in dec_name or 'vis' in dec_name:
-                    self.part_module_names['link'].update(
-                        {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
-                    )
+                    self.part_module_names['link'].update({
+                        'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name):
+                        dec_module
+                    })
                 elif 'body' in self.part_module_names:
-                    self.part_module_names['body'].update(
-                        {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
-                    )
+                    self.part_module_names['body'].update({
+                        'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name):
+                        dec_module
+                    })
 
         # mesh-aligned feature extractor
         self.maf_extractor = nn.ModuleDict()
@@ -1373,17 +1377,15 @@ class PyMAF(nn.Module):
                 pred_cam = mesh_output['pred_cam'].detach()
                 pred_rhand_v = self.mano_sampler(mesh_output['verts_rh'])
                 pred_rhand_proj = projection(
-                    pred_rhand_v, {
-                        **rw_cam, 'cam_sxy': pred_cam
-                    }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+                    pred_rhand_v, {**rw_cam, 'cam_sxy': pred_cam}, iwp_mode=cfg.MODEL.USE_IWP_CAM
                 )
                 if cfg.MODEL.USE_IWP_CAM:
                     pred_rhand_proj = pred_rhand_proj / (224. / 2.)
                 else:
                     pred_rhand_proj = j2d_processing(pred_rhand_proj, rw_cam['kps_transf'])
                 proj_hf_center = {
-                    'rhand':
-                        mesh_output['pred_rhand_kp2d'][:, self.hf_root_idx['rhand']].unsqueeze(1)
+                    'rhand': mesh_output['pred_rhand_kp2d'][:,
+                                                            self.hf_root_idx['rhand']].unsqueeze(1)
                 }
                 proj_hf_pts = {
                     'rhand': torch.cat([proj_hf_center['rhand'], pred_rhand_proj], dim=1)
@@ -1392,9 +1394,7 @@ class PyMAF(nn.Module):
                 pred_cam = mesh_output['pred_cam'].detach()
                 pred_face_v = mesh_output['pred_face_kp3d']
                 pred_face_proj = projection(
-                    pred_face_v, {
-                        **rw_cam, 'cam_sxy': pred_cam
-                    }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+                    pred_face_v, {**rw_cam, 'cam_sxy': pred_cam}, iwp_mode=cfg.MODEL.USE_IWP_CAM
                 )
                 if cfg.MODEL.USE_IWP_CAM:
                     pred_face_proj = pred_face_proj / (224. / 2.)
@@ -1409,9 +1409,7 @@ class PyMAF(nn.Module):
                 pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand])
                 pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1)
                 pred_hand_proj = projection(
-                    pred_hand_v, {
-                        **rw_cam, 'cam_sxy': pred_cam
-                    }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+                    pred_hand_v, {**rw_cam, 'cam_sxy': pred_cam}, iwp_mode=cfg.MODEL.USE_IWP_CAM
                 )
                 if cfg.MODEL.USE_IWP_CAM:
                     pred_hand_proj = pred_hand_proj / (224. / 2.)
@@ -1419,29 +1417,25 @@ class PyMAF(nn.Module):
                     pred_hand_proj = j2d_processing(pred_hand_proj, rw_cam['kps_transf'])
 
                 proj_hf_center = {
-                    'lhand':
-                        mesh_output['pred_lhand_kp2d'][:, self.hf_root_idx['lhand']].unsqueeze(1),
-                    'rhand':
-                        mesh_output['pred_rhand_kp2d'][:, self.hf_root_idx['rhand']].unsqueeze(1),
+                    'lhand': mesh_output['pred_lhand_kp2d'][:,
+                                                            self.hf_root_idx['lhand']].unsqueeze(1),
+                    'rhand': mesh_output['pred_rhand_kp2d'][:,
+                                                            self.hf_root_idx['rhand']].unsqueeze(1),
                 }
                 proj_hf_pts = {
                     'lhand':
-                        torch.cat(
-                            [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1
-                        ),
+                    torch.cat([proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]],
+                              dim=1),
                     'rhand':
-                        torch.cat(
-                            [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1
-                        ),
+                    torch.cat([proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]],
+                              dim=1),
                 }
             elif self.full_body_mode:
                 pred_lhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2lhand])
                 pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand])
                 pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1)
                 pred_hand_proj = projection(
-                    pred_hand_v, {
-                        **rw_cam, 'cam_sxy': pred_cam
-                    }, iwp_mode=cfg.MODEL.USE_IWP_CAM
+                    pred_hand_v, {**rw_cam, 'cam_sxy': pred_cam}, iwp_mode=cfg.MODEL.USE_IWP_CAM
                 )
                 if cfg.MODEL.USE_IWP_CAM:
                     pred_hand_proj = pred_hand_proj / (224. / 2.)
@@ -1449,24 +1443,19 @@ class PyMAF(nn.Module):
                     pred_hand_proj = j2d_processing(pred_hand_proj, rw_cam['kps_transf'])
 
                 proj_hf_center = {
-                    'lhand':
-                        mesh_output['pred_lhand_kp2d'][:, self.hf_root_idx['lhand']].unsqueeze(1),
-                    'rhand':
-                        mesh_output['pred_rhand_kp2d'][:, self.hf_root_idx['rhand']].unsqueeze(1),
-                    'face':
-                        mesh_output['pred_face_kp2d'][:, self.hf_root_idx['face']].unsqueeze(1)
+                    'lhand': mesh_output['pred_lhand_kp2d'][:,
+                                                            self.hf_root_idx['lhand']].unsqueeze(1),
+                    'rhand': mesh_output['pred_rhand_kp2d'][:,
+                                                            self.hf_root_idx['rhand']].unsqueeze(1),
+                    'face': mesh_output['pred_face_kp2d'][:, self.hf_root_idx['face']].unsqueeze(1)
                 }
                 proj_hf_pts = {
                     'lhand':
-                        torch.cat(
-                            [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1
-                        ),
-                    'rhand':
-                        torch.cat(
-                            [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1
-                        ),
-                    'face':
-                        torch.cat([proj_hf_center['face'], mesh_output['pred_face_kp2d']], dim=1)
+                    torch.cat([proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]],
+                              dim=1), 'rhand':
+                    torch.cat([proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]],
+                              dim=1), 'face':
+                    torch.cat([proj_hf_center['face'], mesh_output['pred_face_kp2d']], dim=1)
                 }
 
             # extract mesh-aligned features for the hand / face part
@@ -1542,9 +1531,10 @@ class PyMAF(nn.Module):
                             limb_grid_feature_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling(
                                 grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim
                             )
-                            limb_grid_ref_feat_ctd = torch.cat(
-                                [limb_grid_feature_ctd, limb_ref_feat_ctd], dim=-1
-                            ).permute(0, 2, 1)
+                            limb_grid_ref_feat_ctd = torch.cat([
+                                limb_grid_feature_ctd, limb_ref_feat_ctd
+                            ],
+                                                               dim=-1).permute(0, 2, 1)
 
                             if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT:
                                 att_ref_feat_ctd = self.align_attention[hf_key][
@@ -1581,9 +1571,7 @@ class PyMAF(nn.Module):
                         ref_feature = self.maf_extractor['body'][rf_i](
                             pred_smpl_verts_ds,
                             im_feat=s_feat_i,
-                            cam={
-                                **rw_cam, 'cam_sxy': pred_cam
-                            },
+                            cam={**rw_cam, 'cam_sxy': pred_cam},
                             add_att=True,
                             reduce_dim=reduce_dim
                         )    # [B, 431 * n_feat]
diff --git a/lib/pymafx/models/res_module.py b/lib/pymafx/models/res_module.py
index 94de7ecaa2ba3ead51c5f960e0ae08b806d9cd80..69ca53960d7ce7159eb25deb8684f1605eb2c830 100644
--- a/lib/pymafx/models/res_module.py
+++ b/lib/pymafx/models/res_module.py
@@ -1,18 +1,18 @@
 # code brought in part from https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import, division, print_function
 
+import logging
 import os
+from collections import OrderedDict
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from collections import OrderedDict
+
 from lib.pymafx.core.cfgs import cfg
-# from .transformers.tokenlearner import TokenLearner
 
-import logging
+# from .transformers.tokenlearner import TokenLearner
 
 logger = logging.getLogger(__name__)
 
@@ -119,11 +119,8 @@ class Bottleneck(nn.Module):
 
 
 resnet_spec = {
-    18: (BasicBlock, [2, 2, 2, 2]),
-    34: (BasicBlock, [3, 4, 6, 3]),
-    50: (Bottleneck, [3, 4, 6, 3]),
-    101: (Bottleneck, [3, 4, 23, 3]),
-    152: (Bottleneck, [3, 8, 36, 3])
+    18: (BasicBlock, [2, 2, 2, 2]), 34: (BasicBlock, [3, 4, 6, 3]), 50: (Bottleneck, [3, 4, 6, 3]),
+    101: (Bottleneck, [3, 4, 23, 3]), 152: (Bottleneck, [3, 8, 36, 3])
 }
 
 
diff --git a/lib/pymafx/models/smpl.py b/lib/pymafx/models/smpl.py
index 0a69eaf24e518545542cdd1eb55c819a549e8d55..6dcb6127886e9671fde6a4036d0889ab39ff2b66 100644
--- a/lib/pymafx/models/smpl.py
+++ b/lib/pymafx/models/smpl.py
@@ -1,20 +1,25 @@
 # This script is extended based on https://github.com/nkolot/SPIN/blob/master/models/smpl.py
 
-from typing import Optional
+import json
+import os
+import pickle
 from dataclasses import dataclass
+from typing import Optional
 
-import os
+import numpy as np
 import torch
 import torch.nn as nn
-import numpy as np
-import pickle
+
+from lib.pymafx.core import constants, path_config
 from lib.smplx import SMPL as _SMPL
-from lib.smplx import SMPLXLayer, MANOLayer, FLAMELayer
-from lib.smplx.lbs import batch_rodrigues, transform_mat, vertices2joints, blend_shapes
+from lib.smplx import FLAMELayer, MANOLayer, SMPLXLayer
 from lib.smplx.body_models import SMPLXOutput
-import json
-
-from lib.pymafx.core import path_config, constants
+from lib.smplx.lbs import (
+    batch_rodrigues,
+    blend_shapes,
+    transform_mat,
+    vertices2joints,
+)
 
 SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS
 SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR
@@ -134,11 +139,11 @@ class SMPL(_SMPL):
             ).contiguous()
 
         # Concatenate all pose vectors
-        full_pose = torch.cat(
-            [global_orient.reshape(-1, 1, 3, 3),
-             body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)],
-            dim=1
-        )
+        full_pose = torch.cat([
+            global_orient.reshape(-1, 1, 3, 3),
+            body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)
+        ],
+                              dim=1)
 
         rot_mats = full_pose.view(batch_size, -1, 3, 3)
 
@@ -279,18 +284,16 @@ class SMPLX(SMPLXLayer):
                                                                        -1).contiguous()
 
         # Concatenate all pose vectors
-        full_pose = torch.cat(
-            [
-                global_orient.reshape(-1, 1, 3, 3),
-                body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
-                jaw_pose.reshape(-1, 1, 3, 3),
-                leye_pose.reshape(-1, 1, 3, 3),
-                reye_pose.reshape(-1, 1, 3, 3),
-                left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
-                right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)
-            ],
-            dim=1
-        )
+        full_pose = torch.cat([
+            global_orient.reshape(-1, 1, 3, 3),
+            body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
+            jaw_pose.reshape(-1, 1, 3, 3),
+            leye_pose.reshape(-1, 1, 3, 3),
+            reye_pose.reshape(-1, 1, 3, 3),
+            left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
+            right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)
+        ],
+                              dim=1)
 
         rot_mats = full_pose.view(batch_size, -1, 3, 3)
 
@@ -339,22 +342,20 @@ class SMPLX_ALL(nn.Module):
             self.genders = ['neutral']
         for gender in self.genders:
             assert gender in ['male', 'female', 'neutral']
-        self.model_dict = nn.ModuleDict(
-            {
-                gender: SMPLX(
-                    path_config.SMPL_MODEL_DIR,
-                    gender=gender,
-                    ext='npz',
-                    num_betas=numBetas,
-                    use_pca=False,
-                    batch_size=batch_size,
-                    use_face_contour=use_face_contour,
-                    num_pca_comps=45,
-                    **kwargs
-                )
-                for gender in self.genders
-            }
-        )
+        self.model_dict = nn.ModuleDict({
+            gender: SMPLX(
+                path_config.SMPL_MODEL_DIR,
+                gender=gender,
+                ext='npz',
+                num_betas=numBetas,
+                use_pca=False,
+                batch_size=batch_size,
+                use_face_contour=use_face_contour,
+                num_pca_comps=45,
+                **kwargs
+            )
+            for gender in self.genders
+        })
         self.model_neutral = self.model_dict['neutral']
         joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
         J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA)
@@ -426,9 +427,9 @@ class SMPLX_ALL(nn.Module):
                     #     kwargs[key] += self.model_neutral.left_hand_mean
                     # elif key == 'right_hand_pose':
                     #     kwargs[key] += self.model_neutral.right_hand_mean
-                    kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view(
-                        [batch_size, -1, 3, 3]
-                    )
+                    kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([
+                        batch_size, -1, 3, 3
+                    ])
         if kwargs['body_pose'].shape[1] == 23:
             # remove hand pose in the body_pose
             kwargs['body_pose'] = kwargs['body_pose'][:, :21]
@@ -570,9 +571,9 @@ class MANO(MANOLayer):
         if kwargs['pose2rot']:
             for key in pose_keys:
                 if key in kwargs:
-                    kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view(
-                        [batch_size, -1, 3, 3]
-                    )
+                    kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([
+                        batch_size, -1, 3, 3
+                    ])
         kwargs['hand_pose'] = kwargs.pop('right_hand_pose')
         mano_output = super().forward(*args, **kwargs)
         th_verts = mano_output.vertices
@@ -605,9 +606,9 @@ class FLAME(FLAMELayer):
         if kwargs['pose2rot']:
             for key in pose_keys:
                 if key in kwargs:
-                    kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view(
-                        [batch_size, -1, 3, 3]
-                    )
+                    kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([
+                        batch_size, -1, 3, 3
+                    ])
         flame_output = super().forward(*args, **kwargs)
         output = ModelOutput(
             flame_vertices=flame_output.vertices,
@@ -745,9 +746,8 @@ def get_part_joints(smpl_joints):
 
     # part_joints = torch.zeros().to(smpl_joints.device)
 
-    one_seg_pairs = [
-        (0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17)
-    ]
+    one_seg_pairs = [(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16),
+                     (14, 17)]
     two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)]
 
     one_seg_pairs.extend(two_seg_pairs)
diff --git a/lib/pymafx/models/transformers/bert/__init__.py b/lib/pymafx/models/transformers/bert/__init__.py
index 0432a1e92856c438e5fd2f550dc5029a78fa354c..d66c20fc9ec44b8fb3ae68a611b784bc624c2616 100644
--- a/lib/pymafx/models/transformers/bert/__init__.py
+++ b/lib/pymafx/models/transformers/bert/__init__.py
@@ -1,19 +1,23 @@
 __version__ = "1.0.0"
 
+from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
 from .modeling_bert import (
-    BertConfig, BertModel, load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
-    BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
+    BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+    BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
+    BertConfig,
+    BertModel,
+    load_tf_weights_in_bert,
 )
-
 from .modeling_graphormer import Graphormer
-
-# from .e2e_body_network import Graphormer_Body_Network
-
-# from .e2e_hand_network import Graphormer_Hand_Network
-
 from .modeling_utils import (
-    WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_layer,
-    Conv1D
+    CONFIG_NAME,
+    TF_WEIGHTS_NAME,
+    WEIGHTS_NAME,
+    Conv1D,
+    PretrainedConfig,
+    PreTrainedModel,
+    prune_layer,
 )
 
-from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
+# from .e2e_body_network import Graphormer_Body_Network
+# from .e2e_hand_network import Graphormer_Hand_Network
diff --git a/lib/pymafx/models/transformers/bert/e2e_body_network.py b/lib/pymafx/models/transformers/bert/e2e_body_network.py
index 9d1c75e276aa18fa1e8f2d865cbef7a275f71b8c..6e00cf8716db068ecdca3d790ecacbdb9e518edc 100644
--- a/lib/pymafx/models/transformers/bert/e2e_body_network.py
+++ b/lib/pymafx/models/transformers/bert/e2e_body_network.py
@@ -4,8 +4,8 @@ Licensed under the MIT license.
 
 """
 
-import torch
 import src.modeling.data.config as cfg
+import torch
 
 
 class Graphormer_Body_Network(torch.nn.Module):
diff --git a/lib/pymafx/models/transformers/bert/e2e_hand_network.py b/lib/pymafx/models/transformers/bert/e2e_hand_network.py
index 410968c4abc63e1ae8281b2e0297c8eef4e7bbcf..be88fc5f32099daf9ca56eaa8d362efd5d8915b6 100644
--- a/lib/pymafx/models/transformers/bert/e2e_hand_network.py
+++ b/lib/pymafx/models/transformers/bert/e2e_hand_network.py
@@ -4,8 +4,8 @@ Licensed under the MIT license.
 
 """
 
-import torch
 import src.modeling.data.config as cfg
+import torch
 
 
 class Graphormer_Hand_Network(torch.nn.Module):
diff --git a/lib/pymafx/models/transformers/bert/file_utils.py b/lib/pymafx/models/transformers/bert/file_utils.py
index ee58bed427f90be254caee9a0733d81ae92c8711..2159d8ea87e216039104054f70a159f8c534179b 100644
--- a/lib/pymafx/models/transformers/bert/file_utils.py
+++ b/lib/pymafx/models/transformers/bert/file_utils.py
@@ -3,15 +3,20 @@ Utilities for working with the local dataset cache.
 This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
 Copyright by the AllenNLP authors.
 """
-from __future__ import (absolute_import, division, print_function, unicode_literals)
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
-import sys
+import fnmatch
 import json
 import logging
 import os
 import shutil
+import sys
 import tempfile
-import fnmatch
 from functools import wraps
 from hashlib import sha256
 from io import open
diff --git a/lib/pymafx/models/transformers/bert/modeling_bert.py b/lib/pymafx/models/transformers/bert/modeling_bert.py
index c4a7f27f1bc0e69d87ac3747b8d8acfafb03b4b8..c922f5229d9b85f4eeecade118264e18e6e56753 100644
--- a/lib/pymafx/models/transformers/bert/modeling_bert.py
+++ b/lib/pymafx/models/transformers/bert/modeling_bert.py
@@ -15,7 +15,12 @@
 # limitations under the License.
 """PyTorch BERT model. """
 
-from __future__ import absolute_import, division, print_function, unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
 import json
 import logging
@@ -29,68 +34,72 @@ from torch import nn
 from torch.nn import CrossEntropyLoss, MSELoss
 
 from .modeling_utils import (
-    WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer,
-    add_start_docstrings
+    CONFIG_NAME,
+    WEIGHTS_NAME,
+    PretrainedConfig,
+    PreTrainedModel,
+    add_start_docstrings,
+    prune_linear_layer,
 )
 
 logger = logging.getLogger(__name__)
 
 BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
     'bert-base-uncased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
     'bert-large-uncased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
     'bert-base-cased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
     'bert-large-cased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
     'bert-base-multilingual-uncased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
     'bert-base-multilingual-cased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
     'bert-base-chinese':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
     'bert-base-german-cased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
     'bert-large-uncased-whole-word-masking':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
     'bert-large-cased-whole-word-masking':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
     'bert-large-uncased-whole-word-masking-finetuned-squad':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
     'bert-large-cased-whole-word-masking-finetuned-squad':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
     'bert-base-cased-finetuned-mrpc':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
 }
 
 BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
     'bert-base-uncased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
     'bert-large-uncased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
     'bert-base-cased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
     'bert-large-cased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
     'bert-base-multilingual-uncased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
     'bert-base-multilingual-cased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
     'bert-base-chinese':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
     'bert-base-german-cased':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
     'bert-large-uncased-whole-word-masking':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
     'bert-large-cased-whole-word-masking':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
     'bert-large-uncased-whole-word-masking-finetuned-squad':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
     'bert-large-cased-whole-word-masking-finetuned-squad':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
     'bert-base-cased-finetuned-mrpc':
-        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
+    "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
 }
 
 
@@ -99,6 +108,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
     """
     try:
         import re
+
         import numpy as np
         import tensorflow as tf
     except ImportError:
diff --git a/lib/pymafx/models/transformers/bert/modeling_graphormer.py b/lib/pymafx/models/transformers/bert/modeling_graphormer.py
index e318af8a45d34148e0db68f42181f692afbf8754..315e2e6fe0b3f4b4a35add9dd9fe3d4e2b2a53c4 100644
--- a/lib/pymafx/models/transformers/bert/modeling_graphormer.py
+++ b/lib/pymafx/models/transformers/bert/modeling_graphormer.py
@@ -4,15 +4,29 @@ Licensed under the MIT license.
 
 """
 
-from __future__ import absolute_import, division, print_function, unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
+import code
 import logging
 import math
 import os
-import code
+
 import torch
 from torch import nn
-from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput
+
+from .modeling_bert import (
+    BertEmbeddings,
+    BertIntermediate,
+    BertOutput,
+    BertPooler,
+    BertPreTrainedModel,
+    BertSelfOutput,
+)
 # import src.modeling.data.config as cfg
 # from src.modeling._gcnn import GraphConvolution, GraphResBlock
 from .modeling_utils import prune_linear_layer
@@ -180,9 +194,9 @@ class GraphormerEncoder(nn.Module):
         super(GraphormerEncoder, self).__init__()
         self.output_attentions = config.output_attentions
         self.output_hidden_states = config.output_hidden_states
-        self.layer = nn.ModuleList(
-            [GraphormerLayer(config) for _ in range(config.num_hidden_layers)]
-        )
+        self.layer = nn.ModuleList([
+            GraphormerLayer(config) for _ in range(config.num_hidden_layers)
+        ])
 
     def forward(self, hidden_states, attention_mask, head_mask=None, encoder_history_states=None):
         all_hidden_states = ()
diff --git a/lib/pymafx/models/transformers/bert/modeling_utils.py b/lib/pymafx/models/transformers/bert/modeling_utils.py
index 40a0915822c8e736de8ac2466c075e6cc5ef7e83..56cb54703a5ca9b068933a92440fce1c278ea01e 100644
--- a/lib/pymafx/models/transformers/bert/modeling_utils.py
+++ b/lib/pymafx/models/transformers/bert/modeling_utils.py
@@ -15,7 +15,12 @@
 # limitations under the License.
 """PyTorch BERT model."""
 
-from __future__ import (absolute_import, division, print_function, unicode_literals)
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
 import copy
 import json
@@ -552,9 +557,8 @@ class PreTrainedModel(nn.Module):
 
         if output_loading_info:
             loading_info = {
-                "missing_keys": missing_keys,
-                "unexpected_keys": unexpected_keys,
-                "error_msgs": error_msgs
+                "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs":
+                error_msgs
             }
             return model, loading_info
 
@@ -893,9 +897,8 @@ class SequenceSummary(nn.Module):
                 )
             else:
                 token_ids = token_ids.unsqueeze(-1).unsqueeze(-1)
-                token_ids = token_ids.expand(
-                    (-1, ) * (token_ids.dim() - 1) + (hidden_states.size(-1), )
-                )
+                token_ids = token_ids.expand((-1, ) * (token_ids.dim() - 1) +
+                                             (hidden_states.size(-1), ))
             # shape of token_ids: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
             output = hidden_states.gather(-2,
                                           token_ids).squeeze(-2)    # shape (bsz, XX, hidden_size)
diff --git a/lib/pymafx/models/transformers/net_utils.py b/lib/pymafx/models/transformers/net_utils.py
index 52782911e276705ec0dd908ce9676430c0a58d72..ff59e7ebdf0756306352ccd1ebc2827b4e032c90 100644
--- a/lib/pymafx/models/transformers/net_utils.py
+++ b/lib/pymafx/models/transformers/net_utils.py
@@ -1,6 +1,7 @@
-import torch.nn as nn
-import torch
 import math
+
+import torch
+import torch.nn as nn
 import torch.nn.functional as F
 
 
diff --git a/lib/pymafx/models/transformers/texformer.py b/lib/pymafx/models/transformers/texformer.py
index 4266b24ed6839f91ce5ca819cc3750143387d48f..aada36cd9fad64a8ff47743da40155ed6b012b38 100644
--- a/lib/pymafx/models/transformers/texformer.py
+++ b/lib/pymafx/models/transformers/texformer.py
@@ -1,5 +1,12 @@
 import torch.nn as nn
-from .net_utils import single_conv, double_conv, double_conv_down, double_conv_up, PosEnSine
+
+from .net_utils import (
+    PosEnSine,
+    double_conv,
+    double_conv_down,
+    double_conv_up,
+    single_conv,
+)
 from .transformer_basics import OurMultiheadAttention
 
 
@@ -86,14 +93,12 @@ class Texformer(nn.Module):
         self.unet_k = Unet(src_ch, self.feat_dim, self.feat_dim)
         self.unet_v = Unet(v_ch, self.feat_dim, self.feat_dim)
 
-        self.trans_dec = nn.ModuleList(
-            [
-                None, None, None,
-                TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'),
-                TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'),
-                TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct')
-            ]
-        )
+        self.trans_dec = nn.ModuleList([
+            None, None, None,
+            TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'),
+            TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'),
+            TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct')
+        ])
 
         self.conv0 = double_conv(self.feat_dim, self.feat_dim)
         self.conv1 = double_conv_down(self.feat_dim, self.feat_dim)
diff --git a/lib/pymafx/models/transformers/transformer_basics.py b/lib/pymafx/models/transformers/transformer_basics.py
index 144ccd76b7e2f73189634ab551691c4262781b9d..f2c9d9533926b88d0308c18560b9cf327e9d317a 100644
--- a/lib/pymafx/models/transformers/transformer_basics.py
+++ b/lib/pymafx/models/transformers/transformer_basics.py
@@ -1,6 +1,13 @@
 import torch.nn as nn
-from .net_utils import PosEnSine, softmax_attention, dotproduct_attention, long_range_attention, \
-                                   short_range_attention, patch_attention
+
+from .net_utils import (
+    PosEnSine,
+    dotproduct_attention,
+    long_range_attention,
+    patch_attention,
+    short_range_attention,
+    softmax_attention,
+)
 
 
 class OurMultiheadAttention(nn.Module):
diff --git a/lib/pymafx/utils/blob.py b/lib/pymafx/utils/blob.py
index 00123338e18a3fa74a6c3cb730cac9fb41b59ac5..11814bbec48887f622d11a786ab25271f98d5450 100644
--- a/lib/pymafx/utils/blob.py
+++ b/lib/pymafx/utils/blob.py
@@ -22,16 +22,17 @@
 # --------------------------------------------------------
 """blob helper functions."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
-from six.moves import cPickle as pickle
-import numpy as np
 import cv2
-
+import numpy as np
 from models.core.config import cfg
+from six.moves import cPickle as pickle
 
 
 def get_image_blob(im, target_scale, target_max_size):
diff --git a/lib/pymafx/utils/cam_params.py b/lib/pymafx/utils/cam_params.py
index 1f6c1a8d89b2c80d72c90c841d02425df77aa4a5..f8138bede445a95571e0b179f4f8515a7a2cb672 100644
--- a/lib/pymafx/utils/cam_params.py
+++ b/lib/pymafx/utils/cam_params.py
@@ -15,10 +15,11 @@
 # Contact: ps-license@tuebingen.mpg.de
 
 import os
-from numpy.testing._private.utils import print_assert_equal
-import torch
-import numpy as np
+
 import joblib
+import numpy as np
+import torch
+from numpy.testing._private.utils import print_assert_equal
 
 from .geometry import batch_euler2matrix
 
diff --git a/lib/pymafx/utils/collections.py b/lib/pymafx/utils/collections.py
index edd20a8c89d5d2221dc9d35948eda12c6304ba29..b7288875cd49a06d8305f95a104e0309907d9ec0 100644
--- a/lib/pymafx/utils/collections.py
+++ b/lib/pymafx/utils/collections.py
@@ -14,10 +14,12 @@
 ##############################################################################
 """A simple attribute dictionary used for representing configuration options."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
 
 class AttrDict(dict):
diff --git a/lib/pymafx/utils/colormap.py b/lib/pymafx/utils/colormap.py
index 44ef28c050021a6f03d088e9437de0c4adeb5ee5..1b275cf20b8c307564a2469d77426fe3ac5d2996 100644
--- a/lib/pymafx/utils/colormap.py
+++ b/lib/pymafx/utils/colormap.py
@@ -14,39 +14,38 @@
 ##############################################################################
 """An awesome colormap for really neat visualizations."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
 import numpy as np
 
 
 def colormap(rgb=False):
-    color_list = np.array(
-        [
-            0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, 0.184, 0.556,
-            0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078, 0.184, 0.300, 0.300, 0.300,
-            0.600, 0.600, 0.600, 1.000, 0.000, 0.000, 1.000, 0.500, 0.000, 0.749, 0.749, 0.000,
-            0.000, 1.000, 0.000, 0.000, 0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000,
-            0.333, 0.667, 0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000,
-            0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000, 1.000, 0.000,
-            0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000, 0.500, 0.333, 0.000, 0.500,
-            0.333, 0.333, 0.500, 0.333, 0.667, 0.500, 0.333, 1.000, 0.500, 0.667, 0.000, 0.500,
-            0.667, 0.333, 0.500, 0.667, 0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500,
-            1.000, 0.333, 0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000,
-            0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333, 0.333, 1.000,
-            0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, 1.000, 0.667, 0.333, 1.000,
-            0.667, 0.667, 1.000, 0.667, 1.000, 1.000, 1.000, 0.000, 1.000, 1.000, 0.333, 1.000,
-            1.000, 0.667, 1.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000,
-            0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000,
-            0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000,
-            0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500,
-            0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000,
-            0.143, 0.143, 0.143, 0.286, 0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571,
-            0.714, 0.714, 0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000
-        ]
-    ).astype(np.float32)
+    color_list = np.array([
+        0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, 0.184, 0.556, 0.466,
+        0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078, 0.184, 0.300, 0.300, 0.300, 0.600, 0.600,
+        0.600, 1.000, 0.000, 0.000, 1.000, 0.500, 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000,
+        0.000, 0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, 0.333, 0.667, 0.000, 0.333,
+        1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000, 0.667, 1.000, 0.000, 1.000, 0.333,
+        0.000, 1.000, 0.667, 0.000, 1.000, 1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500,
+        0.000, 1.000, 0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500, 0.333,
+        1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, 0.667, 0.667, 0.500, 0.667, 1.000,
+        0.500, 1.000, 0.000, 0.500, 1.000, 0.333, 0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500,
+        0.000, 0.333, 1.000, 0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333,
+        0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, 1.000, 0.667, 0.333,
+        1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000, 1.000, 0.000, 1.000, 1.000, 0.333, 1.000,
+        1.000, 0.667, 1.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667,
+        0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000, 0.333,
+        0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000,
+        0.000, 0.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000,
+        0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.286, 0.286,
+        0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, 0.714, 0.714, 0.714, 0.857, 0.857, 0.857,
+        1.000, 1.000, 1.000
+    ]).astype(np.float32)
     color_list = color_list.reshape((-1, 3)) * 255
     if not rgb:
         color_list = color_list[:, ::-1]
diff --git a/lib/pymafx/utils/common.py b/lib/pymafx/utils/common.py
index f3330ea18c4783ccacb21657808b8b8ce2301f86..bdcb5d0e3b18ab60f88bcab94de75922513e6195 100755
--- a/lib/pymafx/utils/common.py
+++ b/lib/pymafx/utils/common.py
@@ -1,7 +1,9 @@
-import torch
-import numpy as np
 import logging
 from copy import deepcopy
+
+import numpy as np
+import torch
+
 from .utils.libkdtree import KDTree
 
 logger_py = logging.getLogger(__name__)
@@ -170,14 +172,12 @@ def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1, eps=
         ray_direction.unsqueeze(-2)
 
     # Calculate mask where points intersect unit cube
-    p_mask_inside_cube = (
-        (p_intersect[:, :, :, 0] <= p_distance + eps) &
-        (p_intersect[:, :, :, 1] <= p_distance + eps) &
-        (p_intersect[:, :, :, 2] <= p_distance + eps) &
-        (p_intersect[:, :, :, 0] >= -(p_distance + eps)) &
-        (p_intersect[:, :, :, 1] >= -(p_distance + eps)) &
-        (p_intersect[:, :, :, 2] >= -(p_distance + eps))
-    ).cpu()
+    p_mask_inside_cube = ((p_intersect[:, :, :, 0] <= p_distance + eps) &
+                          (p_intersect[:, :, :, 1] <= p_distance + eps) &
+                          (p_intersect[:, :, :, 2] <= p_distance + eps) &
+                          (p_intersect[:, :, :, 0] >= -(p_distance + eps)) &
+                          (p_intersect[:, :, :, 1] >= -(p_distance + eps)) &
+                          (p_intersect[:, :, :, 2] >= -(p_distance + eps))).cpu()
 
     # Correct rays are these which intersect exactly 2 times
     mask_inside_cube = p_mask_inside_cube.sum(-1) == 2
@@ -190,13 +190,11 @@ def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1, eps=
     # Calculate ray lengths for the interval points
     d_intervals_batch = torch.zeros(batch_size, n_pts, 2).to(device)
     norm_ray = torch.norm(ray_direction[mask_inside_cube], dim=-1)
-    d_intervals_batch[mask_inside_cube] = torch.stack(
-        [
-            torch.norm(p_intervals[:, 0] - ray0[mask_inside_cube], dim=-1) / norm_ray,
-            torch.norm(p_intervals[:, 1] - ray0[mask_inside_cube], dim=-1) / norm_ray,
-        ],
-        dim=-1
-    )
+    d_intervals_batch[mask_inside_cube] = torch.stack([
+        torch.norm(p_intervals[:, 0] - ray0[mask_inside_cube], dim=-1) / norm_ray,
+        torch.norm(p_intervals[:, 1] - ray0[mask_inside_cube], dim=-1) / norm_ray,
+    ],
+                                                      dim=-1)
 
     # Sort the ray lengths
     d_intervals_batch, indices_sort = d_intervals_batch.sort()
diff --git a/lib/pymafx/utils/data_loader.py b/lib/pymafx/utils/data_loader.py
index cc92ad223836e9de322bc80bbab887bb9ec3f17b..3d109f82b3473242a9fb9442037c47471fd0f7d2 100644
--- a/lib/pymafx/utils/data_loader.py
+++ b/lib/pymafx/utils/data_loader.py
@@ -1,4 +1,5 @@
 from __future__ import division
+
 import torch
 from torch.utils.data import DataLoader
 from torch.utils.data.sampler import Sampler
diff --git a/lib/pymafx/utils/demo_utils.py b/lib/pymafx/utils/demo_utils.py
index b1ad8da91c7a7f6f67d4770c9866a02a78aa5275..437a2b42cae65f51dc357ffffef0cc63d6aa58b9 100644
--- a/lib/pymafx/utils/demo_utils.py
+++ b/lib/pymafx/utils/demo_utils.py
@@ -14,20 +14,20 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import os
-import cv2
-import time
 import json
-import torch
-import subprocess
-import numpy as np
+import os
 import os.path as osp
+import subprocess
+import time
 # from pytube import YouTube
 from collections import OrderedDict
 
-from utils.smooth_bbox import get_smooth_bbox_params, get_all_bbox_params
+import cv2
+import numpy as np
+import torch
 from datasets.data_utils.img_utils import get_single_image_crop_demo
 from utils.geometry import rotation_matrix_to_angle_axis
+from utils.smooth_bbox import get_all_bbox_params, get_smooth_bbox_params
 
 
 def preprocess_video(video, joints2d, bboxes, frames, scale=1.0, crop_size=224):
@@ -112,9 +112,10 @@ def smplify_runner(
         pred_pose = pred_rotmat
 
     # Calculate camera parameters for smplify
-    pred_cam_t = torch.stack(
-        [pred_cam[:, 1], pred_cam[:, 2], 2 * 5000 / (224 * pred_cam[:, 0] + 1e-9)], dim=-1
-    )
+    pred_cam_t = torch.stack([
+        pred_cam[:, 1], pred_cam[:, 2], 2 * 5000 / (224 * pred_cam[:, 0] + 1e-9)
+    ],
+                             dim=-1)
 
     gt_keypoints_2d_orig = j2d
     # Before running compute reprojection error of the network
@@ -285,14 +286,11 @@ def prepare_rendering_results(results_dict, nframes):
     for person_id, person_data in results_dict.items():
         for idx, frame_id in enumerate(person_data['frame_ids']):
             frame_results[frame_id][person_id] = {
-                'verts':
-                    person_data['verts'][idx],
+                'verts': person_data['verts'][idx],
                 'smplx_verts':
-                    person_data['smplx_verts'][idx] if 'smplx_verts' in person_data else None,
-                'cam':
-                    person_data['orig_cam'][idx],
-                'cam_t':
-                    person_data['orig_cam_t'][idx] if 'orig_cam_t' in person_data else None,
+                person_data['smplx_verts'][idx] if 'smplx_verts' in person_data else None,
+                'cam': person_data['orig_cam'][idx],
+                'cam_t': person_data['orig_cam_t'][idx] if 'orig_cam_t' in person_data else None,
             # 'cam': person_data['pred_cam'][idx],
             }
 
@@ -300,9 +298,9 @@ def prepare_rendering_results(results_dict, nframes):
     for frame_id, frame_data in enumerate(frame_results):
         # sort based on y-scale of the cam in original image coords
         sort_idx = np.argsort([v['cam'][1] for k, v in frame_data.items()])
-        frame_results[frame_id] = OrderedDict(
-            {list(frame_data.keys())[i]: frame_data[list(frame_data.keys())[i]]
-             for i in sort_idx}
-        )
+        frame_results[frame_id] = OrderedDict({
+            list(frame_data.keys())[i]: frame_data[list(frame_data.keys())[i]]
+            for i in sort_idx
+        })
 
     return frame_results
diff --git a/lib/pymafx/utils/densepose_methods.py b/lib/pymafx/utils/densepose_methods.py
index 93fdf66a6651dcfe05f6e95c55379eaa00c52cb0..7e02eafd746a4a152cbb3d3569ed3fdaf673f7fe 100644
--- a/lib/pymafx/utils/densepose_methods.py
+++ b/lib/pymafx/utils/densepose_methods.py
@@ -4,12 +4,13 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-import numpy as np
 import copy
+import os
+
 import cv2
-from scipy.io import loadmat
+import numpy as np
 import scipy.spatial.distance
-import os
+from scipy.io import loadmat
 
 
 class DensePoseMethods:
@@ -102,24 +103,18 @@ class DensePoseMethods:
         FaceIndicesNow = np.where(self.FaceIndices == I_point)
         FacesNow = self.FacesDensePose[FaceIndicesNow]
         #
-        P_0 = np.vstack(
-            (
-                self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0],
-                np.zeros(self.U_norm[FacesNow][:, 0].shape)
-            )
-        ).transpose()
-        P_1 = np.vstack(
-            (
-                self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1],
-                np.zeros(self.U_norm[FacesNow][:, 1].shape)
-            )
-        ).transpose()
-        P_2 = np.vstack(
-            (
-                self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2],
-                np.zeros(self.U_norm[FacesNow][:, 2].shape)
-            )
-        ).transpose()
+        P_0 = np.vstack((
+            self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0],
+            np.zeros(self.U_norm[FacesNow][:, 0].shape)
+        )).transpose()
+        P_1 = np.vstack((
+            self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1],
+            np.zeros(self.U_norm[FacesNow][:, 1].shape)
+        )).transpose()
+        P_2 = np.vstack((
+            self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2],
+            np.zeros(self.U_norm[FacesNow][:, 2].shape)
+        )).transpose()
         #
 
         for i, [P0, P1, P2] in enumerate(zip(P_0, P_1, P_2)):
diff --git a/lib/pymafx/utils/geometry.py b/lib/pymafx/utils/geometry.py
index 608288fc4d73a4918ab95938a7bf5dbe98ce606f..5c2d2c1f8b22147344afc549ae74630bea56d3ef 100644
--- a/lib/pymafx/utils/geometry.py
+++ b/lib/pymafx/utils/geometry.py
@@ -1,8 +1,10 @@
-import torch
-from torch.nn import functional as F
-import numpy as np
 import numbers
+
+import numpy as np
+import torch
 from einops.einops import rearrange
+from torch.nn import functional as F
+
 """
 Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
 Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
@@ -43,13 +45,11 @@ def quat_to_rotmat(quat):
     wx, wy, wz = w * x, w * y, w * z
     xy, xz, yz = x * y, x * z, y * z
 
-    rotMat = torch.stack(
-        [
-            w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
-            2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
-        ],
-        dim=1
-    ).view(B, 3, 3)
+    rotMat = torch.stack([
+        w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
+        2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
+    ],
+                         dim=1).view(B, 3, 3)
     return rotMat
 
 
@@ -225,39 +225,31 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
     mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
 
     t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
-    q0 = torch.stack(
-        [
-            rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
-            rmat_t[:, 2, 0] + rmat_t[:, 0, 2]
-        ], -1
-    )
+    q0 = torch.stack([
+        rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+        rmat_t[:, 2, 0] + rmat_t[:, 0, 2]
+    ], -1)
     t0_rep = t0.repeat(4, 1).t()
 
     t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
-    q1 = torch.stack(
-        [
-            rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1,
-            rmat_t[:, 1, 2] + rmat_t[:, 2, 1]
-        ], -1
-    )
+    q1 = torch.stack([
+        rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1,
+        rmat_t[:, 1, 2] + rmat_t[:, 2, 1]
+    ], -1)
     t1_rep = t1.repeat(4, 1).t()
 
     t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
-    q2 = torch.stack(
-        [
-            rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
-            rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2
-        ], -1
-    )
+    q2 = torch.stack([
+        rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+        rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2
+    ], -1)
     t2_rep = t2.repeat(4, 1).t()
 
     t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
-    q3 = torch.stack(
-        [
-            t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
-            rmat_t[:, 0, 1] - rmat_t[:, 1, 0]
-        ], -1
-    )
+    q3 = torch.stack([
+        t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+        rmat_t[:, 0, 1] - rmat_t[:, 1, 0]
+    ], -1)
     t3_rep = t3.repeat(4, 1).t()
 
     mask_c0 = mask_d2 * mask_d0_d1
@@ -321,13 +313,11 @@ def quaternion_to_rotation_matrix(quat):
     wx, wy, wz = w * x, w * y, w * z
     xy, xz, yz = x * y, x * z, y * z
 
-    rotMat = torch.stack(
-        [
-            w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
-            2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
-        ],
-        dim=1
-    ).view(B, 3, 3)
+    rotMat = torch.stack([
+        w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
+        2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
+    ],
+                         dim=1).view(B, 3, 3)
     return rotMat
 
 
@@ -405,9 +395,10 @@ def projection(pred_joints, pred_camera, retain_z=False, iwp_mode=True):
     batch_size = pred_joints.shape[0]
     if iwp_mode:
         cam_sxy = pred_camera['cam_sxy']
-        pred_cam_t = torch.stack(
-            [cam_sxy[:, 1], cam_sxy[:, 2], 2 * 5000. / (224. * cam_sxy[:, 0] + 1e-9)], dim=-1
-        )
+        pred_cam_t = torch.stack([
+            cam_sxy[:, 1], cam_sxy[:, 2], 2 * 5000. / (224. * cam_sxy[:, 0] + 1e-9)
+        ],
+                                 dim=-1)
 
         camera_center = torch.zeros(batch_size, 2)
         pred_keypoints_2d = perspective_projection(
@@ -537,12 +528,10 @@ def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_si
     weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
 
     # least squares
-    Q = np.array(
-        [
-            F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
-            O - np.reshape(joints_2d, -1)
-        ]
-    ).T
+    Q = np.array([
+        F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
+        O - np.reshape(joints_2d, -1)
+    ]).T
     c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
 
     # weighted least squares
@@ -609,10 +598,8 @@ def Rot_y(angle, category='torch', prepend_dim=True, device=None):
 		prepend_dim: prepend an extra dimension
 	Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
 	'''
-    m = np.array(
-        [[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0.,
-                                                            np.cos(angle)]]
-    )
+    m = np.array([[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.],
+                  [-np.sin(angle), 0., np.cos(angle)]])
     if category == 'torch':
         if prepend_dim:
             return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -634,10 +621,8 @@ def Rot_x(angle, category='torch', prepend_dim=True, device=None):
 		prepend_dim: prepend an extra dimension
 	Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
 	'''
-    m = np.array(
-        [[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle),
-                                                             np.cos(angle)]]
-    )
+    m = np.array([[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)],
+                  [0., np.sin(angle), np.cos(angle)]])
     if category == 'torch':
         if prepend_dim:
             return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -659,9 +644,8 @@ def Rot_z(angle, category='torch', prepend_dim=True, device=None):
 		prepend_dim: prepend an extra dimension
 	Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
 	'''
-    m = np.array(
-        [[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]]
-    )
+    m = np.array([[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle),
+                                                        np.cos(angle), 0.], [0., 0., 1.]])
     if category == 'torch':
         if prepend_dim:
             return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
diff --git a/lib/pymafx/utils/imutils.py b/lib/pymafx/utils/imutils.py
index b3522fee118cf47c5101bfd8e16991e5c30f58ad..f3df8bf7ccf43a2f30074f52ccbc025017af66cf 100644
--- a/lib/pymafx/utils/imutils.py
+++ b/lib/pymafx/utils/imutils.py
@@ -1,10 +1,10 @@
 """
 This file contains functions that are used to perform data augmentation.
 """
-import torch
-import numpy as np
 import cv2
+import numpy as np
 import skimage.transform
+import torch
 from PIL import Image
 
 from lib.pymafx.core import constants
@@ -129,12 +129,8 @@ def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
 def rot_aa(aa, rot):
     """Rotate axis angle parameters."""
     # pose parameters
-    R = np.array(
-        [
-            [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
-            [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], [0, 0, 1]
-        ]
-    )
+    R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
+                  [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], [0, 0, 1]])
     # find the rotation of the body in camera frame
     per_rdg, _ = cv2.Rodrigues(aa)
     # apply the global rotation to the global orientation
@@ -246,9 +242,9 @@ def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
     target_weight = np.ones((num_joints, 1), dtype=np.float32)
     if joints_vis is not None:
         target_weight[:, 0] = joints_vis[:, 0]
-    target = torch.zeros(
-        (num_joints, heatmap_size[1], heatmap_size[0]), dtype=torch.float32, device=cur_device
-    )
+    target = torch.zeros((num_joints, heatmap_size[1], heatmap_size[0]),
+                         dtype=torch.float32,
+                         device=cur_device)
 
     tmp_size = sigma * 3
 
diff --git a/lib/pymafx/utils/io.py b/lib/pymafx/utils/io.py
index 0926624ddeb1eccf2e9c6393595acfd34a62e84d..67d5b50542ef8f831ac59d46947631ae8f2cc78e 100644
--- a/lib/pymafx/utils/io.py
+++ b/lib/pymafx/utils/io.py
@@ -14,17 +14,21 @@
 ##############################################################################
 """IO utilities."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
-from six.moves import cPickle as pickle
 import hashlib
 import logging
 import os
 import re
 import sys
+
+from six.moves import cPickle as pickle
+
 try:
     from urllib.request import urlopen
 except ImportError:    #python2
diff --git a/lib/pymafx/utils/iuvmap.py b/lib/pymafx/utils/iuvmap.py
index 7f7c25398e04e30b2b244d44badc83415d583852..8c1914dac03f553f7651fac74ca3bd1b45b65645 100644
--- a/lib/pymafx/utils/iuvmap.py
+++ b/lib/pymafx/utils/iuvmap.py
@@ -115,10 +115,8 @@ def iuv_img2map(uvimages, uv_rois=None, new_size=None, n_part=24):
     batch_size = uvimages.size(0)
     uvimg_size = uvimages.size(-1)
 
-    Index2mask = [
-        [0], [1, 2], [3], [4], [5], [6], [7, 9], [8, 10], [11, 13], [12, 14], [15, 17], [16, 18],
-        [19, 21], [20, 22], [23, 24]
-    ]
+    Index2mask = [[0], [1, 2], [3], [4], [5], [6], [7, 9], [8, 10], [11, 13], [12, 14], [15, 17],
+                  [16, 18], [19, 21], [20, 22], [23, 24]]
 
     part_ind = torch.round(uvimages[:, 0, :, :] * n_part)
     part_u = uvimages[:, 1, :, :]
diff --git a/lib/pymafx/utils/keypoints.py b/lib/pymafx/utils/keypoints.py
index 2ab223c2bef79518adc523da1606cfc331ef8251..a4e52d6a51dd24624be07a469fef62e1d0130995 100644
--- a/lib/pymafx/utils/keypoints.py
+++ b/lib/pymafx/utils/keypoints.py
@@ -14,16 +14,18 @@
 ##############################################################################
 """Keypoint utilities (somewhat specific to COCO keypoints)."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
 import cv2
 import numpy as np
 import torch
-import torch.nn.functional as F
 import torch.cuda.comm
+import torch.nn.functional as F
 
 # from core.config import cfg
 # import utils.blob as blob_utils
@@ -39,14 +41,9 @@ def get_keypoints():
         'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
     ]
     keypoint_flip_map = {
-        'left_eye': 'right_eye',
-        'left_ear': 'right_ear',
-        'left_shoulder': 'right_shoulder',
-        'left_elbow': 'right_elbow',
-        'left_wrist': 'right_wrist',
-        'left_hip': 'right_hip',
-        'left_knee': 'right_knee',
-        'left_ankle': 'right_ankle'
+        'left_eye': 'right_eye', 'left_ear': 'right_ear', 'left_shoulder': 'right_shoulder',
+        'left_elbow': 'right_elbow', 'left_wrist': 'right_wrist', 'left_hip': 'right_hip',
+        'left_knee': 'right_knee', 'left_ankle': 'right_ankle'
     }
     return keypoints, keypoint_flip_map
 
@@ -232,9 +229,9 @@ def compute_oks(src_keypoints, src_roi, dst_keypoints, dst_roi):
     dst_roi: Nx4
     """
 
-    sigmas = np.array(
-        [.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]
-    ) / 10.0
+    sigmas = np.array([
+        .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89
+    ]) / 10.0
     vars = (sigmas * 2)**2
 
     # area
diff --git a/lib/pymafx/utils/mesh_generation.py b/lib/pymafx/utils/mesh_generation.py
index 2876209e7678d2906a84850208f6c288103d07c5..3442d16a79e7594a17c1a3ca7ae195ec5c463b4c 100644
--- a/lib/pymafx/utils/mesh_generation.py
+++ b/lib/pymafx/utils/mesh_generation.py
@@ -1,16 +1,16 @@
 import time
-import torch
-import trimesh
+
 import numpy as np
+import torch
 import torch.optim as optim
+import trimesh
 from torch import autograd
-from torch.utils.data import TensorDataset, DataLoader
+from torch.utils.data import DataLoader, TensorDataset
 
-from .common import make_3d_grid
+from .common import make_3d_grid, transform_pointcloud
 from .utils import libmcubes
 from .utils.libmise import MISE
 from .utils.libsimplify import simplify_mesh
-from .common import transform_pointcloud
 
 
 class Generator3D(object):
@@ -286,9 +286,8 @@ class Generator3D(object):
         colors = np.concatenate(colors, axis=0)
         colors = np.clip(colors, 0, 1)
         colors = (colors * 255).astype(np.uint8)
-        colors = np.concatenate(
-            [colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1
-        )
+        colors = np.concatenate([colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)],
+                                axis=1)
         return colors
 
     def estimate_normals(self, vertices, c=None):
@@ -375,13 +374,11 @@ class Generator3D(object):
                 face_normal = face_normal / \
                     (face_normal.norm(dim=1, keepdim=True) + 1e-10)
 
-                face_value = torch.cat(
-                    [
-                        torch.sigmoid(self.model.decode(p_split, c).logits)
-                        for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1)
-                    ],
-                    dim=1
-                )
+                face_value = torch.cat([
+                    torch.sigmoid(self.model.decode(p_split, c).logits)
+                    for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1)
+                ],
+                                       dim=1)
 
                 normal_target = -autograd.grad([face_value.sum()], [face_point],
                                                create_graph=True)[0]
diff --git a/lib/pymafx/utils/part_utils.py b/lib/pymafx/utils/part_utils.py
index 12f0de443fa11e90674761816a644cf82a48a786..5b17763404418e9b97d738af4c497be2d1918cde 100644
--- a/lib/pymafx/utils/part_utils.py
+++ b/lib/pymafx/utils/part_utils.py
@@ -1,8 +1,7 @@
-import torch
-import numpy as np
 import neural_renderer as nr
+import numpy as np
+import torch
 from core import path_config
-
 from models import SMPL
 
 
@@ -42,13 +41,11 @@ class PartRenderer():
     def __call__(self, vertices, camera):
         """Wrapper function for rendering process."""
         # Estimate camera parameters given a fixed focal length
-        cam_t = torch.stack(
-            [
-                camera[:, 1], camera[:, 2], 2 * self.focal_length /
-                (self.render_res * camera[:, 0] + 1e-9)
-            ],
-            dim=-1
-        )
+        cam_t = torch.stack([
+            camera[:, 1], camera[:, 2], 2 * self.focal_length /
+            (self.render_res * camera[:, 0] + 1e-9)
+        ],
+                            dim=-1)
         batch_size = vertices.shape[0]
         K = torch.eye(3, device=vertices.device)
         K[0, 0] = self.focal_length
diff --git a/lib/pymafx/utils/pose_tracker.py b/lib/pymafx/utils/pose_tracker.py
index 92c383cdb3dba6053a0595b9f03305c02e9fc277..be6008b31580d600533921ed722a89568c3a36fd 100644
--- a/lib/pymafx/utils/pose_tracker.py
+++ b/lib/pymafx/utils/pose_tracker.py
@@ -14,12 +14,13 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-import os
 import json
+import os
+import os.path as osp
 import shutil
 import subprocess
+
 import numpy as np
-import os.path as osp
 
 
 def run_openpose(
diff --git a/lib/pymafx/utils/pose_utils.py b/lib/pymafx/utils/pose_utils.py
index 55eb1d771376da71c864a715d1dd6b5d66e9894e..966eb51850b6fa793b1e11d4dcad1eed2c6698da 100644
--- a/lib/pymafx/utils/pose_utils.py
+++ b/lib/pymafx/utils/pose_utils.py
@@ -1,9 +1,8 @@
 """
 Parts of the code are adapted from https://github.com/akanazawa/hmr
 """
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import, division, print_function
+
 import numpy as np
 import torch
 
diff --git a/lib/pymafx/utils/renderer.py b/lib/pymafx/utils/renderer.py
index 9fb19568680b839f93c00a5288c94a5a52025242..5b201badc2c6da232d0d753c9303b7c62361ca7f 100644
--- a/lib/pymafx/utils/renderer.py
+++ b/lib/pymafx/utils/renderer.py
@@ -1,45 +1,59 @@
 import imp
+import json
 import os
 from pickle import NONE
+
+import numpy as np
 # os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
 import torch
+import torch.nn.functional as F
 import trimesh
-import numpy as np
+from core import constants, path_config
+from models.smpl import get_model_faces, get_model_tpose, get_smpl_faces
 # import neural_renderer as nr
 from skimage.transform import resize
 from torchvision.utils import make_grid
-import torch.nn.functional as F
-
-from models.smpl import get_smpl_faces, get_model_faces, get_model_tpose
 from utils.densepose_methods import DensePoseMethods
-from core import constants, path_config
-import json
-from .geometry import convert_to_full_img_cam
 from utils.imutils import crop
 
+from .geometry import convert_to_full_img_cam
+
 try:
     import math
+
     import pyrender
     from pyrender.constants import RenderFlags
 except:
     pass
 try:
-    from opendr.renderer import ColoredRenderer
-    from opendr.lighting import LambertianPointLight, SphericalHarmonics
     from opendr.camera import ProjectPoints
+    from opendr.lighting import LambertianPointLight, SphericalHarmonics
+    from opendr.renderer import ColoredRenderer
 except:
     pass
 
-from pytorch3d.structures.meshes import Meshes
-# from pytorch3d.renderer.mesh.renderer import MeshRendererWithFragments
+import logging
 
 from pytorch3d.renderer import (
-    look_at_view_transform, FoVPerspectiveCameras, PerspectiveCameras, AmbientLights, PointLights,
-    RasterizationSettings, BlendParams, MeshRenderer, MeshRasterizer, SoftPhongShader,
-    SoftSilhouetteShader, HardPhongShader, HardGouraudShader, HardFlatShader, TexturesVertex
+    AmbientLights,
+    BlendParams,
+    FoVPerspectiveCameras,
+    HardFlatShader,
+    HardGouraudShader,
+    HardPhongShader,
+    MeshRasterizer,
+    MeshRenderer,
+    PerspectiveCameras,
+    PointLights,
+    RasterizationSettings,
+    SoftPhongShader,
+    SoftSilhouetteShader,
+    TexturesVertex,
+    look_at_view_transform,
 )
+from pytorch3d.structures.meshes import Meshes
 
-import logging
+# from pytorch3d.renderer.mesh.renderer import MeshRendererWithFragments
 
 logger = logging.getLogger(__name__)
 
@@ -172,15 +186,15 @@ class PyRenderer:
             if len(cam) == 4:
                 sx, sy, tx, ty = cam
                 # sy = sx
-                camera_translation = np.array(
-                    [tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]
-                )
+                camera_translation = np.array([
+                    tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)
+                ])
             elif len(cam) == 3:
                 sx, tx, ty = cam
                 sy = sx
-                camera_translation = np.array(
-                    [-tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]
-                )
+                camera_translation = np.array([
+                    -tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)
+                ])
             render_res = resolution
             self.renderer.viewport_width = render_res[1]
             self.renderer.viewport_height = render_res[0]
@@ -283,12 +297,8 @@ class OpenDRenderer:
         self.resolution = (resolution[0] * ratio, resolution[1] * ratio)
         self.ratio = ratio
         self.focal_length = 5000.
-        self.K = np.array(
-            [
-                [self.focal_length, 0., self.resolution[1] / 2.],
-                [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]
-            ]
-        )
+        self.K = np.array([[self.focal_length, 0., self.resolution[1] / 2.],
+                           [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]])
         self.colors_dict = {
             'red': np.array([0.5, 0.2, 0.2]),
             'pink': np.array([0.7, 0.5, 0.5]),
@@ -303,12 +313,8 @@ class OpenDRenderer:
 
     def reset_res(self, resolution):
         self.resolution = (resolution[0] * self.ratio, resolution[1] * self.ratio)
-        self.K = np.array(
-            [
-                [self.focal_length, 0., self.resolution[1] / 2.],
-                [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]
-            ]
-        )
+        self.K = np.array([[self.focal_length, 0., self.resolution[1] / 2.],
+                           [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]])
 
     def __call__(
         self,
@@ -446,27 +452,22 @@ class OpenDRenderer:
 #  https://github.com/classner/up/blob/master/up_tools/camera.py
 def rotateY(points, angle):
     """Rotate all points in a 2D array around the y axis."""
-    ry = np.array(
-        [[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0.,
-                                                            np.cos(angle)]]
-    )
+    ry = np.array([[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.],
+                   [-np.sin(angle), 0., np.cos(angle)]])
     return np.dot(points, ry)
 
 
 def rotateX(points, angle):
     """Rotate all points in a 2D array around the x axis."""
-    rx = np.array(
-        [[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle),
-                                                             np.cos(angle)]]
-    )
+    rx = np.array([[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)],
+                   [0., np.sin(angle), np.cos(angle)]])
     return np.dot(points, rx)
 
 
 def rotateZ(points, angle):
     """Rotate all points in a 2D array around the z axis."""
-    rz = np.array(
-        [[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]]
-    )
+    rz = np.array([[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle),
+                                                         np.cos(angle), 0.], [0., 0., 1.]])
     return np.dot(points, rz)
 
 
@@ -514,12 +515,8 @@ class IUV_Renderer(object):
                                 break
                     np.save(dp_vert_pid_fname, np.array(dp_vert_pid))
 
-                textures_vts = np.array(
-                    [
-                        (dp_vert_pid[i] / num_part, DP.U_norm[i], DP.V_norm[i])
-                        for i in range(len(vert_mapping))
-                    ]
-                )
+                textures_vts = np.array([(dp_vert_pid[i] / num_part, DP.U_norm[i], DP.V_norm[i])
+                                         for i in range(len(vert_mapping))])
                 self.textures_vts = torch.from_numpy(
                     textures_vts[None].astype(np.float32)
                 )    # (1, 7829, 3)
@@ -569,12 +566,8 @@ class IUV_Renderer(object):
             #     range(n_verts)])
             self.textures_vts = torch.from_numpy(textures_vts[None].astype(np.float32))
 
-        K = np.array(
-            [
-                [self.focal_length, 0., self.orig_size / 2.],
-                [0., self.focal_length, self.orig_size / 2.], [0., 0., 1.]
-            ]
-        )
+        K = np.array([[self.focal_length, 0., self.orig_size / 2.],
+                      [0., self.focal_length, self.orig_size / 2.], [0., 0., 1.]])
 
         R = np.array([[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]])
 
@@ -620,10 +613,10 @@ class IUV_Renderer(object):
 
         K = self.K.repeat(batch_size, 1, 1)
         R = self.R.repeat(batch_size, 1, 1)
-        t = torch.stack(
-            [-cam[:, 1], -cam[:, 2], 2 * self.focal_length / (self.orig_size * cam[:, 0] + 1e-9)],
-            dim=-1
-        )
+        t = torch.stack([
+            -cam[:, 1], -cam[:, 2], 2 * self.focal_length / (self.orig_size * cam[:, 0] + 1e-9)
+        ],
+                        dim=-1)
 
         if cam.is_cuda:
             # device_id = cam.get_device()
diff --git a/lib/pymafx/utils/sample_mesh.py b/lib/pymafx/utils/sample_mesh.py
index 2599bee12d2577b6826ea8bfad8c937f2bcc2db2..1a696cf9580d7cd0e27ac0ade1db3c9e6a9c6e84 100644
--- a/lib/pymafx/utils/sample_mesh.py
+++ b/lib/pymafx/utils/sample_mesh.py
@@ -1,6 +1,8 @@
 import os
-import trimesh
+
 import numpy as np
+import trimesh
+
 from .utils.libmesh import check_mesh_contains
 
 
diff --git a/lib/pymafx/utils/saver.py b/lib/pymafx/utils/saver.py
index 6a6bd3a184cc658dbc666ad2dcf3bc15d8cc427b..faed475e6bc4a8f1e2e3cd16d81b267d9bbb8496 100644
--- a/lib/pymafx/utils/saver.py
+++ b/lib/pymafx/utils/saver.py
@@ -1,8 +1,10 @@
 from __future__ import division
-import os
-import torch
+
 import datetime
 import logging
+import os
+
+import torch
 
 logger = logging.getLogger(__name__)
 
@@ -102,10 +104,8 @@ class CheckpointSaver():
             if optimizer in checkpoint:
                 optimizers[optimizer].load_state_dict(checkpoint[optimizer])
         return {
-            'epoch': checkpoint['epoch'],
-            'batch_idx': checkpoint['batch_idx'],
-            'batch_size': checkpoint['batch_size'],
-            'total_step_count': checkpoint['total_step_count']
+            'epoch': checkpoint['epoch'], 'batch_idx': checkpoint['batch_idx'], 'batch_size':
+            checkpoint['batch_size'], 'total_step_count': checkpoint['total_step_count']
         }
 
     def get_latest_checkpoint(self):
diff --git a/lib/pymafx/utils/segms.py b/lib/pymafx/utils/segms.py
index 44c617529d67323a8664c3e00872e5db091b8be6..c8fbf7e2c49422449cf4a8c4a38e1f320a0b15c0 100644
--- a/lib/pymafx/utils/segms.py
+++ b/lib/pymafx/utils/segms.py
@@ -21,13 +21,14 @@ The following terms are used in this module
     RLE: COCO's run length encoding format
 """
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
 
 import numpy as np
-
 import pycocotools.mask as mask_util
 
 
diff --git a/lib/pymafx/utils/smooth_bbox.py b/lib/pymafx/utils/smooth_bbox.py
index 4393320e7f50128d6838d99c76b5d0f8f45f6efc..a5eb8e431840f974b2cd896d80e8368755f9a682 100644
--- a/lib/pymafx/utils/smooth_bbox.py
+++ b/lib/pymafx/utils/smooth_bbox.py
@@ -93,12 +93,10 @@ def get_all_bbox_params(kps, vis_thresh=2):
             # Linearly interpolate each param.
             previous = bbox_params[-1]
             # This will be 3x(n+2)
-            interpolated = np.array(
-                [
-                    np.linspace(prev, curr, num_to_interpolate + 2)
-                    for prev, curr in zip(previous, bbox_param)
-                ]
-            )
+            interpolated = np.array([
+                np.linspace(prev, curr, num_to_interpolate + 2)
+                for prev, curr in zip(previous, bbox_param)
+            ])
             bbox_params = np.vstack((bbox_params, interpolated.T[1:-1]))
             num_to_interpolate = 0
         bbox_params = np.vstack((bbox_params, bbox_param))
diff --git a/lib/pymafx/utils/transforms.py b/lib/pymafx/utils/transforms.py
index 25534674631d40b8b263b242d05339443b169dcb..5f4189ee0e2da45e565b322d207b011ae3ed70f5 100644
--- a/lib/pymafx/utils/transforms.py
+++ b/lib/pymafx/utils/transforms.py
@@ -4,12 +4,10 @@
 # Written by Bin Xiao (Bin.Xiao@microsoft.com)
 # ------------------------------------------------------------------------------
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import, division, print_function
 
-import numpy as np
 import cv2
+import numpy as np
 
 
 def flip_back(output_flipped, matched_parts):
diff --git a/lib/pymafx/utils/uv_vis.py b/lib/pymafx/utils/uv_vis.py
index 86fdd33ddee774c2bbe02478b2d74f53f8522256..2bac9e86cdeb7d0091382914188755e7ecbf541a 100644
--- a/lib/pymafx/utils/uv_vis.py
+++ b/lib/pymafx/utils/uv_vis.py
@@ -1,10 +1,11 @@
 import os
-import torch
+
+# Use a non-interactive backend
+import matplotlib
 import numpy as np
+import torch
 import torch.nn.functional as F
 from skimage.transform import resize
-# Use a non-interactive backend
-import matplotlib
 
 matplotlib.use('Agg')
 
@@ -100,9 +101,8 @@ def vis_smpl_iuv(
     for draw_i in range(len(cam_pred)):
         err_val = '{:06d}_'.format(int(10 * vert_errors_batch[draw_i]))
         draw_name = err_val + image_name[draw_i]
-        K = np.array(
-            [[focal_length, 0., orig_size / 2.], [0., focal_length, orig_size / 2.], [0., 0., 1.]]
-        )
+        K = np.array([[focal_length, 0., orig_size / 2.], [0., focal_length, orig_size / 2.],
+                      [0., 0., 1.]])
 
         # img_orig, img_resized, img_smpl, render_smpl_rgba = dr_render(
         #     image[draw_i],
diff --git a/lib/pymafx/utils/vis.py b/lib/pymafx/utils/vis.py
index 5273707c05f66275150e7cb2d86f44dcf4c92223..f64a490e60ded94ee8cffa51dae26f2ffa3a6ead 100644
--- a/lib/pymafx/utils/vis.py
+++ b/lib/pymafx/utils/vis.py
@@ -17,24 +17,26 @@
 # limitations under the License.
 ##############################################################################
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import (
+    absolute_import,
+    division,
+    print_function,
+    unicode_literals,
+)
+
+import math
+import os
 
 import cv2
+# Use a non-interactive backend
+import matplotlib
 import numpy as np
-import os
 import pycocotools.mask as mask_util
-import math
 import torchvision
 
 from .colormap import colormap
-from .keypoints import get_keypoints
 from .imutils import normalize_2d_kp
-
-# Use a non-interactive backend
-import matplotlib
+from .keypoints import get_keypoints
 
 matplotlib.use('Agg')
 import matplotlib.pyplot as plt
@@ -191,15 +193,13 @@ def vis_one_image(
         print(dataset.classes[classes[i]], score)
         # show box (off by default, box_alpha=0.0)
         ax.add_patch(
-            plt.Rectangle(
-                (bbox[0], bbox[1]),
-                bbox[2] - bbox[0],
-                bbox[3] - bbox[1],
-                fill=False,
-                edgecolor='g',
-                linewidth=0.5,
-                alpha=box_alpha
-            )
+            plt.Rectangle((bbox[0], bbox[1]),
+                          bbox[2] - bbox[0],
+                          bbox[3] - bbox[1],
+                          fill=False,
+                          edgecolor='g',
+                          linewidth=0.5,
+                          alpha=box_alpha)
         )
 
         if show_class:
diff --git a/lib/smplx/__init__.py b/lib/smplx/__init__.py
index 886949df670691d1ef5995737cafa285224826c4..f29a068748ad6f0cd5e246a7678aeda15991c1ea 100644
--- a/lib/smplx/__init__.py
+++ b/lib/smplx/__init__.py
@@ -15,16 +15,16 @@
 # Contact: ps-license@tuebingen.mpg.de
 
 from .body_models import (
-    create,
+    FLAME,
+    MANO,
     SMPL,
     SMPLH,
     SMPLX,
-    MANO,
-    FLAME,
-    build_layer,
-    SMPLLayer,
+    FLAMELayer,
+    MANOLayer,
     SMPLHLayer,
+    SMPLLayer,
     SMPLXLayer,
-    MANOLayer,
-    FLAMELayer,
+    build_layer,
+    create,
 )
diff --git a/lib/smplx/body_models.py b/lib/smplx/body_models.py
index b98adb635a3f9296c102e2bb6ca93bcdb14ab57d..ad55315dc5aa23422352bcca1d17b5f177d55e5d 100644
--- a/lib/smplx/body_models.py
+++ b/lib/smplx/body_models.py
@@ -14,36 +14,34 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from typing import Optional, Dict, Union
+import logging
 import os
 import os.path as osp
 import pickle
+from collections import namedtuple
+from typing import Dict, Optional, Union
 
 import numpy as np
 import torch
 import torch.nn as nn
-from collections import namedtuple
-
-import logging
 
 logging.getLogger("smplx").setLevel(logging.ERROR)
 
-from .lbs import lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords
-
-from .vertex_ids import vertex_ids as VERTEX_IDS
+from .lbs import find_dynamic_lmk_idx_and_bcoords, lbs, vertices2landmarks
 from .utils import (
-    Struct,
-    to_np,
-    to_tensor,
-    Tensor,
     Array,
-    SMPLOutput,
+    FLAMEOutput,
+    MANOOutput,
     SMPLHOutput,
+    SMPLOutput,
     SMPLXOutput,
-    MANOOutput,
-    FLAMEOutput,
+    Struct,
+    Tensor,
     find_joint_kin_chain,
+    to_np,
+    to_tensor,
 )
+from .vertex_ids import vertex_ids as VERTEX_IDS
 from .vertex_joint_selector import VertexJointSelector
 
 ModelOutput = namedtuple(
@@ -1110,9 +1108,8 @@ class SMPLX(SMPLH):
 
         if create_expression:
             if expression is None:
-                default_expression = torch.zeros(
-                    [batch_size, self.num_expression_coeffs], dtype=dtype
-                )
+                default_expression = torch.zeros([batch_size, self.num_expression_coeffs],
+                                                 dtype=dtype)
             else:
                 default_expression = torch.tensor(expression, dtype=dtype)
             expression_param = nn.Parameter(default_expression, requires_grad=True)
@@ -1337,9 +1334,9 @@ class SMPLX(SMPLH):
             dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
 
             lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
-            lmk_bary_coords = torch.cat(
-                [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1
-            )
+            lmk_bary_coords = torch.cat([
+                lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
+            ], 1)
 
         landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
 
@@ -1513,9 +1510,9 @@ class SMPLXLayer(SMPLX):
                           dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
             )
         if expression is None:
-            expression = torch.zeros(
-                [batch_size, self.num_expression_coeffs], dtype=dtype, device=device
-            )
+            expression = torch.zeros([batch_size, self.num_expression_coeffs],
+                                     dtype=dtype,
+                                     device=device)
         if betas is None:
             betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
         if transl is None:
@@ -1564,9 +1561,9 @@ class SMPLXLayer(SMPLX):
             dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
 
             lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
-            lmk_bary_coords = torch.cat(
-                [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1
-            )
+            lmk_bary_coords = torch.cat([
+                lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
+            ], 1)
 
         landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
 
@@ -2044,9 +2041,8 @@ class FLAME(SMPL):
 
         if create_expression:
             if expression is None:
-                default_expression = torch.zeros(
-                    [batch_size, self.num_expression_coeffs], dtype=dtype
-                )
+                default_expression = torch.zeros([batch_size, self.num_expression_coeffs],
+                                                 dtype=dtype)
             else:
                 default_expression = torch.tensor(expression, dtype=dtype)
             expression_param = nn.Parameter(default_expression, requires_grad=True)
@@ -2202,9 +2198,9 @@ class FLAME(SMPL):
             )
             dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
             lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
-            lmk_bary_coords = torch.cat(
-                [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1
-            )
+            lmk_bary_coords = torch.cat([
+                lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
+            ], 1)
 
         landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
 
@@ -2331,9 +2327,9 @@ class FLAMELayer(FLAME):
         if betas is None:
             betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
         if expression is None:
-            expression = torch.zeros(
-                [batch_size, self.num_expression_coeffs], dtype=dtype, device=device
-            )
+            expression = torch.zeros([batch_size, self.num_expression_coeffs],
+                                     dtype=dtype,
+                                     device=device)
         if transl is None:
             transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
 
@@ -2367,9 +2363,9 @@ class FLAMELayer(FLAME):
             )
             dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
             lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
-            lmk_bary_coords = torch.cat(
-                [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1
-            )
+            lmk_bary_coords = torch.cat([
+                lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
+            ], 1)
 
         landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
 
diff --git a/lib/smplx/lbs.py b/lib/smplx/lbs.py
index ac64f4b41be569331d632bfeb50fef9c50dc3d71..13862e837723d71ed70ecf7b68dd41e84ebd772c 100644
--- a/lib/smplx/lbs.py
+++ b/lib/smplx/lbs.py
@@ -14,17 +14,15 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from __future__ import absolute_import
-from __future__ import print_function
-from __future__ import division
+from __future__ import absolute_import, division, print_function
 
-from typing import Tuple, List, Optional
-import numpy as np
+from typing import List, Optional, Tuple
 
+import numpy as np
 import torch
 import torch.nn.functional as F
 
-from .utils import rot_mat_to_euler, Tensor
+from .utils import Tensor, rot_mat_to_euler
 
 
 def find_dynamic_lmk_idx_and_bcoords(
diff --git a/lib/smplx/utils.py b/lib/smplx/utils.py
index d43a25217573f4c327adbf0411a76d1081632a69..6e2a2ff45f6597c2f8a0acf0820a3eb7dac9cfb2 100644
--- a/lib/smplx/utils.py
+++ b/lib/smplx/utils.py
@@ -14,8 +14,9 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from typing import NewType, Union, Optional
-from dataclasses import dataclass, asdict, fields
+from dataclasses import asdict, dataclass, fields
+from typing import NewType, Optional, Union
+
 import numpy as np
 import torch
 
diff --git a/lib/smplx/vertex_ids.py b/lib/smplx/vertex_ids.py
index 31ed146ed4b3529bfbe0c92450bd3b02559f338b..060ed8ed60117a33358944abb85891abb3de8e30 100644
--- a/lib/smplx/vertex_ids.py
+++ b/lib/smplx/vertex_ids.py
@@ -14,61 +14,57 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from __future__ import print_function
-from __future__ import absolute_import
-from __future__ import division
+from __future__ import absolute_import, division, print_function
 
 # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to
 # MSCOCO and OpenPose joints
 vertex_ids = {
-    "smplh":
-        {
-            "nose": 332,
-            "reye": 6260,
-            "leye": 2800,
-            "rear": 4071,
-            "lear": 583,
-            "rthumb": 6191,
-            "rindex": 5782,
-            "rmiddle": 5905,
-            "rring": 6016,
-            "rpinky": 6133,
-            "lthumb": 2746,
-            "lindex": 2319,
-            "lmiddle": 2445,
-            "lring": 2556,
-            "lpinky": 2673,
-            "LBigToe": 3216,
-            "LSmallToe": 3226,
-            "LHeel": 3387,
-            "RBigToe": 6617,
-            "RSmallToe": 6624,
-            "RHeel": 6787,
-        },
-    "smplx":
-        {
-            "nose": 9120,
-            "reye": 9929,
-            "leye": 9448,
-            "rear": 616,
-            "lear": 6,
-            "rthumb": 8079,
-            "rindex": 7669,
-            "rmiddle": 7794,
-            "rring": 7905,
-            "rpinky": 8022,
-            "lthumb": 5361,
-            "lindex": 4933,
-            "lmiddle": 5058,
-            "lring": 5169,
-            "lpinky": 5286,
-            "LBigToe": 5770,
-            "LSmallToe": 5780,
-            "LHeel": 8846,
-            "RBigToe": 8463,
-            "RSmallToe": 8474,
-            "RHeel": 8635,
-        },
+    "smplh": {
+        "nose": 332,
+        "reye": 6260,
+        "leye": 2800,
+        "rear": 4071,
+        "lear": 583,
+        "rthumb": 6191,
+        "rindex": 5782,
+        "rmiddle": 5905,
+        "rring": 6016,
+        "rpinky": 6133,
+        "lthumb": 2746,
+        "lindex": 2319,
+        "lmiddle": 2445,
+        "lring": 2556,
+        "lpinky": 2673,
+        "LBigToe": 3216,
+        "LSmallToe": 3226,
+        "LHeel": 3387,
+        "RBigToe": 6617,
+        "RSmallToe": 6624,
+        "RHeel": 6787,
+    },
+    "smplx": {
+        "nose": 9120,
+        "reye": 9929,
+        "leye": 9448,
+        "rear": 616,
+        "lear": 6,
+        "rthumb": 8079,
+        "rindex": 7669,
+        "rmiddle": 7794,
+        "rring": 7905,
+        "rpinky": 8022,
+        "lthumb": 5361,
+        "lindex": 4933,
+        "lmiddle": 5058,
+        "lring": 5169,
+        "lpinky": 5286,
+        "LBigToe": 5770,
+        "LSmallToe": 5780,
+        "LHeel": 8846,
+        "RBigToe": 8463,
+        "RSmallToe": 8474,
+        "RHeel": 8635,
+    },
     "mano": {
         "thumb": 744,
         "index": 320,
diff --git a/lib/smplx/vertex_joint_selector.py b/lib/smplx/vertex_joint_selector.py
index 1680e07acb03402a54fc0621ab36ec1d4de2c78e..dfee954435df6a3e674eb0c553476caf5d2a019a 100644
--- a/lib/smplx/vertex_joint_selector.py
+++ b/lib/smplx/vertex_joint_selector.py
@@ -14,12 +14,9 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
-from __future__ import absolute_import
-from __future__ import print_function
-from __future__ import division
+from __future__ import absolute_import, division, print_function
 
 import numpy as np
-
 import torch
 import torch.nn as nn
 
diff --git a/lib/torch_utils/custom_ops.py b/lib/torch_utils/custom_ops.py
index 2170f4732aba52f614b7cec09ac62465275ad90b..c76fc0e6a9c41c9e5b8e861079865eae41189226 100644
--- a/lib/torch_utils/custom_ops.py
+++ b/lib/torch_utils/custom_ops.py
@@ -6,15 +6,15 @@
 # distribution of this software and related documentation without an express
 # license agreement from NVIDIA CORPORATION is strictly prohibited.
 
-import os
 import glob
-import torch
-import torch.utils.cpp_extension
-import importlib
 import hashlib
+import importlib
+import os
 import shutil
 from pathlib import Path
 
+import torch
+import torch.utils.cpp_extension
 from torch.utils.file_baton import FileBaton
 
 #----------------------------------------------------------------------------
diff --git a/lib/torch_utils/misc.py b/lib/torch_utils/misc.py
index 61c266a84d83e9a486df52e725af1c51488951e4..4946f0cc5fd29bb20bec9db27d0285c35878ec43 100644
--- a/lib/torch_utils/misc.py
+++ b/lib/torch_utils/misc.py
@@ -6,12 +6,13 @@
 # distribution of this software and related documentation without an express
 # license agreement from NVIDIA CORPORATION is strictly prohibited.
 
-import re
 import contextlib
-import numpy as np
-import torch
+import re
 import warnings
+
 import dnnlib
+import numpy as np
+import torch
 
 #----------------------------------------------------------------------------
 # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
@@ -272,15 +273,13 @@ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
         buffer_size = sum(t.numel() for t in e.unique_buffers)
         output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
         output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
-        rows += [
-            [
-                name + (':0' if len(e.outputs) >= 2 else ''),
-                str(param_size) if param_size else '-',
-                str(buffer_size) if buffer_size else '-',
-                (output_shapes + ['-'])[0],
-                (output_dtypes + ['-'])[0],
-            ]
-        ]
+        rows += [[
+            name + (':0' if len(e.outputs) >= 2 else ''),
+            str(param_size) if param_size else '-',
+            str(buffer_size) if buffer_size else '-',
+            (output_shapes + ['-'])[0],
+            (output_dtypes + ['-'])[0],
+        ]]
         for idx in range(1, len(e.outputs)):
             rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
         param_total += param_size
diff --git a/lib/torch_utils/ops/bias_act.py b/lib/torch_utils/ops/bias_act.py
index d8cfdb65d25ed077827862bc70e860c450fe929a..81d07ac029a2c36c12bcc6caff59a73476bdaf1e 100644
--- a/lib/torch_utils/ops/bias_act.py
+++ b/lib/torch_utils/ops/bias_act.py
@@ -8,94 +8,94 @@
 """Custom PyTorch ops for efficient bias and activation."""
 
 import os
+import traceback
 import warnings
+
+import dnnlib
 import numpy as np
 import torch
-import dnnlib
-import traceback
 
-from .. import custom_ops
-from .. import misc
+from .. import custom_ops, misc
 
 #----------------------------------------------------------------------------
 
 activation_funcs = {
     'linear':
-        dnnlib.EasyDict(
-            func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False
-        ),
+    dnnlib.EasyDict(
+        func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False
+    ),
     'relu':
-        dnnlib.EasyDict(
-            func=lambda x, **_: torch.nn.functional.relu(x),
-            def_alpha=0,
-            def_gain=np.sqrt(2),
-            cuda_idx=2,
-            ref='y',
-            has_2nd_grad=False
-        ),
+    dnnlib.EasyDict(
+        func=lambda x, **_: torch.nn.functional.relu(x),
+        def_alpha=0,
+        def_gain=np.sqrt(2),
+        cuda_idx=2,
+        ref='y',
+        has_2nd_grad=False
+    ),
     'lrelu':
-        dnnlib.EasyDict(
-            func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
-            def_alpha=0.2,
-            def_gain=np.sqrt(2),
-            cuda_idx=3,
-            ref='y',
-            has_2nd_grad=False
-        ),
+    dnnlib.EasyDict(
+        func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
+        def_alpha=0.2,
+        def_gain=np.sqrt(2),
+        cuda_idx=3,
+        ref='y',
+        has_2nd_grad=False
+    ),
     'tanh':
-        dnnlib.EasyDict(
-            func=lambda x, **_: torch.tanh(x),
-            def_alpha=0,
-            def_gain=1,
-            cuda_idx=4,
-            ref='y',
-            has_2nd_grad=True
-        ),
+    dnnlib.EasyDict(
+        func=lambda x, **_: torch.tanh(x),
+        def_alpha=0,
+        def_gain=1,
+        cuda_idx=4,
+        ref='y',
+        has_2nd_grad=True
+    ),
     'sigmoid':
-        dnnlib.EasyDict(
-            func=lambda x, **_: torch.sigmoid(x),
-            def_alpha=0,
-            def_gain=1,
-            cuda_idx=5,
-            ref='y',
-            has_2nd_grad=True
-        ),
+    dnnlib.EasyDict(
+        func=lambda x, **_: torch.sigmoid(x),
+        def_alpha=0,
+        def_gain=1,
+        cuda_idx=5,
+        ref='y',
+        has_2nd_grad=True
+    ),
     'elu':
-        dnnlib.EasyDict(
-            func=lambda x, **_: torch.nn.functional.elu(x),
-            def_alpha=0,
-            def_gain=1,
-            cuda_idx=6,
-            ref='y',
-            has_2nd_grad=True
-        ),
+    dnnlib.EasyDict(
+        func=lambda x, **_: torch.nn.functional.elu(x),
+        def_alpha=0,
+        def_gain=1,
+        cuda_idx=6,
+        ref='y',
+        has_2nd_grad=True
+    ),
     'selu':
-        dnnlib.EasyDict(
-            func=lambda x, **_: torch.nn.functional.selu(x),
-            def_alpha=0,
-            def_gain=1,
-            cuda_idx=7,
-            ref='y',
-            has_2nd_grad=True
-        ),
+    dnnlib.EasyDict(
+        func=lambda x, **_: torch.nn.functional.selu(x),
+        def_alpha=0,
+        def_gain=1,
+        cuda_idx=7,
+        ref='y',
+        has_2nd_grad=True
+    ),
     'softplus':
-        dnnlib.EasyDict(
-            func=lambda x, **_: torch.nn.functional.softplus(x),
-            def_alpha=0,
-            def_gain=1,
-            cuda_idx=8,
-            ref='y',
-            has_2nd_grad=True
-        ),
+    dnnlib.EasyDict(
+        func=lambda x, **_: torch.nn.functional.softplus(x),
+        def_alpha=0,
+        def_gain=1,
+        cuda_idx=8,
+        ref='y',
+        has_2nd_grad=True
+    ),
     'swish':
-        dnnlib.EasyDict(
-            func=lambda x, **_: torch.sigmoid(x) * x,
-            def_alpha=0,
-            def_gain=np.sqrt(2),
-            cuda_idx=9,
-            ref='x',
-            has_2nd_grad=True
-        ),
+    dnnlib.EasyDict(
+        func=lambda x, **_: torch.sigmoid(x) * x,
+        def_alpha=0,
+        def_gain=np.sqrt(2),
+        cuda_idx=9,
+        ref='x',
+        has_2nd_grad=True
+    ),
 }
 
 #----------------------------------------------------------------------------
diff --git a/lib/torch_utils/ops/conv2d_gradfix.py b/lib/torch_utils/ops/conv2d_gradfix.py
index 29c3d8f5a8a1e2816e225af3157fc1bb99a4fd33..16bcdfdb229acf55b30a33c47ed419bd008bae7d 100644
--- a/lib/torch_utils/ops/conv2d_gradfix.py
+++ b/lib/torch_utils/ops/conv2d_gradfix.py
@@ -8,8 +8,9 @@
 """Custom replacement for `torch.nn.functional.conv2d` that supports
 arbitrarily high order gradients with zero performance penalty."""
 
-import warnings
 import contextlib
+import warnings
+
 import torch
 
 # pylint: disable=redefined-builtin
diff --git a/lib/torch_utils/ops/conv2d_resample.py b/lib/torch_utils/ops/conv2d_resample.py
index 9f347c59165d1aceafee936b36281610b5a64e1b..2529947e6f5d34f9aa65bd21ede9e0fac87190ab 100644
--- a/lib/torch_utils/ops/conv2d_resample.py
+++ b/lib/torch_utils/ops/conv2d_resample.py
@@ -10,10 +10,8 @@
 import torch
 
 from .. import misc
-from . import conv2d_gradfix
-from . import upfirdn2d
-from .upfirdn2d import _parse_padding
-from .upfirdn2d import _get_filter_size
+from . import conv2d_gradfix, upfirdn2d
+from .upfirdn2d import _get_filter_size, _parse_padding
 
 #----------------------------------------------------------------------------
 
diff --git a/lib/torch_utils/ops/grid_sample_gradfix.py b/lib/torch_utils/ops/grid_sample_gradfix.py
index 850feacd5a6300b85493cd7f713bffab1af70536..b4049485344ac9cc09e5522b992962db05900dc7 100644
--- a/lib/torch_utils/ops/grid_sample_gradfix.py
+++ b/lib/torch_utils/ops/grid_sample_gradfix.py
@@ -11,6 +11,7 @@ Only works on 2D images and assumes
 `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
 
 import warnings
+
 import torch
 
 # pylint: disable=redefined-builtin
diff --git a/lib/torch_utils/ops/upfirdn2d.py b/lib/torch_utils/ops/upfirdn2d.py
index 86f6fb36eb83711db42aef6b05c003eceaeeaa69..0d5fa241d9d8d56a7a4e8c3cb9553c5916fbd4d2 100644
--- a/lib/torch_utils/ops/upfirdn2d.py
+++ b/lib/torch_utils/ops/upfirdn2d.py
@@ -8,13 +8,13 @@
 """Custom PyTorch ops for efficient resampling of 2D images."""
 
 import os
+import traceback
 import warnings
+
 import numpy as np
 import torch
-import traceback
 
-from .. import custom_ops
-from .. import misc
+from .. import custom_ops, misc
 from . import conv2d_gradfix
 
 #----------------------------------------------------------------------------
diff --git a/lib/torch_utils/persistence.py b/lib/torch_utils/persistence.py
index c3263dc0690ac12d5d2e74a6d9d8d2af2fed0f5b..34f16d48f416fc07eb6954aa810acf852db5eed5 100644
--- a/lib/torch_utils/persistence.py
+++ b/lib/torch_utils/persistence.py
@@ -12,13 +12,14 @@ during unpickling. This way, any previously exported pickles will remain
 usable even if the original code is no longer available, or if the current
 version of the code is not consistent with what was originally pickled."""
 
-import sys
-import pickle
-import io
-import inspect
 import copy
-import uuid
+import inspect
+import io
+import pickle
+import sys
 import types
+import uuid
+
 import dnnlib
 
 #----------------------------------------------------------------------------
diff --git a/lib/torch_utils/training_stats.py b/lib/torch_utils/training_stats.py
index 11658fdbf55450f5f0d4679e247ff65a4b37151e..a4fd0a4c3687aff712547b2688225ba1ec621f47 100644
--- a/lib/torch_utils/training_stats.py
+++ b/lib/torch_utils/training_stats.py
@@ -11,8 +11,10 @@ synchronization overhead as well as the amount of boilerplate in user
 code."""
 
 import re
+
 import numpy as np
 import torch
+
 import lib.dnnlib
 
 from . import misc
diff --git a/requirements.txt b/requirements.txt
index ad16e028ac6e07019e69d3fa8407af0bd8eca9e4..44b6403e52c59039b093c764343d02700b0b42ef 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,4 +14,5 @@ mediapipe
 einops
 boto3
 open3d
+xatlas
 git+https://github.com/YuliangXiu/rembg.git
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..718afe479d0ace65dabe450ac2d5b3d5060315f2
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,18 @@
+[yapf]
+based_on_style = facebook
+column_limit = 100
+indent_width = 4
+spaces_before_comment = 4
+split_all_comma_separated_values = false
+split_all_top_level_comma_separated_values = false
+dedent_closing_brackets = true
+coalesce_brackets = true
+split_before_dot = false
+each_dict_entry_on_separate_line = false
+indent_dictionary_value = false
+
+[isort]
+multi_line_output = 3
+line_length = 80
+include_trailing_comma = true
+skip=./log,./results,./data,./debug,./lib/common/libmesh/setup.py,./lib/common/libvoxelize/setup.py
\ No newline at end of file