Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from dual_regression_model import DualRegressionModel | |
import transformers | |
from transformers import pipeline | |
from functools import partial | |
# load the models | |
# CLF: A-pt-bs16-dbmdz-bert-base-italian-cased | |
clf_model_tag = "clf_model/" | |
clf_tokenizer = transformers.AutoTokenizer.from_pretrained(clf_model_tag) | |
clf_model = transformers.AutoModelForSequenceClassification.from_pretrained(clf_model_tag) | |
clf_pipeline = pipeline("text-classification", model=clf_model, tokenizer=clf_tokenizer) | |
# REG | |
reg_model_tag = "distilbert-base-multilingual-cased" | |
reg_model_folder = "reg_model/regression_model.pt" | |
reg_model = DualRegressionModel(model_name_or_path=reg_model_tag) | |
reg_model.load_model(reg_model_folder) | |
# define the function to be used for prediction | |
def predict(text): | |
# predict the class | |
clf_prediction = clf_pipeline(text)[0] | |
# predict the coordinates | |
reg_input = reg_model.tokenizer(text, return_tensors="pt") | |
reg_prediction = reg_model(reg_input) | |
latitude, longitude = reg_prediction["latitude"].item(), reg_prediction["longitude"].item() | |
lat_min = 38 | |
lat_max = 46 | |
long_min = 8 | |
long_max = 18 | |
# return the results | |
html_output = f"<h3>The identified region is: {clf_prediction['label']}</h3>" | |
# plot points on the map of Italy | |
html_output += f'<h3>Predicted point on map:</h3><p>Latitude: {latitude}</p><p>Longitude: {longitude}</p>' | |
html_output += f'<iframe width="425" height="350" frameborder="0" scrolling="no" marginheight="0" marginwidth="0" src="https://www.openstreetmap.org/export/embed.html?bbox={long_min}%2C{lat_min}%2C{long_max}%2C{lat_max}&layer=mapnik&marker={latitude}%2C{longitude}" style="border: 1px solid black"></iframe><br/><small><a href="https://www.openstreetmap.org/#map=13/{latitude}/{longitude}">Visualizza mappa ingrandita</a></small>' | |
return html_output | |
# -------------------------------------------------------------------------------------------- | |
# Gradio interface | |
# -------------------------------------------------------------------------------------------- | |
# define the interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=2, placeholder="Insert the text here..."), | |
outputs=gr.HTML(), | |
title="DANTE: Dialect ANalysis TEam", | |
description="This is a demo of a classification and regression model for locating the italian dialect of a given text.", | |
examples=[ | |
["Bisognerebbe saperli materializzare .... !! Ma ovviamente .. belin .... NO SE PEU SCIUSCIA' E SCIORBI'"], | |
["Guaglio' Buongiorno! Azz! Vir te si scurdat puparuol e mulignane pero '!! E che se fa😑"], | |
["Il massimo...ghe ne minga par nisun"], | |
["Che poi a me la tuta piace na cifra da vede. Subisco un po' lo stigma sociale che noi con la fregna dovemo stà sempre apposto.",] | |
] | |
) | |
# launch the interface | |
iface.launch() | |