|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization") |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization") |
|
|
|
def preprocess(inp): |
|
text = "summarize: " + inp + " </s>" |
|
features = tokenizer(text, return_tensors="pt") |
|
return features['input_ids'], features['attention_masks'] |
|
def predict(input_ids, attention_masks): |
|
outputs = model.generate( |
|
input_ids=input_ids, attention_mask=attention_masks, |
|
max_length=256, |
|
early_stopping=True, |
|
) |
|
res = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] |
|
return res |
|
|
|
if __name__ == '__main__': |
|
st.title("ViT5 News Abstractive Summarization (Vietnamese)") |
|
with st.container(): |
|
txt = st.text_area('Enter long documment...', ' ') |
|
inp_ids, attn_mask = preprocess(txt) |
|
st.write('Summary:', predict(inp_ids, attn_mask)) |
|
|