paraphrase_es / app.py
milyiyo's picture
Update app.py
2f251fd
raw
history blame
2.75 kB
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
tokenizer = AutoTokenizer.from_pretrained(
"prithivida/parrot_paraphraser_on_T5", use_auth_token=os.environ["AUTH_TOKEN"])
model = AutoModelForSeq2SeqLM.from_pretrained(
"prithivida/parrot_paraphraser_on_T5", use_auth_token=os.environ["AUTH_TOKEN"])
pln_es_to_en = pipeline('translation_es_to_en',
model=AutoModelForSeq2SeqLM.from_pretrained(
'Helsinki-NLP/opus-mt-es-en'),
tokenizer=AutoTokenizer.from_pretrained(
'Helsinki-NLP/opus-mt-es-en')
)
pln_en_to_es = pipeline('translation_en_to_es',
model=AutoModelForSeq2SeqLM.from_pretrained(
'Helsinki-NLP/opus-mt-en-es'),
tokenizer=AutoTokenizer.from_pretrained(
'Helsinki-NLP/opus-mt-en-es')
)
def paraphrase(sentence: str, lang: str, count: str):
p_count = int(count)
if p_count <= 0 or len(sentence.strip()) == 0:
return {'result': []}
sentence_input = sentence
if lang == 'ES':
sentence_input = pln_es_to_en(sentence_input)[0]['translation_text']
text = f"paraphrase: {sentence_input} </s>"
encoding = tokenizer.encode_plus(text, padding=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
outputs = model.generate(
input_ids=input_ids, attention_mask=attention_masks,
max_length=512, # 256,
do_sample=True,
top_k=120,
top_p=0.95,
early_stopping=True,
num_return_sequences=p_count
)
res = []
for output in outputs:
line = tokenizer.decode(
output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
res.append(line)
if lang == 'EN':
return {'result': res}
else:
res_es = [pln_en_to_es(x)[0]['translation_text']
for x in res]
return {'result': res_es}
def paraphrase_dummy(sentence: str, lang: str, count: str):
return {'result': []}
iface = gr.Interface(fn=paraphrase,
inputs=[
gr.inputs.Textbox(
lines=2, placeholder=None, label='Sentence'),
gr.inputs.Dropdown(
['ES', 'EN'], type="value", label='Language'),
gr.inputs.Number(
default=3, label='Paraphrases count'),
],
outputs=[gr.outputs.JSON(label=None)])
iface.launch()