bishmoy commited on
Commit
838a4fb
·
verified ·
1 Parent(s): d631d86

added option to disable streaming output

Browse files
Files changed (1) hide show
  1. app.py +21 -16
app.py CHANGED
@@ -76,7 +76,8 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
76
  with gr.Accordion("Advanced Settings", open=False):
77
  with gr.Row(equal_height = True):
78
  llm_model = gr.Dropdown(choices = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'None'], value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
79
- llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results to sent as context")
 
80
 
81
  output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
82
  input = gr.Textbox(show_label = False, visible = False)
@@ -100,31 +101,35 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
100
  prompt = get_prompt_text(message, '\n\n'.join(rag_cleaner(out) for out in rag_out[:llm_results_use]))
101
  return md_text_updated, prompt
102
 
103
- def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
104
  model_disabled_text = "LLM Model is disabled"
105
  output = ""
106
  if llm_model_picked == 'None':
107
- for out in model_disabled_text:
108
- output += out
109
- yield output
110
- return output
111
-
 
 
 
112
  client = InferenceClient(llm_model_picked)
113
- #output = client.text_generation(prompt, **generate_kwargs, stream=False, details=False, return_full_text=False)
114
  try:
115
- stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
116
  except:
117
  gr.Warning("LLM Inference rate limit reached, try again later!")
118
  return ""
119
  #output = output.lstrip(' \n') if output.lstrip().startswith('\n') else output
120
 
 
 
 
 
 
 
 
121
 
122
- for response in stream:
123
- output += response.token.text
124
- yield output
125
- return output
126
- #return gr.Textbox(output, visible = True)
127
 
128
- msg.submit(update_with_rag_md, [msg, llm_results], [gr_md, input]).success(ask_llm, [input, llm_model], output_text)
129
 
130
- demo.launch(debug = True)
 
76
  with gr.Accordion("Advanced Settings", open=False):
77
  with gr.Row(equal_height = True):
78
  llm_model = gr.Dropdown(choices = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'None'], value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
79
+ llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
80
+ stream_results = gr.Checkbox(value = True, label = "Stream output")
81
 
82
  output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
83
  input = gr.Textbox(show_label = False, visible = False)
 
101
  prompt = get_prompt_text(message, '\n\n'.join(rag_cleaner(out) for out in rag_out[:llm_results_use]))
102
  return md_text_updated, prompt
103
 
104
+ def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2', stream_outputs = False):
105
  model_disabled_text = "LLM Model is disabled"
106
  output = ""
107
  if llm_model_picked == 'None':
108
+ if stream_outputs:
109
+ for out in model_disabled_text:
110
+ output += out
111
+ yield output
112
+ return output
113
+ else:
114
+ return model_disabled_text
115
+
116
  client = InferenceClient(llm_model_picked)
 
117
  try:
118
+ stream = client.text_generation(prompt, **generate_kwargs, stream=stream_outputs, details=False, return_full_text=False)
119
  except:
120
  gr.Warning("LLM Inference rate limit reached, try again later!")
121
  return ""
122
  #output = output.lstrip(' \n') if output.lstrip().startswith('\n') else output
123
 
124
+ if stream_outputs:
125
+ for response in stream:
126
+ output += response
127
+ yield output
128
+ return output
129
+ else:
130
+ return stream
131
 
 
 
 
 
 
132
 
133
+ msg.submit(update_with_rag_md, [msg, llm_results], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
134
 
135
+ demo.queue(default_concurrency_limit=10).launch()