import json
import logging
import os
import re
import time
from typing import List, Tuple

import numpy
import torch
from rdkit import Chem

from dockformerpp.model.model import AlphaFold
from dockformerpp.utils import residue_constants, protein
from dockformerpp.utils.consts import POSSIBLE_ATOM_TYPES, POSSIBLE_BOND_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES

logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)


def count_models_to_evaluate(model_checkpoint_path):
    model_count = 0
    if model_checkpoint_path:
        model_count += len(model_checkpoint_path.split(","))
    return model_count


def get_model_basename(model_path):
    return os.path.splitext(
                os.path.basename(
                    os.path.normpath(model_path)
                )
            )[0]


def make_output_directory(output_dir, model_name, multiple_model_mode):
    if multiple_model_mode:
        prediction_dir = os.path.join(output_dir, "predictions", model_name)
    else:
        prediction_dir = os.path.join(output_dir, "predictions")
    os.makedirs(prediction_dir, exist_ok=True)
    return prediction_dir


# Function to get the latest checkpoint
def get_latest_checkpoint(checkpoint_dir):
    if not os.path.exists(checkpoint_dir):
        return None
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
    if not checkpoints:
        return None
    latest_checkpoint = max(checkpoints, key=lambda x: os.path.getctime(os.path.join(checkpoint_dir, x)))
    return os.path.join(checkpoint_dir, latest_checkpoint)


def load_models_from_command_line(config, model_device, model_checkpoint_path, output_dir):
    # Create the output directory

    multiple_model_mode = count_models_to_evaluate(model_checkpoint_path) > 1
    if multiple_model_mode:
        logger.info(f"evaluating multiple models")

    if model_checkpoint_path:
        for path in model_checkpoint_path.split(","):
            model = AlphaFold(config)
            model = model.eval()
            checkpoint_basename = get_model_basename(path)
            assert os.path.isfile(path), f"Model checkpoint not found at {path}"
            ckpt_path = path
            d = torch.load(ckpt_path)

            if "ema" in d:
                # The public weights have had this done to them already
                d = d["ema"]["params"]
            model.load_state_dict(d)


            model = model.to(model_device)
            logger.info(
                f"Loaded Model parameters at {path}..."
            )
            output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode)
            yield model, output_directory

    if not model_checkpoint_path:
        raise ValueError("model_checkpoint_path must be specified.")


def parse_fasta(data):
    data = re.sub('>$', '', data, flags=re.M)
    lines = [
        l.replace('\n', '')
        for prot in data.split('>') for l in prot.strip().split('\n', 1)
    ][1:]
    tags, seqs = lines[::2], lines[1::2]

    tags = [re.split('\W| \|', t)[0] for t in tags]

    return tags, seqs


def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")):
    """
    Write dictionary of one or more run step times to a file
    """
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            try:
                timings = json.load(f)
            except json.JSONDecodeError:
                logger.info(f"Overwriting non-standard JSON in {output_file}.")
                timings = {}
    else:
        timings = {}
    timings.update(timing_dict)
    with open(output_file, "w") as f:
        json.dump(timings, f)
    return output_file


def run_model(model, batch, tag, output_dir):
    with torch.no_grad():
        logger.info(f"Running inference for {tag}...")
        t = time.perf_counter()
        out = model(batch)
        inference_time = time.perf_counter() - t
        logger.info(f"Inference time: {inference_time}")
        update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json"))

    return out


def get_molecule_from_output(atoms_atype: List[int], atom_chiralities: List[int], atom_charges: List[int],
                             bonds: List[Tuple[int, int, int]], atom_positions: List[Tuple[float, float, float]]):
    mol = Chem.RWMol()

    assert len(atoms_atype) == len(atom_chiralities) == len(atom_charges) == len(atom_positions)
    for atype_idx, chirality_idx, charge_idx in zip(atoms_atype, atom_chiralities, atom_charges):
        new_atom = Chem.Atom(POSSIBLE_ATOM_TYPES[atype_idx])
        new_atom.SetChiralTag(POSSIBLE_CHIRALITIES[chirality_idx])
        new_atom.SetFormalCharge(POSSIBLE_CHARGES[charge_idx])

        mol.AddAtom(new_atom)

    # Add bonds
    for bond in bonds:
        atom1, atom2, bond_type_idx = bond
        bond_type = POSSIBLE_BOND_TYPES[bond_type_idx]
        mol.AddBond(int(atom1), int(atom2), bond_type)

    # Set atom positions
    conf = Chem.Conformer(len(atoms_atype))
    for i, pos in enumerate(atom_positions.astype(float)):
        conf.SetAtomPosition(i, pos)
    mol.AddConformer(conf)
    return mol


def save_output_structure(aatype, residue_index, chain_index, plddt, final_atom_protein_positions, final_atom_mask,
                          output_path):
    plddt_b_factors = numpy.repeat(
        plddt[..., None], residue_constants.atom_type_num, axis=-1
    )

    unrelaxed_protein = protein.from_prediction(
        aatype=aatype,
        residue_index=residue_index,
        chain_index=chain_index,
        atom_mask=final_atom_mask,
        atom_positions=final_atom_protein_positions,
        b_factors=plddt_b_factors,
        remove_leading_feature_dimension=False,
    )

    with open(output_path, 'w') as fp:
        fp.write(protein.to_pdb(unrelaxed_protein))

    print("Output written to", output_path)