import torch from transformers import Trainer, TrainingArguments from model import get_model_and_tokenizer from data_loader import get_dataloader from utils import load_config, set_seed def main(): config = load_config('configs/model_config.yaml') set_seed(config['training']['seed']) model, tokenizer = get_model_and_tokenizer(config) train_dataloader = get_dataloader(config, tokenizer, 'train') val_dataloader = get_dataloader(config, tokenizer, 'validation') training_args = TrainingArguments( output_dir="./results", num_train_epochs=config['training']['num_epochs'], per_device_train_batch_size=config['training']['batch_size'], per_device_eval_batch_size=config['training']['batch_size'], warmup_steps=500, weight_decay=0.01, logging_dir='./logs', logging_steps=100, evaluation_strategy="steps", eval_steps=1000, save_steps=config['training']['save_every'], load_best_model_at_end=True, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataloader.dataset, eval_dataset=val_dataloader.dataset, ) trainer.train() trainer.save_model("./final_model") if __name__ == "__main__": main()