TahaRasouli commited on
Commit
9310453
1 Parent(s): dc4a67d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -5
app.py CHANGED
@@ -1,7 +1,191 @@
1
- from fastapi import FastAPI
2
 
3
- app = FastAPI()
 
 
 
 
 
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
 
3
+ import streamlit as st
4
+ from phi.assistant import Assistant
5
+ from phi.document import Document
6
+ from phi.document.reader.pdf import PDFReader
7
+ from phi.document.reader.website import WebsiteReader
8
+ from phi.utils.log import logger
9
 
10
+ from assistant import get_groq_assistant # type: ignore
11
+
12
+ st.set_page_config(
13
+ page_title="ISW RAG",
14
+ page_icon=":books:",
15
+ )
16
+ st.title("RAG with Llama3 on Groq")
17
+ st.markdown("Built at ISW")
18
+
19
+ import os
20
+
21
+ from groq import Groq
22
+
23
+ client = Groq(
24
+ api_key=os.environ.get("GROQ_API_KEY"),
25
+ )
26
+
27
+ chat_completion = client.chat.completions.create(
28
+ messages=[
29
+ {
30
+ "role": "user",
31
+ "content": "Explain the importance of fast language models",
32
+ }
33
+ ],
34
+ model="llama3-8b-8192",
35
+ )
36
+
37
+ print(chat_completion.choices[0].message.content)
38
+
39
+ print(chat_completion.choices[0].message.content)
40
+
41
+ def restart_assistant():
42
+ st.session_state["rag_assistant"] = None
43
+ st.session_state["rag_assistant_run_id"] = None
44
+ if "url_scrape_key" in st.session_state:
45
+ st.session_state["url_scrape_key"] += 1
46
+ if "file_uploader_key" in st.session_state:
47
+ st.session_state["file_uploader_key"] += 1
48
+ st.rerun()
49
+
50
+
51
+ def main() -> None:
52
+ # Get LLM model
53
+ llm_model = st.sidebar.selectbox("Select LLM", options=["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768"])
54
+ # Set assistant_type in session state
55
+ if "llm_model" not in st.session_state:
56
+ st.session_state["llm_model"] = llm_model
57
+ # Restart the assistant if assistant_type has changed
58
+ elif st.session_state["llm_model"] != llm_model:
59
+ st.session_state["llm_model"] = llm_model
60
+ restart_assistant()
61
+
62
+ # Get Embeddings model
63
+ embeddings_model = st.sidebar.selectbox(
64
+ "Select Embeddings",
65
+ options=["nomic-embed-text", "text-embedding-3-small"],
66
+ help="When you change the embeddings model, the documents will need to be added again.",
67
+ )
68
+ # Set assistant_type in session state
69
+ if "embeddings_model" not in st.session_state:
70
+ st.session_state["embeddings_model"] = embeddings_model
71
+ # Restart the assistant if assistant_type has changed
72
+ elif st.session_state["embeddings_model"] != embeddings_model:
73
+ st.session_state["embeddings_model"] = embeddings_model
74
+ st.session_state["embeddings_model_updated"] = True
75
+ restart_assistant()
76
+
77
+ # Get the assistant
78
+ rag_assistant: Assistant
79
+ if "rag_assistant" not in st.session_state or st.session_state["rag_assistant"] is None:
80
+ logger.info(f"---*--- Creating {llm_model} Assistant ---*---")
81
+ rag_assistant = get_groq_assistant(llm_model=llm_model, embeddings_model=embeddings_model)
82
+ st.session_state["rag_assistant"] = rag_assistant
83
+ else:
84
+ rag_assistant = st.session_state["rag_assistant"]
85
+
86
+ # Create assistant run (i.e. log to database) and save run_id in session state
87
+ try:
88
+ st.session_state["rag_assistant_run_id"] = rag_assistant.create_run()
89
+ except Exception:
90
+ st.warning("Could not create assistant, is the database running?")
91
+ return
92
+
93
+ # Load existing messages
94
+ assistant_chat_history = rag_assistant.memory.get_chat_history()
95
+ if len(assistant_chat_history) > 0:
96
+ logger.debug("Loading chat history")
97
+ st.session_state["messages"] = assistant_chat_history
98
+ else:
99
+ logger.debug("No chat history found")
100
+ st.session_state["messages"] = [{"role": "assistant", "content": "Upload a doc and ask me questions..."}]
101
+
102
+ # Prompt for user input
103
+ if prompt := st.chat_input():
104
+ st.session_state["messages"].append({"role": "user", "content": prompt})
105
+
106
+ # Display existing chat messages
107
+ for message in st.session_state["messages"]:
108
+ if message["role"] == "system":
109
+ continue
110
+ with st.chat_message(message["role"]):
111
+ st.write(message["content"])
112
+
113
+ # If last message is from a user, generate a new response
114
+ last_message = st.session_state["messages"][-1]
115
+ if last_message.get("role") == "user":
116
+ question = last_message["content"]
117
+ with st.chat_message("assistant"):
118
+ response = ""
119
+ resp_container = st.empty()
120
+ for delta in rag_assistant.run(question):
121
+ response += delta # type: ignore
122
+ resp_container.markdown(response)
123
+ st.session_state["messages"].append({"role": "assistant", "content": response})
124
+
125
+ # Load knowledge base
126
+ if rag_assistant.knowledge_base:
127
+ # -*- Add websites to knowledge base
128
+ if "url_scrape_key" not in st.session_state:
129
+ st.session_state["url_scrape_key"] = 0
130
+
131
+ input_url = st.sidebar.text_input(
132
+ "Add URL to Knowledge Base", type="default", key=st.session_state["url_scrape_key"]
133
+ )
134
+ add_url_button = st.sidebar.button("Add URL")
135
+ if add_url_button:
136
+ if input_url is not None:
137
+ alert = st.sidebar.info("Processing URLs...", icon="ℹ️")
138
+ if f"{input_url}_scraped" not in st.session_state:
139
+ scraper = WebsiteReader(max_links=2, max_depth=1)
140
+ web_documents: List[Document] = scraper.read(input_url)
141
+ if web_documents:
142
+ rag_assistant.knowledge_base.load_documents(web_documents, upsert=True)
143
+ else:
144
+ st.sidebar.error("Could not read website")
145
+ st.session_state[f"{input_url}_uploaded"] = True
146
+ alert.empty()
147
+
148
+ # Add PDFs to knowledge base
149
+ if "file_uploader_key" not in st.session_state:
150
+ st.session_state["file_uploader_key"] = 100
151
+
152
+ uploaded_file = st.sidebar.file_uploader(
153
+ "Add a PDF :page_facing_up:", type="pdf", key=st.session_state["file_uploader_key"]
154
+ )
155
+ if uploaded_file is not None:
156
+ alert = st.sidebar.info("Processing PDF...", icon="🧠")
157
+ rag_name = uploaded_file.name.split(".")[0]
158
+ if f"{rag_name}_uploaded" not in st.session_state:
159
+ reader = PDFReader()
160
+ rag_documents: List[Document] = reader.read(uploaded_file)
161
+ if rag_documents:
162
+ rag_assistant.knowledge_base.load_documents(rag_documents, upsert=True)
163
+ else:
164
+ st.sidebar.error("Could not read PDF")
165
+ st.session_state[f"{rag_name}_uploaded"] = True
166
+ alert.empty()
167
+
168
+ if rag_assistant.knowledge_base and rag_assistant.knowledge_base.vector_db:
169
+ if st.sidebar.button("Clear Knowledge Base"):
170
+ rag_assistant.knowledge_base.vector_db.clear()
171
+ st.sidebar.success("Knowledge base cleared")
172
+
173
+ if rag_assistant.storage:
174
+ rag_assistant_run_ids: List[str] = rag_assistant.storage.get_all_run_ids()
175
+ new_rag_assistant_run_id = st.sidebar.selectbox("Run ID", options=rag_assistant_run_ids)
176
+ if st.session_state["rag_assistant_run_id"] != new_rag_assistant_run_id:
177
+ logger.info(f"---*--- Loading {llm_model} run: {new_rag_assistant_run_id} ---*---")
178
+ st.session_state["rag_assistant"] = get_groq_assistant(
179
+ llm_model=llm_model, embeddings_model=embeddings_model, run_id=new_rag_assistant_run_id
180
+ )
181
+ st.rerun()
182
+
183
+ if st.sidebar.button("New Run"):
184
+ restart_assistant()
185
+
186
+ if "embeddings_model_updated" in st.session_state:
187
+ st.sidebar.info("Please add documents again as the embeddings model has changed.")
188
+ st.session_state["embeddings_model_updated"] = False
189
+
190
+
191
+ main()