Spaces:
Sleeping
Sleeping
from langchain.tools import tool | |
from langchain.agents import AgentExecutor, create_tool_calling_agent | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.messages import HumanMessage | |
from langchain_mistralai.chat_models import ChatMistralAI | |
import torch | |
import os | |
import sys | |
import json | |
sys.path.append(os.getcwd()) | |
from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average, save_knee_angle_fig | |
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable. | |
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4") | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def get_keypoints_from_keypoints(video_path: str) -> str: | |
""" | |
Extracts keypoints from a video file. | |
Args: | |
video_path (str): path to the video file | |
Returns: | |
file_path (str): path to the JSON file containing the keypoints | |
""" | |
save_folder='tmp' | |
os.makedirs(save_folder, exist_ok=True) | |
keypoints = [] | |
results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device) | |
for (i, frame) in enumerate(results): | |
frame_dict = {} | |
frame_dict['frame'] = i | |
frame_dict['keypoints'] = frame.keypoints.xy[0].tolist() | |
keypoints.append(frame_dict) | |
file_path = os.path.join(save_folder, 'keypoints.json') | |
with open(file_path, 'w') as f: | |
json.dump(keypoints, f) | |
return file_path | |
def compute_right_knee_angle_list(json_path: str) -> list[float]: | |
""" | |
Computes the knee angle from a list of keypoints. | |
Args: | |
json_path (str): path to the JSON file containing the keypoints | |
Returns: | |
right_knee_angle_list (list[float]): list of knee angles | |
""" | |
keypoints_list = json.load(open(json_path)) | |
right_knee_angle_list = [] | |
for keypoints in keypoints_list: | |
right_knee_angle = compute_right_knee_angle(keypoints['keypoints']) | |
right_knee_angle_list.append(right_knee_angle) | |
right_knee_angle_list = moving_average(right_knee_angle_list, 10) | |
save_knee_angle_fig(right_knee_angle_list) | |
return right_knee_angle_list | |
def check_knee_angle(json_path: str) -> bool: | |
""" | |
Checks if the minimum knee angle is smaller than a threshold. | |
If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough. | |
Args: | |
json_path (str): path to the JSON file containing the keypoints | |
Returns: | |
is_correct (bool): True if the minimum knee angle is smaller than a threshold, False otherwise | |
""" | |
angles_list = compute_right_knee_angle_list(json_path) | |
for angle in angles_list: | |
if angle < 90: | |
return True | |
return False | |
def check_squat(file_name: str) -> str: | |
""" | |
Checks if the squat is correct. | |
This function uses the check_knee_angle tool to check if the squat is correct. it checks if the user is going deep enough. | |
Args: | |
video_path (str): path to the video file | |
Returns: | |
is_correct (bool): True if the squat is correct, False otherwise | |
""" | |
video_path = os.path.join('uploaded', file_name) | |
if os.path.exists(video_path): | |
json_path = get_keypoints_from_keypoints(video_path) | |
is_correct = check_knee_angle(json_path) | |
if is_correct: | |
return "The squat is correct because your knee angle is smaller than 90 degrees." | |
else: | |
return "The squat is incorrect because your knee angle is greater than 90 degrees." | |
else: | |
return "The video file does not exist." | |
tools = [check_squat] | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response, giving advices", | |
), | |
("placeholder", "{chat_history}"), | |
("human", "{input}"), | |
("placeholder", "{agent_scratchpad}"), | |
] | |
) | |
# Construct the Tools agent | |
agent = create_tool_calling_agent(llm, tools, prompt) | |
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) |