Codex_Prime / src /train.py
dnnsdunca's picture
Create src/train.py
a4d6124 verified
raw
history blame
1.29 kB
import torch
from transformers import Trainer, TrainingArguments
from model import get_model_and_tokenizer
from data_loader import get_dataloader
from utils import load_config, set_seed
def main():
config = load_config('configs/model_config.yaml')
set_seed(config['training']['seed'])
model, tokenizer = get_model_and_tokenizer(config)
train_dataloader = get_dataloader(config, tokenizer, 'train')
val_dataloader = get_dataloader(config, tokenizer, 'validation')
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=config['training']['num_epochs'],
per_device_train_batch_size=config['training']['batch_size'],
per_device_eval_batch_size=config['training']['batch_size'],
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=100,
evaluation_strategy="steps",
eval_steps=1000,
save_steps=config['training']['save_every'],
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataloader.dataset,
eval_dataset=val_dataloader.dataset,
)
trainer.train()
trainer.save_model("./final_model")
if __name__ == "__main__":
main()