import random from mtranslate import translate import streamlit as st from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline LOGO = "https://raw.githubusercontent.com/nlp-en-es/assets/main/logo.png" MODELS = { "RoBERTa Base": { "url": "bertin-project/bertin-roberta-base-spanish" }, "RoBERTa Base Gaussian": { "url": "bertin-project/bertin-base-gaussian" }, "RoBERTa Base Random": { "url": "bertin-project/bertin-base-random" }, "RoBERTa Base Stepwise": { "url": "bertin-project/bertin-base-stepwise" }, "RoBERTa Base Gaussian Experiment": { "url": "bertin-project/bertin-base-gaussian-exp-512seqlen" }, "RoBERTa Base Random Experiment": { "url": "bertin-project/bertin-base-random-exp-512seqlen" } } PROMPT_LIST = [ "Fui a la librería a comprar un .", "¡Qué buen hace hoy!", "Hoy empiezan las vacaciones, vamos a la .", "Mi color favorito es el .", "Voy a , estoy muy cansada.", "Mañana vienen mis amigos de .", "¿Te apetece venir a conmigo?", "En verano hace mucho .", "En el bosque había ." ] @st.cache(show_spinner=False, persist=True) def load_model(masked_text, model_url): model = AutoModelForMaskedLM.from_pretrained(model_url) tokenizer = AutoTokenizer.from_pretrained(model_url) nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer) result = nlp(masked_text) return result # Page st.set_page_config(page_title="BERTIN Demo", page_icon=LOGO) st.title("BERTIN") #Sidebar st.sidebar.image(LOGO) # Body st.markdown( """ BERTIN is a series of BERT-based models for Spanish. The models are trained with Flax and using TPUs sponsored by Google since this is part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organised by HuggingFace. """ ) model_name = st.selectbox("Model",MODELS.keys()) model_url = MODELS[model_name]["url"] prompt = st.selectbox("Prompt", ["Random", "Custom"]) if prompt == "Custom": prompt_box = "Enter your masked text here..." else: prompt_box = random.choice(PROMPT_LIST) text = st.text_area("Enter text", prompt_box) if st.button("Fill the mask"): with st.spinner(text="Getting results..."): st.subheader("Result") result = load_model(text, model_url) if "error" in result: if type(result["error"]) is str: st.write(f'{result["error"]}.', end=" ") if "estimated_time" in result: st.write( f'Please try again in about {result["estimated_time"]:.0f} seconds.' ) else: if type(result["error"]) is list: for error in result["error"]: st.write(f"{error}") else: result_sequence, result_token = result[0]["sequence"], result[0]["token_str"] st.write(result_sequence) st.text("English translation") st.write(translate(result_sequence, "en", "es")) st.markdown( """ ### Team members - Javier de la Rosa ([versae](https://huggingface.co/versae)) - Eduardo González ([edugp](https://huggingface.co/edugp)) - Paulo Villegas ([paulo](https://huggingface.co/paulo)) - Pablo González de Prado ([Pablogps](https://huggingface.co/Pablogps)) - Manu Romero ([mrm8488](https://huggingface.co/mrm8488)) - María Grandury ([mariagrandury](https://huggingface.co/mariagrandury)) ### More information You can find more information about these models [here](https://huggingface.co/bertin-project/bertin-roberta-base-spanish). """ )