import time

import gradio as gr

from gradio_molecule3d import Molecule3D

import numpy as np
from scipy.optimize import differential_evolution, NonlinearConstraint
from biotite.structure.io.pdb import PDBFile
from rdkit import Chem
from rdkit.Chem import AllChem
from biotite.structure import AtomArrayStack


def generate_input_conformer(
    ligand_smiles: str,
    addHs: bool = False,
    minimize_maxIters: int = -1,
) -> Chem.Mol:
    _mol = Chem.MolFromSmiles(ligand_smiles)
    # need to add Hs to generate sensible conformers
    _mol = Chem.AddHs(_mol)

    # try embedding molecule using ETKDGv2 (default)
    confid = AllChem.EmbedMolecule(
        _mol,
        useRandomCoords=True,
        useBasicKnowledge=True,
        maxAttempts=100,
        randomSeed=42,
    )
    if confid != -1:
        if minimize_maxIters > 0:
            # molecule successfully embedded - minimize
            success = AllChem.MMFFOptimizeMolecule(_mol, maxIters=minimize_maxIters)
            # 0 if the optimization converged,
            # -1 if the forcefield could not be set up,
            # 1 if more iterations are required.
            if success == 1:
                # extend optimization to double the steps (extends by the same amount)
                AllChem.MMFFOptimizeMolecule(_mol, maxIters=minimize_maxIters)
    else:
        # this means EmbedMolecule failed
        # try less optimal approach
        confid = AllChem.EmbedMolecule(
            _mol,
            useRandomCoords=True,
            useBasicKnowledge=False,
            maxAttempts=100,
            randomSeed=42,
        )
    return _mol


def optimize_coordinate(points, bound_buffer=15, dmin=6.02):
    bounds = list(
        zip(
            np.average(points, axis=0) - [bound_buffer]*3,
            np.average(points, axis=0) + [bound_buffer]*3
            )
        )
    # Define the constraint function (ensure dmin distance)
    con = NonlinearConstraint(lambda x: np.min(np.linalg.norm(points - x, axis=1)), dmin, 8)
    # Define the objective function (minimize pairwise distance)
    def objective(x):
        return np.sum(np.linalg.norm(points - x, axis=1))
    # Perform differential evolution to find the optimal coordinate
    result = differential_evolution(objective, bounds, constraints=con)
    return result.x, result.fun


def optimize_decoy_coordinate(points, bound_buffer=15, dmin=6.02, decoy_min=4.0, decoy_max=4.98):
    bounds = list(
        zip(
            np.average(points, axis=0) - [bound_buffer]*3,
            np.average(points, axis=0) + [bound_buffer]*3
            )
        )
    # Define the constraint function (ensure dmin distance for all but one atom)
    con1 = NonlinearConstraint(lambda x: np.sum(np.linalg.norm(points - x, axis=1) < dmin), 1, 1)
    con2 = NonlinearConstraint(lambda x: np.min(np.linalg.norm(points - x, axis=1)), decoy_min, decoy_max)
    # Define the objective function (maximize pairwise distance)
    def objective(x):
        return - np.sum(np.linalg.norm(points - x, axis=1))
    # Perform differential evolution to find the optimal coordinate
    result = differential_evolution(objective, bounds, constraints=(con1, con2))
    return result.x, result.fun


def add_decoy_atom(structure, decoy_pos):
    decoy = AtomArrayStack(length=1, depth=1)
    decoy.coord = np.ones_like(decoy.coord) * decoy_pos
    decoy.chain_id = ["q"]
    decoy.element = ["C"]
    decoy.atom_name = ["C"]
    decoy.res_name = ["GLY"]
    return structure + decoy


def set_protein_to_new_coord_plus_decoy_atom(input_pdb_file, new_coord, decoy_coord, output_file):
    structure = PDBFile.read(input_pdb_file).get_structure()
    structure.coord = np.ones_like(structure.coord) * np.array(new_coord)
    # add decoy 
    structure = add_decoy_atom(structure, decoy_coord)
    file = PDBFile()
    file.set_structure(structure)
    file.write(output_file)

    
def predict(input_sequence, input_ligand, input_msa, input_protein):
    start_time = time.time()
    
    # Do inference here
    mol = generate_input_conformer(input_ligand, minimize_maxIters=500)
    
    molwriter = Chem.SDWriter("test_docking_pose.sdf")
    molwriter.write(mol)

    # get only non hydrogen atoms
    heavy_atom_mask = [at.GetAtomicNum() != 1 for at in mol.GetAtoms()]
    mol_coords = mol.GetConformer().GetPositions()[heavy_atom_mask]
    # get opt coords
    new_coord, min_dist_sum = optimize_coordinate(mol_coords)
    # get mindist to protein
    min_dist = np.min(np.linalg.norm(mol_coords - new_coord, axis=1))
    # decoy coord
    decoy_coord, _ = optimize_decoy_coordinate(mol_coords)
    decoy_min_dist = np.min(np.linalg.norm(mol_coords - decoy_coord, axis=1))

    # save protein
    output_file = "test_out.pdb"
    set_protein_to_new_coord_plus_decoy_atom(input_protein, new_coord, decoy_coord, output_file)
    
    # return an output pdb file with the protein and ligand with resname LIG or UNK. 
    # also return any metrics you want to log, metrics will not be used for evaluation but might be useful for users
    metrics = {"min_dist": min_dist, "min_dist_sum": min_dist_sum, "decoy_min_dist": decoy_min_dist}
    
    end_time = time.time()
    run_time = end_time - start_time
    return ["test_out.pdb", "test_docking_pose.sdf"], metrics, run_time

with gr.Blocks() as app:

    gr.Markdown("# Template for inference")

    gr.Markdown("Title, description, and other information about the model")   
    with gr.Row():
        input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
        input_ligand = gr.Textbox(lines=3, label="Input ligand SMILES")
    with gr.Row():
        input_msa = gr.File(label="Input Protein MSA (A3M)")
        input_protein = gr.File(label="Input protein monomer")
        
    
    # define any options here

    # for automated inference the default options are used
    # slider_option = gr.Slider(0,10, label="Slider Option")
    # checkbox_option = gr.Checkbox(label="Checkbox Option")
    # dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")

    btn = gr.Button("Run Inference")

    gr.Examples(
        [
            [
                "",
                "COc1ccc(cc1)n2c3c(c(n2)C(=O)N)CCN(C3=O)c4ccc(cc4)N5CCCCC5=O",
                "empty_file.a3m",
                "test_input.pdb"
            ],
        ],
        [input_sequence, input_ligand, input_msa, input_protein],
    )
    reps =    [
    {
      "model": 0,
      "style": "sphere",
      "color": "grayCarbon",
    },
        {
      "model": 1,
      "style": "stick",
      "color": "greenCarbon",
    }
        
  ]
    
    out = Molecule3D(reps=reps)
    metrics = gr.JSON(label="Metrics")
    run_time = gr.Textbox(label="Runtime")

    btn.click(predict, inputs=[input_sequence, input_ligand, input_msa, input_protein], outputs=[out, metrics, run_time])

app.launch()