|
from langchain.tools import tool |
|
from langchain.agents import AgentExecutor, create_tool_calling_agent |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.messages import HumanMessage |
|
from langchain_mistralai.chat_models import ChatMistralAI |
|
import torch |
|
import os |
|
import sys |
|
import json |
|
sys.path.append(os.getcwd()) |
|
from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average, save_knee_angle_fig |
|
|
|
|
|
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4") |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
@tool |
|
def get_keypoints_from_keypoints(video_path: str) -> str: |
|
""" |
|
Extracts keypoints from a video file. |
|
|
|
Args: |
|
video_path (str): path to the video file |
|
|
|
Returns: |
|
file_path (str): path to the JSON file containing the keypoints |
|
""" |
|
|
|
save_folder='tmp' |
|
os.makedirs(save_folder, exist_ok=True) |
|
keypoints = [] |
|
results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device) |
|
for (i, frame) in enumerate(results): |
|
frame_dict = {} |
|
frame_dict['frame'] = i |
|
frame_dict['keypoints'] = frame.keypoints.xy[0].tolist() |
|
keypoints.append(frame_dict) |
|
|
|
file_path = os.path.join(save_folder, 'keypoints.json') |
|
with open(file_path, 'w') as f: |
|
json.dump(keypoints, f) |
|
return file_path |
|
|
|
def compute_right_knee_angle_list(json_path: str) -> list[float]: |
|
""" |
|
Computes the knee angle from a list of keypoints. |
|
|
|
Args: |
|
json_path (str): path to the JSON file containing the keypoints |
|
|
|
Returns: |
|
right_knee_angle_list (list[float]): list of knee angles |
|
""" |
|
keypoints_list = json.load(open(json_path)) |
|
right_knee_angle_list = [] |
|
for keypoints in keypoints_list: |
|
right_knee_angle = compute_right_knee_angle(keypoints['keypoints']) |
|
right_knee_angle_list.append(right_knee_angle) |
|
|
|
right_knee_angle_list = moving_average(right_knee_angle_list, 10) |
|
save_knee_angle_fig(right_knee_angle_list) |
|
return right_knee_angle_list |
|
|
|
def check_knee_angle(json_path: str) -> bool: |
|
""" |
|
Checks if the minimum knee angle is smaller than a threshold. |
|
If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough. |
|
|
|
Args: |
|
json_path (str): path to the JSON file containing the keypoints |
|
|
|
Returns: |
|
is_correct (bool): True if the minimum knee angle is smaller than a threshold, False otherwise |
|
""" |
|
angles_list = compute_right_knee_angle_list(json_path) |
|
for angle in angles_list: |
|
if angle < 90: |
|
return True |
|
return False |
|
|
|
@tool |
|
def check_squat(file_name: str) -> str: |
|
""" |
|
Checks if the squat is correct. |
|
This function uses the check_knee_angle tool to check if the squat is correct. it checks if the user is going deep enough. |
|
|
|
|
|
Args: |
|
video_path (str): path to the video file |
|
|
|
Returns: |
|
is_correct (bool): True if the squat is correct, False otherwise |
|
""" |
|
|
|
video_path = os.path.join('uploaded', file_name) |
|
if os.path.exists(video_path): |
|
json_path = get_keypoints_from_keypoints(video_path) |
|
is_correct = check_knee_angle(json_path) |
|
if is_correct: |
|
return "The squat is correct because your knee angle is smaller than 90 degrees." |
|
else: |
|
return "The squat is incorrect because your knee angle is greater than 90 degrees." |
|
else: |
|
return "The video file does not exist." |
|
|
|
tools = [check_squat] |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
"You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response, giving advices", |
|
), |
|
("placeholder", "{chat_history}"), |
|
("human", "{input}"), |
|
("placeholder", "{agent_scratchpad}"), |
|
] |
|
) |
|
|
|
|
|
agent = create_tool_calling_agent(llm, tools, prompt) |
|
|
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) |