paper_reader / app.py
ayut's picture
display trace url
34ccb7c
raw
history blame
1.99 kB
import os
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["WANDB_API_KEY"] = os.getenv("WANDB_API_KEY")
import streamlit as st
import weave
from rag.rag import SimpleRAGPipeline
st.set_page_config(
page_title="Chat with the Llama 3 paper!",
page_icon="πŸ¦™",
layout="centered",
initial_sidebar_state="auto",
menu_items=None,
)
st.session_state['session_id'] = '123abc'
WANDB_PROJECT = "paper_reader"
weave_client = weave.init(f"{WANDB_PROJECT}")
st.title("Chat with the Llama 3 paper πŸ’¬πŸ¦™")
with st.spinner('Loading the RAG pipeline...'):
@st.cache_resource(show_spinner=False)
def load_rag_pipeline():
rag_pipeline = SimpleRAGPipeline()
rag_pipeline.build_query_engine()
return rag_pipeline
if "rag_pipeline" not in st.session_state.keys():
st.session_state.rag_pipeline = load_rag_pipeline()
rag_pipeline = st.session_state["rag_pipeline"]
with st.form("my_form"):
query = st.text_area("Ask your question about the Llama 3 paper here:")
submitted = st.form_submit_button("Submit")
if submitted:
with st.spinner('Generating answer...'):
output = rag_pipeline.predict(query)
st.session_state["last_output"] = output
text = ""
for t in output["response"].response_gen:
text += t
st.session_state["last_text"] = text
st.write_stream(output["response"].response_gen)
if "last_output" in st.session_state:
output = st.session_state["last_output"]
text = st.session_state["last_text"]
st.write(text)
url = output["url"]
st.info(f"The weave trace url: {url}", icon="ℹ️")
# use the weave client to retrieve the call and attach feedback
st.button(":thumbsup:", on_click=lambda: weave_client.call(output['call_id']).feedback.add_reaction("πŸ‘"), key='up')
st.button(":thumbsdown:", on_click=lambda: weave_client.call(output['call_id']).feedback.add_reaction("πŸ‘Ž"), key='down')