File size: 974 Bytes
c6fe3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import datetime

import pytorch_lightning as pl
from pytorch_lightning import loggers

from src import config


def _get_wandb_logger(trainer_config: config.TrainerConfig):
    name = f"{config.MODEL_NAME}-{datetime.datetime.now()}"
    if trainer_config.debug:
        name = "debug-" + name
    return loggers.WandbLogger(
        entity=config.WANDB_ENTITY,
        save_dir=config.WANDB_LOG_PATH,
        project=config.MODEL_NAME,
        name=name,
        config=trainer_config._model_config.to_dict(),
    )


def get_trainer(trainer_config: config.TrainerConfig):
    return pl.Trainer(
        max_epochs=trainer_config.epochs if not trainer_config.debug else 1,
        logger=_get_wandb_logger(trainer_config),
        log_every_n_steps=trainer_config.log_every_n_steps,
        gradient_clip_val=1.0,
        limit_train_batches=5 if trainer_config.debug else 1.0,
        limit_val_batches=5 if trainer_config.debug else 1.0,
        accelerator="auto",
    )