Spaces:
Running
on
T4
Running
on
T4
File size: 17,595 Bytes
4a3f787 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
import itertools
import math
import os
import pickle
import random
from argparse import Namespace
from functools import partial
import copy
import numpy as np
import pandas as pd
import torch
import yaml
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from datasets.pdbbind import PDBBind
from utils.diffusion_utils import get_t_schedule
from utils.sampling import randomize_position, sampling
from utils.utils import get_model
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl
class ListDataset(Dataset):
def __init__(self, list):
super().__init__()
self.data_list = list
def len(self) -> int:
return len(self.data_list)
def get(self, idx: int) -> Data:
return self.data_list[idx]
def get_cache_path(args, split):
cache_path = args.cache_path
if not args.no_torsion:
cache_path += '_torsion'
if args.all_atoms:
cache_path += '_allatoms'
split_path = args.split_train if split == 'train' else args.split_val
cache_path = os.path.join(cache_path, f'limit{args.limit_complexes}_INDEX{os.path.splitext(os.path.basename(split_path))[0]}_maxLigSize{args.max_lig_size}_H{int(not args.remove_hs)}_recRad{args.receptor_radius}_recMax{args.c_alpha_max_neighbors}'
+ ('' if not args.all_atoms else f'_atomRad{args.atom_radius}_atomMax{args.atom_max_neighbors}')
+ ('' if args.no_torsion or args.num_conformers == 1 else
f'_confs{args.num_conformers}')
+ ('' if args.esm_embeddings_path is None else f'_esmEmbeddings'))
return cache_path
def get_args_and_cache_path(original_model_dir, split):
with open(f'{original_model_dir}/model_parameters.yml') as f:
model_args = Namespace(**yaml.full_load(f))
return model_args, get_cache_path(model_args,split)
class ConfidenceDataset(Dataset):
def __init__(self, cache_path, original_model_dir, split, device, limit_complexes,
inference_steps, samples_per_complex, all_atoms,
args, balance=False, use_original_model_cache=True, rmsd_classification_cutoff=2,
cache_ids_to_combine= None, cache_creation_id=None):
super(ConfidenceDataset, self).__init__()
self.device = device
self.inference_steps = inference_steps
self.limit_complexes = limit_complexes
self.all_atoms = all_atoms
self.original_model_dir = original_model_dir
self.balance = balance
self.use_original_model_cache = use_original_model_cache
self.rmsd_classification_cutoff = rmsd_classification_cutoff
self.cache_ids_to_combine = cache_ids_to_combine
self.cache_creation_id = cache_creation_id
self.samples_per_complex = samples_per_complex
self.original_model_args, original_model_cache = get_args_and_cache_path(original_model_dir, split)
self.complex_graphs_cache = original_model_cache if self.use_original_model_cache else get_cache_path(args, split)
print('Using the cached complex graphs of the original model args' if self.use_original_model_cache else 'Not using the cached complex graphs of the original model args. Instead the complex graphs are used that are at the location given by the dataset parameters given to confidence_train.py')
print(self.complex_graphs_cache)
if not os.path.exists(os.path.join(self.complex_graphs_cache, "heterographs.pkl")):
print(f'HAPPENING | Complex graphs path does not exist yet: {os.path.join(self.complex_graphs_cache, "heterographs.pkl")}. For that reason, we are now creating the dataset.')
PDBBind(transform=None, root=args.data_dir, limit_complexes=args.limit_complexes,
receptor_radius=args.receptor_radius,
cache_path=args.cache_path, split_path=args.split_val if split == 'val' else args.split_train,
remove_hs=args.remove_hs, max_lig_size=None,
c_alpha_max_neighbors=args.c_alpha_max_neighbors,
matching=not args.no_torsion, keep_original=True,
popsize=args.matching_popsize,
maxiter=args.matching_maxiter,
all_atoms=args.all_atoms,
atom_radius=args.atom_radius,
atom_max_neighbors=args.atom_max_neighbors,
esm_embeddings_path=args.esm_embeddings_path,
require_ligand=True)
print(f'HAPPENING | Loading complex graphs from: {os.path.join(self.complex_graphs_cache, "heterographs.pkl")}')
with open(os.path.join(self.complex_graphs_cache, "heterographs.pkl"), 'rb') as f:
complex_graphs = pickle.load(f)
self.complex_graph_dict = {d.name: d for d in complex_graphs}
self.full_cache_path = os.path.join(cache_path, f'model_{os.path.splitext(os.path.basename(original_model_dir))[0]}'
f'_split_{split}_limit_{limit_complexes}')
if (not os.path.exists(os.path.join(self.full_cache_path, "ligand_positions.pkl")) and self.cache_creation_id is None) or \
(not os.path.exists(os.path.join(self.full_cache_path, f"ligand_positions_id{self.cache_creation_id}.pkl")) and self.cache_creation_id is not None):
os.makedirs(self.full_cache_path, exist_ok=True)
self.preprocessing(original_model_cache)
if self.cache_ids_to_combine is None:
print(f'HAPPENING | Loading positions and rmsds from: {os.path.join(self.full_cache_path, "ligand_positions.pkl")}')
with open(os.path.join(self.full_cache_path, "ligand_positions.pkl"), 'rb') as f:
self.full_ligand_positions, self.rmsds = pickle.load(f)
if os.path.exists(os.path.join(self.full_cache_path, "complex_names_in_same_order.pkl")):
with open(os.path.join(self.full_cache_path, "complex_names_in_same_order.pkl"), 'rb') as f:
generated_rmsd_complex_names = pickle.load(f)
else:
print('HAPPENING | The path, ', os.path.join(self.full_cache_path, "complex_names_in_same_order.pkl"),
' does not exist. \n => We assume that means that we are using a ligand_positions.pkl where the '
'code was not saving the complex names for them yet. We now instead use the complex names of '
'the dataset that the original model used to create the ligand positions and RMSDs.')
with open(os.path.join(original_model_cache, "heterographs.pkl"), 'rb') as f:
original_model_complex_graphs = pickle.load(f)
generated_rmsd_complex_names = [d.name for d in original_model_complex_graphs]
assert (len(self.rmsds) == len(generated_rmsd_complex_names))
else:
all_rmsds_unsorted, all_full_ligand_positions_unsorted, all_names_unsorted = [], [], []
for idx, cache_id in enumerate(self.cache_ids_to_combine):
print(f'HAPPENING | Loading positions and rmsds from cache_id from the path: {os.path.join(self.full_cache_path, "ligand_positions_"+ str(cache_id)+ ".pkl")}')
if not os.path.exists(os.path.join(self.full_cache_path, f"ligand_positions_id{cache_id}.pkl")): raise Exception(f'The generated ligand positions with cache_id do not exist: {cache_id}') # be careful with changing this error message since it is sometimes cought in a try catch
with open(os.path.join(self.full_cache_path, f"ligand_positions_id{cache_id}.pkl"), 'rb') as f:
full_ligand_positions, rmsds = pickle.load(f)
with open(os.path.join(self.full_cache_path, f"complex_names_in_same_order_id{cache_id}.pkl"), 'rb') as f:
names_unsorted = pickle.load(f)
all_names_unsorted.append(names_unsorted)
all_rmsds_unsorted.append(rmsds)
all_full_ligand_positions_unsorted.append(full_ligand_positions)
names_order = list(set(sum(all_names_unsorted, [])))
all_rmsds, all_full_ligand_positions, all_names = [], [], []
for idx, (rmsds_unsorted, full_ligand_positions_unsorted, names_unsorted) in enumerate(zip(all_rmsds_unsorted,all_full_ligand_positions_unsorted, all_names_unsorted)):
name_to_pos_dict = {name: (rmsd, pos) for name, rmsd, pos in zip(names_unsorted, full_ligand_positions_unsorted, rmsds_unsorted) }
intermediate_rmsds = [name_to_pos_dict[name][1] for name in names_order]
all_rmsds.append((intermediate_rmsds))
intermediate_pos = [name_to_pos_dict[name][0] for name in names_order]
all_full_ligand_positions.append((intermediate_pos))
self.full_ligand_positions, self.rmsds = [], []
for positions_tuple in list(zip(*all_full_ligand_positions)):
self.full_ligand_positions.append(np.concatenate(positions_tuple, axis=0))
for positions_tuple in list(zip(*all_rmsds)):
self.rmsds.append(np.concatenate(positions_tuple, axis=0))
generated_rmsd_complex_names = names_order
print('Number of complex graphs: ', len(self.complex_graph_dict))
print('Number of RMSDs and positions for the complex graphs: ', len(self.full_ligand_positions))
self.all_samples_per_complex = samples_per_complex * (1 if self.cache_ids_to_combine is None else len(self.cache_ids_to_combine))
self.positions_rmsds_dict = {name: (pos, rmsd) for name, pos, rmsd in zip (generated_rmsd_complex_names, self.full_ligand_positions, self.rmsds)}
self.dataset_names = list(set(self.positions_rmsds_dict.keys()) & set(self.complex_graph_dict.keys()))
if limit_complexes > 0:
self.dataset_names = self.dataset_names[:limit_complexes]
def len(self):
return len(self.dataset_names)
def get(self, idx):
complex_graph = copy.deepcopy(self.complex_graph_dict[self.dataset_names[idx]])
positions, rmsds = self.positions_rmsds_dict[self.dataset_names[idx]]
if self.balance:
if isinstance(self.rmsd_classification_cutoff, list): raise ValueError("a list for --rmsd_classification_cutoff can only be used without --balance")
label = random.randint(0, 1)
success = rmsds < self.rmsd_classification_cutoff
n_success = np.count_nonzero(success)
if label == 0 and n_success != self.all_samples_per_complex:
# sample negative complex
sample = random.randint(0, self.all_samples_per_complex - n_success - 1)
lig_pos = positions[~success][sample]
complex_graph['ligand'].pos = torch.from_numpy(lig_pos)
else:
# sample positive complex
if n_success > 0: # if no successfull sample returns the matched complex
sample = random.randint(0, n_success - 1)
lig_pos = positions[success][sample]
complex_graph['ligand'].pos = torch.from_numpy(lig_pos)
complex_graph.y = torch.tensor(label).float()
else:
sample = random.randint(0, self.all_samples_per_complex - 1)
complex_graph['ligand'].pos = torch.from_numpy(positions[sample])
complex_graph.y = torch.tensor(rmsds[sample] < self.rmsd_classification_cutoff).float().unsqueeze(0)
if isinstance(self.rmsd_classification_cutoff, list):
complex_graph.y_binned = torch.tensor(np.logical_and(rmsds[sample] < self.rmsd_classification_cutoff + [math.inf],rmsds[sample] >= [0] + self.rmsd_classification_cutoff), dtype=torch.float).unsqueeze(0)
complex_graph.y = torch.tensor(rmsds[sample] < self.rmsd_classification_cutoff[0]).unsqueeze(0).float()
complex_graph.rmsd = torch.tensor(rmsds[sample]).unsqueeze(0).float()
complex_graph['ligand'].node_t = {'tr': 0 * torch.ones(complex_graph['ligand'].num_nodes),
'rot': 0 * torch.ones(complex_graph['ligand'].num_nodes),
'tor': 0 * torch.ones(complex_graph['ligand'].num_nodes)}
complex_graph['receptor'].node_t = {'tr': 0 * torch.ones(complex_graph['receptor'].num_nodes),
'rot': 0 * torch.ones(complex_graph['receptor'].num_nodes),
'tor': 0 * torch.ones(complex_graph['receptor'].num_nodes)}
if self.all_atoms:
complex_graph['atom'].node_t = {'tr': 0 * torch.ones(complex_graph['atom'].num_nodes),
'rot': 0 * torch.ones(complex_graph['atom'].num_nodes),
'tor': 0 * torch.ones(complex_graph['atom'].num_nodes)}
complex_graph.complex_t = {'tr': 0 * torch.ones(1), 'rot': 0 * torch.ones(1), 'tor': 0 * torch.ones(1)}
return complex_graph
def preprocessing(self, original_model_cache):
t_to_sigma = partial(t_to_sigma_compl, args=self.original_model_args)
model = get_model(self.original_model_args, self.device, t_to_sigma=t_to_sigma, no_parallel=True)
state_dict = torch.load(f'{self.original_model_dir}/best_model.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=True)
model = model.to(self.device)
model.eval()
tr_schedule = get_t_schedule(inference_steps=self.inference_steps)
rot_schedule = tr_schedule
tor_schedule = tr_schedule
print('common t schedule', tr_schedule)
print('HAPPENING | loading cached complexes of the original model to create the confidence dataset RMSDs and predicted positions. Doing that from: ', os.path.join(self.complex_graphs_cache, "heterographs.pkl"))
with open(os.path.join(original_model_cache, "heterographs.pkl"), 'rb') as f:
complex_graphs = pickle.load(f)
dataset = ListDataset(complex_graphs)
loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
rmsds, full_ligand_positions, names = [], [], []
for idx, orig_complex_graph in tqdm(enumerate(loader)):
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(self.samples_per_complex)]
randomize_position(data_list, self.original_model_args.no_torsion, False, self.original_model_args.tr_sigma_max)
predictions_list = None
failed_convergence_counter = 0
while predictions_list is None:
try:
predictions_list, confidences = sampling(data_list=data_list, model=model, inference_steps=self.inference_steps,
tr_schedule=tr_schedule, rot_schedule=rot_schedule, tor_schedule=tor_schedule,
device=self.device, t_to_sigma=t_to_sigma, model_args=self.original_model_args)
except Exception as e:
if 'failed to converge' in str(e):
failed_convergence_counter += 1
if failed_convergence_counter > 5:
print('| WARNING: SVD failed to converge 5 times - skipping the complex')
break
print('| WARNING: SVD failed to converge - trying again with a new sample')
else:
raise e
if failed_convergence_counter > 5: predictions_list = data_list
if self.original_model_args.no_torsion:
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy())
filterHs = torch.not_equal(predictions_list[0]['ligand'].x[:, 0], 0).cpu().numpy()
if isinstance(orig_complex_graph['ligand'].orig_pos, list):
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0]
ligand_pos = np.asarray([complex_graph['ligand'].pos.cpu().numpy()[filterHs] for complex_graph in predictions_list])
orig_ligand_pos = np.expand_dims(orig_complex_graph['ligand'].orig_pos[filterHs] - orig_complex_graph.original_center.cpu().numpy(), axis=0)
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
rmsds.append(rmsd)
full_ligand_positions.append(np.asarray([complex_graph['ligand'].pos.cpu().numpy() for complex_graph in predictions_list]))
names.append(orig_complex_graph.name[0])
assert(len(orig_complex_graph.name) == 1) # I just put this assert here because of the above line where I assumed that the list is always only lenght 1. Just in case it isn't maybe check what the names in there are.
with open(os.path.join(self.full_cache_path, f"ligand_positions{'' if self.cache_creation_id is None else '_id' + str(self.cache_creation_id)}.pkl"), 'wb') as f:
pickle.dump((full_ligand_positions, rmsds), f)
with open(os.path.join(self.full_cache_path, f"complex_names_in_same_order{'' if self.cache_creation_id is None else '_id' + str(self.cache_creation_id)}.pkl"), 'wb') as f:
pickle.dump((names), f)
|