|
import datetime |
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning import loggers |
|
|
|
from src import config |
|
|
|
|
|
def _get_wandb_logger(trainer_config: config.TrainerConfig): |
|
name = f"{config.MODEL_NAME}-{datetime.datetime.now()}" |
|
if trainer_config.debug: |
|
name = "debug-" + name |
|
return loggers.WandbLogger( |
|
entity=config.WANDB_ENTITY, |
|
save_dir=config.WANDB_LOG_PATH, |
|
project=config.MODEL_NAME, |
|
name=name, |
|
config=trainer_config._model_config.to_dict(), |
|
) |
|
|
|
|
|
def get_trainer(trainer_config: config.TrainerConfig): |
|
return pl.Trainer( |
|
max_epochs=trainer_config.epochs if not trainer_config.debug else 1, |
|
logger=_get_wandb_logger(trainer_config), |
|
log_every_n_steps=trainer_config.log_every_n_steps, |
|
gradient_clip_val=1.0, |
|
limit_train_batches=5 if trainer_config.debug else 1.0, |
|
limit_val_batches=5 if trainer_config.debug else 1.0, |
|
accelerator="auto", |
|
) |
|
|