bishmoy commited on
Commit
7055307
·
verified ·
1 Parent(s): 7b3311b

Added download, streaming and initial placeholder

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -5,6 +5,7 @@ from ragatouille import RAGPretrainedModel
5
  from huggingface_hub import InferenceClient
6
  import re
7
  from datetime import datetime
 
8
 
9
  retrieve_results = 10
10
 
@@ -16,15 +17,16 @@ generate_kwargs = dict(
16
  )
17
 
18
  RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
 
19
  try:
20
  gr.Info("Setting up retriever, please wait...")
21
- _ = RAG.search("what is Mistral?", k = 1)
22
  gr.Info("Retriever working successfully!")
23
  except:
24
  gr.Warning("Retriever not working!")
25
 
26
  mark_text = '# 🔍 Search Results\n'
27
- header_text = "# ArXiv RAG\n"
28
  try:
29
  with open("README.md", "r") as f:
30
  mdfile = f.read()
@@ -36,6 +38,12 @@ try:
36
  except:
37
  pass
38
 
 
 
 
 
 
 
39
  def rag_cleaner(inp):
40
  rank = inp['rank']
41
  title = inp['document_metadata']['title']
@@ -59,15 +67,15 @@ def get_rag(message):
59
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
60
  header = gr.Markdown(header_text)
61
  with gr.Group():
62
- msg = gr.Textbox(label = 'Search')
63
  with gr.Accordion("Advanced Settings", open=False):
64
  with gr.Row(equal_height = True):
65
  llm_model = gr.Dropdown(choices = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'None'], value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
66
  llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results to sent as context")
67
 
68
- output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True)
69
  input = gr.Textbox(show_label = False, visible = False)
70
- gr_md = gr.Markdown(mark_text)
71
 
72
  def update_with_rag_md(message, llm_results_use = 5):
73
  rag_out = get_rag(message)
@@ -76,8 +84,9 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
76
  rag_answer = rag_out[i]
77
  title = rag_answer['document_metadata']['title'].replace('\n','')
78
 
79
- score = round(rag_answer['score'], 2)
80
- paper_title = f'''### **{score}** | [{title}](https://arxiv.org/abs/{rag_answer['document_id']})\n'''
 
81
  paper_abs = rag_answer['content']
82
  authors = rag_answer['document_metadata']['authors'].replace('\n','')
83
  authors_formatted = f'*{authors}*' + ' \n\n'
@@ -90,9 +99,16 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
90
  if llm_model_picked == 'None':
91
  return gr.Textbox(visible = False)
92
  client = InferenceClient(llm_model_picked)
93
- output = client.text_generation(prompt, **generate_kwargs, stream=False, details=False, return_full_text=False)
94
- output = output.lstrip(' \n') if output.lstrip().startswith('\n') else output
95
- return gr.Textbox(output, visible = True)
 
 
 
 
 
 
 
96
 
97
  msg.submit(update_with_rag_md, [msg, llm_results], [gr_md, input]).success(ask_llm, [input, llm_model], output_text)
98
 
 
5
  from huggingface_hub import InferenceClient
6
  import re
7
  from datetime import datetime
8
+ import json
9
 
10
  retrieve_results = 10
11
 
 
17
  )
18
 
19
  RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
20
+
21
  try:
22
  gr.Info("Setting up retriever, please wait...")
23
+ rag_initial_output = RAG.search("what is Mistral?", k = 1)
24
  gr.Info("Retriever working successfully!")
25
  except:
26
  gr.Warning("Retriever not working!")
27
 
28
  mark_text = '# 🔍 Search Results\n'
29
+ header_text = "# ArXivCS RAG \n"
30
  try:
31
  with open("README.md", "r") as f:
32
  mdfile = f.read()
 
38
  except:
39
  pass
40
 
41
+ with open("sample_outputs.json", "r") as f:
42
+ sample_outputs = json.load(f)
43
+ output_placeholder = sample_outputs['output_placeholder']
44
+ md_text_initial = sample_outputs['search_placeholder']
45
+
46
+
47
  def rag_cleaner(inp):
48
  rank = inp['rank']
49
  title = inp['document_metadata']['title']
 
67
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
68
  header = gr.Markdown(header_text)
69
  with gr.Group():
70
+ msg = gr.Textbox(label = 'Search', placeholder = 'What is Mistral?')
71
  with gr.Accordion("Advanced Settings", open=False):
72
  with gr.Row(equal_height = True):
73
  llm_model = gr.Dropdown(choices = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'None'], value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
74
  llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results to sent as context")
75
 
76
+ output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
77
  input = gr.Textbox(show_label = False, visible = False)
78
+ gr_md = gr.Markdown(mark_text + md_text_initial)
79
 
80
  def update_with_rag_md(message, llm_results_use = 5):
81
  rag_out = get_rag(message)
 
84
  rag_answer = rag_out[i]
85
  title = rag_answer['document_metadata']['title'].replace('\n','')
86
 
87
+ #score = round(rag_answer['score'], 2)
88
+ date = rag_answer['document_metadata']['_time']
89
+ paper_title = f'''### {date} | [{title}](https://arxiv.org/abs/{rag_answer['document_id']}) | [⬇️](https://arxiv.org/pdf/{rag_answer['document_id']})\n'''
90
  paper_abs = rag_answer['content']
91
  authors = rag_answer['document_metadata']['authors'].replace('\n','')
92
  authors_formatted = f'*{authors}*' + ' \n\n'
 
99
  if llm_model_picked == 'None':
100
  return gr.Textbox(visible = False)
101
  client = InferenceClient(llm_model_picked)
102
+ #output = client.text_generation(prompt, **generate_kwargs, stream=False, details=False, return_full_text=False)
103
+ stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
104
+ #output = output.lstrip(' \n') if output.lstrip().startswith('\n') else output
105
+ output = ""
106
+
107
+ for response in stream:
108
+ output += response.token.text
109
+ yield output
110
+ return output
111
+ #return gr.Textbox(output, visible = True)
112
 
113
  msg.submit(update_with_rag_md, [msg, llm_results], [gr_md, input]).success(ask_llm, [input, llm_model], output_text)
114