bert-base-uncased-emotion / finetune-emotions.py
sabersol's picture
model updated
f80494f
raw
history blame
1.75 kB
import torch
from sklearn.metrics import confusion_matrix
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
#import matplotlib.pyplot as plt
import seaborn as sns
import explainableai
import os
from dotenv import load_dotenv
load_dotenv()
BASE_MODEL_NAME = "bert-base-uncased"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
def save_confusion_matrix(y_valid, y_pred):
cm = confusion_matrix(y_valid, y_pred)
f = sns.heatmap(cm, annot=True, fmt='d')
f.figure.savefig("confusion_matrix.png")
def get_encoded_data(tokenizer):
def tokenize(batch):
return tokenizer(batch["text"], padding=True, truncation=True)
emotions = load_dataset("emotion")
emotions_encoded = emotions.map(tokenize, batched=True, batch_size=None)
emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
return emotions_encoded
if __name__ == "__main__":
id2label = {0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise'}
labels = list(id2label.values())
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path = BASE_MODEL_NAME,
num_labels = len(labels),
id2label = id2label,
resume_download=True,).to(device)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
encoded_data = get_encoded_data(tokenizer)
citda = explainableai.CITDA(model, labels, BASE_MODEL_NAME, tokenizer, encoded_data)
y_valid, y_pred = citda.train()
save_confusion_matrix(y_valid, y_pred)
print("y_valid=",len(y_valid), "y_pred=", len(y_pred))