|
from transformers import AutoTokenizer, Trainer, TrainingArguments |
|
from sklearn.metrics import accuracy_score, f1_score |
|
|
|
import numpy as np |
|
|
|
CITDA_EPOCHS = 6 |
|
CITDA_WEIGHT_DECAY = 0.05 |
|
CITDA_BATCH_SIZE = 32 |
|
CITDA_LEARNINGRATE= 2e-5 |
|
|
|
class CITDA: |
|
def __init__(self, model, labels, base_model_name, tokenizer, encoded_data): |
|
self.labels = labels |
|
self.tokenizer = tokenizer |
|
self.model = model |
|
self.encoded_data = encoded_data |
|
|
|
def _get_trainer(self): |
|
def compute_metrics(pred): |
|
labels = pred.label_ids |
|
preds = pred.predictions.argmax(-1) |
|
f1 = f1_score(labels, preds, average="weighted") |
|
acc = accuracy_score(labels, preds) |
|
return {"accuracy": acc, "f1": f1} |
|
|
|
training_args = TrainingArguments(output_dir="results", |
|
num_train_epochs=CITDA_EPOCHS, |
|
learning_rate=CITDA_LEARNINGRATE, |
|
per_device_train_batch_size=CITDA_BATCH_SIZE, |
|
per_device_eval_batch_size=CITDA_BATCH_SIZE, |
|
load_best_model_at_end=True, |
|
metric_for_best_model="f1", |
|
weight_decay=CITDA_WEIGHT_DECAY, |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
disable_tqdm=False, |
|
report_to="wandb") |
|
trainer = Trainer(model=self.model, tokenizer=self.tokenizer, args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset = self.encoded_data["train"], |
|
eval_dataset = self.encoded_data["validation"]) |
|
return trainer |
|
|
|
def train(self): |
|
trainer = self._get_trainer() |
|
trainer.train() |
|
results = trainer.evaluate() |
|
preds_output = trainer.predict(self.encoded_data["validation"]) |
|
|
|
y_valid = np.array(self.encoded_data["validation"]["label"]) |
|
y_pred = np.argmax(preds_output.predictions, axis=1) |
|
|
|
|
|
self.model.save_pretrained('./') |
|
self.tokenizer.save_pretrained('./') |
|
|
|
return y_valid, y_pred |
|
|
|
|