Spaces:
Runtime error
Runtime error
import torch | |
import streamlit as st | |
st.title("Title Generation with Transformers") | |
st.write("") | |
st.write("Input your text here!") | |
default_value = "Ukrainian counterattacks: Kharkiv's regional administrator said a number of villages around Malaya Rogan were retaken by Ukrainian forces. Video verified by CNN shows Ukrainian troops in control of Vilkhivka, one of the settlements roughly 20 miles from the Russian border. The success of Ukrainian forces around Kharkiv has been mirrored further north, near the city of Sumy, where Ukrainian troops have liberated a number of settlements, according to videos geolocated and verified by CNN. A separate counterattack in the south also led to the liberation of two villages from Russian forces northwest of Mariupol, according to the Zaporizhzhia regional military administration." | |
sent = st.text_area("Text", default_value, height = 50) | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
tokenizer = AutoTokenizer.from_pretrained("deep-learning-analytics/automatic-title-generation") | |
model = AutoModelForSeq2SeqLM.from_pretrained("deep-learning-analytics/automatic-title-generation") | |
def tokenize_data(text): | |
# Tokenize the review body | |
input_ = str(text) + ' </s>' | |
max_len = 120 | |
# tokenize inputs | |
tokenized_inputs = tokenizer(input_, padding='max_length', truncation=True, max_length=max_len, return_attention_mask=True, return_tensors='pt') | |
inputs={"input_ids": tokenized_inputs['input_ids'], | |
"attention_mask": tokenized_inputs['attention_mask']} | |
return inputs | |
def generate_answers(text): | |
inputs = tokenize_data(text) | |
results= model.generate(input_ids= inputs['input_ids'], attention_mask=inputs['attention_mask'], do_sample=True, | |
max_length=120, | |
top_k=120, | |
top_p=0.98, | |
early_stopping=True, | |
num_return_sequences=1) | |
answer = tokenizer.decode(results[0], skip_special_tokens=True) | |
return answer | |
answer = generate_answers(sent) | |
st.write(answer) | |
#iface = gr.Interface(fn=generate_answers,inputs=[gr.inputs.Textbox(lines=20)], outputs=["text"]) | |
#iface.launch(inline=False, share=True) |