Théo Rousseaux
UI
261f2d7
from langchain.tools import tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_mistralai.chat_models import ChatMistralAI
import torch
import os
import sys
import json
sys.path.append(os.getcwd())
from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average, save_knee_angle_fig
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@tool
def get_keypoints_from_keypoints(video_path: str) -> str:
"""
Extracts keypoints from a video file.
Args:
video_path (str): path to the video file
Returns:
file_path (str): path to the JSON file containing the keypoints
"""
save_folder='tmp'
os.makedirs(save_folder, exist_ok=True)
keypoints = []
results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device)
for (i, frame) in enumerate(results):
frame_dict = {}
frame_dict['frame'] = i
frame_dict['keypoints'] = frame.keypoints.xy[0].tolist()
keypoints.append(frame_dict)
file_path = os.path.join(save_folder, 'keypoints.json')
with open(file_path, 'w') as f:
json.dump(keypoints, f)
return file_path
def compute_right_knee_angle_list(json_path: str) -> list[float]:
"""
Computes the knee angle from a list of keypoints.
Args:
json_path (str): path to the JSON file containing the keypoints
Returns:
right_knee_angle_list (list[float]): list of knee angles
"""
keypoints_list = json.load(open(json_path))
right_knee_angle_list = []
for keypoints in keypoints_list:
right_knee_angle = compute_right_knee_angle(keypoints['keypoints'])
right_knee_angle_list.append(right_knee_angle)
right_knee_angle_list = moving_average(right_knee_angle_list, 10)
save_knee_angle_fig(right_knee_angle_list)
return right_knee_angle_list
def check_knee_angle(json_path: str) -> bool:
"""
Checks if the minimum knee angle is smaller than a threshold.
If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
Args:
json_path (str): path to the JSON file containing the keypoints
Returns:
is_correct (bool): True if the minimum knee angle is smaller than a threshold, False otherwise
"""
angles_list = compute_right_knee_angle_list(json_path)
for angle in angles_list:
if angle < 90:
return True
return False
@tool
def check_squat(file_name: str) -> str:
"""
Checks if the squat is correct.
This function uses the check_knee_angle tool to check if the squat is correct. it checks if the user is going deep enough.
Args:
video_path (str): path to the video file
Returns:
is_correct (bool): True if the squat is correct, False otherwise
"""
video_path = os.path.join('uploaded', file_name)
if os.path.exists(video_path):
json_path = get_keypoints_from_keypoints(video_path)
is_correct = check_knee_angle(json_path)
if is_correct:
return "The squat is correct because your knee angle is smaller than 90 degrees."
else:
return "The squat is incorrect because your knee angle is greater than 90 degrees."
else:
return "The video file does not exist."
tools = [check_squat]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response, giving advices",
),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
# Construct the Tools agent
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)