|
|
|
import os |
|
import logging |
|
import gradio as gr |
|
import gc |
|
from interface.hddr_llama_onnx_interface import LlamaOnnxInterface |
|
from interface.empty_stub_interface import EmptyStubInterface |
|
from ChatApp.app_modules.utils import ( |
|
reset_textbox, |
|
transfer_input, |
|
reset_state, |
|
delete_last_conversation, |
|
cancel_outputing, |
|
) |
|
from ChatApp.app_modules.presets import ( |
|
small_and_beautiful_theme, |
|
title, |
|
description_top, |
|
description, |
|
) |
|
from ChatApp.app_modules.overwrites import postprocess |
|
|
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", |
|
) |
|
|
|
|
|
empty_stub_model_name = "_Empty Stub_" |
|
|
|
top_directory = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
|
|
|
tokenizer_path = os.path.join(top_directory, "tokenizer.model") |
|
|
|
available_models = { |
|
"Llama-2 13B Float16": { |
|
"onnx_file": os.path.join( |
|
top_directory, "FP16", "LlamaV2_13B_float16.onnx" |
|
), |
|
"tokenizer_path": tokenizer_path, |
|
"embedding_file": os.path.join(top_directory, "embeddings.pth"), |
|
}, |
|
"Llama-2 13B FP32": { |
|
"onnx_file": os.path.join( |
|
top_directory, "FP32", "LlamaV2_13B_float16.onnx" |
|
), |
|
"tokenizer_path": tokenizer_path, |
|
"embedding_file": os.path.join( |
|
top_directory, "embeddings.pth" |
|
), |
|
}, |
|
} |
|
|
|
|
|
interface = EmptyStubInterface() |
|
interface.initialize() |
|
|
|
|
|
|
|
gr.Chatbot.postprocess = postprocess |
|
|
|
with open("ChatApp/assets/custom.css", "r", encoding="utf-8") as f: |
|
custom_css = f.read() |
|
|
|
|
|
def change_model_listener(new_model_name): |
|
if new_model_name is None: |
|
new_model_name = empty_stub_model_name |
|
|
|
global interface |
|
|
|
|
|
if interface is not None: |
|
interface.shutdown() |
|
del interface |
|
gc.collect() |
|
|
|
logging.info(f"Creating a new model [{new_model_name}]") |
|
if new_model_name == empty_stub_model_name: |
|
interface = EmptyStubInterface() |
|
interface.initialize() |
|
else: |
|
d = available_models[new_model_name] |
|
interface = LlamaOnnxInterface( |
|
onnx_file=d["onnx_file"], |
|
tokenizer_path=d["tokenizer_path"], |
|
embedding_file=d["embedding_file"], |
|
) |
|
interface.initialize() |
|
|
|
return new_model_name |
|
|
|
|
|
def interface_predict(*args): |
|
global interface |
|
res = interface.predict(*args) |
|
|
|
for x in res: |
|
yield x |
|
|
|
|
|
def interface_retry(*args): |
|
global interface |
|
res = interface.retry(*args) |
|
|
|
for x in res: |
|
yield x |
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=small_and_beautiful_theme) as demo: |
|
history = gr.State([]) |
|
user_question = gr.State("") |
|
with gr.Row(): |
|
gr.HTML(title) |
|
status_display = gr.Markdown("Success", elem_id="status_display") |
|
gr.Markdown(description_top) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
with gr.Row(): |
|
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot", height=900) |
|
with gr.Row(): |
|
with gr.Column(scale=12): |
|
user_input = gr.Textbox(show_label=False, placeholder="Enter text") |
|
with gr.Column(min_width=70, scale=1): |
|
submit_button = gr.Button("Send") |
|
with gr.Column(min_width=70, scale=1): |
|
cancel_button = gr.Button("Stop") |
|
with gr.Row(): |
|
empty_button = gr.Button( |
|
"๐งน New Conversation", |
|
) |
|
retry_button = gr.Button("๐ Regenerate") |
|
delete_last_button = gr.Button("๐๏ธ Remove Last Turn") |
|
with gr.Column(): |
|
with gr.Column(min_width=50, scale=1): |
|
with gr.Tab(label="Parameter Setting"): |
|
gr.Markdown("# Model") |
|
model_name = gr.Dropdown( |
|
choices=[empty_stub_model_name] + list(available_models.keys()), |
|
label="Model", |
|
show_label=False, |
|
) |
|
model_name.change( |
|
change_model_listener, inputs=[model_name], outputs=[model_name] |
|
) |
|
|
|
gr.Markdown("# Parameters") |
|
top_p = gr.Slider( |
|
minimum=-0, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
interactive=True, |
|
label="Top-p", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.75, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
) |
|
max_length_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=512, |
|
value=256, |
|
step=8, |
|
interactive=True, |
|
label="Max Generation Tokens", |
|
) |
|
max_context_length_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=4096, |
|
value=2048, |
|
step=128, |
|
interactive=True, |
|
label="Max History Tokens", |
|
) |
|
gr.Markdown(description) |
|
|
|
predict_args = dict( |
|
|
|
fn=interface_predict, |
|
inputs=[ |
|
user_question, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_length_tokens, |
|
max_context_length_tokens, |
|
], |
|
outputs=[chatbot, history, status_display], |
|
show_progress=True, |
|
) |
|
retry_args = dict( |
|
fn=interface_retry, |
|
inputs=[ |
|
user_input, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_length_tokens, |
|
max_context_length_tokens, |
|
], |
|
outputs=[chatbot, history, status_display], |
|
show_progress=True, |
|
) |
|
|
|
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display]) |
|
|
|
|
|
transfer_input_args = dict( |
|
fn=transfer_input, |
|
inputs=[user_input], |
|
outputs=[user_question, user_input, submit_button], |
|
show_progress=True, |
|
) |
|
|
|
predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args) |
|
|
|
predict_event2 = submit_button.click(**transfer_input_args).then(**predict_args) |
|
|
|
empty_button.click( |
|
reset_state, |
|
outputs=[chatbot, history, status_display], |
|
show_progress=True, |
|
) |
|
empty_button.click(**reset_args) |
|
|
|
predict_event3 = retry_button.click(**retry_args) |
|
|
|
delete_last_button.click( |
|
delete_last_conversation, |
|
[chatbot, history], |
|
[chatbot, history, status_display], |
|
show_progress=True, |
|
) |
|
cancel_button.click( |
|
cancel_outputing, |
|
[], |
|
[status_display], |
|
cancels=[predict_event1, predict_event2, predict_event3], |
|
) |
|
|
|
demo.load(change_model_listener, inputs=None, outputs=model_name) |
|
|
|
demo.title = "Llama-2 Chat UI" |
|
|
|
demo.queue(concurrency_count=1).launch() |
|
|