Spaces:
Sleeping
Sleeping
Simon Duerr
commited on
Commit
•
486fd8a
1
Parent(s):
3e6dce4
gradio update
Browse files- app.py +461 -0
- datasets/esm_embedding_preparation.py +73 -72
- datasets/pdbbind.py +432 -133
- datasets/process_mols.py +6 -1
- examples/1a46_ligand.sdf +179 -0
- examples/1a46_protein_processed.pdb +0 -0
- examples/1cbr_ligand.sdf +119 -0
- examples/1cbr_protein.pdb +0 -0
- requirements.txt +29 -0
app.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
import copy
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import time
|
9 |
+
from argparse import ArgumentParser, Namespace, FileType
|
10 |
+
from rdkit.Chem import RemoveHs
|
11 |
+
from functools import partial
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
from rdkit import RDLogger
|
15 |
+
from rdkit.Chem import MolFromSmiles, AddHs
|
16 |
+
from torch_geometric.loader import DataLoader
|
17 |
+
import yaml
|
18 |
+
|
19 |
+
from datasets.process_mols import (
|
20 |
+
read_molecule,
|
21 |
+
generate_conformer,
|
22 |
+
write_mol_with_coords,
|
23 |
+
)
|
24 |
+
from datasets.pdbbind import PDBBind
|
25 |
+
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule
|
26 |
+
from utils.sampling import randomize_position, sampling
|
27 |
+
from utils.utils import get_model
|
28 |
+
from utils.visualise import PDBFile
|
29 |
+
from tqdm import tqdm
|
30 |
+
from datasets.esm_embedding_preparation import esm_embedding_prep
|
31 |
+
import subprocess
|
32 |
+
|
33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
|
35 |
+
with open(f"workdir/paper_score_model/model_parameters.yml") as f:
|
36 |
+
score_model_args = Namespace(**yaml.full_load(f))
|
37 |
+
|
38 |
+
with open(f"workdir/paper_confidence_model/model_parameters.yml") as f:
|
39 |
+
confidence_args = Namespace(**yaml.full_load(f))
|
40 |
+
|
41 |
+
t_to_sigma = partial(t_to_sigma_compl, args=score_model_args)
|
42 |
+
|
43 |
+
model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True)
|
44 |
+
state_dict = torch.load(
|
45 |
+
f"workdir/paper_score_model/best_ema_inference_epoch_model.pt",
|
46 |
+
map_location=torch.device("cpu"),
|
47 |
+
)
|
48 |
+
model.load_state_dict(state_dict, strict=True)
|
49 |
+
model = model.to(device)
|
50 |
+
model.eval()
|
51 |
+
|
52 |
+
confidence_model = get_model(
|
53 |
+
confidence_args,
|
54 |
+
device,
|
55 |
+
t_to_sigma=t_to_sigma,
|
56 |
+
no_parallel=True,
|
57 |
+
confidence_mode=True,
|
58 |
+
)
|
59 |
+
state_dict = torch.load(
|
60 |
+
f"workdir/paper_confidence_model/best_model_epoch75.pt",
|
61 |
+
map_location=torch.device("cpu"),
|
62 |
+
)
|
63 |
+
confidence_model.load_state_dict(state_dict, strict=True)
|
64 |
+
confidence_model = confidence_model.to(device)
|
65 |
+
confidence_model.eval()
|
66 |
+
tr_schedule = get_t_schedule(inference_steps=10)
|
67 |
+
rot_schedule = tr_schedule
|
68 |
+
tor_schedule = tr_schedule
|
69 |
+
print("common t schedule", tr_schedule)
|
70 |
+
failures, skipped, confidences_list, names_list, run_times, min_self_distances_list = (
|
71 |
+
0,
|
72 |
+
0,
|
73 |
+
[],
|
74 |
+
[],
|
75 |
+
[],
|
76 |
+
[],
|
77 |
+
)
|
78 |
+
N = 10
|
79 |
+
|
80 |
+
|
81 |
+
def get_pdb(pdb_code="", filepath=""):
|
82 |
+
if pdb_code is None or pdb_code == "":
|
83 |
+
try:
|
84 |
+
return filepath.name
|
85 |
+
except AttributeError as e:
|
86 |
+
return None
|
87 |
+
else:
|
88 |
+
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
|
89 |
+
return f"{pdb_code}.pdb"
|
90 |
+
|
91 |
+
|
92 |
+
def get_ligand(smiles="", filepath=""):
|
93 |
+
if smiles is None or smiles == "":
|
94 |
+
try:
|
95 |
+
return filepath.name
|
96 |
+
except AttributeError as e:
|
97 |
+
return None
|
98 |
+
else:
|
99 |
+
return smiles
|
100 |
+
|
101 |
+
|
102 |
+
def read_mol(molpath):
|
103 |
+
with open(molpath, "r") as fp:
|
104 |
+
lines = fp.readlines()
|
105 |
+
mol = ""
|
106 |
+
for l in lines:
|
107 |
+
mol += l
|
108 |
+
return mol
|
109 |
+
|
110 |
+
|
111 |
+
def molecule(input_pdb, ligand_pdb):
|
112 |
+
|
113 |
+
structure = read_mol(input_pdb)
|
114 |
+
mol = read_mol(ligand_pdb)
|
115 |
+
|
116 |
+
x = (
|
117 |
+
"""<!DOCTYPE html>
|
118 |
+
<html>
|
119 |
+
<head>
|
120 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
121 |
+
<style>
|
122 |
+
body{
|
123 |
+
font-family:sans-serif
|
124 |
+
}
|
125 |
+
.mol-container {
|
126 |
+
width: 600px;
|
127 |
+
height: 600px;
|
128 |
+
position: relative;
|
129 |
+
mx-auto:0
|
130 |
+
}
|
131 |
+
.mol-container select{
|
132 |
+
background-image:None;
|
133 |
+
}
|
134 |
+
</style>
|
135 |
+
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
|
136 |
+
</head>
|
137 |
+
<body>
|
138 |
+
<button id="startanimation">Replay diffusion process</button>
|
139 |
+
<div id="container" class="mol-container"></div>
|
140 |
+
|
141 |
+
<script>
|
142 |
+
let ligand = `"""
|
143 |
+
+ mol
|
144 |
+
+ """`
|
145 |
+
let structure = `"""
|
146 |
+
+ structure
|
147 |
+
+ """`
|
148 |
+
|
149 |
+
let viewer = null;
|
150 |
+
|
151 |
+
$(document).ready(function () {
|
152 |
+
let element = $("#container");
|
153 |
+
let config = { backgroundColor: "white" };
|
154 |
+
viewer = $3Dmol.createViewer(element, config);
|
155 |
+
viewer.addModel( structure, "pdb" );
|
156 |
+
viewer.setStyle({}, {cartoon: {color: "gray"}});
|
157 |
+
viewer.zoomTo();
|
158 |
+
viewer.zoom(0.7);
|
159 |
+
viewer.addModelsAsFrames(ligand, "pdb");
|
160 |
+
viewer.animate({loop: "forward",reps: 1});
|
161 |
+
|
162 |
+
viewer.getModel(1).setStyle({stick:{colorscheme:"magentaCarbon"}});
|
163 |
+
viewer.render();
|
164 |
+
|
165 |
+
})
|
166 |
+
|
167 |
+
$("#startanimation").click(function() {
|
168 |
+
viewer.animate({loop: "forward",reps: 1});
|
169 |
+
});
|
170 |
+
</script>
|
171 |
+
</body></html>"""
|
172 |
+
)
|
173 |
+
|
174 |
+
return f"""<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera;
|
175 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
176 |
+
allow-scripts allow-same-origin allow-popups
|
177 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
178 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
|
179 |
+
|
180 |
+
|
181 |
+
def esm(protein_path, out_file):
|
182 |
+
esm_embedding_prep(out_file, protein_path)
|
183 |
+
# create args object with defaults
|
184 |
+
os.environ["HOME"] = "esm/model_weights"
|
185 |
+
|
186 |
+
subprocess.call(
|
187 |
+
f"python esm/scripts/extract.py esm2_t33_650M_UR50D {out_file} data/esm2_output --repr_layers 33 --include per_tok",
|
188 |
+
shell=True,
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
def update(inp, file, ligand_inp, ligand_file):
|
193 |
+
pdb_path = get_pdb(inp, file)
|
194 |
+
ligand_path = get_ligand(ligand_inp, ligand_file)
|
195 |
+
|
196 |
+
esm(
|
197 |
+
pdb_path,
|
198 |
+
f"data/{os.path.basename(pdb_path)}_prepared_for_esm.fasta",
|
199 |
+
)
|
200 |
+
|
201 |
+
protein_path_list = [pdb_path]
|
202 |
+
ligand_descriptions = [ligand_path]
|
203 |
+
no_random = False
|
204 |
+
ode = False
|
205 |
+
no_final_step_noise = False
|
206 |
+
out_dir = "results/test"
|
207 |
+
test_dataset = PDBBind(
|
208 |
+
transform=None,
|
209 |
+
root="",
|
210 |
+
protein_path_list=protein_path_list,
|
211 |
+
ligand_descriptions=ligand_descriptions,
|
212 |
+
receptor_radius=score_model_args.receptor_radius,
|
213 |
+
cache_path="data/cache",
|
214 |
+
remove_hs=score_model_args.remove_hs,
|
215 |
+
max_lig_size=None,
|
216 |
+
c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors,
|
217 |
+
matching=False,
|
218 |
+
keep_original=False,
|
219 |
+
popsize=score_model_args.matching_popsize,
|
220 |
+
maxiter=score_model_args.matching_maxiter,
|
221 |
+
all_atoms=score_model_args.all_atoms,
|
222 |
+
atom_radius=score_model_args.atom_radius,
|
223 |
+
atom_max_neighbors=score_model_args.atom_max_neighbors,
|
224 |
+
esm_embeddings_path="data/esm2_output",
|
225 |
+
require_ligand=True,
|
226 |
+
num_workers=1,
|
227 |
+
keep_local_structures=False,
|
228 |
+
)
|
229 |
+
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
|
230 |
+
confidence_test_dataset = PDBBind(
|
231 |
+
transform=None,
|
232 |
+
root="",
|
233 |
+
protein_path_list=protein_path_list,
|
234 |
+
ligand_descriptions=ligand_descriptions,
|
235 |
+
receptor_radius=confidence_args.receptor_radius,
|
236 |
+
cache_path="data/cache",
|
237 |
+
remove_hs=confidence_args.remove_hs,
|
238 |
+
max_lig_size=None,
|
239 |
+
c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors,
|
240 |
+
matching=False,
|
241 |
+
keep_original=False,
|
242 |
+
popsize=confidence_args.matching_popsize,
|
243 |
+
maxiter=confidence_args.matching_maxiter,
|
244 |
+
all_atoms=confidence_args.all_atoms,
|
245 |
+
atom_radius=confidence_args.atom_radius,
|
246 |
+
atom_max_neighbors=confidence_args.atom_max_neighbors,
|
247 |
+
esm_embeddings_path="data/esm2_output",
|
248 |
+
require_ligand=True,
|
249 |
+
num_workers=1,
|
250 |
+
)
|
251 |
+
confidence_complex_dict = {d.name: d for d in confidence_test_dataset}
|
252 |
+
for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
|
253 |
+
if (
|
254 |
+
confidence_model is not None
|
255 |
+
and not (
|
256 |
+
confidence_args.use_original_model_cache
|
257 |
+
or confidence_args.transfer_weights
|
258 |
+
)
|
259 |
+
and orig_complex_graph.name[0] not in confidence_complex_dict.keys()
|
260 |
+
):
|
261 |
+
skipped += 1
|
262 |
+
print(
|
263 |
+
f"HAPPENING | The confidence dataset did not contain {orig_complex_graph.name[0]}. We are skipping this complex."
|
264 |
+
)
|
265 |
+
continue
|
266 |
+
try:
|
267 |
+
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(N)]
|
268 |
+
randomize_position(
|
269 |
+
data_list,
|
270 |
+
score_model_args.no_torsion,
|
271 |
+
no_random,
|
272 |
+
score_model_args.tr_sigma_max,
|
273 |
+
)
|
274 |
+
pdb = None
|
275 |
+
lig = orig_complex_graph.mol[0]
|
276 |
+
visualization_list = []
|
277 |
+
for graph in data_list:
|
278 |
+
pdb = PDBFile(lig)
|
279 |
+
pdb.add(lig, 0, 0)
|
280 |
+
pdb.add(
|
281 |
+
(
|
282 |
+
orig_complex_graph["ligand"].pos
|
283 |
+
+ orig_complex_graph.original_center
|
284 |
+
)
|
285 |
+
.detach()
|
286 |
+
.cpu(),
|
287 |
+
1,
|
288 |
+
0,
|
289 |
+
)
|
290 |
+
pdb.add(
|
291 |
+
(graph["ligand"].pos + graph.original_center).detach().cpu(),
|
292 |
+
part=1,
|
293 |
+
order=1,
|
294 |
+
)
|
295 |
+
visualization_list.append(pdb)
|
296 |
+
|
297 |
+
start_time = time.time()
|
298 |
+
if confidence_model is not None and not (
|
299 |
+
confidence_args.use_original_model_cache
|
300 |
+
or confidence_args.transfer_weights
|
301 |
+
):
|
302 |
+
confidence_data_list = [
|
303 |
+
copy.deepcopy(confidence_complex_dict[orig_complex_graph.name[0]])
|
304 |
+
for _ in range(N)
|
305 |
+
]
|
306 |
+
else:
|
307 |
+
confidence_data_list = None
|
308 |
+
|
309 |
+
data_list, confidence = sampling(
|
310 |
+
data_list=data_list,
|
311 |
+
model=model,
|
312 |
+
inference_steps=10,
|
313 |
+
tr_schedule=tr_schedule,
|
314 |
+
rot_schedule=rot_schedule,
|
315 |
+
tor_schedule=tor_schedule,
|
316 |
+
device=device,
|
317 |
+
t_to_sigma=t_to_sigma,
|
318 |
+
model_args=score_model_args,
|
319 |
+
no_random=no_random,
|
320 |
+
ode=ode,
|
321 |
+
visualization_list=visualization_list,
|
322 |
+
confidence_model=confidence_model,
|
323 |
+
confidence_data_list=confidence_data_list,
|
324 |
+
confidence_model_args=confidence_args,
|
325 |
+
batch_size=1,
|
326 |
+
no_final_step_noise=no_final_step_noise,
|
327 |
+
)
|
328 |
+
ligand_pos = np.asarray(
|
329 |
+
[
|
330 |
+
complex_graph["ligand"].pos.cpu().numpy()
|
331 |
+
+ orig_complex_graph.original_center.cpu().numpy()
|
332 |
+
for complex_graph in data_list
|
333 |
+
]
|
334 |
+
)
|
335 |
+
run_times.append(time.time() - start_time)
|
336 |
+
|
337 |
+
if confidence is not None and isinstance(
|
338 |
+
confidence_args.rmsd_classification_cutoff, list
|
339 |
+
):
|
340 |
+
confidence = confidence[:, 0]
|
341 |
+
if confidence is not None:
|
342 |
+
confidence = confidence.cpu().numpy()
|
343 |
+
re_order = np.argsort(confidence)[::-1]
|
344 |
+
confidence = confidence[re_order]
|
345 |
+
confidences_list.append(confidence)
|
346 |
+
ligand_pos = ligand_pos[re_order]
|
347 |
+
write_dir = (
|
348 |
+
f'{out_dir}/index{idx}_{data_list[0]["name"][0].replace("/","-")}'
|
349 |
+
)
|
350 |
+
os.makedirs(write_dir, exist_ok=True)
|
351 |
+
for rank, pos in enumerate(ligand_pos):
|
352 |
+
mol_pred = copy.deepcopy(lig)
|
353 |
+
if score_model_args.remove_hs:
|
354 |
+
mol_pred = RemoveHs(mol_pred)
|
355 |
+
if rank == 0:
|
356 |
+
write_mol_with_coords(
|
357 |
+
mol_pred, pos, os.path.join(write_dir, f"rank{rank+1}.sdf")
|
358 |
+
)
|
359 |
+
write_mol_with_coords(
|
360 |
+
mol_pred,
|
361 |
+
pos,
|
362 |
+
os.path.join(
|
363 |
+
write_dir, f"rank{rank+1}_confidence{confidence[rank]:.2f}.sdf"
|
364 |
+
),
|
365 |
+
)
|
366 |
+
self_distances = np.linalg.norm(
|
367 |
+
ligand_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1
|
368 |
+
)
|
369 |
+
self_distances = np.where(
|
370 |
+
np.eye(self_distances.shape[2]), np.inf, self_distances
|
371 |
+
)
|
372 |
+
min_self_distances_list.append(np.min(self_distances, axis=(1, 2)))
|
373 |
+
|
374 |
+
filenames = []
|
375 |
+
if confidence is not None:
|
376 |
+
for rank, batch_idx in enumerate(re_order):
|
377 |
+
visualization_list[batch_idx].write(
|
378 |
+
os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb")
|
379 |
+
)
|
380 |
+
filenames.append(
|
381 |
+
os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb")
|
382 |
+
)
|
383 |
+
else:
|
384 |
+
for rank, batch_idx in enumerate(ligand_pos):
|
385 |
+
visualization_list[batch_idx].write(
|
386 |
+
os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb")
|
387 |
+
)
|
388 |
+
filenames.append(
|
389 |
+
os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb")
|
390 |
+
)
|
391 |
+
names_list.append(orig_complex_graph.name[0])
|
392 |
+
except Exception as e:
|
393 |
+
print("Failed on", orig_complex_graph["name"], e)
|
394 |
+
failures += 1
|
395 |
+
return None
|
396 |
+
|
397 |
+
labels = [f"rank {i+1}" for i in range(len(filenames))]
|
398 |
+
return (
|
399 |
+
molecule(pdb_path, filenames[0]),
|
400 |
+
gr.Dropdown.update(choices=labels, value="rank 1"),
|
401 |
+
filenames,
|
402 |
+
pdb_path,
|
403 |
+
)
|
404 |
+
|
405 |
+
|
406 |
+
def updateView(out, filenames, pdb):
|
407 |
+
i = int(out.replace("rank", ""))
|
408 |
+
return molecule(pdb, filenames[i])
|
409 |
+
|
410 |
+
|
411 |
+
demo = gr.Blocks()
|
412 |
+
|
413 |
+
with demo:
|
414 |
+
gr.Markdown("# DiffDock")
|
415 |
+
gr.Markdown(
|
416 |
+
">**DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking**, Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi, arXiv:2210.01776 [GitHub](https://github.com/gcorso/diffdock)"
|
417 |
+
)
|
418 |
+
gr.Markdown("Runs the diffusion model `10` times with `10` inference steps")
|
419 |
+
with gr.Box():
|
420 |
+
with gr.Row():
|
421 |
+
with gr.Column():
|
422 |
+
gr.Markdown("## Protein")
|
423 |
+
inp = gr.Textbox(
|
424 |
+
placeholder="PDB Code or upload file below", label="Input structure"
|
425 |
+
)
|
426 |
+
file = gr.File(file_count="single", label="Input PDB")
|
427 |
+
with gr.Column():
|
428 |
+
gr.Markdown("## Ligand")
|
429 |
+
ligand_inp = gr.Textbox(
|
430 |
+
placeholder="Provide SMILES input or upload mol2/sdf file below",
|
431 |
+
label="SMILES string",
|
432 |
+
)
|
433 |
+
ligand_file = gr.File(file_count="single", label="Input Ligand")
|
434 |
+
|
435 |
+
btn = gr.Button("Run predictions")
|
436 |
+
|
437 |
+
gr.Markdown("## Output")
|
438 |
+
pdb = gr.Variable()
|
439 |
+
filenames = gr.Variable()
|
440 |
+
out = gr.Dropdown(interactive=True, label="Ranked samples")
|
441 |
+
mol = gr.HTML()
|
442 |
+
gr.Examples(
|
443 |
+
[
|
444 |
+
[
|
445 |
+
None,
|
446 |
+
"examples/1a46_protein_processed.pdb",
|
447 |
+
None,
|
448 |
+
"examples/1a46_ligand.sdf",
|
449 |
+
]
|
450 |
+
],
|
451 |
+
[inp, file, ligand_inp, ligand_file],
|
452 |
+
[mol, out],
|
453 |
+
# cache_examples=True,
|
454 |
+
)
|
455 |
+
btn.click(
|
456 |
+
fn=update,
|
457 |
+
inputs=[inp, file, ligand_inp, ligand_file],
|
458 |
+
outputs=[mol, out, filenames, pdb],
|
459 |
+
)
|
460 |
+
out.change(fn=updateView, inputs=[out, filenames, pdb], outputs=mol)
|
461 |
+
demo.launch()
|
datasets/esm_embedding_preparation.py
CHANGED
@@ -9,79 +9,80 @@ from Bio.SeqRecord import SeqRecord
|
|
9 |
from tqdm import tqdm
|
10 |
from Bio import SeqIO
|
11 |
|
12 |
-
parser = ArgumentParser()
|
13 |
-
parser.add_argument('--out_file', type=str, default="data/prepared_for_esm.fasta")
|
14 |
-
parser.add_argument('--protein_ligand_csv', type=str, default='data/protein_ligand_example_csv.csv', help='Path to a .csv specifying the input as described in the main README')
|
15 |
-
parser.add_argument('--protein_path', type=str, default=None, help='Path to a single PDB file. If this is not None then it will be used instead of the --protein_ligand_csv')
|
16 |
-
args = parser.parse_args()
|
17 |
|
18 |
-
biopython_parser = PDBParser()
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
'ASN': 'N',
|
23 |
-
'ASP': 'D',
|
24 |
-
'CYS': 'C',
|
25 |
-
'GLN': 'Q',
|
26 |
-
'GLU': 'E',
|
27 |
-
'GLY': 'G',
|
28 |
-
'HIS': 'H',
|
29 |
-
'ILE': 'I',
|
30 |
-
'LEU': 'L',
|
31 |
-
'LYS': 'K',
|
32 |
-
'MET': 'M',
|
33 |
-
'MSE': 'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
|
34 |
-
'PHE': 'F',
|
35 |
-
'PRO': 'P',
|
36 |
-
'PYL': 'O',
|
37 |
-
'SER': 'S',
|
38 |
-
'SEC': 'U',
|
39 |
-
'THR': 'T',
|
40 |
-
'TRP': 'W',
|
41 |
-
'TYR': 'Y',
|
42 |
-
'VAL': 'V',
|
43 |
-
'ASX': 'B',
|
44 |
-
'GLX': 'Z',
|
45 |
-
'XAA': 'X',
|
46 |
-
'XLE': 'J'}
|
47 |
-
|
48 |
-
if args.protein_path is not None:
|
49 |
-
file_paths = [args.protein_path]
|
50 |
-
else:
|
51 |
-
df = pd.read_csv(args.protein_ligand_csv)
|
52 |
-
file_paths = list(set(df['protein_path'].tolist()))
|
53 |
-
sequences = []
|
54 |
-
ids = []
|
55 |
-
for file_path in tqdm(file_paths):
|
56 |
-
structure = biopython_parser.get_structure('random_id', file_path)
|
57 |
-
structure = structure[0]
|
58 |
-
for i, chain in enumerate(structure):
|
59 |
-
seq = ''
|
60 |
-
for res_idx, residue in enumerate(chain):
|
61 |
-
if residue.get_resname() == 'HOH':
|
62 |
-
continue
|
63 |
-
residue_coords = []
|
64 |
-
c_alpha, n, c = None, None, None
|
65 |
-
for atom in residue:
|
66 |
-
if atom.name == 'CA':
|
67 |
-
c_alpha = list(atom.get_vector())
|
68 |
-
if atom.name == 'N':
|
69 |
-
n = list(atom.get_vector())
|
70 |
-
if atom.name == 'C':
|
71 |
-
c = list(atom.get_vector())
|
72 |
-
if c_alpha != None and n != None and c != None: # only append residue if it is an amino acid
|
73 |
-
try:
|
74 |
-
seq += three_to_one[residue.get_resname()]
|
75 |
-
except Exception as e:
|
76 |
-
seq += '-'
|
77 |
-
print("encountered unknown AA: ", residue.get_resname(), ' in the complex ', file_path, '. Replacing it with a dash - .')
|
78 |
-
sequences.append(seq)
|
79 |
-
ids.append(f'{os.path.basename(file_path)}_chain_{i}')
|
80 |
-
records = []
|
81 |
-
for (index, seq) in zip(ids,sequences):
|
82 |
-
record = SeqRecord(Seq(seq), str(index))
|
83 |
-
record.description = ''
|
84 |
-
records.append(record)
|
85 |
-
SeqIO.write(records, args.out_file, "fasta")
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from tqdm import tqdm
|
10 |
from Bio import SeqIO
|
11 |
|
|
|
|
|
|
|
|
|
|
|
12 |
|
|
|
13 |
|
14 |
+
def esm_embedding_prep(out_file, protein_path):
|
15 |
+
biopython_parser = PDBParser()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
three_to_one = {
|
18 |
+
"ALA": "A",
|
19 |
+
"ARG": "R",
|
20 |
+
"ASN": "N",
|
21 |
+
"ASP": "D",
|
22 |
+
"CYS": "C",
|
23 |
+
"GLN": "Q",
|
24 |
+
"GLU": "E",
|
25 |
+
"GLY": "G",
|
26 |
+
"HIS": "H",
|
27 |
+
"ILE": "I",
|
28 |
+
"LEU": "L",
|
29 |
+
"LYS": "K",
|
30 |
+
"MET": "M",
|
31 |
+
"MSE": "M", # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
|
32 |
+
"PHE": "F",
|
33 |
+
"PRO": "P",
|
34 |
+
"PYL": "O",
|
35 |
+
"SER": "S",
|
36 |
+
"SEC": "U",
|
37 |
+
"THR": "T",
|
38 |
+
"TRP": "W",
|
39 |
+
"TYR": "Y",
|
40 |
+
"VAL": "V",
|
41 |
+
"ASX": "B",
|
42 |
+
"GLX": "Z",
|
43 |
+
"XAA": "X",
|
44 |
+
"XLE": "J",
|
45 |
+
}
|
46 |
|
47 |
+
file_paths = [protein_path]
|
48 |
+
sequences = []
|
49 |
+
ids = []
|
50 |
+
for file_path in tqdm(file_paths):
|
51 |
+
structure = biopython_parser.get_structure("random_id", file_path)
|
52 |
+
structure = structure[0]
|
53 |
+
for i, chain in enumerate(structure):
|
54 |
+
seq = ""
|
55 |
+
for res_idx, residue in enumerate(chain):
|
56 |
+
if residue.get_resname() == "HOH":
|
57 |
+
continue
|
58 |
+
residue_coords = []
|
59 |
+
c_alpha, n, c = None, None, None
|
60 |
+
for atom in residue:
|
61 |
+
if atom.name == "CA":
|
62 |
+
c_alpha = list(atom.get_vector())
|
63 |
+
if atom.name == "N":
|
64 |
+
n = list(atom.get_vector())
|
65 |
+
if atom.name == "C":
|
66 |
+
c = list(atom.get_vector())
|
67 |
+
if (
|
68 |
+
c_alpha != None and n != None and c != None
|
69 |
+
): # only append residue if it is an amino acid
|
70 |
+
try:
|
71 |
+
seq += three_to_one[residue.get_resname()]
|
72 |
+
except Exception as e:
|
73 |
+
seq += "-"
|
74 |
+
print(
|
75 |
+
"encountered unknown AA: ",
|
76 |
+
residue.get_resname(),
|
77 |
+
" in the complex ",
|
78 |
+
file_path,
|
79 |
+
". Replacing it with a dash - .",
|
80 |
+
)
|
81 |
+
sequences.append(seq)
|
82 |
+
ids.append(f"{os.path.basename(file_path)}_chain_{i}")
|
83 |
+
records = []
|
84 |
+
for (index, seq) in zip(ids, sequences):
|
85 |
+
record = SeqRecord(Seq(seq), str(index))
|
86 |
+
record.description = ""
|
87 |
+
records.append(record)
|
88 |
+
SeqIO.write(records, out_file, "fasta")
|
datasets/pdbbind.py
CHANGED
@@ -16,8 +16,15 @@ from torch_geometric.loader import DataLoader, DataListLoader
|
|
16 |
from torch_geometric.transforms import BaseTransform
|
17 |
from tqdm import tqdm
|
18 |
|
19 |
-
from datasets.process_mols import
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
from utils.diffusion_utils import modify_conformer, set_time
|
22 |
from utils.utils import read_strings_from_txt
|
23 |
from utils import so3, torus
|
@@ -34,32 +41,87 @@ class NoiseTransform(BaseTransform):
|
|
34 |
t_tr, t_rot, t_tor = t, t, t
|
35 |
return self.apply_noise(data, t_tr, t_rot, t_tor)
|
36 |
|
37 |
-
def apply_noise(
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor)
|
42 |
set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None)
|
43 |
|
44 |
-
tr_update =
|
|
|
|
|
|
|
|
|
45 |
rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update
|
46 |
-
torsion_updates =
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
torsion_updates = None if self.no_torsion else torsion_updates
|
48 |
-
modify_conformer(
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
data.
|
53 |
-
data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
return data
|
55 |
|
56 |
|
57 |
class PDBBind(Dataset):
|
58 |
-
def __init__(
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
super(PDBBind, self).__init__(root, transform)
|
65 |
self.pdbbind_dir = root
|
@@ -75,37 +137,67 @@ class PDBBind(Dataset):
|
|
75 |
self.protein_path_list = protein_path_list
|
76 |
self.ligand_descriptions = ligand_descriptions
|
77 |
self.keep_local_structures = keep_local_structures
|
78 |
-
if
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
if all_atoms:
|
81 |
-
cache_path +=
|
82 |
-
self.full_cache_path = os.path.join(
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
self.popsize, self.maxiter = popsize, maxiter
|
92 |
self.matching, self.keep_original = matching, keep_original
|
93 |
self.num_conformers = num_conformers
|
94 |
self.all_atoms = all_atoms
|
95 |
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
|
96 |
-
if not os.path.exists(
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
os.makedirs(self.full_cache_path, exist_ok=True)
|
99 |
if protein_path_list is None or ligand_descriptions is None:
|
100 |
self.preprocessing()
|
101 |
else:
|
102 |
self.inference_preprocessing()
|
103 |
|
104 |
-
print(
|
105 |
-
|
|
|
|
|
|
|
106 |
self.complex_graphs = pickle.load(f)
|
107 |
if require_ligand:
|
108 |
-
with open(
|
|
|
|
|
109 |
self.rdkit_ligands = pickle.load(f)
|
110 |
|
111 |
print_statistics(self.complex_graphs)
|
@@ -122,18 +214,20 @@ class PDBBind(Dataset):
|
|
122 |
return copy.deepcopy(self.complex_graphs[idx])
|
123 |
|
124 |
def preprocessing(self):
|
125 |
-
print(
|
|
|
|
|
126 |
|
127 |
complex_names_all = read_strings_from_txt(self.split_path)
|
128 |
if self.limit_complexes is not None and self.limit_complexes != 0:
|
129 |
-
complex_names_all = complex_names_all[:self.limit_complexes]
|
130 |
-
print(f
|
131 |
|
132 |
if self.esm_embeddings_path is not None:
|
133 |
id_to_embeddings = torch.load(self.esm_embeddings_path)
|
134 |
chain_embeddings_dictlist = defaultdict(list)
|
135 |
for key, embedding in id_to_embeddings.items():
|
136 |
-
key_name = key.split(
|
137 |
if key_name in complex_names_all:
|
138 |
chain_embeddings_dictlist[key_name].append(embedding)
|
139 |
lm_embeddings_chains_all = []
|
@@ -144,58 +238,98 @@ class PDBBind(Dataset):
|
|
144 |
|
145 |
if self.num_workers > 1:
|
146 |
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
147 |
-
for i in range(len(complex_names_all)//1000+1):
|
148 |
-
if os.path.exists(
|
|
|
|
|
149 |
continue
|
150 |
-
complex_names = complex_names_all[1000*i:1000*(i+1)]
|
151 |
-
lm_embeddings_chains = lm_embeddings_chains_all[
|
|
|
|
|
152 |
complex_graphs, rdkit_ligands = [], []
|
153 |
if self.num_workers > 1:
|
154 |
p = Pool(self.num_workers, maxtasksperchild=1)
|
155 |
p.__enter__()
|
156 |
-
with tqdm(
|
|
|
|
|
|
|
157 |
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
158 |
-
for t in map_fn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
complex_graphs.extend(t[0])
|
160 |
rdkit_ligands.extend(t[1])
|
161 |
pbar.update()
|
162 |
-
if self.num_workers > 1:
|
|
|
163 |
|
164 |
-
with open(
|
|
|
|
|
165 |
pickle.dump((complex_graphs), f)
|
166 |
-
with open(
|
|
|
|
|
167 |
pickle.dump((rdkit_ligands), f)
|
168 |
|
169 |
complex_graphs_all = []
|
170 |
-
for i in range(len(complex_names_all)//1000+1):
|
171 |
-
with open(
|
|
|
|
|
172 |
l = pickle.load(f)
|
173 |
complex_graphs_all.extend(l)
|
174 |
-
with open(
|
|
|
|
|
175 |
pickle.dump((complex_graphs_all), f)
|
176 |
|
177 |
rdkit_ligands_all = []
|
178 |
for i in range(len(complex_names_all) // 1000 + 1):
|
179 |
-
with open(
|
|
|
|
|
180 |
l = pickle.load(f)
|
181 |
rdkit_ligands_all.extend(l)
|
182 |
-
with open(
|
|
|
|
|
183 |
pickle.dump((rdkit_ligands_all), f)
|
184 |
else:
|
185 |
complex_graphs, rdkit_ligands = [], []
|
186 |
-
with tqdm(total=len(complex_names_all), desc=
|
187 |
-
for t in map(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
complex_graphs.extend(t[0])
|
189 |
rdkit_ligands.extend(t[1])
|
190 |
pbar.update()
|
191 |
-
with open(
|
|
|
|
|
192 |
pickle.dump((complex_graphs), f)
|
193 |
-
with open(
|
|
|
|
|
194 |
pickle.dump((rdkit_ligands), f)
|
195 |
|
196 |
def inference_preprocessing(self):
|
197 |
ligands_list = []
|
198 |
-
print(
|
199 |
for ligand_description in tqdm(self.ligand_descriptions):
|
200 |
mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path
|
201 |
if mol is not None:
|
@@ -211,70 +345,126 @@ class PDBBind(Dataset):
|
|
211 |
ligands_list.append(mol)
|
212 |
|
213 |
if self.esm_embeddings_path is not None:
|
214 |
-
print(
|
215 |
lm_embeddings_chains_all = []
|
216 |
-
if not os.path.exists(self.esm_embeddings_path):
|
|
|
|
|
|
|
217 |
for protein_path in self.protein_path_list:
|
218 |
-
embeddings_paths = sorted(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
lm_embeddings_chains = []
|
220 |
for embeddings_path in embeddings_paths:
|
221 |
-
lm_embeddings_chains.append(
|
|
|
|
|
222 |
lm_embeddings_chains_all.append(lm_embeddings_chains)
|
223 |
else:
|
224 |
lm_embeddings_chains_all = [None] * len(self.protein_path_list)
|
225 |
|
226 |
-
print(
|
227 |
if self.num_workers > 1:
|
228 |
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
229 |
-
for i in range(len(self.protein_path_list)//1000+1):
|
230 |
-
if os.path.exists(
|
|
|
|
|
231 |
continue
|
232 |
-
protein_paths_chunk = self.protein_path_list[1000*i:1000*(i+1)]
|
233 |
-
ligand_description_chunk = self.ligand_descriptions[
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
236 |
complex_graphs, rdkit_ligands = [], []
|
237 |
if self.num_workers > 1:
|
238 |
p = Pool(self.num_workers, maxtasksperchild=1)
|
239 |
p.__enter__()
|
240 |
-
with tqdm(
|
|
|
|
|
|
|
241 |
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
242 |
-
for t in map_fn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
complex_graphs.extend(t[0])
|
244 |
rdkit_ligands.extend(t[1])
|
245 |
pbar.update()
|
246 |
-
if self.num_workers > 1:
|
|
|
247 |
|
248 |
-
with open(
|
|
|
|
|
249 |
pickle.dump((complex_graphs), f)
|
250 |
-
with open(
|
|
|
|
|
251 |
pickle.dump((rdkit_ligands), f)
|
252 |
|
253 |
complex_graphs_all = []
|
254 |
-
for i in range(len(self.protein_path_list)//1000+1):
|
255 |
-
with open(
|
|
|
|
|
256 |
l = pickle.load(f)
|
257 |
complex_graphs_all.extend(l)
|
258 |
-
with open(
|
|
|
|
|
259 |
pickle.dump((complex_graphs_all), f)
|
260 |
|
261 |
rdkit_ligands_all = []
|
262 |
for i in range(len(self.protein_path_list) // 1000 + 1):
|
263 |
-
with open(
|
|
|
|
|
264 |
l = pickle.load(f)
|
265 |
rdkit_ligands_all.extend(l)
|
266 |
-
with open(
|
|
|
|
|
267 |
pickle.dump((rdkit_ligands_all), f)
|
268 |
else:
|
269 |
complex_graphs, rdkit_ligands = [], []
|
270 |
-
with tqdm(
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
complex_graphs.extend(t[0])
|
273 |
rdkit_ligands.extend(t[1])
|
274 |
pbar.update()
|
275 |
-
with open(
|
|
|
|
|
276 |
pickle.dump((complex_graphs), f)
|
277 |
-
with open(
|
|
|
|
|
278 |
pickle.dump((rdkit_ligands), f)
|
279 |
|
280 |
def get_complex(self, par):
|
@@ -285,51 +475,94 @@ class PDBBind(Dataset):
|
|
285 |
|
286 |
if ligand is not None:
|
287 |
rec_model = parse_pdb_from_path(name)
|
288 |
-
name = f
|
289 |
ligs = [ligand]
|
290 |
else:
|
291 |
try:
|
292 |
rec_model = parse_receptor(name, self.pdbbind_dir)
|
293 |
except Exception as e:
|
294 |
-
print(f
|
295 |
print(e)
|
296 |
return [], []
|
297 |
|
298 |
ligs = read_mols(self.pdbbind_dir, name, remove_hs=False)
|
299 |
complex_graphs = []
|
300 |
for i, lig in enumerate(ligs):
|
301 |
-
if
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
303 |
continue
|
304 |
complex_graph = HeteroData()
|
305 |
-
complex_graph[
|
306 |
try:
|
307 |
-
get_lig_graph_with_matching(
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
continue
|
313 |
|
314 |
-
get_rec_graph(
|
315 |
-
|
316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
except Exception as e:
|
319 |
-
print(f
|
320 |
print(e)
|
321 |
raise e
|
322 |
continue
|
323 |
|
324 |
-
protein_center = torch.mean(
|
325 |
-
|
|
|
|
|
326 |
if self.all_atoms:
|
327 |
-
complex_graph[
|
328 |
|
329 |
if (not self.matching) or self.num_conformers == 1:
|
330 |
-
complex_graph[
|
331 |
else:
|
332 |
-
for p in complex_graph[
|
333 |
p -= protein_center
|
334 |
|
335 |
complex_graph.original_center = protein_center
|
@@ -341,11 +574,18 @@ def print_statistics(complex_graphs):
|
|
341 |
statistics = ([], [], [], [])
|
342 |
|
343 |
for complex_graph in complex_graphs:
|
344 |
-
lig_pos =
|
345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
molecule_center = torch.mean(lig_pos, dim=0)
|
347 |
radius_molecule = torch.max(
|
348 |
-
torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1)
|
|
|
349 |
distance_center = torch.linalg.vector_norm(molecule_center)
|
350 |
statistics[0].append(radius_protein)
|
351 |
statistics[1].append(radius_molecule)
|
@@ -355,52 +595,111 @@ def print_statistics(complex_graphs):
|
|
355 |
else:
|
356 |
statistics[3].append(0)
|
357 |
|
358 |
-
name = [
|
359 |
-
|
|
|
|
|
|
|
|
|
|
|
360 |
for i in range(4):
|
361 |
array = np.asarray(statistics[i])
|
362 |
-
print(
|
|
|
|
|
363 |
|
364 |
|
365 |
def construct_loader(args, t_to_sigma):
|
366 |
-
transform = NoiseTransform(
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
|
382 |
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
|
383 |
-
train_loader = loader_class(
|
384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
return train_loader, val_loader
|
387 |
|
388 |
|
389 |
def read_mol(pdbbind_dir, name, remove_hs=False):
|
390 |
-
lig = read_molecule(
|
|
|
|
|
|
|
|
|
391 |
if lig is None: # read mol2 file if sdf file cannot be sanitized
|
392 |
-
lig = read_molecule(
|
|
|
|
|
|
|
|
|
393 |
return lig
|
394 |
|
395 |
|
396 |
def read_mols(pdbbind_dir, name, remove_hs=False):
|
397 |
ligs = []
|
398 |
for file in os.listdir(os.path.join(pdbbind_dir, name)):
|
399 |
-
if file.endswith(".sdf") and
|
400 |
-
lig = read_molecule(
|
401 |
-
|
402 |
-
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
if lig is not None:
|
405 |
ligs.append(lig)
|
406 |
-
return ligs
|
|
|
16 |
from torch_geometric.transforms import BaseTransform
|
17 |
from tqdm import tqdm
|
18 |
|
19 |
+
from datasets.process_mols import (
|
20 |
+
read_molecule,
|
21 |
+
get_rec_graph,
|
22 |
+
generate_conformer,
|
23 |
+
get_lig_graph_with_matching,
|
24 |
+
extract_receptor_structure,
|
25 |
+
parse_receptor,
|
26 |
+
parse_pdb_from_path,
|
27 |
+
)
|
28 |
from utils.diffusion_utils import modify_conformer, set_time
|
29 |
from utils.utils import read_strings_from_txt
|
30 |
from utils import so3, torus
|
|
|
41 |
t_tr, t_rot, t_tor = t, t, t
|
42 |
return self.apply_noise(data, t_tr, t_rot, t_tor)
|
43 |
|
44 |
+
def apply_noise(
|
45 |
+
self,
|
46 |
+
data,
|
47 |
+
t_tr,
|
48 |
+
t_rot,
|
49 |
+
t_tor,
|
50 |
+
tr_update=None,
|
51 |
+
rot_update=None,
|
52 |
+
torsion_updates=None,
|
53 |
+
):
|
54 |
+
if not torch.is_tensor(data["ligand"].pos):
|
55 |
+
data["ligand"].pos = random.choice(data["ligand"].pos)
|
56 |
|
57 |
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor)
|
58 |
set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None)
|
59 |
|
60 |
+
tr_update = (
|
61 |
+
torch.normal(mean=0, std=tr_sigma, size=(1, 3))
|
62 |
+
if tr_update is None
|
63 |
+
else tr_update
|
64 |
+
)
|
65 |
rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update
|
66 |
+
torsion_updates = (
|
67 |
+
np.random.normal(
|
68 |
+
loc=0.0, scale=tor_sigma, size=data["ligand"].edge_mask.sum()
|
69 |
+
)
|
70 |
+
if torsion_updates is None
|
71 |
+
else torsion_updates
|
72 |
+
)
|
73 |
torsion_updates = None if self.no_torsion else torsion_updates
|
74 |
+
modify_conformer(
|
75 |
+
data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates
|
76 |
+
)
|
77 |
+
|
78 |
+
data.tr_score = -tr_update / tr_sigma**2
|
79 |
+
data.rot_score = (
|
80 |
+
torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma))
|
81 |
+
.float()
|
82 |
+
.unsqueeze(0)
|
83 |
+
)
|
84 |
+
data.tor_score = (
|
85 |
+
None
|
86 |
+
if self.no_torsion
|
87 |
+
else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float()
|
88 |
+
)
|
89 |
+
data.tor_sigma_edge = (
|
90 |
+
None
|
91 |
+
if self.no_torsion
|
92 |
+
else np.ones(data["ligand"].edge_mask.sum()) * tor_sigma
|
93 |
+
)
|
94 |
return data
|
95 |
|
96 |
|
97 |
class PDBBind(Dataset):
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
root,
|
101 |
+
transform=None,
|
102 |
+
cache_path="data/cache",
|
103 |
+
split_path="data/",
|
104 |
+
limit_complexes=0,
|
105 |
+
receptor_radius=30,
|
106 |
+
num_workers=1,
|
107 |
+
c_alpha_max_neighbors=None,
|
108 |
+
popsize=15,
|
109 |
+
maxiter=15,
|
110 |
+
matching=True,
|
111 |
+
keep_original=False,
|
112 |
+
max_lig_size=None,
|
113 |
+
remove_hs=False,
|
114 |
+
num_conformers=1,
|
115 |
+
all_atoms=False,
|
116 |
+
atom_radius=5,
|
117 |
+
atom_max_neighbors=None,
|
118 |
+
esm_embeddings_path=None,
|
119 |
+
require_ligand=False,
|
120 |
+
ligands_list=None,
|
121 |
+
protein_path_list=None,
|
122 |
+
ligand_descriptions=None,
|
123 |
+
keep_local_structures=False,
|
124 |
+
):
|
125 |
|
126 |
super(PDBBind, self).__init__(root, transform)
|
127 |
self.pdbbind_dir = root
|
|
|
137 |
self.protein_path_list = protein_path_list
|
138 |
self.ligand_descriptions = ligand_descriptions
|
139 |
self.keep_local_structures = keep_local_structures
|
140 |
+
if (
|
141 |
+
matching
|
142 |
+
or protein_path_list is not None
|
143 |
+
and ligand_descriptions is not None
|
144 |
+
):
|
145 |
+
cache_path += "_torsion"
|
146 |
if all_atoms:
|
147 |
+
cache_path += "_allatoms"
|
148 |
+
self.full_cache_path = os.path.join(
|
149 |
+
cache_path,
|
150 |
+
f"limit{self.limit_complexes}"
|
151 |
+
f"_INDEX{os.path.splitext(os.path.basename(self.split_path))[0]}"
|
152 |
+
f"_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}"
|
153 |
+
f"_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}"
|
154 |
+
+ (
|
155 |
+
""
|
156 |
+
if not all_atoms
|
157 |
+
else f"_atomRad{atom_radius}_atomMax{atom_max_neighbors}"
|
158 |
+
)
|
159 |
+
+ ("" if not matching or num_conformers == 1 else f"_confs{num_conformers}")
|
160 |
+
+ ("" if self.esm_embeddings_path is None else f"_esmEmbeddings")
|
161 |
+
+ ("" if not keep_local_structures else f"_keptLocalStruct")
|
162 |
+
+ (
|
163 |
+
""
|
164 |
+
if protein_path_list is None or ligand_descriptions is None
|
165 |
+
else str(
|
166 |
+
binascii.crc32(
|
167 |
+
"".join(ligand_descriptions + protein_path_list).encode()
|
168 |
+
)
|
169 |
+
)
|
170 |
+
),
|
171 |
+
)
|
172 |
self.popsize, self.maxiter = popsize, maxiter
|
173 |
self.matching, self.keep_original = matching, keep_original
|
174 |
self.num_conformers = num_conformers
|
175 |
self.all_atoms = all_atoms
|
176 |
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
|
177 |
+
if not os.path.exists(
|
178 |
+
os.path.join(self.full_cache_path, "heterographs.pkl")
|
179 |
+
) or (
|
180 |
+
require_ligand
|
181 |
+
and not os.path.exists(
|
182 |
+
os.path.join(self.full_cache_path, "rdkit_ligands.pkl")
|
183 |
+
)
|
184 |
+
):
|
185 |
os.makedirs(self.full_cache_path, exist_ok=True)
|
186 |
if protein_path_list is None or ligand_descriptions is None:
|
187 |
self.preprocessing()
|
188 |
else:
|
189 |
self.inference_preprocessing()
|
190 |
|
191 |
+
print(
|
192 |
+
"loading data from memory: ",
|
193 |
+
os.path.join(self.full_cache_path, "heterographs.pkl"),
|
194 |
+
)
|
195 |
+
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), "rb") as f:
|
196 |
self.complex_graphs = pickle.load(f)
|
197 |
if require_ligand:
|
198 |
+
with open(
|
199 |
+
os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "rb"
|
200 |
+
) as f:
|
201 |
self.rdkit_ligands = pickle.load(f)
|
202 |
|
203 |
print_statistics(self.complex_graphs)
|
|
|
214 |
return copy.deepcopy(self.complex_graphs[idx])
|
215 |
|
216 |
def preprocessing(self):
|
217 |
+
print(
|
218 |
+
f"Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]"
|
219 |
+
)
|
220 |
|
221 |
complex_names_all = read_strings_from_txt(self.split_path)
|
222 |
if self.limit_complexes is not None and self.limit_complexes != 0:
|
223 |
+
complex_names_all = complex_names_all[: self.limit_complexes]
|
224 |
+
print(f"Loading {len(complex_names_all)} complexes.")
|
225 |
|
226 |
if self.esm_embeddings_path is not None:
|
227 |
id_to_embeddings = torch.load(self.esm_embeddings_path)
|
228 |
chain_embeddings_dictlist = defaultdict(list)
|
229 |
for key, embedding in id_to_embeddings.items():
|
230 |
+
key_name = key.split("_")[0]
|
231 |
if key_name in complex_names_all:
|
232 |
chain_embeddings_dictlist[key_name].append(embedding)
|
233 |
lm_embeddings_chains_all = []
|
|
|
238 |
|
239 |
if self.num_workers > 1:
|
240 |
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
241 |
+
for i in range(len(complex_names_all) // 1000 + 1):
|
242 |
+
if os.path.exists(
|
243 |
+
os.path.join(self.full_cache_path, f"heterographs{i}.pkl")
|
244 |
+
):
|
245 |
continue
|
246 |
+
complex_names = complex_names_all[1000 * i : 1000 * (i + 1)]
|
247 |
+
lm_embeddings_chains = lm_embeddings_chains_all[
|
248 |
+
1000 * i : 1000 * (i + 1)
|
249 |
+
]
|
250 |
complex_graphs, rdkit_ligands = [], []
|
251 |
if self.num_workers > 1:
|
252 |
p = Pool(self.num_workers, maxtasksperchild=1)
|
253 |
p.__enter__()
|
254 |
+
with tqdm(
|
255 |
+
total=len(complex_names),
|
256 |
+
desc=f"loading complexes {i}/{len(complex_names_all)//1000+1}",
|
257 |
+
) as pbar:
|
258 |
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
259 |
+
for t in map_fn(
|
260 |
+
self.get_complex,
|
261 |
+
zip(
|
262 |
+
complex_names,
|
263 |
+
lm_embeddings_chains,
|
264 |
+
[None] * len(complex_names),
|
265 |
+
[None] * len(complex_names),
|
266 |
+
),
|
267 |
+
):
|
268 |
complex_graphs.extend(t[0])
|
269 |
rdkit_ligands.extend(t[1])
|
270 |
pbar.update()
|
271 |
+
if self.num_workers > 1:
|
272 |
+
p.__exit__(None, None, None)
|
273 |
|
274 |
+
with open(
|
275 |
+
os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "wb"
|
276 |
+
) as f:
|
277 |
pickle.dump((complex_graphs), f)
|
278 |
+
with open(
|
279 |
+
os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "wb"
|
280 |
+
) as f:
|
281 |
pickle.dump((rdkit_ligands), f)
|
282 |
|
283 |
complex_graphs_all = []
|
284 |
+
for i in range(len(complex_names_all) // 1000 + 1):
|
285 |
+
with open(
|
286 |
+
os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "rb"
|
287 |
+
) as f:
|
288 |
l = pickle.load(f)
|
289 |
complex_graphs_all.extend(l)
|
290 |
+
with open(
|
291 |
+
os.path.join(self.full_cache_path, f"heterographs.pkl"), "wb"
|
292 |
+
) as f:
|
293 |
pickle.dump((complex_graphs_all), f)
|
294 |
|
295 |
rdkit_ligands_all = []
|
296 |
for i in range(len(complex_names_all) // 1000 + 1):
|
297 |
+
with open(
|
298 |
+
os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "rb"
|
299 |
+
) as f:
|
300 |
l = pickle.load(f)
|
301 |
rdkit_ligands_all.extend(l)
|
302 |
+
with open(
|
303 |
+
os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), "wb"
|
304 |
+
) as f:
|
305 |
pickle.dump((rdkit_ligands_all), f)
|
306 |
else:
|
307 |
complex_graphs, rdkit_ligands = [], []
|
308 |
+
with tqdm(total=len(complex_names_all), desc="loading complexes") as pbar:
|
309 |
+
for t in map(
|
310 |
+
self.get_complex,
|
311 |
+
zip(
|
312 |
+
complex_names_all,
|
313 |
+
lm_embeddings_chains_all,
|
314 |
+
[None] * len(complex_names_all),
|
315 |
+
[None] * len(complex_names_all),
|
316 |
+
),
|
317 |
+
):
|
318 |
complex_graphs.extend(t[0])
|
319 |
rdkit_ligands.extend(t[1])
|
320 |
pbar.update()
|
321 |
+
with open(
|
322 |
+
os.path.join(self.full_cache_path, "heterographs.pkl"), "wb"
|
323 |
+
) as f:
|
324 |
pickle.dump((complex_graphs), f)
|
325 |
+
with open(
|
326 |
+
os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "wb"
|
327 |
+
) as f:
|
328 |
pickle.dump((rdkit_ligands), f)
|
329 |
|
330 |
def inference_preprocessing(self):
|
331 |
ligands_list = []
|
332 |
+
print("Reading molecules and generating local structures with RDKit")
|
333 |
for ligand_description in tqdm(self.ligand_descriptions):
|
334 |
mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path
|
335 |
if mol is not None:
|
|
|
345 |
ligands_list.append(mol)
|
346 |
|
347 |
if self.esm_embeddings_path is not None:
|
348 |
+
print("Reading language model embeddings.")
|
349 |
lm_embeddings_chains_all = []
|
350 |
+
if not os.path.exists(self.esm_embeddings_path):
|
351 |
+
raise Exception(
|
352 |
+
"ESM embeddings path does not exist: ", self.esm_embeddings_path
|
353 |
+
)
|
354 |
for protein_path in self.protein_path_list:
|
355 |
+
embeddings_paths = sorted(
|
356 |
+
glob.glob(
|
357 |
+
os.path.join(
|
358 |
+
self.esm_embeddings_path, os.path.basename(protein_path)
|
359 |
+
)
|
360 |
+
+ "*"
|
361 |
+
)
|
362 |
+
)
|
363 |
lm_embeddings_chains = []
|
364 |
for embeddings_path in embeddings_paths:
|
365 |
+
lm_embeddings_chains.append(
|
366 |
+
torch.load(embeddings_path)["representations"][33]
|
367 |
+
)
|
368 |
lm_embeddings_chains_all.append(lm_embeddings_chains)
|
369 |
else:
|
370 |
lm_embeddings_chains_all = [None] * len(self.protein_path_list)
|
371 |
|
372 |
+
print("Generating graphs for ligands and proteins")
|
373 |
if self.num_workers > 1:
|
374 |
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
375 |
+
for i in range(len(self.protein_path_list) // 1000 + 1):
|
376 |
+
if os.path.exists(
|
377 |
+
os.path.join(self.full_cache_path, f"heterographs{i}.pkl")
|
378 |
+
):
|
379 |
continue
|
380 |
+
protein_paths_chunk = self.protein_path_list[1000 * i : 1000 * (i + 1)]
|
381 |
+
ligand_description_chunk = self.ligand_descriptions[
|
382 |
+
1000 * i : 1000 * (i + 1)
|
383 |
+
]
|
384 |
+
ligands_chunk = ligands_list[1000 * i : 1000 * (i + 1)]
|
385 |
+
lm_embeddings_chains = lm_embeddings_chains_all[
|
386 |
+
1000 * i : 1000 * (i + 1)
|
387 |
+
]
|
388 |
complex_graphs, rdkit_ligands = [], []
|
389 |
if self.num_workers > 1:
|
390 |
p = Pool(self.num_workers, maxtasksperchild=1)
|
391 |
p.__enter__()
|
392 |
+
with tqdm(
|
393 |
+
total=len(protein_paths_chunk),
|
394 |
+
desc=f"loading complexes {i}/{len(protein_paths_chunk)//1000+1}",
|
395 |
+
) as pbar:
|
396 |
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
397 |
+
for t in map_fn(
|
398 |
+
self.get_complex,
|
399 |
+
zip(
|
400 |
+
protein_paths_chunk,
|
401 |
+
lm_embeddings_chains,
|
402 |
+
ligands_chunk,
|
403 |
+
ligand_description_chunk,
|
404 |
+
),
|
405 |
+
):
|
406 |
complex_graphs.extend(t[0])
|
407 |
rdkit_ligands.extend(t[1])
|
408 |
pbar.update()
|
409 |
+
if self.num_workers > 1:
|
410 |
+
p.__exit__(None, None, None)
|
411 |
|
412 |
+
with open(
|
413 |
+
os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "wb"
|
414 |
+
) as f:
|
415 |
pickle.dump((complex_graphs), f)
|
416 |
+
with open(
|
417 |
+
os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "wb"
|
418 |
+
) as f:
|
419 |
pickle.dump((rdkit_ligands), f)
|
420 |
|
421 |
complex_graphs_all = []
|
422 |
+
for i in range(len(self.protein_path_list) // 1000 + 1):
|
423 |
+
with open(
|
424 |
+
os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "rb"
|
425 |
+
) as f:
|
426 |
l = pickle.load(f)
|
427 |
complex_graphs_all.extend(l)
|
428 |
+
with open(
|
429 |
+
os.path.join(self.full_cache_path, f"heterographs.pkl"), "wb"
|
430 |
+
) as f:
|
431 |
pickle.dump((complex_graphs_all), f)
|
432 |
|
433 |
rdkit_ligands_all = []
|
434 |
for i in range(len(self.protein_path_list) // 1000 + 1):
|
435 |
+
with open(
|
436 |
+
os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "rb"
|
437 |
+
) as f:
|
438 |
l = pickle.load(f)
|
439 |
rdkit_ligands_all.extend(l)
|
440 |
+
with open(
|
441 |
+
os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), "wb"
|
442 |
+
) as f:
|
443 |
pickle.dump((rdkit_ligands_all), f)
|
444 |
else:
|
445 |
complex_graphs, rdkit_ligands = [], []
|
446 |
+
with tqdm(
|
447 |
+
total=len(self.protein_path_list), desc="loading complexes"
|
448 |
+
) as pbar:
|
449 |
+
for t in map(
|
450 |
+
self.get_complex,
|
451 |
+
zip(
|
452 |
+
self.protein_path_list,
|
453 |
+
lm_embeddings_chains_all,
|
454 |
+
ligands_list,
|
455 |
+
self.ligand_descriptions,
|
456 |
+
),
|
457 |
+
):
|
458 |
complex_graphs.extend(t[0])
|
459 |
rdkit_ligands.extend(t[1])
|
460 |
pbar.update()
|
461 |
+
with open(
|
462 |
+
os.path.join(self.full_cache_path, "heterographs.pkl"), "wb"
|
463 |
+
) as f:
|
464 |
pickle.dump((complex_graphs), f)
|
465 |
+
with open(
|
466 |
+
os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "wb"
|
467 |
+
) as f:
|
468 |
pickle.dump((rdkit_ligands), f)
|
469 |
|
470 |
def get_complex(self, par):
|
|
|
475 |
|
476 |
if ligand is not None:
|
477 |
rec_model = parse_pdb_from_path(name)
|
478 |
+
name = f"{name}____{ligand_description}"
|
479 |
ligs = [ligand]
|
480 |
else:
|
481 |
try:
|
482 |
rec_model = parse_receptor(name, self.pdbbind_dir)
|
483 |
except Exception as e:
|
484 |
+
print(f"Skipping {name} because of the error:")
|
485 |
print(e)
|
486 |
return [], []
|
487 |
|
488 |
ligs = read_mols(self.pdbbind_dir, name, remove_hs=False)
|
489 |
complex_graphs = []
|
490 |
for i, lig in enumerate(ligs):
|
491 |
+
if (
|
492 |
+
self.max_lig_size is not None
|
493 |
+
and lig.GetNumHeavyAtoms() > self.max_lig_size
|
494 |
+
):
|
495 |
+
print(
|
496 |
+
f"Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data."
|
497 |
+
)
|
498 |
continue
|
499 |
complex_graph = HeteroData()
|
500 |
+
complex_graph["name"] = name
|
501 |
try:
|
502 |
+
get_lig_graph_with_matching(
|
503 |
+
lig,
|
504 |
+
complex_graph,
|
505 |
+
self.popsize,
|
506 |
+
self.maxiter,
|
507 |
+
self.matching,
|
508 |
+
self.keep_original,
|
509 |
+
self.num_conformers,
|
510 |
+
remove_hs=self.remove_hs,
|
511 |
+
)
|
512 |
+
print(lm_embedding_chains)
|
513 |
+
(
|
514 |
+
rec,
|
515 |
+
rec_coords,
|
516 |
+
c_alpha_coords,
|
517 |
+
n_coords,
|
518 |
+
c_coords,
|
519 |
+
lm_embeddings,
|
520 |
+
) = extract_receptor_structure(
|
521 |
+
copy.deepcopy(rec_model),
|
522 |
+
lig,
|
523 |
+
lm_embedding_chains=lm_embedding_chains,
|
524 |
+
)
|
525 |
+
if lm_embeddings is not None and len(c_alpha_coords) != len(
|
526 |
+
lm_embeddings
|
527 |
+
):
|
528 |
+
print(
|
529 |
+
f"LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}."
|
530 |
+
)
|
531 |
continue
|
532 |
|
533 |
+
get_rec_graph(
|
534 |
+
rec,
|
535 |
+
rec_coords,
|
536 |
+
c_alpha_coords,
|
537 |
+
n_coords,
|
538 |
+
c_coords,
|
539 |
+
complex_graph,
|
540 |
+
rec_radius=self.receptor_radius,
|
541 |
+
c_alpha_max_neighbors=self.c_alpha_max_neighbors,
|
542 |
+
all_atoms=self.all_atoms,
|
543 |
+
atom_radius=self.atom_radius,
|
544 |
+
atom_max_neighbors=self.atom_max_neighbors,
|
545 |
+
remove_hs=self.remove_hs,
|
546 |
+
lm_embeddings=lm_embeddings,
|
547 |
+
)
|
548 |
|
549 |
except Exception as e:
|
550 |
+
print(f"Skipping {name} because of the error:")
|
551 |
print(e)
|
552 |
raise e
|
553 |
continue
|
554 |
|
555 |
+
protein_center = torch.mean(
|
556 |
+
complex_graph["receptor"].pos, dim=0, keepdim=True
|
557 |
+
)
|
558 |
+
complex_graph["receptor"].pos -= protein_center
|
559 |
if self.all_atoms:
|
560 |
+
complex_graph["atom"].pos -= protein_center
|
561 |
|
562 |
if (not self.matching) or self.num_conformers == 1:
|
563 |
+
complex_graph["ligand"].pos -= protein_center
|
564 |
else:
|
565 |
+
for p in complex_graph["ligand"].pos:
|
566 |
p -= protein_center
|
567 |
|
568 |
complex_graph.original_center = protein_center
|
|
|
574 |
statistics = ([], [], [], [])
|
575 |
|
576 |
for complex_graph in complex_graphs:
|
577 |
+
lig_pos = (
|
578 |
+
complex_graph["ligand"].pos
|
579 |
+
if torch.is_tensor(complex_graph["ligand"].pos)
|
580 |
+
else complex_graph["ligand"].pos[0]
|
581 |
+
)
|
582 |
+
radius_protein = torch.max(
|
583 |
+
torch.linalg.vector_norm(complex_graph["receptor"].pos, dim=1)
|
584 |
+
)
|
585 |
molecule_center = torch.mean(lig_pos, dim=0)
|
586 |
radius_molecule = torch.max(
|
587 |
+
torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1)
|
588 |
+
)
|
589 |
distance_center = torch.linalg.vector_norm(molecule_center)
|
590 |
statistics[0].append(radius_protein)
|
591 |
statistics[1].append(radius_molecule)
|
|
|
595 |
else:
|
596 |
statistics[3].append(0)
|
597 |
|
598 |
+
name = [
|
599 |
+
"radius protein",
|
600 |
+
"radius molecule",
|
601 |
+
"distance protein-mol",
|
602 |
+
"rmsd matching",
|
603 |
+
]
|
604 |
+
print("Number of complexes: ", len(complex_graphs))
|
605 |
for i in range(4):
|
606 |
array = np.asarray(statistics[i])
|
607 |
+
print(
|
608 |
+
f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}"
|
609 |
+
)
|
610 |
|
611 |
|
612 |
def construct_loader(args, t_to_sigma):
|
613 |
+
transform = NoiseTransform(
|
614 |
+
t_to_sigma=t_to_sigma, no_torsion=args.no_torsion, all_atom=args.all_atoms
|
615 |
+
)
|
616 |
+
|
617 |
+
common_args = {
|
618 |
+
"transform": transform,
|
619 |
+
"root": args.data_dir,
|
620 |
+
"limit_complexes": args.limit_complexes,
|
621 |
+
"receptor_radius": args.receptor_radius,
|
622 |
+
"c_alpha_max_neighbors": args.c_alpha_max_neighbors,
|
623 |
+
"remove_hs": args.remove_hs,
|
624 |
+
"max_lig_size": args.max_lig_size,
|
625 |
+
"matching": not args.no_torsion,
|
626 |
+
"popsize": args.matching_popsize,
|
627 |
+
"maxiter": args.matching_maxiter,
|
628 |
+
"num_workers": args.num_workers,
|
629 |
+
"all_atoms": args.all_atoms,
|
630 |
+
"atom_radius": args.atom_radius,
|
631 |
+
"atom_max_neighbors": args.atom_max_neighbors,
|
632 |
+
"esm_embeddings_path": args.esm_embeddings_path,
|
633 |
+
}
|
634 |
+
|
635 |
+
train_dataset = PDBBind(
|
636 |
+
cache_path=args.cache_path,
|
637 |
+
split_path=args.split_train,
|
638 |
+
keep_original=True,
|
639 |
+
num_conformers=args.num_conformers,
|
640 |
+
**common_args,
|
641 |
+
)
|
642 |
+
val_dataset = PDBBind(
|
643 |
+
cache_path=args.cache_path,
|
644 |
+
split_path=args.split_val,
|
645 |
+
keep_original=True,
|
646 |
+
**common_args,
|
647 |
+
)
|
648 |
|
649 |
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
|
650 |
+
train_loader = loader_class(
|
651 |
+
dataset=train_dataset,
|
652 |
+
batch_size=args.batch_size,
|
653 |
+
num_workers=args.num_dataloader_workers,
|
654 |
+
shuffle=True,
|
655 |
+
pin_memory=args.pin_memory,
|
656 |
+
)
|
657 |
+
val_loader = loader_class(
|
658 |
+
dataset=val_dataset,
|
659 |
+
batch_size=args.batch_size,
|
660 |
+
num_workers=args.num_dataloader_workers,
|
661 |
+
shuffle=True,
|
662 |
+
pin_memory=args.pin_memory,
|
663 |
+
)
|
664 |
|
665 |
return train_loader, val_loader
|
666 |
|
667 |
|
668 |
def read_mol(pdbbind_dir, name, remove_hs=False):
|
669 |
+
lig = read_molecule(
|
670 |
+
os.path.join(pdbbind_dir, name, f"{name}_ligand.sdf"),
|
671 |
+
remove_hs=remove_hs,
|
672 |
+
sanitize=True,
|
673 |
+
)
|
674 |
if lig is None: # read mol2 file if sdf file cannot be sanitized
|
675 |
+
lig = read_molecule(
|
676 |
+
os.path.join(pdbbind_dir, name, f"{name}_ligand.mol2"),
|
677 |
+
remove_hs=remove_hs,
|
678 |
+
sanitize=True,
|
679 |
+
)
|
680 |
return lig
|
681 |
|
682 |
|
683 |
def read_mols(pdbbind_dir, name, remove_hs=False):
|
684 |
ligs = []
|
685 |
for file in os.listdir(os.path.join(pdbbind_dir, name)):
|
686 |
+
if file.endswith(".sdf") and "rdkit" not in file:
|
687 |
+
lig = read_molecule(
|
688 |
+
os.path.join(pdbbind_dir, name, file),
|
689 |
+
remove_hs=remove_hs,
|
690 |
+
sanitize=True,
|
691 |
+
)
|
692 |
+
if lig is None and os.path.exists(
|
693 |
+
os.path.join(pdbbind_dir, name, file[:-4] + ".mol2")
|
694 |
+
): # read mol2 file if sdf file cannot be sanitized
|
695 |
+
print(
|
696 |
+
"Using the .sdf file failed. We found a .mol2 file instead and are trying to use that."
|
697 |
+
)
|
698 |
+
lig = read_molecule(
|
699 |
+
os.path.join(pdbbind_dir, name, file[:-4] + ".mol2"),
|
700 |
+
remove_hs=remove_hs,
|
701 |
+
sanitize=True,
|
702 |
+
)
|
703 |
if lig is not None:
|
704 |
ligs.append(lig)
|
705 |
+
return ligs
|
datasets/process_mols.py
CHANGED
@@ -490,8 +490,10 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
|
|
490 |
if molecule_file.endswith('.mol2'):
|
491 |
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
|
492 |
elif molecule_file.endswith('.sdf'):
|
|
|
493 |
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
|
494 |
mol = supplier[0]
|
|
|
495 |
elif molecule_file.endswith('.pdbqt'):
|
496 |
with open(molecule_file) as file:
|
497 |
pdbqt_data = file.readlines()
|
@@ -505,6 +507,8 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
|
|
505 |
return ValueError('Expect the format of the molecule_file to be '
|
506 |
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
|
507 |
|
|
|
|
|
508 |
try:
|
509 |
if sanitize or calc_charges:
|
510 |
Chem.SanitizeMol(mol)
|
@@ -518,7 +522,8 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
|
|
518 |
|
519 |
if remove_hs:
|
520 |
mol = Chem.RemoveHs(mol, sanitize=sanitize)
|
521 |
-
except:
|
|
|
522 |
return None
|
523 |
|
524 |
return mol
|
|
|
490 |
if molecule_file.endswith('.mol2'):
|
491 |
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
|
492 |
elif molecule_file.endswith('.sdf'):
|
493 |
+
print(molecule_file)
|
494 |
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
|
495 |
mol = supplier[0]
|
496 |
+
print(mol)
|
497 |
elif molecule_file.endswith('.pdbqt'):
|
498 |
with open(molecule_file) as file:
|
499 |
pdbqt_data = file.readlines()
|
|
|
507 |
return ValueError('Expect the format of the molecule_file to be '
|
508 |
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
|
509 |
|
510 |
+
print(sanitize, calc_charges, remove_hs)
|
511 |
+
|
512 |
try:
|
513 |
if sanitize or calc_charges:
|
514 |
Chem.SanitizeMol(mol)
|
|
|
522 |
|
523 |
if remove_hs:
|
524 |
mol = Chem.RemoveHs(mol, sanitize=sanitize)
|
525 |
+
except Exception as e:
|
526 |
+
print(e)
|
527 |
return None
|
528 |
|
529 |
return mol
|
examples/1a46_ligand.sdf
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1a46_ligand
|
2 |
+
-I-interpret-
|
3 |
+
|
4 |
+
85 88 0 0 0 0 0 0 0 0999 V2000
|
5 |
+
17.8330 -13.0420 21.6620 C 0 0 0 0 0
|
6 |
+
18.8870 -13.0710 20.5870 C 0 0 0 0 0
|
7 |
+
19.8510 -14.2200 21.1170 C 0 0 0 0 0
|
8 |
+
19.3270 -16.4440 22.1560 C 0 0 0 0 0
|
9 |
+
18.1340 -17.2300 22.7620 C 0 0 0 0 0
|
10 |
+
17.2230 -16.3290 23.5970 C 0 0 0 0 0
|
11 |
+
17.0320 -14.9230 23.0460 C 0 0 0 0 0
|
12 |
+
18.8520 -15.2420 21.4440 N 0 3 0 0 0
|
13 |
+
17.7750 -14.5090 22.0480 N 0 0 0 0 0
|
14 |
+
15.9850 -14.2900 23.3800 O 0 0 0 0 0
|
15 |
+
16.6380 -13.0610 20.7550 C 0 0 0 0 0
|
16 |
+
16.4620 -13.9620 19.8370 O 0 0 0 0 0
|
17 |
+
15.8090 -16.7300 23.6610 N 0 3 0 0 0
|
18 |
+
17.4150 -16.4170 25.1230 C 0 0 0 0 0
|
19 |
+
18.7640 -15.9840 25.5820 C 0 0 0 0 0
|
20 |
+
19.0510 -14.6340 25.7600 C 0 0 0 0 0
|
21 |
+
20.3910 -14.2520 26.0760 C 0 0 0 0 0
|
22 |
+
21.4290 -15.1780 26.2150 C 0 0 0 0 0
|
23 |
+
21.0990 -16.5480 26.0980 C 0 0 0 0 0
|
24 |
+
19.7890 -16.9510 25.7560 C 0 0 0 0 0
|
25 |
+
15.6470 -12.0890 20.7690 N 0 0 0 0 0
|
26 |
+
14.4940 -11.8920 19.9090 C 0 0 0 0 0
|
27 |
+
14.4960 -10.9450 18.7130 C 0 0 0 0 0
|
28 |
+
13.3800 -10.6840 18.0770 O 0 0 0 0 0
|
29 |
+
13.1950 -11.6150 20.6280 C 0 0 0 0 0
|
30 |
+
12.8670 -12.5040 21.7570 C 0 0 0 0 0
|
31 |
+
11.5610 -12.2200 22.4370 C 0 0 0 0 0
|
32 |
+
11.1700 -13.3510 23.3530 C 0 0 0 0 0
|
33 |
+
10.0380 -13.1110 24.2350 N 0 3 0 0 0
|
34 |
+
14.8040 -11.9210 16.4570 N 0 0 0 0 0
|
35 |
+
15.3450 -11.4350 17.5510 C 0 0 0 0 0
|
36 |
+
16.4740 -11.0890 17.7310 O 0 0 0 0 0
|
37 |
+
15.6510 -12.3330 15.3350 C 0 0 0 0 0
|
38 |
+
16.0390 -13.7960 15.2500 C 0 0 0 0 0
|
39 |
+
14.9560 -14.6030 14.5390 C 0 0 0 0 0
|
40 |
+
14.5990 -13.9990 13.1800 C 0 0 0 0 0
|
41 |
+
14.1680 -12.5610 13.3540 C 0 0 0 0 0
|
42 |
+
15.2770 -11.7400 13.9980 C 0 0 0 0 0
|
43 |
+
17.9332 -12.2994 22.4536 H 0 0 0 0 0
|
44 |
+
19.3882 -12.1140 20.4420 H 0 0 0 0 0
|
45 |
+
18.4882 -13.2617 19.5906 H 0 0 0 0 0
|
46 |
+
20.4926 -13.9283 21.9484 H 0 0 0 0 0
|
47 |
+
20.6127 -14.5392 20.4056 H 0 0 0 0 0
|
48 |
+
19.8508 -17.0880 21.4496 H 0 0 0 0 0
|
49 |
+
19.9921 -16.1358 22.9627 H 0 0 0 0 0
|
50 |
+
18.5327 -18.0092 23.4116 H 0 0 0 0 0
|
51 |
+
17.5467 -17.6450 21.9429 H 0 0 0 0 0
|
52 |
+
18.5389 -15.7277 20.6035 H 0 0 0 0 0
|
53 |
+
15.7428 -17.6818 24.0216 H 0 0 0 0 0
|
54 |
+
15.3044 -16.0949 24.2794 H 0 0 0 0 0
|
55 |
+
15.4029 -16.6903 22.7262 H 0 0 0 0 0
|
56 |
+
17.2937 -17.4623 25.4072 H 0 0 0 0 0
|
57 |
+
16.6848 -15.7509 25.5825 H 0 0 0 0 0
|
58 |
+
18.2682 -13.8821 25.6602 H 0 0 0 0 0
|
59 |
+
20.6133 -13.1939 26.2145 H 0 0 0 0 0
|
60 |
+
22.4528 -14.8565 26.4061 H 0 0 0 0 0
|
61 |
+
21.8654 -17.3029 26.2740 H 0 0 0 0 0
|
62 |
+
19.5640 -18.0094 25.6250 H 0 0 0 0 0
|
63 |
+
15.7457 -11.3948 21.5098 H 0 0 0 0 0
|
64 |
+
14.5905 -12.8910 19.4839 H 0 0 0 0 0
|
65 |
+
14.8425 -10.0689 19.2612 H 0 0 0 0 0
|
66 |
+
13.5584 -10.0751 17.3566 H 0 0 0 0 0
|
67 |
+
12.4050 -11.7585 19.8909 H 0 0 0 0 0
|
68 |
+
13.2901 -10.6141 21.0491 H 0 0 0 0 0
|
69 |
+
13.6465 -12.3595 22.5050 H 0 0 0 0 0
|
70 |
+
12.7942 -13.5124 21.3496 H 0 0 0 0 0
|
71 |
+
10.7892 -12.1043 21.6761 H 0 0 0 0 0
|
72 |
+
11.6663 -11.3113 23.0296 H 0 0 0 0 0
|
73 |
+
12.0278 -13.5229 24.0031 H 0 0 0 0 0
|
74 |
+
10.8774 -14.1769 22.7046 H 0 0 0 0 0
|
75 |
+
9.8690 -13.9413 24.8029 H 0 0 0 0 0
|
76 |
+
10.2441 -12.3181 24.8427 H 0 0 0 0 0
|
77 |
+
9.2101 -12.9059 23.6756 H 0 0 0 0 0
|
78 |
+
13.7904 -12.0118 16.3885 H 0 0 0 0 0
|
79 |
+
16.5871 -11.8550 15.6237 H 0 0 0 0 0
|
80 |
+
16.1623 -14.1864 16.2602 H 0 0 0 0 0
|
81 |
+
16.9681 -13.8812 14.6864 H 0 0 0 0 0
|
82 |
+
14.0613 -14.5994 15.1616 H 0 0 0 0 0
|
83 |
+
15.3317 -15.6133 14.3772 H 0 0 0 0 0
|
84 |
+
13.7819 -14.5683 12.7368 H 0 0 0 0 0
|
85 |
+
15.4725 -14.0364 12.5291 H 0 0 0 0 0
|
86 |
+
13.2893 -12.5323 13.9983 H 0 0 0 0 0
|
87 |
+
13.9420 -12.1402 12.3742 H 0 0 0 0 0
|
88 |
+
16.1510 -11.7459 13.3467 H 0 0 0 0 0
|
89 |
+
14.9268 -10.7183 14.1449 H 0 0 0 0 0
|
90 |
+
2 1 1 0 0 0
|
91 |
+
1 9 1 0 0 0
|
92 |
+
1 11 1 0 0 0
|
93 |
+
3 2 1 0 0 0
|
94 |
+
8 3 1 0 0 0
|
95 |
+
4 5 1 0 0 0
|
96 |
+
4 8 1 0 0 0
|
97 |
+
5 6 1 0 0 0
|
98 |
+
6 7 1 0 0 0
|
99 |
+
6 13 1 0 0 0
|
100 |
+
6 14 1 0 0 0
|
101 |
+
7 9 1 0 0 0
|
102 |
+
7 10 2 0 0 0
|
103 |
+
8 9 1 0 0 0
|
104 |
+
11 12 2 0 0 0
|
105 |
+
11 21 1 0 0 0
|
106 |
+
14 15 1 0 0 0
|
107 |
+
15 16 4 0 0 0
|
108 |
+
15 20 4 0 0 0
|
109 |
+
16 17 4 0 0 0
|
110 |
+
17 18 4 0 0 0
|
111 |
+
18 19 4 0 0 0
|
112 |
+
19 20 4 0 0 0
|
113 |
+
21 22 1 0 0 0
|
114 |
+
22 23 1 0 0 0
|
115 |
+
22 25 1 0 0 0
|
116 |
+
23 24 1 0 0 0
|
117 |
+
23 31 1 0 0 0
|
118 |
+
25 26 1 0 0 0
|
119 |
+
26 27 1 0 0 0
|
120 |
+
27 28 1 0 0 0
|
121 |
+
28 29 1 0 0 0
|
122 |
+
31 30 1 0 0 0
|
123 |
+
30 33 1 0 0 0
|
124 |
+
31 32 2 0 0 0
|
125 |
+
33 34 1 0 0 0
|
126 |
+
33 38 1 0 0 0
|
127 |
+
34 35 1 0 0 0
|
128 |
+
35 36 1 0 0 0
|
129 |
+
36 37 1 0 0 0
|
130 |
+
37 38 1 0 0 0
|
131 |
+
1 39 1 0 0 0
|
132 |
+
2 40 1 0 0 0
|
133 |
+
2 41 1 0 0 0
|
134 |
+
3 42 1 0 0 0
|
135 |
+
3 43 1 0 0 0
|
136 |
+
4 44 1 0 0 0
|
137 |
+
4 45 1 0 0 0
|
138 |
+
5 46 1 0 0 0
|
139 |
+
5 47 1 0 0 0
|
140 |
+
8 48 1 0 0 0
|
141 |
+
13 49 1 0 0 0
|
142 |
+
13 50 1 0 0 0
|
143 |
+
13 51 1 0 0 0
|
144 |
+
14 52 1 0 0 0
|
145 |
+
14 53 1 0 0 0
|
146 |
+
16 54 1 0 0 0
|
147 |
+
17 55 1 0 0 0
|
148 |
+
18 56 1 0 0 0
|
149 |
+
19 57 1 0 0 0
|
150 |
+
20 58 1 0 0 0
|
151 |
+
21 59 1 0 0 0
|
152 |
+
22 60 1 0 0 0
|
153 |
+
23 61 1 0 0 0
|
154 |
+
24 62 1 0 0 0
|
155 |
+
25 63 1 0 0 0
|
156 |
+
25 64 1 0 0 0
|
157 |
+
26 65 1 0 0 0
|
158 |
+
26 66 1 0 0 0
|
159 |
+
27 67 1 0 0 0
|
160 |
+
27 68 1 0 0 0
|
161 |
+
28 69 1 0 0 0
|
162 |
+
28 70 1 0 0 0
|
163 |
+
29 71 1 0 0 0
|
164 |
+
29 72 1 0 0 0
|
165 |
+
29 73 1 0 0 0
|
166 |
+
30 74 1 0 0 0
|
167 |
+
33 75 1 0 0 0
|
168 |
+
34 76 1 0 0 0
|
169 |
+
34 77 1 0 0 0
|
170 |
+
35 78 1 0 0 0
|
171 |
+
35 79 1 0 0 0
|
172 |
+
36 80 1 0 0 0
|
173 |
+
36 81 1 0 0 0
|
174 |
+
37 82 1 0 0 0
|
175 |
+
37 83 1 0 0 0
|
176 |
+
38 84 1 0 0 0
|
177 |
+
38 85 1 0 0 0
|
178 |
+
M END
|
179 |
+
$$$$
|
examples/1a46_protein_processed.pdb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/1cbr_ligand.sdf
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1cbr_ligand
|
2 |
+
|
3 |
+
Created by X-TOOL on Fri Nov 18 12:01:53 2016
|
4 |
+
49 49 0 0 0 0 0 0 0 0999 V2000
|
5 |
+
5.0920 2.4270 -10.7940 C 0 0 0 1 0 4
|
6 |
+
6.0790 1.2390 -10.8790 C 0 0 0 3 0 4
|
7 |
+
7.4570 1.5880 -11.3400 C 0 0 0 3 0 4
|
8 |
+
8.1090 2.6160 -10.4790 C 0 0 0 3 0 4
|
9 |
+
7.1710 3.7700 -10.1040 C 0 0 0 1 0 3
|
10 |
+
5.8090 3.6640 -10.1590 C 0 0 0 1 0 3
|
11 |
+
4.8670 4.7410 -9.7870 C 0 0 0 2 0 3
|
12 |
+
5.0090 5.6850 -8.8490 C 0 0 0 2 0 3
|
13 |
+
4.0490 6.7120 -8.5120 C 0 0 0 1 0 3
|
14 |
+
4.3830 7.6020 -7.5550 C 0 0 0 2 0 3
|
15 |
+
3.5130 8.6700 -7.1050 C 0 0 0 2 0 3
|
16 |
+
3.9620 9.5090 -6.1670 C 0 0 0 2 0 3
|
17 |
+
3.1640 10.5920 -5.6370 C 0 0 0 1 0 3
|
18 |
+
3.7030 11.3990 -4.6890 C 0 0 0 2 0 3
|
19 |
+
3.0710 12.5430 -4.0160 C 0 5 0 1 0 3
|
20 |
+
3.9070 2.0000 -9.9190 C 0 0 0 4 0 4
|
21 |
+
4.5820 2.7980 -12.2130 C 0 0 0 4 0 4
|
22 |
+
7.9800 4.9390 -9.5360 C 0 0 0 4 0 4
|
23 |
+
2.7160 6.8010 -9.2660 C 0 0 0 4 0 4
|
24 |
+
1.7300 10.7780 -6.1620 C 0 0 0 4 0 4
|
25 |
+
2.5240 13.4330 -4.7040 O 0 0 0 1 0 1
|
26 |
+
3.0900 12.6020 -2.7660 O 0 0 0 1 0 1
|
27 |
+
5.6628 0.5003 -11.5797 H 0 0 0 1 0 1
|
28 |
+
6.1586 0.7905 -9.8778 H 0 0 0 1 0 1
|
29 |
+
7.3965 1.9765 -12.3673 H 0 0 0 1 0 1
|
30 |
+
8.0733 0.6769 -11.3282 H 0 0 0 1 0 1
|
31 |
+
8.9730 3.0290 -11.0202 H 0 0 0 1 0 1
|
32 |
+
8.4536 2.1305 -9.5541 H 0 0 0 1 0 1
|
33 |
+
3.9353 4.7700 -10.3501 H 0 0 0 1 0 1
|
34 |
+
5.9398 5.6789 -8.2837 H 0 0 0 1 0 1
|
35 |
+
5.3651 7.5126 -7.0930 H 0 0 0 1 0 1
|
36 |
+
2.5140 8.7864 -7.5226 H 0 0 0 1 0 1
|
37 |
+
4.9725 9.3712 -5.7852 H 0 0 0 1 0 1
|
38 |
+
4.7256 11.1723 -4.3911 H 0 0 0 1 0 1
|
39 |
+
3.1893 2.8302 -9.8432 H 0 0 0 1 0 1
|
40 |
+
4.2693 1.7357 -8.9146 H 0 0 0 1 0 1
|
41 |
+
3.4124 1.1280 -10.3717 H 0 0 0 1 0 1
|
42 |
+
5.4325 3.1041 -12.8399 H 0 0 0 1 0 1
|
43 |
+
3.8636 3.6277 -12.1392 H 0 0 0 1 0 1
|
44 |
+
4.0887 1.9250 -12.6652 H 0 0 0 1 0 1
|
45 |
+
7.2992 5.7611 -9.2702 H 0 0 0 1 0 1
|
46 |
+
8.6994 5.2884 -10.2913 H 0 0 0 1 0 1
|
47 |
+
8.5226 4.6076 -8.6384 H 0 0 0 1 0 1
|
48 |
+
2.6523 5.9808 -9.9962 H 0 0 0 1 0 1
|
49 |
+
2.6558 7.7653 -9.7917 H 0 0 0 1 0 1
|
50 |
+
1.8841 6.7206 -8.5508 H 0 0 0 1 0 1
|
51 |
+
1.5151 10.0113 -6.9209 H 0 0 0 1 0 1
|
52 |
+
1.6308 11.7769 -6.6117 H 0 0 0 1 0 1
|
53 |
+
1.0187 10.6787 -5.3288 H 0 0 0 1 0 1
|
54 |
+
1 2 1 0 0 1
|
55 |
+
1 6 1 0 0 1
|
56 |
+
1 16 1 0 0 2
|
57 |
+
1 17 1 0 0 2
|
58 |
+
2 3 1 0 0 1
|
59 |
+
3 4 1 0 0 1
|
60 |
+
4 5 1 0 0 1
|
61 |
+
5 6 2 0 0 1
|
62 |
+
5 18 1 0 0 2
|
63 |
+
6 7 1 0 0 2
|
64 |
+
7 8 2 0 0 2
|
65 |
+
8 9 1 0 0 2
|
66 |
+
9 10 2 0 0 2
|
67 |
+
9 19 1 0 0 2
|
68 |
+
10 11 1 0 0 2
|
69 |
+
11 12 2 0 0 2
|
70 |
+
12 13 1 0 0 2
|
71 |
+
13 14 2 0 0 2
|
72 |
+
13 20 1 0 0 2
|
73 |
+
14 15 1 0 0 2
|
74 |
+
15 21 2 0 0 2
|
75 |
+
15 22 2 0 0 2
|
76 |
+
2 23 1 0 0 2
|
77 |
+
2 24 1 0 0 2
|
78 |
+
3 25 1 0 0 2
|
79 |
+
3 26 1 0 0 2
|
80 |
+
4 27 1 0 0 2
|
81 |
+
4 28 1 0 0 2
|
82 |
+
7 29 1 0 0 2
|
83 |
+
8 30 1 0 0 2
|
84 |
+
10 31 1 0 0 2
|
85 |
+
11 32 1 0 0 2
|
86 |
+
12 33 1 0 0 2
|
87 |
+
14 34 1 0 0 2
|
88 |
+
16 35 1 0 0 2
|
89 |
+
16 36 1 0 0 2
|
90 |
+
16 37 1 0 0 2
|
91 |
+
17 38 1 0 0 2
|
92 |
+
17 39 1 0 0 2
|
93 |
+
17 40 1 0 0 2
|
94 |
+
18 41 1 0 0 2
|
95 |
+
18 42 1 0 0 2
|
96 |
+
18 43 1 0 0 2
|
97 |
+
19 44 1 0 0 2
|
98 |
+
19 45 1 0 0 2
|
99 |
+
19 46 1 0 0 2
|
100 |
+
20 47 1 0 0 2
|
101 |
+
20 48 1 0 0 2
|
102 |
+
20 49 1 0 0 2
|
103 |
+
M END
|
104 |
+
> <MOLECULAR_FORMULA>
|
105 |
+
C20H27O2
|
106 |
+
|
107 |
+
> <MOLECULAR_WEIGHT>
|
108 |
+
299.2
|
109 |
+
|
110 |
+
> <NUM_HB_ATOMS>
|
111 |
+
2
|
112 |
+
|
113 |
+
> <NUM_ROTOR>
|
114 |
+
0
|
115 |
+
|
116 |
+
> <XLOGP2>
|
117 |
+
3.40
|
118 |
+
|
119 |
+
$$$$
|
examples/1cbr_protein.pdb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
biopandas==0.4.1
|
2 |
+
biopython==1.79
|
3 |
+
e3nn==0.5.0
|
4 |
+
jinja2==3.1.2
|
5 |
+
joblib==1.2.0
|
6 |
+
markupsafe==2.1.1
|
7 |
+
mpmath==1.2.1
|
8 |
+
networkx==2.8.7
|
9 |
+
opt-einsum==3.3.0
|
10 |
+
opt-einsum-fx==0.1.4
|
11 |
+
packaging==21.3
|
12 |
+
pandas==1.5.0
|
13 |
+
scikit-learn==1.1.2
|
14 |
+
scipy==1.9.1
|
15 |
+
spyrmsd==0.5.2
|
16 |
+
sympy==1.11.1
|
17 |
+
spyrmsd==0.5.2
|
18 |
+
sympy==1.11.1
|
19 |
+
pytorch==1.12.1
|
20 |
+
numpy==1.23.1
|
21 |
+
torchaudio=0.12.1
|
22 |
+
torchvision=0.13.1
|
23 |
+
rdkit-pypi==2022.3.5
|
24 |
+
torch-scatter
|
25 |
+
torch-sparse
|
26 |
+
torch-cluster
|
27 |
+
torch-spline-conv
|
28 |
+
torch-geometric
|
29 |
+
-f https://data.pyg.org/whl/torch-1.12.0+cu102.html
|