Yuliang commited on
Commit
fb140f6
·
1 Parent(s): e5f16e8

remove MeshLab dependency with Open3D

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +8 -6
  2. apps/IFGeo.py +37 -45
  3. apps/Normal.py +37 -40
  4. apps/avatarizer.py +28 -13
  5. apps/infer.py +211 -87
  6. apps/multi_render.py +1 -3
  7. configs/econ.yaml +1 -0
  8. docs/installation.md +2 -3
  9. lib/common/BNI.py +22 -22
  10. lib/common/BNI_utils.py +104 -93
  11. lib/common/blender_utils.py +29 -22
  12. lib/common/cloth_extraction.py +13 -14
  13. lib/common/config.py +7 -9
  14. lib/common/imutils.py +147 -338
  15. lib/common/libmesh/inside_mesh.py +9 -11
  16. lib/common/libmesh/setup.py +1 -4
  17. lib/common/libvoxelize/setup.py +1 -2
  18. lib/common/local_affine.py +31 -21
  19. lib/common/render.py +79 -47
  20. lib/common/render_utils.py +16 -25
  21. lib/common/seg3d_lossless.py +130 -154
  22. lib/common/seg3d_utils.py +67 -117
  23. lib/common/train_util.py +25 -445
  24. lib/common/voxelize.py +91 -80
  25. lib/dataset/Evaluator.py +29 -19
  26. lib/dataset/NormalDataset.py +45 -56
  27. lib/dataset/NormalModule.py +2 -3
  28. lib/dataset/PointFeat.py +10 -4
  29. lib/dataset/TestDataset.py +44 -16
  30. lib/dataset/body_model.py +68 -91
  31. lib/dataset/mesh_util.py +112 -416
  32. lib/net/BasePIFuNet.py +3 -4
  33. lib/net/Discriminator.py +76 -65
  34. lib/net/FBNet.py +99 -92
  35. lib/net/GANLoss.py +2 -3
  36. lib/net/IFGeoNet.py +72 -64
  37. lib/net/IFGeoNet_nobody.py +54 -44
  38. lib/net/NormalNet.py +13 -10
  39. lib/net/geometry.py +68 -56
  40. lib/net/net_util.py +22 -27
  41. lib/net/voxelize.py +21 -15
  42. lib/pixielib/models/FLAME.py +13 -15
  43. lib/pixielib/models/SMPLX.py +495 -502
  44. lib/pixielib/models/encoders.py +2 -5
  45. lib/pixielib/models/hrnet.py +108 -152
  46. lib/pixielib/models/lbs.py +25 -42
  47. lib/pixielib/models/moderators.py +2 -10
  48. lib/pixielib/models/resnet.py +12 -42
  49. lib/pixielib/pixie.py +102 -136
  50. lib/pixielib/utils/array_cropper.py +15 -20
README.md CHANGED
@@ -103,20 +103,23 @@ python -m apps.avatarizer -n {filename}
103
 
104
  ### Some adjustable parameters in _config/econ.yaml_
105
 
106
- - `use_ifnet: True`
107
- - True: use IF-Nets+ for mesh completion ( $\text{ECON}_\text{IF}$ - Better quality)
108
- - False: use SMPL-X for mesh completion ( $\text{ECON}_\text{EX}$ - Faster speed)
109
  - `use_smpl: ["hand", "face"]`
110
  - [ ]: don't use either hands or face parts from SMPL-X
111
  - ["hand"]: only use the **visible** hands from SMPL-X
112
  - ["hand", "face"]: use both **visible** hands and face from SMPL-X
113
  - `thickness: 2cm`
114
  - could be increased accordingly in case final reconstruction **xx_full.obj** looks flat
 
 
115
  - `hps_type: PIXIE`
116
  - "pixie": more accurate for face and hands
117
  - "pymafx": more robust for challenging poses
118
- - `k: 4`
119
- - could be reduced accordingly in case the surface of **xx_full.obj** has discontinous artifacts
 
120
 
121
  <br/>
122
 
@@ -160,7 +163,6 @@ Here are some great resources we benefit from:
160
  - [BiNI](https://github.com/hoshino042/bilateral_normal_integration) for Bilateral Normal Integration
161
  - [MonoPortDataset](https://github.com/Project-Splinter/MonoPortDataset) for Data Processing, [MonoPort](https://github.com/Project-Splinter/MonoPort) for fast implicit surface query
162
  - [rembg](https://github.com/danielgatis/rembg) for Human Segmentation
163
- - [pypoisson](https://github.com/mmolero/pypoisson) for poisson reconstruction
164
  - [MediaPipe](https://google.github.io/mediapipe/getting_started/python.html) for full-body landmark estimation
165
  - [PyTorch-NICP](https://github.com/wuhaozhe/pytorch-nicp) for non-rigid registration
166
  - [smplx](https://github.com/vchoutas/smplx), [PyMAF-X](https://www.liuyebin.com/pymaf-x/), [PIXIE](https://github.com/YadiraF/PIXIE) for Human Pose & Shape Estimation
 
103
 
104
  ### Some adjustable parameters in _config/econ.yaml_
105
 
106
+ - `use_ifnet: False`
107
+ - True: use IF-Nets+ for mesh completion ( $\text{ECON}_\text{IF}$ - Better quality, **~2min / img**)
108
+ - False: use SMPL-X for mesh completion ( $\text{ECON}_\text{EX}$ - Faster speed, **~1.5min / img**)
109
  - `use_smpl: ["hand", "face"]`
110
  - [ ]: don't use either hands or face parts from SMPL-X
111
  - ["hand"]: only use the **visible** hands from SMPL-X
112
  - ["hand", "face"]: use both **visible** hands and face from SMPL-X
113
  - `thickness: 2cm`
114
  - could be increased accordingly in case final reconstruction **xx_full.obj** looks flat
115
+ - `k: 4`
116
+ - could be reduced accordingly in case the surface of **xx_full.obj** has discontinous artifacts
117
  - `hps_type: PIXIE`
118
  - "pixie": more accurate for face and hands
119
  - "pymafx": more robust for challenging poses
120
+ - `texture_src: image`
121
+ - "image": direct mapping the aligned pixels to final mesh
122
+ - "SD": use Stable Diffusion to generate full texture (TODO)
123
 
124
  <br/>
125
 
 
163
  - [BiNI](https://github.com/hoshino042/bilateral_normal_integration) for Bilateral Normal Integration
164
  - [MonoPortDataset](https://github.com/Project-Splinter/MonoPortDataset) for Data Processing, [MonoPort](https://github.com/Project-Splinter/MonoPort) for fast implicit surface query
165
  - [rembg](https://github.com/danielgatis/rembg) for Human Segmentation
 
166
  - [MediaPipe](https://google.github.io/mediapipe/getting_started/python.html) for full-body landmark estimation
167
  - [PyTorch-NICP](https://github.com/wuhaozhe/pytorch-nicp) for non-rigid registration
168
  - [smplx](https://github.com/vchoutas/smplx), [PyMAF-X](https://www.liuyebin.com/pymaf-x/), [PIXIE](https://github.com/YadiraF/PIXIE) for Human Pose & Shape Estimation
apps/IFGeo.py CHANGED
@@ -24,7 +24,6 @@ torch.backends.cudnn.benchmark = True
24
 
25
 
26
  class IFGeo(pl.LightningModule):
27
-
28
  def __init__(self, cfg):
29
  super(IFGeo, self).__init__()
30
 
@@ -44,14 +43,15 @@ class IFGeo(pl.LightningModule):
44
  from lib.net.IFGeoNet_nobody import IFGeoNet
45
  self.netG = IFGeoNet(cfg)
46
 
47
-
48
- self.resolutions = (np.logspace(
49
- start=5,
50
- stop=np.log2(self.mcube_res),
51
- base=2,
52
- num=int(np.log2(self.mcube_res) - 4),
53
- endpoint=True,
54
- ) + 1.0)
 
55
 
56
  self.resolutions = self.resolutions.astype(np.int16).tolist()
57
 
@@ -82,9 +82,9 @@ class IFGeo(pl.LightningModule):
82
 
83
  if self.cfg.optim == "Adadelta":
84
 
85
- optimizer_G = torch.optim.Adadelta(optim_params_G,
86
- lr=self.lr_G,
87
- weight_decay=weight_decay)
88
 
89
  elif self.cfg.optim == "Adam":
90
 
@@ -103,20 +103,14 @@ class IFGeo(pl.LightningModule):
103
  raise NotImplementedError
104
 
105
  # set scheduler
106
- scheduler_G = torch.optim.lr_scheduler.MultiStepLR(optimizer_G,
107
- milestones=self.cfg.schedule,
108
- gamma=self.cfg.gamma)
109
 
110
  return [optimizer_G], [scheduler_G]
111
 
112
  def training_step(self, batch, batch_idx):
113
 
114
- # cfg log
115
- if self.cfg.devices == 1:
116
- if not self.cfg.fast_dev and self.global_step == 0:
117
- export_cfg(self.logger, osp.join(self.cfg.results_path, self.cfg.name), self.cfg)
118
- self.logger.experiment.config.update(convert_to_dict(self.cfg))
119
-
120
  self.netG.train()
121
 
122
  preds_G = self.netG(batch)
@@ -127,12 +121,9 @@ class IFGeo(pl.LightningModule):
127
  "loss": error_G,
128
  }
129
 
130
- self.log_dict(metrics_log,
131
- prog_bar=True,
132
- logger=True,
133
- on_step=True,
134
- on_epoch=False,
135
- sync_dist=True)
136
 
137
  return metrics_log
138
 
@@ -143,12 +134,14 @@ class IFGeo(pl.LightningModule):
143
  "train/avgloss": batch_mean(outputs, "loss"),
144
  }
145
 
146
- self.log_dict(metrics_log,
147
- prog_bar=False,
148
- logger=True,
149
- on_step=False,
150
- on_epoch=True,
151
- rank_zero_only=True)
 
 
152
 
153
  def validation_step(self, batch, batch_idx):
154
 
@@ -162,12 +155,9 @@ class IFGeo(pl.LightningModule):
162
  "val/loss": error_G,
163
  }
164
 
165
- self.log_dict(metrics_log,
166
- prog_bar=True,
167
- logger=False,
168
- on_step=True,
169
- on_epoch=False,
170
- sync_dist=True)
171
 
172
  return metrics_log
173
 
@@ -178,9 +168,11 @@ class IFGeo(pl.LightningModule):
178
  "val/avgloss": batch_mean(outputs, "val/loss"),
179
  }
180
 
181
- self.log_dict(metrics_log,
182
- prog_bar=False,
183
- logger=True,
184
- on_step=False,
185
- on_epoch=True,
186
- rank_zero_only=True)
 
 
 
24
 
25
 
26
  class IFGeo(pl.LightningModule):
 
27
  def __init__(self, cfg):
28
  super(IFGeo, self).__init__()
29
 
 
43
  from lib.net.IFGeoNet_nobody import IFGeoNet
44
  self.netG = IFGeoNet(cfg)
45
 
46
+ self.resolutions = (
47
+ np.logspace(
48
+ start=5,
49
+ stop=np.log2(self.mcube_res),
50
+ base=2,
51
+ num=int(np.log2(self.mcube_res) - 4),
52
+ endpoint=True,
53
+ ) + 1.0
54
+ )
55
 
56
  self.resolutions = self.resolutions.astype(np.int16).tolist()
57
 
 
82
 
83
  if self.cfg.optim == "Adadelta":
84
 
85
+ optimizer_G = torch.optim.Adadelta(
86
+ optim_params_G, lr=self.lr_G, weight_decay=weight_decay
87
+ )
88
 
89
  elif self.cfg.optim == "Adam":
90
 
 
103
  raise NotImplementedError
104
 
105
  # set scheduler
106
+ scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
107
+ optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
108
+ )
109
 
110
  return [optimizer_G], [scheduler_G]
111
 
112
  def training_step(self, batch, batch_idx):
113
 
 
 
 
 
 
 
114
  self.netG.train()
115
 
116
  preds_G = self.netG(batch)
 
121
  "loss": error_G,
122
  }
123
 
124
+ self.log_dict(
125
+ metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True
126
+ )
 
 
 
127
 
128
  return metrics_log
129
 
 
134
  "train/avgloss": batch_mean(outputs, "loss"),
135
  }
136
 
137
+ self.log_dict(
138
+ metrics_log,
139
+ prog_bar=False,
140
+ logger=True,
141
+ on_step=False,
142
+ on_epoch=True,
143
+ rank_zero_only=True
144
+ )
145
 
146
  def validation_step(self, batch, batch_idx):
147
 
 
155
  "val/loss": error_G,
156
  }
157
 
158
+ self.log_dict(
159
+ metrics_log, prog_bar=True, logger=False, on_step=True, on_epoch=False, sync_dist=True
160
+ )
 
 
 
161
 
162
  return metrics_log
163
 
 
168
  "val/avgloss": batch_mean(outputs, "val/loss"),
169
  }
170
 
171
+ self.log_dict(
172
+ metrics_log,
173
+ prog_bar=False,
174
+ logger=True,
175
+ on_step=False,
176
+ on_epoch=True,
177
+ rank_zero_only=True
178
+ )
apps/Normal.py CHANGED
@@ -1,14 +1,12 @@
1
  from lib.net import NormalNet
2
- from lib.common.train_util import convert_to_dict, export_cfg, batch_mean
3
  import torch
4
  import numpy as np
5
- import os.path as osp
6
  from skimage.transform import resize
7
  import pytorch_lightning as pl
8
 
9
 
10
  class Normal(pl.LightningModule):
11
-
12
  def __init__(self, cfg):
13
  super(Normal, self).__init__()
14
  self.cfg = cfg
@@ -44,19 +42,19 @@ class Normal(pl.LightningModule):
44
  optimizer_N_F = torch.optim.Adam(optim_params_N_F, lr=self.lr_F, betas=(0.5, 0.999))
45
  optimizer_N_B = torch.optim.Adam(optim_params_N_B, lr=self.lr_B, betas=(0.5, 0.999))
46
 
47
- scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_F,
48
- milestones=self.cfg.schedule,
49
- gamma=self.cfg.gamma)
50
 
51
- scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_B,
52
- milestones=self.cfg.schedule,
53
- gamma=self.cfg.gamma)
54
  if 'gan' in self.ALL_losses:
55
  optim_params_N_D = [{"params": self.netG.netD.parameters(), "lr": self.lr_D}]
56
  optimizer_N_D = torch.optim.Adam(optim_params_N_D, lr=self.lr_D, betas=(0.5, 0.999))
57
- scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_D,
58
- milestones=self.cfg.schedule,
59
- gamma=self.cfg.gamma)
60
  self.schedulers = [scheduler_N_F, scheduler_N_B, scheduler_N_D]
61
  optims = [optimizer_N_F, optimizer_N_B, optimizer_N_D]
62
 
@@ -77,19 +75,16 @@ class Normal(pl.LightningModule):
77
  ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(1, 2, 0),
78
  (height, height),
79
  anti_aliasing=True,
80
- ))
 
81
 
82
- self.logger.log_image(key=f"Normal/{dataset}/{idx if not self.overfit else 1}",
83
- images=[(np.concatenate(result_list, axis=1) * 255.0).astype(np.uint8)
84
- ])
 
85
 
86
  def training_step(self, batch, batch_idx):
87
 
88
- # cfg log
89
- if not self.cfg.fast_dev and self.global_step == 0 and self.cfg.devices == 1:
90
- export_cfg(self.logger, osp.join(self.cfg.results_path, self.cfg.name), self.cfg)
91
- self.logger.experiment.config.update(convert_to_dict(self.cfg))
92
-
93
  self.netG.train()
94
 
95
  # retrieve the data
@@ -125,7 +120,8 @@ class Normal(pl.LightningModule):
125
  opt_B.step()
126
 
127
  if batch_idx > 0 and batch_idx % int(
128
- self.cfg.freq_show_train) == 0 and self.cfg.devices == 1:
 
129
 
130
  self.netG.eval()
131
  with torch.no_grad():
@@ -142,12 +138,9 @@ class Normal(pl.LightningModule):
142
  for key in error_dict.keys():
143
  metrics_log["train/loss_" + key] = error_dict[key].item()
144
 
145
- self.log_dict(metrics_log,
146
- prog_bar=True,
147
- logger=True,
148
- on_step=True,
149
- on_epoch=False,
150
- sync_dist=True)
151
 
152
  return metrics_log
153
 
@@ -163,12 +156,14 @@ class Normal(pl.LightningModule):
163
  loss_name = key
164
  metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
165
 
166
- self.log_dict(metrics_log,
167
- prog_bar=False,
168
- logger=True,
169
- on_step=False,
170
- on_epoch=True,
171
- rank_zero_only=True)
 
 
172
 
173
  def validation_step(self, batch, batch_idx):
174
 
@@ -212,9 +207,11 @@ class Normal(pl.LightningModule):
212
  [stage, loss_name] = key.split("/")
213
  metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
214
 
215
- self.log_dict(metrics_log,
216
- prog_bar=False,
217
- logger=True,
218
- on_step=False,
219
- on_epoch=True,
220
- rank_zero_only=True)
 
 
 
1
  from lib.net import NormalNet
2
+ from lib.common.train_util import batch_mean
3
  import torch
4
  import numpy as np
 
5
  from skimage.transform import resize
6
  import pytorch_lightning as pl
7
 
8
 
9
  class Normal(pl.LightningModule):
 
10
  def __init__(self, cfg):
11
  super(Normal, self).__init__()
12
  self.cfg = cfg
 
42
  optimizer_N_F = torch.optim.Adam(optim_params_N_F, lr=self.lr_F, betas=(0.5, 0.999))
43
  optimizer_N_B = torch.optim.Adam(optim_params_N_B, lr=self.lr_B, betas=(0.5, 0.999))
44
 
45
+ scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
46
+ optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
47
+ )
48
 
49
+ scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
50
+ optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
51
+ )
52
  if 'gan' in self.ALL_losses:
53
  optim_params_N_D = [{"params": self.netG.netD.parameters(), "lr": self.lr_D}]
54
  optimizer_N_D = torch.optim.Adam(optim_params_N_D, lr=self.lr_D, betas=(0.5, 0.999))
55
+ scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR(
56
+ optimizer_N_D, milestones=self.cfg.schedule, gamma=self.cfg.gamma
57
+ )
58
  self.schedulers = [scheduler_N_F, scheduler_N_B, scheduler_N_D]
59
  optims = [optimizer_N_F, optimizer_N_B, optimizer_N_D]
60
 
 
75
  ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(1, 2, 0),
76
  (height, height),
77
  anti_aliasing=True,
78
+ )
79
+ )
80
 
81
+ self.logger.log_image(
82
+ key=f"Normal/{dataset}/{idx if not self.overfit else 1}",
83
+ images=[(np.concatenate(result_list, axis=1) * 255.0).astype(np.uint8)]
84
+ )
85
 
86
  def training_step(self, batch, batch_idx):
87
 
 
 
 
 
 
88
  self.netG.train()
89
 
90
  # retrieve the data
 
120
  opt_B.step()
121
 
122
  if batch_idx > 0 and batch_idx % int(
123
+ self.cfg.freq_show_train
124
+ ) == 0 and self.cfg.devices == 1:
125
 
126
  self.netG.eval()
127
  with torch.no_grad():
 
138
  for key in error_dict.keys():
139
  metrics_log["train/loss_" + key] = error_dict[key].item()
140
 
141
+ self.log_dict(
142
+ metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True
143
+ )
 
 
 
144
 
145
  return metrics_log
146
 
 
156
  loss_name = key
157
  metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
158
 
159
+ self.log_dict(
160
+ metrics_log,
161
+ prog_bar=False,
162
+ logger=True,
163
+ on_step=False,
164
+ on_epoch=True,
165
+ rank_zero_only=True
166
+ )
167
 
168
  def validation_step(self, batch, batch_idx):
169
 
 
207
  [stage, loss_name] = key.split("/")
208
  metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
209
 
210
+ self.log_dict(
211
+ metrics_log,
212
+ prog_bar=False,
213
+ logger=True,
214
+ on_step=False,
215
+ on_epoch=True,
216
+ rank_zero_only=True
217
+ )
apps/avatarizer.py CHANGED
@@ -44,7 +44,8 @@ smpl_model = smplx.create(
44
  use_pca=False,
45
  num_betas=200,
46
  num_expression_coeffs=50,
47
- ext='pkl')
 
48
 
49
  smpl_out_lst = []
50
 
@@ -62,7 +63,9 @@ for pose_type in ["t-pose", "da-pose", "pose"]:
62
  return_full_pose=True,
63
  return_joint_transformation=True,
64
  return_vertex_transformation=True,
65
- pose_type=pose_type))
 
 
66
 
67
  smpl_verts = smpl_out_lst[2].vertices.detach()[0]
68
  smpl_tree = cKDTree(smpl_verts.cpu().numpy())
@@ -74,7 +77,8 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
74
  econ_verts = torch.tensor(econ_obj.vertices).float()
75
  rot_mat_t = smpl_out_lst[2].vertex_transformation.detach()[0][idx[:, 0]]
76
  homo_coord = torch.ones_like(econ_verts)[..., :1]
77
- econ_cano_verts = torch.inverse(rot_mat_t) @ torch.cat([econ_verts, homo_coord], dim=1).unsqueeze(-1)
 
78
  econ_cano_verts = econ_cano_verts[:, :3, 0].cpu()
79
  econ_cano = trimesh.Trimesh(econ_cano_verts, econ_obj.faces)
80
 
@@ -84,7 +88,9 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
84
  econ_da = trimesh.Trimesh(econ_da_verts[:, :3, 0].cpu(), econ_obj.faces)
85
 
86
  # da-pose for SMPL-X
87
- smpl_da = trimesh.Trimesh(smpl_out_lst[1].vertices.detach()[0], smpl_model.faces, maintain_orders=True, process=False)
 
 
88
  smpl_da.export(f"{prefix}_smpl_da.obj")
89
 
90
  # remove hands from ECON for next registeration
@@ -97,7 +103,8 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
97
  # remove SMPL-X hand and face
98
  register_mask = ~np.isin(
99
  np.arange(smpl_da.vertices.shape[0]),
100
- np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid]))
 
101
  register_mask *= ~smplx_container.eyeball_vertex_mask.bool().numpy()
102
  smpl_da_body = smpl_da.copy()
103
  smpl_da_body.update_faces(register_mask[smpl_da.faces].all(axis=1))
@@ -115,8 +122,13 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
115
  # remove over-streched+hand faces from ECON
116
  econ_da_body = econ_da.copy()
117
  edge_before = np.sqrt(
118
- ((econ_obj.vertices[econ_cano.edges[:, 0]] - econ_obj.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1))
119
- edge_after = np.sqrt(((econ_da.vertices[econ_cano.edges[:, 0]] - econ_da.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1))
 
 
 
 
 
120
  edge_diff = edge_after / edge_before.clip(1e-2)
121
  streched_mask = np.unique(econ_cano.edges[edge_diff > 6])
122
  mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid)
@@ -148,8 +160,9 @@ econ_J_regressor = (smpl_model.J_regressor[:, idx] * knn_weights[None]).sum(axis
148
  econ_lbs_weights = (smpl_model.lbs_weights.T[:, idx] * knn_weights[None]).sum(axis=-1).T
149
 
150
  num_posedirs = smpl_model.posedirs.shape[0]
151
- econ_posedirs = (smpl_model.posedirs.view(num_posedirs, -1, 3)[:, idx, :] *
152
- knn_weights[None, ..., None]).sum(axis=-2).view(num_posedirs, -1).float()
 
153
 
154
  econ_J_regressor /= econ_J_regressor.sum(axis=1, keepdims=True)
155
  econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
@@ -157,8 +170,9 @@ econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
157
  # re-compute da-pose rot_mat for ECON
158
  rot_mat_da = smpl_out_lst[1].vertex_transformation.detach()[0][idx[:, 0]]
159
  econ_da_verts = torch.tensor(econ_da.vertices).float()
160
- econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat([econ_da_verts, torch.ones_like(econ_da_verts)[..., :1]],
161
- dim=1).unsqueeze(-1)
 
162
  econ_cano_verts = econ_cano_verts[:, :3, 0].double()
163
 
164
  # ----------------------------------------------------
@@ -174,7 +188,8 @@ posed_econ_verts, _ = general_lbs(
174
  posedirs=econ_posedirs,
175
  J_regressor=econ_J_regressor,
176
  parents=smpl_model.parents,
177
- lbs_weights=econ_lbs_weights)
 
178
 
179
  econ_pose = trimesh.Trimesh(posed_econ_verts[0].detach(), econ_da.faces)
180
- econ_pose.export(f"{prefix}_econ_pose.obj")
 
44
  use_pca=False,
45
  num_betas=200,
46
  num_expression_coeffs=50,
47
+ ext='pkl'
48
+ )
49
 
50
  smpl_out_lst = []
51
 
 
63
  return_full_pose=True,
64
  return_joint_transformation=True,
65
  return_vertex_transformation=True,
66
+ pose_type=pose_type
67
+ )
68
+ )
69
 
70
  smpl_verts = smpl_out_lst[2].vertices.detach()[0]
71
  smpl_tree = cKDTree(smpl_verts.cpu().numpy())
 
77
  econ_verts = torch.tensor(econ_obj.vertices).float()
78
  rot_mat_t = smpl_out_lst[2].vertex_transformation.detach()[0][idx[:, 0]]
79
  homo_coord = torch.ones_like(econ_verts)[..., :1]
80
+ econ_cano_verts = torch.inverse(rot_mat_t) @ torch.cat([econ_verts, homo_coord],
81
+ dim=1).unsqueeze(-1)
82
  econ_cano_verts = econ_cano_verts[:, :3, 0].cpu()
83
  econ_cano = trimesh.Trimesh(econ_cano_verts, econ_obj.faces)
84
 
 
88
  econ_da = trimesh.Trimesh(econ_da_verts[:, :3, 0].cpu(), econ_obj.faces)
89
 
90
  # da-pose for SMPL-X
91
+ smpl_da = trimesh.Trimesh(
92
+ smpl_out_lst[1].vertices.detach()[0], smpl_model.faces, maintain_orders=True, process=False
93
+ )
94
  smpl_da.export(f"{prefix}_smpl_da.obj")
95
 
96
  # remove hands from ECON for next registeration
 
103
  # remove SMPL-X hand and face
104
  register_mask = ~np.isin(
105
  np.arange(smpl_da.vertices.shape[0]),
106
+ np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid])
107
+ )
108
  register_mask *= ~smplx_container.eyeball_vertex_mask.bool().numpy()
109
  smpl_da_body = smpl_da.copy()
110
  smpl_da_body.update_faces(register_mask[smpl_da.faces].all(axis=1))
 
122
  # remove over-streched+hand faces from ECON
123
  econ_da_body = econ_da.copy()
124
  edge_before = np.sqrt(
125
+ ((econ_obj.vertices[econ_cano.edges[:, 0]] -
126
+ econ_obj.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1)
127
+ )
128
+ edge_after = np.sqrt(
129
+ ((econ_da.vertices[econ_cano.edges[:, 0]] -
130
+ econ_da.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1)
131
+ )
132
  edge_diff = edge_after / edge_before.clip(1e-2)
133
  streched_mask = np.unique(econ_cano.edges[edge_diff > 6])
134
  mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid)
 
160
  econ_lbs_weights = (smpl_model.lbs_weights.T[:, idx] * knn_weights[None]).sum(axis=-1).T
161
 
162
  num_posedirs = smpl_model.posedirs.shape[0]
163
+ econ_posedirs = (
164
+ smpl_model.posedirs.view(num_posedirs, -1, 3)[:, idx, :] * knn_weights[None, ..., None]
165
+ ).sum(axis=-2).view(num_posedirs, -1).float()
166
 
167
  econ_J_regressor /= econ_J_regressor.sum(axis=1, keepdims=True)
168
  econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
 
170
  # re-compute da-pose rot_mat for ECON
171
  rot_mat_da = smpl_out_lst[1].vertex_transformation.detach()[0][idx[:, 0]]
172
  econ_da_verts = torch.tensor(econ_da.vertices).float()
173
+ econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat(
174
+ [econ_da_verts, torch.ones_like(econ_da_verts)[..., :1]], dim=1
175
+ ).unsqueeze(-1)
176
  econ_cano_verts = econ_cano_verts[:, :3, 0].double()
177
 
178
  # ----------------------------------------------------
 
188
  posedirs=econ_posedirs,
189
  J_regressor=econ_J_regressor,
190
  parents=smpl_model.parents,
191
+ lbs_weights=econ_lbs_weights
192
+ )
193
 
194
  econ_pose = trimesh.Trimesh(posed_econ_verts[0].detach(), econ_da.faces)
195
+ econ_pose.export(f"{prefix}_econ_pose.obj")
apps/infer.py CHANGED
@@ -34,7 +34,8 @@ from apps.IFGeo import IFGeo
34
  from pytorch3d.ops import SubdivideMeshes
35
  from lib.common.config import cfg
36
  from lib.common.render import query_color
37
- from lib.common.train_util import init_loss, load_normal_networks, load_networks
 
38
  from lib.common.BNI import BNI
39
  from lib.common.BNI_utils import save_normal_tensor
40
  from lib.dataset.TestDataset import TestDataset
@@ -68,20 +69,25 @@ if __name__ == "__main__":
68
  device = torch.device(f"cuda:{args.gpu_device}")
69
 
70
  # setting for testing on in-the-wild images
71
- cfg_show_list = ["test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True, "batch_size", 1]
 
 
 
72
 
73
  cfg.merge_from_list(cfg_show_list)
74
  cfg.freeze()
75
 
76
- # load model
77
- normal_model = Normal(cfg).to(device)
78
- load_normal_networks(normal_model, cfg.normal_path)
79
- normal_model.netG.eval()
80
-
81
- # load IFGeo model
82
- ifnet_model = IFGeo(cfg).to(device)
83
- load_networks(ifnet_model, mlp_path=cfg.ifnet_path)
84
- ifnet_model.netG.eval()
 
 
85
 
86
  # SMPLX object
87
  SMPLX_object = SMPLX()
@@ -89,16 +95,24 @@ if __name__ == "__main__":
89
  dataset_param = {
90
  "image_dir": args.in_dir,
91
  "seg_dir": args.seg_dir,
92
- "use_seg": True, # w/ or w/o segmentation
93
- "hps_type": cfg.bni.hps_type, # pymafx/pixie
94
  "vol_res": cfg.vol_res,
95
  "single": args.multi,
96
  }
97
 
98
  if cfg.bni.use_ifnet:
99
- print(colored("Use IF-Nets (Implicit)+ for completion", "green"))
 
 
 
 
 
 
 
 
100
  else:
101
- print(colored("Use SMPL-X (Explicit) for completion", "green"))
102
 
103
  dataset = TestDataset(dataset_param, device)
104
 
@@ -125,13 +139,17 @@ if __name__ == "__main__":
125
  # 2. SMPL params (xxx_smpl.npy)
126
  # 3. d-BiNI surfaces (xxx_BNI.obj)
127
  # 4. seperate face/hand mesh (xxx_hand/face.obj)
128
- # 5. full shape impainted by IF-Nets+, and remeshed shape (xxx_IF_(remesh).obj)
129
  # 6. sideded or occluded parts (xxx_side.obj)
130
  # 7. final reconstructed clothed human (xxx_full.obj)
131
 
132
  os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
133
 
134
- in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["img_icon"].to(device), "mask": data["img_mask"].to(device)}
 
 
 
 
135
 
136
  # The optimizer and variables
137
  optimed_pose = data["body_pose"].requires_grad_(True)
@@ -139,7 +157,9 @@ if __name__ == "__main__":
139
  optimed_betas = data["betas"].requires_grad_(True)
140
  optimed_orient = data["global_orient"].requires_grad_(True)
141
 
142
- optimizer_smpl = torch.optim.Adam([optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True)
 
 
143
  scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
144
  optimizer_smpl,
145
  mode="min",
@@ -156,10 +176,12 @@ if __name__ == "__main__":
156
 
157
  smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
158
 
 
159
  if osp.exists(smpl_path):
160
 
161
  smpl_verts_lst = []
162
  smpl_faces_lst = []
 
163
  for idx in range(N_body):
164
 
165
  smpl_obj = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj"
@@ -173,10 +195,12 @@ if __name__ == "__main__":
173
  batch_smpl_faces = torch.stack(smpl_faces_lst)
174
 
175
  # render optimized mesh as normal [-1,1]
176
- in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(batch_smpl_verts, batch_smpl_faces)
 
 
177
 
178
  with torch.no_grad():
179
- in_tensor["normal_F"], in_tensor["normal_B"] = normal_model.netG(in_tensor)
180
 
181
  in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
182
  in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
@@ -194,8 +218,10 @@ if __name__ == "__main__":
194
  N_body, N_pose = optimed_pose.shape[:2]
195
 
196
  # 6d_rot to rot_mat
197
- optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, 6)).view(N_body, 1, 3, 3)
198
- optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, 6)).view(N_body, N_pose, 3, 3)
 
 
199
 
200
  smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model(
201
  shape_params=optimed_betas,
@@ -208,11 +234,16 @@ if __name__ == "__main__":
208
  )
209
 
210
  smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
211
- smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor([1.0, 1.0, -1.0]).to(device)
 
 
212
 
213
  # landmark errors
214
- smpl_joints_3d = (smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0) * 0.5
215
- in_tensor["smpl_joint"] = smpl_joints[:, dataset.smpl_data.smpl_joint_ids_24_pixie, :]
 
 
 
216
 
217
  ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device)
218
  ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device)
@@ -227,7 +258,7 @@ if __name__ == "__main__":
227
  T_mask_F, T_mask_B = dataset.render.get_image(type="mask")
228
 
229
  with torch.no_grad():
230
- in_tensor["normal_F"], in_tensor["normal_B"] = normal_model.netG(in_tensor)
231
 
232
  diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"])
233
  diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"])
@@ -249,25 +280,37 @@ if __name__ == "__main__":
249
 
250
  # BUG: PyTorch3D silhouette renderer generates dilated mask
251
  bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
252
- smpl_arr_fake = torch.cat([in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), in_tensor["T_normal_B"][:, 0].ne(bg_value).float()],
253
- dim=-1)
 
 
 
 
 
254
 
255
- body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
 
256
  body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
257
  body_overlap_flag = body_overlap < cfg.body_overlap_thres
258
 
259
- losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] + diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0
 
 
 
260
 
261
  losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
262
  occluded_idx = torch.where(body_overlap_flag)[0]
263
  ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95
264
- losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) * ghum_conf).mean(dim=1)
 
265
 
266
  # Weighted sum of the losses
267
  smpl_loss = 0.0
268
- pbar_desc = "Body Fitting --- "
269
  for k in ["normal", "silhouette", "joint"]:
270
- per_loop_loss = (losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device)).mean()
 
 
271
  pbar_desc += f"{k}: {per_loop_loss:.3f} | "
272
  smpl_loss += per_loop_loss
273
  pbar_desc += f"Total: {smpl_loss:.3f}"
@@ -279,19 +322,25 @@ if __name__ == "__main__":
279
  # save intermediate results / vis_freq and final_step
280
  if (i % args.vis_freq == 0) or (i == args.loop_smpl - 1):
281
 
282
- per_loop_lst.extend([
283
- in_tensor["image"],
284
- in_tensor["T_normal_F"],
285
- in_tensor["normal_F"],
286
- diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1),
287
- ])
288
- per_loop_lst.extend([
289
- in_tensor["image"],
290
- in_tensor["T_normal_B"],
291
- in_tensor["normal_B"],
292
- diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1),
293
- ])
294
- per_data_lst.append(get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl"))
 
 
 
 
 
 
295
 
296
  smpl_loss.backward()
297
  optimizer_smpl.step()
@@ -304,14 +353,21 @@ if __name__ == "__main__":
304
  img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
305
  torchvision.utils.save_image(
306
  torch.cat(
307
- [data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5, (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5],
308
- dim=3), img_crop_path)
 
 
 
 
 
309
 
310
  rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
311
  rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
312
 
313
  img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
314
- torchvision.utils.save_image(torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path)
 
 
315
 
316
  smpl_obj_lst = []
317
 
@@ -329,15 +385,28 @@ if __name__ == "__main__":
329
  if not osp.exists(smpl_obj_path):
330
  smpl_obj.export(smpl_obj_path)
331
  smpl_info = {
332
- "betas": optimed_betas[idx].detach().cpu().unsqueeze(0),
333
- "body_pose": rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()).cpu().unsqueeze(0),
334
- "global_orient": rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()).cpu().unsqueeze(0),
335
- "transl": optimed_trans[idx].detach().cpu(),
336
- "expression": data["exp"][idx].cpu().unsqueeze(0),
337
- "jaw_pose": rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0),
338
- "left_hand_pose": rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]).cpu().unsqueeze(0),
339
- "right_hand_pose": rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]).cpu().unsqueeze(0),
340
- "scale": data["scale"][idx].cpu(),
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  }
342
  np.save(
343
  smpl_obj_path.replace(".obj", ".npy"),
@@ -359,10 +428,13 @@ if __name__ == "__main__":
359
 
360
  per_data_lst = []
361
 
362
- batch_smpl_verts = in_tensor["smpl_verts"].detach() * torch.tensor([1.0, -1.0, 1.0], device=device)
 
363
  batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]]
364
 
365
- in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(batch_smpl_verts, batch_smpl_faces)
 
 
366
 
367
  per_loop_lst = []
368
 
@@ -389,7 +461,13 @@ if __name__ == "__main__":
389
  )
390
 
391
  # BNI process
392
- BNI_object = BNI(dir_path=osp.join(args.out_dir, cfg.name, "BNI"), name=data["name"], BNI_dict=BNI_dict, cfg=cfg.bni, device=device)
 
 
 
 
 
 
393
 
394
  BNI_object.extract_surface(False)
395
 
@@ -406,29 +484,40 @@ if __name__ == "__main__":
406
  side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
407
 
408
  # mesh completion via IF-net
409
- in_tensor.update(dataset.depth_to_voxel({"depth_F": BNI_object.F_depth.unsqueeze(0), "depth_B": BNI_object.B_depth.unsqueeze(0)}))
 
 
 
 
 
 
 
410
 
411
  occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
412
  0,
413
  ] * 3, scale=2.0).data.transpose(2, 1, 0)
414
  occupancies = np.flip(occupancies, axis=1)
415
 
416
- in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device)
 
417
 
418
  with torch.no_grad():
419
- sdf = ifnet_model.reconEngine(netG=ifnet_model.netG, batch=in_tensor)
420
- verts_IF, faces_IF = ifnet_model.reconEngine.export_mesh(sdf)
421
 
422
- if ifnet_model.clean_mesh_flag:
423
  verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF)
424
 
425
  side_mesh = trimesh.Trimesh(verts_IF, faces_IF)
426
- side_mesh = remesh(side_mesh, side_mesh_path)
427
 
428
  else:
429
  side_mesh = apply_vertex_mask(
430
  side_mesh,
431
- (SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask + SMPLX_object.eyeball_vertex_mask).eq(0).float(),
 
 
 
432
  )
433
 
434
  #register side_mesh to BNI surfaces
@@ -448,7 +537,9 @@ if __name__ == "__main__":
448
  # 3. remove eyeball faces
449
 
450
  # export intermediate meshes
451
- BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
 
 
452
  full_lst = []
453
 
454
  if "face" in cfg.bni.use_smpl:
@@ -458,37 +549,63 @@ if __name__ == "__main__":
458
  face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
459
 
460
  # remove face neighbor triangles
461
- BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
462
- side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
 
 
 
 
 
 
 
 
 
463
  face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
464
  full_lst += [face_mesh]
465
 
466
  if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]):
467
 
468
- hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0],)
469
  if data['hands_visibility'][idx][0]:
470
- hand_mask.index_fill_(0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0)
 
 
471
  if data['hands_visibility'][idx][1]:
472
- hand_mask.index_fill_(0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0)
 
 
473
 
474
  # only hands
475
  hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
476
 
477
  # remove hand neighbor triangles
478
- BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
479
- side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
 
 
 
 
 
 
 
 
 
480
  hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
481
  full_lst += [hand_mesh]
482
 
483
  full_lst += [BNI_object.F_B_trimesh]
484
 
485
  # initial side_mesh could be SMPLX or IF-net
486
- side_mesh = part_removal(side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False)
 
 
487
 
488
  full_lst += [side_mesh]
489
 
490
  # # export intermediate meshes
491
- BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
 
 
492
  side_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj")
493
 
494
  if cfg.bni.use_poisson:
@@ -505,15 +622,22 @@ if __name__ == "__main__":
505
  rotate_recon_lst = dataset.render.get_image(cam_type="four")
506
  per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
507
 
508
- # coloring the final mesh
509
- final_colors = query_color(
510
- torch.tensor(final_mesh.vertices).float(),
511
- torch.tensor(final_mesh.faces).long(),
512
- in_tensor["image"][idx:idx + 1],
513
- device=device,
514
- )
515
- final_mesh.visual.vertex_colors = final_colors
516
- final_mesh.export(final_path)
 
 
 
 
 
 
 
517
 
518
  # for video rendering
519
  in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
 
34
  from pytorch3d.ops import SubdivideMeshes
35
  from lib.common.config import cfg
36
  from lib.common.render import query_color
37
+ from lib.common.train_util import init_loss, Format
38
+ from lib.common.imutils import blend_rgb_norm
39
  from lib.common.BNI import BNI
40
  from lib.common.BNI_utils import save_normal_tensor
41
  from lib.dataset.TestDataset import TestDataset
 
69
  device = torch.device(f"cuda:{args.gpu_device}")
70
 
71
  # setting for testing on in-the-wild images
72
+ cfg_show_list = [
73
+ "test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True,
74
+ "batch_size", 1
75
+ ]
76
 
77
  cfg.merge_from_list(cfg_show_list)
78
  cfg.freeze()
79
 
80
+ # load normal model
81
+ normal_net = Normal.load_from_checkpoint(
82
+ cfg=cfg, checkpoint_path=cfg.normal_path, map_location=device, strict=False
83
+ )
84
+ normal_net = normal_net.to(device)
85
+ normal_net.netG.eval()
86
+ print(
87
+ colored(
88
+ f"Resume Normal Estimator from {Format.start} {cfg.normal_path} {Format.end}", "green"
89
+ )
90
+ )
91
 
92
  # SMPLX object
93
  SMPLX_object = SMPLX()
 
95
  dataset_param = {
96
  "image_dir": args.in_dir,
97
  "seg_dir": args.seg_dir,
98
+ "use_seg": True, # w/ or w/o segmentation
99
+ "hps_type": cfg.bni.hps_type, # pymafx/pixie
100
  "vol_res": cfg.vol_res,
101
  "single": args.multi,
102
  }
103
 
104
  if cfg.bni.use_ifnet:
105
+ # load IFGeo model
106
+ ifnet = IFGeo.load_from_checkpoint(
107
+ cfg=cfg, checkpoint_path=cfg.ifnet_path, map_location=device, strict=False
108
+ )
109
+ ifnet = ifnet.to(device)
110
+ ifnet.netG.eval()
111
+
112
+ print(colored(f"Resume IF-Net+ from {Format.start} {cfg.ifnet_path} {Format.end}", "green"))
113
+ print(colored(f"Complete with {Format.start} IF-Nets+ (Implicit) {Format.end}", "green"))
114
  else:
115
+ print(colored(f"Complete with {Format.start} SMPL-X (Explicit) {Format.end}", "green"))
116
 
117
  dataset = TestDataset(dataset_param, device)
118
 
 
139
  # 2. SMPL params (xxx_smpl.npy)
140
  # 3. d-BiNI surfaces (xxx_BNI.obj)
141
  # 4. seperate face/hand mesh (xxx_hand/face.obj)
142
+ # 5. full shape impainted by IF-Nets+ after remeshing (xxx_IF.obj)
143
  # 6. sideded or occluded parts (xxx_side.obj)
144
  # 7. final reconstructed clothed human (xxx_full.obj)
145
 
146
  os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
147
 
148
+ in_tensor = {
149
+ "smpl_faces": data["smpl_faces"],
150
+ "image": data["img_icon"].to(device),
151
+ "mask": data["img_mask"].to(device)
152
+ }
153
 
154
  # The optimizer and variables
155
  optimed_pose = data["body_pose"].requires_grad_(True)
 
157
  optimed_betas = data["betas"].requires_grad_(True)
158
  optimed_orient = data["global_orient"].requires_grad_(True)
159
 
160
+ optimizer_smpl = torch.optim.Adam(
161
+ [optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True
162
+ )
163
  scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
164
  optimizer_smpl,
165
  mode="min",
 
176
 
177
  smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
178
 
179
+ # remove this line if you change the loop_smpl and obtain different SMPL-X fits
180
  if osp.exists(smpl_path):
181
 
182
  smpl_verts_lst = []
183
  smpl_faces_lst = []
184
+
185
  for idx in range(N_body):
186
 
187
  smpl_obj = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj"
 
195
  batch_smpl_faces = torch.stack(smpl_faces_lst)
196
 
197
  # render optimized mesh as normal [-1,1]
198
+ in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
199
+ batch_smpl_verts, batch_smpl_faces
200
+ )
201
 
202
  with torch.no_grad():
203
+ in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
204
 
205
  in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
206
  in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
 
218
  N_body, N_pose = optimed_pose.shape[:2]
219
 
220
  # 6d_rot to rot_mat
221
+ optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1,
222
+ 6)).view(N_body, 1, 3, 3)
223
+ optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1,
224
+ 6)).view(N_body, N_pose, 3, 3)
225
 
226
  smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model(
227
  shape_params=optimed_betas,
 
234
  )
235
 
236
  smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
237
+ smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor(
238
+ [1.0, 1.0, -1.0]
239
+ ).to(device)
240
 
241
  # landmark errors
242
+ smpl_joints_3d = (
243
+ smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0
244
+ ) * 0.5
245
+ in_tensor["smpl_joint"] = smpl_joints[:,
246
+ dataset.smpl_data.smpl_joint_ids_24_pixie, :]
247
 
248
  ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device)
249
  ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device)
 
258
  T_mask_F, T_mask_B = dataset.render.get_image(type="mask")
259
 
260
  with torch.no_grad():
261
+ in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
262
 
263
  diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"])
264
  diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"])
 
280
 
281
  # BUG: PyTorch3D silhouette renderer generates dilated mask
282
  bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
283
+ smpl_arr_fake = torch.cat(
284
+ [
285
+ in_tensor["T_normal_F"][:, 0].ne(bg_value).float(),
286
+ in_tensor["T_normal_B"][:, 0].ne(bg_value).float()
287
+ ],
288
+ dim=-1
289
+ )
290
 
291
+ body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)
292
+ ).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
293
  body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
294
  body_overlap_flag = body_overlap < cfg.body_overlap_thres
295
 
296
+ losses["normal"]["value"] = (
297
+ diff_F_smpl * body_overlap_mask[..., :512] +
298
+ diff_B_smpl * body_overlap_mask[..., 512:]
299
+ ).mean() / 2.0
300
 
301
  losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
302
  occluded_idx = torch.where(body_overlap_flag)[0]
303
  ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95
304
+ losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) *
305
+ ghum_conf).mean(dim=1)
306
 
307
  # Weighted sum of the losses
308
  smpl_loss = 0.0
309
+ pbar_desc = "Body Fitting -- "
310
  for k in ["normal", "silhouette", "joint"]:
311
+ per_loop_loss = (
312
+ losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device)
313
+ ).mean()
314
  pbar_desc += f"{k}: {per_loop_loss:.3f} | "
315
  smpl_loss += per_loop_loss
316
  pbar_desc += f"Total: {smpl_loss:.3f}"
 
322
  # save intermediate results / vis_freq and final_step
323
  if (i % args.vis_freq == 0) or (i == args.loop_smpl - 1):
324
 
325
+ per_loop_lst.extend(
326
+ [
327
+ in_tensor["image"],
328
+ in_tensor["T_normal_F"],
329
+ in_tensor["normal_F"],
330
+ diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1),
331
+ ]
332
+ )
333
+ per_loop_lst.extend(
334
+ [
335
+ in_tensor["image"],
336
+ in_tensor["T_normal_B"],
337
+ in_tensor["normal_B"],
338
+ diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1),
339
+ ]
340
+ )
341
+ per_data_lst.append(
342
+ get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl")
343
+ )
344
 
345
  smpl_loss.backward()
346
  optimizer_smpl.step()
 
353
  img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
354
  torchvision.utils.save_image(
355
  torch.cat(
356
+ [
357
+ data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
358
+ (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5
359
+ ],
360
+ dim=3
361
+ ), img_crop_path
362
+ )
363
 
364
  rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
365
  rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
366
 
367
  img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
368
+ torchvision.utils.save_image(
369
+ torch.cat([data["img_raw"], rgb_norm_F, rgb_norm_B], dim=-1) / 255., img_overlap_path
370
+ )
371
 
372
  smpl_obj_lst = []
373
 
 
385
  if not osp.exists(smpl_obj_path):
386
  smpl_obj.export(smpl_obj_path)
387
  smpl_info = {
388
+ "betas":
389
+ optimed_betas[idx].detach().cpu().unsqueeze(0),
390
+ "body_pose":
391
+ rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()
392
+ ).cpu().unsqueeze(0),
393
+ "global_orient":
394
+ rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()
395
+ ).cpu().unsqueeze(0),
396
+ "transl":
397
+ optimed_trans[idx].detach().cpu(),
398
+ "expression":
399
+ data["exp"][idx].cpu().unsqueeze(0),
400
+ "jaw_pose":
401
+ rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0),
402
+ "left_hand_pose":
403
+ rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]
404
+ ).cpu().unsqueeze(0),
405
+ "right_hand_pose":
406
+ rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]
407
+ ).cpu().unsqueeze(0),
408
+ "scale":
409
+ data["scale"][idx].cpu(),
410
  }
411
  np.save(
412
  smpl_obj_path.replace(".obj", ".npy"),
 
428
 
429
  per_data_lst = []
430
 
431
+ batch_smpl_verts = in_tensor["smpl_verts"].detach(
432
+ ) * torch.tensor([1.0, -1.0, 1.0], device=device)
433
  batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]]
434
 
435
+ in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(
436
+ batch_smpl_verts, batch_smpl_faces
437
+ )
438
 
439
  per_loop_lst = []
440
 
 
461
  )
462
 
463
  # BNI process
464
+ BNI_object = BNI(
465
+ dir_path=osp.join(args.out_dir, cfg.name, "BNI"),
466
+ name=data["name"],
467
+ BNI_dict=BNI_dict,
468
+ cfg=cfg.bni,
469
+ device=device
470
+ )
471
 
472
  BNI_object.extract_surface(False)
473
 
 
484
  side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
485
 
486
  # mesh completion via IF-net
487
+ in_tensor.update(
488
+ dataset.depth_to_voxel(
489
+ {
490
+ "depth_F": BNI_object.F_depth.unsqueeze(0),
491
+ "depth_B": BNI_object.B_depth.unsqueeze(0)
492
+ }
493
+ )
494
+ )
495
 
496
  occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
497
  0,
498
  ] * 3, scale=2.0).data.transpose(2, 1, 0)
499
  occupancies = np.flip(occupancies, axis=1)
500
 
501
+ in_tensor["body_voxels"] = torch.tensor(occupancies.copy()
502
+ ).float().unsqueeze(0).to(device)
503
 
504
  with torch.no_grad():
505
+ sdf = ifnet.reconEngine(netG=ifnet.netG, batch=in_tensor)
506
+ verts_IF, faces_IF = ifnet.reconEngine.export_mesh(sdf)
507
 
508
+ if ifnet.clean_mesh_flag:
509
  verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF)
510
 
511
  side_mesh = trimesh.Trimesh(verts_IF, faces_IF)
512
+ side_mesh = remesh_laplacian(side_mesh, side_mesh_path)
513
 
514
  else:
515
  side_mesh = apply_vertex_mask(
516
  side_mesh,
517
+ (
518
+ SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
519
+ SMPLX_object.eyeball_vertex_mask
520
+ ).eq(0).float(),
521
  )
522
 
523
  #register side_mesh to BNI surfaces
 
537
  # 3. remove eyeball faces
538
 
539
  # export intermediate meshes
540
+ BNI_object.F_B_trimesh.export(
541
+ f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj"
542
+ )
543
  full_lst = []
544
 
545
  if "face" in cfg.bni.use_smpl:
 
549
  face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
550
 
551
  # remove face neighbor triangles
552
+ BNI_object.F_B_trimesh = part_removal(
553
+ BNI_object.F_B_trimesh,
554
+ face_mesh,
555
+ cfg.bni.face_thres,
556
+ device,
557
+ smplx_mesh,
558
+ region="face"
559
+ )
560
+ side_mesh = part_removal(
561
+ side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face"
562
+ )
563
  face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
564
  full_lst += [face_mesh]
565
 
566
  if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]):
567
 
568
+ hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0], )
569
  if data['hands_visibility'][idx][0]:
570
+ hand_mask.index_fill_(
571
+ 0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0
572
+ )
573
  if data['hands_visibility'][idx][1]:
574
+ hand_mask.index_fill_(
575
+ 0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0
576
+ )
577
 
578
  # only hands
579
  hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
580
 
581
  # remove hand neighbor triangles
582
+ BNI_object.F_B_trimesh = part_removal(
583
+ BNI_object.F_B_trimesh,
584
+ hand_mesh,
585
+ cfg.bni.hand_thres,
586
+ device,
587
+ smplx_mesh,
588
+ region="hand"
589
+ )
590
+ side_mesh = part_removal(
591
+ side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand"
592
+ )
593
  hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
594
  full_lst += [hand_mesh]
595
 
596
  full_lst += [BNI_object.F_B_trimesh]
597
 
598
  # initial side_mesh could be SMPLX or IF-net
599
+ side_mesh = part_removal(
600
+ side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False
601
+ )
602
 
603
  full_lst += [side_mesh]
604
 
605
  # # export intermediate meshes
606
+ BNI_object.F_B_trimesh.export(
607
+ f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj"
608
+ )
609
  side_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj")
610
 
611
  if cfg.bni.use_poisson:
 
622
  rotate_recon_lst = dataset.render.get_image(cam_type="four")
623
  per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
624
 
625
+ if cfg.bni.texture_src == 'image':
626
+
627
+ # coloring the final mesh (front: RGB pixels, back: normal colors)
628
+ final_colors = query_color(
629
+ torch.tensor(final_mesh.vertices).float(),
630
+ torch.tensor(final_mesh.faces).long(),
631
+ in_tensor["image"][idx:idx + 1],
632
+ device=device,
633
+ )
634
+ final_mesh.visual.vertex_colors = final_colors
635
+ final_mesh.export(final_path)
636
+
637
+ elif cfg.bni.texture_src == 'SD':
638
+
639
+ # !TODO: add texture from Stable Diffusion
640
+ pass
641
 
642
  # for video rendering
643
  in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
apps/multi_render.py CHANGED
@@ -20,6 +20,4 @@ faces_lst = in_tensor["body_faces"] + in_tensor["BNI_faces"]
20
 
21
  # self-rotated video
22
  render.load_meshes(verts_lst, faces_lst)
23
- render.get_rendered_video_multi(
24
- in_tensor,
25
- f"{root}/{args.name}_cloth.mp4")
 
20
 
21
  # self-rotated video
22
  render.load_meshes(verts_lst, faces_lst)
23
+ render.get_rendered_video_multi(in_tensor, f"{root}/{args.name}_cloth.mp4")
 
 
configs/econ.yaml CHANGED
@@ -35,3 +35,4 @@ bni:
35
  face_thres: 6e-2
36
  thickness: 0.02
37
  hps_type: "pixie"
 
 
35
  face_thres: 6e-2
36
  thickness: 0.02
37
  hps_type: "pixie"
38
+ texture_src: "SD"
docs/installation.md CHANGED
@@ -9,12 +9,11 @@ cd ECON
9
 
10
  ## Environment
11
 
12
- - Ubuntu 20 / 18
13
- - GCC=7 (required by [pypoisson](https://github.com/mmolero/pypoisson/issues/13))
14
  - **CUDA=11.4, GPU Memory > 12GB**
15
  - Python = 3.8
16
  - PyTorch >= 1.13.0 (official [Get Started](https://pytorch.org/get-started/locally/))
17
- - CUPY >= 11.3.0 (offcial [Installation](https://docs.cupy.dev/en/stable/install.html#installing-cupy-from-pypi))
18
  - PyTorch3D (official [INSTALL.md](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md), recommend [install-from-local-clone](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-from-a-local-clone))
19
 
20
  ```bash
 
9
 
10
  ## Environment
11
 
12
+ - Ubuntu 20 / 18, (Windows as well, see [issue#7](https://github.com/YuliangXiu/ECON/issues/7))
 
13
  - **CUDA=11.4, GPU Memory > 12GB**
14
  - Python = 3.8
15
  - PyTorch >= 1.13.0 (official [Get Started](https://pytorch.org/get-started/locally/))
16
+ - Cupy >= 11.3.0 (offcial [Installation](https://docs.cupy.dev/en/stable/install.html#installing-cupy-from-pypi))
17
  - PyTorch3D (official [INSTALL.md](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md), recommend [install-from-local-clone](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-from-a-local-clone))
18
 
19
  ```bash
lib/common/BNI.py CHANGED
@@ -1,12 +1,12 @@
1
- from lib.common.BNI_utils import (verts_inverse_transform, depth_inverse_transform,
2
- double_side_bilateral_normal_integration)
 
3
 
4
  import torch
5
  import trimesh
6
 
7
 
8
  class BNI:
9
-
10
  def __init__(self, dir_path, name, BNI_dict, cfg, device):
11
 
12
  self.scale = 256.0
@@ -64,22 +64,20 @@ class BNI:
64
 
65
  F_B_verts = torch.cat((F_verts, B_verts), dim=0)
66
  F_B_faces = torch.cat(
67
- (bni_result["F_faces"], bni_result["B_faces"] + bni_result["F_faces"].max() + 1), dim=0)
 
68
 
69
- self.F_B_trimesh = trimesh.Trimesh(F_B_verts.float(),
70
- F_B_faces.long(),
71
- process=False,
72
- maintain_order=True)
73
 
74
- self.F_trimesh = trimesh.Trimesh(F_verts.float(),
75
- bni_result["F_faces"].long(),
76
- process=False,
77
- maintain_order=True)
78
 
79
- self.B_trimesh = trimesh.Trimesh(B_verts.float(),
80
- bni_result["B_faces"].long(),
81
- process=False,
82
- maintain_order=True)
83
 
84
 
85
  if __name__ == "__main__":
@@ -93,16 +91,18 @@ if __name__ == "__main__":
93
  bni_dict = np.load(npy_file, allow_pickle=True).item()
94
 
95
  default_cfg = {'k': 2, 'lambda1': 1e-4, 'boundary_consist': 1e-6}
96
-
97
  # for k in [1, 2, 4, 10, 100]:
98
  # default_cfg['k'] = k
99
  # for k in [1e-8, 1e-4, 1e-2, 1e-1, 1]:
100
- # default_cfg['lambda1'] = k
101
  # for k in [1e-4, 1e-2, 0]:
102
- # default_cfg['boundary_consist'] = k
103
-
104
- bni_object = BNI(osp.dirname(npy_file), osp.basename(npy_file), bni_dict, default_cfg,
105
- torch.device('cuda:0'))
 
 
106
 
107
  bni_object.extract_surface()
108
  bni_object.F_trimesh.export(osp.join(osp.dirname(npy_file), "F.obj"))
 
1
+ from lib.common.BNI_utils import (
2
+ verts_inverse_transform, depth_inverse_transform, double_side_bilateral_normal_integration
3
+ )
4
 
5
  import torch
6
  import trimesh
7
 
8
 
9
  class BNI:
 
10
  def __init__(self, dir_path, name, BNI_dict, cfg, device):
11
 
12
  self.scale = 256.0
 
64
 
65
  F_B_verts = torch.cat((F_verts, B_verts), dim=0)
66
  F_B_faces = torch.cat(
67
+ (bni_result["F_faces"], bni_result["B_faces"] + bni_result["F_faces"].max() + 1), dim=0
68
+ )
69
 
70
+ self.F_B_trimesh = trimesh.Trimesh(
71
+ F_B_verts.float(), F_B_faces.long(), process=False, maintain_order=True
72
+ )
 
73
 
74
+ self.F_trimesh = trimesh.Trimesh(
75
+ F_verts.float(), bni_result["F_faces"].long(), process=False, maintain_order=True
76
+ )
 
77
 
78
+ self.B_trimesh = trimesh.Trimesh(
79
+ B_verts.float(), bni_result["B_faces"].long(), process=False, maintain_order=True
80
+ )
 
81
 
82
 
83
  if __name__ == "__main__":
 
91
  bni_dict = np.load(npy_file, allow_pickle=True).item()
92
 
93
  default_cfg = {'k': 2, 'lambda1': 1e-4, 'boundary_consist': 1e-6}
94
+
95
  # for k in [1, 2, 4, 10, 100]:
96
  # default_cfg['k'] = k
97
  # for k in [1e-8, 1e-4, 1e-2, 1e-1, 1]:
98
+ # default_cfg['lambda1'] = k
99
  # for k in [1e-4, 1e-2, 0]:
100
+ # default_cfg['boundary_consist'] = k
101
+
102
+ bni_object = BNI(
103
+ osp.dirname(npy_file), osp.basename(npy_file), bni_dict, default_cfg,
104
+ torch.device('cuda:0')
105
+ )
106
 
107
  bni_object.extract_surface()
108
  bni_object.F_trimesh.export(osp.join(osp.dirname(npy_file), "F.obj"))
lib/common/BNI_utils.py CHANGED
@@ -53,8 +53,9 @@ def find_contour(mask, method='all'):
53
 
54
  contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
55
  else:
56
- contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE,
57
- cv2.CHAIN_APPROX_SIMPLE)
 
58
 
59
  contour_cloth = np.array(find_max_list(contours))[:, 0, :]
60
 
@@ -67,16 +68,19 @@ def mean_value_cordinates(inner_pts, contour_pts):
67
  body_edges_c = np.roll(body_edges_a, shift=-1, axis=1)
68
  body_edges_b = np.sqrt(((contour_pts - np.roll(contour_pts, shift=-1, axis=0))**2).sum(axis=1))
69
 
70
- body_edges = np.concatenate([
71
- body_edges_a[..., None], body_edges_c[..., None],
72
- np.repeat(body_edges_b[None, :, None], axis=0, repeats=len(inner_pts))
73
- ],
74
- axis=-1)
 
 
75
 
76
  body_cos = (body_edges[:, :, 0]**2 + body_edges[:, :, 1]**2 -
77
  body_edges[:, :, 2]**2) / (2 * body_edges[:, :, 0] * body_edges[:, :, 1])
78
  body_tan_half = np.sqrt(
79
- (1. - np.clip(body_cos, a_max=1., a_min=-1.)) / np.clip(1. + body_cos, 1e-6, 2.))
 
80
 
81
  w = (body_tan_half + np.roll(body_tan_half, shift=1, axis=1)) / body_edges_a
82
  w /= w.sum(axis=1, keepdims=True)
@@ -97,16 +101,18 @@ def dispCorres(img_size, contour1, contour2, phi, dir_path):
97
  contour2 = contour2[None, :, None, :].astype(np.int32)
98
 
99
  disp = np.zeros((img_size, img_size, 3), dtype=np.uint8)
100
- cv2.drawContours(disp, contour1, -1, (0, 255, 0), 1) # green
101
- cv2.drawContours(disp, contour2, -1, (255, 0, 0), 1) # blue
102
 
103
- for i in range(contour1.shape[1]): # do not show all the points when display
104
  # cv2.circle(disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), 1,
105
  # (255, 0, 0), -1)
106
  corresPoint = contour2[0, phi[i], 0]
107
  # cv2.circle(disp, (corresPoint[0], corresPoint[1]), 1, (0, 255, 0), -1)
108
- cv2.line(disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]),
109
- (corresPoint[0], corresPoint[1]), (255, 255, 255), 1)
 
 
110
 
111
  cv2.imwrite(osp.join(dir_path, "corres.png"), disp)
112
 
@@ -162,7 +168,8 @@ def verts_transform(t, depth_scale):
162
  t_copy *= depth_scale * 0.5
163
  t_copy += depth_scale * 0.5
164
  t_copy = t_copy[:, [1, 0, 2]] * torch.Tensor([2.0, 2.0, -2.0]) + torch.Tensor(
165
- [0.0, 0.0, depth_scale])
 
166
 
167
  return t_copy
168
 
@@ -328,19 +335,22 @@ def construct_facets_from(mask):
328
  facet_move_top_mask = move_top(mask)
329
  facet_move_left_mask = move_left(mask)
330
  facet_move_top_left_mask = move_top_left(mask)
331
- facet_top_left_mask = (facet_move_top_mask * facet_move_left_mask * facet_move_top_left_mask *
332
- mask)
 
333
  facet_top_right_mask = move_right(facet_top_left_mask)
334
  facet_bottom_left_mask = move_bottom(facet_top_left_mask)
335
  facet_bottom_right_mask = move_bottom_right(facet_top_left_mask)
336
 
337
- return cp.hstack((
338
- 4 * cp.ones((cp.sum(facet_top_left_mask).item(), 1)),
339
- idx[facet_top_left_mask][:, None],
340
- idx[facet_bottom_left_mask][:, None],
341
- idx[facet_bottom_right_mask][:, None],
342
- idx[facet_top_right_mask][:, None],
343
- )).astype(int)
 
 
344
 
345
 
346
  def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1):
@@ -364,8 +374,8 @@ def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1):
364
  u[..., 0] = xx
365
  u[..., 1] = yy
366
  u[..., 2] = 1
367
- u = u[mask].T # 3 x m
368
- vertices = (cp.linalg.inv(K) @ u).T * depth_map[mask, cp.newaxis] # m x 3
369
 
370
  return vertices
371
 
@@ -374,7 +384,6 @@ def sigmoid(x, k=1):
374
  return 1 / (1 + cp.exp(-k * x))
375
 
376
 
377
-
378
  def boundary_excluded_mask(mask):
379
  top_mask = cp.pad(mask, ((1, 0), (0, 0)), "constant", constant_values=0)[:-1, :]
380
  bottom_mask = cp.pad(mask, ((0, 1), (0, 0)), "constant", constant_values=0)[1:, :]
@@ -410,22 +419,24 @@ def create_boundary_matrix(mask):
410
  return B, B_full
411
 
412
 
413
- def double_side_bilateral_normal_integration(normal_front,
414
- normal_back,
415
- normal_mask,
416
- depth_front=None,
417
- depth_back=None,
418
- depth_mask=None,
419
- k=2,
420
- lambda_normal_back=1,
421
- lambda_depth_front=1e-4,
422
- lambda_depth_back=1e-2,
423
- lambda_boundary_consistency=1,
424
- step_size=1,
425
- max_iter=150,
426
- tol=1e-4,
427
- cg_max_iter=5000,
428
- cg_tol=1e-3):
 
 
429
 
430
  # To avoid confusion, we list the coordinate systems in this code as follows
431
  #
@@ -467,14 +478,12 @@ def double_side_bilateral_normal_integration(normal_front,
467
  del normal_map_back
468
 
469
  # right, left, top, bottom
470
- A3_f, A4_f, A1_f, A2_f = generate_dx_dy(normal_mask,
471
- nz_horizontal=nz_front,
472
- nz_vertical=nz_front,
473
- step_size=step_size)
474
- A3_b, A4_b, A1_b, A2_b = generate_dx_dy(normal_mask,
475
- nz_horizontal=nz_back,
476
- nz_vertical=nz_back,
477
- step_size=step_size)
478
 
479
  has_left_mask = cp.logical_and(move_right(normal_mask), normal_mask)
480
  has_right_mask = cp.logical_and(move_left(normal_mask), normal_mask)
@@ -498,29 +507,25 @@ def double_side_bilateral_normal_integration(normal_front,
498
  b_back = cp.concatenate((-nx_back, -nx_back, -ny_back, -ny_back))
499
 
500
  # initialization
501
- W_front = spdiags(0.5 * cp.ones(4 * num_normals),
502
- 0,
503
- 4 * num_normals,
504
- 4 * num_normals,
505
- format="csr")
506
- W_back = spdiags(0.5 * cp.ones(4 * num_normals),
507
- 0,
508
- 4 * num_normals,
509
- 4 * num_normals,
510
- format="csr")
511
 
512
  z_front = cp.zeros(num_normals, float)
513
  z_back = cp.zeros(num_normals, float)
514
  z_combined = cp.concatenate((z_front, z_back))
515
 
516
  B, B_full = create_boundary_matrix(normal_mask)
517
- B_mat = lambda_boundary_consistency * coo_matrix(B_full.get().T @ B_full.get()) #bug
518
 
519
  energy_list = []
520
 
521
  if depth_mask is not None:
522
- depth_mask_flat = depth_mask[normal_mask].astype(bool) # shape: (num_normals,)
523
- z_prior_front = depth_map_front[normal_mask] # shape: (num_normals,)
524
  z_prior_front[~depth_mask_flat] = 0
525
  z_prior_back = depth_map_back[normal_mask]
526
  z_prior_back[~depth_mask_flat] = 0
@@ -554,40 +559,43 @@ def double_side_bilateral_normal_integration(normal_front,
554
  vstack((csr_matrix((num_normals, num_normals)), A_mat_back))]) + B_mat
555
  b_vec_combined = cp.concatenate((b_vec_front, b_vec_back))
556
 
557
- D = spdiags(1 / cp.clip(A_mat_combined.diagonal(), 1e-5, None), 0, 2 * num_normals,
558
- 2 * num_normals, "csr") # Jacob preconditioner
 
 
559
 
560
- z_combined, _ = cg(A_mat_combined,
561
- b_vec_combined,
562
- M=D,
563
- x0=z_combined,
564
- maxiter=cg_max_iter,
565
- tol=cg_tol)
566
  z_front = z_combined[:num_normals]
567
  z_back = z_combined[num_normals:]
568
- wu_f = sigmoid((A2_f.dot(z_front))**2 - (A1_f.dot(z_front))**2, k) # top
569
- wv_f = sigmoid((A4_f.dot(z_front))**2 - (A3_f.dot(z_front))**2, k) # right
570
  wu_f[top_boundnary_mask] = 0.5
571
  wu_f[bottom_boundary_mask] = 0.5
572
  wv_f[left_boundary_mask] = 0.5
573
  wv_f[right_boudnary_mask] = 0.5
574
- W_front = spdiags(cp.concatenate((wu_f, 1 - wu_f, wv_f, 1 - wv_f)),
575
- 0,
576
- 4 * num_normals,
577
- 4 * num_normals,
578
- format="csr")
579
-
580
- wu_b = sigmoid((A2_b.dot(z_back))**2 - (A1_b.dot(z_back))**2, k) # top
581
- wv_b = sigmoid((A4_b.dot(z_back))**2 - (A3_b.dot(z_back))**2, k) # right
 
 
582
  wu_b[top_boundnary_mask] = 0.5
583
  wu_b[bottom_boundary_mask] = 0.5
584
  wv_b[left_boundary_mask] = 0.5
585
  wv_b[right_boudnary_mask] = 0.5
586
- W_back = spdiags(cp.concatenate((wu_b, 1 - wu_b, wv_b, 1 - wv_b)),
587
- 0,
588
- 4 * num_normals,
589
- 4 * num_normals,
590
- format="csr")
 
 
591
 
592
  energy_old = energy
593
  energy = (A_front_data @ z_front - b_front).T @ W_front @ (A_front_data @ z_front - b_front) + \
@@ -603,23 +611,26 @@ def double_side_bilateral_normal_integration(normal_front,
603
  if relative_energy < tol:
604
  break
605
  # del A1, A2, A3, A4, nx, ny
606
-
607
  depth_map_front_est = cp.ones_like(normal_mask, float) * cp.nan
608
  depth_map_front_est[normal_mask] = z_front
609
 
610
  depth_map_back_est = cp.ones_like(normal_mask, float) * cp.nan
611
  depth_map_back_est[normal_mask] = z_back
612
-
613
  # manually cut the intersection
614
- normal_mask[depth_map_front_est>=depth_map_back_est] = False
615
  depth_map_front_est[~normal_mask] = cp.nan
616
  depth_map_back_est[~normal_mask] = cp.nan
617
 
618
  vertices_front = cp.asnumpy(
619
- map_depth_map_to_point_clouds(depth_map_front_est, normal_mask, K=None,
620
- step_size=step_size))
 
 
621
  vertices_back = cp.asnumpy(
622
- map_depth_map_to_point_clouds(depth_map_back_est, normal_mask, K=None, step_size=step_size))
 
623
 
624
  facets_back = cp.asnumpy(construct_facets_from(normal_mask))
625
 
@@ -656,7 +667,7 @@ def save_normal_tensor(in_tensor, idx, png_path, thickness=0.0):
656
  depth_B_arr = depth2arr(in_tensor["depth_B"][idx])
657
 
658
  BNI_dict = {}
659
-
660
  # clothed human
661
  BNI_dict["normal_F"] = normal_F_arr
662
  BNI_dict["normal_B"] = normal_B_arr
 
53
 
54
  contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
55
  else:
56
+ contours, _ = cv2.findContours(
57
+ mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
58
+ )
59
 
60
  contour_cloth = np.array(find_max_list(contours))[:, 0, :]
61
 
 
68
  body_edges_c = np.roll(body_edges_a, shift=-1, axis=1)
69
  body_edges_b = np.sqrt(((contour_pts - np.roll(contour_pts, shift=-1, axis=0))**2).sum(axis=1))
70
 
71
+ body_edges = np.concatenate(
72
+ [
73
+ body_edges_a[..., None], body_edges_c[..., None],
74
+ np.repeat(body_edges_b[None, :, None], axis=0, repeats=len(inner_pts))
75
+ ],
76
+ axis=-1
77
+ )
78
 
79
  body_cos = (body_edges[:, :, 0]**2 + body_edges[:, :, 1]**2 -
80
  body_edges[:, :, 2]**2) / (2 * body_edges[:, :, 0] * body_edges[:, :, 1])
81
  body_tan_half = np.sqrt(
82
+ (1. - np.clip(body_cos, a_max=1., a_min=-1.)) / np.clip(1. + body_cos, 1e-6, 2.)
83
+ )
84
 
85
  w = (body_tan_half + np.roll(body_tan_half, shift=1, axis=1)) / body_edges_a
86
  w /= w.sum(axis=1, keepdims=True)
 
101
  contour2 = contour2[None, :, None, :].astype(np.int32)
102
 
103
  disp = np.zeros((img_size, img_size, 3), dtype=np.uint8)
104
+ cv2.drawContours(disp, contour1, -1, (0, 255, 0), 1) # green
105
+ cv2.drawContours(disp, contour2, -1, (255, 0, 0), 1) # blue
106
 
107
+ for i in range(contour1.shape[1]): # do not show all the points when display
108
  # cv2.circle(disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), 1,
109
  # (255, 0, 0), -1)
110
  corresPoint = contour2[0, phi[i], 0]
111
  # cv2.circle(disp, (corresPoint[0], corresPoint[1]), 1, (0, 255, 0), -1)
112
+ cv2.line(
113
+ disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), (corresPoint[0], corresPoint[1]),
114
+ (255, 255, 255), 1
115
+ )
116
 
117
  cv2.imwrite(osp.join(dir_path, "corres.png"), disp)
118
 
 
168
  t_copy *= depth_scale * 0.5
169
  t_copy += depth_scale * 0.5
170
  t_copy = t_copy[:, [1, 0, 2]] * torch.Tensor([2.0, 2.0, -2.0]) + torch.Tensor(
171
+ [0.0, 0.0, depth_scale]
172
+ )
173
 
174
  return t_copy
175
 
 
335
  facet_move_top_mask = move_top(mask)
336
  facet_move_left_mask = move_left(mask)
337
  facet_move_top_left_mask = move_top_left(mask)
338
+ facet_top_left_mask = (
339
+ facet_move_top_mask * facet_move_left_mask * facet_move_top_left_mask * mask
340
+ )
341
  facet_top_right_mask = move_right(facet_top_left_mask)
342
  facet_bottom_left_mask = move_bottom(facet_top_left_mask)
343
  facet_bottom_right_mask = move_bottom_right(facet_top_left_mask)
344
 
345
+ return cp.hstack(
346
+ (
347
+ 4 * cp.ones((cp.sum(facet_top_left_mask).item(), 1)),
348
+ idx[facet_top_left_mask][:, None],
349
+ idx[facet_bottom_left_mask][:, None],
350
+ idx[facet_bottom_right_mask][:, None],
351
+ idx[facet_top_right_mask][:, None],
352
+ )
353
+ ).astype(int)
354
 
355
 
356
  def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1):
 
374
  u[..., 0] = xx
375
  u[..., 1] = yy
376
  u[..., 2] = 1
377
+ u = u[mask].T # 3 x m
378
+ vertices = (cp.linalg.inv(K) @ u).T * depth_map[mask, cp.newaxis] # m x 3
379
 
380
  return vertices
381
 
 
384
  return 1 / (1 + cp.exp(-k * x))
385
 
386
 
 
387
  def boundary_excluded_mask(mask):
388
  top_mask = cp.pad(mask, ((1, 0), (0, 0)), "constant", constant_values=0)[:-1, :]
389
  bottom_mask = cp.pad(mask, ((0, 1), (0, 0)), "constant", constant_values=0)[1:, :]
 
419
  return B, B_full
420
 
421
 
422
+ def double_side_bilateral_normal_integration(
423
+ normal_front,
424
+ normal_back,
425
+ normal_mask,
426
+ depth_front=None,
427
+ depth_back=None,
428
+ depth_mask=None,
429
+ k=2,
430
+ lambda_normal_back=1,
431
+ lambda_depth_front=1e-4,
432
+ lambda_depth_back=1e-2,
433
+ lambda_boundary_consistency=1,
434
+ step_size=1,
435
+ max_iter=150,
436
+ tol=1e-4,
437
+ cg_max_iter=5000,
438
+ cg_tol=1e-3
439
+ ):
440
 
441
  # To avoid confusion, we list the coordinate systems in this code as follows
442
  #
 
478
  del normal_map_back
479
 
480
  # right, left, top, bottom
481
+ A3_f, A4_f, A1_f, A2_f = generate_dx_dy(
482
+ normal_mask, nz_horizontal=nz_front, nz_vertical=nz_front, step_size=step_size
483
+ )
484
+ A3_b, A4_b, A1_b, A2_b = generate_dx_dy(
485
+ normal_mask, nz_horizontal=nz_back, nz_vertical=nz_back, step_size=step_size
486
+ )
 
 
487
 
488
  has_left_mask = cp.logical_and(move_right(normal_mask), normal_mask)
489
  has_right_mask = cp.logical_and(move_left(normal_mask), normal_mask)
 
507
  b_back = cp.concatenate((-nx_back, -nx_back, -ny_back, -ny_back))
508
 
509
  # initialization
510
+ W_front = spdiags(
511
+ 0.5 * cp.ones(4 * num_normals), 0, 4 * num_normals, 4 * num_normals, format="csr"
512
+ )
513
+ W_back = spdiags(
514
+ 0.5 * cp.ones(4 * num_normals), 0, 4 * num_normals, 4 * num_normals, format="csr"
515
+ )
 
 
 
 
516
 
517
  z_front = cp.zeros(num_normals, float)
518
  z_back = cp.zeros(num_normals, float)
519
  z_combined = cp.concatenate((z_front, z_back))
520
 
521
  B, B_full = create_boundary_matrix(normal_mask)
522
+ B_mat = lambda_boundary_consistency * coo_matrix(B_full.get().T @ B_full.get()) #bug
523
 
524
  energy_list = []
525
 
526
  if depth_mask is not None:
527
+ depth_mask_flat = depth_mask[normal_mask].astype(bool) # shape: (num_normals,)
528
+ z_prior_front = depth_map_front[normal_mask] # shape: (num_normals,)
529
  z_prior_front[~depth_mask_flat] = 0
530
  z_prior_back = depth_map_back[normal_mask]
531
  z_prior_back[~depth_mask_flat] = 0
 
559
  vstack((csr_matrix((num_normals, num_normals)), A_mat_back))]) + B_mat
560
  b_vec_combined = cp.concatenate((b_vec_front, b_vec_back))
561
 
562
+ D = spdiags(
563
+ 1 / cp.clip(A_mat_combined.diagonal(), 1e-5, None), 0, 2 * num_normals, 2 * num_normals,
564
+ "csr"
565
+ ) # Jacob preconditioner
566
 
567
+ z_combined, _ = cg(
568
+ A_mat_combined, b_vec_combined, M=D, x0=z_combined, maxiter=cg_max_iter, tol=cg_tol
569
+ )
 
 
 
570
  z_front = z_combined[:num_normals]
571
  z_back = z_combined[num_normals:]
572
+ wu_f = sigmoid((A2_f.dot(z_front))**2 - (A1_f.dot(z_front))**2, k) # top
573
+ wv_f = sigmoid((A4_f.dot(z_front))**2 - (A3_f.dot(z_front))**2, k) # right
574
  wu_f[top_boundnary_mask] = 0.5
575
  wu_f[bottom_boundary_mask] = 0.5
576
  wv_f[left_boundary_mask] = 0.5
577
  wv_f[right_boudnary_mask] = 0.5
578
+ W_front = spdiags(
579
+ cp.concatenate((wu_f, 1 - wu_f, wv_f, 1 - wv_f)),
580
+ 0,
581
+ 4 * num_normals,
582
+ 4 * num_normals,
583
+ format="csr"
584
+ )
585
+
586
+ wu_b = sigmoid((A2_b.dot(z_back))**2 - (A1_b.dot(z_back))**2, k) # top
587
+ wv_b = sigmoid((A4_b.dot(z_back))**2 - (A3_b.dot(z_back))**2, k) # right
588
  wu_b[top_boundnary_mask] = 0.5
589
  wu_b[bottom_boundary_mask] = 0.5
590
  wv_b[left_boundary_mask] = 0.5
591
  wv_b[right_boudnary_mask] = 0.5
592
+ W_back = spdiags(
593
+ cp.concatenate((wu_b, 1 - wu_b, wv_b, 1 - wv_b)),
594
+ 0,
595
+ 4 * num_normals,
596
+ 4 * num_normals,
597
+ format="csr"
598
+ )
599
 
600
  energy_old = energy
601
  energy = (A_front_data @ z_front - b_front).T @ W_front @ (A_front_data @ z_front - b_front) + \
 
611
  if relative_energy < tol:
612
  break
613
  # del A1, A2, A3, A4, nx, ny
614
+
615
  depth_map_front_est = cp.ones_like(normal_mask, float) * cp.nan
616
  depth_map_front_est[normal_mask] = z_front
617
 
618
  depth_map_back_est = cp.ones_like(normal_mask, float) * cp.nan
619
  depth_map_back_est[normal_mask] = z_back
620
+
621
  # manually cut the intersection
622
+ normal_mask[depth_map_front_est >= depth_map_back_est] = False
623
  depth_map_front_est[~normal_mask] = cp.nan
624
  depth_map_back_est[~normal_mask] = cp.nan
625
 
626
  vertices_front = cp.asnumpy(
627
+ map_depth_map_to_point_clouds(
628
+ depth_map_front_est, normal_mask, K=None, step_size=step_size
629
+ )
630
+ )
631
  vertices_back = cp.asnumpy(
632
+ map_depth_map_to_point_clouds(depth_map_back_est, normal_mask, K=None, step_size=step_size)
633
+ )
634
 
635
  facets_back = cp.asnumpy(construct_facets_from(normal_mask))
636
 
 
667
  depth_B_arr = depth2arr(in_tensor["depth_B"][idx])
668
 
669
  BNI_dict = {}
670
+
671
  # clothed human
672
  BNI_dict["normal_F"] = normal_F_arr
673
  BNI_dict["normal_B"] = normal_B_arr
lib/common/blender_utils.py CHANGED
@@ -3,6 +3,7 @@ import sys, os
3
  from math import radians
4
  import mathutils
5
  import bmesh
 
6
  print(sys.exec_prefix)
7
  from tqdm import tqdm
8
  import numpy as np
@@ -29,7 +30,6 @@ shadows = False
29
  # diffuse_color = (18/255., 139/255., 142/255.,1) #correct
30
  # diffuse_color = (251/255., 60/255., 60/255.,1) #wrong
31
 
32
-
33
  smooth = False
34
 
35
  wireframe = False
@@ -47,13 +47,16 @@ compositor_alpha = 0.7
47
  # Helper functions
48
  ##################################################
49
 
 
50
  def blender_print(*args, **kwargs):
51
- print (*args, **kwargs, file=sys.stderr)
 
52
 
53
  def using_app():
54
  ''' Returns if script is running through Blender application (GUI or background processing)'''
55
  return (not sys.argv[0].endswith('.py'))
56
 
 
57
  def setup_diffuse_transparent_material(target, color, object_transparent, backface_transparent):
58
  ''' Sets up diffuse/transparent material with backface culling in cycles'''
59
 
@@ -110,8 +113,10 @@ def setup_diffuse_transparent_material(target, color, object_transparent, backfa
110
  links.new(node_mix_backface.outputs[0], node_output.inputs[0])
111
  return
112
 
 
113
  ##################################################
114
 
 
115
  def setup_scene():
116
  global render
117
  global cycles_gpu
@@ -150,12 +155,13 @@ def setup_scene():
150
  if cycles_gpu:
151
  print('Activating GPU acceleration')
152
  bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
153
-
154
  if bpy.app.version[0] >= 3:
155
- cuda_devices = bpy.context.preferences.addons['cycles'].preferences.get_devices_for_type(compute_device_type = 'CUDA')
 
156
  else:
157
- (cuda_devices, opencl_devices) = bpy.context.preferences.addons['cycles'].preferences.get_devices()
158
-
159
 
160
  if (len(cuda_devices) < 1):
161
  print('ERROR: CUDA GPU acceleration not available')
@@ -178,7 +184,7 @@ def setup_scene():
178
  if bpy.app.version[0] < 3:
179
  scene.render.tile_x = 64
180
  scene.render.tile_y = 64
181
-
182
  # Disable Blender 3 denoiser to properly measure Cycles render speed
183
  if bpy.app.version[0] >= 3:
184
  scene.cycles.use_denoising = False
@@ -226,7 +232,6 @@ def setup_scene():
226
  bpy.ops.mesh.mark_freestyle_edge(clear=True)
227
  bpy.ops.object.mode_set(mode='OBJECT')
228
 
229
-
230
  # Setup freestyle mode for wireframe overlay rendering
231
  if wireframe:
232
  scene.render.use_freestyle = True
@@ -245,8 +250,10 @@ def setup_scene():
245
  # Output transparent image when no background is used
246
  scene.render.image_settings.color_mode = 'RGBA'
247
 
 
248
  ##################################################
249
 
 
250
  def setup_compositing():
251
 
252
  global compositor_image_scale
@@ -275,6 +282,7 @@ def setup_compositing():
275
 
276
  links.new(blend_node.outputs[0], tree.nodes['Composite'].inputs[0])
277
 
 
278
  def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
279
  '''Render image of given model file'''
280
  global smooth
@@ -288,13 +296,13 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
288
  # Import object into scene
289
  bpy.ops.import_scene.obj(filepath=path)
290
  object = bpy.context.selected_objects[0]
291
-
292
  object.rotation_euler = (radians(90.0), 0.0, radians(yaw))
293
- z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:,1])
294
  # z_top = np.max(np.array([vert.co for vert in object.data.vertices])[:,1])
295
  # blender_print(radians(90.0), z_bottom, z_top)
296
  object.location -= mathutils.Vector((0.0, 0.0, z_bottom))
297
-
298
  if quads:
299
  bpy.context.view_layer.objects.active = object
300
  bpy.ops.object.mode_set(mode='EDIT')
@@ -309,11 +317,11 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
309
  bpy.ops.object.mode_set(mode='EDIT')
310
  bpy.ops.mesh.mark_freestyle_edge(clear=False)
311
  bpy.ops.object.mode_set(mode='OBJECT')
312
-
313
  if correct:
314
- diffuse_color = (18/255., 139/255., 142/255.,1) #correct
315
  else:
316
- diffuse_color = (251/255., 60/255., 60/255.,1) #wrong
317
 
318
  setup_diffuse_transparent_material(object, diffuse_color, object_transparent, mouth_transparent)
319
 
@@ -336,10 +344,10 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
336
  bpy.ops.render.render(write_still=True)
337
 
338
  # Remove temporary output redirection
339
- # sys.stdout.flush()
340
- # os.close(1)
341
- # os.dup(old)
342
- # os.close(old)
343
 
344
  # Delete last selected object from scene
345
  object.select_set(True)
@@ -351,7 +359,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True):
351
  global quality_preview
352
 
353
  if not input_file.endswith('.obj'):
354
- print('ERROR: Invalid input: ' + input_file )
355
  return
356
 
357
  print('Processing: ' + input_file)
@@ -361,7 +369,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True):
361
  if quality_preview:
362
  output_file = output_file.replace('.png', '-preview.png')
363
 
364
- angle = 360.0/views
365
  pbar = tqdm(range(0, views))
366
  for view in pbar:
367
  pbar.set_description(f"{os.path.basename(output_file)} | View:{str(view)}")
@@ -369,8 +377,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True):
369
  output_file_view = f"{output_file}/{view:03d}.png"
370
  if not os.path.exists(os.path.join(output_dir, output_file_view)):
371
  render_file(input_file, input_dir, output_file_view, output_dir, yaw, correct)
372
-
373
  cmd = "ffmpeg -loglevel quiet -r 30 -f lavfi -i color=c=white:s=512x512 -i " + os.path.join(output_dir, output_file, '%3d.png') + \
374
  " -shortest -filter_complex \"[0:v][1:v]overlay=shortest=1,format=yuv420p[out]\" -map \"[out]\" -y " + output_dir+"/"+output_file+".mp4"
375
  os.system(cmd)
376
-
 
3
  from math import radians
4
  import mathutils
5
  import bmesh
6
+
7
  print(sys.exec_prefix)
8
  from tqdm import tqdm
9
  import numpy as np
 
30
  # diffuse_color = (18/255., 139/255., 142/255.,1) #correct
31
  # diffuse_color = (251/255., 60/255., 60/255.,1) #wrong
32
 
 
33
  smooth = False
34
 
35
  wireframe = False
 
47
  # Helper functions
48
  ##################################################
49
 
50
+
51
  def blender_print(*args, **kwargs):
52
+ print(*args, **kwargs, file=sys.stderr)
53
+
54
 
55
  def using_app():
56
  ''' Returns if script is running through Blender application (GUI or background processing)'''
57
  return (not sys.argv[0].endswith('.py'))
58
 
59
+
60
  def setup_diffuse_transparent_material(target, color, object_transparent, backface_transparent):
61
  ''' Sets up diffuse/transparent material with backface culling in cycles'''
62
 
 
113
  links.new(node_mix_backface.outputs[0], node_output.inputs[0])
114
  return
115
 
116
+
117
  ##################################################
118
 
119
+
120
  def setup_scene():
121
  global render
122
  global cycles_gpu
 
155
  if cycles_gpu:
156
  print('Activating GPU acceleration')
157
  bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
158
+
159
  if bpy.app.version[0] >= 3:
160
+ cuda_devices = bpy.context.preferences.addons[
161
+ 'cycles'].preferences.get_devices_for_type(compute_device_type='CUDA')
162
  else:
163
+ (cuda_devices, opencl_devices
164
+ ) = bpy.context.preferences.addons['cycles'].preferences.get_devices()
165
 
166
  if (len(cuda_devices) < 1):
167
  print('ERROR: CUDA GPU acceleration not available')
 
184
  if bpy.app.version[0] < 3:
185
  scene.render.tile_x = 64
186
  scene.render.tile_y = 64
187
+
188
  # Disable Blender 3 denoiser to properly measure Cycles render speed
189
  if bpy.app.version[0] >= 3:
190
  scene.cycles.use_denoising = False
 
232
  bpy.ops.mesh.mark_freestyle_edge(clear=True)
233
  bpy.ops.object.mode_set(mode='OBJECT')
234
 
 
235
  # Setup freestyle mode for wireframe overlay rendering
236
  if wireframe:
237
  scene.render.use_freestyle = True
 
250
  # Output transparent image when no background is used
251
  scene.render.image_settings.color_mode = 'RGBA'
252
 
253
+
254
  ##################################################
255
 
256
+
257
  def setup_compositing():
258
 
259
  global compositor_image_scale
 
282
 
283
  links.new(blend_node.outputs[0], tree.nodes['Composite'].inputs[0])
284
 
285
+
286
  def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
287
  '''Render image of given model file'''
288
  global smooth
 
296
  # Import object into scene
297
  bpy.ops.import_scene.obj(filepath=path)
298
  object = bpy.context.selected_objects[0]
299
+
300
  object.rotation_euler = (radians(90.0), 0.0, radians(yaw))
301
+ z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:, 1])
302
  # z_top = np.max(np.array([vert.co for vert in object.data.vertices])[:,1])
303
  # blender_print(radians(90.0), z_bottom, z_top)
304
  object.location -= mathutils.Vector((0.0, 0.0, z_bottom))
305
+
306
  if quads:
307
  bpy.context.view_layer.objects.active = object
308
  bpy.ops.object.mode_set(mode='EDIT')
 
317
  bpy.ops.object.mode_set(mode='EDIT')
318
  bpy.ops.mesh.mark_freestyle_edge(clear=False)
319
  bpy.ops.object.mode_set(mode='OBJECT')
320
+
321
  if correct:
322
+ diffuse_color = (18 / 255., 139 / 255., 142 / 255., 1) #correct
323
  else:
324
+ diffuse_color = (251 / 255., 60 / 255., 60 / 255., 1) #wrong
325
 
326
  setup_diffuse_transparent_material(object, diffuse_color, object_transparent, mouth_transparent)
327
 
 
344
  bpy.ops.render.render(write_still=True)
345
 
346
  # Remove temporary output redirection
347
+ # sys.stdout.flush()
348
+ # os.close(1)
349
+ # os.dup(old)
350
+ # os.close(old)
351
 
352
  # Delete last selected object from scene
353
  object.select_set(True)
 
359
  global quality_preview
360
 
361
  if not input_file.endswith('.obj'):
362
+ print('ERROR: Invalid input: ' + input_file)
363
  return
364
 
365
  print('Processing: ' + input_file)
 
369
  if quality_preview:
370
  output_file = output_file.replace('.png', '-preview.png')
371
 
372
+ angle = 360.0 / views
373
  pbar = tqdm(range(0, views))
374
  for view in pbar:
375
  pbar.set_description(f"{os.path.basename(output_file)} | View:{str(view)}")
 
377
  output_file_view = f"{output_file}/{view:03d}.png"
378
  if not os.path.exists(os.path.join(output_dir, output_file_view)):
379
  render_file(input_file, input_dir, output_file_view, output_dir, yaw, correct)
380
+
381
  cmd = "ffmpeg -loglevel quiet -r 30 -f lavfi -i color=c=white:s=512x512 -i " + os.path.join(output_dir, output_file, '%3d.png') + \
382
  " -shortest -filter_complex \"[0:v][1:v]overlay=shortest=1,format=yuv420p[out]\" -map \"[out]\" -y " + output_dir+"/"+output_file+".mp4"
383
  os.system(cmd)
 
lib/common/cloth_extraction.py CHANGED
@@ -36,11 +36,13 @@ def load_segmentation(path, shape):
36
  xy = np.vstack((x, y)).T
37
  coordinates.append(xy)
38
 
39
- segmentations.append({
40
- "type": val["category_name"],
41
- "type_id": val["category_id"],
42
- "coordinates": coordinates,
43
- })
 
 
44
 
45
  return segmentations
46
 
@@ -56,9 +58,8 @@ def smpl_to_recon_labels(recon, smpl, k=1):
56
  Returns a dictionary containing the bodypart and the corresponding indices
57
  """
58
  smpl_vert_segmentation = json.load(
59
- open(
60
- os.path.join(os.path.dirname(__file__),
61
- "smpl_vert_segmentation.json")))
62
  n = smpl.vertices.shape[0]
63
  y = np.array([None] * n)
64
  for key, val in smpl_vert_segmentation.items():
@@ -71,8 +72,7 @@ def smpl_to_recon_labels(recon, smpl, k=1):
71
 
72
  recon_labels = {}
73
  for key in smpl_vert_segmentation.keys():
74
- recon_labels[key] = list(
75
- np.argwhere(y_pred == key).flatten().astype(int))
76
 
77
  return recon_labels
78
 
@@ -139,8 +139,7 @@ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
139
  if type == 1 or type == 3 or type == 10:
140
  body_parts_to_remove += ["leftForeArm", "rightForeArm"]
141
  # No sleeves at all or lower body clothes
142
- elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8
143
- or type == 9):
144
  body_parts_to_remove += [
145
  "leftForeArm",
146
  "rightForeArm",
@@ -159,8 +158,8 @@ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
159
  ]
160
 
161
  verts_to_remove = list(
162
- itertools.chain.from_iterable(
163
- [recon_labels[part] for part in body_parts_to_remove]))
164
 
165
  label_mask = np.zeros(num_verts, dtype=bool)
166
  label_mask[verts_to_remove] = True
 
36
  xy = np.vstack((x, y)).T
37
  coordinates.append(xy)
38
 
39
+ segmentations.append(
40
+ {
41
+ "type": val["category_name"],
42
+ "type_id": val["category_id"],
43
+ "coordinates": coordinates,
44
+ }
45
+ )
46
 
47
  return segmentations
48
 
 
58
  Returns a dictionary containing the bodypart and the corresponding indices
59
  """
60
  smpl_vert_segmentation = json.load(
61
+ open(os.path.join(os.path.dirname(__file__), "smpl_vert_segmentation.json"))
62
+ )
 
63
  n = smpl.vertices.shape[0]
64
  y = np.array([None] * n)
65
  for key, val in smpl_vert_segmentation.items():
 
72
 
73
  recon_labels = {}
74
  for key in smpl_vert_segmentation.keys():
75
+ recon_labels[key] = list(np.argwhere(y_pred == key).flatten().astype(int))
 
76
 
77
  return recon_labels
78
 
 
139
  if type == 1 or type == 3 or type == 10:
140
  body_parts_to_remove += ["leftForeArm", "rightForeArm"]
141
  # No sleeves at all or lower body clothes
142
+ elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9):
 
143
  body_parts_to_remove += [
144
  "leftForeArm",
145
  "rightForeArm",
 
158
  ]
159
 
160
  verts_to_remove = list(
161
+ itertools.chain.from_iterable([recon_labels[part] for part in body_parts_to_remove])
162
+ )
163
 
164
  label_mask = np.zeros(num_verts, dtype=bool)
165
  label_mask[verts_to_remove] = True
lib/common/config.py CHANGED
@@ -100,6 +100,7 @@ _C.bni.thickness = 0.00
100
  _C.bni.hand_thres = 4e-2
101
  _C.bni.face_thres = 6e-2
102
  _C.bni.hps_type = "pixie"
 
103
 
104
  # kernel_size, stride, dilation, padding
105
 
@@ -170,10 +171,10 @@ _C.dataset.rp_type = "pifu900"
170
  _C.dataset.th_type = "train"
171
  _C.dataset.input_size = 512
172
  _C.dataset.rotation_num = 3
173
- _C.dataset.num_precomp = 10 # Number of segmentation classifiers
174
- _C.dataset.num_multiseg = 500 # Number of categories per classifier
175
- _C.dataset.num_knn = 10 # for loss/error
176
- _C.dataset.num_knn_dis = 20 # for accuracy
177
  _C.dataset.num_verts_max = 20000
178
  _C.dataset.zray_type = False
179
  _C.dataset.online_smpl = False
@@ -210,8 +211,7 @@ def get_cfg_defaults():
210
 
211
  # Alternatively, provide a way to import the defaults as
212
  # a global singleton:
213
- cfg = _C # users can `from config import cfg`
214
-
215
 
216
  # cfg = get_cfg_defaults()
217
  # cfg.merge_from_file('./configs/example.yaml')
@@ -244,9 +244,7 @@ def parse_args(args):
244
  def parse_args_extend(args):
245
  if args.resume:
246
  if not os.path.exists(args.log_dir):
247
- raise ValueError(
248
- "Experiment are set to resume mode, but log directory does not exist."
249
- )
250
 
251
  # load log's cfg
252
  cfg_file = os.path.join(args.log_dir, "cfg.yaml")
 
100
  _C.bni.hand_thres = 4e-2
101
  _C.bni.face_thres = 6e-2
102
  _C.bni.hps_type = "pixie"
103
+ _C.bni.texture_src = "image"
104
 
105
  # kernel_size, stride, dilation, padding
106
 
 
171
  _C.dataset.th_type = "train"
172
  _C.dataset.input_size = 512
173
  _C.dataset.rotation_num = 3
174
+ _C.dataset.num_precomp = 10 # Number of segmentation classifiers
175
+ _C.dataset.num_multiseg = 500 # Number of categories per classifier
176
+ _C.dataset.num_knn = 10 # for loss/error
177
+ _C.dataset.num_knn_dis = 20 # for accuracy
178
  _C.dataset.num_verts_max = 20000
179
  _C.dataset.zray_type = False
180
  _C.dataset.online_smpl = False
 
211
 
212
  # Alternatively, provide a way to import the defaults as
213
  # a global singleton:
214
+ cfg = _C # users can `from config import cfg`
 
215
 
216
  # cfg = get_cfg_defaults()
217
  # cfg.merge_from_file('./configs/example.yaml')
 
244
  def parse_args_extend(args):
245
  if args.resume:
246
  if not os.path.exists(args.log_dir):
247
+ raise ValueError("Experiment are set to resume mode, but log directory does not exist.")
 
 
248
 
249
  # load log's cfg
250
  cfg_file = os.path.join(args.log_dir, "cfg.yaml")
lib/common/imutils.py CHANGED
@@ -3,14 +3,13 @@ import mediapipe as mp
3
  import torch
4
  import numpy as np
5
  import torch.nn.functional as F
6
- from rembg import remove
7
- from rembg.session_factory import new_session
8
  from PIL import Image
9
- from torchvision.models import detection
10
-
11
  from lib.pymafx.core import constants
12
- from lib.common.cloth_extraction import load_segmentation
 
 
13
  from torchvision import transforms
 
14
 
15
 
16
  def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
@@ -24,42 +23,40 @@ def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
24
  return transforms.Compose(all_ops)
25
 
26
 
27
- def aug_matrix(w1, h1, w2, h2):
28
- dx = (w2 - w1) / 2.0
29
- dy = (h2 - h1) / 2.0
30
-
31
- matrix_trans = np.array([[1.0, 0, dx], [0, 1.0, dy], [0, 0, 1.0]])
32
-
33
- scale = np.min([float(w2) / w1, float(h2) / h1])
34
 
35
- M = get_affine_matrix(center=(w2 / 2.0, h2 / 2.0), translate=(0, 0), scale=scale)
36
-
37
- M = np.array(M + [0.0, 0.0, 1.0]).reshape(3, 3)
38
- M = M.dot(matrix_trans)
39
 
40
  return M
41
 
42
 
43
- def get_affine_matrix(center, translate, scale):
44
- cx, cy = center
45
- tx, ty = translate
46
-
47
- M = [1, 0, 0, 0, 1, 0]
48
- M = [x * scale for x in M]
49
 
50
- # Apply translation and of center translation: RSS * C^-1
51
- M[2] += M[0] * (-cx) + M[1] * (-cy)
52
- M[5] += M[3] * (-cx) + M[4] * (-cy)
 
 
 
 
 
 
 
53
 
54
- # Apply center translation: T * C * RSS * C^-1
55
- M[2] += cx + tx
56
- M[5] += cy + ty
57
  return M
58
 
59
 
60
  def load_img(img_file):
61
 
62
  img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
 
 
 
 
 
63
  if len(img.shape) == 2:
64
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
65
 
@@ -68,11 +65,10 @@ def load_img(img_file):
68
  else:
69
  img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
70
 
71
- return img
72
 
73
 
74
  def get_keypoints(image):
75
-
76
  def collect_xyv(x, body=True):
77
  lmk = x.landmark
78
  all_lmks = []
@@ -84,8 +80,8 @@ def get_keypoints(image):
84
  mp_holistic = mp.solutions.holistic
85
 
86
  with mp_holistic.Holistic(
87
- static_image_mode=True,
88
- model_complexity=2,
89
  ) as holistic:
90
  results = holistic.process(image)
91
 
@@ -93,9 +89,15 @@ def get_keypoints(image):
93
 
94
  result = {}
95
  result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps
96
- result["lhand"] = collect_xyv(results.left_hand_landmarks, False) if results.left_hand_landmarks else fake_kps
97
- result["rhand"] = collect_xyv(results.right_hand_landmarks, False) if results.right_hand_landmarks else fake_kps
98
- result["face"] = collect_xyv(results.face_landmarks, False) if results.face_landmarks else fake_kps
 
 
 
 
 
 
99
 
100
  return result
101
 
@@ -104,13 +106,21 @@ def get_pymafx(image, landmarks):
104
 
105
  # image [3,512,512]
106
 
107
- item = {'img_body': F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0]}
 
 
 
108
 
109
  for part in ['lhand', 'rhand', 'face']:
110
  kp2d = landmarks[part]
111
  kp2d_valid = kp2d[kp2d[:, 3] > 0.]
112
  if len(kp2d_valid) > 0:
113
- bbox = [min(kp2d_valid[:, 0]), min(kp2d_valid[:, 1]), max(kp2d_valid[:, 0]), max(kp2d_valid[:, 1])]
 
 
 
 
 
114
  center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
115
  scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
116
 
@@ -141,20 +151,6 @@ def get_pymafx(image, landmarks):
141
  return item
142
 
143
 
144
- def expand_bbox(bbox, width, height, ratio=0.1):
145
-
146
- bbox = np.around(bbox).astype(np.int16)
147
- bbox_width = bbox[2] - bbox[0]
148
- bbox_height = bbox[3] - bbox[1]
149
-
150
- bbox[1] = max(bbox[1] - bbox_height * ratio, 0)
151
- bbox[3] = min(bbox[3] + bbox_height * ratio, height)
152
- bbox[0] = max(bbox[0] - bbox_width * ratio, 0)
153
- bbox[2] = min(bbox[2] + bbox_width * ratio, width)
154
-
155
- return bbox
156
-
157
-
158
  def remove_floats(mask):
159
 
160
  # 1. find all the contours
@@ -173,51 +169,48 @@ def remove_floats(mask):
173
  return new_mask
174
 
175
 
176
- def process_image(img_file, hps_type, single, input_res=512):
177
 
178
- img_raw = load_img(img_file)
179
-
180
- in_height, in_width = img_raw.shape[:2]
181
- M = aug_matrix(in_width, in_height, input_res * 2, input_res * 2)
182
-
183
- # from rectangle to square by padding (input_res*2, input_res*2)
184
- img_square = cv2.warpAffine(img_raw, M[0:2, :], (input_res * 2, input_res * 2), flags=cv2.INTER_CUBIC)
 
 
 
185
 
186
  # detection for bbox
187
- detector = detection.maskrcnn_resnet50_fpn(weights=detection.MaskRCNN_ResNet50_FPN_V2_Weights)
188
- detector.eval()
189
- predictions = detector([torch.from_numpy(img_square).permute(2, 0, 1) / 255.])[0]
190
 
191
  if single:
192
  top_score = predictions["scores"][predictions["labels"] == 1].max()
193
  human_ids = torch.where(predictions["scores"] == top_score)[0]
194
  else:
195
- human_ids = torch.logical_and(predictions["labels"] == 1, predictions["scores"] > 0.9).nonzero().squeeze(1)
 
196
 
197
  boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy()
198
  masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy()
199
 
200
- width = boxes[:, 2] - boxes[:, 0] #(N,)
201
- height = boxes[:, 3] - boxes[:, 1] #(N,)
202
- center = np.array([(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]).T #(N,2)
203
- scale = np.array([width, height]).max(axis=0) / 90.
204
 
205
  img_icon_lst = []
206
  img_crop_lst = []
207
  img_hps_lst = []
208
  img_mask_lst = []
209
- uncrop_param_lst = []
210
  landmark_lst = []
211
  hands_visibility_lst = []
212
  img_pymafx_lst = []
213
 
214
  uncrop_param = {
215
- "center": center,
216
- "scale": scale,
217
  "ori_shape": [in_height, in_width],
218
  "box_shape": [input_res, input_res],
219
- "crop_shape": [input_res * 2, input_res * 2, 3],
220
- "M": M,
 
221
  }
222
 
223
  for idx in range(len(boxes)):
@@ -228,59 +221,74 @@ def process_image(img_file, hps_type, single, input_res=512):
228
  else:
229
  mask_detection = masks[0] * 0.
230
 
231
- img_crop, _ = crop(
232
- np.concatenate([img_square, (mask_detection < 0.4) * 255], axis=2), center[idx], scale[idx], [input_res, input_res])
233
-
234
- # get accurate segmentation mask of focus person
 
 
 
 
 
 
 
 
 
 
 
235
  img_rembg = remove(img_crop, post_process_mask=True, session=new_session("u2net"))
236
  img_mask = remove_floats(img_rembg[:, :, [3]])
237
 
238
- # required image tensors / arrays
239
-
240
- # img_icon (tensor): (-1, 1), [3,512,512]
241
- # img_hps (tensor): (-2.11, 2.44), [3,224,224]
242
-
243
- # img_np (array): (0, 255), [512,512,3]
244
- # img_rembg (array): (0, 255), [512,512,4]
245
- # img_mask (array): (0, 1), [512,512,1]
246
- # img_crop (array): (0, 255), [512,512,4]
247
-
248
  mean_icon = std_icon = (0.5, 0.5, 0.5)
249
  img_np = (img_rembg[..., :3] * img_mask).astype(np.uint8)
250
- img_icon = transform_to_tensor(512, mean_icon, std_icon)(Image.fromarray(img_np)) * torch.tensor(img_mask).permute(
251
- 2, 0, 1)
252
- img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN, constants.IMG_NORM_STD)(Image.fromarray(img_np))
 
 
253
 
254
  landmarks = get_keypoints(img_np)
255
 
 
 
 
 
 
 
 
 
256
  if hps_type == 'pymafx':
257
  img_pymafx_lst.append(
258
  get_pymafx(
259
- transform_to_tensor(512, constants.IMG_NORM_MEAN, constants.IMG_NORM_STD)(Image.fromarray(img_np)),
260
- landmarks))
 
 
261
 
262
  img_crop_lst.append(torch.tensor(img_crop).permute(2, 0, 1) / 255.0)
263
  img_icon_lst.append(img_icon)
264
  img_hps_lst.append(img_hps)
265
  img_mask_lst.append(torch.tensor(img_mask[..., 0]))
266
- uncrop_param_lst.append(uncrop_param)
267
  landmark_lst.append(landmarks['body'])
268
 
269
- hands_visibility = [True, True]
270
- if landmarks['lhand'][:, -1].mean() == 0.:
271
- hands_visibility[0] = False
272
- if landmarks['rhand'][:, -1].mean() == 0.:
273
- hands_visibility[1] = False
274
- hands_visibility_lst.append(hands_visibility)
 
 
 
275
 
276
  return_dict = {
277
- "img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res]
278
- "img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res]
279
- "img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res]
280
- "img_raw": img_raw, #[H, W, 3]
281
- "img_mask": torch.stack(img_mask_lst).float(), #[N, res, res]
282
  "uncrop_param": uncrop_param,
283
- "landmark": torch.stack(landmark_lst), #[N, 33, 4]
284
  "hands_visibility": hands_visibility_lst,
285
  }
286
 
@@ -302,250 +310,51 @@ def process_image(img_file, hps_type, single, input_res=512):
302
  return return_dict
303
 
304
 
305
- def get_transform(center, scale, res):
306
- """Generate transformation matrix."""
307
- h = 100 * scale
308
- t = np.zeros((3, 3))
309
- t[0, 0] = float(res[1]) / h
310
- t[1, 1] = float(res[0]) / h
311
- t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
312
- t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
313
- t[2, 2] = 1
314
-
315
- return t
316
-
317
-
318
- def transform(pt, center, scale, res, invert=0):
319
- """Transform pixel location to different reference."""
320
- t = get_transform(center, scale, res)
321
- if invert:
322
- t = np.linalg.inv(t)
323
- new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
324
- new_pt = np.dot(t, new_pt)
325
- return np.around(new_pt[:2]).astype(np.int16)
326
-
327
-
328
- def crop(img, center, scale, res):
329
- """Crop image according to the supplied bounding box."""
330
-
331
- img_height, img_width = img.shape[:2]
332
-
333
- # Upper left point
334
- ul = np.array(transform([0, 0], center, scale, res, invert=1))
335
-
336
- # Bottom right point
337
- br = np.array(transform(res, center, scale, res, invert=1))
338
-
339
- new_shape = [br[1] - ul[1], br[0] - ul[0]]
340
- if len(img.shape) > 2:
341
- new_shape += [img.shape[2]]
342
- new_img = np.zeros(new_shape)
343
-
344
- # Range to fill new array
345
- new_x = max(0, -ul[0]), min(br[0], img_width) - ul[0]
346
- new_y = max(0, -ul[1]), min(br[1], img_height) - ul[1]
347
-
348
- # Range to sample from original image
349
- old_x = max(0, ul[0]), min(img_width, br[0])
350
- old_y = max(0, ul[1]), min(img_height, br[1])
351
-
352
- new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
353
- new_img = F.interpolate(
354
- torch.tensor(new_img).permute(2, 0, 1).unsqueeze(0), res, mode='bilinear').permute(0, 2, 3,
355
- 1)[0].numpy().astype(np.uint8)
356
-
357
- return new_img, (old_x, new_x, old_y, new_y, new_shape)
358
-
359
-
360
- def crop_segmentation(org_coord, res, cropping_parameters):
361
- old_x, new_x, old_y, new_y, new_shape = cropping_parameters
362
 
363
- new_coord = np.zeros((org_coord.shape))
364
- new_coord[:, 0] = new_x[0] + (org_coord[:, 0] - old_x[0])
365
- new_coord[:, 1] = new_y[0] + (org_coord[:, 1] - old_y[0])
366
-
367
- new_coord[:, 0] = res[0] * (new_coord[:, 0] / new_shape[1])
368
- new_coord[:, 1] = res[1] * (new_coord[:, 1] / new_shape[0])
369
-
370
- return new_coord
371
-
372
-
373
- def corner_align(ul, br):
374
-
375
- if ul[1] - ul[0] != br[1] - br[0]:
376
- ul[1] = ul[0] + br[1] - br[0]
377
-
378
- return ul, br
379
-
380
-
381
- def uncrop(img, center, scale, orig_shape):
382
- """'Undo' the image cropping/resizing.
383
- This function is used when evaluating mask/part segmentation.
384
- """
385
-
386
- res = img.shape[:2]
387
-
388
- # Upper left point
389
- ul = np.array(transform([0, 0], center, scale, res, invert=1))
390
- # Bottom right point
391
- br = np.array(transform(res, center, scale, res, invert=1))
392
-
393
- # quick fix
394
- ul, br = corner_align(ul, br)
395
-
396
- # size of cropped image
397
- crop_shape = [br[1] - ul[1], br[0] - ul[0]]
398
- new_img = np.zeros(orig_shape, dtype=np.uint8)
399
-
400
- # Range to fill new array
401
- new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
402
- new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
403
 
404
- # Range to sample from original image
405
- old_x = max(0, ul[0]), min(orig_shape[1], br[0])
406
- old_y = max(0, ul[1]), min(orig_shape[0], br[1])
407
 
408
- img = np.array(Image.fromarray(img.astype(np.uint8)).resize(crop_shape))
 
409
 
410
- new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
 
411
 
412
- return new_img
413
 
 
414
 
415
- def rot_aa(aa, rot):
416
- """Rotate axis angle parameters."""
417
- # pose parameters
418
- R = np.array([
419
- [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
420
- [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
421
- [0, 0, 1],
422
- ])
423
- # find the rotation of the body in camera frame
424
- per_rdg, _ = cv2.Rodrigues(aa)
425
- # apply the global rotation to the global orientation
426
- resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
427
- aa = (resrot.T)[0]
428
- return aa
429
 
 
430
 
431
- def flip_img(img):
432
- """Flip rgb images or masks.
433
- channels come last, e.g. (256,256,3).
434
- """
435
- img = np.fliplr(img)
436
- return img
437
 
 
 
 
 
 
 
 
 
438
 
439
- def flip_kp(kp, is_smpl=False):
440
- """Flip keypoints."""
441
- if len(kp) == 24:
442
- if is_smpl:
443
- flipped_parts = constants.SMPL_JOINTS_FLIP_PERM
444
- else:
445
- flipped_parts = constants.J24_FLIP_PERM
446
- elif len(kp) == 49:
447
- if is_smpl:
448
- flipped_parts = constants.SMPL_J49_FLIP_PERM
449
- else:
450
- flipped_parts = constants.J49_FLIP_PERM
451
- kp = kp[flipped_parts]
452
- kp[:, 0] = -kp[:, 0]
453
- return kp
454
-
455
-
456
- def flip_pose(pose):
457
- """Flip pose.
458
- The flipping is based on SMPL parameters.
459
- """
460
- flipped_parts = constants.SMPL_POSE_FLIP_PERM
461
- pose = pose[flipped_parts]
462
- # we also negate the second and the third dimension of the axis-angle
463
- pose[1::3] = -pose[1::3]
464
- pose[2::3] = -pose[2::3]
465
- return pose
466
-
467
-
468
- def normalize_2d_kp(kp_2d, crop_size=224, inv=False):
469
- # Normalize keypoints between -1, 1
470
- if not inv:
471
- ratio = 1.0 / crop_size
472
- kp_2d = 2.0 * kp_2d * ratio - 1.0
473
- else:
474
- ratio = 1.0 / crop_size
475
- kp_2d = (kp_2d + 1.0) / (2 * ratio)
476
-
477
- return kp_2d
478
-
479
-
480
- def visualize_landmarks(image, joints, color):
481
-
482
- img_w, img_h = image.shape[:2]
483
-
484
- for joint in joints:
485
- image = cv2.circle(image, (int(joint[0] * img_w), int(joint[1] * img_h)), 5, color)
486
-
487
- return image
488
-
489
-
490
- def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
491
- """
492
- param joints: [num_joints, 3]
493
- param joints_vis: [num_joints, 3]
494
- return: target, target_weight(1: visible, 0: invisible)
495
- """
496
- num_joints = joints.shape[0]
497
- device = joints.device
498
- cur_device = torch.device(device.type, device.index)
499
- if not hasattr(heatmap_size, "__len__"):
500
- # width height
501
- heatmap_size = [heatmap_size, heatmap_size]
502
- assert len(heatmap_size) == 2
503
- target_weight = np.ones((num_joints, 1), dtype=np.float32)
504
- if joints_vis is not None:
505
- target_weight[:, 0] = joints_vis[:, 0]
506
- target = torch.zeros(
507
- (num_joints, heatmap_size[1], heatmap_size[0]),
508
- dtype=torch.float32,
509
- device=cur_device,
510
  )
511
 
512
- tmp_size = sigma * 3
513
-
514
- for joint_id in range(num_joints):
515
- mu_x = int(joints[joint_id][0] * heatmap_size[0] + 0.5)
516
- mu_y = int(joints[joint_id][1] * heatmap_size[1] + 0.5)
517
- # Check that any part of the gaussian is in-bounds
518
- ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
519
- br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
520
- if (ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] or br[0] < 0 or br[1] < 0):
521
- # If not, just return the image as is
522
- target_weight[joint_id] = 0
523
- continue
524
-
525
- # # Generate gaussian
526
- size = 2 * tmp_size + 1
527
- # x = np.arange(0, size, 1, np.float32)
528
- # y = x[:, np.newaxis]
529
- # x0 = y0 = size // 2
530
- # # The gaussian is not normalized, we want the center value to equal 1
531
- # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
532
- # g = torch.from_numpy(g.astype(np.float32))
533
-
534
- x = torch.arange(0, size, dtype=torch.float32, device=cur_device)
535
- y = x.unsqueeze(-1)
536
- x0 = y0 = size // 2
537
- # The gaussian is not normalized, we want the center value to equal 1
538
- g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
539
-
540
- # Usable gaussian range
541
- g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
542
- g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
543
- # Image range
544
- img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
545
- img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
546
-
547
- v = target_weight[joint_id]
548
- if v > 0.5:
549
- target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
550
-
551
- return target, target_weight
 
3
  import torch
4
  import numpy as np
5
  import torch.nn.functional as F
 
 
6
  from PIL import Image
 
 
7
  from lib.pymafx.core import constants
8
+
9
+ from rembg import remove
10
+ from rembg.session_factory import new_session
11
  from torchvision import transforms
12
+ from kornia.geometry.transform import get_affine_matrix2d, warp_affine
13
 
14
 
15
  def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
 
23
  return transforms.Compose(all_ops)
24
 
25
 
26
+ def get_affine_matrix_wh(w1, h1, w2, h2):
 
 
 
 
 
 
27
 
28
+ transl = torch.tensor([(w2 - w1) / 2.0, (h2 - h1) / 2.0]).unsqueeze(0)
29
+ center = torch.tensor([w1 / 2.0, h1 / 2.0]).unsqueeze(0)
30
+ scale = torch.min(torch.tensor([w2 / w1, h2 / h1])).repeat(2).unsqueeze(0)
31
+ M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.]))
32
 
33
  return M
34
 
35
 
36
+ def get_affine_matrix_box(boxes, w2, h2):
 
 
 
 
 
37
 
38
+ # boxes [left, top, right, bottom]
39
+ width = boxes[:, 2] - boxes[:, 0] #(N,)
40
+ height = boxes[:, 3] - boxes[:, 1] #(N,)
41
+ center = torch.tensor(
42
+ [(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]
43
+ ).T #(N,2)
44
+ scale = torch.min(torch.tensor([w2 / width, h2 / height]),
45
+ dim=0)[0].unsqueeze(1).repeat(1, 2) * 0.9 #(N,2)
46
+ transl = torch.tensor([w2 / 2.0 - center[:, 0], h2 / 2.0 - center[:, 1]]).unsqueeze(0) #(N,2)
47
+ M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.]))
48
 
 
 
 
49
  return M
50
 
51
 
52
  def load_img(img_file):
53
 
54
  img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
55
+
56
+ # considering 16-bit image
57
+ if img.dtype == np.uint16:
58
+ img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
59
+
60
  if len(img.shape) == 2:
61
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
62
 
 
65
  else:
66
  img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
67
 
68
+ return torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float(), img.shape[:2]
69
 
70
 
71
  def get_keypoints(image):
 
72
  def collect_xyv(x, body=True):
73
  lmk = x.landmark
74
  all_lmks = []
 
80
  mp_holistic = mp.solutions.holistic
81
 
82
  with mp_holistic.Holistic(
83
+ static_image_mode=True,
84
+ model_complexity=2,
85
  ) as holistic:
86
  results = holistic.process(image)
87
 
 
89
 
90
  result = {}
91
  result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps
92
+ result["lhand"] = collect_xyv(
93
+ results.left_hand_landmarks, False
94
+ ) if results.left_hand_landmarks else fake_kps
95
+ result["rhand"] = collect_xyv(
96
+ results.right_hand_landmarks, False
97
+ ) if results.right_hand_landmarks else fake_kps
98
+ result["face"] = collect_xyv(
99
+ results.face_landmarks, False
100
+ ) if results.face_landmarks else fake_kps
101
 
102
  return result
103
 
 
106
 
107
  # image [3,512,512]
108
 
109
+ item = {
110
+ 'img_body':
111
+ F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0]
112
+ }
113
 
114
  for part in ['lhand', 'rhand', 'face']:
115
  kp2d = landmarks[part]
116
  kp2d_valid = kp2d[kp2d[:, 3] > 0.]
117
  if len(kp2d_valid) > 0:
118
+ bbox = [
119
+ min(kp2d_valid[:, 0]),
120
+ min(kp2d_valid[:, 1]),
121
+ max(kp2d_valid[:, 0]),
122
+ max(kp2d_valid[:, 1])
123
+ ]
124
  center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
125
  scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
126
 
 
151
  return item
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def remove_floats(mask):
155
 
156
  # 1. find all the contours
 
169
  return new_mask
170
 
171
 
172
+ def process_image(img_file, hps_type, single, input_res, detector):
173
 
174
+ img_raw, (in_height, in_width) = load_img(img_file)
175
+ tgt_res = input_res * 2
176
+ M_square = get_affine_matrix_wh(in_width, in_height, tgt_res, tgt_res)
177
+ img_square = warp_affine(
178
+ img_raw,
179
+ M_square[:, :2], (tgt_res, ) * 2,
180
+ mode='bilinear',
181
+ padding_mode='zeros',
182
+ align_corners=True
183
+ )
184
 
185
  # detection for bbox
186
+ predictions = detector(img_square / 255.)[0]
 
 
187
 
188
  if single:
189
  top_score = predictions["scores"][predictions["labels"] == 1].max()
190
  human_ids = torch.where(predictions["scores"] == top_score)[0]
191
  else:
192
+ human_ids = torch.logical_and(predictions["labels"] == 1,
193
+ predictions["scores"] > 0.9).nonzero().squeeze(1)
194
 
195
  boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy()
196
  masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy()
197
 
198
+ M_crop = get_affine_matrix_box(boxes, input_res, input_res)
 
 
 
199
 
200
  img_icon_lst = []
201
  img_crop_lst = []
202
  img_hps_lst = []
203
  img_mask_lst = []
 
204
  landmark_lst = []
205
  hands_visibility_lst = []
206
  img_pymafx_lst = []
207
 
208
  uncrop_param = {
 
 
209
  "ori_shape": [in_height, in_width],
210
  "box_shape": [input_res, input_res],
211
+ "square_shape": [tgt_res, tgt_res],
212
+ "M_square": M_square,
213
+ "M_crop": M_crop
214
  }
215
 
216
  for idx in range(len(boxes)):
 
221
  else:
222
  mask_detection = masks[0] * 0.
223
 
224
+ img_square_rgba = torch.cat(
225
+ [img_square.squeeze(0).permute(1, 2, 0),
226
+ torch.tensor(mask_detection < 0.4) * 255],
227
+ dim=2
228
+ )
229
+
230
+ img_crop = warp_affine(
231
+ img_square_rgba.unsqueeze(0).permute(0, 3, 1, 2),
232
+ M_crop[idx:idx + 1, :2], (input_res, ) * 2,
233
+ mode='bilinear',
234
+ padding_mode='zeros',
235
+ align_corners=True
236
+ ).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
237
+
238
+ # get accurate person segmentation mask
239
  img_rembg = remove(img_crop, post_process_mask=True, session=new_session("u2net"))
240
  img_mask = remove_floats(img_rembg[:, :, [3]])
241
 
 
 
 
 
 
 
 
 
 
 
242
  mean_icon = std_icon = (0.5, 0.5, 0.5)
243
  img_np = (img_rembg[..., :3] * img_mask).astype(np.uint8)
244
+ img_icon = transform_to_tensor(512, mean_icon, std_icon)(
245
+ Image.fromarray(img_np)
246
+ ) * torch.tensor(img_mask).permute(2, 0, 1)
247
+ img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN,
248
+ constants.IMG_NORM_STD)(Image.fromarray(img_np))
249
 
250
  landmarks = get_keypoints(img_np)
251
 
252
+ # get hands visibility
253
+ hands_visibility = [True, True]
254
+ if landmarks['lhand'][:, -1].mean() == 0.:
255
+ hands_visibility[0] = False
256
+ if landmarks['rhand'][:, -1].mean() == 0.:
257
+ hands_visibility[1] = False
258
+ hands_visibility_lst.append(hands_visibility)
259
+
260
  if hps_type == 'pymafx':
261
  img_pymafx_lst.append(
262
  get_pymafx(
263
+ transform_to_tensor(512, constants.IMG_NORM_MEAN,
264
+ constants.IMG_NORM_STD)(Image.fromarray(img_np)), landmarks
265
+ )
266
+ )
267
 
268
  img_crop_lst.append(torch.tensor(img_crop).permute(2, 0, 1) / 255.0)
269
  img_icon_lst.append(img_icon)
270
  img_hps_lst.append(img_hps)
271
  img_mask_lst.append(torch.tensor(img_mask[..., 0]))
 
272
  landmark_lst.append(landmarks['body'])
273
 
274
+ # required image tensors / arrays
275
+
276
+ # img_icon (tensor): (-1, 1), [3,512,512]
277
+ # img_hps (tensor): (-2.11, 2.44), [3,224,224]
278
+
279
+ # img_np (array): (0, 255), [512,512,3]
280
+ # img_rembg (array): (0, 255), [512,512,4]
281
+ # img_mask (array): (0, 1), [512,512,1]
282
+ # img_crop (array): (0, 255), [512,512,4]
283
 
284
  return_dict = {
285
+ "img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res]
286
+ "img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res]
287
+ "img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res]
288
+ "img_raw": img_raw, #[1, 3, H, W]
289
+ "img_mask": torch.stack(img_mask_lst).float(), #[N, res, res]
290
  "uncrop_param": uncrop_param,
291
+ "landmark": torch.stack(landmark_lst), #[N, 33, 4]
292
  "hands_visibility": hands_visibility_lst,
293
  }
294
 
 
310
  return return_dict
311
 
312
 
313
+ def blend_rgb_norm(norms, data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
+ # norms [N, 3, res, res]
316
+ masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1)
317
+ norm_mask = F.interpolate(
318
+ torch.cat([norms, masks], dim=1).detach(),
319
+ size=data["uncrop_param"]["box_shape"],
320
+ mode="bilinear",
321
+ align_corners=False
322
+ )
323
+ final = data["img_raw"].type_as(norm_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
+ for idx in range(len(norms)):
 
 
326
 
327
+ norm_pred = (norm_mask[idx:idx + 1, :3, :, :] + 1.0) * 255.0 / 2.0
328
+ mask_pred = norm_mask[idx:idx + 1, 3:4, :, :].repeat(1, 3, 1, 1)
329
 
330
+ norm_ori = unwrap(norm_pred, data["uncrop_param"], idx)
331
+ mask_ori = unwrap(mask_pred, data["uncrop_param"], idx)
332
 
333
+ final = final * (1.0 - mask_ori) + norm_ori * mask_ori
334
 
335
+ return final.detach().cpu()
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ def unwrap(image, uncrop_param, idx):
339
 
340
+ device = image.device
 
 
 
 
 
341
 
342
+ img_square = warp_affine(
343
+ image,
344
+ torch.inverse(uncrop_param["M_crop"])[idx:idx + 1, :2].to(device),
345
+ uncrop_param["square_shape"],
346
+ mode='bilinear',
347
+ padding_mode='zeros',
348
+ align_corners=True
349
+ )
350
 
351
+ img_ori = warp_affine(
352
+ img_square,
353
+ torch.inverse(uncrop_param["M_square"])[:, :2].to(device),
354
+ uncrop_param["ori_shape"],
355
+ mode='bilinear',
356
+ padding_mode='zeros',
357
+ align_corners=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  )
359
 
360
+ return img_ori
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/common/libmesh/inside_mesh.py CHANGED
@@ -5,7 +5,7 @@ from .triangle_hash import TriangleHash as _TriangleHash
5
  def check_mesh_contains(mesh, points, hash_resolution=512):
6
  intersector = MeshIntersector(mesh, hash_resolution)
7
  contains, hole_points = intersector.query(points)
8
- return contains, hole_points
9
 
10
 
11
  class MeshIntersector:
@@ -25,8 +25,7 @@ class MeshIntersector:
25
  # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5))
26
 
27
  triangles2d = triangles[:, :, :2]
28
- self._tri_intersector2d = TriangleIntersector2d(
29
- triangles2d, resolution)
30
 
31
  def query(self, points):
32
  # Rescale points
@@ -38,8 +37,7 @@ class MeshIntersector:
38
 
39
  # cull points outside of the axis aligned bounding box
40
  # this avoids running ray tests unless points are close
41
- inside_aabb = np.all(
42
- (0 <= points) & (points <= self.resolution), axis=1)
43
  if not inside_aabb.any():
44
  return contains, hole_points
45
 
@@ -48,14 +46,14 @@ class MeshIntersector:
48
  points = points[mask]
49
 
50
  # Compute intersection depth and check order
51
- points_indices, tri_indices = self._tri_intersector2d.query(
52
- points[:, :2])
53
 
54
  triangles_intersect = self._triangles[tri_indices]
55
  points_intersect = points[points_indices]
56
 
57
  depth_intersect, abs_n_2 = self.compute_intersection_depth(
58
- points_intersect, triangles_intersect)
 
59
 
60
  # Count number of intersections in both directions
61
  smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2
@@ -73,7 +71,7 @@ class MeshIntersector:
73
  # print('Warning: contains1 != contains2 for some points.')
74
  contains[mask] = (contains1 & contains2)
75
  hole_points[mask] = np.logical_xor(contains1, contains2)
76
- return contains, hole_points
77
 
78
  def compute_intersection_depth(self, points, triangles):
79
  t1 = triangles[:, 0, :]
@@ -150,7 +148,7 @@ class TriangleIntersector2d:
150
 
151
  sum_uv = u + v
152
  contains[mask] = (
153
- (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA)
154
- & (0 < sum_uv) & (sum_uv < abs_detA)
155
  )
156
  return contains
 
5
  def check_mesh_contains(mesh, points, hash_resolution=512):
6
  intersector = MeshIntersector(mesh, hash_resolution)
7
  contains, hole_points = intersector.query(points)
8
+ return contains, hole_points
9
 
10
 
11
  class MeshIntersector:
 
25
  # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5))
26
 
27
  triangles2d = triangles[:, :, :2]
28
+ self._tri_intersector2d = TriangleIntersector2d(triangles2d, resolution)
 
29
 
30
  def query(self, points):
31
  # Rescale points
 
37
 
38
  # cull points outside of the axis aligned bounding box
39
  # this avoids running ray tests unless points are close
40
+ inside_aabb = np.all((0 <= points) & (points <= self.resolution), axis=1)
 
41
  if not inside_aabb.any():
42
  return contains, hole_points
43
 
 
46
  points = points[mask]
47
 
48
  # Compute intersection depth and check order
49
+ points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2])
 
50
 
51
  triangles_intersect = self._triangles[tri_indices]
52
  points_intersect = points[points_indices]
53
 
54
  depth_intersect, abs_n_2 = self.compute_intersection_depth(
55
+ points_intersect, triangles_intersect
56
+ )
57
 
58
  # Count number of intersections in both directions
59
  smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2
 
71
  # print('Warning: contains1 != contains2 for some points.')
72
  contains[mask] = (contains1 & contains2)
73
  hole_points[mask] = np.logical_xor(contains1, contains2)
74
+ return contains, hole_points
75
 
76
  def compute_intersection_depth(self, points, triangles):
77
  t1 = triangles[:, 0, :]
 
148
 
149
  sum_uv = u + v
150
  contains[mask] = (
151
+ (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) & (0 < sum_uv) &
152
+ (sum_uv < abs_detA)
153
  )
154
  return contains
lib/common/libmesh/setup.py CHANGED
@@ -2,7 +2,4 @@ from setuptools import setup
2
  from Cython.Build import cythonize
3
  import numpy
4
 
5
-
6
- setup(name = 'libmesh',
7
- ext_modules = cythonize("*.pyx"),
8
- include_dirs=[numpy.get_include()])
 
2
  from Cython.Build import cythonize
3
  import numpy
4
 
5
+ setup(name='libmesh', ext_modules=cythonize("*.pyx"), include_dirs=[numpy.get_include()])
 
 
 
lib/common/libvoxelize/setup.py CHANGED
@@ -1,5 +1,4 @@
1
  from setuptools import setup
2
  from Cython.Build import cythonize
3
 
4
- setup(name = 'libvoxelize',
5
- ext_modules = cythonize("*.pyx"))
 
1
  from setuptools import setup
2
  from Cython.Build import cythonize
3
 
4
+ setup(name='libvoxelize', ext_modules=cythonize("*.pyx"))
 
lib/common/local_affine.py CHANGED
@@ -16,7 +16,6 @@ from lib.common.train_util import init_loss
16
 
17
  # reference: https://github.com/wuhaozhe/pytorch-nicp
18
  class LocalAffine(nn.Module):
19
-
20
  def __init__(self, num_points, batch_size=1, edges=None):
21
  '''
22
  specify the number of points, the number of points should be constant across the batch
@@ -26,8 +25,14 @@ class LocalAffine(nn.Module):
26
  add additional pooling on top of w matrix
27
  '''
28
  super(LocalAffine, self).__init__()
29
- self.A = nn.Parameter(torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1))
30
- self.b = nn.Parameter(torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(batch_size, num_points, 1, 1))
 
 
 
 
 
 
31
  self.edges = edges
32
  self.num_points = num_points
33
 
@@ -38,24 +43,23 @@ class LocalAffine(nn.Module):
38
  '''
39
  if self.edges is None:
40
  raise Exception("edges cannot be none when calculate stiff")
41
- idx1 = self.edges[:, 0]
42
- idx2 = self.edges[:, 1]
43
  affine_weight = torch.cat((self.A, self.b), dim=3)
44
- w1 = torch.index_select(affine_weight, dim=1, index=idx1)
45
- w2 = torch.index_select(affine_weight, dim=1, index=idx2)
46
  w_diff = (w1 - w2)**2
47
  w_rigid = (torch.linalg.det(self.A) - 1.0)**2
48
  return w_diff, w_rigid
49
 
50
  def forward(self, x):
51
  '''
52
- x should have shape of B * N * 3
53
  '''
54
  x = x.unsqueeze(3)
55
  out_x = torch.matmul(self.A, x)
56
  out_x = out_x + self.b
57
- stiffness, rigid = self.stiffness()
58
  out_x.squeeze_(3)
 
 
59
  return out_x, stiffness, rigid
60
 
61
 
@@ -75,10 +79,16 @@ def register(target_mesh, src_mesh, device):
75
  tgt_mesh = trimesh2meshes(target_mesh).to(device)
76
  src_verts = src_mesh.verts_padded().clone()
77
 
78
- local_affine_model = LocalAffine(src_mesh.verts_padded().shape[1],
79
- src_mesh.verts_padded().shape[0], src_mesh.edges_packed()).to(device)
 
 
80
 
81
- optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}], lr=1e-2, amsgrad=True)
 
 
 
 
82
  scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
83
  optimizer_cloth,
84
  mode="min",
@@ -90,28 +100,27 @@ def register(target_mesh, src_mesh, device):
90
 
91
  losses = init_loss()
92
 
93
- loop_cloth = tqdm(range(200))
94
 
95
  for i in loop_cloth:
96
 
97
  optimizer_cloth.zero_grad()
98
 
99
- deformed_verts, stiffness, rigid = local_affine_model(src_verts)
100
  src_mesh = src_mesh.update_padded(deformed_verts)
101
 
102
  # losses for laplacian, edge, normal consistency
103
  update_mesh_shape_prior_losses(src_mesh, losses)
104
 
105
  losses["cloth"]["value"] = chamfer_distance(
106
- x=src_mesh.verts_padded(),
107
- y=tgt_mesh.verts_padded())[0]
108
-
109
- losses["stiffness"]["value"] = torch.mean(stiffness)
110
  losses["rigid"]["value"] = torch.mean(rigid)
111
 
112
  # Weighted sum of the losses
113
  cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
114
- pbar_desc = "Register SMPL-X towards ECON --- "
115
 
116
  for k in losses.keys():
117
  if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
@@ -119,7 +128,7 @@ def register(target_mesh, src_mesh, device):
119
  losses[k]["value"] * losses[k]["weight"]
120
  pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "
121
 
122
- pbar_desc += f"Total: {cloth_loss:.5f}"
123
  loop_cloth.set_description(pbar_desc)
124
 
125
  # update params
@@ -131,6 +140,7 @@ def register(target_mesh, src_mesh, device):
131
  src_mesh.verts_packed().detach().squeeze(0).cpu(),
132
  src_mesh.faces_packed().detach().squeeze(0).cpu(),
133
  process=False,
134
- maintains_order=True)
 
135
 
136
  return final
 
16
 
17
  # reference: https://github.com/wuhaozhe/pytorch-nicp
18
  class LocalAffine(nn.Module):
 
19
  def __init__(self, num_points, batch_size=1, edges=None):
20
  '''
21
  specify the number of points, the number of points should be constant across the batch
 
25
  add additional pooling on top of w matrix
26
  '''
27
  super(LocalAffine, self).__init__()
28
+ self.A = nn.Parameter(
29
+ torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1)
30
+ )
31
+ self.b = nn.Parameter(
32
+ torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(
33
+ batch_size, num_points, 1, 1
34
+ )
35
+ )
36
  self.edges = edges
37
  self.num_points = num_points
38
 
 
43
  '''
44
  if self.edges is None:
45
  raise Exception("edges cannot be none when calculate stiff")
 
 
46
  affine_weight = torch.cat((self.A, self.b), dim=3)
47
+ w1 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 0])
48
+ w2 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 1])
49
  w_diff = (w1 - w2)**2
50
  w_rigid = (torch.linalg.det(self.A) - 1.0)**2
51
  return w_diff, w_rigid
52
 
53
  def forward(self, x):
54
  '''
55
+ x should have shape of B * N * 3 * 1
56
  '''
57
  x = x.unsqueeze(3)
58
  out_x = torch.matmul(self.A, x)
59
  out_x = out_x + self.b
 
60
  out_x.squeeze_(3)
61
+ stiffness, rigid = self.stiffness()
62
+
63
  return out_x, stiffness, rigid
64
 
65
 
 
79
  tgt_mesh = trimesh2meshes(target_mesh).to(device)
80
  src_verts = src_mesh.verts_padded().clone()
81
 
82
+ local_affine_model = LocalAffine(
83
+ src_mesh.verts_padded().shape[1],
84
+ src_mesh.verts_padded().shape[0], src_mesh.edges_packed()
85
+ ).to(device)
86
 
87
+ optimizer_cloth = torch.optim.Adam(
88
+ [{
89
+ 'params': local_affine_model.parameters()
90
+ }], lr=1e-2, amsgrad=True
91
+ )
92
  scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
93
  optimizer_cloth,
94
  mode="min",
 
100
 
101
  losses = init_loss()
102
 
103
+ loop_cloth = tqdm(range(100))
104
 
105
  for i in loop_cloth:
106
 
107
  optimizer_cloth.zero_grad()
108
 
109
+ deformed_verts, stiffness, rigid = local_affine_model(x=src_verts)
110
  src_mesh = src_mesh.update_padded(deformed_verts)
111
 
112
  # losses for laplacian, edge, normal consistency
113
  update_mesh_shape_prior_losses(src_mesh, losses)
114
 
115
  losses["cloth"]["value"] = chamfer_distance(
116
+ x=src_mesh.verts_padded(), y=tgt_mesh.verts_padded()
117
+ )[0]
118
+ losses["stiff"]["value"] = torch.mean(stiffness)
 
119
  losses["rigid"]["value"] = torch.mean(rigid)
120
 
121
  # Weighted sum of the losses
122
  cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
123
+ pbar_desc = "Register SMPL-X -> d-BiNI -- "
124
 
125
  for k in losses.keys():
126
  if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
 
128
  losses[k]["value"] * losses[k]["weight"]
129
  pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "
130
 
131
+ pbar_desc += f"TOTAL: {cloth_loss:.3f}"
132
  loop_cloth.set_description(pbar_desc)
133
 
134
  # update params
 
140
  src_mesh.verts_packed().detach().squeeze(0).cpu(),
141
  src_mesh.faces_packed().detach().squeeze(0).cpu(),
142
  process=False,
143
+ maintains_order=True
144
+ )
145
 
146
  return final
lib/common/render.py CHANGED
@@ -31,7 +31,8 @@ from pytorch3d.renderer import (
31
  )
32
  from pytorch3d.renderer.mesh import TexturesVertex
33
  from pytorch3d.structures import Meshes
34
- from lib.dataset.mesh_util import get_visibility, blend_rgb_norm
 
35
 
36
  import lib.common.render_utils as util
37
  import torch
@@ -74,20 +75,23 @@ def query_color(verts, faces, image, device):
74
 
75
  (xy, z) = verts.split([2, 1], dim=1)
76
  visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
77
- uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
78
  uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
79
  colors = (
80
- (torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :, 0].permute(1, 0) +
81
- 1.0) * 0.5 * 255.0)
 
 
 
82
  colors[visibility == 0.0] = (
83
  (Meshes(verts.unsqueeze(0), faces.unsqueeze(0)).verts_normals_padded().squeeze(0) + 1.0) *
84
- 0.5 * 255.0)[visibility == 0.0]
 
85
 
86
  return colors.detach().cpu()
87
 
88
 
89
  class cleanShader(torch.nn.Module):
90
-
91
  def __init__(self, blend_params=None):
92
  super().__init__()
93
  self.blend_params = blend_params if blend_params is not None else BlendParams()
@@ -103,7 +107,6 @@ class cleanShader(torch.nn.Module):
103
 
104
 
105
  class Render:
106
-
107
  def __init__(self, size=512, device=torch.device("cuda:0")):
108
  self.device = device
109
  self.size = size
@@ -119,21 +122,30 @@ class Render:
119
 
120
  self.cam_pos = {
121
  "frontback":
122
- torch.tensor([
123
- (0, self.mesh_y_center, self.dis),
124
- (0, self.mesh_y_center, -self.dis),
125
- ]),
 
 
126
  "four":
127
- torch.tensor([
128
- (0, self.mesh_y_center, self.dis),
129
- (self.dis, self.mesh_y_center, 0),
130
- (0, self.mesh_y_center, -self.dis),
131
- (-self.dis, self.mesh_y_center, 0),
132
- ]),
 
 
133
  "around":
134
- torch.tensor([(100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center,
135
- 100.0 * math.sin(np.pi / 180 * angle))
136
- for angle in range(0, 360, self.step)])
 
 
 
 
 
137
  }
138
 
139
  self.type = "color"
@@ -153,8 +165,8 @@ class Render:
153
 
154
  R, T = look_at_view_transform(
155
  eye=self.cam_pos[type][idx],
156
- at=((0, self.mesh_y_center, 0),),
157
- up=((0, 1, 0),),
158
  )
159
 
160
  cameras = FoVOrthographicCameras(
@@ -167,7 +179,7 @@ class Render:
167
  min_y=-100.0,
168
  max_x=100.0,
169
  min_x=-100.0,
170
- scale_xyz=(self.scale * np.ones(3),) * len(R),
171
  )
172
 
173
  return cameras
@@ -202,15 +214,17 @@ class Render:
202
  cull_backfaces=True,
203
  )
204
 
205
- self.silhouetteRas = MeshRasterizer(cameras=camera,
206
- raster_settings=self.raster_settings_silhouette)
207
- self.renderer = MeshRenderer(rasterizer=self.silhouetteRas,
208
- shader=SoftSilhouetteShader())
 
 
209
 
210
  elif type == "pointcloud":
211
- self.raster_settings_pcd = PointsRasterizationSettings(image_size=self.size,
212
- radius=0.006,
213
- points_per_pixel=10)
214
 
215
  self.pcdRas = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings_pcd)
216
  self.renderer = PointsRenderer(
@@ -230,8 +244,12 @@ class Render:
230
  V_lst = []
231
  F_lst = []
232
  for V, F in zip(verts, faces):
233
- V_lst.append(torch.tensor(V).float().to(self.device))
234
- F_lst.append(torch.tensor(F).long().to(self.device))
 
 
 
 
235
  self.meshes = Meshes(V_lst, F_lst).to(self.device)
236
  else:
237
  # array or tensor
@@ -248,7 +266,8 @@ class Render:
248
  # texture only support single mesh
249
  if len(self.meshes) == 1:
250
  self.meshes.textures = TexturesVertex(
251
- verts_features=(self.meshes.verts_normals_padded() + 1.0) * 0.5)
 
252
 
253
  def get_image(self, cam_type="frontback", type="rgb", bg="gray"):
254
 
@@ -260,7 +279,8 @@ class Render:
260
 
261
  current_mesh = self.meshes[mesh_id]
262
  current_mesh.textures = TexturesVertex(
263
- verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5)
 
264
 
265
  if type == "depth":
266
  fragments = self.meshRas(current_mesh.extend(len(self.cam_pos[cam_type])))
@@ -276,7 +296,7 @@ class Render:
276
  print(f"unknown {type}")
277
 
278
  if cam_type == 'frontback':
279
- images[1] = torch.flip(images[1], dims=(-1,))
280
 
281
  # images [N_render, 3, res, res]
282
  img_lst.append(images.unsqueeze(1))
@@ -287,9 +307,8 @@ class Render:
287
  return list(meshes)
288
 
289
  def get_rendered_video_multi(self, data, save_path):
290
-
291
- width = data["img_raw"].shape[1]
292
- height = data["img_raw"].shape[0]
293
 
294
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
295
  video = cv2.VideoWriter(
@@ -302,14 +321,15 @@ class Render:
302
  pbar = tqdm(range(len(self.meshes)))
303
  pbar.set_description(colored(f"Normal Rendering {os.path.basename(save_path)}...", "blue"))
304
 
305
- mesh_renders = [] #[(N_cam, 3, res, res)*N_mesh]
306
 
307
  # render all the normals
308
  for mesh_id in pbar:
309
 
310
  current_mesh = self.meshes[mesh_id]
311
  current_mesh.textures = TexturesVertex(
312
- verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5)
 
313
 
314
  norm_lst = []
315
 
@@ -320,21 +340,33 @@ class Render:
320
  self.init_renderer(batch_cams, "mesh", "gray")
321
 
322
  norm_lst.append(
323
- self.renderer(current_mesh.extend(len(batch_cams_idx)))[..., :3].permute(
324
- 0, 3, 1, 2))
 
325
  mesh_renders.append(torch.cat(norm_lst).detach().cpu())
326
 
327
  # generate video frame by frame
328
  pbar = tqdm(range(len(self.cam_pos["around"])))
329
  pbar.set_description(colored(f"Video Exporting {os.path.basename(save_path)}...", "blue"))
 
330
  for cam_id in pbar:
331
- img_raw = data["img_raw"].astype(np.uint8)
332
  num_obj = len(mesh_renders) // 2
333
- img_smpl = blend_rgb_norm((torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0, data)
334
- img_cloth = blend_rgb_norm((torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0, data)
 
 
 
 
335
 
336
- top_img = cv2.resize(np.concatenate([img_raw, img_smpl], axis=1), (width, height // 2))
337
- final_img = np.concatenate([top_img, img_cloth], axis=0)
 
 
 
 
 
 
338
  video.write(final_img[:, :, ::-1])
339
 
340
  video.release()
 
31
  )
32
  from pytorch3d.renderer.mesh import TexturesVertex
33
  from pytorch3d.structures import Meshes
34
+ from lib.dataset.mesh_util import get_visibility
35
+ from lib.common.imutils import blend_rgb_norm
36
 
37
  import lib.common.render_utils as util
38
  import torch
 
75
 
76
  (xy, z) = verts.split([2, 1], dim=1)
77
  visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
78
+ uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
79
  uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
80
  colors = (
81
+ (
82
+ torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :,
83
+ 0].permute(1, 0) + 1.0
84
+ ) * 0.5 * 255.0
85
+ )
86
  colors[visibility == 0.0] = (
87
  (Meshes(verts.unsqueeze(0), faces.unsqueeze(0)).verts_normals_padded().squeeze(0) + 1.0) *
88
+ 0.5 * 255.0
89
+ )[visibility == 0.0]
90
 
91
  return colors.detach().cpu()
92
 
93
 
94
  class cleanShader(torch.nn.Module):
 
95
  def __init__(self, blend_params=None):
96
  super().__init__()
97
  self.blend_params = blend_params if blend_params is not None else BlendParams()
 
107
 
108
 
109
  class Render:
 
110
  def __init__(self, size=512, device=torch.device("cuda:0")):
111
  self.device = device
112
  self.size = size
 
122
 
123
  self.cam_pos = {
124
  "frontback":
125
+ torch.tensor(
126
+ [
127
+ (0, self.mesh_y_center, self.dis),
128
+ (0, self.mesh_y_center, -self.dis),
129
+ ]
130
+ ),
131
  "four":
132
+ torch.tensor(
133
+ [
134
+ (0, self.mesh_y_center, self.dis),
135
+ (self.dis, self.mesh_y_center, 0),
136
+ (0, self.mesh_y_center, -self.dis),
137
+ (-self.dis, self.mesh_y_center, 0),
138
+ ]
139
+ ),
140
  "around":
141
+ torch.tensor(
142
+ [
143
+ (
144
+ 100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center,
145
+ 100.0 * math.sin(np.pi / 180 * angle)
146
+ ) for angle in range(0, 360, self.step)
147
+ ]
148
+ )
149
  }
150
 
151
  self.type = "color"
 
165
 
166
  R, T = look_at_view_transform(
167
  eye=self.cam_pos[type][idx],
168
+ at=((0, self.mesh_y_center, 0), ),
169
+ up=((0, 1, 0), ),
170
  )
171
 
172
  cameras = FoVOrthographicCameras(
 
179
  min_y=-100.0,
180
  max_x=100.0,
181
  min_x=-100.0,
182
+ scale_xyz=(self.scale * np.ones(3), ) * len(R),
183
  )
184
 
185
  return cameras
 
214
  cull_backfaces=True,
215
  )
216
 
217
+ self.silhouetteRas = MeshRasterizer(
218
+ cameras=camera, raster_settings=self.raster_settings_silhouette
219
+ )
220
+ self.renderer = MeshRenderer(
221
+ rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
222
+ )
223
 
224
  elif type == "pointcloud":
225
+ self.raster_settings_pcd = PointsRasterizationSettings(
226
+ image_size=self.size, radius=0.006, points_per_pixel=10
227
+ )
228
 
229
  self.pcdRas = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings_pcd)
230
  self.renderer = PointsRenderer(
 
244
  V_lst = []
245
  F_lst = []
246
  for V, F in zip(verts, faces):
247
+ if not torch.is_tensor(V):
248
+ V_lst.append(torch.tensor(V).float().to(self.device))
249
+ F_lst.append(torch.tensor(F).long().to(self.device))
250
+ else:
251
+ V_lst.append(V.float().to(self.device))
252
+ F_lst.append(F.long().to(self.device))
253
  self.meshes = Meshes(V_lst, F_lst).to(self.device)
254
  else:
255
  # array or tensor
 
266
  # texture only support single mesh
267
  if len(self.meshes) == 1:
268
  self.meshes.textures = TexturesVertex(
269
+ verts_features=(self.meshes.verts_normals_padded() + 1.0) * 0.5
270
+ )
271
 
272
  def get_image(self, cam_type="frontback", type="rgb", bg="gray"):
273
 
 
279
 
280
  current_mesh = self.meshes[mesh_id]
281
  current_mesh.textures = TexturesVertex(
282
+ verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
283
+ )
284
 
285
  if type == "depth":
286
  fragments = self.meshRas(current_mesh.extend(len(self.cam_pos[cam_type])))
 
296
  print(f"unknown {type}")
297
 
298
  if cam_type == 'frontback':
299
+ images[1] = torch.flip(images[1], dims=(-1, ))
300
 
301
  # images [N_render, 3, res, res]
302
  img_lst.append(images.unsqueeze(1))
 
307
  return list(meshes)
308
 
309
  def get_rendered_video_multi(self, data, save_path):
310
+
311
+ height, width = data["img_raw"].shape[2:]
 
312
 
313
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
314
  video = cv2.VideoWriter(
 
321
  pbar = tqdm(range(len(self.meshes)))
322
  pbar.set_description(colored(f"Normal Rendering {os.path.basename(save_path)}...", "blue"))
323
 
324
+ mesh_renders = [] #[(N_cam, 3, res, res)*N_mesh]
325
 
326
  # render all the normals
327
  for mesh_id in pbar:
328
 
329
  current_mesh = self.meshes[mesh_id]
330
  current_mesh.textures = TexturesVertex(
331
+ verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
332
+ )
333
 
334
  norm_lst = []
335
 
 
340
  self.init_renderer(batch_cams, "mesh", "gray")
341
 
342
  norm_lst.append(
343
+ self.renderer(current_mesh.extend(len(batch_cams_idx))
344
+ )[..., :3].permute(0, 3, 1, 2)
345
+ )
346
  mesh_renders.append(torch.cat(norm_lst).detach().cpu())
347
 
348
  # generate video frame by frame
349
  pbar = tqdm(range(len(self.cam_pos["around"])))
350
  pbar.set_description(colored(f"Video Exporting {os.path.basename(save_path)}...", "blue"))
351
+
352
  for cam_id in pbar:
353
+ img_raw = data["img_raw"]
354
  num_obj = len(mesh_renders) // 2
355
+ img_smpl = blend_rgb_norm(
356
+ (torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0, data
357
+ )
358
+ img_cloth = blend_rgb_norm(
359
+ (torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0, data
360
+ )
361
 
362
+ top_img = cv2.resize(
363
+ torch.cat([img_raw, img_smpl],
364
+ dim=-1).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8),
365
+ (width, height // 2)
366
+ )
367
+ final_img = np.concatenate(
368
+ [top_img, img_cloth.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)], axis=0
369
+ )
370
  video.write(final_img[:, :, ::-1])
371
 
372
  video.release()
lib/common/render_utils.py CHANGED
@@ -25,9 +25,7 @@ from pytorch3d.renderer.mesh import rasterize_meshes
25
  Tensor = NewType("Tensor", torch.Tensor)
26
 
27
 
28
- def solid_angles(points: Tensor,
29
- triangles: Tensor,
30
- thresh: float = 1e-8) -> Tensor:
31
  """Compute solid angle between the input points and triangles
32
  Follows the method described in:
33
  The Solid Angle of a Plane Triangle
@@ -55,9 +53,7 @@ def solid_angles(points: Tensor,
55
  norms = torch.norm(centered_tris, dim=-1)
56
 
57
  # Should be BxQxFx3
58
- cross_prod = torch.cross(centered_tris[:, :, :, 1],
59
- centered_tris[:, :, :, 2],
60
- dim=-1)
61
  # Should be BxQxF
62
  numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
63
  del cross_prod
@@ -67,8 +63,10 @@ def solid_angles(points: Tensor,
67
  dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
68
  del centered_tris
69
 
70
- denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] +
71
- dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0])
 
 
72
  del dot01, dot12, dot02, norms
73
 
74
  # Should be BxQ
@@ -80,9 +78,7 @@ def solid_angles(points: Tensor,
80
  return 2 * solid_angle
81
 
82
 
83
- def winding_numbers(points: Tensor,
84
- triangles: Tensor,
85
- thresh: float = 1e-8) -> Tensor:
86
  """Uses winding_numbers to compute inside/outside
87
  Robust inside-outside segmentation using generalized winding numbers
88
  Alec Jacobson,
@@ -109,8 +105,7 @@ def winding_numbers(points: Tensor,
109
  """
110
  # The generalized winding number is the sum of solid angles of the point
111
  # with respect to all triangles.
112
- return (1 / (4 * math.pi) *
113
- solid_angles(points, triangles, thresh=thresh).sum(dim=-1))
114
 
115
 
116
  def batch_contains(verts, faces, points):
@@ -124,8 +119,7 @@ def batch_contains(verts, faces, points):
124
  contains = torch.zeros(B, N)
125
 
126
  for i in range(B):
127
- contains[i] = torch.as_tensor(
128
- trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
129
 
130
  return 2.0 * (contains - 0.5)
131
 
@@ -155,8 +149,7 @@ def face_vertices(vertices, faces):
155
  bs, nv = vertices.shape[:2]
156
  bs, nf = faces.shape[:2]
157
  device = vertices.device
158
- faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
159
- nv)[:, None, None]
160
  vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
161
 
162
  return vertices[faces.long()]
@@ -168,7 +161,6 @@ class Pytorch3dRasterizer(nn.Module):
168
  x,y,z are in image space, normalized
169
  can only render squared image now
170
  """
171
-
172
  def __init__(self, image_size=224, blur_radius=0.0, faces_per_pixel=1):
173
  """
174
  use fixed raster_settings for rendering faces
@@ -189,8 +181,7 @@ class Pytorch3dRasterizer(nn.Module):
189
  def forward(self, vertices, faces, attributes=None):
190
  fixed_vertices = vertices.clone()
191
  fixed_vertices[..., :2] = -fixed_vertices[..., :2]
192
- meshes_screen = Meshes(verts=fixed_vertices.float(),
193
- faces=faces.long())
194
  raster_settings = self.raster_settings
195
  pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
196
  meshes_screen,
@@ -204,8 +195,9 @@ class Pytorch3dRasterizer(nn.Module):
204
  vismask = (pix_to_face > -1).float()
205
  D = attributes.shape[-1]
206
  attributes = attributes.clone()
207
- attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
208
- 3, attributes.shape[-1])
 
209
  N, H, W, K, _ = bary_coords.shape
210
  mask = pix_to_face == -1
211
  pix_to_face = pix_to_face.clone()
@@ -213,8 +205,7 @@ class Pytorch3dRasterizer(nn.Module):
213
  idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
214
  pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
215
  pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
216
- pixel_vals[mask] = 0 # Replace masked values in output.
217
  pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
218
- pixel_vals = torch.cat(
219
- [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
220
  return pixel_vals
 
25
  Tensor = NewType("Tensor", torch.Tensor)
26
 
27
 
28
+ def solid_angles(points: Tensor, triangles: Tensor, thresh: float = 1e-8) -> Tensor:
 
 
29
  """Compute solid angle between the input points and triangles
30
  Follows the method described in:
31
  The Solid Angle of a Plane Triangle
 
53
  norms = torch.norm(centered_tris, dim=-1)
54
 
55
  # Should be BxQxFx3
56
+ cross_prod = torch.cross(centered_tris[:, :, :, 1], centered_tris[:, :, :, 2], dim=-1)
 
 
57
  # Should be BxQxF
58
  numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
59
  del cross_prod
 
63
  dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
64
  del centered_tris
65
 
66
+ denominator = (
67
+ norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] + dot02 * norms[:, :, :, 1] +
68
+ dot12 * norms[:, :, :, 0]
69
+ )
70
  del dot01, dot12, dot02, norms
71
 
72
  # Should be BxQ
 
78
  return 2 * solid_angle
79
 
80
 
81
+ def winding_numbers(points: Tensor, triangles: Tensor, thresh: float = 1e-8) -> Tensor:
 
 
82
  """Uses winding_numbers to compute inside/outside
83
  Robust inside-outside segmentation using generalized winding numbers
84
  Alec Jacobson,
 
105
  """
106
  # The generalized winding number is the sum of solid angles of the point
107
  # with respect to all triangles.
108
+ return (1 / (4 * math.pi) * solid_angles(points, triangles, thresh=thresh).sum(dim=-1))
 
109
 
110
 
111
  def batch_contains(verts, faces, points):
 
119
  contains = torch.zeros(B, N)
120
 
121
  for i in range(B):
122
+ contains[i] = torch.as_tensor(trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
 
123
 
124
  return 2.0 * (contains - 0.5)
125
 
 
149
  bs, nv = vertices.shape[:2]
150
  bs, nf = faces.shape[:2]
151
  device = vertices.device
152
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
 
153
  vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
154
 
155
  return vertices[faces.long()]
 
161
  x,y,z are in image space, normalized
162
  can only render squared image now
163
  """
 
164
  def __init__(self, image_size=224, blur_radius=0.0, faces_per_pixel=1):
165
  """
166
  use fixed raster_settings for rendering faces
 
181
  def forward(self, vertices, faces, attributes=None):
182
  fixed_vertices = vertices.clone()
183
  fixed_vertices[..., :2] = -fixed_vertices[..., :2]
184
+ meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long())
 
185
  raster_settings = self.raster_settings
186
  pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
187
  meshes_screen,
 
195
  vismask = (pix_to_face > -1).float()
196
  D = attributes.shape[-1]
197
  attributes = attributes.clone()
198
+ attributes = attributes.view(
199
+ attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1]
200
+ )
201
  N, H, W, K, _ = bary_coords.shape
202
  mask = pix_to_face == -1
203
  pix_to_face = pix_to_face.clone()
 
205
  idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
206
  pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
207
  pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
208
+ pixel_vals[mask] = 0 # Replace masked values in output.
209
  pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
210
+ pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
 
211
  return pixel_vals
lib/common/seg3d_lossless.py CHANGED
@@ -31,7 +31,6 @@ logging.getLogger("lightning").setLevel(logging.ERROR)
31
 
32
 
33
  class Seg3dLossless(nn.Module):
34
-
35
  def __init__(
36
  self,
37
  query_func,
@@ -53,19 +52,14 @@ class Seg3dLossless(nn.Module):
53
  """
54
  super().__init__()
55
  self.query_func = query_func
56
- self.register_buffer(
57
- "b_min",
58
- torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
59
- self.register_buffer(
60
- "b_max",
61
- torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
62
 
63
  # ti.init(arch=ti.cuda)
64
  # self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
65
 
66
  if type(resolutions[0]) is int:
67
- resolutions = torch.tensor([(res, res, res)
68
- for res in resolutions])
69
  else:
70
  resolutions = torch.tensor(resolutions)
71
  self.register_buffer("resolutions", resolutions)
@@ -87,45 +81,36 @@ class Seg3dLossless(nn.Module):
87
  ), f"resolution {resolution} need to be odd becuase of align_corner."
88
 
89
  # init first resolution
90
- init_coords = create_grid3D(0,
91
- resolutions[-1] - 1,
92
- steps=resolutions[0]) # [N, 3]
93
- init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1,
94
- 1) # [bz, N, 3]
95
  self.register_buffer("init_coords", init_coords)
96
 
97
  # some useful tensors
98
  calculated = torch.zeros(
99
- (self.resolutions[-1][2], self.resolutions[-1][1],
100
- self.resolutions[-1][0]),
101
  dtype=torch.bool,
102
  )
103
  self.register_buffer("calculated", calculated)
104
 
105
- gird8_offsets = (torch.stack(
106
- torch.meshgrid(
107
- [
108
- torch.tensor([-1, 0, 1]),
109
- torch.tensor([-1, 0, 1]),
110
- torch.tensor([-1, 0, 1]),
111
- ],
112
- indexing="ij",
113
- )).int().view(3, -1).t()) # [27, 3]
 
 
 
114
  self.register_buffer("gird8_offsets", gird8_offsets)
115
 
116
  # smooth convs
117
- self.smooth_conv3x3 = SmoothConv3D(in_channels=1,
118
- out_channels=1,
119
- kernel_size=3)
120
- self.smooth_conv5x5 = SmoothConv3D(in_channels=1,
121
- out_channels=1,
122
- kernel_size=5)
123
- self.smooth_conv7x7 = SmoothConv3D(in_channels=1,
124
- out_channels=1,
125
- kernel_size=7)
126
- self.smooth_conv9x9 = SmoothConv3D(in_channels=1,
127
- out_channels=1,
128
- kernel_size=9)
129
 
130
  @torch.no_grad()
131
  def batch_eval(self, coords, **kwargs):
@@ -144,7 +129,7 @@ class Seg3dLossless(nn.Module):
144
  # query function
145
  occupancys = self.query_func(**kwargs, points=coords2D)
146
  if type(occupancys) is list:
147
- occupancys = torch.stack(occupancys) # [bz, C, N]
148
  assert (
149
  len(occupancys.size()) == 3
150
  ), "query_func should return a occupancy with shape of [bz, C, N]"
@@ -175,10 +160,9 @@ class Seg3dLossless(nn.Module):
175
 
176
  # first step
177
  if torch.equal(resolution, self.resolutions[0]):
178
- coords = self.init_coords.clone() # torch.long
179
  occupancys = self.batch_eval(coords, **kwargs)
180
- occupancys = occupancys.view(self.batchsize, self.channels, D,
181
- H, W)
182
  if (occupancys > 0.5).sum() == 0:
183
  # return F.interpolate(
184
  # occupancys, size=(final_D, final_H, final_W),
@@ -239,23 +223,22 @@ class Seg3dLossless(nn.Module):
239
 
240
  with torch.no_grad():
241
  if torch.equal(resolution, self.resolutions[1]):
242
- is_boundary = (self.smooth_conv9x9(is_boundary.float())
243
- > 0)[0, 0]
244
  elif torch.equal(resolution, self.resolutions[2]):
245
- is_boundary = (self.smooth_conv7x7(is_boundary.float())
246
- > 0)[0, 0]
247
  else:
248
- is_boundary = (self.smooth_conv3x3(is_boundary.float())
249
- > 0)[0, 0]
250
 
251
  coords_accum = coords_accum.long()
252
  is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
253
  coords_accum[0, :, 0], ] = False
254
- point_coords = (is_boundary.permute(
255
- 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0))
256
- point_indices = (point_coords[:, :, 2] * H * W +
257
- point_coords[:, :, 1] * W +
258
- point_coords[:, :, 0])
 
 
259
 
260
  R, C, D, H, W = occupancys.shape
261
 
@@ -269,13 +252,15 @@ class Seg3dLossless(nn.Module):
269
  # put mask point predictions to the right places on the upsampled grid.
270
  R, C, D, H, W = occupancys.shape
271
  point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
272
- occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
273
- 2, point_indices, occupancys_topk).view(R, C, D, H, W))
 
 
 
274
 
275
  with torch.no_grad():
276
  voxels = coords / stride
277
- coords_accum = torch.cat([voxels, coords_accum],
278
- dim=1).unique(dim=1)
279
 
280
  return occupancys[0, 0]
281
 
@@ -300,18 +285,16 @@ class Seg3dLossless(nn.Module):
300
 
301
  # first step
302
  if torch.equal(resolution, self.resolutions[0]):
303
- coords = self.init_coords.clone() # torch.long
304
  occupancys = self.batch_eval(coords, **kwargs)
305
- occupancys = occupancys.view(self.batchsize, self.channels, D,
306
- H, W)
307
 
308
  if self.visualize:
309
  self.plot(occupancys, coords, final_D, final_H, final_W)
310
 
311
  with torch.no_grad():
312
  coords_accum = coords / stride
313
- calculated[coords[0, :, 2], coords[0, :, 1],
314
- coords[0, :, 0]] = True
315
 
316
  # next steps
317
  else:
@@ -338,35 +321,34 @@ class Seg3dLossless(nn.Module):
338
 
339
  with torch.no_grad():
340
  # TODO
341
- if self.use_shadow and torch.equal(resolution,
342
- self.resolutions[-1]):
343
  # larger z means smaller depth here
344
  depth_res = resolution[2].item()
345
- depth_index = torch.linspace(0,
346
- depth_res - 1,
347
- steps=depth_res).type_as(
348
- occupancys.device)
349
- depth_index_max = (torch.max(
350
- (occupancys > self.balance_value) *
351
- (depth_index + 1),
352
- dim=-1,
353
- keepdim=True,
354
- )[0] - 1)
355
  shadow = depth_index < depth_index_max
356
  is_boundary[shadow] = False
357
  is_boundary = is_boundary[0, 0]
358
  else:
359
- is_boundary = (self.smooth_conv3x3(is_boundary.float())
360
- > 0)[0, 0]
361
  # is_boundary = is_boundary[0, 0]
362
 
363
  is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
364
  coords_accum[0, :, 0], ] = False
365
- point_coords = (is_boundary.permute(
366
- 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0))
367
- point_indices = (point_coords[:, :, 2] * H * W +
368
- point_coords[:, :, 1] * W +
369
- point_coords[:, :, 0])
 
 
370
 
371
  R, C, D, H, W = occupancys.shape
372
  # interpolated value
@@ -388,28 +370,28 @@ class Seg3dLossless(nn.Module):
388
  # put mask point predictions to the right places on the upsampled grid.
389
  R, C, D, H, W = occupancys.shape
390
  point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
391
- occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
392
- 2, point_indices, occupancys_topk).view(R, C, D, H, W))
 
 
 
393
 
394
  with torch.no_grad():
395
  # conflicts
396
- conflicts = ((occupancys_interp - self.balance_value) *
397
- (occupancys_topk - self.balance_value) < 0)[0,
398
- 0]
 
399
 
400
  if self.visualize:
401
- self.plot(occupancys, coords, final_D, final_H,
402
- final_W)
403
 
404
  voxels = coords / stride
405
- coords_accum = torch.cat([voxels, coords_accum],
406
- dim=1).unique(dim=1)
407
- calculated[coords[0, :, 2], coords[0, :, 1],
408
- coords[0, :, 0]] = True
409
 
410
  while conflicts.sum() > 0:
411
- if self.use_shadow and torch.equal(resolution,
412
- self.resolutions[-1]):
413
  break
414
 
415
  with torch.no_grad():
@@ -426,25 +408,27 @@ class Seg3dLossless(nn.Module):
426
  )
427
 
428
  conflicts_boundary = (
429
- (conflicts_coords.int() +
430
- self.gird8_offsets.unsqueeze(1) *
431
- stride.int()).reshape(-1, 3).long().unique(dim=0))
432
- conflicts_boundary[:,
433
- 0] = conflicts_boundary[:, 0].clamp(
434
- 0,
435
- calculated.size(2) - 1)
436
- conflicts_boundary[:,
437
- 1] = conflicts_boundary[:, 1].clamp(
438
- 0,
439
- calculated.size(1) - 1)
440
- conflicts_boundary[:,
441
- 2] = conflicts_boundary[:, 2].clamp(
442
- 0,
443
- calculated.size(0) - 1)
444
-
445
- coords = conflicts_boundary[calculated[
446
- conflicts_boundary[:, 2], conflicts_boundary[:, 1],
447
- conflicts_boundary[:, 0], ] == False]
 
 
448
 
449
  if self.debug:
450
  self.plot(
@@ -458,9 +442,10 @@ class Seg3dLossless(nn.Module):
458
 
459
  coords = coords.unsqueeze(0)
460
  point_coords = coords / stride
461
- point_indices = (point_coords[:, :, 2] * H * W +
462
- point_coords[:, :, 1] * W +
463
- point_coords[:, :, 0])
 
464
 
465
  R, C, D, H, W = occupancys.shape
466
  # interpolated value
@@ -481,44 +466,37 @@ class Seg3dLossless(nn.Module):
481
 
482
  with torch.no_grad():
483
  # conflicts
484
- conflicts = ((occupancys_interp - self.balance_value) *
485
- (occupancys_topk - self.balance_value) <
486
- 0)[0, 0]
 
487
 
488
  # put mask point predictions to the right places on the upsampled grid.
489
- point_indices = point_indices.unsqueeze(1).expand(
490
- -1, C, -1)
491
- occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
492
- 2, point_indices, occupancys_topk).view(R, C, D, H, W))
 
 
493
 
494
  with torch.no_grad():
495
  voxels = coords / stride
496
- coords_accum = torch.cat([voxels, coords_accum],
497
- dim=1).unique(dim=1)
498
- calculated[coords[0, :, 2], coords[0, :, 1],
499
- coords[0, :, 0]] = True
500
 
501
  if self.visualize:
502
  this_stage_coords = torch.cat(this_stage_coords, dim=1)
503
- self.plot(occupancys, this_stage_coords, final_D, final_H,
504
- final_W)
505
 
506
  return occupancys[0, 0]
507
 
508
- def plot(self,
509
- occupancys,
510
- coords,
511
- final_D,
512
- final_H,
513
- final_W,
514
- title="",
515
- **kwargs):
516
  final = F.interpolate(
517
  occupancys.float(),
518
  size=(final_D, final_H, final_W),
519
  mode="trilinear",
520
  align_corners=True,
521
- ) # here true is correct!
522
  x = coords[0, :, 0].to("cpu")
523
  y = coords[0, :, 1].to("cpu")
524
  z = coords[0, :, 2].to("cpu")
@@ -548,20 +526,19 @@ class Seg3dLossless(nn.Module):
548
  sdf_all = sdf.permute(2, 1, 0)
549
 
550
  # shadow
551
- grad_v = (sdf_all > 0.5) * torch.linspace(
552
- resolution, 1, steps=resolution).to(sdf.device)
553
- grad_c = torch.ones_like(sdf_all) * torch.linspace(
554
- 0, resolution - 1, steps=resolution).to(sdf.device)
555
  max_v, max_c = grad_v.max(dim=2)
556
  shadow = grad_c > max_c.view(resolution, resolution, 1)
557
  keep = (sdf_all > 0.5) & (~shadow)
558
 
559
- p1 = keep.nonzero(as_tuple=False).t() # [3, N]
560
- p2 = p1.clone() # z
561
  p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
562
- p3 = p1.clone() # y
563
  p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
564
- p4 = p1.clone() # x
565
  p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
566
 
567
  v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
@@ -569,10 +546,10 @@ class Seg3dLossless(nn.Module):
569
  v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
570
  v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
571
 
572
- X = p1[0, :].long() # [N,]
573
- Y = p1[1, :].long() # [N,]
574
- Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + p1[2, :].float() * (
575
- v2 - 0.5) / (v2 - v1) # [N,]
576
  Z = Z.clamp(0, resolution)
577
 
578
  # normal
@@ -588,8 +565,7 @@ class Seg3dLossless(nn.Module):
588
 
589
  @torch.no_grad()
590
  def render_normal(self, resolution, X, Y, Z, norm):
591
- image = torch.ones((1, 3, resolution, resolution),
592
- dtype=torch.float32).to(norm.device)
593
  color = (norm + 1) / 2.0
594
  color = color.clamp(0, 1)
595
  image[0, :, Y, X] = color.t()
@@ -617,9 +593,9 @@ class Seg3dLossless(nn.Module):
617
  def export_mesh(self, occupancys):
618
 
619
  final = occupancys[1:, 1:, 1:].contiguous()
620
-
621
  verts, faces = marching_cubes(final.unsqueeze(0), isolevel=0.5)
622
  verts = verts[0].cpu().float()
623
- faces = faces[0].cpu().long()[:,[0,2,1]]
624
-
625
  return verts, faces
 
31
 
32
 
33
  class Seg3dLossless(nn.Module):
 
34
  def __init__(
35
  self,
36
  query_func,
 
52
  """
53
  super().__init__()
54
  self.query_func = query_func
55
+ self.register_buffer("b_min", torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
56
+ self.register_buffer("b_max", torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
 
 
 
 
57
 
58
  # ti.init(arch=ti.cuda)
59
  # self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
60
 
61
  if type(resolutions[0]) is int:
62
+ resolutions = torch.tensor([(res, res, res) for res in resolutions])
 
63
  else:
64
  resolutions = torch.tensor(resolutions)
65
  self.register_buffer("resolutions", resolutions)
 
81
  ), f"resolution {resolution} need to be odd becuase of align_corner."
82
 
83
  # init first resolution
84
+ init_coords = create_grid3D(0, resolutions[-1] - 1, steps=resolutions[0]) # [N, 3]
85
+ init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1, 1) # [bz, N, 3]
 
 
 
86
  self.register_buffer("init_coords", init_coords)
87
 
88
  # some useful tensors
89
  calculated = torch.zeros(
90
+ (self.resolutions[-1][2], self.resolutions[-1][1], self.resolutions[-1][0]),
 
91
  dtype=torch.bool,
92
  )
93
  self.register_buffer("calculated", calculated)
94
 
95
+ gird8_offsets = (
96
+ torch.stack(
97
+ torch.meshgrid(
98
+ [
99
+ torch.tensor([-1, 0, 1]),
100
+ torch.tensor([-1, 0, 1]),
101
+ torch.tensor([-1, 0, 1]),
102
+ ],
103
+ indexing="ij",
104
+ )
105
+ ).int().view(3, -1).t()
106
+ ) # [27, 3]
107
  self.register_buffer("gird8_offsets", gird8_offsets)
108
 
109
  # smooth convs
110
+ self.smooth_conv3x3 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=3)
111
+ self.smooth_conv5x5 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=5)
112
+ self.smooth_conv7x7 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=7)
113
+ self.smooth_conv9x9 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=9)
 
 
 
 
 
 
 
 
114
 
115
  @torch.no_grad()
116
  def batch_eval(self, coords, **kwargs):
 
129
  # query function
130
  occupancys = self.query_func(**kwargs, points=coords2D)
131
  if type(occupancys) is list:
132
+ occupancys = torch.stack(occupancys) # [bz, C, N]
133
  assert (
134
  len(occupancys.size()) == 3
135
  ), "query_func should return a occupancy with shape of [bz, C, N]"
 
160
 
161
  # first step
162
  if torch.equal(resolution, self.resolutions[0]):
163
+ coords = self.init_coords.clone() # torch.long
164
  occupancys = self.batch_eval(coords, **kwargs)
165
+ occupancys = occupancys.view(self.batchsize, self.channels, D, H, W)
 
166
  if (occupancys > 0.5).sum() == 0:
167
  # return F.interpolate(
168
  # occupancys, size=(final_D, final_H, final_W),
 
223
 
224
  with torch.no_grad():
225
  if torch.equal(resolution, self.resolutions[1]):
226
+ is_boundary = (self.smooth_conv9x9(is_boundary.float()) > 0)[0, 0]
 
227
  elif torch.equal(resolution, self.resolutions[2]):
228
+ is_boundary = (self.smooth_conv7x7(is_boundary.float()) > 0)[0, 0]
 
229
  else:
230
+ is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0]
 
231
 
232
  coords_accum = coords_accum.long()
233
  is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
234
  coords_accum[0, :, 0], ] = False
235
+ point_coords = (
236
+ is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
237
+ )
238
+ point_indices = (
239
+ point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
240
+ point_coords[:, :, 0]
241
+ )
242
 
243
  R, C, D, H, W = occupancys.shape
244
 
 
252
  # put mask point predictions to the right places on the upsampled grid.
253
  R, C, D, H, W = occupancys.shape
254
  point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
255
+ occupancys = (
256
+ occupancys.reshape(R, C,
257
+ D * H * W).scatter_(2, point_indices,
258
+ occupancys_topk).view(R, C, D, H, W)
259
+ )
260
 
261
  with torch.no_grad():
262
  voxels = coords / stride
263
+ coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
 
264
 
265
  return occupancys[0, 0]
266
 
 
285
 
286
  # first step
287
  if torch.equal(resolution, self.resolutions[0]):
288
+ coords = self.init_coords.clone() # torch.long
289
  occupancys = self.batch_eval(coords, **kwargs)
290
+ occupancys = occupancys.view(self.batchsize, self.channels, D, H, W)
 
291
 
292
  if self.visualize:
293
  self.plot(occupancys, coords, final_D, final_H, final_W)
294
 
295
  with torch.no_grad():
296
  coords_accum = coords / stride
297
+ calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
 
298
 
299
  # next steps
300
  else:
 
321
 
322
  with torch.no_grad():
323
  # TODO
324
+ if self.use_shadow and torch.equal(resolution, self.resolutions[-1]):
 
325
  # larger z means smaller depth here
326
  depth_res = resolution[2].item()
327
+ depth_index = torch.linspace(0, depth_res - 1,
328
+ steps=depth_res).type_as(occupancys.device)
329
+ depth_index_max = (
330
+ torch.max(
331
+ (occupancys > self.balance_value) * (depth_index + 1),
332
+ dim=-1,
333
+ keepdim=True,
334
+ )[0] - 1
335
+ )
 
336
  shadow = depth_index < depth_index_max
337
  is_boundary[shadow] = False
338
  is_boundary = is_boundary[0, 0]
339
  else:
340
+ is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0]
 
341
  # is_boundary = is_boundary[0, 0]
342
 
343
  is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
344
  coords_accum[0, :, 0], ] = False
345
+ point_coords = (
346
+ is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
347
+ )
348
+ point_indices = (
349
+ point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
350
+ point_coords[:, :, 0]
351
+ )
352
 
353
  R, C, D, H, W = occupancys.shape
354
  # interpolated value
 
370
  # put mask point predictions to the right places on the upsampled grid.
371
  R, C, D, H, W = occupancys.shape
372
  point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
373
+ occupancys = (
374
+ occupancys.reshape(R, C,
375
+ D * H * W).scatter_(2, point_indices,
376
+ occupancys_topk).view(R, C, D, H, W)
377
+ )
378
 
379
  with torch.no_grad():
380
  # conflicts
381
+ conflicts = (
382
+ (occupancys_interp - self.balance_value) *
383
+ (occupancys_topk - self.balance_value) < 0
384
+ )[0, 0]
385
 
386
  if self.visualize:
387
+ self.plot(occupancys, coords, final_D, final_H, final_W)
 
388
 
389
  voxels = coords / stride
390
+ coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
391
+ calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
 
 
392
 
393
  while conflicts.sum() > 0:
394
+ if self.use_shadow and torch.equal(resolution, self.resolutions[-1]):
 
395
  break
396
 
397
  with torch.no_grad():
 
408
  )
409
 
410
  conflicts_boundary = (
411
+ (
412
+ conflicts_coords.int() +
413
+ self.gird8_offsets.unsqueeze(1) * stride.int()
414
+ ).reshape(-1, 3).long().unique(dim=0)
415
+ )
416
+ conflicts_boundary[:, 0] = conflicts_boundary[:, 0].clamp(
417
+ 0,
418
+ calculated.size(2) - 1
419
+ )
420
+ conflicts_boundary[:, 1] = conflicts_boundary[:, 1].clamp(
421
+ 0,
422
+ calculated.size(1) - 1
423
+ )
424
+ conflicts_boundary[:, 2] = conflicts_boundary[:, 2].clamp(
425
+ 0,
426
+ calculated.size(0) - 1
427
+ )
428
+
429
+ coords = conflicts_boundary[calculated[conflicts_boundary[:, 2],
430
+ conflicts_boundary[:, 1],
431
+ conflicts_boundary[:, 0], ] == False]
432
 
433
  if self.debug:
434
  self.plot(
 
442
 
443
  coords = coords.unsqueeze(0)
444
  point_coords = coords / stride
445
+ point_indices = (
446
+ point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
447
+ point_coords[:, :, 0]
448
+ )
449
 
450
  R, C, D, H, W = occupancys.shape
451
  # interpolated value
 
466
 
467
  with torch.no_grad():
468
  # conflicts
469
+ conflicts = (
470
+ (occupancys_interp - self.balance_value) *
471
+ (occupancys_topk - self.balance_value) < 0
472
+ )[0, 0]
473
 
474
  # put mask point predictions to the right places on the upsampled grid.
475
+ point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
476
+ occupancys = (
477
+ occupancys.reshape(R, C,
478
+ D * H * W).scatter_(2, point_indices,
479
+ occupancys_topk).view(R, C, D, H, W)
480
+ )
481
 
482
  with torch.no_grad():
483
  voxels = coords / stride
484
+ coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
485
+ calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
 
 
486
 
487
  if self.visualize:
488
  this_stage_coords = torch.cat(this_stage_coords, dim=1)
489
+ self.plot(occupancys, this_stage_coords, final_D, final_H, final_W)
 
490
 
491
  return occupancys[0, 0]
492
 
493
+ def plot(self, occupancys, coords, final_D, final_H, final_W, title="", **kwargs):
 
 
 
 
 
 
 
494
  final = F.interpolate(
495
  occupancys.float(),
496
  size=(final_D, final_H, final_W),
497
  mode="trilinear",
498
  align_corners=True,
499
+ ) # here true is correct!
500
  x = coords[0, :, 0].to("cpu")
501
  y = coords[0, :, 1].to("cpu")
502
  z = coords[0, :, 2].to("cpu")
 
526
  sdf_all = sdf.permute(2, 1, 0)
527
 
528
  # shadow
529
+ grad_v = (sdf_all > 0.5) * torch.linspace(resolution, 1, steps=resolution).to(sdf.device)
530
+ grad_c = torch.ones_like(sdf_all) * torch.linspace(0, resolution - 1,
531
+ steps=resolution).to(sdf.device)
 
532
  max_v, max_c = grad_v.max(dim=2)
533
  shadow = grad_c > max_c.view(resolution, resolution, 1)
534
  keep = (sdf_all > 0.5) & (~shadow)
535
 
536
+ p1 = keep.nonzero(as_tuple=False).t() # [3, N]
537
+ p2 = p1.clone() # z
538
  p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
539
+ p3 = p1.clone() # y
540
  p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
541
+ p4 = p1.clone() # x
542
  p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
543
 
544
  v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
 
546
  v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
547
  v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
548
 
549
+ X = p1[0, :].long() # [N,]
550
+ Y = p1[1, :].long() # [N,]
551
+ Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + p1[2, :].float() * (v2 - 0.5
552
+ ) / (v2 - v1) # [N,]
553
  Z = Z.clamp(0, resolution)
554
 
555
  # normal
 
565
 
566
  @torch.no_grad()
567
  def render_normal(self, resolution, X, Y, Z, norm):
568
+ image = torch.ones((1, 3, resolution, resolution), dtype=torch.float32).to(norm.device)
 
569
  color = (norm + 1) / 2.0
570
  color = color.clamp(0, 1)
571
  image[0, :, Y, X] = color.t()
 
593
  def export_mesh(self, occupancys):
594
 
595
  final = occupancys[1:, 1:, 1:].contiguous()
596
+
597
  verts, faces = marching_cubes(final.unsqueeze(0), isolevel=0.5)
598
  verts = verts[0].cpu().float()
599
+ faces = faces[0].cpu().long()[:, [0, 2, 1]]
600
+
601
  return verts, faces
lib/common/seg3d_utils.py CHANGED
@@ -20,11 +20,7 @@ import torch.nn.functional as F
20
  import matplotlib.pyplot as plt
21
 
22
 
23
- def plot_mask2D(mask,
24
- title="",
25
- point_coords=None,
26
- figsize=10,
27
- point_marker_size=5):
28
  '''
29
  Simple plotting tool to show intermediate mask predictions and points
30
  where PointRend is applied.
@@ -46,26 +42,19 @@ def plot_mask2D(mask,
46
  plt.xlabel(W, fontsize=30)
47
  plt.xticks([], [])
48
  plt.yticks([], [])
49
- plt.imshow(mask.detach(),
50
- interpolation="nearest",
51
- cmap=plt.get_cmap('gray'))
52
  if point_coords is not None:
53
- plt.scatter(x=point_coords[0],
54
- y=point_coords[1],
55
- color="red",
56
- s=point_marker_size,
57
- clip_on=True)
58
  plt.xlim(-0.5, W - 0.5)
59
  plt.ylim(H - 0.5, -0.5)
60
  plt.show()
61
 
62
 
63
- def plot_mask3D(mask=None,
64
- title="",
65
- point_coords=None,
66
- figsize=1500,
67
- point_marker_size=8,
68
- interactive=True):
69
  '''
70
  Simple plotting tool to show intermediate mask predictions and points
71
  where PointRend is applied.
@@ -90,7 +79,8 @@ def plot_mask3D(mask=None,
90
 
91
  # marching cube to find surface
92
  verts, faces, normals, values = measure.marching_cubes_lewiner(
93
- mask, 0.5, gradient_direction='ascent')
 
94
 
95
  # create a mesh
96
  mesh = trimesh.Trimesh(verts, faces)
@@ -110,57 +100,49 @@ def plot_mask3D(mask=None,
110
  pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
111
  vis_list.append(pc)
112
 
113
- vp.show(*vis_list,
114
- bg="white",
115
- axes=1,
116
- interactive=interactive,
117
- azimuth=30,
118
- elevation=30)
119
 
120
 
121
  def create_grid3D(min, max, steps):
122
  if type(min) is int:
123
- min = (min, min, min) # (x, y, z)
124
  if type(max) is int:
125
- max = (max, max, max) # (x, y)
126
  if type(steps) is int:
127
- steps = (steps, steps, steps) # (x, y, z)
128
  arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
129
  arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
130
  arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
131
- gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX],
132
- indexing='ij')
133
- coords = torch.stack([gridW, girdH,
134
- gridD]) # [2, steps[0], steps[1], steps[2]]
135
- coords = coords.view(3, -1).t() # [N, 3]
136
  return coords
137
 
138
 
139
  def create_grid2D(min, max, steps):
140
  if type(min) is int:
141
- min = (min, min) # (x, y)
142
  if type(max) is int:
143
- max = (max, max) # (x, y)
144
  if type(steps) is int:
145
- steps = (steps, steps) # (x, y)
146
  arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
147
  arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
148
  girdH, gridW = torch.meshgrid([arrangeY, arrangeX], indexing='ij')
149
- coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
150
- coords = coords.view(2, -1).t() # [N, 2]
151
  return coords
152
 
153
 
154
  class SmoothConv2D(nn.Module):
155
-
156
  def __init__(self, in_channels, out_channels, kernel_size=3):
157
  super().__init__()
158
  assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
159
  self.padding = (kernel_size - 1) // 2
160
 
161
  weight = torch.ones(
162
- (in_channels, out_channels, kernel_size, kernel_size),
163
- dtype=torch.float32) / (kernel_size**2)
164
  self.register_buffer('weight', weight)
165
 
166
  def forward(self, input):
@@ -168,53 +150,49 @@ class SmoothConv2D(nn.Module):
168
 
169
 
170
  class SmoothConv3D(nn.Module):
171
-
172
  def __init__(self, in_channels, out_channels, kernel_size=3):
173
  super().__init__()
174
  assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
175
  self.padding = (kernel_size - 1) // 2
176
 
177
  weight = torch.ones(
178
- (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
179
- dtype=torch.float32) / (kernel_size**3)
180
  self.register_buffer('weight', weight)
181
 
182
  def forward(self, input):
183
  return F.conv3d(input, self.weight, padding=self.padding)
184
 
185
 
186
- def build_smooth_conv3D(in_channels=1,
187
- out_channels=1,
188
- kernel_size=3,
189
- padding=1):
190
- smooth_conv = torch.nn.Conv3d(in_channels=in_channels,
191
- out_channels=out_channels,
192
- kernel_size=kernel_size,
193
- padding=padding)
194
  smooth_conv.weight.data = torch.ones(
195
- (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
196
- dtype=torch.float32) / (kernel_size**3)
197
  smooth_conv.bias.data = torch.zeros(out_channels)
198
  return smooth_conv
199
 
200
 
201
- def build_smooth_conv2D(in_channels=1,
202
- out_channels=1,
203
- kernel_size=3,
204
- padding=1):
205
- smooth_conv = torch.nn.Conv2d(in_channels=in_channels,
206
- out_channels=out_channels,
207
- kernel_size=kernel_size,
208
- padding=padding)
209
  smooth_conv.weight.data = torch.ones(
210
- (in_channels, out_channels, kernel_size, kernel_size),
211
- dtype=torch.float32) / (kernel_size**2)
212
  smooth_conv.bias.data = torch.zeros(out_channels)
213
  return smooth_conv
214
 
215
 
216
- def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
217
- **kwargs):
218
  """
219
  Find `num_points` most uncertain points from `uncertainty_map` grid.
220
  Args:
@@ -233,28 +211,21 @@ def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
233
  # d_step = 1.0 / float(D)
234
 
235
  num_points = min(D * H * W, num_points)
236
- point_scores, point_indices = torch.topk(uncertainty_map.view(
237
- R, D * H * W),
238
- k=num_points,
239
- dim=1)
240
- point_coords = torch.zeros(R,
241
- num_points,
242
- 3,
243
- dtype=torch.float,
244
- device=uncertainty_map.device)
245
  # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
246
  # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
247
  # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
248
- point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
249
- point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
250
- point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
251
- print(f"resolution {D} x {H} x {W}", point_scores.min(),
252
- point_scores.max())
253
  return point_indices, point_coords
254
 
255
 
256
- def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
257
- clip_min):
258
  """
259
  Find `num_points` most uncertain points from `uncertainty_map` grid.
260
  Args:
@@ -276,28 +247,21 @@ def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
276
  uncertainty_map = uncertainty_map.view(D * H * W)
277
  indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
278
  num_points = min(num_points, indices.size(0))
279
- point_scores, point_indices = torch.topk(uncertainty_map[indices],
280
- k=num_points,
281
- dim=0)
282
  point_indices = indices[point_indices].unsqueeze(0)
283
 
284
- point_coords = torch.zeros(R,
285
- num_points,
286
- 3,
287
- dtype=torch.float,
288
- device=uncertainty_map.device)
289
  # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
290
  # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
291
  # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
292
- point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
293
- point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
294
- point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
295
  # print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
296
  return point_indices, point_coords
297
 
298
 
299
- def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
300
- **kwargs):
301
  """
302
  Find `num_points` most uncertain points from `uncertainty_map` grid.
303
  Args:
@@ -315,14 +279,8 @@ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
315
  # w_step = 1.0 / float(W)
316
 
317
  num_points = min(H * W, num_points)
318
- point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W),
319
- k=num_points,
320
- dim=1)
321
- point_coords = torch.zeros(R,
322
- num_points,
323
- 2,
324
- dtype=torch.long,
325
- device=uncertainty_map.device)
326
  # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
327
  # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
328
  point_coords[:, :, 0] = (point_indices % W).to(torch.long)
@@ -331,8 +289,7 @@ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
331
  return point_indices, point_coords
332
 
333
 
334
- def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
335
- clip_min):
336
  """
337
  Find `num_points` most uncertain points from `uncertainty_map` grid.
338
  Args:
@@ -353,16 +310,10 @@ def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
353
  uncertainty_map = uncertainty_map.view(H * W)
354
  indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
355
  num_points = min(num_points, indices.size(0))
356
- point_scores, point_indices = torch.topk(uncertainty_map[indices],
357
- k=num_points,
358
- dim=0)
359
  point_indices = indices[point_indices].unsqueeze(0)
360
 
361
- point_coords = torch.zeros(R,
362
- num_points,
363
- 2,
364
- dtype=torch.long,
365
- device=uncertainty_map.device)
366
  # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
367
  # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
368
  point_coords[:, :, 0] = (point_indices % W).to(torch.long)
@@ -388,7 +339,6 @@ def calculate_uncertainty(logits, classes=None, balance_value=0.5):
388
  if logits.shape[1] == 1:
389
  gt_class_logits = logits
390
  else:
391
- gt_class_logits = logits[
392
- torch.arange(logits.shape[0], device=logits.device),
393
- classes].unsqueeze(1)
394
  return -torch.abs(gt_class_logits - balance_value)
 
20
  import matplotlib.pyplot as plt
21
 
22
 
23
+ def plot_mask2D(mask, title="", point_coords=None, figsize=10, point_marker_size=5):
 
 
 
 
24
  '''
25
  Simple plotting tool to show intermediate mask predictions and points
26
  where PointRend is applied.
 
42
  plt.xlabel(W, fontsize=30)
43
  plt.xticks([], [])
44
  plt.yticks([], [])
45
+ plt.imshow(mask.detach(), interpolation="nearest", cmap=plt.get_cmap('gray'))
 
 
46
  if point_coords is not None:
47
+ plt.scatter(
48
+ x=point_coords[0], y=point_coords[1], color="red", s=point_marker_size, clip_on=True
49
+ )
 
 
50
  plt.xlim(-0.5, W - 0.5)
51
  plt.ylim(H - 0.5, -0.5)
52
  plt.show()
53
 
54
 
55
+ def plot_mask3D(
56
+ mask=None, title="", point_coords=None, figsize=1500, point_marker_size=8, interactive=True
57
+ ):
 
 
 
58
  '''
59
  Simple plotting tool to show intermediate mask predictions and points
60
  where PointRend is applied.
 
79
 
80
  # marching cube to find surface
81
  verts, faces, normals, values = measure.marching_cubes_lewiner(
82
+ mask, 0.5, gradient_direction='ascent'
83
+ )
84
 
85
  # create a mesh
86
  mesh = trimesh.Trimesh(verts, faces)
 
100
  pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
101
  vis_list.append(pc)
102
 
103
+ vp.show(*vis_list, bg="white", axes=1, interactive=interactive, azimuth=30, elevation=30)
 
 
 
 
 
104
 
105
 
106
  def create_grid3D(min, max, steps):
107
  if type(min) is int:
108
+ min = (min, min, min) # (x, y, z)
109
  if type(max) is int:
110
+ max = (max, max, max) # (x, y)
111
  if type(steps) is int:
112
+ steps = (steps, steps, steps) # (x, y, z)
113
  arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
114
  arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
115
  arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
116
+ gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX], indexing='ij')
117
+ coords = torch.stack([gridW, girdH, gridD]) # [2, steps[0], steps[1], steps[2]]
118
+ coords = coords.view(3, -1).t() # [N, 3]
 
 
119
  return coords
120
 
121
 
122
  def create_grid2D(min, max, steps):
123
  if type(min) is int:
124
+ min = (min, min) # (x, y)
125
  if type(max) is int:
126
+ max = (max, max) # (x, y)
127
  if type(steps) is int:
128
+ steps = (steps, steps) # (x, y)
129
  arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
130
  arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
131
  girdH, gridW = torch.meshgrid([arrangeY, arrangeX], indexing='ij')
132
+ coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
133
+ coords = coords.view(2, -1).t() # [N, 2]
134
  return coords
135
 
136
 
137
  class SmoothConv2D(nn.Module):
 
138
  def __init__(self, in_channels, out_channels, kernel_size=3):
139
  super().__init__()
140
  assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
141
  self.padding = (kernel_size - 1) // 2
142
 
143
  weight = torch.ones(
144
+ (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32
145
+ ) / (kernel_size**2)
146
  self.register_buffer('weight', weight)
147
 
148
  def forward(self, input):
 
150
 
151
 
152
  class SmoothConv3D(nn.Module):
 
153
  def __init__(self, in_channels, out_channels, kernel_size=3):
154
  super().__init__()
155
  assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
156
  self.padding = (kernel_size - 1) // 2
157
 
158
  weight = torch.ones(
159
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32
160
+ ) / (kernel_size**3)
161
  self.register_buffer('weight', weight)
162
 
163
  def forward(self, input):
164
  return F.conv3d(input, self.weight, padding=self.padding)
165
 
166
 
167
+ def build_smooth_conv3D(in_channels=1, out_channels=1, kernel_size=3, padding=1):
168
+ smooth_conv = torch.nn.Conv3d(
169
+ in_channels=in_channels,
170
+ out_channels=out_channels,
171
+ kernel_size=kernel_size,
172
+ padding=padding
173
+ )
 
174
  smooth_conv.weight.data = torch.ones(
175
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32
176
+ ) / (kernel_size**3)
177
  smooth_conv.bias.data = torch.zeros(out_channels)
178
  return smooth_conv
179
 
180
 
181
+ def build_smooth_conv2D(in_channels=1, out_channels=1, kernel_size=3, padding=1):
182
+ smooth_conv = torch.nn.Conv2d(
183
+ in_channels=in_channels,
184
+ out_channels=out_channels,
185
+ kernel_size=kernel_size,
186
+ padding=padding
187
+ )
 
188
  smooth_conv.weight.data = torch.ones(
189
+ (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32
190
+ ) / (kernel_size**2)
191
  smooth_conv.bias.data = torch.zeros(out_channels)
192
  return smooth_conv
193
 
194
 
195
+ def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, **kwargs):
 
196
  """
197
  Find `num_points` most uncertain points from `uncertainty_map` grid.
198
  Args:
 
211
  # d_step = 1.0 / float(D)
212
 
213
  num_points = min(D * H * W, num_points)
214
+ point_scores, point_indices = torch.topk(
215
+ uncertainty_map.view(R, D * H * W), k=num_points, dim=1
216
+ )
217
+ point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device)
 
 
 
 
 
218
  # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
219
  # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
220
  # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
221
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
222
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
223
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
224
+ print(f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
 
225
  return point_indices, point_coords
226
 
227
 
228
+ def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, clip_min):
 
229
  """
230
  Find `num_points` most uncertain points from `uncertainty_map` grid.
231
  Args:
 
247
  uncertainty_map = uncertainty_map.view(D * H * W)
248
  indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
249
  num_points = min(num_points, indices.size(0))
250
+ point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0)
 
 
251
  point_indices = indices[point_indices].unsqueeze(0)
252
 
253
+ point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device)
 
 
 
 
254
  # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
255
  # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
256
  # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
257
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
258
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
259
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
260
  # print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
261
  return point_indices, point_coords
262
 
263
 
264
+ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, **kwargs):
 
265
  """
266
  Find `num_points` most uncertain points from `uncertainty_map` grid.
267
  Args:
 
279
  # w_step = 1.0 / float(W)
280
 
281
  num_points = min(H * W, num_points)
282
+ point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)
283
+ point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device)
 
 
 
 
 
 
284
  # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
285
  # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
286
  point_coords[:, :, 0] = (point_indices % W).to(torch.long)
 
289
  return point_indices, point_coords
290
 
291
 
292
+ def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, clip_min):
 
293
  """
294
  Find `num_points` most uncertain points from `uncertainty_map` grid.
295
  Args:
 
310
  uncertainty_map = uncertainty_map.view(H * W)
311
  indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
312
  num_points = min(num_points, indices.size(0))
313
+ point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0)
 
 
314
  point_indices = indices[point_indices].unsqueeze(0)
315
 
316
+ point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device)
 
 
 
 
317
  # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
318
  # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
319
  point_coords[:, :, 0] = (point_indices % W).to(torch.long)
 
339
  if logits.shape[1] == 1:
340
  gt_class_logits = logits
341
  else:
342
+ gt_class_logits = logits[torch.arange(logits.shape[0], device=logits.device),
343
+ classes].unsqueeze(1)
 
344
  return -torch.abs(gt_class_logits - balance_value)
lib/common/train_util.py CHANGED
@@ -14,63 +14,62 @@
14
  #
15
  # Contact: ps-license@tuebingen.mpg.de
16
 
17
- import yaml
18
- import os.path as osp
19
  import torch
20
- import numpy as np
21
  from ..dataset.mesh_util import *
22
  from ..net.geometry import orthogonal
23
- import cv2, PIL
24
- from tqdm import tqdm
25
- import os
26
  from termcolor import colored
27
  import pytorch_lightning as pl
28
 
29
 
 
 
 
 
 
30
  def init_loss():
31
 
32
  losses = {
33
- # Cloth: Normal_recon - Normal_pred
34
  "cloth": {
35
  "weight": 1e3,
36
  "value": 0.0
37
  },
38
- # Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
39
- "stiffness": {
40
  "weight": 1e5,
41
  "value": 0.0
42
  },
43
- # Cloth: det(R) = 1
44
  "rigid": {
45
  "weight": 1e5,
46
  "value": 0.0
47
  },
48
- # Cloth: edge length
49
  "edge": {
50
  "weight": 0,
51
  "value": 0.0
52
  },
53
- # Cloth: normal consistency
54
  "nc": {
55
  "weight": 0,
56
  "value": 0.0
57
  },
58
- # Cloth: laplacian smoonth
59
- "laplacian": {
60
  "weight": 1e2,
61
  "value": 0.0
62
  },
63
- # Body: Normal_pred - Normal_smpl
64
  "normal": {
65
  "weight": 1e0,
66
  "value": 0.0
67
  },
68
- # Body: Silhouette_pred - Silhouette_smpl
69
  "silhouette": {
70
  "weight": 1e0,
71
  "value": 0.0
72
  },
73
- # Joint: reprojected joints difference
74
  "joint": {
75
  "weight": 5e0,
76
  "value": 0.0
@@ -81,7 +80,6 @@ def init_loss():
81
 
82
 
83
  class SubTrainer(pl.Trainer):
84
-
85
  def save_checkpoint(self, filepath, weights_only=False):
86
  """Save model/training states as a checkpoint file through state-dump and file-write.
87
  Args:
@@ -101,214 +99,6 @@ class SubTrainer(pl.Trainer):
101
  pl.utilities.cloud_io.atomic_save(_checkpoint, filepath)
102
 
103
 
104
- def rename(old_dict, old_name, new_name):
105
- new_dict = {}
106
- for key, value in zip(old_dict.keys(), old_dict.values()):
107
- new_key = key if key != old_name else new_name
108
- new_dict[new_key] = old_dict[key]
109
- return new_dict
110
-
111
-
112
- def load_normal_networks(model, normal_path):
113
-
114
- pretrained_dict = torch.load(
115
- normal_path,
116
- map_location=model.device)["state_dict"]
117
- model_dict = model.state_dict()
118
-
119
- # 1. filter out unnecessary keys
120
- pretrained_dict = {
121
- k: v
122
- for k, v in pretrained_dict.items()
123
- if k in model_dict and v.shape == model_dict[k].shape
124
- }
125
-
126
- # # 2. overwrite entries in the existing state dict
127
- model_dict.update(pretrained_dict)
128
- # 3. load the new state dict
129
- model.load_state_dict(model_dict)
130
-
131
- del pretrained_dict
132
- del model_dict
133
-
134
- print(colored(f"Resume Normal weights from {normal_path}", "green"))
135
-
136
-
137
- def load_networks(model, mlp_path, normal_path=None):
138
-
139
- model_dict = model.state_dict()
140
- main_dict = {}
141
- normal_dict = {}
142
-
143
- # MLP part loading
144
- if os.path.exists(mlp_path) and mlp_path.endswith("ckpt"):
145
- main_dict = torch.load(
146
- mlp_path,
147
- map_location=model.device)["state_dict"]
148
-
149
- main_dict = {
150
- k: v
151
- for k, v in main_dict.items()
152
- if k in model_dict and v.shape == model_dict[k].shape and (
153
- "reconEngine" not in k) and ("normal_filter" not in k) and (
154
- "voxelization" not in k)
155
- }
156
- print(colored(f"Resume MLP weights from {mlp_path}", "green"))
157
-
158
- # normal network part loading
159
- if normal_path is not None and os.path.exists(normal_path) and normal_path.endswith("ckpt"):
160
- normal_dict = torch.load(
161
- normal_path,
162
- map_location=model.device)["state_dict"]
163
-
164
- for key in normal_dict.keys():
165
- normal_dict = rename(normal_dict, key,
166
- key.replace("netG", "netG.normal_filter"))
167
-
168
- normal_dict = {
169
- k: v
170
- for k, v in normal_dict.items()
171
- if k in model_dict and v.shape == model_dict[k].shape
172
- }
173
- print(colored(f"Resume normal model from {normal_path}", "green"))
174
-
175
- model_dict.update(main_dict)
176
- model_dict.update(normal_dict)
177
- model.load_state_dict(model_dict)
178
-
179
- # clean unused GPU memory
180
- del main_dict
181
- del normal_dict
182
- del model_dict
183
- torch.cuda.empty_cache()
184
-
185
-
186
- def reshape_sample_tensor(sample_tensor, num_views):
187
- if num_views == 1:
188
- return sample_tensor
189
- # Need to repeat sample_tensor along the batch dim num_views times
190
- sample_tensor = sample_tensor.unsqueeze(dim=1)
191
- sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
192
- sample_tensor = sample_tensor.view(
193
- sample_tensor.shape[0] * sample_tensor.shape[1],
194
- sample_tensor.shape[2],
195
- sample_tensor.shape[3],
196
- )
197
- return sample_tensor
198
-
199
-
200
- def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
201
- """Sets the learning rate to the initial LR decayed by schedule"""
202
- if epoch in schedule:
203
- lr *= gamma
204
- for param_group in optimizer.param_groups:
205
- param_group["lr"] = lr
206
- return lr
207
-
208
-
209
- def compute_acc(pred, gt, thresh=0.5):
210
- """
211
- return:
212
- IOU, precision, and recall
213
- """
214
- with torch.no_grad():
215
- vol_pred = pred > thresh
216
- vol_gt = gt > thresh
217
-
218
- union = vol_pred | vol_gt
219
- inter = vol_pred & vol_gt
220
-
221
- true_pos = inter.sum().float()
222
-
223
- union = union.sum().float()
224
- if union == 0:
225
- union = 1
226
- vol_pred = vol_pred.sum().float()
227
- if vol_pred == 0:
228
- vol_pred = 1
229
- vol_gt = vol_gt.sum().float()
230
- if vol_gt == 0:
231
- vol_gt = 1
232
- return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
233
-
234
- def calc_error(opt, net, cuda, dataset, num_tests):
235
- if num_tests > len(dataset):
236
- num_tests = len(dataset)
237
- with torch.no_grad():
238
- erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
239
- for idx in tqdm(range(num_tests)):
240
- data = dataset[idx * len(dataset) // num_tests]
241
- # retrieve the data
242
- image_tensor = data["img"].to(device=cuda)
243
- calib_tensor = data["calib"].to(device=cuda)
244
- sample_tensor = data["samples"].to(device=cuda).unsqueeze(0)
245
- if opt.num_views > 1:
246
- sample_tensor = reshape_sample_tensor(sample_tensor,
247
- opt.num_views)
248
- label_tensor = data["labels"].to(device=cuda).unsqueeze(0)
249
-
250
- res, error = net.forward(image_tensor,
251
- sample_tensor,
252
- calib_tensor,
253
- labels=label_tensor)
254
-
255
- IOU, prec, recall = compute_acc(res, label_tensor)
256
-
257
- # print(
258
- # '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
259
- # .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
260
- erorr_arr.append(error.item())
261
- IOU_arr.append(IOU.item())
262
- prec_arr.append(prec.item())
263
- recall_arr.append(recall.item())
264
-
265
- return (
266
- np.average(erorr_arr),
267
- np.average(IOU_arr),
268
- np.average(prec_arr),
269
- np.average(recall_arr),
270
- )
271
-
272
-
273
- def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
274
- if num_tests > len(dataset):
275
- num_tests = len(dataset)
276
- with torch.no_grad():
277
- error_color_arr = []
278
-
279
- for idx in tqdm(range(num_tests)):
280
- data = dataset[idx * len(dataset) // num_tests]
281
- # retrieve the data
282
- image_tensor = data["img"].to(device=cuda)
283
- calib_tensor = data["calib"].to(device=cuda)
284
- color_sample_tensor = data["color_samples"].to(
285
- device=cuda).unsqueeze(0)
286
-
287
- if opt.num_views > 1:
288
- color_sample_tensor = reshape_sample_tensor(
289
- color_sample_tensor, opt.num_views)
290
-
291
- rgb_tensor = data["rgbs"].to(device=cuda).unsqueeze(0)
292
-
293
- netG.filter(image_tensor)
294
- _, errorC = netC.forward(
295
- image_tensor,
296
- netG.get_im_feat(),
297
- color_sample_tensor,
298
- calib_tensor,
299
- labels=rgb_tensor,
300
- )
301
-
302
- # print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
303
- # .format(idx, num_tests, errorG.item(), errorC.item()))
304
- error_color_arr.append(errorC.item())
305
-
306
- return np.average(error_color_arr)
307
-
308
-
309
- # pytorch lightning training related fucntions
310
-
311
-
312
  def query_func(opt, netG, features, points, proj_matrix=None):
313
  """
314
  - points: size of (bz, N, 3)
@@ -317,7 +107,7 @@ def query_func(opt, netG, features, points, proj_matrix=None):
317
  """
318
  assert len(points) == 1
319
  samples = points.repeat(opt.num_views, 1, 1)
320
- samples = samples.permute(0, 2, 1) # [bz, 3, N]
321
 
322
  # view specific query
323
  if proj_matrix is not None:
@@ -337,85 +127,25 @@ def query_func(opt, netG, features, points, proj_matrix=None):
337
 
338
  return preds
339
 
 
340
  def query_func_IF(batch, netG, points):
341
  """
342
  - points: size of (bz, N, 3)
343
  return: size of (bz, 1, N)
344
  """
345
-
346
  batch["samples_geo"] = points
347
  batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points)
348
-
349
  preds = netG(batch)
350
 
351
  return preds.unsqueeze(1)
352
 
353
 
354
- def isin(ar1, ar2):
355
- return (ar1[..., None] == ar2).any(-1)
356
-
357
-
358
- def in1d(ar1, ar2):
359
- mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool)
360
- mask[ar2.unique()] = True
361
- return mask[ar1]
362
-
363
  def batch_mean(res, key):
364
- return torch.stack([
365
- x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key])
366
- for x in res
367
- ]).mean()
368
-
369
-
370
- def tf_log_convert(log_dict):
371
- new_log_dict = log_dict.copy()
372
- for k, v in log_dict.items():
373
- new_log_dict[k.replace("_", "/")] = v
374
- del new_log_dict[k]
375
-
376
- return new_log_dict
377
-
378
-
379
- def bar_log_convert(log_dict, name=None, rot=None):
380
- from decimal import Decimal
381
-
382
- new_log_dict = {}
383
-
384
- if name is not None:
385
- new_log_dict["name"] = name[0]
386
- if rot is not None:
387
- new_log_dict["rot"] = rot[0]
388
-
389
- for k, v in log_dict.items():
390
- color = "yellow"
391
- if "loss" in k:
392
- color = "red"
393
- k = k.replace("loss", "L")
394
- elif "acc" in k:
395
- color = "green"
396
- k = k.replace("acc", "A")
397
- elif "iou" in k:
398
- color = "green"
399
- k = k.replace("iou", "I")
400
- elif "prec" in k:
401
- color = "green"
402
- k = k.replace("prec", "P")
403
- elif "recall" in k:
404
- color = "green"
405
- k = k.replace("recall", "R")
406
-
407
- if "lr" not in k:
408
- new_log_dict[colored(k.split("_")[1],
409
- color)] = colored(f"{v:.3f}", color)
410
- else:
411
- new_log_dict[colored(k.split("_")[1],
412
- color)] = colored(f"{Decimal(str(v)):.1E}",
413
- color)
414
-
415
- if "loss" in new_log_dict.keys():
416
- del new_log_dict["loss"]
417
-
418
- return new_log_dict
419
 
420
 
421
  def accumulate(outputs, rot_num, split):
@@ -430,160 +160,10 @@ def accumulate(outputs, rot_num, split):
430
  keyword = f"{dataset}/{metric}"
431
  if keyword not in hparam_log_dict.keys():
432
  hparam_log_dict[keyword] = 0
433
- for idx in range(split[dataset][0] * rot_num,
434
- split[dataset][1] * rot_num):
435
  hparam_log_dict[keyword] += outputs[idx][metric].item()
436
- hparam_log_dict[keyword] /= (split[dataset][1] -
437
- split[dataset][0]) * rot_num
438
 
439
  print(colored(hparam_log_dict, "green"))
440
 
441
  return hparam_log_dict
442
-
443
-
444
- def calc_error_N(outputs, targets):
445
- """calculate the error of normal (IGR)
446
-
447
- Args:
448
- outputs (torch.tensor): [B, 3, N]
449
- target (torch.tensor): [B, N, 3]
450
-
451
- # manifold loss and grad_loss in IGR paper
452
- grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
453
- normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
454
-
455
- Returns:
456
- torch.tensor: error of valid normals on the surface
457
- """
458
- # outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3))
459
- outputs = -outputs.permute(0, 2, 1).reshape(-1, 1)
460
- targets = targets.reshape(-1, 3)[:, 2:3]
461
- with_normals = targets.sum(dim=1).abs() > 0.0
462
-
463
- # eikonal loss
464
- grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean()
465
- # normals loss
466
- normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean()
467
-
468
- return grad_loss * 0.0 + normal_loss
469
-
470
-
471
- def calc_knn_acc(preds, carn_verts, labels, pick_num):
472
- """calculate knn accuracy
473
-
474
- Args:
475
- preds (torch.tensor): [B, 3, N]
476
- carn_verts (torch.tensor): [SMPLX_V_num, 3]
477
- labels (torch.tensor): [B, N_knn, N]
478
- """
479
- N_knn_full = labels.shape[1]
480
- preds = preds.permute(0, 2, 1).reshape(-1, 3)
481
- labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn]
482
- labels = labels[:, :pick_num]
483
-
484
- dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num]
485
- knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn]
486
- cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0]
487
- bool_col = torch.zeros_like(cat_mat)[:, 0]
488
- for i in range(pick_num * 2 - 1):
489
- bool_col += cat_mat[:, i] == cat_mat[:, i + 1]
490
- acc = (bool_col > 0).sum() / len(bool_col)
491
-
492
- return acc
493
-
494
-
495
- def calc_acc_seg(output, target, num_multiseg):
496
- from pytorch_lightning.metrics import Accuracy
497
-
498
- return Accuracy()(output.reshape(-1, num_multiseg).cpu(),
499
- target.flatten().cpu())
500
-
501
-
502
- def add_watermark(imgs, titles):
503
-
504
- # Write some Text
505
-
506
- font = cv2.FONT_HERSHEY_SIMPLEX
507
- bottomLeftCornerOfText = (350, 50)
508
- bottomRightCornerOfText = (800, 50)
509
- fontScale = 1
510
- fontColor = (1.0, 1.0, 1.0)
511
- lineType = 2
512
-
513
- for i in range(len(imgs)):
514
-
515
- title = titles[i + 1]
516
- cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale,
517
- fontColor, lineType)
518
-
519
- if i == 0:
520
- cv2.putText(
521
- imgs[i],
522
- str(titles[i][0]),
523
- bottomRightCornerOfText,
524
- font,
525
- fontScale,
526
- fontColor,
527
- lineType,
528
- )
529
-
530
- result = np.concatenate(imgs, axis=0).transpose(2, 0, 1)
531
-
532
- return result
533
-
534
-
535
- def make_test_gif(img_dir):
536
-
537
- if img_dir is not None and len(os.listdir(img_dir)) > 0:
538
- for dataset in os.listdir(img_dir):
539
- for subject in sorted(os.listdir(osp.join(img_dir, dataset))):
540
- img_lst = []
541
- im1 = None
542
- for file in sorted(
543
- os.listdir(osp.join(img_dir, dataset, subject))):
544
- if file[-3:] not in ["obj", "gif"]:
545
- img_path = os.path.join(img_dir, dataset, subject,
546
- file)
547
- if im1 == None:
548
- im1 = PIL.Image.open(img_path)
549
- else:
550
- img_lst.append(PIL.Image.open(img_path))
551
-
552
- print(os.path.join(img_dir, dataset, subject, "out.gif"))
553
- im1.save(
554
- os.path.join(img_dir, dataset, subject, "out.gif"),
555
- save_all=True,
556
- append_images=img_lst,
557
- duration=500,
558
- loop=0,
559
- )
560
-
561
-
562
- def export_cfg(logger, dir, cfg):
563
-
564
- cfg_export_file = osp.join(dir, f"cfg_{logger.version}.yaml")
565
-
566
- if not osp.exists(cfg_export_file):
567
- os.makedirs(osp.dirname(cfg_export_file), exist_ok=True)
568
- with open(cfg_export_file, "w+") as file:
569
- _ = yaml.dump(cfg, file)
570
-
571
-
572
- from yacs.config import CfgNode
573
-
574
- _VALID_TYPES = {tuple, list, str, int, float, bool}
575
-
576
-
577
- def convert_to_dict(cfg_node, key_list=[]):
578
- """ Convert a config node to dictionary """
579
- if not isinstance(cfg_node, CfgNode):
580
- if type(cfg_node) not in _VALID_TYPES:
581
- print(
582
- "Key {} with value {} is not a valid type; valid types: {}".
583
- format(".".join(key_list), type(cfg_node), _VALID_TYPES), )
584
- return cfg_node
585
- else:
586
- cfg_dict = dict(cfg_node)
587
- for k, v in cfg_dict.items():
588
- cfg_dict[k] = convert_to_dict(v, key_list + [k])
589
- return cfg_dict
 
14
  #
15
  # Contact: ps-license@tuebingen.mpg.de
16
 
 
 
17
  import torch
 
18
  from ..dataset.mesh_util import *
19
  from ..net.geometry import orthogonal
 
 
 
20
  from termcolor import colored
21
  import pytorch_lightning as pl
22
 
23
 
24
+ class Format:
25
+ end = '\033[0m'
26
+ start = '\033[4m'
27
+
28
+
29
  def init_loss():
30
 
31
  losses = {
32
+ # Cloth: chamfer distance
33
  "cloth": {
34
  "weight": 1e3,
35
  "value": 0.0
36
  },
37
+ # Stiffness: [RT]_v1 - [RT]_v2 (v1-edge-v2)
38
+ "stiff": {
39
  "weight": 1e5,
40
  "value": 0.0
41
  },
42
+ # Cloth: det(R) = 1
43
  "rigid": {
44
  "weight": 1e5,
45
  "value": 0.0
46
  },
47
+ # Cloth: edge length
48
  "edge": {
49
  "weight": 0,
50
  "value": 0.0
51
  },
52
+ # Cloth: normal consistency
53
  "nc": {
54
  "weight": 0,
55
  "value": 0.0
56
  },
57
+ # Cloth: laplacian smoonth
58
+ "lapla": {
59
  "weight": 1e2,
60
  "value": 0.0
61
  },
62
+ # Body: Normal_pred - Normal_smpl
63
  "normal": {
64
  "weight": 1e0,
65
  "value": 0.0
66
  },
67
+ # Body: Silhouette_pred - Silhouette_smpl
68
  "silhouette": {
69
  "weight": 1e0,
70
  "value": 0.0
71
  },
72
+ # Joint: reprojected joints difference
73
  "joint": {
74
  "weight": 5e0,
75
  "value": 0.0
 
80
 
81
 
82
  class SubTrainer(pl.Trainer):
 
83
  def save_checkpoint(self, filepath, weights_only=False):
84
  """Save model/training states as a checkpoint file through state-dump and file-write.
85
  Args:
 
99
  pl.utilities.cloud_io.atomic_save(_checkpoint, filepath)
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def query_func(opt, netG, features, points, proj_matrix=None):
103
  """
104
  - points: size of (bz, N, 3)
 
107
  """
108
  assert len(points) == 1
109
  samples = points.repeat(opt.num_views, 1, 1)
110
+ samples = samples.permute(0, 2, 1) # [bz, 3, N]
111
 
112
  # view specific query
113
  if proj_matrix is not None:
 
127
 
128
  return preds
129
 
130
+
131
  def query_func_IF(batch, netG, points):
132
  """
133
  - points: size of (bz, N, 3)
134
  return: size of (bz, 1, N)
135
  """
136
+
137
  batch["samples_geo"] = points
138
  batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points)
139
+
140
  preds = netG(batch)
141
 
142
  return preds.unsqueeze(1)
143
 
144
 
 
 
 
 
 
 
 
 
 
145
  def batch_mean(res, key):
146
+ return torch.stack(
147
+ [x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res]
148
+ ).mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  def accumulate(outputs, rot_num, split):
 
160
  keyword = f"{dataset}/{metric}"
161
  if keyword not in hparam_log_dict.keys():
162
  hparam_log_dict[keyword] = 0
163
+ for idx in range(split[dataset][0] * rot_num, split[dataset][1] * rot_num):
 
164
  hparam_log_dict[keyword] += outputs[idx][metric].item()
165
+ hparam_log_dict[keyword] /= (split[dataset][1] - split[dataset][0]) * rot_num
 
166
 
167
  print(colored(hparam_log_dict, "green"))
168
 
169
  return hparam_log_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/common/voxelize.py CHANGED
@@ -13,6 +13,7 @@ from lib.common.libmesh.inside_mesh import check_mesh_contains
13
 
14
  # From Occupancy Networks, Mescheder et. al. CVPR'19
15
 
 
16
  def make_3d_grid(bb_min, bb_max, shape):
17
  ''' Makes a 3D grid.
18
 
@@ -37,7 +38,7 @@ def make_3d_grid(bb_min, bb_max, shape):
37
 
38
  class VoxelGrid:
39
  def __init__(self, data, loc=(0., 0., 0.), scale=1):
40
- assert(data.shape[0] == data.shape[1] == data.shape[2])
41
  data = np.asarray(data, dtype=np.bool)
42
  loc = np.asarray(loc)
43
  self.data = data
@@ -53,7 +54,7 @@ class VoxelGrid:
53
 
54
  # Default scale, scales the mesh to [-0.45, 0.45]^3
55
  if scale is None:
56
- scale = (bounds[1] - bounds[0]).max()/0.9
57
 
58
  loc = np.asarray(loc)
59
  scale = float(scale)
@@ -61,7 +62,7 @@ class VoxelGrid:
61
  # Transform mesh
62
  mesh = mesh.copy()
63
  mesh.apply_translation(-loc)
64
- mesh.apply_scale(1/scale)
65
 
66
  # Apply method
67
  if method == 'ray':
@@ -75,7 +76,7 @@ class VoxelGrid:
75
  def down_sample(self, factor=2):
76
  if not (self.resolution % factor) == 0:
77
  raise ValueError('Resolution must be divisible by factor.')
78
- new_data = block_reduce(self.data, (factor,) * 3, np.max)
79
  return VoxelGrid(new_data, self.loc, self.scale)
80
 
81
  def to_mesh(self):
@@ -103,9 +104,9 @@ class VoxelGrid:
103
  f2 = f2_r | f2_l
104
  f3 = f3_r | f3_l
105
 
106
- assert(f1.shape == (nx + 1, ny, nz))
107
- assert(f2.shape == (nx, ny + 1, nz))
108
- assert(f3.shape == (nx, ny, nz + 1))
109
 
110
  # Determine if vertex present
111
  v = np.full(grid_shape, False)
@@ -146,53 +147,76 @@ class VoxelGrid:
146
  f2_r_x, f2_r_y, f2_r_z = np.where(f2_r)
147
  f3_r_x, f3_r_y, f3_r_z = np.where(f3_r)
148
 
149
- faces_1_l = np.stack([
150
- v_idx[f1_l_x, f1_l_y, f1_l_z],
151
- v_idx[f1_l_x, f1_l_y, f1_l_z + 1],
152
- v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1],
153
- v_idx[f1_l_x, f1_l_y + 1, f1_l_z],
154
- ], axis=1)
155
-
156
- faces_1_r = np.stack([
157
- v_idx[f1_r_x, f1_r_y, f1_r_z],
158
- v_idx[f1_r_x, f1_r_y + 1, f1_r_z],
159
- v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1],
160
- v_idx[f1_r_x, f1_r_y, f1_r_z + 1],
161
- ], axis=1)
162
-
163
- faces_2_l = np.stack([
164
- v_idx[f2_l_x, f2_l_y, f2_l_z],
165
- v_idx[f2_l_x + 1, f2_l_y, f2_l_z],
166
- v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1],
167
- v_idx[f2_l_x, f2_l_y, f2_l_z + 1],
168
- ], axis=1)
169
-
170
- faces_2_r = np.stack([
171
- v_idx[f2_r_x, f2_r_y, f2_r_z],
172
- v_idx[f2_r_x, f2_r_y, f2_r_z + 1],
173
- v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1],
174
- v_idx[f2_r_x + 1, f2_r_y, f2_r_z],
175
- ], axis=1)
176
-
177
- faces_3_l = np.stack([
178
- v_idx[f3_l_x, f3_l_y, f3_l_z],
179
- v_idx[f3_l_x, f3_l_y + 1, f3_l_z],
180
- v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z],
181
- v_idx[f3_l_x + 1, f3_l_y, f3_l_z],
182
- ], axis=1)
183
-
184
- faces_3_r = np.stack([
185
- v_idx[f3_r_x, f3_r_y, f3_r_z],
186
- v_idx[f3_r_x + 1, f3_r_y, f3_r_z],
187
- v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z],
188
- v_idx[f3_r_x, f3_r_y + 1, f3_r_z],
189
- ], axis=1)
190
-
191
- faces = np.concatenate([
192
- faces_1_l, faces_1_r,
193
- faces_2_l, faces_2_r,
194
- faces_3_l, faces_3_r,
195
- ], axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  vertices = self.loc + self.scale * vertices
198
  mesh = trimesh.Trimesh(vertices, faces, process=False)
@@ -200,7 +224,7 @@ class VoxelGrid:
200
 
201
  @property
202
  def resolution(self):
203
- assert(self.data.shape[0] == self.data.shape[1] == self.data.shape[2])
204
  return self.data.shape[0]
205
 
206
  def contains(self, points):
@@ -211,12 +235,9 @@ class VoxelGrid:
211
  # Discretize points to [0, nx-1]^3
212
  points_i = ((points + 0.5) * nx).astype(np.int32)
213
  # i1, i2, i3 have sizes (batch_size, T)
214
- i1, i2, i3 = points_i[..., 0], points_i[..., 1], points_i[..., 2]
215
  # Only use indices inside bounding box
216
- mask = (
217
- (i1 >= 0) & (i2 >= 0) & (i3 >= 0)
218
- & (nx > i1) & (nx > i2) & (nx > i3)
219
- )
220
  # Prevent out of bounds error
221
  i1 = i1[mask]
222
  i2 = i2[mask]
@@ -254,7 +275,7 @@ def voxelize_surface(mesh, resolution):
254
  vertices = (vertices + 0.5) * resolution
255
 
256
  face_loc = vertices[faces]
257
- occ = np.full((resolution,) * 3, 0, dtype=np.int32)
258
  face_loc = face_loc.astype(np.float32)
259
 
260
  voxelize_mesh_(occ, face_loc)
@@ -264,9 +285,9 @@ def voxelize_surface(mesh, resolution):
264
 
265
 
266
  def voxelize_interior(mesh, resolution):
267
- shape = (resolution,) * 3
268
- bb_min = (0.5,) * 3
269
- bb_max = (resolution - 0.5,) * 3
270
  # Create points. Add noise to break symmetry
271
  points = make_3d_grid(bb_min, bb_max, shape=shape).numpy()
272
  points = points + 0.1 * (np.random.rand(*points.shape) - 0.5)
@@ -280,14 +301,9 @@ def check_voxel_occupied(occupancy_grid):
280
  occ = occupancy_grid
281
 
282
  occupied = (
283
- occ[..., :-1, :-1, :-1]
284
- & occ[..., :-1, :-1, 1:]
285
- & occ[..., :-1, 1:, :-1]
286
- & occ[..., :-1, 1:, 1:]
287
- & occ[..., 1:, :-1, :-1]
288
- & occ[..., 1:, :-1, 1:]
289
- & occ[..., 1:, 1:, :-1]
290
- & occ[..., 1:, 1:, 1:]
291
  )
292
  return occupied
293
 
@@ -296,14 +312,9 @@ def check_voxel_unoccupied(occupancy_grid):
296
  occ = occupancy_grid
297
 
298
  unoccupied = ~(
299
- occ[..., :-1, :-1, :-1]
300
- | occ[..., :-1, :-1, 1:]
301
- | occ[..., :-1, 1:, :-1]
302
- | occ[..., :-1, 1:, 1:]
303
- | occ[..., 1:, :-1, :-1]
304
- | occ[..., 1:, :-1, 1:]
305
- | occ[..., 1:, 1:, :-1]
306
- | occ[..., 1:, 1:, 1:]
307
  )
308
  return unoccupied
309
 
 
13
 
14
  # From Occupancy Networks, Mescheder et. al. CVPR'19
15
 
16
+
17
  def make_3d_grid(bb_min, bb_max, shape):
18
  ''' Makes a 3D grid.
19
 
 
38
 
39
  class VoxelGrid:
40
  def __init__(self, data, loc=(0., 0., 0.), scale=1):
41
+ assert (data.shape[0] == data.shape[1] == data.shape[2])
42
  data = np.asarray(data, dtype=np.bool)
43
  loc = np.asarray(loc)
44
  self.data = data
 
54
 
55
  # Default scale, scales the mesh to [-0.45, 0.45]^3
56
  if scale is None:
57
+ scale = (bounds[1] - bounds[0]).max() / 0.9
58
 
59
  loc = np.asarray(loc)
60
  scale = float(scale)
 
62
  # Transform mesh
63
  mesh = mesh.copy()
64
  mesh.apply_translation(-loc)
65
+ mesh.apply_scale(1 / scale)
66
 
67
  # Apply method
68
  if method == 'ray':
 
76
  def down_sample(self, factor=2):
77
  if not (self.resolution % factor) == 0:
78
  raise ValueError('Resolution must be divisible by factor.')
79
+ new_data = block_reduce(self.data, (factor, ) * 3, np.max)
80
  return VoxelGrid(new_data, self.loc, self.scale)
81
 
82
  def to_mesh(self):
 
104
  f2 = f2_r | f2_l
105
  f3 = f3_r | f3_l
106
 
107
+ assert (f1.shape == (nx + 1, ny, nz))
108
+ assert (f2.shape == (nx, ny + 1, nz))
109
+ assert (f3.shape == (nx, ny, nz + 1))
110
 
111
  # Determine if vertex present
112
  v = np.full(grid_shape, False)
 
147
  f2_r_x, f2_r_y, f2_r_z = np.where(f2_r)
148
  f3_r_x, f3_r_y, f3_r_z = np.where(f3_r)
149
 
150
+ faces_1_l = np.stack(
151
+ [
152
+ v_idx[f1_l_x, f1_l_y, f1_l_z],
153
+ v_idx[f1_l_x, f1_l_y, f1_l_z + 1],
154
+ v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1],
155
+ v_idx[f1_l_x, f1_l_y + 1, f1_l_z],
156
+ ],
157
+ axis=1
158
+ )
159
+
160
+ faces_1_r = np.stack(
161
+ [
162
+ v_idx[f1_r_x, f1_r_y, f1_r_z],
163
+ v_idx[f1_r_x, f1_r_y + 1, f1_r_z],
164
+ v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1],
165
+ v_idx[f1_r_x, f1_r_y, f1_r_z + 1],
166
+ ],
167
+ axis=1
168
+ )
169
+
170
+ faces_2_l = np.stack(
171
+ [
172
+ v_idx[f2_l_x, f2_l_y, f2_l_z],
173
+ v_idx[f2_l_x + 1, f2_l_y, f2_l_z],
174
+ v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1],
175
+ v_idx[f2_l_x, f2_l_y, f2_l_z + 1],
176
+ ],
177
+ axis=1
178
+ )
179
+
180
+ faces_2_r = np.stack(
181
+ [
182
+ v_idx[f2_r_x, f2_r_y, f2_r_z],
183
+ v_idx[f2_r_x, f2_r_y, f2_r_z + 1],
184
+ v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1],
185
+ v_idx[f2_r_x + 1, f2_r_y, f2_r_z],
186
+ ],
187
+ axis=1
188
+ )
189
+
190
+ faces_3_l = np.stack(
191
+ [
192
+ v_idx[f3_l_x, f3_l_y, f3_l_z],
193
+ v_idx[f3_l_x, f3_l_y + 1, f3_l_z],
194
+ v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z],
195
+ v_idx[f3_l_x + 1, f3_l_y, f3_l_z],
196
+ ],
197
+ axis=1
198
+ )
199
+
200
+ faces_3_r = np.stack(
201
+ [
202
+ v_idx[f3_r_x, f3_r_y, f3_r_z],
203
+ v_idx[f3_r_x + 1, f3_r_y, f3_r_z],
204
+ v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z],
205
+ v_idx[f3_r_x, f3_r_y + 1, f3_r_z],
206
+ ],
207
+ axis=1
208
+ )
209
+
210
+ faces = np.concatenate(
211
+ [
212
+ faces_1_l,
213
+ faces_1_r,
214
+ faces_2_l,
215
+ faces_2_r,
216
+ faces_3_l,
217
+ faces_3_r,
218
+ ], axis=0
219
+ )
220
 
221
  vertices = self.loc + self.scale * vertices
222
  mesh = trimesh.Trimesh(vertices, faces, process=False)
 
224
 
225
  @property
226
  def resolution(self):
227
+ assert (self.data.shape[0] == self.data.shape[1] == self.data.shape[2])
228
  return self.data.shape[0]
229
 
230
  def contains(self, points):
 
235
  # Discretize points to [0, nx-1]^3
236
  points_i = ((points + 0.5) * nx).astype(np.int32)
237
  # i1, i2, i3 have sizes (batch_size, T)
238
+ i1, i2, i3 = points_i[..., 0], points_i[..., 1], points_i[..., 2]
239
  # Only use indices inside bounding box
240
+ mask = ((i1 >= 0) & (i2 >= 0) & (i3 >= 0) & (nx > i1) & (nx > i2) & (nx > i3))
 
 
 
241
  # Prevent out of bounds error
242
  i1 = i1[mask]
243
  i2 = i2[mask]
 
275
  vertices = (vertices + 0.5) * resolution
276
 
277
  face_loc = vertices[faces]
278
+ occ = np.full((resolution, ) * 3, 0, dtype=np.int32)
279
  face_loc = face_loc.astype(np.float32)
280
 
281
  voxelize_mesh_(occ, face_loc)
 
285
 
286
 
287
  def voxelize_interior(mesh, resolution):
288
+ shape = (resolution, ) * 3
289
+ bb_min = (0.5, ) * 3
290
+ bb_max = (resolution - 0.5, ) * 3
291
  # Create points. Add noise to break symmetry
292
  points = make_3d_grid(bb_min, bb_max, shape=shape).numpy()
293
  points = points + 0.1 * (np.random.rand(*points.shape) - 0.5)
 
301
  occ = occupancy_grid
302
 
303
  occupied = (
304
+ occ[..., :-1, :-1, :-1] & occ[..., :-1, :-1, 1:] & occ[..., :-1, 1:, :-1] &
305
+ occ[..., :-1, 1:, 1:] & occ[..., 1:, :-1, :-1] & occ[..., 1:, :-1, 1:] &
306
+ occ[..., 1:, 1:, :-1] & occ[..., 1:, 1:, 1:]
 
 
 
 
 
307
  )
308
  return occupied
309
 
 
312
  occ = occupancy_grid
313
 
314
  unoccupied = ~(
315
+ occ[..., :-1, :-1, :-1] | occ[..., :-1, :-1, 1:] | occ[..., :-1, 1:, :-1] |
316
+ occ[..., :-1, 1:, 1:] | occ[..., 1:, :-1, :-1] | occ[..., 1:, :-1, 1:] |
317
+ occ[..., 1:, 1:, :-1] | occ[..., 1:, 1:, 1:]
 
 
 
 
 
318
  )
319
  return unoccupied
320
 
lib/dataset/Evaluator.py CHANGED
@@ -37,7 +37,6 @@ class _PointFaceDistance(Function):
37
  """
38
  Torch autograd Function wrapper PointFaceDistance Cuda implementation
39
  """
40
-
41
  @staticmethod
42
  def forward(
43
  ctx,
@@ -92,12 +91,15 @@ class _PointFaceDistance(Function):
92
  grad_dists = grad_dists.contiguous()
93
  points, tris, idxs = ctx.saved_tensors
94
  min_triangle_area = ctx.min_triangle_area
95
- grad_points, grad_tris = _C.point_face_dist_backward(points, tris, idxs, grad_dists, min_triangle_area)
 
 
96
  return grad_points, None, grad_tris, None, None, None
97
 
98
 
99
- def _rand_barycentric_coords(size1, size2, dtype: torch.dtype,
100
- device: torch.device) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
101
  """
102
  Helper function to generate random barycentric coordinates which are uniformly
103
  distributed over a triangle.
@@ -167,19 +169,21 @@ def sample_points_from_meshes(meshes, num_samples: int = 10000):
167
  faces = meshes.faces_packed()
168
  mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
169
  num_meshes = len(meshes)
170
- num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.
171
 
172
  # Initialize samples tensor with fill value 0 for empty meshes.
173
  samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
174
 
175
  # Only compute samples for non empty meshes
176
  with torch.no_grad():
177
- areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero.
178
  max_faces = meshes.num_faces_per_mesh().max().item()
179
- areas_padded = packed_to_padded(areas, mesh_to_face[meshes.valid], max_faces) # (N, F)
180
 
181
  # TODO (gkioxari) Confirm multinomial bug is not present with real data.
182
- samples_face_idxs = areas_padded.multinomial(num_samples, replacement=True) # (N, num_samples)
 
 
183
  samples_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
184
 
185
  # Randomly generate barycentric coords.
@@ -200,23 +204,25 @@ def point_mesh_distance(meshes, pcls, weighted=True):
200
  raise ValueError("meshes and pointclouds must be equal sized batches")
201
 
202
  # packed representation for pointclouds
203
- points = pcls.points_packed() # (P, 3)
204
  points_first_idx = pcls.cloud_to_packed_first_idx()
205
  max_points = pcls.num_points_per_cloud().max().item()
206
 
207
  # packed representation for faces
208
  verts_packed = meshes.verts_packed()
209
  faces_packed = meshes.faces_packed()
210
- tris = verts_packed[faces_packed] # (T, 3, 3)
211
  tris_first_idx = meshes.mesh_to_faces_packed_first_idx()
212
 
213
  # point to face distance: shape (P,)
214
- point_to_face, idxs = _PointFaceDistance.apply(points, points_first_idx, tris, tris_first_idx, max_points, 5e-3)
 
 
215
 
216
  if weighted:
217
  # weight each example by the inverse of number of points in the example
218
- point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
219
- num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
220
  weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
221
  weights_p = 1.0 / weights_p.float()
222
  point_to_face = torch.sqrt(point_to_face) * weights_p
@@ -225,7 +231,6 @@ def point_mesh_distance(meshes, pcls, weighted=True):
225
 
226
 
227
  class Evaluator:
228
-
229
  def __init__(self, device):
230
 
231
  self.render = Render(size=512, device=device)
@@ -253,8 +258,8 @@ class Evaluator:
253
  self.render.meshes = self.tgt_mesh
254
  tgt_normal_imgs = self.render.get_image(cam_type="four", bg="black")
255
 
256
- src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
257
- tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
258
  src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True)
259
  tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True)
260
 
@@ -274,8 +279,11 @@ class Evaluator:
274
  # error_hf = ((((src_normal_arr - tgt_normal_arr) * sim_mask)**2).sum(dim=0).mean()) * 4.0
275
 
276
  normal_img = Image.fromarray(
277
- (torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute(1, 2, 0).detach().cpu().numpy() * 255.0).astype(
278
- np.uint8))
 
 
 
279
  normal_img.save(normal_path)
280
 
281
  return error
@@ -291,7 +299,9 @@ class Evaluator:
291
  p2s_dist_all, _ = point_mesh_distance(self.src_mesh, tgt_points) * 100.0
292
  p2s_dist = p2s_dist_all.sum()
293
 
294
- chamfer_dist = (point_mesh_distance(self.tgt_mesh, src_points)[0].sum() * 100.0 + p2s_dist) * 0.5
 
 
295
 
296
  return chamfer_dist, p2s_dist
297
 
 
37
  """
38
  Torch autograd Function wrapper PointFaceDistance Cuda implementation
39
  """
 
40
  @staticmethod
41
  def forward(
42
  ctx,
 
91
  grad_dists = grad_dists.contiguous()
92
  points, tris, idxs = ctx.saved_tensors
93
  min_triangle_area = ctx.min_triangle_area
94
+ grad_points, grad_tris = _C.point_face_dist_backward(
95
+ points, tris, idxs, grad_dists, min_triangle_area
96
+ )
97
  return grad_points, None, grad_tris, None, None, None
98
 
99
 
100
+ def _rand_barycentric_coords(
101
+ size1, size2, dtype: torch.dtype, device: torch.device
102
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
  """
104
  Helper function to generate random barycentric coordinates which are uniformly
105
  distributed over a triangle.
 
169
  faces = meshes.faces_packed()
170
  mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
171
  num_meshes = len(meshes)
172
+ num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.
173
 
174
  # Initialize samples tensor with fill value 0 for empty meshes.
175
  samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
176
 
177
  # Only compute samples for non empty meshes
178
  with torch.no_grad():
179
+ areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero.
180
  max_faces = meshes.num_faces_per_mesh().max().item()
181
+ areas_padded = packed_to_padded(areas, mesh_to_face[meshes.valid], max_faces) # (N, F)
182
 
183
  # TODO (gkioxari) Confirm multinomial bug is not present with real data.
184
+ samples_face_idxs = areas_padded.multinomial(
185
+ num_samples, replacement=True
186
+ ) # (N, num_samples)
187
  samples_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
188
 
189
  # Randomly generate barycentric coords.
 
204
  raise ValueError("meshes and pointclouds must be equal sized batches")
205
 
206
  # packed representation for pointclouds
207
+ points = pcls.points_packed() # (P, 3)
208
  points_first_idx = pcls.cloud_to_packed_first_idx()
209
  max_points = pcls.num_points_per_cloud().max().item()
210
 
211
  # packed representation for faces
212
  verts_packed = meshes.verts_packed()
213
  faces_packed = meshes.faces_packed()
214
+ tris = verts_packed[faces_packed] # (T, 3, 3)
215
  tris_first_idx = meshes.mesh_to_faces_packed_first_idx()
216
 
217
  # point to face distance: shape (P,)
218
+ point_to_face, idxs = _PointFaceDistance.apply(
219
+ points, points_first_idx, tris, tris_first_idx, max_points, 5e-3
220
+ )
221
 
222
  if weighted:
223
  # weight each example by the inverse of number of points in the example
224
+ point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
225
+ num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
226
  weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
227
  weights_p = 1.0 / weights_p.float()
228
  point_to_face = torch.sqrt(point_to_face) * weights_p
 
231
 
232
 
233
  class Evaluator:
 
234
  def __init__(self, device):
235
 
236
  self.render = Render(size=512, device=device)
 
258
  self.render.meshes = self.tgt_mesh
259
  tgt_normal_imgs = self.render.get_image(cam_type="four", bg="black")
260
 
261
+ src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
262
+ tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
263
  src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True)
264
  tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True)
265
 
 
279
  # error_hf = ((((src_normal_arr - tgt_normal_arr) * sim_mask)**2).sum(dim=0).mean()) * 4.0
280
 
281
  normal_img = Image.fromarray(
282
+ (
283
+ torch.cat([src_normal_arr, tgt_normal_arr],
284
+ dim=1).permute(1, 2, 0).detach().cpu().numpy() * 255.0
285
+ ).astype(np.uint8)
286
+ )
287
  normal_img.save(normal_path)
288
 
289
  return error
 
299
  p2s_dist_all, _ = point_mesh_distance(self.src_mesh, tgt_points) * 100.0
300
  p2s_dist = p2s_dist_all.sum()
301
 
302
+ chamfer_dist = (
303
+ point_mesh_distance(self.tgt_mesh, src_points)[0].sum() * 100.0 + p2s_dist
304
+ ) * 0.5
305
 
306
  return chamfer_dist, p2s_dist
307
 
lib/dataset/NormalDataset.py CHANGED
@@ -23,7 +23,6 @@ import torchvision.transforms as transforms
23
 
24
 
25
  class NormalDataset:
26
-
27
  def __init__(self, cfg, split="train"):
28
 
29
  self.split = split
@@ -44,8 +43,7 @@ class NormalDataset:
44
  if self.split != "train":
45
  self.rotations = range(0, 360, 120)
46
  else:
47
- self.rotations = np.arange(0, 360, 360 //
48
- self.opt.rotation_num).astype(np.int)
49
 
50
  self.datasets_dict = {}
51
 
@@ -54,26 +52,29 @@ class NormalDataset:
54
  dataset_dir = osp.join(self.root, dataset)
55
 
56
  self.datasets_dict[dataset] = {
57
- "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"),
58
- dtype=str),
59
  "scale": self.scales[dataset_id],
60
  }
61
 
62
  self.subject_list = self.get_subject_list(split)
63
 
64
  # PIL to tensor
65
- self.image_to_tensor = transforms.Compose([
66
- transforms.Resize(self.input_size),
67
- transforms.ToTensor(),
68
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
69
- ])
 
 
70
 
71
  # PIL to tensor
72
- self.mask_to_tensor = transforms.Compose([
73
- transforms.Resize(self.input_size),
74
- transforms.ToTensor(),
75
- transforms.Normalize((0.0, ), (1.0, )),
76
- ])
 
 
77
 
78
  def get_subject_list(self, split):
79
 
@@ -88,16 +89,12 @@ class NormalDataset:
88
  subject_list += np.loadtxt(split_txt, dtype=str).tolist()
89
 
90
  if self.split != "test":
91
- subject_list += subject_list[:self.bsize -
92
- len(subject_list) % self.bsize]
93
  print(colored(f"total: {len(subject_list)}", "yellow"))
94
 
95
- bug_list = sorted(
96
- np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
97
 
98
- subject_list = [
99
- subject for subject in subject_list if (subject not in bug_list)
100
- ]
101
 
102
  # subject_list = ["thuman2/0008"]
103
  return subject_list
@@ -113,48 +110,41 @@ class NormalDataset:
113
  rotation = self.rotations[rid]
114
  subject = self.subject_list[mid].split("/")[1]
115
  dataset = self.subject_list[mid].split("/")[0]
116
- render_folder = "/".join(
117
- [dataset + f"_{self.opt.rotation_num}views", subject])
118
 
119
  if not osp.exists(osp.join(self.root, render_folder)):
120
  render_folder = "/".join([dataset + f"_36views", subject])
121
 
122
  # setup paths
123
  data_dict = {
124
- "dataset":
125
- dataset,
126
- "subject":
127
- subject,
128
- "rotation":
129
- rotation,
130
- "scale":
131
- self.datasets_dict[dataset]["scale"],
132
- "image_path":
133
- osp.join(self.root, render_folder, "render",
134
- f"{rotation:03d}.png"),
135
  }
136
 
137
  # image/normal/depth loader
138
  for name, channel in zip(self.in_total, self.in_total_dim):
139
 
140
  if f"{name}_path" not in data_dict.keys():
141
- data_dict.update({
142
- f"{name}_path":
143
- osp.join(self.root, render_folder, name,
144
- f"{rotation:03d}.png")
145
- })
146
-
147
- data_dict.update({
148
- name:
149
- self.imagepath2tensor(data_dict[f"{name}_path"],
150
- channel,
151
- inv=False,
152
- erasing=False)
153
- })
154
-
155
- path_keys = [
156
- key for key in data_dict.keys() if "_path" in key or "_dir" in key
157
- ]
158
 
159
  for key in path_keys:
160
  del data_dict[key]
@@ -172,10 +162,9 @@ class NormalDataset:
172
 
173
  # simulate occlusion
174
  if erasing:
175
- mask = kornia.augmentation.RandomErasing(p=0.2,
176
- scale=(0.01, 0.2),
177
- ratio=(0.3, 3.3),
178
- keepdim=True)(mask)
179
  image = (image * mask)[:channel]
180
 
181
  return (image * (0.5 - inv) * 2.0).float()
 
23
 
24
 
25
  class NormalDataset:
 
26
  def __init__(self, cfg, split="train"):
27
 
28
  self.split = split
 
43
  if self.split != "train":
44
  self.rotations = range(0, 360, 120)
45
  else:
46
+ self.rotations = np.arange(0, 360, 360 // self.opt.rotation_num).astype(np.int)
 
47
 
48
  self.datasets_dict = {}
49
 
 
52
  dataset_dir = osp.join(self.root, dataset)
53
 
54
  self.datasets_dict[dataset] = {
55
+ "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str),
 
56
  "scale": self.scales[dataset_id],
57
  }
58
 
59
  self.subject_list = self.get_subject_list(split)
60
 
61
  # PIL to tensor
62
+ self.image_to_tensor = transforms.Compose(
63
+ [
64
+ transforms.Resize(self.input_size),
65
+ transforms.ToTensor(),
66
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
67
+ ]
68
+ )
69
 
70
  # PIL to tensor
71
+ self.mask_to_tensor = transforms.Compose(
72
+ [
73
+ transforms.Resize(self.input_size),
74
+ transforms.ToTensor(),
75
+ transforms.Normalize((0.0, ), (1.0, )),
76
+ ]
77
+ )
78
 
79
  def get_subject_list(self, split):
80
 
 
89
  subject_list += np.loadtxt(split_txt, dtype=str).tolist()
90
 
91
  if self.split != "test":
92
+ subject_list += subject_list[:self.bsize - len(subject_list) % self.bsize]
 
93
  print(colored(f"total: {len(subject_list)}", "yellow"))
94
 
95
+ bug_list = sorted(np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
 
96
 
97
+ subject_list = [subject for subject in subject_list if (subject not in bug_list)]
 
 
98
 
99
  # subject_list = ["thuman2/0008"]
100
  return subject_list
 
110
  rotation = self.rotations[rid]
111
  subject = self.subject_list[mid].split("/")[1]
112
  dataset = self.subject_list[mid].split("/")[0]
113
+ render_folder = "/".join([dataset + f"_{self.opt.rotation_num}views", subject])
 
114
 
115
  if not osp.exists(osp.join(self.root, render_folder)):
116
  render_folder = "/".join([dataset + f"_36views", subject])
117
 
118
  # setup paths
119
  data_dict = {
120
+ "dataset": dataset,
121
+ "subject": subject,
122
+ "rotation": rotation,
123
+ "scale": self.datasets_dict[dataset]["scale"],
124
+ "image_path": osp.join(self.root, render_folder, "render", f"{rotation:03d}.png"),
 
 
 
 
 
 
125
  }
126
 
127
  # image/normal/depth loader
128
  for name, channel in zip(self.in_total, self.in_total_dim):
129
 
130
  if f"{name}_path" not in data_dict.keys():
131
+ data_dict.update(
132
+ {
133
+ f"{name}_path":
134
+ osp.join(self.root, render_folder, name, f"{rotation:03d}.png")
135
+ }
136
+ )
137
+
138
+ data_dict.update(
139
+ {
140
+ name:
141
+ self.imagepath2tensor(
142
+ data_dict[f"{name}_path"], channel, inv=False, erasing=False
143
+ )
144
+ }
145
+ )
146
+
147
+ path_keys = [key for key in data_dict.keys() if "_path" in key or "_dir" in key]
148
 
149
  for key in path_keys:
150
  del data_dict[key]
 
162
 
163
  # simulate occlusion
164
  if erasing:
165
+ mask = kornia.augmentation.RandomErasing(
166
+ p=0.2, scale=(0.01, 0.2), ratio=(0.3, 3.3), keepdim=True
167
+ )(mask)
 
168
  image = (image * mask)[:channel]
169
 
170
  return (image * (0.5 - inv) * 2.0).float()
lib/dataset/NormalModule.py CHANGED
@@ -22,7 +22,6 @@ import pytorch_lightning as pl
22
 
23
 
24
  class NormalModule(pl.LightningDataModule):
25
-
26
  def __init__(self, cfg):
27
  super(NormalModule, self).__init__()
28
  self.cfg = cfg
@@ -40,7 +39,7 @@ class NormalModule(pl.LightningDataModule):
40
  self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
41
  self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
42
  self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
43
-
44
  self.data_size = {
45
  "train": len(self.train_dataset),
46
  "val": len(self.val_dataset),
@@ -69,7 +68,7 @@ class NormalModule(pl.LightningDataModule):
69
  )
70
 
71
  return val_data_loader
72
-
73
  def val_dataloader(self):
74
 
75
  test_data_loader = DataLoader(
 
22
 
23
 
24
  class NormalModule(pl.LightningDataModule):
 
25
  def __init__(self, cfg):
26
  super(NormalModule, self).__init__()
27
  self.cfg = cfg
 
39
  self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
40
  self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
41
  self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
42
+
43
  self.data_size = {
44
  "train": len(self.train_dataset),
45
  "val": len(self.val_dataset),
 
68
  )
69
 
70
  return val_data_loader
71
+
72
  def val_dataloader(self):
73
 
74
  test_data_loader = DataLoader(
lib/dataset/PointFeat.py CHANGED
@@ -6,7 +6,6 @@ from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection
6
 
7
 
8
  class PointFeat:
9
-
10
  def __init__(self, verts, faces):
11
 
12
  # verts [B, N_vert, 3]
@@ -23,7 +22,10 @@ class PointFeat:
23
 
24
  if verts.shape[1] == 10475:
25
  faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask]
26
- mouth_faces = (torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1, 1).to(self.device))
 
 
 
27
  self.faces = torch.cat([faces, mouth_faces], dim=1).long()
28
 
29
  self.verts = verts.float()
@@ -35,11 +37,15 @@ class PointFeat:
35
  points = points.float()
36
  residues, pts_ind = point_mesh_distance(self.mesh, Pointclouds(points), weighted=False)
37
 
38
- closest_triangles = torch.gather(self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
 
 
39
  bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles)
40
 
41
  feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces)
42
- closest_normals = torch.gather(feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
 
 
43
  shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0))
44
 
45
  pts2shoot_normals = points - shoot_verts
 
6
 
7
 
8
  class PointFeat:
 
9
  def __init__(self, verts, faces):
10
 
11
  # verts [B, N_vert, 3]
 
22
 
23
  if verts.shape[1] == 10475:
24
  faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask]
25
+ mouth_faces = (
26
+ torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1,
27
+ 1).to(self.device)
28
+ )
29
  self.faces = torch.cat([faces, mouth_faces], dim=1).long()
30
 
31
  self.verts = verts.float()
 
37
  points = points.float()
38
  residues, pts_ind = point_mesh_distance(self.mesh, Pointclouds(points), weighted=False)
39
 
40
+ closest_triangles = torch.gather(
41
+ self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
42
+ ).view(-1, 3, 3)
43
  bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles)
44
 
45
  feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces)
46
+ closest_normals = torch.gather(
47
+ feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
48
+ ).view(-1, 3, 3)
49
  shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0))
50
 
51
  pts2shoot_normals = points - shoot_verts
lib/dataset/TestDataset.py CHANGED
@@ -25,6 +25,7 @@ from lib.pixielib.utils.config import cfg as pixie_cfg
25
  from lib.pixielib.pixie import PIXIE
26
  from lib.pixielib.models.SMPLX import SMPLX as PIXIE_SMPLX
27
  from lib.common.imutils import process_image
 
28
  from lib.net.geometry import rotation_matrix_to_angle_axis, rot6d_to_rotmat
29
 
30
  from lib.pymafx.core import path_config
@@ -36,8 +37,9 @@ from lib.dataset.body_model import TetraSMPLModel
36
  from lib.dataset.mesh_util import get_visibility, SMPLX
37
  import torch.nn.functional as F
38
  from torchvision import transforms
 
 
39
  import os.path as osp
40
- import os
41
  import torch
42
  import glob
43
  import numpy as np
@@ -48,7 +50,6 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
48
 
49
 
50
  class TestDataset:
51
-
52
  def __init__(self, cfg, device):
53
 
54
  self.image_dir = cfg["image_dir"]
@@ -65,7 +66,9 @@ class TestDataset:
65
  keep_lst = sorted(glob.glob(f"{self.image_dir}/*"))
66
  img_fmts = ["jpg", "png", "jpeg", "JPG", "bmp"]
67
 
68
- self.subject_list = sorted([item for item in keep_lst if item.split(".")[-1] in img_fmts], reverse=False)
 
 
69
 
70
  # smpl related
71
  self.smpl_data = SMPLX()
@@ -80,7 +83,16 @@ class TestDataset:
80
 
81
  self.smpl_model = PIXIE_SMPLX(pixie_cfg.model).to(self.device)
82
 
83
- print(colored(f"Use {self.hps_type.upper()} to estimate human pose and shape", "green"))
 
 
 
 
 
 
 
 
 
84
 
85
  self.render = Render(size=512, device=self.device)
86
 
@@ -90,7 +102,9 @@ class TestDataset:
90
  def compute_vis_cmap(self, smpl_verts, smpl_faces):
91
 
92
  (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=-1)
93
- smpl_vis = get_visibility(xy, z, torch.as_tensor(smpl_faces).long()[:, :, [0, 2, 1]]).unsqueeze(-1)
 
 
94
  smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type).unsqueeze(0)
95
 
96
  return {
@@ -109,7 +123,8 @@ class TestDataset:
109
  depth_FB[:, ~depth_mask[0]] = 0.
110
 
111
  # Important: index_long = depth_value - 1
112
- index_z = (((depth_FB + 1.) * 0.5 * self.vol_res) - 1).clip(0, self.vol_res - 1).permute(1, 2, 0)
 
113
  index_z_ceil = torch.ceil(index_z).long()
114
  index_z_floor = torch.floor(index_z).long()
115
  index_z_frac = torch.frac(index_z)
@@ -121,7 +136,7 @@ class TestDataset:
121
  F.one_hot(index_z_floor[..., 1], self.vol_res) * (1.0 - index_z_frac[..., 1])
122
 
123
  voxels[index_mask] *= 0
124
- voxels = torch.flip(voxels, [2]).permute(2, 0, 1).float() #[x-2, y-0, z-1]
125
 
126
  return {
127
  "depth_voxels": voxels.flip([
@@ -139,18 +154,25 @@ class TestDataset:
139
  smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0])
140
 
141
  verts = (
142
- np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * scale.item() + trans.detach().cpu().numpy())
 
 
143
  faces = (
144
  np.loadtxt(
145
  osp.join(self.smpl_data.tedra_dir, "tetrahedrons_neutral_adult.txt"),
146
  dtype=np.int32,
147
- ) - 1)
 
148
 
149
  pad_v_num = int(8000 - verts.shape[0])
150
  pad_f_num = int(25100 - faces.shape[0])
151
 
152
- verts = (np.pad(verts, ((0, pad_v_num), (0, 0)), mode="constant", constant_values=0.0).astype(np.float32) * 0.5)
153
- faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode="constant", constant_values=0.0).astype(np.int32)
 
 
 
 
154
 
155
  verts[:, 2] *= -1.0
156
 
@@ -168,7 +190,7 @@ class TestDataset:
168
  img_path = self.subject_list[index]
169
  img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
170
 
171
- arr_dict = process_image(img_path, self.hps_type, self.single, 512)
172
  arr_dict.update({"name": img_name})
173
 
174
  with torch.no_grad():
@@ -179,7 +201,10 @@ class TestDataset:
179
  preds_dict, _ = self.hps.forward(batch)
180
 
181
  arr_dict["smpl_faces"] = (
182
- torch.as_tensor(self.smpl_data.smplx_faces.astype(np.int64)).unsqueeze(0).long().to(self.device))
 
 
 
183
  arr_dict["type"] = self.smpl_type
184
 
185
  if self.hps_type == "pymafx":
@@ -198,13 +223,16 @@ class TestDataset:
198
  elif self.hps_type == "pixie":
199
  arr_dict.update(preds_dict)
200
  arr_dict["global_orient"] = preds_dict["global_pose"]
201
- arr_dict["betas"] = preds_dict["shape"] #200
202
  arr_dict["smpl_verts"] = preds_dict["vertices"]
203
  scale, tranX, tranY = preds_dict["cam"].split(1, dim=1)
204
  # 1.1435, 0.0128, 0.3520
205
 
206
  arr_dict["scale"] = scale.unsqueeze(1)
207
- arr_dict["trans"] = (torch.cat([tranX, tranY, torch.zeros_like(tranX)], dim=1).unsqueeze(1).to(self.device).float())
 
 
 
208
 
209
  # data_dict info (key-shape):
210
  # scale, tranX, tranY - tensor.float
@@ -230,4 +258,4 @@ class TestDataset:
230
 
231
  # render optimized mesh (normal, T_normal, image [-1,1])
232
  self.render.load_meshes(verts, faces)
233
- return self.render.get_image(type="depth")
 
25
  from lib.pixielib.pixie import PIXIE
26
  from lib.pixielib.models.SMPLX import SMPLX as PIXIE_SMPLX
27
  from lib.common.imutils import process_image
28
+ from lib.common.train_util import Format
29
  from lib.net.geometry import rotation_matrix_to_angle_axis, rot6d_to_rotmat
30
 
31
  from lib.pymafx.core import path_config
 
37
  from lib.dataset.mesh_util import get_visibility, SMPLX
38
  import torch.nn.functional as F
39
  from torchvision import transforms
40
+ from torchvision.models import detection
41
+
42
  import os.path as osp
 
43
  import torch
44
  import glob
45
  import numpy as np
 
50
 
51
 
52
  class TestDataset:
 
53
  def __init__(self, cfg, device):
54
 
55
  self.image_dir = cfg["image_dir"]
 
66
  keep_lst = sorted(glob.glob(f"{self.image_dir}/*"))
67
  img_fmts = ["jpg", "png", "jpeg", "JPG", "bmp"]
68
 
69
+ self.subject_list = sorted(
70
+ [item for item in keep_lst if item.split(".")[-1] in img_fmts], reverse=False
71
+ )
72
 
73
  # smpl related
74
  self.smpl_data = SMPLX()
 
83
 
84
  self.smpl_model = PIXIE_SMPLX(pixie_cfg.model).to(self.device)
85
 
86
+ self.detector = detection.maskrcnn_resnet50_fpn(
87
+ weights=detection.MaskRCNN_ResNet50_FPN_V2_Weights
88
+ )
89
+ self.detector.eval()
90
+
91
+ print(
92
+ colored(
93
+ f"SMPL-X estimate with {Format.start} {self.hps_type.upper()} {Format.end}", "green"
94
+ )
95
+ )
96
 
97
  self.render = Render(size=512, device=self.device)
98
 
 
102
  def compute_vis_cmap(self, smpl_verts, smpl_faces):
103
 
104
  (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=-1)
105
+ smpl_vis = get_visibility(xy, z,
106
+ torch.as_tensor(smpl_faces).long()[:, :,
107
+ [0, 2, 1]]).unsqueeze(-1)
108
  smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type).unsqueeze(0)
109
 
110
  return {
 
123
  depth_FB[:, ~depth_mask[0]] = 0.
124
 
125
  # Important: index_long = depth_value - 1
126
+ index_z = (((depth_FB + 1.) * 0.5 * self.vol_res) - 1).clip(0, self.vol_res -
127
+ 1).permute(1, 2, 0)
128
  index_z_ceil = torch.ceil(index_z).long()
129
  index_z_floor = torch.floor(index_z).long()
130
  index_z_frac = torch.frac(index_z)
 
136
  F.one_hot(index_z_floor[..., 1], self.vol_res) * (1.0 - index_z_frac[..., 1])
137
 
138
  voxels[index_mask] *= 0
139
+ voxels = torch.flip(voxels, [2]).permute(2, 0, 1).float() #[x-2, y-0, z-1]
140
 
141
  return {
142
  "depth_voxels": voxels.flip([
 
154
  smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0])
155
 
156
  verts = (
157
+ np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * scale.item() +
158
+ trans.detach().cpu().numpy()
159
+ )
160
  faces = (
161
  np.loadtxt(
162
  osp.join(self.smpl_data.tedra_dir, "tetrahedrons_neutral_adult.txt"),
163
  dtype=np.int32,
164
+ ) - 1
165
+ )
166
 
167
  pad_v_num = int(8000 - verts.shape[0])
168
  pad_f_num = int(25100 - faces.shape[0])
169
 
170
+ verts = (
171
+ np.pad(verts, ((0, pad_v_num),
172
+ (0, 0)), mode="constant", constant_values=0.0).astype(np.float32) * 0.5
173
+ )
174
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode="constant",
175
+ constant_values=0.0).astype(np.int32)
176
 
177
  verts[:, 2] *= -1.0
178
 
 
190
  img_path = self.subject_list[index]
191
  img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
192
 
193
+ arr_dict = process_image(img_path, self.hps_type, self.single, 512, self.detector)
194
  arr_dict.update({"name": img_name})
195
 
196
  with torch.no_grad():
 
201
  preds_dict, _ = self.hps.forward(batch)
202
 
203
  arr_dict["smpl_faces"] = (
204
+ torch.as_tensor(self.smpl_data.smplx_faces.astype(np.int64)).unsqueeze(0).long().to(
205
+ self.device
206
+ )
207
+ )
208
  arr_dict["type"] = self.smpl_type
209
 
210
  if self.hps_type == "pymafx":
 
223
  elif self.hps_type == "pixie":
224
  arr_dict.update(preds_dict)
225
  arr_dict["global_orient"] = preds_dict["global_pose"]
226
+ arr_dict["betas"] = preds_dict["shape"] #200
227
  arr_dict["smpl_verts"] = preds_dict["vertices"]
228
  scale, tranX, tranY = preds_dict["cam"].split(1, dim=1)
229
  # 1.1435, 0.0128, 0.3520
230
 
231
  arr_dict["scale"] = scale.unsqueeze(1)
232
+ arr_dict["trans"] = (
233
+ torch.cat([tranX, tranY, torch.zeros_like(tranX)],
234
+ dim=1).unsqueeze(1).to(self.device).float()
235
+ )
236
 
237
  # data_dict info (key-shape):
238
  # scale, tranX, tranY - tensor.float
 
258
 
259
  # render optimized mesh (normal, T_normal, image [-1,1])
260
  self.render.load_meshes(verts, faces)
261
+ return self.render.get_image(type="depth")
lib/dataset/body_model.py CHANGED
@@ -21,7 +21,6 @@ import os
21
 
22
 
23
  class SMPLModel:
24
-
25
  def __init__(self, model_path, age):
26
  """
27
  SMPL model.
@@ -49,20 +48,16 @@ class SMPLModel:
49
 
50
  if age == "kid":
51
  v_template_smil = np.load(
52
- os.path.join(os.path.dirname(model_path),
53
- "smpl/smpl_kid_template.npy"))
54
  v_template_smil -= np.mean(v_template_smil, axis=0)
55
- v_template_diff = np.expand_dims(v_template_smil - self.v_template,
56
- axis=2)
57
  self.shapedirs = np.concatenate(
58
- (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
59
- axis=2)
60
  self.beta_shape[0] += 1
61
 
62
- id_to_col = {
63
- self.kintree_table[1, i]: i
64
- for i in range(self.kintree_table.shape[1])
65
- }
66
  self.parent = {
67
  i: id_to_col[self.kintree_table[0, i]]
68
  for i in range(1, self.kintree_table.shape[1])
@@ -121,33 +116,30 @@ class SMPLModel:
121
  pose_cube = self.pose.reshape((-1, 1, 3))
122
  # rotation matrix for each joint
123
  self.R = self.rodrigues(pose_cube)
124
- I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
125
- (self.R.shape[0] - 1, 3, 3))
126
  lrotmin = (self.R[1:] - I_cube).ravel()
127
  # how pose affect body shape in zero pose
128
  v_posed = v_shaped + self.posedirs.dot(lrotmin)
129
  # world transformation of each joint
130
  G = np.empty((self.kintree_table.shape[1], 4, 4))
131
- G[0] = self.with_zeros(
132
- np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
133
  for i in range(1, self.kintree_table.shape[1]):
134
  G[i] = G[self.parent[i]].dot(
135
  self.with_zeros(
136
- np.hstack([
137
- self.R[i],
138
- ((self.J[i, :] - self.J[self.parent[i], :]).reshape(
139
- [3, 1])),
140
- ])))
 
 
 
141
  # remove the transformation due to the rest pose
142
- G = G - self.pack(
143
- np.matmul(
144
- G,
145
- np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
146
  # transformation of each vertex
147
  T = np.tensordot(self.weights, G, axes=[[1], [0]])
148
  rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
149
- v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
150
- 4])[:, :3]
151
  self.verts = v + self.trans.reshape([1, 3])
152
  self.G = G
153
 
@@ -171,19 +163,20 @@ class SMPLModel:
171
  r_hat = r / theta
172
  cos = np.cos(theta)
173
  z_stick = np.zeros(theta.shape[0])
174
- m = np.dstack([
175
- z_stick,
176
- -r_hat[:, 0, 2],
177
- r_hat[:, 0, 1],
178
- r_hat[:, 0, 2],
179
- z_stick,
180
- -r_hat[:, 0, 0],
181
- -r_hat[:, 0, 1],
182
- r_hat[:, 0, 0],
183
- z_stick,
184
- ]).reshape([-1, 3, 3])
185
- i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
186
- [theta.shape[0], 3, 3])
 
187
  A = np.transpose(r_hat, axes=[0, 2, 1])
188
  B = r_hat
189
  dot = np.matmul(A, B)
@@ -238,12 +231,7 @@ class SMPLModel:
238
 
239
 
240
  class TetraSMPLModel:
241
-
242
- def __init__(self,
243
- model_path,
244
- model_addition_path,
245
- age="adult",
246
- v_template=None):
247
  """
248
  SMPL model.
249
 
@@ -276,10 +264,7 @@ class TetraSMPLModel:
276
  self.posedirs_added = params_added["posedirs_added"]
277
  self.tetrahedrons = params_added["tetrahedrons"]
278
 
279
- id_to_col = {
280
- self.kintree_table[1, i]: i
281
- for i in range(self.kintree_table.shape[1])
282
- }
283
  self.parent = {
284
  i: id_to_col[self.kintree_table[0, i]]
285
  for i in range(1, self.kintree_table.shape[1])
@@ -291,14 +276,13 @@ class TetraSMPLModel:
291
 
292
  if age == "kid":
293
  v_template_smil = np.load(
294
- os.path.join(os.path.dirname(model_path),
295
- "smpl_kid_template.npy"))
296
  v_template_smil -= np.mean(v_template_smil, axis=0)
297
- v_template_diff = np.expand_dims(v_template_smil - self.v_template,
298
- axis=2)
299
  self.shapedirs = np.concatenate(
300
- (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
301
- axis=2)
302
  self.beta_shape[0] += 1
303
 
304
  self.pose = np.zeros(self.pose_shape)
@@ -356,50 +340,42 @@ class TetraSMPLModel:
356
  """
357
  # how beta affect body shape
358
  v_shaped = self.shapedirs.dot(self.beta) + self.v_template
359
- v_shaped_added = self.shapedirs_added.dot(
360
- self.beta) + self.v_template_added
361
  # joints location
362
  self.J = self.J_regressor.dot(v_shaped)
363
  pose_cube = self.pose.reshape((-1, 1, 3))
364
  # rotation matrix for each joint
365
  self.R = self.rodrigues(pose_cube)
366
- I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
367
- (self.R.shape[0] - 1, 3, 3))
368
  lrotmin = (self.R[1:] - I_cube).ravel()
369
  # how pose affect body shape in zero pose
370
  v_posed = v_shaped + self.posedirs.dot(lrotmin)
371
  v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin)
372
  # world transformation of each joint
373
  G = np.empty((self.kintree_table.shape[1], 4, 4))
374
- G[0] = self.with_zeros(
375
- np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
376
  for i in range(1, self.kintree_table.shape[1]):
377
  G[i] = G[self.parent[i]].dot(
378
  self.with_zeros(
379
- np.hstack([
380
- self.R[i],
381
- ((self.J[i, :] - self.J[self.parent[i], :]).reshape(
382
- [3, 1])),
383
- ])))
 
 
 
384
  # remove the transformation due to the rest pose
385
- G = G - self.pack(
386
- np.matmul(
387
- G,
388
- np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
389
  self.G = G
390
  # transformation of each vertex
391
  T = np.tensordot(self.weights, G, axes=[[1], [0]])
392
  rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
393
- v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
394
- 4])[:, :3]
395
  self.verts = v + self.trans.reshape([1, 3])
396
  T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]])
397
- rest_shape_added_h = np.hstack(
398
- (v_posed_added, np.ones([v_posed_added.shape[0], 1])))
399
- v_added = np.matmul(T_added,
400
- rest_shape_added_h.reshape([-1, 4,
401
- 1])).reshape([-1, 4
402
- ])[:, :3]
403
  self.verts_added = v_added + self.trans.reshape([1, 3])
404
 
405
  def rodrigues(self, r):
@@ -422,19 +398,20 @@ class TetraSMPLModel:
422
  r_hat = r / theta
423
  cos = np.cos(theta)
424
  z_stick = np.zeros(theta.shape[0])
425
- m = np.dstack([
426
- z_stick,
427
- -r_hat[:, 0, 2],
428
- r_hat[:, 0, 1],
429
- r_hat[:, 0, 2],
430
- z_stick,
431
- -r_hat[:, 0, 0],
432
- -r_hat[:, 0, 1],
433
- r_hat[:, 0, 0],
434
- z_stick,
435
- ]).reshape([-1, 3, 3])
436
- i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
437
- [theta.shape[0], 3, 3])
 
438
  A = np.transpose(r_hat, axes=[0, 2, 1])
439
  B = r_hat
440
  dot = np.matmul(A, B)
 
21
 
22
 
23
  class SMPLModel:
 
24
  def __init__(self, model_path, age):
25
  """
26
  SMPL model.
 
48
 
49
  if age == "kid":
50
  v_template_smil = np.load(
51
+ os.path.join(os.path.dirname(model_path), "smpl/smpl_kid_template.npy")
52
+ )
53
  v_template_smil -= np.mean(v_template_smil, axis=0)
54
+ v_template_diff = np.expand_dims(v_template_smil - self.v_template, axis=2)
 
55
  self.shapedirs = np.concatenate(
56
+ (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), axis=2
57
+ )
58
  self.beta_shape[0] += 1
59
 
60
+ id_to_col = {self.kintree_table[1, i]: i for i in range(self.kintree_table.shape[1])}
 
 
 
61
  self.parent = {
62
  i: id_to_col[self.kintree_table[0, i]]
63
  for i in range(1, self.kintree_table.shape[1])
 
116
  pose_cube = self.pose.reshape((-1, 1, 3))
117
  # rotation matrix for each joint
118
  self.R = self.rodrigues(pose_cube)
119
+ I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), (self.R.shape[0] - 1, 3, 3))
 
120
  lrotmin = (self.R[1:] - I_cube).ravel()
121
  # how pose affect body shape in zero pose
122
  v_posed = v_shaped + self.posedirs.dot(lrotmin)
123
  # world transformation of each joint
124
  G = np.empty((self.kintree_table.shape[1], 4, 4))
125
+ G[0] = self.with_zeros(np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
 
126
  for i in range(1, self.kintree_table.shape[1]):
127
  G[i] = G[self.parent[i]].dot(
128
  self.with_zeros(
129
+ np.hstack(
130
+ [
131
+ self.R[i],
132
+ ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
133
+ ]
134
+ )
135
+ )
136
+ )
137
  # remove the transformation due to the rest pose
138
+ G = G - self.pack(np.matmul(G, np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
 
 
 
139
  # transformation of each vertex
140
  T = np.tensordot(self.weights, G, axes=[[1], [0]])
141
  rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
142
+ v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
 
143
  self.verts = v + self.trans.reshape([1, 3])
144
  self.G = G
145
 
 
163
  r_hat = r / theta
164
  cos = np.cos(theta)
165
  z_stick = np.zeros(theta.shape[0])
166
+ m = np.dstack(
167
+ [
168
+ z_stick,
169
+ -r_hat[:, 0, 2],
170
+ r_hat[:, 0, 1],
171
+ r_hat[:, 0, 2],
172
+ z_stick,
173
+ -r_hat[:, 0, 0],
174
+ -r_hat[:, 0, 1],
175
+ r_hat[:, 0, 0],
176
+ z_stick,
177
+ ]
178
+ ).reshape([-1, 3, 3])
179
+ i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3])
180
  A = np.transpose(r_hat, axes=[0, 2, 1])
181
  B = r_hat
182
  dot = np.matmul(A, B)
 
231
 
232
 
233
  class TetraSMPLModel:
234
+ def __init__(self, model_path, model_addition_path, age="adult", v_template=None):
 
 
 
 
 
235
  """
236
  SMPL model.
237
 
 
264
  self.posedirs_added = params_added["posedirs_added"]
265
  self.tetrahedrons = params_added["tetrahedrons"]
266
 
267
+ id_to_col = {self.kintree_table[1, i]: i for i in range(self.kintree_table.shape[1])}
 
 
 
268
  self.parent = {
269
  i: id_to_col[self.kintree_table[0, i]]
270
  for i in range(1, self.kintree_table.shape[1])
 
276
 
277
  if age == "kid":
278
  v_template_smil = np.load(
279
+ os.path.join(os.path.dirname(model_path), "smpl_kid_template.npy")
280
+ )
281
  v_template_smil -= np.mean(v_template_smil, axis=0)
282
+ v_template_diff = np.expand_dims(v_template_smil - self.v_template, axis=2)
 
283
  self.shapedirs = np.concatenate(
284
+ (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), axis=2
285
+ )
286
  self.beta_shape[0] += 1
287
 
288
  self.pose = np.zeros(self.pose_shape)
 
340
  """
341
  # how beta affect body shape
342
  v_shaped = self.shapedirs.dot(self.beta) + self.v_template
343
+ v_shaped_added = self.shapedirs_added.dot(self.beta) + self.v_template_added
 
344
  # joints location
345
  self.J = self.J_regressor.dot(v_shaped)
346
  pose_cube = self.pose.reshape((-1, 1, 3))
347
  # rotation matrix for each joint
348
  self.R = self.rodrigues(pose_cube)
349
+ I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), (self.R.shape[0] - 1, 3, 3))
 
350
  lrotmin = (self.R[1:] - I_cube).ravel()
351
  # how pose affect body shape in zero pose
352
  v_posed = v_shaped + self.posedirs.dot(lrotmin)
353
  v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin)
354
  # world transformation of each joint
355
  G = np.empty((self.kintree_table.shape[1], 4, 4))
356
+ G[0] = self.with_zeros(np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
 
357
  for i in range(1, self.kintree_table.shape[1]):
358
  G[i] = G[self.parent[i]].dot(
359
  self.with_zeros(
360
+ np.hstack(
361
+ [
362
+ self.R[i],
363
+ ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
364
+ ]
365
+ )
366
+ )
367
+ )
368
  # remove the transformation due to the rest pose
369
+ G = G - self.pack(np.matmul(G, np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
 
 
 
370
  self.G = G
371
  # transformation of each vertex
372
  T = np.tensordot(self.weights, G, axes=[[1], [0]])
373
  rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
374
+ v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
 
375
  self.verts = v + self.trans.reshape([1, 3])
376
  T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]])
377
+ rest_shape_added_h = np.hstack((v_posed_added, np.ones([v_posed_added.shape[0], 1])))
378
+ v_added = np.matmul(T_added, rest_shape_added_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
 
 
 
 
379
  self.verts_added = v_added + self.trans.reshape([1, 3])
380
 
381
  def rodrigues(self, r):
 
398
  r_hat = r / theta
399
  cos = np.cos(theta)
400
  z_stick = np.zeros(theta.shape[0])
401
+ m = np.dstack(
402
+ [
403
+ z_stick,
404
+ -r_hat[:, 0, 2],
405
+ r_hat[:, 0, 1],
406
+ r_hat[:, 0, 2],
407
+ z_stick,
408
+ -r_hat[:, 0, 0],
409
+ -r_hat[:, 0, 1],
410
+ r_hat[:, 0, 0],
411
+ z_stick,
412
+ ]
413
+ ).reshape([-1, 3, 3])
414
+ i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3])
415
  A = np.transpose(r_hat, axes=[0, 2, 1])
416
  B = r_hat
417
  dot = np.matmul(A, B)
lib/dataset/mesh_util.py CHANGED
@@ -14,32 +14,33 @@
14
  #
15
  # Contact: ps-license@tuebingen.mpg.de
16
 
 
17
  import numpy as np
18
- import cv2
19
- import pymeshlab
20
  import torch
21
  import torchvision
22
  import trimesh
23
- import os
24
- from termcolor import colored
25
  import os.path as osp
26
  import _pickle as cPickle
 
27
  from scipy.spatial import cKDTree
28
 
29
  from pytorch3d.structures import Meshes
30
  import torch.nn.functional as F
31
  import lib.smplx as smplx
 
32
  from pytorch3d.renderer.mesh import rasterize_meshes
33
  from PIL import Image, ImageFont, ImageDraw
34
  from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency
35
- import tinyobjloader
36
 
37
- from lib.common.imutils import uncrop
38
- from lib.common.render_utils import Pytorch3dRasterizer
39
 
 
 
 
40
 
41
- class SMPLX:
42
 
 
43
  def __init__(self):
44
 
45
  self.current_dir = osp.join(osp.dirname(__file__), "../../data/smpl_related")
@@ -54,10 +55,14 @@ class SMPLX:
54
 
55
  self.smplx_eyeball_fid_path = osp.join(self.current_dir, "smpl_data/eyeball_fid.npy")
56
  self.smplx_fill_mouth_fid_path = osp.join(self.current_dir, "smpl_data/fill_mouth_fid.npy")
57
- self.smplx_flame_vid_path = osp.join(self.current_dir, "smpl_data/FLAME_SMPLX_vertex_ids.npy")
 
 
58
  self.smplx_mano_vid_path = osp.join(self.current_dir, "smpl_data/MANO_SMPLX_vertex_ids.pkl")
59
  self.front_flame_path = osp.join(self.current_dir, "smpl_data/FLAME_face_mask_ids.npy")
60
- self.smplx_vertex_lmkid_path = osp.join(self.current_dir, "smpl_data/smplx_vertex_lmkid.npy")
 
 
61
 
62
  self.smplx_faces = np.load(self.smplx_faces_path)
63
  self.smplx_verts = np.load(self.smplx_verts_path)
@@ -68,84 +73,51 @@ class SMPLX:
68
  self.smplx_eyeball_fid_mask = np.load(self.smplx_eyeball_fid_path)
69
  self.smplx_mouth_fid = np.load(self.smplx_fill_mouth_fid_path)
70
  self.smplx_mano_vid_dict = np.load(self.smplx_mano_vid_path, allow_pickle=True)
71
- self.smplx_mano_vid = np.concatenate([self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]])
 
 
72
  self.smplx_flame_vid = np.load(self.smplx_flame_vid_path, allow_pickle=True)
73
  self.smplx_front_flame_vid = self.smplx_flame_vid[np.load(self.front_flame_path)]
74
 
75
  # hands
76
- self.mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_(0, torch.tensor(self.smplx_mano_vid), 1.0)
 
 
77
  # face
78
- self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_(
79
- 0, torch.tensor(self.smplx_front_flame_vid), 1.0)
80
- self.eyeball_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_(
81
- 0, torch.tensor(self.smplx_faces[self.smplx_eyeball_fid_mask].flatten()), 1.0)
 
 
82
 
83
  self.smplx_to_smpl = cPickle.load(open(self.smplx_to_smplx_path, "rb"))
84
 
85
  self.model_dir = osp.join(self.current_dir, "models")
86
  self.tedra_dir = osp.join(self.current_dir, "../tedra_data")
87
 
88
- self.ghum_smpl_pairs = torch.tensor([
89
- (0, 24),
90
- (2, 26),
91
- (5, 25),
92
- (7, 28),
93
- (8, 27),
94
- (11, 16),
95
- (12, 17),
96
- (13, 18),
97
- (14, 19),
98
- (15, 20),
99
- (16, 21),
100
- (17, 39),
101
- (18, 44),
102
- (19, 36),
103
- (20, 41),
104
- (21, 35),
105
- (22, 40),
106
- (23, 1),
107
- (24, 2),
108
- (25, 4),
109
- (26, 5),
110
- (27, 7),
111
- (28, 8),
112
- (29, 31),
113
- (30, 34),
114
- (31, 29),
115
- (32, 32),
116
- ]).long()
117
 
118
  # smpl-smplx correspondence
119
  self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73]
120
  self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [61 + 68, 72 + 68]
121
- self.smpl_joint_ids_45 = (np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist())
122
-
123
- self.extra_joint_ids = (
124
- np.array([
125
- 61,
126
- 72,
127
- 66,
128
- 69,
129
- 58,
130
- 68,
131
- 57,
132
- 56,
133
- 64,
134
- 59,
135
- 67,
136
- 75,
137
- 70,
138
- 65,
139
- 60,
140
- 61,
141
- 63,
142
- 62,
143
- 76,
144
- 71,
145
- 72,
146
- 74,
147
- 73,
148
- ]) + 68)
149
 
150
  self.smpl_joint_ids_45_pixie = (np.arange(22).tolist() + self.extra_joint_ids.tolist())
151
 
@@ -222,27 +194,6 @@ def load_fit_body(fitted_path, scale, smpl_type="smplx", smpl_gender="neutral",
222
  return smpl_mesh, smpl_joints
223
 
224
 
225
- def create_grid_points_from_xyz_bounds(bound, res):
226
-
227
- min_x, max_x, min_y, max_y, min_z, max_z = bound
228
- x = torch.linspace(min_x, max_x, res)
229
- y = torch.linspace(min_y, max_y, res)
230
- z = torch.linspace(min_z, max_z, res)
231
- X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
232
-
233
- return torch.stack([X, Y, Z], dim=-1)
234
-
235
-
236
- def create_grid_points_from_xy_bounds(bound, res):
237
-
238
- min_x, max_x, min_y, max_y = bound
239
- x = torch.linspace(min_x, max_x, res)
240
- y = torch.linspace(min_y, max_y, res)
241
- X, Y = torch.meshgrid(x, y, indexing='ij')
242
-
243
- return torch.stack([X, Y], dim=-1)
244
-
245
-
246
  def apply_face_mask(mesh, face_mask):
247
 
248
  mesh.update_faces(face_mask)
@@ -277,7 +228,8 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr
277
 
278
  part_extractor = PointFeat(
279
  torch.tensor(part_mesh.vertices).unsqueeze(0).to(device),
280
- torch.tensor(part_mesh.faces).unsqueeze(0).to(device))
 
281
 
282
  (part_dist, _) = part_extractor.query(torch.tensor(full_mesh.vertices).unsqueeze(0).to(device))
283
 
@@ -286,12 +238,20 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr
286
  if region == "hand":
287
  _, idx = smpl_tree.query(full_mesh.vertices, k=1)
288
  full_lmkid = SMPL_container.smplx_vertex_lmkid[idx]
289
- remove_mask = torch.logical_and(remove_mask, torch.tensor(full_lmkid >= 20).type_as(remove_mask).unsqueeze(0))
 
 
 
290
 
291
  elif region == "face":
292
  _, idx = smpl_tree.query(full_mesh.vertices, k=5)
293
- face_space_mask = torch.isin(torch.tensor(idx), torch.tensor(SMPL_container.smplx_front_flame_vid))
294
- remove_mask = torch.logical_and(remove_mask, face_space_mask.any(dim=1).type_as(remove_mask).unsqueeze(0))
 
 
 
 
 
295
 
296
  BNI_part_mask = ~(remove_mask).flatten()[full_mesh.faces].any(dim=1)
297
  full_mesh.update_faces(BNI_part_mask.detach().cpu())
@@ -303,109 +263,6 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr
303
  return full_mesh
304
 
305
 
306
- def cross(triangles):
307
- """
308
- Returns the cross product of two edges from input triangles
309
- Parameters
310
- --------------
311
- triangles: (n, 3, 3) float
312
- Vertices of triangles
313
- Returns
314
- --------------
315
- crosses : (n, 3) float
316
- Cross product of two edge vectors
317
- """
318
- vectors = np.diff(triangles, axis=1)
319
- crosses = np.cross(vectors[:, 0], vectors[:, 1])
320
- return crosses
321
-
322
-
323
- def tri_area(triangles=None, crosses=None, sum=False):
324
- """
325
- Calculates the sum area of input triangles
326
- Parameters
327
- ----------
328
- triangles : (n, 3, 3) float
329
- Vertices of triangles
330
- crosses : (n, 3) float or None
331
- As a speedup don't re- compute cross products
332
- sum : bool
333
- Return summed area or individual triangle area
334
- Returns
335
- ----------
336
- area : (n,) float or float
337
- Individual or summed area depending on `sum` argument
338
- """
339
- if crosses is None:
340
- crosses = cross(triangles)
341
- area = (np.sum(crosses**2, axis=1)**.5) * .5
342
- if sum:
343
- return np.sum(area)
344
- return area
345
-
346
-
347
- def sample_surface(triangles, count, area=None):
348
- """
349
- Sample the surface of a mesh, returning the specified
350
- number of points
351
- For individual triangle sampling uses this method:
352
- http://mathworld.wolfram.com/TrianglePointPicking.html
353
- Parameters
354
- ---------
355
- triangles : (n, 3, 3) float
356
- Vertices of triangles
357
- count : int
358
- Number of points to return
359
- Returns
360
- ---------
361
- samples : (count, 3) float
362
- Points in space on the surface of mesh
363
- face_index : (count,) int
364
- Indices of faces for each sampled point
365
- """
366
-
367
- # len(mesh.faces) float, array of the areas
368
- # of each face of the mesh
369
- if area is None:
370
- area = tri_area(triangles)
371
-
372
- # total area (float)
373
- area_sum = np.sum(area)
374
- # cumulative area (len(mesh.faces))
375
- area_cum = np.cumsum(area)
376
- face_pick = np.random.random(count) * area_sum
377
- face_index = np.searchsorted(area_cum, face_pick)
378
-
379
- # pull triangles into the form of an origin + 2 vectors
380
- tri_origins = triangles[:, 0]
381
- tri_vectors = triangles[:, 1:].copy()
382
- tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))
383
-
384
- # pull the vectors for the faces we are going to sample from
385
- tri_origins = tri_origins[face_index]
386
- tri_vectors = tri_vectors[face_index]
387
-
388
- # randomly generate two 0-1 scalar components to multiply edge vectors by
389
- random_lengths = np.random.random((len(tri_vectors), 2, 1))
390
-
391
- # points will be distributed on a quadrilateral if we use 2 0-1 samples
392
- # if the two scalar components sum less than 1.0 the point will be
393
- # inside the triangle, so we find vectors longer than 1.0 and
394
- # transform them to be inside the triangle
395
- random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0
396
- random_lengths[random_test] -= 1.0
397
- random_lengths = np.abs(random_lengths)
398
-
399
- # multiply triangle edge vectors by the random lengths and sum
400
- sample_vector = (tri_vectors * random_lengths).sum(axis=1)
401
-
402
- # finally, offset by the origin to generate
403
- # (n,3) points in space on the triangle
404
- samples = torch.tensor(sample_vector + tri_origins).float()
405
-
406
- return samples, face_index
407
-
408
-
409
  def obj_loader(path, with_uv=True):
410
  # Create reader.
411
  reader = tinyobjloader.ObjReader()
@@ -424,8 +281,8 @@ def obj_loader(path, with_uv=True):
424
  f_vt = tri[:, [2, 5, 8]]
425
 
426
  if with_uv:
427
- face_uvs = vt[f_vt].mean(axis=1) #[m, 2]
428
- vert_uvs = np.zeros((v.shape[0], 2), dtype=np.float32) #[n, 2]
429
  vert_uvs[f_v.reshape(-1)] = vt[f_vt.reshape(-1)]
430
 
431
  return v, f_v, vert_uvs, face_uvs
@@ -434,7 +291,6 @@ def obj_loader(path, with_uv=True):
434
 
435
 
436
  class HoppeMesh:
437
-
438
  def __init__(self, verts, faces, uvs=None, texture=None):
439
  """
440
  The HoppeSDF calculates signed distance towards a predefined oriented point cloud
@@ -459,34 +315,20 @@ class HoppeMesh:
459
  - points: [n, 3]
460
  - return: [n, 4] rgba
461
  """
462
- triangles = self.verts[faces] #[n, 3, 3]
463
- barycentric = trimesh.triangles.points_to_barycentric(triangles, points) #[n, 3]
464
- vert_colors = self.vertex_colors[faces] #[n, 3, 4]
465
  point_colors = torch.tensor((barycentric[:, :, None] * vert_colors).sum(axis=1)).float()
466
  return point_colors
467
 
468
  def triangles(self):
469
- return self.verts[self.faces].numpy() #[n, 3, 3]
470
 
471
 
472
  def tensor2variable(tensor, device):
473
  return tensor.requires_grad_(True).to(device)
474
 
475
 
476
- class GMoF(torch.nn.Module):
477
-
478
- def __init__(self, rho=1):
479
- super(GMoF, self).__init__()
480
- self.rho = rho
481
-
482
- def extra_repr(self):
483
- return "rho = {}".format(self.rho)
484
-
485
- def forward(self, residual):
486
- dist = torch.div(residual, residual + self.rho**2)
487
- return self.rho**2 * dist
488
-
489
-
490
  def mesh_edge_loss(meshes, target_length: float = 0.0):
491
  """
492
  Computes mesh edge length regularization loss averaged across all meshes
@@ -508,10 +350,10 @@ def mesh_edge_loss(meshes, target_length: float = 0.0):
508
  return torch.tensor([0.0], dtype=torch.float32, device=meshes.device, requires_grad=True)
509
 
510
  N = len(meshes)
511
- edges_packed = meshes.edges_packed() # (sum(E_n), 3)
512
- verts_packed = meshes.verts_packed() # (sum(V_n), 3)
513
- edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
514
- num_edges_per_mesh = meshes.num_edges_per_mesh() # N
515
 
516
  # Determine the weight for each edge based on the number of edges in the
517
  # mesh it corresponds to.
@@ -531,99 +373,37 @@ def mesh_edge_loss(meshes, target_length: float = 0.0):
531
  return loss_all
532
 
533
 
534
- def remesh(obj, obj_path):
535
-
536
- obj.export(obj_path)
537
- ms = pymeshlab.MeshSet()
538
- ms.load_new_mesh(obj_path)
539
- # ms.meshing_decimation_quadric_edge_collapse(targetfacenum=100000)
540
- ms.meshing_isotropic_explicit_remeshing(targetlen=pymeshlab.Percentage(0.5), adaptive=True)
541
- ms.apply_coord_laplacian_smoothing()
542
- ms.save_current_mesh(obj_path[:-4] + "_remesh.obj")
543
- polished_mesh = trimesh.load_mesh(obj_path[:-4] + "_remesh.obj")
544
 
545
- return polished_mesh
546
-
547
-
548
- def poisson_remesh(obj_path):
549
-
550
- ms = pymeshlab.MeshSet()
551
- ms.load_new_mesh(obj_path)
552
- ms.meshing_decimation_quadric_edge_collapse(targetfacenum=50000)
553
- # ms.apply_coord_laplacian_smoothing()
554
- ms.save_current_mesh(obj_path)
555
- # ms.save_current_mesh(obj_path.replace(".obj", ".ply"))
556
- polished_mesh = trimesh.load_mesh(obj_path)
557
 
558
- return polished_mesh
559
 
560
 
561
  def poisson(mesh, obj_path, depth=10):
562
 
563
- from pypoisson import poisson_reconstruction
564
- faces, vertices = poisson_reconstruction(mesh.vertices, mesh.vertex_normals, depth=depth)
565
-
566
- new_meshes = trimesh.Trimesh(vertices, faces)
567
- new_mesh_lst = new_meshes.split(only_watertight=False)
568
- comp_num = [new_mesh.vertices.shape[0] for new_mesh in new_mesh_lst]
569
- final_mesh = new_mesh_lst[comp_num.index(max(comp_num))]
570
- final_mesh.export(obj_path)
 
571
 
572
- final_mesh = poisson_remesh(obj_path)
 
 
573
 
574
- return final_mesh
 
575
 
576
-
577
- def get_mask(tensor, dim):
578
-
579
- mask = torch.abs(tensor).sum(dim=dim, keepdims=True) > 0.0
580
- mask = mask.type_as(tensor)
581
-
582
- return mask
583
-
584
-
585
- def blend_rgb_norm(norms, data):
586
-
587
- # norms [N, 3, res, res]
588
-
589
- masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1)
590
- norm_mask = F.interpolate(
591
- torch.cat([norms, masks], dim=1).detach().cpu(),
592
- size=data["uncrop_param"]["box_shape"],
593
- mode="bilinear",
594
- align_corners=False).permute(0, 2, 3, 1).numpy()
595
- final = data["img_raw"]
596
-
597
- for idx in range(len(norms)):
598
-
599
- norm_pred = (norm_mask[idx, :, :, :3] + 1.0) * 255.0 / 2.0
600
- mask_pred = np.repeat(norm_mask[idx, :, :, 3:4], 3, axis=-1)
601
-
602
- norm_ori = unwrap(norm_pred, data["uncrop_param"], idx)
603
- mask_ori = unwrap(mask_pred, data["uncrop_param"], idx)
604
-
605
- final = final * (1.0 - mask_ori) + norm_ori * mask_ori
606
-
607
- return final.astype(np.uint8)
608
-
609
-
610
- def unwrap(image, uncrop_param, idx):
611
-
612
- img_uncrop = uncrop(
613
- image,
614
- uncrop_param["center"][idx],
615
- uncrop_param["scale"][idx],
616
- uncrop_param["crop_shape"],
617
- )
618
-
619
- img_orig = cv2.warpAffine(
620
- img_uncrop,
621
- np.linalg.inv(uncrop_param["M"])[:2, :],
622
- uncrop_param["ori_shape"][::-1],
623
- flags=cv2.INTER_CUBIC,
624
- )
625
-
626
- return img_orig
627
 
628
 
629
  # Losses to smooth / regularize the mesh shape
@@ -634,60 +414,7 @@ def update_mesh_shape_prior_losses(mesh, losses):
634
  # mesh normal consistency
635
  losses["nc"]["value"] = mesh_normal_consistency(mesh)
636
  # mesh laplacian smoothing
637
- losses["laplacian"]["value"] = mesh_laplacian_smoothing(mesh, method="uniform")
638
-
639
-
640
- def rename(old_dict, old_name, new_name):
641
- new_dict = {}
642
- for key, value in zip(old_dict.keys(), old_dict.values()):
643
- new_key = key if key != old_name else new_name
644
- new_dict[new_key] = old_dict[key]
645
- return new_dict
646
-
647
-
648
- def load_checkpoint(model, cfg):
649
-
650
- model_dict = model.state_dict()
651
- main_dict = {}
652
- normal_dict = {}
653
-
654
- device = torch.device(f"cuda:{cfg['test_gpus'][0]}")
655
-
656
- if os.path.exists(cfg.resume_path) and cfg.resume_path.endswith("ckpt"):
657
- main_dict = torch.load(cfg.resume_path, map_location=device)["state_dict"]
658
-
659
- main_dict = {
660
- k: v for k, v in main_dict.items() if k in model_dict and v.shape == model_dict[k].shape and
661
- ("reconEngine" not in k) and ("normal_filter" not in k) and ("voxelization" not in k)
662
- }
663
- print(colored(f"Resume MLP weights from {cfg.resume_path}", "green"))
664
-
665
- if os.path.exists(cfg.normal_path) and cfg.normal_path.endswith("ckpt"):
666
- normal_dict = torch.load(cfg.normal_path, map_location=device)["state_dict"]
667
-
668
- for key in normal_dict.keys():
669
- normal_dict = rename(normal_dict, key, key.replace("netG", "netG.normal_filter"))
670
-
671
- normal_dict = {k: v for k, v in normal_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
672
- print(colored(f"Resume normal model from {cfg.normal_path}", "green"))
673
-
674
- model_dict.update(main_dict)
675
- model_dict.update(normal_dict)
676
- model.load_state_dict(model_dict)
677
-
678
- model.netG = model.netG.to(device)
679
- model.reconEngine = model.reconEngine.to(device)
680
-
681
- model.netG.training = False
682
- model.netG.eval()
683
-
684
- del main_dict
685
- del normal_dict
686
- del model_dict
687
-
688
- torch.cuda.empty_cache()
689
-
690
- return model
691
 
692
 
693
  def read_smpl_constants(folder):
@@ -706,8 +433,10 @@ def read_smpl_constants(folder):
706
  smpl_vertex_code = np.float32(np.copy(smpl_vtx_std))
707
  """Load smpl faces & tetrahedrons"""
708
  smpl_faces = np.loadtxt(os.path.join(folder, "faces.txt"), dtype=np.int32) - 1
709
- smpl_face_code = (smpl_vertex_code[smpl_faces[:, 0]] + smpl_vertex_code[smpl_faces[:, 1]] +
710
- smpl_vertex_code[smpl_faces[:, 2]]) / 3.0
 
 
711
  smpl_tetras = (np.loadtxt(os.path.join(folder, "tetrahedrons.txt"), dtype=np.int32) - 1)
712
 
713
  return_dict = {
@@ -720,19 +449,6 @@ def read_smpl_constants(folder):
720
  return return_dict
721
 
722
 
723
- def feat_select(feat, select):
724
-
725
- # feat [B, featx2, N]
726
- # select [B, 1, N]
727
- # return [B, feat, N]
728
-
729
- dim = feat.shape[1] // 2
730
- idx = torch.tile((1 - select), (1, dim, 1)) * dim + torch.arange(0, dim).unsqueeze(0).unsqueeze(2).type_as(select)
731
- feat_select = torch.gather(feat, 1, idx.long())
732
-
733
- return feat_select
734
-
735
-
736
  def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel=1):
737
  """get the visibility of vertices
738
 
@@ -771,7 +487,9 @@ def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel
771
 
772
  for idx in range(N_body):
773
  Num_faces = len(faces[idx])
774
- vis_vertices_id = torch.unique(faces[idx][torch.unique(pix_to_face[idx][pix_to_face[idx] != -1]) - Num_faces * idx, :])
 
 
775
  vis_mask[idx, vis_vertices_id] = 1.0
776
 
777
  # print("------------------------\n")
@@ -825,7 +543,7 @@ def orthogonal(points, calibrations, transforms=None):
825
  """
826
  rot = calibrations[:, :3, :3]
827
  trans = calibrations[:, :3, 3:4]
828
- pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
829
  if transforms is not None:
830
  scale = transforms[:2, :2]
831
  shift = transforms[:2, 2:3]
@@ -925,37 +643,14 @@ def compute_normal_batch(vertices, faces):
925
  return vert_norm
926
 
927
 
928
- def calculate_mIoU(outputs, labels):
929
-
930
- SMOOTH = 1e-6
931
-
932
- outputs = outputs.int()
933
- labels = labels.int()
934
-
935
- intersection = ((outputs & labels).float().sum()) # Will be zero if Truth=0 or Prediction=0
936
- union = (outputs | labels).float().sum() # Will be zzero if both are 0
937
-
938
- iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0
939
-
940
- thresholded = (torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10) # This is equal to comparing with thresolds
941
-
942
- return (thresholded.mean().detach().cpu().numpy()
943
- ) # Or thresholded.mean() if you are interested in average across the batch
944
-
945
-
946
- def add_alpha(colors, alpha=0.7):
947
-
948
- colors_pad = np.pad(colors, ((0, 0), (0, 1)), mode="constant", constant_values=alpha)
949
-
950
- return colors_pad
951
-
952
-
953
  def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"):
954
 
955
  font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf")
956
  font = ImageFont.truetype(font_path, 30)
957
  grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0), nrow=nrow, padding=0)
958
- grid_img = Image.fromarray(((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 * 255.0).astype(np.uint8))
 
 
959
 
960
  if False:
961
  # add text
@@ -965,16 +660,20 @@ def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"):
965
  draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font)
966
 
967
  if type == "smpl":
968
- for col_id, col_txt in enumerate([
 
969
  "image",
970
  "smpl-norm(render)",
971
  "cloth-norm(pred)",
972
  "diff-norm",
973
  "diff-mask",
974
- ]):
 
975
  draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
976
  elif type == "cloth":
977
- for col_id, col_txt in enumerate(["image", "cloth-norm(recon)", "cloth-norm(pred)", "diff-norm"]):
 
 
978
  draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
979
  for col_id, col_txt in enumerate(["0", "90", "180", "270"]):
980
  draw.text(
@@ -996,12 +695,9 @@ def clean_mesh(verts, faces):
996
  device = verts.device
997
 
998
  mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(), faces.detach().cpu().numpy())
999
- mesh_lst = mesh_lst.split(only_watertight=False)
1000
- comp_num = [mesh.vertices.shape[0] for mesh in mesh_lst]
1001
-
1002
- mesh_clean = mesh_lst[comp_num.index(max(comp_num))]
1003
- final_verts = torch.as_tensor(mesh_clean.vertices).float().to(device)
1004
- final_faces = torch.as_tensor(mesh_clean.faces).long().to(device)
1005
 
1006
  return final_verts, final_faces
1007
 
 
14
  #
15
  # Contact: ps-license@tuebingen.mpg.de
16
 
17
+ import os
18
  import numpy as np
 
 
19
  import torch
20
  import torchvision
21
  import trimesh
22
+ import open3d as o3d
23
+ import tinyobjloader
24
  import os.path as osp
25
  import _pickle as cPickle
26
+ from termcolor import colored
27
  from scipy.spatial import cKDTree
28
 
29
  from pytorch3d.structures import Meshes
30
  import torch.nn.functional as F
31
  import lib.smplx as smplx
32
+ from lib.common.render_utils import Pytorch3dRasterizer
33
  from pytorch3d.renderer.mesh import rasterize_meshes
34
  from PIL import Image, ImageFont, ImageDraw
35
  from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency
 
36
 
 
 
37
 
38
+ class Format:
39
+ end = '\033[0m'
40
+ start = '\033[4m'
41
 
 
42
 
43
+ class SMPLX:
44
  def __init__(self):
45
 
46
  self.current_dir = osp.join(osp.dirname(__file__), "../../data/smpl_related")
 
55
 
56
  self.smplx_eyeball_fid_path = osp.join(self.current_dir, "smpl_data/eyeball_fid.npy")
57
  self.smplx_fill_mouth_fid_path = osp.join(self.current_dir, "smpl_data/fill_mouth_fid.npy")
58
+ self.smplx_flame_vid_path = osp.join(
59
+ self.current_dir, "smpl_data/FLAME_SMPLX_vertex_ids.npy"
60
+ )
61
  self.smplx_mano_vid_path = osp.join(self.current_dir, "smpl_data/MANO_SMPLX_vertex_ids.pkl")
62
  self.front_flame_path = osp.join(self.current_dir, "smpl_data/FLAME_face_mask_ids.npy")
63
+ self.smplx_vertex_lmkid_path = osp.join(
64
+ self.current_dir, "smpl_data/smplx_vertex_lmkid.npy"
65
+ )
66
 
67
  self.smplx_faces = np.load(self.smplx_faces_path)
68
  self.smplx_verts = np.load(self.smplx_verts_path)
 
73
  self.smplx_eyeball_fid_mask = np.load(self.smplx_eyeball_fid_path)
74
  self.smplx_mouth_fid = np.load(self.smplx_fill_mouth_fid_path)
75
  self.smplx_mano_vid_dict = np.load(self.smplx_mano_vid_path, allow_pickle=True)
76
+ self.smplx_mano_vid = np.concatenate(
77
+ [self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]]
78
+ )
79
  self.smplx_flame_vid = np.load(self.smplx_flame_vid_path, allow_pickle=True)
80
  self.smplx_front_flame_vid = self.smplx_flame_vid[np.load(self.front_flame_path)]
81
 
82
  # hands
83
+ self.mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_(
84
+ 0, torch.tensor(self.smplx_mano_vid), 1.0
85
+ )
86
  # face
87
+ self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_(
88
+ 0, torch.tensor(self.smplx_front_flame_vid), 1.0
89
+ )
90
+ self.eyeball_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_(
91
+ 0, torch.tensor(self.smplx_faces[self.smplx_eyeball_fid_mask].flatten()), 1.0
92
+ )
93
 
94
  self.smplx_to_smpl = cPickle.load(open(self.smplx_to_smplx_path, "rb"))
95
 
96
  self.model_dir = osp.join(self.current_dir, "models")
97
  self.tedra_dir = osp.join(self.current_dir, "../tedra_data")
98
 
99
+ self.ghum_smpl_pairs = torch.tensor(
100
+ [
101
+ (0, 24), (2, 26), (5, 25), (7, 28), (8, 27), (11, 16), (12, 17), (13, 18), (14, 19),
102
+ (15, 20), (16, 21), (17, 39), (18, 44), (19, 36), (20, 41), (21, 35), (22, 40),
103
+ (23, 1), (24, 2), (25, 4), (26, 5), (27, 7), (28, 8), (29, 31), (30, 34), (31, 29),
104
+ (32, 32)
105
+ ]
106
+ ).long()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # smpl-smplx correspondence
109
  self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73]
110
  self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [61 + 68, 72 + 68]
111
+ self.smpl_joint_ids_45 = np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist()
112
+
113
+ self.extra_joint_ids = np.array(
114
+ [
115
+ 61, 72, 66, 69, 58, 68, 57, 56, 64, 59, 67, 75, 70, 65, 60, 61, 63, 62, 76, 71, 72,
116
+ 74, 73
117
+ ]
118
+ )
119
+
120
+ self.extra_joint_ids += 68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  self.smpl_joint_ids_45_pixie = (np.arange(22).tolist() + self.extra_joint_ids.tolist())
123
 
 
194
  return smpl_mesh, smpl_joints
195
 
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def apply_face_mask(mesh, face_mask):
198
 
199
  mesh.update_faces(face_mask)
 
228
 
229
  part_extractor = PointFeat(
230
  torch.tensor(part_mesh.vertices).unsqueeze(0).to(device),
231
+ torch.tensor(part_mesh.faces).unsqueeze(0).to(device)
232
+ )
233
 
234
  (part_dist, _) = part_extractor.query(torch.tensor(full_mesh.vertices).unsqueeze(0).to(device))
235
 
 
238
  if region == "hand":
239
  _, idx = smpl_tree.query(full_mesh.vertices, k=1)
240
  full_lmkid = SMPL_container.smplx_vertex_lmkid[idx]
241
+ remove_mask = torch.logical_and(
242
+ remove_mask,
243
+ torch.tensor(full_lmkid >= 20).type_as(remove_mask).unsqueeze(0)
244
+ )
245
 
246
  elif region == "face":
247
  _, idx = smpl_tree.query(full_mesh.vertices, k=5)
248
+ face_space_mask = torch.isin(
249
+ torch.tensor(idx), torch.tensor(SMPL_container.smplx_front_flame_vid)
250
+ )
251
+ remove_mask = torch.logical_and(
252
+ remove_mask,
253
+ face_space_mask.any(dim=1).type_as(remove_mask).unsqueeze(0)
254
+ )
255
 
256
  BNI_part_mask = ~(remove_mask).flatten()[full_mesh.faces].any(dim=1)
257
  full_mesh.update_faces(BNI_part_mask.detach().cpu())
 
263
  return full_mesh
264
 
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  def obj_loader(path, with_uv=True):
267
  # Create reader.
268
  reader = tinyobjloader.ObjReader()
 
281
  f_vt = tri[:, [2, 5, 8]]
282
 
283
  if with_uv:
284
+ face_uvs = vt[f_vt].mean(axis=1) #[m, 2]
285
+ vert_uvs = np.zeros((v.shape[0], 2), dtype=np.float32) #[n, 2]
286
  vert_uvs[f_v.reshape(-1)] = vt[f_vt.reshape(-1)]
287
 
288
  return v, f_v, vert_uvs, face_uvs
 
291
 
292
 
293
  class HoppeMesh:
 
294
  def __init__(self, verts, faces, uvs=None, texture=None):
295
  """
296
  The HoppeSDF calculates signed distance towards a predefined oriented point cloud
 
315
  - points: [n, 3]
316
  - return: [n, 4] rgba
317
  """
318
+ triangles = self.verts[faces] #[n, 3, 3]
319
+ barycentric = trimesh.triangles.points_to_barycentric(triangles, points) #[n, 3]
320
+ vert_colors = self.vertex_colors[faces] #[n, 3, 4]
321
  point_colors = torch.tensor((barycentric[:, :, None] * vert_colors).sum(axis=1)).float()
322
  return point_colors
323
 
324
  def triangles(self):
325
+ return self.verts[self.faces].numpy() #[n, 3, 3]
326
 
327
 
328
  def tensor2variable(tensor, device):
329
  return tensor.requires_grad_(True).to(device)
330
 
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  def mesh_edge_loss(meshes, target_length: float = 0.0):
333
  """
334
  Computes mesh edge length regularization loss averaged across all meshes
 
350
  return torch.tensor([0.0], dtype=torch.float32, device=meshes.device, requires_grad=True)
351
 
352
  N = len(meshes)
353
+ edges_packed = meshes.edges_packed() # (sum(E_n), 3)
354
+ verts_packed = meshes.verts_packed() # (sum(V_n), 3)
355
+ edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
356
+ num_edges_per_mesh = meshes.num_edges_per_mesh() # N
357
 
358
  # Determine the weight for each edge based on the number of edges in the
359
  # mesh it corresponds to.
 
373
  return loss_all
374
 
375
 
376
+ def remesh_laplacian(mesh, obj_path):
 
 
 
 
 
 
 
 
 
377
 
378
+ mesh = mesh.simplify_quadratic_decimation(50000)
379
+ mesh = trimesh.smoothing.filter_humphrey(
380
+ mesh, alpha=0.1, beta=0.5, iterations=10, laplacian_operator=None
381
+ )
382
+ mesh.export(obj_path)
 
 
 
 
 
 
 
383
 
384
+ return mesh
385
 
386
 
387
  def poisson(mesh, obj_path, depth=10):
388
 
389
+ pcd_path = obj_path[:-4] + ".ply"
390
+ assert (mesh.vertex_normals.shape[1] == 3)
391
+ mesh.export(pcd_path)
392
+ pcl = o3d.io.read_point_cloud(pcd_path)
393
+ with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm:
394
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
395
+ pcl, depth=depth, n_threads=-1
396
+ )
397
+ print(colored(f"\n Poisson completion to {Format.start} {obj_path} {Format.end}", "yellow"))
398
 
399
+ # only keep the largest component
400
+ largest_mesh = keep_largest(trimesh.Trimesh(np.array(mesh.vertices), np.array(mesh.triangles)))
401
+ largest_mesh.export(obj_path)
402
 
403
+ # mesh decimation for faster rendering
404
+ low_res_mesh = largest_mesh.simplify_quadratic_decimation(50000)
405
 
406
+ return low_res_mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
 
409
  # Losses to smooth / regularize the mesh shape
 
414
  # mesh normal consistency
415
  losses["nc"]["value"] = mesh_normal_consistency(mesh)
416
  # mesh laplacian smoothing
417
+ losses["lapla"]["value"] = mesh_laplacian_smoothing(mesh, method="uniform")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
 
420
  def read_smpl_constants(folder):
 
433
  smpl_vertex_code = np.float32(np.copy(smpl_vtx_std))
434
  """Load smpl faces & tetrahedrons"""
435
  smpl_faces = np.loadtxt(os.path.join(folder, "faces.txt"), dtype=np.int32) - 1
436
+ smpl_face_code = (
437
+ smpl_vertex_code[smpl_faces[:, 0]] + smpl_vertex_code[smpl_faces[:, 1]] +
438
+ smpl_vertex_code[smpl_faces[:, 2]]
439
+ ) / 3.0
440
  smpl_tetras = (np.loadtxt(os.path.join(folder, "tetrahedrons.txt"), dtype=np.int32) - 1)
441
 
442
  return_dict = {
 
449
  return return_dict
450
 
451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel=1):
453
  """get the visibility of vertices
454
 
 
487
 
488
  for idx in range(N_body):
489
  Num_faces = len(faces[idx])
490
+ vis_vertices_id = torch.unique(
491
+ faces[idx][torch.unique(pix_to_face[idx][pix_to_face[idx] != -1]) - Num_faces * idx, :]
492
+ )
493
  vis_mask[idx, vis_vertices_id] = 1.0
494
 
495
  # print("------------------------\n")
 
543
  """
544
  rot = calibrations[:, :3, :3]
545
  trans = calibrations[:, :3, 3:4]
546
+ pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
547
  if transforms is not None:
548
  scale = transforms[:2, :2]
549
  shift = transforms[:2, 2:3]
 
643
  return vert_norm
644
 
645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"):
647
 
648
  font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf")
649
  font = ImageFont.truetype(font_path, 30)
650
  grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0), nrow=nrow, padding=0)
651
+ grid_img = Image.fromarray(
652
+ ((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 * 255.0).astype(np.uint8)
653
+ )
654
 
655
  if False:
656
  # add text
 
660
  draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font)
661
 
662
  if type == "smpl":
663
+ for col_id, col_txt in enumerate(
664
+ [
665
  "image",
666
  "smpl-norm(render)",
667
  "cloth-norm(pred)",
668
  "diff-norm",
669
  "diff-mask",
670
+ ]
671
+ ):
672
  draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
673
  elif type == "cloth":
674
+ for col_id, col_txt in enumerate(
675
+ ["image", "cloth-norm(recon)", "cloth-norm(pred)", "diff-norm"]
676
+ ):
677
  draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
678
  for col_id, col_txt in enumerate(["0", "90", "180", "270"]):
679
  draw.text(
 
695
  device = verts.device
696
 
697
  mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(), faces.detach().cpu().numpy())
698
+ largest_mesh = keep_largest(mesh_lst)
699
+ final_verts = torch.as_tensor(largest_mesh.vertices).float().to(device)
700
+ final_faces = torch.as_tensor(largest_mesh.faces).long().to(device)
 
 
 
701
 
702
  return final_verts, final_faces
703
 
lib/net/BasePIFuNet.py CHANGED
@@ -21,11 +21,10 @@ from .geometry import index, orthogonal, perspective
21
 
22
 
23
  class BasePIFuNet(pl.LightningModule):
24
-
25
  def __init__(
26
- self,
27
- projection_mode="orthogonal",
28
- error_term=nn.MSELoss(),
29
  ):
30
  """
31
  :param projection_mode:
 
21
 
22
 
23
  class BasePIFuNet(pl.LightningModule):
 
24
  def __init__(
25
+ self,
26
+ projection_mode="orthogonal",
27
+ error_term=nn.MSELoss(),
28
  ):
29
  """
30
  :param projection_mode:
lib/net/Discriminator.py CHANGED
@@ -9,17 +9,18 @@ from lib.torch_utils.ops.native_ops import FusedLeakyReLU, fused_leaky_relu, upf
9
 
10
 
11
  class DiscriminatorHead(nn.Module):
12
-
13
  def __init__(self, in_channel, disc_stddev=False):
14
  super().__init__()
15
 
16
  self.disc_stddev = disc_stddev
17
  stddev_dim = 1 if disc_stddev else 0
18
 
19
- self.conv_stddev = ConvLayer2d(in_channel=in_channel + stddev_dim,
20
- out_channel=in_channel,
21
- kernel_size=3,
22
- activate=True)
 
 
23
 
24
  self.final_linear = nn.Sequential(
25
  nn.Flatten(),
@@ -32,8 +33,8 @@ class DiscriminatorHead(nn.Module):
32
  inv_perm = torch.argsort(perm)
33
 
34
  batch, channel, height, width = x.shape
35
- x = x[
36
- perm] # shuffle inputs so that all views in a single trajectory don't get put together
37
 
38
  group = min(batch, stddev_group)
39
  stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width)
@@ -41,7 +42,7 @@ class DiscriminatorHead(nn.Module):
41
  stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
42
  stddev = stddev.repeat(group, 1, height, width)
43
 
44
- stddev = stddev[inv_perm] # reorder inputs
45
  x = x[inv_perm]
46
 
47
  out = torch.cat([x, stddev], 1)
@@ -56,7 +57,6 @@ class DiscriminatorHead(nn.Module):
56
 
57
 
58
  class ConvDecoder(nn.Module):
59
-
60
  def __init__(self, in_channel, out_channel, in_res, out_res):
61
  super().__init__()
62
 
@@ -68,20 +68,22 @@ class ConvDecoder(nn.Module):
68
  for i in range(log_size_in, log_size_out):
69
  out_ch = in_ch // 2
70
  self.layers.append(
71
- ConvLayer2d(in_channel=in_ch,
72
- out_channel=out_ch,
73
- kernel_size=3,
74
- upsample=True,
75
- bias=True,
76
- activate=True))
 
 
 
77
  in_ch = out_ch
78
 
79
  self.layers.append(
80
- ConvLayer2d(in_channel=in_ch,
81
- out_channel=out_channel,
82
- kernel_size=3,
83
- bias=True,
84
- activate=False))
85
  self.layers = nn.Sequential(*self.layers)
86
 
87
  def forward(self, x):
@@ -89,7 +91,6 @@ class ConvDecoder(nn.Module):
89
 
90
 
91
  class StyleDiscriminator(nn.Module):
92
-
93
  def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs):
94
  super().__init__()
95
 
@@ -104,7 +105,8 @@ class StyleDiscriminator(nn.Module):
104
  for i in range(log_size_in, log_size_out, -1):
105
  out_channels = int(min(in_channels * 2, ch_max))
106
  self.layers.append(
107
- ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True))
 
108
  in_channels = out_channels
109
  self.layers = nn.Sequential(*self.layers)
110
 
@@ -147,7 +149,6 @@ class Blur(nn.Module):
147
  Upsample factor.
148
 
149
  """
150
-
151
  def __init__(self, kernel, pad, upsample_factor=1):
152
  super().__init__()
153
 
@@ -177,7 +178,6 @@ class Upsample(nn.Module):
177
  Upsampling factor.
178
 
179
  """
180
-
181
  def __init__(self, kernel=[1, 3, 3, 1], factor=2):
182
  super().__init__()
183
 
@@ -208,7 +208,6 @@ class Downsample(nn.Module):
208
  Downsampling factor.
209
 
210
  """
211
-
212
  def __init__(self, kernel=[1, 3, 3, 1], factor=2):
213
  super().__init__()
214
 
@@ -250,7 +249,6 @@ class EqualLinear(nn.Module):
250
  Apply leakyReLU activation.
251
 
252
  """
253
-
254
  def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False):
255
  super().__init__()
256
 
@@ -300,7 +298,6 @@ class EqualConv2d(nn.Module):
300
  Use bias term.
301
 
302
  """
303
-
304
  def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
305
  super().__init__()
306
 
@@ -316,16 +313,20 @@ class EqualConv2d(nn.Module):
316
  self.bias = None
317
 
318
  def forward(self, input):
319
- out = F.conv2d(input,
320
- self.weight * self.scale,
321
- bias=self.bias,
322
- stride=self.stride,
323
- padding=self.padding)
 
 
324
  return out
325
 
326
  def __repr__(self):
327
- return (f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
328
- f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})")
 
 
329
 
330
 
331
  class EqualConvTranspose2d(nn.Module):
@@ -353,15 +354,16 @@ class EqualConvTranspose2d(nn.Module):
353
  Use bias term.
354
 
355
  """
356
-
357
- def __init__(self,
358
- in_channel,
359
- out_channel,
360
- kernel_size,
361
- stride=1,
362
- padding=0,
363
- output_padding=0,
364
- bias=True):
 
365
  super().__init__()
366
 
367
  self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size))
@@ -388,12 +390,13 @@ class EqualConvTranspose2d(nn.Module):
388
  return out
389
 
390
  def __repr__(self):
391
- return (f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},'
392
- f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})')
 
 
393
 
394
 
395
  class ConvLayer2d(nn.Sequential):
396
-
397
  def __init__(
398
  self,
399
  in_channel,
@@ -415,12 +418,15 @@ class ConvLayer2d(nn.Sequential):
415
  pad1 = p // 2 + 1
416
 
417
  layers.append(
418
- EqualConvTranspose2d(in_channel,
419
- out_channel,
420
- kernel_size,
421
- padding=0,
422
- stride=2,
423
- bias=bias and not activate))
 
 
 
424
  layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor))
425
 
426
  if downsample:
@@ -431,23 +437,29 @@ class ConvLayer2d(nn.Sequential):
431
 
432
  layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
433
  layers.append(
434
- EqualConv2d(in_channel,
435
- out_channel,
436
- kernel_size,
437
- padding=0,
438
- stride=2,
439
- bias=bias and not activate))
 
 
 
440
 
441
  if (not downsample) and (not upsample):
442
  padding = kernel_size // 2
443
 
444
  layers.append(
445
- EqualConv2d(in_channel,
446
- out_channel,
447
- kernel_size,
448
- padding=padding,
449
- stride=1,
450
- bias=bias and not activate))
 
 
 
451
 
452
  if activate:
453
  layers.append(FusedLeakyReLU(out_channel, bias=bias))
@@ -472,7 +484,6 @@ class ConvResBlock2d(nn.Module):
472
  Apply downsampling via strided convolution in the second conv.
473
 
474
  """
475
-
476
  def __init__(self, in_channel, out_channel, upsample=False, downsample=False):
477
  super().__init__()
478
 
 
9
 
10
 
11
  class DiscriminatorHead(nn.Module):
 
12
  def __init__(self, in_channel, disc_stddev=False):
13
  super().__init__()
14
 
15
  self.disc_stddev = disc_stddev
16
  stddev_dim = 1 if disc_stddev else 0
17
 
18
+ self.conv_stddev = ConvLayer2d(
19
+ in_channel=in_channel + stddev_dim,
20
+ out_channel=in_channel,
21
+ kernel_size=3,
22
+ activate=True
23
+ )
24
 
25
  self.final_linear = nn.Sequential(
26
  nn.Flatten(),
 
33
  inv_perm = torch.argsort(perm)
34
 
35
  batch, channel, height, width = x.shape
36
+ x = x[perm
37
+ ] # shuffle inputs so that all views in a single trajectory don't get put together
38
 
39
  group = min(batch, stddev_group)
40
  stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width)
 
42
  stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
43
  stddev = stddev.repeat(group, 1, height, width)
44
 
45
+ stddev = stddev[inv_perm] # reorder inputs
46
  x = x[inv_perm]
47
 
48
  out = torch.cat([x, stddev], 1)
 
57
 
58
 
59
  class ConvDecoder(nn.Module):
 
60
  def __init__(self, in_channel, out_channel, in_res, out_res):
61
  super().__init__()
62
 
 
68
  for i in range(log_size_in, log_size_out):
69
  out_ch = in_ch // 2
70
  self.layers.append(
71
+ ConvLayer2d(
72
+ in_channel=in_ch,
73
+ out_channel=out_ch,
74
+ kernel_size=3,
75
+ upsample=True,
76
+ bias=True,
77
+ activate=True
78
+ )
79
+ )
80
  in_ch = out_ch
81
 
82
  self.layers.append(
83
+ ConvLayer2d(
84
+ in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False
85
+ )
86
+ )
 
87
  self.layers = nn.Sequential(*self.layers)
88
 
89
  def forward(self, x):
 
91
 
92
 
93
  class StyleDiscriminator(nn.Module):
 
94
  def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs):
95
  super().__init__()
96
 
 
105
  for i in range(log_size_in, log_size_out, -1):
106
  out_channels = int(min(in_channels * 2, ch_max))
107
  self.layers.append(
108
+ ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True)
109
+ )
110
  in_channels = out_channels
111
  self.layers = nn.Sequential(*self.layers)
112
 
 
149
  Upsample factor.
150
 
151
  """
 
152
  def __init__(self, kernel, pad, upsample_factor=1):
153
  super().__init__()
154
 
 
178
  Upsampling factor.
179
 
180
  """
 
181
  def __init__(self, kernel=[1, 3, 3, 1], factor=2):
182
  super().__init__()
183
 
 
208
  Downsampling factor.
209
 
210
  """
 
211
  def __init__(self, kernel=[1, 3, 3, 1], factor=2):
212
  super().__init__()
213
 
 
249
  Apply leakyReLU activation.
250
 
251
  """
 
252
  def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False):
253
  super().__init__()
254
 
 
298
  Use bias term.
299
 
300
  """
 
301
  def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
302
  super().__init__()
303
 
 
313
  self.bias = None
314
 
315
  def forward(self, input):
316
+ out = F.conv2d(
317
+ input,
318
+ self.weight * self.scale,
319
+ bias=self.bias,
320
+ stride=self.stride,
321
+ padding=self.padding
322
+ )
323
  return out
324
 
325
  def __repr__(self):
326
+ return (
327
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
328
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
329
+ )
330
 
331
 
332
  class EqualConvTranspose2d(nn.Module):
 
354
  Use bias term.
355
 
356
  """
357
+ def __init__(
358
+ self,
359
+ in_channel,
360
+ out_channel,
361
+ kernel_size,
362
+ stride=1,
363
+ padding=0,
364
+ output_padding=0,
365
+ bias=True
366
+ ):
367
  super().__init__()
368
 
369
  self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size))
 
390
  return out
391
 
392
  def __repr__(self):
393
+ return (
394
+ f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},'
395
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
396
+ )
397
 
398
 
399
  class ConvLayer2d(nn.Sequential):
 
400
  def __init__(
401
  self,
402
  in_channel,
 
418
  pad1 = p // 2 + 1
419
 
420
  layers.append(
421
+ EqualConvTranspose2d(
422
+ in_channel,
423
+ out_channel,
424
+ kernel_size,
425
+ padding=0,
426
+ stride=2,
427
+ bias=bias and not activate
428
+ )
429
+ )
430
  layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor))
431
 
432
  if downsample:
 
437
 
438
  layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
439
  layers.append(
440
+ EqualConv2d(
441
+ in_channel,
442
+ out_channel,
443
+ kernel_size,
444
+ padding=0,
445
+ stride=2,
446
+ bias=bias and not activate
447
+ )
448
+ )
449
 
450
  if (not downsample) and (not upsample):
451
  padding = kernel_size // 2
452
 
453
  layers.append(
454
+ EqualConv2d(
455
+ in_channel,
456
+ out_channel,
457
+ kernel_size,
458
+ padding=padding,
459
+ stride=1,
460
+ bias=bias and not activate
461
+ )
462
+ )
463
 
464
  if activate:
465
  layers.append(FusedLeakyReLU(out_channel, bias=bias))
 
484
  Apply downsampling via strided convolution in the second conv.
485
 
486
  """
 
487
  def __init__(self, in_channel, out_channel, upsample=False, downsample=False):
488
  super().__init__()
489
 
lib/net/FBNet.py CHANGED
@@ -51,17 +51,17 @@ def get_norm_layer(norm_type="instance"):
51
 
52
 
53
  def define_G(
54
- input_nc,
55
- output_nc,
56
- ngf,
57
- netG,
58
- n_downsample_global=3,
59
- n_blocks_global=9,
60
- n_local_enhancers=1,
61
- n_blocks_local=3,
62
- norm="instance",
63
- gpu_ids=[],
64
- last_op=nn.Tanh(),
65
  ):
66
  norm_layer = get_norm_layer(norm_type=norm)
67
  if netG == "global":
@@ -97,17 +97,20 @@ def define_G(
97
  return netG
98
 
99
 
100
- def define_D(input_nc,
101
- ndf,
102
- n_layers_D,
103
- norm='instance',
104
- use_sigmoid=False,
105
- num_D=1,
106
- getIntermFeat=False,
107
- gpu_ids=[]):
 
 
108
  norm_layer = get_norm_layer(norm_type=norm)
109
- netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D,
110
- getIntermFeat)
 
111
  if len(gpu_ids) > 0:
112
  assert (torch.cuda.is_available())
113
  netD.cuda(gpu_ids[0])
@@ -129,7 +132,6 @@ def print_network(net):
129
  # Generator
130
  ##############################################################################
131
  class LocalEnhancer(pl.LightningModule):
132
-
133
  def __init__(
134
  self,
135
  input_nc,
@@ -155,8 +157,9 @@ class LocalEnhancer(pl.LightningModule):
155
  n_blocks_global,
156
  norm_layer,
157
  ).model
158
- model_global = [model_global[i] for i in range(len(model_global) - 3)
159
- ] # get rid of final convolution layers
 
160
  self.model = nn.Sequential(*model_global)
161
 
162
  ###### local enhancer layers #####
@@ -224,17 +227,16 @@ class LocalEnhancer(pl.LightningModule):
224
 
225
 
226
  class GlobalGenerator(pl.LightningModule):
227
-
228
  def __init__(
229
- self,
230
- input_nc,
231
- output_nc,
232
- ngf=64,
233
- n_downsampling=3,
234
- n_blocks=9,
235
- norm_layer=nn.BatchNorm2d,
236
- padding_type="reflect",
237
- last_op=nn.Tanh(),
238
  ):
239
  assert n_blocks >= 0
240
  super(GlobalGenerator, self).__init__()
@@ -296,42 +298,49 @@ class GlobalGenerator(pl.LightningModule):
296
 
297
  # Defines the PatchGAN discriminator with the specified arguments.
298
  class NLayerDiscriminator(nn.Module):
299
-
300
- def __init__(self,
301
- input_nc,
302
- ndf=64,
303
- n_layers=3,
304
- norm_layer=nn.BatchNorm2d,
305
- use_sigmoid=False,
306
- getIntermFeat=False):
 
307
  super(NLayerDiscriminator, self).__init__()
308
  self.getIntermFeat = getIntermFeat
309
  self.n_layers = n_layers
310
 
311
  kw = 4
312
  padw = int(np.ceil((kw - 1.0) / 2))
313
- sequence = [[
314
- nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
315
- nn.LeakyReLU(0.2, True)
316
- ]]
 
 
317
 
318
  nf = ndf
319
  for n in range(1, n_layers):
320
  nf_prev = nf
321
  nf = min(nf * 2, 512)
322
- sequence += [[
323
- nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
324
- norm_layer(nf),
325
- nn.LeakyReLU(0.2, True)
326
- ]]
 
 
327
 
328
  nf_prev = nf
329
  nf = min(nf * 2, 512)
330
- sequence += [[
331
- nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
332
- norm_layer(nf),
333
- nn.LeakyReLU(0.2, True)
334
- ]]
 
 
335
 
336
  sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
337
 
@@ -359,27 +368,30 @@ class NLayerDiscriminator(nn.Module):
359
 
360
 
361
  class MultiscaleDiscriminator(pl.LightningModule):
362
-
363
- def __init__(self,
364
- input_nc,
365
- ndf=64,
366
- n_layers=3,
367
- norm_layer=nn.BatchNorm2d,
368
- use_sigmoid=False,
369
- num_D=3,
370
- getIntermFeat=False):
 
371
  super(MultiscaleDiscriminator, self).__init__()
372
  self.num_D = num_D
373
  self.n_layers = n_layers
374
  self.getIntermFeat = getIntermFeat
375
 
376
  for i in range(num_D):
377
- netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid,
378
- getIntermFeat)
 
379
  if getIntermFeat:
380
  for j in range(n_layers + 2):
381
- setattr(self, 'scale' + str(i) + '_layer' + str(j),
382
- getattr(netD, 'model' + str(j)))
 
383
  else:
384
  setattr(self, 'layer' + str(i), netD.model)
385
 
@@ -414,11 +426,11 @@ class MultiscaleDiscriminator(pl.LightningModule):
414
 
415
  # Define a resnet block
416
  class ResnetBlock(pl.LightningModule):
417
-
418
  def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
419
  super(ResnetBlock, self).__init__()
420
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation,
421
- use_dropout)
 
422
 
423
  def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
424
  conv_block = []
@@ -459,7 +471,6 @@ class ResnetBlock(pl.LightningModule):
459
 
460
 
461
  class Encoder(pl.LightningModule):
462
-
463
  def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
464
  super(Encoder, self).__init__()
465
  self.output_nc = output_nc
@@ -510,18 +521,17 @@ class Encoder(pl.LightningModule):
510
  inst_list = np.unique(inst.cpu().numpy().astype(int))
511
  for i in inst_list:
512
  for b in range(input.size()[0]):
513
- indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4
514
  for j in range(self.output_nc):
515
  output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
516
- indices[:, 3],]
517
  mean_feat = torch.mean(output_ins).expand_as(output_ins)
518
  outputs_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
519
- indices[:, 3],] = mean_feat
520
  return outputs_mean
521
 
522
 
523
  class Vgg19(nn.Module):
524
-
525
  def __init__(self, requires_grad=False):
526
  super(Vgg19, self).__init__()
527
  vgg_pretrained_features = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
@@ -555,7 +565,6 @@ class Vgg19(nn.Module):
555
 
556
 
557
  class VGG19FeatLayer(nn.Module):
558
-
559
  def __init__(self):
560
  super(VGG19FeatLayer, self).__init__()
561
  self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.eval()
@@ -593,7 +602,6 @@ class VGG19FeatLayer(nn.Module):
593
 
594
 
595
  class VGGLoss(pl.LightningModule):
596
-
597
  def __init__(self):
598
  super(VGGLoss, self).__init__()
599
  self.vgg = Vgg19().eval()
@@ -609,11 +617,7 @@ class VGGLoss(pl.LightningModule):
609
 
610
 
611
  class GANLoss(pl.LightningModule):
612
-
613
- def __init__(self,
614
- use_lsgan=True,
615
- target_real_label=1.0,
616
- target_fake_label=0.0):
617
  super(GANLoss, self).__init__()
618
  self.real_label = target_real_label
619
  self.fake_label = target_fake_label
@@ -628,16 +632,18 @@ class GANLoss(pl.LightningModule):
628
  def get_target_tensor(self, input, target_is_real):
629
  target_tensor = None
630
  if target_is_real:
631
- create_label = ((self.real_label_var is None) or
632
- (self.real_label_var.numel() != input.numel()))
 
633
  if create_label:
634
  real_tensor = self.tensor(input.size()).fill_(self.real_label)
635
  self.real_label_var = real_tensor
636
  self.real_label_var.requires_grad = False
637
  target_tensor = self.real_label_var
638
  else:
639
- create_label = ((self.fake_label_var is None) or
640
- (self.fake_label_var.numel() != input.numel()))
 
641
  if create_label:
642
  fake_tensor = self.tensor(input.size()).fill_(self.fake_label)
643
  self.fake_label_var = fake_tensor
@@ -659,7 +665,6 @@ class GANLoss(pl.LightningModule):
659
 
660
 
661
  class IDMRFLoss(pl.LightningModule):
662
-
663
  def __init__(self, featlayer=VGG19FeatLayer):
664
  super(IDMRFLoss, self).__init__()
665
  self.featlayer = featlayer()
@@ -678,7 +683,8 @@ class IDMRFLoss(pl.LightningModule):
678
  patch_size = 1
679
  patch_stride = 1
680
  patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(
681
- 3, patch_size, patch_stride)
 
682
  self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5)
683
  dims = self.patches_OIHW.size()
684
  self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5])
@@ -743,7 +749,8 @@ class IDMRFLoss(pl.LightningModule):
743
  self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer])
744
  for layer in self.feat_content_layers
745
  ]
746
- self.content_loss = functools.reduce(lambda x, y: x + y,
747
- content_loss_list) * self.lambda_content
 
748
 
749
  return self.style_loss + self.content_loss
 
51
 
52
 
53
  def define_G(
54
+ input_nc,
55
+ output_nc,
56
+ ngf,
57
+ netG,
58
+ n_downsample_global=3,
59
+ n_blocks_global=9,
60
+ n_local_enhancers=1,
61
+ n_blocks_local=3,
62
+ norm="instance",
63
+ gpu_ids=[],
64
+ last_op=nn.Tanh(),
65
  ):
66
  norm_layer = get_norm_layer(norm_type=norm)
67
  if netG == "global":
 
97
  return netG
98
 
99
 
100
+ def define_D(
101
+ input_nc,
102
+ ndf,
103
+ n_layers_D,
104
+ norm='instance',
105
+ use_sigmoid=False,
106
+ num_D=1,
107
+ getIntermFeat=False,
108
+ gpu_ids=[]
109
+ ):
110
  norm_layer = get_norm_layer(norm_type=norm)
111
+ netD = MultiscaleDiscriminator(
112
+ input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat
113
+ )
114
  if len(gpu_ids) > 0:
115
  assert (torch.cuda.is_available())
116
  netD.cuda(gpu_ids[0])
 
132
  # Generator
133
  ##############################################################################
134
  class LocalEnhancer(pl.LightningModule):
 
135
  def __init__(
136
  self,
137
  input_nc,
 
157
  n_blocks_global,
158
  norm_layer,
159
  ).model
160
+ model_global = [
161
+ model_global[i] for i in range(len(model_global) - 3)
162
+ ] # get rid of final convolution layers
163
  self.model = nn.Sequential(*model_global)
164
 
165
  ###### local enhancer layers #####
 
227
 
228
 
229
  class GlobalGenerator(pl.LightningModule):
 
230
  def __init__(
231
+ self,
232
+ input_nc,
233
+ output_nc,
234
+ ngf=64,
235
+ n_downsampling=3,
236
+ n_blocks=9,
237
+ norm_layer=nn.BatchNorm2d,
238
+ padding_type="reflect",
239
+ last_op=nn.Tanh(),
240
  ):
241
  assert n_blocks >= 0
242
  super(GlobalGenerator, self).__init__()
 
298
 
299
  # Defines the PatchGAN discriminator with the specified arguments.
300
  class NLayerDiscriminator(nn.Module):
301
+ def __init__(
302
+ self,
303
+ input_nc,
304
+ ndf=64,
305
+ n_layers=3,
306
+ norm_layer=nn.BatchNorm2d,
307
+ use_sigmoid=False,
308
+ getIntermFeat=False
309
+ ):
310
  super(NLayerDiscriminator, self).__init__()
311
  self.getIntermFeat = getIntermFeat
312
  self.n_layers = n_layers
313
 
314
  kw = 4
315
  padw = int(np.ceil((kw - 1.0) / 2))
316
+ sequence = [
317
+ [
318
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
319
+ nn.LeakyReLU(0.2, True)
320
+ ]
321
+ ]
322
 
323
  nf = ndf
324
  for n in range(1, n_layers):
325
  nf_prev = nf
326
  nf = min(nf * 2, 512)
327
+ sequence += [
328
+ [
329
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
330
+ norm_layer(nf),
331
+ nn.LeakyReLU(0.2, True)
332
+ ]
333
+ ]
334
 
335
  nf_prev = nf
336
  nf = min(nf * 2, 512)
337
+ sequence += [
338
+ [
339
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
340
+ norm_layer(nf),
341
+ nn.LeakyReLU(0.2, True)
342
+ ]
343
+ ]
344
 
345
  sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
346
 
 
368
 
369
 
370
  class MultiscaleDiscriminator(pl.LightningModule):
371
+ def __init__(
372
+ self,
373
+ input_nc,
374
+ ndf=64,
375
+ n_layers=3,
376
+ norm_layer=nn.BatchNorm2d,
377
+ use_sigmoid=False,
378
+ num_D=3,
379
+ getIntermFeat=False
380
+ ):
381
  super(MultiscaleDiscriminator, self).__init__()
382
  self.num_D = num_D
383
  self.n_layers = n_layers
384
  self.getIntermFeat = getIntermFeat
385
 
386
  for i in range(num_D):
387
+ netD = NLayerDiscriminator(
388
+ input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat
389
+ )
390
  if getIntermFeat:
391
  for j in range(n_layers + 2):
392
+ setattr(
393
+ self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))
394
+ )
395
  else:
396
  setattr(self, 'layer' + str(i), netD.model)
397
 
 
426
 
427
  # Define a resnet block
428
  class ResnetBlock(pl.LightningModule):
 
429
  def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
430
  super(ResnetBlock, self).__init__()
431
+ self.conv_block = self.build_conv_block(
432
+ dim, padding_type, norm_layer, activation, use_dropout
433
+ )
434
 
435
  def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
436
  conv_block = []
 
471
 
472
 
473
  class Encoder(pl.LightningModule):
 
474
  def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
475
  super(Encoder, self).__init__()
476
  self.output_nc = output_nc
 
521
  inst_list = np.unique(inst.cpu().numpy().astype(int))
522
  for i in inst_list:
523
  for b in range(input.size()[0]):
524
+ indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4
525
  for j in range(self.output_nc):
526
  output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
527
+ indices[:, 3], ]
528
  mean_feat = torch.mean(output_ins).expand_as(output_ins)
529
  outputs_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
530
+ indices[:, 3], ] = mean_feat
531
  return outputs_mean
532
 
533
 
534
  class Vgg19(nn.Module):
 
535
  def __init__(self, requires_grad=False):
536
  super(Vgg19, self).__init__()
537
  vgg_pretrained_features = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
 
565
 
566
 
567
  class VGG19FeatLayer(nn.Module):
 
568
  def __init__(self):
569
  super(VGG19FeatLayer, self).__init__()
570
  self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.eval()
 
602
 
603
 
604
  class VGGLoss(pl.LightningModule):
 
605
  def __init__(self):
606
  super(VGGLoss, self).__init__()
607
  self.vgg = Vgg19().eval()
 
617
 
618
 
619
  class GANLoss(pl.LightningModule):
620
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
 
 
 
 
621
  super(GANLoss, self).__init__()
622
  self.real_label = target_real_label
623
  self.fake_label = target_fake_label
 
632
  def get_target_tensor(self, input, target_is_real):
633
  target_tensor = None
634
  if target_is_real:
635
+ create_label = (
636
+ (self.real_label_var is None) or (self.real_label_var.numel() != input.numel())
637
+ )
638
  if create_label:
639
  real_tensor = self.tensor(input.size()).fill_(self.real_label)
640
  self.real_label_var = real_tensor
641
  self.real_label_var.requires_grad = False
642
  target_tensor = self.real_label_var
643
  else:
644
+ create_label = (
645
+ (self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())
646
+ )
647
  if create_label:
648
  fake_tensor = self.tensor(input.size()).fill_(self.fake_label)
649
  self.fake_label_var = fake_tensor
 
665
 
666
 
667
  class IDMRFLoss(pl.LightningModule):
 
668
  def __init__(self, featlayer=VGG19FeatLayer):
669
  super(IDMRFLoss, self).__init__()
670
  self.featlayer = featlayer()
 
683
  patch_size = 1
684
  patch_stride = 1
685
  patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(
686
+ 3, patch_size, patch_stride
687
+ )
688
  self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5)
689
  dims = self.patches_OIHW.size()
690
  self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5])
 
749
  self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer])
750
  for layer in self.feat_content_layers
751
  ]
752
+ self.content_loss = functools.reduce(
753
+ lambda x, y: x + y, content_loss_list
754
+ ) * self.lambda_content
755
 
756
  return self.style_loss + self.content_loss
lib/net/GANLoss.py CHANGED
@@ -32,13 +32,12 @@ def logistic_loss(fake_pred, real_pred, mode):
32
 
33
 
34
  def r1_loss(real_pred, real_img):
35
- (grad_real,) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)
36
  grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
37
  return grad_penalty
38
 
39
 
40
  class GANLoss(nn.Module):
41
-
42
  def __init__(
43
  self,
44
  opt,
@@ -64,7 +63,7 @@ class GANLoss(nn.Module):
64
  logits_fake = self.discriminator(disc_in_fake)
65
 
66
  disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d')
67
-
68
  log = {
69
  "disc_loss": disc_loss.detach(),
70
  "logits_real": logits_real.mean().detach(),
 
32
 
33
 
34
  def r1_loss(real_pred, real_img):
35
+ (grad_real, ) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)
36
  grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
37
  return grad_penalty
38
 
39
 
40
  class GANLoss(nn.Module):
 
41
  def __init__(
42
  self,
43
  opt,
 
63
  logits_fake = self.discriminator(disc_in_fake)
64
 
65
  disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d')
66
+
67
  log = {
68
  "disc_loss": disc_loss.detach(),
69
  "logits_real": logits_real.mean().detach(),
lib/net/IFGeoNet.py CHANGED
@@ -8,20 +8,17 @@ from lib.dataset.mesh_util import read_smpl_constants, SMPLX
8
 
9
 
10
  class SelfAttention(torch.nn.Module):
11
-
12
  def __init__(self, in_channels, out_channels):
13
  super().__init__()
14
- self.conv = nn.Conv3d(in_channels,
15
- out_channels,
16
- 3,
17
- padding=1,
18
- padding_mode='replicate')
19
- self.attention = nn.Conv3d(in_channels,
20
- out_channels,
21
- kernel_size=3,
22
- padding=1,
23
- padding_mode='replicate',
24
- bias=False)
25
  with torch.no_grad():
26
  self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
27
 
@@ -32,38 +29,45 @@ class SelfAttention(torch.nn.Module):
32
 
33
 
34
  class IFGeoNet(nn.Module):
35
-
36
  def __init__(self, cfg, hidden_dim=256):
37
  super(IFGeoNet, self).__init__()
38
 
39
- self.conv_in_partial = nn.Conv3d(1, 16, 3, padding=1,
40
- padding_mode='replicate') # out: 256 ->m.p. 128
 
41
 
42
- self.conv_in_smpl = nn.Conv3d(1, 4, 3, padding=1,
43
- padding_mode='replicate') # out: 256 ->m.p. 128
 
44
 
45
  self.SA = SelfAttention(4, 4)
46
- self.conv_0_fusion = nn.Conv3d(16 + 4, 32, 3, padding=1,
47
- padding_mode='replicate') # out: 128
48
- self.conv_0_1_fusion = nn.Conv3d(32, 32, 3, padding=1,
49
- padding_mode='replicate') # out: 128 ->m.p. 64
50
-
51
- self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
52
- self.conv_0_1 = nn.Conv3d(32, 32, 3, padding=1,
53
- padding_mode='replicate') # out: 128 ->m.p. 64
54
-
55
- self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
56
- self.conv_1_1 = nn.Conv3d(64, 64, 3, padding=1,
57
- padding_mode='replicate') # out: 64 -> mp 32
58
-
59
- self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
60
- self.conv_2_1 = nn.Conv3d(128, 128, 3, padding=1,
61
- padding_mode='replicate') # out: 32 -> mp 16
62
- self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
63
- self.conv_3_1 = nn.Conv3d(128, 128, 3, padding=1,
64
- padding_mode='replicate') # out: 16 -> mp 8
65
- self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
66
- self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
 
 
 
 
 
 
67
 
68
  feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
69
  self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
@@ -97,21 +101,21 @@ class IFGeoNet(nn.Module):
97
  smooth_kernel_size=7,
98
  batch_size=cfg.batch_size,
99
  )
100
-
101
  self.l1_loss = nn.SmoothL1Loss()
102
 
103
  def forward(self, batch):
104
-
105
  if "body_voxels" in batch.keys():
106
  x_smpl = batch["body_voxels"]
107
  else:
108
  with torch.no_grad():
109
  self.voxelization.update_param(batch["voxel_faces"])
110
- x_smpl = self.voxelization(batch["voxel_verts"])[:, 0] #[B, 128, 128, 128]
111
-
112
  p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
113
- batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
114
- x = batch["depth_voxels"] #[B, 128, 128, 128]
115
 
116
  x = x.unsqueeze(1)
117
  x_smpl = x_smpl.unsqueeze(1)
@@ -119,63 +123,67 @@ class IFGeoNet(nn.Module):
119
  p = p.unsqueeze(1).unsqueeze(1)
120
 
121
  # partial inputs feature extraction
122
- feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners = True)
123
  net_partial = self.actvn(self.conv_in_partial(x))
124
  net_partial = self.partial_conv_in_bn(net_partial)
125
- net_partial = self.maxpool(net_partial) # out 64
126
 
127
  # smpl inputs feature extraction
128
  # feature_0_smpl = F.grid_sample(x_smpl, p, padding_mode='border', align_corners = True)
129
  net_smpl = self.actvn(self.conv_in_smpl(x_smpl))
130
  net_smpl = self.smpl_conv_in_bn(net_smpl)
131
- net_smpl = self.maxpool(net_smpl) # out 64
132
  net_smpl = self.SA(net_smpl)
133
-
134
  # Feature fusion
135
  net = self.actvn(self.conv_0_fusion(torch.concat([net_partial, net_smpl], dim=1)))
136
  net = self.actvn(self.conv_0_1_fusion(net))
137
  net = self.conv0_1_bn_fusion(net)
138
- feature_1_fused = F.grid_sample(net, p, padding_mode='border', align_corners = True)
139
  # net = self.maxpool(net) # out 64
140
 
141
  net = self.actvn(self.conv_0(net))
142
  net = self.actvn(self.conv_0_1(net))
143
  net = self.conv0_1_bn(net)
144
- feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners = True)
145
- net = self.maxpool(net) # out 32
146
 
147
  net = self.actvn(self.conv_1(net))
148
  net = self.actvn(self.conv_1_1(net))
149
  net = self.conv1_1_bn(net)
150
- feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners = True)
151
- net = self.maxpool(net) # out 16
152
 
153
  net = self.actvn(self.conv_2(net))
154
  net = self.actvn(self.conv_2_1(net))
155
  net = self.conv2_1_bn(net)
156
- feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners = True)
157
- net = self.maxpool(net) # out 8
158
 
159
  net = self.actvn(self.conv_3(net))
160
  net = self.actvn(self.conv_3_1(net))
161
  net = self.conv3_1_bn(net)
162
- feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners = True)
163
- net = self.maxpool(net) # out 4
164
 
165
  net = self.actvn(self.conv_4(net))
166
  net = self.actvn(self.conv_4_1(net))
167
  net = self.conv4_1_bn(net)
168
- feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners = True) # out 2
169
 
170
  # here every channel corresponse to one feature.
171
 
172
- features = torch.cat((feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4,
173
- feature_5, feature_6),
174
- dim=1) # (B, features, 1,7,sample_num)
 
 
 
 
175
  shape = features.shape
176
  features = torch.reshape(
177
- features,
178
- (shape[0], shape[1] * shape[3], shape[4])) # (B, featues_per_sample, samples_num)
179
  # (B, featue_size, samples_num)
180
  features = torch.cat((features, p_features), dim=1)
181
 
@@ -183,7 +191,7 @@ class IFGeoNet(nn.Module):
183
  net = self.actvn(self.fc_1(net))
184
  net = self.actvn(self.fc_2(net))
185
  net = self.fc_out(net).squeeze(1)
186
-
187
  return net
188
 
189
  def compute_loss(self, prds, tgts):
 
8
 
9
 
10
  class SelfAttention(torch.nn.Module):
 
11
  def __init__(self, in_channels, out_channels):
12
  super().__init__()
13
+ self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate')
14
+ self.attention = nn.Conv3d(
15
+ in_channels,
16
+ out_channels,
17
+ kernel_size=3,
18
+ padding=1,
19
+ padding_mode='replicate',
20
+ bias=False
21
+ )
 
 
22
  with torch.no_grad():
23
  self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
24
 
 
29
 
30
 
31
  class IFGeoNet(nn.Module):
 
32
  def __init__(self, cfg, hidden_dim=256):
33
  super(IFGeoNet, self).__init__()
34
 
35
+ self.conv_in_partial = nn.Conv3d(
36
+ 1, 16, 3, padding=1, padding_mode='replicate'
37
+ ) # out: 256 ->m.p. 128
38
 
39
+ self.conv_in_smpl = nn.Conv3d(
40
+ 1, 4, 3, padding=1, padding_mode='replicate'
41
+ ) # out: 256 ->m.p. 128
42
 
43
  self.SA = SelfAttention(4, 4)
44
+ self.conv_0_fusion = nn.Conv3d(
45
+ 16 + 4, 32, 3, padding=1, padding_mode='replicate'
46
+ ) # out: 128
47
+ self.conv_0_1_fusion = nn.Conv3d(
48
+ 32, 32, 3, padding=1, padding_mode='replicate'
49
+ ) # out: 128 ->m.p. 64
50
+
51
+ self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
52
+ self.conv_0_1 = nn.Conv3d(
53
+ 32, 32, 3, padding=1, padding_mode='replicate'
54
+ ) # out: 128 ->m.p. 64
55
+
56
+ self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
57
+ self.conv_1_1 = nn.Conv3d(
58
+ 64, 64, 3, padding=1, padding_mode='replicate'
59
+ ) # out: 64 -> mp 32
60
+
61
+ self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
62
+ self.conv_2_1 = nn.Conv3d(
63
+ 128, 128, 3, padding=1, padding_mode='replicate'
64
+ ) # out: 32 -> mp 16
65
+ self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
66
+ self.conv_3_1 = nn.Conv3d(
67
+ 128, 128, 3, padding=1, padding_mode='replicate'
68
+ ) # out: 16 -> mp 8
69
+ self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
70
+ self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
71
 
72
  feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
73
  self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
 
101
  smooth_kernel_size=7,
102
  batch_size=cfg.batch_size,
103
  )
104
+
105
  self.l1_loss = nn.SmoothL1Loss()
106
 
107
  def forward(self, batch):
108
+
109
  if "body_voxels" in batch.keys():
110
  x_smpl = batch["body_voxels"]
111
  else:
112
  with torch.no_grad():
113
  self.voxelization.update_param(batch["voxel_faces"])
114
+ x_smpl = self.voxelization(batch["voxel_verts"])[:, 0] #[B, 128, 128, 128]
115
+
116
  p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
117
+ batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
118
+ x = batch["depth_voxels"] #[B, 128, 128, 128]
119
 
120
  x = x.unsqueeze(1)
121
  x_smpl = x_smpl.unsqueeze(1)
 
123
  p = p.unsqueeze(1).unsqueeze(1)
124
 
125
  # partial inputs feature extraction
126
+ feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True)
127
  net_partial = self.actvn(self.conv_in_partial(x))
128
  net_partial = self.partial_conv_in_bn(net_partial)
129
+ net_partial = self.maxpool(net_partial) # out 64
130
 
131
  # smpl inputs feature extraction
132
  # feature_0_smpl = F.grid_sample(x_smpl, p, padding_mode='border', align_corners = True)
133
  net_smpl = self.actvn(self.conv_in_smpl(x_smpl))
134
  net_smpl = self.smpl_conv_in_bn(net_smpl)
135
+ net_smpl = self.maxpool(net_smpl) # out 64
136
  net_smpl = self.SA(net_smpl)
137
+
138
  # Feature fusion
139
  net = self.actvn(self.conv_0_fusion(torch.concat([net_partial, net_smpl], dim=1)))
140
  net = self.actvn(self.conv_0_1_fusion(net))
141
  net = self.conv0_1_bn_fusion(net)
142
+ feature_1_fused = F.grid_sample(net, p, padding_mode='border', align_corners=True)
143
  # net = self.maxpool(net) # out 64
144
 
145
  net = self.actvn(self.conv_0(net))
146
  net = self.actvn(self.conv_0_1(net))
147
  net = self.conv0_1_bn(net)
148
+ feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
149
+ net = self.maxpool(net) # out 32
150
 
151
  net = self.actvn(self.conv_1(net))
152
  net = self.actvn(self.conv_1_1(net))
153
  net = self.conv1_1_bn(net)
154
+ feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
155
+ net = self.maxpool(net) # out 16
156
 
157
  net = self.actvn(self.conv_2(net))
158
  net = self.actvn(self.conv_2_1(net))
159
  net = self.conv2_1_bn(net)
160
+ feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
161
+ net = self.maxpool(net) # out 8
162
 
163
  net = self.actvn(self.conv_3(net))
164
  net = self.actvn(self.conv_3_1(net))
165
  net = self.conv3_1_bn(net)
166
+ feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
167
+ net = self.maxpool(net) # out 4
168
 
169
  net = self.actvn(self.conv_4(net))
170
  net = self.actvn(self.conv_4_1(net))
171
  net = self.conv4_1_bn(net)
172
+ feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2
173
 
174
  # here every channel corresponse to one feature.
175
 
176
+ features = torch.cat(
177
+ (
178
+ feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
179
+ feature_6
180
+ ),
181
+ dim=1
182
+ ) # (B, features, 1,7,sample_num)
183
  shape = features.shape
184
  features = torch.reshape(
185
+ features, (shape[0], shape[1] * shape[3], shape[4])
186
+ ) # (B, featues_per_sample, samples_num)
187
  # (B, featue_size, samples_num)
188
  features = torch.cat((features, p_features), dim=1)
189
 
 
191
  net = self.actvn(self.fc_1(net))
192
  net = self.actvn(self.fc_2(net))
193
  net = self.fc_out(net).squeeze(1)
194
+
195
  return net
196
 
197
  def compute_loss(self, prds, tgts):
lib/net/IFGeoNet_nobody.py CHANGED
@@ -8,16 +8,17 @@ from lib.dataset.mesh_util import read_smpl_constants, SMPLX
8
 
9
 
10
  class SelfAttention(torch.nn.Module):
11
-
12
  def __init__(self, in_channels, out_channels):
13
  super().__init__()
14
  self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate')
15
- self.attention = nn.Conv3d(in_channels,
16
- out_channels,
17
- kernel_size=3,
18
- padding=1,
19
- padding_mode='replicate',
20
- bias=False)
 
 
21
  with torch.no_grad():
22
  self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
23
 
@@ -28,34 +29,39 @@ class SelfAttention(torch.nn.Module):
28
 
29
 
30
  class IFGeoNet(nn.Module):
31
-
32
  def __init__(self, cfg, hidden_dim=256):
33
  super(IFGeoNet, self).__init__()
34
 
35
- self.conv_in_partial = nn.Conv3d(1, 16, 3, padding=1,
36
- padding_mode='replicate') # out: 256 ->m.p. 128
 
37
 
38
  self.SA = SelfAttention(4, 4)
39
- self.conv_0_fusion = nn.Conv3d(16, 32, 3, padding=1, padding_mode='replicate') # out: 128
40
- self.conv_0_1_fusion = nn.Conv3d(32, 32, 3, padding=1,
41
- padding_mode='replicate') # out: 128 ->m.p. 64
42
-
43
- self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
44
- self.conv_0_1 = nn.Conv3d(32, 32, 3, padding=1,
45
- padding_mode='replicate') # out: 128 ->m.p. 64
46
-
47
- self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
48
- self.conv_1_1 = nn.Conv3d(64, 64, 3, padding=1,
49
- padding_mode='replicate') # out: 64 -> mp 32
50
-
51
- self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
52
- self.conv_2_1 = nn.Conv3d(128, 128, 3, padding=1,
53
- padding_mode='replicate') # out: 32 -> mp 16
54
- self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
55
- self.conv_3_1 = nn.Conv3d(128, 128, 3, padding=1,
56
- padding_mode='replicate') # out: 16 -> mp 8
57
- self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
58
- self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
 
 
 
 
 
59
 
60
  feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
61
  self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
@@ -95,8 +101,8 @@ class IFGeoNet(nn.Module):
95
  def forward(self, batch):
96
 
97
  p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
98
- batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
99
- x = batch["depth_voxels"] #[B, 128, 128, 128]
100
 
101
  x = x.unsqueeze(1)
102
  p_features = p.transpose(1, -1)
@@ -106,7 +112,7 @@ class IFGeoNet(nn.Module):
106
  feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True)
107
  net_partial = self.actvn(self.conv_in_partial(x))
108
  net_partial = self.partial_conv_in_bn(net_partial)
109
- net_partial = self.maxpool(net_partial) # out 64
110
 
111
  # Feature fusion
112
  net = self.actvn(self.conv_0_fusion(net_partial))
@@ -119,40 +125,44 @@ class IFGeoNet(nn.Module):
119
  net = self.actvn(self.conv_0_1(net))
120
  net = self.conv0_1_bn(net)
121
  feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
122
- net = self.maxpool(net) # out 32
123
 
124
  net = self.actvn(self.conv_1(net))
125
  net = self.actvn(self.conv_1_1(net))
126
  net = self.conv1_1_bn(net)
127
  feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
128
- net = self.maxpool(net) # out 16
129
 
130
  net = self.actvn(self.conv_2(net))
131
  net = self.actvn(self.conv_2_1(net))
132
  net = self.conv2_1_bn(net)
133
  feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
134
- net = self.maxpool(net) # out 8
135
 
136
  net = self.actvn(self.conv_3(net))
137
  net = self.actvn(self.conv_3_1(net))
138
  net = self.conv3_1_bn(net)
139
  feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
140
- net = self.maxpool(net) # out 4
141
 
142
  net = self.actvn(self.conv_4(net))
143
  net = self.actvn(self.conv_4_1(net))
144
  net = self.conv4_1_bn(net)
145
- feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2
146
 
147
  # here every channel corresponse to one feature.
148
 
149
- features = torch.cat((feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4,
150
- feature_5, feature_6),
151
- dim=1) # (B, features, 1,7,sample_num)
 
 
 
 
152
  shape = features.shape
153
  features = torch.reshape(
154
- features,
155
- (shape[0], shape[1] * shape[3], shape[4])) # (B, featues_per_sample, samples_num)
156
  # (B, featue_size, samples_num)
157
  features = torch.cat((features, p_features), dim=1)
158
 
@@ -167,4 +177,4 @@ class IFGeoNet(nn.Module):
167
 
168
  loss = self.l1_loss(prds, tgts)
169
 
170
- return loss
 
8
 
9
 
10
  class SelfAttention(torch.nn.Module):
 
11
  def __init__(self, in_channels, out_channels):
12
  super().__init__()
13
  self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate')
14
+ self.attention = nn.Conv3d(
15
+ in_channels,
16
+ out_channels,
17
+ kernel_size=3,
18
+ padding=1,
19
+ padding_mode='replicate',
20
+ bias=False
21
+ )
22
  with torch.no_grad():
23
  self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
24
 
 
29
 
30
 
31
  class IFGeoNet(nn.Module):
 
32
  def __init__(self, cfg, hidden_dim=256):
33
  super(IFGeoNet, self).__init__()
34
 
35
+ self.conv_in_partial = nn.Conv3d(
36
+ 1, 16, 3, padding=1, padding_mode='replicate'
37
+ ) # out: 256 ->m.p. 128
38
 
39
  self.SA = SelfAttention(4, 4)
40
+ self.conv_0_fusion = nn.Conv3d(16, 32, 3, padding=1, padding_mode='replicate') # out: 128
41
+ self.conv_0_1_fusion = nn.Conv3d(
42
+ 32, 32, 3, padding=1, padding_mode='replicate'
43
+ ) # out: 128 ->m.p. 64
44
+
45
+ self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
46
+ self.conv_0_1 = nn.Conv3d(
47
+ 32, 32, 3, padding=1, padding_mode='replicate'
48
+ ) # out: 128 ->m.p. 64
49
+
50
+ self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
51
+ self.conv_1_1 = nn.Conv3d(
52
+ 64, 64, 3, padding=1, padding_mode='replicate'
53
+ ) # out: 64 -> mp 32
54
+
55
+ self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
56
+ self.conv_2_1 = nn.Conv3d(
57
+ 128, 128, 3, padding=1, padding_mode='replicate'
58
+ ) # out: 32 -> mp 16
59
+ self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
60
+ self.conv_3_1 = nn.Conv3d(
61
+ 128, 128, 3, padding=1, padding_mode='replicate'
62
+ ) # out: 16 -> mp 8
63
+ self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
64
+ self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
65
 
66
  feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
67
  self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
 
101
  def forward(self, batch):
102
 
103
  p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
104
+ batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
105
+ x = batch["depth_voxels"] #[B, 128, 128, 128]
106
 
107
  x = x.unsqueeze(1)
108
  p_features = p.transpose(1, -1)
 
112
  feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True)
113
  net_partial = self.actvn(self.conv_in_partial(x))
114
  net_partial = self.partial_conv_in_bn(net_partial)
115
+ net_partial = self.maxpool(net_partial) # out 64
116
 
117
  # Feature fusion
118
  net = self.actvn(self.conv_0_fusion(net_partial))
 
125
  net = self.actvn(self.conv_0_1(net))
126
  net = self.conv0_1_bn(net)
127
  feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
128
+ net = self.maxpool(net) # out 32
129
 
130
  net = self.actvn(self.conv_1(net))
131
  net = self.actvn(self.conv_1_1(net))
132
  net = self.conv1_1_bn(net)
133
  feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
134
+ net = self.maxpool(net) # out 16
135
 
136
  net = self.actvn(self.conv_2(net))
137
  net = self.actvn(self.conv_2_1(net))
138
  net = self.conv2_1_bn(net)
139
  feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
140
+ net = self.maxpool(net) # out 8
141
 
142
  net = self.actvn(self.conv_3(net))
143
  net = self.actvn(self.conv_3_1(net))
144
  net = self.conv3_1_bn(net)
145
  feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
146
+ net = self.maxpool(net) # out 4
147
 
148
  net = self.actvn(self.conv_4(net))
149
  net = self.actvn(self.conv_4_1(net))
150
  net = self.conv4_1_bn(net)
151
+ feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2
152
 
153
  # here every channel corresponse to one feature.
154
 
155
+ features = torch.cat(
156
+ (
157
+ feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
158
+ feature_6
159
+ ),
160
+ dim=1
161
+ ) # (B, features, 1,7,sample_num)
162
  shape = features.shape
163
  features = torch.reshape(
164
+ features, (shape[0], shape[1] * shape[3], shape[4])
165
+ ) # (B, featues_per_sample, samples_num)
166
  # (B, featue_size, samples_num)
167
  features = torch.cat((features, p_features), dim=1)
168
 
 
177
 
178
  loss = self.l1_loss(prds, tgts)
179
 
180
+ return loss
lib/net/NormalNet.py CHANGED
@@ -35,7 +35,6 @@ class NormalNet(BasePIFuNet):
35
  4. Classification.
36
  5. During training, error is calculated on all stacks.
37
  """
38
-
39
  def __init__(self, cfg):
40
 
41
  super(NormalNet, self).__init__()
@@ -65,9 +64,11 @@ class NormalNet(BasePIFuNet):
65
  item[0] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"
66
  ]
67
  self.in_nmlF_dim = sum(
68
- [item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"])
 
69
  self.in_nmlB_dim = sum(
70
- [item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"])
 
71
 
72
  self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
73
  self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
@@ -134,18 +135,20 @@ class NormalNet(BasePIFuNet):
134
  if 'mrf' in self.F_losses:
135
  mrf_F_loss = self.mrf_loss(
136
  F.interpolate(prd_F, scale_factor=scale_factor, mode='bicubic', align_corners=True),
137
- F.interpolate(tgt_F, scale_factor=scale_factor, mode='bicubic', align_corners=True))
 
138
  total_loss["netF"] += self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
139
  total_loss["mrf_F"] = self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
140
  if 'mrf' in self.B_losses:
141
  mrf_B_loss = self.mrf_loss(
142
  F.interpolate(prd_B, scale_factor=scale_factor, mode='bicubic', align_corners=True),
143
- F.interpolate(tgt_B, scale_factor=scale_factor, mode='bicubic', align_corners=True))
 
144
  total_loss["netB"] += self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
145
  total_loss["mrf_B"] = self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
146
 
147
  if 'gan' in self.ALL_losses:
148
-
149
  total_loss["netD"] = 0.0
150
 
151
  pred_fake = self.netD.forward(prd_B)
@@ -154,8 +157,8 @@ class NormalNet(BasePIFuNet):
154
  loss_D_real = self.gan_loss(pred_real, True)
155
  loss_G_fake = self.gan_loss(pred_fake, True)
156
 
157
- total_loss["netD"] += 0.5 * (
158
- loss_D_fake + loss_D_real) * self.B_losses_ratio[self.B_losses.index('gan')]
159
  total_loss["D_fake"] = loss_D_fake * self.B_losses_ratio[self.B_losses.index('gan')]
160
  total_loss["D_real"] = loss_D_real * self.B_losses_ratio[self.B_losses.index('gan')]
161
 
@@ -167,8 +170,8 @@ class NormalNet(BasePIFuNet):
167
  for i in range(2):
168
  for j in range(len(pred_fake[i]) - 1):
169
  loss_G_GAN_Feat += self.l1_loss(pred_fake[i][j], pred_real[i][j].detach())
170
- total_loss["netB"] += loss_G_GAN_Feat * self.B_losses_ratio[self.B_losses.index(
171
- 'gan_feat')]
172
  total_loss["G_GAN_Feat"] = loss_G_GAN_Feat * self.B_losses_ratio[
173
  self.B_losses.index('gan_feat')]
174
 
 
35
  4. Classification.
36
  5. During training, error is calculated on all stacks.
37
  """
 
38
  def __init__(self, cfg):
39
 
40
  super(NormalNet, self).__init__()
 
64
  item[0] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"
65
  ]
66
  self.in_nmlF_dim = sum(
67
+ [item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"]
68
+ )
69
  self.in_nmlB_dim = sum(
70
+ [item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"]
71
+ )
72
 
73
  self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
74
  self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
 
135
  if 'mrf' in self.F_losses:
136
  mrf_F_loss = self.mrf_loss(
137
  F.interpolate(prd_F, scale_factor=scale_factor, mode='bicubic', align_corners=True),
138
+ F.interpolate(tgt_F, scale_factor=scale_factor, mode='bicubic', align_corners=True)
139
+ )
140
  total_loss["netF"] += self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
141
  total_loss["mrf_F"] = self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
142
  if 'mrf' in self.B_losses:
143
  mrf_B_loss = self.mrf_loss(
144
  F.interpolate(prd_B, scale_factor=scale_factor, mode='bicubic', align_corners=True),
145
+ F.interpolate(tgt_B, scale_factor=scale_factor, mode='bicubic', align_corners=True)
146
+ )
147
  total_loss["netB"] += self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
148
  total_loss["mrf_B"] = self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
149
 
150
  if 'gan' in self.ALL_losses:
151
+
152
  total_loss["netD"] = 0.0
153
 
154
  pred_fake = self.netD.forward(prd_B)
 
157
  loss_D_real = self.gan_loss(pred_real, True)
158
  loss_G_fake = self.gan_loss(pred_fake, True)
159
 
160
+ total_loss["netD"] += 0.5 * (loss_D_fake + loss_D_real
161
+ ) * self.B_losses_ratio[self.B_losses.index('gan')]
162
  total_loss["D_fake"] = loss_D_fake * self.B_losses_ratio[self.B_losses.index('gan')]
163
  total_loss["D_real"] = loss_D_real * self.B_losses_ratio[self.B_losses.index('gan')]
164
 
 
170
  for i in range(2):
171
  for j in range(len(pred_fake[i]) - 1):
172
  loss_G_GAN_Feat += self.l1_loss(pred_fake[i][j], pred_real[i][j].detach())
173
+ total_loss["netB"] += loss_G_GAN_Feat * self.B_losses_ratio[
174
+ self.B_losses.index('gan_feat')]
175
  total_loss["G_GAN_Feat"] = loss_G_GAN_Feat * self.B_losses_ratio[
176
  self.B_losses.index('gan_feat')]
177
 
lib/net/geometry.py CHANGED
@@ -19,12 +19,12 @@ import numpy as np
19
  import numbers
20
  from torch.nn import functional as F
21
  from einops.einops import rearrange
22
-
23
  """
24
  Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
25
  Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
26
  """
27
 
 
28
  def quaternion_to_rotation_matrix(quat):
29
  """Convert quaternion coefficients to rotation matrix.
30
  Args:
@@ -42,11 +42,13 @@ def quaternion_to_rotation_matrix(quat):
42
  wx, wy, wz = w * x, w * y, w * z
43
  xy, xz, yz = x * y, x * z, y * z
44
 
45
- rotMat = torch.stack([
46
- w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
47
- 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
48
- ],
49
- dim=1).view(B, 3, 3)
 
 
50
  return rotMat
51
 
52
 
@@ -56,7 +58,7 @@ def index(feat, uv):
56
  :param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1]
57
  :return: [B, C, N] image features at the uv coordinates
58
  """
59
- uv = uv.transpose(1, 2) # [B, N, 2]
60
 
61
  (B, N, _) = uv.shape
62
  C = feat.shape[1]
@@ -64,14 +66,14 @@ def index(feat, uv):
64
  if uv.shape[-1] == 3:
65
  # uv = uv[:,:,[2,1,0]]
66
  # uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...]
67
- uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3]
68
  else:
69
- uv = uv.unsqueeze(2) # [B, N, 1, 2]
70
 
71
  # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
72
  # for old versions, simply remove the aligned_corners argument.
73
- samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
74
- return samples.view(B, C, N) # [B, C, N]
75
 
76
 
77
  def orthogonal(points, calibrations, transforms=None):
@@ -84,7 +86,7 @@ def orthogonal(points, calibrations, transforms=None):
84
  """
85
  rot = calibrations[:, :3, :3]
86
  trans = calibrations[:, :3, 3:4]
87
- pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
88
  if transforms is not None:
89
  scale = transforms[:2, :2]
90
  shift = transforms[:2, 2:3]
@@ -102,7 +104,7 @@ def perspective(points, calibrations, transforms=None):
102
  """
103
  rot = calibrations[:, :3, :3]
104
  trans = calibrations[:, :3, 3:4]
105
- homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
106
  xy = homo[:, :2, :] / homo[:, 2:3, :]
107
  if transforms is not None:
108
  scale = transforms[:2, :2]
@@ -187,7 +189,8 @@ def rotation_matrix_to_angle_axis(rotation_matrix):
187
  if rotation_matrix.shape[1:] == (3, 3):
188
  rot_mat = rotation_matrix.reshape(-1, 3, 3)
189
  hom = torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device).reshape(
190
- 1, 3, 1).expand(rot_mat.shape[0], -1, -1)
 
191
  rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
192
 
193
  quaternion = rotation_matrix_to_quaternion(rotation_matrix)
@@ -222,8 +225,9 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
222
  raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
223
 
224
  if not quaternion.shape[-1] == 4:
225
- raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format(
226
- quaternion.shape))
 
227
  # unpack input and compute conversion
228
  q1: torch.Tensor = quaternion[..., 1]
229
  q2: torch.Tensor = quaternion[..., 2]
@@ -276,11 +280,13 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
276
  raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
277
 
278
  if len(rotation_matrix.shape) > 3:
279
- raise ValueError("Input size must be a three dimensional tensor. Got {}".format(
280
- rotation_matrix.shape))
 
281
  if not rotation_matrix.shape[-2:] == (3, 4):
282
- raise ValueError("Input size must be a N x 3 x 4 tensor. Got {}".format(
283
- rotation_matrix.shape))
 
284
 
285
  rmat_t = torch.transpose(rotation_matrix, 1, 2)
286
 
@@ -347,8 +353,10 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
347
  mask_c3 = mask_c3.view(-1, 1).type_as(q3)
348
 
349
  q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
350
- q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa
351
- + t3_rep * mask_c3) # noqa
 
 
352
  q *= 0.5
353
  return q
354
 
@@ -389,6 +397,7 @@ def rot6d_to_rotmat(x):
389
  mat = torch.stack((b1, b2, b3), dim=-1)
390
  return mat
391
 
 
392
  def rotmat_to_rot6d(x):
393
  """Convert 3x3 rotation matrix to 6D rotation representation.
394
  Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
@@ -402,6 +411,7 @@ def rotmat_to_rot6d(x):
402
  x = x.reshape(batch_size, 6)
403
  return x
404
 
 
405
  def rotmat_to_angle(x):
406
  """Convert rotation to one-D angle.
407
  Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
@@ -440,12 +450,9 @@ def projection(pred_joints, pred_camera, retain_z=False):
440
  return pred_keypoints_2d
441
 
442
 
443
- def perspective_projection(points,
444
- rotation,
445
- translation,
446
- focal_length,
447
- camera_center,
448
- retain_z=False):
449
  """
450
  This function computes the perspective projection of a set of points.
451
  Input:
@@ -501,10 +508,12 @@ def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_si
501
  weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
502
 
503
  # least squares
504
- Q = np.array([
505
- F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
506
- O - np.reshape(joints_2d, -1)
507
- ]).T
 
 
508
  c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
509
 
510
  # weighted least squares
@@ -558,15 +567,12 @@ def estimate_translation(S, joints_2d, focal_length=5000., img_size=224., use_al
558
  S_i = S[i]
559
  joints_i = joints_2d[i]
560
  conf_i = joints_conf[i]
561
- trans[i] = estimate_translation_np(S_i,
562
- joints_i,
563
- conf_i,
564
- focal_length=focal_length[i],
565
- img_size=img_size[i])
566
  return torch.from_numpy(trans).to(device)
567
 
568
 
569
-
570
  def Rot_y(angle, category="torch", prepend_dim=True, device=None):
571
  """Rotate around y-axis by angle
572
  Args:
@@ -574,11 +580,13 @@ def Rot_y(angle, category="torch", prepend_dim=True, device=None):
574
  prepend_dim: prepend an extra dimension
575
  Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
576
  """
577
- m = np.array([
578
- [np.cos(angle), 0.0, np.sin(angle)],
579
- [0.0, 1.0, 0.0],
580
- [-np.sin(angle), 0.0, np.cos(angle)],
581
- ])
 
 
582
  if category == "torch":
583
  if prepend_dim:
584
  return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -600,11 +608,13 @@ def Rot_x(angle, category="torch", prepend_dim=True, device=None):
600
  prepend_dim: prepend an extra dimension
601
  Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
602
  """
603
- m = np.array([
604
- [1.0, 0.0, 0.0],
605
- [0.0, np.cos(angle), -np.sin(angle)],
606
- [0.0, np.sin(angle), np.cos(angle)],
607
- ])
 
 
608
  if category == "torch":
609
  if prepend_dim:
610
  return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -626,11 +636,13 @@ def Rot_z(angle, category="torch", prepend_dim=True, device=None):
626
  prepend_dim: prepend an extra dimension
627
  Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
628
  """
629
- m = np.array([
630
- [np.cos(angle), -np.sin(angle), 0.0],
631
- [np.sin(angle), np.cos(angle), 0.0],
632
- [0.0, 0.0, 1.0],
633
- ])
 
 
634
  if category == "torch":
635
  if prepend_dim:
636
  return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
@@ -672,7 +684,7 @@ def compute_twist_rotation(rotation_matrix, twist_axis):
672
  twist_rotation = quaternion_to_rotation_matrix(twist_quaternion)
673
  twist_aa = quaternion_to_angle_axis(twist_quaternion)
674
 
675
- twist_angle = torch.sum(twist_aa, dim=1, keepdim=True) / torch.sum(
676
- twist_axis, dim=1, keepdim=True)
677
 
678
- return twist_rotation, twist_angle
 
19
  import numbers
20
  from torch.nn import functional as F
21
  from einops.einops import rearrange
 
22
  """
23
  Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
24
  Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
25
  """
26
 
27
+
28
  def quaternion_to_rotation_matrix(quat):
29
  """Convert quaternion coefficients to rotation matrix.
30
  Args:
 
42
  wx, wy, wz = w * x, w * y, w * z
43
  xy, xz, yz = x * y, x * z, y * z
44
 
45
+ rotMat = torch.stack(
46
+ [
47
+ w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
48
+ 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
49
+ ],
50
+ dim=1
51
+ ).view(B, 3, 3)
52
  return rotMat
53
 
54
 
 
58
  :param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1]
59
  :return: [B, C, N] image features at the uv coordinates
60
  """
61
+ uv = uv.transpose(1, 2) # [B, N, 2]
62
 
63
  (B, N, _) = uv.shape
64
  C = feat.shape[1]
 
66
  if uv.shape[-1] == 3:
67
  # uv = uv[:,:,[2,1,0]]
68
  # uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...]
69
+ uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3]
70
  else:
71
+ uv = uv.unsqueeze(2) # [B, N, 1, 2]
72
 
73
  # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
74
  # for old versions, simply remove the aligned_corners argument.
75
+ samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
76
+ return samples.view(B, C, N) # [B, C, N]
77
 
78
 
79
  def orthogonal(points, calibrations, transforms=None):
 
86
  """
87
  rot = calibrations[:, :3, :3]
88
  trans = calibrations[:, :3, 3:4]
89
+ pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
90
  if transforms is not None:
91
  scale = transforms[:2, :2]
92
  shift = transforms[:2, 2:3]
 
104
  """
105
  rot = calibrations[:, :3, :3]
106
  trans = calibrations[:, :3, 3:4]
107
+ homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
108
  xy = homo[:, :2, :] / homo[:, 2:3, :]
109
  if transforms is not None:
110
  scale = transforms[:2, :2]
 
189
  if rotation_matrix.shape[1:] == (3, 3):
190
  rot_mat = rotation_matrix.reshape(-1, 3, 3)
191
  hom = torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device).reshape(
192
+ 1, 3, 1
193
+ ).expand(rot_mat.shape[0], -1, -1)
194
  rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
195
 
196
  quaternion = rotation_matrix_to_quaternion(rotation_matrix)
 
225
  raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
226
 
227
  if not quaternion.shape[-1] == 4:
228
+ raise ValueError(
229
+ "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
230
+ )
231
  # unpack input and compute conversion
232
  q1: torch.Tensor = quaternion[..., 1]
233
  q2: torch.Tensor = quaternion[..., 2]
 
280
  raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
281
 
282
  if len(rotation_matrix.shape) > 3:
283
+ raise ValueError(
284
+ "Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape)
285
+ )
286
  if not rotation_matrix.shape[-2:] == (3, 4):
287
+ raise ValueError(
288
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(rotation_matrix.shape)
289
+ )
290
 
291
  rmat_t = torch.transpose(rotation_matrix, 1, 2)
292
 
 
353
  mask_c3 = mask_c3.view(-1, 1).type_as(q3)
354
 
355
  q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
356
+ q /= torch.sqrt(
357
+ t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa
358
+ + t3_rep * mask_c3
359
+ ) # noqa
360
  q *= 0.5
361
  return q
362
 
 
397
  mat = torch.stack((b1, b2, b3), dim=-1)
398
  return mat
399
 
400
+
401
  def rotmat_to_rot6d(x):
402
  """Convert 3x3 rotation matrix to 6D rotation representation.
403
  Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
 
411
  x = x.reshape(batch_size, 6)
412
  return x
413
 
414
+
415
  def rotmat_to_angle(x):
416
  """Convert rotation to one-D angle.
417
  Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
 
450
  return pred_keypoints_2d
451
 
452
 
453
+ def perspective_projection(
454
+ points, rotation, translation, focal_length, camera_center, retain_z=False
455
+ ):
 
 
 
456
  """
457
  This function computes the perspective projection of a set of points.
458
  Input:
 
508
  weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
509
 
510
  # least squares
511
+ Q = np.array(
512
+ [
513
+ F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
514
+ O - np.reshape(joints_2d, -1)
515
+ ]
516
+ ).T
517
  c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
518
 
519
  # weighted least squares
 
567
  S_i = S[i]
568
  joints_i = joints_2d[i]
569
  conf_i = joints_conf[i]
570
+ trans[i] = estimate_translation_np(
571
+ S_i, joints_i, conf_i, focal_length=focal_length[i], img_size=img_size[i]
572
+ )
 
 
573
  return torch.from_numpy(trans).to(device)
574
 
575
 
 
576
  def Rot_y(angle, category="torch", prepend_dim=True, device=None):
577
  """Rotate around y-axis by angle
578
  Args:
 
580
  prepend_dim: prepend an extra dimension
581
  Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
582
  """
583
+ m = np.array(
584
+ [
585
+ [np.cos(angle), 0.0, np.sin(angle)],
586
+ [0.0, 1.0, 0.0],
587
+ [-np.sin(angle), 0.0, np.cos(angle)],
588
+ ]
589
+ )
590
  if category == "torch":
591
  if prepend_dim:
592
  return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
 
608
  prepend_dim: prepend an extra dimension
609
  Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
610
  """
611
+ m = np.array(
612
+ [
613
+ [1.0, 0.0, 0.0],
614
+ [0.0, np.cos(angle), -np.sin(angle)],
615
+ [0.0, np.sin(angle), np.cos(angle)],
616
+ ]
617
+ )
618
  if category == "torch":
619
  if prepend_dim:
620
  return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
 
636
  prepend_dim: prepend an extra dimension
637
  Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
638
  """
639
+ m = np.array(
640
+ [
641
+ [np.cos(angle), -np.sin(angle), 0.0],
642
+ [np.sin(angle), np.cos(angle), 0.0],
643
+ [0.0, 0.0, 1.0],
644
+ ]
645
+ )
646
  if category == "torch":
647
  if prepend_dim:
648
  return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
 
684
  twist_rotation = quaternion_to_rotation_matrix(twist_quaternion)
685
  twist_aa = quaternion_to_angle_axis(twist_quaternion)
686
 
687
+ twist_angle = torch.sum(twist_aa, dim=1,
688
+ keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True)
689
 
690
+ return twist_rotation, twist_angle
lib/net/net_util.py CHANGED
@@ -71,11 +71,10 @@ def init_weights(net, init_type="normal", init_gain=0.02):
71
  We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
72
  work better for some applications. Feel free to try yourself.
73
  """
74
-
75
- def init_func(m): # define the initialization function
76
  classname = m.__class__.__name__
77
- if hasattr(m, "weight") and (classname.find("Conv") != -1 or
78
- classname.find("Linear") != -1):
79
  if init_type == "normal":
80
  init.normal_(m.weight.data, 0.0, init_gain)
81
  elif init_type == "xavier":
@@ -85,17 +84,19 @@ def init_weights(net, init_type="normal", init_gain=0.02):
85
  elif init_type == "orthogonal":
86
  init.orthogonal_(m.weight.data, gain=init_gain)
87
  else:
88
- raise NotImplementedError("initialization method [%s] is not implemented" %
89
- init_type)
 
90
  if hasattr(m, "bias") and m.bias is not None:
91
  init.constant_(m.bias.data, 0.0)
92
- elif (classname.find("BatchNorm2d") !=
93
- -1): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
 
94
  init.normal_(m.weight.data, 1.0, init_gain)
95
  init.constant_(m.bias.data, 0.0)
96
 
97
  # print('initialize network with %s' % init_type)
98
- net.apply(init_func) # apply the initialization function <init_func>
99
 
100
 
101
  def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]):
@@ -110,7 +111,7 @@ def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]):
110
  """
111
  if len(gpu_ids) > 0:
112
  assert torch.cuda.is_available()
113
- net = torch.nn.DataParallel(net) # multi-GPUs
114
  init_weights(net, init_type, init_gain=init_gain)
115
  return net
116
 
@@ -127,13 +128,9 @@ def imageSpaceRotation(xy, rot):
127
  return (disp * xy).sum(dim=1)
128
 
129
 
130
- def cal_gradient_penalty(netD,
131
- real_data,
132
- fake_data,
133
- device,
134
- type="mixed",
135
- constant=1.0,
136
- lambda_gp=10.0):
137
  """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
138
 
139
  Arguments:
@@ -155,9 +152,11 @@ def cal_gradient_penalty(netD,
155
  interpolatesv = fake_data
156
  elif type == "mixed":
157
  alpha = torch.rand(real_data.shape[0], 1)
158
- alpha = (alpha.expand(real_data.shape[0],
159
- real_data.nelement() //
160
- real_data.shape[0]).contiguous().view(*real_data.shape))
 
 
161
  alpha = alpha.to(device)
162
  interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
163
  else:
@@ -172,9 +171,9 @@ def cal_gradient_penalty(netD,
172
  retain_graph=True,
173
  only_inputs=True,
174
  )
175
- gradients = gradients[0].view(real_data.size(0), -1) # flat the data
176
- gradient_penalty = ((
177
- (gradients + 1e-16).norm(2, dim=1) - constant)**2).mean() * lambda_gp # added eps
178
  return gradient_penalty, gradients
179
  else:
180
  return 0.0, None
@@ -201,13 +200,11 @@ def get_norm_layer(norm_type="instance"):
201
 
202
 
203
  class Flatten(nn.Module):
204
-
205
  def forward(self, input):
206
  return input.view(input.size(0), -1)
207
 
208
 
209
  class ConvBlock(nn.Module):
210
-
211
  def __init__(self, in_planes, out_planes, opt):
212
  super(ConvBlock, self).__init__()
213
  [k, s, d, p] = opt.conv3x3
@@ -258,5 +255,3 @@ class ConvBlock(nn.Module):
258
  out3 += residual
259
 
260
  return out3
261
-
262
-
 
71
  We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
72
  work better for some applications. Feel free to try yourself.
73
  """
74
+ def init_func(m): # define the initialization function
 
75
  classname = m.__class__.__name__
76
+ if hasattr(m,
77
+ "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
78
  if init_type == "normal":
79
  init.normal_(m.weight.data, 0.0, init_gain)
80
  elif init_type == "xavier":
 
84
  elif init_type == "orthogonal":
85
  init.orthogonal_(m.weight.data, gain=init_gain)
86
  else:
87
+ raise NotImplementedError(
88
+ "initialization method [%s] is not implemented" % init_type
89
+ )
90
  if hasattr(m, "bias") and m.bias is not None:
91
  init.constant_(m.bias.data, 0.0)
92
+ elif (
93
+ classname.find("BatchNorm2d") != -1
94
+ ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
95
  init.normal_(m.weight.data, 1.0, init_gain)
96
  init.constant_(m.bias.data, 0.0)
97
 
98
  # print('initialize network with %s' % init_type)
99
+ net.apply(init_func) # apply the initialization function <init_func>
100
 
101
 
102
  def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]):
 
111
  """
112
  if len(gpu_ids) > 0:
113
  assert torch.cuda.is_available()
114
+ net = torch.nn.DataParallel(net) # multi-GPUs
115
  init_weights(net, init_type, init_gain=init_gain)
116
  return net
117
 
 
128
  return (disp * xy).sum(dim=1)
129
 
130
 
131
+ def cal_gradient_penalty(
132
+ netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0
133
+ ):
 
 
 
 
134
  """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
135
 
136
  Arguments:
 
152
  interpolatesv = fake_data
153
  elif type == "mixed":
154
  alpha = torch.rand(real_data.shape[0], 1)
155
+ alpha = (
156
+ alpha.expand(real_data.shape[0],
157
+ real_data.nelement() //
158
+ real_data.shape[0]).contiguous().view(*real_data.shape)
159
+ )
160
  alpha = alpha.to(device)
161
  interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
162
  else:
 
171
  retain_graph=True,
172
  only_inputs=True,
173
  )
174
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
175
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant)**
176
+ 2).mean() * lambda_gp # added eps
177
  return gradient_penalty, gradients
178
  else:
179
  return 0.0, None
 
200
 
201
 
202
  class Flatten(nn.Module):
 
203
  def forward(self, input):
204
  return input.view(input.size(0), -1)
205
 
206
 
207
  class ConvBlock(nn.Module):
 
208
  def __init__(self, in_planes, out_planes, opt):
209
  super(ConvBlock, self).__init__()
210
  [k, s, d, p] = opt.conv3x3
 
255
  out3 += residual
256
 
257
  return out3
 
 
lib/net/voxelize.py CHANGED
@@ -13,7 +13,6 @@ class VoxelizationFunction(Function):
13
  Definition of differentiable voxelization function
14
  Currently implemented only for cuda Tensors
15
  """
16
-
17
  @staticmethod
18
  def forward(
19
  ctx,
@@ -48,12 +47,15 @@ class VoxelizationFunction(Function):
48
  smpl_face_code = smpl_face_code.contiguous()
49
  smpl_tetrahedrons = smpl_tetrahedrons.contiguous()
50
 
51
- occ_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, ctx.volume_res,
52
- ctx.volume_res).fill_(0.0)
53
- semantic_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, ctx.volume_res,
54
- ctx.volume_res, 3).fill_(0.0)
55
- weight_sum_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, ctx.volume_res,
56
- ctx.volume_res).fill_(1e-3)
 
 
 
57
 
58
  # occ_volume [B, volume_res, volume_res, volume_res]
59
  # semantic_volume [B, volume_res, volume_res, volume_res, 3]
@@ -80,7 +82,6 @@ class Voxelization(nn.Module):
80
  """
81
  Wrapper around the autograd function VoxelizationFunction
82
  """
83
-
84
  def __init__(
85
  self,
86
  smpl_vertex_code,
@@ -151,21 +152,25 @@ class Voxelization(nn.Module):
151
  self.sigma,
152
  self.smooth_kernel_size,
153
  )
154
- return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw)
155
 
156
  def vertices_to_faces(self, vertices):
157
  assert vertices.ndimension() == 3
158
  bs, nv = vertices.shape[:2]
159
- face = (self.smpl_face_indices_batch +
160
- (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None])
 
 
161
  vertices_ = vertices.reshape((bs * nv, 3))
162
  return vertices_[face.long()]
163
 
164
  def vertices_to_tetrahedrons(self, vertices):
165
  assert vertices.ndimension() == 3
166
  bs, nv = vertices.shape[:2]
167
- tets = (self.smpl_tetraderon_indices_batch +
168
- (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None])
 
 
169
  vertices_ = vertices.reshape((bs * nv, 3))
170
  return vertices_[tets.long()]
171
 
@@ -174,8 +179,9 @@ class Voxelization(nn.Module):
174
  assert face_verts.shape[2] == 3
175
  assert face_verts.shape[3] == 3
176
  bs, nf = face_verts.shape[:2]
177
- face_centers = (face_verts[:, :, 0, :] + face_verts[:, :, 1, :] +
178
- face_verts[:, :, 2, :]) / 3.0
 
179
  face_centers = face_centers.reshape((bs, nf, 3))
180
  return face_centers
181
 
 
13
  Definition of differentiable voxelization function
14
  Currently implemented only for cuda Tensors
15
  """
 
16
  @staticmethod
17
  def forward(
18
  ctx,
 
47
  smpl_face_code = smpl_face_code.contiguous()
48
  smpl_tetrahedrons = smpl_tetrahedrons.contiguous()
49
 
50
+ occ_volume = torch.cuda.FloatTensor(
51
+ ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res
52
+ ).fill_(0.0)
53
+ semantic_volume = torch.cuda.FloatTensor(
54
+ ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res, 3
55
+ ).fill_(0.0)
56
+ weight_sum_volume = torch.cuda.FloatTensor(
57
+ ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res
58
+ ).fill_(1e-3)
59
 
60
  # occ_volume [B, volume_res, volume_res, volume_res]
61
  # semantic_volume [B, volume_res, volume_res, volume_res, 3]
 
82
  """
83
  Wrapper around the autograd function VoxelizationFunction
84
  """
 
85
  def __init__(
86
  self,
87
  smpl_vertex_code,
 
152
  self.sigma,
153
  self.smooth_kernel_size,
154
  )
155
+ return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw)
156
 
157
  def vertices_to_faces(self, vertices):
158
  assert vertices.ndimension() == 3
159
  bs, nv = vertices.shape[:2]
160
+ face = (
161
+ self.smpl_face_indices_batch +
162
+ (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]
163
+ )
164
  vertices_ = vertices.reshape((bs * nv, 3))
165
  return vertices_[face.long()]
166
 
167
  def vertices_to_tetrahedrons(self, vertices):
168
  assert vertices.ndimension() == 3
169
  bs, nv = vertices.shape[:2]
170
+ tets = (
171
+ self.smpl_tetraderon_indices_batch +
172
+ (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]
173
+ )
174
  vertices_ = vertices.reshape((bs * nv, 3))
175
  return vertices_[tets.long()]
176
 
 
179
  assert face_verts.shape[2] == 3
180
  assert face_verts.shape[3] == 3
181
  bs, nf = face_verts.shape[:2]
182
+ face_centers = (
183
+ face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + face_verts[:, :, 2, :]
184
+ ) / 3.0
185
  face_centers = face_centers.reshape((bs, nf, 3))
186
  return face_centers
187
 
lib/pixielib/models/FLAME.py CHANGED
@@ -27,7 +27,6 @@ class FLAMETex(nn.Module):
27
  FLAME texture converted from BFM:
28
  https://github.com/TimoBolkart/BFM_to_FLAME
29
  """
30
-
31
  def __init__(self, config):
32
  super(FLAMETex, self).__init__()
33
  if config.tex_type == "BFM":
@@ -54,8 +53,7 @@ class FLAMETex(nn.Module):
54
  n_tex = config.n_tex
55
  num_components = texture_basis.shape[1]
56
  texture_mean = torch.from_numpy(texture_mean).float()[None, ...]
57
- texture_basis = torch.from_numpy(
58
- texture_basis[:, :n_tex]).float()[None, ...]
59
  self.register_buffer("texture_mean", texture_mean)
60
  self.register_buffer("texture_basis", texture_basis)
61
 
@@ -64,10 +62,8 @@ class FLAMETex(nn.Module):
64
  texcode: [batchsize, n_tex]
65
  texture: [bz, 3, 256, 256], range: 0-1
66
  """
67
- texture = self.texture_mean + (self.texture_basis *
68
- texcode[:, None, :]).sum(-1)
69
- texture = texture.reshape(texcode.shape[0], 512, 512,
70
- 3).permute(0, 3, 1, 2)
71
  texture = F.interpolate(texture, [256, 256])
72
  texture = texture[:, [2, 1, 0], :, :]
73
  return texture
@@ -78,13 +74,13 @@ def texture_flame2smplx(cached_data, flame_texture, smplx_texture):
78
  TODO: pytorch version ==> grid sample
79
  """
80
  if smplx_texture.shape[0] != smplx_texture.shape[1]:
81
- print("SMPL-X texture not squared (%d != %d)" %
82
- (smplx_texture[0], smplx_texture[1]))
83
  return
84
  if smplx_texture.shape[0] != cached_data["target_resolution"]:
85
  print(
86
- "SMPL-X texture size does not match cached image resolution (%d != %d)"
87
- % (smplx_texture.shape[0], cached_data["target_resolution"]))
 
88
  return
89
  x_coords = cached_data["x_coords"]
90
  y_coords = cached_data["y_coords"]
@@ -98,11 +94,13 @@ def texture_flame2smplx(cached_data, flame_texture, smplx_texture):
98
  flame_texture.shape[0],
99
  ).astype(int)
100
  source_tex_coords[:, 1] = np.clip(
101
- flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0,
102
- flame_texture.shape[1]).astype(int)
103
 
104
  smplx_texture[y_coords[target_pixel_ids].astype(int),
105
- x_coords[target_pixel_ids].astype(int), :, ] = flame_texture[
106
- source_tex_coords[:, 0], source_tex_coords[:, 1]]
 
 
107
 
108
  return smplx_texture
 
27
  FLAME texture converted from BFM:
28
  https://github.com/TimoBolkart/BFM_to_FLAME
29
  """
 
30
  def __init__(self, config):
31
  super(FLAMETex, self).__init__()
32
  if config.tex_type == "BFM":
 
53
  n_tex = config.n_tex
54
  num_components = texture_basis.shape[1]
55
  texture_mean = torch.from_numpy(texture_mean).float()[None, ...]
56
+ texture_basis = torch.from_numpy(texture_basis[:, :n_tex]).float()[None, ...]
 
57
  self.register_buffer("texture_mean", texture_mean)
58
  self.register_buffer("texture_basis", texture_basis)
59
 
 
62
  texcode: [batchsize, n_tex]
63
  texture: [bz, 3, 256, 256], range: 0-1
64
  """
65
+ texture = self.texture_mean + (self.texture_basis * texcode[:, None, :]).sum(-1)
66
+ texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2)
 
 
67
  texture = F.interpolate(texture, [256, 256])
68
  texture = texture[:, [2, 1, 0], :, :]
69
  return texture
 
74
  TODO: pytorch version ==> grid sample
75
  """
76
  if smplx_texture.shape[0] != smplx_texture.shape[1]:
77
+ print("SMPL-X texture not squared (%d != %d)" % (smplx_texture[0], smplx_texture[1]))
 
78
  return
79
  if smplx_texture.shape[0] != cached_data["target_resolution"]:
80
  print(
81
+ "SMPL-X texture size does not match cached image resolution (%d != %d)" %
82
+ (smplx_texture.shape[0], cached_data["target_resolution"])
83
+ )
84
  return
85
  x_coords = cached_data["x_coords"]
86
  y_coords = cached_data["y_coords"]
 
94
  flame_texture.shape[0],
95
  ).astype(int)
96
  source_tex_coords[:, 1] = np.clip(
97
+ flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0, flame_texture.shape[1]
98
+ ).astype(int)
99
 
100
  smplx_texture[y_coords[target_pixel_ids].astype(int),
101
+ x_coords[target_pixel_ids].astype(int), :, ] = flame_texture[source_tex_coords[:,
102
+ 0],
103
+ source_tex_coords[:,
104
+ 1]]
105
 
106
  return smplx_texture
lib/pixielib/models/SMPLX.py CHANGED
@@ -209,452 +209,468 @@ extra_names = [
209
  SMPLX_names += extra_names
210
 
211
  part_indices = {}
212
- part_indices["body"] = np.array([
213
- 0,
214
- 1,
215
- 2,
216
- 3,
217
- 4,
218
- 5,
219
- 6,
220
- 7,
221
- 8,
222
- 9,
223
- 10,
224
- 11,
225
- 12,
226
- 13,
227
- 14,
228
- 15,
229
- 16,
230
- 17,
231
- 18,
232
- 19,
233
- 20,
234
- 21,
235
- 22,
236
- 23,
237
- 24,
238
- 123,
239
- 124,
240
- 125,
241
- 126,
242
- 127,
243
- 132,
244
- 134,
245
- 135,
246
- 136,
247
- 137,
248
- 138,
249
- 143,
250
- ])
251
- part_indices["torso"] = np.array([
252
- 0,
253
- 1,
254
- 2,
255
- 3,
256
- 6,
257
- 9,
258
- 12,
259
- 13,
260
- 14,
261
- 15,
262
- 16,
263
- 17,
264
- 18,
265
- 19,
266
- 22,
267
- 23,
268
- 24,
269
- 55,
270
- 56,
271
- 57,
272
- 58,
273
- 59,
274
- 76,
275
- 77,
276
- 78,
277
- 79,
278
- 80,
279
- 81,
280
- 82,
281
- 83,
282
- 84,
283
- 85,
284
- 86,
285
- 87,
286
- 88,
287
- 89,
288
- 90,
289
- 91,
290
- 92,
291
- 93,
292
- 94,
293
- 95,
294
- 96,
295
- 97,
296
- 98,
297
- 99,
298
- 100,
299
- 101,
300
- 102,
301
- 103,
302
- 104,
303
- 105,
304
- 106,
305
- 107,
306
- 108,
307
- 109,
308
- 110,
309
- 111,
310
- 112,
311
- 113,
312
- 114,
313
- 115,
314
- 116,
315
- 117,
316
- 118,
317
- 119,
318
- 120,
319
- 121,
320
- 122,
321
- 123,
322
- 124,
323
- 125,
324
- 126,
325
- 127,
326
- 128,
327
- 129,
328
- 130,
329
- 131,
330
- 132,
331
- 133,
332
- 134,
333
- 135,
334
- 136,
335
- 137,
336
- 138,
337
- 139,
338
- 140,
339
- 141,
340
- 142,
341
- 143,
342
- 144,
343
- ])
344
- part_indices["head"] = np.array([
345
- 12,
346
- 15,
347
- 22,
348
- 23,
349
- 24,
350
- 55,
351
- 56,
352
- 57,
353
- 58,
354
- 59,
355
- 60,
356
- 61,
357
- 62,
358
- 63,
359
- 64,
360
- 65,
361
- 66,
362
- 67,
363
- 68,
364
- 69,
365
- 70,
366
- 71,
367
- 72,
368
- 73,
369
- 74,
370
- 75,
371
- 76,
372
- 77,
373
- 78,
374
- 79,
375
- 80,
376
- 81,
377
- 82,
378
- 83,
379
- 84,
380
- 85,
381
- 86,
382
- 87,
383
- 88,
384
- 89,
385
- 90,
386
- 91,
387
- 92,
388
- 93,
389
- 94,
390
- 95,
391
- 96,
392
- 97,
393
- 98,
394
- 99,
395
- 100,
396
- 101,
397
- 102,
398
- 103,
399
- 104,
400
- 105,
401
- 106,
402
- 107,
403
- 108,
404
- 109,
405
- 110,
406
- 111,
407
- 112,
408
- 113,
409
- 114,
410
- 115,
411
- 116,
412
- 117,
413
- 118,
414
- 119,
415
- 120,
416
- 121,
417
- 122,
418
- 123,
419
- 125,
420
- 126,
421
- 134,
422
- 136,
423
- 137,
424
- ])
425
- part_indices["face"] = np.array([
426
- 55,
427
- 56,
428
- 57,
429
- 58,
430
- 59,
431
- 60,
432
- 61,
433
- 62,
434
- 63,
435
- 64,
436
- 65,
437
- 66,
438
- 67,
439
- 68,
440
- 69,
441
- 70,
442
- 71,
443
- 72,
444
- 73,
445
- 74,
446
- 75,
447
- 76,
448
- 77,
449
- 78,
450
- 79,
451
- 80,
452
- 81,
453
- 82,
454
- 83,
455
- 84,
456
- 85,
457
- 86,
458
- 87,
459
- 88,
460
- 89,
461
- 90,
462
- 91,
463
- 92,
464
- 93,
465
- 94,
466
- 95,
467
- 96,
468
- 97,
469
- 98,
470
- 99,
471
- 100,
472
- 101,
473
- 102,
474
- 103,
475
- 104,
476
- 105,
477
- 106,
478
- 107,
479
- 108,
480
- 109,
481
- 110,
482
- 111,
483
- 112,
484
- 113,
485
- 114,
486
- 115,
487
- 116,
488
- 117,
489
- 118,
490
- 119,
491
- 120,
492
- 121,
493
- 122,
494
- ])
495
- part_indices["upper"] = np.array([
496
- 12,
497
- 13,
498
- 14,
499
- 55,
500
- 56,
501
- 57,
502
- 58,
503
- 59,
504
- 60,
505
- 61,
506
- 62,
507
- 63,
508
- 64,
509
- 65,
510
- 66,
511
- 67,
512
- 68,
513
- 69,
514
- 70,
515
- 71,
516
- 72,
517
- 73,
518
- 74,
519
- 75,
520
- 76,
521
- 77,
522
- 78,
523
- 79,
524
- 80,
525
- 81,
526
- 82,
527
- 83,
528
- 84,
529
- 85,
530
- 86,
531
- 87,
532
- 88,
533
- 89,
534
- 90,
535
- 91,
536
- 92,
537
- 93,
538
- 94,
539
- 95,
540
- 96,
541
- 97,
542
- 98,
543
- 99,
544
- 100,
545
- 101,
546
- 102,
547
- 103,
548
- 104,
549
- 105,
550
- 106,
551
- 107,
552
- 108,
553
- 109,
554
- 110,
555
- 111,
556
- 112,
557
- 113,
558
- 114,
559
- 115,
560
- 116,
561
- 117,
562
- 118,
563
- 119,
564
- 120,
565
- 121,
566
- 122,
567
- ])
568
- part_indices["hand"] = np.array([
569
- 20,
570
- 21,
571
- 25,
572
- 26,
573
- 27,
574
- 28,
575
- 29,
576
- 30,
577
- 31,
578
- 32,
579
- 33,
580
- 34,
581
- 35,
582
- 36,
583
- 37,
584
- 38,
585
- 39,
586
- 40,
587
- 41,
588
- 42,
589
- 43,
590
- 44,
591
- 45,
592
- 46,
593
- 47,
594
- 48,
595
- 49,
596
- 50,
597
- 51,
598
- 52,
599
- 53,
600
- 54,
601
- 128,
602
- 129,
603
- 130,
604
- 131,
605
- 133,
606
- 139,
607
- 140,
608
- 141,
609
- 142,
610
- 144,
611
- ])
612
- part_indices["left_hand"] = np.array([
613
- 20,
614
- 25,
615
- 26,
616
- 27,
617
- 28,
618
- 29,
619
- 30,
620
- 31,
621
- 32,
622
- 33,
623
- 34,
624
- 35,
625
- 36,
626
- 37,
627
- 38,
628
- 39,
629
- 128,
630
- 129,
631
- 130,
632
- 131,
633
- 133,
634
- ])
635
- part_indices["right_hand"] = np.array([
636
- 21,
637
- 40,
638
- 41,
639
- 42,
640
- 43,
641
- 44,
642
- 45,
643
- 46,
644
- 47,
645
- 48,
646
- 49,
647
- 50,
648
- 51,
649
- 52,
650
- 53,
651
- 54,
652
- 139,
653
- 140,
654
- 141,
655
- 142,
656
- 144,
657
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  # kinematic tree
659
  head_kin_chain = [15, 12, 9, 6, 3, 0]
660
 
@@ -691,13 +707,12 @@ class SMPLX(nn.Module):
691
  Given smplx parameters, this class generates a differentiable SMPLX function
692
  which outputs a mesh and 3D joints
693
  """
694
-
695
  def __init__(self, config):
696
  super(SMPLX, self).__init__()
697
  # print("creating the SMPLX Decoder")
698
  ss = np.load(config.smplx_model_path, allow_pickle=True)
699
  smplx_model = Struct(**ss)
700
-
701
  self.dtype = torch.float32
702
  self.register_buffer(
703
  "faces_tensor",
@@ -705,8 +720,8 @@ class SMPLX(nn.Module):
705
  )
706
  # The vertices of the template model
707
  self.register_buffer(
708
- "v_template",
709
- to_tensor(to_np(smplx_model.v_template), dtype=self.dtype))
710
  # The shape components and expression
711
  # expression space is the same as FLAME
712
  shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype)
@@ -721,21 +736,18 @@ class SMPLX(nn.Module):
721
  # The pose components
722
  num_pose_basis = smplx_model.posedirs.shape[-1]
723
  posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T
724
- self.register_buffer("posedirs",
725
- to_tensor(to_np(posedirs), dtype=self.dtype))
726
  self.register_buffer(
727
- "J_regressor",
728
- to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype))
729
  parents = to_tensor(to_np(smplx_model.kintree_table[0])).long()
730
  parents[0] = -1
731
  self.register_buffer("parents", parents)
732
- self.register_buffer(
733
- "lbs_weights",
734
- to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
735
  # for face keypoints
736
  self.register_buffer(
737
- "lmk_faces_idx",
738
- torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long))
739
  self.register_buffer(
740
  "lmk_bary_coords",
741
  torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype),
@@ -746,24 +758,20 @@ class SMPLX(nn.Module):
746
  )
747
  self.register_buffer(
748
  "dynamic_lmk_bary_coords",
749
- torch.tensor(smplx_model.dynamic_lmk_bary_coords,
750
- dtype=self.dtype),
751
  )
752
  # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks
753
- self.register_buffer("head_kin_chain",
754
- torch.tensor(head_kin_chain, dtype=torch.long))
755
 
756
  # -- initialize parameters
757
  # shape and expression
758
  self.register_buffer(
759
  "shape_params",
760
- nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype),
761
- requires_grad=False),
762
  )
763
  self.register_buffer(
764
  "expression_params",
765
- nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype),
766
- requires_grad=False),
767
  )
768
  # pose: represented as rotation matrx [number of joints, 3, 3]
769
  self.register_buffer(
@@ -824,8 +832,7 @@ class SMPLX(nn.Module):
824
  )
825
 
826
  if config.extra_joint_path:
827
- self.extra_joint_selector = JointsFromVerticesSelector(
828
- fname=config.extra_joint_path)
829
  self.use_joint_regressor = True
830
  self.keypoint_names = SMPLX_names
831
  if self.use_joint_regressor:
@@ -843,7 +850,8 @@ class SMPLX(nn.Module):
843
  self.register_buffer("target_idxs", torch.from_numpy(target))
844
  self.register_buffer(
845
  "extra_joint_regressor",
846
- torch.from_numpy(j14_regressor).to(torch.float32))
 
847
  self.part_indices = part_indices
848
 
849
  def forward(
@@ -880,23 +888,17 @@ class SMPLX(nn.Module):
880
  if expression_params is None:
881
  expression_params = self.expression_params.expand(batch_size, -1)
882
  if global_pose is None:
883
- global_pose = self.global_pose.unsqueeze(0).expand(
884
- batch_size, -1, -1, -1)
885
  if body_pose is None:
886
- body_pose = self.body_pose.unsqueeze(0).expand(
887
- batch_size, -1, -1, -1)
888
  if jaw_pose is None:
889
- jaw_pose = self.jaw_pose.unsqueeze(0).expand(
890
- batch_size, -1, -1, -1)
891
  if eye_pose is None:
892
- eye_pose = self.eye_pose.unsqueeze(0).expand(
893
- batch_size, -1, -1, -1)
894
  if left_hand_pose is None:
895
- left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(
896
- batch_size, -1, -1, -1)
897
  if right_hand_pose is None:
898
- right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(
899
- batch_size, -1, -1, -1)
900
 
901
  shape_components = torch.cat([shape_params, expression_params], dim=1)
902
  full_pose = torch.cat(
@@ -910,8 +912,7 @@ class SMPLX(nn.Module):
910
  ],
911
  dim=1,
912
  )
913
- template_vertices = self.v_template.unsqueeze(0).expand(
914
- batch_size, -1, -1)
915
  # smplx
916
  vertices, joints = lbs(
917
  shape_components,
@@ -926,10 +927,8 @@ class SMPLX(nn.Module):
926
  pose2rot=False,
927
  )
928
  # face dynamic landmarks
929
- lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(
930
- batch_size, -1)
931
- lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(
932
- batch_size, -1, -1)
933
  dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
934
  vertices,
935
  full_pose,
@@ -939,14 +938,12 @@ class SMPLX(nn.Module):
939
  )
940
  lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
941
  lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1)
942
- landmarks = vertices2landmarks(vertices, self.faces_tensor,
943
- lmk_faces_idx, lmk_bary_coords)
944
 
945
  final_joint_set = [joints, landmarks]
946
  if hasattr(self, "extra_joint_selector"):
947
  # Add any extra joints that might be needed
948
- extra_joints = self.extra_joint_selector(vertices,
949
- self.faces_tensor)
950
  final_joint_set.append(extra_joints)
951
  # Create the final joint set
952
  joints = torch.cat(final_joint_set, dim=1)
@@ -978,16 +975,15 @@ class SMPLX(nn.Module):
978
  # -> Left elbow -> Left wrist
979
  kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
980
  else:
981
- raise NotImplementedError(
982
- f"pose_abs2rel does not support: {abs_joint}")
983
 
984
  batch_size = global_pose.shape[0]
985
  dtype = global_pose.dtype
986
  device = global_pose.device
987
  full_pose = torch.cat([global_pose, body_pose], dim=1)
988
- rel_rot_mat = (torch.eye(3, device=device,
989
- dtype=dtype).unsqueeze_(dim=0).repeat(
990
- batch_size, 1, 1))
991
  for idx in kin_chain[1:]:
992
  rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat)
993
 
@@ -1027,11 +1023,8 @@ class SMPLX(nn.Module):
1027
  # -> Left elbow -> Left wrist
1028
  kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
1029
  else:
1030
- raise NotImplementedError(
1031
- f"pose_rel2abs does not support: {abs_joint}")
1032
- rel_rot_mat = torch.eye(3,
1033
- device=full_pose.device,
1034
- dtype=full_pose.dtype).unsqueeze_(dim=0)
1035
  for idx in kin_chain:
1036
  rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat)
1037
  abs_pose = rel_rot_mat[:, None, :, :]
 
209
  SMPLX_names += extra_names
210
 
211
  part_indices = {}
212
+ part_indices["body"] = np.array(
213
+ [
214
+ 0,
215
+ 1,
216
+ 2,
217
+ 3,
218
+ 4,
219
+ 5,
220
+ 6,
221
+ 7,
222
+ 8,
223
+ 9,
224
+ 10,
225
+ 11,
226
+ 12,
227
+ 13,
228
+ 14,
229
+ 15,
230
+ 16,
231
+ 17,
232
+ 18,
233
+ 19,
234
+ 20,
235
+ 21,
236
+ 22,
237
+ 23,
238
+ 24,
239
+ 123,
240
+ 124,
241
+ 125,
242
+ 126,
243
+ 127,
244
+ 132,
245
+ 134,
246
+ 135,
247
+ 136,
248
+ 137,
249
+ 138,
250
+ 143,
251
+ ]
252
+ )
253
+ part_indices["torso"] = np.array(
254
+ [
255
+ 0,
256
+ 1,
257
+ 2,
258
+ 3,
259
+ 6,
260
+ 9,
261
+ 12,
262
+ 13,
263
+ 14,
264
+ 15,
265
+ 16,
266
+ 17,
267
+ 18,
268
+ 19,
269
+ 22,
270
+ 23,
271
+ 24,
272
+ 55,
273
+ 56,
274
+ 57,
275
+ 58,
276
+ 59,
277
+ 76,
278
+ 77,
279
+ 78,
280
+ 79,
281
+ 80,
282
+ 81,
283
+ 82,
284
+ 83,
285
+ 84,
286
+ 85,
287
+ 86,
288
+ 87,
289
+ 88,
290
+ 89,
291
+ 90,
292
+ 91,
293
+ 92,
294
+ 93,
295
+ 94,
296
+ 95,
297
+ 96,
298
+ 97,
299
+ 98,
300
+ 99,
301
+ 100,
302
+ 101,
303
+ 102,
304
+ 103,
305
+ 104,
306
+ 105,
307
+ 106,
308
+ 107,
309
+ 108,
310
+ 109,
311
+ 110,
312
+ 111,
313
+ 112,
314
+ 113,
315
+ 114,
316
+ 115,
317
+ 116,
318
+ 117,
319
+ 118,
320
+ 119,
321
+ 120,
322
+ 121,
323
+ 122,
324
+ 123,
325
+ 124,
326
+ 125,
327
+ 126,
328
+ 127,
329
+ 128,
330
+ 129,
331
+ 130,
332
+ 131,
333
+ 132,
334
+ 133,
335
+ 134,
336
+ 135,
337
+ 136,
338
+ 137,
339
+ 138,
340
+ 139,
341
+ 140,
342
+ 141,
343
+ 142,
344
+ 143,
345
+ 144,
346
+ ]
347
+ )
348
+ part_indices["head"] = np.array(
349
+ [
350
+ 12,
351
+ 15,
352
+ 22,
353
+ 23,
354
+ 24,
355
+ 55,
356
+ 56,
357
+ 57,
358
+ 58,
359
+ 59,
360
+ 60,
361
+ 61,
362
+ 62,
363
+ 63,
364
+ 64,
365
+ 65,
366
+ 66,
367
+ 67,
368
+ 68,
369
+ 69,
370
+ 70,
371
+ 71,
372
+ 72,
373
+ 73,
374
+ 74,
375
+ 75,
376
+ 76,
377
+ 77,
378
+ 78,
379
+ 79,
380
+ 80,
381
+ 81,
382
+ 82,
383
+ 83,
384
+ 84,
385
+ 85,
386
+ 86,
387
+ 87,
388
+ 88,
389
+ 89,
390
+ 90,
391
+ 91,
392
+ 92,
393
+ 93,
394
+ 94,
395
+ 95,
396
+ 96,
397
+ 97,
398
+ 98,
399
+ 99,
400
+ 100,
401
+ 101,
402
+ 102,
403
+ 103,
404
+ 104,
405
+ 105,
406
+ 106,
407
+ 107,
408
+ 108,
409
+ 109,
410
+ 110,
411
+ 111,
412
+ 112,
413
+ 113,
414
+ 114,
415
+ 115,
416
+ 116,
417
+ 117,
418
+ 118,
419
+ 119,
420
+ 120,
421
+ 121,
422
+ 122,
423
+ 123,
424
+ 125,
425
+ 126,
426
+ 134,
427
+ 136,
428
+ 137,
429
+ ]
430
+ )
431
+ part_indices["face"] = np.array(
432
+ [
433
+ 55,
434
+ 56,
435
+ 57,
436
+ 58,
437
+ 59,
438
+ 60,
439
+ 61,
440
+ 62,
441
+ 63,
442
+ 64,
443
+ 65,
444
+ 66,
445
+ 67,
446
+ 68,
447
+ 69,
448
+ 70,
449
+ 71,
450
+ 72,
451
+ 73,
452
+ 74,
453
+ 75,
454
+ 76,
455
+ 77,
456
+ 78,
457
+ 79,
458
+ 80,
459
+ 81,
460
+ 82,
461
+ 83,
462
+ 84,
463
+ 85,
464
+ 86,
465
+ 87,
466
+ 88,
467
+ 89,
468
+ 90,
469
+ 91,
470
+ 92,
471
+ 93,
472
+ 94,
473
+ 95,
474
+ 96,
475
+ 97,
476
+ 98,
477
+ 99,
478
+ 100,
479
+ 101,
480
+ 102,
481
+ 103,
482
+ 104,
483
+ 105,
484
+ 106,
485
+ 107,
486
+ 108,
487
+ 109,
488
+ 110,
489
+ 111,
490
+ 112,
491
+ 113,
492
+ 114,
493
+ 115,
494
+ 116,
495
+ 117,
496
+ 118,
497
+ 119,
498
+ 120,
499
+ 121,
500
+ 122,
501
+ ]
502
+ )
503
+ part_indices["upper"] = np.array(
504
+ [
505
+ 12,
506
+ 13,
507
+ 14,
508
+ 55,
509
+ 56,
510
+ 57,
511
+ 58,
512
+ 59,
513
+ 60,
514
+ 61,
515
+ 62,
516
+ 63,
517
+ 64,
518
+ 65,
519
+ 66,
520
+ 67,
521
+ 68,
522
+ 69,
523
+ 70,
524
+ 71,
525
+ 72,
526
+ 73,
527
+ 74,
528
+ 75,
529
+ 76,
530
+ 77,
531
+ 78,
532
+ 79,
533
+ 80,
534
+ 81,
535
+ 82,
536
+ 83,
537
+ 84,
538
+ 85,
539
+ 86,
540
+ 87,
541
+ 88,
542
+ 89,
543
+ 90,
544
+ 91,
545
+ 92,
546
+ 93,
547
+ 94,
548
+ 95,
549
+ 96,
550
+ 97,
551
+ 98,
552
+ 99,
553
+ 100,
554
+ 101,
555
+ 102,
556
+ 103,
557
+ 104,
558
+ 105,
559
+ 106,
560
+ 107,
561
+ 108,
562
+ 109,
563
+ 110,
564
+ 111,
565
+ 112,
566
+ 113,
567
+ 114,
568
+ 115,
569
+ 116,
570
+ 117,
571
+ 118,
572
+ 119,
573
+ 120,
574
+ 121,
575
+ 122,
576
+ ]
577
+ )
578
+ part_indices["hand"] = np.array(
579
+ [
580
+ 20,
581
+ 21,
582
+ 25,
583
+ 26,
584
+ 27,
585
+ 28,
586
+ 29,
587
+ 30,
588
+ 31,
589
+ 32,
590
+ 33,
591
+ 34,
592
+ 35,
593
+ 36,
594
+ 37,
595
+ 38,
596
+ 39,
597
+ 40,
598
+ 41,
599
+ 42,
600
+ 43,
601
+ 44,
602
+ 45,
603
+ 46,
604
+ 47,
605
+ 48,
606
+ 49,
607
+ 50,
608
+ 51,
609
+ 52,
610
+ 53,
611
+ 54,
612
+ 128,
613
+ 129,
614
+ 130,
615
+ 131,
616
+ 133,
617
+ 139,
618
+ 140,
619
+ 141,
620
+ 142,
621
+ 144,
622
+ ]
623
+ )
624
+ part_indices["left_hand"] = np.array(
625
+ [
626
+ 20,
627
+ 25,
628
+ 26,
629
+ 27,
630
+ 28,
631
+ 29,
632
+ 30,
633
+ 31,
634
+ 32,
635
+ 33,
636
+ 34,
637
+ 35,
638
+ 36,
639
+ 37,
640
+ 38,
641
+ 39,
642
+ 128,
643
+ 129,
644
+ 130,
645
+ 131,
646
+ 133,
647
+ ]
648
+ )
649
+ part_indices["right_hand"] = np.array(
650
+ [
651
+ 21,
652
+ 40,
653
+ 41,
654
+ 42,
655
+ 43,
656
+ 44,
657
+ 45,
658
+ 46,
659
+ 47,
660
+ 48,
661
+ 49,
662
+ 50,
663
+ 51,
664
+ 52,
665
+ 53,
666
+ 54,
667
+ 139,
668
+ 140,
669
+ 141,
670
+ 142,
671
+ 144,
672
+ ]
673
+ )
674
  # kinematic tree
675
  head_kin_chain = [15, 12, 9, 6, 3, 0]
676
 
 
707
  Given smplx parameters, this class generates a differentiable SMPLX function
708
  which outputs a mesh and 3D joints
709
  """
 
710
  def __init__(self, config):
711
  super(SMPLX, self).__init__()
712
  # print("creating the SMPLX Decoder")
713
  ss = np.load(config.smplx_model_path, allow_pickle=True)
714
  smplx_model = Struct(**ss)
715
+
716
  self.dtype = torch.float32
717
  self.register_buffer(
718
  "faces_tensor",
 
720
  )
721
  # The vertices of the template model
722
  self.register_buffer(
723
+ "v_template", to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)
724
+ )
725
  # The shape components and expression
726
  # expression space is the same as FLAME
727
  shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype)
 
736
  # The pose components
737
  num_pose_basis = smplx_model.posedirs.shape[-1]
738
  posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T
739
+ self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype))
 
740
  self.register_buffer(
741
+ "J_regressor", to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)
742
+ )
743
  parents = to_tensor(to_np(smplx_model.kintree_table[0])).long()
744
  parents[0] = -1
745
  self.register_buffer("parents", parents)
746
+ self.register_buffer("lbs_weights", to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
 
 
747
  # for face keypoints
748
  self.register_buffer(
749
+ "lmk_faces_idx", torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)
750
+ )
751
  self.register_buffer(
752
  "lmk_bary_coords",
753
  torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype),
 
758
  )
759
  self.register_buffer(
760
  "dynamic_lmk_bary_coords",
761
+ torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype),
 
762
  )
763
  # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks
764
+ self.register_buffer("head_kin_chain", torch.tensor(head_kin_chain, dtype=torch.long))
 
765
 
766
  # -- initialize parameters
767
  # shape and expression
768
  self.register_buffer(
769
  "shape_params",
770
+ nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False),
 
771
  )
772
  self.register_buffer(
773
  "expression_params",
774
+ nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False),
 
775
  )
776
  # pose: represented as rotation matrx [number of joints, 3, 3]
777
  self.register_buffer(
 
832
  )
833
 
834
  if config.extra_joint_path:
835
+ self.extra_joint_selector = JointsFromVerticesSelector(fname=config.extra_joint_path)
 
836
  self.use_joint_regressor = True
837
  self.keypoint_names = SMPLX_names
838
  if self.use_joint_regressor:
 
850
  self.register_buffer("target_idxs", torch.from_numpy(target))
851
  self.register_buffer(
852
  "extra_joint_regressor",
853
+ torch.from_numpy(j14_regressor).to(torch.float32)
854
+ )
855
  self.part_indices = part_indices
856
 
857
  def forward(
 
888
  if expression_params is None:
889
  expression_params = self.expression_params.expand(batch_size, -1)
890
  if global_pose is None:
891
+ global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
 
892
  if body_pose is None:
893
+ body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
 
894
  if jaw_pose is None:
895
+ jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
 
896
  if eye_pose is None:
897
+ eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
 
898
  if left_hand_pose is None:
899
+ left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
 
900
  if right_hand_pose is None:
901
+ right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
 
902
 
903
  shape_components = torch.cat([shape_params, expression_params], dim=1)
904
  full_pose = torch.cat(
 
912
  ],
913
  dim=1,
914
  )
915
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
 
916
  # smplx
917
  vertices, joints = lbs(
918
  shape_components,
 
927
  pose2rot=False,
928
  )
929
  # face dynamic landmarks
930
+ lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
931
+ lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
 
 
932
  dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
933
  vertices,
934
  full_pose,
 
938
  )
939
  lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
940
  lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1)
941
+ landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
 
942
 
943
  final_joint_set = [joints, landmarks]
944
  if hasattr(self, "extra_joint_selector"):
945
  # Add any extra joints that might be needed
946
+ extra_joints = self.extra_joint_selector(vertices, self.faces_tensor)
 
947
  final_joint_set.append(extra_joints)
948
  # Create the final joint set
949
  joints = torch.cat(final_joint_set, dim=1)
 
975
  # -> Left elbow -> Left wrist
976
  kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
977
  else:
978
+ raise NotImplementedError(f"pose_abs2rel does not support: {abs_joint}")
 
979
 
980
  batch_size = global_pose.shape[0]
981
  dtype = global_pose.dtype
982
  device = global_pose.device
983
  full_pose = torch.cat([global_pose, body_pose], dim=1)
984
+ rel_rot_mat = (
985
+ torch.eye(3, device=device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1)
986
+ )
987
  for idx in kin_chain[1:]:
988
  rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat)
989
 
 
1023
  # -> Left elbow -> Left wrist
1024
  kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
1025
  else:
1026
+ raise NotImplementedError(f"pose_rel2abs does not support: {abs_joint}")
1027
+ rel_rot_mat = torch.eye(3, device=full_pose.device, dtype=full_pose.dtype).unsqueeze_(dim=0)
 
 
 
1028
  for idx in kin_chain:
1029
  rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat)
1030
  abs_pose = rel_rot_mat[:, None, :, :]
lib/pixielib/models/encoders.py CHANGED
@@ -5,14 +5,13 @@ import torch.nn.functional as F
5
 
6
 
7
  class ResnetEncoder(nn.Module):
8
-
9
  def __init__(self, append_layers=None):
10
  super(ResnetEncoder, self).__init__()
11
  from . import resnet
12
 
13
  # feature_size = 2048
14
  self.feature_dim = 2048
15
- self.encoder = resnet.load_ResNet50Model() # out: 2048
16
  # regressor
17
  self.append_layers = append_layers
18
 
@@ -25,7 +24,6 @@ class ResnetEncoder(nn.Module):
25
 
26
 
27
  class MLP(nn.Module):
28
-
29
  def __init__(self, channels=[2048, 1024, 1], last_op=None):
30
  super(MLP, self).__init__()
31
  layers = []
@@ -45,13 +43,12 @@ class MLP(nn.Module):
45
 
46
 
47
  class HRNEncoder(nn.Module):
48
-
49
  def __init__(self, append_layers=None):
50
  super(HRNEncoder, self).__init__()
51
  from . import hrnet
52
 
53
  self.feature_dim = 2048
54
- self.encoder = hrnet.load_HRNet(pretrained=True) # out: 2048
55
  # regressor
56
  self.append_layers = append_layers
57
 
 
5
 
6
 
7
  class ResnetEncoder(nn.Module):
 
8
  def __init__(self, append_layers=None):
9
  super(ResnetEncoder, self).__init__()
10
  from . import resnet
11
 
12
  # feature_size = 2048
13
  self.feature_dim = 2048
14
+ self.encoder = resnet.load_ResNet50Model() # out: 2048
15
  # regressor
16
  self.append_layers = append_layers
17
 
 
24
 
25
 
26
  class MLP(nn.Module):
 
27
  def __init__(self, channels=[2048, 1024, 1], last_op=None):
28
  super(MLP, self).__init__()
29
  layers = []
 
43
 
44
 
45
  class HRNEncoder(nn.Module):
 
46
  def __init__(self, append_layers=None):
47
  super(HRNEncoder, self).__init__()
48
  from . import hrnet
49
 
50
  self.feature_dim = 2048
51
+ self.encoder = hrnet.load_HRNet(pretrained=True) # out: 2048
52
  # regressor
53
  self.append_layers = append_layers
54
 
lib/pixielib/models/hrnet.py CHANGED
@@ -15,38 +15,42 @@ def load_HRNet(pretrained=False):
15
  hr_net_cfg_dict = {
16
  "use_old_impl": False,
17
  "pretrained_layers": ["*"],
18
- "stage1": {
19
- "num_modules": 1,
20
- "num_branches": 1,
21
- "num_blocks": [4],
22
- "num_channels": [64],
23
- "block": "BOTTLENECK",
24
- "fuse_method": "SUM",
25
- },
26
- "stage2": {
27
- "num_modules": 1,
28
- "num_branches": 2,
29
- "num_blocks": [4, 4],
30
- "num_channels": [48, 96],
31
- "block": "BASIC",
32
- "fuse_method": "SUM",
33
- },
34
- "stage3": {
35
- "num_modules": 4,
36
- "num_branches": 3,
37
- "num_blocks": [4, 4, 4],
38
- "num_channels": [48, 96, 192],
39
- "block": "BASIC",
40
- "fuse_method": "SUM",
41
- },
42
- "stage4": {
43
- "num_modules": 3,
44
- "num_branches": 4,
45
- "num_blocks": [4, 4, 4, 4],
46
- "num_channels": [48, 96, 192, 384],
47
- "block": "BASIC",
48
- "fuse_method": "SUM",
49
- },
 
 
 
 
50
  }
51
  hr_net_cfg = hr_net_cfg_dict
52
  model = HighResolutionNet(hr_net_cfg)
@@ -55,7 +59,6 @@ def load_HRNet(pretrained=False):
55
 
56
 
57
  class HighResolutionModule(nn.Module):
58
-
59
  def __init__(
60
  self,
61
  num_branches,
@@ -67,8 +70,7 @@ class HighResolutionModule(nn.Module):
67
  multi_scale_output=True,
68
  ):
69
  super(HighResolutionModule, self).__init__()
70
- self._check_branches(num_branches, blocks, num_blocks, num_inchannels,
71
- num_channels)
72
 
73
  self.num_inchannels = num_inchannels
74
  self.fuse_method = fuse_method
@@ -76,37 +78,33 @@ class HighResolutionModule(nn.Module):
76
 
77
  self.multi_scale_output = multi_scale_output
78
 
79
- self.branches = self._make_branches(num_branches, blocks, num_blocks,
80
- num_channels)
81
  self.fuse_layers = self._make_fuse_layers()
82
  self.relu = nn.ReLU(True)
83
 
84
- def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels,
85
- num_channels):
86
  if num_branches != len(num_blocks):
87
- error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
88
- num_branches, len(num_blocks))
89
  raise ValueError(error_msg)
90
 
91
  if num_branches != len(num_channels):
92
  error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
93
- num_branches, len(num_channels))
 
94
  raise ValueError(error_msg)
95
 
96
  if num_branches != len(num_inchannels):
97
  error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
98
- num_branches, len(num_inchannels))
 
99
  raise ValueError(error_msg)
100
 
101
- def _make_one_branch(self,
102
- branch_index,
103
- block,
104
- num_blocks,
105
- num_channels,
106
- stride=1):
107
  downsample = None
108
- if (stride != 1 or self.num_inchannels[branch_index] !=
109
- num_channels[branch_index] * block.expansion):
 
 
110
  downsample = nn.Sequential(
111
  nn.Conv2d(
112
  self.num_inchannels[branch_index],
@@ -115,8 +113,7 @@ class HighResolutionModule(nn.Module):
115
  stride=stride,
116
  bias=False,
117
  ),
118
- nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
119
- momentum=BN_MOMENTUM),
120
  )
121
 
122
  layers = []
@@ -126,13 +123,11 @@ class HighResolutionModule(nn.Module):
126
  num_channels[branch_index],
127
  stride,
128
  downsample,
129
- ))
130
- self.num_inchannels[
131
- branch_index] = num_channels[branch_index] * block.expansion
132
  for i in range(1, num_blocks[branch_index]):
133
- layers.append(
134
- block(self.num_inchannels[branch_index],
135
- num_channels[branch_index]))
136
 
137
  return nn.Sequential(*layers)
138
 
@@ -140,8 +135,7 @@ class HighResolutionModule(nn.Module):
140
  branches = []
141
 
142
  for i in range(num_branches):
143
- branches.append(
144
- self._make_one_branch(i, block, num_blocks, num_channels))
145
 
146
  return nn.ModuleList(branches)
147
 
@@ -167,9 +161,9 @@ class HighResolutionModule(nn.Module):
167
  bias=False,
168
  ),
169
  nn.BatchNorm2d(num_inchannels[i]),
170
- nn.Upsample(scale_factor=2**(j - i),
171
- mode="nearest"),
172
- ))
173
  elif j == i:
174
  fuse_layer.append(None)
175
  else:
@@ -188,7 +182,8 @@ class HighResolutionModule(nn.Module):
188
  bias=False,
189
  ),
190
  nn.BatchNorm2d(num_outchannels_conv3x3),
191
- ))
 
192
  else:
193
  num_outchannels_conv3x3 = num_inchannels[j]
194
  conv3x3s.append(
@@ -203,7 +198,8 @@ class HighResolutionModule(nn.Module):
203
  ),
204
  nn.BatchNorm2d(num_outchannels_conv3x3),
205
  nn.ReLU(True),
206
- ))
 
207
  fuse_layer.append(nn.Sequential(*conv3x3s))
208
  fuse_layers.append(nn.ModuleList(fuse_layer))
209
 
@@ -237,7 +233,6 @@ blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck}
237
 
238
 
239
  class HighResolutionNet(nn.Module):
240
-
241
  def __init__(self, cfg, **kwargs):
242
  self.inplanes = 64
243
  super(HighResolutionNet, self).__init__()
@@ -245,19 +240,9 @@ class HighResolutionNet(nn.Module):
245
  self.use_old_impl = use_old_impl
246
 
247
  # stem net
248
- self.conv1 = nn.Conv2d(3,
249
- 64,
250
- kernel_size=3,
251
- stride=2,
252
- padding=1,
253
- bias=False)
254
  self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
255
- self.conv2 = nn.Conv2d(64,
256
- 64,
257
- kernel_size=3,
258
- stride=2,
259
- padding=1,
260
- bias=False)
261
  self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
262
  self.relu = nn.ReLU(inplace=True)
263
 
@@ -271,41 +256,29 @@ class HighResolutionNet(nn.Module):
271
  self.stage2_cfg = cfg.get("stage2", {})
272
  num_channels = self.stage2_cfg.get("num_channels", (32, 64))
273
  block = blocks_dict[self.stage2_cfg.get("block")]
274
- num_channels = [
275
- num_channels[i] * block.expansion for i in range(len(num_channels))
276
- ]
277
  stage2_num_channels = num_channels
278
- self.transition1 = self._make_transition_layer([stage1_out_channel],
279
- num_channels)
280
- self.stage2, pre_stage_channels = self._make_stage(
281
- self.stage2_cfg, num_channels)
282
 
283
  self.stage3_cfg = cfg.get("stage3")
284
  num_channels = self.stage3_cfg["num_channels"]
285
  block = blocks_dict[self.stage3_cfg["block"]]
286
- num_channels = [
287
- num_channels[i] * block.expansion for i in range(len(num_channels))
288
- ]
289
  stage3_num_channels = num_channels
290
- self.transition2 = self._make_transition_layer(pre_stage_channels,
291
- num_channels)
292
- self.stage3, pre_stage_channels = self._make_stage(
293
- self.stage3_cfg, num_channels)
294
 
295
  self.stage4_cfg = cfg.get("stage4")
296
  num_channels = self.stage4_cfg["num_channels"]
297
  block = blocks_dict[self.stage4_cfg["block"]]
298
- num_channels = [
299
- num_channels[i] * block.expansion for i in range(len(num_channels))
300
- ]
301
- self.transition3 = self._make_transition_layer(pre_stage_channels,
302
- num_channels)
303
  stage_4_out_channels = num_channels
304
 
305
  self.stage4, pre_stage_channels = self._make_stage(
306
- self.stage4_cfg,
307
- num_channels,
308
- multi_scale_output=not self.use_old_impl)
309
  stage4_num_channels = num_channels
310
 
311
  self.output_channels_dim = pre_stage_channels
@@ -316,35 +289,34 @@ class HighResolutionNet(nn.Module):
316
  self.avg_pooling = nn.AdaptiveAvgPool2d(1)
317
 
318
  if use_old_impl:
319
- in_dims = (2**2 * stage2_num_channels[-1] +
320
- 2**1 * stage3_num_channels[-1] +
321
- stage_4_out_channels[-1])
 
322
  else:
323
  # TODO: Replace with parameters
324
  in_dims = 4 * 384
325
  self.subsample_4 = self._make_subsample_layer(
326
- in_channels=stage4_num_channels[0], num_layers=3)
 
327
 
328
  self.subsample_3 = self._make_subsample_layer(
329
- in_channels=stage2_num_channels[-1], num_layers=2)
 
330
  self.subsample_2 = self._make_subsample_layer(
331
- in_channels=stage3_num_channels[-1], num_layers=1)
332
- self.conv_layers = self._make_conv_layer(in_channels=in_dims,
333
- num_layers=5)
334
 
335
  def get_output_dim(self):
336
- base_output = {
337
- f"layer{idx + 1}": val
338
- for idx, val in enumerate(self.output_channels_dim)
339
- }
340
  output = base_output.copy()
341
  for key in base_output:
342
  output[f"{key}_avg_pooling"] = output[key]
343
  output["concat"] = 2048
344
  return output
345
 
346
- def _make_transition_layer(self, num_channels_pre_layer,
347
- num_channels_cur_layer):
348
  num_branches_cur = len(num_channels_cur_layer)
349
  num_branches_pre = len(num_channels_pre_layer)
350
 
@@ -364,26 +336,24 @@ class HighResolutionNet(nn.Module):
364
  ),
365
  nn.BatchNorm2d(num_channels_cur_layer[i]),
366
  nn.ReLU(inplace=True),
367
- ))
 
368
  else:
369
  transition_layers.append(None)
370
  else:
371
  conv3x3s = []
372
  for j in range(i + 1 - num_branches_pre):
373
  inchannels = num_channels_pre_layer[-1]
374
- outchannels = (num_channels_cur_layer[i] if j == i -
375
- num_branches_pre else inchannels)
 
376
  conv3x3s.append(
377
  nn.Sequential(
378
- nn.Conv2d(inchannels,
379
- outchannels,
380
- 3,
381
- 2,
382
- 1,
383
- bias=False),
384
  nn.BatchNorm2d(outchannels),
385
  nn.ReLU(inplace=True),
386
- ))
 
387
  transition_layers.append(nn.Sequential(*conv3x3s))
388
 
389
  return nn.ModuleList(transition_layers)
@@ -410,24 +380,13 @@ class HighResolutionNet(nn.Module):
410
 
411
  return nn.Sequential(*layers)
412
 
413
- def _make_conv_layer(self,
414
- in_channels=2048,
415
- num_layers=3,
416
- num_filters=2048,
417
- stride=1):
418
 
419
  layers = []
420
  for i in range(num_layers):
421
 
422
- downsample = nn.Conv2d(in_channels,
423
- num_filters,
424
- stride=1,
425
- kernel_size=1,
426
- bias=False)
427
- layers.append(
428
- Bottleneck(in_channels,
429
- num_filters // 4,
430
- downsample=downsample))
431
  in_channels = num_filters
432
 
433
  return nn.Sequential(*layers)
@@ -444,18 +403,15 @@ class HighResolutionNet(nn.Module):
444
  kernel_size=3,
445
  stride=stride,
446
  padding=1,
447
- ))
 
448
  in_channels = 2 * in_channels
449
  layers.append(nn.BatchNorm2d(in_channels, momentum=BN_MOMENTUM))
450
  layers.append(nn.ReLU(inplace=True))
451
 
452
  return nn.Sequential(*layers)
453
 
454
- def _make_stage(self,
455
- layer_config,
456
- num_inchannels,
457
- multi_scale_output=True,
458
- log=False):
459
  num_modules = layer_config["num_modules"]
460
  num_branches = layer_config["num_branches"]
461
  num_blocks = layer_config["num_blocks"]
@@ -480,7 +436,8 @@ class HighResolutionNet(nn.Module):
480
  num_channels,
481
  fuse_method,
482
  reset_multi_scale_output,
483
- ))
 
484
  modules[-1].log = log
485
  num_inchannels = modules[-1].get_num_inchannels()
486
 
@@ -580,15 +537,14 @@ class HighResolutionNet(nn.Module):
580
  def load_weights(self, pretrained=""):
581
  pretrained = osp.expandvars(pretrained)
582
  if osp.isfile(pretrained):
583
- pretrained_state_dict = torch.load(
584
- pretrained, map_location=torch.device("cpu"))
585
 
586
  need_init_state_dict = {}
587
  for name, m in pretrained_state_dict.items():
588
- if (name.split(".")[0] in self.pretrained_layers
589
- or self.pretrained_layers[0] == "*"):
 
590
  need_init_state_dict[name] = m
591
- missing, unexpected = self.load_state_dict(need_init_state_dict,
592
- strict=False)
593
  elif pretrained:
594
  raise ValueError("{} is not exist!".format(pretrained))
 
15
  hr_net_cfg_dict = {
16
  "use_old_impl": False,
17
  "pretrained_layers": ["*"],
18
+ "stage1":
19
+ {
20
+ "num_modules": 1,
21
+ "num_branches": 1,
22
+ "num_blocks": [4],
23
+ "num_channels": [64],
24
+ "block": "BOTTLENECK",
25
+ "fuse_method": "SUM",
26
+ },
27
+ "stage2":
28
+ {
29
+ "num_modules": 1,
30
+ "num_branches": 2,
31
+ "num_blocks": [4, 4],
32
+ "num_channels": [48, 96],
33
+ "block": "BASIC",
34
+ "fuse_method": "SUM",
35
+ },
36
+ "stage3":
37
+ {
38
+ "num_modules": 4,
39
+ "num_branches": 3,
40
+ "num_blocks": [4, 4, 4],
41
+ "num_channels": [48, 96, 192],
42
+ "block": "BASIC",
43
+ "fuse_method": "SUM",
44
+ },
45
+ "stage4":
46
+ {
47
+ "num_modules": 3,
48
+ "num_branches": 4,
49
+ "num_blocks": [4, 4, 4, 4],
50
+ "num_channels": [48, 96, 192, 384],
51
+ "block": "BASIC",
52
+ "fuse_method": "SUM",
53
+ },
54
  }
55
  hr_net_cfg = hr_net_cfg_dict
56
  model = HighResolutionNet(hr_net_cfg)
 
59
 
60
 
61
  class HighResolutionModule(nn.Module):
 
62
  def __init__(
63
  self,
64
  num_branches,
 
70
  multi_scale_output=True,
71
  ):
72
  super(HighResolutionModule, self).__init__()
73
+ self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)
 
74
 
75
  self.num_inchannels = num_inchannels
76
  self.fuse_method = fuse_method
 
78
 
79
  self.multi_scale_output = multi_scale_output
80
 
81
+ self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
 
82
  self.fuse_layers = self._make_fuse_layers()
83
  self.relu = nn.ReLU(True)
84
 
85
+ def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
 
86
  if num_branches != len(num_blocks):
87
+ error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks))
 
88
  raise ValueError(error_msg)
89
 
90
  if num_branches != len(num_channels):
91
  error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
92
+ num_branches, len(num_channels)
93
+ )
94
  raise ValueError(error_msg)
95
 
96
  if num_branches != len(num_inchannels):
97
  error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
98
+ num_branches, len(num_inchannels)
99
+ )
100
  raise ValueError(error_msg)
101
 
102
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
 
 
 
 
 
103
  downsample = None
104
+ if (
105
+ stride != 1 or
106
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion
107
+ ):
108
  downsample = nn.Sequential(
109
  nn.Conv2d(
110
  self.num_inchannels[branch_index],
 
113
  stride=stride,
114
  bias=False,
115
  ),
116
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
 
117
  )
118
 
119
  layers = []
 
123
  num_channels[branch_index],
124
  stride,
125
  downsample,
126
+ )
127
+ )
128
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
129
  for i in range(1, num_blocks[branch_index]):
130
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
 
 
131
 
132
  return nn.Sequential(*layers)
133
 
 
135
  branches = []
136
 
137
  for i in range(num_branches):
138
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
 
139
 
140
  return nn.ModuleList(branches)
141
 
 
161
  bias=False,
162
  ),
163
  nn.BatchNorm2d(num_inchannels[i]),
164
+ nn.Upsample(scale_factor=2**(j - i), mode="nearest"),
165
+ )
166
+ )
167
  elif j == i:
168
  fuse_layer.append(None)
169
  else:
 
182
  bias=False,
183
  ),
184
  nn.BatchNorm2d(num_outchannels_conv3x3),
185
+ )
186
+ )
187
  else:
188
  num_outchannels_conv3x3 = num_inchannels[j]
189
  conv3x3s.append(
 
198
  ),
199
  nn.BatchNorm2d(num_outchannels_conv3x3),
200
  nn.ReLU(True),
201
+ )
202
+ )
203
  fuse_layer.append(nn.Sequential(*conv3x3s))
204
  fuse_layers.append(nn.ModuleList(fuse_layer))
205
 
 
233
 
234
 
235
  class HighResolutionNet(nn.Module):
 
236
  def __init__(self, cfg, **kwargs):
237
  self.inplanes = 64
238
  super(HighResolutionNet, self).__init__()
 
240
  self.use_old_impl = use_old_impl
241
 
242
  # stem net
243
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
 
 
 
 
 
244
  self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
245
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
 
 
 
 
 
246
  self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
247
  self.relu = nn.ReLU(inplace=True)
248
 
 
256
  self.stage2_cfg = cfg.get("stage2", {})
257
  num_channels = self.stage2_cfg.get("num_channels", (32, 64))
258
  block = blocks_dict[self.stage2_cfg.get("block")]
259
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
 
 
260
  stage2_num_channels = num_channels
261
+ self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels)
262
+ self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
 
 
263
 
264
  self.stage3_cfg = cfg.get("stage3")
265
  num_channels = self.stage3_cfg["num_channels"]
266
  block = blocks_dict[self.stage3_cfg["block"]]
267
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
 
 
268
  stage3_num_channels = num_channels
269
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
270
+ self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
 
 
271
 
272
  self.stage4_cfg = cfg.get("stage4")
273
  num_channels = self.stage4_cfg["num_channels"]
274
  block = blocks_dict[self.stage4_cfg["block"]]
275
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
276
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
 
 
 
277
  stage_4_out_channels = num_channels
278
 
279
  self.stage4, pre_stage_channels = self._make_stage(
280
+ self.stage4_cfg, num_channels, multi_scale_output=not self.use_old_impl
281
+ )
 
282
  stage4_num_channels = num_channels
283
 
284
  self.output_channels_dim = pre_stage_channels
 
289
  self.avg_pooling = nn.AdaptiveAvgPool2d(1)
290
 
291
  if use_old_impl:
292
+ in_dims = (
293
+ 2**2 * stage2_num_channels[-1] + 2**1 * stage3_num_channels[-1] +
294
+ stage_4_out_channels[-1]
295
+ )
296
  else:
297
  # TODO: Replace with parameters
298
  in_dims = 4 * 384
299
  self.subsample_4 = self._make_subsample_layer(
300
+ in_channels=stage4_num_channels[0], num_layers=3
301
+ )
302
 
303
  self.subsample_3 = self._make_subsample_layer(
304
+ in_channels=stage2_num_channels[-1], num_layers=2
305
+ )
306
  self.subsample_2 = self._make_subsample_layer(
307
+ in_channels=stage3_num_channels[-1], num_layers=1
308
+ )
309
+ self.conv_layers = self._make_conv_layer(in_channels=in_dims, num_layers=5)
310
 
311
  def get_output_dim(self):
312
+ base_output = {f"layer{idx + 1}": val for idx, val in enumerate(self.output_channels_dim)}
 
 
 
313
  output = base_output.copy()
314
  for key in base_output:
315
  output[f"{key}_avg_pooling"] = output[key]
316
  output["concat"] = 2048
317
  return output
318
 
319
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
 
320
  num_branches_cur = len(num_channels_cur_layer)
321
  num_branches_pre = len(num_channels_pre_layer)
322
 
 
336
  ),
337
  nn.BatchNorm2d(num_channels_cur_layer[i]),
338
  nn.ReLU(inplace=True),
339
+ )
340
+ )
341
  else:
342
  transition_layers.append(None)
343
  else:
344
  conv3x3s = []
345
  for j in range(i + 1 - num_branches_pre):
346
  inchannels = num_channels_pre_layer[-1]
347
+ outchannels = (
348
+ num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
349
+ )
350
  conv3x3s.append(
351
  nn.Sequential(
352
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
 
 
 
 
 
353
  nn.BatchNorm2d(outchannels),
354
  nn.ReLU(inplace=True),
355
+ )
356
+ )
357
  transition_layers.append(nn.Sequential(*conv3x3s))
358
 
359
  return nn.ModuleList(transition_layers)
 
380
 
381
  return nn.Sequential(*layers)
382
 
383
+ def _make_conv_layer(self, in_channels=2048, num_layers=3, num_filters=2048, stride=1):
 
 
 
 
384
 
385
  layers = []
386
  for i in range(num_layers):
387
 
388
+ downsample = nn.Conv2d(in_channels, num_filters, stride=1, kernel_size=1, bias=False)
389
+ layers.append(Bottleneck(in_channels, num_filters // 4, downsample=downsample))
 
 
 
 
 
 
 
390
  in_channels = num_filters
391
 
392
  return nn.Sequential(*layers)
 
403
  kernel_size=3,
404
  stride=stride,
405
  padding=1,
406
+ )
407
+ )
408
  in_channels = 2 * in_channels
409
  layers.append(nn.BatchNorm2d(in_channels, momentum=BN_MOMENTUM))
410
  layers.append(nn.ReLU(inplace=True))
411
 
412
  return nn.Sequential(*layers)
413
 
414
+ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True, log=False):
 
 
 
 
415
  num_modules = layer_config["num_modules"]
416
  num_branches = layer_config["num_branches"]
417
  num_blocks = layer_config["num_blocks"]
 
436
  num_channels,
437
  fuse_method,
438
  reset_multi_scale_output,
439
+ )
440
+ )
441
  modules[-1].log = log
442
  num_inchannels = modules[-1].get_num_inchannels()
443
 
 
537
  def load_weights(self, pretrained=""):
538
  pretrained = osp.expandvars(pretrained)
539
  if osp.isfile(pretrained):
540
+ pretrained_state_dict = torch.load(pretrained, map_location=torch.device("cpu"))
 
541
 
542
  need_init_state_dict = {}
543
  for name, m in pretrained_state_dict.items():
544
+ if (
545
+ name.split(".")[0] in self.pretrained_layers or self.pretrained_layers[0] == "*"
546
+ ):
547
  need_init_state_dict[name] = m
548
+ missing, unexpected = self.load_state_dict(need_init_state_dict, strict=False)
 
549
  elif pretrained:
550
  raise ValueError("{} is not exist!".format(pretrained))
lib/pixielib/models/lbs.py CHANGED
@@ -30,8 +30,7 @@ def rot_mat_to_euler(rot_mats):
30
  # Calculates rotation matrix to euler angles
31
  # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
32
 
33
- sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
34
- rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
35
  return torch.atan2(-rot_mats[:, 2, 0], sy)
36
 
37
 
@@ -86,15 +85,13 @@ def find_dynamic_lmk_idx_and_bcoords(
86
  # aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
87
  rot_mats = torch.index_select(pose, 1, head_kin_chain)
88
 
89
- rel_rot_mat = torch.eye(3, device=vertices.device,
90
- dtype=dtype).unsqueeze_(dim=0)
91
  for idx in range(len(head_kin_chain)):
92
  # rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
93
  rel_rot_mat = torch.matmul(rot_mats[:, idx], rel_rot_mat)
94
 
95
- y_rot_angle = torch.round(
96
- torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
97
- max=39)).to(dtype=torch.long)
98
  # print(y_rot_angle[0])
99
  neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
100
  mask = y_rot_angle.lt(-39).to(dtype=torch.long)
@@ -102,8 +99,7 @@ def find_dynamic_lmk_idx_and_bcoords(
102
  y_rot_angle = neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle
103
  # print(y_rot_angle[0])
104
 
105
- dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0,
106
- y_rot_angle)
107
  dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle)
108
 
109
  return dyn_lmk_faces_idx, dyn_lmk_b_coords
@@ -135,11 +131,11 @@ def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
135
  batch_size, num_verts = vertices.shape[:2]
136
  device = vertices.device
137
 
138
- lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
139
- batch_size, -1, 3)
140
 
141
- lmk_faces += (torch.arange(batch_size, dtype=torch.long,
142
- device=device).view(-1, 1, 1) * num_verts)
 
143
 
144
  lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3)
145
 
@@ -211,13 +207,11 @@ def lbs(
211
  # N x J x 3 x 3
212
  ident = torch.eye(3, dtype=dtype, device=device)
213
  if pose2rot:
214
- rot_mats = batch_rodrigues(pose.view(-1, 3),
215
- dtype=dtype).view([batch_size, -1, 3, 3])
216
 
217
  pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
218
  # (N x P) x (P, V * 3) -> N x V x 3
219
- pose_offsets = torch.matmul(pose_feature,
220
- posedirs).view(batch_size, -1, 3)
221
  else:
222
  pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
223
  rot_mats = pose.view(batch_size, -1, 3, 3)
@@ -234,12 +228,9 @@ def lbs(
234
  W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
235
  # (N x V x (J + 1)) x (N x (J + 1) x 16)
236
  num_joints = J_regressor.shape[0]
237
- T = torch.matmul(W, A.view(batch_size, num_joints,
238
- 16)).view(batch_size, -1, 4, 4)
239
 
240
- homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
241
- dtype=dtype,
242
- device=device)
243
  v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
244
  v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
245
 
@@ -318,8 +309,7 @@ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
318
  K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
319
 
320
  zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
321
- K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros],
322
- dim=1).view((batch_size, 3, 3))
323
 
324
  ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
325
  rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
@@ -335,9 +325,7 @@ def transform_mat(R, t):
335
  - T: Bx4x4 Transformation matrix
336
  """
337
  # No padding left or right, only add an extra row
338
- return torch.cat([F.pad(R, [0, 0, 0, 1]),
339
- F.pad(t, [0, 0, 0, 1], value=1)],
340
- dim=2)
341
 
342
 
343
  def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
@@ -370,15 +358,13 @@ def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
370
  rel_joints[:, 1:] -= joints[:, parents[1:]]
371
 
372
  transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
373
- rel_joints.reshape(-1, 3, 1)).reshape(
374
- -1, joints.shape[1], 4, 4)
375
 
376
  transform_chain = [transforms_mat[:, 0]]
377
  for i in range(1, parents.shape[0]):
378
  # Subtract the joint location at the rest pose
379
  # No need for rotation, since it's identity when at rest
380
- curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:,
381
- i])
382
  transform_chain.append(curr_res)
383
 
384
  transforms = torch.stack(transform_chain, dim=1)
@@ -392,21 +378,22 @@ def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
392
  joints_homogen = F.pad(joints, [0, 0, 0, 1])
393
 
394
  rel_transforms = transforms - F.pad(
395
- torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
 
396
 
397
  return posed_joints, rel_transforms
398
 
399
 
400
  class JointsFromVerticesSelector(nn.Module):
401
-
402
  def __init__(self, fname):
403
  """Selects extra joints from vertices"""
404
  super(JointsFromVerticesSelector, self).__init__()
405
 
406
  err_msg = ("Either pass a filename or triangle face ids, names and"
407
  " barycentrics")
408
- assert fname is not None or (face_ids is not None and bcs is not None
409
- and names is not None), err_msg
 
410
  if fname is not None:
411
  fname = os.path.expanduser(os.path.expandvars(fname))
412
  with open(fname, "r") as f:
@@ -422,13 +409,11 @@ class JointsFromVerticesSelector(nn.Module):
422
  assert len(bcs) == len(
423
  face_ids
424
  ), "The number of barycentric coordinates must be equal to the faces"
425
- assert len(names) == len(
426
- face_ids), "The number of names must be equal to the number of "
427
 
428
  self.names = names
429
  self.register_buffer("bcs", torch.tensor(bcs, dtype=torch.float32))
430
- self.register_buffer("face_ids",
431
- torch.tensor(face_ids, dtype=torch.long))
432
 
433
  def extra_joint_names(self):
434
  """Returns the names of the extra joints"""
@@ -439,8 +424,7 @@ class JointsFromVerticesSelector(nn.Module):
439
  return []
440
  vertex_ids = faces[self.face_ids].reshape(-1)
441
  # Should be BxNx3x3
442
- triangles = torch.index_select(vertices, 1, vertex_ids).reshape(
443
- -1, len(self.bcs), 3, 3)
444
  return (triangles * self.bcs[None, :, :, None]).sum(dim=2)
445
 
446
 
@@ -463,7 +447,6 @@ def to_np(array, dtype=np.float32):
463
 
464
 
465
  class Struct(object):
466
-
467
  def __init__(self, **kwargs):
468
  for key, val in kwargs.items():
469
  setattr(self, key, val)
 
30
  # Calculates rotation matrix to euler angles
31
  # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
32
 
33
+ sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
 
34
  return torch.atan2(-rot_mats[:, 2, 0], sy)
35
 
36
 
 
85
  # aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
86
  rot_mats = torch.index_select(pose, 1, head_kin_chain)
87
 
88
+ rel_rot_mat = torch.eye(3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0)
 
89
  for idx in range(len(head_kin_chain)):
90
  # rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
91
  rel_rot_mat = torch.matmul(rot_mats[:, idx], rel_rot_mat)
92
 
93
+ y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
94
+ max=39)).to(dtype=torch.long)
 
95
  # print(y_rot_angle[0])
96
  neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
97
  mask = y_rot_angle.lt(-39).to(dtype=torch.long)
 
99
  y_rot_angle = neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle
100
  # print(y_rot_angle[0])
101
 
102
+ dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, y_rot_angle)
 
103
  dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle)
104
 
105
  return dyn_lmk_faces_idx, dyn_lmk_b_coords
 
131
  batch_size, num_verts = vertices.shape[:2]
132
  device = vertices.device
133
 
134
+ lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(batch_size, -1, 3)
 
135
 
136
+ lmk_faces += (
137
+ torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
138
+ )
139
 
140
  lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3)
141
 
 
207
  # N x J x 3 x 3
208
  ident = torch.eye(3, dtype=dtype, device=device)
209
  if pose2rot:
210
+ rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
 
211
 
212
  pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
213
  # (N x P) x (P, V * 3) -> N x V x 3
214
+ pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
 
215
  else:
216
  pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
217
  rot_mats = pose.view(batch_size, -1, 3, 3)
 
228
  W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
229
  # (N x V x (J + 1)) x (N x (J + 1) x 16)
230
  num_joints = J_regressor.shape[0]
231
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
 
232
 
233
+ homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], dtype=dtype, device=device)
 
 
234
  v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
235
  v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
236
 
 
309
  K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
310
 
311
  zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
312
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3))
 
313
 
314
  ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
315
  rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
 
325
  - T: Bx4x4 Transformation matrix
326
  """
327
  # No padding left or right, only add an extra row
328
+ return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
 
 
329
 
330
 
331
  def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
 
358
  rel_joints[:, 1:] -= joints[:, parents[1:]]
359
 
360
  transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
361
+ rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
 
362
 
363
  transform_chain = [transforms_mat[:, 0]]
364
  for i in range(1, parents.shape[0]):
365
  # Subtract the joint location at the rest pose
366
  # No need for rotation, since it's identity when at rest
367
+ curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i])
 
368
  transform_chain.append(curr_res)
369
 
370
  transforms = torch.stack(transform_chain, dim=1)
 
378
  joints_homogen = F.pad(joints, [0, 0, 0, 1])
379
 
380
  rel_transforms = transforms - F.pad(
381
+ torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]
382
+ )
383
 
384
  return posed_joints, rel_transforms
385
 
386
 
387
  class JointsFromVerticesSelector(nn.Module):
 
388
  def __init__(self, fname):
389
  """Selects extra joints from vertices"""
390
  super(JointsFromVerticesSelector, self).__init__()
391
 
392
  err_msg = ("Either pass a filename or triangle face ids, names and"
393
  " barycentrics")
394
+ assert fname is not None or (
395
+ face_ids is not None and bcs is not None and names is not None
396
+ ), err_msg
397
  if fname is not None:
398
  fname = os.path.expanduser(os.path.expandvars(fname))
399
  with open(fname, "r") as f:
 
409
  assert len(bcs) == len(
410
  face_ids
411
  ), "The number of barycentric coordinates must be equal to the faces"
412
+ assert len(names) == len(face_ids), "The number of names must be equal to the number of "
 
413
 
414
  self.names = names
415
  self.register_buffer("bcs", torch.tensor(bcs, dtype=torch.float32))
416
+ self.register_buffer("face_ids", torch.tensor(face_ids, dtype=torch.long))
 
417
 
418
  def extra_joint_names(self):
419
  """Returns the names of the extra joints"""
 
424
  return []
425
  vertex_ids = faces[self.face_ids].reshape(-1)
426
  # Should be BxNx3x3
427
+ triangles = torch.index_select(vertices, 1, vertex_ids).reshape(-1, len(self.bcs), 3, 3)
 
428
  return (triangles * self.bcs[None, :, :, None]).sum(dim=2)
429
 
430
 
 
447
 
448
 
449
  class Struct(object):
 
450
  def __init__(self, **kwargs):
451
  for key, val in kwargs.items():
452
  setattr(self, key, val)
lib/pixielib/models/moderators.py CHANGED
@@ -12,11 +12,7 @@ import torch.nn.functional as F
12
 
13
 
14
  class TempSoftmaxFusion(nn.Module):
15
-
16
- def __init__(self,
17
- channels=[2048 * 2, 1024, 1],
18
- detach_inputs=False,
19
- detach_feature=False):
20
  super(TempSoftmaxFusion, self).__init__()
21
  self.detach_inputs = detach_inputs
22
  self.detach_feature = detach_feature
@@ -63,11 +59,7 @@ class TempSoftmaxFusion(nn.Module):
63
 
64
 
65
  class GumbelSoftmaxFusion(nn.Module):
66
-
67
- def __init__(self,
68
- channels=[2048 * 2, 1024, 1],
69
- detach_inputs=False,
70
- detach_feature=False):
71
  super(GumbelSoftmaxFusion, self).__init__()
72
  self.detach_inputs = detach_inputs
73
  self.detach_feature = detach_feature
 
12
 
13
 
14
  class TempSoftmaxFusion(nn.Module):
15
+ def __init__(self, channels=[2048 * 2, 1024, 1], detach_inputs=False, detach_feature=False):
 
 
 
 
16
  super(TempSoftmaxFusion, self).__init__()
17
  self.detach_inputs = detach_inputs
18
  self.detach_feature = detach_feature
 
59
 
60
 
61
  class GumbelSoftmaxFusion(nn.Module):
62
+ def __init__(self, channels=[2048 * 2, 1024, 1], detach_inputs=False, detach_feature=False):
 
 
 
 
63
  super(GumbelSoftmaxFusion, self).__init__()
64
  self.detach_inputs = detach_inputs
65
  self.detach_feature = detach_feature
lib/pixielib/models/resnet.py CHANGED
@@ -22,16 +22,10 @@ from torchvision import models
22
 
23
 
24
  class ResNet(nn.Module):
25
-
26
  def __init__(self, block, layers, num_classes=1000):
27
  self.inplanes = 64
28
  super(ResNet, self).__init__()
29
- self.conv1 = nn.Conv2d(3,
30
- 64,
31
- kernel_size=7,
32
- stride=2,
33
- padding=3,
34
- bias=False)
35
  self.bn1 = nn.BatchNorm2d(64)
36
  self.relu = nn.ReLU(inplace=True)
37
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@@ -98,12 +92,7 @@ class Bottleneck(nn.Module):
98
  super(Bottleneck, self).__init__()
99
  self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
100
  self.bn1 = nn.BatchNorm2d(planes)
101
- self.conv2 = nn.Conv2d(planes,
102
- planes,
103
- kernel_size=3,
104
- stride=stride,
105
- padding=1,
106
- bias=False)
107
  self.bn2 = nn.BatchNorm2d(planes)
108
  self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
109
  self.bn3 = nn.BatchNorm2d(planes * 4)
@@ -136,12 +125,7 @@ class Bottleneck(nn.Module):
136
 
137
  def conv3x3(in_planes, out_planes, stride=1):
138
  """3x3 convolution with padding"""
139
- return nn.Conv2d(in_planes,
140
- out_planes,
141
- kernel_size=3,
142
- stride=stride,
143
- padding=1,
144
- bias=False)
145
 
146
 
147
  class BasicBlock(nn.Module):
@@ -196,8 +180,7 @@ def load_ResNet50Model():
196
  model = ResNet(Bottleneck, [3, 4, 6, 3])
197
  copy_parameter_from_resnet(
198
  model,
199
- torchvision.models.resnet50(
200
- weights=models.ResNet50_Weights.DEFAULT).state_dict(),
201
  )
202
  return model
203
 
@@ -206,8 +189,7 @@ def load_ResNet101Model():
206
  model = ResNet(Bottleneck, [3, 4, 23, 3])
207
  copy_parameter_from_resnet(
208
  model,
209
- torchvision.models.resnet101(
210
- weights=models.ResNet101_Weights.DEFAULT).state_dict(),
211
  )
212
  return model
213
 
@@ -216,8 +198,7 @@ def load_ResNet152Model():
216
  model = ResNet(Bottleneck, [3, 8, 36, 3])
217
  copy_parameter_from_resnet(
218
  model,
219
- torchvision.models.resnet152(
220
- weights=models.ResNet152_Weights.DEFAULT).state_dict(),
221
  )
222
  return model
223
 
@@ -229,7 +210,6 @@ def load_ResNet152Model():
229
 
230
  class DoubleConv(nn.Module):
231
  """(convolution => [BN] => ReLU) * 2"""
232
-
233
  def __init__(self, in_channels, out_channels):
234
  super().__init__()
235
  self.double_conv = nn.Sequential(
@@ -247,11 +227,9 @@ class DoubleConv(nn.Module):
247
 
248
  class Down(nn.Module):
249
  """Downscaling with maxpool then double conv"""
250
-
251
  def __init__(self, in_channels, out_channels):
252
  super().__init__()
253
- self.maxpool_conv = nn.Sequential(
254
- nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
255
 
256
  def forward(self, x):
257
  return self.maxpool_conv(x)
@@ -259,20 +237,16 @@ class Down(nn.Module):
259
 
260
  class Up(nn.Module):
261
  """Upscaling then double conv"""
262
-
263
  def __init__(self, in_channels, out_channels, bilinear=True):
264
  super().__init__()
265
 
266
  # if bilinear, use the normal convolutions to reduce the number of channels
267
  if bilinear:
268
- self.up = nn.Upsample(scale_factor=2,
269
- mode="bilinear",
270
- align_corners=True)
271
  else:
272
- self.up = nn.ConvTranspose2d(in_channels // 2,
273
- in_channels // 2,
274
- kernel_size=2,
275
- stride=2)
276
 
277
  self.conv = DoubleConv(in_channels, out_channels)
278
 
@@ -282,9 +256,7 @@ class Up(nn.Module):
282
  diffY = x2.size()[2] - x1.size()[2]
283
  diffX = x2.size()[3] - x1.size()[3]
284
 
285
- x1 = F.pad(
286
- x1,
287
- [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
288
  # if you have padding issues, see
289
  # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
290
  # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
@@ -293,7 +265,6 @@ class Up(nn.Module):
293
 
294
 
295
  class OutConv(nn.Module):
296
-
297
  def __init__(self, in_channels, out_channels):
298
  super(OutConv, self).__init__()
299
  self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
@@ -303,7 +274,6 @@ class OutConv(nn.Module):
303
 
304
 
305
  class UNet(nn.Module):
306
-
307
  def __init__(self, n_channels, n_classes, bilinear=True):
308
  super(UNet, self).__init__()
309
  self.n_channels = n_channels
 
22
 
23
 
24
  class ResNet(nn.Module):
 
25
  def __init__(self, block, layers, num_classes=1000):
26
  self.inplanes = 64
27
  super(ResNet, self).__init__()
28
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
 
 
 
 
 
29
  self.bn1 = nn.BatchNorm2d(64)
30
  self.relu = nn.ReLU(inplace=True)
31
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 
92
  super(Bottleneck, self).__init__()
93
  self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
94
  self.bn1 = nn.BatchNorm2d(planes)
95
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
 
 
 
 
 
96
  self.bn2 = nn.BatchNorm2d(planes)
97
  self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
98
  self.bn3 = nn.BatchNorm2d(planes * 4)
 
125
 
126
  def conv3x3(in_planes, out_planes, stride=1):
127
  """3x3 convolution with padding"""
128
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
 
 
 
 
 
129
 
130
 
131
  class BasicBlock(nn.Module):
 
180
  model = ResNet(Bottleneck, [3, 4, 6, 3])
181
  copy_parameter_from_resnet(
182
  model,
183
+ torchvision.models.resnet50(weights=models.ResNet50_Weights.DEFAULT).state_dict(),
 
184
  )
185
  return model
186
 
 
189
  model = ResNet(Bottleneck, [3, 4, 23, 3])
190
  copy_parameter_from_resnet(
191
  model,
192
+ torchvision.models.resnet101(weights=models.ResNet101_Weights.DEFAULT).state_dict(),
 
193
  )
194
  return model
195
 
 
198
  model = ResNet(Bottleneck, [3, 8, 36, 3])
199
  copy_parameter_from_resnet(
200
  model,
201
+ torchvision.models.resnet152(weights=models.ResNet152_Weights.DEFAULT).state_dict(),
 
202
  )
203
  return model
204
 
 
210
 
211
  class DoubleConv(nn.Module):
212
  """(convolution => [BN] => ReLU) * 2"""
 
213
  def __init__(self, in_channels, out_channels):
214
  super().__init__()
215
  self.double_conv = nn.Sequential(
 
227
 
228
  class Down(nn.Module):
229
  """Downscaling with maxpool then double conv"""
 
230
  def __init__(self, in_channels, out_channels):
231
  super().__init__()
232
+ self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
 
233
 
234
  def forward(self, x):
235
  return self.maxpool_conv(x)
 
237
 
238
  class Up(nn.Module):
239
  """Upscaling then double conv"""
 
240
  def __init__(self, in_channels, out_channels, bilinear=True):
241
  super().__init__()
242
 
243
  # if bilinear, use the normal convolutions to reduce the number of channels
244
  if bilinear:
245
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
 
 
246
  else:
247
+ self.up = nn.ConvTranspose2d(
248
+ in_channels // 2, in_channels // 2, kernel_size=2, stride=2
249
+ )
 
250
 
251
  self.conv = DoubleConv(in_channels, out_channels)
252
 
 
256
  diffY = x2.size()[2] - x1.size()[2]
257
  diffX = x2.size()[3] - x1.size()[3]
258
 
259
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
 
 
260
  # if you have padding issues, see
261
  # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
262
  # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
 
265
 
266
 
267
  class OutConv(nn.Module):
 
268
  def __init__(self, in_channels, out_channels):
269
  super(OutConv, self).__init__()
270
  self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
 
274
 
275
 
276
  class UNet(nn.Module):
 
277
  def __init__(self, n_channels, n_classes, bilinear=True):
278
  super(UNet, self).__init__()
279
  self.n_channels = n_channels
lib/pixielib/pixie.py CHANGED
@@ -33,7 +33,6 @@ from .utils.config import cfg
33
 
34
 
35
  class PIXIE(object):
36
-
37
  def __init__(self, config=None, device="cuda:0"):
38
  if config is None:
39
  self.cfg = cfg
@@ -45,10 +44,7 @@ class PIXIE(object):
45
  self.param_list_dict = {}
46
  for lst in self.cfg.params.keys():
47
  param_list = cfg.params.get(lst)
48
- self.param_list_dict[lst] = {
49
- i: cfg.model.get("n_" + i)
50
- for i in param_list
51
- }
52
 
53
  # Build the models
54
  self._create_model()
@@ -97,24 +93,19 @@ class PIXIE(object):
97
  self.Regressor = {}
98
  for key in self.cfg.network.regressor.keys():
99
  n_output = sum(self.param_list_dict[f"{key}_list"].values())
100
- channels = ([2048] + self.cfg.network.regressor.get(key).channels +
101
- [n_output])
102
  if self.cfg.network.regressor.get(key).type == "mlp":
103
  self.Regressor[key] = MLP(channels=channels).to(self.device)
104
- self.model_dict[f"Regressor_{key}"] = self.Regressor[
105
- key].state_dict()
106
 
107
  # Build the extractors
108
  # to extract separate head/left hand/right hand feature from body feature
109
  self.Extractor = {}
110
  for key in self.cfg.network.extractor.keys():
111
- channels = [
112
- 2048
113
- ] + self.cfg.network.extractor.get(key).channels + [2048]
114
  if self.cfg.network.extractor.get(key).type == "mlp":
115
  self.Extractor[key] = MLP(channels=channels).to(self.device)
116
- self.model_dict[f"Extractor_{key}"] = self.Extractor[
117
- key].state_dict()
118
 
119
  # Build the moderators
120
  self.Moderator = {}
@@ -122,15 +113,13 @@ class PIXIE(object):
122
  share_part = key.split("_")[0]
123
  detach_inputs = self.cfg.network.moderator.get(key).detach_inputs
124
  detach_feature = self.cfg.network.moderator.get(key).detach_feature
125
- channels = [2048 * 2
126
- ] + self.cfg.network.moderator.get(key).channels + [2]
127
  self.Moderator[key] = TempSoftmaxFusion(
128
  detach_inputs=detach_inputs,
129
  detach_feature=detach_feature,
130
  channels=channels,
131
  ).to(self.device)
132
- self.model_dict[f"Moderator_{key}"] = self.Moderator[
133
- key].state_dict()
134
 
135
  # Build the SMPL-X body model, which we also use to represent faces and
136
  # hands, using the relevant parts only
@@ -147,9 +136,7 @@ class PIXIE(object):
147
  print(f"pixie trained model path: {model_path} does not exist!")
148
  exit()
149
  # eval mode
150
- for module in [
151
- self.Encoder, self.Regressor, self.Moderator, self.Extractor
152
- ]:
153
  for net in module.values():
154
  net.eval()
155
 
@@ -185,14 +172,14 @@ class PIXIE(object):
185
  # crop
186
  cropper_key = "hand" if "hand" in part_key else part_key
187
  points_scale = image.shape[-2:]
188
- cropped_image, tform = self.Cropper[cropper_key].crop(
189
- image, points_for_crop, points_scale)
190
  # transform points(must be normalized to [-1.1]) accordingly
191
  cropped_points_dict = {}
192
  for points_key in points_dict.keys():
193
  points = points_dict[points_key]
194
  cropped_points = self.Cropper[cropper_key].transform_points(
195
- points, tform, points_scale, normalize=True)
 
196
  cropped_points_dict[points_key] = cropped_points
197
  return cropped_image, cropped_points_dict
198
 
@@ -244,8 +231,7 @@ class PIXIE(object):
244
  # then predict share parameters
245
  feature[key][f"{key}_share"] = feature[key][key]
246
  share_dict = self.decompose_code(
247
- self.Regressor[f"{part}_share"](
248
- feature[key][f"{part}_share"]),
249
  self.param_list_dict[f"{part}_share_list"],
250
  )
251
  # compose parameters
@@ -257,13 +243,16 @@ class PIXIE(object):
257
  f_body = feature["body"]["body"]
258
  # extract part feature
259
  for part_name in ["head", "left_hand", "right_hand"]:
260
- feature["body"][f"{part_name}_share"] = self.Extractor[
261
- f"{part_name}_share"](f_body)
 
262
 
263
  # -- check if part crops are given, if not, crop parts by coarse body estimation
264
- if ("head_image" not in data[key].keys()
265
- or "left_hand_image" not in data[key].keys()
266
- or "right_hand_image" not in data[key].keys()):
 
 
267
  # - run without fusion to get coarse estimation, for cropping parts
268
  # body only
269
  body_dict = self.decompose_code(
@@ -272,29 +261,26 @@ class PIXIE(object):
272
  )
273
  # head share
274
  head_share_dict = self.decompose_code(
275
- self.Regressor["head" + "_share"](
276
- feature[key]["head" + "_share"]),
277
  self.param_list_dict["head" + "_share_list"],
278
  )
279
  # right hand share
280
  right_hand_share_dict = self.decompose_code(
281
- self.Regressor["hand" + "_share"](
282
- feature[key]["right_hand" + "_share"]),
283
  self.param_list_dict["hand" + "_share_list"],
284
  )
285
  # left hand share
286
  left_hand_share_dict = self.decompose_code(
287
- self.Regressor["hand" + "_share"](
288
- feature[key]["left_hand" + "_share"]),
289
  self.param_list_dict["hand" + "_share_list"],
290
  )
291
  # change the dict name from right to left
292
- left_hand_share_dict[
293
- "left_hand_pose"] = left_hand_share_dict.pop(
294
- "right_hand_pose")
295
- left_hand_share_dict[
296
- "left_wrist_pose"] = left_hand_share_dict.pop(
297
- "right_wrist_pose")
298
  param_dict[key] = {
299
  **body_dict,
300
  **head_share_dict,
@@ -304,21 +290,18 @@ class PIXIE(object):
304
  if body_only:
305
  param_dict["moderator_weight"] = None
306
  return param_dict
307
- prediction_body_only = self.decode(param_dict[key],
308
- param_type="body")
309
  # crop
310
  for part_name in ["head", "left_hand", "right_hand"]:
311
  part = part_name.split("_")[-1]
312
  points_dict = {
313
- "smplx_kpt":
314
- prediction_body_only["smplx_kpt"],
315
- "trans_verts":
316
- prediction_body_only["transformed_vertices"],
317
  }
318
- image_hd = torchvision.transforms.Resize(1024)(
319
- data["body"]["image"])
320
  cropped_image, cropped_joints_dict = self.part_from_body(
321
- image_hd, part_name, points_dict)
 
322
  data[key][part_name + "_image"] = cropped_image
323
 
324
  # -- encode features from part crops, then fuse feature using the weight from moderator
@@ -338,16 +321,12 @@ class PIXIE(object):
338
  self.Regressor[f"{part}_share"](f_part),
339
  self.param_list_dict[f"{part}_share_list"],
340
  )
341
- param_dict["body_" + part_name] = {
342
- **part_dict,
343
- **part_share_dict
344
- }
345
 
346
  # moderator to assign weight, then integrate features
347
- f_body_out, f_part_out, f_weight = self.Moderator[
348
- f"{part}_share"](feature["body"][f"{part_name}_share"],
349
- f_part,
350
- work=True)
351
  if copy_and_paste:
352
  # copy and paste strategy always trusts the results from part
353
  feature["body"][f"{part_name}_share"] = f_part
@@ -355,8 +334,9 @@ class PIXIE(object):
355
  # for hand, if part weight > 0.7 (very confident, then fully trust part)
356
  part_w = f_weight[:, [1]]
357
  part_w[part_w > 0.7] = 1.0
358
- f_body_out = (feature["body"][f"{part_name}_share"] *
359
- (1.0 - part_w) + f_part * part_w)
 
360
  feature["body"][f"{part_name}_share"] = f_body_out
361
  else:
362
  feature["body"][f"{part_name}_share"] = f_body_out
@@ -367,29 +347,24 @@ class PIXIE(object):
367
  # -- predict parameters from fused body feature
368
  # head share
369
  head_share_dict = self.decompose_code(
370
- self.Regressor["head" + "_share"](feature[key]["head" +
371
- "_share"]),
372
  self.param_list_dict["head" + "_share_list"],
373
  )
374
  # right hand share
375
  right_hand_share_dict = self.decompose_code(
376
- self.Regressor["hand" + "_share"](
377
- feature[key]["right_hand" + "_share"]),
378
  self.param_list_dict["hand" + "_share_list"],
379
  )
380
  # left hand share
381
  left_hand_share_dict = self.decompose_code(
382
- self.Regressor["hand" + "_share"](
383
- feature[key]["left_hand" + "_share"]),
384
  self.param_list_dict["hand" + "_share_list"],
385
  )
386
  # change the dict name from right to left
387
- left_hand_share_dict[
388
- "left_hand_pose"] = left_hand_share_dict.pop(
389
- "right_hand_pose")
390
- left_hand_share_dict[
391
- "left_wrist_pose"] = left_hand_share_dict.pop(
392
- "right_wrist_pose")
393
  param_dict["body"] = {
394
  **body_dict,
395
  **head_share_dict,
@@ -403,10 +378,10 @@ class PIXIE(object):
403
  if keep_local:
404
  # for local change that will not affect whole body and produce unnatral pose, trust part
405
  param_dict[key]["exp"] = param_dict["body_head"]["exp"]
406
- param_dict[key]["right_hand_pose"] = param_dict[
407
- "body_right_hand"]["right_hand_pose"]
408
- param_dict[key]["left_hand_pose"] = param_dict[
409
- "body_left_hand"]["right_hand_pose"]
410
 
411
  return param_dict
412
 
@@ -426,75 +401,70 @@ class PIXIE(object):
426
  if "pose" in key and "jaw" not in key:
427
  param_dict[key] = converter.batch_cont2matrix(param_dict[key])
428
  if param_type == "body" or param_type == "head":
429
- param_dict["jaw_pose"] = converter.batch_euler2matrix(
430
- param_dict["jaw_pose"])[:, None, :, :]
431
 
432
  # complement params if it's not in given param dict
433
  if param_type == "head":
434
  batch_size = param_dict["shape"].shape[0]
435
  param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
436
  param_dict["global_pose"] = param_dict["head_pose"]
437
- param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(
438
- 0).expand(
439
- batch_size, -1, -1,
440
- -1)[:, :self.param_list_dict["body_list"]["partbody_pose"]]
441
  param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
442
- batch_size, -1, -1, -1)
443
- param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(
444
- 0).expand(batch_size, -1, -1, -1)
445
- param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(
446
- 0).expand(batch_size, -1, -1, -1)
447
- param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(
448
- 0).expand(batch_size, -1, -1, -1)
449
- param_dict[
450
- "right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(
451
- 0).expand(batch_size, -1, -1, -1)
 
 
 
 
452
  elif param_type == "hand":
453
  batch_size = param_dict["right_hand_pose"].shape[0]
454
- param_dict["abs_right_wrist_pose"] = param_dict[
455
- "right_wrist_pose"].clone()
456
  dtype = param_dict["right_hand_pose"].dtype
457
  device = param_dict["right_hand_pose"].device
458
- x_180_pose = (torch.eye(3, dtype=dtype,
459
- device=device).unsqueeze(0).repeat(
460
- 1, 1, 1))
461
  x_180_pose[0, 2, 2] = -1.0
462
  x_180_pose[0, 1, 1] = -1.0
463
- param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(
464
- batch_size, -1, -1, -1)
465
- param_dict["shape"] = self.smplx.shape_params.expand(
466
- batch_size, -1)
467
- param_dict["exp"] = self.smplx.expression_params.expand(
468
- batch_size, -1)
469
  param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand(
470
- batch_size, -1, -1, -1)
 
471
  param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
472
- batch_size, -1, -1, -1)
473
- param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(
474
- batch_size, -1, -1, -1)
475
- param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(
476
- 0).expand(
477
- batch_size, -1, -1,
478
- -1)[:, :self.param_list_dict["body_list"]["partbody_pose"]]
479
- param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(
480
- 0).expand(batch_size, -1, -1, -1)
481
- param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(
482
- 0).expand(batch_size, -1, -1, -1)
 
483
  elif param_type == "body":
484
  # the predcition from the head and hand share regressor is always absolute pose
485
  batch_size = param_dict["shape"].shape[0]
486
  param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
487
- param_dict["abs_right_wrist_pose"] = param_dict[
488
- "right_wrist_pose"].clone()
489
- param_dict["abs_left_wrist_pose"] = param_dict[
490
- "left_wrist_pose"].clone()
491
  # the body-hand share regressor is working for right hand
492
  # so we assume body network get the flipped feature for the left hand. then get the parameters
493
  # then we need to flip it back to left, which matches the input left hand
494
- param_dict["left_wrist_pose"] = util.flip_pose(
495
- param_dict["left_wrist_pose"])
496
- param_dict["left_hand_pose"] = util.flip_pose(
497
- param_dict["left_hand_pose"])
498
  else:
499
  exit()
500
 
@@ -508,8 +478,7 @@ class PIXIE(object):
508
  Returns:
509
  predictions: smplx predictions
510
  """
511
- if "jaw_pose" in param_dict.keys() and len(
512
- param_dict["jaw_pose"].shape) == 2:
513
  self.convert_pose(param_dict, param_type)
514
  elif param_dict["right_wrist_pose"].shape[-1] == 6:
515
  self.convert_pose(param_dict, param_type)
@@ -532,9 +501,8 @@ class PIXIE(object):
532
  # change absolute head&hand pose to relative pose according to rest body pose
533
  if param_type == "head" or param_type == "body":
534
  param_dict["body_pose"] = self.smplx.pose_abs2rel(
535
- param_dict["global_pose"],
536
- param_dict["body_pose"],
537
- abs_joint="head")
538
  if param_type == "hand" or param_type == "body":
539
  param_dict["body_pose"] = self.smplx.pose_abs2rel(
540
  param_dict["global_pose"],
@@ -550,7 +518,7 @@ class PIXIE(object):
550
  if self.cfg.model.check_pose:
551
  # check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose)
552
  # xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left)
553
- for pose_ind in [14]: # head [15-1, 20-1, 21-1]:
554
  curr_pose = param_dict["body_pose"][:, pose_ind]
555
  euler_pose = converter._compute_euler_from_matrix(curr_pose)
556
  for i, max_angle in enumerate([20, 70, 10]):
@@ -560,9 +528,7 @@ class PIXIE(object):
560
  min=-max_angle * np.pi / 180,
561
  max=max_angle * np.pi / 180,
562
  )] = 0.0
563
- param_dict[
564
- "body_pose"][:, pose_ind] = converter.batch_euler2matrix(
565
- euler_pose)
566
 
567
  # SMPLX
568
  verts, landmarks, joints = self.smplx(
@@ -594,8 +560,8 @@ class PIXIE(object):
594
 
595
  # change the order of face keypoints, to be the same as "standard" 68 keypoints
596
  prediction["face_kpt"] = torch.cat(
597
- [prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]],
598
- dim=1)
599
 
600
  prediction.update(param_dict)
601
 
 
33
 
34
 
35
  class PIXIE(object):
 
36
  def __init__(self, config=None, device="cuda:0"):
37
  if config is None:
38
  self.cfg = cfg
 
44
  self.param_list_dict = {}
45
  for lst in self.cfg.params.keys():
46
  param_list = cfg.params.get(lst)
47
+ self.param_list_dict[lst] = {i: cfg.model.get("n_" + i) for i in param_list}
 
 
 
48
 
49
  # Build the models
50
  self._create_model()
 
93
  self.Regressor = {}
94
  for key in self.cfg.network.regressor.keys():
95
  n_output = sum(self.param_list_dict[f"{key}_list"].values())
96
+ channels = ([2048] + self.cfg.network.regressor.get(key).channels + [n_output])
 
97
  if self.cfg.network.regressor.get(key).type == "mlp":
98
  self.Regressor[key] = MLP(channels=channels).to(self.device)
99
+ self.model_dict[f"Regressor_{key}"] = self.Regressor[key].state_dict()
 
100
 
101
  # Build the extractors
102
  # to extract separate head/left hand/right hand feature from body feature
103
  self.Extractor = {}
104
  for key in self.cfg.network.extractor.keys():
105
+ channels = [2048] + self.cfg.network.extractor.get(key).channels + [2048]
 
 
106
  if self.cfg.network.extractor.get(key).type == "mlp":
107
  self.Extractor[key] = MLP(channels=channels).to(self.device)
108
+ self.model_dict[f"Extractor_{key}"] = self.Extractor[key].state_dict()
 
109
 
110
  # Build the moderators
111
  self.Moderator = {}
 
113
  share_part = key.split("_")[0]
114
  detach_inputs = self.cfg.network.moderator.get(key).detach_inputs
115
  detach_feature = self.cfg.network.moderator.get(key).detach_feature
116
+ channels = [2048 * 2] + self.cfg.network.moderator.get(key).channels + [2]
 
117
  self.Moderator[key] = TempSoftmaxFusion(
118
  detach_inputs=detach_inputs,
119
  detach_feature=detach_feature,
120
  channels=channels,
121
  ).to(self.device)
122
+ self.model_dict[f"Moderator_{key}"] = self.Moderator[key].state_dict()
 
123
 
124
  # Build the SMPL-X body model, which we also use to represent faces and
125
  # hands, using the relevant parts only
 
136
  print(f"pixie trained model path: {model_path} does not exist!")
137
  exit()
138
  # eval mode
139
+ for module in [self.Encoder, self.Regressor, self.Moderator, self.Extractor]:
 
 
140
  for net in module.values():
141
  net.eval()
142
 
 
172
  # crop
173
  cropper_key = "hand" if "hand" in part_key else part_key
174
  points_scale = image.shape[-2:]
175
+ cropped_image, tform = self.Cropper[cropper_key].crop(image, points_for_crop, points_scale)
 
176
  # transform points(must be normalized to [-1.1]) accordingly
177
  cropped_points_dict = {}
178
  for points_key in points_dict.keys():
179
  points = points_dict[points_key]
180
  cropped_points = self.Cropper[cropper_key].transform_points(
181
+ points, tform, points_scale, normalize=True
182
+ )
183
  cropped_points_dict[points_key] = cropped_points
184
  return cropped_image, cropped_points_dict
185
 
 
231
  # then predict share parameters
232
  feature[key][f"{key}_share"] = feature[key][key]
233
  share_dict = self.decompose_code(
234
+ self.Regressor[f"{part}_share"](feature[key][f"{part}_share"]),
 
235
  self.param_list_dict[f"{part}_share_list"],
236
  )
237
  # compose parameters
 
243
  f_body = feature["body"]["body"]
244
  # extract part feature
245
  for part_name in ["head", "left_hand", "right_hand"]:
246
+ feature["body"][f"{part_name}_share"] = self.Extractor[f"{part_name}_share"](
247
+ f_body
248
+ )
249
 
250
  # -- check if part crops are given, if not, crop parts by coarse body estimation
251
+ if (
252
+ "head_image" not in data[key].keys() or
253
+ "left_hand_image" not in data[key].keys() or
254
+ "right_hand_image" not in data[key].keys()
255
+ ):
256
  # - run without fusion to get coarse estimation, for cropping parts
257
  # body only
258
  body_dict = self.decompose_code(
 
261
  )
262
  # head share
263
  head_share_dict = self.decompose_code(
264
+ self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
 
265
  self.param_list_dict["head" + "_share_list"],
266
  )
267
  # right hand share
268
  right_hand_share_dict = self.decompose_code(
269
+ self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
 
270
  self.param_list_dict["hand" + "_share_list"],
271
  )
272
  # left hand share
273
  left_hand_share_dict = self.decompose_code(
274
+ self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
 
275
  self.param_list_dict["hand" + "_share_list"],
276
  )
277
  # change the dict name from right to left
278
+ left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop(
279
+ "right_hand_pose"
280
+ )
281
+ left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
282
+ "right_wrist_pose"
283
+ )
284
  param_dict[key] = {
285
  **body_dict,
286
  **head_share_dict,
 
290
  if body_only:
291
  param_dict["moderator_weight"] = None
292
  return param_dict
293
+ prediction_body_only = self.decode(param_dict[key], param_type="body")
 
294
  # crop
295
  for part_name in ["head", "left_hand", "right_hand"]:
296
  part = part_name.split("_")[-1]
297
  points_dict = {
298
+ "smplx_kpt": prediction_body_only["smplx_kpt"],
299
+ "trans_verts": prediction_body_only["transformed_vertices"],
 
 
300
  }
301
+ image_hd = torchvision.transforms.Resize(1024)(data["body"]["image"])
 
302
  cropped_image, cropped_joints_dict = self.part_from_body(
303
+ image_hd, part_name, points_dict
304
+ )
305
  data[key][part_name + "_image"] = cropped_image
306
 
307
  # -- encode features from part crops, then fuse feature using the weight from moderator
 
321
  self.Regressor[f"{part}_share"](f_part),
322
  self.param_list_dict[f"{part}_share_list"],
323
  )
324
+ param_dict["body_" + part_name] = {**part_dict, **part_share_dict}
 
 
 
325
 
326
  # moderator to assign weight, then integrate features
327
+ f_body_out, f_part_out, f_weight = self.Moderator[f"{part}_share"](
328
+ feature["body"][f"{part_name}_share"], f_part, work=True
329
+ )
 
330
  if copy_and_paste:
331
  # copy and paste strategy always trusts the results from part
332
  feature["body"][f"{part_name}_share"] = f_part
 
334
  # for hand, if part weight > 0.7 (very confident, then fully trust part)
335
  part_w = f_weight[:, [1]]
336
  part_w[part_w > 0.7] = 1.0
337
+ f_body_out = (
338
+ feature["body"][f"{part_name}_share"] * (1.0 - part_w) + f_part * part_w
339
+ )
340
  feature["body"][f"{part_name}_share"] = f_body_out
341
  else:
342
  feature["body"][f"{part_name}_share"] = f_body_out
 
347
  # -- predict parameters from fused body feature
348
  # head share
349
  head_share_dict = self.decompose_code(
350
+ self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
 
351
  self.param_list_dict["head" + "_share_list"],
352
  )
353
  # right hand share
354
  right_hand_share_dict = self.decompose_code(
355
+ self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
 
356
  self.param_list_dict["hand" + "_share_list"],
357
  )
358
  # left hand share
359
  left_hand_share_dict = self.decompose_code(
360
+ self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
 
361
  self.param_list_dict["hand" + "_share_list"],
362
  )
363
  # change the dict name from right to left
364
+ left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop("right_hand_pose")
365
+ left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
366
+ "right_wrist_pose"
367
+ )
 
 
368
  param_dict["body"] = {
369
  **body_dict,
370
  **head_share_dict,
 
378
  if keep_local:
379
  # for local change that will not affect whole body and produce unnatral pose, trust part
380
  param_dict[key]["exp"] = param_dict["body_head"]["exp"]
381
+ param_dict[key]["right_hand_pose"] = param_dict["body_right_hand"][
382
+ "right_hand_pose"]
383
+ param_dict[key]["left_hand_pose"] = param_dict["body_left_hand"][
384
+ "right_hand_pose"]
385
 
386
  return param_dict
387
 
 
401
  if "pose" in key and "jaw" not in key:
402
  param_dict[key] = converter.batch_cont2matrix(param_dict[key])
403
  if param_type == "body" or param_type == "head":
404
+ param_dict["jaw_pose"] = converter.batch_euler2matrix(param_dict["jaw_pose"]
405
+ )[:, None, :, :]
406
 
407
  # complement params if it's not in given param dict
408
  if param_type == "head":
409
  batch_size = param_dict["shape"].shape[0]
410
  param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
411
  param_dict["global_pose"] = param_dict["head_pose"]
412
+ param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
413
+ batch_size, -1, -1, -1
414
+ )[:, :self.param_list_dict["body_list"]["partbody_pose"]]
 
415
  param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
416
+ batch_size, -1, -1, -1
417
+ )
418
+ param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
419
+ batch_size, -1, -1, -1
420
+ )
421
+ param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
422
+ batch_size, -1, -1, -1
423
+ )
424
+ param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
425
+ batch_size, -1, -1, -1
426
+ )
427
+ param_dict["right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(0).expand(
428
+ batch_size, -1, -1, -1
429
+ )
430
  elif param_type == "hand":
431
  batch_size = param_dict["right_hand_pose"].shape[0]
432
+ param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
 
433
  dtype = param_dict["right_hand_pose"].dtype
434
  device = param_dict["right_hand_pose"].device
435
+ x_180_pose = (torch.eye(3, dtype=dtype, device=device).unsqueeze(0).repeat(1, 1, 1))
 
 
436
  x_180_pose[0, 2, 2] = -1.0
437
  x_180_pose[0, 1, 1] = -1.0
438
+ param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
439
+ param_dict["shape"] = self.smplx.shape_params.expand(batch_size, -1)
440
+ param_dict["exp"] = self.smplx.expression_params.expand(batch_size, -1)
 
 
 
441
  param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand(
442
+ batch_size, -1, -1, -1
443
+ )
444
  param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
445
+ batch_size, -1, -1, -1
446
+ )
447
+ param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
448
+ param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
449
+ batch_size, -1, -1, -1
450
+ )[:, :self.param_list_dict["body_list"]["partbody_pose"]]
451
+ param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
452
+ batch_size, -1, -1, -1
453
+ )
454
+ param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
455
+ batch_size, -1, -1, -1
456
+ )
457
  elif param_type == "body":
458
  # the predcition from the head and hand share regressor is always absolute pose
459
  batch_size = param_dict["shape"].shape[0]
460
  param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
461
+ param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
462
+ param_dict["abs_left_wrist_pose"] = param_dict["left_wrist_pose"].clone()
 
 
463
  # the body-hand share regressor is working for right hand
464
  # so we assume body network get the flipped feature for the left hand. then get the parameters
465
  # then we need to flip it back to left, which matches the input left hand
466
+ param_dict["left_wrist_pose"] = util.flip_pose(param_dict["left_wrist_pose"])
467
+ param_dict["left_hand_pose"] = util.flip_pose(param_dict["left_hand_pose"])
 
 
468
  else:
469
  exit()
470
 
 
478
  Returns:
479
  predictions: smplx predictions
480
  """
481
+ if "jaw_pose" in param_dict.keys() and len(param_dict["jaw_pose"].shape) == 2:
 
482
  self.convert_pose(param_dict, param_type)
483
  elif param_dict["right_wrist_pose"].shape[-1] == 6:
484
  self.convert_pose(param_dict, param_type)
 
501
  # change absolute head&hand pose to relative pose according to rest body pose
502
  if param_type == "head" or param_type == "body":
503
  param_dict["body_pose"] = self.smplx.pose_abs2rel(
504
+ param_dict["global_pose"], param_dict["body_pose"], abs_joint="head"
505
+ )
 
506
  if param_type == "hand" or param_type == "body":
507
  param_dict["body_pose"] = self.smplx.pose_abs2rel(
508
  param_dict["global_pose"],
 
518
  if self.cfg.model.check_pose:
519
  # check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose)
520
  # xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left)
521
+ for pose_ind in [14]: # head [15-1, 20-1, 21-1]:
522
  curr_pose = param_dict["body_pose"][:, pose_ind]
523
  euler_pose = converter._compute_euler_from_matrix(curr_pose)
524
  for i, max_angle in enumerate([20, 70, 10]):
 
528
  min=-max_angle * np.pi / 180,
529
  max=max_angle * np.pi / 180,
530
  )] = 0.0
531
+ param_dict["body_pose"][:, pose_ind] = converter.batch_euler2matrix(euler_pose)
 
 
532
 
533
  # SMPLX
534
  verts, landmarks, joints = self.smplx(
 
560
 
561
  # change the order of face keypoints, to be the same as "standard" 68 keypoints
562
  prediction["face_kpt"] = torch.cat(
563
+ [prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]], dim=1
564
+ )
565
 
566
  prediction.update(param_dict)
567
 
lib/pixielib/utils/array_cropper.py CHANGED
@@ -23,15 +23,14 @@ def points2bbox(points, points_scale=None):
23
  bottom = np.max(points[:, 1])
24
  size = max(right - left, bottom - top)
25
  # + old_size*0.1])
26
- center = np.array(
27
- [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
28
  return center, size
29
  # translate center
30
 
31
 
32
  def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.0):
33
  trans_scale = (np.random.rand(2) * 2 - 1) * trans_scale
34
- center = center + trans_scale * bbox_size # 0.5
35
  scale = np.random.rand() * (scale[1] - scale[0]) + scale[0]
36
  size = int(bbox_size * scale)
37
  return center, size
@@ -48,27 +47,25 @@ def crop_array(image, center, bboxsize, crop_size):
48
  tform: 3x3 affine matrix
49
  """
50
  # points: top-left, top-right, bottom-right
51
- src_pts = np.array([
52
- [center[0] - bboxsize / 2, center[1] - bboxsize / 2],
53
- [center[0] + bboxsize / 2, center[1] - bboxsize / 2],
54
- [center[0] + bboxsize / 2, center[1] + bboxsize / 2],
55
- ])
56
- DST_PTS = np.array([[0, 0], [crop_size - 1, 0],
57
- [crop_size - 1, crop_size - 1]])
 
58
 
59
  # estimate transformation between points
60
  tform = estimate_transform("similarity", src_pts, DST_PTS)
61
 
62
  # warp images
63
- cropped_image = warp(image,
64
- tform.inverse,
65
- output_shape=(crop_size, crop_size))
66
 
67
  return cropped_image, tform.params.T
68
 
69
 
70
  class Cropper(object):
71
-
72
  def __init__(self, crop_size, scale=[1, 1], trans_scale=0.0):
73
  self.crop_size = crop_size
74
  self.scale = scale
@@ -78,11 +75,9 @@ class Cropper(object):
78
  # points to bbox
79
  center, bbox_size = points2bbox(points, points_scale)
80
  # argument bbox.
81
- center, bbox_size = augment_bbox(center,
82
- bbox_size,
83
- scale=self.scale,
84
- trans_scale=self.trans_scale)
85
  # crop
86
- cropped_image, tform = crop_array(image, center, bbox_size,
87
- self.crop_size)
88
  return cropped_image, tform
 
23
  bottom = np.max(points[:, 1])
24
  size = max(right - left, bottom - top)
25
  # + old_size*0.1])
26
+ center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
 
27
  return center, size
28
  # translate center
29
 
30
 
31
  def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.0):
32
  trans_scale = (np.random.rand(2) * 2 - 1) * trans_scale
33
+ center = center + trans_scale * bbox_size # 0.5
34
  scale = np.random.rand() * (scale[1] - scale[0]) + scale[0]
35
  size = int(bbox_size * scale)
36
  return center, size
 
47
  tform: 3x3 affine matrix
48
  """
49
  # points: top-left, top-right, bottom-right
50
+ src_pts = np.array(
51
+ [
52
+ [center[0] - bboxsize / 2, center[1] - bboxsize / 2],
53
+ [center[0] + bboxsize / 2, center[1] - bboxsize / 2],
54
+ [center[0] + bboxsize / 2, center[1] + bboxsize / 2],
55
+ ]
56
+ )
57
+ DST_PTS = np.array([[0, 0], [crop_size - 1, 0], [crop_size - 1, crop_size - 1]])
58
 
59
  # estimate transformation between points
60
  tform = estimate_transform("similarity", src_pts, DST_PTS)
61
 
62
  # warp images
63
+ cropped_image = warp(image, tform.inverse, output_shape=(crop_size, crop_size))
 
 
64
 
65
  return cropped_image, tform.params.T
66
 
67
 
68
  class Cropper(object):
 
69
  def __init__(self, crop_size, scale=[1, 1], trans_scale=0.0):
70
  self.crop_size = crop_size
71
  self.scale = scale
 
75
  # points to bbox
76
  center, bbox_size = points2bbox(points, points_scale)
77
  # argument bbox.
78
+ center, bbox_size = augment_bbox(
79
+ center, bbox_size, scale=self.scale, trans_scale=self.trans_scale
80
+ )
 
81
  # crop
82
+ cropped_image, tform = crop_array(image, center, bbox_size, self.crop_size)
 
83
  return cropped_image, tform