import subprocess # Installing flash_attn subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) from threading import Thread import spaces import gradio as gr import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer ) model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto') tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = model.config.eos_token_id for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False @spaces.GPU(duration=120) def predict(history, prompt, max_length, top_p, temperature): stop = StopOnTokens() messages = [] if prompt: messages.append({"role": "system", "content": prompt}) for idx, (user_msg, model_msg) in enumerate(history): if prompt and idx == 0: continue if idx == len(history) - 1 and not model_msg: query = user_msg break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to( next(model.parameters()).device) streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True) eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|observation|>")] generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_length, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": StoppingCriteriaList([stop]), "repetition_penalty": 1, "eos_token_id": eos_token_id, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() for new_token in streamer: if new_token and '<|user|>' not in new_token: history[-1][1] += new_token yield history with gr.Blocks() as demo: gr.Markdown( """
longwriter-glm4-9b Huggingface Space🤗
🤗 Model Hub | 🌐 Github | 📜 arxiv
⚠️ This is just a basic demo. Due to the scheduling limitations of Zero GPU, the output length is restricted to under 4K. If you wish to experience the full capabilities of the model (output exceeding 10K), please deploy the model yourself. Thank you for your understanding.
""" ) chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=3): with gr.Column(scale=12): user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit") with gr.Column(scale=1): prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False) pBtn = gr.Button("Set Prompt") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 128000, value=4096, step=1.0, label="Maximum length(Input + Output)", interactive=True) top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True) def user(query, history): return "", history + [[query, ""]] def set_prompt(prompt_text): return [[prompt_text, "Set prompt successfully"]] pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot) submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot ) emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False) demo.queue() demo.launch()