File size: 4,975 Bytes
6f8d4ee
83a2412
6f8d4ee
83a2412
 
6f8d4ee
630e57e
 
aace9fa
630e57e
 
 
 
 
 
 
aace9fa
 
630e57e
 
aace9fa
630e57e
 
 
 
 
 
 
 
 
 
5f7b72c
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995512e
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995512e
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b694633
83a2412
5f7b72c
83a2412
5f7b72c
b694633
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1f38dd
83a2412
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import subprocess

# Installing flash_attn
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
               shell=True)

from threading import Thread
import spaces
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer
)

model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = model.config.eos_token_id
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


@spaces.GPU(duration=280)
def predict(history, prompt, max_length, top_p, temperature):
    stop = StopOnTokens()
    messages = []
    if prompt:
        messages.append({"role": "system", "content": prompt})
    for idx, (user_msg, model_msg) in enumerate(history):
        if prompt and idx == 0:
            continue
        if idx == len(history) - 1 and not model_msg:
            query = user_msg
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
        next(model.parameters()).device)
    streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
    eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                    tokenizer.get_command("<|observation|>")]
    generate_kwargs = {
        "input_ids": model_inputs,
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "stopping_criteria": StoppingCriteriaList([stop]),
        "repetition_penalty": 1,
        "eos_token_id": eos_token_id,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    for new_token in streamer:
        if new_token and '<|user|>' not in new_token:
            history[-1][1] += new_token
        yield history


with gr.Blocks() as demo:
    gr.Markdown(
        """
        <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
            longwriter-glm4-9b Huggingface Space🤗
        </div>
        <div style="text-align: center;">
            <a href="https://huggingface.co/THUDM/LongWriter-glm4-9b">🤗 Model Hub</a> |
            <a href="https://github.com/THUDM/LongWriter">🌐 Github</a> |
            <a href="https://arxiv.org/pdf/2408.07055">📜 arxiv </a>
        </div>
        <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
            ⚠️ Due to the limitations of Huggingface ZERO GPUs, in order to output 10K characters in one go, 
            we need to request a 4-5 minute quota each time. 
            This will result in you only being able to use it once every 4 hours. 
            If you plan to use it long-term, please consider deploying the model or fork this space yourself.
        </div>
        """
    )
    chatbot = gr.Chatbot()

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit")
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
            pBtn = gr.Button("Set Prompt")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 128000, value=10240, step=1.0, label="Maximum length(Input + Output)",
                                   interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)


    def user(query, history):
        return "", history + [[query, ""]]


    def set_prompt(prompt_text):
        return [[prompt_text, "Set prompt successfully"]]


    pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)

    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
        predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
    )
    emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)

demo.queue()
demo.launch()