dia-critic / app.py
Florin Bobiș
remove copyright
f99df08
raw
history blame
No virus
1.61 kB
import streamlit as st
from transformers import MT5ForConditionalGeneration, T5Tokenizer
import time
@st.cache_resource
def load_model():
model = MT5ForConditionalGeneration.from_pretrained('iliemihai/mt5-base-romanian-diacritics', cache_dir='cache/')
return model
@st.cache_resource
def load_tokenizer():
tokenizer = T5Tokenizer.from_pretrained('iliemihai/mt5-base-romanian-diacritics', legacy=False, cache_dir='cache/')
return tokenizer
def initialize_app():
st.set_page_config(
page_title="Dia-critic",
page_icon="public/favicon.ico",
menu_items={
"About": "### Contact\n ✉️florinbobis@gmail.com",
},
)
st.title("🖋️Dia-critic")
def generate_text(text):
model = load_model()
tokenizer = load_tokenizer()
inputs = tokenizer(text, max_length=256, truncation=True, return_tensors="pt")
outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
return output
def main():
initialize_app()
input_text = st.text_area("Introduceți textul mai jos")
st.write(f'{len(input_text)} caractere.')
if st.button("Corectează"):
if input_text != "":
res = ''
with st.spinner('Sarcină în desfășurare...'):
# start task
res = generate_text(input_text)
with st.container(border=True):
st.markdown(res)
else:
st.warning("Câmpul este gol!")
if __name__ == "__main__":
main()