Spaces:
Sleeping
Sleeping
File size: 5,023 Bytes
6f8d4ee 83a2412 6f8d4ee 83a2412 6f8d4ee 630e57e aace9fa 630e57e aace9fa 630e57e aace9fa 630e57e 5f7b72c 630e57e 995512e 630e57e 995512e 630e57e b694633 83a2412 5f7b72c 83a2412 5f7b72c b694633 630e57e cbe73b1 630e57e b1f38dd 83a2412 630e57e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import subprocess
# Installing flash_attn
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True)
from threading import Thread
import spaces
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer
)
model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@spaces.GPU(duration=280)
def predict(history, prompt, max_length, top_p, temperature):
stop = StopOnTokens()
messages = []
if prompt:
messages.append({"role": "system", "content": prompt})
for idx, (user_msg, model_msg) in enumerate(history):
if prompt and idx == 0:
continue
if idx == len(history) - 1 and not model_msg:
query = user_msg
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1,
"eos_token_id": eos_token_id,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for new_token in streamer:
if new_token and '<|user|>' not in new_token:
history[-1][1] += new_token
yield history
with gr.Blocks() as demo:
gr.Markdown(
"""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
longwriter-glm4-9b Huggingface Space🤗
</div>
<div style="text-align: center;">
<a href="https://huggingface.co/THUDM/LongWriter-glm4-9b">🤗 Model Hub</a> |
<a href="https://github.com/THUDM/LongWriter">🌐 Github</a> |
<a href="https://arxiv.org/pdf/2408.07055">📜 arxiv </a>
</div>
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
⚠️ Due to the limitations of Huggingface ZERO GPUs, in order to output 10K characters in one go,
we need to request a 4-5 minute quota each time.
This will result in you only being able to use it once every 4 hours.
If you plan to use it long-term, please consider deploying the model or fork this space yourself.
</div>
"""
)
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=3):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...(Example: Write a 10000-word China travel guide)", lines=10, container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit")
with gr.Column(scale=1):
prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
pBtn = gr.Button("Set Prompt")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 128000, value=10240, step=1.0, label="Maximum length(Input + Output)",
interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
def user(query, history):
return "", history + [[query, ""]]
def set_prompt(prompt_text):
return [[prompt_text, "Set prompt successfully"]]
pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
)
emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)
demo.queue()
demo.launch()
|