bishmoy commited on
Commit
0d2d09d
·
verified ·
1 Parent(s): 3c305fd

adjusted number of retrievals and llm inputs

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -4,6 +4,9 @@ import gradio as gr
4
  from ragatouille import RAGPretrainedModel
5
  from huggingface_hub import InferenceClient
6
 
 
 
 
7
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
8
 
9
  generate_kwargs = dict(
@@ -36,7 +39,7 @@ def get_prompt_text(question, context, formatted = True):
36
  return f"<s>" + f"[INST] {sys_instruction} " + f" {message} [/INST] </s> "
37
  return f"Context:\n {context} \n Given the following info, take a deep breath and lets think step by step to answer the question: {question}. Cite the titles of your sources when answering.\n\n"
38
 
39
- def get_references(question, retriever, k = 10):
40
  rag_out = retriever.search(query=question, k=k)
41
  return rag_out
42
 
@@ -59,7 +62,7 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
59
  paper_title = f'''### [{title}](https://arxiv.org/abs/{rag_answer['document_id']})\n'''
60
  paper_abs = rag_answer['content']
61
  md_text_updated += paper_title + paper_abs + '\n---------------\n'+ '\n'
62
- prompt = get_prompt_text(message, '\n\n'.join(rag_cleaner(out) for out in rag_out))
63
  return md_text_updated, prompt
64
 
65
  def ask_llm(prompt):
 
4
  from ragatouille import RAGPretrainedModel
5
  from huggingface_hub import InferenceClient
6
 
7
+ retrieve_results = 10
8
+ llm_results = 5
9
+
10
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
11
 
12
  generate_kwargs = dict(
 
39
  return f"<s>" + f"[INST] {sys_instruction} " + f" {message} [/INST] </s> "
40
  return f"Context:\n {context} \n Given the following info, take a deep breath and lets think step by step to answer the question: {question}. Cite the titles of your sources when answering.\n\n"
41
 
42
+ def get_references(question, retriever, k = retrieve_results):
43
  rag_out = retriever.search(query=question, k=k)
44
  return rag_out
45
 
 
62
  paper_title = f'''### [{title}](https://arxiv.org/abs/{rag_answer['document_id']})\n'''
63
  paper_abs = rag_answer['content']
64
  md_text_updated += paper_title + paper_abs + '\n---------------\n'+ '\n'
65
+ prompt = get_prompt_text(message, '\n\n'.join(rag_cleaner(out) for out in rag_out[:llm_results]))
66
  return md_text_updated, prompt
67
 
68
  def ask_llm(prompt):