# import argparse import os # import shutil # import torch # from sklearn.metrics import roc_auc_score from torch_geometric.loader import DataLoader from torch_geometric.transforms import Compose # import datetime, pytz from core.config.config import Config, parse_config from core.models.sbdd_train_loop import SBDDTrainLoop from core.callbacks.basic import NormalizerCallback from core.callbacks.validation_callback import ( DockingTestCallback, OUT_DIR ) import core.utils.transforms as trans from core.datasets.utils import PDBProtein, parse_sdf_file from core.datasets.pl_data import ProteinLigandData, torchify_dict from core.datasets.pl_data import FOLLOW_BATCH import pytorch_lightning as pl from pytorch_lightning import seed_everything # from absl import logging # import glob # from core.evaluation.utils import scoring_func # from core.evaluation.docking_vina import VinaDockingTask # from posecheck import PoseCheck # import numpy as np # from rdkit import Chem def get_dataloader_from_pdb(cfg): assert cfg.evaluation.protein_path is not None and cfg.evaluation.ligand_path is not None protein_fn, ligand_fn = cfg.evaluation.protein_path, cfg.evaluation.ligand_path # load protein and ligand protein = PDBProtein(protein_fn) ligand_dict = parse_sdf_file(ligand_fn) lig_pos = ligand_dict["pos"] print('[DEBUG] get_dataloader') print(lig_pos.shape, lig_pos.mean(axis=0)) pdb_block_pocket = protein.residues_to_pdb_block( protein.query_residues_ligand(ligand_dict, cfg.dynamics.net_config.r_max) ) pocket = PDBProtein(pdb_block_pocket) pocket_dict = pocket.to_dict_atom() data = ProteinLigandData.from_protein_ligand_dicts( protein_dict=torchify_dict(pocket_dict), ligand_dict=torchify_dict(ligand_dict), ) data.protein_filename = protein_fn data.ligand_filename = ligand_fn # transform protein_featurizer = trans.FeaturizeProteinAtom() ligand_featurizer = trans.FeaturizeLigandAtom(cfg.data.transform.ligand_atom_mode) transform_list = [ protein_featurizer, ligand_featurizer, ] transform = Compose(transform_list) cfg.dynamics.protein_atom_feature_dim = protein_featurizer.feature_dim cfg.dynamics.ligand_atom_feature_dim = ligand_featurizer.feature_dim print(f"protein feature dim: {cfg.dynamics.protein_atom_feature_dim}, " + f"ligand feature dim: {cfg.dynamics.ligand_atom_feature_dim}") # dataloader collate_exclude_keys = ["ligand_nbh_list"] test_set = [transform(data)] * cfg.evaluation.num_samples cfg.evaluation.num_samples = 1 test_loader = DataLoader( test_set, batch_size=cfg.evaluation.batch_size, shuffle=False, follow_batch=FOLLOW_BATCH, exclude_keys=collate_exclude_keys ) cfg.evaluation.docking_config.protein_root = os.path.dirname(os.path.abspath(protein_fn)) print(f"protein root: {cfg.evaluation.docking_config.protein_root}") return test_loader def call(protein_fn, ligand_fn, num_samples=10, sample_steps=100, sample_num_atoms='prior', beta1=1.5, sigma1_coord=0.03, sampling_strategy='end_back', seed=1234): cfg = Config('./checkpoints/config.yaml') seed_everything(cfg.seed) cfg.evaluation.protein_path = protein_fn cfg.evaluation.ligand_path = ligand_fn cfg.test_only = True cfg.no_wandb = True cfg.evaluation.num_samples = num_samples cfg.evaluation.sample_steps = sample_steps cfg.evaluation.sample_num_atoms = sample_num_atoms # or 'prior' cfg.dynamics.beta1 = beta1 cfg.dynamics.sigma1_coord = sigma1_coord cfg.dynamics.sampling_strategy = sampling_strategy cfg.seed = seed cfg.train.max_grad_norm = 'Q' # print(f"The config of this process is:\n{cfg}") print(protein_fn, ligand_fn) test_loader = get_dataloader_from_pdb(cfg) # wandb_logger.log_hyperparams(cfg.todict()) model = SBDDTrainLoop(config=cfg) trainer = pl.Trainer( default_root_dir=cfg.accounting.logdir, max_epochs=cfg.train.epochs, check_val_every_n_epoch=cfg.train.ckpt_freq, devices=1, # logger=wandb_logger, num_sanity_val_steps=0, callbacks=[ NormalizerCallback(normalizer_dict=cfg.data.normalizer_dict), DockingTestCallback( dataset=None, # TODO: implement CrossDockGen & NewBenchmark atom_decoder=cfg.data.atom_decoder, atom_enc_mode=cfg.data.transform.ligand_atom_mode, atom_type_one_hot=False, single_bond=True, docking_config=cfg.evaluation.docking_config, ), ], ) trainer.test(model, dataloaders=test_loader, ckpt_path=cfg.evaluation.ckpt_path) class Metrics: def __init__(self, protein_fn, ref_ligand_fn, ligand_fn): self.protein_fn = protein_fn self.ref_ligand_fn = ref_ligand_fn self.ligand_fn = ligand_fn self.exhaustiveness = 16 def vina_dock(self, mol): chem_results = {} try: # qed, logp, sa, lipinski, ring size, etc chem_results.update(scoring_func.get_chem(mol)) chem_results['atom_num'] = mol.GetNumAtoms() # docking vina_task = VinaDockingTask.from_generated_mol( mol, protein_filename=self.protein_fn) score_only_results = vina_task.run(mode='score_only', exhaustiveness=self.exhaustiveness) minimize_results = vina_task.run(mode='minimize', exhaustiveness=self.exhaustiveness) docking_results = vina_task.run(mode='dock', exhaustiveness=self.exhaustiveness) chem_results['vina_score'] = score_only_results[0]['affinity'] chem_results['vina_minimize'] = minimize_results[0]['affinity'] chem_results['vina_dock'] = docking_results[0]['affinity'] # chem_results['vina_dock_pose'] = docking_results[0]['pose'] return chem_results except Exception as e: print(e) return chem_results def pose_check(self, mol): pc = PoseCheck() pose_check_results = {} protein_ready = False try: pc.load_protein_from_pdb(self.protein_fn) protein_ready = True except ValueError as e: return pose_check_results ligand_ready = False try: pc.load_ligands_from_mols([mol]) ligand_ready = True except ValueError as e: return pose_check_results if ligand_ready: try: strain = pc.calculate_strain_energy()[0] pose_check_results['strain'] = strain except Exception as e: pass if protein_ready and ligand_ready: try: clash = pc.calculate_clashes()[0] pose_check_results['clash'] = clash except Exception as e: pass try: df = pc.calculate_interactions() columns = np.array([column[2] for column in df.columns]) flags = np.array([df[column][0] for column in df.columns]) def count_inter(inter_type): if len(columns) == 0: return 0 count = sum((columns == inter_type) & flags) return count # ['Hydrophobic', 'HBDonor', 'VdWContact', 'HBAcceptor'] hb_donor = count_inter('HBDonor') hb_acceptor = count_inter('HBAcceptor') vdw = count_inter('VdWContact') hydrophobic = count_inter('Hydrophobic') pose_check_results['hb_donor'] = hb_donor pose_check_results['hb_acceptor'] = hb_acceptor pose_check_results['vdw'] = vdw pose_check_results['hydrophobic'] = hydrophobic except Exception as e: pass for k, v in pose_check_results.items(): mol.SetProp(k, str(v)) return pose_check_results def evaluate(self): mol = Chem.SDMolSupplier(self.ligand_fn, removeHs=False)[0] chem_results = self.vina_dock(mol) pose_check_results = self.pose_check(mol) chem_results.update(pose_check_results) return chem_results