Spaces:
Sleeping
Sleeping
Théo Rousseaux
commited on
Commit
•
9a8177e
1
Parent(s):
d105aff
agent pose new version
Browse files
Modules/PoseEstimation/pose_agent.py
CHANGED
@@ -7,7 +7,7 @@ import os
|
|
7 |
import sys
|
8 |
import json
|
9 |
sys.path.append(os.getcwd())
|
10 |
-
from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average
|
11 |
|
12 |
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
|
13 |
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
|
@@ -54,9 +54,11 @@ def compute_right_knee_angle_list(json_path: str) -> list[float]:
|
|
54 |
for keypoints in keypoints_list:
|
55 |
right_knee_angle = compute_right_knee_angle(keypoints['keypoints'])
|
56 |
right_knee_angle_list.append(right_knee_angle)
|
57 |
-
return moving_average(right_knee_angle_list, 10)
|
58 |
|
59 |
-
|
|
|
|
|
|
|
60 |
def check_knee_angle(json_path: str) -> bool:
|
61 |
"""
|
62 |
Checks if the minimum knee angle is smaller than a threshold.
|
@@ -74,13 +76,29 @@ def check_knee_angle(json_path: str) -> bool:
|
|
74 |
return True
|
75 |
return False
|
76 |
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
prompt = ChatPromptTemplate.from_messages(
|
80 |
[
|
81 |
(
|
82 |
"system",
|
83 |
-
"You are a helpful assistant. Make sure to use the check_knee_angle tool if the user wants to check his movement.",
|
84 |
),
|
85 |
("placeholder", "{chat_history}"),
|
86 |
("human", "{input}"),
|
@@ -92,5 +110,5 @@ prompt = ChatPromptTemplate.from_messages(
|
|
92 |
agent = create_tool_calling_agent(llm, tools, prompt)
|
93 |
|
94 |
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
95 |
-
response = agent_executor.invoke({"input": f"Is my squat correct ? The
|
96 |
print(response["output"])
|
|
|
7 |
import sys
|
8 |
import json
|
9 |
sys.path.append(os.getcwd())
|
10 |
+
from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average, save_knee_angle_fig
|
11 |
|
12 |
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
|
13 |
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
|
|
|
54 |
for keypoints in keypoints_list:
|
55 |
right_knee_angle = compute_right_knee_angle(keypoints['keypoints'])
|
56 |
right_knee_angle_list.append(right_knee_angle)
|
|
|
57 |
|
58 |
+
right_knee_angle_list = moving_average(right_knee_angle_list, 10)
|
59 |
+
save_knee_angle_fig(right_knee_angle_list)
|
60 |
+
return right_knee_angle_list
|
61 |
+
|
62 |
def check_knee_angle(json_path: str) -> bool:
|
63 |
"""
|
64 |
Checks if the minimum knee angle is smaller than a threshold.
|
|
|
76 |
return True
|
77 |
return False
|
78 |
|
79 |
+
@tool
|
80 |
+
def check_squat(video_path: str) -> bool:
|
81 |
+
"""
|
82 |
+
Checks if the squat is correct.
|
83 |
+
This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
|
84 |
+
|
85 |
+
|
86 |
+
Args:
|
87 |
+
video_path (str): path to the video file
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
is_correct (bool): True if the squat is correct, False otherwise
|
91 |
+
"""
|
92 |
+
json_path = get_keypoints_from_keypoints(video_path)
|
93 |
+
return check_knee_angle(json_path)
|
94 |
+
|
95 |
+
tools = [check_squat]
|
96 |
|
97 |
prompt = ChatPromptTemplate.from_messages(
|
98 |
[
|
99 |
(
|
100 |
"system",
|
101 |
+
"You are a helpful assistant. Make sure to use the check_knee_angle tool if the user wants to check his movement. Also explain your response",
|
102 |
),
|
103 |
("placeholder", "{chat_history}"),
|
104 |
("human", "{input}"),
|
|
|
110 |
agent = create_tool_calling_agent(llm, tools, prompt)
|
111 |
|
112 |
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
113 |
+
response = agent_executor.invoke({"input": f"Is my squat correct ? The video file is in data/pose/squat.mp4."})
|
114 |
print(response["output"])
|
Modules/PoseEstimation/pose_estimation.ipynb
CHANGED
@@ -1062,7 +1062,12 @@
|
|
1062 |
"source": [
|
1063 |
"import matplotlib.pyplot as plt\n",
|
1064 |
"\n",
|
1065 |
-
"
|
|
|
|
|
|
|
|
|
|
|
1066 |
]
|
1067 |
},
|
1068 |
{
|
|
|
1062 |
"source": [
|
1063 |
"import matplotlib.pyplot as plt\n",
|
1064 |
"\n",
|
1065 |
+
"os.makedirs('fig', exist_ok=True)\n",
|
1066 |
+
"plt.plot(angles)\n",
|
1067 |
+
"plt.xlabel('Frame')\n",
|
1068 |
+
"plt.ylabel('Knee Angle')\n",
|
1069 |
+
"plt.title('Evolution of the knee angle')\n",
|
1070 |
+
"plt.savefig('fig/knee_angle.png')"
|
1071 |
]
|
1072 |
},
|
1073 |
{
|
Modules/PoseEstimation/pose_estimator.py
CHANGED
@@ -2,6 +2,7 @@ from ultralytics import YOLO
|
|
2 |
import numpy as np
|
3 |
import os
|
4 |
import json
|
|
|
5 |
|
6 |
id_joints_dict = {0: 'nose',
|
7 |
1: 'left_eye',
|
@@ -99,4 +100,12 @@ def moving_average(data, window_size):
|
|
99 |
for i in range(len(data) - window_size + 1):
|
100 |
avg.append(sum(data[i:i + window_size]) / window_size)
|
101 |
|
102 |
-
return avg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import numpy as np
|
3 |
import os
|
4 |
import json
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
|
7 |
id_joints_dict = {0: 'nose',
|
8 |
1: 'left_eye',
|
|
|
100 |
for i in range(len(data) - window_size + 1):
|
101 |
avg.append(sum(data[i:i + window_size]) / window_size)
|
102 |
|
103 |
+
return avg
|
104 |
+
|
105 |
+
def save_knee_angle_fig(angles):
|
106 |
+
os.makedirs('fig', exist_ok=True)
|
107 |
+
plt.plot(angles)
|
108 |
+
plt.xlabel('Frame')
|
109 |
+
plt.ylabel('Knee Angle')
|
110 |
+
plt.title('Evolution of the knee angle')
|
111 |
+
plt.savefig('fig/knee_angle.png')
|