Spaces:
Running
on
T4
Running
on
T4
import os | |
import re | |
import json | |
import copy | |
import gradio as gr | |
from llama2 import GradioLLaMA2ChatPPManager | |
from llama2 import gen_text, gen_text_none_stream | |
from styles import MODEL_SELECTION_CSS | |
from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS | |
from templates import templates | |
from constants import DEFAULT_GLOBAL_CTX | |
from pingpong import PingPong | |
from pingpong.context import CtxLastWindowStrategy | |
from pingpong.context import InternetSearchStrategy, SimilaritySearcher | |
TOKEN = os.getenv('HF_TOKEN') | |
MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf' | |
def build_prompts(ppmanager, global_context, win_size=3): | |
dummy_ppm = copy.deepcopy(ppmanager) | |
dummy_ppm.ctx = global_context | |
lws = CtxLastWindowStrategy(win_size) | |
return lws(dummy_ppm) | |
ex_file = open("examples.txt", "r") | |
examples = ex_file.read().split("\n") | |
ex_btns = [] | |
chl_file = open("channels.txt", "r") | |
channels = chl_file.read().split("\n") | |
channel_btns = [] | |
def get_placeholders(text): | |
"""Returns all substrings in between <placeholder> and </placeholder>.""" | |
pattern = r"\[([^\]]*)\]" | |
matches = re.findall(pattern, text) | |
return matches | |
def fill_up_placeholders(txt): | |
placeholders = get_placeholders(txt) | |
highlighted_txt = txt | |
return ( | |
gr.update( | |
visible=True, | |
value=highlighted_txt | |
), | |
gr.update( | |
visible=True if len(placeholders) >= 1 else False, | |
placeholder=placeholders[0] if len(placeholders) >= 1 else "" | |
), | |
gr.update( | |
visible=True if len(placeholders) >= 2 else False, | |
placeholder=placeholders[1] if len(placeholders) >= 2 else "" | |
), | |
gr.update( | |
visible=True if len(placeholders) >= 3 else False, | |
placeholder=placeholders[2] if len(placeholders) >= 3 else "" | |
), | |
"" if len(placeholders) >= 1 else txt | |
) | |
def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cuda"): | |
internet_search_ppm = copy.deepcopy(ppmanager) | |
user_msg = internet_search_ppm.pingpongs[-1].ping | |
internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query." | |
internet_search_ppm.pingpongs[-1].ping = internet_search_prompt | |
internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv) | |
instruction = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN) | |
### | |
searcher = SimilaritySearcher.from_pretrained(device=device) | |
iss = InternetSearchStrategy( | |
searcher, | |
instruction=instruction, | |
serper_api_key=serper_api_key | |
)(ppmanager) | |
step_ppm = None | |
while True: | |
try: | |
step_ppm, _ = next(iss) | |
yield "", step_ppm.build_uis() | |
except StopIteration: | |
break | |
search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv) | |
yield search_prompt, ppmanager.build_uis() | |
async def rollback_last( | |
idx, local_data, chat_state, | |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv, | |
internet_option, serper_api_key | |
): | |
internet_option = True if internet_option == "on" else False | |
res = [ | |
chat_state["ppmanager_type"].from_json(json.dumps(ppm)) | |
for ppm in local_data | |
] | |
ppm = res[idx] | |
last_user_message = res[idx].pingpongs[-1].ping | |
res[idx].pingpongs = res[idx].pingpongs[:-1] | |
ppm.add_pingpong( | |
PingPong(last_user_message, "") | |
) | |
prompt = build_prompts(ppm, global_context, ctx_num_lconv) | |
####### | |
if internet_option: | |
search_prompt = None | |
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): | |
search_prompt = tmp_prompt | |
yield "", prompt, uis, str(res), gr.update(interactive=False), "off" | |
async for result in gen_text( | |
search_prompt if internet_option else prompt, | |
hf_model=MODEL_ID, hf_token=TOKEN, | |
parameters={ | |
'max_new_tokens': res_mnts, | |
'do_sample': res_sample, | |
'return_full_text': False, | |
'temperature': res_temp, | |
'top_k': res_topk, | |
'repetition_penalty': res_rpen | |
} | |
): | |
ppm.append_pong(result) | |
yield prompt, ppm.build_uis(), str(res), gr.update(interactive=False), "off" | |
yield prompt, ppm.build_uis(), str(res), gr.update(interactive=True), "off" | |
def reset_chat(idx, ld, state): | |
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] | |
res[idx].pingpongs = [] | |
return ( | |
"", | |
[], | |
str(res), | |
gr.update(visible=True), | |
gr.update(interactive=False), | |
) | |
async def chat_stream( | |
idx, local_data, instruction_txtbox, chat_state, | |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv, | |
internet_option, serper_api_key | |
): | |
internet_option = True if internet_option == "on" else False | |
res = [ | |
chat_state["ppmanager_type"].from_json(json.dumps(ppm)) | |
for ppm in local_data | |
] | |
ppm = res[idx] | |
ppm.add_pingpong( | |
PingPong(instruction_txtbox, "") | |
) | |
prompt = build_prompts(ppm, global_context, ctx_num_lconv) | |
####### | |
if internet_option: | |
search_prompt = None | |
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): | |
search_prompt = tmp_prompt | |
yield "", prompt, uis, str(res), gr.update(interactive=False), "off" | |
async for result in gen_text( | |
search_prompt if internet_option else prompt, | |
hf_model=MODEL_ID, hf_token=TOKEN, | |
parameters={ | |
'max_new_tokens': res_mnts, | |
'do_sample': res_sample, | |
'return_full_text': False, | |
'temperature': res_temp, | |
'top_k': res_topk, | |
'repetition_penalty': res_rpen | |
} | |
): | |
ppm.append_pong(result) | |
yield "", prompt, ppm.build_uis(), str(res), gr.update(interactive=False), "off" | |
yield "", prompt, ppm.build_uis(), str(res), gr.update(interactive=True), "off" | |
def channel_num(btn_title): | |
choice = 0 | |
for idx, channel in enumerate(channels): | |
if channel == btn_title: | |
choice = idx | |
return choice | |
def set_chatbot(btn, ld, state): | |
choice = channel_num(btn) | |
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] | |
empty = len(res[choice].pingpongs) == 0 | |
return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty)) | |
def set_example(btn): | |
return btn, gr.update(visible=False) | |
def get_final_template( | |
txt, placeholder_txt1, placeholder_txt2, placeholder_txt3 | |
): | |
placeholders = get_placeholders(txt) | |
example_prompt = txt | |
if len(placeholders) >= 1: | |
if placeholder_txt1 != "": | |
example_prompt = example_prompt.replace(f"[{placeholders[0]}]", placeholder_txt1) | |
if len(placeholders) >= 2: | |
if placeholder_txt2 != "": | |
example_prompt = example_prompt.replace(f"[{placeholders[1]}]", placeholder_txt2) | |
if len(placeholders) >= 3: | |
if placeholder_txt3 != "": | |
example_prompt = example_prompt.replace(f"[{placeholders[2]}]", placeholder_txt3) | |
return ( | |
example_prompt, | |
"", | |
"", | |
"" | |
) | |
with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo: | |
with gr.Column() as chat_view: | |
idx = gr.State(0) | |
chat_state = gr.State({ | |
"ppmanager_type": GradioLLaMA2ChatPPManager | |
}) | |
local_data = gr.JSON({}, visible=False) | |
gr.Markdown("## LLaMA2 70B with Gradio Chat and Hugging Face Inference API", elem_classes=["center"]) | |
gr.Markdown( | |
"This space demonstrates how to build feature rich chatbot UI in [Gradio](https://www.gradio.app/). Supported features " | |
"include • multiple chatting channels, • chat history save/restoration, • stop generating text response, • regenerate the " | |
"last conversation, • clean the chat history, • dynamic kick-starting prompt templates, • adjusting text generation parameters, " | |
"• inspecting the actual prompt that the model sees. The underlying Large Language Model is the [Meta AI](https://ai.meta.com/)'s " | |
"[LLaMA2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) which is hosted as [Hugging Face Inference API](https://huggingface.co/inference-api), " | |
"and [Text Generation Inference](https://github.com/huggingface/text-generation-inference) is the underlying serving framework. ", | |
elem_classes=["center"] | |
) | |
gr.Markdown( | |
"***NOTE:*** If you are subscribing [PRO](https://huggingface.co/pricing#pro), you can simply duplicate this space and use your " | |
"Hugging Face Access Token to run the same application. Just add `HF_TOKEN` secret with the Token value accorindg to [this guide](https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables)", | |
elem_classes=["center"] | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=180): | |
gr.Markdown("GradioChat", elem_id="left-top") | |
with gr.Column(elem_id="left-pane"): | |
with gr.Accordion("Histories", elem_id="chat-history-accordion", open=True): | |
channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"])) | |
for channel in channels[1:]: | |
channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"])) | |
internet_option = gr.Radio(choices=["on", "off"], value="off", label="internet mode") | |
serper_api_key = gr.Textbox( | |
value= os.getenv("SERPER_API_KEY"), | |
placeholder="Get one by visiting serper.dev", | |
label="Serper api key", | |
visible=False | |
) | |
with gr.Column(scale=8, elem_id="right-pane"): | |
with gr.Column( | |
elem_id="initial-popup", visible=False | |
) as example_block: | |
with gr.Row(scale=1): | |
with gr.Column(elem_id="initial-popup-left-pane"): | |
gr.Markdown("GradioChat", elem_id="initial-popup-title") | |
gr.Markdown("Making the community's best AI chat models available to everyone.") | |
with gr.Column(elem_id="initial-popup-right-pane"): | |
gr.Markdown("Chat UI is now open sourced on Hugging Face Hub") | |
gr.Markdown("check out the [↗ repository](https://huggingface.co/spaces/chansung/test-multi-conv)") | |
with gr.Column(scale=1): | |
gr.Markdown("Examples") | |
with gr.Row(): | |
for example in examples: | |
ex_btns.append(gr.Button(example, elem_classes=["example-btn"])) | |
with gr.Column(elem_id="aux-btns-popup", visible=True): | |
with gr.Row(): | |
# stop = gr.Button("Stop", elem_classes=["aux-btn"]) | |
regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"]) | |
clean = gr.Button("Clean", elem_classes=["aux-btn"]) | |
with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False): | |
context_inspector = gr.Textbox( | |
"", | |
elem_id="aux-viewer-inspector", | |
label="", | |
lines=30, | |
max_lines=50, | |
) | |
chatbot = gr.Chatbot(elem_id='chatbot', label="LLaMA2-70B-Chat") | |
instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt") | |
with gr.Accordion("Example Templates", open=False): | |
template_txt = gr.Textbox(visible=False) | |
template_md = gr.Markdown(label="Chosen Template", visible=False, elem_classes="template-txt") | |
with gr.Row(): | |
placeholder_txt1 = gr.Textbox(label="placeholder #1", visible=False, interactive=True) | |
placeholder_txt2 = gr.Textbox(label="placeholder #2", visible=False, interactive=True) | |
placeholder_txt3 = gr.Textbox(label="placeholder #3", visible=False, interactive=True) | |
for template in templates: | |
with gr.Tab(template['title']): | |
gr.Examples( | |
template['template'], | |
inputs=[template_txt], | |
outputs=[template_md, placeholder_txt1, placeholder_txt2, placeholder_txt3, instruction_txtbox], | |
run_on_click=True, | |
fn=fill_up_placeholders, | |
) | |
with gr.Accordion("Control Panel", open=False) as control_panel: | |
with gr.Column(): | |
with gr.Column(): | |
gr.Markdown("#### Global context") | |
with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=True): | |
global_context = gr.Textbox( | |
DEFAULT_GLOBAL_CTX, | |
lines=5, | |
max_lines=10, | |
interactive=True, | |
elem_id="global-context" | |
) | |
gr.Markdown("#### GenConfig for **response** text generation") | |
with gr.Row(): | |
res_temp = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="temp", interactive=True) | |
res_topk = gr.Slider(20, 1000, 50, step=1, label="top_k", interactive=True) | |
res_rpen = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="rep_penalty", interactive=True) | |
res_mnts = gr.Slider(64, 8192, 512, step=1, label="new_tokens", interactive=True) | |
res_sample = gr.Radio([True, False], value=True, label="sample", interactive=True) | |
with gr.Column(): | |
gr.Markdown("#### Context managements") | |
with gr.Row(): | |
ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True) | |
send_event = instruction_txtbox.submit( | |
lambda: [ | |
gr.update(visible=False), | |
gr.update(interactive=True) | |
], | |
None, | |
[example_block, regenerate] | |
).then( | |
chat_stream, | |
[idx, local_data, instruction_txtbox, chat_state, | |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv, | |
internet_option, serper_api_key], | |
[instruction_txtbox, context_inspector, chatbot, local_data, regenerate, internet_option] | |
).then( | |
None, local_data, None, | |
_js="(v)=>{ setStorage('local_data',v) }" | |
) | |
# regen_event1 = regenerate.click( | |
# rollback_last, | |
# [idx, local_data, chat_state], | |
# [instruction_txtbox, chatbot, local_data, regenerate] | |
# ) | |
# regen_event2 = regen_event1.then( | |
# chat_stream, | |
# [idx, local_data, instruction_txtbox, chat_state, | |
# global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv], | |
# [context_inspector, chatbot, local_data] | |
# ) | |
# regen_event3 = regen_event2.then( | |
# lambda: gr.update(interactive=True), | |
# None, | |
# regenerate | |
# ) | |
# regen_event4 = regen_event3.then( | |
# None, local_data, None, | |
# _js="(v)=>{ setStorage('local_data',v) }" | |
# ) | |
regen_event = regenerate.click( | |
rollback_last, | |
[idx, local_data, chat_state, | |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv, | |
internet_option, serper_api_key], | |
[context_inspector, chatbot, local_data, regenerate, internet_option] | |
).then( | |
None, local_data, None, | |
_js="(v)=>{ setStorage('local_data',v) }" | |
) | |
# stop.click( | |
# lambda: gr.update(interactive=True), None, regenerate, | |
# cancels=[send_event, regen_event] | |
# ) | |
for btn in channel_btns: | |
btn.click( | |
set_chatbot, | |
[btn, local_data, chat_state], | |
[chatbot, idx, example_block, regenerate] | |
).then( | |
None, btn, None, | |
_js=UPDATE_LEFT_BTNS_STATE | |
) | |
for btn in ex_btns: | |
btn.click( | |
set_example, | |
[btn], | |
[instruction_txtbox, example_block] | |
) | |
clean.click( | |
reset_chat, | |
[idx, local_data, chat_state], | |
[instruction_txtbox, chatbot, local_data, example_block, regenerate] | |
).then( | |
None, local_data, None, | |
_js="(v)=>{ setStorage('local_data',v) }" | |
) | |
placeholder_txt1.change( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[template_md], | |
show_progress=False, | |
_js=UPDATE_PLACEHOLDERS, | |
fn=None | |
) | |
placeholder_txt2.change( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[template_md], | |
show_progress=False, | |
_js=UPDATE_PLACEHOLDERS, | |
fn=None | |
) | |
placeholder_txt3.change( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[template_md], | |
show_progress=False, | |
_js=UPDATE_PLACEHOLDERS, | |
fn=None | |
) | |
placeholder_txt1.submit( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
fn=get_final_template | |
) | |
placeholder_txt2.submit( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
fn=get_final_template | |
) | |
placeholder_txt3.submit( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
fn=get_final_template | |
) | |
demo.load( | |
None, | |
inputs=None, | |
outputs=[chatbot, local_data], | |
_js=GET_LOCAL_STORAGE, | |
) | |
demo.queue(concurrency_count=5, max_size=256).launch() |