Update app.py
Browse files
app.py
CHANGED
@@ -1,26 +1,23 @@
|
|
1 |
import torch
|
2 |
-
from
|
|
|
3 |
import gradio as gr
|
|
|
4 |
import spaces
|
5 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
6 |
import os
|
7 |
-
from threading import Thread
|
8 |
-
|
9 |
|
|
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
11 |
-
|
12 |
-
|
13 |
-
MODEL_NAME = MODELS.split("/")[-1]
|
14 |
|
15 |
-
|
|
|
16 |
|
17 |
DESCRIPTION = f"""
|
18 |
-
<h3
|
19 |
<center>
|
20 |
-
<p
|
21 |
-
<br>
|
22 |
-
Feel free to test without log.
|
23 |
-
</p>
|
24 |
</center>
|
25 |
"""
|
26 |
|
@@ -36,41 +33,42 @@ text-align: center;
|
|
36 |
}
|
37 |
"""
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
|
|
|
46 |
@spaces.GPU(duration=2)
|
47 |
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
|
48 |
-
print(f'message is - {message}')
|
49 |
-
print(f'history is - {history}')
|
50 |
conversation = []
|
51 |
for prompt, answer in history:
|
52 |
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
|
53 |
conversation.append({"role": "user", "content": message})
|
54 |
|
55 |
-
|
56 |
-
|
57 |
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
|
58 |
-
inputs = tokenizer(input_ids, return_tensors="pt").to(
|
59 |
-
|
60 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
61 |
|
|
|
|
|
|
|
62 |
generate_kwargs = dict(
|
63 |
-
inputs,
|
64 |
streamer=streamer,
|
65 |
top_k=top_k,
|
66 |
top_p=top_p,
|
67 |
repetition_penalty=penalty,
|
68 |
-
max_new_tokens=max_new_tokens,
|
69 |
-
do_sample=True,
|
70 |
temperature=temperature,
|
71 |
-
eos_token_id
|
72 |
)
|
73 |
-
|
|
|
74 |
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
75 |
thread.start()
|
76 |
|
@@ -79,70 +77,35 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
|
|
79 |
buffer += new_text
|
80 |
yield buffer
|
81 |
|
82 |
-
|
83 |
-
|
84 |
chatbot = gr.Chatbot(height=450)
|
85 |
|
86 |
with gr.Blocks(css=CSS) as demo:
|
87 |
gr.HTML(TITLE)
|
88 |
gr.HTML(DESCRIPTION)
|
89 |
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
|
|
90 |
gr.ChatInterface(
|
91 |
fn=stream_chat,
|
92 |
chatbot=chatbot,
|
93 |
fill_height=True,
|
94 |
-
additional_inputs_accordion=gr.Accordion(label="⚙️
|
95 |
additional_inputs=[
|
96 |
-
gr.Slider(
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
label="Temperature",
|
102 |
-
render=False,
|
103 |
-
),
|
104 |
-
gr.Slider(
|
105 |
-
minimum=128,
|
106 |
-
maximum=4096,
|
107 |
-
step=1,
|
108 |
-
value=1024,
|
109 |
-
label="Max new tokens",
|
110 |
-
render=False,
|
111 |
-
),
|
112 |
-
gr.Slider(
|
113 |
-
minimum=0.0,
|
114 |
-
maximum=1.0,
|
115 |
-
step=0.1,
|
116 |
-
value=0.8,
|
117 |
-
label="top_p",
|
118 |
-
render=False,
|
119 |
-
),
|
120 |
-
gr.Slider(
|
121 |
-
minimum=1,
|
122 |
-
maximum=20,
|
123 |
-
step=1,
|
124 |
-
value=20,
|
125 |
-
label="top_k",
|
126 |
-
render=False,
|
127 |
-
),
|
128 |
-
gr.Slider(
|
129 |
-
minimum=0.0,
|
130 |
-
maximum=2.0,
|
131 |
-
step=0.1,
|
132 |
-
value=1.0,
|
133 |
-
label="Repetition penalty",
|
134 |
-
render=False,
|
135 |
-
),
|
136 |
],
|
137 |
examples=[
|
138 |
-
["
|
139 |
-
["
|
140 |
-
["
|
141 |
-
["
|
142 |
],
|
143 |
cache_examples=False,
|
144 |
)
|
145 |
|
146 |
-
|
147 |
if __name__ == "__main__":
|
148 |
demo.launch()
|
|
|
1 |
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
3 |
+
from peft import PeftModel
|
4 |
import gradio as gr
|
5 |
+
from threading import Thread
|
6 |
import spaces
|
|
|
7 |
import os
|
|
|
|
|
8 |
|
9 |
+
# 从环境变量中获取 Hugging Face 模型信息
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
11 |
+
BASE_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct" # 替换为基础模型
|
12 |
+
LORA_MODEL_PATH = "QLWD/test-3b" # 替换为 LoRA 模型仓库路径
|
|
|
13 |
|
14 |
+
# 定义界面标题和描述
|
15 |
+
TITLE = "<h1><center>LoRA 微调模型测试</center></h1>"
|
16 |
|
17 |
DESCRIPTION = f"""
|
18 |
+
<h3>模型: <a href="https://huggingface.co/{LORA_MODEL_PATH}">LoRA 微调模型</a></h3>
|
19 |
<center>
|
20 |
+
<p>测试基础模型 + LoRA 补丁的生成效果。</p>
|
|
|
|
|
|
|
21 |
</center>
|
22 |
"""
|
23 |
|
|
|
33 |
}
|
34 |
"""
|
35 |
|
36 |
+
# 加载基础模型和 LoRA 微调权重
|
37 |
+
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16, device_map="auto")
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
|
39 |
+
|
40 |
+
# 加载 LoRA 微调权重
|
41 |
+
model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH)
|
42 |
+
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
|
44 |
+
# 定义推理函数
|
45 |
@spaces.GPU(duration=2)
|
46 |
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
|
|
|
|
|
47 |
conversation = []
|
48 |
for prompt, answer in history:
|
49 |
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
|
50 |
conversation.append({"role": "user", "content": message})
|
51 |
|
52 |
+
# 使用自定义对话模板生成 input_ids
|
|
|
53 |
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
|
54 |
+
inputs = tokenizer(input_ids, return_tensors="pt").to("cuda")
|
|
|
|
|
55 |
|
56 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
57 |
+
|
58 |
+
# 设置生成参数
|
59 |
generate_kwargs = dict(
|
60 |
+
inputs,
|
61 |
streamer=streamer,
|
62 |
top_k=top_k,
|
63 |
top_p=top_p,
|
64 |
repetition_penalty=penalty,
|
65 |
+
max_new_tokens=max_new_tokens,
|
66 |
+
do_sample=True,
|
67 |
temperature=temperature,
|
68 |
+
eos_token_id=[151645, 151643],
|
69 |
)
|
70 |
+
|
71 |
+
# 启动生成线程
|
72 |
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
73 |
thread.start()
|
74 |
|
|
|
77 |
buffer += new_text
|
78 |
yield buffer
|
79 |
|
80 |
+
# 定义 Gradio 界面
|
|
|
81 |
chatbot = gr.Chatbot(height=450)
|
82 |
|
83 |
with gr.Blocks(css=CSS) as demo:
|
84 |
gr.HTML(TITLE)
|
85 |
gr.HTML(DESCRIPTION)
|
86 |
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
87 |
+
|
88 |
gr.ChatInterface(
|
89 |
fn=stream_chat,
|
90 |
chatbot=chatbot,
|
91 |
fill_height=True,
|
92 |
+
additional_inputs_accordion=gr.Accordion(label="⚙️ 参数设置", open=False, render=False),
|
93 |
additional_inputs=[
|
94 |
+
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
|
95 |
+
gr.Slider(minimum=128, maximum=4096, step=1, value=1024, label="Max new tokens", render=False),
|
96 |
+
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="top_p", render=False),
|
97 |
+
gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k", render=False),
|
98 |
+
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Repetition penalty", render=False),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
],
|
100 |
examples=[
|
101 |
+
["请帮我生成一段关于学习的句子"],
|
102 |
+
["解释一下量子计算的概念"],
|
103 |
+
["给我提供一些Python编程技巧"],
|
104 |
+
["用CSS和JavaScript创建一个固定的页眉"],
|
105 |
],
|
106 |
cache_examples=False,
|
107 |
)
|
108 |
|
109 |
+
# 启动 Gradio 应用
|
110 |
if __name__ == "__main__":
|
111 |
demo.launch()
|