Spaces:
Sleeping
Sleeping
import torch | |
from transformers import Trainer, TrainingArguments | |
from model import get_model_and_tokenizer | |
from data_loader import get_dataloader | |
from utils import load_config, set_seed | |
def main(): | |
config = load_config('configs/model_config.yaml') | |
set_seed(config['training']['seed']) | |
model, tokenizer = get_model_and_tokenizer(config) | |
train_dataloader = get_dataloader(config, tokenizer, 'train') | |
val_dataloader = get_dataloader(config, tokenizer, 'validation') | |
training_args = TrainingArguments( | |
output_dir="./results", | |
num_train_epochs=config['training']['num_epochs'], | |
per_device_train_batch_size=config['training']['batch_size'], | |
per_device_eval_batch_size=config['training']['batch_size'], | |
warmup_steps=500, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=100, | |
evaluation_strategy="steps", | |
eval_steps=1000, | |
save_steps=config['training']['save_every'], | |
load_best_model_at_end=True, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataloader.dataset, | |
eval_dataset=val_dataloader.dataset, | |
) | |
trainer.train() | |
trainer.save_model("./final_model") | |
if __name__ == "__main__": | |
main() |