Spaces:
Sleeping
Sleeping
import subprocess | |
# Installing flash_attn | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
from threading import Thread | |
import spaces | |
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
TextIteratorStreamer | |
) | |
model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto') | |
tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True) | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
stop_ids = model.config.eos_token_id | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
def predict(history, prompt, max_length, top_p, temperature): | |
stop = StopOnTokens() | |
messages = [] | |
if prompt: | |
messages.append({"role": "system", "content": prompt}) | |
for idx, (user_msg, model_msg) in enumerate(history): | |
if prompt and idx == 0: | |
continue | |
if idx == len(history) - 1 and not model_msg: | |
query = user_msg | |
break | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if model_msg: | |
messages.append({"role": "assistant", "content": model_msg}) | |
model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to( | |
next(model.parameters()).device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True) | |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), | |
tokenizer.get_command("<|observation|>")] | |
generate_kwargs = { | |
"input_ids": model_inputs, | |
"streamer": streamer, | |
"max_new_tokens": max_length, | |
"do_sample": True, | |
"top_p": top_p, | |
"temperature": temperature, | |
"stopping_criteria": StoppingCriteriaList([stop]), | |
"repetition_penalty": 1, | |
"eos_token_id": eos_token_id, | |
} | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
for new_token in streamer: | |
if new_token and '<|user|>' not in new_token: | |
history[-1][1] += new_token | |
yield history | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | |
longwriter-glm4-9b Huggingface Space🤗 | |
</div> | |
<div style="text-align: center;"> | |
<a href="https://huggingface.co/THUDM/LongWriter-glm4-9b">🤗 Model Hub</a> | | |
<a href="https://github.com/THUDM/LongWriter">🌐 Github</a> | | |
<a href="https://arxiv.org/pdf/2408.07055">📜 arxiv </a> | |
</div> | |
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;"> | |
⚠️ This is just a basic demo. Due to the scheduling limitations of Zero GPU, the output length is restricted to under 4K. If you wish to experience the full capabilities of the model (output exceeding 10K), please deploy the model yourself. Thank you for your understanding. | |
</div> | |
""" | |
) | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Column(scale=12): | |
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False) | |
with gr.Column(min_width=32, scale=1): | |
submitBtn = gr.Button("Submit") | |
with gr.Column(scale=1): | |
prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False) | |
pBtn = gr.Button("Set Prompt") | |
with gr.Column(scale=1): | |
emptyBtn = gr.Button("Clear History") | |
max_length = gr.Slider(0, 128000, value=4096, step=1.0, label="Maximum length(Input + Output)", interactive=True) | |
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) | |
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True) | |
def user(query, history): | |
return "", history + [[query, ""]] | |
def set_prompt(prompt_text): | |
return [[prompt_text, "Set prompt successfully"]] | |
pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot) | |
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( | |
predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot | |
) | |
emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False) | |
demo.queue() | |
demo.launch() | |