import os.path as osp
import warnings
warnings.filterwarnings('ignore')
from typing import Optional
from pathlib import Path
from models.maplocnet import MapLocNet
import hydra
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
from module import GenericModule
from logger import logger, pl_logger, EXPERIMENTS_PATH
from module import GenericModule
from dataset import UavMapDatasetModule
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# print(osp.join(osp.dirname(__file__), "conf"))


class CleanProgressBar(pl.callbacks.TQDMProgressBar):
    def get_metrics(self, trainer, model):
        items = super().get_metrics(trainer, model)
        items.pop("v_num", None)  # don't show the version number
        items.pop("loss", None)
        return items


class SeedingCallback(pl.callbacks.Callback):
    def on_epoch_start_(self, trainer, module):
        seed = module.cfg.experiment.seed
        is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0
        if trainer.training and not is_overfit:
            seed = seed + trainer.current_epoch

        # Temporarily disable the logging (does not seem to work?)
        pl_logger.disabled = True
        try:
            pl.seed_everything(seed, workers=True)
        finally:
            pl_logger.disabled = False

    def on_train_epoch_start(self, *args, **kwargs):
        self.on_epoch_start_(*args, **kwargs)

    def on_validation_epoch_start(self, *args, **kwargs):
        self.on_epoch_start_(*args, **kwargs)

    def on_test_epoch_start(self, *args, **kwargs):
        self.on_epoch_start_(*args, **kwargs)


class ConsoleLogger(pl.callbacks.Callback):
    @rank_zero_only
    def on_train_epoch_start(self, trainer, module):
        logger.info(
            "New training epoch %d for experiment '%s'.",
            module.current_epoch,
            module.cfg.experiment.name,
        )

    # @rank_zero_only
    # def on_validation_epoch_end(self, trainer, module):
    #     results = {
    #         **dict(module.metrics_val.items()),
    #         **dict(module.losses_val.items()),
    #     }
    #     results = [f"{k} {v.compute():.3E}" for k, v in results.items()]
    #     logger.info(f'[Validation] {{{", ".join(results)}}}')


def find_last_checkpoint_path(experiment_dir):
    cls = pl.callbacks.ModelCheckpoint
    path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION)
    if osp.exists(path):
        return path
    else:
        return None


def prepare_experiment_dir(experiment_dir, cfg, rank):
    config_path = osp.join(experiment_dir, "config.yaml")
    last_checkpoint_path = find_last_checkpoint_path(experiment_dir)
    if last_checkpoint_path is not None:
        if rank == 0:
            logger.info(
                "Resuming the training from checkpoint %s", last_checkpoint_path
            )
        if osp.exists(config_path):
            with open(config_path, "r") as fp:
                cfg_prev = OmegaConf.create(fp.read())
            compare_keys = ["experiment", "data", "model", "training"]
            if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy(
                    cfg_prev, compare_keys
            ):
                raise ValueError(
                    "Attempting to resume training with a different config: "
                    f"{OmegaConf.masked_copy(cfg, compare_keys)} vs "
                    f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}"
                )
    if rank == 0:
        Path(experiment_dir).mkdir(exist_ok=True, parents=True)
        with open(config_path, "w") as fp:
            OmegaConf.save(cfg, fp)
    return last_checkpoint_path


def train(cfg: DictConfig) -> None:
    torch.set_float32_matmul_precision("medium")
    OmegaConf.resolve(cfg)
    rank = rank_zero_only.rank

    if rank == 0:
        logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg))
    if cfg.experiment.gpus in (None, 0):
        logger.warning("Will train on CPU...")
        cfg.experiment.gpus = 0
    elif not torch.cuda.is_available():
        raise ValueError("Requested GPU but no NVIDIA drivers found.")
    pl.seed_everything(cfg.experiment.seed, workers=True)

    init_checkpoint_path = cfg.training.get("finetune_from_checkpoint")
    if init_checkpoint_path is not None:
        logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path)
        model = GenericModule.load_from_checkpoint(
            init_checkpoint_path, strict=True, find_best=False, cfg=cfg
        )
    else:
        model = GenericModule(cfg)
    if rank == 0:
        logger.info("Network:\n%s", model.model)

    experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name)
    last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank)
    checkpointing_epoch = pl.callbacks.ModelCheckpoint(
        dirpath=experiment_dir,
        filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}",
        auto_insert_metric_name=False,
        save_last=True,
        every_n_epochs=1,
        save_on_train_epoch_end=True,
        verbose=True,
        **cfg.training.checkpointing,
    )
    checkpointing_step = pl.callbacks.ModelCheckpoint(
        dirpath=experiment_dir,
        filename="checkpoint-step-{step}-{loss/total/val:02f}",
        auto_insert_metric_name=False,
        save_last=True,
        every_n_train_steps=1000,
        verbose=True,
        **cfg.training.checkpointing,
    )
    checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing"

    # 创建 EarlyStopping 回调
    early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5)

    strategy = None
    if cfg.experiment.gpus > 1:
        strategy = pl.strategies.DDPStrategy(find_unused_parameters=False)
        for split in ["train", "val"]:
            cfg.data[split].batch_size = (
                    cfg.data[split].batch_size // cfg.experiment.gpus
            )
            cfg.data[split].num_workers = int(
                (cfg.data[split].num_workers + cfg.experiment.gpus - 1)
                / cfg.experiment.gpus
            )

    # data = data_modules[cfg.data.get("name", "mapillary")](cfg.data)

    datamodule =UavMapDatasetModule(cfg.data)

    tb_args = {"name": cfg.experiment.name, "version": ""}
    tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args)

    callbacks = [
        checkpointing_epoch,
        checkpointing_step,
        # early_stopping_callback,
        pl.callbacks.LearningRateMonitor(),
        SeedingCallback(),
        CleanProgressBar(),
        ConsoleLogger(),
    ]
    if cfg.experiment.gpus > 0:
        callbacks.append(pl.callbacks.DeviceStatsMonitor())

    trainer = pl.Trainer(
        default_root_dir=experiment_dir,
        detect_anomaly=False,
        # strategy=ddp_find_unused_parameters_true,
        enable_model_summary=True,
        sync_batchnorm=True,
        enable_checkpointing=True,
        logger=tb,
        callbacks=callbacks,
        strategy=strategy,
        check_val_every_n_epoch=1,
        accelerator="gpu",
        num_nodes=1,
        **cfg.training.trainer,
    )
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path)


@hydra.main(
    config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml"
)
def main(cfg: DictConfig) -> None:
    OmegaConf.save(config=cfg, f='maplocnet.yaml')
    train(cfg)


if __name__ == "__main__":
    main()