alimboff commited on
Commit
90e9236
1 Parent(s): 98af1e4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -62
README.md CHANGED
@@ -14,25 +14,25 @@ tags:
14
  - code
15
  ---
16
 
17
- ### README.md для модели классификации 3-х языков (русский, кабардинский, карачаево-балкарский)
18
 
19
  #### Описание модели
20
- Эта модель классифицирует тексты на три языка: русский (`rus_Cyrl`), кабардинский (`kbd_Cyrl`) и карачаево-балкарский (`krc_Cyrl`). Модель основана на архитектуре BERT и обучена на специализированном корпусе, охватывающем данные для каждого из указанных языков. Модель показывает высокую точность на этапе валидации и обладает высокой скоростью работы как на GPU, так и на CPU.
21
 
22
  #### Результаты обучения
23
 
24
  ```
25
  Epoch 1/3
26
- Train loss 0.0431 accuracy 0.9889
27
- Val loss 0.0014 accuracy 1.0000
28
  ----------
29
  Epoch 2/3
30
- Train loss 0.0111 accuracy 0.9974
31
- Val loss 0.0023 accuracy 0.9994
32
  ----------
33
  Epoch 3/3
34
- Train loss 0.0081 accuracy 0.9982
35
- Val loss 0.0013 accuracy 1.0000
36
  ```
37
 
38
  #### Производительность
@@ -42,7 +42,7 @@ Val loss 0.0013 accuracy 1.0000
42
 
43
  #### Использование модели
44
 
45
- ##### 1. Код для работы с моделью (возвращает один label):
46
 
47
  ```python
48
  import torch
@@ -81,59 +81,8 @@ def predict(text):
81
 
82
  return predicted_class
83
 
84
- while True:
85
- text = input("Текст>>> ")
86
- print(predict(text))
87
- ```
88
-
89
- ##### 2. Код для работы с моделью (возвращает вероятности через softmax):
90
-
91
- ```python
92
- import torch
93
- from transformers import BertTokenizer, BertForSequenceClassification
94
- import torch.nn.functional as F
95
-
96
- model_path = 'BERT_v3/zehedz'
97
-
98
- model = BertForSequenceClassification.from_pretrained(model_path, num_labels=3, problem_type="single_label_classification")
99
- tokenizer = BertTokenizer.from_pretrained(model_path)
100
-
101
- def predict(text):
102
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
103
- model.to(device)
104
- model.eval()
105
-
106
- encoding = tokenizer.encode_plus(
107
- text,
108
- add_special_tokens=True,
109
- max_length=512,
110
- return_token_type_ids=False,
111
- truncation=True,
112
- padding='max_length',
113
- return_attention_mask=True,
114
- return_tensors='pt',
115
- )
116
-
117
- input_ids = encoding['input_ids'].to(device)
118
- attention_mask = encoding['attention_mask'].to(device)
119
-
120
- with torch.no_grad():
121
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
122
- logits = outputs.logits
123
-
124
- probs = F.softmax(logits, dim=1).cpu().numpy()[0]
125
-
126
- labels = ['kbd_Cyrl', 'rus_Cyrl', 'krc_Cyrl']
127
- for i, label in enumerate(labels):
128
- print(f"Class: {label}, Probability: {probs[i]:.4f}")
129
-
130
- predicted_class = labels[torch.argmax(logits, dim=1).cpu().numpy()[0]]
131
-
132
- return predicted_class
133
-
134
- while True:
135
- text = input("Текст>>> ")
136
- predict(text)
137
  ```
138
 
139
  #### Использование в API Space на Hugging Face
 
14
  - code
15
  ---
16
 
17
+ ### Zehedz
18
 
19
  #### Описание модели
20
+ Эта модель классифицирует тексты на три языка: Русский (`rus_Cyrl`), Кабардино-Черкесский (`kbd_Cyrl`) и Карачаево-Балкарский (`krc_Cyrl`). Модель основана на архитектуре BERT и обучена на специализированном корпусе, охватывающем данные для каждого из указанных языков. Модель показывает высокую точность на этапе валидации и обладает высокой скоростью работы как на GPU, так и на CPU.
21
 
22
  #### Результаты обучения
23
 
24
  ```
25
  Epoch 1/3
26
+ Train loss: 0.0431 | accuracy: 0.9889
27
+ Val loss: 0.0014 | accuracy: 1.0000
28
  ----------
29
  Epoch 2/3
30
+ Train loss: 0.0111 | accuracy: 0.9974
31
+ Val loss: 0.0023 | accuracy: 0.9994
32
  ----------
33
  Epoch 3/3
34
+ Train loss: 0.0081 | accuracy: 0.9982
35
+ Val loss: 0.0013 | accuracy: 1.0000
36
  ```
37
 
38
  #### Производительность
 
42
 
43
  #### Использование модели
44
 
45
+ ##### Код для работы с моделью:
46
 
47
  ```python
48
  import torch
 
81
 
82
  return predicted_class
83
 
84
+ text = "Привет, как дела?"
85
+ print(predict(text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  ```
87
 
88
  #### Использование в API Space на Hugging Face