File size: 4,531 Bytes
6f8d4ee
83a2412
630e57e
 
aace9fa
630e57e
 
 
 
 
 
 
aace9fa
 
630e57e
 
aace9fa
630e57e
 
 
 
 
 
 
 
 
 
5f7b72c
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995512e
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0a9e18
 
 
630e57e
 
 
 
 
 
 
 
1d54087
630e57e
1d54087
5412729
2c2fc7f
5412729
 
8b7b2e1
1d54087
5412729
630e57e
 
 
 
 
 
 
cbe73b1
630e57e
119893c
630e57e
 
119893c
630e57e
 
b1f38dd
83a2412
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import subprocess

from threading import Thread
import spaces
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer
)

model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = model.config.eos_token_id
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


@spaces.GPU(duration=280)
def predict(history, prompt, max_length, top_p, temperature):
    stop = StopOnTokens()
    messages = []
    if prompt:
        messages.append({"role": "system", "content": prompt})
    for idx, (user_msg, model_msg) in enumerate(history):
        if prompt and idx == 0:
            continue
        if idx == len(history) - 1 and not model_msg:
            query = user_msg
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
        next(model.parameters()).device)
    streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
    eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                    tokenizer.get_command("<|observation|>")]
    generate_kwargs = {
        "input_ids": model_inputs,
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "stopping_criteria": StoppingCriteriaList([stop]),
        "repetition_penalty": 1,
        "eos_token_id": eos_token_id,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    for new_token in streamer:
        if new_token and '<|user|>' in new_token:
            new_token = new_token.split('<|user|>')[0]
        if new_token:
            history[-1][1] += new_token
        yield history


with gr.Blocks() as demo:
    gr.Markdown(
        """
        <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
            医创客AI文案生成🤗
        </div>
 
         <div style="text-align: center; font-size: 15px; font-weight: black; margin-bottom: 20px; line-height: 1.5;">
            医创客长文案生成工具
        </div>
        <br>
        <div style="text-align: center; font-size: 15px; font-weight: bold; margin-bottom: 20px; line-height: 1.5;">
        ⚠️ 需要3-5分钟,可以生成结果.<br>
        </div>  
        """
    )
    chatbot = gr.Chatbot()

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...(Example: Write a 10000-word China travel guide)", lines=10, container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("提交")
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
            pBtn = gr.Button("Set Prompt设置文案模版")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 128000, value=10240, step=1.0, label="Maximum length(Input + Output)",
                                   interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)


    def user(query, history):
        return "", history + [[query, ""]]


    def set_prompt(prompt_text):
        return [[prompt_text, "Set prompt successfully"]]


    pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)

    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
        predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
    )
    emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)

demo.queue()
demo.launch()