File size: 1,294 Bytes
a4d6124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from transformers import Trainer, TrainingArguments
from model import get_model_and_tokenizer
from data_loader import get_dataloader
from utils import load_config, set_seed

def main():
    config = load_config('configs/model_config.yaml')
    set_seed(config['training']['seed'])

    model, tokenizer = get_model_and_tokenizer(config)
    train_dataloader = get_dataloader(config, tokenizer, 'train')
    val_dataloader = get_dataloader(config, tokenizer, 'validation')

    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=config['training']['num_epochs'],
        per_device_train_batch_size=config['training']['batch_size'],
        per_device_eval_batch_size=config['training']['batch_size'],
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=100,
        evaluation_strategy="steps",
        eval_steps=1000,
        save_steps=config['training']['save_every'],
        load_best_model_at_end=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataloader.dataset,
        eval_dataset=val_dataloader.dataset,
    )

    trainer.train()
    trainer.save_model("./final_model")

if __name__ == "__main__":
    main()