from langchain.tools import tool from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import HumanMessage from langchain_mistralai.chat_models import ChatMistralAI import torch import os import sys import json sys.path.append(os.getcwd()) from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average, save_knee_angle_fig # If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable. llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4") device = 'cuda' if torch.cuda.is_available() else 'cpu' @tool def get_keypoints_from_keypoints(video_path: str) -> str: """ Extracts keypoints from a video file. Args: video_path (str): path to the video file Returns: file_path (str): path to the JSON file containing the keypoints """ save_folder='tmp' os.makedirs(save_folder, exist_ok=True) keypoints = [] results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device) for (i, frame) in enumerate(results): frame_dict = {} frame_dict['frame'] = i frame_dict['keypoints'] = frame.keypoints.xy[0].tolist() keypoints.append(frame_dict) file_path = os.path.join(save_folder, 'keypoints.json') with open(file_path, 'w') as f: json.dump(keypoints, f) return file_path def compute_right_knee_angle_list(json_path: str) -> list[float]: """ Computes the knee angle from a list of keypoints. Args: json_path (str): path to the JSON file containing the keypoints Returns: right_knee_angle_list (list[float]): list of knee angles """ keypoints_list = json.load(open(json_path)) right_knee_angle_list = [] for keypoints in keypoints_list: right_knee_angle = compute_right_knee_angle(keypoints['keypoints']) right_knee_angle_list.append(right_knee_angle) right_knee_angle_list = moving_average(right_knee_angle_list, 10) save_knee_angle_fig(right_knee_angle_list) return right_knee_angle_list def check_knee_angle(json_path: str) -> bool: """ Checks if the minimum knee angle is smaller than a threshold. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough. Args: json_path (str): path to the JSON file containing the keypoints Returns: is_correct (bool): True if the minimum knee angle is smaller than a threshold, False otherwise """ angles_list = compute_right_knee_angle_list(json_path) for angle in angles_list: if angle < 90: return True return False @tool def check_squat(file_name: str) -> str: """ Checks if the squat is correct. This function uses the check_knee_angle tool to check if the squat is correct. it checks if the user is going deep enough. Args: video_path (str): path to the video file Returns: is_correct (bool): True if the squat is correct, False otherwise """ video_path = os.path.join('uploaded', file_name) if os.path.exists(video_path): json_path = get_keypoints_from_keypoints(video_path) is_correct = check_knee_angle(json_path) if is_correct: return "The squat is correct because your knee angle is smaller than 90 degrees." else: return "The squat is incorrect because your knee angle is greater than 90 degrees." else: return "The video file does not exist." tools = [check_squat] prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response, giving advices", ), ("placeholder", "{chat_history}"), ("human", "{input}"), ("placeholder", "{agent_scratchpad}"), ] ) # Construct the Tools agent agent = create_tool_calling_agent(llm, tools, prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)