Spaces:
Runtime error
Runtime error
mariagrandury
commited on
Commit
•
832ee1c
1
Parent(s):
3954682
Update application
Browse files
app.py
CHANGED
@@ -8,22 +8,22 @@ LOGO = "https://raw.githubusercontent.com/nlp-en-es/assets/main/logo.png"
|
|
8 |
|
9 |
MODELS = {
|
10 |
"RoBERTa Base": {
|
11 |
-
"url": "
|
12 |
},
|
13 |
"RoBERTa Base Gaussian": {
|
14 |
-
"url": "
|
15 |
},
|
16 |
"RoBERTa Base Random": {
|
17 |
-
"url": "
|
18 |
},
|
19 |
"RoBERTa Base Stepwise": {
|
20 |
-
"url": "
|
21 |
},
|
22 |
"RoBERTa Base Gaussian Experiment": {
|
23 |
-
"url": "
|
24 |
},
|
25 |
"RoBERTa Base Random Experiment": {
|
26 |
-
"url": "
|
27 |
}
|
28 |
}
|
29 |
|
@@ -41,9 +41,9 @@ PROMPT_LIST = [
|
|
41 |
|
42 |
|
43 |
@st.cache(show_spinner=False, persist=True)
|
44 |
-
def load_model(masked_text,
|
45 |
-
model = AutoModelForMaskedLM.from_pretrained(
|
46 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
47 |
nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
48 |
result = nlp(masked_text)
|
49 |
return result
|
@@ -66,7 +66,8 @@ st.markdown(
|
|
66 |
"""
|
67 |
)
|
68 |
|
69 |
-
model_name = st.selectbox("Model",
|
|
|
70 |
|
71 |
prompt = st.selectbox("Prompt", ["Random", "Custom"])
|
72 |
if prompt == "Custom":
|
@@ -78,7 +79,7 @@ text = st.text_area("Enter text", prompt_box)
|
|
78 |
if st.button("Fill the mask"):
|
79 |
with st.spinner(text="Getting results..."):
|
80 |
st.subheader("Result")
|
81 |
-
result = load_model(text,
|
82 |
if "error" in result:
|
83 |
if type(result["error"]) is str:
|
84 |
st.write(f'{result["error"]}.', end=" ")
|
|
|
8 |
|
9 |
MODELS = {
|
10 |
"RoBERTa Base": {
|
11 |
+
"url": "bertin-project/bertin-roberta-base-spanish"
|
12 |
},
|
13 |
"RoBERTa Base Gaussian": {
|
14 |
+
"url": "bertin-project/bertin-base-gaussian"
|
15 |
},
|
16 |
"RoBERTa Base Random": {
|
17 |
+
"url": "bertin-project/bertin-base-random"
|
18 |
},
|
19 |
"RoBERTa Base Stepwise": {
|
20 |
+
"url": "bertin-project/bertin-base-stepwise"
|
21 |
},
|
22 |
"RoBERTa Base Gaussian Experiment": {
|
23 |
+
"url": "bertin-project/bertin-base-gaussian-exp-512seqlen"
|
24 |
},
|
25 |
"RoBERTa Base Random Experiment": {
|
26 |
+
"url": "bertin-project/bertin-base-random-exp-512seqlen"
|
27 |
}
|
28 |
}
|
29 |
|
|
|
41 |
|
42 |
|
43 |
@st.cache(show_spinner=False, persist=True)
|
44 |
+
def load_model(masked_text, model_url):
|
45 |
+
model = AutoModelForMaskedLM.from_pretrained(model_url)
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(model_url)
|
47 |
nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
48 |
result = nlp(masked_text)
|
49 |
return result
|
|
|
66 |
"""
|
67 |
)
|
68 |
|
69 |
+
model_name = st.selectbox("Model",MODELS.keys())
|
70 |
+
model_url = MODELS[model_name]["url"]
|
71 |
|
72 |
prompt = st.selectbox("Prompt", ["Random", "Custom"])
|
73 |
if prompt == "Custom":
|
|
|
79 |
if st.button("Fill the mask"):
|
80 |
with st.spinner(text="Getting results..."):
|
81 |
st.subheader("Result")
|
82 |
+
result = load_model(text, model_url)
|
83 |
if "error" in result:
|
84 |
if type(result["error"]) is str:
|
85 |
st.write(f'{result["error"]}.', end=" ")
|