Amitontheweb commited on
Commit
7c242ca
·
verified ·
1 Parent(s): 78b413a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -6,12 +6,20 @@ import torch
6
  import gradio as gr
7
 
8
 
9
- # Load default model as GPT2
10
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
13
  model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
14
 
 
 
 
 
 
 
 
 
15
 
16
  # Define functions
17
 
@@ -129,13 +137,17 @@ def generate(input_text, number_steps, number_beams, number_beam_groups, diversi
129
  def load_model(model_selected):
130
 
131
  if model_selected == "gpt2":
132
- tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
133
- model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", pad_token_id=tokenizer.eos_token_id)
134
  #print (model_selected + " loaded")
135
 
136
  if model_selected == "Gemma 2":
137
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
138
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
 
 
 
 
139
 
140
 
141
 
@@ -298,7 +310,7 @@ with gr.Blocks() as demo:
298
 
299
  with gr.Column (scale=0, min_width=200) as Models_Strategy:
300
 
301
- model_selected = gr.Radio (["gpt2", "Gemma 2"], label="ML Model", value="gpt2")
302
  strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
303
 
304
 
 
6
  import gradio as gr
7
 
8
 
9
+ # Load default model as GPT2 and other models
10
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
13
  model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
14
 
15
+ tokenizer_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2")
16
+ model_gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
17
+
18
+ tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
19
+ model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
20
+
21
+ tokenizer_Mistral = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
22
+ model_Mistral = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
23
 
24
  # Define functions
25
 
 
137
  def load_model(model_selected):
138
 
139
  if model_selected == "gpt2":
140
+ tokenizer = tokenizer_gpt2
141
+ model = model_gpt2
142
  #print (model_selected + " loaded")
143
 
144
  if model_selected == "Gemma 2":
145
+ tokenizer = tokenizer_gemma
146
+ model = model_gemma
147
+
148
+ if model_selected == "Mistral":
149
+ tokenizer = tokenizer_Mistral
150
+ model = model_Mistral
151
 
152
 
153
 
 
310
 
311
  with gr.Column (scale=0, min_width=200) as Models_Strategy:
312
 
313
+ model_selected = gr.Radio (["gpt2", "Gemma 2", "Mistral"], label="ML Model", value="gpt2")
314
  strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
315
 
316