import os, sys |
import argparse |
import shutil |
import subprocess |
from omegaconf import OmegaConf |
import torch |
from pytorch_lightning import seed_everything |
from pytorch_lightning.trainer import Trainer |
from pytorch_lightning.strategies import DDPStrategy |
from pytorch_lightning.callbacks import Callback |
from pytorch_lightning.utilities import rank_zero_only |
from src.utils.train_util import instantiate_from_config |
import os |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' |
@rank_zero_only |
def rank_zero_print(*args): |
print(*args) |
def get_parser(**parser_kwargs): |
def str2bool(v): |
if isinstance(v, bool): |
return v |
if v.lower() in ("yes", "true", "t", "y", "1"): |
return True |
elif v.lower() in ("no", "false", "f", "n", "0"): |
return False |
else: |
raise argparse.ArgumentTypeError("Boolean value expected.") |
parser = argparse.ArgumentParser(**parser_kwargs) |
parser.add_argument( |
"-r", |
"--resume", |
type=str, |
default=None, |
help="resume from checkpoint", |
) |
parser.add_argument( |
"--resume_weights_only", |
action="store_true", |
help="only resume model weights", |
) |
parser.add_argument( |
"-b", |
"--base", |
type=str, |
default="base_config.yaml", |
help="path to base configs", |
) |
parser.add_argument( |
"-n", |
"--name", |
type=str, |
default="", |
help="experiment name", |
) |
parser.add_argument( |
"--num_nodes", |
type=int, |
default=1, |
help="number of nodes to use", |
) |
parser.add_argument( |
"--gpus", |
type=str, |
default="0,", |
help="gpu ids to use", |
) |
parser.add_argument( |
"-s", |
"--seed", |
type=int, |
default=42, |
help="seed for seed_everything", |
) |
parser.add_argument( |
"-l", |
"--logdir", |
type=str, |
default="logs", |
help="directory for logging data", |
) |
return parser |
class ClearCacheCallback(Callback): |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
torch.cuda.empty_cache() |
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): |
torch.cuda.empty_cache() |
class SetupCallback(Callback): |
def __init__(self, resume, logdir, ckptdir, cfgdir, config): |
super().__init__() |
self.resume = resume |
self.logdir = logdir |
self.ckptdir = ckptdir |
self.cfgdir = cfgdir |
self.config = config |
def on_fit_start(self, trainer, pl_module): |
if trainer.global_rank == 0: |
os.makedirs(self.logdir, exist_ok=True) |
os.makedirs(self.ckptdir, exist_ok=True) |
os.makedirs(self.cfgdir, exist_ok=True) |
rank_zero_print("Project config") |
rank_zero_print(OmegaConf.to_yaml(self.config)) |
OmegaConf.save(self.config, |
os.path.join(self.cfgdir, "project.yaml")) |
class CodeSnapshot(Callback): |
""" |
Modified from https://github.com/threestudio-project/threestudio/blob/main/threestudio/utils/callbacks.py#L60 |
""" |
def __init__(self, savedir): |
self.savedir = savedir |
def get_file_list(self): |
return [ |
b.decode() |
for b in set( |
subprocess.check_output( |
'git ls-files -- ":!:configs/*"', shell=True |
).splitlines() |
) |
| set( |
subprocess.check_output( |
"git ls-files --others --exclude-standard", shell=True |
).splitlines() |
) |
] |
@rank_zero_only |
def save_code_snapshot(self): |
os.makedirs(self.savedir, exist_ok=True) |
for f in self.get_file_list(): |
if not os.path.exists(f) or os.path.isdir(f): |
continue |
os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) |
shutil.copyfile(f, os.path.join(self.savedir, f)) |
def on_fit_start(self, trainer, pl_module): |
try: |
pass |
except: |
rank_zero_only( |
"Code snapshot is not saved. Please make sure you have git installed and are in a git repository." |
) |
if __name__ == "__main__": |
sys.path.append(os.getcwd()) |
parser = get_parser() |
opt, unknown = parser.parse_known_args() |
cfg_fname = os.path.split(opt.base)[-1] |
cfg_name = os.path.splitext(cfg_fname)[0] |
exp_name = "-" + opt.name if opt.name != "" else "" |
logdir = os.path.join(opt.logdir, cfg_name+exp_name) |
ckptdir = os.path.join(logdir, "checkpoints") |
cfgdir = os.path.join(logdir, "configs") |
codedir = os.path.join(logdir, "code") |
seed_everything(opt.seed) |
config = OmegaConf.load(opt.base) |
lightning_config = config.lightning |
trainer_config = lightning_config.trainer |
trainer_config["accelerator"] = "cuda" |
rank_zero_print(f"Running on GPUs {opt.gpus}") |
ngpu = len(opt.gpus.strip(",").split(',')) |
trainer_config['devices'] = ngpu |
trainer_opt = argparse.Namespace(**trainer_config) |
lightning_config.trainer = trainer_config |
model = instantiate_from_config(config.model) |
if opt.resume and opt.resume_weights_only: |
model = model.__class__.load_from_checkpoint(opt.resume, **config.model.params) |
model.logdir = logdir |
trainer_kwargs = dict() |
default_logger_cfg = { |
"target": "pytorch_lightning.loggers.TensorBoardLogger", |
"params": { |
"name": "tensorboard", |
"save_dir": logdir, |
"version": "0", |
} |
} |
logger_cfg = OmegaConf.merge(default_logger_cfg) |
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) |
default_modelckpt_cfg = { |
"target": "pytorch_lightning.callbacks.ModelCheckpoint", |
"params": { |
"dirpath": ckptdir, |
"filename": "{step:08}", |
"verbose": True, |
"save_last": True, |
"every_n_train_steps": 5000, |
"save_top_k": -1, |
} |
} |
if "modelcheckpoint" in lightning_config: |
modelckpt_cfg = lightning_config.modelcheckpoint |
else: |
modelckpt_cfg = OmegaConf.create() |
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) |
default_callbacks_cfg = { |
"setup_callback": { |
"target": "train.SetupCallback", |
"params": { |
"resume": opt.resume, |
"logdir": logdir, |
"ckptdir": ckptdir, |
"cfgdir": cfgdir, |
"config": config, |
} |
}, |
"learning_rate_logger": { |
"target": "pytorch_lightning.callbacks.LearningRateMonitor", |
"params": { |
"logging_interval": "step", |
} |
}, |
"code_snapshot": { |
"target": "train.CodeSnapshot", |
"params": { |
"savedir": codedir, |
} |
}, |
} |
default_callbacks_cfg["checkpoint_callback"] = modelckpt_cfg |
if "callbacks" in lightning_config: |
callbacks_cfg = lightning_config.callbacks |
else: |
callbacks_cfg = OmegaConf.create() |
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) |
trainer_kwargs["callbacks"] = [ |
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] |
trainer_kwargs["callbacks"].append(ClearCacheCallback()) |
trainer_kwargs['precision'] = '32-true' |
trainer_kwargs["strategy"] = DDPStrategy(find_unused_parameters=True) |
trainer = Trainer(**trainer_config, **trainer_kwargs, num_nodes=opt.num_nodes) |
trainer.logdir = logdir |
data = instantiate_from_config(config.data) |
data.prepare_data() |
data.setup("fit") |
base_lr = config.model.base_learning_rate |
if 'accumulate_grad_batches' in lightning_config.trainer: |
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches |
else: |
accumulate_grad_batches = 1 |
rank_zero_print(f"accumulate_grad_batches = {accumulate_grad_batches}") |
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches |
model.learning_rate = base_lr |
rank_zero_print("++++ NOT USING LR SCALING ++++") |
rank_zero_print(f"Setting learning rate to {model.learning_rate:.2e}") |
if opt.resume and not opt.resume_weights_only: |
trainer.fit(model, data, ckpt_path=opt.resume) |
else: |
trainer.fit(model, data) |