QLWD commited on
Commit
232a3c7
1 Parent(s): 66be2b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -79
app.py CHANGED
@@ -1,26 +1,23 @@
1
  import torch
2
- from PIL import Image
 
3
  import gradio as gr
 
4
  import spaces
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
- from threading import Thread
8
-
9
 
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL_ID = "Qwen/Qwen2-7B-Instruct"
12
- MODELS = os.environ.get("MODELS")
13
- MODEL_NAME = MODELS.split("/")[-1]
14
 
15
- TITLE = "<h1><center>Qwen2-7B-instruct</center></h1>"
 
16
 
17
  DESCRIPTION = f"""
18
- <h3>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></h3>
19
  <center>
20
- <p>Qwen is the large language model built by Alibaba Cloud.
21
- <br>
22
- Feel free to test without log.
23
- </p>
24
  </center>
25
  """
26
 
@@ -36,41 +33,42 @@ text-align: center;
36
  }
37
  """
38
 
39
- model = AutoModelForCausalLM.from_pretrained(
40
- MODELS,
41
- torch_dtype=torch.float16,
42
- device_map="auto",
43
- )
44
- tokenizer = AutoTokenizer.from_pretrained(MODELS)
 
45
 
 
46
  @spaces.GPU(duration=2)
47
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
48
- print(f'message is - {message}')
49
- print(f'history is - {history}')
50
  conversation = []
51
  for prompt, answer in history:
52
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
53
  conversation.append({"role": "user", "content": message})
54
 
55
- print(f"Conversation is -\n{conversation}")
56
-
57
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
58
- inputs = tokenizer(input_ids, return_tensors="pt").to(0)
59
-
60
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
61
 
 
 
 
62
  generate_kwargs = dict(
63
- inputs,
64
  streamer=streamer,
65
  top_k=top_k,
66
  top_p=top_p,
67
  repetition_penalty=penalty,
68
- max_new_tokens=max_new_tokens,
69
- do_sample=True,
70
  temperature=temperature,
71
- eos_token_id = [151645, 151643],
72
  )
73
-
 
74
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
75
  thread.start()
76
 
@@ -79,70 +77,35 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
79
  buffer += new_text
80
  yield buffer
81
 
82
-
83
-
84
  chatbot = gr.Chatbot(height=450)
85
 
86
  with gr.Blocks(css=CSS) as demo:
87
  gr.HTML(TITLE)
88
  gr.HTML(DESCRIPTION)
89
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
 
90
  gr.ChatInterface(
91
  fn=stream_chat,
92
  chatbot=chatbot,
93
  fill_height=True,
94
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
95
  additional_inputs=[
96
- gr.Slider(
97
- minimum=0,
98
- maximum=1,
99
- step=0.1,
100
- value=0.8,
101
- label="Temperature",
102
- render=False,
103
- ),
104
- gr.Slider(
105
- minimum=128,
106
- maximum=4096,
107
- step=1,
108
- value=1024,
109
- label="Max new tokens",
110
- render=False,
111
- ),
112
- gr.Slider(
113
- minimum=0.0,
114
- maximum=1.0,
115
- step=0.1,
116
- value=0.8,
117
- label="top_p",
118
- render=False,
119
- ),
120
- gr.Slider(
121
- minimum=1,
122
- maximum=20,
123
- step=1,
124
- value=20,
125
- label="top_k",
126
- render=False,
127
- ),
128
- gr.Slider(
129
- minimum=0.0,
130
- maximum=2.0,
131
- step=0.1,
132
- value=1.0,
133
- label="Repetition penalty",
134
- render=False,
135
- ),
136
  ],
137
  examples=[
138
- ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
139
- ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
140
- ["Tell me a random fun fact about the Roman Empire."],
141
- ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
142
  ],
143
  cache_examples=False,
144
  )
145
 
146
-
147
  if __name__ == "__main__":
148
  demo.launch()
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ from peft import PeftModel
4
  import gradio as gr
5
+ from threading import Thread
6
  import spaces
 
7
  import os
 
 
8
 
9
+ # 从环境变量中获取 Hugging Face 模型信息
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ BASE_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct" # 替换为基础模型
12
+ LORA_MODEL_PATH = "QLWD/test-3b" # 替换为 LoRA 模型仓库路径
 
13
 
14
+ # 定义界面标题和描述
15
+ TITLE = "<h1><center>LoRA 微调模型测试</center></h1>"
16
 
17
  DESCRIPTION = f"""
18
+ <h3>模型: <a href="https://huggingface.co/{LORA_MODEL_PATH}">LoRA 微调模型</a></h3>
19
  <center>
20
+ <p>测试基础模型 + LoRA 补丁的生成效果。</p>
 
 
 
21
  </center>
22
  """
23
 
 
33
  }
34
  """
35
 
36
+ # 加载基础模型和 LoRA 微调权重
37
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16, device_map="auto")
38
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
39
+
40
+ # 加载 LoRA 微调权重
41
+ model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH)
42
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
+ # 定义推理函数
45
  @spaces.GPU(duration=2)
46
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
 
 
47
  conversation = []
48
  for prompt, answer in history:
49
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
50
  conversation.append({"role": "user", "content": message})
51
 
52
+ # 使用自定义对话模板生成 input_ids
 
53
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
54
+ inputs = tokenizer(input_ids, return_tensors="pt").to("cuda")
 
 
55
 
56
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
57
+
58
+ # 设置生成参数
59
  generate_kwargs = dict(
60
+ inputs,
61
  streamer=streamer,
62
  top_k=top_k,
63
  top_p=top_p,
64
  repetition_penalty=penalty,
65
+ max_new_tokens=max_new_tokens,
66
+ do_sample=True,
67
  temperature=temperature,
68
+ eos_token_id=[151645, 151643],
69
  )
70
+
71
+ # 启动生成线程
72
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
73
  thread.start()
74
 
 
77
  buffer += new_text
78
  yield buffer
79
 
80
+ # 定义 Gradio 界面
 
81
  chatbot = gr.Chatbot(height=450)
82
 
83
  with gr.Blocks(css=CSS) as demo:
84
  gr.HTML(TITLE)
85
  gr.HTML(DESCRIPTION)
86
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
87
+
88
  gr.ChatInterface(
89
  fn=stream_chat,
90
  chatbot=chatbot,
91
  fill_height=True,
92
+ additional_inputs_accordion=gr.Accordion(label="⚙️ 参数设置", open=False, render=False),
93
  additional_inputs=[
94
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
95
+ gr.Slider(minimum=128, maximum=4096, step=1, value=1024, label="Max new tokens", render=False),
96
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="top_p", render=False),
97
+ gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k", render=False),
98
+ gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Repetition penalty", render=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  ],
100
  examples=[
101
+ ["请帮我生成一段关于学习的句子"],
102
+ ["解释一下量子计算的概念"],
103
+ ["给我提供一些Python编程技巧"],
104
+ ["CSSJavaScript创建一个固定的页眉"],
105
  ],
106
  cache_examples=False,
107
  )
108
 
109
+ # 启动 Gradio 应用
110
  if __name__ == "__main__":
111
  demo.launch()