|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
from torch_geometric.loader import DataLoader |
|
from torch_geometric.transforms import Compose |
|
|
|
|
|
|
|
from core.config.config import Config, parse_config |
|
from core.models.sbdd_train_loop import SBDDTrainLoop |
|
from core.callbacks.basic import NormalizerCallback |
|
from core.callbacks.validation_callback import ( |
|
DockingTestCallback, |
|
OUT_DIR |
|
) |
|
|
|
import core.utils.transforms as trans |
|
from core.datasets.utils import PDBProtein, parse_sdf_file |
|
from core.datasets.pl_data import ProteinLigandData, torchify_dict |
|
from core.datasets.pl_data import FOLLOW_BATCH |
|
|
|
import pytorch_lightning as pl |
|
|
|
from pytorch_lightning import seed_everything |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_dataloader_from_pdb(cfg): |
|
assert cfg.evaluation.protein_path is not None and cfg.evaluation.ligand_path is not None |
|
protein_fn, ligand_fn = cfg.evaluation.protein_path, cfg.evaluation.ligand_path |
|
|
|
|
|
protein = PDBProtein(protein_fn) |
|
ligand_dict = parse_sdf_file(ligand_fn) |
|
lig_pos = ligand_dict["pos"] |
|
|
|
print('[DEBUG] get_dataloader') |
|
print(lig_pos.shape, lig_pos.mean(axis=0)) |
|
|
|
pdb_block_pocket = protein.residues_to_pdb_block( |
|
protein.query_residues_ligand(ligand_dict, cfg.dynamics.net_config.r_max) |
|
) |
|
pocket = PDBProtein(pdb_block_pocket) |
|
pocket_dict = pocket.to_dict_atom() |
|
|
|
data = ProteinLigandData.from_protein_ligand_dicts( |
|
protein_dict=torchify_dict(pocket_dict), |
|
ligand_dict=torchify_dict(ligand_dict), |
|
) |
|
data.protein_filename = protein_fn |
|
data.ligand_filename = ligand_fn |
|
|
|
|
|
protein_featurizer = trans.FeaturizeProteinAtom() |
|
ligand_featurizer = trans.FeaturizeLigandAtom(cfg.data.transform.ligand_atom_mode) |
|
transform_list = [ |
|
protein_featurizer, |
|
ligand_featurizer, |
|
] |
|
transform = Compose(transform_list) |
|
cfg.dynamics.protein_atom_feature_dim = protein_featurizer.feature_dim |
|
cfg.dynamics.ligand_atom_feature_dim = ligand_featurizer.feature_dim |
|
print(f"protein feature dim: {cfg.dynamics.protein_atom_feature_dim}, " + |
|
f"ligand feature dim: {cfg.dynamics.ligand_atom_feature_dim}") |
|
|
|
|
|
collate_exclude_keys = ["ligand_nbh_list"] |
|
test_set = [transform(data)] * cfg.evaluation.num_samples |
|
cfg.evaluation.num_samples = 1 |
|
test_loader = DataLoader( |
|
test_set, |
|
batch_size=cfg.evaluation.batch_size, |
|
shuffle=False, |
|
follow_batch=FOLLOW_BATCH, |
|
exclude_keys=collate_exclude_keys |
|
) |
|
|
|
cfg.evaluation.docking_config.protein_root = os.path.dirname(os.path.abspath(protein_fn)) |
|
print(f"protein root: {cfg.evaluation.docking_config.protein_root}") |
|
|
|
return test_loader |
|
|
|
|
|
def call(protein_fn, ligand_fn, |
|
num_samples=10, sample_steps=100, sample_num_atoms='prior', |
|
beta1=1.5, sigma1_coord=0.03, sampling_strategy='end_back', seed=1234): |
|
|
|
cfg = Config('./checkpoints/config.yaml') |
|
seed_everything(cfg.seed) |
|
|
|
cfg.evaluation.protein_path = protein_fn |
|
cfg.evaluation.ligand_path = ligand_fn |
|
cfg.test_only = True |
|
cfg.no_wandb = True |
|
cfg.evaluation.num_samples = num_samples |
|
cfg.evaluation.sample_steps = sample_steps |
|
cfg.evaluation.sample_num_atoms = sample_num_atoms |
|
cfg.dynamics.beta1 = beta1 |
|
cfg.dynamics.sigma1_coord = sigma1_coord |
|
cfg.dynamics.sampling_strategy = sampling_strategy |
|
cfg.seed = seed |
|
cfg.train.max_grad_norm = 'Q' |
|
|
|
|
|
|
|
print(protein_fn, ligand_fn) |
|
test_loader = get_dataloader_from_pdb(cfg) |
|
|
|
|
|
model = SBDDTrainLoop(config=cfg) |
|
|
|
trainer = pl.Trainer( |
|
default_root_dir=cfg.accounting.logdir, |
|
max_epochs=cfg.train.epochs, |
|
check_val_every_n_epoch=cfg.train.ckpt_freq, |
|
devices=1, |
|
|
|
num_sanity_val_steps=0, |
|
callbacks=[ |
|
NormalizerCallback(normalizer_dict=cfg.data.normalizer_dict), |
|
DockingTestCallback( |
|
dataset=None, |
|
atom_decoder=cfg.data.atom_decoder, |
|
atom_enc_mode=cfg.data.transform.ligand_atom_mode, |
|
atom_type_one_hot=False, |
|
single_bond=True, |
|
docking_config=cfg.evaluation.docking_config, |
|
), |
|
], |
|
) |
|
|
|
trainer.test(model, dataloaders=test_loader, ckpt_path=cfg.evaluation.ckpt_path) |
|
|
|
|
|
class Metrics: |
|
def __init__(self, protein_fn, ref_ligand_fn, ligand_fn): |
|
self.protein_fn = protein_fn |
|
self.ref_ligand_fn = ref_ligand_fn |
|
self.ligand_fn = ligand_fn |
|
self.exhaustiveness = 16 |
|
|
|
def vina_dock(self, mol): |
|
chem_results = {} |
|
|
|
try: |
|
|
|
chem_results.update(scoring_func.get_chem(mol)) |
|
chem_results['atom_num'] = mol.GetNumAtoms() |
|
|
|
|
|
vina_task = VinaDockingTask.from_generated_mol( |
|
mol, protein_filename=self.protein_fn) |
|
score_only_results = vina_task.run(mode='score_only', exhaustiveness=self.exhaustiveness) |
|
minimize_results = vina_task.run(mode='minimize', exhaustiveness=self.exhaustiveness) |
|
docking_results = vina_task.run(mode='dock', exhaustiveness=self.exhaustiveness) |
|
|
|
chem_results['vina_score'] = score_only_results[0]['affinity'] |
|
chem_results['vina_minimize'] = minimize_results[0]['affinity'] |
|
chem_results['vina_dock'] = docking_results[0]['affinity'] |
|
|
|
|
|
return chem_results |
|
except Exception as e: |
|
print(e) |
|
|
|
return chem_results |
|
|
|
def pose_check(self, mol): |
|
pc = PoseCheck() |
|
|
|
pose_check_results = {} |
|
|
|
protein_ready = False |
|
try: |
|
pc.load_protein_from_pdb(self.protein_fn) |
|
protein_ready = True |
|
except ValueError as e: |
|
return pose_check_results |
|
|
|
ligand_ready = False |
|
try: |
|
pc.load_ligands_from_mols([mol]) |
|
ligand_ready = True |
|
except ValueError as e: |
|
return pose_check_results |
|
|
|
if ligand_ready: |
|
try: |
|
strain = pc.calculate_strain_energy()[0] |
|
pose_check_results['strain'] = strain |
|
except Exception as e: |
|
pass |
|
|
|
if protein_ready and ligand_ready: |
|
try: |
|
clash = pc.calculate_clashes()[0] |
|
pose_check_results['clash'] = clash |
|
except Exception as e: |
|
pass |
|
|
|
try: |
|
df = pc.calculate_interactions() |
|
columns = np.array([column[2] for column in df.columns]) |
|
flags = np.array([df[column][0] for column in df.columns]) |
|
|
|
def count_inter(inter_type): |
|
if len(columns) == 0: |
|
return 0 |
|
count = sum((columns == inter_type) & flags) |
|
return count |
|
|
|
|
|
hb_donor = count_inter('HBDonor') |
|
hb_acceptor = count_inter('HBAcceptor') |
|
vdw = count_inter('VdWContact') |
|
hydrophobic = count_inter('Hydrophobic') |
|
|
|
pose_check_results['hb_donor'] = hb_donor |
|
pose_check_results['hb_acceptor'] = hb_acceptor |
|
pose_check_results['vdw'] = vdw |
|
pose_check_results['hydrophobic'] = hydrophobic |
|
except Exception as e: |
|
pass |
|
|
|
for k, v in pose_check_results.items(): |
|
mol.SetProp(k, str(v)) |
|
|
|
return pose_check_results |
|
|
|
def evaluate(self): |
|
mol = Chem.SDMolSupplier(self.ligand_fn, removeHs=False)[0] |
|
|
|
chem_results = self.vina_dock(mol) |
|
pose_check_results = self.pose_check(mol) |
|
chem_results.update(pose_check_results) |
|
|
|
return chem_results |
|
|
|
|