import logging
from collections import defaultdict
from typing import List, Callable
from gt4sd.properties import PropertyPredictorRegistry
from gt4sd.algorithms.prediction.paccmann.core import PaccMann, AffinityPredictor
import torch

import mols2grid
import pandas as pd

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def get_affinity_function(target: str) -> Callable:
    return lambda mols: torch.stack(
        list(
            PaccMann(
                AffinityPredictor(protein_targets=[target] * len(mols), ligands=mols)
            ).sample(len(mols))
        )
    ).tolist()


EVAL_DICT = {
    "qed": PropertyPredictorRegistry.get_property_predictor("qed"),
    "sa": PropertyPredictorRegistry.get_property_predictor("sas"),
}


def draw_grid_generate(
    samples: List[str],
    properties: List[str],
    protein_target: str,
    n_cols: int = 3,
    size=(140, 200),
) -> str:
    """
    Uses mols2grid to draw a HTML grid for the generated molecules

    Args:
        samples: The generated samples.
        n_cols: Number of columns in grid. Defaults to 5.
        size: Size of molecule in grid. Defaults to (140, 200).

    Returns:
        HTML to display
    """

    if protein_target != "":
        EVAL_DICT.update({"affinity": get_affinity_function(protein_target)})

    result = defaultdict(list)
    result.update(
        {"SMILES": samples, "Name": [f"Generated_{i}" for i in range(len(samples))]},
    )
    if "affinity" in properties:
        properties.remove("affinity")
        vals = EVAL_DICT["affinity"](samples)
        result["affinity"] = vals
    # Fill properties
    for sample in samples:
        for prop in properties:
            value = EVAL_DICT[prop](sample)
            result[prop].append(f"{prop} = {value}")

    result_df = pd.DataFrame(result)
    obj = mols2grid.display(
        result_df,
        tooltip=list(result.keys()),
        height=1100,
        n_cols=n_cols,
        name="Results",
        size=size,
    )
    return obj.data