Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -107,9 +107,40 @@ def reset_chat(idx, ld, state):
|
|
107 |
gr.update(interactive=False),
|
108 |
)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
async def chat_stream(
|
111 |
idx, local_data, instruction_txtbox, chat_state,
|
112 |
-
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
|
|
|
113 |
):
|
114 |
res = [
|
115 |
chat_state["ppmanager_type"].from_json(json.dumps(ppm))
|
@@ -121,6 +152,14 @@ async def chat_stream(
|
|
121 |
PingPong(instruction_txtbox, "")
|
122 |
)
|
123 |
prompt = build_prompts(ppm, global_context, ctx_num_lconv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
async for result in gen_text(
|
125 |
prompt, hf_model=MODEL_ID, hf_token=TOKEN,
|
126 |
parameters={
|
@@ -283,14 +322,15 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
283 |
elem_id="global-context"
|
284 |
)
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
294 |
|
295 |
gr.Markdown("#### GenConfig for **response** text generation")
|
296 |
with gr.Row():
|
@@ -315,7 +355,8 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
315 |
).then(
|
316 |
chat_stream,
|
317 |
[idx, local_data, instruction_txtbox, chat_state,
|
318 |
-
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
|
|
|
319 |
[instruction_txtbox, context_inspector, chatbot, local_data, regenerate]
|
320 |
).then(
|
321 |
None, local_data, None,
|
@@ -346,7 +387,8 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
346 |
regen_event = regenerate.click(
|
347 |
rollback_last,
|
348 |
[idx, local_data, chat_state,
|
349 |
-
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
|
|
|
350 |
[context_inspector, chatbot, local_data, regenerate]
|
351 |
).then(
|
352 |
None, local_data, None,
|
|
|
107 |
gr.update(interactive=False),
|
108 |
)
|
109 |
|
110 |
+
def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cpu"):
|
111 |
+
internet_search_ppm = copy.deepcopy(ppm)
|
112 |
+
internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, "
|
113 |
+
f"give me an appropriate query to answer my question for google search. "
|
114 |
+
f"You should not say more than query. You should not say any words except the query."
|
115 |
+
|
116 |
+
internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
|
117 |
+
internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv)
|
118 |
+
|
119 |
+
instruction = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN)
|
120 |
+
###
|
121 |
+
|
122 |
+
searcher = SimilaritySearcher.from_pretrained(device=device)
|
123 |
+
iss = InternetSearchStrategy(
|
124 |
+
searcher,
|
125 |
+
instruction=instruction,
|
126 |
+
serper_api_key=serper_api_key
|
127 |
+
)(ppmanager)
|
128 |
+
|
129 |
+
step_ppm = None
|
130 |
+
while True:
|
131 |
+
try:
|
132 |
+
step_ppm, _ = next(iss)
|
133 |
+
yield "", step_ppm.build_uis()
|
134 |
+
except StopIteration:
|
135 |
+
break
|
136 |
+
|
137 |
+
search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv)
|
138 |
+
yield search_prompt, ppmanager.build_uis()
|
139 |
+
|
140 |
async def chat_stream(
|
141 |
idx, local_data, instruction_txtbox, chat_state,
|
142 |
+
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
143 |
+
internet_option, serper_api_key
|
144 |
):
|
145 |
res = [
|
146 |
chat_state["ppmanager_type"].from_json(json.dumps(ppm))
|
|
|
152 |
PingPong(instruction_txtbox, "")
|
153 |
)
|
154 |
prompt = build_prompts(ppm, global_context, ctx_num_lconv)
|
155 |
+
|
156 |
+
#######
|
157 |
+
if internet_option:
|
158 |
+
search_prompt = None
|
159 |
+
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
|
160 |
+
search_prompt = tmp_prompt
|
161 |
+
yield "", uis, prompt, str(res)
|
162 |
+
|
163 |
async for result in gen_text(
|
164 |
prompt, hf_model=MODEL_ID, hf_token=TOKEN,
|
165 |
parameters={
|
|
|
322 |
elem_id="global-context"
|
323 |
)
|
324 |
|
325 |
+
gr.Markdown("#### Internet search")
|
326 |
+
with gr.Row():
|
327 |
+
internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode")
|
328 |
+
serper_api_key = gr.Textbox(
|
329 |
+
value= os.getenv("SERPER_API_KEY"),
|
330 |
+
placeholder="Get one by visiting serper.dev",
|
331 |
+
label="Serper api key",
|
332 |
+
visible=False
|
333 |
+
)
|
334 |
|
335 |
gr.Markdown("#### GenConfig for **response** text generation")
|
336 |
with gr.Row():
|
|
|
355 |
).then(
|
356 |
chat_stream,
|
357 |
[idx, local_data, instruction_txtbox, chat_state,
|
358 |
+
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
359 |
+
internet_option, serper_api_key],
|
360 |
[instruction_txtbox, context_inspector, chatbot, local_data, regenerate]
|
361 |
).then(
|
362 |
None, local_data, None,
|
|
|
387 |
regen_event = regenerate.click(
|
388 |
rollback_last,
|
389 |
[idx, local_data, chat_state,
|
390 |
+
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
391 |
+
internet_option, serper_api_key],
|
392 |
[context_inspector, chatbot, local_data, regenerate]
|
393 |
).then(
|
394 |
None, local_data, None,
|