|
from threading import Thread |
|
import torch |
|
import gradio as gr |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import subprocess |
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
BANNER_HTML = """ |
|
<p align="center"> |
|
<a href="https://github.com/ymcui/Chinese-LLaMA-Alpaca-3"> |
|
<img src="https://ymcui.com/images/chinese-llama-alpaca-3-banner.png" width="600"/> |
|
</a> |
|
</p> |
|
<h3> |
|
<center>Check our <a href='https://github.com/ymcui/Chinese-LLaMA-Alpaca-3' target='_blank'>Chinese-LLaMA-Alpaca-3 GitHub Project</a> for more information. |
|
</center> |
|
</h3> |
|
<p> |
|
<center><em>The demo is mainly for academic purposes. Illegal usages are prohibited. Default model: <a href="https://huggingface.co/hfl/llama-3-chinese-8b-instruct-v3">hfl/llama-3-chinese-8b-instruct-v3</a></em></center> |
|
</p> |
|
""" |
|
|
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant. 你是一个乐于助人的助手。" |
|
|
|
|
|
def load_model(version): |
|
global tokenizer, model |
|
if version == "v1": |
|
model_name = "hfl/llama-3-chinese-8b-instruct" |
|
elif version == "v2": |
|
model_name = "hfl/llama-3-chinese-8b-instruct-v2" |
|
elif version == "v3": |
|
model_name = "hfl/llama-3-chinese-8b-instruct-v3" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2") |
|
return f"Model {model_name} loaded." |
|
|
|
@spaces.GPU(duration=50) |
|
def stream_chat(message: str, history: list, system_prompt: str, model_version: str, temperature: float, max_new_tokens: int): |
|
conversation = [{"role": "system", "content": system_prompt or DEFAULT_SYSTEM_PROMPT}] |
|
for prompt, answer in history: |
|
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) |
|
|
|
conversation.append({"role": "user", "content": message}) |
|
|
|
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device) |
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
|
|
terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] |
|
|
|
generate_kwargs = { |
|
"input_ids": input_ids, |
|
"streamer": streamer, |
|
"eos_token_id": terminators, |
|
"pad_token_id": tokenizer.eos_token_id, |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"top_k": 40, |
|
"top_p": 0.9, |
|
"num_beams": 1, |
|
"repetition_penalty": 1.1, |
|
"do_sample": temperature != 0, |
|
} |
|
|
|
generation_thread = Thread(target=model.generate, kwargs=generate_kwargs) |
|
generation_thread.start() |
|
|
|
output = "" |
|
for new_token in streamer: |
|
output += new_token |
|
yield output |
|
|
|
chatbot = gr.Chatbot(height=500) |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(BANNER_HTML) |
|
gr.ChatInterface( |
|
fn=stream_chat, |
|
chatbot=chatbot, |
|
fill_height=True, |
|
additional_inputs_accordion=gr.Accordion(label="Parameters / 参数设置", open=False, render=False), |
|
additional_inputs=[ |
|
gr.Text(value=DEFAULT_SYSTEM_PROMPT, label="System Prompt / 系统提示词", render=False), |
|
gr.Radio(choices=["v1", "v2", "v3"], label="Model Version / 模型版本", value="v3", interactive=False, render=False), |
|
gr.Slider(minimum=0, maximum=1.5, step=0.1, value=0.6, label="Temperature / 温度系数", render=False), |
|
gr.Slider(minimum=128, maximum=2048, step=1, value=512, label="Max new tokens / 最大生成长度", render=False), |
|
], |
|
cache_examples=False, |
|
submit_btn="Send / 发送", |
|
stop_btn="Stop / 停止", |
|
retry_btn="🔄 Retry / 重试", |
|
undo_btn="↩️ Undo / 撤销", |
|
clear_btn="🗑️ Clear / 清空", |
|
) |
|
|
|
if __name__ == "__main__": |
|
load_model("v3") |
|
demo.launch() |