File size: 1,746 Bytes
4e81dd6
 
 
 
 
 
 
7044b4a
 
 
4e81dd6
 
 
7044b4a
 
4e81dd6
7044b4a
 
4e81dd6
 
 
 
 
 
 
 
 
 
7044b4a
4e81dd6
f80494f
 
4e81dd6
 
 
f80494f
4e81dd6
 
 
 
 
 
 
 
7044b4a
4e81dd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from sklearn.metrics import confusion_matrix
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
#import matplotlib.pyplot as plt
import seaborn as sns
import explainableai
import os
from dotenv import load_dotenv
load_dotenv()

BASE_MODEL_NAME = "bert-base-uncased"

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

def save_confusion_matrix(y_valid, y_pred):
    cm = confusion_matrix(y_valid, y_pred)
    f = sns.heatmap(cm, annot=True, fmt='d')
    f.figure.savefig("confusion_matrix.png")

def get_encoded_data(tokenizer):
    def tokenize(batch):
        return tokenizer(batch["text"], padding=True, truncation=True)
    emotions = load_dataset("emotion")
    emotions_encoded = emotions.map(tokenize, batched=True, batch_size=None)
    emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    return emotions_encoded

if __name__ == "__main__":
    id2label = {0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise'}
    labels = list(id2label.values())
    model = AutoModelForSequenceClassification.from_pretrained(
                pretrained_model_name_or_path = BASE_MODEL_NAME,
                num_labels = len(labels),
                id2label = id2label,
                resume_download=True,).to(device)

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

    encoded_data = get_encoded_data(tokenizer)
    citda = explainableai.CITDA(model, labels, BASE_MODEL_NAME, tokenizer, encoded_data)
    y_valid, y_pred = citda.train()
    
    save_confusion_matrix(y_valid, y_pred)
    print("y_valid=",len(y_valid), "y_pred=", len(y_pred))