Fine-tuned rubert-base-cased for multi-label emotion classification task.
Model was trained on ru_go_emotions_ekman dataset. Original translation of comments to Russian was done at seara/ru_go_emotions. Dataset is Russian translation of GoEmotions dataset. Google Translate was used to generate the machine translation.
Original 26 emotions from GoEmotions were mapped to 6 base emotions as per Dr. Ekman theory.
Labels predicted by classifier:
0: anger
1: disgust
2: fear
3: joy
4: sadness
5: surprise
6: neutral
Label mapping from 27 emotions from GoEmotion to 6 base emotions as per Dr. Ekman theory:
GoEmotion | Ekman |
---|---|
admiration | joy |
amusement | joy |
anger | anger |
annoyance | anger |
approval | joy |
caring | joy |
confusion | surprise |
curiosity | surprise |
desire | joy |
disappointment | sadness |
disapproval | anger |
disgust | disgust |
embarrassment | sadness |
excitement | joy |
fear | fear |
gratitude | joy |
grief | sadness |
joy | joy |
love | joy |
nervousness | fear |
optimism | joy |
pride | joy |
realization | surprise |
relief | joy |
remorse | sadness |
sadness | sadness |
surprise | surprise |
neutral | neutral |
Seed used for random number generator is 42:
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
Training parameters:
max_length: null
batch_size: 32
shuffle: True
num_workers: 2
pin_memory: False
drop_last: False
optimizer: adam
lr: 0.00001
weight_decay: 0
problem_type: multi_label_classification
num_epochs: 4
Evaluation results on test split of ru_go_emotions_ekman
Precision | Recall | F1-Score | AUC-ROC | Support | |
---|---|---|---|---|---|
anger | 0.56 | 0.44 | 0.49 | 0.86 | 726 |
disgust | 0.65 | 0.24 | 0.36 | 0.92 | 123 |
fear | 0.64 | 0.60 | 0.62 | 0.93 | 98 |
joy | 0.79 | 0.80 | 0.80 | 0.91 | 2104 |
sadness | 0.68 | 0.44 | 0.53 | 0.89 | 379 |
surprise | 0.60 | 0.52 | 0.56 | 0.88 | 677 |
neutral | 0.65 | 0.58 | 0.61 | 0.82 | 1787 |
micro avg | 0.69 | 0.62 | 0.65 | 0.92 | 5894 |
macro avg | 0.65 | 0.52 | 0.57 | 0.89 | 5894 |
weighted avg | 0.69 | 0.62 | 0.65 | 0.87 | 5894 |
samples avg | 0.65 | 0.64 | 0.64 | nan | 5894 |
- Downloads last month
- 8
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.