Théo Rousseaux commited on
Commit
9a8177e
1 Parent(s): d105aff

agent pose new version

Browse files
Modules/PoseEstimation/pose_agent.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import sys
8
  import json
9
  sys.path.append(os.getcwd())
10
- from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average
11
 
12
  # If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
13
  llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
@@ -54,9 +54,11 @@ def compute_right_knee_angle_list(json_path: str) -> list[float]:
54
  for keypoints in keypoints_list:
55
  right_knee_angle = compute_right_knee_angle(keypoints['keypoints'])
56
  right_knee_angle_list.append(right_knee_angle)
57
- return moving_average(right_knee_angle_list, 10)
58
 
59
- @tool
 
 
 
60
  def check_knee_angle(json_path: str) -> bool:
61
  """
62
  Checks if the minimum knee angle is smaller than a threshold.
@@ -74,13 +76,29 @@ def check_knee_angle(json_path: str) -> bool:
74
  return True
75
  return False
76
 
77
- tools = [get_keypoints_from_keypoints, check_knee_angle]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  prompt = ChatPromptTemplate.from_messages(
80
  [
81
  (
82
  "system",
83
- "You are a helpful assistant. Make sure to use the check_knee_angle tool if the user wants to check his movement.",
84
  ),
85
  ("placeholder", "{chat_history}"),
86
  ("human", "{input}"),
@@ -92,5 +110,5 @@ prompt = ChatPromptTemplate.from_messages(
92
  agent = create_tool_calling_agent(llm, tools, prompt)
93
 
94
  agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
95
- response = agent_executor.invoke({"input": f"Is my squat correct ? The json file is in tmp/keypoints.json."})
96
  print(response["output"])
 
7
  import sys
8
  import json
9
  sys.path.append(os.getcwd())
10
+ from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average, save_knee_angle_fig
11
 
12
  # If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
13
  llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
 
54
  for keypoints in keypoints_list:
55
  right_knee_angle = compute_right_knee_angle(keypoints['keypoints'])
56
  right_knee_angle_list.append(right_knee_angle)
 
57
 
58
+ right_knee_angle_list = moving_average(right_knee_angle_list, 10)
59
+ save_knee_angle_fig(right_knee_angle_list)
60
+ return right_knee_angle_list
61
+
62
  def check_knee_angle(json_path: str) -> bool:
63
  """
64
  Checks if the minimum knee angle is smaller than a threshold.
 
76
  return True
77
  return False
78
 
79
+ @tool
80
+ def check_squat(video_path: str) -> bool:
81
+ """
82
+ Checks if the squat is correct.
83
+ This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
84
+
85
+
86
+ Args:
87
+ video_path (str): path to the video file
88
+
89
+ Returns:
90
+ is_correct (bool): True if the squat is correct, False otherwise
91
+ """
92
+ json_path = get_keypoints_from_keypoints(video_path)
93
+ return check_knee_angle(json_path)
94
+
95
+ tools = [check_squat]
96
 
97
  prompt = ChatPromptTemplate.from_messages(
98
  [
99
  (
100
  "system",
101
+ "You are a helpful assistant. Make sure to use the check_knee_angle tool if the user wants to check his movement. Also explain your response",
102
  ),
103
  ("placeholder", "{chat_history}"),
104
  ("human", "{input}"),
 
110
  agent = create_tool_calling_agent(llm, tools, prompt)
111
 
112
  agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
113
+ response = agent_executor.invoke({"input": f"Is my squat correct ? The video file is in data/pose/squat.mp4."})
114
  print(response["output"])
Modules/PoseEstimation/pose_estimation.ipynb CHANGED
@@ -1062,7 +1062,12 @@
1062
  "source": [
1063
  "import matplotlib.pyplot as plt\n",
1064
  "\n",
1065
- "plt.plot(angles)"
 
 
 
 
 
1066
  ]
1067
  },
1068
  {
 
1062
  "source": [
1063
  "import matplotlib.pyplot as plt\n",
1064
  "\n",
1065
+ "os.makedirs('fig', exist_ok=True)\n",
1066
+ "plt.plot(angles)\n",
1067
+ "plt.xlabel('Frame')\n",
1068
+ "plt.ylabel('Knee Angle')\n",
1069
+ "plt.title('Evolution of the knee angle')\n",
1070
+ "plt.savefig('fig/knee_angle.png')"
1071
  ]
1072
  },
1073
  {
Modules/PoseEstimation/pose_estimator.py CHANGED
@@ -2,6 +2,7 @@ from ultralytics import YOLO
2
  import numpy as np
3
  import os
4
  import json
 
5
 
6
  id_joints_dict = {0: 'nose',
7
  1: 'left_eye',
@@ -99,4 +100,12 @@ def moving_average(data, window_size):
99
  for i in range(len(data) - window_size + 1):
100
  avg.append(sum(data[i:i + window_size]) / window_size)
101
 
102
- return avg
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import os
4
  import json
5
+ import matplotlib.pyplot as plt
6
 
7
  id_joints_dict = {0: 'nose',
8
  1: 'left_eye',
 
100
  for i in range(len(data) - window_size + 1):
101
  avg.append(sum(data[i:i + window_size]) / window_size)
102
 
103
+ return avg
104
+
105
+ def save_knee_angle_fig(angles):
106
+ os.makedirs('fig', exist_ok=True)
107
+ plt.plot(angles)
108
+ plt.xlabel('Frame')
109
+ plt.ylabel('Knee Angle')
110
+ plt.title('Evolution of the knee angle')
111
+ plt.savefig('fig/knee_angle.png')