Amitontheweb commited on
Commit
4792edb
·
verified ·
1 Parent(s): 2cc0ba7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -20,8 +20,8 @@ model_gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
20
  tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
21
  model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=token)
22
 
23
- tokenizer_Mistral = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
24
- model_Mistral = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
25
 
26
  # Define functions
27
 
@@ -138,7 +138,7 @@ def generate(input_text, number_steps, number_beams, number_beam_groups, diversi
138
 
139
  def load_model(model_selected):
140
 
141
- if model_selected == "gpt2":
142
  tokenizer = tokenizer_gpt2
143
  model = model_gpt2
144
  #print (model_selected + " loaded")
@@ -147,9 +147,9 @@ def load_model(model_selected):
147
  tokenizer = tokenizer_gemma
148
  model = model_gemma
149
 
150
- if model_selected == "Mistral":
151
- tokenizer = tokenizer_Mistral
152
- model = model_Mistral
153
 
154
 
155
 
@@ -311,7 +311,7 @@ with gr.Blocks() as demo:
311
 
312
  with gr.Column (scale=0, min_width=200) as Models_Strategy:
313
 
314
- model_selected = gr.Radio (["gpt2", "Gemma 2", "Mistral"], label="ML Model", value="gpt2")
315
  strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
316
 
317
 
 
20
  tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
21
  model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=token)
22
 
23
+ tokenizer_gpt-neo = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
24
+ model_gpt-neo = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
25
 
26
  # Define functions
27
 
 
138
 
139
  def load_model(model_selected):
140
 
141
+ if model_selected == "GPT2":
142
  tokenizer = tokenizer_gpt2
143
  model = model_gpt2
144
  #print (model_selected + " loaded")
 
147
  tokenizer = tokenizer_gemma
148
  model = model_gemma
149
 
150
+ if model_selected == "Eleuther GPT Neo 1.3B":
151
+ tokenizer = tokenizer_gpt-neo
152
+ model = model_gpt-neo
153
 
154
 
155
 
 
311
 
312
  with gr.Column (scale=0, min_width=200) as Models_Strategy:
313
 
314
+ model_selected = gr.Radio (["GPT2", "Gemma 2", "Eleuther GPT Neo 1.3B"], label="ML Model", value="gpt2")
315
  strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
316
 
317