Spaces:
Sleeping
Sleeping
File size: 1,294 Bytes
a4d6124 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
from transformers import Trainer, TrainingArguments
from model import get_model_and_tokenizer
from data_loader import get_dataloader
from utils import load_config, set_seed
def main():
config = load_config('configs/model_config.yaml')
set_seed(config['training']['seed'])
model, tokenizer = get_model_and_tokenizer(config)
train_dataloader = get_dataloader(config, tokenizer, 'train')
val_dataloader = get_dataloader(config, tokenizer, 'validation')
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=config['training']['num_epochs'],
per_device_train_batch_size=config['training']['batch_size'],
per_device_eval_batch_size=config['training']['batch_size'],
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=100,
evaluation_strategy="steps",
eval_steps=1000,
save_steps=config['training']['save_every'],
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataloader.dataset,
eval_dataset=val_dataloader.dataset,
)
trainer.train()
trainer.save_model("./final_model")
if __name__ == "__main__":
main() |