MolCRAFT / sample_for_pocket.py
Atomu2014's picture
remove vina
3b77f6e
raw
history blame
8.5 kB
# import argparse
import os
# import shutil
# import torch
# from sklearn.metrics import roc_auc_score
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import Compose
# import datetime, pytz
from core.config.config import Config, parse_config
from core.models.sbdd_train_loop import SBDDTrainLoop
from core.callbacks.basic import NormalizerCallback
from core.callbacks.validation_callback import (
DockingTestCallback,
OUT_DIR
)
import core.utils.transforms as trans
from core.datasets.utils import PDBProtein, parse_sdf_file
from core.datasets.pl_data import ProteinLigandData, torchify_dict
from core.datasets.pl_data import FOLLOW_BATCH
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
# from absl import logging
# import glob
# from core.evaluation.utils import scoring_func
# from core.evaluation.docking_vina import VinaDockingTask
# from posecheck import PoseCheck
# import numpy as np
# from rdkit import Chem
def get_dataloader_from_pdb(cfg):
assert cfg.evaluation.protein_path is not None and cfg.evaluation.ligand_path is not None
protein_fn, ligand_fn = cfg.evaluation.protein_path, cfg.evaluation.ligand_path
# load protein and ligand
protein = PDBProtein(protein_fn)
ligand_dict = parse_sdf_file(ligand_fn)
lig_pos = ligand_dict["pos"]
print('[DEBUG] get_dataloader')
print(lig_pos.shape, lig_pos.mean(axis=0))
pdb_block_pocket = protein.residues_to_pdb_block(
protein.query_residues_ligand(ligand_dict, cfg.dynamics.net_config.r_max)
)
pocket = PDBProtein(pdb_block_pocket)
pocket_dict = pocket.to_dict_atom()
data = ProteinLigandData.from_protein_ligand_dicts(
protein_dict=torchify_dict(pocket_dict),
ligand_dict=torchify_dict(ligand_dict),
)
data.protein_filename = protein_fn
data.ligand_filename = ligand_fn
# transform
protein_featurizer = trans.FeaturizeProteinAtom()
ligand_featurizer = trans.FeaturizeLigandAtom(cfg.data.transform.ligand_atom_mode)
transform_list = [
protein_featurizer,
ligand_featurizer,
]
transform = Compose(transform_list)
cfg.dynamics.protein_atom_feature_dim = protein_featurizer.feature_dim
cfg.dynamics.ligand_atom_feature_dim = ligand_featurizer.feature_dim
print(f"protein feature dim: {cfg.dynamics.protein_atom_feature_dim}, " +
f"ligand feature dim: {cfg.dynamics.ligand_atom_feature_dim}")
# dataloader
collate_exclude_keys = ["ligand_nbh_list"]
test_set = [transform(data)] * cfg.evaluation.num_samples
cfg.evaluation.num_samples = 1
test_loader = DataLoader(
test_set,
batch_size=cfg.evaluation.batch_size,
shuffle=False,
follow_batch=FOLLOW_BATCH,
exclude_keys=collate_exclude_keys
)
cfg.evaluation.docking_config.protein_root = os.path.dirname(os.path.abspath(protein_fn))
print(f"protein root: {cfg.evaluation.docking_config.protein_root}")
return test_loader
def call(protein_fn, ligand_fn,
num_samples=10, sample_steps=100, sample_num_atoms='prior',
beta1=1.5, sigma1_coord=0.03, sampling_strategy='end_back', seed=1234):
cfg = Config('./checkpoints/config.yaml')
seed_everything(cfg.seed)
cfg.evaluation.protein_path = protein_fn
cfg.evaluation.ligand_path = ligand_fn
cfg.test_only = True
cfg.no_wandb = True
cfg.evaluation.num_samples = num_samples
cfg.evaluation.sample_steps = sample_steps
cfg.evaluation.sample_num_atoms = sample_num_atoms # or 'prior'
cfg.dynamics.beta1 = beta1
cfg.dynamics.sigma1_coord = sigma1_coord
cfg.dynamics.sampling_strategy = sampling_strategy
cfg.seed = seed
cfg.train.max_grad_norm = 'Q'
# print(f"The config of this process is:\n{cfg}")
print(protein_fn, ligand_fn)
test_loader = get_dataloader_from_pdb(cfg)
# wandb_logger.log_hyperparams(cfg.todict())
model = SBDDTrainLoop(config=cfg)
trainer = pl.Trainer(
default_root_dir=cfg.accounting.logdir,
max_epochs=cfg.train.epochs,
check_val_every_n_epoch=cfg.train.ckpt_freq,
devices=1,
# logger=wandb_logger,
num_sanity_val_steps=0,
callbacks=[
NormalizerCallback(normalizer_dict=cfg.data.normalizer_dict),
DockingTestCallback(
dataset=None, # TODO: implement CrossDockGen & NewBenchmark
atom_decoder=cfg.data.atom_decoder,
atom_enc_mode=cfg.data.transform.ligand_atom_mode,
atom_type_one_hot=False,
single_bond=True,
docking_config=cfg.evaluation.docking_config,
),
],
)
trainer.test(model, dataloaders=test_loader, ckpt_path=cfg.evaluation.ckpt_path)
class Metrics:
def __init__(self, protein_fn, ref_ligand_fn, ligand_fn):
self.protein_fn = protein_fn
self.ref_ligand_fn = ref_ligand_fn
self.ligand_fn = ligand_fn
self.exhaustiveness = 16
def vina_dock(self, mol):
chem_results = {}
try:
# qed, logp, sa, lipinski, ring size, etc
chem_results.update(scoring_func.get_chem(mol))
chem_results['atom_num'] = mol.GetNumAtoms()
# docking
vina_task = VinaDockingTask.from_generated_mol(
mol, protein_filename=self.protein_fn)
score_only_results = vina_task.run(mode='score_only', exhaustiveness=self.exhaustiveness)
minimize_results = vina_task.run(mode='minimize', exhaustiveness=self.exhaustiveness)
docking_results = vina_task.run(mode='dock', exhaustiveness=self.exhaustiveness)
chem_results['vina_score'] = score_only_results[0]['affinity']
chem_results['vina_minimize'] = minimize_results[0]['affinity']
chem_results['vina_dock'] = docking_results[0]['affinity']
# chem_results['vina_dock_pose'] = docking_results[0]['pose']
return chem_results
except Exception as e:
print(e)
return chem_results
def pose_check(self, mol):
pc = PoseCheck()
pose_check_results = {}
protein_ready = False
try:
pc.load_protein_from_pdb(self.protein_fn)
protein_ready = True
except ValueError as e:
return pose_check_results
ligand_ready = False
try:
pc.load_ligands_from_mols([mol])
ligand_ready = True
except ValueError as e:
return pose_check_results
if ligand_ready:
try:
strain = pc.calculate_strain_energy()[0]
pose_check_results['strain'] = strain
except Exception as e:
pass
if protein_ready and ligand_ready:
try:
clash = pc.calculate_clashes()[0]
pose_check_results['clash'] = clash
except Exception as e:
pass
try:
df = pc.calculate_interactions()
columns = np.array([column[2] for column in df.columns])
flags = np.array([df[column][0] for column in df.columns])
def count_inter(inter_type):
if len(columns) == 0:
return 0
count = sum((columns == inter_type) & flags)
return count
# ['Hydrophobic', 'HBDonor', 'VdWContact', 'HBAcceptor']
hb_donor = count_inter('HBDonor')
hb_acceptor = count_inter('HBAcceptor')
vdw = count_inter('VdWContact')
hydrophobic = count_inter('Hydrophobic')
pose_check_results['hb_donor'] = hb_donor
pose_check_results['hb_acceptor'] = hb_acceptor
pose_check_results['vdw'] = vdw
pose_check_results['hydrophobic'] = hydrophobic
except Exception as e:
pass
for k, v in pose_check_results.items():
mol.SetProp(k, str(v))
return pose_check_results
def evaluate(self):
mol = Chem.SDMolSupplier(self.ligand_fn, removeHs=False)[0]
chem_results = self.vina_dock(mol)
pose_check_results = self.pose_check(mol)
chem_results.update(pose_check_results)
return chem_results