diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e8558fa76f120147e1ade865b3d8fabfc12c231c
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2023 Active3DPose Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
index 154df8298fab5ecf322016157858e08cd1bccbe1..3f092443ce78a70c908e87b5f600ad12d8ae8486 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,110 @@
----
-license: apache-2.0
----
+# MotionBERT
+
+ [![arXiv](https://img.shields.io/badge/arXiv-2210.06551-b31b1b.svg)](https://arxiv.org/abs/2210.06551)
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/motionbert-unified-pretraining-for-human/monocular-3d-human-pose-estimation-on-human3)](https://paperswithcode.com/sota/monocular-3d-human-pose-estimation-on-human3?p=motionbert-unified-pretraining-for-human)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/motionbert-unified-pretraining-for-human/one-shot-3d-action-recognition-on-ntu-rgbd)](https://paperswithcode.com/sota/one-shot-3d-action-recognition-on-ntu-rgbd?p=motionbert-unified-pretraining-for-human)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/motionbert-unified-pretraining-for-human/3d-human-pose-estimation-on-3dpw)](https://paperswithcode.com/sota/3d-human-pose-estimation-on-3dpw?p=motionbert-unified-pretraining-for-human)
+
+This is the official PyTorch implementation of the paper *"[Learning Human Motion Representations: A Unified Perspective](https://arxiv.org/pdf/2210.06551.pdf)"*.
+
+
+
+## Installation
+
+```bash
+conda create -n motionbert python=3.7 anaconda
+conda activate motionbert
+# Please install PyTorch according to your CUDA version.
+conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
+pip install -r requirements.txt
+```
+
+
+
+## Getting Started
+
+| Task | Document |
+| --------------------------------- | ------------------------------------------------------------ |
+| Pretrain | [docs/pretrain.md](docs/pretrain.md) |
+| 3D human pose estimation | [docs/pose3d.md](docs/pose3d.md) |
+| Skeleton-based action recognition | [docs/action.md](docs/action.md) |
+| Mesh recovery | [docs/mesh.md](docs/mesh.md) |
+
+
+
+## Applications
+
+### In-the-wild inference (for custom videos)
+
+Please refer to [docs/inference.md](docs/inference.md).
+
+### Using MotionBERT for *human-centric* video representations
+
+```python
+'''
+ x: 2D skeletons
+ type =
+ shape = [batch size * frames * joints(17) * channels(3)]
+
+ MotionBERT: pretrained human motion encoder
+ type =
+
+ E: encoded motion representation
+ type =
+ shape = [batch size * frames * joints(17) * channels(512)]
+'''
+E = MotionBERT.get_representation(x)
+```
+
+
+
+> **Hints**
+>
+> 1. The model could handle different input lengths (no more than 243 frames). No need to explicitly specify the input length elsewhere.
+> 2. The model uses 17 body keypoints ([H36M format](https://github.com/JimmySuen/integral-human-pose/blob/master/pytorch_projects/common_pytorch/dataset/hm36.py#L32)). If you are using other formats, please convert them before feeding to MotionBERT.
+> 3. Please refer to [model_action.py](lib/model/model_action.py) and [model_mesh.py](lib/model/model_mesh.py) for examples of (easily) adapting MotionBERT to different downstream tasks.
+> 4. For RGB videos, you need to extract 2D poses ([inference.md](docs/inference.md)), convert the keypoint format ([dataset_wild.py](lib/data/dataset_wild.py)), and then feed to MotionBERT ([infer_wild.py](infer_wild.py)).
+>
+
+
+
+## Model Zoo
+
+
+
+| Model | Download Link | Config | Performance |
+| ------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ---------------- |
+| MotionBERT (162MB) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgS425shtVi9e5reN?e=6UeBa2) | [pretrain/MB_pretrain.yaml](configs/pretrain/MB_pretrain.yaml) | - |
+| MotionBERT-Lite (61MB) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgS27Ydcbpxlkl0ng?e=rq2Btn) | [pretrain/MB_lite.yaml](configs/pretrain/MB_lite.yaml) | - |
+| 3D Pose (H36M-SH, scratch) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgSvNejMQ0OHxMGZC?e=KcwBk1) | [pose3d/MB_train_h36m.yaml](configs/pose3d/MB_train_h36m.yaml) | 39.2mm (MPJPE) |
+| 3D Pose (H36M-SH, ft) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgSoTqtyR5Zsgi8_Z?e=rn4VJf) | [pose3d/MB_ft_h36m.yaml](configs/pose3d/MB_ft_h36m.yaml) | 37.2mm (MPJPE) |
+| Action Recognition (x-sub, ft) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgTX23yT_NO7RiZz-?e=nX6w2j) | [action/MB_ft_NTU60_xsub.yaml](configs/action/MB_ft_NTU60_xsub.yaml) | 97.2% (Top1 Acc) |
+| Action Recognition (x-view, ft) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgTaNiXw2Nal-g37M?e=lSkE4T) | [action/MB_ft_NTU60_xview.yaml](configs/action/MB_ft_NTU60_xview.yaml) | 93.0% (Top1 Acc) |
+| Mesh (with 3DPW, ft) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgTmgYNslCDWMNQi9?e=WjcB1F) | [mesh/MB_ft_pw3d.yaml](configs/mesh/MB_ft_pw3d.yaml) | 88.1mm (MPVE) |
+
+In most use cases (especially with finetuning), `MotionBERT-Lite` gives a similar performance with lower computation overhead.
+
+
+
+## TODO
+
+- [x] Scripts and docs for pretraining
+
+- [x] Demo for custom videos
+
+
+
+## Citation
+
+If you find our work useful for your project, please consider citing the paper:
+
+```bibtex
+@article{motionbert2022,
+ title = {Learning Human Motion Representations: A Unified Perspective},
+ author = {Zhu, Wentao and Ma, Xiaoxuan and Liu, Zhaoyang and Liu, Libin and Wu, Wayne and Wang, Yizhou},
+ year = {2022},
+ journal = {arXiv preprint arXiv:2210.06551},
+}
+```
+
diff --git a/configs/action/MB_ft_NTU120_oneshot.yaml b/configs/action/MB_ft_NTU120_oneshot.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ed25ac9719d0b7938338a21898cf2221a8b53513
--- /dev/null
+++ b/configs/action/MB_ft_NTU120_oneshot.yaml
@@ -0,0 +1,35 @@
+# General
+finetune: True
+partial_train: null
+
+# Traning
+n_views: 2
+temp: 0.1
+
+epochs: 50
+batch_size: 32
+lr_backbone: 0.0001
+lr_head: 0.001
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+model_version: embed
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+num_joints: 17
+hidden_dim: 2048
+dropout_ratio: 0.1
+
+# Data
+clip_len: 100
+
+# Augmentation
+random_move: True
+scale_range_train: [1, 3]
+scale_range_test: [2, 2]
\ No newline at end of file
diff --git a/configs/action/MB_ft_NTU60_xsub.yaml b/configs/action/MB_ft_NTU60_xsub.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e26f8c42f58bbf754b3786f59bd6391b8f4a8dd4
--- /dev/null
+++ b/configs/action/MB_ft_NTU60_xsub.yaml
@@ -0,0 +1,35 @@
+# General
+finetune: True
+partial_train: null
+
+# Traning
+epochs: 300
+batch_size: 32
+lr_backbone: 0.0001
+lr_head: 0.001
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+model_version: class
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+num_joints: 17
+hidden_dim: 2048
+dropout_ratio: 0.5
+
+# Data
+dataset: ntu60_hrnet
+data_split: xsub
+clip_len: 243
+action_classes: 60
+
+# Augmentation
+random_move: True
+scale_range_train: [1, 3]
+scale_range_test: [2, 2]
\ No newline at end of file
diff --git a/configs/action/MB_ft_NTU60_xview.yaml b/configs/action/MB_ft_NTU60_xview.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e1fe649f5cca6a30579a42790f4fea2f6f93f5bf
--- /dev/null
+++ b/configs/action/MB_ft_NTU60_xview.yaml
@@ -0,0 +1,35 @@
+# General
+finetune: True
+partial_train: null
+
+# Traning
+epochs: 300
+batch_size: 32
+lr_backbone: 0.0001
+lr_head: 0.001
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+model_version: class
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+num_joints: 17
+hidden_dim: 2048
+dropout_ratio: 0.5
+
+# Data
+dataset: ntu60_hrnet
+data_split: xview
+clip_len: 243
+action_classes: 60
+
+# Augmentation
+random_move: True
+scale_range_train: [1, 3]
+scale_range_test: [2, 2]
\ No newline at end of file
diff --git a/configs/action/MB_train_NTU120_oneshot.yaml b/configs/action/MB_train_NTU120_oneshot.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9aa8d732177eff9a0ece5706c9658645944fb928
--- /dev/null
+++ b/configs/action/MB_train_NTU120_oneshot.yaml
@@ -0,0 +1,35 @@
+# General
+finetune: False
+partial_train: null
+
+# Traning
+n_views: 2
+temp: 0.1
+
+epochs: 50
+batch_size: 32
+lr_backbone: 0.0001
+lr_head: 0.001
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+model_version: embed
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+num_joints: 17
+hidden_dim: 2048
+dropout_ratio: 0.1
+
+# Data
+clip_len: 100
+
+# Augmentation
+random_move: True
+scale_range_train: [1, 3]
+scale_range_test: [2, 2]
\ No newline at end of file
diff --git a/configs/action/MB_train_NTU60_xsub.yaml b/configs/action/MB_train_NTU60_xsub.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..adec630e16b18ce9d07ba2907f8ebd94e8ddc146
--- /dev/null
+++ b/configs/action/MB_train_NTU60_xsub.yaml
@@ -0,0 +1,35 @@
+# General
+finetune: False
+partial_train: null
+
+# Traning
+epochs: 300
+batch_size: 32
+lr_backbone: 0.0001
+lr_head: 0.0001
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+model_version: class
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+num_joints: 17
+hidden_dim: 2048
+dropout_ratio: 0.5
+
+# Data
+dataset: ntu60_hrnet
+data_split: xsub
+clip_len: 243
+action_classes: 60
+
+# Augmentation
+random_move: True
+scale_range_train: [1, 3]
+scale_range_test: [2, 2]
\ No newline at end of file
diff --git a/configs/action/MB_train_NTU60_xview.yaml b/configs/action/MB_train_NTU60_xview.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c76d291a8fbe0b6472f9175fd04bbe755a4e304d
--- /dev/null
+++ b/configs/action/MB_train_NTU60_xview.yaml
@@ -0,0 +1,35 @@
+# General
+finetune: False
+partial_train: null
+
+# Traning
+epochs: 300
+batch_size: 32
+lr_backbone: 0.0001
+lr_head: 0.0001
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+model_version: class
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+num_joints: 17
+hidden_dim: 2048
+dropout_ratio: 0.5
+
+# Data
+dataset: ntu60_hrnet
+data_split: xview
+clip_len: 243
+action_classes: 60
+
+# Augmentation
+random_move: True
+scale_range_train: [1, 3]
+scale_range_test: [2, 2]
\ No newline at end of file
diff --git a/configs/mesh/MB_ft_h36m.yaml b/configs/mesh/MB_ft_h36m.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a5f397ab68ee3a7d0f328712b5ca6cc3b5b748a8
--- /dev/null
+++ b/configs/mesh/MB_ft_h36m.yaml
@@ -0,0 +1,51 @@
+# General
+finetune: True
+partial_train: null
+train_pw3d: False
+warmup_h36m: 100
+
+# Traning
+epochs: 60
+checkpoint_frequency: 20
+batch_size: 128
+batch_size_img: 512
+dropout: 0.1
+dropout_loc: 1
+lr_backbone: 0.00005
+lr_head: 0.0005
+weight_decay: 0.01
+lr_decay: 0.98
+
+# Model
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+hidden_dim: 1024
+
+# Data
+data_root: data/mesh
+dt_file_h36m: mesh_det_h36m.pkl
+clip_len: 16
+data_stride: 8
+sample_stride: 1
+num_joints: 17
+
+# Loss
+lambda_3d: 0.5
+lambda_scale: 0
+lambda_3dv: 10
+lambda_lv: 0
+lambda_lg: 0
+lambda_a: 0
+lambda_av: 0
+lambda_pose: 1000
+lambda_shape: 1
+lambda_norm: 20
+loss_type: 'L1'
+
+# Augmentation
+flip: True
\ No newline at end of file
diff --git a/configs/mesh/MB_ft_pw3d.yaml b/configs/mesh/MB_ft_pw3d.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2a59d6fe70f572416204f6f4f18ad82bb0f64a6c
--- /dev/null
+++ b/configs/mesh/MB_ft_pw3d.yaml
@@ -0,0 +1,53 @@
+# General
+finetune: True
+partial_train: null
+train_pw3d: True
+warmup_h36m: 20
+warmup_coco: 100
+
+# Traning
+epochs: 60
+checkpoint_frequency: 20
+batch_size: 128
+batch_size_img: 512
+dropout: 0.1
+lr_backbone: 0.00005
+lr_head: 0.0005
+weight_decay: 0.01
+lr_decay: 0.98
+
+# Model
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+hidden_dim: 1024
+
+# Data
+data_root: data/mesh
+dt_file_h36m: mesh_det_h36m.pkl
+dt_file_coco: mesh_det_coco.pkl
+dt_file_pw3d: mesh_det_pw3d.pkl
+clip_len: 16
+data_stride: 8
+sample_stride: 1
+num_joints: 17
+
+# Loss
+lambda_3d: 0.5
+lambda_scale: 0
+lambda_3dv: 10
+lambda_lv: 0
+lambda_lg: 0
+lambda_a: 0
+lambda_av: 0
+lambda_pose: 1000
+lambda_shape: 1
+lambda_norm: 20
+loss_type: 'L1'
+
+# Augmentation
+flip: True
\ No newline at end of file
diff --git a/configs/mesh/MB_train_h36m.yaml b/configs/mesh/MB_train_h36m.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..74669f73b09f157e326e47f1676899b7d9ba92db
--- /dev/null
+++ b/configs/mesh/MB_train_h36m.yaml
@@ -0,0 +1,51 @@
+# General
+finetune: False
+partial_train: null
+train_pw3d: False
+warmup_h36m: 100
+
+# Traning
+epochs: 100
+checkpoint_frequency: 20
+batch_size: 128
+batch_size_img: 512
+dropout: 0.1
+dropout_loc: 1
+lr_backbone: 0.0001
+lr_head: 0.0001
+weight_decay: 0.01
+lr_decay: 0.98
+
+# Model
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+hidden_dim: 1024
+
+# Data
+data_root: data/mesh
+dt_file_h36m: mesh_det_h36m.pkl
+clip_len: 16
+data_stride: 8
+sample_stride: 1
+num_joints: 17
+
+# Loss
+lambda_3d: 0.5
+lambda_scale: 0
+lambda_3dv: 10
+lambda_lv: 0
+lambda_lg: 0
+lambda_a: 0
+lambda_av: 0
+lambda_pose: 1000
+lambda_shape: 1
+lambda_norm: 20
+loss_type: 'L1'
+
+# Augmentation
+flip: True
\ No newline at end of file
diff --git a/configs/mesh/MB_train_pw3d.yaml b/configs/mesh/MB_train_pw3d.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..379d0f95d74676783554703ad23312b9e0fcec65
--- /dev/null
+++ b/configs/mesh/MB_train_pw3d.yaml
@@ -0,0 +1,53 @@
+# General
+finetune: False
+partial_train: null
+train_pw3d: True
+warmup_h36m: 20
+warmup_coco: 100
+
+# Traning
+epochs: 60
+checkpoint_frequency: 20
+batch_size: 128
+batch_size_img: 512
+dropout: 0.1
+lr_backbone: 0.0001
+lr_head: 0.0001
+weight_decay: 0.01
+lr_decay: 0.98
+
+# Model
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+hidden_dim: 1024
+
+# Data
+data_root: data/mesh
+dt_file_h36m: mesh_det_h36m.pkl
+dt_file_coco: mesh_det_coco.pkl
+dt_file_pw3d: mesh_det_pw3d.pkl
+clip_len: 16
+data_stride: 8
+sample_stride: 1
+num_joints: 17
+
+# Loss
+lambda_3d: 0.5
+lambda_scale: 0
+lambda_3dv: 10
+lambda_lv: 0
+lambda_lg: 0
+lambda_a: 0
+lambda_av: 0
+lambda_pose: 1000
+lambda_shape: 1
+lambda_norm: 20
+loss_type: 'L1'
+
+# Augmentation
+flip: True
diff --git a/configs/pose3d/MB_ft_h36m.yaml b/configs/pose3d/MB_ft_h36m.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b52f5a2c31057529959f871761e3430f1ada4371
--- /dev/null
+++ b/configs/pose3d/MB_ft_h36m.yaml
@@ -0,0 +1,50 @@
+# General
+train_2d: False
+no_eval: False
+finetune: True
+partial_train: null
+
+# Traning
+epochs: 60
+checkpoint_frequency: 30
+batch_size: 32
+dropout: 0.0
+learning_rate: 0.0002
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+
+# Data
+data_root: data/motion3d/MB3D_f243s81/
+subset_list: [H36M-SH]
+dt_file: h36m_sh_conf_cam_source_final.pkl
+clip_len: 243
+data_stride: 81
+rootrel: True
+sample_stride: 1
+num_joints: 17
+no_conf: False
+gt_2d: False
+
+# Loss
+lambda_3d_velocity: 20.0
+lambda_scale: 0.5
+lambda_lv: 0.0
+lambda_lg: 0.0
+lambda_a: 0.0
+lambda_av: 0.0
+
+# Augmentation
+synthetic: False
+flip: True
+mask_ratio: 0.
+mask_T_ratio: 0.
+noise: False
diff --git a/configs/pose3d/MB_ft_h36m_global.yaml b/configs/pose3d/MB_ft_h36m_global.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eb9aa5e18a1a6b27d07c4d1395e6feeb12857406
--- /dev/null
+++ b/configs/pose3d/MB_ft_h36m_global.yaml
@@ -0,0 +1,50 @@
+# General
+train_2d: False
+no_eval: False
+finetune: True
+partial_train: null
+
+# Traning
+epochs: 60
+checkpoint_frequency: 30
+batch_size: 32
+dropout: 0.0
+learning_rate: 0.0002
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+
+# Data
+data_root: data/motion3d/MB3D_f243s81/
+subset_list: [H36M-SH]
+dt_file: h36m_sh_conf_cam_source_final.pkl
+clip_len: 243
+data_stride: 81
+rootrel: False
+sample_stride: 1
+num_joints: 17
+no_conf: False
+gt_2d: False
+
+# Loss
+lambda_3d_velocity: 20.0
+lambda_scale: 0.5
+lambda_lv: 0.0
+lambda_lg: 0.0
+lambda_a: 0.0
+lambda_av: 0.0
+
+# Augmentation
+synthetic: False
+flip: True
+mask_ratio: 0.
+mask_T_ratio: 0.
+noise: False
\ No newline at end of file
diff --git a/configs/pose3d/MB_ft_h36m_global_lite.yaml b/configs/pose3d/MB_ft_h36m_global_lite.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4a2d5d976b8e0b52c1a166d92206aa16e95c0eb8
--- /dev/null
+++ b/configs/pose3d/MB_ft_h36m_global_lite.yaml
@@ -0,0 +1,50 @@
+# General
+train_2d: False
+no_eval: False
+finetune: True
+partial_train: null
+
+# Traning
+epochs: 60
+checkpoint_frequency: 30
+batch_size: 32
+dropout: 0.0
+learning_rate: 0.0005
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+maxlen: 243
+dim_feat: 256
+mlp_ratio: 4
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+
+# Data
+data_root: data/motion3d/MB3D_f243s81/
+subset_list: [H36M-SH]
+dt_file: h36m_sh_conf_cam_source_final.pkl
+clip_len: 243
+data_stride: 81
+rootrel: False
+sample_stride: 1
+num_joints: 17
+no_conf: False
+gt_2d: False
+
+# Loss
+lambda_3d_velocity: 20.0
+lambda_scale: 0.5
+lambda_lv: 0.0
+lambda_lg: 0.0
+lambda_a: 0.0
+lambda_av: 0.0
+
+# Augmentation
+synthetic: False
+flip: True
+mask_ratio: 0.
+mask_T_ratio: 0.
+noise: False
\ No newline at end of file
diff --git a/configs/pose3d/MB_train_h36m.yaml b/configs/pose3d/MB_train_h36m.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..39a196e6a3089f0c88523e29a716c592369fc66a
--- /dev/null
+++ b/configs/pose3d/MB_train_h36m.yaml
@@ -0,0 +1,51 @@
+# General
+train_2d: False
+no_eval: False
+finetune: False
+partial_train: null
+
+# Traning
+epochs: 120
+checkpoint_frequency: 30
+batch_size: 32
+dropout: 0.0
+learning_rate: 0.0002
+weight_decay: 0.01
+lr_decay: 0.99
+
+# Model
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+
+# Data
+data_root: data/motion3d/MB3D_f243s81/
+subset_list: [H36M-SH]
+dt_file: h36m_sh_conf_cam_source_final.pkl
+clip_len: 243
+data_stride: 81
+rootrel: True
+sample_stride: 1
+num_joints: 17
+no_conf: False
+gt_2d: False
+
+# Loss
+lambda_3d_velocity: 20.0
+lambda_scale: 0.5
+lambda_lv: 0.0
+lambda_lg: 0.0
+lambda_a: 0.0
+lambda_av: 0.0
+
+# Augmentation
+synthetic: False
+flip: True
+mask_ratio: 0.
+mask_T_ratio: 0.
+noise: False
+
diff --git a/configs/pretrain/MB_lite.yaml b/configs/pretrain/MB_lite.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ae06971e2b01a5e9a66a1c57ddfcb5a6bd31b6a5
--- /dev/null
+++ b/configs/pretrain/MB_lite.yaml
@@ -0,0 +1,53 @@
+# General
+train_2d: True
+no_eval: False
+finetune: False
+partial_train: null
+
+# Traning
+epochs: 90
+checkpoint_frequency: 30
+batch_size: 64
+dropout: 0.0
+learning_rate: 0.0005
+weight_decay: 0.01
+lr_decay: 0.99
+pretrain_3d_curriculum: 30
+
+# Model
+maxlen: 243
+dim_feat: 256
+mlp_ratio: 4
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+
+# Data
+data_root: data/motion3d/MB3D_f243s81/
+subset_list: [AMASS, H36M-SH]
+dt_file: h36m_sh_conf_cam_source_final.pkl
+clip_len: 243
+data_stride: 81
+rootrel: True
+sample_stride: 1
+num_joints: 17
+no_conf: False
+gt_2d: False
+
+# Loss
+lambda_3d_velocity: 20.0
+lambda_scale: 0.5
+lambda_lv: 0.0
+lambda_lg: 0.0
+lambda_a: 0.0
+lambda_av: 0.0
+
+# Augmentation
+synthetic: True # synthetic: don't use 2D detection results, fake it (from 3D)
+flip: True
+mask_ratio: 0.05
+mask_T_ratio: 0.1
+noise: True
+noise_path: params/synthetic_noise.pth
+d2c_params_path: params/d2c_params.pkl
\ No newline at end of file
diff --git a/configs/pretrain/MB_pretrain.yaml b/configs/pretrain/MB_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..efc86eb7aa138893980d949ba43e0705606b46d3
--- /dev/null
+++ b/configs/pretrain/MB_pretrain.yaml
@@ -0,0 +1,53 @@
+# General
+train_2d: True
+no_eval: False
+finetune: False
+partial_train: null
+
+# Traning
+epochs: 90
+checkpoint_frequency: 30
+batch_size: 64
+dropout: 0.0
+learning_rate: 0.0005
+weight_decay: 0.01
+lr_decay: 0.99
+pretrain_3d_curriculum: 30
+
+# Model
+maxlen: 243
+dim_feat: 512
+mlp_ratio: 2
+depth: 5
+dim_rep: 512
+num_heads: 8
+att_fuse: True
+
+# Data
+data_root: data/motion3d/MB3D_f243s81/
+subset_list: [AMASS, H36M-SH]
+dt_file: h36m_sh_conf_cam_source_final.pkl
+clip_len: 243
+data_stride: 81
+rootrel: True
+sample_stride: 1
+num_joints: 17
+no_conf: False
+gt_2d: False
+
+# Loss
+lambda_3d_velocity: 20.0
+lambda_scale: 0.5
+lambda_lv: 0.0
+lambda_lg: 0.0
+lambda_a: 0.0
+lambda_av: 0.0
+
+# Augmentation
+synthetic: True # synthetic: don't use 2D detection results, fake it (from 3D)
+flip: True
+mask_ratio: 0.05
+mask_T_ratio: 0.1
+noise: True
+noise_path: params/synthetic_noise.pth
+d2c_params_path: params/d2c_params.pkl
diff --git a/docs/action.md b/docs/action.md
new file mode 100644
index 0000000000000000000000000000000000000000..874f9415f9e9f62d3e9af446d6f1a6d2306666c8
--- /dev/null
+++ b/docs/action.md
@@ -0,0 +1,86 @@
+# Skeleton-based Action Recognition
+
+## Data
+
+The NTURGB+D 2D detection results are provided by [pyskl](https://github.com/kennymckormick/pyskl/blob/main/tools/data/README.md) using HRNet.
+
+1. Download [`ntu60_hrnet.pkl`](https://download.openmmlab.com/mmaction/pyskl/data/nturgbd/ntu60_hrnet.pkl) and [`ntu120_hrnet.pkl`](https://download.openmmlab.com/mmaction/pyskl/data/nturgbd/ntu120_hrnet.pkl) to `data/action/`.
+2. Download the 1-shot split [here](https://1drv.ms/f/s!AvAdh0LSjEOlfi-hqlHxdVMZxWM) and put it to `data/action/`.
+
+## Running
+
+### NTURGB+D
+
+**Train from scratch:**
+
+```shell
+# Cross-subject
+python train_action.py \
+--config configs/action/MB_train_NTU60_xsub.yaml \
+--checkpoint checkpoint/action/MB_train_NTU60_xsub
+
+# Cross-view
+python train_action.py \
+--config configs/action/MB_train_NTU60_xview.yaml \
+--checkpoint checkpoint/action/MB_train_NTU60_xview
+```
+
+**Finetune from pretrained MotionBERT:**
+
+```shell
+# Cross-subject
+python train_action.py \
+--config configs/action/MB_ft_NTU60_xsub.yaml \
+--pretrained checkpoint/pretrain/MB_release \
+--checkpoint checkpoint/action/FT_MB_release_MB_ft_NTU60_xsub
+
+# Cross-view
+python train_action.py \
+--config configs/action/MB_ft_NTU60_xview.yaml \
+--pretrained checkpoint/pretrain/MB_release \
+--checkpoint checkpoint/action/FT_MB_release_MB_ft_NTU60_xview
+```
+
+**Evaluate:**
+
+```bash
+# Cross-subject
+python train_action.py \
+--config configs/action/MB_train_NTU60_xsub.yaml \
+--evaluate checkpoint/action/MB_train_NTU60_xsub/best_epoch.bin
+
+# Cross-view
+python train_action.py \
+--config configs/action/MB_train_NTU60_xview.yaml \
+--evaluate checkpoint/action/MB_train_NTU60_xview/best_epoch.bin
+```
+
+### NTURGB+D-120 (1-shot)
+
+**Train from scratch:**
+
+```bash
+python train_action_1shot.py \
+--config configs/action/MB_train_NTU120_oneshot.yaml \
+--checkpoint checkpoint/action/MB_train_NTU120_oneshot
+```
+
+**Finetune from a pretrained model:**
+
+```bash
+python train_action_1shot.py \
+--config configs/action/MB_ft_NTU120_oneshot.yaml \
+--pretrained checkpoint/pretrain/MB_release \
+--checkpoint checkpoint/action/FT_MB_release_MB_ft_NTU120_oneshot
+```
+
+**Evaluate:**
+
+```bash
+python train_action_1shot.py \
+--config configs/action/MB_train_NTU120_oneshot.yaml \
+--evaluate checkpoint/action/MB_train_NTU120_oneshot/best_epoch.bin
+```
+
+
+
diff --git a/docs/inference.md b/docs/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..0333a71b5870ff3bb75ea277d33ec371b76322c2
--- /dev/null
+++ b/docs/inference.md
@@ -0,0 +1,48 @@
+# In-the-wild Inference
+
+## 2D Pose
+
+Please use [AlphaPose](https://github.com/MVIG-SJTU/AlphaPose#quick-start) to extract the 2D keypoints for your video first. We use the *Fast Pose* model trained on *Halpe* dataset ([Link](https://github.com/MVIG-SJTU/AlphaPose/blob/master/docs/MODEL_ZOO.md#halpe-dataset-26-keypoints)).
+
+Note: Currently we only support single person. If your video contains multiple person, you may need to use the [Pose Tracking Module for AlphaPose](https://github.com/MVIG-SJTU/AlphaPose/tree/master/trackers) and set `--focus` to specify the target person id.
+
+
+
+## 3D Pose
+
+| ![pose_1](https://github.com/motionbert/motionbert.github.io/blob/main/assets/pose_1.gif?raw=true) | ![pose_2](https://raw.githubusercontent.com/motionbert/motionbert.github.io/main/assets/pose_2.gif) |
+| ------------------------------------------------------------ | ------------------------------------------------------------ |
+
+
+1. Please download the checkpoint [here](https://1drv.ms/f/s!AvAdh0LSjEOlgT67igq_cIoYvO2y?e=bfEc73) and put it to `checkpoint/pose3d/FT_MB_lite_MB_ft_h36m_global_lite/`.
+1. Run the following command to infer from the extracted 2D poses:
+```bash
+python infer_wild.py \
+--vid_path \
+--json_path \
+--out_path
+```
+
+
+
+## Mesh
+
+| ![mesh_1](https://raw.githubusercontent.com/motionbert/motionbert.github.io/main/assets/mesh_1.gif) | ![mesh_2](https://github.com/motionbert/motionbert.github.io/blob/main/assets/mesh_2.gif?raw=true) |
+| ------------------------------------------------------------ | ----------- |
+
+1. Please download the checkpoint [here](https://1drv.ms/f/s!AvAdh0LSjEOlgTmgYNslCDWMNQi9?e=WjcB1F) and put it to `checkpoint/mesh/FT_MB_release_MB_ft_pw3d/`
+2. Run the following command to infer from the extracted 2D poses:
+```bash
+python infer_wild_mesh.py \
+--vid_path \
+--json_path \
+--out_path \
+--ref_3d_motion_path <3d-pose-results.npy> # Optional, use the estimated 3D motion for root trajectory.
+```
+
+
+
+
+
+
+
diff --git a/docs/mesh.md b/docs/mesh.md
new file mode 100644
index 0000000000000000000000000000000000000000..3d7e1fe27a1c050647f78ce814ee9c7609e86d7d
--- /dev/null
+++ b/docs/mesh.md
@@ -0,0 +1,61 @@
+# Human Mesh Recovery
+
+## Data
+
+1. Download the datasets [here](https://1drv.ms/f/s!AvAdh0LSjEOlfy-hqlHxdVMZxWM) and put them to `data/mesh/`. We use Human3.6M, COCO, and PW3D for training and testing. Descriptions of the joint regressors could be found in [SPIN](https://github.com/nkolot/SPIN/tree/master/data).
+2. Download the SMPL model(`basicModel_neutral_lbs_10_207_0_v1.0.0.pkl`) from [SMPLify](https://smplify.is.tue.mpg.de/), put it to `data/mesh/`, and rename it as `SMPL_NEUTRAL.pkl`
+
+
+## Running
+
+**Train from scratch:**
+
+```bash
+# with 3DPW
+python train_mesh.py \
+--config configs/mesh/MB_train_pw3d.yaml \
+--checkpoint checkpoint/mesh/MB_train_pw3d
+
+# H36M
+python train_mesh.py \
+--config configs/mesh/MB_train_h36m.yaml \
+--checkpoint checkpoint/mesh/MB_train_h36m
+```
+
+**Finetune from a pretrained model:**
+
+```bash
+# with 3DPW
+python train_mesh.py \
+--config configs/mesh/MB_ft_pw3d.yaml \
+--pretrained checkpoint/pretrain/MB_release \
+--checkpoint checkpoint/mesh/FT_MB_release_MB_ft_pw3d
+
+# H36M
+python train_mesh.py \
+--config configs/mesh/MB_ft_h36m.yaml \
+--pretrained checkpoint/pretrain/MB_release \
+--checkpoint checkpoint/mesh/FT_MB_release_MB_ft_h36m
+
+```
+
+**Evaluate:**
+
+```bash
+# with 3DPW
+python train_mesh.py \
+--config configs/mesh/MB_train_pw3d.yaml \
+--evaluate checkpoint/mesh/MB_train_pw3d/best_epoch.bin
+
+# H36M
+python train_mesh.py \
+--config configs/mesh/MB_train_h36m.yaml \
+--evaluate checkpoint/mesh/MB_train_h36m/best_epoch.bin
+```
+
+
+
+
+
+
+
diff --git a/docs/pose3d.md b/docs/pose3d.md
new file mode 100644
index 0000000000000000000000000000000000000000..448e3ee148703d8eb26e1cce295c67a85276adaf
--- /dev/null
+++ b/docs/pose3d.md
@@ -0,0 +1,51 @@
+# 3D Human Pose Estimation
+
+## Data
+
+1. Download the finetuned Stacked Hourglass detections and our preprocessed H3.6M data (.pkl) [here](https://1drv.ms/u/s!AvAdh0LSjEOlgSMvoapR8XVTGcVj) and put it to `data/motion3d`.
+
+ > Note that the preprocessed data is only intended for reproducing our results more easily. If you want to use the dataset, please register to the [Human3.6m website](http://vision.imar.ro/human3.6m/) and download the dataset in its original format. Please refer to [LCN](https://github.com/CHUNYUWANG/lcn-pose#data) for how we prepare the H3.6M data.
+
+2. Slice the motion clips (len=243, stride=81)
+
+ ```bash
+ python tools/convert_h36m.py
+ ```
+
+## Running
+
+**Train from scratch:**
+
+```bash
+python train.py \
+--config configs/pose3d/MB_train_h36m.yaml \
+--checkpoint checkpoint/pose3d/MB_train_h36m
+```
+
+**Finetune from pretrained MotionBERT:**
+
+```bash
+python train.py \
+--config configs/pose3d/MB_ft_h36m.yaml \
+--pretrained checkpoint/pretrain/MB_release \
+--checkpoint checkpoint/pose3d/FT_MB_release_MB_ft_h36m
+```
+
+**Evaluate:**
+
+```bash
+python train.py \
+--config configs/pose3d/MB_train_h36m.yaml \
+--evaluate checkpoint/pose3d/MB_train_h36m/best_epoch.bin
+```
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/pretrain.md b/docs/pretrain.md
new file mode 100644
index 0000000000000000000000000000000000000000..bfec36d102d0bfab8bf496a3dff3ffa989c82cc7
--- /dev/null
+++ b/docs/pretrain.md
@@ -0,0 +1,59 @@
+# Pretrain
+
+## Data
+
+### AMASS
+
+1. Please download data from the [official website](https://amass.is.tue.mpg.de/download.php) (SMPL+H).
+2. We provide the preprocessing scripts as follows. Minor modifications might be necessary.
+ - [tools/compress_amass.py](../tools/compress_amass.py): downsample the frame rate
+ - [tools/preprocess_amass.py](../tools/preprocess_amass.py): render the mocap data and extract the 3D keypoints
+ - [tools/convert_amass.py](../tools/convert_amass.py): slice them to motion clips
+
+
+### Human 3.6M
+
+Please refer to [pose3d.md](pose3d.md#data).
+
+### InstaVariety
+
+1. Please download data from [human_dynamics](https://github.com/akanazawa/human_dynamics/blob/master/doc/insta_variety.md#generating-tfrecords) to `data/motion2d`.
+1. Use [tools/convert_insta.py](../tools/convert_insta.py) to preprocess the 2D keypoints (need to specify `name_action` ).
+
+### PoseTrack
+
+Please download PoseTrack18 from [MMPose](https://mmpose.readthedocs.io/en/latest/tasks/2d_body_keypoint.html#posetrack18) and unzip to `data/motion2d`.
+
+
+
+The processed directory tree should look like this:
+
+```
+.
+└── data/
+ ├── motion3d/
+ │ └── MB3D_f243s81/
+ │ ├── AMASS
+ │ └── H36M-SH
+ ├── motion2d/
+ │ ├── InstaVariety/
+ │ │ ├── motion_all.npy
+ │ │ └── id_all.npy
+ │ └── posetrack18_annotations/
+ │ ├── train
+ │ └── ...
+ └── ...
+```
+
+
+
+## Train
+
+```bash
+python train.py \
+--config configs/pretrain/MB_pretrain.yaml \
+-c checkpoint/pretrain/MB_pretrain
+```
+
+
+
diff --git a/infer_wild.py b/infer_wild.py
new file mode 100644
index 0000000000000000000000000000000000000000..17acd194e06db341001101f4c7bb70b6710bdae8
--- /dev/null
+++ b/infer_wild.py
@@ -0,0 +1,97 @@
+import os
+import numpy as np
+import argparse
+from tqdm import tqdm
+import imageio
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from lib.utils.tools import *
+from lib.utils.learning import *
+from lib.utils.utils_data import flip_data
+from lib.data.dataset_wild import WildDetDataset
+from lib.utils.vismo import render_and_save
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="configs/pose3d/MB_ft_h36m_global_lite.yaml", help="Path to the config file.")
+ parser.add_argument('-e', '--evaluate', default='checkpoint/pose3d/FT_MB_lite_MB_ft_h36m_global_lite/best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
+ parser.add_argument('-j', '--json_path', type=str, help='alphapose detection result json path')
+ parser.add_argument('-v', '--vid_path', type=str, help='video path')
+ parser.add_argument('-o', '--out_path', type=str, help='output path')
+ parser.add_argument('--pixel', action='store_true', help='align with pixle coordinates')
+ parser.add_argument('--focus', type=int, default=None, help='target person id')
+ parser.add_argument('--clip_len', type=int, default=243, help='clip length for network input')
+ opts = parser.parse_args()
+ return opts
+
+opts = parse_args()
+args = get_config(opts.config)
+
+model_backbone = load_backbone(args)
+if torch.cuda.is_available():
+ model_backbone = nn.DataParallel(model_backbone)
+ model_backbone = model_backbone.cuda()
+
+print('Loading checkpoint', opts.evaluate)
+checkpoint = torch.load(opts.evaluate, map_location=lambda storage, loc: storage)
+model_backbone.load_state_dict(checkpoint['model_pos'], strict=True)
+model_pos = model_backbone
+model_pos.eval()
+testloader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True,
+ 'drop_last': False
+}
+
+vid = imageio.get_reader(opts.vid_path, 'ffmpeg')
+fps_in = vid.get_meta_data()['fps']
+vid_size = vid.get_meta_data()['size']
+os.makedirs(opts.out_path, exist_ok=True)
+
+if opts.pixel:
+ # Keep relative scale with pixel coornidates
+ wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, vid_size=vid_size, scale_range=None, focus=opts.focus)
+else:
+ # Scale to [-1,1]
+ wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, scale_range=[1,1], focus=opts.focus)
+
+test_loader = DataLoader(wild_dataset, **testloader_params)
+
+results_all = []
+with torch.no_grad():
+ for batch_input in tqdm(test_loader):
+ N, T = batch_input.shape[:2]
+ if torch.cuda.is_available():
+ batch_input = batch_input.cuda()
+ if args.no_conf:
+ batch_input = batch_input[:, :, :, :2]
+ if args.flip:
+ batch_input_flip = flip_data(batch_input)
+ predicted_3d_pos_1 = model_pos(batch_input)
+ predicted_3d_pos_flip = model_pos(batch_input_flip)
+ predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip) # Flip back
+ predicted_3d_pos = (predicted_3d_pos_1 + predicted_3d_pos_2) / 2.0
+ else:
+ predicted_3d_pos = model_pos(batch_input)
+ if args.rootrel:
+ predicted_3d_pos[:,:,0,:]=0 # [N,T,17,3]
+ else:
+ predicted_3d_pos[:,0,0,2]=0
+ pass
+ if args.gt_2d:
+ predicted_3d_pos[...,:2] = batch_input[...,:2]
+ results_all.append(predicted_3d_pos.cpu().numpy())
+
+results_all = np.hstack(results_all)
+results_all = np.concatenate(results_all)
+render_and_save(results_all, '%s/X3D.mp4' % (opts.out_path), keep_imgs=False, fps=fps_in)
+if opts.pixel:
+ # Convert to pixel coordinates
+ results_all = results_all * (min(vid_size) / 2.0)
+ results_all[:,:,:2] = results_all[:,:,:2] + np.array(vid_size) / 2.0
+np.save('%s/X3D.npy' % (opts.out_path), results_all)
\ No newline at end of file
diff --git a/infer_wild_mesh.py b/infer_wild_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c1c9b7d022a6b1cf771f84c644c75ee086f99c1
--- /dev/null
+++ b/infer_wild_mesh.py
@@ -0,0 +1,157 @@
+import os
+import os.path as osp
+import numpy as np
+import argparse
+import pickle
+from tqdm import tqdm
+import time
+import random
+import imageio
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+
+from lib.utils.tools import *
+from lib.utils.learning import *
+from lib.utils.utils_data import flip_data
+from lib.utils.utils_mesh import flip_thetas_batch
+from lib.data.dataset_wild import WildDetDataset
+# from lib.model.loss import *
+from lib.model.model_mesh import MeshRegressor
+from lib.utils.vismo import render_and_save, motion2video_mesh
+from lib.utils.utils_smpl import *
+from scipy.optimize import least_squares
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="configs/mesh/MB_ft_pw3d.yaml", help="Path to the config file.")
+ parser.add_argument('-e', '--evaluate', default='checkpoint/mesh/FT_MB_release_MB_ft_pw3d/best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
+ parser.add_argument('-j', '--json_path', type=str, help='alphapose detection result json path')
+ parser.add_argument('-v', '--vid_path', type=str, help='video path')
+ parser.add_argument('-o', '--out_path', type=str, help='output path')
+ parser.add_argument('--ref_3d_motion_path', type=str, default=None, help='3D motion path')
+ parser.add_argument('--pixel', action='store_true', help='align with pixle coordinates')
+ parser.add_argument('--focus', type=int, default=None, help='target person id')
+ parser.add_argument('--clip_len', type=int, default=243, help='clip length for network input')
+ opts = parser.parse_args()
+ return opts
+
+def err(p, x, y):
+ return np.linalg.norm(p[0] * x + np.array([p[1], p[2], p[3]]) - y, axis=-1).mean()
+
+def solve_scale(x, y):
+ print('Estimating camera transformation.')
+ best_res = 100000
+ best_scale = None
+ for init_scale in tqdm(range(0,2000,5)):
+ p0 = [init_scale, 0.0, 0.0, 0.0]
+ est = least_squares(err, p0, args = (x.reshape(-1,3), y.reshape(-1,3)))
+ if est['fun'] < best_res:
+ best_res = est['fun']
+ best_scale = est['x'][0]
+ print('Pose matching error = %.2f mm.' % best_res)
+ return best_scale
+
+opts = parse_args()
+args = get_config(opts.config)
+
+# root_rel
+# args.rootrel = True
+
+smpl = SMPL(args.data_root, batch_size=1).cuda()
+J_regressor = smpl.J_regressor_h36m
+
+end = time.time()
+model_backbone = load_backbone(args)
+print(f'init backbone time: {(time.time()-end):02f}s')
+end = time.time()
+model = MeshRegressor(args, backbone=model_backbone, dim_rep=args.dim_rep, hidden_dim=args.hidden_dim, dropout_ratio=args.dropout)
+print(f'init whole model time: {(time.time()-end):02f}s')
+
+if torch.cuda.is_available():
+ model = nn.DataParallel(model)
+ model = model.cuda()
+
+chk_filename = opts.evaluate if opts.evaluate else opts.resume
+print('Loading checkpoint', chk_filename)
+checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
+model.load_state_dict(checkpoint['model'], strict=True)
+model.eval()
+
+testloader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True,
+ 'drop_last': False
+}
+
+vid = imageio.get_reader(opts.vid_path, 'ffmpeg')
+fps_in = vid.get_meta_data()['fps']
+vid_size = vid.get_meta_data()['size']
+os.makedirs(opts.out_path, exist_ok=True)
+
+if opts.pixel:
+ # Keep relative scale with pixel coornidates
+ wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, vid_size=vid_size, scale_range=None, focus=opts.focus)
+else:
+ # Scale to [-1,1]
+ wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, scale_range=[1,1], focus=opts.focus)
+
+test_loader = DataLoader(wild_dataset, **testloader_params)
+
+verts_all = []
+reg3d_all = []
+with torch.no_grad():
+ for batch_input in tqdm(test_loader):
+ batch_size, clip_frames = batch_input.shape[:2]
+ if torch.cuda.is_available():
+ batch_input = batch_input.cuda().float()
+ output = model(batch_input)
+ batch_input_flip = flip_data(batch_input)
+ output_flip = model(batch_input_flip)
+ output_flip_pose = output_flip[0]['theta'][:, :, :72]
+ output_flip_shape = output_flip[0]['theta'][:, :, 72:]
+ output_flip_pose = flip_thetas_batch(output_flip_pose)
+ output_flip_pose = output_flip_pose.reshape(-1, 72)
+ output_flip_shape = output_flip_shape.reshape(-1, 10)
+ output_flip_smpl = smpl(
+ betas=output_flip_shape,
+ body_pose=output_flip_pose[:, 3:],
+ global_orient=output_flip_pose[:, :3],
+ pose2rot=True
+ )
+ output_flip_verts = output_flip_smpl.vertices.detach()
+ J_regressor_batch = J_regressor[None, :].expand(output_flip_verts.shape[0], -1, -1).to(output_flip_verts.device)
+ output_flip_kp3d = torch.matmul(J_regressor_batch, output_flip_verts) # (NT,17,3)
+ output_flip_back = [{
+ 'verts': output_flip_verts.reshape(batch_size, clip_frames, -1, 3) * 1000.0,
+ 'kp_3d': output_flip_kp3d.reshape(batch_size, clip_frames, -1, 3),
+ }]
+ output_final = [{}]
+ for k, v in output_flip_back[0].items():
+ output_final[0][k] = (output[0][k] + output_flip_back[0][k]) / 2.0
+ output = output_final
+ verts_all.append(output[0]['verts'].cpu().numpy())
+ reg3d_all.append(output[0]['kp_3d'].cpu().numpy())
+
+verts_all = np.hstack(verts_all)
+verts_all = np.concatenate(verts_all)
+reg3d_all = np.hstack(reg3d_all)
+reg3d_all = np.concatenate(reg3d_all)
+
+if opts.ref_3d_motion_path:
+ ref_pose = np.load(opts.ref_3d_motion_path)
+ x = ref_pose - ref_pose[:, :1]
+ y = reg3d_all - reg3d_all[:, :1]
+ scale = solve_scale(x, y)
+ root_cam = ref_pose[:, :1] * scale
+ verts_all = verts_all - reg3d_all[:,:1] + root_cam
+
+render_and_save(verts_all, osp.join(opts.out_path, 'mesh.mp4'), keep_imgs=False, fps=fps_in, draw_face=True)
+
diff --git a/lib/data/augmentation.py b/lib/data/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..0818d641ccc52459141738d86be86c6fcced9662
--- /dev/null
+++ b/lib/data/augmentation.py
@@ -0,0 +1,99 @@
+import numpy as np
+import os
+import random
+import torch
+import copy
+import torch.nn as nn
+from lib.utils.tools import read_pkl
+from lib.utils.utils_data import flip_data, crop_scale_3d
+
+class Augmenter2D(object):
+ """
+ Make 2D augmentations on the fly. PyTorch batch-processing GPU version.
+ """
+ def __init__(self, args):
+ self.d2c_params = read_pkl(args.d2c_params_path)
+ self.noise = torch.load(args.noise_path)
+ self.mask_ratio = args.mask_ratio
+ self.mask_T_ratio = args.mask_T_ratio
+ self.num_Kframes = 27
+ self.noise_std = 0.002
+
+ def dis2conf(self, dis, a, b, m, s):
+ f = a/(dis+a)+b*dis
+ shift = torch.randn(*dis.shape)*s + m
+ # if torch.cuda.is_available():
+ shift = shift.to(dis.device)
+ return f + shift
+
+ def add_noise(self, motion_2d):
+ a, b, m, s = self.d2c_params["a"], self.d2c_params["b"], self.d2c_params["m"], self.d2c_params["s"]
+ if "uniform_range" in self.noise.keys():
+ uniform_range = self.noise["uniform_range"]
+ else:
+ uniform_range = 0.06
+ motion_2d = motion_2d[:,:,:,:2]
+ batch_size = motion_2d.shape[0]
+ num_frames = motion_2d.shape[1]
+ num_joints = motion_2d.shape[2]
+ mean = self.noise['mean'].float()
+ std = self.noise['std'].float()
+ weight = self.noise['weight'][:,None].float()
+ sel = torch.rand((batch_size, self.num_Kframes, num_joints, 1))
+ gaussian_sample = (torch.randn(batch_size, self.num_Kframes, num_joints, 2) * std + mean)
+ uniform_sample = (torch.rand((batch_size, self.num_Kframes, num_joints, 2))-0.5) * uniform_range
+ noise_mean = 0
+ delta_noise = torch.randn(num_frames, num_joints, 2) * self.noise_std + noise_mean
+ # if torch.cuda.is_available():
+ mean = mean.to(motion_2d.device)
+ std = std.to(motion_2d.device)
+ weight = weight.to(motion_2d.device)
+ gaussian_sample = gaussian_sample.to(motion_2d.device)
+ uniform_sample = uniform_sample.to(motion_2d.device)
+ sel = sel.to(motion_2d.device)
+ delta_noise = delta_noise.to(motion_2d.device)
+
+ delta = gaussian_sample*(sel=weight)
+ delta_expand = torch.nn.functional.interpolate(delta.unsqueeze(1), [num_frames, num_joints, 2], mode='trilinear', align_corners=True)[:,0]
+ delta_final = delta_expand + delta_noise
+ motion_2d = motion_2d + delta_final
+ dx = delta_final[:,:,:,0]
+ dy = delta_final[:,:,:,1]
+ dis2 = dx*dx+dy*dy
+ dis = torch.sqrt(dis2)
+ conf = self.dis2conf(dis, a, b, m, s).clip(0,1).reshape([batch_size, num_frames, num_joints, -1])
+ return torch.cat((motion_2d, conf), dim=3)
+
+ def add_mask(self, x):
+ ''' motion_2d: (N,T,17,3)
+ '''
+ N,T,J,C = x.shape
+ mask = torch.rand(N,T,J,1, dtype=x.dtype, device=x.device) > self.mask_ratio
+ mask_T = torch.rand(1,T,1,1, dtype=x.dtype, device=x.device) > self.mask_T_ratio
+ x = x * mask * mask_T
+ return x
+
+ def augment2D(self, motion_2d, mask=False, noise=False):
+ if noise:
+ motion_2d = self.add_noise(motion_2d)
+ if mask:
+ motion_2d = self.add_mask(motion_2d)
+ return motion_2d
+
+class Augmenter3D(object):
+ """
+ Make 3D augmentations when dataloaders get items. NumPy single motion version.
+ """
+ def __init__(self, args):
+ self.flip = args.flip
+ if hasattr(args, "scale_range_pretrain"):
+ self.scale_range_pretrain = args.scale_range_pretrain
+ else:
+ self.scale_range_pretrain = None
+
+ def augment3D(self, motion_3d):
+ if self.scale_range_pretrain:
+ motion_3d = crop_scale_3d(motion_3d, self.scale_range_pretrain)
+ if self.flip and random.random()>0.5:
+ motion_3d = flip_data(motion_3d)
+ return motion_3d
\ No newline at end of file
diff --git a/lib/data/datareader_h36m.py b/lib/data/datareader_h36m.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0f20b6845ea476921fbcdf5dc96fc0aee630951
--- /dev/null
+++ b/lib/data/datareader_h36m.py
@@ -0,0 +1,136 @@
+# Adapted from Optimizing Network Structure for 3D Human Pose Estimation (ICCV 2019) (https://github.com/CHUNYUWANG/lcn-pose/blob/master/tools/data.py)
+
+import numpy as np
+import os, sys
+import random
+import copy
+from lib.utils.tools import read_pkl
+from lib.utils.utils_data import split_clips
+random.seed(0)
+
+class DataReaderH36M(object):
+ def __init__(self, n_frames, sample_stride, data_stride_train, data_stride_test, read_confidence=True, dt_root = 'data/motion3d', dt_file = 'h36m_cpn_cam_source.pkl'):
+ self.gt_trainset = None
+ self.gt_testset = None
+ self.split_id_train = None
+ self.split_id_test = None
+ self.test_hw = None
+ self.dt_dataset = read_pkl('%s/%s' % (dt_root, dt_file))
+ self.n_frames = n_frames
+ self.sample_stride = sample_stride
+ self.data_stride_train = data_stride_train
+ self.data_stride_test = data_stride_test
+ self.read_confidence = read_confidence
+
+ def read_2d(self):
+ trainset = self.dt_dataset['train']['joint_2d'][::self.sample_stride, :, :2].astype(np.float32) # [N, 17, 2]
+ testset = self.dt_dataset['test']['joint_2d'][::self.sample_stride, :, :2].astype(np.float32) # [N, 17, 2]
+ # map to [-1, 1]
+ for idx, camera_name in enumerate(self.dt_dataset['train']['camera_name']):
+ if camera_name == '54138969' or camera_name == '60457274':
+ res_w, res_h = 1000, 1002
+ elif camera_name == '55011271' or camera_name == '58860488':
+ res_w, res_h = 1000, 1000
+ else:
+ assert 0, '%d data item has an invalid camera name' % idx
+ trainset[idx, :, :] = trainset[idx, :, :] / res_w * 2 - [1, res_h / res_w]
+ for idx, camera_name in enumerate(self.dt_dataset['test']['camera_name']):
+ if camera_name == '54138969' or camera_name == '60457274':
+ res_w, res_h = 1000, 1002
+ elif camera_name == '55011271' or camera_name == '58860488':
+ res_w, res_h = 1000, 1000
+ else:
+ assert 0, '%d data item has an invalid camera name' % idx
+ testset[idx, :, :] = testset[idx, :, :] / res_w * 2 - [1, res_h / res_w]
+ if self.read_confidence:
+ if 'confidence' in self.dt_dataset['train'].keys():
+ train_confidence = self.dt_dataset['train']['confidence'][::self.sample_stride].astype(np.float32)
+ test_confidence = self.dt_dataset['test']['confidence'][::self.sample_stride].astype(np.float32)
+ if len(train_confidence.shape)==2: # (1559752, 17)
+ train_confidence = train_confidence[:,:,None]
+ test_confidence = test_confidence[:,:,None]
+ else:
+ # No conf provided, fill with 1.
+ train_confidence = np.ones(trainset.shape)[:,:,0:1]
+ test_confidence = np.ones(testset.shape)[:,:,0:1]
+ trainset = np.concatenate((trainset, train_confidence), axis=2) # [N, 17, 3]
+ testset = np.concatenate((testset, test_confidence), axis=2) # [N, 17, 3]
+ return trainset, testset
+
+ def read_3d(self):
+ train_labels = self.dt_dataset['train']['joint3d_image'][::self.sample_stride, :, :3].astype(np.float32) # [N, 17, 3]
+ test_labels = self.dt_dataset['test']['joint3d_image'][::self.sample_stride, :, :3].astype(np.float32) # [N, 17, 3]
+ # map to [-1, 1]
+ for idx, camera_name in enumerate(self.dt_dataset['train']['camera_name']):
+ if camera_name == '54138969' or camera_name == '60457274':
+ res_w, res_h = 1000, 1002
+ elif camera_name == '55011271' or camera_name == '58860488':
+ res_w, res_h = 1000, 1000
+ else:
+ assert 0, '%d data item has an invalid camera name' % idx
+ train_labels[idx, :, :2] = train_labels[idx, :, :2] / res_w * 2 - [1, res_h / res_w]
+ train_labels[idx, :, 2:] = train_labels[idx, :, 2:] / res_w * 2
+
+ for idx, camera_name in enumerate(self.dt_dataset['test']['camera_name']):
+ if camera_name == '54138969' or camera_name == '60457274':
+ res_w, res_h = 1000, 1002
+ elif camera_name == '55011271' or camera_name == '58860488':
+ res_w, res_h = 1000, 1000
+ else:
+ assert 0, '%d data item has an invalid camera name' % idx
+ test_labels[idx, :, :2] = test_labels[idx, :, :2] / res_w * 2 - [1, res_h / res_w]
+ test_labels[idx, :, 2:] = test_labels[idx, :, 2:] / res_w * 2
+
+ return train_labels, test_labels
+ def read_hw(self):
+ if self.test_hw is not None:
+ return self.test_hw
+ test_hw = np.zeros((len(self.dt_dataset['test']['camera_name']), 2))
+ for idx, camera_name in enumerate(self.dt_dataset['test']['camera_name']):
+ if camera_name == '54138969' or camera_name == '60457274':
+ res_w, res_h = 1000, 1002
+ elif camera_name == '55011271' or camera_name == '58860488':
+ res_w, res_h = 1000, 1000
+ else:
+ assert 0, '%d data item has an invalid camera name' % idx
+ test_hw[idx] = res_w, res_h
+ self.test_hw = test_hw
+ return test_hw
+
+ def get_split_id(self):
+ if self.split_id_train is not None and self.split_id_test is not None:
+ return self.split_id_train, self.split_id_test
+ vid_list_train = self.dt_dataset['train']['source'][::self.sample_stride] # (1559752,)
+ vid_list_test = self.dt_dataset['test']['source'][::self.sample_stride] # (566920,)
+ self.split_id_train = split_clips(vid_list_train, self.n_frames, data_stride=self.data_stride_train)
+ self.split_id_test = split_clips(vid_list_test, self.n_frames, data_stride=self.data_stride_test)
+ return self.split_id_train, self.split_id_test
+
+ def get_hw(self):
+# Only Testset HW is needed for denormalization
+ test_hw = self.read_hw() # train_data (1559752, 2) test_data (566920, 2)
+ split_id_train, split_id_test = self.get_split_id()
+ test_hw = test_hw[split_id_test][:,0,:] # (N, 2)
+ return test_hw
+
+ def get_sliced_data(self):
+ train_data, test_data = self.read_2d() # train_data (1559752, 17, 3) test_data (566920, 17, 3)
+ train_labels, test_labels = self.read_3d() # train_labels (1559752, 17, 3) test_labels (566920, 17, 3)
+ split_id_train, split_id_test = self.get_split_id()
+ train_data, test_data = train_data[split_id_train], test_data[split_id_test] # (N, 27, 17, 3)
+ train_labels, test_labels = train_labels[split_id_train], test_labels[split_id_test] # (N, 27, 17, 3)
+ # ipdb.set_trace()
+ return train_data, test_data, train_labels, test_labels
+
+ def denormalize(self, test_data):
+# data: (N, n_frames, 51) or data: (N, n_frames, 17, 3)
+ n_clips = test_data.shape[0]
+ test_hw = self.get_hw()
+ data = test_data.reshape([n_clips, -1, 17, 3])
+ assert len(data) == len(test_hw)
+ # denormalize (x,y,z) coordiantes for results
+ for idx, item in enumerate(data):
+ res_w, res_h = test_hw[idx]
+ data[idx, :, :, :2] = (data[idx, :, :, :2] + np.array([1, res_h / res_w])) * res_w / 2
+ data[idx, :, :, 2:] = data[idx, :, :, 2:] * res_w / 2
+ return data # [n_clips, -1, 17, 3]
diff --git a/lib/data/datareader_mesh.py b/lib/data/datareader_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cb1e87910a50bf36ee69c1dfd316075b2260b9f
--- /dev/null
+++ b/lib/data/datareader_mesh.py
@@ -0,0 +1,59 @@
+import numpy as np
+import os, sys
+import copy
+from lib.utils.tools import read_pkl
+from lib.utils.utils_data import split_clips
+
+class DataReaderMesh(object):
+ def __init__(self, n_frames, sample_stride, data_stride_train, data_stride_test, read_confidence=True, dt_root = 'data/mesh', dt_file = 'pw3d_det.pkl', res=[1920, 1920]):
+ self.split_id_train = None
+ self.split_id_test = None
+ self.dt_dataset = read_pkl('%s/%s' % (dt_root, dt_file))
+ self.n_frames = n_frames
+ self.sample_stride = sample_stride
+ self.data_stride_train = data_stride_train
+ self.data_stride_test = data_stride_test
+ self.read_confidence = read_confidence
+ self.res = res
+
+ def read_2d(self):
+ if self.res is not None:
+ res_w, res_h = self.res
+ offset = [1, res_h / res_w]
+ else:
+ res = np.array(self.dt_dataset['train']['img_hw'])[::self.sample_stride].astype(np.float32)
+ res_w, res_h = res.max(1)[:, None, None], res.max(1)[:, None, None]
+ offset = 1
+ trainset = self.dt_dataset['train']['joint_2d'][::self.sample_stride, :, :2].astype(np.float32) # [N, 17, 2]
+ testset = self.dt_dataset['test']['joint_2d'][::self.sample_stride, :, :2].astype(np.float32) # [N, 17, 2]
+ # res_w, res_h = self.res
+ trainset = trainset / res_w * 2 - offset
+ testset = testset / res_w * 2 - offset
+ if self.read_confidence:
+ train_confidence = self.dt_dataset['train']['confidence'][::self.sample_stride].astype(np.float32)
+ test_confidence = self.dt_dataset['test']['confidence'][::self.sample_stride].astype(np.float32)
+ if len(train_confidence.shape)==2:
+ train_confidence = train_confidence[:,:,None]
+ test_confidence = test_confidence[:,:,None]
+ trainset = np.concatenate((trainset, train_confidence), axis=2) # [N, 17, 3]
+ testset = np.concatenate((testset, test_confidence), axis=2) # [N, 17, 3]
+ return trainset, testset
+
+ def get_split_id(self):
+ if self.split_id_train is not None and self.split_id_test is not None:
+ return self.split_id_train, self.split_id_test
+ vid_list_train = self.dt_dataset['train']['source'][::self.sample_stride]
+ vid_list_test = self.dt_dataset['test']['source'][::self.sample_stride]
+ self.split_id_train = split_clips(vid_list_train, self.n_frames, self.data_stride_train)
+ self.split_id_test = split_clips(vid_list_test, self.n_frames, self.data_stride_test)
+ return self.split_id_train, self.split_id_test
+
+ def get_sliced_data(self):
+ train_data, test_data = self.read_2d()
+ train_labels, test_labels = self.read_3d()
+ split_id_train, split_id_test = self.get_split_id()
+ train_data, test_data = train_data[split_id_train], test_data[split_id_test] # (N, 27, 17, 3)
+ train_labels, test_labels = train_labels[split_id_train], test_labels[split_id_test] # (N, 27, 17, 3)
+ return train_data, test_data, train_labels, test_labels
+
+
\ No newline at end of file
diff --git a/lib/data/dataset_action.py b/lib/data/dataset_action.py
new file mode 100644
index 0000000000000000000000000000000000000000..87bc5de62698baeb9d785139566b20ff5d4b5280
--- /dev/null
+++ b/lib/data/dataset_action.py
@@ -0,0 +1,206 @@
+import torch
+import numpy as np
+import os
+import random
+import copy
+from torch.utils.data import Dataset, DataLoader
+from lib.utils.utils_data import crop_scale, resample
+from lib.utils.tools import read_pkl
+
+def get_action_names(file_path = "data/action/ntu_actions.txt"):
+ f = open(file_path, "r")
+ s = f.read()
+ actions = s.split('\n')
+ action_names = []
+ for a in actions:
+ action_names.append(a.split('.')[1][1:])
+ return action_names
+
+def make_cam(x, img_shape):
+ '''
+ Input: x (M x T x V x C)
+ img_shape (height, width)
+ '''
+ h, w = img_shape
+ if w >= h:
+ x_cam = x / w * 2 - 1
+ else:
+ x_cam = x / h * 2 - 1
+ return x_cam
+
+def coco2h36m(x):
+ '''
+ Input: x (M x T x V x C)
+
+ COCO: {0-nose 1-Leye 2-Reye 3-Lear 4Rear 5-Lsho 6-Rsho 7-Lelb 8-Relb 9-Lwri 10-Rwri 11-Lhip 12-Rhip 13-Lkne 14-Rkne 15-Lank 16-Rank}
+
+ H36M:
+ 0: 'root',
+ 1: 'rhip',
+ 2: 'rkne',
+ 3: 'rank',
+ 4: 'lhip',
+ 5: 'lkne',
+ 6: 'lank',
+ 7: 'belly',
+ 8: 'neck',
+ 9: 'nose',
+ 10: 'head',
+ 11: 'lsho',
+ 12: 'lelb',
+ 13: 'lwri',
+ 14: 'rsho',
+ 15: 'relb',
+ 16: 'rwri'
+ '''
+ y = np.zeros(x.shape)
+ y[:,:,0,:] = (x[:,:,11,:] + x[:,:,12,:]) * 0.5
+ y[:,:,1,:] = x[:,:,12,:]
+ y[:,:,2,:] = x[:,:,14,:]
+ y[:,:,3,:] = x[:,:,16,:]
+ y[:,:,4,:] = x[:,:,11,:]
+ y[:,:,5,:] = x[:,:,13,:]
+ y[:,:,6,:] = x[:,:,15,:]
+ y[:,:,8,:] = (x[:,:,5,:] + x[:,:,6,:]) * 0.5
+ y[:,:,7,:] = (y[:,:,0,:] + y[:,:,8,:]) * 0.5
+ y[:,:,9,:] = x[:,:,0,:]
+ y[:,:,10,:] = (x[:,:,1,:] + x[:,:,2,:]) * 0.5
+ y[:,:,11,:] = x[:,:,5,:]
+ y[:,:,12,:] = x[:,:,7,:]
+ y[:,:,13,:] = x[:,:,9,:]
+ y[:,:,14,:] = x[:,:,6,:]
+ y[:,:,15,:] = x[:,:,8,:]
+ y[:,:,16,:] = x[:,:,10,:]
+ return y
+
+def random_move(data_numpy,
+ angle_range=[-10., 10.],
+ scale_range=[0.9, 1.1],
+ transform_range=[-0.1, 0.1],
+ move_time_candidate=[1]):
+ data_numpy = np.transpose(data_numpy, (3,1,2,0)) # M,T,V,C-> C,T,V,M
+ C, T, V, M = data_numpy.shape
+ move_time = random.choice(move_time_candidate)
+ node = np.arange(0, T, T * 1.0 / move_time).round().astype(int)
+ node = np.append(node, T)
+ num_node = len(node)
+ A = np.random.uniform(angle_range[0], angle_range[1], num_node)
+ S = np.random.uniform(scale_range[0], scale_range[1], num_node)
+ T_x = np.random.uniform(transform_range[0], transform_range[1], num_node)
+ T_y = np.random.uniform(transform_range[0], transform_range[1], num_node)
+ a = np.zeros(T)
+ s = np.zeros(T)
+ t_x = np.zeros(T)
+ t_y = np.zeros(T)
+ # linspace
+ for i in range(num_node - 1):
+ a[node[i]:node[i + 1]] = np.linspace(
+ A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180
+ s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1], node[i + 1] - node[i])
+ t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1], node[i + 1] - node[i])
+ t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1], node[i + 1] - node[i])
+ theta = np.array([[np.cos(a) * s, -np.sin(a) * s],
+ [np.sin(a) * s, np.cos(a) * s]])
+ # perform transformation
+ for i_frame in range(T):
+ xy = data_numpy[0:2, i_frame, :, :]
+ new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1))
+ new_xy[0] += t_x[i_frame]
+ new_xy[1] += t_y[i_frame]
+ data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M)
+ data_numpy = np.transpose(data_numpy, (3,1,2,0)) # C,T,V,M -> M,T,V,C
+ return data_numpy
+
+def human_tracking(x):
+ M, T = x.shape[:2]
+ if M==1:
+ return x
+ else:
+ diff0 = np.sum(np.linalg.norm(x[0,1:] - x[0,:-1], axis=-1), axis=-1) # (T-1, V, C) -> (T-1)
+ diff1 = np.sum(np.linalg.norm(x[0,1:] - x[1,:-1], axis=-1), axis=-1)
+ x_new = np.zeros(x.shape)
+ sel = np.cumsum(diff0 > diff1) % 2
+ sel = sel[:,None,None]
+ x_new[0][0] = x[0][0]
+ x_new[1][0] = x[1][0]
+ x_new[0,1:] = x[1,1:] * sel + x[0,1:] * (1-sel)
+ x_new[1,1:] = x[0,1:] * sel + x[1,1:] * (1-sel)
+ return x_new
+
+class ActionDataset(Dataset):
+ def __init__(self, data_path, data_split, n_frames=243, random_move=True, scale_range=[1,1], check_split=True): # data_split: train/test etc.
+ np.random.seed(0)
+ dataset = read_pkl(data_path)
+ if check_split:
+ assert data_split in dataset['split'].keys()
+ self.split = dataset['split'][data_split]
+ annotations = dataset['annotations']
+ self.random_move = random_move
+ self.is_train = "train" in data_split or (check_split==False)
+ if "oneshot" in data_split:
+ self.is_train = False
+ self.scale_range = scale_range
+ motions = []
+ labels = []
+ for sample in annotations:
+ if check_split and (not sample['frame_dir'] in self.split):
+ continue
+ resample_id = resample(ori_len=sample['total_frames'], target_len=n_frames, randomness=self.is_train)
+ motion_cam = make_cam(x=sample['keypoint'], img_shape=sample['img_shape'])
+ motion_cam = human_tracking(motion_cam)
+ motion_cam = coco2h36m(motion_cam)
+ motion_conf = sample['keypoint_score'][..., None]
+ motion = np.concatenate((motion_cam[:,resample_id], motion_conf[:,resample_id]), axis=-1)
+ if motion.shape[0]==1: # Single person, make a fake zero person
+ fake = np.zeros(motion.shape)
+ motion = np.concatenate((motion, fake), axis=0)
+ motions.append(motion.astype(np.float32))
+ labels.append(sample['label'])
+ self.motions = np.array(motions)
+ self.labels = np.array(labels)
+
+ def __len__(self):
+ 'Denotes the total number of samples'
+ return len(self.motions)
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+class NTURGBD(ActionDataset):
+ def __init__(self, data_path, data_split, n_frames=243, random_move=True, scale_range=[1,1]):
+ super(NTURGBD, self).__init__(data_path, data_split, n_frames, random_move, scale_range)
+
+ def __getitem__(self, idx):
+ 'Generates one sample of data'
+ motion, label = self.motions[idx], self.labels[idx] # (M,T,J,C)
+ if self.random_move:
+ motion = random_move(motion)
+ if self.scale_range:
+ result = crop_scale(motion, scale_range=self.scale_range)
+ else:
+ result = motion
+ return result.astype(np.float32), label
+
+class NTURGBD1Shot(ActionDataset):
+ def __init__(self, data_path, data_split, n_frames=243, random_move=True, scale_range=[1,1], check_split=False):
+ super(NTURGBD1Shot, self).__init__(data_path, data_split, n_frames, random_move, scale_range, check_split)
+ oneshot_classes = [0, 6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114]
+ new_classes = set(range(120)) - set(oneshot_classes)
+ old2new = {}
+ for i, cid in enumerate(new_classes):
+ old2new[cid] = i
+ filtered = [not (x in oneshot_classes) for x in self.labels]
+ self.motions = self.motions[filtered]
+ filtered_labels = self.labels[filtered]
+ self.labels = [old2new[x] for x in filtered_labels]
+
+ def __getitem__(self, idx):
+ 'Generates one sample of data'
+ motion, label = self.motions[idx], self.labels[idx] # (M,T,J,C)
+ if self.random_move:
+ motion = random_move(motion)
+ if self.scale_range:
+ result = crop_scale(motion, scale_range=self.scale_range)
+ else:
+ result = motion
+ return result.astype(np.float32), label
\ No newline at end of file
diff --git a/lib/data/dataset_mesh.py b/lib/data/dataset_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..c496a3ac34648a39076379508d67625099f589b3
--- /dev/null
+++ b/lib/data/dataset_mesh.py
@@ -0,0 +1,97 @@
+import torch
+import numpy as np
+import glob
+import os
+import io
+import random
+import pickle
+from torch.utils.data import Dataset, DataLoader
+from lib.data.augmentation import Augmenter3D
+from lib.utils.tools import read_pkl
+from lib.utils.utils_data import flip_data, crop_scale
+from lib.utils.utils_mesh import flip_thetas
+from lib.utils.utils_smpl import SMPL
+from torch.utils.data import Dataset, DataLoader
+from lib.data.datareader_h36m import DataReaderH36M
+from lib.data.datareader_mesh import DataReaderMesh
+from lib.data.dataset_action import random_move
+
+class SMPLDataset(Dataset):
+ def __init__(self, args, data_split, dataset): # data_split: train/test; dataset: h36m, coco, pw3d
+ random.seed(0)
+ np.random.seed(0)
+ self.clip_len = args.clip_len
+ self.data_split = data_split
+ if dataset=="h36m":
+ datareader = DataReaderH36M(n_frames=self.clip_len, sample_stride=args.sample_stride, data_stride_train=args.data_stride, data_stride_test=self.clip_len, dt_root=args.data_root, dt_file=args.dt_file_h36m)
+ elif dataset=="coco":
+ datareader = DataReaderMesh(n_frames=1, sample_stride=args.sample_stride, data_stride_train=1, data_stride_test=1, dt_root=args.data_root, dt_file=args.dt_file_coco, res=[640, 640])
+ elif dataset=="pw3d":
+ datareader = DataReaderMesh(n_frames=self.clip_len, sample_stride=args.sample_stride, data_stride_train=args.data_stride, data_stride_test=self.clip_len, dt_root=args.data_root, dt_file=args.dt_file_pw3d, res=[1920, 1920])
+ else:
+ raise Exception("Mesh dataset undefined.")
+
+ split_id_train, split_id_test = datareader.get_split_id() # Index of clips
+ train_data, test_data = datareader.read_2d()
+ train_data, test_data = train_data[split_id_train], test_data[split_id_test] # Input: (N, T, 17, 3)
+ self.motion_2d = {'train': train_data, 'test': test_data}[data_split]
+
+ dt = datareader.dt_dataset
+ smpl_pose_train = dt['train']['smpl_pose'][split_id_train] # (N, T, 72)
+ smpl_shape_train = dt['train']['smpl_shape'][split_id_train] # (N, T, 10)
+ smpl_pose_test = dt['test']['smpl_pose'][split_id_test] # (N, T, 72)
+ smpl_shape_test = dt['test']['smpl_shape'][split_id_test] # (N, T, 10)
+
+ self.motion_smpl_3d = {'train': {'pose': smpl_pose_train, 'shape': smpl_shape_train}, 'test': {'pose': smpl_pose_test, 'shape': smpl_shape_test}}[data_split]
+ self.smpl = SMPL(
+ args.data_root,
+ batch_size=1,
+ )
+
+ def __len__(self):
+ 'Denotes the total number of samples'
+ return len(self.motion_2d)
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+class MotionSMPL(SMPLDataset):
+ def __init__(self, args, data_split, dataset):
+ super(MotionSMPL, self).__init__(args, data_split, dataset)
+ self.flip = args.flip
+
+ def __getitem__(self, index):
+ 'Generates one sample of data'
+ # Select sample
+ motion_2d = self.motion_2d[index] # motion_2d: (T,17,3)
+ motion_2d[:,:,2] = np.clip(motion_2d[:,:,2], 0, 1)
+ motion_smpl_pose = self.motion_smpl_3d['pose'][index].reshape(-1, 24, 3) # motion_smpl_3d: (T, 24, 3)
+ motion_smpl_shape = self.motion_smpl_3d['shape'][index] # motion_smpl_3d: (T,10)
+
+ if self.data_split=="train":
+ if self.flip and random.random() > 0.5: # Training augmentation - random flipping
+ motion_2d = flip_data(motion_2d)
+ motion_smpl_pose = flip_thetas(motion_smpl_pose)
+
+
+ motion_smpl_pose = torch.from_numpy(motion_smpl_pose).reshape(-1, 72).float()
+ motion_smpl_shape = torch.from_numpy(motion_smpl_shape).reshape(-1, 10).float()
+ motion_smpl = self.smpl(
+ betas=motion_smpl_shape,
+ body_pose=motion_smpl_pose[:, 3:],
+ global_orient=motion_smpl_pose[:, :3],
+ pose2rot=True
+ )
+ motion_verts = motion_smpl.vertices.detach()*1000.0
+ J_regressor = self.smpl.J_regressor_h36m
+ J_regressor_batch = J_regressor[None, :].expand(motion_verts.shape[0], -1, -1).to(motion_verts.device)
+ motion_3d_reg = torch.matmul(J_regressor_batch, motion_verts) # motion_3d: (T,17,3)
+ motion_verts = motion_verts - motion_3d_reg[:, :1, :]
+ motion_3d_reg = motion_3d_reg - motion_3d_reg[:, :1, :] # motion_3d: (T,17,3)
+ motion_theta = torch.cat((motion_smpl_pose, motion_smpl_shape), -1)
+ motion_smpl_3d = {
+ 'theta': motion_theta, # smpl pose and shape
+ 'kp_3d': motion_3d_reg, # 3D keypoints
+ 'verts': motion_verts, # 3D mesh vertices
+ }
+ return motion_2d, motion_smpl_3d
\ No newline at end of file
diff --git a/lib/data/dataset_motion_2d.py b/lib/data/dataset_motion_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..b136f33507de96323d517452def1fe1686743700
--- /dev/null
+++ b/lib/data/dataset_motion_2d.py
@@ -0,0 +1,148 @@
+import sys
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import Dataset, DataLoader
+import numpy as np
+import os
+import random
+import copy
+import json
+from collections import defaultdict
+from lib.utils.utils_data import crop_scale, flip_data, resample, split_clips
+
+def posetrack2h36m(x):
+ '''
+ Input: x (T x V x C)
+
+ PoseTrack keypoints = [ 'nose',
+ 'head_bottom',
+ 'head_top',
+ 'left_ear',
+ 'right_ear',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'left_hip',
+ 'right_hip',
+ 'left_knee',
+ 'right_knee',
+ 'left_ankle',
+ 'right_ankle']
+ H36M:
+ 0: 'root',
+ 1: 'rhip',
+ 2: 'rkne',
+ 3: 'rank',
+ 4: 'lhip',
+ 5: 'lkne',
+ 6: 'lank',
+ 7: 'belly',
+ 8: 'neck',
+ 9: 'nose',
+ 10: 'head',
+ 11: 'lsho',
+ 12: 'lelb',
+ 13: 'lwri',
+ 14: 'rsho',
+ 15: 'relb',
+ 16: 'rwri'
+ '''
+ y = np.zeros(x.shape)
+ y[:,0,:] = (x[:,11,:] + x[:,12,:]) * 0.5
+ y[:,1,:] = x[:,12,:]
+ y[:,2,:] = x[:,14,:]
+ y[:,3,:] = x[:,16,:]
+ y[:,4,:] = x[:,11,:]
+ y[:,5,:] = x[:,13,:]
+ y[:,6,:] = x[:,15,:]
+ y[:,8,:] = x[:,1,:]
+ y[:,7,:] = (y[:,0,:] + y[:,8,:]) * 0.5
+ y[:,9,:] = x[:,0,:]
+ y[:,10,:] = x[:,2,:]
+ y[:,11,:] = x[:,5,:]
+ y[:,12,:] = x[:,7,:]
+ y[:,13,:] = x[:,9,:]
+ y[:,14,:] = x[:,6,:]
+ y[:,15,:] = x[:,8,:]
+ y[:,16,:] = x[:,10,:]
+ y[:,0,2] = np.minimum(x[:,11,2], x[:,12,2])
+ y[:,7,2] = np.minimum(y[:,0,2], y[:,8,2])
+ return y
+
+
+class PoseTrackDataset2D(Dataset):
+ def __init__(self, flip=True, scale_range=[0.25, 1]):
+ super(PoseTrackDataset2D, self).__init__()
+ self.flip = flip
+ data_root = "data/motion2d/posetrack18_annotations/train/"
+ file_list = sorted(os.listdir(data_root))
+ all_motions = []
+ all_motions_filtered = []
+ self.scale_range = scale_range
+ for filename in file_list:
+ with open(os.path.join(data_root, filename), 'r') as file:
+ json_dict = json.load(file)
+ annots = json_dict['annotations']
+ imgs = json_dict['images']
+ motions = defaultdict(list)
+ for annot in annots:
+ tid = annot['track_id']
+ pose2d = np.array(annot['keypoints']).reshape(-1,3)
+ motions[tid].append(pose2d)
+ all_motions += list(motions.values())
+ for motion in all_motions:
+ if len(motion)<30:
+ continue
+ motion = np.array(motion[:30])
+ if np.sum(motion[:,:,2]) <= 306: # Valid joint num threshold
+ continue
+ motion = crop_scale(motion, self.scale_range)
+ motion = posetrack2h36m(motion)
+ motion[motion[:,:,2]==0] = 0
+ if np.sum(motion[:,0,2]) < 30:
+ continue # Root all visible (needed for framewise rootrel)
+ all_motions_filtered.append(motion)
+ all_motions_filtered = np.array(all_motions_filtered)
+ self.motions_2d = all_motions_filtered
+
+ def __len__(self):
+ 'Denotes the total number of samples'
+ return len(self.motions_2d)
+
+ def __getitem__(self, index):
+ 'Generates one sample of data'
+ motion_2d = torch.FloatTensor(self.motions_2d[index])
+ if self.flip and random.random()>0.5:
+ motion_2d = flip_data(motion_2d)
+ return motion_2d, motion_2d
+
+class InstaVDataset2D(Dataset):
+ def __init__(self, n_frames=81, data_stride=27, flip=True, valid_threshold=0.0, scale_range=[0.25, 1]):
+ super(InstaVDataset2D, self).__init__()
+ self.flip = flip
+ self.scale_range = scale_range
+ motion_all = np.load('data/motion2d/InstaVariety/motion_all.npy')
+ id_all = np.load('data/motion2d/InstaVariety/id_all.npy')
+ split_id = split_clips(id_all, n_frames, data_stride)
+ motions_2d = motion_all[split_id] # [N, T, 17, 3]
+ valid_idx = (motions_2d[:,0,0,2] > valid_threshold)
+ self.motions_2d = motions_2d[valid_idx]
+
+ def __len__(self):
+ 'Denotes the total number of samples'
+ return len(self.motions_2d)
+
+ def __getitem__(self, index):
+ 'Generates one sample of data'
+ motion_2d = self.motions_2d[index]
+ motion_2d = crop_scale(motion_2d, self.scale_range)
+ motion_2d[motion_2d[:,:,2]==0] = 0
+ if self.flip and random.random()>0.5:
+ motion_2d = flip_data(motion_2d)
+ motion_2d = torch.FloatTensor(motion_2d)
+ return motion_2d, motion_2d
+
\ No newline at end of file
diff --git a/lib/data/dataset_motion_3d.py b/lib/data/dataset_motion_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2de10dc797f2c8d024ec2030af8cdd061fb18f4
--- /dev/null
+++ b/lib/data/dataset_motion_3d.py
@@ -0,0 +1,68 @@
+import torch
+import numpy as np
+import glob
+import os
+import io
+import random
+import pickle
+from torch.utils.data import Dataset, DataLoader
+from lib.data.augmentation import Augmenter3D
+from lib.utils.tools import read_pkl
+from lib.utils.utils_data import flip_data
+
+class MotionDataset(Dataset):
+ def __init__(self, args, subset_list, data_split): # data_split: train/test
+ np.random.seed(0)
+ self.data_root = args.data_root
+ self.subset_list = subset_list
+ self.data_split = data_split
+ file_list_all = []
+ for subset in self.subset_list:
+ data_path = os.path.join(self.data_root, subset, self.data_split)
+ motion_list = sorted(os.listdir(data_path))
+ for i in motion_list:
+ file_list_all.append(os.path.join(data_path, i))
+ self.file_list = file_list_all
+
+ def __len__(self):
+ 'Denotes the total number of samples'
+ return len(self.file_list)
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+class MotionDataset3D(MotionDataset):
+ def __init__(self, args, subset_list, data_split):
+ super(MotionDataset3D, self).__init__(args, subset_list, data_split)
+ self.flip = args.flip
+ self.synthetic = args.synthetic
+ self.aug = Augmenter3D(args)
+ self.gt_2d = args.gt_2d
+
+ def __getitem__(self, index):
+ 'Generates one sample of data'
+ # Select sample
+ file_path = self.file_list[index]
+ motion_file = read_pkl(file_path)
+ motion_3d = motion_file["data_label"]
+ if self.data_split=="train":
+ if self.synthetic or self.gt_2d:
+ motion_3d = self.aug.augment3D(motion_3d)
+ motion_2d = np.zeros(motion_3d.shape, dtype=np.float32)
+ motion_2d[:,:,:2] = motion_3d[:,:,:2]
+ motion_2d[:,:,2] = 1 # No 2D detection, use GT xy and c=1.
+ elif motion_file["data_input"] is not None: # Have 2D detection
+ motion_2d = motion_file["data_input"]
+ if self.flip and random.random() > 0.5: # Training augmentation - random flipping
+ motion_2d = flip_data(motion_2d)
+ motion_3d = flip_data(motion_3d)
+ else:
+ raise ValueError('Training illegal.')
+ elif self.data_split=="test":
+ motion_2d = motion_file["data_input"]
+ if self.gt_2d:
+ motion_2d[:,:,:2] = motion_3d[:,:,:2]
+ motion_2d[:,:,2] = 1
+ else:
+ raise ValueError('Data split unknown.')
+ return torch.FloatTensor(motion_2d), torch.FloatTensor(motion_3d)
\ No newline at end of file
diff --git a/lib/data/dataset_wild.py b/lib/data/dataset_wild.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a462c4759af73f247c7964fc3c6f53579d64a02
--- /dev/null
+++ b/lib/data/dataset_wild.py
@@ -0,0 +1,102 @@
+import torch
+import numpy as np
+import ipdb
+import glob
+import os
+import io
+import math
+import random
+import json
+import pickle
+import math
+from torch.utils.data import Dataset, DataLoader
+from lib.utils.utils_data import crop_scale
+
+def halpe2h36m(x):
+ '''
+ Input: x (T x V x C)
+ //Halpe 26 body keypoints
+ {0, "Nose"},
+ {1, "LEye"},
+ {2, "REye"},
+ {3, "LEar"},
+ {4, "REar"},
+ {5, "LShoulder"},
+ {6, "RShoulder"},
+ {7, "LElbow"},
+ {8, "RElbow"},
+ {9, "LWrist"},
+ {10, "RWrist"},
+ {11, "LHip"},
+ {12, "RHip"},
+ {13, "LKnee"},
+ {14, "Rknee"},
+ {15, "LAnkle"},
+ {16, "RAnkle"},
+ {17, "Head"},
+ {18, "Neck"},
+ {19, "Hip"},
+ {20, "LBigToe"},
+ {21, "RBigToe"},
+ {22, "LSmallToe"},
+ {23, "RSmallToe"},
+ {24, "LHeel"},
+ {25, "RHeel"},
+ '''
+ T, V, C = x.shape
+ y = np.zeros([T,17,C])
+ y[:,0,:] = x[:,19,:]
+ y[:,1,:] = x[:,12,:]
+ y[:,2,:] = x[:,14,:]
+ y[:,3,:] = x[:,16,:]
+ y[:,4,:] = x[:,11,:]
+ y[:,5,:] = x[:,13,:]
+ y[:,6,:] = x[:,15,:]
+ y[:,7,:] = (x[:,18,:] + x[:,19,:]) * 0.5
+ y[:,8,:] = x[:,18,:]
+ y[:,9,:] = x[:,0,:]
+ y[:,10,:] = x[:,17,:]
+ y[:,11,:] = x[:,5,:]
+ y[:,12,:] = x[:,7,:]
+ y[:,13,:] = x[:,9,:]
+ y[:,14,:] = x[:,6,:]
+ y[:,15,:] = x[:,8,:]
+ y[:,16,:] = x[:,10,:]
+ return y
+
+def read_input(json_path, vid_size, scale_range, focus):
+ with open(json_path, "r") as read_file:
+ results = json.load(read_file)
+ kpts_all = []
+ for item in results:
+ if focus!=None and item['idx']!=focus:
+ continue
+ kpts = np.array(item['keypoints']).reshape([-1,3])
+ kpts_all.append(kpts)
+ kpts_all = np.array(kpts_all)
+ kpts_all = halpe2h36m(kpts_all)
+ if vid_size:
+ w, h = vid_size
+ scale = min(w,h) / 2.0
+ kpts_all[:,:,:2] = kpts_all[:,:,:2] - np.array([w, h]) / 2.0
+ kpts_all[:,:,:2] = kpts_all[:,:,:2] / scale
+ motion = kpts_all
+ if scale_range:
+ motion = crop_scale(kpts_all, scale_range)
+ return motion.astype(np.float32)
+
+class WildDetDataset(Dataset):
+ def __init__(self, json_path, clip_len=243, vid_size=None, scale_range=None, focus=None):
+ self.json_path = json_path
+ self.clip_len = clip_len
+ self.vid_all = read_input(json_path, vid_size, scale_range, focus)
+
+ def __len__(self):
+ 'Denotes the total number of samples'
+ return math.ceil(len(self.vid_all) / self.clip_len)
+
+ def __getitem__(self, index):
+ 'Generates one sample of data'
+ st = index*self.clip_len
+ end = min((index+1)*self.clip_len, len(self.vid_all))
+ return self.vid_all[st:end]
\ No newline at end of file
diff --git a/lib/model/DSTformer.py b/lib/model/DSTformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2af23881e5cc71a41b07722acd462de804bca6c8
--- /dev/null
+++ b/lib/model/DSTformer.py
@@ -0,0 +1,362 @@
+import torch
+import torch.nn as nn
+import math
+import warnings
+import random
+import numpy as np
+from collections import OrderedDict
+from functools import partial
+from itertools import repeat
+from lib.model.drop import DropPath
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+class MLP(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., st_mode='vanilla'):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.mode = st_mode
+ if self.mode == 'parallel':
+ self.ts_attn = nn.Linear(dim*2, dim*2)
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ else:
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.attn_count_s = None
+ self.attn_count_t = None
+
+ def forward(self, x, seqlen=1):
+ B, N, C = x.shape
+
+ if self.mode == 'series':
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+ x = self.forward_spatial(q, k, v)
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+ x = self.forward_temporal(q, k, v, seqlen=seqlen)
+ elif self.mode == 'parallel':
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+ x_t = self.forward_temporal(q, k, v, seqlen=seqlen)
+ x_s = self.forward_spatial(q, k, v)
+
+ alpha = torch.cat([x_s, x_t], dim=-1)
+ alpha = alpha.mean(dim=1, keepdim=True)
+ alpha = self.ts_attn(alpha).reshape(B, 1, C, 2)
+ alpha = alpha.softmax(dim=-1)
+ x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0]
+ elif self.mode == 'coupling':
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+ x = self.forward_coupling(q, k, v, seqlen=seqlen)
+ elif self.mode == 'vanilla':
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+ x = self.forward_spatial(q, k, v)
+ elif self.mode == 'temporal':
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+ x = self.forward_temporal(q, k, v, seqlen=seqlen)
+ elif self.mode == 'spatial':
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+ x = self.forward_spatial(q, k, v)
+ else:
+ raise NotImplementedError(self.mode)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def reshape_T(self, x, seqlen=1, inverse=False):
+ if not inverse:
+ N, C = x.shape[-2:]
+ x = x.reshape(-1, seqlen, self.num_heads, N, C).transpose(1,2)
+ x = x.reshape(-1, self.num_heads, seqlen*N, C) #(B, H, TN, c)
+ else:
+ TN, C = x.shape[-2:]
+ x = x.reshape(-1, self.num_heads, seqlen, TN // seqlen, C).transpose(1,2)
+ x = x.reshape(-1, self.num_heads, TN // seqlen, C) #(BT, H, N, C)
+ return x
+
+ def forward_coupling(self, q, k, v, seqlen=8):
+ BT, _, N, C = q.shape
+ q = self.reshape_T(q, seqlen)
+ k = self.reshape_T(k, seqlen)
+ v = self.reshape_T(v, seqlen)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = attn @ v
+ x = self.reshape_T(x, seqlen, inverse=True)
+ x = x.transpose(1,2).reshape(BT, N, C*self.num_heads)
+ return x
+
+ def forward_spatial(self, q, k, v):
+ B, _, N, C = q.shape
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = attn @ v
+ x = x.transpose(1,2).reshape(B, N, C*self.num_heads)
+ return x
+
+ def forward_temporal(self, q, k, v, seqlen=8):
+ B, _, N, C = q.shape
+ qt = q.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
+ kt = k.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
+ vt = v.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
+
+ attn = (qt @ kt.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = attn @ vt #(B, H, N, T, C)
+ x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C*self.num_heads)
+ return x
+
+ def count_attn(self, attn):
+ attn = attn.detach().cpu().numpy()
+ attn = attn.mean(axis=1)
+ attn_t = attn[:, :, 1].mean(axis=1)
+ attn_s = attn[:, :, 0].mean(axis=1)
+ if self.attn_count_s is None:
+ self.attn_count_s = attn_s
+ self.attn_count_t = attn_t
+ else:
+ self.attn_count_s = np.concatenate([self.attn_count_s, attn_s], axis=0)
+ self.attn_count_t = np.concatenate([self.attn_count_t, attn_t], axis=0)
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., mlp_out_ratio=1., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, st_mode='stage_st', att_fuse=False):
+ super().__init__()
+ # assert 'stage' in st_mode
+ self.st_mode = st_mode
+ self.norm1_s = norm_layer(dim)
+ self.norm1_t = norm_layer(dim)
+ self.attn_s = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="spatial")
+ self.attn_t = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="temporal")
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2_s = norm_layer(dim)
+ self.norm2_t = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ mlp_out_dim = int(dim * mlp_out_ratio)
+ self.mlp_s = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop)
+ self.mlp_t = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop)
+ self.att_fuse = att_fuse
+ if self.att_fuse:
+ self.ts_attn = nn.Linear(dim*2, dim*2)
+ def forward(self, x, seqlen=1):
+ if self.st_mode=='stage_st':
+ x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
+ x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
+ x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
+ x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
+ elif self.st_mode=='stage_ts':
+ x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
+ x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
+ x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
+ x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
+ elif self.st_mode=='stage_para':
+ x_t = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
+ x_t = x_t + self.drop_path(self.mlp_t(self.norm2_t(x_t)))
+ x_s = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
+ x_s = x_s + self.drop_path(self.mlp_s(self.norm2_s(x_s)))
+ if self.att_fuse:
+ # x_s, x_t: [BF, J, dim]
+ alpha = torch.cat([x_s, x_t], dim=-1)
+ BF, J = alpha.shape[:2]
+ # alpha = alpha.mean(dim=1, keepdim=True)
+ alpha = self.ts_attn(alpha).reshape(BF, J, -1, 2)
+ alpha = alpha.softmax(dim=-1)
+ x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0]
+ else:
+ x = (x_s + x_t)*0.5
+ else:
+ raise NotImplementedError(self.st_mode)
+ return x
+
+class DSTformer(nn.Module):
+ def __init__(self, dim_in=3, dim_out=3, dim_feat=256, dim_rep=512,
+ depth=5, num_heads=8, mlp_ratio=4,
+ num_joints=17, maxlen=243,
+ qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, att_fuse=True):
+ super().__init__()
+ self.dim_out = dim_out
+ self.dim_feat = dim_feat
+ self.joints_embed = nn.Linear(dim_in, dim_feat)
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks_st = nn.ModuleList([
+ Block(
+ dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ st_mode="stage_st")
+ for i in range(depth)])
+ self.blocks_ts = nn.ModuleList([
+ Block(
+ dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ st_mode="stage_ts")
+ for i in range(depth)])
+ self.norm = norm_layer(dim_feat)
+ if dim_rep:
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(dim_feat, dim_rep)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+ self.head = nn.Linear(dim_rep, dim_out) if dim_out > 0 else nn.Identity()
+ self.temp_embed = nn.Parameter(torch.zeros(1, maxlen, 1, dim_feat))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_joints, dim_feat))
+ trunc_normal_(self.temp_embed, std=.02)
+ trunc_normal_(self.pos_embed, std=.02)
+ self.apply(self._init_weights)
+ self.att_fuse = att_fuse
+ if self.att_fuse:
+ self.ts_attn = nn.ModuleList([nn.Linear(dim_feat*2, 2) for i in range(depth)])
+ for i in range(depth):
+ self.ts_attn[i].weight.data.fill_(0)
+ self.ts_attn[i].bias.data.fill_(0.5)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, dim_out, global_pool=''):
+ self.dim_out = dim_out
+ self.head = nn.Linear(self.dim_feat, dim_out) if dim_out > 0 else nn.Identity()
+
+ def forward(self, x, return_rep=False):
+ B, F, J, C = x.shape
+ x = x.reshape(-1, J, C)
+ BF = x.shape[0]
+ x = self.joints_embed(x)
+ x = x + self.pos_embed
+ _, J, C = x.shape
+ x = x.reshape(-1, F, J, C) + self.temp_embed[:,:F,:,:]
+ x = x.reshape(BF, J, C)
+ x = self.pos_drop(x)
+ alphas = []
+ for idx, (blk_st, blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)):
+ x_st = blk_st(x, F)
+ x_ts = blk_ts(x, F)
+ if self.att_fuse:
+ att = self.ts_attn[idx]
+ alpha = torch.cat([x_st, x_ts], dim=-1)
+ BF, J = alpha.shape[:2]
+ alpha = att(alpha)
+ alpha = alpha.softmax(dim=-1)
+ x = x_st * alpha[:,:,0:1] + x_ts * alpha[:,:,1:2]
+ else:
+ x = (x_st + x_ts)*0.5
+ x = self.norm(x)
+ x = x.reshape(B, F, J, -1)
+ x = self.pre_logits(x) # [B, F, J, dim_feat]
+ if return_rep:
+ return x
+ x = self.head(x)
+ return x
+
+ def get_representation(self, x):
+ return self.forward(x, return_rep=True)
+
\ No newline at end of file
diff --git a/lib/model/drop.py b/lib/model/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..efbed356f80c4e94d0ecb7f1f9a64c3c3e232887
--- /dev/null
+++ b/lib/model/drop.py
@@ -0,0 +1,43 @@
+""" DropBlock, DropPath
+PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
+Papers:
+DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
+Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
+Code:
+DropBlock impl inspired by two Tensorflow impl that I liked:
+ - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
+ - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
\ No newline at end of file
diff --git a/lib/model/loss.py b/lib/model/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..4397ce1665f05debad6f46dc89be01d5481bf37c
--- /dev/null
+++ b/lib/model/loss.py
@@ -0,0 +1,204 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+
+# Numpy-based errors
+
+def mpjpe(predicted, target):
+ """
+ Mean per-joint position error (i.e. mean Euclidean distance),
+ often referred to as "Protocol #1" in many papers.
+ """
+ assert predicted.shape == target.shape
+ return np.mean(np.linalg.norm(predicted - target, axis=len(target.shape)-1), axis=1)
+
+def p_mpjpe(predicted, target):
+ """
+ Pose error: MPJPE after rigid alignment (scale, rotation, and translation),
+ often referred to as "Protocol #2" in many papers.
+ """
+ assert predicted.shape == target.shape
+
+ muX = np.mean(target, axis=1, keepdims=True)
+ muY = np.mean(predicted, axis=1, keepdims=True)
+
+ X0 = target - muX
+ Y0 = predicted - muY
+
+ normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True))
+ normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True))
+
+ X0 /= normX
+ Y0 /= normY
+
+ H = np.matmul(X0.transpose(0, 2, 1), Y0)
+ U, s, Vt = np.linalg.svd(H)
+ V = Vt.transpose(0, 2, 1)
+ R = np.matmul(V, U.transpose(0, 2, 1))
+
+ # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1
+ sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1))
+ V[:, :, -1] *= sign_detR
+ s[:, -1] *= sign_detR.flatten()
+ R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation
+ tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2)
+ a = tr * normX / normY # Scale
+ t = muX - a*np.matmul(muY, R) # Translation
+ # Perform rigid transformation on the input
+ predicted_aligned = a*np.matmul(predicted, R) + t
+ # Return MPJPE
+ return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1), axis=1)
+
+
+# PyTorch-based errors (for losses)
+
+def loss_mpjpe(predicted, target):
+ """
+ Mean per-joint position error (i.e. mean Euclidean distance),
+ often referred to as "Protocol #1" in many papers.
+ """
+ assert predicted.shape == target.shape
+ return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1))
+
+def weighted_mpjpe(predicted, target, w):
+ """
+ Weighted mean per-joint position error (i.e. mean Euclidean distance)
+ """
+ assert predicted.shape == target.shape
+ assert w.shape[0] == predicted.shape[0]
+ return torch.mean(w * torch.norm(predicted - target, dim=len(target.shape)-1))
+
+def loss_2d_weighted(predicted, target, conf):
+ assert predicted.shape == target.shape
+ predicted_2d = predicted[:,:,:,:2]
+ target_2d = target[:,:,:,:2]
+ diff = (predicted_2d - target_2d) * conf
+ return torch.mean(torch.norm(diff, dim=-1))
+
+def n_mpjpe(predicted, target):
+ """
+ Normalized MPJPE (scale only), adapted from:
+ https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py
+ """
+ assert predicted.shape == target.shape
+ norm_predicted = torch.mean(torch.sum(predicted**2, dim=3, keepdim=True), dim=2, keepdim=True)
+ norm_target = torch.mean(torch.sum(target*predicted, dim=3, keepdim=True), dim=2, keepdim=True)
+ scale = norm_target / norm_predicted
+ return loss_mpjpe(scale * predicted, target)
+
+def weighted_bonelen_loss(predict_3d_length, gt_3d_length):
+ loss_length = 0.001 * torch.pow(predict_3d_length - gt_3d_length, 2).mean()
+ return loss_length
+
+def weighted_boneratio_loss(predict_3d_length, gt_3d_length):
+ loss_length = 0.1 * torch.pow((predict_3d_length - gt_3d_length)/gt_3d_length, 2).mean()
+ return loss_length
+
+def get_limb_lens(x):
+ '''
+ Input: (N, T, 17, 3)
+ Output: (N, T, 16)
+ '''
+ limbs_id = [[0,1], [1,2], [2,3],
+ [0,4], [4,5], [5,6],
+ [0,7], [7,8], [8,9], [9,10],
+ [8,11], [11,12], [12,13],
+ [8,14], [14,15], [15,16]
+ ]
+ limbs = x[:,:,limbs_id,:]
+ limbs = limbs[:,:,:,0,:]-limbs[:,:,:,1,:]
+ limb_lens = torch.norm(limbs, dim=-1)
+ return limb_lens
+
+def loss_limb_var(x):
+ '''
+ Input: (N, T, 17, 3)
+ '''
+ if x.shape[1]<=1:
+ return torch.FloatTensor(1).fill_(0.)[0].to(x.device)
+ limb_lens = get_limb_lens(x)
+ limb_lens_var = torch.var(limb_lens, dim=1)
+ limb_loss_var = torch.mean(limb_lens_var)
+ return limb_loss_var
+
+def loss_limb_gt(x, gt):
+ '''
+ Input: (N, T, 17, 3), (N, T, 17, 3)
+ '''
+ limb_lens_x = get_limb_lens(x)
+ limb_lens_gt = get_limb_lens(gt) # (N, T, 16)
+ return nn.L1Loss()(limb_lens_x, limb_lens_gt)
+
+def loss_velocity(predicted, target):
+ """
+ Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative)
+ """
+ assert predicted.shape == target.shape
+ if predicted.shape[1]<=1:
+ return torch.FloatTensor(1).fill_(0.)[0].to(predicted.device)
+ velocity_predicted = predicted[:,1:] - predicted[:,:-1]
+ velocity_target = target[:,1:] - target[:,:-1]
+ return torch.mean(torch.norm(velocity_predicted - velocity_target, dim=-1))
+
+def loss_joint(predicted, target):
+ assert predicted.shape == target.shape
+ return nn.L1Loss()(predicted, target)
+
+def get_angles(x):
+ '''
+ Input: (N, T, 17, 3)
+ Output: (N, T, 16)
+ '''
+ limbs_id = [[0,1], [1,2], [2,3],
+ [0,4], [4,5], [5,6],
+ [0,7], [7,8], [8,9], [9,10],
+ [8,11], [11,12], [12,13],
+ [8,14], [14,15], [15,16]
+ ]
+ angle_id = [[ 0, 3],
+ [ 0, 6],
+ [ 3, 6],
+ [ 0, 1],
+ [ 1, 2],
+ [ 3, 4],
+ [ 4, 5],
+ [ 6, 7],
+ [ 7, 10],
+ [ 7, 13],
+ [ 8, 13],
+ [10, 13],
+ [ 7, 8],
+ [ 8, 9],
+ [10, 11],
+ [11, 12],
+ [13, 14],
+ [14, 15] ]
+ eps = 1e-7
+ limbs = x[:,:,limbs_id,:]
+ limbs = limbs[:,:,:,0,:]-limbs[:,:,:,1,:]
+ angles = limbs[:,:,angle_id,:]
+ angle_cos = F.cosine_similarity(angles[:,:,:,0,:], angles[:,:,:,1,:], dim=-1)
+ return torch.acos(angle_cos.clamp(-1+eps, 1-eps))
+
+def loss_angle(x, gt):
+ '''
+ Input: (N, T, 17, 3), (N, T, 17, 3)
+ '''
+ limb_angles_x = get_angles(x)
+ limb_angles_gt = get_angles(gt)
+ return nn.L1Loss()(limb_angles_x, limb_angles_gt)
+
+def loss_angle_velocity(x, gt):
+ """
+ Mean per-angle velocity error (i.e. mean Euclidean distance of the 1st derivative)
+ """
+ assert x.shape == gt.shape
+ if x.shape[1]<=1:
+ return torch.FloatTensor(1).fill_(0.)[0].to(x.device)
+ x_a = get_angles(x)
+ gt_a = get_angles(gt)
+ x_av = x_a[:,1:] - x_a[:,:-1]
+ gt_av = gt_a[:,1:] - gt_a[:,:-1]
+ return nn.L1Loss()(x_av, gt_av)
+
diff --git a/lib/model/loss_mesh.py b/lib/model/loss_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..82f615f1484d163b750eaf80eaa3560bfeadd773
--- /dev/null
+++ b/lib/model/loss_mesh.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+import ipdb
+from lib.utils.utils_mesh import batch_rodrigues
+from lib.model.loss import *
+
+class MeshLoss(nn.Module):
+ def __init__(
+ self,
+ loss_type='MSE',
+ device='cuda',
+ ):
+ super(MeshLoss, self).__init__()
+ self.device = device
+ self.loss_type = loss_type
+ if loss_type == 'MSE':
+ self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
+ self.criterion_regr = nn.MSELoss().to(self.device)
+ elif loss_type == 'L1':
+ self.criterion_keypoints = nn.L1Loss(reduction='none').to(self.device)
+ self.criterion_regr = nn.L1Loss().to(self.device)
+
+ def forward(
+ self,
+ smpl_output,
+ data_gt,
+ ):
+ # to reduce time dimension
+ reduce = lambda x: x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
+ data_3d_theta = reduce(data_gt['theta'])
+
+ preds = smpl_output[-1]
+ pred_theta = preds['theta']
+ theta_size = pred_theta.shape[:2]
+ pred_theta = reduce(pred_theta)
+ preds_local = preds['kp_3d'] - preds['kp_3d'][:, :, 0:1,:] # (N, T, 17, 3)
+ gt_local = data_gt['kp_3d'] - data_gt['kp_3d'][:, :, 0:1,:]
+ real_shape, pred_shape = data_3d_theta[:, 72:], pred_theta[:, 72:]
+ real_pose, pred_pose = data_3d_theta[:, :72], pred_theta[:, :72]
+ loss_dict = {}
+ loss_dict['loss_3d_pos'] = loss_mpjpe(preds_local, gt_local)
+ loss_dict['loss_3d_scale'] = n_mpjpe(preds_local, gt_local)
+ loss_dict['loss_3d_velocity'] = loss_velocity(preds_local, gt_local)
+ loss_dict['loss_lv'] = loss_limb_var(preds_local)
+ loss_dict['loss_lg'] = loss_limb_gt(preds_local, gt_local)
+ loss_dict['loss_a'] = loss_angle(preds_local, gt_local)
+ loss_dict['loss_av'] = loss_angle_velocity(preds_local, gt_local)
+
+ if pred_theta.shape[0] > 0:
+ loss_pose, loss_shape = self.smpl_losses(pred_pose, pred_shape, real_pose, real_shape)
+ loss_norm = torch.norm(pred_theta, dim=-1).mean()
+ loss_dict['loss_shape'] = loss_shape
+ loss_dict['loss_pose'] = loss_pose
+ loss_dict['loss_norm'] = loss_norm
+ return loss_dict
+
+ def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas):
+ pred_rotmat_valid = batch_rodrigues(pred_rotmat.reshape(-1,3)).reshape(-1, 24, 3, 3)
+ gt_rotmat_valid = batch_rodrigues(gt_pose.reshape(-1,3)).reshape(-1, 24, 3, 3)
+ pred_betas_valid = pred_betas
+ gt_betas_valid = gt_betas
+ if len(pred_rotmat_valid) > 0:
+ loss_regr_pose = self.criterion_regr(pred_rotmat_valid, gt_rotmat_valid)
+ loss_regr_betas = self.criterion_regr(pred_betas_valid, gt_betas_valid)
+ else:
+ loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device)
+ loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device)
+ return loss_regr_pose, loss_regr_betas
diff --git a/lib/model/loss_supcon.py b/lib/model/loss_supcon.py
new file mode 100644
index 0000000000000000000000000000000000000000..17117d4210160679dee0bdc4a4b7d97c433e1d43
--- /dev/null
+++ b/lib/model/loss_supcon.py
@@ -0,0 +1,98 @@
+"""
+Author: Yonglong Tian (yonglong@mit.edu)
+Date: May 07, 2020
+"""
+from __future__ import print_function
+
+import torch
+import torch.nn as nn
+
+
+class SupConLoss(nn.Module):
+ """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
+ It also supports the unsupervised contrastive loss in SimCLR"""
+ def __init__(self, temperature=0.07, contrast_mode='all',
+ base_temperature=0.07):
+ super(SupConLoss, self).__init__()
+ self.temperature = temperature
+ self.contrast_mode = contrast_mode
+ self.base_temperature = base_temperature
+
+ def forward(self, features, labels=None, mask=None):
+ """Compute loss for model. If both `labels` and `mask` are None,
+ it degenerates to SimCLR unsupervised loss:
+ https://arxiv.org/pdf/2002.05709.pdf
+
+ Args:
+ features: hidden vector of shape [bsz, n_views, ...].
+ labels: ground truth of shape [bsz].
+ mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
+ has the same class as sample i. Can be asymmetric.
+ Returns:
+ A loss scalar.
+ """
+ device = (torch.device('cuda')
+ if features.is_cuda
+ else torch.device('cpu'))
+
+ if len(features.shape) < 3:
+ raise ValueError('`features` needs to be [bsz, n_views, ...],'
+ 'at least 3 dimensions are required')
+ if len(features.shape) > 3:
+ features = features.view(features.shape[0], features.shape[1], -1)
+
+ batch_size = features.shape[0]
+ if labels is not None and mask is not None:
+ raise ValueError('Cannot define both `labels` and `mask`')
+ elif labels is None and mask is None:
+ mask = torch.eye(batch_size, dtype=torch.float32).to(device)
+ elif labels is not None:
+ labels = labels.contiguous().view(-1, 1)
+ if labels.shape[0] != batch_size:
+ raise ValueError('Num of labels does not match num of features')
+ mask = torch.eq(labels, labels.T).float().to(device)
+ else:
+ mask = mask.float().to(device)
+
+ contrast_count = features.shape[1]
+ contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
+ if self.contrast_mode == 'one':
+ anchor_feature = features[:, 0]
+ anchor_count = 1
+ elif self.contrast_mode == 'all':
+ anchor_feature = contrast_feature
+ anchor_count = contrast_count
+ else:
+ raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
+
+ # compute logits
+ anchor_dot_contrast = torch.div(
+ torch.matmul(anchor_feature, contrast_feature.T),
+ self.temperature)
+ # for numerical stability
+ logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
+ logits = anchor_dot_contrast - logits_max.detach()
+
+ # tile mask
+ mask = mask.repeat(anchor_count, contrast_count)
+ # mask-out self-contrast cases
+ logits_mask = torch.scatter(
+ torch.ones_like(mask),
+ 1,
+ torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
+ 0
+ )
+ mask = mask * logits_mask
+
+ # compute log_prob
+ exp_logits = torch.exp(logits) * logits_mask
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
+
+ # compute mean of log-likelihood over positive
+ mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
+
+ # loss
+ loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
+ loss = loss.view(anchor_count, batch_size).mean()
+
+ return loss
diff --git a/lib/model/model_action.py b/lib/model/model_action.py
new file mode 100644
index 0000000000000000000000000000000000000000..785ec2671b4bdb8a1b47190753b9b46534d4280f
--- /dev/null
+++ b/lib/model/model_action.py
@@ -0,0 +1,71 @@
+import sys
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class ActionHeadClassification(nn.Module):
+ def __init__(self, dropout_ratio=0., dim_rep=512, num_classes=60, num_joints=17, hidden_dim=2048):
+ super(ActionHeadClassification, self).__init__()
+ self.dropout = nn.Dropout(p=dropout_ratio)
+ self.bn = nn.BatchNorm1d(hidden_dim, momentum=0.1)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, num_classes)
+
+ def forward(self, feat):
+ '''
+ Input: (N, M, T, J, C)
+ '''
+ N, M, T, J, C = feat.shape
+ feat = self.dropout(feat)
+ feat = feat.permute(0, 1, 3, 4, 2) # (N, M, T, J, C) -> (N, M, J, C, T)
+ feat = feat.mean(dim=-1)
+ feat = feat.reshape(N, M, -1) # (N, M, J*C)
+ feat = feat.mean(dim=1)
+ feat = self.fc1(feat)
+ feat = self.bn(feat)
+ feat = self.relu(feat)
+ feat = self.fc2(feat)
+ return feat
+
+class ActionHeadEmbed(nn.Module):
+ def __init__(self, dropout_ratio=0., dim_rep=512, num_joints=17, hidden_dim=2048):
+ super(ActionHeadEmbed, self).__init__()
+ self.dropout = nn.Dropout(p=dropout_ratio)
+ self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim)
+ def forward(self, feat):
+ '''
+ Input: (N, M, T, J, C)
+ '''
+ N, M, T, J, C = feat.shape
+ feat = self.dropout(feat)
+ feat = feat.permute(0, 1, 3, 4, 2) # (N, M, T, J, C) -> (N, M, J, C, T)
+ feat = feat.mean(dim=-1)
+ feat = feat.reshape(N, M, -1) # (N, M, J*C)
+ feat = feat.mean(dim=1)
+ feat = self.fc1(feat)
+ feat = F.normalize(feat, dim=-1)
+ return feat
+
+class ActionNet(nn.Module):
+ def __init__(self, backbone, dim_rep=512, num_classes=60, dropout_ratio=0., version='class', hidden_dim=2048, num_joints=17):
+ super(ActionNet, self).__init__()
+ self.backbone = backbone
+ self.feat_J = num_joints
+ if version=='class':
+ self.head = ActionHeadClassification(dropout_ratio=dropout_ratio, dim_rep=dim_rep, num_classes=num_classes, num_joints=num_joints)
+ elif version=='embed':
+ self.head = ActionHeadEmbed(dropout_ratio=dropout_ratio, dim_rep=dim_rep, hidden_dim=hidden_dim, num_joints=num_joints)
+ else:
+ raise Exception('Version Error.')
+
+ def forward(self, x):
+ '''
+ Input: (N, M x T x 17 x 3)
+ '''
+ N, M, T, J, C = x.shape
+ x = x.reshape(N*M, T, J, C)
+ feat = self.backbone.get_representation(x)
+ feat = feat.reshape([N, M, T, self.feat_J, -1]) # (N, M, T, J, C)
+ out = self.head(feat)
+ return out
\ No newline at end of file
diff --git a/lib/model/model_mesh.py b/lib/model/model_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..dff579db07a3611c98deea341f3ccd8e87aea33c
--- /dev/null
+++ b/lib/model/model_mesh.py
@@ -0,0 +1,101 @@
+import sys
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from lib.utils.utils_smpl import SMPL
+from lib.utils.utils_mesh import rotation_matrix_to_angle_axis, rot6d_to_rotmat
+
+class SMPLRegressor(nn.Module):
+ def __init__(self, args, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.):
+ super(SMPLRegressor, self).__init__()
+ param_pose_dim = 24 * 6
+ self.dropout = nn.Dropout(p=dropout_ratio)
+ self.fc1 = nn.Linear(num_joints*dim_rep, hidden_dim)
+ self.pool2 = nn.AdaptiveAvgPool2d((None, 1))
+ self.fc2 = nn.Linear(num_joints*dim_rep, hidden_dim)
+ self.bn1 = nn.BatchNorm1d(hidden_dim, momentum=0.1)
+ self.bn2 = nn.BatchNorm1d(hidden_dim, momentum=0.1)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.head_pose = nn.Linear(hidden_dim, param_pose_dim)
+ self.head_shape = nn.Linear(hidden_dim, 10)
+ nn.init.xavier_uniform_(self.head_pose.weight, gain=0.01)
+ nn.init.xavier_uniform_(self.head_shape.weight, gain=0.01)
+ self.smpl = SMPL(
+ args.data_root,
+ batch_size=64,
+ create_transl=False,
+ )
+ mean_params = np.load(self.smpl.smpl_mean_params)
+ init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
+ init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
+ self.register_buffer('init_pose', init_pose)
+ self.register_buffer('init_shape', init_shape)
+ self.J_regressor = self.smpl.J_regressor_h36m
+
+ def forward(self, feat, init_pose=None, init_shape=None):
+ N, T, J, C = feat.shape
+ NT = N * T
+ feat = feat.reshape(N, T, -1)
+
+ feat_pose = feat.reshape(NT, -1) # (N*T, J*C)
+
+ feat_pose = self.dropout(feat_pose)
+ feat_pose = self.fc1(feat_pose)
+ feat_pose = self.bn1(feat_pose)
+ feat_pose = self.relu1(feat_pose) # (NT, C)
+
+ feat_shape = feat.permute(0,2,1) # (N, T, J*C) -> (N, J*C, T)
+ feat_shape = self.pool2(feat_shape).reshape(N, -1) # (N, J*C)
+
+ feat_shape = self.dropout(feat_shape)
+ feat_shape = self.fc2(feat_shape)
+ feat_shape = self.bn2(feat_shape)
+ feat_shape = self.relu2(feat_shape) # (N, C)
+
+ pred_pose = self.init_pose.expand(NT, -1) # (NT, C)
+ pred_shape = self.init_shape.expand(N, -1) # (N, C)
+
+ pred_pose = self.head_pose(feat_pose) + pred_pose
+ pred_shape = self.head_shape(feat_shape) + pred_shape
+ pred_shape = pred_shape.expand(T, N, -1).permute(1, 0, 2).reshape(NT, -1)
+ pred_rotmat = rot6d_to_rotmat(pred_pose).view(-1, 24, 3, 3)
+ pred_output = self.smpl(
+ betas=pred_shape,
+ body_pose=pred_rotmat[:, 1:],
+ global_orient=pred_rotmat[:, 0].unsqueeze(1),
+ pose2rot=False
+ )
+ pred_vertices = pred_output.vertices*1000.0
+ assert self.J_regressor is not None
+ J_regressor_batch = self.J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
+ pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
+ pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72)
+ output = [{
+ 'theta' : torch.cat([pose, pred_shape], dim=1), # (N*T, 72+10)
+ 'verts' : pred_vertices, # (N*T, 6890, 3)
+ 'kp_3d' : pred_joints, # (N*T, 17, 3)
+ }]
+ return output
+
+class MeshRegressor(nn.Module):
+ def __init__(self, args, backbone, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.5):
+ super(MeshRegressor, self).__init__()
+ self.backbone = backbone
+ self.feat_J = num_joints
+ self.head = SMPLRegressor(args, dim_rep, num_joints, hidden_dim, dropout_ratio)
+
+ def forward(self, x, init_pose=None, init_shape=None, n_iter=3):
+ '''
+ Input: (N x T x 17 x 3)
+ '''
+ N, T, J, C = x.shape
+ feat = self.backbone.get_representation(x)
+ feat = feat.reshape([N, T, self.feat_J, -1]) # (N, T, J, C)
+ smpl_output = self.head(feat)
+ for s in smpl_output:
+ s['theta'] = s['theta'].reshape(N, T, -1)
+ s['verts'] = s['verts'].reshape(N, T, -1, 3)
+ s['kp_3d'] = s['kp_3d'].reshape(N, T, -1, 3)
+ return smpl_output
\ No newline at end of file
diff --git a/lib/utils/learning.py b/lib/utils/learning.py
new file mode 100644
index 0000000000000000000000000000000000000000..191e6697919a338f59ec53263b6edc3f300a2783
--- /dev/null
+++ b/lib/utils/learning.py
@@ -0,0 +1,102 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+from functools import partial
+from lib.model.DSTformer import DSTformer
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ with torch.no_grad():
+ maxk = max(topk)
+ batch_size = target.size(0)
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+def load_pretrained_weights(model, checkpoint):
+ """Load pretrianed weights to model
+ Incompatible layers (unmatched in name or size) will be ignored
+ Args:
+ - model (nn.Module): network model, which must not be nn.DataParallel
+ - weight_path (str): path to pretrained weights
+ """
+ import collections
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ model_dict = model.state_dict()
+ new_state_dict = collections.OrderedDict()
+ matched_layers, discarded_layers = [], []
+ for k, v in state_dict.items():
+ # If the pretrained state_dict was saved as nn.DataParallel,
+ # keys would contain "module.", which should be ignored.
+ if k.startswith('module.'):
+ k = k[7:]
+ if k in model_dict and model_dict[k].size() == v.size():
+ new_state_dict[k] = v
+ matched_layers.append(k)
+ else:
+ discarded_layers.append(k)
+ model_dict.update(new_state_dict)
+ model.load_state_dict(model_dict, strict=True)
+ print('load_weight', len(matched_layers))
+ return model
+
+def partial_train_layers(model, partial_list):
+ """Train partial layers of a given model."""
+ for name, p in model.named_parameters():
+ p.requires_grad = False
+ for trainable in partial_list:
+ if trainable in name:
+ p.requires_grad = True
+ break
+ return model
+
+def load_backbone(args):
+ if not(hasattr(args, "backbone")):
+ args.backbone = 'DSTformer' # Default
+ if args.backbone=='DSTformer':
+ model_backbone = DSTformer(dim_in=3, dim_out=3, dim_feat=args.dim_feat, dim_rep=args.dim_rep,
+ depth=args.depth, num_heads=args.num_heads, mlp_ratio=args.mlp_ratio, norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ maxlen=args.maxlen, num_joints=args.num_joints)
+ elif args.backbone=='TCN':
+ from lib.model.model_tcn import PoseTCN
+ model_backbone = PoseTCN()
+ elif args.backbone=='poseformer':
+ from lib.model.model_poseformer import PoseTransformer
+ model_backbone = PoseTransformer(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=32, depth=4,
+ num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0, attn_mask=None)
+ elif args.backbone=='mixste':
+ from lib.model.model_mixste import MixSTE2
+ model_backbone = MixSTE2(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=512, depth=8,
+ num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0)
+ elif args.backbone=='stgcn':
+ from lib.model.model_stgcn import Model as STGCN
+ model_backbone = STGCN()
+ else:
+ raise Exception("Undefined backbone type.")
+ return model_backbone
\ No newline at end of file
diff --git a/lib/utils/tools.py b/lib/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2b780f0b184584923e54643372a26cca3e9c277
--- /dev/null
+++ b/lib/utils/tools.py
@@ -0,0 +1,69 @@
+import numpy as np
+import os, sys
+import pickle
+import yaml
+from easydict import EasyDict as edict
+from typing import Any, IO
+
+ROOT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
+
+class TextLogger:
+ def __init__(self, log_path):
+ self.log_path = log_path
+ with open(self.log_path, "w") as f:
+ f.write("")
+ def log(self, log):
+ with open(self.log_path, "a+") as f:
+ f.write(log + "\n")
+
+class Loader(yaml.SafeLoader):
+ """YAML Loader with `!include` constructor."""
+
+ def __init__(self, stream: IO) -> None:
+ """Initialise Loader."""
+
+ try:
+ self._root = os.path.split(stream.name)[0]
+ except AttributeError:
+ self._root = os.path.curdir
+
+ super().__init__(stream)
+
+def construct_include(loader: Loader, node: yaml.Node) -> Any:
+ """Include file referenced at node."""
+
+ filename = os.path.abspath(os.path.join(loader._root, loader.construct_scalar(node)))
+ extension = os.path.splitext(filename)[1].lstrip('.')
+
+ with open(filename, 'r') as f:
+ if extension in ('yaml', 'yml'):
+ return yaml.load(f, Loader)
+ elif extension in ('json', ):
+ return json.load(f)
+ else:
+ return ''.join(f.readlines())
+
+def get_config(config_path):
+ yaml.add_constructor('!include', construct_include, Loader)
+ with open(config_path, 'r') as stream:
+ config = yaml.load(stream, Loader=Loader)
+ config = edict(config)
+ _, config_filename = os.path.split(config_path)
+ config_name, _ = os.path.splitext(config_filename)
+ config.name = config_name
+ return config
+
+def ensure_dir(path):
+ """
+ create path by first checking its existence,
+ :param paths: path
+ :return:
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def read_pkl(data_url):
+ file = open(data_url,'rb')
+ content = pickle.load(file)
+ file.close()
+ return content
\ No newline at end of file
diff --git a/lib/utils/utils_data.py b/lib/utils/utils_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..df7b61efacfa737191237fdeba63f77a6408b701
--- /dev/null
+++ b/lib/utils/utils_data.py
@@ -0,0 +1,112 @@
+import os
+import torch
+import torch.nn.functional as F
+import numpy as np
+import copy
+
+def crop_scale(motion, scale_range=[1, 1]):
+ '''
+ Motion: [(M), T, 17, 3].
+ Normalize to [-1, 1]
+ '''
+ result = copy.deepcopy(motion)
+ valid_coords = motion[motion[..., 2]!=0][:,:2]
+ if len(valid_coords) < 4:
+ return np.zeros(motion.shape)
+ xmin = min(valid_coords[:,0])
+ xmax = max(valid_coords[:,0])
+ ymin = min(valid_coords[:,1])
+ ymax = max(valid_coords[:,1])
+ ratio = np.random.uniform(low=scale_range[0], high=scale_range[1], size=1)[0]
+ scale = max(xmax-xmin, ymax-ymin) * ratio
+ if scale==0:
+ return np.zeros(motion.shape)
+ xs = (xmin+xmax-scale) / 2
+ ys = (ymin+ymax-scale) / 2
+ result[...,:2] = (motion[..., :2]- [xs,ys]) / scale
+ result[...,:2] = (result[..., :2] - 0.5) * 2
+ result = np.clip(result, -1, 1)
+ return result
+
+def crop_scale_3d(motion, scale_range=[1, 1]):
+ '''
+ Motion: [T, 17, 3]. (x, y, z)
+ Normalize to [-1, 1]
+ Z is relative to the first frame's root.
+ '''
+ result = copy.deepcopy(motion)
+ result[:,:,2] = result[:,:,2] - result[0,0,2]
+ xmin = np.min(motion[...,0])
+ xmax = np.max(motion[...,0])
+ ymin = np.min(motion[...,1])
+ ymax = np.max(motion[...,1])
+ ratio = np.random.uniform(low=scale_range[0], high=scale_range[1], size=1)[0]
+ scale = max(xmax-xmin, ymax-ymin) / ratio
+ if scale==0:
+ return np.zeros(motion.shape)
+ xs = (xmin+xmax-scale) / 2
+ ys = (ymin+ymax-scale) / 2
+ result[...,:2] = (motion[..., :2]- [xs,ys]) / scale
+ result[...,2] = result[...,2] / scale
+ result = (result - 0.5) * 2
+ return result
+
+def flip_data(data):
+ """
+ horizontal flip
+ data: [N, F, 17, D] or [F, 17, D]. X (horizontal coordinate) is the first channel in D.
+ Return
+ result: same
+ """
+ left_joints = [4, 5, 6, 11, 12, 13]
+ right_joints = [1, 2, 3, 14, 15, 16]
+ flipped_data = copy.deepcopy(data)
+ flipped_data[..., 0] *= -1 # flip x of all joints
+ flipped_data[..., left_joints+right_joints, :] = flipped_data[..., right_joints+left_joints, :]
+ return flipped_data
+
+def resample(ori_len, target_len, replay=False, randomness=True):
+ if replay:
+ if ori_len > target_len:
+ st = np.random.randint(ori_len-target_len)
+ return range(st, st+target_len) # Random clipping from sequence
+ else:
+ return np.array(range(target_len)) % ori_len # Replay padding
+ else:
+ if randomness:
+ even = np.linspace(0, ori_len, num=target_len, endpoint=False)
+ if ori_len < target_len:
+ low = np.floor(even)
+ high = np.ceil(even)
+ sel = np.random.randint(2, size=even.shape)
+ result = np.sort(sel*low+(1-sel)*high)
+ else:
+ interval = even[1] - even[0]
+ result = np.random.random(even.shape)*interval + even
+ result = np.clip(result, a_min=0, a_max=ori_len-1).astype(np.uint32)
+ else:
+ result = np.linspace(0, ori_len, num=target_len, endpoint=False, dtype=int)
+ return result
+
+def split_clips(vid_list, n_frames, data_stride):
+ result = []
+ n_clips = 0
+ st = 0
+ i = 0
+ saved = set()
+ while i(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:,
+ 2], norm_quat[:,
+ 3]
+
+ batch_size = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w * x, w * y, w * z
+ xy, xz, yz = x * y, x * z, y * z
+
+ rotMat = torch.stack([
+ w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
+ w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
+ w2 - x2 - y2 + z2
+ ],
+ dim=1).view(batch_size, 3, 3)
+ return rotMat
+
+
+def rotation_matrix_to_angle_axis(rotation_matrix):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert 3x4 rotation matrix to Rodrigues vector
+
+ Args:
+ rotation_matrix (Tensor): rotation matrix.
+
+ Returns:
+ Tensor: Rodrigues vector transformation.
+
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 3)`
+
+ Example:
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
+ """
+ if rotation_matrix.shape[1:] == (3,3):
+ rot_mat = rotation_matrix.reshape(-1, 3, 3)
+ hom = torch.tensor([0, 0, 1], dtype=torch.float32,
+ device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1)
+ rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
+
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
+ aa = quaternion_to_angle_axis(quaternion)
+ aa[torch.isnan(aa)] = 0.0
+ return aa
+
+
+def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert quaternion vector to angle axis of rotation.
+
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
+
+ Args:
+ quaternion (torch.Tensor): tensor with quaternions.
+
+ Return:
+ torch.Tensor: tensor with angle axis of rotation.
+
+ Shape:
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
+ - Output: :math:`(*, 3)`
+
+ Example:
+ >>> quaternion = torch.rand(2, 4) # Nx4
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
+ """
+ if not torch.is_tensor(quaternion):
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
+ type(quaternion)))
+
+ if not quaternion.shape[-1] == 4:
+ raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
+ .format(quaternion.shape))
+ # unpack input and compute conversion
+ q1: torch.Tensor = quaternion[..., 1]
+ q2: torch.Tensor = quaternion[..., 2]
+ q3: torch.Tensor = quaternion[..., 3]
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
+
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
+ cos_theta: torch.Tensor = quaternion[..., 0]
+ two_theta: torch.Tensor = 2.0 * torch.where(
+ cos_theta < 0.0,
+ torch.atan2(-sin_theta, -cos_theta),
+ torch.atan2(sin_theta, cos_theta))
+
+ k_pos: torch.Tensor = two_theta / sin_theta
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
+
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
+ angle_axis[..., 0] += q1 * k
+ angle_axis[..., 1] += q2 * k
+ angle_axis[..., 2] += q3 * k
+ return angle_axis
+
+
+def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert 3x4 rotation matrix to 4d quaternion vector
+
+ This algorithm is based on algorithm described in
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
+
+ Args:
+ rotation_matrix (Tensor): the rotation matrix to convert.
+
+ Return:
+ Tensor: the rotation in quaternion
+
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 4)`
+
+ Example:
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
+ """
+ if not torch.is_tensor(rotation_matrix):
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
+ type(rotation_matrix)))
+
+ if len(rotation_matrix.shape) > 3:
+ raise ValueError(
+ "Input size must be a three dimensional tensor. Got {}".format(
+ rotation_matrix.shape))
+ if not rotation_matrix.shape[-2:] == (3, 4):
+ raise ValueError(
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(
+ rotation_matrix.shape))
+
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
+
+ mask_d2 = rmat_t[:, 2, 2] < eps
+
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
+
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
+ t0_rep = t0.repeat(4, 1).t()
+
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
+ t1_rep = t1.repeat(4, 1).t()
+
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
+ t2_rep = t2.repeat(4, 1).t()
+
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
+ t3_rep = t3.repeat(4, 1).t()
+
+ mask_c0 = mask_d2 * mask_d0_d1
+ mask_c1 = mask_d2 * ~mask_d0_d1
+ mask_c2 = ~mask_d2 * mask_d0_nd1
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
+
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
+ q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
+ t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
+ q *= 0.5
+ return q
+
+
+def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000., img_size=224.):
+ """
+ This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py
+
+ Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
+ Input:
+ S: (25, 3) 3D joint locations
+ joints: (25, 3) 2D joint locations and confidence
+ Returns:
+ (3,) camera translation vector
+ """
+
+ num_joints = S.shape[0]
+ # focal length
+ f = np.array([focal_length,focal_length])
+ # optical center
+ center = np.array([img_size/2., img_size/2.])
+
+ # transformations
+ Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1)
+ XY = np.reshape(S[:,0:2],-1)
+ O = np.tile(center,num_joints)
+ F = np.tile(f,num_joints)
+ weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1)
+
+ # least squares
+ Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T
+ c = (np.reshape(joints_2d,-1)-O)*Z - F*XY
+
+ # weighted least squares
+ W = np.diagflat(weight2)
+ Q = np.dot(W,Q)
+ c = np.dot(W,c)
+
+ # square matrix
+ A = np.dot(Q.T,Q)
+ b = np.dot(Q.T,c)
+
+ # solution
+ trans = np.linalg.solve(A, b)
+
+ return trans
+
+
+def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.):
+ """
+ This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py
+
+ Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
+ Input:
+ S: (B, 49, 3) 3D joint locations
+ joints: (B, 49, 3) 2D joint locations and confidence
+ Returns:
+ (B, 3) camera translation vectors
+ """
+
+ device = S.device
+ # Use only joints 25:49 (GT joints)
+ S = S[:, 25:, :].cpu().numpy()
+ joints_2d = joints_2d[:, 25:, :].cpu().numpy()
+ joints_conf = joints_2d[:, :, -1]
+ joints_2d = joints_2d[:, :, :-1]
+ trans = np.zeros((S.shape[0], 3), dtype=np.float32)
+ # Find the translation for each example in the batch
+ for i in range(S.shape[0]):
+ S_i = S[i]
+ joints_i = joints_2d[i]
+ conf_i = joints_conf[i]
+ trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size)
+ return torch.from_numpy(trans).to(device)
+
+
+def rot6d_to_rotmat_spin(x):
+ """Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Input:
+ (B,6) Batch of 6-D rotation representations
+ Output:
+ (B,3,3) Batch of corresponding rotation matrices
+ """
+ x = x.view(-1,3,2)
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+
+ # inp = a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1
+ # denom = inp.pow(2).sum(dim=1).sqrt().unsqueeze(-1) + 1e-8
+ # b2 = inp / denom
+
+ b3 = torch.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+
+def rot6d_to_rotmat(x):
+ x = x.view(-1,3,2)
+
+ # Normalize the first vector
+ b1 = F.normalize(x[:, :, 0], dim=1, eps=1e-6)
+
+ dot_prod = torch.sum(b1 * x[:, :, 1], dim=1, keepdim=True)
+ # Compute the second vector by finding the orthogonal complement to it
+ b2 = F.normalize(x[:, :, 1] - dot_prod * b1, dim=-1, eps=1e-6)
+
+ # Finish building the basis by taking the cross product
+ b3 = torch.cross(b1, b2, dim=1)
+ rot_mats = torch.stack([b1, b2, b3], dim=-1)
+
+ return rot_mats
+
+
+def rigid_transform_3D(A, B):
+ n, dim = A.shape
+ centroid_A = np.mean(A, axis = 0)
+ centroid_B = np.mean(B, axis = 0)
+ H = np.dot(np.transpose(A - centroid_A), B - centroid_B) / n
+ U, s, V = np.linalg.svd(H)
+ R = np.dot(np.transpose(V), np.transpose(U))
+ if np.linalg.det(R) < 0:
+ s[-1] = -s[-1]
+ V[2] = -V[2]
+ R = np.dot(np.transpose(V), np.transpose(U))
+
+ varP = np.var(A, axis=0).sum()
+ c = 1/varP * np.sum(s)
+
+ t = -np.dot(c*R, np.transpose(centroid_A)) + np.transpose(centroid_B)
+ return c, R, t
+
+
+def rigid_align(A, B):
+ c, R, t = rigid_transform_3D(A, B)
+ A2 = np.transpose(np.dot(c*R, np.transpose(A))) + t
+ return A2
+
+def compute_error(output, target):
+ with torch.no_grad():
+ pred_verts = output[0]['verts'].reshape(-1, 6890, 3)
+ target_verts = target['verts'].reshape(-1, 6890, 3)
+
+ pred_j3ds = output[0]['kp_3d'].reshape(-1, 17, 3)
+ target_j3ds = target['kp_3d'].reshape(-1, 17, 3)
+
+ # mpve
+ pred_verts = pred_verts - pred_j3ds[:, :1, :]
+ target_verts = target_verts - target_j3ds[:, :1, :]
+ mpves = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).cpu()
+
+ # mpjpe
+ pred_j3ds = pred_j3ds - pred_j3ds[:, :1, :]
+ target_j3ds = target_j3ds - target_j3ds[:, :1, :]
+ mpjpes = torch.sqrt(((pred_j3ds - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu()
+ return mpjpes.mean(), mpves.mean()
+
+def compute_error_frames(output, target):
+ with torch.no_grad():
+ pred_verts = output[0]['verts'].reshape(-1, 6890, 3)
+ target_verts = target['verts'].reshape(-1, 6890, 3)
+
+ pred_j3ds = output[0]['kp_3d'].reshape(-1, 17, 3)
+ target_j3ds = target['kp_3d'].reshape(-1, 17, 3)
+
+ # mpve
+ pred_verts = pred_verts - pred_j3ds[:, :1, :]
+ target_verts = target_verts - target_j3ds[:, :1, :]
+ mpves = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).cpu()
+
+ # mpjpe
+ pred_j3ds = pred_j3ds - pred_j3ds[:, :1, :]
+ target_j3ds = target_j3ds - target_j3ds[:, :1, :]
+ mpjpes = torch.sqrt(((pred_j3ds - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu()
+ return mpjpes, mpves
+
+def evaluate_mesh(results):
+ pred_verts = results['verts'].reshape(-1, 6890, 3)
+ target_verts = results['verts_gt'].reshape(-1, 6890, 3)
+
+ pred_j3ds = results['kp_3d'].reshape(-1, 17, 3)
+ target_j3ds = results['kp_3d_gt'].reshape(-1, 17, 3)
+ num_samples = pred_j3ds.shape[0]
+
+ # mpve
+ pred_verts = pred_verts - pred_j3ds[:, :1, :]
+ target_verts = target_verts - target_j3ds[:, :1, :]
+ mpve = np.mean(np.mean(np.sqrt(np.square(pred_verts - target_verts).sum(axis=2)), axis=1))
+
+
+ # mpjpe-17 & mpjpe-14
+ h36m_17_to_14 = (1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16)
+ pred_j3ds_17j = (pred_j3ds - pred_j3ds[:, :1, :])
+ target_j3ds_17j = (target_j3ds - target_j3ds[:, :1, :])
+
+ pred_j3ds = pred_j3ds_17j[:, h36m_17_to_14, :].copy()
+ target_j3ds = target_j3ds_17j[:, h36m_17_to_14, :].copy()
+
+ mpjpe = np.mean(np.sqrt(np.square(pred_j3ds - target_j3ds).sum(axis=2)), axis=1) # (N, )
+ mpjpe_17j = np.mean(np.sqrt(np.square(pred_j3ds_17j - target_j3ds_17j).sum(axis=2)), axis=1) # (N, )
+
+ pred_j3ds_pa, pred_j3ds_pa_17j = [], []
+ for n in range(num_samples):
+ pred_j3ds_pa.append(rigid_align(pred_j3ds[n], target_j3ds[n]))
+ pred_j3ds_pa_17j.append(rigid_align(pred_j3ds_17j[n], target_j3ds_17j[n]))
+ pred_j3ds_pa = np.array(pred_j3ds_pa)
+ pred_j3ds_pa_17j = np.array(pred_j3ds_pa_17j)
+
+ pa_mpjpe = np.mean(np.sqrt(np.square(pred_j3ds_pa - target_j3ds).sum(axis=2)), axis=1) # (N, )
+ pa_mpjpe_17j = np.mean(np.sqrt(np.square(pred_j3ds_pa_17j - target_j3ds_17j).sum(axis=2)), axis=1) # (N, )
+
+
+ error_dict = {
+ 'mpve': mpve.mean(),
+ 'mpjpe': mpjpe.mean(),
+ 'pa_mpjpe': pa_mpjpe.mean(),
+ 'mpjpe_17j': mpjpe_17j.mean(),
+ 'pa_mpjpe_17j': pa_mpjpe_17j.mean(),
+ }
+ return error_dict
+
+
+def rectify_pose(pose):
+ """
+ Rectify "upside down" people in global coord
+
+ Args:
+ pose (72,): Pose.
+
+ Returns:
+ Rotated pose.
+ """
+ pose = pose.copy()
+ R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0]
+ R_root = cv2.Rodrigues(pose[:3])[0]
+ new_root = R_root.dot(R_mod)
+ pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3)
+ return pose
+
+def flip_thetas(thetas):
+ """Flip thetas.
+
+ Parameters
+ ----------
+ thetas : numpy.ndarray
+ Joints in shape (F, num_thetas, 3)
+ theta_pairs : list
+ List of theta pairs.
+
+ Returns
+ -------
+ numpy.ndarray
+ Flipped thetas with shape (F, num_thetas, 3)
+
+ """
+ #Joint pairs which defines the pairs of joint to be swapped when the image is flipped horizontally.
+ theta_pairs = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23))
+ thetas_flip = thetas.copy()
+ # reflect horizontally
+ thetas_flip[:, :, 1] = -1 * thetas_flip[:, :, 1]
+ thetas_flip[:, :, 2] = -1 * thetas_flip[:, :, 2]
+ # change left-right parts
+ for pair in theta_pairs:
+ thetas_flip[:, pair[0], :], thetas_flip[:, pair[1], :] = \
+ thetas_flip[:, pair[1], :], thetas_flip[:, pair[0], :].copy()
+ return thetas_flip
+
+def flip_thetas_batch(thetas):
+ """Flip thetas in batch.
+
+ Parameters
+ ----------
+ thetas : numpy.array
+ Joints in shape (N, F, num_thetas*3)
+ theta_pairs : list
+ List of theta pairs.
+
+ Returns
+ -------
+ numpy.array
+ Flipped thetas with shape (N, F, num_thetas*3)
+
+ """
+ #Joint pairs which defines the pairs of joint to be swapped when the image is flipped horizontally.
+ theta_pairs = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23))
+ thetas_flip = copy.deepcopy(thetas).reshape(*thetas.shape[:2], 24, 3)
+ # reflect horizontally
+ thetas_flip[:, :, :, 1] = -1 * thetas_flip[:, :, :, 1]
+ thetas_flip[:, :, :, 2] = -1 * thetas_flip[:, :, :, 2]
+ # change left-right parts
+ for pair in theta_pairs:
+ thetas_flip[:, :, pair[0], :], thetas_flip[:, :, pair[1], :] = \
+ thetas_flip[:, :, pair[1], :], thetas_flip[:, :, pair[0], :].clone()
+
+ return thetas_flip.reshape(*thetas.shape[:2], -1)
+
+# def smpl_aa_to_ortho6d(smpl_aa):
+# # [...,72] -> [...,144]
+# rot_aa = smpl_aa.reshape([-1,24,3])
+# rotmat = axis_angle_to_matrix(rot_aa)
+# rot6d = matrix_to_rotation_6d(rotmat)
+# rot6d = rot6d.reshape(-1,24*6)
+# return rot6d
\ No newline at end of file
diff --git a/lib/utils/utils_smpl.py b/lib/utils/utils_smpl.py
new file mode 100644
index 0000000000000000000000000000000000000000..2215dd8517a1bdbfc1acf34c3ce13befc486fb67
--- /dev/null
+++ b/lib/utils/utils_smpl.py
@@ -0,0 +1,88 @@
+# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/models/hmr.py
+# Adhere to their licence to use this script
+
+import torch
+import numpy as np
+import os.path as osp
+from smplx import SMPL as _SMPL
+from smplx.utils import ModelOutput, SMPLOutput
+from smplx.lbs import vertices2joints
+
+
+# Map joints to SMPL joints
+JOINT_MAP = {
+ 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
+ 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
+ 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
+ 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
+ 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
+ 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
+ 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
+ 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
+ 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
+ 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
+ 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
+ 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
+ 'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
+ 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
+ 'Spine (H36M)': 51, 'Jaw (H36M)': 52,
+ 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
+ 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
+}
+JOINT_NAMES = [
+ 'OP Nose', 'OP Neck', 'OP RShoulder',
+ 'OP RElbow', 'OP RWrist', 'OP LShoulder',
+ 'OP LElbow', 'OP LWrist', 'OP MidHip',
+ 'OP RHip', 'OP RKnee', 'OP RAnkle',
+ 'OP LHip', 'OP LKnee', 'OP LAnkle',
+ 'OP REye', 'OP LEye', 'OP REar',
+ 'OP LEar', 'OP LBigToe', 'OP LSmallToe',
+ 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel',
+ 'Right Ankle', 'Right Knee', 'Right Hip',
+ 'Left Hip', 'Left Knee', 'Left Ankle',
+ 'Right Wrist', 'Right Elbow', 'Right Shoulder',
+ 'Left Shoulder', 'Left Elbow', 'Left Wrist',
+ 'Neck (LSP)', 'Top of Head (LSP)',
+ 'Pelvis (MPII)', 'Thorax (MPII)',
+ 'Spine (H36M)', 'Jaw (H36M)',
+ 'Head (H36M)', 'Nose', 'Left Eye',
+ 'Right Eye', 'Left Ear', 'Right Ear'
+]
+
+JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
+SMPL_MODEL_DIR = 'data/mesh'
+H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
+H36M_TO_J14 = H36M_TO_J17[:14]
+
+
+class SMPL(_SMPL):
+ """ Extension of the official SMPL implementation to support more joints """
+
+ def __init__(self, *args, **kwargs):
+ super(SMPL, self).__init__(*args, **kwargs)
+ joints = [JOINT_MAP[i] for i in JOINT_NAMES]
+ self.smpl_mean_params = osp.join(args[0], 'smpl_mean_params.npz')
+ J_regressor_extra = np.load(osp.join(args[0], 'J_regressor_extra.npy'))
+ self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
+ J_regressor_h36m = np.load(osp.join(args[0], 'J_regressor_h36m_correct.npy'))
+ self.register_buffer('J_regressor_h36m', torch.tensor(J_regressor_h36m, dtype=torch.float32))
+ self.joint_map = torch.tensor(joints, dtype=torch.long)
+
+ def forward(self, *args, **kwargs):
+ kwargs['get_skin'] = True
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
+ extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
+ joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
+ joints = joints[:, self.joint_map, :]
+ output = SMPLOutput(vertices=smpl_output.vertices,
+ global_orient=smpl_output.global_orient,
+ body_pose=smpl_output.body_pose,
+ joints=joints,
+ betas=smpl_output.betas,
+ full_pose=smpl_output.full_pose)
+ return output
+
+
+def get_smpl_faces():
+ smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False)
+ return smpl.faces
\ No newline at end of file
diff --git a/lib/utils/vismo.py b/lib/utils/vismo.py
new file mode 100644
index 0000000000000000000000000000000000000000..92b290dc7dd7075f9914bc5578b1a5c5b120ddb1
--- /dev/null
+++ b/lib/utils/vismo.py
@@ -0,0 +1,345 @@
+import numpy as np
+import os
+import cv2
+import math
+import copy
+import imageio
+import io
+from tqdm import tqdm
+from PIL import Image
+from lib.utils.tools import ensure_dir
+import matplotlib
+import matplotlib.pyplot as plt
+from mpl_toolkits.mplot3d import Axes3D
+from lib.utils.utils_smpl import *
+import ipdb
+
+def render_and_save(motion_input, save_path, keep_imgs=False, fps=25, color="#F96706#FB8D43#FDB381", with_conf=False, draw_face=False):
+ ensure_dir(os.path.dirname(save_path))
+ motion = copy.deepcopy(motion_input)
+ if motion.shape[-1]==2 or motion.shape[-1]==3:
+ motion = np.transpose(motion, (1,2,0)) #(T,17,D) -> (17,D,T)
+ if motion.shape[1]==2 or with_conf:
+ colors = hex2rgb(color)
+ if not with_conf:
+ J, D, T = motion.shape
+ motion_full = np.ones([J,3,T])
+ motion_full[:,:2,:] = motion
+ else:
+ motion_full = motion
+ motion_full[:,:2,:] = pixel2world_vis_motion(motion_full[:,:2,:])
+ motion2video(motion_full, save_path=save_path, colors=colors, fps=fps)
+ elif motion.shape[0]==6890:
+ # motion_world = pixel2world_vis_motion(motion, dim=3)
+ motion2video_mesh(motion, save_path=save_path, keep_imgs=keep_imgs, fps=fps, draw_face=draw_face)
+ else:
+ motion_world = pixel2world_vis_motion(motion, dim=3)
+ motion2video_3d(motion_world, save_path=save_path, keep_imgs=keep_imgs, fps=fps)
+
+def pixel2world_vis(pose):
+# pose: (17,2)
+ return (pose + [1, 1]) * 512 / 2
+
+def pixel2world_vis_motion(motion, dim=2, is_tensor=False):
+# pose: (17,2,N)
+ N = motion.shape[-1]
+ if dim==2:
+ offset = np.ones([2,N]).astype(np.float32)
+ else:
+ offset = np.ones([3,N]).astype(np.float32)
+ offset[2,:] = 0
+ if is_tensor:
+ offset = torch.tensor(offset)
+ return (motion + offset) * 512 / 2
+
+def vis_data_batch(data_input, data_label, n_render=10, save_path='doodle/vis_train_data/'):
+ '''
+ data_input: [N,T,17,2/3]
+ data_label: [N,T,17,3]
+ '''
+ pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
+ for i in range(min(len(data_input), n_render)):
+ render_and_save(data_input[i][:,:,:2], '%s/input_%d.mp4' % (save_path, i))
+ render_and_save(data_label[i], '%s/gt_%d.mp4' % (save_path, i))
+
+def get_img_from_fig(fig, dpi=120):
+ buf = io.BytesIO()
+ fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight", pad_inches=0)
+ buf.seek(0)
+ img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
+ buf.close()
+ img = cv2.imdecode(img_arr, 1)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)
+ return img
+
+def rgb2rgba(color):
+ return (color[0], color[1], color[2], 255)
+
+def hex2rgb(hex, number_of_colors=3):
+ h = hex
+ rgb = []
+ for i in range(number_of_colors):
+ h = h.lstrip('#')
+ hex_color = h[0:6]
+ rgb_color = [int(hex_color[i:i+2], 16) for i in (0, 2 ,4)]
+ rgb.append(rgb_color)
+ h = h[6:]
+ return rgb
+
+def joints2image(joints_position, colors, transparency=False, H=1000, W=1000, nr_joints=49, imtype=np.uint8, grayscale=False, bg_color=(255, 255, 255)):
+# joints_position: [17*2]
+ nr_joints = joints_position.shape[0]
+
+ if nr_joints == 49: # full joints(49): basic(15) + eyes(2) + toes(2) + hands(30)
+ limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7], \
+ [8, 9], [8, 13], [9, 10], [10, 11], [11, 12], [13, 14], [14, 15], [15, 16],
+ ]#[0, 17], [0, 18]] #ignore eyes
+
+ L = rgb2rgba(colors[0]) if transparency else colors[0]
+ M = rgb2rgba(colors[1]) if transparency else colors[1]
+ R = rgb2rgba(colors[2]) if transparency else colors[2]
+
+ colors_joints = [M, M, L, L, L, R, R,
+ R, M, L, L, L, L, R, R, R,
+ R, R, L] + [L] * 15 + [R] * 15
+
+ colors_limbs = [M, L, R, M, L, L, R,
+ R, L, R, L, L, L, R, R, R,
+ R, R]
+ elif nr_joints == 15: # basic joints(15) + (eyes(2))
+ limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7],
+ [8, 9], [8, 12], [9, 10], [10, 11], [12, 13], [13, 14]]
+ # [0, 15], [0, 16] two eyes are not drawn
+
+ L = rgb2rgba(colors[0]) if transparency else colors[0]
+ M = rgb2rgba(colors[1]) if transparency else colors[1]
+ R = rgb2rgba(colors[2]) if transparency else colors[2]
+
+ colors_joints = [M, M, L, L, L, R, R,
+ R, M, L, L, L, R, R, R]
+
+ colors_limbs = [M, L, R, M, L, L, R,
+ R, L, R, L, L, R, R]
+ elif nr_joints == 17: # H36M, 0: 'root',
+ # 1: 'rhip',
+ # 2: 'rkne',
+ # 3: 'rank',
+ # 4: 'lhip',
+ # 5: 'lkne',
+ # 6: 'lank',
+ # 7: 'belly',
+ # 8: 'neck',
+ # 9: 'nose',
+ # 10: 'head',
+ # 11: 'lsho',
+ # 12: 'lelb',
+ # 13: 'lwri',
+ # 14: 'rsho',
+ # 15: 'relb',
+ # 16: 'rwri'
+ limbSeq = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8], [8, 9], [8, 11], [8, 14], [9, 10], [11, 12], [12, 13], [14, 15], [15, 16]]
+
+ L = rgb2rgba(colors[0]) if transparency else colors[0]
+ M = rgb2rgba(colors[1]) if transparency else colors[1]
+ R = rgb2rgba(colors[2]) if transparency else colors[2]
+
+ colors_joints = [M, R, R, R, L, L, L, M, M, M, M, L, L, L, R, R, R]
+ colors_limbs = [R, R, R, L, L, L, M, M, M, L, R, M, L, L, R, R]
+
+ else:
+ raise ValueError("Only support number of joints be 49 or 17 or 15")
+
+ if transparency:
+ canvas = np.zeros(shape=(H, W, 4))
+ else:
+ canvas = np.ones(shape=(H, W, 3)) * np.array(bg_color).reshape([1, 1, 3])
+ hips = joints_position[0]
+ neck = joints_position[8]
+ torso_length = ((hips[1] - neck[1]) ** 2 + (hips[0] - neck[0]) ** 2) ** 0.5
+ head_radius = int(torso_length/4.5)
+ end_effectors_radius = int(torso_length/15)
+ end_effectors_radius = 7
+ joints_radius = 7
+ for i in range(0, len(colors_joints)):
+ if i in (17, 18):
+ continue
+ elif i > 18:
+ radius = 2
+ else:
+ radius = joints_radius
+ if len(joints_position[i])==3: # If there is confidence, weigh by confidence
+ weight = joints_position[i][2]
+ if weight==0:
+ continue
+ cv2.circle(canvas, (int(joints_position[i][0]),int(joints_position[i][1])), radius, colors_joints[i], thickness=-1)
+
+ stickwidth = 2
+ for i in range(len(limbSeq)):
+ limb = limbSeq[i]
+ cur_canvas = canvas.copy()
+ point1_index = limb[0]
+ point2_index = limb[1]
+ point1 = joints_position[point1_index]
+ point2 = joints_position[point2_index]
+ if len(point1)==3: # If there is confidence, weigh by confidence
+ limb_weight = min(point1[2], point2[2])
+ if limb_weight==0:
+ bb = bounding_box(canvas)
+ canvas_cropped = canvas[:,bb[2]:bb[3], :]
+ continue
+ X = [point1[1], point2[1]]
+ Y = [point1[0], point2[0]]
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ alpha = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(alpha), 0, 360, 1)
+ cv2.fillConvexPoly(cur_canvas, polygon, colors_limbs[i])
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+ bb = bounding_box(canvas)
+ canvas_cropped = canvas[:,bb[2]:bb[3], :]
+ canvas = canvas.astype(imtype)
+ canvas_cropped = canvas_cropped.astype(imtype)
+ if grayscale:
+ if transparency:
+ canvas = cv2.cvtColor(canvas, cv2.COLOR_RGBA2GRAY)
+ canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGBA2GRAY)
+ else:
+ canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2GRAY)
+ canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGB2GRAY)
+ return [canvas, canvas_cropped]
+
+
+def motion2video(motion, save_path, colors, h=512, w=512, bg_color=(255, 255, 255), transparency=False, motion_tgt=None, fps=25, save_frame=False, grayscale=False, show_progress=True, as_array=False):
+ nr_joints = motion.shape[0]
+# as_array = save_path.endswith(".npy")
+ vlen = motion.shape[-1]
+
+ out_array = np.zeros([vlen, h, w, 3]) if as_array else None
+ videowriter = None if as_array else imageio.get_writer(save_path, fps=fps)
+
+ if save_frame:
+ frames_dir = save_path[:-4] + '-frames'
+ ensure_dir(frames_dir)
+
+ iterator = range(vlen)
+ if show_progress: iterator = tqdm(iterator)
+ for i in iterator:
+ [img, img_cropped] = joints2image(motion[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=grayscale)
+ if motion_tgt is not None:
+ [img_tgt, img_tgt_cropped] = joints2image(motion_tgt[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=grayscale)
+ img_ori = img.copy()
+ img = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
+ img_cropped = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
+ bb = bounding_box(img_cropped)
+ img_cropped = img_cropped[:, bb[2]:bb[3], :]
+ if save_frame:
+ save_image(img_cropped, os.path.join(frames_dir, "%04d.png" % i))
+ if as_array: out_array[i] = img
+ else: videowriter.append_data(img)
+
+ if not as_array:
+ videowriter.close()
+
+ return out_array
+
+def motion2video_3d(motion, save_path, fps=25, keep_imgs = False):
+# motion: (17,3,N)
+ videowriter = imageio.get_writer(save_path, fps=fps)
+ vlen = motion.shape[-1]
+ save_name = save_path.split('.')[0]
+ frames = []
+ joint_pairs = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8], [8, 9], [8, 11], [8, 14], [9, 10], [11, 12], [12, 13], [14, 15], [15, 16]]
+ joint_pairs_left = [[8, 11], [11, 12], [12, 13], [0, 4], [4, 5], [5, 6]]
+ joint_pairs_right = [[8, 14], [14, 15], [15, 16], [0, 1], [1, 2], [2, 3]]
+
+ color_mid = "#00457E"
+ color_left = "#02315E"
+ color_right = "#2F70AF"
+ for f in tqdm(range(vlen)):
+ j3d = motion[:,:,f]
+ fig = plt.figure(0, figsize=(10, 10))
+ ax = plt.axes(projection="3d")
+ ax.set_xlim(-512, 0)
+ ax.set_ylim(-256, 256)
+ ax.set_zlim(-512, 0)
+ # ax.set_xlabel('X')
+ # ax.set_ylabel('Y')
+ # ax.set_zlabel('Z')
+ ax.view_init(elev=12., azim=80)
+ plt.tick_params(left = False, right = False , labelleft = False ,
+ labelbottom = False, bottom = False)
+ for i in range(len(joint_pairs)):
+ limb = joint_pairs[i]
+ xs, ys, zs = [np.array([j3d[limb[0], j], j3d[limb[1], j]]) for j in range(3)]
+ if joint_pairs[i] in joint_pairs_left:
+ ax.plot(-xs, -zs, -ys, color=color_left, lw=3, marker='o', markerfacecolor='w', markersize=3, markeredgewidth=2) # axis transformation for visualization
+ elif joint_pairs[i] in joint_pairs_right:
+ ax.plot(-xs, -zs, -ys, color=color_right, lw=3, marker='o', markerfacecolor='w', markersize=3, markeredgewidth=2) # axis transformation for visualization
+ else:
+ ax.plot(-xs, -zs, -ys, color=color_mid, lw=3, marker='o', markerfacecolor='w', markersize=3, markeredgewidth=2) # axis transformation for visualization
+
+ frame_vis = get_img_from_fig(fig)
+ videowriter.append_data(frame_vis)
+ videowriter.close()
+
+def motion2video_mesh(motion, save_path, fps=25, keep_imgs = False, draw_face=True):
+ videowriter = imageio.get_writer(save_path, fps=fps)
+ vlen = motion.shape[-1]
+ draw_skele = (motion.shape[0]==17)
+ save_name = save_path.split('.')[0]
+ smpl_faces = get_smpl_faces()
+ frames = []
+ joint_pairs = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8], [8, 9], [8, 11], [8, 14], [9, 10], [11, 12], [12, 13], [14, 15], [15, 16]]
+
+
+ X, Y, Z = motion[:, 0], motion[:, 1], motion[:, 2]
+ max_range = np.array([X.max()-X.min(), Y.max()-Y.min(), Z.max()-Z.min()]).max() / 2.0
+ mid_x = (X.max()+X.min()) * 0.5
+ mid_y = (Y.max()+Y.min()) * 0.5
+ mid_z = (Z.max()+Z.min()) * 0.5
+
+ for f in tqdm(range(vlen)):
+ j3d = motion[:,:,f]
+ plt.gca().set_axis_off()
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
+ fig = plt.figure(0, figsize=(8, 8))
+ ax = plt.axes(projection="3d", proj_type = 'ortho')
+ ax.set_xlim(mid_x - max_range, mid_x + max_range)
+ ax.set_ylim(mid_y - max_range, mid_y + max_range)
+ ax.set_zlim(mid_z - max_range, mid_z + max_range)
+ ax.view_init(elev=-90, azim=-90)
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+ plt.margins(0, 0, 0)
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
+ plt.axis('off')
+ plt.xticks([])
+ plt.yticks([])
+
+ # plt.savefig("filename.png", transparent=True, bbox_inches="tight", pad_inches=0)
+
+ if draw_skele:
+ for i in range(len(joint_pairs)):
+ limb = joint_pairs[i]
+ xs, ys, zs = [np.array([j3d[limb[0], j], j3d[limb[1], j]]) for j in range(3)]
+ ax.plot(-xs, -zs, -ys, c=[0,0,0], lw=3, marker='o', markerfacecolor='w', markersize=3, markeredgewidth=2) # axis transformation for visualization
+ elif draw_face:
+ ax.plot_trisurf(j3d[:, 0], j3d[:, 1], triangles=smpl_faces, Z=j3d[:, 2], color=(166/255.0,188/255.0,218/255.0,0.9))
+ else:
+ ax.scatter(j3d[:, 0], j3d[:, 1], j3d[:, 2], s=3, c='w', edgecolors='grey')
+ frame_vis = get_img_from_fig(fig, dpi=128)
+ plt.cla()
+ videowriter.append_data(frame_vis)
+ videowriter.close()
+
+def save_image(image_numpy, image_path):
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path)
+
+def bounding_box(img):
+ a = np.where(img != 0)
+ bbox = np.min(a[0]), np.max(a[0]), np.min(a[1]), np.max(a[1])
+ return bbox
diff --git a/params/d2c_params.pkl b/params/d2c_params.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..0a45d3604e9852cba80e9acd630dde2424d6187b
--- /dev/null
+++ b/params/d2c_params.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3b02023c3fc660f4808c735e2f8a9eae1206a411f1ad7e3429d33719da1cd0d1
+size 184
diff --git a/params/synthetic_noise.pth b/params/synthetic_noise.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f4d1739dbce24feeef8c70c0a305b760ca0df605
--- /dev/null
+++ b/params/synthetic_noise.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c801dfb859b08cf2ed96012176b0dcc7af2358d1a5d18a7c72b6e944416297b
+size 1997
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2f2f952dc8fe17f327de4cf2877a98e88e4e90c4
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+tensorboardX
+tqdm
+easydict
+prettytable
+chumpy
+opencv-python
+imageio-ffmpeg
+matplotlib==3.1.1
+roma
+ipdb
+pytorch-metric-learning # For one-hot action recognition
+smplx[all] # For mesh recovery
diff --git a/tools/compress_amass.py b/tools/compress_amass.py
new file mode 100644
index 0000000000000000000000000000000000000000..3acc72c4121ea4a3fb8a2ea3e09d9efd0b1ee296
--- /dev/null
+++ b/tools/compress_amass.py
@@ -0,0 +1,62 @@
+import numpy as np
+import os
+import pickle
+
+raw_dir = './data/AMASS/amass_202203/'
+processed_dir = './data/AMASS/amass_fps60'
+os.makedirs(processed_dir, exist_ok=True)
+
+files = []
+length = 0
+target_fps = 60
+
+def traverse(f):
+ fs = os.listdir(f)
+ for f1 in fs:
+ tmp_path = os.path.join(f,f1)
+ # file
+ if not os.path.isdir(tmp_path):
+ files.append(tmp_path)
+ # dir
+ else:
+ traverse(tmp_path)
+
+traverse(raw_dir)
+
+print('files:', len(files))
+
+fnames = []
+all_motions = []
+
+with open('data/AMASS/fps.csv', 'w') as f:
+ print('fname_new, len_ori, fps, len_new', file=f)
+ for fname in sorted(files):
+ try:
+ raw_x = np.load(fname)
+ x = dict(raw_x)
+ fps = x['mocap_framerate']
+ len_ori = len(x['trans'])
+ sample_stride = round(fps / target_fps)
+ x['mocap_framerate'] = target_fps
+ x['trans'] = x['trans'][::sample_stride]
+ x['dmpls'] = x['dmpls'][::sample_stride]
+ x['poses'] = x['poses'][::sample_stride]
+ fname_new = '_'.join(fname.split('/')[2:])
+ len_new = len(x['trans'])
+
+ length += len_new
+ print(fname_new, ',', len_ori, ',', fps, ',', len_new, file=f)
+ fnames.append(fname_new)
+ all_motions.append(x)
+ np.savez('%s/%s' % (processed_dir, fname_new), x)
+ except:
+ pass
+
+# break
+
+print('poseFrame:', length)
+print('motions:', len(fnames))
+
+with open("data/AMASS/all_motions_fps%d.pkl" % target_fps, "wb") as myprofile:
+ pickle.dump(all_motions, myprofile)
+
diff --git a/tools/convert_amass.py b/tools/convert_amass.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c9d0a8de4bd4c2d7273f9824221aa1e520dfa1
--- /dev/null
+++ b/tools/convert_amass.py
@@ -0,0 +1,67 @@
+import os
+import sys
+import copy
+import pickle
+import ipdb
+import torch
+import numpy as np
+sys.path.insert(0, os.getcwd())
+from lib.utils.utils_data import split_clips
+from tqdm import tqdm
+
+fileName = open('data/AMASS/amass_joints_h36m_60.pkl','rb')
+joints_all = pickle.load(fileName)
+
+joints_cam = []
+vid_list = []
+vid_len_list = []
+scale_factor = 0.298
+
+for i, item in enumerate(joints_all): # (17,N,3):
+ item = item.astype(np.float32)
+ vid_len = item.shape[1]
+ vid_len_list.append(vid_len)
+ for _ in range(vid_len):
+ vid_list.append(i)
+ real2cam = np.array([[1,0,0],
+ [0,0,1],
+ [0,-1,0]], dtype=np.float32)
+ item = np.transpose(item, (1,0,2)) # (17,N,3) -> (N,17,3)
+ motion_cam = item @ real2cam
+ motion_cam *= scale_factor
+ # motion_cam = motion_cam - motion_cam[0,0,:]
+ joints_cam.append(motion_cam)
+
+joints_cam_all = np.vstack(joints_cam)
+split_id = datareader.split_clips(vid_list, n_frames=243, data_stride=81)
+print(joints_cam_all.shape) # (N,17,3)
+
+max_x, minx_x = np.max(joints_cam_all[:,:,0]), np.min(joints_cam_all[:,:,0])
+max_y, minx_y = np.max(joints_cam_all[:,:,1]), np.min(joints_cam_all[:,:,1])
+max_z, minx_z = np.max(joints_cam_all[:,:,2]), np.min(joints_cam_all[:,:,2])
+print(max_x, minx_x)
+print(max_y, minx_y)
+print(max_z, minx_z)
+
+joints_cam_clip = joints_cam_all[split_id]
+print(joints_cam_clip.shape) # (N,27,17,3)
+
+# np.save('doodle/joints_cam_clip_amass_60.npy', joints_cam_clip)
+
+root_path = "data/motion3d/MB3D_f243s81/AMASS"
+subset_name = "train"
+save_path = os.path.join(root_path, subset_name)
+if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+num_clips = len(joints_cam_clip)
+for i in tqdm(range(num_clips)):
+ motion = joints_cam_clip[i]
+ data_dict = {
+ "data_input": None,
+ "data_label": motion
+ }
+ with open(os.path.join(save_path, "%08d.pkl" % i), "wb") as myprofile:
+ pickle.dump(data_dict, myprofile)
+
+
diff --git a/tools/convert_h36m.py b/tools/convert_h36m.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2997ea40266d41be5abcac29a1b44a846273af0
--- /dev/null
+++ b/tools/convert_h36m.py
@@ -0,0 +1,38 @@
+import os
+import sys
+import pickle
+import numpy as np
+import random
+sys.path.insert(0, os.getcwd())
+from lib.utils.tools import read_pkl
+from lib.data.datareader_h36m import DataReaderH36M
+from tqdm import tqdm
+
+
+def save_clips(subset_name, root_path, train_data, train_labels):
+ len_train = len(train_data)
+ save_path = os.path.join(root_path, subset_name)
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ for i in tqdm(range(len_train)):
+ data_input, data_label = train_data[i], train_labels[i]
+ data_dict = {
+ "data_input": data_input,
+ "data_label": data_label
+ }
+ with open(os.path.join(save_path, "%08d.pkl" % i), "wb") as myprofile:
+ pickle.dump(data_dict, myprofile)
+
+datareader = DataReaderH36M(n_frames=243, sample_stride=1, data_stride_train=81, data_stride_test=243, dt_file = 'h36m_sh_conf_cam_source_final.pkl', dt_root='data/motion3d/')
+train_data, test_data, train_labels, test_labels = datareader.get_sliced_data()
+print(train_data.shape, test_data.shape)
+assert len(train_data) == len(train_labels)
+assert len(test_data) == len(test_labels)
+
+root_path = "data/motion3d/MB3D_f243s81/H36M-SH"
+if not os.path.exists(root_path):
+ os.makedirs(root_path)
+
+save_clips("train", root_path, train_data, train_labels)
+save_clips("test", root_path, test_data, test_labels)
+
diff --git a/tools/convert_insta.py b/tools/convert_insta.py
new file mode 100644
index 0000000000000000000000000000000000000000..9135c2c7e20e37a0e4fb2ae4edcfd62105068c8b
--- /dev/null
+++ b/tools/convert_insta.py
@@ -0,0 +1,79 @@
+from __future__ import print_function
+import os
+import sys
+import random
+import copy
+import argparse
+import math
+import pickle
+import json
+import glob
+import numpy as np
+sys.path.insert(0, os.getcwd())
+from lib.utils.utils_data import crop_scale
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--name_action', type=str)
+ args = parser.parse_args()
+ print("\nParameters:")
+ for attr, value in sorted(args.__dict__.items()):
+ print("\t{}={}".format(attr.upper(), value))
+ return args
+
+def json2pose(json_dict):
+ pose_h36m = np.zeros([17,3])
+ idx2key = ['Hip',
+ 'R Hip',
+ 'R Knee',
+ 'R Ankle',
+ 'L Hip',
+ 'L Knee',
+ 'L Ankle',
+ 'Belly',
+ 'Neck',
+ 'Nose',
+ 'Head',
+ 'L Shoulder',
+ 'L Elbow',
+ 'L Wrist',
+ 'R Shoulder',
+ 'R Elbow',
+ 'R Wrist',
+ ]
+ for i in range(17):
+ if idx2key[i]=='Belly' or idx2key[i]=='Head':
+ pose_h36m[i] = 0, 0, 0
+ else:
+ item = json_dict[idx2key[i]]
+ pose_h36m[i] = item['x'], item['y'], item['logits']
+ return pose_h36m
+
+def load_motion(json_path):
+ json_dict = json.load(open(json_path, 'r'))
+ pose_h36m = json2pose(json_dict)
+ return pose_h36m
+
+
+args = parse_args()
+dataset_root = 'data/Motion2d/InstaVariety/InstaVariety_tracks/'
+action_motions = []
+dir_action = os.path.join(dataset_root, args.name_action)
+for name_vid in sorted(os.listdir(dir_action)):
+ dir_vid = os.path.join(dir_action, name_vid)
+ for name_clip in sorted(os.listdir(dir_vid)):
+ motion_path = os.path.join(dir_vid, name_clip)
+ motion_list = sorted(glob.glob(motion_path+'/*.json'))
+ if len(motion_list)==0:
+ continue
+ motion = [load_motion(i) for i in motion_list]
+ motion = np.array(motion)
+ motion = crop_scale(motion)
+ motion[:,:,:2] = motion[:,:,:2] - motion[0:1,0:1,:2]
+ motion[motion[:,:,2]==0] = 0
+ action_motions.append(motion)
+ print("%s Done, %d vids processed" % (name_vid, len(action_motions)))
+print("%s Done, %d vids processed" % (args.name_action, len(action_motions)))
+with open(os.path.join(dir_action, '%s.pkl' % args.name_action), 'wb') as f:
+ pickle.dump(action_motions, f)
diff --git a/tools/preprocess_amass.py b/tools/preprocess_amass.py
new file mode 100644
index 0000000000000000000000000000000000000000..399d48ccac91cc4b5996149cdc226861f2ed19f8
--- /dev/null
+++ b/tools/preprocess_amass.py
@@ -0,0 +1,64 @@
+import torch
+import numpy as np
+import os
+from os import path as osp
+from human_body_prior.body_model.body_model import BodyModel
+import copy
+import pickle
+import ipdb
+import pandas as pd
+
+df = pd.read_csv('./data/AMASS/fps.csv', sep=',',header=None)
+fname_list = list(df[0][1:])
+
+processed_dir = './data/AMASS/amass_fps60/'
+J_reg_dir = 'data/AMASS/J_regressor_h36m_correct.npy'
+all_motions = 'data/AMASS/all_motions_fps60.pkl'
+
+file = open(all_motions, 'rb')
+motion_data = pickle.load(file)
+J_reg = np.load(J_reg_dir)
+all_joints = []
+
+max_len = 2916
+with open('data/AMASS/clip_list.csv', 'w') as f:
+ print('clip_id, fname, clip_len', file=f)
+ for i, bdata in enumerate(motion_data):
+ if i%200==0:
+ print(i, 'seqs done.')
+ comp_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ subject_gender = bdata['gender']
+ if (str(subject_gender) != 'female') and (str(subject_gender) != 'male'):
+ subject_gender = 'female'
+
+ bm_fname = osp.join('data/AMASS/body_models/smplh/{}/model.npz'.format(subject_gender))
+ dmpl_fname = osp.join('data/AMASS/body_models/dmpls/{}/model.npz'.format(subject_gender))
+
+ # number of body parameters
+ num_betas = 16
+ # number of DMPL parameters
+ num_dmpls = 8
+
+ bm = BodyModel(bm_fname=bm_fname, num_betas=num_betas, num_dmpls=num_dmpls, dmpl_fname=dmpl_fname).to(comp_device)
+ time_length = len(bdata['trans'])
+ num_slice = time_length // max_len
+
+ for sid in range(num_slice+1):
+ start = sid*max_len
+ end = min((sid+1)*max_len, time_length)
+ body_parms = {
+ 'root_orient': torch.Tensor(bdata['poses'][start:end, :3]).to(comp_device), # controls the global root orientation
+ 'pose_body': torch.Tensor(bdata['poses'][start:end, 3:66]).to(comp_device), # controls the body
+ 'pose_hand': torch.Tensor(bdata['poses'][start:end, 66:]).to(comp_device), # controls the finger articulation
+ 'trans': torch.Tensor(bdata['trans'][start:end]).to(comp_device), # controls the global body position
+ 'betas': torch.Tensor(np.repeat(bdata['betas'][:num_betas][np.newaxis], repeats=(end-start), axis=0)).to(comp_device), # controls the body shape. Body shape is static
+ 'dmpls': torch.Tensor(bdata['dmpls'][start:end, :num_dmpls]).to(comp_device) # controls soft tissue dynamics
+ }
+ body_trans_root = bm(**{k:v for k,v in body_parms.items() if k in ['pose_body', 'betas', 'pose_hand', 'dmpls', 'trans', 'root_orient']})
+ mesh = body_trans_root.v.cpu().numpy()
+ kpts = np.dot(J_reg, mesh) # (17,T,3)
+ all_joints.append(kpts)
+ print(len(all_joints)-1, ',', fname_list[i], ',', end-start, file=f)
+ fileName = open('data/AMASS/amass_joints_h36m_60.pkl','wb')
+ pickle.dump(all_joints, fileName)
+ print(len(all_joints))
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..5950ee6c1587462c5ba7cb80978e22ab8fbb15a8
--- /dev/null
+++ b/train.py
@@ -0,0 +1,383 @@
+import os
+import numpy as np
+import argparse
+import errno
+import math
+import pickle
+import tensorboardX
+from tqdm import tqdm
+from time import time
+import copy
+import random
+import prettytable
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+
+from lib.utils.tools import *
+from lib.utils.learning import *
+from lib.utils.utils_data import flip_data
+from lib.data.dataset_motion_2d import PoseTrackDataset2D, InstaVDataset2D
+from lib.data.dataset_motion_3d import MotionDataset3D
+from lib.data.augmentation import Augmenter2D
+from lib.data.datareader_h36m import DataReaderH36M
+from lib.model.loss import *
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.")
+ parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory')
+ parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory')
+ parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)')
+ parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
+ parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)')
+ parser.add_argument('-sd', '--seed', default=0, type=int, help='random seed')
+ opts = parser.parse_args()
+ return opts
+
+def set_random_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+def save_checkpoint(chk_path, epoch, lr, optimizer, model_pos, min_loss):
+ print('Saving checkpoint to', chk_path)
+ torch.save({
+ 'epoch': epoch + 1,
+ 'lr': lr,
+ 'optimizer': optimizer.state_dict(),
+ 'model_pos': model_pos.state_dict(),
+ 'min_loss' : min_loss
+ }, chk_path)
+
+def evaluate(args, model_pos, test_loader, datareader):
+ print('INFO: Testing')
+ results_all = []
+ model_pos.eval()
+ with torch.no_grad():
+ for batch_input, batch_gt in tqdm(test_loader):
+ N, T = batch_gt.shape[:2]
+ if torch.cuda.is_available():
+ batch_input = batch_input.cuda()
+ if args.no_conf:
+ batch_input = batch_input[:, :, :, :2]
+ if args.flip:
+ batch_input_flip = flip_data(batch_input)
+ predicted_3d_pos_1 = model_pos(batch_input)
+ predicted_3d_pos_flip = model_pos(batch_input_flip)
+ predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip) # Flip back
+ predicted_3d_pos = (predicted_3d_pos_1+predicted_3d_pos_2) / 2
+ else:
+ predicted_3d_pos = model_pos(batch_input)
+ if args.rootrel:
+ predicted_3d_pos[:,:,0,:] = 0 # [N,T,17,3]
+ else:
+ batch_gt[:,0,0,2] = 0
+
+ if args.gt_2d:
+ predicted_3d_pos[...,:2] = batch_input[...,:2]
+ results_all.append(predicted_3d_pos.cpu().numpy())
+ results_all = np.concatenate(results_all)
+ results_all = datareader.denormalize(results_all)
+ _, split_id_test = datareader.get_split_id()
+ actions = np.array(datareader.dt_dataset['test']['action'])
+ factors = np.array(datareader.dt_dataset['test']['2.5d_factor'])
+ gts = np.array(datareader.dt_dataset['test']['joints_2.5d_image'])
+ sources = np.array(datareader.dt_dataset['test']['source'])
+
+ num_test_frames = len(actions)
+ frames = np.array(range(num_test_frames))
+ action_clips = actions[split_id_test]
+ factor_clips = factors[split_id_test]
+ source_clips = sources[split_id_test]
+ frame_clips = frames[split_id_test]
+ gt_clips = gts[split_id_test]
+ assert len(results_all)==len(action_clips)
+
+ e1_all = np.zeros(num_test_frames)
+ e2_all = np.zeros(num_test_frames)
+ oc = np.zeros(num_test_frames)
+ results = {}
+ results_procrustes = {}
+ action_names = sorted(set(datareader.dt_dataset['test']['action']))
+ for action in action_names:
+ results[action] = []
+ results_procrustes[action] = []
+ block_list = ['s_09_act_05_subact_02',
+ 's_09_act_10_subact_02',
+ 's_09_act_13_subact_01']
+ for idx in range(len(action_clips)):
+ source = source_clips[idx][0][:-6]
+ if source in block_list:
+ continue
+ frame_list = frame_clips[idx]
+ action = action_clips[idx][0]
+ factor = factor_clips[idx][:,None,None]
+ gt = gt_clips[idx]
+ pred = results_all[idx]
+ pred *= factor
+
+ # Root-relative Errors
+ pred = pred - pred[:,0:1,:]
+ gt = gt - gt[:,0:1,:]
+ err1 = mpjpe(pred, gt)
+ err2 = p_mpjpe(pred, gt)
+ e1_all[frame_list] += err1
+ e2_all[frame_list] += err2
+ oc[frame_list] += 1
+ for idx in range(num_test_frames):
+ if e1_all[idx] > 0:
+ err1 = e1_all[idx] / oc[idx]
+ err2 = e2_all[idx] / oc[idx]
+ action = actions[idx]
+ results[action].append(err1)
+ results_procrustes[action].append(err2)
+ final_result = []
+ final_result_procrustes = []
+ summary_table = prettytable.PrettyTable()
+ summary_table.field_names = ['test_name'] + action_names
+ for action in action_names:
+ final_result.append(np.mean(results[action]))
+ final_result_procrustes.append(np.mean(results_procrustes[action]))
+ summary_table.add_row(['P1'] + final_result)
+ summary_table.add_row(['P2'] + final_result_procrustes)
+ print(summary_table)
+ e1 = np.mean(np.array(final_result))
+ e2 = np.mean(np.array(final_result_procrustes))
+ print('Protocol #1 Error (MPJPE):', e1, 'mm')
+ print('Protocol #2 Error (P-MPJPE):', e2, 'mm')
+ print('----------')
+ return e1, e2, results_all
+
+def train_epoch(args, model_pos, train_loader, losses, optimizer, has_3d, has_gt):
+ model_pos.train()
+ for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)):
+ batch_size = len(batch_input)
+ if torch.cuda.is_available():
+ batch_input = batch_input.cuda()
+ batch_gt = batch_gt.cuda()
+ with torch.no_grad():
+ if args.no_conf:
+ batch_input = batch_input[:, :, :, :2]
+ if not has_3d:
+ conf = copy.deepcopy(batch_input[:,:,:,2:]) # For 2D data, weight/confidence is at the last channel
+ if args.rootrel:
+ batch_gt = batch_gt - batch_gt[:,:,0:1,:]
+ else:
+ batch_gt[:,:,:,2] = batch_gt[:,:,:,2] - batch_gt[:,0:1,0:1,2] # Place the depth of first frame root to 0.
+ if args.mask or args.noise:
+ batch_input = args.aug.augment2D(batch_input, noise=(args.noise and has_gt), mask=args.mask)
+ # Predict 3D poses
+ predicted_3d_pos = model_pos(batch_input) # (N, T, 17, 3)
+
+ optimizer.zero_grad()
+ if has_3d:
+ loss_3d_pos = loss_mpjpe(predicted_3d_pos, batch_gt)
+ loss_3d_scale = n_mpjpe(predicted_3d_pos, batch_gt)
+ loss_3d_velocity = loss_velocity(predicted_3d_pos, batch_gt)
+ loss_lv = loss_limb_var(predicted_3d_pos)
+ loss_lg = loss_limb_gt(predicted_3d_pos, batch_gt)
+ loss_a = loss_angle(predicted_3d_pos, batch_gt)
+ loss_av = loss_angle_velocity(predicted_3d_pos, batch_gt)
+ loss_total = loss_3d_pos + \
+ args.lambda_scale * loss_3d_scale + \
+ args.lambda_3d_velocity * loss_3d_velocity + \
+ args.lambda_lv * loss_lv + \
+ args.lambda_lg * loss_lg + \
+ args.lambda_a * loss_a + \
+ args.lambda_av * loss_av
+ losses['3d_pos'].update(loss_3d_pos.item(), batch_size)
+ losses['3d_scale'].update(loss_3d_scale.item(), batch_size)
+ losses['3d_velocity'].update(loss_3d_velocity.item(), batch_size)
+ losses['lv'].update(loss_lv.item(), batch_size)
+ losses['lg'].update(loss_lg.item(), batch_size)
+ losses['angle'].update(loss_a.item(), batch_size)
+ losses['angle_velocity'].update(loss_av.item(), batch_size)
+ losses['total'].update(loss_total.item(), batch_size)
+ else:
+ loss_2d_proj = loss_2d_weighted(predicted_3d_pos, batch_gt, conf)
+ loss_total = loss_2d_proj
+ losses['2d_proj'].update(loss_2d_proj.item(), batch_size)
+ losses['total'].update(loss_total.item(), batch_size)
+ loss_total.backward()
+ optimizer.step()
+
+def train_with_config(args, opts):
+ print(args)
+ try:
+ os.makedirs(opts.checkpoint)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint)
+ train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs"))
+
+
+ print('Loading dataset...')
+ trainloader_params = {
+ 'batch_size': args.batch_size,
+ 'shuffle': True,
+ 'num_workers': 12,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+
+ testloader_params = {
+ 'batch_size': args.batch_size,
+ 'shuffle': False,
+ 'num_workers': 12,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+
+ train_dataset = MotionDataset3D(args, args.subset_list, 'train')
+ test_dataset = MotionDataset3D(args, args.subset_list, 'test')
+ train_loader_3d = DataLoader(train_dataset, **trainloader_params)
+ test_loader = DataLoader(test_dataset, **testloader_params)
+
+ if args.train_2d:
+ posetrack = PoseTrackDataset2D()
+ posetrack_loader_2d = DataLoader(posetrack, **trainloader_params)
+ instav = InstaVDataset2D()
+ instav_loader_2d = DataLoader(instav, **trainloader_params)
+
+ datareader = DataReaderH36M(n_frames=args.clip_len, sample_stride=args.sample_stride, data_stride_train=args.data_stride, data_stride_test=args.clip_len, dt_root = 'data/motion3d', dt_file=args.dt_file)
+ min_loss = 100000
+ model_backbone = load_backbone(args)
+ model_params = 0
+ for parameter in model_backbone.parameters():
+ model_params = model_params + parameter.numel()
+ print('INFO: Trainable parameter count:', model_params)
+
+ if torch.cuda.is_available():
+ model_backbone = nn.DataParallel(model_backbone)
+ model_backbone = model_backbone.cuda()
+
+ if args.finetune:
+ if opts.resume or opts.evaluate:
+ chk_filename = opts.evaluate if opts.evaluate else opts.resume
+ print('Loading checkpoint', chk_filename)
+ checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
+ model_backbone.load_state_dict(checkpoint['model_pos'], strict=True)
+ model_pos = model_backbone
+ else:
+ chk_filename = os.path.join(opts.pretrained, opts.selection)
+ print('Loading checkpoint', chk_filename)
+ checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
+ model_backbone.load_state_dict(checkpoint['model_pos'], strict=True)
+ model_pos = model_backbone
+ else:
+ chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin")
+ if os.path.exists(chk_filename):
+ opts.resume = chk_filename
+ if opts.resume or opts.evaluate:
+ chk_filename = opts.evaluate if opts.evaluate else opts.resume
+ print('Loading checkpoint', chk_filename)
+ checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
+ model_backbone.load_state_dict(checkpoint['model_pos'], strict=True)
+ model_pos = model_backbone
+
+ if args.partial_train:
+ model_pos = partial_train_layers(model_pos, args.partial_train)
+
+ if not opts.evaluate:
+ lr = args.learning_rate
+ optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model_pos.parameters()), lr=lr, weight_decay=args.weight_decay)
+ lr_decay = args.lr_decay
+ st = 0
+ if args.train_2d:
+ print('INFO: Training on {}(3D)+{}(2D) batches'.format(len(train_loader_3d), len(instav_loader_2d) + len(posetrack_loader_2d)))
+ else:
+ print('INFO: Training on {}(3D) batches'.format(len(train_loader_3d)))
+ if opts.resume:
+ st = checkpoint['epoch']
+ if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ else:
+ print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.')
+ lr = checkpoint['lr']
+ if 'min_loss' in checkpoint and checkpoint['min_loss'] is not None:
+ min_loss = checkpoint['min_loss']
+
+ args.mask = (args.mask_ratio > 0 and args.mask_T_ratio > 0)
+ if args.mask or args.noise:
+ args.aug = Augmenter2D(args)
+
+ # Training
+ for epoch in range(st, args.epochs):
+ print('Training epoch %d.' % epoch)
+ start_time = time()
+ losses = {}
+ losses['3d_pos'] = AverageMeter()
+ losses['3d_scale'] = AverageMeter()
+ losses['2d_proj'] = AverageMeter()
+ losses['lg'] = AverageMeter()
+ losses['lv'] = AverageMeter()
+ losses['total'] = AverageMeter()
+ losses['3d_velocity'] = AverageMeter()
+ losses['angle'] = AverageMeter()
+ losses['angle_velocity'] = AverageMeter()
+ N = 0
+
+ # Curriculum Learning
+ if args.train_2d and (epoch >= args.pretrain_3d_curriculum):
+ train_epoch(args, model_pos, posetrack_loader_2d, losses, optimizer, has_3d=False, has_gt=True)
+ train_epoch(args, model_pos, instav_loader_2d, losses, optimizer, has_3d=False, has_gt=False)
+ train_epoch(args, model_pos, train_loader_3d, losses, optimizer, has_3d=True, has_gt=True)
+ elapsed = (time() - start_time) / 60
+
+ if args.no_eval:
+ print('[%d] time %.2f lr %f 3d_train %f' % (
+ epoch + 1,
+ elapsed,
+ lr,
+ losses['3d_pos'].avg))
+ else:
+ e1, e2, results_all = evaluate(args, model_pos, test_loader, datareader)
+ print('[%d] time %.2f lr %f 3d_train %f e1 %f e2 %f' % (
+ epoch + 1,
+ elapsed,
+ lr,
+ losses['3d_pos'].avg,
+ e1, e2))
+ train_writer.add_scalar('Error P1', e1, epoch + 1)
+ train_writer.add_scalar('Error P2', e2, epoch + 1)
+ train_writer.add_scalar('loss_3d_pos', losses['3d_pos'].avg, epoch + 1)
+ train_writer.add_scalar('loss_2d_proj', losses['2d_proj'].avg, epoch + 1)
+ train_writer.add_scalar('loss_3d_scale', losses['3d_scale'].avg, epoch + 1)
+ train_writer.add_scalar('loss_3d_velocity', losses['3d_velocity'].avg, epoch + 1)
+ train_writer.add_scalar('loss_lv', losses['lv'].avg, epoch + 1)
+ train_writer.add_scalar('loss_lg', losses['lg'].avg, epoch + 1)
+ train_writer.add_scalar('loss_a', losses['angle'].avg, epoch + 1)
+ train_writer.add_scalar('loss_av', losses['angle_velocity'].avg, epoch + 1)
+ train_writer.add_scalar('loss_total', losses['total'].avg, epoch + 1)
+
+ # Decay learning rate exponentially
+ lr *= lr_decay
+ for param_group in optimizer.param_groups:
+ param_group['lr'] *= lr_decay
+
+ # Save checkpoints
+ chk_path = os.path.join(opts.checkpoint, 'epoch_{}.bin'.format(epoch))
+ chk_path_latest = os.path.join(opts.checkpoint, 'latest_epoch.bin')
+ chk_path_best = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch))
+
+ save_checkpoint(chk_path_latest, epoch, lr, optimizer, model_pos, min_loss)
+ if (epoch + 1) % args.checkpoint_frequency == 0:
+ save_checkpoint(chk_path, epoch, lr, optimizer, model_pos, min_loss)
+ if e1 < min_loss:
+ min_loss = e1
+ save_checkpoint(chk_path_best, epoch, lr, optimizer, model_pos, min_loss)
+
+ if opts.evaluate:
+ e1, e2, results_all = evaluate(args, model_pos, test_loader, datareader)
+
+if __name__ == "__main__":
+ opts = parse_args()
+ set_random_seed(opts.seed)
+ args = get_config(opts.config)
+ train_with_config(args, opts)
\ No newline at end of file
diff --git a/train_action.py b/train_action.py
new file mode 100644
index 0000000000000000000000000000000000000000..e105c26e14580eb9aa04f979f3cb0560fb504b41
--- /dev/null
+++ b/train_action.py
@@ -0,0 +1,243 @@
+import os
+import numpy as np
+import time
+import sys
+import argparse
+import errno
+from collections import OrderedDict
+import tensorboardX
+from tqdm import tqdm
+import random
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.optim.lr_scheduler import StepLR
+from torch.utils.data import DataLoader
+
+from lib.utils.tools import *
+from lib.utils.learning import *
+from lib.model.loss import *
+from lib.data.dataset_action import NTURGBD
+from lib.model.model_action import ActionNet
+
+random.seed(0)
+np.random.seed(0)
+torch.manual_seed(0)
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.")
+ parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory')
+ parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory')
+ parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)')
+ parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
+ parser.add_argument('-freq', '--print_freq', default=100)
+ parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)')
+ opts = parser.parse_args()
+ return opts
+
+def validate(test_loader, model, criterion):
+ model.eval()
+ batch_time = AverageMeter()
+ losses = AverageMeter()
+ top1 = AverageMeter()
+ top5 = AverageMeter()
+ with torch.no_grad():
+ end = time.time()
+ for idx, (batch_input, batch_gt) in tqdm(enumerate(test_loader)):
+ batch_size = len(batch_input)
+ if torch.cuda.is_available():
+ batch_gt = batch_gt.cuda()
+ batch_input = batch_input.cuda()
+ output = model(batch_input) # (N, num_classes)
+ loss = criterion(output, batch_gt)
+
+ # update metric
+ losses.update(loss.item(), batch_size)
+ acc1, acc5 = accuracy(output, batch_gt, topk=(1, 5))
+ top1.update(acc1[0], batch_size)
+ top5.update(acc5[0], batch_size)
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if (idx+1) % opts.print_freq == 0:
+ print('Test: [{0}/{1}]\t'
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
+ 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
+ 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
+ idx, len(test_loader), batch_time=batch_time,
+ loss=losses, top1=top1, top5=top5))
+ return losses.avg, top1.avg, top5.avg
+
+
+def train_with_config(args, opts):
+ print(args)
+ try:
+ os.makedirs(opts.checkpoint)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint)
+ train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs"))
+ model_backbone = load_backbone(args)
+ if args.finetune:
+ if opts.resume or opts.evaluate:
+ pass
+ else:
+ chk_filename = os.path.join(opts.pretrained, opts.selection)
+ print('Loading backbone', chk_filename)
+ checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)['model_pos']
+ model_backbone = load_pretrained_weights(model_backbone, checkpoint)
+ if args.partial_train:
+ model_backbone = partial_train_layers(model_backbone, args.partial_train)
+ model = ActionNet(backbone=model_backbone, dim_rep=args.dim_rep, num_classes=args.action_classes, dropout_ratio=args.dropout_ratio, version=args.model_version, hidden_dim=args.hidden_dim, num_joints=args.num_joints)
+ criterion = torch.nn.CrossEntropyLoss()
+ if torch.cuda.is_available():
+ model = nn.DataParallel(model)
+ model = model.cuda()
+ criterion = criterion.cuda()
+ best_acc = 0
+ model_params = 0
+ for parameter in model.parameters():
+ model_params = model_params + parameter.numel()
+ print('INFO: Trainable parameter count:', model_params)
+ print('Loading dataset...')
+ trainloader_params = {
+ 'batch_size': args.batch_size,
+ 'shuffle': True,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+ testloader_params = {
+ 'batch_size': args.batch_size,
+ 'shuffle': False,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+ data_path = 'data/action/%s.pkl' % args.dataset
+ ntu60_xsub_train = NTURGBD(data_path=data_path, data_split=args.data_split+'_train', n_frames=args.clip_len, random_move=args.random_move, scale_range=args.scale_range_train)
+ ntu60_xsub_val = NTURGBD(data_path=data_path, data_split=args.data_split+'_val', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test)
+
+ train_loader = DataLoader(ntu60_xsub_train, **trainloader_params)
+ test_loader = DataLoader(ntu60_xsub_val, **testloader_params)
+
+ chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin")
+ if os.path.exists(chk_filename):
+ opts.resume = chk_filename
+ if opts.resume or opts.evaluate:
+ chk_filename = opts.evaluate if opts.evaluate else opts.resume
+ print('Loading checkpoint', chk_filename)
+ checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
+ model.load_state_dict(checkpoint['model'], strict=True)
+
+ if not opts.evaluate:
+ optimizer = optim.AdamW(
+ [ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone},
+ {"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head},
+ ], lr=args.lr_backbone,
+ weight_decay=args.weight_decay
+ )
+
+ scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay)
+ st = 0
+ print('INFO: Training on {} batches'.format(len(train_loader)))
+ if opts.resume:
+ st = checkpoint['epoch']
+ if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ else:
+ print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.')
+ lr = checkpoint['lr']
+ if 'best_acc' in checkpoint and checkpoint['best_acc'] is not None:
+ best_acc = checkpoint['best_acc']
+ # Training
+ for epoch in range(st, args.epochs):
+ print('Training epoch %d.' % epoch)
+ losses_train = AverageMeter()
+ top1 = AverageMeter()
+ top5 = AverageMeter()
+ batch_time = AverageMeter()
+ data_time = AverageMeter()
+ model.train()
+ end = time.time()
+ iters = len(train_loader)
+ for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): # (N, 2, T, 17, 3)
+ data_time.update(time.time() - end)
+ batch_size = len(batch_input)
+ if torch.cuda.is_available():
+ batch_gt = batch_gt.cuda()
+ batch_input = batch_input.cuda()
+ output = model(batch_input) # (N, num_classes)
+ optimizer.zero_grad()
+ loss_train = criterion(output, batch_gt)
+ losses_train.update(loss_train.item(), batch_size)
+ acc1, acc5 = accuracy(output, batch_gt, topk=(1, 5))
+ top1.update(acc1[0], batch_size)
+ top5.update(acc5[0], batch_size)
+ loss_train.backward()
+ optimizer.step()
+ batch_time.update(time.time() - end)
+ end = time.time()
+ if (idx + 1) % opts.print_freq == 0:
+ print('Train: [{0}][{1}/{2}]\t'
+ 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+ 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
+ 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
+ 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
+ epoch, idx + 1, len(train_loader), batch_time=batch_time,
+ data_time=data_time, loss=losses_train, top1=top1))
+ sys.stdout.flush()
+
+ test_loss, test_top1, test_top5 = validate(test_loader, model, criterion)
+
+ train_writer.add_scalar('train_loss', losses_train.avg, epoch + 1)
+ train_writer.add_scalar('train_top1', top1.avg, epoch + 1)
+ train_writer.add_scalar('train_top5', top5.avg, epoch + 1)
+ train_writer.add_scalar('test_loss', test_loss, epoch + 1)
+ train_writer.add_scalar('test_top1', test_top1, epoch + 1)
+ train_writer.add_scalar('test_top5', test_top5, epoch + 1)
+
+ scheduler.step()
+
+ # Save latest checkpoint.
+ chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin')
+ print('Saving checkpoint to', chk_path)
+ torch.save({
+ 'epoch': epoch+1,
+ 'lr': scheduler.get_last_lr(),
+ 'optimizer': optimizer.state_dict(),
+ 'model': model.state_dict(),
+ 'best_acc' : best_acc
+ }, chk_path)
+
+ # Save best checkpoint.
+ best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch))
+ if test_top1 > best_acc:
+ best_acc = test_top1
+ print("save best checkpoint")
+ torch.save({
+ 'epoch': epoch+1,
+ 'lr': scheduler.get_last_lr(),
+ 'optimizer': optimizer.state_dict(),
+ 'model': model.state_dict(),
+ 'best_acc' : best_acc
+ }, best_chk_path)
+
+ if opts.evaluate:
+ test_loss, test_top1, test_top5 = validate(test_loader, model, criterion)
+ print('Loss {loss:.4f} \t'
+ 'Acc@1 {top1:.3f} \t'
+ 'Acc@5 {top5:.3f} \t'.format(loss=test_loss, top1=test_top1, top5=test_top5))
+
+if __name__ == "__main__":
+ opts = parse_args()
+ args = get_config(opts.config)
+ train_with_config(args, opts)
\ No newline at end of file
diff --git a/train_action_1shot.py b/train_action_1shot.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f07902cb46b3d6b9d6258f59090e5932f756d5d
--- /dev/null
+++ b/train_action_1shot.py
@@ -0,0 +1,243 @@
+import os
+import numpy as np
+import time
+import sys
+import argparse
+import errno
+from collections import OrderedDict
+import tensorboardX
+from tqdm import tqdm
+import random
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.optim.lr_scheduler import StepLR
+from torch.utils.data import DataLoader
+
+from lib.utils.tools import *
+from lib.utils.learning import *
+from lib.model.loss import *
+from lib.data.dataset_action import NTURGBD, NTURGBD1Shot
+from lib.model.model_action import ActionNet
+
+from lib.model.loss_supcon import SupConLoss
+from pytorch_metric_learning import samplers
+
+random.seed(0)
+np.random.seed(0)
+torch.manual_seed(0)
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.")
+ parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory')
+ parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory')
+ parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)')
+ parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
+ parser.add_argument('-freq', '--print_freq', default=100)
+ parser.add_argument('-ms', '--selection', default='best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)')
+ opts = parser.parse_args()
+ return opts
+
+def extract_feats(dataloader_x, model):
+ all_feats = []
+ all_gts = []
+ with torch.no_grad():
+ for idx, (batch_input, batch_gt) in tqdm(enumerate(dataloader_x)): # (N, 2, T, 17, 3)
+ if torch.cuda.is_available():
+ batch_input = batch_input.cuda()
+ feat = model(batch_input)
+ all_feats.append(feat)
+ all_gts.append(batch_gt)
+ all_feats = torch.cat(all_feats)
+ all_gts = torch.cat(all_gts)
+ return all_feats, all_gts
+
+def validate(anchor_loader, test_loader, model):
+ train_feats, train_labels = extract_feats(anchor_loader, model)
+ test_feats, test_labels = extract_feats(test_loader, model)
+ M = len(train_feats)
+ N = len(test_feats)
+ train_feats = train_feats.unsqueeze(1)
+ test_feats = test_feats.unsqueeze(0)
+ dis = F.cosine_similarity(train_feats, test_feats, dim=-1)
+ pred = train_labels[torch.argmax(dis, dim=0)]
+ assert len(pred)==len(test_labels)
+ acc = sum(pred==test_labels) / len(pred)
+ return acc
+
+def train_with_config(args, opts):
+ print(args)
+ try:
+ os.makedirs(opts.checkpoint)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint)
+ train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs"))
+ model_backbone = load_backbone(args)
+ if args.finetune:
+ if opts.resume or opts.evaluate:
+ pass
+ else:
+ chk_filename = os.path.join(opts.pretrained, "best_epoch.bin")
+ print('Loading backbone', chk_filename)
+ checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint['model_pos'].items():
+ name = k[7:] # remove 'module.'
+ new_state_dict[name] = v
+ model_backbone.load_state_dict(new_state_dict, strict=True)
+ if args.partial_train:
+ model_backbone = partial_train_layers(model_backbone, args.partial_train)
+ model = ActionNet(backbone=model_backbone, dim_rep=args.dim_rep, dropout_ratio=args.dropout_ratio, version=args.model_version, hidden_dim=args.hidden_dim, num_joints=args.num_joints)
+ criterion = SupConLoss(temperature=args.temp)
+
+ if torch.cuda.is_available():
+ model = nn.DataParallel(model)
+ model = model.cuda()
+ criterion = criterion.cuda()
+
+ chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin")
+ if os.path.exists(chk_filename):
+ opts.resume = chk_filename
+ if opts.resume or opts.evaluate:
+ chk_filename = opts.evaluate if opts.evaluate else opts.resume
+ print('Loading checkpoint', chk_filename)
+ checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
+ model.load_state_dict(checkpoint['model'], strict=True)
+
+ best_acc = 0
+ model_params = 0
+ for parameter in model.parameters():
+ model_params = model_params + parameter.numel()
+ print('INFO: Trainable parameter count:', model_params)
+ print('Loading dataset...')
+
+ anchorloader_params = {
+ 'batch_size': args.batch_size,
+ 'shuffle': False,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+
+ testloader_params = {
+ 'batch_size': args.batch_size,
+ 'shuffle': False,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+ data_path_1shot = 'data/action/ntu120_hrnet_oneshot.pkl'
+ ntu60_1shot_anchor = NTURGBD(data_path=data_path_1shot, data_split='oneshot_train', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test)
+ ntu60_1shot_test = NTURGBD(data_path=data_path_1shot, data_split='oneshot_val', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test)
+ anchor_loader = DataLoader(ntu60_1shot_anchor, **anchorloader_params)
+ test_loader = DataLoader(ntu60_1shot_test, **testloader_params)
+
+ if not opts.evaluate:
+ # Load training data (auxiliary set)
+ data_path = 'data/action/ntu120_hrnet.pkl'
+ ntu120_1shot_train = NTURGBD1Shot(data_path=data_path, data_split='', n_frames=args.clip_len, random_move=args.random_move, scale_range=args.scale_range_train, check_split=False)
+ sampler = samplers.MPerClassSampler(ntu120_1shot_train.labels, m=args.n_views, batch_size=args.batch_size, length_before_new_iter=len(ntu120_1shot_train))
+ trainloader_params = {
+ 'batch_size': args.batch_size,
+ 'shuffle': False,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True,
+ 'sampler': sampler
+ }
+ train_loader = DataLoader(ntu120_1shot_train, **trainloader_params)
+ optimizer = optim.AdamW(
+ [ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone},
+ {"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head},
+ ], lr=args.lr_backbone,
+ weight_decay=args.weight_decay
+ )
+ scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay)
+ st = 0
+ print('INFO: Training on {} batches'.format(len(train_loader)))
+ if opts.resume:
+ st = checkpoint['epoch']
+ if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ else:
+ print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.')
+
+ lr = checkpoint['lr']
+ if 'best_acc' in checkpoint and checkpoint['best_acc'] is not None:
+ best_acc = checkpoint['best_acc']
+
+ # Training
+ for epoch in range(st, args.epochs):
+ print('Training epoch %d.' % epoch)
+ losses_train = AverageMeter()
+ batch_time = AverageMeter()
+ data_time = AverageMeter()
+
+ model.train()
+ end = time.time()
+
+ for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)):
+ data_time.update(time.time() - end)
+ batch_size = len(batch_input)
+ if torch.cuda.is_available():
+ batch_gt = batch_gt.cuda()
+ batch_input = batch_input.cuda()
+ feat = model(batch_input)
+ feat = feat.reshape(batch_size, -1, args.hidden_dim)
+ optimizer.zero_grad()
+ loss_train = criterion(feat, batch_gt)
+ losses_train.update(loss_train.item(), batch_size)
+ loss_train.backward()
+ optimizer.step()
+ batch_time.update(time.time() - end)
+ end = time.time()
+ if (idx + 1) % opts.print_freq == 0:
+ print('Train: [{0}][{1}/{2}]\t'
+ 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+ 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
+ 'loss {loss.val:.3f} ({loss.avg:.3f})\t'.format(
+ epoch, idx + 1, len(train_loader), batch_time=batch_time,
+ data_time=data_time, loss=losses_train))
+ sys.stdout.flush()
+ test_top1 = validate(anchor_loader, test_loader, model)
+ train_writer.add_scalar('train_loss_supcon', losses_train.avg, epoch + 1)
+ train_writer.add_scalar('test_top1', test_top1, epoch + 1)
+ scheduler.step()
+ # Save latest checkpoint.
+ chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin')
+ print('Saving checkpoint to', chk_path)
+ torch.save({
+ 'epoch': epoch+1,
+ 'lr': scheduler.get_last_lr(),
+ 'optimizer': optimizer.state_dict(),
+ 'model': model.state_dict(),
+ 'best_acc' : best_acc
+ }, chk_path)
+
+ # Save best checkpoint
+ best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch))
+ if test_top1 > best_acc:
+ best_acc = test_top1
+ print("save best checkpoint")
+ torch.save({
+ 'epoch': epoch+1,
+ 'lr': scheduler.get_last_lr(),
+ 'optimizer': optimizer.state_dict(),
+ 'model': model.state_dict(),
+ 'best_acc' : best_acc
+ }, best_chk_path)
+ if opts.evaluate:
+ test_top1 = validate(anchor_loader, test_loader, model)
+ print(test_top1)
+if __name__ == "__main__":
+ opts = parse_args()
+ args = get_config(opts.config)
+ train_with_config(args, opts)
+
\ No newline at end of file
diff --git a/train_mesh.py b/train_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c73dd5856f4bc492f1fa47dad903fd1227c60b6
--- /dev/null
+++ b/train_mesh.py
@@ -0,0 +1,437 @@
+import os
+import random
+import copy
+import time
+import sys
+import shutil
+import argparse
+import errno
+import math
+import numpy as np
+from collections import defaultdict, OrderedDict
+import tensorboardX
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torch.optim.lr_scheduler import StepLR
+
+from lib.utils.tools import *
+from lib.model.loss import *
+from lib.model.loss_mesh import *
+from lib.utils.utils_mesh import *
+from lib.utils.utils_smpl import *
+from lib.utils.utils_data import *
+from lib.utils.learning import *
+from lib.data.dataset_mesh import MotionSMPL
+from lib.model.model_mesh import MeshRegressor
+from torch.utils.data import DataLoader
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.")
+ parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory')
+ parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory')
+ parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)')
+ parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
+ parser.add_argument('-freq', '--print_freq', default=100)
+ parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)')
+ parser.add_argument('-sd', '--seed', default=0, type=int, help='random seed')
+ opts = parser.parse_args()
+ return opts
+
+def set_random_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+def validate(test_loader, model, criterion, dataset_name='h36m'):
+ model.eval()
+ print(f'===========> validating {dataset_name}')
+ batch_time = AverageMeter()
+ losses = AverageMeter()
+ losses_dict = {'loss_3d_pos': AverageMeter(),
+ 'loss_3d_scale': AverageMeter(),
+ 'loss_3d_velocity': AverageMeter(),
+ 'loss_lv': AverageMeter(),
+ 'loss_lg': AverageMeter(),
+ 'loss_a': AverageMeter(),
+ 'loss_av': AverageMeter(),
+ 'loss_pose': AverageMeter(),
+ 'loss_shape': AverageMeter(),
+ 'loss_norm': AverageMeter(),
+ }
+ mpjpes = AverageMeter()
+ mpves = AverageMeter()
+ results = defaultdict(list)
+ smpl = SMPL(args.data_root, batch_size=1).cuda()
+ J_regressor = smpl.J_regressor_h36m
+ with torch.no_grad():
+ end = time.time()
+ for idx, (batch_input, batch_gt) in tqdm(enumerate(test_loader)):
+ batch_size, clip_len = batch_input.shape[:2]
+ if torch.cuda.is_available():
+ batch_gt['theta'] = batch_gt['theta'].cuda().float()
+ batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float()
+ batch_gt['verts'] = batch_gt['verts'].cuda().float()
+ batch_input = batch_input.cuda().float()
+ output = model(batch_input)
+ output_final = output
+ if args.flip:
+ batch_input_flip = flip_data(batch_input)
+ output_flip = model(batch_input_flip)
+ output_flip_pose = output_flip[0]['theta'][:, :, :72]
+ output_flip_shape = output_flip[0]['theta'][:, :, 72:]
+ output_flip_pose = flip_thetas_batch(output_flip_pose)
+ output_flip_pose = output_flip_pose.reshape(-1, 72)
+ output_flip_shape = output_flip_shape.reshape(-1, 10)
+ output_flip_smpl = smpl(
+ betas=output_flip_shape,
+ body_pose=output_flip_pose[:, 3:],
+ global_orient=output_flip_pose[:, :3],
+ pose2rot=True
+ )
+ output_flip_verts = output_flip_smpl.vertices.detach()*1000.0
+ J_regressor_batch = J_regressor[None, :].expand(output_flip_verts.shape[0], -1, -1).to(output_flip_verts.device)
+ output_flip_kp3d = torch.matmul(J_regressor_batch, output_flip_verts) # (NT,17,3)
+ output_flip_back = [{
+ 'theta': torch.cat((output_flip_pose.reshape(batch_size, clip_len, -1), output_flip_shape.reshape(batch_size, clip_len, -1)), dim=-1),
+ 'verts': output_flip_verts.reshape(batch_size, clip_len, -1, 3),
+ 'kp_3d': output_flip_kp3d.reshape(batch_size, clip_len, -1, 3),
+ }]
+ output_final = [{}]
+ for k, v in output_flip[0].items():
+ output_final[0][k] = (output[0][k] + output_flip_back[0][k])*0.5
+ output = output_final
+ loss_dict = criterion(output, batch_gt)
+ loss = args.lambda_3d * loss_dict['loss_3d_pos'] + \
+ args.lambda_scale * loss_dict['loss_3d_scale'] + \
+ args.lambda_3dv * loss_dict['loss_3d_velocity'] + \
+ args.lambda_lv * loss_dict['loss_lv'] + \
+ args.lambda_lg * loss_dict['loss_lg'] + \
+ args.lambda_a * loss_dict['loss_a'] + \
+ args.lambda_av * loss_dict['loss_av'] + \
+ args.lambda_shape * loss_dict['loss_shape'] + \
+ args.lambda_pose * loss_dict['loss_pose'] + \
+ args.lambda_norm * loss_dict['loss_norm']
+ # update metric
+ losses.update(loss.item(), batch_size)
+ loss_str = ''
+ for k, v in loss_dict.items():
+ losses_dict[k].update(v.item(), batch_size)
+ loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k])
+ mpjpe, mpve = compute_error(output, batch_gt)
+ mpjpes.update(mpjpe, batch_size)
+ mpves.update(mpve, batch_size)
+
+ for keys in output[0].keys():
+ output[0][keys] = output[0][keys].detach().cpu().numpy()
+ batch_gt[keys] = batch_gt[keys].detach().cpu().numpy()
+ results['kp_3d'].append(output[0]['kp_3d'])
+ results['verts'].append(output[0]['verts'])
+ results['kp_3d_gt'].append(batch_gt['kp_3d'])
+ results['verts_gt'].append(batch_gt['verts'])
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if idx % int(opts.print_freq) == 0:
+ print('Test: [{0}/{1}]\t'
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
+ '{2}'
+ 'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t'
+ 'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format(
+ idx, len(test_loader), loss_str, batch_time=batch_time,
+ loss=losses, mpves=mpves, mpjpes=mpjpes))
+
+ print(f'==> start concating results of {dataset_name}')
+ for term in results.keys():
+ results[term] = np.concatenate(results[term])
+ print(f'==> start evaluating {dataset_name}...')
+ error_dict = evaluate_mesh(results)
+ err_str = ''
+ for err_key, err_val in error_dict.items():
+ err_str += '{}: {:.2f}mm \t'.format(err_key, err_val)
+ print(f'=======================> {dataset_name} validation done: ', loss_str)
+ print(f'=======================> {dataset_name} validation done: ', err_str)
+ return losses.avg, error_dict['mpjpe'], error_dict['pa_mpjpe'], error_dict['mpve'], losses_dict
+
+
+def train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch):
+ model.train()
+ end = time.time()
+ for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)):
+ data_time.update(time.time() - end)
+ batch_size = len(batch_input)
+
+ if torch.cuda.is_available():
+ batch_gt['theta'] = batch_gt['theta'].cuda().float()
+ batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float()
+ batch_gt['verts'] = batch_gt['verts'].cuda().float()
+ batch_input = batch_input.cuda().float()
+ output = model(batch_input)
+ optimizer.zero_grad()
+ loss_dict = criterion(output, batch_gt)
+ loss_train = args.lambda_3d * loss_dict['loss_3d_pos'] + \
+ args.lambda_scale * loss_dict['loss_3d_scale'] + \
+ args.lambda_3dv * loss_dict['loss_3d_velocity'] + \
+ args.lambda_lv * loss_dict['loss_lv'] + \
+ args.lambda_lg * loss_dict['loss_lg'] + \
+ args.lambda_a * loss_dict['loss_a'] + \
+ args.lambda_av * loss_dict['loss_av'] + \
+ args.lambda_shape * loss_dict['loss_shape'] + \
+ args.lambda_pose * loss_dict['loss_pose'] + \
+ args.lambda_norm * loss_dict['loss_norm']
+ losses_train.update(loss_train.item(), batch_size)
+ loss_str = ''
+ for k, v in loss_dict.items():
+ losses_dict[k].update(v.item(), batch_size)
+ loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k])
+
+ mpjpe, mpve = compute_error(output, batch_gt)
+ mpjpes.update(mpjpe, batch_size)
+ mpves.update(mpve, batch_size)
+
+ loss_train.backward()
+ optimizer.step()
+
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if idx % int(opts.print_freq) == 0:
+ print('Train: [{0}][{1}/{2}]\t'
+ 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+ 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
+ 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
+ '{3}'
+ 'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t'
+ 'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format(
+ epoch, idx + 1, len(train_loader), loss_str, batch_time=batch_time,
+ data_time=data_time, loss=losses_train, mpves=mpves, mpjpes=mpjpes))
+ sys.stdout.flush()
+
+def train_with_config(args, opts):
+ print(args)
+ try:
+ os.makedirs(opts.checkpoint)
+ shutil.copy(opts.config, opts.checkpoint)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint)
+ train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs"))
+ model_backbone = load_backbone(args)
+ if args.finetune:
+ if opts.resume or opts.evaluate:
+ pass
+ else:
+ chk_filename = os.path.join(opts.pretrained, opts.selection)
+ print('Loading backbone', chk_filename)
+ checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)['model_pos']
+ model_backbone = load_pretrained_weights(model_backbone, checkpoint)
+ if args.partial_train:
+ model_backbone = partial_train_layers(model_backbone, args.partial_train)
+ model = MeshRegressor(args, backbone=model_backbone, dim_rep=args.dim_rep, hidden_dim=args.hidden_dim, dropout_ratio=args.dropout, num_joints=args.num_joints)
+ criterion = MeshLoss(loss_type = args.loss_type)
+ best_jpe = 9999.0
+ model_params = 0
+ for parameter in model.parameters():
+ if parameter.requires_grad == True:
+ model_params = model_params + parameter.numel()
+ print('INFO: Trainable parameter count:', model_params)
+ print('Loading dataset...')
+ trainloader_params = {
+ 'batch_size': args.batch_size,
+ 'shuffle': True,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+ testloader_params = {
+ 'batch_size': args.batch_size,
+ 'shuffle': False,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+ if hasattr(args, "dt_file_h36m"):
+ mesh_train = MotionSMPL(args, data_split='train', dataset="h36m")
+ mesh_val = MotionSMPL(args, data_split='test', dataset="h36m")
+ train_loader = DataLoader(mesh_train, **trainloader_params)
+ test_loader = DataLoader(mesh_val, **testloader_params)
+ print('INFO: Training on {} batches (h36m)'.format(len(train_loader)))
+
+ if hasattr(args, "dt_file_pw3d"):
+ if args.train_pw3d:
+ mesh_train_pw3d = MotionSMPL(args, data_split='train', dataset="pw3d")
+ train_loader_pw3d = DataLoader(mesh_train_pw3d, **trainloader_params)
+ print('INFO: Training on {} batches (pw3d)'.format(len(train_loader_pw3d)))
+ mesh_val_pw3d = MotionSMPL(args, data_split='test', dataset="pw3d")
+ test_loader_pw3d = DataLoader(mesh_val_pw3d, **testloader_params)
+
+
+ trainloader_img_params = {
+ 'batch_size': args.batch_size_img,
+ 'shuffle': True,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+ testloader_img_params = {
+ 'batch_size': args.batch_size_img,
+ 'shuffle': False,
+ 'num_workers': 8,
+ 'pin_memory': True,
+ 'prefetch_factor': 4,
+ 'persistent_workers': True
+ }
+
+ if hasattr(args, "dt_file_coco"):
+ mesh_train_coco = MotionSMPL(args, data_split='train', dataset="coco")
+ mesh_val_coco = MotionSMPL(args, data_split='test', dataset="coco")
+ train_loader_coco = DataLoader(mesh_train_coco, **trainloader_img_params)
+ test_loader_coco = DataLoader(mesh_val_coco, **testloader_img_params)
+ print('INFO: Training on {} batches (coco)'.format(len(train_loader_coco)))
+
+ if torch.cuda.is_available():
+ model = nn.DataParallel(model)
+ model = model.cuda()
+
+ chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin")
+ if os.path.exists(chk_filename):
+ opts.resume = chk_filename
+ if opts.resume or opts.evaluate:
+ chk_filename = opts.evaluate if opts.evaluate else opts.resume
+ print('Loading checkpoint', chk_filename)
+ checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
+ model.load_state_dict(checkpoint['model'], strict=True)
+ if not opts.evaluate:
+ optimizer = optim.AdamW(
+ [ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone},
+ {"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head},
+ ], lr=args.lr_backbone,
+ weight_decay=args.weight_decay
+ )
+ scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay)
+ st = 0
+ if opts.resume:
+ st = checkpoint['epoch']
+ if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ else:
+ print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.')
+ lr = checkpoint['lr']
+ if 'best_jpe' in checkpoint and checkpoint['best_jpe'] is not None:
+ best_jpe = checkpoint['best_jpe']
+
+ # Training
+ for epoch in range(st, args.epochs):
+ print('Training epoch %d.' % epoch)
+ losses_train = AverageMeter()
+ losses_dict = {
+ 'loss_3d_pos': AverageMeter(),
+ 'loss_3d_scale': AverageMeter(),
+ 'loss_3d_velocity': AverageMeter(),
+ 'loss_lv': AverageMeter(),
+ 'loss_lg': AverageMeter(),
+ 'loss_a': AverageMeter(),
+ 'loss_av': AverageMeter(),
+ 'loss_pose': AverageMeter(),
+ 'loss_shape': AverageMeter(),
+ 'loss_norm': AverageMeter(),
+ }
+ mpjpes = AverageMeter()
+ mpves = AverageMeter()
+ batch_time = AverageMeter()
+ data_time = AverageMeter()
+
+ if hasattr(args, "dt_file_h36m") and epoch < args.warmup_h36m:
+ train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch)
+ test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, test_losses_dict = validate(test_loader, model, criterion, 'h36m')
+ for k, v in test_losses_dict.items():
+ train_writer.add_scalar('test_loss/'+k, v.avg, epoch + 1)
+ train_writer.add_scalar('test_loss', test_loss, epoch + 1)
+ train_writer.add_scalar('test_mpjpe', test_mpjpe, epoch + 1)
+ train_writer.add_scalar('test_pa_mpjpe', test_pa_mpjpe, epoch + 1)
+ train_writer.add_scalar('test_mpve', test_mpve, epoch + 1)
+
+ if hasattr(args, "dt_file_coco") and epoch < args.warmup_coco:
+ train_epoch(args, opts, model, train_loader_coco, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch)
+
+ if hasattr(args, "dt_file_pw3d"):
+ if args.train_pw3d:
+ train_epoch(args, opts, model, train_loader_pw3d, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch)
+ test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, test_losses_dict_pw3d = validate(test_loader_pw3d, model, criterion, 'pw3d')
+ for k, v in test_losses_dict_pw3d.items():
+ train_writer.add_scalar('test_loss_pw3d/'+k, v.avg, epoch + 1)
+ train_writer.add_scalar('test_loss_pw3d', test_loss_pw3d, epoch + 1)
+ train_writer.add_scalar('test_mpjpe_pw3d', test_mpjpe_pw3d, epoch + 1)
+ train_writer.add_scalar('test_pa_mpjpe_pw3d', test_pa_mpjpe_pw3d, epoch + 1)
+ train_writer.add_scalar('test_mpve_pw3d', test_mpve_pw3d, epoch + 1)
+
+ for k, v in losses_dict.items():
+ train_writer.add_scalar('train_loss/'+k, v.avg, epoch + 1)
+ train_writer.add_scalar('train_loss', losses_train.avg, epoch + 1)
+ train_writer.add_scalar('train_mpjpe', mpjpes.avg, epoch + 1)
+ train_writer.add_scalar('train_mpve', mpves.avg, epoch + 1)
+
+ # Decay learning rate exponentially
+ scheduler.step()
+ # Save latest checkpoint.
+ chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin')
+ print('Saving checkpoint to', chk_path)
+ torch.save({
+ 'epoch': epoch+1,
+ 'lr': scheduler.get_last_lr(),
+ 'optimizer': optimizer.state_dict(),
+ 'model': model.state_dict(),
+ 'best_jpe' : best_jpe
+ }, chk_path)
+
+ # Save checkpoint if necessary.
+ if (epoch+1) % args.checkpoint_frequency == 0:
+ chk_path = os.path.join(opts.checkpoint, 'epoch_{}.bin'.format(epoch))
+ print('Saving checkpoint to', chk_path)
+ torch.save({
+ 'epoch': epoch+1,
+ 'lr': scheduler.get_last_lr(),
+ 'optimizer': optimizer.state_dict(),
+ 'model': model.state_dict(),
+ 'best_jpe' : best_jpe
+ }, chk_path)
+
+ if hasattr(args, "dt_file_pw3d"):
+ best_jpe_cur = test_mpjpe_pw3d
+ else:
+ best_jpe_cur = test_mpjpe
+ # Save best checkpoint.
+ best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch))
+ if best_jpe_cur < best_jpe:
+ best_jpe = best_jpe_cur
+ print("save best checkpoint")
+ torch.save({
+ 'epoch': epoch+1,
+ 'lr': scheduler.get_last_lr(),
+ 'optimizer': optimizer.state_dict(),
+ 'model': model.state_dict(),
+ 'best_jpe' : best_jpe
+ }, best_chk_path)
+
+ if opts.evaluate:
+ if hasattr(args, "dt_file_h36m"):
+ test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, _ = validate(test_loader, model, criterion, 'h36m')
+ if hasattr(args, "dt_file_pw3d"):
+ test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, _ = validate(test_loader_pw3d, model, criterion, 'pw3d')
+
+if __name__ == "__main__":
+ opts = parse_args()
+ set_random_seed(opts.seed)
+ args = get_config(opts.config)
+ train_with_config(args, opts)
\ No newline at end of file