Amitontheweb commited on
Commit
12a5174
1 Parent(s): 48ddd53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -11,14 +11,14 @@ token = os.environ.get("HF_TOKEN")
11
  # Load default model as GPT2 and other models
12
 
13
 
14
- tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
15
- model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
16
 
17
  tokenizer_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2")
18
  model_gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
19
 
20
- tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
21
- model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=token)
22
 
23
  tokenizer_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
24
  model_qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
@@ -143,9 +143,9 @@ def load_model(model_selected):
143
  model = model_gpt2
144
  #print (model_selected + " loaded")
145
 
146
- if model_selected == "Gemma 2":
147
- tokenizer = tokenizer_gemma
148
- model = model_gemma
149
 
150
  if model_selected == "Qwen2":
151
  tokenizer = tokenizer_qwen
@@ -304,14 +304,14 @@ with gr.Blocks() as demo:
304
 
305
  No_beam_group_list = [2]
306
 
307
- #tokenizer = tokenizer_gpt2
308
- #model = model_gpt2
309
 
310
  with gr.Row():
311
 
312
  with gr.Column (scale=0, min_width=200) as Models_Strategy:
313
 
314
- model_selected = gr.Radio (["GPT2", "Gemma 2", "Qwen2"], label="ML Model", value="GPT2")
315
  strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
316
 
317
 
 
11
  # Load default model as GPT2 and other models
12
 
13
 
14
+ #tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
15
+ #model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
16
 
17
  tokenizer_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2")
18
  model_gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
19
 
20
+ #tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
21
+ #model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=token)
22
 
23
  tokenizer_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
24
  model_qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
 
143
  model = model_gpt2
144
  #print (model_selected + " loaded")
145
 
146
+ #if model_selected == "Gemma 2":
147
+ #tokenizer = tokenizer_gemma
148
+ #model = model_gemma
149
 
150
  if model_selected == "Qwen2":
151
  tokenizer = tokenizer_qwen
 
304
 
305
  No_beam_group_list = [2]
306
 
307
+ tokenizer = tokenizer_gpt2
308
+ model = model_gpt2
309
 
310
  with gr.Row():
311
 
312
  with gr.Column (scale=0, min_width=200) as Models_Strategy:
313
 
314
+ model_selected = gr.Radio (["GPT2", "Qwen2"], label="ML Model", value="GPT2")
315
  strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
316
 
317