Spaces:
Sleeping
Sleeping
Amitontheweb
commited on
Commit
•
12a5174
1
Parent(s):
48ddd53
Update app.py
Browse files
app.py
CHANGED
@@ -11,14 +11,14 @@ token = os.environ.get("HF_TOKEN")
|
|
11 |
# Load default model as GPT2 and other models
|
12 |
|
13 |
|
14 |
-
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
15 |
-
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
16 |
|
17 |
tokenizer_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
18 |
model_gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
19 |
|
20 |
-
tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
|
21 |
-
model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=token)
|
22 |
|
23 |
tokenizer_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
|
24 |
model_qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
|
@@ -143,9 +143,9 @@ def load_model(model_selected):
|
|
143 |
model = model_gpt2
|
144 |
#print (model_selected + " loaded")
|
145 |
|
146 |
-
if model_selected == "Gemma 2":
|
147 |
-
tokenizer = tokenizer_gemma
|
148 |
-
model = model_gemma
|
149 |
|
150 |
if model_selected == "Qwen2":
|
151 |
tokenizer = tokenizer_qwen
|
@@ -304,14 +304,14 @@ with gr.Blocks() as demo:
|
|
304 |
|
305 |
No_beam_group_list = [2]
|
306 |
|
307 |
-
|
308 |
-
|
309 |
|
310 |
with gr.Row():
|
311 |
|
312 |
with gr.Column (scale=0, min_width=200) as Models_Strategy:
|
313 |
|
314 |
-
model_selected = gr.Radio (["GPT2", "
|
315 |
strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
|
316 |
|
317 |
|
|
|
11 |
# Load default model as GPT2 and other models
|
12 |
|
13 |
|
14 |
+
#tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
15 |
+
#model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
16 |
|
17 |
tokenizer_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
18 |
model_gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
19 |
|
20 |
+
#tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
|
21 |
+
#model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=token)
|
22 |
|
23 |
tokenizer_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
|
24 |
model_qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
|
|
|
143 |
model = model_gpt2
|
144 |
#print (model_selected + " loaded")
|
145 |
|
146 |
+
#if model_selected == "Gemma 2":
|
147 |
+
#tokenizer = tokenizer_gemma
|
148 |
+
#model = model_gemma
|
149 |
|
150 |
if model_selected == "Qwen2":
|
151 |
tokenizer = tokenizer_qwen
|
|
|
304 |
|
305 |
No_beam_group_list = [2]
|
306 |
|
307 |
+
tokenizer = tokenizer_gpt2
|
308 |
+
model = model_gpt2
|
309 |
|
310 |
with gr.Row():
|
311 |
|
312 |
with gr.Column (scale=0, min_width=200) as Models_Strategy:
|
313 |
|
314 |
+
model_selected = gr.Radio (["GPT2", "Qwen2"], label="ML Model", value="GPT2")
|
315 |
strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
|
316 |
|
317 |
|