bert-base-uncased-emotion / finetune-emotions.py
sabersol's picture
class created
4e81dd6
raw
history blame
1.77 kB
# Modified https://github.com/bhadreshpsavani/ExploringSentimentalAnalysis/blob/main/SentimentalAnalysisWithDistilbert.ipynb
import torch
from sklearn.metrics import confusion_matrix
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
#import matplotlib.pyplot as plt
import seaborn as sns
import explainableai
BASE_MODEL_NAME = "bert-base-uncased"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def save_confusion_matrix(y_valid, y_preds):
cm = confusion_matrix(y_valid, y_preds)
f = sns.heatmap(cm, annot=True, fmt='d')
f.figure.savefig("confusion_matrix.png")
def get_encoded_data(tokenizer):
def tokenize(batch):
return tokenizer(batch["text"], padding=True, truncation=True)
emotions = load_dataset("emotion")
emotions_encoded = emotions.map(tokenize, batched=True, batch_size=None)
emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
return emotions_encoded
if __name__ == "__main__":
labels = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path = BASE_MODEL_NAME,
num_labels = len(labels),
id2label=[{i: labels[i]} for i in range(len(labels))],
resume_download=True,).to(device)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
encoded_data = get_encoded_data(tokenizer)
citda = explainableai.CITDA(model, labels, BASE_MODEL_NAME, tokenizer, encoded_data)
y_valid, y_pred = citda.train()
save_confusion_matrix(y_valid, y_preds)
print("y_valid=",len(y_valid), "y_pred=", len(y_pred))