|
import torch |
|
from sklearn.metrics import confusion_matrix |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from datasets import load_dataset |
|
|
|
import seaborn as sns |
|
import explainableai |
|
import os |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
BASE_MODEL_NAME = "bert-base-uncased" |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
print("Device:", device) |
|
|
|
def save_confusion_matrix(y_valid, y_pred): |
|
cm = confusion_matrix(y_valid, y_pred) |
|
f = sns.heatmap(cm, annot=True, fmt='d') |
|
f.figure.savefig("confusion_matrix.png") |
|
|
|
def get_encoded_data(tokenizer): |
|
def tokenize(batch): |
|
return tokenizer(batch["text"], padding=True, truncation=True) |
|
emotions = load_dataset("emotion") |
|
emotions_encoded = emotions.map(tokenize, batched=True, batch_size=None) |
|
emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"]) |
|
return emotions_encoded |
|
|
|
if __name__ == "__main__": |
|
id2label = {0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise'} |
|
labels = list(id2label.values()) |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
pretrained_model_name_or_path = BASE_MODEL_NAME, |
|
num_labels = len(labels), |
|
id2label = id2label, |
|
resume_download=True,).to(device) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME) |
|
|
|
encoded_data = get_encoded_data(tokenizer) |
|
citda = explainableai.CITDA(model, labels, BASE_MODEL_NAME, tokenizer, encoded_data) |
|
y_valid, y_pred = citda.train() |
|
|
|
save_confusion_matrix(y_valid, y_pred) |
|
print("y_valid=",len(y_valid), "y_pred=", len(y_pred)) |
|
|