ayut commited on
Commit
c35885d
β€’
1 Parent(s): c0fdffe
Files changed (2) hide show
  1. app.py +25 -10
  2. rag/rag.py +1 -1
app.py CHANGED
@@ -15,9 +15,11 @@ st.set_page_config(
15
  menu_items=None,
16
  )
17
 
 
 
18
  WANDB_PROJECT = "paper_reader"
19
 
20
- weave.init(f"{WANDB_PROJECT}")
21
 
22
  st.title("Chat with the Llama 3 paper πŸ’¬πŸ¦™")
23
 
@@ -33,16 +35,29 @@ with st.spinner('Loading the RAG pipeline...'):
33
  if "rag_pipeline" not in st.session_state.keys():
34
  st.session_state.rag_pipeline = load_rag_pipeline()
35
 
36
- rag_pipeline = st.session_state["rag_pipeline"]
37
-
38
-
39
- def generate_response(query):
40
- response = rag_pipeline.predict(query)
41
- st.write_stream(response.response_gen)
42
-
43
 
44
  with st.form("my_form"):
45
  query = st.text_area("Ask your question about the Llama 3 paper here:")
46
  submitted = st.form_submit_button("Submit")
47
- if submitted:
48
- generate_response(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  menu_items=None,
16
  )
17
 
18
+ st.session_state['session_id'] = '123abc'
19
+
20
  WANDB_PROJECT = "paper_reader"
21
 
22
+ weave_client = weave.init(f"{WANDB_PROJECT}")
23
 
24
  st.title("Chat with the Llama 3 paper πŸ’¬πŸ¦™")
25
 
 
35
  if "rag_pipeline" not in st.session_state.keys():
36
  st.session_state.rag_pipeline = load_rag_pipeline()
37
 
38
+ rag_pipeline = st.session_state["rag_pipeline"]
 
 
 
 
 
 
39
 
40
  with st.form("my_form"):
41
  query = st.text_area("Ask your question about the Llama 3 paper here:")
42
  submitted = st.form_submit_button("Submit")
43
+
44
+ if submitted:
45
+ with st.spinner('Generating answer...'):
46
+ output = rag_pipeline.predict(query)
47
+ st.session_state["last_output"] = output
48
+ text = ""
49
+ for t in output["response"].response_gen:
50
+ text += t
51
+ st.session_state["last_text"] = text
52
+
53
+ st.write_stream(output["response"].response_gen)
54
+
55
+
56
+ if "last_output" in st.session_state:
57
+ output = st.session_state["last_output"]
58
+ text = st.session_state["last_text"]
59
+ st.write(text)
60
+
61
+ # use the weave client to retrieve the call and attach feedback
62
+ st.button(":thumbsup:", on_click=lambda: weave_client.call(output['call_id']).feedback.add_reaction("πŸ‘"), key='up')
63
+ st.button(":thumbsdown:", on_click=lambda: weave_client.call(output['call_id']).feedback.add_reaction("πŸ‘Ž"), key='down')
rag/rag.py CHANGED
@@ -150,7 +150,7 @@ class SimpleRAGPipeline(weave.Model):
150
  @weave.op()
151
  def predict(self, question: str):
152
  response = self.query_engine.query(question)
153
- return response
154
 
155
 
156
  if __name__ == "__main__":
 
150
  @weave.op()
151
  def predict(self, question: str):
152
  response = self.query_engine.query(question)
153
+ return {"response": response, 'call_id': weave.get_current_call().id}
154
 
155
 
156
  if __name__ == "__main__":