bishmoy commited on
Commit
5628a76
·
verified ·
1 Parent(s): 2e45345

code clean up

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -24,11 +24,13 @@ try:
24
  gr.Info("Setting up retriever, please wait...")
25
  rag_initial_output = RAG.search("what is Mistral?", k = 1)
26
  gr.Info("Retriever working successfully!")
 
27
  except:
28
  gr.Warning("Retriever not working!")
29
 
30
  mark_text = '# 🔍 Search Results\n'
31
  header_text = "# ArXivCS RAG \n"
 
32
  try:
33
  with open("README.md", "r") as f:
34
  mdfile = f.read()
@@ -37,6 +39,7 @@ try:
37
  date = match.group().split(': ')[1]
38
  formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
39
  header_text += f'Index Last Updated: {formatted_date}\n'
 
40
  except:
41
  pass
42
 
@@ -45,6 +48,7 @@ if show_examples:
45
  sample_outputs = json.load(f)
46
  output_placeholder = sample_outputs['output_placeholder']
47
  md_text_initial = sample_outputs['search_placeholder']
 
48
  else:
49
  output_placeholder = None
50
  md_text_initial = ''
@@ -61,6 +65,7 @@ def get_prompt_text(question, context, formatted = True, llm_model_picked = 'mis
61
  if formatted:
62
  sys_instruction = f"Context:\n {context} \n Given the following scientific paper abstracts, take a deep breath and lets think step by step to answer the question. Cite the titles of your sources when answering, do not cite links or dates."
63
  message = f"Question: {question}"
 
64
  if 'mistralai' in llm_model_picked:
65
  return f"<s>" + f"[INST] {sys_instruction}" + f" {message}[/INST]"
66
 
@@ -74,12 +79,14 @@ def get_references(question, retriever, k = retrieve_results):
74
  return rag_out
75
 
76
  def get_rag(message):
77
- return get_references(message, RAG)
78
 
79
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
80
  header = gr.Markdown(header_text)
 
81
  with gr.Group():
82
  msg = gr.Textbox(label = 'Search', placeholder = 'What is Mistral?')
 
83
  with gr.Accordion("Advanced Settings", open=False):
84
  with gr.Row(equal_height = True):
85
  llm_model = gr.Dropdown(choices = llm_models_to_choose, value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
@@ -97,7 +104,6 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
97
  rag_answer = rag_out[i]
98
  title = rag_answer['document_metadata']['title'].replace('\n','')
99
 
100
- #score = round(rag_answer['score'], 2)
101
  date = rag_answer['document_metadata']['_time']
102
  paper_title = f'''### {date} | [{title}](https://arxiv.org/abs/{rag_answer['document_id']}) | [⬇️](https://arxiv.org/pdf/{rag_answer['document_id']})\n'''
103
  paper_abs = rag_answer['content']
@@ -111,6 +117,7 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
111
  def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2', stream_outputs = False):
112
  model_disabled_text = "LLM Model is disabled"
113
  output = ""
 
114
  if llm_model_picked == 'None':
115
  if stream_outputs:
116
  for out in model_disabled_text:
@@ -123,10 +130,10 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
123
  client = InferenceClient(llm_model_picked)
124
  try:
125
  stream = client.text_generation(prompt, **generate_kwargs, stream=stream_outputs, details=False, return_full_text=False)
 
126
  except:
127
  gr.Warning("LLM Inference rate limit reached, try again later!")
128
  return ""
129
- #output = output.lstrip(' \n') if output.lstrip().startswith('\n') else output
130
 
131
  if stream_outputs:
132
  for response in stream:
@@ -139,4 +146,4 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
139
 
140
  msg.submit(update_with_rag_md, [msg, llm_results, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
141
 
142
- demo.queue(default_concurrency_limit=10).launch()
 
24
  gr.Info("Setting up retriever, please wait...")
25
  rag_initial_output = RAG.search("what is Mistral?", k = 1)
26
  gr.Info("Retriever working successfully!")
27
+
28
  except:
29
  gr.Warning("Retriever not working!")
30
 
31
  mark_text = '# 🔍 Search Results\n'
32
  header_text = "# ArXivCS RAG \n"
33
+
34
  try:
35
  with open("README.md", "r") as f:
36
  mdfile = f.read()
 
39
  date = match.group().split(': ')[1]
40
  formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
41
  header_text += f'Index Last Updated: {formatted_date}\n'
42
+
43
  except:
44
  pass
45
 
 
48
  sample_outputs = json.load(f)
49
  output_placeholder = sample_outputs['output_placeholder']
50
  md_text_initial = sample_outputs['search_placeholder']
51
+
52
  else:
53
  output_placeholder = None
54
  md_text_initial = ''
 
65
  if formatted:
66
  sys_instruction = f"Context:\n {context} \n Given the following scientific paper abstracts, take a deep breath and lets think step by step to answer the question. Cite the titles of your sources when answering, do not cite links or dates."
67
  message = f"Question: {question}"
68
+
69
  if 'mistralai' in llm_model_picked:
70
  return f"<s>" + f"[INST] {sys_instruction}" + f" {message}[/INST]"
71
 
 
79
  return rag_out
80
 
81
  def get_rag(message):
82
+ return get_references(message, RAG)
83
 
84
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
85
  header = gr.Markdown(header_text)
86
+
87
  with gr.Group():
88
  msg = gr.Textbox(label = 'Search', placeholder = 'What is Mistral?')
89
+
90
  with gr.Accordion("Advanced Settings", open=False):
91
  with gr.Row(equal_height = True):
92
  llm_model = gr.Dropdown(choices = llm_models_to_choose, value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
 
104
  rag_answer = rag_out[i]
105
  title = rag_answer['document_metadata']['title'].replace('\n','')
106
 
 
107
  date = rag_answer['document_metadata']['_time']
108
  paper_title = f'''### {date} | [{title}](https://arxiv.org/abs/{rag_answer['document_id']}) | [⬇️](https://arxiv.org/pdf/{rag_answer['document_id']})\n'''
109
  paper_abs = rag_answer['content']
 
117
  def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2', stream_outputs = False):
118
  model_disabled_text = "LLM Model is disabled"
119
  output = ""
120
+
121
  if llm_model_picked == 'None':
122
  if stream_outputs:
123
  for out in model_disabled_text:
 
130
  client = InferenceClient(llm_model_picked)
131
  try:
132
  stream = client.text_generation(prompt, **generate_kwargs, stream=stream_outputs, details=False, return_full_text=False)
133
+
134
  except:
135
  gr.Warning("LLM Inference rate limit reached, try again later!")
136
  return ""
 
137
 
138
  if stream_outputs:
139
  for response in stream:
 
146
 
147
  msg.submit(update_with_rag_md, [msg, llm_results, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
148
 
149
+ demo.queue().launch()