File size: 1,771 Bytes
4e81dd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Modified https://github.com/bhadreshpsavani/ExploringSentimentalAnalysis/blob/main/SentimentalAnalysisWithDistilbert.ipynb

import torch
from sklearn.metrics import confusion_matrix
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
#import matplotlib.pyplot as plt
import seaborn as sns
import explainableai

BASE_MODEL_NAME = "bert-base-uncased"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


def save_confusion_matrix(y_valid, y_preds):
    cm = confusion_matrix(y_valid, y_preds)
    f = sns.heatmap(cm, annot=True, fmt='d')
    f.figure.savefig("confusion_matrix.png")

def get_encoded_data(tokenizer):
    def tokenize(batch):
        return tokenizer(batch["text"], padding=True, truncation=True)
    emotions = load_dataset("emotion")
    emotions_encoded = emotions.map(tokenize, batched=True, batch_size=None)
    emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    return emotions_encoded
if __name__ == "__main__":
    labels = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
    model = AutoModelForSequenceClassification.from_pretrained(
                pretrained_model_name_or_path = BASE_MODEL_NAME,
                num_labels = len(labels),
                id2label=[{i: labels[i]} for i in range(len(labels))],
                resume_download=True,).to(device)

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

    encoded_data = get_encoded_data(tokenizer)
    citda = explainableai.CITDA(model, labels, BASE_MODEL_NAME, tokenizer, encoded_data)
    y_valid, y_pred = citda.train()
    
    save_confusion_matrix(y_valid, y_preds)
    print("y_valid=",len(y_valid), "y_pred=", len(y_pred))