Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, TFAutoModel | |
import joblib | |
from huggingface_hub import hf_hub_download | |
import json | |
from typing import List | |
class TextClassificationPipeline: | |
def __init__(self, tokenizer, distilbert_model, xgb_model): | |
self.tokenizer = tokenizer | |
self.distilbert_model = distilbert_model | |
self.xgb_model = xgb_model | |
def __call__(self, text): | |
inputs = self.tokenizer(text, return_tensors="tf", padding=True, truncation=True, max_length=128) | |
outputs = self.distilbert_model(**inputs) | |
embeddings = outputs.last_hidden_state[:, 0, :].numpy() | |
prediction = self.xgb_model.predict(embeddings) | |
return prediction | |
HF_MODEL_ID = "AndresR2909/suicide-related-text-classification_distilbert_xgboost" | |
# Descargar modelo | |
xgboost_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="xgboost_model.joblib") | |
# Cargar los modelos | |
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID) | |
distilbert_model = TFAutoModel.from_pretrained(HF_MODEL_ID) | |
xgb_model = joblib.load(xgboost_path) | |
# Crear el pipeline una sola vez al inicio | |
pipeline = TextClassificationPipeline(tokenizer, distilbert_model, xgb_model) | |
# Funci贸n para la API | |
def predict_api(texts: List[str]) -> List[int]: | |
# Hacer predicciones (usando el pipeline precargado) | |
predictions = [pipeline(text)[0] for text in texts] | |
return predictions | |
# Crear la interfaz de Gradio | |
def main(text): | |
# Conversi贸n de texto a lista (fuera de la funci贸n predict) | |
string_list = [text] | |
pred = predict_api(string_list) | |
if pred[0] == 0: | |
return "Normal" | |
else: | |
return "Relacionado con suicidio" | |
# Crear la interfaz (opcional) | |
iface = gr.Interface( | |
fn=main, | |
inputs=gr.Textbox(lines=2, placeholder="Introduce un texto aqu铆..."), | |
outputs="text", | |
title="Clasificaci贸n de Texto (API)", | |
description="Introduce un texto para obtener una predicci贸n en formato JSON.", | |
) | |
# Crear un bloque de gradio para el API | |
with gr.Blocks() as blocks: | |
gr.Textbox(lines=2, placeholder="Introduce un texto aqu铆...", label="Entrada de texto") | |
gr.Textbox(label="Resultado", interactive=False) | |
# Lanzar la interfaz gr谩fica si deseas compartirla | |
iface.launch(share=True) | |
# Montar la API | |
#app = gr.mount_gradio_app(iface, blocks=blocks, path="/api/predict") | |