Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from transformers import BertTokenizer, BertModel | |
import gradio as gr | |
class MultiTaskBertModel(nn.Module): | |
def __init__(self, num_labels_task1, num_labels_task2): | |
super(MultiTaskBertModel, self).__init__() | |
self.bert = BertModel.from_pretrained('bert-base-uncased') | |
self.classifier_task1 = nn.Linear(self.bert.config.hidden_size, num_labels_task1) | |
self.classifier_task2 = nn.Linear(self.bert.config.hidden_size, num_labels_task2) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
pooled_output = outputs.pooler_output | |
logits_task1 = self.classifier_task1(pooled_output) | |
logits_task2 = self.classifier_task2(pooled_output) | |
return logits_task1, logits_task2 | |
# Загрузка сохраненной модели | |
model = MultiTaskBertModel(num_labels_task1=3, num_labels_task2=4) | |
model = torch.load('ticket_classifier.pth', weights_only=False, map_location=torch.device('cpu')) | |
model.eval() | |
# Загрузка токенизатора | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Функция для предсказания | |
label_mapping_type = {'варианты доставки': 0, | |
'варианты оплаты': 1, | |
'возврат средств': 2, | |
'восстановление пароля': 3, | |
'время доставки': 4, | |
'выбор адреса доставки': 5, | |
'жалоба': 6, | |
'изменение адреса доставки': 7, | |
'изменение заказа': 8, | |
'отзыв': 9, | |
'отмена заказа': 10, | |
'отслеживание возврата средств': 11, | |
'отслеживание заказа': 12, | |
'подписка на новостную рассылку': 13, | |
'политика возврата': 14, | |
'получение информации': 15, | |
'проблемы с оплатой': 16, | |
'проблемы с регистрацией': 17, | |
'проверка платы за отмену': 18, | |
'проверка счета': 19, | |
'проверка счетов': 20, | |
'размещение заказа': 21, | |
'редактирование учетной записи': 22, | |
'связь с человеком': 23, | |
'связь со службой поддержки': 24, | |
'смена учетной записи': 25, | |
'создание учетной записи': 26, | |
'удаление аккаунта': 27} | |
label_mapping_priority = {'высокий': 0, 'низкий': 1, 'средний': 2} | |
def get_key_by_value(dictionary, value): | |
reverse_dict = {v: k for k, v in dictionary.items()} | |
return reverse_dict.get(value) | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
input_ids = inputs["input_ids"] | |
attention_mask = inputs["attention_mask"] | |
# Получаем предсказания для двух задач | |
logits_task1, logits_task2 = model(input_ids=input_ids, attention_mask=attention_mask) | |
# Преобразование логитов в предсказания классов | |
pred_task1 = torch.argmax(logits_task1, dim=1).item() | |
pred_task2 = torch.argmax(logits_task2, dim=1).item() | |
return {"Тема": get_key_by_value(label_mapping_type, pred_task1)} | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=2, placeholder="Введите запрос для анализа..."), # Обновлено на gr.Textbox | |
outputs=gr.JSON(), # Обновлено на gr.JSON | |
title="Классификация запроса", | |
description="'варианты доставки', 'варианты оплаты', 'возврат средств', 'восстановление пароля', 'время доставки', 'выбор адреса доставки', 'жалоба', 'изменение адреса доставки', 'изменение заказа', 'отзыв', 'отмена заказа', 'отслеживание возврата средств', 'отслеживание заказа', 'подписка на новостную рассылку', 'политика возврата', 'получение информации', 'проблемы с оплатой', 'проблемы с регистрацией', 'проверка платы за отмену', 'проверка счета', 'размещение заказа', 'редактирование учетной записи', 'связь с человеком', 'связь со службой поддержки', 'смена учетной записи', 'создание учетной записи', 'удаление аккаунта'", | |
) | |
iface.launch(share=True) | |