Spaces:
Running
Running
import logging | |
from collections import defaultdict | |
from typing import List, Callable | |
from gt4sd.properties import PropertyPredictorRegistry | |
from gt4sd.algorithms.prediction.paccmann.core import PaccMann, AffinityPredictor | |
import torch | |
import mols2grid | |
import pandas as pd | |
logger = logging.getLogger(__name__) | |
logger.addHandler(logging.NullHandler()) | |
def get_affinity_function(target: str) -> Callable: | |
return lambda mols: torch.stack( | |
list( | |
PaccMann( | |
AffinityPredictor(protein_targets=[target] * len(mols), ligands=mols) | |
).sample(len(mols)) | |
) | |
).tolist() | |
EVAL_DICT = { | |
"qed": PropertyPredictorRegistry.get_property_predictor("qed"), | |
"sa": PropertyPredictorRegistry.get_property_predictor("sas"), | |
} | |
def draw_grid_generate( | |
samples: List[str], | |
properties: List[str], | |
protein_target: str, | |
n_cols: int = 3, | |
size=(140, 200), | |
) -> str: | |
""" | |
Uses mols2grid to draw a HTML grid for the generated molecules | |
Args: | |
samples: The generated samples. | |
n_cols: Number of columns in grid. Defaults to 5. | |
size: Size of molecule in grid. Defaults to (140, 200). | |
Returns: | |
HTML to display | |
""" | |
if protein_target != "": | |
EVAL_DICT.update({"affinity": get_affinity_function(protein_target)}) | |
result = defaultdict(list) | |
result.update( | |
{"SMILES": samples, "Name": [f"Generated_{i}" for i in range(len(samples))]}, | |
) | |
if "affinity" in properties: | |
properties.remove("affinity") | |
vals = EVAL_DICT["affinity"](samples) | |
result["affinity"] = vals | |
# Fill properties | |
for sample in samples: | |
for prop in properties: | |
value = EVAL_DICT[prop](sample) | |
result[prop].append(f"{prop} = {value}") | |
result_df = pd.DataFrame(result) | |
obj = mols2grid.display( | |
result_df, | |
tooltip=list(result.keys()), | |
height=1100, | |
n_cols=n_cols, | |
name="Results", | |
size=size, | |
) | |
return obj.data | |