Spaces:
Runtime error
Runtime error
import argparse | |
import glob | |
from pathlib import Path | |
from omegaconf import OmegaConf | |
import pytorch_lightning as pl | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.loggers import TensorBoardLogger | |
from lightning import FontLightningModule | |
from utils import save_files | |
def load_configuration(path_config): | |
setting = OmegaConf.load(path_config) | |
# load hyperparameter | |
hp = OmegaConf.load(setting.config.dataset) | |
hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.model)) | |
hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.logging)) | |
# with lightning setting | |
if hasattr(setting.config, 'lightning'): | |
pl_config = OmegaConf.load(setting.config.lightning) | |
if hasattr(pl_config, 'pl_config'): | |
return hp, pl_config.pl_config | |
return hp, pl_config | |
# without lightning setting | |
return hp | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Code to train font style transfer') | |
parser.add_argument("--config", type=str, default="./config/setting.yaml", | |
help="Config file for training") | |
parser.add_argument('-g', '--gpus', type=str, default='0,1', | |
help="Number of gpus to use (e.g. '0,1,2,3'). Will use all if not given.") | |
parser.add_argument('-p', '--resume_checkpoint_path', type=str, default=None, | |
help="path of checkpoint for resuming") | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args() | |
hp, pl_config = load_configuration(args.config) | |
logging_dir = Path(hp.logging.log_dir) | |
# call lightning module | |
font_pl = FontLightningModule(hp) | |
# set logging | |
hp.logging['log_dir'] = logging_dir / 'tensorboard' | |
savefiles = [] | |
for reg in hp.logging.savefiles: | |
savefiles += glob.glob(reg) | |
hp.logging['log_dir'].mkdir(exist_ok=True) | |
save_files(str(logging_dir), savefiles) | |
# set tensorboard logger | |
logger = TensorBoardLogger(str(logging_dir), name=str(hp.logging.seed)) | |
# set checkpoing callback | |
weights_save_path = logging_dir / 'checkpoint' / str(hp.logging.seed) | |
weights_save_path.mkdir(exist_ok=True) | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=str(weights_save_path), | |
**pl_config.checkpoint.callback | |
) | |
# set lightning trainer | |
trainer = pl.Trainer( | |
logger=logger, | |
gpus=-1 if args.gpus is None else args.gpus, | |
callbacks=[checkpoint_callback], | |
**pl_config.trainer | |
) | |
# let's train | |
trainer.fit(font_pl) | |
if __name__ == "__main__": | |
main() | |