gabrielclark3330 commited on
Commit
bcc5c70
1 Parent(s): 459aa64

Do instruct models for 2.7 and 7 b sizes

Browse files
Files changed (1) hide show
  1. app.py +123 -71
app.py CHANGED
@@ -3,26 +3,50 @@ import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
- # Define models as None to delay loading
7
- model, model_instruct = None, None
8
- tokenizer, tokenizer_instruct = None, None
9
 
10
- def generate_response_base(input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
11
- global model, tokenizer
12
- if model is None:
13
- tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
14
- model = AutoModelForCausalLM.from_pretrained(
15
- "Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16
16
- )
17
- selected_model = model
18
- selected_tokenizer = tokenizer
19
 
20
- # Tokenize and generate response
21
- input_ids = selected_tokenizer(input_text, return_tensors="pt").input_ids.to(selected_model.device)
22
- outputs = selected_model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  input_ids=input_ids,
24
  max_new_tokens=int(max_new_tokens),
25
  do_sample=True,
 
26
  temperature=temperature,
27
  top_k=int(top_k),
28
  top_p=top_p,
@@ -31,34 +55,27 @@ def generate_response_base(input_text, max_new_tokens, temperature, top_k, top_p
31
  length_penalty=length_penalty,
32
  num_return_sequences=1
33
  )
34
- response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
35
- return response
 
 
36
 
37
- def generate_response_instruct(chat_history, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
38
- global model_instruct, tokenizer_instruct
39
- if model_instruct is None:
40
- tokenizer_instruct = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-instruct")
41
- model_instruct = AutoModelForCausalLM.from_pretrained(
42
- "Zyphra/Zamba2-7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16
43
- )
44
- selected_model = model_instruct
45
- selected_tokenizer = tokenizer_instruct
46
-
47
- # Build the sample
48
  sample = []
49
  for turn in chat_history:
50
  if turn[0]:
51
  sample.append({'role': 'user', 'content': turn[0]})
52
  if turn[1]:
53
  sample.append({'role': 'assistant', 'content': turn[1]})
54
- # Format the chat sample
55
- chat_sample = selected_tokenizer.apply_chat_template(sample, tokenize=False)
56
- # Tokenize input and generate output
57
- input_ids = selected_tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(selected_model.device)
58
- outputs = selected_model.generate(
59
  input_ids=input_ids,
60
  max_new_tokens=int(max_new_tokens),
61
  do_sample=True,
 
62
  temperature=temperature,
63
  top_k=int(top_k),
64
  top_p=top_p,
@@ -67,53 +84,88 @@ def generate_response_instruct(chat_history, max_new_tokens, temperature, top_k,
67
  length_penalty=length_penalty,
68
  num_return_sequences=1
69
  )
70
- response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
71
- return response
72
-
73
- def clear_text():
74
- return ""
75
 
76
  with gr.Blocks() as demo:
77
- gr.Markdown("# Zamba2-7B Model Selector")
78
  with gr.Tabs():
79
- with gr.TabItem("Base Model"):
80
- gr.Markdown("### Zamba2-7B Base Model")
81
- input_text = gr.Textbox(lines=2, placeholder="Enter your input text...", label="Input Text")
82
- output_text = gr.Textbox(label="Generated Response")
83
- max_new_tokens = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
84
- temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
85
- top_k = gr.Slider(1, 100, step=1, value=50, label="Top K")
86
- top_p = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
87
- repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
88
- num_beams = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
89
- length_penalty = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
90
- submit_button = gr.Button("Generate Response")
91
- submit_button.click(fn=generate_response_base, inputs=[input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty], outputs=output_text)
92
- submit_button.click(fn=clear_text, outputs=input_text)
93
- with gr.TabItem("Instruct Model"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  gr.Markdown("### Zamba2-7B Instruct Model")
95
- chat_history = gr.Chatbot()
96
- message = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
97
- max_new_tokens_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
98
- temperature_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
99
- top_k_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
100
- top_p_instruct = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
101
- repetition_penalty_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
102
- num_beams_instruct = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
103
- length_penalty_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
 
 
 
104
 
105
- def user_message(message, chat_history):
106
  chat_history = chat_history + [[message, None]]
107
- return "", chat_history
108
 
109
- def bot_response(chat_history):
110
- response = generate_response_instruct(chat_history, max_new_tokens_instruct, temperature_instruct, top_k_instruct, top_p_instruct, repetition_penalty_instruct, num_beams_instruct, length_penalty_instruct)
111
  chat_history[-1][1] = response
112
- return chat_history
113
 
114
- message.submit(user_message, [message, chat_history], [message, chat_history], queue=False).then(
115
- bot_response, inputs=[chat_history], outputs=[chat_history]
 
 
 
 
 
 
 
 
 
 
116
  )
117
 
118
  if __name__ == "__main__":
119
- demo.launch()
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
+ model_name_2_7B_instruct = "Zyphra/Zamba2-2.7B-instruct"
7
+ model_name_7B_instruct = "Zyphra/Zamba2-7B-instruct"
 
8
 
9
+ tokenizer_2_7B_instruct = AutoTokenizer.from_pretrained(model_name_2_7B_instruct)
10
+ model_2_7B_instruct = AutoModelForCausalLM.from_pretrained(
11
+ model_name_2_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
12
+ )
 
 
 
 
 
13
 
14
+ tokenizer_7B_instruct = AutoTokenizer.from_pretrained(model_name_7B_instruct)
15
+ model_7B_instruct = AutoModelForCausalLM.from_pretrained(
16
+ model_name_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
17
+ )
18
+
19
+ def extract_assistant_response(generated_text):
20
+ assistant_token = '<|im_start|> assistant'
21
+ end_token = '<|im_end|>'
22
+ start_idx = generated_text.rfind(assistant_token)
23
+ if start_idx == -1:
24
+ # Assistant token not found
25
+ return generated_text.strip()
26
+ start_idx += len(assistant_token)
27
+ end_idx = generated_text.find(end_token, start_idx)
28
+ if end_idx == -1:
29
+ # End token not found, return from start_idx to end
30
+ return generated_text[start_idx:].strip()
31
+ else:
32
+ return generated_text[start_idx:end_idx].strip()
33
+
34
+ def generate_response_2_7B_instruct(chat_history, max_new_tokens):
35
+ sample = []
36
+ for turn in chat_history:
37
+ if turn[0]:
38
+ sample.append({'role': 'user', 'content': turn[0]})
39
+ if turn[1]:
40
+ sample.append({'role': 'assistant', 'content': turn[1]})
41
+ chat_sample = tokenizer_2_7B_instruct.apply_chat_template(sample, tokenize=False)
42
+ input_ids = tokenizer_2_7B_instruct(chat_sample, return_tensors='pt', add_special_tokens=False).to(model_2_7B_instruct.device)
43
+ outputs = model_2_7B_instruct.generate(**input_ids, max_new_tokens=int(max_new_tokens), return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
44
+ """
45
+ outputs = model_2_7B_instruct.generate(
46
  input_ids=input_ids,
47
  max_new_tokens=int(max_new_tokens),
48
  do_sample=True,
49
+ use_cache=True,
50
  temperature=temperature,
51
  top_k=int(top_k),
52
  top_p=top_p,
 
55
  length_penalty=length_penalty,
56
  num_return_sequences=1
57
  )
58
+ """
59
+ generated_text = tokenizer_2_7B_instruct.decode(outputs[0])
60
+ assistant_response = extract_assistant_response(generated_text)
61
+ return assistant_response
62
 
63
+ def generate_response_7B_instruct(chat_history, max_new_tokens):
 
 
 
 
 
 
 
 
 
 
64
  sample = []
65
  for turn in chat_history:
66
  if turn[0]:
67
  sample.append({'role': 'user', 'content': turn[0]})
68
  if turn[1]:
69
  sample.append({'role': 'assistant', 'content': turn[1]})
70
+ chat_sample = tokenizer_7B_instruct.apply_chat_template(sample, tokenize=False)
71
+ input_ids = tokenizer_7B_instruct(chat_sample, return_tensors='pt', add_special_tokens=False).to(model_7B_instruct.device)
72
+ outputs = model_7B_instruct.generate(**input_ids, max_new_tokens=int(max_new_tokens), return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
73
+ """
74
+ outputs = model_7B_instruct.generate(
75
  input_ids=input_ids,
76
  max_new_tokens=int(max_new_tokens),
77
  do_sample=True,
78
+ use_cache=True,
79
  temperature=temperature,
80
  top_k=int(top_k),
81
  top_p=top_p,
 
84
  length_penalty=length_penalty,
85
  num_return_sequences=1
86
  )
87
+ """
88
+ generated_text = tokenizer_7B_instruct.decode(outputs[0])
89
+ assistant_response = extract_assistant_response(generated_text)
90
+ return assistant_response
 
91
 
92
  with gr.Blocks() as demo:
93
+ gr.Markdown("# Zamba2 Model Selector")
94
  with gr.Tabs():
95
+ with gr.TabItem("2.7B Instruct Model"):
96
+ gr.Markdown("### Zamba2-2.7B Instruct Model")
97
+ with gr.Column():
98
+ chat_history_2_7B_instruct = gr.State([])
99
+ chatbot_2_7B_instruct = gr.Chatbot()
100
+ message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
101
+ with gr.Accordion("Generation Parameters", open=False):
102
+ max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
103
+ # temperature_2_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature")
104
+ # top_k_2_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
105
+ # top_p_2_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P")
106
+ # repetition_penalty_2_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
107
+ # num_beams_2_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams")
108
+ # length_penalty_2_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
109
+
110
+ def user_message_2_7B_instruct(message, chat_history):
111
+ chat_history = chat_history + [[message, None]]
112
+ return gr.update(value=""), chat_history, chat_history
113
+
114
+ def bot_response_2_7B_instruct(chat_history, max_new_tokens):
115
+ response = generate_response_2_7B_instruct(chat_history, max_new_tokens)
116
+ chat_history[-1][1] = response
117
+ return chat_history, chat_history
118
+
119
+ send_button_2_7B_instruct = gr.Button("Send")
120
+ send_button_2_7B_instruct.click(
121
+ fn=user_message_2_7B_instruct,
122
+ inputs=[message_2_7B_instruct, chat_history_2_7B_instruct],
123
+ outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct]
124
+ ).then(
125
+ fn=bot_response_2_7B_instruct,
126
+ inputs=[
127
+ chat_history_2_7B_instruct,
128
+ max_new_tokens_2_7B_instruct
129
+ ],
130
+ outputs=[chat_history_2_7B_instruct, chatbot_2_7B_instruct]
131
+ )
132
+ with gr.TabItem("7B Instruct Model"):
133
  gr.Markdown("### Zamba2-7B Instruct Model")
134
+ with gr.Column():
135
+ chat_history_7B_instruct = gr.State([])
136
+ chatbot_7B_instruct = gr.Chatbot()
137
+ message_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
138
+ with gr.Accordion("Generation Parameters", open=False):
139
+ max_new_tokens_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
140
+ # temperature_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature")
141
+ # top_k_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
142
+ # top_p_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P")
143
+ # repetition_penalty_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
144
+ # num_beams_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams")
145
+ # length_penalty_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
146
 
147
+ def user_message_7B_instruct(message, chat_history):
148
  chat_history = chat_history + [[message, None]]
149
+ return gr.update(value=""), chat_history, chat_history
150
 
151
+ def bot_response_7B_instruct(chat_history, max_new_tokens):
152
+ response = generate_response_7B_instruct(chat_history, max_new_tokens)
153
  chat_history[-1][1] = response
154
+ return chat_history, chat_history
155
 
156
+ send_button_7B_instruct = gr.Button("Send")
157
+ send_button_7B_instruct.click(
158
+ fn=user_message_7B_instruct,
159
+ inputs=[message_7B_instruct, chat_history_7B_instruct],
160
+ outputs=[message_7B_instruct, chat_history_7B_instruct, chatbot_7B_instruct]
161
+ ).then(
162
+ fn=bot_response_7B_instruct,
163
+ inputs=[
164
+ chat_history_7B_instruct,
165
+ max_new_tokens_7B_instruct
166
+ ],
167
+ outputs=[chat_history_7B_instruct, chatbot_7B_instruct]
168
  )
169
 
170
  if __name__ == "__main__":
171
+ demo.queue().launch()