gabrielclark3330 commited on
Commit
459aa64
1 Parent(s): ede06bd

instruct and base chat types

Browse files
Files changed (1) hide show
  1. app.py +92 -42
app.py CHANGED
@@ -7,63 +7,113 @@ import torch
7
  model, model_instruct = None, None
8
  tokenizer, tokenizer_instruct = None, None
9
 
10
- # Define the response function with lazy loading
11
- def generate_response(input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty, model_choice):
12
- global model, model_instruct, tokenizer, tokenizer_instruct
13
-
14
- # Lazy loading of the selected model
15
- if model_choice == "Zamba2-7B":
16
- if model is None: # Load only if not already loaded
17
- tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
18
- model = AutoModelForCausalLM.from_pretrained(
19
- "Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16
20
- )
21
- selected_model = model
22
- selected_tokenizer = tokenizer
23
- else:
24
- if model_instruct is None: # Load only if not already loaded
25
- tokenizer_instruct = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-instruct")
26
- model_instruct = AutoModelForCausalLM.from_pretrained(
27
- "Zyphra/Zamba2-7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16
28
- )
29
- selected_model = model_instruct
30
- selected_tokenizer = tokenizer_instruct
31
 
32
  # Tokenize and generate response
33
  input_ids = selected_tokenizer(input_text, return_tensors="pt").input_ids.to(selected_model.device)
34
  outputs = selected_model.generate(
35
  input_ids=input_ids,
36
- max_new_tokens=max_new_tokens,
37
  do_sample=True,
38
  temperature=temperature,
39
- top_k=top_k,
40
  top_p=top_p,
41
  repetition_penalty=repetition_penalty,
42
- num_beams=num_beams,
43
  length_penalty=length_penalty,
44
  num_return_sequences=1
45
  )
46
  response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
47
  return response
48
 
49
- # Gradio interface with model selection
50
- demo = gr.Interface(
51
- fn=generate_response,
52
- inputs=[
53
- gr.Textbox(lines=1, placeholder="Enter your input text...", label="Input Text"),
54
- gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens"),
55
- gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature"),
56
- gr.Slider(1, 100, step=1, value=50, label="Top K"),
57
- gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P"),
58
- gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty"),
59
- gr.Slider(1, 10, step=1, value=5, label="Number of Beams"),
60
- gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty"),
61
- gr.Dropdown(["Zamba2-7B", "Zamba2-7B-instruct"], label="Model Choice")
62
- ],
63
- outputs=gr.Textbox(label="Generated Response"),
64
- title="Zamba2-7B Model Selector",
65
- description="Choose a model and ask a question with customizable parameters."
66
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  if __name__ == "__main__":
69
  demo.launch()
 
7
  model, model_instruct = None, None
8
  tokenizer, tokenizer_instruct = None, None
9
 
10
+ def generate_response_base(input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
11
+ global model, tokenizer
12
+ if model is None:
13
+ tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ "Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16
16
+ )
17
+ selected_model = model
18
+ selected_tokenizer = tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Tokenize and generate response
21
  input_ids = selected_tokenizer(input_text, return_tensors="pt").input_ids.to(selected_model.device)
22
  outputs = selected_model.generate(
23
  input_ids=input_ids,
24
+ max_new_tokens=int(max_new_tokens),
25
  do_sample=True,
26
  temperature=temperature,
27
+ top_k=int(top_k),
28
  top_p=top_p,
29
  repetition_penalty=repetition_penalty,
30
+ num_beams=int(num_beams),
31
  length_penalty=length_penalty,
32
  num_return_sequences=1
33
  )
34
  response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
35
  return response
36
 
37
+ def generate_response_instruct(chat_history, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
38
+ global model_instruct, tokenizer_instruct
39
+ if model_instruct is None:
40
+ tokenizer_instruct = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-instruct")
41
+ model_instruct = AutoModelForCausalLM.from_pretrained(
42
+ "Zyphra/Zamba2-7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16
43
+ )
44
+ selected_model = model_instruct
45
+ selected_tokenizer = tokenizer_instruct
46
+
47
+ # Build the sample
48
+ sample = []
49
+ for turn in chat_history:
50
+ if turn[0]:
51
+ sample.append({'role': 'user', 'content': turn[0]})
52
+ if turn[1]:
53
+ sample.append({'role': 'assistant', 'content': turn[1]})
54
+ # Format the chat sample
55
+ chat_sample = selected_tokenizer.apply_chat_template(sample, tokenize=False)
56
+ # Tokenize input and generate output
57
+ input_ids = selected_tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(selected_model.device)
58
+ outputs = selected_model.generate(
59
+ input_ids=input_ids,
60
+ max_new_tokens=int(max_new_tokens),
61
+ do_sample=True,
62
+ temperature=temperature,
63
+ top_k=int(top_k),
64
+ top_p=top_p,
65
+ repetition_penalty=repetition_penalty,
66
+ num_beams=int(num_beams),
67
+ length_penalty=length_penalty,
68
+ num_return_sequences=1
69
+ )
70
+ response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+ return response
72
+
73
+ def clear_text():
74
+ return ""
75
+
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("# Zamba2-7B Model Selector")
78
+ with gr.Tabs():
79
+ with gr.TabItem("Base Model"):
80
+ gr.Markdown("### Zamba2-7B Base Model")
81
+ input_text = gr.Textbox(lines=2, placeholder="Enter your input text...", label="Input Text")
82
+ output_text = gr.Textbox(label="Generated Response")
83
+ max_new_tokens = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
84
+ temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
85
+ top_k = gr.Slider(1, 100, step=1, value=50, label="Top K")
86
+ top_p = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
87
+ repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
88
+ num_beams = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
89
+ length_penalty = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
90
+ submit_button = gr.Button("Generate Response")
91
+ submit_button.click(fn=generate_response_base, inputs=[input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty], outputs=output_text)
92
+ submit_button.click(fn=clear_text, outputs=input_text)
93
+ with gr.TabItem("Instruct Model"):
94
+ gr.Markdown("### Zamba2-7B Instruct Model")
95
+ chat_history = gr.Chatbot()
96
+ message = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
97
+ max_new_tokens_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
98
+ temperature_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
99
+ top_k_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
100
+ top_p_instruct = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
101
+ repetition_penalty_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
102
+ num_beams_instruct = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
103
+ length_penalty_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
104
+
105
+ def user_message(message, chat_history):
106
+ chat_history = chat_history + [[message, None]]
107
+ return "", chat_history
108
+
109
+ def bot_response(chat_history):
110
+ response = generate_response_instruct(chat_history, max_new_tokens_instruct, temperature_instruct, top_k_instruct, top_p_instruct, repetition_penalty_instruct, num_beams_instruct, length_penalty_instruct)
111
+ chat_history[-1][1] = response
112
+ return chat_history
113
+
114
+ message.submit(user_message, [message, chat_history], [message, chat_history], queue=False).then(
115
+ bot_response, inputs=[chat_history], outputs=[chat_history]
116
+ )
117
 
118
  if __name__ == "__main__":
119
  demo.launch()