DmitriySv's picture
Update app.py
b009f2e verified
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import gradio as gr
class MultiTaskBertModel(nn.Module):
def __init__(self, num_labels_task1, num_labels_task2):
super(MultiTaskBertModel, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.classifier_task1 = nn.Linear(self.bert.config.hidden_size, num_labels_task1)
self.classifier_task2 = nn.Linear(self.bert.config.hidden_size, num_labels_task2)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits_task1 = self.classifier_task1(pooled_output)
logits_task2 = self.classifier_task2(pooled_output)
return logits_task1, logits_task2
# Загрузка сохраненной модели
model = MultiTaskBertModel(num_labels_task1=3, num_labels_task2=4)
model = torch.load('ticket_classifier.pth', weights_only=False, map_location=torch.device('cpu'))
model.eval()
# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Функция для предсказания
label_mapping_type = {'варианты доставки': 0,
'варианты оплаты': 1,
'возврат средств': 2,
'восстановление пароля': 3,
'время доставки': 4,
'выбор адреса доставки': 5,
'жалоба': 6,
'изменение адреса доставки': 7,
'изменение заказа': 8,
'отзыв': 9,
'отмена заказа': 10,
'отслеживание возврата средств': 11,
'отслеживание заказа': 12,
'подписка на новостную рассылку': 13,
'политика возврата': 14,
'получение информации': 15,
'проблемы с оплатой': 16,
'проблемы с регистрацией': 17,
'проверка платы за отмену': 18,
'проверка счета': 19,
'проверка счетов': 20,
'размещение заказа': 21,
'редактирование учетной записи': 22,
'связь с человеком': 23,
'связь со службой поддержки': 24,
'смена учетной записи': 25,
'создание учетной записи': 26,
'удаление аккаунта': 27}
label_mapping_priority = {'высокий': 0, 'низкий': 1, 'средний': 2}
def get_key_by_value(dictionary, value):
reverse_dict = {v: k for k, v in dictionary.items()}
return reverse_dict.get(value)
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Получаем предсказания для двух задач
logits_task1, logits_task2 = model(input_ids=input_ids, attention_mask=attention_mask)
# Преобразование логитов в предсказания классов
pred_task1 = torch.argmax(logits_task1, dim=1).item()
pred_task2 = torch.argmax(logits_task2, dim=1).item()
return {"Тема": get_key_by_value(label_mapping_type, pred_task1)}
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Введите запрос для анализа..."), # Обновлено на gr.Textbox
outputs=gr.JSON(), # Обновлено на gr.JSON
title="Классификация запроса",
description="'варианты доставки', 'варианты оплаты', 'возврат средств', 'восстановление пароля', 'время доставки', 'выбор адреса доставки', 'жалоба', 'изменение адреса доставки', 'изменение заказа', 'отзыв', 'отмена заказа', 'отслеживание возврата средств', 'отслеживание заказа', 'подписка на новостную рассылку', 'политика возврата', 'получение информации', 'проблемы с оплатой', 'проблемы с регистрацией', 'проверка платы за отмену', 'проверка счета', 'размещение заказа', 'редактирование учетной записи', 'связь с человеком', 'связь со службой поддержки', 'смена учетной записи', 'создание учетной записи', 'удаление аккаунта'",
)
iface.launch(share=True)