Spaces:
Runtime error
Runtime error
import os.path as osp | |
import warnings | |
warnings.filterwarnings('ignore') | |
from typing import Optional | |
from pathlib import Path | |
from models.maplocnet import MapLocNet | |
import hydra | |
import pytorch_lightning as pl | |
import torch | |
from omegaconf import DictConfig, OmegaConf | |
from pytorch_lightning.utilities import rank_zero_only | |
from module import GenericModule | |
from logger import logger, pl_logger, EXPERIMENTS_PATH | |
from module import GenericModule | |
from dataset import UavMapDatasetModule | |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping | |
# print(osp.join(osp.dirname(__file__), "conf")) | |
class CleanProgressBar(pl.callbacks.TQDMProgressBar): | |
def get_metrics(self, trainer, model): | |
items = super().get_metrics(trainer, model) | |
items.pop("v_num", None) # don't show the version number | |
items.pop("loss", None) | |
return items | |
class SeedingCallback(pl.callbacks.Callback): | |
def on_epoch_start_(self, trainer, module): | |
seed = module.cfg.experiment.seed | |
is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0 | |
if trainer.training and not is_overfit: | |
seed = seed + trainer.current_epoch | |
# Temporarily disable the logging (does not seem to work?) | |
pl_logger.disabled = True | |
try: | |
pl.seed_everything(seed, workers=True) | |
finally: | |
pl_logger.disabled = False | |
def on_train_epoch_start(self, *args, **kwargs): | |
self.on_epoch_start_(*args, **kwargs) | |
def on_validation_epoch_start(self, *args, **kwargs): | |
self.on_epoch_start_(*args, **kwargs) | |
def on_test_epoch_start(self, *args, **kwargs): | |
self.on_epoch_start_(*args, **kwargs) | |
class ConsoleLogger(pl.callbacks.Callback): | |
def on_train_epoch_start(self, trainer, module): | |
logger.info( | |
"New training epoch %d for experiment '%s'.", | |
module.current_epoch, | |
module.cfg.experiment.name, | |
) | |
# @rank_zero_only | |
# def on_validation_epoch_end(self, trainer, module): | |
# results = { | |
# **dict(module.metrics_val.items()), | |
# **dict(module.losses_val.items()), | |
# } | |
# results = [f"{k} {v.compute():.3E}" for k, v in results.items()] | |
# logger.info(f'[Validation] {{{", ".join(results)}}}') | |
def find_last_checkpoint_path(experiment_dir): | |
cls = pl.callbacks.ModelCheckpoint | |
path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION) | |
if osp.exists(path): | |
return path | |
else: | |
return None | |
def prepare_experiment_dir(experiment_dir, cfg, rank): | |
config_path = osp.join(experiment_dir, "config.yaml") | |
last_checkpoint_path = find_last_checkpoint_path(experiment_dir) | |
if last_checkpoint_path is not None: | |
if rank == 0: | |
logger.info( | |
"Resuming the training from checkpoint %s", last_checkpoint_path | |
) | |
if osp.exists(config_path): | |
with open(config_path, "r") as fp: | |
cfg_prev = OmegaConf.create(fp.read()) | |
compare_keys = ["experiment", "data", "model", "training"] | |
if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy( | |
cfg_prev, compare_keys | |
): | |
raise ValueError( | |
"Attempting to resume training with a different config: " | |
f"{OmegaConf.masked_copy(cfg, compare_keys)} vs " | |
f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}" | |
) | |
if rank == 0: | |
Path(experiment_dir).mkdir(exist_ok=True, parents=True) | |
with open(config_path, "w") as fp: | |
OmegaConf.save(cfg, fp) | |
return last_checkpoint_path | |
def train(cfg: DictConfig) -> None: | |
torch.set_float32_matmul_precision("medium") | |
OmegaConf.resolve(cfg) | |
rank = rank_zero_only.rank | |
if rank == 0: | |
logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg)) | |
if cfg.experiment.gpus in (None, 0): | |
logger.warning("Will train on CPU...") | |
cfg.experiment.gpus = 0 | |
elif not torch.cuda.is_available(): | |
raise ValueError("Requested GPU but no NVIDIA drivers found.") | |
pl.seed_everything(cfg.experiment.seed, workers=True) | |
init_checkpoint_path = cfg.training.get("finetune_from_checkpoint") | |
if init_checkpoint_path is not None: | |
logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path) | |
model = GenericModule.load_from_checkpoint( | |
init_checkpoint_path, strict=True, find_best=False, cfg=cfg | |
) | |
else: | |
model = GenericModule(cfg) | |
if rank == 0: | |
logger.info("Network:\n%s", model.model) | |
experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name) | |
last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank) | |
checkpointing_epoch = pl.callbacks.ModelCheckpoint( | |
dirpath=experiment_dir, | |
filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}", | |
auto_insert_metric_name=False, | |
save_last=True, | |
every_n_epochs=1, | |
save_on_train_epoch_end=True, | |
verbose=True, | |
**cfg.training.checkpointing, | |
) | |
checkpointing_step = pl.callbacks.ModelCheckpoint( | |
dirpath=experiment_dir, | |
filename="checkpoint-step-{step}-{loss/total/val:02f}", | |
auto_insert_metric_name=False, | |
save_last=True, | |
every_n_train_steps=1000, | |
verbose=True, | |
**cfg.training.checkpointing, | |
) | |
checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing" | |
# 创建 EarlyStopping 回调 | |
early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5) | |
strategy = None | |
if cfg.experiment.gpus > 1: | |
strategy = pl.strategies.DDPStrategy(find_unused_parameters=False) | |
for split in ["train", "val"]: | |
cfg.data[split].batch_size = ( | |
cfg.data[split].batch_size // cfg.experiment.gpus | |
) | |
cfg.data[split].num_workers = int( | |
(cfg.data[split].num_workers + cfg.experiment.gpus - 1) | |
/ cfg.experiment.gpus | |
) | |
# data = data_modules[cfg.data.get("name", "mapillary")](cfg.data) | |
datamodule =UavMapDatasetModule(cfg.data) | |
tb_args = {"name": cfg.experiment.name, "version": ""} | |
tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args) | |
callbacks = [ | |
checkpointing_epoch, | |
checkpointing_step, | |
# early_stopping_callback, | |
pl.callbacks.LearningRateMonitor(), | |
SeedingCallback(), | |
CleanProgressBar(), | |
ConsoleLogger(), | |
] | |
if cfg.experiment.gpus > 0: | |
callbacks.append(pl.callbacks.DeviceStatsMonitor()) | |
trainer = pl.Trainer( | |
default_root_dir=experiment_dir, | |
detect_anomaly=False, | |
# strategy=ddp_find_unused_parameters_true, | |
enable_model_summary=True, | |
sync_batchnorm=True, | |
enable_checkpointing=True, | |
logger=tb, | |
callbacks=callbacks, | |
strategy=strategy, | |
check_val_every_n_epoch=1, | |
accelerator="gpu", | |
num_nodes=1, | |
**cfg.training.trainer, | |
) | |
trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path) | |
def main(cfg: DictConfig) -> None: | |
OmegaConf.save(config=cfg, f='maplocnet.yaml') | |
train(cfg) | |
if __name__ == "__main__": | |
main() | |