from transformers import AutoTokenizer, Trainer, TrainingArguments from sklearn.metrics import accuracy_score, f1_score import numpy as np CITDA_EPOCHS = 6 CITDA_WEIGHT_DECAY = 0.05 # L2 regularization CITDA_BATCH_SIZE = 32 CITDA_LEARNINGRATE= 2e-5 class CITDA: def __init__(self, model, labels, base_model_name, tokenizer, encoded_data): self.labels = labels self.tokenizer = tokenizer self.model = model self.encoded_data = encoded_data def _get_trainer(self): def compute_metrics(pred): labels = pred.label_ids preds = pred.predictions.argmax(-1) f1 = f1_score(labels, preds, average="weighted") acc = accuracy_score(labels, preds) return {"accuracy": acc, "f1": f1} training_args = TrainingArguments(output_dir="results", num_train_epochs=CITDA_EPOCHS, learning_rate=CITDA_LEARNINGRATE, per_device_train_batch_size=CITDA_BATCH_SIZE, per_device_eval_batch_size=CITDA_BATCH_SIZE, load_best_model_at_end=True, metric_for_best_model="f1", weight_decay=CITDA_WEIGHT_DECAY, evaluation_strategy="epoch", save_strategy="epoch", disable_tqdm=False, report_to="wandb") trainer = Trainer(model=self.model, tokenizer=self.tokenizer, args=training_args, compute_metrics=compute_metrics, train_dataset = self.encoded_data["train"], eval_dataset = self.encoded_data["validation"]) return trainer def train(self): trainer = self._get_trainer() trainer.train() results = trainer.evaluate() preds_output = trainer.predict(self.encoded_data["validation"]) y_valid = np.array(self.encoded_data["validation"]["label"]) y_pred = np.argmax(preds_output.predictions, axis=1) #Saving the fine-tuned model self.model.save_pretrained('./') self.tokenizer.save_pretrained('./') return y_valid, y_pred