Spaces:
Sleeping
Sleeping
Amitontheweb
commited on
Commit
•
7c242ca
1
Parent(s):
78b413a
Update app.py
Browse files
app.py
CHANGED
@@ -6,12 +6,20 @@ import torch
|
|
6 |
import gradio as gr
|
7 |
|
8 |
|
9 |
-
# Load default model as GPT2
|
10 |
|
11 |
|
12 |
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
13 |
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# Define functions
|
17 |
|
@@ -129,13 +137,17 @@ def generate(input_text, number_steps, number_beams, number_beam_groups, diversi
|
|
129 |
def load_model(model_selected):
|
130 |
|
131 |
if model_selected == "gpt2":
|
132 |
-
tokenizer =
|
133 |
-
model =
|
134 |
#print (model_selected + " loaded")
|
135 |
|
136 |
if model_selected == "Gemma 2":
|
137 |
-
tokenizer =
|
138 |
-
model =
|
|
|
|
|
|
|
|
|
139 |
|
140 |
|
141 |
|
@@ -298,7 +310,7 @@ with gr.Blocks() as demo:
|
|
298 |
|
299 |
with gr.Column (scale=0, min_width=200) as Models_Strategy:
|
300 |
|
301 |
-
model_selected = gr.Radio (["gpt2", "Gemma 2"], label="ML Model", value="gpt2")
|
302 |
strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
|
303 |
|
304 |
|
|
|
6 |
import gradio as gr
|
7 |
|
8 |
|
9 |
+
# Load default model as GPT2 and other models
|
10 |
|
11 |
|
12 |
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
13 |
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
14 |
|
15 |
+
tokenizer_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
16 |
+
model_gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
17 |
+
|
18 |
+
tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
|
19 |
+
model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
|
20 |
+
|
21 |
+
tokenizer_Mistral = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
22 |
+
model_Mistral = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
|
23 |
|
24 |
# Define functions
|
25 |
|
|
|
137 |
def load_model(model_selected):
|
138 |
|
139 |
if model_selected == "gpt2":
|
140 |
+
tokenizer = tokenizer_gpt2
|
141 |
+
model = model_gpt2
|
142 |
#print (model_selected + " loaded")
|
143 |
|
144 |
if model_selected == "Gemma 2":
|
145 |
+
tokenizer = tokenizer_gemma
|
146 |
+
model = model_gemma
|
147 |
+
|
148 |
+
if model_selected == "Mistral":
|
149 |
+
tokenizer = tokenizer_Mistral
|
150 |
+
model = model_Mistral
|
151 |
|
152 |
|
153 |
|
|
|
310 |
|
311 |
with gr.Column (scale=0, min_width=200) as Models_Strategy:
|
312 |
|
313 |
+
model_selected = gr.Radio (["gpt2", "Gemma 2", "Mistral"], label="ML Model", value="gpt2")
|
314 |
strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
|
315 |
|
316 |
|