Hieu Ngoc Giap
commited on
Commit
•
2b07837
1
Parent(s):
ecf1cc2
Upload folder using huggingface_hub
Browse files- pose/config/config.yaml +29 -0
- pose/config/load_cfg.py +41 -0
- pose/logs/__init__.py +32 -0
- pose/poetry.lock +0 -0
- pose/pyproject.toml +48 -0
- pose/src/inference/__init__.py +0 -0
- pose/src/inference/__pycache__/__init__.cpython-39.pyc +0 -0
- pose/src/inference/__pycache__/base.cpython-39.pyc +0 -0
- pose/src/inference/__pycache__/decode.cpython-39.pyc +0 -0
- pose/src/inference/__pycache__/pose_inference.cpython-39.pyc +0 -0
- pose/src/inference/base.py +57 -0
- pose/src/inference/decode.py +537 -0
- pose/src/inference/pose_inference.py +132 -0
- pose/src/weights/pose_model_scratch.pth +3 -0
pose/config/config.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -------------- Data -------------
|
2 |
+
data_root_path: ./data/raw
|
3 |
+
train_mask_data_path: ./data/raw/mask/train2014/mask_COCO_train2014_
|
4 |
+
val_mask_data_path: ./data/raw/mask/val2014/mask_COCO_val2014_
|
5 |
+
label_file: label.json
|
6 |
+
label_subset_file: label_subset.json
|
7 |
+
|
8 |
+
# -------------- Model -------------
|
9 |
+
model_weight_path: ./src/weights/pose_model_scratch.pth
|
10 |
+
|
11 |
+
# -------------- Logging ----------
|
12 |
+
logging_file: ./logs/logging_file.log
|
13 |
+
|
14 |
+
# ------------- Hyperparamters ------------
|
15 |
+
hyperparameters:
|
16 |
+
train_batch_size: 8
|
17 |
+
val_batch_size: 8
|
18 |
+
lr: 0.001
|
19 |
+
betas: [0.9, 0.999]
|
20 |
+
weight_decay: 0.0001
|
21 |
+
epochs: 20
|
22 |
+
|
23 |
+
# -------------- DVC remote ---------------
|
24 |
+
dvc_remote_name: gcs-storage
|
25 |
+
dvc_remote_url: gs://human-pose-data-bucket/data
|
26 |
+
|
27 |
+
|
28 |
+
# -------------- MLflow --------------
|
29 |
+
experiment_name: openpose-human-pose-training
|
pose/config/load_cfg.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
ROOT = Path(__file__).resolve().parent.parent
|
7 |
+
CONFIG_FILE_PATH = ROOT / "config" / "config.yaml"
|
8 |
+
|
9 |
+
|
10 |
+
class DictDotNotation(dict):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super().__init__(*args, **kwargs)
|
13 |
+
self.__dict__ = self
|
14 |
+
|
15 |
+
|
16 |
+
def _find_config_file() -> Path:
|
17 |
+
"""Locate the configuration file."""
|
18 |
+
if CONFIG_FILE_PATH.is_file():
|
19 |
+
return CONFIG_FILE_PATH
|
20 |
+
raise FileNotFoundError(f"Config file not found at {CONFIG_FILE_PATH}")
|
21 |
+
|
22 |
+
|
23 |
+
def load_config_file(cfg_path: Optional[Path] = None) -> Optional[dict]:
|
24 |
+
if not cfg_path:
|
25 |
+
cfg_path = _find_config_file()
|
26 |
+
|
27 |
+
if cfg_path:
|
28 |
+
with open(cfg_path, "r") as f:
|
29 |
+
yaml_data = yaml.safe_load(f)
|
30 |
+
if not yaml_data:
|
31 |
+
raise ValueError("Invalid or empty YAML configuration")
|
32 |
+
return yaml_data
|
33 |
+
|
34 |
+
|
35 |
+
def configure() -> DictDotNotation:
|
36 |
+
cfg = load_config_file()
|
37 |
+
cfg = DictDotNotation(cfg)
|
38 |
+
return cfg
|
39 |
+
|
40 |
+
|
41 |
+
cfg = configure()
|
pose/logs/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypeVar
|
2 |
+
|
3 |
+
import loguru
|
4 |
+
|
5 |
+
from config import cfg
|
6 |
+
|
7 |
+
log_level = "DEBUG"
|
8 |
+
log_format = (
|
9 |
+
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS zz}</green> | "
|
10 |
+
"<level>{level: <8}</level> | "
|
11 |
+
"<yellow>Line {line: >4} ({file}):</yellow> <b>{message}</b>"
|
12 |
+
)
|
13 |
+
|
14 |
+
_T_logoru_logger = TypeVar("_T_logoru_logger", bound=loguru._logger.Logger)
|
15 |
+
|
16 |
+
|
17 |
+
def logger_handler(
|
18 |
+
use_log_file: bool = True, file: str = "./logs/logging_file.log"
|
19 |
+
) -> _T_logoru_logger:
|
20 |
+
if use_log_file:
|
21 |
+
loguru.logger.add(
|
22 |
+
file,
|
23 |
+
level=log_level,
|
24 |
+
format=log_format,
|
25 |
+
colorize=False,
|
26 |
+
backtrace=True,
|
27 |
+
diagnose=True,
|
28 |
+
)
|
29 |
+
return loguru.logger
|
30 |
+
|
31 |
+
|
32 |
+
log = logger_handler(file=cfg.logging_file)
|
pose/poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pose/pyproject.toml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "human-pose-estimation-development"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Development stage of human pose estimation system"
|
5 |
+
authors = ["To Duc Thanh"]
|
6 |
+
license = "MIT"
|
7 |
+
readme = "README.md"
|
8 |
+
package-mode = false
|
9 |
+
|
10 |
+
[tool.poetry.dependencies]
|
11 |
+
python = ">=3.9,<3.9.7"
|
12 |
+
pyyaml = "^6.0.1"
|
13 |
+
torch = "^2.2.2"
|
14 |
+
torchvision = "^0.17.2"
|
15 |
+
opencv-python = "^4.9.0.80"
|
16 |
+
tqdm = "^4.66.2"
|
17 |
+
loguru = "^0.7.2"
|
18 |
+
matplotlib = "^3.8.4"
|
19 |
+
python-dotenv = "^1.0.1"
|
20 |
+
mlflow = "^2.11.3"
|
21 |
+
pynvml = "^11.5.0"
|
22 |
+
dvc = {extras = ["gdrive", "gs"], version = "^3.48.4"}
|
23 |
+
scipy = "^1.13.0"
|
24 |
+
minio = "^7.2.5"
|
25 |
+
|
26 |
+
|
27 |
+
[tool.poetry.group.dev.dependencies]
|
28 |
+
isort = "^5.13.2"
|
29 |
+
pytest = "^8.1.1"
|
30 |
+
pre-commit = "^3.7.0"
|
31 |
+
jupyterlab = "^4.1.5"
|
32 |
+
ruff = "^0.3.5"
|
33 |
+
|
34 |
+
|
35 |
+
[tool.poetry.group.deploy.dependencies]
|
36 |
+
torch = "^2.2.2"
|
37 |
+
torchvision = "^0.17.2"
|
38 |
+
opencv-python = "^4.9.0.80"
|
39 |
+
scipy = "^1.13.0"
|
40 |
+
loguru = "^0.7.2"
|
41 |
+
matplotlib = "^3.8.4"
|
42 |
+
fastapi = "^0.110.1"
|
43 |
+
uvicorn = {extras = ["standard"], version = "^0.29.0"}
|
44 |
+
python-multipart = "^0.0.9"
|
45 |
+
|
46 |
+
[build-system]
|
47 |
+
requires = ["poetry-core"]
|
48 |
+
build-backend = "poetry.core.masonry.api"
|
pose/src/inference/__init__.py
ADDED
File without changes
|
pose/src/inference/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (207 Bytes). View file
|
|
pose/src/inference/__pycache__/base.cpython-39.pyc
ADDED
Binary file (2.15 kB). View file
|
|
pose/src/inference/__pycache__/decode.cpython-39.pyc
ADDED
Binary file (12.5 kB). View file
|
|
pose/src/inference/__pycache__/pose_inference.cpython-39.pyc
ADDED
Binary file (3.9 kB). View file
|
|
pose/src/inference/base.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import (
|
2 |
+
ABC,
|
3 |
+
abstractmethod,
|
4 |
+
)
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class PoseInferenceBase(ABC):
|
12 |
+
@abstractmethod
|
13 |
+
def preprocess(
|
14 |
+
self, img: Union[np.ndarray, str], *args, **kwargs
|
15 |
+
) -> Union[torch.Tensor, np.ndarray]:
|
16 |
+
"""
|
17 |
+
Preprocesses the input image before inference.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
img (Union[np.ndarray, str]): The input image as a NumPy array or a path to the image file.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Union[torch.Tensor, np.ndarray]: The preprocessed image tensor or array ready for inference.
|
24 |
+
"""
|
25 |
+
raise NotImplementedError
|
26 |
+
|
27 |
+
@abstractmethod
|
28 |
+
def process(
|
29 |
+
self, img: Union[np.ndarray, str], *args, **kwargs
|
30 |
+
) -> Union[torch.Tensor, np.ndarray]:
|
31 |
+
"""
|
32 |
+
Performs inference on the input image.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
img (Union[np.ndarray, str]): The input image as a NumPy array or a path to the image file.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
Union[torch.Tensor, np.ndarray]: The output of the inference process.
|
39 |
+
"""
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
@abstractmethod
|
43 |
+
def postprocess(
|
44 |
+
self, oriImg: np.ndarray, heatmaps: np.ndarray, pafs: np.ndarray, *args, **kwargs
|
45 |
+
) -> Union[torch.Tensor, np.ndarray]:
|
46 |
+
"""
|
47 |
+
Postprocesses the inference results.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
oriImg (np.ndarray): The original input image.
|
51 |
+
heatmaps (np.ndarray): The heatmaps generated by the inference.
|
52 |
+
pafs (np.ndarray): The Part Affinity Fields (PAFs) generated by the inference.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
Union[torch.Tensor, np.ndarray]: The postprocessed results.
|
56 |
+
"""
|
57 |
+
raise NotImplementedError
|
pose/src/inference/decode.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
from typing import (
|
4 |
+
List,
|
5 |
+
Tuple,
|
6 |
+
)
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import matplotlib.cm
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from scipy.ndimage.filters import (
|
13 |
+
gaussian_filter,
|
14 |
+
maximum_filter,
|
15 |
+
)
|
16 |
+
from scipy.ndimage.morphology import generate_binary_structure
|
17 |
+
|
18 |
+
# It is better to use 0.1 as threshold when evaluation, but 0.3 for demo
|
19 |
+
# purpose.
|
20 |
+
cmap = matplotlib.cm.get_cmap("hsv")
|
21 |
+
|
22 |
+
# Heatmap indices to find each limb (joint connection). Eg: limb_type=1 is
|
23 |
+
# Neck->LShoulder, so joint_to_limb_heatmap_relationship[1] represents the
|
24 |
+
# indices of heatmaps to look for joints: neck=1, LShoulder=5
|
25 |
+
joint_to_limb_heatmap_relationship = [
|
26 |
+
[1, 2],
|
27 |
+
[1, 5],
|
28 |
+
[2, 3],
|
29 |
+
[3, 4],
|
30 |
+
[5, 6],
|
31 |
+
[6, 7],
|
32 |
+
[1, 8],
|
33 |
+
[8, 9],
|
34 |
+
[9, 10],
|
35 |
+
[1, 11],
|
36 |
+
[11, 12],
|
37 |
+
[12, 13],
|
38 |
+
[1, 0],
|
39 |
+
[0, 14],
|
40 |
+
[14, 16],
|
41 |
+
[0, 15],
|
42 |
+
[15, 17],
|
43 |
+
[2, 16],
|
44 |
+
[5, 17],
|
45 |
+
]
|
46 |
+
|
47 |
+
# PAF indices containing the x and y coordinates of the PAF for a given limb.
|
48 |
+
# Eg: limb_type=1 is Neck->LShoulder, so
|
49 |
+
# PAFneckLShoulder_x=paf_xy_coords_per_limb[1][0] and
|
50 |
+
# PAFneckLShoulder_y=paf_xy_coords_per_limb[1][1]
|
51 |
+
paf_xy_coords_per_limb = [
|
52 |
+
[12, 13],
|
53 |
+
[20, 21],
|
54 |
+
[14, 15],
|
55 |
+
[16, 17],
|
56 |
+
[22, 23],
|
57 |
+
[24, 25],
|
58 |
+
[0, 1],
|
59 |
+
[2, 3],
|
60 |
+
[4, 5],
|
61 |
+
[6, 7],
|
62 |
+
[8, 9],
|
63 |
+
[10, 11],
|
64 |
+
[28, 29],
|
65 |
+
[30, 31],
|
66 |
+
[34, 35],
|
67 |
+
[32, 33],
|
68 |
+
[36, 37],
|
69 |
+
[18, 19],
|
70 |
+
[26, 27],
|
71 |
+
]
|
72 |
+
|
73 |
+
# Color code used to plot different joints and limbs (eg: joint_type=3 and
|
74 |
+
# limb_type=3 will use colors[3])
|
75 |
+
colors = [
|
76 |
+
[255, 0, 0],
|
77 |
+
[255, 85, 0],
|
78 |
+
[255, 170, 0],
|
79 |
+
[255, 255, 0],
|
80 |
+
[170, 255, 0],
|
81 |
+
[85, 255, 0],
|
82 |
+
[0, 255, 0],
|
83 |
+
[0, 255, 85],
|
84 |
+
[0, 255, 170],
|
85 |
+
[0, 255, 255],
|
86 |
+
[0, 170, 255],
|
87 |
+
[0, 85, 255],
|
88 |
+
[0, 0, 255],
|
89 |
+
[85, 0, 255],
|
90 |
+
[170, 0, 255],
|
91 |
+
[255, 0, 255],
|
92 |
+
[255, 0, 170],
|
93 |
+
[255, 0, 85],
|
94 |
+
[255, 0, 0],
|
95 |
+
]
|
96 |
+
|
97 |
+
NUM_JOINTS = 18
|
98 |
+
NUM_LIMBS = len(joint_to_limb_heatmap_relationship)
|
99 |
+
|
100 |
+
|
101 |
+
def find_peaks(param: dict, img: np.ndarray) -> np.ndarray:
|
102 |
+
"""
|
103 |
+
Finds local maxima in a (grayscale) image whose values are above a given threshold.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
param (dict): A dictionary containing parameters.
|
107 |
+
- 'thre1' (float): Threshold for peak detection.
|
108 |
+
img (np.ndarray): Input grayscale image (2D array) where peaks are to be found.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
np.ndarray: A 2D array containing the [x, y] coordinates of each peak found in the image.
|
112 |
+
"""
|
113 |
+
peaks_binary = (maximum_filter(img, footprint=generate_binary_structure(2, 1)) == img) * (
|
114 |
+
img > param["thre1"]
|
115 |
+
)
|
116 |
+
# Note reverse ([::-1]): we return [[x y], [x y]...] instead of [[y x], [y x]...]
|
117 |
+
return np.array(np.nonzero(peaks_binary)[::-1]).T
|
118 |
+
|
119 |
+
|
120 |
+
def compute_resized_coords(coords: Tuple[float, float], resizeFactor: float) -> np.ndarray:
|
121 |
+
"""
|
122 |
+
Computes the new coordinates of a cell in an array after resizing the array.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
coords (Tuple[float, float]): Coordinates (indices) of a cell in some input array.
|
126 |
+
resizeFactor (float): Resize coefficient, indicating how much bigger the destination array is compared to the original one.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
np.ndarray: Coordinates in an array of size `shape_dest`, expressing the array indices of the closest point to 'coords'
|
130 |
+
if an image of size `shape_source` was resized to `shape_dest`.
|
131 |
+
"""
|
132 |
+
|
133 |
+
# 1) Add 0.5 to coords to get coordinates of center of the pixel (e.g.
|
134 |
+
# index [0,0] represents the pixel at location [0.5,0.5])
|
135 |
+
# 2) Transform those coordinates to shape_dest, by multiplying by resizeFactor
|
136 |
+
# 3) That number represents the location of the pixel center in the new array,
|
137 |
+
# so subtract 0.5 to get coordinates of the array index/indices (revert step 1)
|
138 |
+
return (np.array(coords, dtype=float) + 0.5) * resizeFactor - 0.5
|
139 |
+
|
140 |
+
|
141 |
+
def NMS(
|
142 |
+
param: dict,
|
143 |
+
heatmaps: np.ndarray,
|
144 |
+
upsampFactor: float = 1.0,
|
145 |
+
bool_refine_center: bool = True,
|
146 |
+
bool_gaussian_filt: bool = False,
|
147 |
+
) -> List[np.ndarray]:
|
148 |
+
"""
|
149 |
+
Performs Non-Maxima Suppression (NMS) to find peaks (local maxima) in a set of grayscale images.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
param (dict): Additional parameters for NMS.
|
153 |
+
heatmaps (np.ndarray): Set of grayscale images on which to find local maxima. A 3D numpy array with dimensions
|
154 |
+
image_height x image_width x num_heatmaps.
|
155 |
+
upsampFactor (float): Size ratio between CPM (Convolutional Pose Machine) heatmap output and the input image size.
|
156 |
+
For example, upsampFactor=16 if the original image was 480x640 and heatmaps are 30x40xN.
|
157 |
+
bool_refine_center (bool): Flag indicating whether to refine the center of the peak. Defaults to True.
|
158 |
+
If True, the function upsamples a small patch around each low-res peak and fine-tunes the location of the peak
|
159 |
+
at the resolution of the original input image. If False, simply returns the low-res peak found upscaled by upsampFactor.
|
160 |
+
bool_gaussian_filt (bool): Flag indicating whether to apply a 1D Gaussian filter (smoothing) to each upsampled patch
|
161 |
+
before fine-tuning the location of each peak. Defaults to False.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
np.ndarray: A NUM_JOINTS x 4 numpy array where each row represents a joint type (0=nose, 1=neck...) and the columns
|
165 |
+
indicate the {x,y} position, the score (probability), and a unique id (counter).
|
166 |
+
"""
|
167 |
+
|
168 |
+
joint_list_per_joint_type = []
|
169 |
+
cnt_total_joints = 0
|
170 |
+
|
171 |
+
# For every peak found, win_size specifies how many pixels in each
|
172 |
+
# direction from the peak we take to obtain the patch that will be
|
173 |
+
# upsampled. Eg: win_size=1 -> patch is 3x3; win_size=2 -> 5x5
|
174 |
+
# (for BICUBIC interpolation to be accurate, win_size needs to be >=2!)
|
175 |
+
win_size = 2
|
176 |
+
|
177 |
+
for joint in range(NUM_JOINTS):
|
178 |
+
map_orig = heatmaps[:, :, joint]
|
179 |
+
peak_coords = find_peaks(param, map_orig)
|
180 |
+
peaks = np.zeros((len(peak_coords), 4))
|
181 |
+
for i, peak in enumerate(peak_coords):
|
182 |
+
if bool_refine_center:
|
183 |
+
x_min, y_min = np.maximum(0, peak - win_size)
|
184 |
+
x_max, y_max = np.minimum(np.array(map_orig.T.shape) - 1, peak + win_size)
|
185 |
+
|
186 |
+
# Take a small patch around each peak and only upsample that tiny region
|
187 |
+
patch = map_orig[y_min : y_max + 1, x_min : x_max + 1]
|
188 |
+
map_upsamp = cv2.resize(
|
189 |
+
patch, None, fx=upsampFactor, fy=upsampFactor, interpolation=cv2.INTER_CUBIC
|
190 |
+
)
|
191 |
+
|
192 |
+
# Gaussian filtering takes an average of 0.8ms/peak (and there might be
|
193 |
+
# more than one peak per joint!) -> For now, skip it (it's
|
194 |
+
# accurate enough)
|
195 |
+
map_upsamp = (
|
196 |
+
gaussian_filter(map_upsamp, sigma=3) if bool_gaussian_filt else map_upsamp
|
197 |
+
)
|
198 |
+
|
199 |
+
# Obtain the coordinates of the maximum value in the patch
|
200 |
+
location_of_max = np.unravel_index(map_upsamp.argmax(), map_upsamp.shape)
|
201 |
+
# Remember that peaks indicates [x,y] -> need to reverse it for
|
202 |
+
# [y,x]
|
203 |
+
location_of_patch_center = compute_resized_coords(
|
204 |
+
peak[::-1] - [y_min, x_min], upsampFactor
|
205 |
+
)
|
206 |
+
# Calculate the offset wrt to the patch center where the actual
|
207 |
+
# maximum is
|
208 |
+
refined_center = location_of_max - location_of_patch_center
|
209 |
+
peak_score = map_upsamp[location_of_max]
|
210 |
+
else:
|
211 |
+
refined_center = [0, 0]
|
212 |
+
# Flip peak coordinates since they are [x,y] instead of [y,x]
|
213 |
+
peak_score = map_orig[tuple(peak[::-1])]
|
214 |
+
peaks[i, :] = tuple(
|
215 |
+
[
|
216 |
+
int(round(x))
|
217 |
+
for x in compute_resized_coords(peak_coords[i], upsampFactor)
|
218 |
+
+ refined_center[::-1]
|
219 |
+
]
|
220 |
+
) + (peak_score, cnt_total_joints)
|
221 |
+
cnt_total_joints += 1
|
222 |
+
joint_list_per_joint_type.append(peaks)
|
223 |
+
|
224 |
+
return joint_list_per_joint_type
|
225 |
+
|
226 |
+
|
227 |
+
def find_connected_joints(
|
228 |
+
param: dict,
|
229 |
+
paf_upsamp: np.ndarray,
|
230 |
+
joint_list_per_joint_type: List[np.ndarray],
|
231 |
+
num_intermed_pts: int = 10,
|
232 |
+
) -> List[np.ndarray]:
|
233 |
+
"""
|
234 |
+
For every type of limb (e.g., forearm, shin, etc.), looks for every potential
|
235 |
+
pair of joints (e.g., every wrist-elbow combination) and evaluates the PAFs to
|
236 |
+
determine which pairs are indeed body limbs.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
paf_upsamp (np.ndarray): PAFs upsampled to the original input image resolution.
|
240 |
+
joint_list_per_joint_type (List[np.ndarray]): List of joint lists per joint type. See the 'return' doc of NMS().
|
241 |
+
num_intermed_pts (int): Number of intermediate points to take between joint_src and joint_dst, at which
|
242 |
+
the PAFs will be evaluated. Defaults to 10.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
List[np.ndarray]: List of NUM_LIMBS rows. For every limb_type (a row) we store a list of all limbs of that type found
|
246 |
+
(e.g., all the right forearms). For each limb (each item in connected_limbs[limb_type]), we store 5 cells:
|
247 |
+
{joint_src_id, joint_dst_id}: a unique number associated with each joint,
|
248 |
+
limb_score_penalizing_long_dist: a score of how good a connection of the joints is, penalized if the limb length is too long,
|
249 |
+
{joint_src_index, joint_dst_index}: the index of the joint within all the joints of that type found
|
250 |
+
(e.g., the 3rd right elbow found)
|
251 |
+
"""
|
252 |
+
connected_limbs = []
|
253 |
+
|
254 |
+
# Auxiliary array to access paf_upsamp quickly
|
255 |
+
limb_intermed_coords = np.empty((4, num_intermed_pts), dtype=np.intp)
|
256 |
+
for limb_type in range(NUM_LIMBS):
|
257 |
+
# List of all joints of type A found, where A is specified by limb_type
|
258 |
+
# (eg: a right forearm starts in a right elbow)
|
259 |
+
joints_src = joint_list_per_joint_type[joint_to_limb_heatmap_relationship[limb_type][0]]
|
260 |
+
# List of all joints of type B found, where B is specified by limb_type
|
261 |
+
# (eg: a right forearm ends in a right wrist)
|
262 |
+
joints_dst = joint_list_per_joint_type[joint_to_limb_heatmap_relationship[limb_type][1]]
|
263 |
+
if len(joints_src) == 0 or len(joints_dst) == 0:
|
264 |
+
# No limbs of this type found (eg: no right forearms found because
|
265 |
+
# we didn't find any right wrists or right elbows)
|
266 |
+
connected_limbs.append([])
|
267 |
+
else:
|
268 |
+
connection_candidates = []
|
269 |
+
# Specify the paf index that contains the x-coord of the paf for
|
270 |
+
# this limb
|
271 |
+
limb_intermed_coords[2, :] = paf_xy_coords_per_limb[limb_type][0]
|
272 |
+
# And the y-coord paf index
|
273 |
+
limb_intermed_coords[3, :] = paf_xy_coords_per_limb[limb_type][1]
|
274 |
+
for i, joint_src in enumerate(joints_src):
|
275 |
+
# Try every possible joints_src[i]-joints_dst[j] pair and see
|
276 |
+
# if it's a feasible limb
|
277 |
+
for j, joint_dst in enumerate(joints_dst):
|
278 |
+
# Subtract the position of both joints to obtain the
|
279 |
+
# direction of the potential limb
|
280 |
+
limb_dir = joint_dst[:2] - joint_src[:2]
|
281 |
+
# Compute the distance/length of the potential limb (norm
|
282 |
+
# of limb_dir)
|
283 |
+
limb_dist = np.sqrt(np.sum(limb_dir**2)) + 1e-8
|
284 |
+
limb_dir = limb_dir / limb_dist # Normalize limb_dir to be a unit vector
|
285 |
+
|
286 |
+
# Linearly distribute num_intermed_pts points from the x
|
287 |
+
# coordinate of joint_src to the x coordinate of joint_dst
|
288 |
+
limb_intermed_coords[1, :] = np.round(
|
289 |
+
np.linspace(joint_src[0], joint_dst[0], num=num_intermed_pts)
|
290 |
+
)
|
291 |
+
limb_intermed_coords[0, :] = np.round(
|
292 |
+
np.linspace(joint_src[1], joint_dst[1], num=num_intermed_pts)
|
293 |
+
) # Same for the y coordinate
|
294 |
+
intermed_paf = paf_upsamp[
|
295 |
+
limb_intermed_coords[0, :],
|
296 |
+
limb_intermed_coords[1, :],
|
297 |
+
limb_intermed_coords[2:4, :],
|
298 |
+
].T
|
299 |
+
|
300 |
+
score_intermed_pts = intermed_paf.dot(limb_dir)
|
301 |
+
score_penalizing_long_dist = score_intermed_pts.mean() + min(
|
302 |
+
0.5 * paf_upsamp.shape[0] / limb_dist - 1, 0
|
303 |
+
)
|
304 |
+
# Criterion 1: At least 80% of the intermediate points have
|
305 |
+
# a score higher than thre2
|
306 |
+
criterion1 = (
|
307 |
+
np.count_nonzero(score_intermed_pts > param["thre2"])
|
308 |
+
> 0.8 * num_intermed_pts
|
309 |
+
)
|
310 |
+
# Criterion 2: Mean score, penalized for large limb
|
311 |
+
# distances (larger than half the image height), is
|
312 |
+
# positive
|
313 |
+
criterion2 = score_penalizing_long_dist > 0
|
314 |
+
if criterion1 and criterion2:
|
315 |
+
# Last value is the combined paf(+limb_dist) + heatmap
|
316 |
+
# scores of both joints
|
317 |
+
connection_candidates.append(
|
318 |
+
[
|
319 |
+
i,
|
320 |
+
j,
|
321 |
+
score_penalizing_long_dist,
|
322 |
+
score_penalizing_long_dist + joint_src[2] + joint_dst[2],
|
323 |
+
]
|
324 |
+
)
|
325 |
+
|
326 |
+
# Sort connection candidates based on their
|
327 |
+
# score_penalizing_long_dist
|
328 |
+
connection_candidates = sorted(connection_candidates, key=lambda x: x[2], reverse=True)
|
329 |
+
connections = np.empty((0, 5))
|
330 |
+
# There can only be as many limbs as the smallest number of source
|
331 |
+
# or destination joints (eg: only 2 forearms if there's 5 wrists
|
332 |
+
# but 2 elbows)
|
333 |
+
max_connections = min(len(joints_src), len(joints_dst))
|
334 |
+
# Traverse all potential joint connections (sorted by their score)
|
335 |
+
for potential_connection in connection_candidates:
|
336 |
+
i, j, s = potential_connection[0:3]
|
337 |
+
# Make sure joints_src[i] or joints_dst[j] haven't already been
|
338 |
+
# connected to other joints_dst or joints_src
|
339 |
+
if i not in connections[:, 3] and j not in connections[:, 4]:
|
340 |
+
# [joint_src_id, joint_dst_id, limb_score_penalizing_long_dist, joint_src_index, joint_dst_index]
|
341 |
+
connections = np.vstack(
|
342 |
+
[connections, [joints_src[i][3], joints_dst[j][3], s, i, j]]
|
343 |
+
)
|
344 |
+
# Exit if we've already established max_connections
|
345 |
+
# connections (each joint can't be connected to more than
|
346 |
+
# one joint)
|
347 |
+
if len(connections) >= max_connections:
|
348 |
+
break
|
349 |
+
connected_limbs.append(connections)
|
350 |
+
|
351 |
+
return connected_limbs
|
352 |
+
|
353 |
+
|
354 |
+
def group_limbs_of_same_person(
|
355 |
+
connected_limbs: List[np.ndarray], joint_list: np.ndarray
|
356 |
+
) -> np.ndarray:
|
357 |
+
"""
|
358 |
+
Associate limbs belonging to the same person together.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
connected_limbs (List[np.ndarray]): List of connected limbs.
|
362 |
+
See the 'return' doc of find_connected_joints().
|
363 |
+
joint_list (np.ndarray): Unraveled version of joint_list_per_joint.
|
364 |
+
See the 'return' doc of NMS().
|
365 |
+
|
366 |
+
Returns:
|
367 |
+
np.ndarray: A 2D array of size num_people x (NUM_JOINTS+2). For each person found:
|
368 |
+
- First NUM_JOINTS columns contain the index (in joint_list) of the joints associated with that person
|
369 |
+
(or -1 if their i-th joint wasn't found).
|
370 |
+
- 2nd-to-last column: Overall score of the joints+limbs that belong to this person.
|
371 |
+
"""
|
372 |
+
person_to_joint_assoc = []
|
373 |
+
|
374 |
+
for limb_type in range(NUM_LIMBS):
|
375 |
+
joint_src_type, joint_dst_type = joint_to_limb_heatmap_relationship[limb_type]
|
376 |
+
|
377 |
+
for limb_info in connected_limbs[limb_type]:
|
378 |
+
person_assoc_idx = []
|
379 |
+
for person, person_limbs in enumerate(person_to_joint_assoc):
|
380 |
+
if (
|
381 |
+
person_limbs[joint_src_type] == limb_info[0]
|
382 |
+
or person_limbs[joint_dst_type] == limb_info[1]
|
383 |
+
):
|
384 |
+
person_assoc_idx.append(person)
|
385 |
+
|
386 |
+
# If one of the joints has been associated to a person, and either
|
387 |
+
# the other joint is also associated with the same person or not
|
388 |
+
# associated to anyone yet:
|
389 |
+
if len(person_assoc_idx) == 1:
|
390 |
+
person_limbs = person_to_joint_assoc[person_assoc_idx[0]]
|
391 |
+
# If the other joint is not associated to anyone yet,
|
392 |
+
if person_limbs[joint_dst_type] != limb_info[1]:
|
393 |
+
# Associate it with the current person
|
394 |
+
person_limbs[joint_dst_type] = limb_info[1]
|
395 |
+
# Increase the number of limbs associated to this person
|
396 |
+
person_limbs[-1] += 1
|
397 |
+
# And update the total score (+= heatmap score of joint_dst
|
398 |
+
# + score of connecting joint_src with joint_dst)
|
399 |
+
person_limbs[-2] += joint_list[limb_info[1].astype(int), 2] + limb_info[2]
|
400 |
+
elif len(person_assoc_idx) == 2: # if found 2 and disjoint, merge them
|
401 |
+
person1_limbs = person_to_joint_assoc[person_assoc_idx[0]]
|
402 |
+
person2_limbs = person_to_joint_assoc[person_assoc_idx[1]]
|
403 |
+
membership = ((person1_limbs >= 0) & (person2_limbs >= 0))[:-2]
|
404 |
+
if (
|
405 |
+
not membership.any()
|
406 |
+
): # If both people have no same joints connected, merge them into a single person
|
407 |
+
# Update which joints are connected
|
408 |
+
person1_limbs[:-2] += person2_limbs[:-2] + 1
|
409 |
+
# Update the overall score and total count of joints
|
410 |
+
# connected by summing their counters
|
411 |
+
person1_limbs[-2:] += person2_limbs[-2:]
|
412 |
+
# Add the score of the current joint connection to the
|
413 |
+
# overall score
|
414 |
+
person1_limbs[-2] += limb_info[2]
|
415 |
+
person_to_joint_assoc.pop(person_assoc_idx[1])
|
416 |
+
else: # Same case as len(person_assoc_idx)==1 above
|
417 |
+
person1_limbs[joint_dst_type] = limb_info[1]
|
418 |
+
person1_limbs[-1] += 1
|
419 |
+
person1_limbs[-2] += joint_list[limb_info[1].astype(int), 2] + limb_info[2]
|
420 |
+
else: # No person has claimed any of these joints, create a new person
|
421 |
+
# Initialize person info to all -1 (no joint associations)
|
422 |
+
row = -1 * np.ones(20)
|
423 |
+
# Store the joint info of the new connection
|
424 |
+
row[joint_src_type] = limb_info[0]
|
425 |
+
row[joint_dst_type] = limb_info[1]
|
426 |
+
# Total count of connected joints for this person: 2
|
427 |
+
row[-1] = 2
|
428 |
+
# Compute overall score: score joint_src + score joint_dst + score connection
|
429 |
+
# {joint_src,joint_dst}
|
430 |
+
row[-2] = sum(joint_list[limb_info[:2].astype(int), 2]) + limb_info[2]
|
431 |
+
person_to_joint_assoc.append(row)
|
432 |
+
|
433 |
+
# Delete people who have very few parts connected
|
434 |
+
people_to_delete = []
|
435 |
+
for person_id, person_info in enumerate(person_to_joint_assoc):
|
436 |
+
if person_info[-1] < 3 or person_info[-2] / person_info[-1] < 0.2:
|
437 |
+
people_to_delete.append(person_id)
|
438 |
+
# Traverse the list in reverse order so we delete indices starting from the
|
439 |
+
# last one (otherwise, removing item for example 0 would modify the indices of
|
440 |
+
# the remaining people to be deleted!)
|
441 |
+
for index in people_to_delete[::-1]:
|
442 |
+
person_to_joint_assoc.pop(index)
|
443 |
+
|
444 |
+
# Appending items to a np.array can be very costly (allocating new memory, copying over the array, then adding new row)
|
445 |
+
# Instead, we treat the set of people as a list (fast to append items) and
|
446 |
+
# only convert to np.array at the end
|
447 |
+
return np.array(person_to_joint_assoc)
|
448 |
+
|
449 |
+
|
450 |
+
def plot_pose(
|
451 |
+
img_orig: np.ndarray,
|
452 |
+
joint_list: np.ndarray,
|
453 |
+
person_to_joint_assoc: np.ndarray,
|
454 |
+
bool_fast_plot: bool = True,
|
455 |
+
plot_ear_to_shoulder: bool = False,
|
456 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
457 |
+
canvas = img_orig.copy() # Make a copy so we don't modify the original image
|
458 |
+
|
459 |
+
# to_plot is the location of all joints found overlaid on top of the
|
460 |
+
# original image
|
461 |
+
to_plot = canvas.copy() if bool_fast_plot else cv2.addWeighted(img_orig, 0.3, canvas, 0.7, 0)
|
462 |
+
|
463 |
+
limb_thickness = 4
|
464 |
+
# Last 2 limbs connect ears with shoulders and this looks very weird.
|
465 |
+
# Disabled by default to be consistent with original rtpose output
|
466 |
+
which_limbs_to_plot = NUM_LIMBS if plot_ear_to_shoulder else NUM_LIMBS - 2
|
467 |
+
for limb_type in range(which_limbs_to_plot):
|
468 |
+
for person_joint_info in person_to_joint_assoc:
|
469 |
+
joint_indices = person_joint_info[
|
470 |
+
joint_to_limb_heatmap_relationship[limb_type]
|
471 |
+
].astype(int)
|
472 |
+
if -1 in joint_indices:
|
473 |
+
# Only draw actual limbs (connected joints), skip if not
|
474 |
+
# connected
|
475 |
+
continue
|
476 |
+
# joint_coords[:,0] represents Y coords of both joints;
|
477 |
+
# joint_coords[:,1], X coords
|
478 |
+
joint_coords = joint_list[joint_indices, 0:2]
|
479 |
+
|
480 |
+
for joint in joint_coords:
|
481 |
+
cv2.circle(canvas, tuple(joint[0:2].astype(int)), 2, (255, 255, 255), thickness=-1)
|
482 |
+
# mean along the axis=0 computes meanYcoord and meanXcoord -> Round
|
483 |
+
# and make int to avoid errors
|
484 |
+
coords_center = tuple(np.round(np.mean(joint_coords, 0)).astype(int))
|
485 |
+
# joint_coords[0,:] is the coords of joint_src; joint_coords[1,:]
|
486 |
+
# is the coords of joint_dst
|
487 |
+
limb_dir = joint_coords[0, :] - joint_coords[1, :]
|
488 |
+
limb_length = np.linalg.norm(limb_dir)
|
489 |
+
# Get the angle of limb_dir in degrees using atan2(limb_dir_x, limb_dir_y)
|
490 |
+
angle = math.degrees(math.atan2(limb_dir[1], limb_dir[0]))
|
491 |
+
|
492 |
+
# For faster plotting, just plot over canvas instead of constantly
|
493 |
+
# copying it
|
494 |
+
cur_canvas = canvas if bool_fast_plot else canvas.copy()
|
495 |
+
polygon = cv2.ellipse2Poly(
|
496 |
+
coords_center, (int(limb_length / 2), limb_thickness), int(angle), 0, 360, 1
|
497 |
+
)
|
498 |
+
cv2.fillConvexPoly(cur_canvas, polygon, colors[limb_type])
|
499 |
+
if not bool_fast_plot:
|
500 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
501 |
+
|
502 |
+
return to_plot, canvas
|
503 |
+
|
504 |
+
|
505 |
+
def decode_pose(
|
506 |
+
img_orig: np.ndarray, heatmaps: np.ndarray, pafs: np.ndarray
|
507 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
508 |
+
param = {"thre1": 0.1, "thre2": 0.05, "thre3": 0.5}
|
509 |
+
|
510 |
+
# Bottom-up approach:
|
511 |
+
# Step 1: find all joints in the image (organized by joint type: [0]=nose, [1]=neck...)
|
512 |
+
joint_list_per_joint_type = NMS(param, heatmaps, img_orig.shape[0] / float(heatmaps.shape[0]))
|
513 |
+
# joint_list is an unravel'd version of joint_list_per_joint, where we add
|
514 |
+
# a 5th column to indicate the joint_type (0=nose, 1=neck...)
|
515 |
+
joint_list = np.array(
|
516 |
+
[
|
517 |
+
tuple(peak) + (joint_type,)
|
518 |
+
for joint_type, joint_peaks in enumerate(joint_list_per_joint_type)
|
519 |
+
for peak in joint_peaks
|
520 |
+
]
|
521 |
+
)
|
522 |
+
|
523 |
+
# Step 2: find which joints go together to form limbs (which wrists go with which elbows)
|
524 |
+
paf_upsamp = cv2.resize(
|
525 |
+
pafs,
|
526 |
+
(img_orig.shape[1], img_orig.shape[0]),
|
527 |
+
interpolation=cv2.INTER_CUBIC,
|
528 |
+
)
|
529 |
+
connected_limbs = find_connected_joints(param, paf_upsamp, joint_list_per_joint_type)
|
530 |
+
|
531 |
+
# Step 3: associate limbs that belong to the same person
|
532 |
+
person_to_joint_assoc = group_limbs_of_same_person(connected_limbs, joint_list)
|
533 |
+
|
534 |
+
# Step 4: plot results
|
535 |
+
to_plot, canvas = plot_pose(img_orig, joint_list, person_to_joint_assoc)
|
536 |
+
|
537 |
+
return to_plot, canvas, joint_list, person_to_joint_assoc
|
pose/src/inference/pose_inference.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import (
|
2 |
+
Optional,
|
3 |
+
Tuple,
|
4 |
+
Union,
|
5 |
+
)
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from loguru import logger
|
12 |
+
|
13 |
+
from src.models.networks import OpenPoseNet
|
14 |
+
|
15 |
+
from .base import PoseInferenceBase
|
16 |
+
from .decode import decode_pose
|
17 |
+
|
18 |
+
|
19 |
+
class PoseInference(PoseInferenceBase):
|
20 |
+
def __init__(self, model_weight_path: str, device: Optional[str] = None):
|
21 |
+
super().__init__()
|
22 |
+
self.net = OpenPoseNet()
|
23 |
+
if not device:
|
24 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
25 |
+
else:
|
26 |
+
self.device = device
|
27 |
+
|
28 |
+
net_weights = torch.load(model_weight_path, map_location=self.device)
|
29 |
+
keys = list(net_weights.keys())
|
30 |
+
|
31 |
+
weights_load = {}
|
32 |
+
for i in range(len(keys)):
|
33 |
+
weights_load[list(self.net.state_dict().keys())[i]] = net_weights[list(keys)[i]]
|
34 |
+
|
35 |
+
state = self.net.state_dict()
|
36 |
+
state.update(weights_load)
|
37 |
+
self.net.load_state_dict(state)
|
38 |
+
self.net.eval()
|
39 |
+
|
40 |
+
logger.info(f"Load model successfully to device '{self.device}' for inference")
|
41 |
+
|
42 |
+
def preprocess(
|
43 |
+
self,
|
44 |
+
img: Union[np.ndarray, str],
|
45 |
+
size: Tuple[int] = (368, 368),
|
46 |
+
color_mean: Tuple[float] = (0.485, 0.456, 0.406),
|
47 |
+
color_std: Tuple[float] = (0.229, 0.224, 0.225),
|
48 |
+
) -> torch.Tensor:
|
49 |
+
if isinstance(img, str):
|
50 |
+
original_img = cv2.imread(img)
|
51 |
+
elif isinstance(img, np.ndarray):
|
52 |
+
original_img = img
|
53 |
+
else:
|
54 |
+
raise ValueError("'img' parameter must be of type string or numpy array")
|
55 |
+
|
56 |
+
original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
|
57 |
+
img = cv2.resize(original_img, size, interpolation=cv2.INTER_CUBIC)
|
58 |
+
img = img.astype(np.float32) / 255.0
|
59 |
+
|
60 |
+
preprocessed_img = img.copy()
|
61 |
+
|
62 |
+
for i in range(3):
|
63 |
+
preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - color_mean[i]
|
64 |
+
preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / color_std[i]
|
65 |
+
|
66 |
+
img = preprocessed_img.transpose((2, 0, 1)).astype(np.float32)
|
67 |
+
img = torch.from_numpy(img)
|
68 |
+
img = img.unsqueeze(0)
|
69 |
+
|
70 |
+
return img, original_img
|
71 |
+
|
72 |
+
def process(
|
73 |
+
self,
|
74 |
+
img: Union[np.ndarray, str],
|
75 |
+
size: Tuple[int] = (368, 368),
|
76 |
+
color_mean: Tuple[float] = (0.485, 0.456, 0.406),
|
77 |
+
color_std: Tuple[float] = (0.229, 0.224, 0.225),
|
78 |
+
):
|
79 |
+
preprocessed_img, original_img = self.preprocess(
|
80 |
+
img=img, size=size, color_mean=color_mean, color_std=color_std
|
81 |
+
)
|
82 |
+
# Run model
|
83 |
+
predicted_outputs, _ = self.net(preprocessed_img)
|
84 |
+
|
85 |
+
shape = original_img.shape
|
86 |
+
heatmaps = PoseInference._generate_heatmap(predicted_outputs, size, shape)
|
87 |
+
pafs = PoseInference._generate_part_affinity_fields(predicted_outputs, size, shape)
|
88 |
+
|
89 |
+
result_img = self.postprocess(original_img, heatmaps, pafs)
|
90 |
+
result_img = cv2.cvtColor(result_img, cv2.COLOR_RGB2BGR)
|
91 |
+
return result_img
|
92 |
+
|
93 |
+
def postprocess(
|
94 |
+
self,
|
95 |
+
oriImg: np.ndarray,
|
96 |
+
heatmaps: np.ndarray,
|
97 |
+
pafs: np.ndarray,
|
98 |
+
) -> np.ndarray:
|
99 |
+
_, result_img, _, _ = decode_pose(oriImg, heatmaps, pafs)
|
100 |
+
return result_img
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def _generate_heatmap(
|
104 |
+
predicted_outputs: torch.Tensor,
|
105 |
+
size: Tuple[int],
|
106 |
+
oriImg_shape: Tuple[int],
|
107 |
+
) -> np.ndarray:
|
108 |
+
_heatmaps = predicted_outputs[1][0].detach().numpy().transpose(1, 2, 0)
|
109 |
+
_heatmaps = cv2.resize(_heatmaps, size, interpolation=cv2.INTER_CUBIC)
|
110 |
+
_heatmaps = cv2.resize(
|
111 |
+
_heatmaps,
|
112 |
+
(oriImg_shape[1], oriImg_shape[0]),
|
113 |
+
interpolation=cv2.INTER_CUBIC,
|
114 |
+
)
|
115 |
+
logger.info("Generate heatmap ...")
|
116 |
+
return _heatmaps
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def _generate_part_affinity_fields(
|
120 |
+
predicted_outputs: torch.Tensor,
|
121 |
+
size: Tuple[int],
|
122 |
+
oriImg_shape: Tuple[int],
|
123 |
+
) -> np.ndarray:
|
124 |
+
_pafs = predicted_outputs[0][0].detach().numpy().transpose(1, 2, 0)
|
125 |
+
_pafs = cv2.resize(_pafs, size, interpolation=cv2.INTER_CUBIC)
|
126 |
+
_pafs = cv2.resize(
|
127 |
+
_pafs,
|
128 |
+
(oriImg_shape[1], oriImg_shape[0]),
|
129 |
+
interpolation=cv2.INTER_CUBIC,
|
130 |
+
)
|
131 |
+
logger.info("Generate part affinity fields ...")
|
132 |
+
return _pafs
|
pose/src/weights/pose_model_scratch.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb202d924c7c2b3b3943c879a34fdc539f0e29648764df920bc50d21681272fb
|
3 |
+
size 209282651
|