deep-learning-analytics's picture
Update app.py
8817b20
raw
history blame
2.25 kB
import torch
import streamlit as st
st.title("Title Generation with Transformers")
st.write("")
st.write("Input your text here!")
default_value = "Ukrainian counterattacks: Kharkiv's regional administrator said a number of villages around Malaya Rogan were retaken by Ukrainian forces. Video verified by CNN shows Ukrainian troops in control of Vilkhivka, one of the settlements roughly 20 miles from the Russian border. The success of Ukrainian forces around Kharkiv has been mirrored further north, near the city of Sumy, where Ukrainian troops have liberated a number of settlements, according to videos geolocated and verified by CNN. A separate counterattack in the south also led to the liberation of two villages from Russian forces northwest of Mariupol, according to the Zaporizhzhia regional military administration."
sent = st.text_area("Text", default_value, height = 50)
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("deep-learning-analytics/automatic-title-generation")
model = AutoModelForSeq2SeqLM.from_pretrained("deep-learning-analytics/automatic-title-generation")
def tokenize_data(text):
# Tokenize the review body
input_ = str(text) + ' </s>'
max_len = 120
# tokenize inputs
tokenized_inputs = tokenizer(input_, padding='max_length', truncation=True, max_length=max_len, return_attention_mask=True, return_tensors='pt')
inputs={"input_ids": tokenized_inputs['input_ids'],
"attention_mask": tokenized_inputs['attention_mask']}
return inputs
def generate_answers(text):
inputs = tokenize_data(text)
results= model.generate(input_ids= inputs['input_ids'], attention_mask=inputs['attention_mask'], do_sample=True,
max_length=120,
top_k=120,
top_p=0.98,
early_stopping=True,
num_return_sequences=1)
answer = tokenizer.decode(results[0], skip_special_tokens=True)
return answer
answer = generate_answers(sent)
st.write(answer)
#iface = gr.Interface(fn=generate_answers,inputs=[gr.inputs.Textbox(lines=20)], outputs=["text"])
#iface.launch(inline=False, share=True)