import torch from sklearn.metrics import confusion_matrix from transformers import AutoTokenizer, AutoModelForSequenceClassification from datasets import load_dataset #import matplotlib.pyplot as plt import seaborn as sns import explainableai import os from dotenv import load_dotenv load_dotenv() BASE_MODEL_NAME = "bert-base-uncased" device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") print("Device:", device) def save_confusion_matrix(y_valid, y_pred): cm = confusion_matrix(y_valid, y_pred) f = sns.heatmap(cm, annot=True, fmt='d') f.figure.savefig("confusion_matrix.png") def get_encoded_data(tokenizer): def tokenize(batch): return tokenizer(batch["text"], padding=True, truncation=True) emotions = load_dataset("emotion") emotions_encoded = emotions.map(tokenize, batched=True, batch_size=None) emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"]) return emotions_encoded if __name__ == "__main__": labels = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'] model = AutoModelForSequenceClassification.from_pretrained( pretrained_model_name_or_path = BASE_MODEL_NAME, num_labels = len(labels), id2label=[{i: labels[i]} for i in range(len(labels))], resume_download=True,).to(device) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME) encoded_data = get_encoded_data(tokenizer) citda = explainableai.CITDA(model, labels, BASE_MODEL_NAME, tokenizer, encoded_data) y_valid, y_pred = citda.train() save_confusion_matrix(y_valid, y_pred) print("y_valid=",len(y_valid), "y_pred=", len(y_pred))