Nico8800 commited on
Commit
631a7fe
β€’
2 Parent(s): 36222dd 989534c

Merge branch 'main' of https://huggingface.co/spaces/EntrepreneurFirst/FitnessEquation

Browse files
Modules/PoseEstimation/pose_agent.py CHANGED
@@ -77,7 +77,7 @@ def check_knee_angle(json_path: str) -> bool:
77
  return False
78
 
79
  @tool
80
- def check_squat(video_path: str) -> bool:
81
  """
82
  Checks if the squat is correct.
83
  This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
@@ -89,8 +89,17 @@ def check_squat(video_path: str) -> bool:
89
  Returns:
90
  is_correct (bool): True if the squat is correct, False otherwise
91
  """
92
- json_path = get_keypoints_from_keypoints(video_path)
93
- return check_knee_angle(json_path)
 
 
 
 
 
 
 
 
 
94
 
95
  tools = [check_squat]
96
 
@@ -98,7 +107,7 @@ prompt = ChatPromptTemplate.from_messages(
98
  [
99
  (
100
  "system",
101
- "You are a helpful assistant. Make sure to use the check_knee_angle tool if the user wants to check his movement. Also explain your response",
102
  ),
103
  ("placeholder", "{chat_history}"),
104
  ("human", "{input}"),
@@ -109,6 +118,4 @@ prompt = ChatPromptTemplate.from_messages(
109
  # Construct the Tools agent
110
  agent = create_tool_calling_agent(llm, tools, prompt)
111
 
112
- agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
113
- response = agent_executor.invoke({"input": f"Is my squat correct ? The video file is in data/pose/squat.mp4."})
114
- print(response["output"])
 
77
  return False
78
 
79
  @tool
80
+ def check_squat(file_name: str) -> str:
81
  """
82
  Checks if the squat is correct.
83
  This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
 
89
  Returns:
90
  is_correct (bool): True if the squat is correct, False otherwise
91
  """
92
+
93
+ video_path = os.path.join('uploaded', file_name)
94
+ if os.path.exists(video_path):
95
+ json_path = get_keypoints_from_keypoints(video_path)
96
+ is_correct = check_knee_angle(json_path)
97
+ if is_correct:
98
+ return "The squat is correct because your knee angle is smaller than 90 degrees."
99
+ else:
100
+ return "The squat is incorrect because your knee angle is greater than 90 degrees."
101
+ else:
102
+ return "The video file does not exist."
103
 
104
  tools = [check_squat]
105
 
 
107
  [
108
  (
109
  "system",
110
+ "You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response",
111
  ),
112
  ("placeholder", "{chat_history}"),
113
  ("human", "{input}"),
 
118
  # Construct the Tools agent
119
  agent = create_tool_calling_agent(llm, tools, prompt)
120
 
121
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
 
 
Modules/rag.py CHANGED
@@ -23,7 +23,7 @@ from langchain.retrievers import (
23
  from huggingface_hub import login
24
  login(token=os.getenv("HUGGING_FACE_TOKEN"))
25
 
26
- def load_chunk_persist_pdf() -> Chroma:
27
 
28
  pdf_folder_path = os.path.join(os.getcwd(),Path(f"data/pdf/{task}"))
29
  documents = []
@@ -37,12 +37,13 @@ def load_chunk_persist_pdf() -> Chroma:
37
  os.makedirs("data/chroma_store/", exist_ok=True)
38
  vectorstore = Chroma.from_documents(
39
  documents=chunked_documents,
40
- embedding=MistralAIEmbeddings(api_key=mistral_api_key),
41
  persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/"))
42
  )
43
  vectorstore.persist()
44
  return vectorstore
45
 
 
46
  zero2hero_vectorstore = load_chunk_persist_pdf("zero2hero")
47
  bodyweight_vectorstore = load_chunk_persist_pdf("bodyweight")
48
  nutrition_vectorstore = load_chunk_persist_pdf("nutrition")
@@ -51,13 +52,14 @@ zero2hero_retriever = zero2hero_vectorstore.as_retriever()
51
  nutrition_retriever = nutrition_vectorstore.as_retriever()
52
  bodyweight_retriever = bodyweight_vectorstore.as_retriever()
53
  workout_retriever = workout_vectorstore.as_retriever()
 
54
 
55
  llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
56
 
57
  prompt = ChatPromptTemplate.from_template(
58
  """
59
  You are a professional AI coach specialized in fitness, bodybuilding and nutrition.
60
- You must adapt to the user : if he is a beginner, use simple words. You are gentle and motivative.
61
  Use the following pieces of retrieved context to answer the question.
62
  If you don't know the answer, use your common knowledge.
63
  Use three sentences maximum and keep the answer concise.
@@ -73,7 +75,7 @@ prompt = ChatPromptTemplate.from_template(
73
  def format_docs(docs):
74
  return "\n\n".join(doc.page_content for doc in docs)
75
 
76
- retriever = MergerRetriever(retrievers=[zero2hero_retriever, bodyweight_retriever, nutrition_retriever, workout_retriever])
77
 
78
  rag_chain = (
79
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
@@ -84,6 +86,6 @@ rag_chain = (
84
 
85
 
86
 
87
- print(rag_chain.invoke("What supplement could i buy to improve my sleep?"))
88
 
89
  # print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program, and a nutrition program"))
 
23
  from huggingface_hub import login
24
  login(token=os.getenv("HUGGING_FACE_TOKEN"))
25
 
26
+ def load_chunk_persist_pdf(task) -> Chroma:
27
 
28
  pdf_folder_path = os.path.join(os.getcwd(),Path(f"data/pdf/{task}"))
29
  documents = []
 
37
  os.makedirs("data/chroma_store/", exist_ok=True)
38
  vectorstore = Chroma.from_documents(
39
  documents=chunked_documents,
40
+ embedding=MistralAIEmbeddings(),
41
  persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/"))
42
  )
43
  vectorstore.persist()
44
  return vectorstore
45
 
46
+ personal_info_vectorstore = load_chunk_persist_pdf("personal_info")
47
  zero2hero_vectorstore = load_chunk_persist_pdf("zero2hero")
48
  bodyweight_vectorstore = load_chunk_persist_pdf("bodyweight")
49
  nutrition_vectorstore = load_chunk_persist_pdf("nutrition")
 
52
  nutrition_retriever = nutrition_vectorstore.as_retriever()
53
  bodyweight_retriever = bodyweight_vectorstore.as_retriever()
54
  workout_retriever = workout_vectorstore.as_retriever()
55
+ personal_info_retriever = personal_info_vectorstore.as_retriever()
56
 
57
  llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
58
 
59
  prompt = ChatPromptTemplate.from_template(
60
  """
61
  You are a professional AI coach specialized in fitness, bodybuilding and nutrition.
62
+ You must adapt to the user according to personal informations in the context. A You are gentle and motivative.
63
  Use the following pieces of retrieved context to answer the question.
64
  If you don't know the answer, use your common knowledge.
65
  Use three sentences maximum and keep the answer concise.
 
75
  def format_docs(docs):
76
  return "\n\n".join(doc.page_content for doc in docs)
77
 
78
+ retriever = MergerRetriever(retrievers=[zero2hero_retriever, bodyweight_retriever, nutrition_retriever, workout_retriever, personal_info_retriever])
79
 
80
  rag_chain = (
81
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
 
86
 
87
 
88
 
89
+ print(rag_chain.invoke("WHi I'm Susan. Can you make a fitness program for me please?"))
90
 
91
  # print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program, and a nutrition program"))
app.py CHANGED
@@ -7,10 +7,10 @@ from langchain_core.prompts import ChatPromptTemplate
7
  from dotenv import load_dotenv
8
  load_dotenv() # load .env api keys
9
  import os
10
-
11
  from Modules.rag import rag_chain
12
  from Modules.router import router_chain
13
- # from Modules.PoseEstimation.pose_agent import agent_executor
14
 
15
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
16
  from Modules.PoseEstimation import pose_estimator
@@ -39,6 +39,9 @@ with col1:
39
 
40
  if "messages" not in st.session_state:
41
  st.session_state.messages = []
 
 
 
42
 
43
  for message in st.session_state.messages:
44
  with st.chat_message(message["role"]):
@@ -52,45 +55,53 @@ with col1:
52
  with st.chat_message("assistant"):
53
  # Build answer from LLM
54
  direction = router_chain.invoke({"question":prompt})
 
55
  if direction=='fitness_advices':
56
- response = rag_chain.invoke(
57
- prompt
58
- )
 
59
  elif direction=='smalltalk':
60
- response = base_chain.invoke(
61
- {"question":prompt}
62
- ).content
63
- # elif direction =='movement_analysis':
64
- # response = agent_executor.invoke(
65
- # {"input" : instruction}
66
- # )["output"]
67
- print(type(response))
 
 
 
68
  st.session_state.messages.append({"role": "assistant", "content": response})
69
  st.markdown(response)
70
 
71
- st.subheader("Movement Analysis")
72
- # TO DO
73
  # Second column containers
74
  with col2:
75
- st.subheader("Sports Agenda")
76
  # TO DO
77
  st.subheader("Video Analysis")
78
- ask_video = st.empty()
79
- if video_uploaded is None:
80
- video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
81
  if video_uploaded:
82
  video_uploaded = save_uploaded_file(video_uploaded)
83
- ask_video.empty()
 
 
 
 
84
  _left, mid, _right = st.columns(3)
85
  with mid:
86
  if os.path.exists('runs'):
87
- st.video(os.path.join('runs', 'pose', 'predict', 'squat.mp4'), loop=True)
 
 
 
88
  else :
89
  st.video(video_uploaded)
90
 
91
-
92
- st.subheader("Graph Displayer")
93
  if os.path.exists('fig'):
 
94
  file_list = os.listdir('fig')
95
  for file in file_list:
96
  st.image(os.path.join('fig', file))
 
7
  from dotenv import load_dotenv
8
  load_dotenv() # load .env api keys
9
  import os
10
+ import shutil
11
  from Modules.rag import rag_chain
12
  from Modules.router import router_chain
13
+ from Modules.PoseEstimation.pose_agent import agent_executor
14
 
15
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
16
  from Modules.PoseEstimation import pose_estimator
 
39
 
40
  if "messages" not in st.session_state:
41
  st.session_state.messages = []
42
+
43
+ if "file_name" not in st.session_state:
44
+ st.session_state.file_name = None
45
 
46
  for message in st.session_state.messages:
47
  with st.chat_message(message["role"]):
 
55
  with st.chat_message("assistant"):
56
  # Build answer from LLM
57
  direction = router_chain.invoke({"question":prompt})
58
+
59
  if direction=='fitness_advices':
60
+ with st.spinner("Thinking..."):
61
+ response = rag_chain.invoke(
62
+ prompt
63
+ )
64
  elif direction=='smalltalk':
65
+ with st.spinner("Thinking..."):
66
+ response = base_chain.invoke(
67
+ {"question":prompt}
68
+ ).content
69
+ elif direction =='movement_analysis':
70
+ if st.session_state.file_name is not None:
71
+ prompt += "the file name is " + st.session_state.file_name
72
+ with st.spinner("Analyzing movement..."):
73
+ response = agent_executor.invoke(
74
+ {"input" : prompt}
75
+ )["output"]
76
  st.session_state.messages.append({"role": "assistant", "content": response})
77
  st.markdown(response)
78
 
 
 
79
  # Second column containers
80
  with col2:
81
+ # st.subheader("Sports Agenda")
82
  # TO DO
83
  st.subheader("Video Analysis")
84
+
85
+ video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"])
 
86
  if video_uploaded:
87
  video_uploaded = save_uploaded_file(video_uploaded)
88
+ if video_uploaded.split("/")[-1] != st.session_state.file_name:
89
+ shutil.rmtree('fig', ignore_errors=True)
90
+ shutil.rmtree('runs', ignore_errors=True)
91
+ st.session_state.file_name = None
92
+ st.session_state.file_name = video_uploaded.split("/")[-1]
93
  _left, mid, _right = st.columns(3)
94
  with mid:
95
  if os.path.exists('runs'):
96
+ predict_list = os.listdir(os.path.join('runs', 'pose'))
97
+ predict_list.sort()
98
+ predict_dir = predict_list[-1]
99
+ st.video(os.path.join('runs', 'pose', predict_dir, 'squat.mp4'), loop=True)
100
  else :
101
  st.video(video_uploaded)
102
 
 
 
103
  if os.path.exists('fig'):
104
+ st.subheader("Graph Displayer")
105
  file_list = os.listdir('fig')
106
  for file in file_list:
107
  st.image(os.path.join('fig', file))
data/Zero To Hero - Bienvenue dans ta nouvelle vie.pdf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e59511fd335776dcb8b0e98e9e9afe6d22483df94c83755163d5d5d4a7f3809
3
- size 60211845
 
 
 
 
data/{RaptorBodyweight.pdf β†’ pdf/personal_info/Susan_Thompson.pdf} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cbc90aebadbad36b8f47cd9ba45a958d278e38e7481037acb9ecee5e3805d771
3
- size 26474997
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25e4eaa98480003224e09c48011e82678f9947382fc37c16011b75fab80f93a2
3
+ size 55515