metadata
license: cc-by-4.0
language:
- ru
- kbd
- krc
metrics:
- accuracy
pipeline_tag: text-classification
datasets:
- alimboff/adiga_taulu_corpus
library_name: transformers
tags:
- code
Zehedz
Описание модели
Эта модель классифицирует тексты на три языка: Русский (rus_Cyrl
), Кабардино-Черкесский (kbd_Cyrl
) и Карачаево-Балкарский (krc_Cyrl
). Модель основана на архитектуре BERT и обучена на специализированном корпусе, охватывающем данные для каждого из указанных языков. Модель показывает высокую точность на этапе валидации и обладает высокой скоростью работы как на GPU, так и на CPU.
Результаты обучения
Epoch 1/3
Train loss: 0.0431 | accuracy: 0.9889
Val loss: 0.0014 | accuracy: 1.0000
----------
Epoch 2/3
Train loss: 0.0111 | accuracy: 0.9974
Val loss: 0.0023 | accuracy: 0.9994
----------
Epoch 3/3
Train loss: 0.0081 | accuracy: 0.9982
Val loss: 0.0013 | accuracy: 1.0000
Производительность
- Средняя скорость работы на GPU (CUDA): 0.008 секунд на одно предсказание
- Средняя скорость работы на CPU: 0.05 секунд на одно предсказание
Использование модели
Код для работы с моделью:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
model_path = 'BERT_v3/zehedz'
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=3, problem_type="single_label_classification")
tokenizer = BertTokenizer.from_pretrained(model_path)
def predict(text):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
encoding = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=512,
return_token_type_ids=False,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt',
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
labels = ['kbd_Cyrl', 'rus_Cyrl', 'krc_Cyrl']
predicted_class = labels[torch.argmax(logits, dim=1).cpu().numpy()[0]]
return predicted_class
text = "Привет, как дела?"
print(predict(text))
Использование в API Space на Hugging Face
import torch
Эта модель идеально подходит для задач, связанных с автоматическим определением языка текста в многоязычных системах и приложениях.