import subprocess from threading import Thread import spaces import gradio as gr import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer ) model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto') tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = model.config.eos_token_id for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False @spaces.GPU(duration=280) def predict(history, prompt, max_length, top_p, temperature): stop = StopOnTokens() messages = [] if prompt: messages.append({"role": "system", "content": prompt}) for idx, (user_msg, model_msg) in enumerate(history): if prompt and idx == 0: continue if idx == len(history) - 1 and not model_msg: query = user_msg break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to( next(model.parameters()).device) streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True) eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|observation|>")] generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_length, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": StoppingCriteriaList([stop]), "repetition_penalty": 1, "eos_token_id": eos_token_id, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() for new_token in streamer: if new_token and '<|user|>' in new_token: new_token = new_token.split('<|user|>')[0] if new_token: history[-1][1] += new_token yield history with gr.Blocks() as demo: gr.Markdown( """
医创客AI文案生成🤗
医创客长文案生成工具

⚠️ 需要3-5分钟,可以生成结果.
""" ) chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=3): with gr.Column(scale=12): user_input = gr.Textbox(show_label=False, placeholder="Input...(Example: Write a 10000-word China travel guide)", lines=10, container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("提交") with gr.Column(scale=1): prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False) pBtn = gr.Button("Set Prompt设置文案模版") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 128000, value=10240, step=1.0, label="Maximum length(Input + Output)", interactive=True) top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True) def user(query, history): return "", history + [[query, ""]] def set_prompt(prompt_text): return [[prompt_text, "Set prompt successfully"]] pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot) submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot ) emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False) demo.queue() demo.launch()