|
from strings import TITLE, ABSTRACT, BOTTOM_LINE |
|
from strings import DEFAULT_EXAMPLES |
|
from strings import SPECIAL_STRS |
|
from styles import PARENT_BLOCK_CSS |
|
|
|
import time |
|
import gradio as gr |
|
|
|
from model import load_model |
|
from gen import get_output_batch, StreamModel |
|
from utils import generate_prompt, post_processes_batch, post_process_stream, get_generation_config, common_post_process |
|
|
|
model, tokenizer = load_model( |
|
base="decapoda-research/llama-13b-hf", |
|
finetuned="chansung/alpaca-lora-13b" |
|
) |
|
|
|
model = StreamModel(model, tokenizer) |
|
|
|
def chat_stream( |
|
context, |
|
instruction, |
|
state_chatbot, |
|
): |
|
|
|
|
|
|
|
instruction_display = common_post_process(instruction) |
|
instruction_prompt = generate_prompt(instruction, state_chatbot, context) |
|
bot_response = model( |
|
instruction_prompt, |
|
max_tokens=256, |
|
temperature=1, |
|
top_p=0.9 |
|
) |
|
|
|
instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display |
|
state_chatbot = state_chatbot + [(instruction_display, None)] |
|
|
|
prev_index = 0 |
|
agg_tokens = "" |
|
cutoff_idx = 0 |
|
for tokens in bot_response: |
|
tokens = tokens.strip() |
|
cur_token = tokens[prev_index:] |
|
|
|
if "#" in cur_token and agg_tokens == "": |
|
cutoff_idx = tokens.find("#") |
|
agg_tokens = tokens[cutoff_idx:] |
|
|
|
if agg_tokens != "": |
|
if len(agg_tokens) < len("### Instruction:") : |
|
agg_tokens = agg_tokens + cur_token |
|
elif len(agg_tokens) >= len("### Instruction:"): |
|
if tokens.find("### Instruction:") > -1: |
|
processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip()) |
|
|
|
state_chatbot[-1] = ( |
|
instruction_display, |
|
processed_response |
|
) |
|
yield (state_chatbot, state_chatbot, context) |
|
break |
|
else: |
|
agg_tokens = "" |
|
cutoff_idx = 0 |
|
|
|
if agg_tokens == "": |
|
processed_response, to_exit = post_process_stream(tokens) |
|
state_chatbot[-1] = (instruction_display, processed_response) |
|
yield (state_chatbot, state_chatbot, context) |
|
|
|
if to_exit: |
|
break |
|
|
|
prev_index = len(tokens) |
|
|
|
yield ( |
|
state_chatbot, |
|
state_chatbot, |
|
gr.Textbox.update(value=tokens) if instruction_display == SPECIAL_STRS["summarize"] else context |
|
) |
|
|
|
def chat_batch( |
|
contexts, |
|
instructions, |
|
state_chatbots, |
|
): |
|
state_results = [] |
|
ctx_results = [] |
|
|
|
instruct_prompts = [ |
|
generate_prompt(instruct, histories, ctx) |
|
for ctx, instruct, histories in zip(contexts, instructions, state_chatbots) |
|
] |
|
|
|
bot_responses = get_output_batch( |
|
model, tokenizer, instruct_prompts, generation_config |
|
) |
|
bot_responses = post_processes_batch(bot_responses) |
|
|
|
for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots): |
|
new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)] |
|
ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx) |
|
state_results.append(new_state_chatbot) |
|
|
|
return (state_results, state_results, ctx_results) |
|
|
|
def reset_textbox(): |
|
return gr.Textbox.update(value='') |
|
|
|
with gr.Blocks(css=PARENT_BLOCK_CSS) as demo: |
|
state_chatbot = gr.State([]) |
|
|
|
with gr.Column(elem_id='col_container'): |
|
gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}") |
|
|
|
with gr.Accordion("Context Setting", open=False): |
|
context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context") |
|
hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False) |
|
|
|
chatbot = gr.Chatbot(elem_id='chatbot', label="Alpaca-LoRA") |
|
instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction") |
|
send_prompt_btn = gr.Button(value="Send Prompt") |
|
|
|
with gr.Accordion("Helper Buttons", open=False): |
|
gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.") |
|
continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False) |
|
summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False) |
|
|
|
continue_btn = gr.Button(value="Continue") |
|
summarize_btn = gr.Button(value="Summarize") |
|
|
|
gr.Markdown("#### Examples") |
|
for idx, examples in enumerate(DEFAULT_EXAMPLES): |
|
with gr.Accordion(examples["title"], open=False): |
|
gr.Examples( |
|
examples=examples["examples"], |
|
inputs=[ |
|
hidden_txtbox, instruction_txtbox |
|
], |
|
label=None |
|
) |
|
|
|
gr.Markdown(f"{BOTTOM_LINE}") |
|
|
|
send_prompt_btn.click( |
|
chat_stream, |
|
[context_txtbox, instruction_txtbox, state_chatbot], |
|
[state_chatbot, chatbot, context_txtbox], |
|
) |
|
send_prompt_btn.click( |
|
reset_textbox, |
|
[], |
|
[instruction_txtbox], |
|
) |
|
|
|
continue_btn.click( |
|
chat_stream, |
|
[context_txtbox, continue_txtbox, state_chatbot], |
|
[state_chatbot, chatbot, context_txtbox], |
|
) |
|
continue_btn.click( |
|
reset_textbox, |
|
[], |
|
[instruction_txtbox], |
|
) |
|
|
|
summarize_btn.click( |
|
chat_stream, |
|
[context_txtbox, summrize_txtbox, state_chatbot], |
|
[state_chatbot, chatbot, context_txtbox], |
|
) |
|
summarize_btn.click( |
|
reset_textbox, |
|
[], |
|
[instruction_txtbox], |
|
) |
|
|
|
demo.queue( |
|
concurrency_count=2, |
|
max_size=100, |
|
).launch( |
|
max_threads=2, |
|
server_name="0.0.0.0", |
|
) |