hpoghos's picture
add code
f949b3f
from pathlib import Path
from pytorch_lightning import Callback
import os
import torch
from lightning_fabric.utilities.cloud_io import get_filesystem
from pytorch_lightning.cli import LightningArgumentParser
from pytorch_lightning import LightningModule, Trainer
from lightning_utilities.core.imports import RequirementCache
from omegaconf import OmegaConf
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache(
"jsonargparse[signatures]>=4.17.0")
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
from jsonargparse import (
ActionConfigFile,
ArgumentParser,
class_from_function,
Namespace,
register_unresolvable_import_paths,
set_config_read_mode,
)
# Required until fix https://github.com/pytorch/pytorch/issues/74483
register_unresolvable_import_paths(torch)
set_config_read_mode(fsspec_enabled=True)
else:
locals()["ArgumentParser"] = object
locals()["Namespace"] = object
class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.
Args:
parser: The parser object used to parse the configuration.
config: The parsed configuration that will be saved.
config_filename: Filename for the config file.
overwrite: Whether to overwrite an existing config file.
multifile: When input is multiple config files, saved config preserves this structure.
Raises:
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
"""
def __init__(
self,
parser: LightningArgumentParser,
config: Namespace,
log_dir: str,
config_filename: str = "config.yaml",
overwrite: bool = False,
multifile: bool = False,
) -> None:
self.parser = parser
self.config = config
self.config_filename = config_filename
self.overwrite = overwrite
self.multifile = multifile
self.already_saved = False
self.log_dir = log_dir
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
if self.already_saved:
return
log_dir = self.log_dir
assert log_dir is not None
config_path = os.path.join(log_dir, self.config_filename)
fs = get_filesystem(log_dir)
if not self.overwrite:
# check if the file exists on rank 0
file_exists = fs.isfile(
config_path) if trainer.is_global_zero else False
# broadcast whether to fail to all ranks
file_exists = trainer.strategy.broadcast(file_exists)
if file_exists:
raise RuntimeError(
f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"
" results of a previous run. You can delete the previous config file,"
" set `LightningCLI(save_config_callback=None)` to disable config saving,"
' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.'
)
# save the file on rank 0
if trainer.is_global_zero:
# save only on rank zero to avoid race conditions.
# the `log_dir` needs to be created as we rely on the logger to do it usually
# but it hasn't logged anything at this point
fs.makedirs(log_dir, exist_ok=True)
self.parser.save(
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)
self.already_saved = True
trainer.logger.log_hyperparams(OmegaConf.load(config_path))
# broadcast so that all ranks are in sync on future calls to .setup()
self.already_saved = trainer.strategy.broadcast(self.already_saved)