gabrielclark3330
commited on
Commit
•
459aa64
1
Parent(s):
ede06bd
instruct and base chat types
Browse files
app.py
CHANGED
@@ -7,63 +7,113 @@ import torch
|
|
7 |
model, model_instruct = None, None
|
8 |
tokenizer, tokenizer_instruct = None, None
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
"Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16
|
20 |
-
)
|
21 |
-
selected_model = model
|
22 |
-
selected_tokenizer = tokenizer
|
23 |
-
else:
|
24 |
-
if model_instruct is None: # Load only if not already loaded
|
25 |
-
tokenizer_instruct = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-instruct")
|
26 |
-
model_instruct = AutoModelForCausalLM.from_pretrained(
|
27 |
-
"Zyphra/Zamba2-7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16
|
28 |
-
)
|
29 |
-
selected_model = model_instruct
|
30 |
-
selected_tokenizer = tokenizer_instruct
|
31 |
|
32 |
# Tokenize and generate response
|
33 |
input_ids = selected_tokenizer(input_text, return_tensors="pt").input_ids.to(selected_model.device)
|
34 |
outputs = selected_model.generate(
|
35 |
input_ids=input_ids,
|
36 |
-
max_new_tokens=max_new_tokens,
|
37 |
do_sample=True,
|
38 |
temperature=temperature,
|
39 |
-
top_k=top_k,
|
40 |
top_p=top_p,
|
41 |
repetition_penalty=repetition_penalty,
|
42 |
-
num_beams=num_beams,
|
43 |
length_penalty=length_penalty,
|
44 |
num_return_sequences=1
|
45 |
)
|
46 |
response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
47 |
return response
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
if __name__ == "__main__":
|
69 |
demo.launch()
|
|
|
7 |
model, model_instruct = None, None
|
8 |
tokenizer, tokenizer_instruct = None, None
|
9 |
|
10 |
+
def generate_response_base(input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
|
11 |
+
global model, tokenizer
|
12 |
+
if model is None:
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
|
14 |
+
model = AutoModelForCausalLM.from_pretrained(
|
15 |
+
"Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16
|
16 |
+
)
|
17 |
+
selected_model = model
|
18 |
+
selected_tokenizer = tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
# Tokenize and generate response
|
21 |
input_ids = selected_tokenizer(input_text, return_tensors="pt").input_ids.to(selected_model.device)
|
22 |
outputs = selected_model.generate(
|
23 |
input_ids=input_ids,
|
24 |
+
max_new_tokens=int(max_new_tokens),
|
25 |
do_sample=True,
|
26 |
temperature=temperature,
|
27 |
+
top_k=int(top_k),
|
28 |
top_p=top_p,
|
29 |
repetition_penalty=repetition_penalty,
|
30 |
+
num_beams=int(num_beams),
|
31 |
length_penalty=length_penalty,
|
32 |
num_return_sequences=1
|
33 |
)
|
34 |
response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
35 |
return response
|
36 |
|
37 |
+
def generate_response_instruct(chat_history, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
|
38 |
+
global model_instruct, tokenizer_instruct
|
39 |
+
if model_instruct is None:
|
40 |
+
tokenizer_instruct = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-instruct")
|
41 |
+
model_instruct = AutoModelForCausalLM.from_pretrained(
|
42 |
+
"Zyphra/Zamba2-7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16
|
43 |
+
)
|
44 |
+
selected_model = model_instruct
|
45 |
+
selected_tokenizer = tokenizer_instruct
|
46 |
+
|
47 |
+
# Build the sample
|
48 |
+
sample = []
|
49 |
+
for turn in chat_history:
|
50 |
+
if turn[0]:
|
51 |
+
sample.append({'role': 'user', 'content': turn[0]})
|
52 |
+
if turn[1]:
|
53 |
+
sample.append({'role': 'assistant', 'content': turn[1]})
|
54 |
+
# Format the chat sample
|
55 |
+
chat_sample = selected_tokenizer.apply_chat_template(sample, tokenize=False)
|
56 |
+
# Tokenize input and generate output
|
57 |
+
input_ids = selected_tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(selected_model.device)
|
58 |
+
outputs = selected_model.generate(
|
59 |
+
input_ids=input_ids,
|
60 |
+
max_new_tokens=int(max_new_tokens),
|
61 |
+
do_sample=True,
|
62 |
+
temperature=temperature,
|
63 |
+
top_k=int(top_k),
|
64 |
+
top_p=top_p,
|
65 |
+
repetition_penalty=repetition_penalty,
|
66 |
+
num_beams=int(num_beams),
|
67 |
+
length_penalty=length_penalty,
|
68 |
+
num_return_sequences=1
|
69 |
+
)
|
70 |
+
response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
71 |
+
return response
|
72 |
+
|
73 |
+
def clear_text():
|
74 |
+
return ""
|
75 |
+
|
76 |
+
with gr.Blocks() as demo:
|
77 |
+
gr.Markdown("# Zamba2-7B Model Selector")
|
78 |
+
with gr.Tabs():
|
79 |
+
with gr.TabItem("Base Model"):
|
80 |
+
gr.Markdown("### Zamba2-7B Base Model")
|
81 |
+
input_text = gr.Textbox(lines=2, placeholder="Enter your input text...", label="Input Text")
|
82 |
+
output_text = gr.Textbox(label="Generated Response")
|
83 |
+
max_new_tokens = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
|
84 |
+
temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
|
85 |
+
top_k = gr.Slider(1, 100, step=1, value=50, label="Top K")
|
86 |
+
top_p = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
|
87 |
+
repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
|
88 |
+
num_beams = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
|
89 |
+
length_penalty = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
|
90 |
+
submit_button = gr.Button("Generate Response")
|
91 |
+
submit_button.click(fn=generate_response_base, inputs=[input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty], outputs=output_text)
|
92 |
+
submit_button.click(fn=clear_text, outputs=input_text)
|
93 |
+
with gr.TabItem("Instruct Model"):
|
94 |
+
gr.Markdown("### Zamba2-7B Instruct Model")
|
95 |
+
chat_history = gr.Chatbot()
|
96 |
+
message = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
|
97 |
+
max_new_tokens_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
|
98 |
+
temperature_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
|
99 |
+
top_k_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
|
100 |
+
top_p_instruct = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
|
101 |
+
repetition_penalty_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
|
102 |
+
num_beams_instruct = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
|
103 |
+
length_penalty_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
|
104 |
+
|
105 |
+
def user_message(message, chat_history):
|
106 |
+
chat_history = chat_history + [[message, None]]
|
107 |
+
return "", chat_history
|
108 |
+
|
109 |
+
def bot_response(chat_history):
|
110 |
+
response = generate_response_instruct(chat_history, max_new_tokens_instruct, temperature_instruct, top_k_instruct, top_p_instruct, repetition_penalty_instruct, num_beams_instruct, length_penalty_instruct)
|
111 |
+
chat_history[-1][1] = response
|
112 |
+
return chat_history
|
113 |
+
|
114 |
+
message.submit(user_message, [message, chat_history], [message, chat_history], queue=False).then(
|
115 |
+
bot_response, inputs=[chat_history], outputs=[chat_history]
|
116 |
+
)
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
demo.launch()
|