satia / app.py
stinoco's picture
Added classification models for subcategories
cef1d7f
import gradio as gr
import numpy as np
from transformers import pipeline
from utils.tokenizer import Tokenizer
from utils.lstm import LSTM
from utils.load_model import load_model
from utils.production_model import ProductionModel
# Cargamos modelos
## Transformers
pipeline_clf = pipeline("text-classification", model = "stinoco/beto-sentiment-analysis-finetuned", return_all_scores = True)
pipeline_pos = pipeline("token-classification", model = "sagorsarker/codeswitch-spaeng-pos-lince")
## LSTM
clf_marketing = load_model('marketing')
clf_cliente = load_model('cliente')
clf_conforme = load_model('conforme')
clf_devoluciones = load_model('devoluciones')
clf_entrega = load_model('entrega')
clf_financiamiento = load_model('financiamiento')
clf_otros = load_model('otros')
clf_stock = load_model('stock')
clf_ventas = load_model('ventas')
# PREDICT
def predict(text):
# Text Classification
classes = pipeline_clf(text)[0]
macro_probas = {element['label']: element['score'] for element in classes}
macro_probas = dict(sorted(macro_probas.items(), key=lambda x: x[1], reverse = True)[:4])
macro_probas['Resto'] = 1 - sum(macro_probas.values())
macro_label = max(macro_probas, key = macro_probas.get)
macro_labels = macro_label.split(' - ')
output = {macro_output: macro_probas, cliente_component: None, conforme_component: None,
devoluciones_component: None, entrega_component: None, financiamiento_component: None,
otros_component: None, stock_component: None, marketing_component: None,
ventas_component: None, row_cliente: gr.update(visible = False),
row_conforme: gr.update(visible = False), row_devoluciones: gr.update(visible = False),
row_entrega: gr.update(visible = False), row_financiamiento: gr.update(visible = False),
row_otros: gr.update(visible = False), row_stock: gr.update(visible = False),
row_marketing: gr.update(visible = False), row_ventas: gr.update(visible = False),}
if 'Atención al cliente' in macro_labels:
output[row_cliente] = gr.update(visible = True)
output[cliente_component] = clf_cliente.predict([text])
if 'Conforme' in macro_labels:
output[row_conforme] = gr.update(visible = True)
output[conforme_component] = clf_conforme.predict([text])
if 'Devoluciones' in macro_labels:
output[row_devoluciones] = gr.update(visible = True)
output[devoluciones_component] = clf_devoluciones.predict([text])
if 'Entrega' in macro_labels:
output[row_entrega] = gr.update(visible = True)
output[entrega_component] = clf_entrega.predict([text])
if 'Financiamiento' in macro_labels:
output[row_financiamiento] = gr.update(visible = True)
output[financiamiento_component] = clf_financiamiento.predict([text])
if 'Otros' in macro_labels:
output[row_otros] = gr.update(visible = True)
output[otros_component] = clf_otros.predict([text])
if 'Stock' in macro_labels:
output[row_stock] = gr.update(visible = True)
output[stock_component] = clf_stock.predict([text])
if 'Trade Marketing' in macro_labels:
output[row_marketing] = gr.update(visible = True)
output[marketing_component] = clf_marketing.predict([text])
if 'Ventas' in macro_labels:
output[row_ventas] = gr.update(visible = True)
output[ventas_component] = clf_ventas.predict([text])
return output
# DEMO
with gr.Blocks(title = 'Modelo NPS') as demo:
gr.Markdown(
'''
# <center>Modelo de Clasificación NPS</center>
Este es un modelo para categorizar reclamos de NPS, prueba escribiendo reclamos abajo!
''')
with gr.Column() as text_col:
with gr.Row():
text_input = gr.Textbox(placeholder = "Ingresa el reclamo acá", label = 'Reclamo')
#macro_output = gr.outputs.Label(label = 'Categorías Generales')
with gr.Row():
macro_output = gr.outputs.Label(label = 'Categorías Generales')
with gr.Row():
#macro_output = gr.outputs.Label(label = 'Categorías Generales')
with gr.Row(visible = False) as row_cliente:
cliente_component = gr.outputs.Label(label = 'Categorías Atención al Cliente')
with gr.Row(visible = False) as row_conforme:
conforme_component = gr.outputs.Label(label = 'Categorías Conforme')
with gr.Row(visible = False) as row_devoluciones:
devoluciones_component = gr.outputs.Label(label = 'Categorías Devoluciones')
with gr.Row(visible = False) as row_entrega:
entrega_component = gr.outputs.Label(label = 'Categorías Entrega')
with gr.Row(visible = False) as row_financiamiento:
financiamiento_component = gr.outputs.Label(label = 'Categorías Financiamiento')
with gr.Row(visible = False) as row_otros:
otros_component = gr.outputs.Label(label = 'Categorías Otros')
with gr.Row(visible = False) as row_stock:
stock_component = gr.outputs.Label(label = 'Categorías Stock')
with gr.Row(visible = False) as row_marketing:
marketing_component = gr.outputs.Label(label = 'Categorías Trade Marketing')
with gr.Row(visible = False) as row_ventas:
ventas_component = gr.outputs.Label(label = 'Categorías Ventas')
outputs = [
macro_output, cliente_component, conforme_component, devoluciones_component,
entrega_component, financiamiento_component, otros_component, stock_component,
marketing_component, ventas_component, row_cliente, row_conforme,
row_devoluciones, row_entrega, row_financiamiento, row_otros,
row_stock, row_marketing, row_ventas, ]
button = gr.Button('Submit')
button.click(fn = predict, inputs = text_input, outputs = outputs)
gr.Examples(
examples = [['sale mas a cuenta comprar en los supermercados que a la cervecería'],
['llega las latas abolladas sucias'],
['vendedor no viene presencialmente solo por whatsapp'],
['mejorar la atención de los repartidores porque roban'],
['seria bueno mas promociones y publicidad']],
inputs = text_input)
demo.launch()