Spaces:
Sleeping
Sleeping
File size: 4,531 Bytes
6f8d4ee 83a2412 630e57e aace9fa 630e57e aace9fa 630e57e aace9fa 630e57e 5f7b72c 630e57e 995512e 630e57e a0a9e18 630e57e 1d54087 630e57e 1d54087 5412729 2c2fc7f 5412729 8b7b2e1 1d54087 5412729 630e57e cbe73b1 630e57e 119893c 630e57e 119893c 630e57e b1f38dd 83a2412 630e57e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import subprocess
from threading import Thread
import spaces
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer
)
model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@spaces.GPU(duration=280)
def predict(history, prompt, max_length, top_p, temperature):
stop = StopOnTokens()
messages = []
if prompt:
messages.append({"role": "system", "content": prompt})
for idx, (user_msg, model_msg) in enumerate(history):
if prompt and idx == 0:
continue
if idx == len(history) - 1 and not model_msg:
query = user_msg
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1,
"eos_token_id": eos_token_id,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for new_token in streamer:
if new_token and '<|user|>' in new_token:
new_token = new_token.split('<|user|>')[0]
if new_token:
history[-1][1] += new_token
yield history
with gr.Blocks() as demo:
gr.Markdown(
"""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
医创客AI文案生成🤗
</div>
<div style="text-align: center; font-size: 15px; font-weight: black; margin-bottom: 20px; line-height: 1.5;">
医创客长文案生成工具
</div>
<br>
<div style="text-align: center; font-size: 15px; font-weight: bold; margin-bottom: 20px; line-height: 1.5;">
⚠️ 需要3-5分钟,可以生成结果.<br>
</div>
"""
)
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=3):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...(Example: Write a 10000-word China travel guide)", lines=10, container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("提交")
with gr.Column(scale=1):
prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
pBtn = gr.Button("Set Prompt设置文案模版")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 128000, value=10240, step=1.0, label="Maximum length(Input + Output)",
interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
def user(query, history):
return "", history + [[query, ""]]
def set_prompt(prompt_text):
return [[prompt_text, "Set prompt successfully"]]
pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
)
emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)
demo.queue()
demo.launch()
|