Théo Rousseaux commited on
Commit
85b06be
1 Parent(s): 42ef85f

streamlit chat

Browse files
Files changed (3) hide show
  1. Modules/rag.py +5 -1
  2. app.py +21 -37
  3. requirements.txt +4 -1
Modules/rag.py CHANGED
@@ -4,6 +4,7 @@ from dotenv import load_dotenv
4
  load_dotenv() # load .env api keys
5
 
6
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
 
7
 
8
  from langchain_community.document_loaders import PyPDFLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -24,6 +25,9 @@ from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddi
24
  from langchain_community.tools import DuckDuckGoSearchRun
25
  from pathlib import Path
26
 
 
 
 
27
  def load_chunk_persist_pdf() -> Chroma:
28
 
29
  pdf_folder_path = os.path.join(os.getcwd(),Path("data/pdf/"))
@@ -38,7 +42,7 @@ def load_chunk_persist_pdf() -> Chroma:
38
  os.makedirs("data/chroma_store/", exist_ok=True)
39
  vectorstore = Chroma.from_documents(
40
  documents=chunked_documents,
41
- embedding=MistralAIEmbeddings(),
42
  persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/"))
43
  )
44
  vectorstore.persist()
 
4
  load_dotenv() # load .env api keys
5
 
6
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
7
+ print("mistral_api_key", mistral_api_key)
8
 
9
  from langchain_community.document_loaders import PyPDFLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
25
  from langchain_community.tools import DuckDuckGoSearchRun
26
  from pathlib import Path
27
 
28
+ from huggingface_hub import login
29
+ login(token=os.getenv("HUGGING_FACE_TOKEN"))
30
+
31
  def load_chunk_persist_pdf() -> Chroma:
32
 
33
  pdf_folder_path = os.path.join(os.getcwd(),Path("data/pdf/"))
 
42
  os.makedirs("data/chroma_store/", exist_ok=True)
43
  vectorstore = Chroma.from_documents(
44
  documents=chunked_documents,
45
+ embedding=MistralAIEmbeddings(api_key=mistral_api_key),
46
  persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/"))
47
  )
48
  vectorstore.persist()
app.py CHANGED
@@ -16,13 +16,6 @@ mistral_api_key = os.getenv("MISTRAL_API_KEY")
16
  from Modules.PoseEstimation import pose_estimator
17
  from utils import save_uploaded_file
18
 
19
- def format_messages(messages):
20
- formatted_messages = ""
21
- for message in messages:
22
- role = message["role"]
23
- content = message["content"]
24
- formatted_messages += f"{role}: {content}\n"
25
- return formatted_messages
26
 
27
  st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
28
  # Create two columns
@@ -41,41 +34,31 @@ base_chain = prompt | llm
41
 
42
  # First column containers
43
  with col1:
44
- st.subheader("Audio Recorder")
45
- recorded = False
46
- temp_path = 'data/temp_audio/audio_file.wav'
47
- wav_audio_data = st_audiorec()
48
- if wav_audio_data is not None:
49
- with open(temp_path, 'wb') as f:
50
- # Write the audio data to the file
51
- f.write(wav_audio_data)
52
- instruction = transcribe(temp_path)
53
- print(instruction)
54
- recorded = True
55
 
 
 
56
 
57
- st.subheader("LLM answering")
58
- if recorded:
59
- if "messages" not in st.session_state:
60
- st.session_state.messages = []
61
- for message in st.session_state.messages:
62
- with st.chat_message(message["role"]):
63
- st.markdown(message["content"])
64
 
65
- st.session_state.messages.append({"role": "user", "content": instruction})
 
66
  with st.chat_message("user"):
67
- st.markdown(instruction)
68
 
69
  with st.chat_message("assistant"):
70
  # Build answer from LLM
71
- direction = router_chain.invoke({"question":instruction})
72
  if direction=='fitness_advices':
73
  response = rag_chain.invoke(
74
- instruction
75
  )
76
  elif direction=='smalltalk':
77
  response = base_chain.invoke(
78
- {"question":instruction}
79
  ).content
80
  # elif direction =='movement_analysis':
81
  # response = agent_executor.invoke(
@@ -100,13 +83,14 @@ with col2:
100
  ask_video.empty()
101
  _left, mid, _right = st.columns(3)
102
  with mid:
103
- st.video(video_uploaded)
104
- apply_pose = st.button("Apply Pose Estimation")
105
-
106
- if apply_pose:
107
- with st.spinner("Processing video"):
108
- keypoints = pose_estimator.get_keypoints_from_keypoints(pose_estimator.model, video_uploaded)
109
 
110
 
111
  st.subheader("Graph Displayer")
112
- # TO DO
 
 
 
 
16
  from Modules.PoseEstimation import pose_estimator
17
  from utils import save_uploaded_file
18
 
 
 
 
 
 
 
 
19
 
20
  st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
21
  # Create two columns
 
34
 
35
  # First column containers
36
  with col1:
37
+
38
+ st.subheader("LLM answering")
 
 
 
 
 
 
 
 
 
39
 
40
+ if "messages" not in st.session_state:
41
+ st.session_state.messages = []
42
 
43
+ for message in st.session_state.messages:
44
+ with st.chat_message(message["role"]):
45
+ st.markdown(message["content"])
 
 
 
 
46
 
47
+ if prompt := st.chat_input("What is up?"):
48
+ st.session_state.messages.append({"role": "user", "content": prompt})
49
  with st.chat_message("user"):
50
+ st.markdown(prompt)
51
 
52
  with st.chat_message("assistant"):
53
  # Build answer from LLM
54
+ direction = router_chain.invoke({"question":prompt})
55
  if direction=='fitness_advices':
56
  response = rag_chain.invoke(
57
+ prompt
58
  )
59
  elif direction=='smalltalk':
60
  response = base_chain.invoke(
61
+ {"question":prompt}
62
  ).content
63
  # elif direction =='movement_analysis':
64
  # response = agent_executor.invoke(
 
83
  ask_video.empty()
84
  _left, mid, _right = st.columns(3)
85
  with mid:
86
+ if os.path.exists('runs'):
87
+ st.video(os.path.join('runs', 'pose', 'predict', 'squat.mp4'), loop=True)
88
+ else :
89
+ st.video(video_uploaded)
 
 
90
 
91
 
92
  st.subheader("Graph Displayer")
93
+ if os.path.exists('fig'):
94
+ file_list = os.listdir('fig')
95
+ for file in file_list:
96
+ st.image(os.path.join('fig', file))
requirements.txt CHANGED
@@ -13,4 +13,7 @@ chromadb
13
  langgraph
14
  langchainhub
15
  pypdf
16
- duckduckgo-search
 
 
 
 
13
  langgraph
14
  langchainhub
15
  pypdf
16
+ duckduckgo-search
17
+ python-dotenv
18
+ pypdf
19
+ chromadb