import json
import sys
from typing import Optional

# This import must be on top to set the environment variables before importing other modules
import env_consts
import time
import os

from lightning.pytorch import seed_everything
import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.profilers import AdvancedProfiler

from dockformerpp.config import model_config
from dockformerpp.data.data_modules import OpenFoldDataModule, DockFormerDataModule
from dockformerpp.model.model import AlphaFold
from dockformerpp.utils import residue_constants
from dockformerpp.utils.exponential_moving_average import ExponentialMovingAverage
from dockformerpp.utils.loss import AlphaFoldLoss, lddt_ca
from dockformerpp.utils.lr_schedulers import AlphaFoldLRScheduler
from dockformerpp.utils.script_utils import get_latest_checkpoint
from dockformerpp.utils.superimposition import superimpose
from dockformerpp.utils.tensor_utils import tensor_tree_map
from dockformerpp.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
    rmsd,
)


class ModelWrapper(pl.LightningModule):
    def __init__(self, config):
        super(ModelWrapper, self).__init__()
        self.config = config
        self.model = AlphaFold(config)

        self.loss = AlphaFoldLoss(config.loss)

        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
        
        self.cached_weights = None
        self.last_lr_step = -1

        self.aggregated_metrics = {}
        self.log_agg_every_n_steps = 50  # match Trainer(log_every_n_steps=50)

    def forward(self, batch):
        return self.model(batch)

    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            # print("logging loss", loss_name, indiv_loss, flush=True)
            self.log(
                f"{phase}/{loss_name}", 
                indiv_loss, 
                on_step=train, on_epoch=(not train), logger=True, sync_dist=True
            )

            if train:
                agg_name = f"{phase}/{loss_name}_agg"
                if agg_name not in self.aggregated_metrics:
                    self.aggregated_metrics[agg_name] = []
                self.aggregated_metrics[agg_name].append(float(indiv_loss))
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
                    on_step=False, on_epoch=True, logger=True, sync_dist=True
                )

        # print("logging validation metrics", flush=True)
        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
                batch, 
                outputs,
                superimposition_metrics=(not train)
            )

        for k, v in other_metrics.items():
            # print("logging metric", k, v, flush=True)
            if train:
                agg_name = f"{phase}/{k}_agg"
                if agg_name not in self.aggregated_metrics:
                    self.aggregated_metrics[agg_name] = []
                self.aggregated_metrics[agg_name].append(float(torch.mean(v)))
            self.log(
                f"{phase}/{k}",
                torch.mean(v),
                on_step=False, on_epoch=True, logger=True, sync_dist=True
            )

        if train and any([len(v) >= self.log_agg_every_n_steps for v in self.aggregated_metrics.values()]):
            for k, v in self.aggregated_metrics.items():
                print("logging agg", k, len(v), sum(v) / len(v), flush=True)
                self.log(k, sum(v) / len(v), on_step=True, on_epoch=False, logger=True, sync_dist=True)
                self.aggregated_metrics[k] = []

    def training_step(self, batch, batch_idx):
        if self.ema.device != batch["aatype"].device:
            self.ema.to(batch["aatype"].device)

        # ground_truth = batch.pop('gt_features', None)

        # Run the model
        # print("running model", round(time.time() % 10000, 3), flush=True)
        outputs = self(batch)

        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

        # print("running loss", round(time.time() % 10000, 3), flush=True)
        # Compute loss
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )

        # Log it
        self._log(loss_breakdown, batch, outputs)
        # print("loss done", round(time.time() % 10000, 3), flush=True)


        return loss

    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)

    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if self.cached_weights is None:
            # model.state_dict() contains references to model weights rather
            # than copies. Therefore, we need to clone them before calling 
            # load_state_dict().
            clone_param = lambda t: t.detach().clone()
            self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
            self.model.load_state_dict(self.ema.state_dict()["params"])

        # Run the model
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

        batch["use_clamped_fape"] = 0.

        # Compute loss and other metrics
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )

        self._log(loss_breakdown, batch, outputs, train=False)
        
    def on_validation_epoch_end(self):
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None

    def _compute_validation_metrics(self, 
        batch, 
        outputs, 
        superimposition_metrics=False
    ):
        metrics = {}
        joined_all_atom_mask = batch["atom37_atom_exists_in_gt"]

        protein_r_all_atom_mask = torch.repeat_interleave(
            batch["protein_r_mask"], 37, dim=-1).view(*joined_all_atom_mask.shape)
        protein_l_all_atom_mask = torch.repeat_interleave(
            batch["protein_l_mask"], 37, dim=-1).view(*joined_all_atom_mask.shape)

        lddt_ca_score = lddt_ca(
            outputs["final_atom_positions"],
            batch["atom37_gt_positions"],
            joined_all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
        metrics["lddt_ca_joined"] = lddt_ca_score

        lddt_ca_score = lddt_ca(
            outputs["final_atom_positions"],
            batch["atom37_gt_positions"],
            protein_r_all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
        metrics["lddt_ca_r"] = lddt_ca_score

        lddt_ca_score = lddt_ca(
            outputs["final_atom_positions"],
            batch["atom37_gt_positions"],
            protein_l_all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
        metrics["lddt_ca_l"] = lddt_ca_score

        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_ca = batch["atom37_gt_positions"][..., ca_pos, :]
        pred_coords_ca = outputs["final_atom_positions"][..., ca_pos, :]

        drmsd_ca_score = drmsd(
            pred_coords_ca,
            gt_coords_ca,
            mask=batch["structural_mask"],  # still required here to compute n
        )
        metrics["drmsd_ca_joined"] = drmsd_ca_score

        drmsd_ca_score = drmsd(
            pred_coords_ca,
            gt_coords_ca,
            mask=batch["protein_r_mask"],
        )
        metrics["drmsd_ca_r"] = drmsd_ca_score

        drmsd_ca_score = drmsd(
            pred_coords_ca,
            gt_coords_ca,
            mask=batch["protein_l_mask"],
        )
        metrics["drmsd_ca_l"] = drmsd_ca_score

        # --- inter contacts
        gt_contacts = batch["gt_inter_contacts"]
        pred_contacts = torch.sigmoid(outputs["inter_contact_logits"].clone().detach()).squeeze(-1)
        pred_contacts = (pred_contacts > 0.5).float()
        pred_contacts = pred_contacts * batch["inter_pair_mask"]

        # Calculate True Positives, False Positives, and False Negatives
        tp = torch.sum((gt_contacts == 1) & (pred_contacts == 1))
        fp = torch.sum((gt_contacts == 0) & (pred_contacts == 1))
        fn = torch.sum((gt_contacts == 1) & (pred_contacts == 0))

        # Calculate Recall and Precision
        recall = tp / (tp + fn) if (tp + fn) > 0 else tp.float()
        precision = tp / (tp + fp) if (tp + fp) > 0 else tp.float()

        metrics["inter_contacts_recall"] = recall.clone().detach()
        metrics["inter_contacts_precision"] = precision.clone().detach()

        # --- Affinity
        gt_affinity = batch["affinity"].squeeze(-1)
        affinity_linspace = torch.linspace(0, 15, 32, device=batch["affinity"].device)

        pred_affinity_2d = torch.sum(
            torch.softmax(outputs["affinity_2d_logits"].clone().detach(), -1) * affinity_linspace, dim=-1)

        pred_affinity_cls = torch.sum(
            torch.softmax(outputs["affinity_cls_logits"].clone().detach(), -1) * affinity_linspace, dim=-1)

        aff_loss_factor = batch["affinity_loss_factor"].squeeze()

        metrics["affinity_dist_2d"] = (torch.abs(gt_affinity - pred_affinity_2d) * aff_loss_factor).sum() / aff_loss_factor.sum()
        metrics["affinity_dist_cls"] = (torch.abs(gt_affinity - pred_affinity_cls) * aff_loss_factor).sum() / aff_loss_factor.sum()
        metrics["affinity_dist_avg"] = (torch.abs(gt_affinity - (pred_affinity_cls + pred_affinity_2d) / 2) * aff_loss_factor).sum() / aff_loss_factor.sum()

        if superimposition_metrics:
            superimposed_pred, alignment_rmsd, rots, transs = superimpose(
                gt_coords_ca, pred_coords_ca, batch["structural_mask"],
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_ca, batch["structural_mask"]
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_ca, batch["structural_mask"]
            )

            metrics["alignment_rmsd_joined"] = alignment_rmsd
            metrics["gdt_ts_joined"] = gdt_ts_score
            metrics["gdt_ha_joined"] = gdt_ha_score

            superimposed_pred_l, alignment_rmsd, rots, transs = superimpose(
                gt_coords_ca, pred_coords_ca, batch["protein_l_mask"],
            )
            metrics["alignment_rmsd_l"] = alignment_rmsd

            superimposed_pred_r, alignment_rmsd, rots, transs = superimpose(
                gt_coords_ca, pred_coords_ca, batch["protein_r_mask"],
            )
            metrics["alignment_rmsd_r"] = alignment_rmsd

            superimposed_l_by_r_trans_coords = pred_coords_ca @ rots + transs[:, None, :]
            l_by_r_alignment_rmsds = rmsd(gt_coords_ca, superimposed_l_by_r_trans_coords, mask=batch["protein_l_mask"])
            metrics["alignment_rmsd_l_by_r"] = l_by_r_alignment_rmsds.mean()

            metrics["alignment_rmsd_l_by_r_under_2"] = torch.mean((l_by_r_alignment_rmsds < 2).float())
            metrics["alignment_rmsd_l_by_r_under_5"] = torch.mean((l_by_r_alignment_rmsds < 5).float())

            print("ligand rmsd:", l_by_r_alignment_rmsds)

        return metrics

    def configure_optimizers(self, 
        learning_rate: Optional[float] = None,
        eps: float = 1e-5,
    ) -> torch.optim.Adam:
        if learning_rate is None:
            learning_rate = self.config.globals.max_lr

        optimizer = torch.optim.Adam(
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
            last_epoch=self.last_lr_step,
            max_lr=self.config.globals.max_lr,
            start_decay_after_n_steps=10000,
            decay_every_n_steps=10000,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }

    def on_load_checkpoint(self, checkpoint):
        ema = checkpoint["ema"]
        self.ema.load_state_dict(ema)

    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step


def override_config(base_config, overriding_config):
    for k, v in overriding_config.items():
        if isinstance(v, dict):
            base_config[k] = override_config(base_config[k], v)
        else:
            base_config[k] = v
    return base_config


def train(override_config_path: str):
    run_config = json.load(open(override_config_path, "r"))
    seed = 42
    seed_everything(seed, workers=True)
    output_dir = run_config["train_output_dir"]
    os.makedirs(output_dir, exist_ok=True)

    print("Starting train", time.time())
    config = model_config(
        run_config.get("stage", "initial_training"),
        train=True,
        low_prec=True
    )
    config = override_config(config, run_config.get("override_conf", {}))
    accumulate_grad_batches = run_config.get("accumulate_grad_batches", 1)
    print("config loaded", time.time())

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_name = "cuda" if torch.cuda.is_available() else "cpu"
    # device_name = "mps" if device_name == "cpu" and torch.backends.mps.is_available() else device_name
    model_module = ModelWrapper(config)
    print("model loaded", time.time())

    # device_name = "cpu"

    # for debugging memory:
    # torch.cuda.memory._record_memory_history()

    if "train_input_dir" in run_config:
        data_module = OpenFoldDataModule(
            config=config.data,
            batch_seed=seed,
            train_data_dir=run_config["train_input_dir"],
            val_data_dir=run_config["val_input_dir"],
            train_epoch_len=run_config.get("train_epoch_len", 1000),
        )
    else:
        data_module = DockFormerDataModule(
            config=config.data,
            batch_seed=seed,
            train_data_file=run_config["train_input_file"],
            val_data_file=run_config["val_input_file"],
        )
    print("data module loaded", time.time())

    checkpoint_dir = os.path.join(output_dir, "checkpoint")
    ckpt_path = run_config.get("ckpt_path", get_latest_checkpoint(checkpoint_dir))

    if ckpt_path:
        print(f"Resuming from checkpoint: {ckpt_path}")
        sd = torch.load(ckpt_path)
        last_global_step = int(sd['global_step'])
        model_module.resume_last_lr_step(last_global_step)

    # Do we need this?
    data_module.prepare_data()
    data_module.setup("fit")

    callbacks = []

    mc = ModelCheckpoint(
        dirpath=checkpoint_dir,
        # every_n_epochs=1,
        every_n_train_steps=250,
        auto_insert_metric_name=False,
        save_top_k=1,
        save_on_train_epoch_end=True,  # before validation
    )

    mc2 = ModelCheckpoint(
        dirpath=checkpoint_dir,  # Directory to save checkpoints
        filename="step{step}_rmsd{val/alignment_rmsd_l_by_r:.2f}",  # Filename format for best
        monitor="val/alignment_rmsd_l_by_r",  # Metric to monitor
        mode="min",  # We want the lowest `ligand_rmsd`
        save_top_k=1,  # Save only the best model based on `ligand_rmsd`
        every_n_epochs=1,  # Save a checkpoint every epoch
        auto_insert_metric_name=False
    )
    callbacks.append(mc)
    callbacks.append(mc2)

    lr_monitor = LearningRateMonitor(logging_interval="step")
    callbacks.append(lr_monitor)

    loggers = []

    wandb_project_name = "DockFormerPP"
    wandb_run_id_path = os.path.join(output_dir, "wandb_run_id.txt")

    # Initialize WandbLogger and save run_id
    local_rank = int(os.getenv('LOCAL_RANK', os.getenv("SLURM_PROCID", '0')))
    global_rank = int(os.getenv('GLOBAL_RANK', os.getenv("SLURM_LOCALID", '0')))
    print("ranks", os.getenv('LOCAL_RANK', 'd0'), os.getenv('local_rank', 'd0'), os.getenv('GLOBAL_RANK', 'd0'),
          os.getenv('global_rank', 'd0'), os.getenv("SLURM_PROCID", 'd0'), os.getenv('SLURM_LOCALID', 'd0'), flush=True)
    if local_rank == 0 and global_rank == 0 and not os.path.exists(wandb_run_id_path):
        wandb_logger = WandbLogger(project=wandb_project_name, save_dir=output_dir)
        with open(wandb_run_id_path, 'w') as f:
            f.write(wandb_logger.experiment.id)
        wandb_logger.experiment.config.update(run_config, allow_val_change=True)
    else:
        # Necessary for multi-node training https://github.com/rstrudel/segmenter/issues/22
        while not os.path.exists(wandb_run_id_path):
            print(f"Waiting for run_id file to be created ({local_rank})", flush=True)
            time.sleep(1)
        with open(wandb_run_id_path, 'r') as f:
            run_id = f.read().strip()
        wandb_logger = WandbLogger(project=wandb_project_name, save_dir=output_dir, resume='must', id=run_id)
    loggers.append(wandb_logger)

    strategy_params = {"strategy": "auto"}
    if run_config.get("multi_node", False):
        strategy_params["strategy"] = "ddp"
        # strategy_params["strategy"] = "ddp_find_unused_parameters_true" # this causes issues with checkpointing...
        strategy_params["num_nodes"] = run_config["multi_node"]["num_nodes"]
        strategy_params["devices"] = run_config["multi_node"]["devices"]

    trainer = pl.Trainer(
        accelerator=device_name,
        default_root_dir=output_dir,
        **strategy_params,
        reload_dataloaders_every_n_epochs=1,
        accumulate_grad_batches=accumulate_grad_batches,
        check_val_every_n_epoch=run_config.get("check_val_every_n_epoch", 10),
        callbacks=callbacks,
        logger=loggers,
        # profiler=AdvancedProfiler(),
    )

    print("Starting fit", time.time())
    trainer.fit(
        model_module,
        datamodule=data_module,
        ckpt_path=ckpt_path,
    )

    # profiler_results = trainer.profiler.summary()
    # print(profiler_results)

    # torch.cuda.memory._dump_snapshot("my_train_snapshot.pickle")
    # view on https://pytorch.org/memory_viz


if __name__ == "__main__":
    if len(sys.argv) > 1:
        train(sys.argv[1])
    else:
        train(os.path.join(os.path.dirname(__file__), "run_config.json"))