Spaces:
Runtime error
Runtime error
from typing import List | |
import streamlit as st | |
from phi.assistant import Assistant | |
from phi.document import Document | |
from phi.document.reader.pdf import PDFReader | |
from phi.document.reader.website import WebsiteReader | |
from phi.utils.log import logger | |
from assistant import get_groq_assistant # type: ignore | |
st.set_page_config( | |
page_title="ISW RAG", | |
page_icon=":books:", | |
) | |
st.title("RAG with Llama3 on Groq") | |
st.markdown("Built at ISW") | |
import os | |
from groq import Groq | |
client = Groq( | |
api_key=os.environ.get("GROQ_API_KEY"), | |
) | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": "Explain the importance of fast language models", | |
} | |
], | |
model="llama3-8b-8192", | |
) | |
print(chat_completion.choices[0].message.content) | |
print(chat_completion.choices[0].message.content) | |
def restart_assistant(): | |
st.session_state["rag_assistant"] = None | |
st.session_state["rag_assistant_run_id"] = None | |
if "url_scrape_key" in st.session_state: | |
st.session_state["url_scrape_key"] += 1 | |
if "file_uploader_key" in st.session_state: | |
st.session_state["file_uploader_key"] += 1 | |
st.rerun() | |
def main() -> None: | |
# Get LLM model | |
llm_model = st.sidebar.selectbox("Select LLM", options=["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768"]) | |
# Set assistant_type in session state | |
if "llm_model" not in st.session_state: | |
st.session_state["llm_model"] = llm_model | |
# Restart the assistant if assistant_type has changed | |
elif st.session_state["llm_model"] != llm_model: | |
st.session_state["llm_model"] = llm_model | |
restart_assistant() | |
# Get Embeddings model | |
embeddings_model = st.sidebar.selectbox( | |
"Select Embeddings", | |
options=["nomic-embed-text", "text-embedding-3-small"], | |
help="When you change the embeddings model, the documents will need to be added again.", | |
) | |
# Set assistant_type in session state | |
if "embeddings_model" not in st.session_state: | |
st.session_state["embeddings_model"] = embeddings_model | |
# Restart the assistant if assistant_type has changed | |
elif st.session_state["embeddings_model"] != embeddings_model: | |
st.session_state["embeddings_model"] = embeddings_model | |
st.session_state["embeddings_model_updated"] = True | |
restart_assistant() | |
# Get the assistant | |
rag_assistant: Assistant | |
if "rag_assistant" not in st.session_state or st.session_state["rag_assistant"] is None: | |
logger.info(f"---*--- Creating {llm_model} Assistant ---*---") | |
rag_assistant = get_groq_assistant(llm_model=llm_model, embeddings_model=embeddings_model) | |
st.session_state["rag_assistant"] = rag_assistant | |
else: | |
rag_assistant = st.session_state["rag_assistant"] | |
# Create assistant run (i.e. log to database) and save run_id in session state | |
try: | |
st.session_state["rag_assistant_run_id"] = rag_assistant.create_run() | |
except Exception: | |
st.warning("Could not create assistant, is the database running?") | |
return | |
# Load existing messages | |
assistant_chat_history = rag_assistant.memory.get_chat_history() | |
if len(assistant_chat_history) > 0: | |
logger.debug("Loading chat history") | |
st.session_state["messages"] = assistant_chat_history | |
else: | |
logger.debug("No chat history found") | |
st.session_state["messages"] = [{"role": "assistant", "content": "Upload a doc and ask me questions..."}] | |
# Prompt for user input | |
if prompt := st.chat_input(): | |
st.session_state["messages"].append({"role": "user", "content": prompt}) | |
# Display existing chat messages | |
for message in st.session_state["messages"]: | |
if message["role"] == "system": | |
continue | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# If last message is from a user, generate a new response | |
last_message = st.session_state["messages"][-1] | |
if last_message.get("role") == "user": | |
question = last_message["content"] | |
with st.chat_message("assistant"): | |
response = "" | |
resp_container = st.empty() | |
for delta in rag_assistant.run(question): | |
response += delta # type: ignore | |
resp_container.markdown(response) | |
st.session_state["messages"].append({"role": "assistant", "content": response}) | |
# Load knowledge base | |
if rag_assistant.knowledge_base: | |
# -*- Add websites to knowledge base | |
if "url_scrape_key" not in st.session_state: | |
st.session_state["url_scrape_key"] = 0 | |
input_url = st.sidebar.text_input( | |
"Add URL to Knowledge Base", type="default", key=st.session_state["url_scrape_key"] | |
) | |
add_url_button = st.sidebar.button("Add URL") | |
if add_url_button: | |
if input_url is not None: | |
alert = st.sidebar.info("Processing URLs...", icon="ℹ️") | |
if f"{input_url}_scraped" not in st.session_state: | |
scraper = WebsiteReader(max_links=2, max_depth=1) | |
web_documents: List[Document] = scraper.read(input_url) | |
if web_documents: | |
rag_assistant.knowledge_base.load_documents(web_documents, upsert=True) | |
else: | |
st.sidebar.error("Could not read website") | |
st.session_state[f"{input_url}_uploaded"] = True | |
alert.empty() | |
# Add PDFs to knowledge base | |
if "file_uploader_key" not in st.session_state: | |
st.session_state["file_uploader_key"] = 100 | |
uploaded_file = st.sidebar.file_uploader( | |
"Add a PDF :page_facing_up:", type="pdf", key=st.session_state["file_uploader_key"] | |
) | |
if uploaded_file is not None: | |
alert = st.sidebar.info("Processing PDF...", icon="🧠") | |
rag_name = uploaded_file.name.split(".")[0] | |
if f"{rag_name}_uploaded" not in st.session_state: | |
reader = PDFReader() | |
rag_documents: List[Document] = reader.read(uploaded_file) | |
if rag_documents: | |
rag_assistant.knowledge_base.load_documents(rag_documents, upsert=True) | |
else: | |
st.sidebar.error("Could not read PDF") | |
st.session_state[f"{rag_name}_uploaded"] = True | |
alert.empty() | |
if rag_assistant.knowledge_base and rag_assistant.knowledge_base.vector_db: | |
if st.sidebar.button("Clear Knowledge Base"): | |
rag_assistant.knowledge_base.vector_db.clear() | |
st.sidebar.success("Knowledge base cleared") | |
if rag_assistant.storage: | |
rag_assistant_run_ids: List[str] = rag_assistant.storage.get_all_run_ids() | |
new_rag_assistant_run_id = st.sidebar.selectbox("Run ID", options=rag_assistant_run_ids) | |
if st.session_state["rag_assistant_run_id"] != new_rag_assistant_run_id: | |
logger.info(f"---*--- Loading {llm_model} run: {new_rag_assistant_run_id} ---*---") | |
st.session_state["rag_assistant"] = get_groq_assistant( | |
llm_model=llm_model, embeddings_model=embeddings_model, run_id=new_rag_assistant_run_id | |
) | |
st.rerun() | |
if st.sidebar.button("New Run"): | |
restart_assistant() | |
if "embeddings_model_updated" in st.session_state: | |
st.sidebar.info("Please add documents again as the embeddings model has changed.") | |
st.session_state["embeddings_model_updated"] = False | |
main() |