Spaces:
Sleeping
Sleeping
Merge branch 'main' of https://huggingface.co/spaces/EntrepreneurFirst/FitnessEquation
Browse files
Modules/PoseEstimation/pose_agent.py
CHANGED
@@ -77,7 +77,7 @@ def check_knee_angle(json_path: str) -> bool:
|
|
77 |
return False
|
78 |
|
79 |
@tool
|
80 |
-
def check_squat(
|
81 |
"""
|
82 |
Checks if the squat is correct.
|
83 |
This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
|
@@ -89,8 +89,17 @@ def check_squat(video_path: str) -> bool:
|
|
89 |
Returns:
|
90 |
is_correct (bool): True if the squat is correct, False otherwise
|
91 |
"""
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
tools = [check_squat]
|
96 |
|
@@ -98,7 +107,7 @@ prompt = ChatPromptTemplate.from_messages(
|
|
98 |
[
|
99 |
(
|
100 |
"system",
|
101 |
-
"You are a helpful assistant. Make sure to use the
|
102 |
),
|
103 |
("placeholder", "{chat_history}"),
|
104 |
("human", "{input}"),
|
@@ -109,6 +118,4 @@ prompt = ChatPromptTemplate.from_messages(
|
|
109 |
# Construct the Tools agent
|
110 |
agent = create_tool_calling_agent(llm, tools, prompt)
|
111 |
|
112 |
-
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
113 |
-
response = agent_executor.invoke({"input": f"Is my squat correct ? The video file is in data/pose/squat.mp4."})
|
114 |
-
print(response["output"])
|
|
|
77 |
return False
|
78 |
|
79 |
@tool
|
80 |
+
def check_squat(file_name: str) -> str:
|
81 |
"""
|
82 |
Checks if the squat is correct.
|
83 |
This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
|
|
|
89 |
Returns:
|
90 |
is_correct (bool): True if the squat is correct, False otherwise
|
91 |
"""
|
92 |
+
|
93 |
+
video_path = os.path.join('uploaded', file_name)
|
94 |
+
if os.path.exists(video_path):
|
95 |
+
json_path = get_keypoints_from_keypoints(video_path)
|
96 |
+
is_correct = check_knee_angle(json_path)
|
97 |
+
if is_correct:
|
98 |
+
return "The squat is correct because your knee angle is smaller than 90 degrees."
|
99 |
+
else:
|
100 |
+
return "The squat is incorrect because your knee angle is greater than 90 degrees."
|
101 |
+
else:
|
102 |
+
return "The video file does not exist."
|
103 |
|
104 |
tools = [check_squat]
|
105 |
|
|
|
107 |
[
|
108 |
(
|
109 |
"system",
|
110 |
+
"You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response",
|
111 |
),
|
112 |
("placeholder", "{chat_history}"),
|
113 |
("human", "{input}"),
|
|
|
118 |
# Construct the Tools agent
|
119 |
agent = create_tool_calling_agent(llm, tools, prompt)
|
120 |
|
121 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
|
|
|
Modules/rag.py
CHANGED
@@ -23,7 +23,7 @@ from langchain.retrievers import (
|
|
23 |
from huggingface_hub import login
|
24 |
login(token=os.getenv("HUGGING_FACE_TOKEN"))
|
25 |
|
26 |
-
def load_chunk_persist_pdf() -> Chroma:
|
27 |
|
28 |
pdf_folder_path = os.path.join(os.getcwd(),Path(f"data/pdf/{task}"))
|
29 |
documents = []
|
@@ -37,12 +37,13 @@ def load_chunk_persist_pdf() -> Chroma:
|
|
37 |
os.makedirs("data/chroma_store/", exist_ok=True)
|
38 |
vectorstore = Chroma.from_documents(
|
39 |
documents=chunked_documents,
|
40 |
-
embedding=MistralAIEmbeddings(
|
41 |
persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/"))
|
42 |
)
|
43 |
vectorstore.persist()
|
44 |
return vectorstore
|
45 |
|
|
|
46 |
zero2hero_vectorstore = load_chunk_persist_pdf("zero2hero")
|
47 |
bodyweight_vectorstore = load_chunk_persist_pdf("bodyweight")
|
48 |
nutrition_vectorstore = load_chunk_persist_pdf("nutrition")
|
@@ -51,13 +52,14 @@ zero2hero_retriever = zero2hero_vectorstore.as_retriever()
|
|
51 |
nutrition_retriever = nutrition_vectorstore.as_retriever()
|
52 |
bodyweight_retriever = bodyweight_vectorstore.as_retriever()
|
53 |
workout_retriever = workout_vectorstore.as_retriever()
|
|
|
54 |
|
55 |
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
|
56 |
|
57 |
prompt = ChatPromptTemplate.from_template(
|
58 |
"""
|
59 |
You are a professional AI coach specialized in fitness, bodybuilding and nutrition.
|
60 |
-
You must adapt to the user
|
61 |
Use the following pieces of retrieved context to answer the question.
|
62 |
If you don't know the answer, use your common knowledge.
|
63 |
Use three sentences maximum and keep the answer concise.
|
@@ -73,7 +75,7 @@ prompt = ChatPromptTemplate.from_template(
|
|
73 |
def format_docs(docs):
|
74 |
return "\n\n".join(doc.page_content for doc in docs)
|
75 |
|
76 |
-
retriever = MergerRetriever(retrievers=[zero2hero_retriever, bodyweight_retriever, nutrition_retriever, workout_retriever])
|
77 |
|
78 |
rag_chain = (
|
79 |
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
@@ -84,6 +86,6 @@ rag_chain = (
|
|
84 |
|
85 |
|
86 |
|
87 |
-
print(rag_chain.invoke("
|
88 |
|
89 |
# print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program, and a nutrition program"))
|
|
|
23 |
from huggingface_hub import login
|
24 |
login(token=os.getenv("HUGGING_FACE_TOKEN"))
|
25 |
|
26 |
+
def load_chunk_persist_pdf(task) -> Chroma:
|
27 |
|
28 |
pdf_folder_path = os.path.join(os.getcwd(),Path(f"data/pdf/{task}"))
|
29 |
documents = []
|
|
|
37 |
os.makedirs("data/chroma_store/", exist_ok=True)
|
38 |
vectorstore = Chroma.from_documents(
|
39 |
documents=chunked_documents,
|
40 |
+
embedding=MistralAIEmbeddings(),
|
41 |
persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/"))
|
42 |
)
|
43 |
vectorstore.persist()
|
44 |
return vectorstore
|
45 |
|
46 |
+
personal_info_vectorstore = load_chunk_persist_pdf("personal_info")
|
47 |
zero2hero_vectorstore = load_chunk_persist_pdf("zero2hero")
|
48 |
bodyweight_vectorstore = load_chunk_persist_pdf("bodyweight")
|
49 |
nutrition_vectorstore = load_chunk_persist_pdf("nutrition")
|
|
|
52 |
nutrition_retriever = nutrition_vectorstore.as_retriever()
|
53 |
bodyweight_retriever = bodyweight_vectorstore.as_retriever()
|
54 |
workout_retriever = workout_vectorstore.as_retriever()
|
55 |
+
personal_info_retriever = personal_info_vectorstore.as_retriever()
|
56 |
|
57 |
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
|
58 |
|
59 |
prompt = ChatPromptTemplate.from_template(
|
60 |
"""
|
61 |
You are a professional AI coach specialized in fitness, bodybuilding and nutrition.
|
62 |
+
You must adapt to the user according to personal informations in the context. A You are gentle and motivative.
|
63 |
Use the following pieces of retrieved context to answer the question.
|
64 |
If you don't know the answer, use your common knowledge.
|
65 |
Use three sentences maximum and keep the answer concise.
|
|
|
75 |
def format_docs(docs):
|
76 |
return "\n\n".join(doc.page_content for doc in docs)
|
77 |
|
78 |
+
retriever = MergerRetriever(retrievers=[zero2hero_retriever, bodyweight_retriever, nutrition_retriever, workout_retriever, personal_info_retriever])
|
79 |
|
80 |
rag_chain = (
|
81 |
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
|
86 |
|
87 |
|
88 |
|
89 |
+
print(rag_chain.invoke("WHi I'm Susan. Can you make a fitness program for me please?"))
|
90 |
|
91 |
# print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program, and a nutrition program"))
|
app.py
CHANGED
@@ -7,10 +7,10 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
7 |
from dotenv import load_dotenv
|
8 |
load_dotenv() # load .env api keys
|
9 |
import os
|
10 |
-
|
11 |
from Modules.rag import rag_chain
|
12 |
from Modules.router import router_chain
|
13 |
-
|
14 |
|
15 |
mistral_api_key = os.getenv("MISTRAL_API_KEY")
|
16 |
from Modules.PoseEstimation import pose_estimator
|
@@ -39,6 +39,9 @@ with col1:
|
|
39 |
|
40 |
if "messages" not in st.session_state:
|
41 |
st.session_state.messages = []
|
|
|
|
|
|
|
42 |
|
43 |
for message in st.session_state.messages:
|
44 |
with st.chat_message(message["role"]):
|
@@ -52,45 +55,53 @@ with col1:
|
|
52 |
with st.chat_message("assistant"):
|
53 |
# Build answer from LLM
|
54 |
direction = router_chain.invoke({"question":prompt})
|
|
|
55 |
if direction=='fitness_advices':
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
59 |
elif direction=='smalltalk':
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
69 |
st.markdown(response)
|
70 |
|
71 |
-
st.subheader("Movement Analysis")
|
72 |
-
# TO DO
|
73 |
# Second column containers
|
74 |
with col2:
|
75 |
-
st.subheader("Sports Agenda")
|
76 |
# TO DO
|
77 |
st.subheader("Video Analysis")
|
78 |
-
|
79 |
-
|
80 |
-
video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
|
81 |
if video_uploaded:
|
82 |
video_uploaded = save_uploaded_file(video_uploaded)
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
_left, mid, _right = st.columns(3)
|
85 |
with mid:
|
86 |
if os.path.exists('runs'):
|
87 |
-
|
|
|
|
|
|
|
88 |
else :
|
89 |
st.video(video_uploaded)
|
90 |
|
91 |
-
|
92 |
-
st.subheader("Graph Displayer")
|
93 |
if os.path.exists('fig'):
|
|
|
94 |
file_list = os.listdir('fig')
|
95 |
for file in file_list:
|
96 |
st.image(os.path.join('fig', file))
|
|
|
7 |
from dotenv import load_dotenv
|
8 |
load_dotenv() # load .env api keys
|
9 |
import os
|
10 |
+
import shutil
|
11 |
from Modules.rag import rag_chain
|
12 |
from Modules.router import router_chain
|
13 |
+
from Modules.PoseEstimation.pose_agent import agent_executor
|
14 |
|
15 |
mistral_api_key = os.getenv("MISTRAL_API_KEY")
|
16 |
from Modules.PoseEstimation import pose_estimator
|
|
|
39 |
|
40 |
if "messages" not in st.session_state:
|
41 |
st.session_state.messages = []
|
42 |
+
|
43 |
+
if "file_name" not in st.session_state:
|
44 |
+
st.session_state.file_name = None
|
45 |
|
46 |
for message in st.session_state.messages:
|
47 |
with st.chat_message(message["role"]):
|
|
|
55 |
with st.chat_message("assistant"):
|
56 |
# Build answer from LLM
|
57 |
direction = router_chain.invoke({"question":prompt})
|
58 |
+
|
59 |
if direction=='fitness_advices':
|
60 |
+
with st.spinner("Thinking..."):
|
61 |
+
response = rag_chain.invoke(
|
62 |
+
prompt
|
63 |
+
)
|
64 |
elif direction=='smalltalk':
|
65 |
+
with st.spinner("Thinking..."):
|
66 |
+
response = base_chain.invoke(
|
67 |
+
{"question":prompt}
|
68 |
+
).content
|
69 |
+
elif direction =='movement_analysis':
|
70 |
+
if st.session_state.file_name is not None:
|
71 |
+
prompt += "the file name is " + st.session_state.file_name
|
72 |
+
with st.spinner("Analyzing movement..."):
|
73 |
+
response = agent_executor.invoke(
|
74 |
+
{"input" : prompt}
|
75 |
+
)["output"]
|
76 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
77 |
st.markdown(response)
|
78 |
|
|
|
|
|
79 |
# Second column containers
|
80 |
with col2:
|
81 |
+
# st.subheader("Sports Agenda")
|
82 |
# TO DO
|
83 |
st.subheader("Video Analysis")
|
84 |
+
|
85 |
+
video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"])
|
|
|
86 |
if video_uploaded:
|
87 |
video_uploaded = save_uploaded_file(video_uploaded)
|
88 |
+
if video_uploaded.split("/")[-1] != st.session_state.file_name:
|
89 |
+
shutil.rmtree('fig', ignore_errors=True)
|
90 |
+
shutil.rmtree('runs', ignore_errors=True)
|
91 |
+
st.session_state.file_name = None
|
92 |
+
st.session_state.file_name = video_uploaded.split("/")[-1]
|
93 |
_left, mid, _right = st.columns(3)
|
94 |
with mid:
|
95 |
if os.path.exists('runs'):
|
96 |
+
predict_list = os.listdir(os.path.join('runs', 'pose'))
|
97 |
+
predict_list.sort()
|
98 |
+
predict_dir = predict_list[-1]
|
99 |
+
st.video(os.path.join('runs', 'pose', predict_dir, 'squat.mp4'), loop=True)
|
100 |
else :
|
101 |
st.video(video_uploaded)
|
102 |
|
|
|
|
|
103 |
if os.path.exists('fig'):
|
104 |
+
st.subheader("Graph Displayer")
|
105 |
file_list = os.listdir('fig')
|
106 |
for file in file_list:
|
107 |
st.image(os.path.join('fig', file))
|
data/Zero To Hero - Bienvenue dans ta nouvelle vie.pdf
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:1e59511fd335776dcb8b0e98e9e9afe6d22483df94c83755163d5d5d4a7f3809
|
3 |
-
size 60211845
|
|
|
|
|
|
|
|
data/{RaptorBodyweight.pdf β pdf/personal_info/Susan_Thompson.pdf}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25e4eaa98480003224e09c48011e82678f9947382fc37c16011b75fab80f93a2
|
3 |
+
size 55515
|