Spidartist commited on
Commit
a030099
1 Parent(s): 237ec65
Files changed (3) hide show
  1. app.py +120 -0
  2. model.py +274 -0
  3. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+ import tarfile
5
+
6
+ import gradio as gr
7
+
8
+ from model import AppModel
9
+
10
+ DESCRIPTION = '''# ViTPose
11
+ This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](https://github.com/ViTAE-Transformer/ViTPose).
12
+ Related app: [https://huggingface.co/spaces/Gradio-Blocks/ViTPose](https://huggingface.co/spaces/Gradio-Blocks/ViTPose)
13
+ '''
14
+
15
+
16
+ def set_example_video(example: list) -> dict:
17
+ return gr.Video.update(value=example[0])
18
+
19
+
20
+ def extract_tar() -> None:
21
+ if pathlib.Path('mmdet_configs/configs').exists():
22
+ return
23
+ with tarfile.open('mmdet_configs/configs.tar') as f:
24
+ f.extractall('mmdet_configs')
25
+
26
+
27
+ extract_tar()
28
+
29
+ model = AppModel()
30
+
31
+ with gr.Blocks(css='style.css') as demo:
32
+ gr.Markdown(DESCRIPTION)
33
+
34
+ with gr.Row():
35
+ with gr.Column():
36
+ input_video = gr.Video(label='Input Video',
37
+ format='mp4',
38
+ elem_id='input_video')
39
+ detector_name = gr.Dropdown(list(
40
+ model.det_model.MODEL_DICT.keys()),
41
+ value=model.det_model.model_name,
42
+ label='Detector')
43
+ pose_model_name = gr.Dropdown(list(
44
+ model.pose_model.MODEL_DICT.keys()),
45
+ value=model.pose_model.model_name,
46
+ label='Pose Model')
47
+ det_score_threshold = gr.Slider(0,
48
+ 1,
49
+ step=0.05,
50
+ value=0.5,
51
+ label='Box Score Threshold')
52
+ max_num_frames = gr.Slider(1,
53
+ 300,
54
+ step=1,
55
+ value=60,
56
+ label='Maximum Number of Frames')
57
+ predict_button = gr.Button(value='Predict')
58
+ pose_preds = gr.Variable()
59
+
60
+ paths = sorted(pathlib.Path('videos').rglob('*.mp4'))
61
+ example_videos = gr.Dataset(components=[input_video],
62
+ samples=[[path.as_posix()]
63
+ for path in paths])
64
+
65
+ with gr.Column():
66
+ result = gr.Video(label='Result', format='mp4', elem_id='result')
67
+ vis_kpt_score_threshold = gr.Slider(
68
+ 0,
69
+ 1,
70
+ step=0.05,
71
+ value=0.3,
72
+ label='Visualization Score Threshold')
73
+ vis_dot_radius = gr.Slider(1,
74
+ 10,
75
+ step=1,
76
+ value=4,
77
+ label='Dot Radius')
78
+ vis_line_thickness = gr.Slider(1,
79
+ 10,
80
+ step=1,
81
+ value=2,
82
+ label='Line Thickness')
83
+ redraw_button = gr.Button(value='Redraw')
84
+
85
+ detector_name.change(fn=model.det_model.set_model,
86
+ inputs=detector_name,
87
+ outputs=None)
88
+ pose_model_name.change(fn=model.pose_model.set_model,
89
+ inputs=pose_model_name,
90
+ outputs=None)
91
+ predict_button.click(fn=model.run,
92
+ inputs=[
93
+ input_video,
94
+ detector_name,
95
+ pose_model_name,
96
+ det_score_threshold,
97
+ max_num_frames,
98
+ vis_kpt_score_threshold,
99
+ vis_dot_radius,
100
+ vis_line_thickness,
101
+ ],
102
+ outputs=[
103
+ result,
104
+ pose_preds,
105
+ ])
106
+ redraw_button.click(fn=model.visualize_pose_results,
107
+ inputs=[
108
+ input_video,
109
+ pose_preds,
110
+ vis_kpt_score_threshold,
111
+ vis_dot_radius,
112
+ vis_line_thickness,
113
+ ],
114
+ outputs=result)
115
+
116
+ example_videos.click(fn=set_example_video,
117
+ inputs=example_videos,
118
+ outputs=input_video)
119
+
120
+ demo.queue().launch(show_api=False)
model.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shlex
5
+ import subprocess
6
+ import sys
7
+ import tempfile
8
+
9
+ if os.getenv('SYSTEM') == 'spaces':
10
+ import mim
11
+
12
+ mim.uninstall('mmcv-full', confirm_yes=True)
13
+ mim.install('mmcv-full==1.5.0', is_yes=True)
14
+
15
+ subprocess.call(shlex.split('pip uninstall -y opencv-python'))
16
+ subprocess.call(shlex.split('pip uninstall -y opencv-python-headless'))
17
+ subprocess.call(
18
+ shlex.split('pip install opencv-python-headless==4.5.5.64'))
19
+
20
+ import cv2
21
+ import huggingface_hub
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ sys.path.insert(0, 'ViTPose/')
27
+
28
+ from mmdet.apis import inference_detector, init_detector
29
+ from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
30
+ process_mmdet_results, vis_pose_result)
31
+
32
+ HF_TOKEN = os.getenv('HF_TOKEN')
33
+
34
+
35
+ class DetModel:
36
+ MODEL_DICT = {
37
+ 'YOLOX-tiny': {
38
+ 'config':
39
+ 'mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py',
40
+ 'model':
41
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth',
42
+ },
43
+ 'YOLOX-s': {
44
+ 'config':
45
+ 'mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py',
46
+ 'model':
47
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth',
48
+ },
49
+ 'YOLOX-l': {
50
+ 'config':
51
+ 'mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py',
52
+ 'model':
53
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
54
+ },
55
+ 'YOLOX-x': {
56
+ 'config':
57
+ 'mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py',
58
+ 'model':
59
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth',
60
+ },
61
+ }
62
+
63
+ def __init__(self):
64
+ self.device = torch.device(
65
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
66
+ self._load_all_models_once()
67
+ self.model_name = 'YOLOX-l'
68
+ self.model = self._load_model(self.model_name)
69
+
70
+ def _load_all_models_once(self) -> None:
71
+ for name in self.MODEL_DICT:
72
+ self._load_model(name)
73
+
74
+ def _load_model(self, name: str) -> nn.Module:
75
+ dic = self.MODEL_DICT[name]
76
+ return init_detector(dic['config'], dic['model'], device=self.device)
77
+
78
+ def set_model(self, name: str) -> None:
79
+ if name == self.model_name:
80
+ return
81
+ self.model_name = name
82
+ self.model = self._load_model(name)
83
+
84
+ def detect_and_visualize(
85
+ self, image: np.ndarray,
86
+ score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
87
+ out = self.detect(image)
88
+ vis = self.visualize_detection_results(image, out, score_threshold)
89
+ return out, vis
90
+
91
+ def detect(self, image: np.ndarray) -> list[np.ndarray]:
92
+ image = image[:, :, ::-1] # RGB -> BGR
93
+ out = inference_detector(self.model, image)
94
+ return out
95
+
96
+ def visualize_detection_results(
97
+ self,
98
+ image: np.ndarray,
99
+ detection_results: list[np.ndarray],
100
+ score_threshold: float = 0.3) -> np.ndarray:
101
+ person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)] * 79
102
+
103
+ image = image[:, :, ::-1] # RGB -> BGR
104
+ vis = self.model.show_result(image,
105
+ person_det,
106
+ score_thr=score_threshold,
107
+ bbox_color=None,
108
+ text_color=(200, 200, 200),
109
+ mask_color=None)
110
+ return vis[:, :, ::-1] # BGR -> RGB
111
+
112
+
113
+ class PoseModel:
114
+ MODEL_DICT = {
115
+ 'ViTPose-B (single-task train)': {
116
+ 'config':
117
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
118
+ 'model': 'models/vitpose-b.pth',
119
+ },
120
+ 'ViTPose-L (single-task train)': {
121
+ 'config':
122
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
123
+ 'model': 'models/vitpose-l.pth',
124
+ },
125
+ 'ViTPose-B (multi-task train, COCO)': {
126
+ 'config':
127
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
128
+ 'model': 'models/vitpose-b-multi-coco.pth',
129
+ },
130
+ 'ViTPose-L (multi-task train, COCO)': {
131
+ 'config':
132
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
133
+ 'model': 'models/vitpose-l-multi-coco.pth',
134
+ },
135
+ }
136
+
137
+ def __init__(self):
138
+ self.device = torch.device(
139
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
140
+ self.model_name = 'ViTPose-B (multi-task train, COCO)'
141
+ self.model = self._load_model(self.model_name)
142
+
143
+ def _load_all_models_once(self) -> None:
144
+ for name in self.MODEL_DICT:
145
+ self._load_model(name)
146
+
147
+ def _load_model(self, name: str) -> nn.Module:
148
+ dic = self.MODEL_DICT[name]
149
+ ckpt_path = huggingface_hub.hf_hub_download('hysts/ViTPose',
150
+ dic['model'],
151
+ use_auth_token=HF_TOKEN)
152
+ model = init_pose_model(dic['config'], ckpt_path, device=self.device)
153
+ return model
154
+
155
+ def set_model(self, name: str) -> None:
156
+ if name == self.model_name:
157
+ return
158
+ self.model_name = name
159
+ self.model = self._load_model(name)
160
+
161
+ def predict_pose_and_visualize(
162
+ self,
163
+ image: np.ndarray,
164
+ det_results: list[np.ndarray],
165
+ box_score_threshold: float,
166
+ kpt_score_threshold: float,
167
+ vis_dot_radius: int,
168
+ vis_line_thickness: int,
169
+ ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
170
+ out = self.predict_pose(image, det_results, box_score_threshold)
171
+ vis = self.visualize_pose_results(image, out, kpt_score_threshold,
172
+ vis_dot_radius, vis_line_thickness)
173
+ return out, vis
174
+
175
+ def predict_pose(
176
+ self,
177
+ image: np.ndarray,
178
+ det_results: list[np.ndarray],
179
+ box_score_threshold: float = 0.5) -> list[dict[str, np.ndarray]]:
180
+ image = image[:, :, ::-1] # RGB -> BGR
181
+ person_results = process_mmdet_results(det_results, 1)
182
+ out, _ = inference_top_down_pose_model(self.model,
183
+ image,
184
+ person_results=person_results,
185
+ bbox_thr=box_score_threshold,
186
+ format='xyxy')
187
+ return out
188
+
189
+ def visualize_pose_results(self,
190
+ image: np.ndarray,
191
+ pose_results: list[dict[str, np.ndarray]],
192
+ kpt_score_threshold: float = 0.3,
193
+ vis_dot_radius: int = 4,
194
+ vis_line_thickness: int = 1) -> np.ndarray:
195
+ image = image[:, :, ::-1] # RGB -> BGR
196
+ vis = vis_pose_result(self.model,
197
+ image,
198
+ pose_results,
199
+ kpt_score_thr=kpt_score_threshold,
200
+ radius=vis_dot_radius,
201
+ thickness=vis_line_thickness)
202
+ return vis[:, :, ::-1] # BGR -> RGB
203
+
204
+
205
+ class AppModel:
206
+ def __init__(self):
207
+ self.det_model = DetModel()
208
+ self.pose_model = PoseModel()
209
+
210
+ def run(
211
+ self, video_path: str, det_model_name: str, pose_model_name: str,
212
+ box_score_threshold: float, max_num_frames: int,
213
+ kpt_score_threshold: float, vis_dot_radius: int,
214
+ vis_line_thickness: int
215
+ ) -> tuple[str, list[list[dict[str, np.ndarray]]]]:
216
+ if video_path is None:
217
+ return
218
+ self.det_model.set_model(det_model_name)
219
+ self.pose_model.set_model(pose_model_name)
220
+
221
+ cap = cv2.VideoCapture(video_path)
222
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
223
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
224
+ fps = cap.get(cv2.CAP_PROP_FPS)
225
+
226
+ preds_all = []
227
+
228
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
229
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
230
+ writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
231
+ for _ in range(max_num_frames):
232
+ ok, frame = cap.read()
233
+ if not ok:
234
+ break
235
+ rgb_frame = frame[:, :, ::-1]
236
+ det_preds = self.det_model.detect(rgb_frame)
237
+ preds, vis = self.pose_model.predict_pose_and_visualize(
238
+ rgb_frame, det_preds, box_score_threshold, kpt_score_threshold,
239
+ vis_dot_radius, vis_line_thickness)
240
+ preds_all.append(preds)
241
+ writer.write(vis[:, :, ::-1])
242
+ cap.release()
243
+ writer.release()
244
+
245
+ return out_file.name, preds_all
246
+
247
+ def visualize_pose_results(self, video_path: str,
248
+ pose_preds_all: list[list[dict[str,
249
+ np.ndarray]]],
250
+ kpt_score_threshold: float, vis_dot_radius: int,
251
+ vis_line_thickness: int) -> str:
252
+ if video_path is None or pose_preds_all is None:
253
+ return
254
+ cap = cv2.VideoCapture(video_path)
255
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
256
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
257
+ fps = cap.get(cv2.CAP_PROP_FPS)
258
+
259
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
260
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
261
+ writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
262
+ for pose_preds in pose_preds_all:
263
+ ok, frame = cap.read()
264
+ if not ok:
265
+ break
266
+ rgb_frame = frame[:, :, ::-1]
267
+ vis = self.pose_model.visualize_pose_results(
268
+ rgb_frame, pose_preds, kpt_score_threshold, vis_dot_radius,
269
+ vis_line_thickness)
270
+ writer.write(vis[:, :, ::-1])
271
+ cap.release()
272
+ writer.release()
273
+
274
+ return out_file.name
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ mmcv-full==1.5.0
2
+ mmdet==2.24.1
3
+ mmpose==0.25.1
4
+ numpy==1.23.5
5
+ opencv-python-headless==4.5.5.64
6
+ openmim==0.1.5
7
+ timm==0.5.4
8
+ torch==1.11.0
9
+ torchvision==0.12.0