Spaces:
Runtime error
Runtime error
#Import transformers and gradio | |
import transformers | |
import gradio as gr | |
import git | |
#Load arabert preprocessor | |
import git | |
git.Git("arabert").clone("https://github.com/aub-mind/arabert") | |
from arabert.preprocess import ArabertPreprocessor | |
arabert_prep = ArabertPreprocessor(model_name="bert-base-arabert", keep_emojis=False) | |
#Load Model | |
from transformers import EncoderDecoderModel, AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("tareknaous/bert2bert-empathetic-response-msa") | |
model = EncoderDecoderModel.from_pretrained("tareknaous/bert2bert-empathetic-response-msa") | |
model.eval() | |
def generate_response(text, minimum_length, k, p, temperature): | |
text_clean = arabert_prep.preprocess(text) | |
inputs = tokenizer.encode_plus(text_clean,return_tensors='pt') | |
outputs = model.generate(input_ids = inputs.input_ids, | |
attention_mask = inputs.attention_mask, | |
do_sample = True, | |
min_length=minimum_length, | |
top_k = k, | |
top_p = p, | |
temperature = temperature) | |
preds = tokenizer.batch_decode(outputs) | |
response = str(preds) | |
response = response.replace("\'", '') | |
response = response.replace("[[CLS]", '') | |
response = response.replace("[SEP]]", '') | |
response = str(arabert_prep.desegment(response)) | |
return response | |
title = 'BERT2BERT Response Generation in Arabic' | |
description = 'This demo is for a BERT2BERT model trained for single-turn open-domain dialogue response generation in Modern Standard Arabic' | |
gr.Interface(fn=generate_response, | |
inputs=[ | |
gr.inputs.Textbox(), | |
gr.inputs.Slider(5, 20, step=1, label='Minimum Output Length'), | |
gr.inputs.Slider(0, 1000, step=10, label='Top-K'), | |
gr.inputs.Slider(0, 1, step=0.1, label='Top-P'), | |
gr.inputs.Slider(0, 3, step=0.1, label='Temperature'), | |
], | |
outputs="text", | |
title=title, | |
description=description).launch() |