Théo Rousseaux commited on
Commit
989534c
1 Parent(s): ed250fe

POSE AGENT WORKING

Browse files
Files changed (2) hide show
  1. Modules/PoseEstimation/pose_agent.py +14 -7
  2. app.py +34 -23
Modules/PoseEstimation/pose_agent.py CHANGED
@@ -77,7 +77,7 @@ def check_knee_angle(json_path: str) -> bool:
77
  return False
78
 
79
  @tool
80
- def check_squat(video_path: str) -> bool:
81
  """
82
  Checks if the squat is correct.
83
  This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
@@ -89,8 +89,17 @@ def check_squat(video_path: str) -> bool:
89
  Returns:
90
  is_correct (bool): True if the squat is correct, False otherwise
91
  """
92
- json_path = get_keypoints_from_keypoints(video_path)
93
- return check_knee_angle(json_path)
 
 
 
 
 
 
 
 
 
94
 
95
  tools = [check_squat]
96
 
@@ -98,7 +107,7 @@ prompt = ChatPromptTemplate.from_messages(
98
  [
99
  (
100
  "system",
101
- "You are a helpful assistant. Make sure to use the check_knee_angle tool if the user wants to check his movement. Also explain your response",
102
  ),
103
  ("placeholder", "{chat_history}"),
104
  ("human", "{input}"),
@@ -109,6 +118,4 @@ prompt = ChatPromptTemplate.from_messages(
109
  # Construct the Tools agent
110
  agent = create_tool_calling_agent(llm, tools, prompt)
111
 
112
- agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
113
- response = agent_executor.invoke({"input": f"Is my squat correct ? The video file is in data/pose/squat.mp4."})
114
- print(response["output"])
 
77
  return False
78
 
79
  @tool
80
+ def check_squat(file_name: str) -> str:
81
  """
82
  Checks if the squat is correct.
83
  This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
 
89
  Returns:
90
  is_correct (bool): True if the squat is correct, False otherwise
91
  """
92
+
93
+ video_path = os.path.join('uploaded', file_name)
94
+ if os.path.exists(video_path):
95
+ json_path = get_keypoints_from_keypoints(video_path)
96
+ is_correct = check_knee_angle(json_path)
97
+ if is_correct:
98
+ return "The squat is correct because your knee angle is smaller than 90 degrees."
99
+ else:
100
+ return "The squat is incorrect because your knee angle is greater than 90 degrees."
101
+ else:
102
+ return "The video file does not exist."
103
 
104
  tools = [check_squat]
105
 
 
107
  [
108
  (
109
  "system",
110
+ "You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response",
111
  ),
112
  ("placeholder", "{chat_history}"),
113
  ("human", "{input}"),
 
118
  # Construct the Tools agent
119
  agent = create_tool_calling_agent(llm, tools, prompt)
120
 
121
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
 
 
app.py CHANGED
@@ -7,10 +7,10 @@ from langchain_core.prompts import ChatPromptTemplate
7
  from dotenv import load_dotenv
8
  load_dotenv() # load .env api keys
9
  import os
10
-
11
  from Modules.rag import rag_chain
12
  from Modules.router import router_chain
13
- # from Modules.PoseEstimation.pose_agent import agent_executor
14
 
15
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
16
  from Modules.PoseEstimation import pose_estimator
@@ -39,6 +39,9 @@ with col1:
39
 
40
  if "messages" not in st.session_state:
41
  st.session_state.messages = []
 
 
 
42
 
43
  for message in st.session_state.messages:
44
  with st.chat_message(message["role"]):
@@ -52,45 +55,53 @@ with col1:
52
  with st.chat_message("assistant"):
53
  # Build answer from LLM
54
  direction = router_chain.invoke({"question":prompt})
 
55
  if direction=='fitness_advices':
56
- response = rag_chain.invoke(
57
- prompt
58
- )
 
59
  elif direction=='smalltalk':
60
- response = base_chain.invoke(
61
- {"question":prompt}
62
- ).content
63
- # elif direction =='movement_analysis':
64
- # response = agent_executor.invoke(
65
- # {"input" : instruction}
66
- # )["output"]
67
- print(type(response))
 
 
 
68
  st.session_state.messages.append({"role": "assistant", "content": response})
69
  st.markdown(response)
70
 
71
- st.subheader("Movement Analysis")
72
- # TO DO
73
  # Second column containers
74
  with col2:
75
- st.subheader("Sports Agenda")
76
  # TO DO
77
  st.subheader("Video Analysis")
78
- ask_video = st.empty()
79
- if video_uploaded is None:
80
- video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
81
  if video_uploaded:
82
  video_uploaded = save_uploaded_file(video_uploaded)
83
- ask_video.empty()
 
 
 
 
84
  _left, mid, _right = st.columns(3)
85
  with mid:
86
  if os.path.exists('runs'):
87
- st.video(os.path.join('runs', 'pose', 'predict', 'squat.mp4'), loop=True)
 
 
 
88
  else :
89
  st.video(video_uploaded)
90
 
91
-
92
- st.subheader("Graph Displayer")
93
  if os.path.exists('fig'):
 
94
  file_list = os.listdir('fig')
95
  for file in file_list:
96
  st.image(os.path.join('fig', file))
 
7
  from dotenv import load_dotenv
8
  load_dotenv() # load .env api keys
9
  import os
10
+ import shutil
11
  from Modules.rag import rag_chain
12
  from Modules.router import router_chain
13
+ from Modules.PoseEstimation.pose_agent import agent_executor
14
 
15
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
16
  from Modules.PoseEstimation import pose_estimator
 
39
 
40
  if "messages" not in st.session_state:
41
  st.session_state.messages = []
42
+
43
+ if "file_name" not in st.session_state:
44
+ st.session_state.file_name = None
45
 
46
  for message in st.session_state.messages:
47
  with st.chat_message(message["role"]):
 
55
  with st.chat_message("assistant"):
56
  # Build answer from LLM
57
  direction = router_chain.invoke({"question":prompt})
58
+
59
  if direction=='fitness_advices':
60
+ with st.spinner("Thinking..."):
61
+ response = rag_chain.invoke(
62
+ prompt
63
+ )
64
  elif direction=='smalltalk':
65
+ with st.spinner("Thinking..."):
66
+ response = base_chain.invoke(
67
+ {"question":prompt}
68
+ ).content
69
+ elif direction =='movement_analysis':
70
+ if st.session_state.file_name is not None:
71
+ prompt += "the file name is " + st.session_state.file_name
72
+ with st.spinner("Analyzing movement..."):
73
+ response = agent_executor.invoke(
74
+ {"input" : prompt}
75
+ )["output"]
76
  st.session_state.messages.append({"role": "assistant", "content": response})
77
  st.markdown(response)
78
 
 
 
79
  # Second column containers
80
  with col2:
81
+ # st.subheader("Sports Agenda")
82
  # TO DO
83
  st.subheader("Video Analysis")
84
+
85
+ video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"])
 
86
  if video_uploaded:
87
  video_uploaded = save_uploaded_file(video_uploaded)
88
+ if video_uploaded.split("/")[-1] != st.session_state.file_name:
89
+ shutil.rmtree('fig', ignore_errors=True)
90
+ shutil.rmtree('runs', ignore_errors=True)
91
+ st.session_state.file_name = None
92
+ st.session_state.file_name = video_uploaded.split("/")[-1]
93
  _left, mid, _right = st.columns(3)
94
  with mid:
95
  if os.path.exists('runs'):
96
+ predict_list = os.listdir(os.path.join('runs', 'pose'))
97
+ predict_list.sort()
98
+ predict_dir = predict_list[-1]
99
+ st.video(os.path.join('runs', 'pose', predict_dir, 'squat.mp4'), loop=True)
100
  else :
101
  st.video(video_uploaded)
102
 
 
 
103
  if os.path.exists('fig'):
104
+ st.subheader("Graph Displayer")
105
  file_list = os.listdir('fig')
106
  for file in file_list:
107
  st.image(os.path.join('fig', file))