Spaces:
Sleeping
Sleeping
Théo Rousseaux
commited on
Commit
•
989534c
1
Parent(s):
ed250fe
POSE AGENT WORKING
Browse files- Modules/PoseEstimation/pose_agent.py +14 -7
- app.py +34 -23
Modules/PoseEstimation/pose_agent.py
CHANGED
@@ -77,7 +77,7 @@ def check_knee_angle(json_path: str) -> bool:
|
|
77 |
return False
|
78 |
|
79 |
@tool
|
80 |
-
def check_squat(
|
81 |
"""
|
82 |
Checks if the squat is correct.
|
83 |
This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
|
@@ -89,8 +89,17 @@ def check_squat(video_path: str) -> bool:
|
|
89 |
Returns:
|
90 |
is_correct (bool): True if the squat is correct, False otherwise
|
91 |
"""
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
tools = [check_squat]
|
96 |
|
@@ -98,7 +107,7 @@ prompt = ChatPromptTemplate.from_messages(
|
|
98 |
[
|
99 |
(
|
100 |
"system",
|
101 |
-
"You are a helpful assistant. Make sure to use the
|
102 |
),
|
103 |
("placeholder", "{chat_history}"),
|
104 |
("human", "{input}"),
|
@@ -109,6 +118,4 @@ prompt = ChatPromptTemplate.from_messages(
|
|
109 |
# Construct the Tools agent
|
110 |
agent = create_tool_calling_agent(llm, tools, prompt)
|
111 |
|
112 |
-
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
113 |
-
response = agent_executor.invoke({"input": f"Is my squat correct ? The video file is in data/pose/squat.mp4."})
|
114 |
-
print(response["output"])
|
|
|
77 |
return False
|
78 |
|
79 |
@tool
|
80 |
+
def check_squat(file_name: str) -> str:
|
81 |
"""
|
82 |
Checks if the squat is correct.
|
83 |
This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
|
|
|
89 |
Returns:
|
90 |
is_correct (bool): True if the squat is correct, False otherwise
|
91 |
"""
|
92 |
+
|
93 |
+
video_path = os.path.join('uploaded', file_name)
|
94 |
+
if os.path.exists(video_path):
|
95 |
+
json_path = get_keypoints_from_keypoints(video_path)
|
96 |
+
is_correct = check_knee_angle(json_path)
|
97 |
+
if is_correct:
|
98 |
+
return "The squat is correct because your knee angle is smaller than 90 degrees."
|
99 |
+
else:
|
100 |
+
return "The squat is incorrect because your knee angle is greater than 90 degrees."
|
101 |
+
else:
|
102 |
+
return "The video file does not exist."
|
103 |
|
104 |
tools = [check_squat]
|
105 |
|
|
|
107 |
[
|
108 |
(
|
109 |
"system",
|
110 |
+
"You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response",
|
111 |
),
|
112 |
("placeholder", "{chat_history}"),
|
113 |
("human", "{input}"),
|
|
|
118 |
# Construct the Tools agent
|
119 |
agent = create_tool_calling_agent(llm, tools, prompt)
|
120 |
|
121 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
|
|
|
app.py
CHANGED
@@ -7,10 +7,10 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
7 |
from dotenv import load_dotenv
|
8 |
load_dotenv() # load .env api keys
|
9 |
import os
|
10 |
-
|
11 |
from Modules.rag import rag_chain
|
12 |
from Modules.router import router_chain
|
13 |
-
|
14 |
|
15 |
mistral_api_key = os.getenv("MISTRAL_API_KEY")
|
16 |
from Modules.PoseEstimation import pose_estimator
|
@@ -39,6 +39,9 @@ with col1:
|
|
39 |
|
40 |
if "messages" not in st.session_state:
|
41 |
st.session_state.messages = []
|
|
|
|
|
|
|
42 |
|
43 |
for message in st.session_state.messages:
|
44 |
with st.chat_message(message["role"]):
|
@@ -52,45 +55,53 @@ with col1:
|
|
52 |
with st.chat_message("assistant"):
|
53 |
# Build answer from LLM
|
54 |
direction = router_chain.invoke({"question":prompt})
|
|
|
55 |
if direction=='fitness_advices':
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
59 |
elif direction=='smalltalk':
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
69 |
st.markdown(response)
|
70 |
|
71 |
-
st.subheader("Movement Analysis")
|
72 |
-
# TO DO
|
73 |
# Second column containers
|
74 |
with col2:
|
75 |
-
st.subheader("Sports Agenda")
|
76 |
# TO DO
|
77 |
st.subheader("Video Analysis")
|
78 |
-
|
79 |
-
|
80 |
-
video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
|
81 |
if video_uploaded:
|
82 |
video_uploaded = save_uploaded_file(video_uploaded)
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
_left, mid, _right = st.columns(3)
|
85 |
with mid:
|
86 |
if os.path.exists('runs'):
|
87 |
-
|
|
|
|
|
|
|
88 |
else :
|
89 |
st.video(video_uploaded)
|
90 |
|
91 |
-
|
92 |
-
st.subheader("Graph Displayer")
|
93 |
if os.path.exists('fig'):
|
|
|
94 |
file_list = os.listdir('fig')
|
95 |
for file in file_list:
|
96 |
st.image(os.path.join('fig', file))
|
|
|
7 |
from dotenv import load_dotenv
|
8 |
load_dotenv() # load .env api keys
|
9 |
import os
|
10 |
+
import shutil
|
11 |
from Modules.rag import rag_chain
|
12 |
from Modules.router import router_chain
|
13 |
+
from Modules.PoseEstimation.pose_agent import agent_executor
|
14 |
|
15 |
mistral_api_key = os.getenv("MISTRAL_API_KEY")
|
16 |
from Modules.PoseEstimation import pose_estimator
|
|
|
39 |
|
40 |
if "messages" not in st.session_state:
|
41 |
st.session_state.messages = []
|
42 |
+
|
43 |
+
if "file_name" not in st.session_state:
|
44 |
+
st.session_state.file_name = None
|
45 |
|
46 |
for message in st.session_state.messages:
|
47 |
with st.chat_message(message["role"]):
|
|
|
55 |
with st.chat_message("assistant"):
|
56 |
# Build answer from LLM
|
57 |
direction = router_chain.invoke({"question":prompt})
|
58 |
+
|
59 |
if direction=='fitness_advices':
|
60 |
+
with st.spinner("Thinking..."):
|
61 |
+
response = rag_chain.invoke(
|
62 |
+
prompt
|
63 |
+
)
|
64 |
elif direction=='smalltalk':
|
65 |
+
with st.spinner("Thinking..."):
|
66 |
+
response = base_chain.invoke(
|
67 |
+
{"question":prompt}
|
68 |
+
).content
|
69 |
+
elif direction =='movement_analysis':
|
70 |
+
if st.session_state.file_name is not None:
|
71 |
+
prompt += "the file name is " + st.session_state.file_name
|
72 |
+
with st.spinner("Analyzing movement..."):
|
73 |
+
response = agent_executor.invoke(
|
74 |
+
{"input" : prompt}
|
75 |
+
)["output"]
|
76 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
77 |
st.markdown(response)
|
78 |
|
|
|
|
|
79 |
# Second column containers
|
80 |
with col2:
|
81 |
+
# st.subheader("Sports Agenda")
|
82 |
# TO DO
|
83 |
st.subheader("Video Analysis")
|
84 |
+
|
85 |
+
video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"])
|
|
|
86 |
if video_uploaded:
|
87 |
video_uploaded = save_uploaded_file(video_uploaded)
|
88 |
+
if video_uploaded.split("/")[-1] != st.session_state.file_name:
|
89 |
+
shutil.rmtree('fig', ignore_errors=True)
|
90 |
+
shutil.rmtree('runs', ignore_errors=True)
|
91 |
+
st.session_state.file_name = None
|
92 |
+
st.session_state.file_name = video_uploaded.split("/")[-1]
|
93 |
_left, mid, _right = st.columns(3)
|
94 |
with mid:
|
95 |
if os.path.exists('runs'):
|
96 |
+
predict_list = os.listdir(os.path.join('runs', 'pose'))
|
97 |
+
predict_list.sort()
|
98 |
+
predict_dir = predict_list[-1]
|
99 |
+
st.video(os.path.join('runs', 'pose', predict_dir, 'squat.mp4'), loop=True)
|
100 |
else :
|
101 |
st.video(video_uploaded)
|
102 |
|
|
|
|
|
103 |
if os.path.exists('fig'):
|
104 |
+
st.subheader("Graph Displayer")
|
105 |
file_list = os.listdir('fig')
|
106 |
for file in file_list:
|
107 |
st.image(os.path.join('fig', file))
|