Spaces:
Build error
Build error
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
tokenizer = AutoTokenizer.from_pretrained( | |
"prithivida/parrot_paraphraser_on_T5", use_auth_token=os.environ["AUTH_TOKEN"]) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"prithivida/parrot_paraphraser_on_T5", use_auth_token=os.environ["AUTH_TOKEN"]) | |
pln_es_to_en = pipeline('translation_es_to_en', | |
model=AutoModelForSeq2SeqLM.from_pretrained( | |
'Helsinki-NLP/opus-mt-es-en'), | |
tokenizer=AutoTokenizer.from_pretrained( | |
'Helsinki-NLP/opus-mt-es-en') | |
) | |
pln_en_to_es = pipeline('translation_en_to_es', | |
model=AutoModelForSeq2SeqLM.from_pretrained( | |
'Helsinki-NLP/opus-mt-en-es'), | |
tokenizer=AutoTokenizer.from_pretrained( | |
'Helsinki-NLP/opus-mt-en-es') | |
) | |
def paraphrase(sentence: str, lang: str, count: str): | |
p_count = int(count) | |
if p_count <= 0 or len(sentence.strip()) == 0: | |
return {'result': []} | |
sentence_input = sentence | |
if lang == 'ES': | |
sentence_input = pln_es_to_en(sentence_input)[0]['translation_text'] | |
text = f"paraphrase: {sentence_input} </s>" | |
encoding = tokenizer.encode_plus(text, padding=True, return_tensors="pt") | |
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"] | |
outputs = model.generate( | |
input_ids=input_ids, attention_mask=attention_masks, | |
max_length=512, # 256, | |
do_sample=True, | |
top_k=120, | |
top_p=0.95, | |
early_stopping=True, | |
num_return_sequences=p_count | |
) | |
res = [] | |
for output in outputs: | |
line = tokenizer.decode( | |
output, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
res.append(line) | |
if lang == 'EN': | |
return {'result': res} | |
else: | |
res_es = [pln_en_to_es(x)[0]['translation_text'] | |
for x in res] | |
return {'result': res_es} | |
def paraphrase_dummy(sentence: str, lang: str, count: str): | |
return {'result': []} | |
iface = gr.Interface(fn=paraphrase, | |
inputs=[ | |
gr.inputs.Textbox( | |
lines=2, placeholder=None, label='Sentence'), | |
gr.inputs.Dropdown( | |
['ES', 'EN'], type="value", label='Language'), | |
gr.inputs.Number( | |
default=3, label='Paraphrases count'), | |
], | |
outputs=[gr.outputs.JSON(label=None)]) | |
iface.launch() |