Update
Browse files- README.md +1 -1
- app.py +101 -227
- model.py +0 -74
- requirements.txt +5 -5
- style.css +1 -1
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🦙
|
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: other
|
|
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.46.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: other
|
app.py
CHANGED
@@ -1,18 +1,21 @@
|
|
|
|
1 |
from typing import Iterator
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
11 |
MAX_MAX_NEW_TOKENS = 2048
|
12 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
13 |
-
MAX_INPUT_TOKEN_LENGTH =
|
14 |
|
15 |
-
DESCRIPTION = """
|
16 |
# Llama-2 13B Chat
|
17 |
|
18 |
This Space demonstrates model [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta, a Llama 2 model with 13B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
|
@@ -33,248 +36,119 @@ this demo is governed by the original [license](https://huggingface.co/spaces/hu
|
|
33 |
"""
|
34 |
|
35 |
if not torch.cuda.is_available():
|
36 |
-
DESCRIPTION +=
|
37 |
-
|
38 |
|
39 |
-
def clear_and_save_textbox(message: str) -> tuple[str, str]:
|
40 |
-
return '', message
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
try:
|
52 |
-
message, _ = history.pop()
|
53 |
-
except IndexError:
|
54 |
-
message = ''
|
55 |
-
return history, message or ''
|
56 |
|
57 |
|
58 |
def generate(
|
59 |
message: str,
|
60 |
-
|
61 |
system_prompt: str,
|
62 |
-
max_new_tokens: int,
|
63 |
-
temperature: float,
|
64 |
-
top_p: float,
|
65 |
-
top_k: int,
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
gr.Markdown(DESCRIPTION)
|
96 |
-
gr.DuplicateButton(value='Duplicate Space for private use',
|
97 |
-
elem_id='duplicate-button')
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
container=False,
|
104 |
-
show_label=False,
|
105 |
-
placeholder='Type a message...',
|
106 |
-
scale=10,
|
107 |
-
)
|
108 |
-
submit_button = gr.Button('Submit',
|
109 |
-
variant='primary',
|
110 |
-
scale=1,
|
111 |
-
min_width=0)
|
112 |
-
with gr.Row():
|
113 |
-
retry_button = gr.Button('🔄 Retry', variant='secondary')
|
114 |
-
undo_button = gr.Button('↩️ Undo', variant='secondary')
|
115 |
-
clear_button = gr.Button('🗑️ Clear', variant='secondary')
|
116 |
|
117 |
-
saved_input = gr.State()
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
label=
|
125 |
minimum=1,
|
126 |
maximum=MAX_MAX_NEW_TOKENS,
|
127 |
step=1,
|
128 |
value=DEFAULT_MAX_NEW_TOKENS,
|
129 |
-
)
|
130 |
-
|
131 |
-
label=
|
132 |
minimum=0.1,
|
133 |
maximum=4.0,
|
134 |
step=0.1,
|
135 |
-
value=
|
136 |
-
)
|
137 |
-
|
138 |
-
label=
|
139 |
minimum=0.05,
|
140 |
maximum=1.0,
|
141 |
step=0.05,
|
142 |
-
value=0.
|
143 |
-
)
|
144 |
-
|
145 |
-
label=
|
146 |
minimum=1,
|
147 |
maximum=1000,
|
148 |
step=1,
|
149 |
value=50,
|
150 |
-
)
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
gr.Markdown(LICENSE)
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
inputs=textbox,
|
171 |
-
outputs=[textbox, saved_input],
|
172 |
-
api_name=False,
|
173 |
-
queue=False,
|
174 |
-
).then(
|
175 |
-
fn=display_input,
|
176 |
-
inputs=[saved_input, chatbot],
|
177 |
-
outputs=chatbot,
|
178 |
-
api_name=False,
|
179 |
-
queue=False,
|
180 |
-
).then(
|
181 |
-
fn=check_input_token_length,
|
182 |
-
inputs=[saved_input, chatbot, system_prompt],
|
183 |
-
api_name=False,
|
184 |
-
queue=False,
|
185 |
-
).success(
|
186 |
-
fn=generate,
|
187 |
-
inputs=[
|
188 |
-
saved_input,
|
189 |
-
chatbot,
|
190 |
-
system_prompt,
|
191 |
-
max_new_tokens,
|
192 |
-
temperature,
|
193 |
-
top_p,
|
194 |
-
top_k,
|
195 |
-
],
|
196 |
-
outputs=chatbot,
|
197 |
-
api_name=False,
|
198 |
-
)
|
199 |
-
|
200 |
-
button_event_preprocess = submit_button.click(
|
201 |
-
fn=clear_and_save_textbox,
|
202 |
-
inputs=textbox,
|
203 |
-
outputs=[textbox, saved_input],
|
204 |
-
api_name=False,
|
205 |
-
queue=False,
|
206 |
-
).then(
|
207 |
-
fn=display_input,
|
208 |
-
inputs=[saved_input, chatbot],
|
209 |
-
outputs=chatbot,
|
210 |
-
api_name=False,
|
211 |
-
queue=False,
|
212 |
-
).then(
|
213 |
-
fn=check_input_token_length,
|
214 |
-
inputs=[saved_input, chatbot, system_prompt],
|
215 |
-
api_name=False,
|
216 |
-
queue=False,
|
217 |
-
).success(
|
218 |
-
fn=generate,
|
219 |
-
inputs=[
|
220 |
-
saved_input,
|
221 |
-
chatbot,
|
222 |
-
system_prompt,
|
223 |
-
max_new_tokens,
|
224 |
-
temperature,
|
225 |
-
top_p,
|
226 |
-
top_k,
|
227 |
-
],
|
228 |
-
outputs=chatbot,
|
229 |
-
api_name=False,
|
230 |
-
)
|
231 |
-
|
232 |
-
retry_button.click(
|
233 |
-
fn=delete_prev_fn,
|
234 |
-
inputs=chatbot,
|
235 |
-
outputs=[chatbot, saved_input],
|
236 |
-
api_name=False,
|
237 |
-
queue=False,
|
238 |
-
).then(
|
239 |
-
fn=display_input,
|
240 |
-
inputs=[saved_input, chatbot],
|
241 |
-
outputs=chatbot,
|
242 |
-
api_name=False,
|
243 |
-
queue=False,
|
244 |
-
).then(
|
245 |
-
fn=generate,
|
246 |
-
inputs=[
|
247 |
-
saved_input,
|
248 |
-
chatbot,
|
249 |
-
system_prompt,
|
250 |
-
max_new_tokens,
|
251 |
-
temperature,
|
252 |
-
top_p,
|
253 |
-
top_k,
|
254 |
-
],
|
255 |
-
outputs=chatbot,
|
256 |
-
api_name=False,
|
257 |
-
)
|
258 |
-
|
259 |
-
undo_button.click(
|
260 |
-
fn=delete_prev_fn,
|
261 |
-
inputs=chatbot,
|
262 |
-
outputs=[chatbot, saved_input],
|
263 |
-
api_name=False,
|
264 |
-
queue=False,
|
265 |
-
).then(
|
266 |
-
fn=lambda x: x,
|
267 |
-
inputs=[saved_input],
|
268 |
-
outputs=textbox,
|
269 |
-
api_name=False,
|
270 |
-
queue=False,
|
271 |
-
)
|
272 |
-
|
273 |
-
clear_button.click(
|
274 |
-
fn=lambda: ([], ''),
|
275 |
-
outputs=[chatbot, saved_input],
|
276 |
-
queue=False,
|
277 |
-
api_name=False,
|
278 |
-
)
|
279 |
-
|
280 |
-
demo.queue(max_size=20).launch()
|
|
|
1 |
+
from threading import Thread
|
2 |
from typing import Iterator
|
3 |
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
+
from transformers import (
|
7 |
+
AutoConfig,
|
8 |
+
AutoModelForCausalLM,
|
9 |
+
AutoTokenizer,
|
10 |
+
TextIteratorStreamer,
|
11 |
+
)
|
12 |
+
|
13 |
+
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
|
14 |
MAX_MAX_NEW_TOKENS = 2048
|
15 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
16 |
+
MAX_INPUT_TOKEN_LENGTH = 4096
|
17 |
|
18 |
+
DESCRIPTION = """\
|
19 |
# Llama-2 13B Chat
|
20 |
|
21 |
This Space demonstrates model [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta, a Llama 2 model with 13B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
|
|
|
36 |
"""
|
37 |
|
38 |
if not torch.cuda.is_available():
|
39 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
|
|
40 |
|
|
|
|
|
41 |
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
model_id = "meta-llama/Llama-2-13b-chat-hf"
|
44 |
+
config = AutoConfig.from_pretrained(model_id)
|
45 |
+
config.pretraining_tp = 1
|
46 |
+
model = AutoModelForCausalLM.from_pretrained(
|
47 |
+
model_id, config=config, torch_dtype=torch.float16, load_in_4bit=True, device_map="auto"
|
48 |
+
)
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
50 |
+
tokenizer.use_default_system_prompt = False
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
|
53 |
def generate(
|
54 |
message: str,
|
55 |
+
chat_history: list[tuple[str, str]],
|
56 |
system_prompt: str,
|
57 |
+
max_new_tokens: int = 1024,
|
58 |
+
temperature: float = 0.6,
|
59 |
+
top_p: float = 0.9,
|
60 |
+
top_k: int = 50,
|
61 |
+
repetition_penalty: float = 1.2,
|
62 |
+
) -> Iterator[str]:
|
63 |
+
conversation = []
|
64 |
+
if system_prompt:
|
65 |
+
conversation.append({"role": "system", "content": system_prompt})
|
66 |
+
for user, assistant in chat_history:
|
67 |
+
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
68 |
+
conversation.append({"role": "user", "content": message})
|
69 |
+
|
70 |
+
chat = tokenizer.apply_chat_template(conversation, tokenize=False)
|
71 |
+
inputs = tokenizer(chat, return_tensors="pt", add_special_tokens=False).to("cuda")
|
72 |
+
if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
|
73 |
+
inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
|
74 |
+
gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
75 |
+
|
76 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
77 |
+
generate_kwargs = dict(
|
78 |
+
inputs,
|
79 |
+
streamer=streamer,
|
80 |
+
max_new_tokens=max_new_tokens,
|
81 |
+
do_sample=True,
|
82 |
+
top_p=top_p,
|
83 |
+
top_k=top_k,
|
84 |
+
temperature=temperature,
|
85 |
+
num_beams=1,
|
86 |
+
repetition_penalty=repetition_penalty,
|
87 |
+
)
|
88 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
89 |
+
t.start()
|
|
|
|
|
|
|
90 |
|
91 |
+
outputs = []
|
92 |
+
for text in streamer:
|
93 |
+
outputs.append(text)
|
94 |
+
yield "".join(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
|
|
96 |
|
97 |
+
chat_interface = gr.ChatInterface(
|
98 |
+
fn=generate,
|
99 |
+
additional_inputs=[
|
100 |
+
gr.Textbox(label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6),
|
101 |
+
gr.Slider(
|
102 |
+
label="Max new tokens",
|
103 |
minimum=1,
|
104 |
maximum=MAX_MAX_NEW_TOKENS,
|
105 |
step=1,
|
106 |
value=DEFAULT_MAX_NEW_TOKENS,
|
107 |
+
),
|
108 |
+
gr.Slider(
|
109 |
+
label="Temperature",
|
110 |
minimum=0.1,
|
111 |
maximum=4.0,
|
112 |
step=0.1,
|
113 |
+
value=0.6,
|
114 |
+
),
|
115 |
+
gr.Slider(
|
116 |
+
label="Top-p (nucleus sampling)",
|
117 |
minimum=0.05,
|
118 |
maximum=1.0,
|
119 |
step=0.05,
|
120 |
+
value=0.9,
|
121 |
+
),
|
122 |
+
gr.Slider(
|
123 |
+
label="Top-k",
|
124 |
minimum=1,
|
125 |
maximum=1000,
|
126 |
step=1,
|
127 |
value=50,
|
128 |
+
),
|
129 |
+
gr.Slider(
|
130 |
+
label="Repetition penalty",
|
131 |
+
minimum=1.0,
|
132 |
+
maximum=2.0,
|
133 |
+
step=0.05,
|
134 |
+
value=1.2,
|
135 |
+
),
|
136 |
+
],
|
137 |
+
stop_btn=None,
|
138 |
+
examples=[
|
139 |
+
["Hello there! How are you doing?"],
|
140 |
+
["Can you explain briefly to me what is the Python programming language?"],
|
141 |
+
["Explain the plot of Cinderella in a sentence."],
|
142 |
+
["How many hours does it take a man to eat a Helicopter?"],
|
143 |
+
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
|
144 |
+
],
|
145 |
+
)
|
146 |
+
|
147 |
+
with gr.Blocks(css="style.css") as demo:
|
148 |
+
gr.Markdown(DESCRIPTION)
|
149 |
+
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
|
150 |
+
chat_interface.render()
|
151 |
gr.Markdown(LICENSE)
|
152 |
|
153 |
+
if __name__ == "__main__":
|
154 |
+
demo.queue(max_size=20).launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
from threading import Thread
|
2 |
-
from typing import Iterator
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
6 |
-
|
7 |
-
model_id = 'meta-llama/Llama-2-13b-chat-hf'
|
8 |
-
|
9 |
-
if torch.cuda.is_available():
|
10 |
-
config = AutoConfig.from_pretrained(model_id)
|
11 |
-
config.pretraining_tp = 1
|
12 |
-
model = AutoModelForCausalLM.from_pretrained(
|
13 |
-
model_id,
|
14 |
-
config=config,
|
15 |
-
torch_dtype=torch.float16,
|
16 |
-
load_in_4bit=True,
|
17 |
-
device_map='auto'
|
18 |
-
)
|
19 |
-
else:
|
20 |
-
model = None
|
21 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
22 |
-
|
23 |
-
|
24 |
-
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
25 |
-
system_prompt: str) -> str:
|
26 |
-
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
27 |
-
# The first user input is _not_ stripped
|
28 |
-
do_strip = False
|
29 |
-
for user_input, response in chat_history:
|
30 |
-
user_input = user_input.strip() if do_strip else user_input
|
31 |
-
do_strip = True
|
32 |
-
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
33 |
-
message = message.strip() if do_strip else message
|
34 |
-
texts.append(f'{message} [/INST]')
|
35 |
-
return ''.join(texts)
|
36 |
-
|
37 |
-
|
38 |
-
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
|
39 |
-
prompt = get_prompt(message, chat_history, system_prompt)
|
40 |
-
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
|
41 |
-
return input_ids.shape[-1]
|
42 |
-
|
43 |
-
|
44 |
-
def run(message: str,
|
45 |
-
chat_history: list[tuple[str, str]],
|
46 |
-
system_prompt: str,
|
47 |
-
max_new_tokens: int = 1024,
|
48 |
-
temperature: float = 0.8,
|
49 |
-
top_p: float = 0.95,
|
50 |
-
top_k: int = 50) -> Iterator[str]:
|
51 |
-
prompt = get_prompt(message, chat_history, system_prompt)
|
52 |
-
inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
|
53 |
-
|
54 |
-
streamer = TextIteratorStreamer(tokenizer,
|
55 |
-
timeout=10.,
|
56 |
-
skip_prompt=True,
|
57 |
-
skip_special_tokens=True)
|
58 |
-
generate_kwargs = dict(
|
59 |
-
inputs,
|
60 |
-
streamer=streamer,
|
61 |
-
max_new_tokens=max_new_tokens,
|
62 |
-
do_sample=True,
|
63 |
-
top_p=top_p,
|
64 |
-
top_k=top_k,
|
65 |
-
temperature=temperature,
|
66 |
-
num_beams=1,
|
67 |
-
)
|
68 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
69 |
-
t.start()
|
70 |
-
|
71 |
-
outputs = []
|
72 |
-
for text in streamer:
|
73 |
-
outputs.append(text)
|
74 |
-
yield ''.join(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
accelerate==0.
|
2 |
-
bitsandbytes==0.
|
3 |
-
gradio==3.
|
4 |
protobuf==3.20.3
|
5 |
-
scipy==1.11.
|
6 |
sentencepiece==0.1.99
|
7 |
torch==2.0.1
|
8 |
-
transformers==4.
|
|
|
1 |
+
accelerate==0.23.0
|
2 |
+
bitsandbytes==0.41.1
|
3 |
+
gradio==3.46.0
|
4 |
protobuf==3.20.3
|
5 |
+
scipy==1.11.2
|
6 |
sentencepiece==0.1.99
|
7 |
torch==2.0.1
|
8 |
+
transformers==4.34.0
|
style.css
CHANGED
@@ -9,7 +9,7 @@ h1 {
|
|
9 |
border-radius: 100vh;
|
10 |
}
|
11 |
|
12 |
-
|
13 |
max-width: 900px;
|
14 |
margin: auto;
|
15 |
padding-top: 1.5rem;
|
|
|
9 |
border-radius: 100vh;
|
10 |
}
|
11 |
|
12 |
+
.contain {
|
13 |
max-width: 900px;
|
14 |
margin: auto;
|
15 |
padding-top: 1.5rem;
|