libokj commited on
Commit
b0ea08b
·
1 Parent(s): 88a8966

Fix conflicts

Browse files
configs/gen_fbdd_v1.yaml DELETED
@@ -1,42 +0,0 @@
1
- score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
2
- confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
3
- docking_batch_size: 20
4
- #confidence_model_dir: ./workdir/v1.1/confidence_model
5
- different_schedules: false
6
- inf_sched_alpha: 1
7
- inf_sched_beta: 1
8
- inference_steps: 20
9
- initial_noise_std_proportion: 1.4601642460337794
10
- limit_failures: 5
11
- #model_dir: ./workdir/v1.1/score_model
12
- #comment
13
- no_final_step_noise: true
14
- no_model: false
15
- no_random: false
16
- no_random_pocket: false
17
- ode: false
18
- old_filtering_model: true
19
- old_score_model: false
20
- resample_rdkit: false
21
- samples_per_complex: 10
22
- sigma_schedule: expbeta
23
- temp_psi_rot: 0.9022615585677628
24
- temp_psi_tor: 0.5946212391366862
25
- temp_psi_tr: 0.727287304570729
26
- temp_sampling_rot: 2.06391612594481
27
- temp_sampling_tor: 7.044261621607846
28
- temp_sampling_tr: 1.170050527854316
29
- temp_sigma_data_rot: 0.7464326999906034
30
- temp_sigma_data_tor: 0.6943254174849822
31
- temp_sigma_data_tr: 0.9299802531572672
32
-
33
- rmsd_threshold: 1.5
34
-
35
- linker_ckpt:
36
- pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
37
- pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
38
- geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
39
- size_ckpt: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
40
- linker_condition: 'none' # pocket
41
- linker_batch_size: 64
42
- linker_steps: 1000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gen_fbdd_v1.yaml DELETED
@@ -1,44 +0,0 @@
1
- actual_steps: 19
2
- score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
3
- confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
4
- #confidence_model_dir: ./workdir/v1.1/confidence_model
5
- different_schedules: false
6
- inf_sched_alpha: 1
7
- inf_sched_beta: 1
8
- inference_steps: 20
9
- initial_noise_std_proportion: 1.4601642460337794
10
- limit_failures: 5
11
- #model_dir: ./workdir/v1.1/score_model
12
- #comment
13
- no_final_step_noise: true
14
- no_model: false
15
- no_random: false
16
- no_random_pocket: false
17
- ode: false
18
- old_filtering_model: true
19
- old_score_model: false
20
- resample_rdkit: false
21
- samples_per_complex: 10
22
- sigma_schedule: expbeta
23
- temp_psi_rot: 0.9022615585677628
24
- temp_psi_tor: 0.5946212391366862
25
- temp_psi_tr: 0.727287304570729
26
- temp_sampling_rot: 2.06391612594481
27
- temp_sampling_tor: 7.044261621607846
28
- temp_sampling_tr: 1.170050527854316
29
- temp_sigma_data_rot: 0.7464326999906034
30
- temp_sigma_data_tor: 0.6943254174849822
31
- temp_sigma_data_tr: 0.9299802531572672
32
-
33
- eps: 5
34
- min_samples: 10
35
- frag_dist_range: [2, 5]
36
- distance_type: min
37
-
38
- linker_ckpt:
39
- pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
40
- pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
41
- geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
42
- linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
43
- linker_condition: 'none' # pocket
44
- linker_batch_size: 64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py.bak DELETED
@@ -1,940 +0,0 @@
1
- import logging
2
- import subprocess
3
- import sys
4
- from argparse import ArgumentParser, Namespace, FileType
5
- import copy
6
- import itertools
7
- import os
8
- from datetime import datetime
9
- from pathlib import Path
10
- from functools import partial, cache
11
- import warnings
12
- import yaml
13
- from Bio.PDB import PDBParser
14
- from sklearn.cluster import DBSCAN
15
-
16
- from src import const
17
- from src.datasets import (
18
- collate_with_fragment_without_pocket_edges, get_dataloader, get_one_hot, parse_molecule, ProteinConditionedDataset
19
- )
20
- from src.lightning import DDPM
21
- from src.linker_size_lightning import SizeClassifier
22
- from src.utils import set_deterministic, FoundNaNException
23
- from src.visualizer import save_sdf
24
-
25
- # Ignore pandas deprecation warning around pyarrow
26
- warnings.filterwarnings("ignore", category=DeprecationWarning,
27
- message="(?s).*Pyarrow will become a required dependency of pandas.*")
28
- import numpy as np
29
- import pandas as pd
30
- from pandarallel import pandarallel
31
- import torch
32
- from torch_geometric.loader import DataLoader
33
-
34
- from Bio import SeqIO
35
- from rdkit import RDLogger, Chem
36
- from rdkit.Chem import RemoveAllHs
37
-
38
- # TODO imports are a little odd, utils seems to shadow things
39
- from utils.logging_utils import configure_logger, get_logger
40
- from datasets.process_mols import create_mol_with_coords, read_molecule
41
- from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule
42
- from utils.inference_utils import InferenceDataset
43
- from utils.sampling import randomize_position, sampling
44
- from utils.utils import get_model
45
- from utils.visualise import PDBFile
46
- from tqdm import tqdm
47
-
48
- RDLogger.DisableLog('rdApp.*')
49
- warnings.filterwarnings("ignore", category=UserWarning,
50
- message="The TorchScript type system doesn't support instance-level annotations on"
51
- " empty non-base types in `__init__`")
52
-
53
- # Prody logging is very verbose by default
54
- prody_logger = logging.getLogger(".prody")
55
- prody_logger.setLevel(logging.ERROR)
56
-
57
- # Pandarallel initialization
58
- nb_workers = os.cpu_count()
59
- progress_bar = False
60
- if hasattr(sys, 'gettrace') and sys.gettrace() is not None: # Debug mode
61
- nb_workers = 1
62
- progress_bar = True
63
- pandarallel.initialize(nb_workers=nb_workers, progress_bar=progress_bar)
64
-
65
-
66
- def read_compound_library(file_path):
67
- df = None
68
- if file_path.suffix == '.csv':
69
- df = pd.read_csv(file_path)
70
- elif file_path.suffix == '.sdf':
71
- supplier = Chem.SDMolSupplier(file_path, sanitize=False, removeHs=False)
72
- # Convert to a dataframe
73
- df = pd.DataFrame([{'X1': Chem.MolToSmiles(mol), 'ID1': mol.GetProp('_Name')} for mol in supplier])
74
- # Use InChiKey as ID1 if None
75
- df.loc[df['ID1'].isna(), 'ID1'] = df.loc[
76
- df['ID1'].isna(), 'X1'
77
- ].apply(Chem.MolFromSmiles).apply(Chem.MolToInchiKey)
78
-
79
- return df
80
-
81
- def read_protein_library(file_path):
82
- df = None
83
- if file_path.suffix == '.csv':
84
- df = pd.read_csv(file_path)
85
- elif file_path.suffix == '.fasta':
86
- records = list(SeqIO.parse(file_path, 'fasta'))
87
- df = pd.DataFrame([{'X2': str(record.seq), 'ID2': record.id} for record in records])
88
-
89
- return df
90
-
91
- def process_fragment_library(df):
92
- """
93
- SMILES strings with separators (e.g., .) represent distinct molecular entities, such as ligands, ions, or
94
- co-crystallized molecules. Splitting them ensures that each entity is treated individually, allowing focused
95
- analysis of their roles in binding. Single atom fragments (e.g., counterions like [I-] or [Cl-] are irrelevant in
96
- docking and are to be removed. This filtering focuses on structurally relevant fragments.
97
- """
98
- # Get subset of rows with SMILES containing separators
99
- fragmented_rows = df['X1'].str.contains('.', regex=False)
100
- df_fragmented = df[fragmented_rows].copy()
101
-
102
- # Split SMILES into lists and expand
103
- df_fragmented['X1'] = df_fragmented['X1'].str.split('.')
104
- df_fragmented = df_fragmented.explode('X1').reset_index(drop=True)
105
-
106
- # Append fragment index as alphabet (A, B, C... AA, AB...) to ID1 for rows with fragmented SMILES
107
- df_fragmented['ID1'] = df_fragmented.groupby('ID1').cumcount().apply(num_to_letter_code).radd(
108
- df_fragmented['ID1'] + '_')
109
- df = pd.concat([df[~fragmented_rows], df_fragmented]).sort_index().reset_index(drop=True)
110
- df['mol'] = df['X1'].apply(read_molecule, remove_confs=True)
111
- df = df.dropna(subset=['mol'])
112
-
113
- # # Remove fragments with no carbon atoms
114
- # df = df[df['mol'].swifter.apply(lambda mol: any(atom.GetSymbol() == 'C' for atom in mol.GetAtoms()))]
115
- # Remove single-atom fragments
116
- df = df[df['mol'].apply(lambda mol: mol.GetNumAtoms() > 1)]
117
- # Canonicalize SMILES
118
- df['X1'] = df['mol'].apply(lambda x: Chem.MolToSmiles(x))
119
-
120
- return df
121
-
122
-
123
- def check_one_to_one(df, ID_column, X_column):
124
- # Check for multiple X values for the same ID
125
- id_to_x_conflicts = df.groupby(ID_column)[X_column].nunique()
126
- conflicting_ids = id_to_x_conflicts[id_to_x_conflicts > 1]
127
-
128
- # Check for multiple ID values for the same X
129
- x_to_id_conflicts = df.groupby(X_column)[ID_column].nunique()
130
- conflicting_xs = x_to_id_conflicts[x_to_id_conflicts > 1]
131
-
132
- # Print conflicting mappings
133
- if not conflicting_ids.empty:
134
- print(f"Conflicting {ID_column} -> multiple {X_column}:")
135
- for idx in conflicting_ids.index:
136
- print(f"{ID_column}: {idx}, {X_column} values: {df[df[ID_column] == idx][X_column].unique()}")
137
-
138
- if not conflicting_xs.empty:
139
- print(f"Conflicting {X_column} -> multiple {ID_column}:")
140
- for x in conflicting_xs.index:
141
- print(f"{X_column}: {x}, {ID_column} values: {df[df[X_column] == x][ID_column].unique()}")
142
-
143
- # Return whether the mappings are one-to-one
144
- return conflicting_ids.empty and conflicting_xs.empty
145
-
146
-
147
- def num_to_letter_code(n):
148
- result = ''
149
- while n >= 0:
150
- result = chr(65 + (n % 26)) + result
151
- n = n // 26 - 1
152
- return result
153
-
154
-
155
- def dock_fragments(args):
156
- with open(Path(args.score_ckpt).parent / 'model_parameters.yml') as f:
157
- score_model_args = Namespace(**yaml.full_load(f))
158
- if args.confidence_ckpt is not None:
159
- with open(Path(args.confidence_ckpt).parent / 'model_parameters.yml') as f:
160
- confidence_args = Namespace(**yaml.full_load(f))
161
- log.info(f"DiffFragDock will run on {device}")
162
-
163
- docking_out_dir = Path(args.out_dir, 'docking')
164
- docking_out_dir.mkdir(parents=True, exist_ok=True)
165
- if args.protein_ligand_csv is not None:
166
- csv_path = Path(args.protein_ligand_csv)
167
- assert csv_path.is_file(), f"File {args.protein_ligand_csv} does not exist"
168
- df = pd.read_csv(csv_path)
169
- df = process_fragment_library(df)
170
- else:
171
- assert args.X1 is not None and args.X2 is not None, "Either a .csv file or `X1` and `X2` must be provided."
172
-
173
- compound_df = pd.DataFrame(columns=['X1', 'ID1'])
174
- if Path(args.X1).is_file():
175
- compound_path = Path(args.X1)
176
- if compound_path.suffix in ['.csv', '.sdf']:
177
- compound_df[['X1', 'ID1']] = read_compound_library(compound_path)[['X1', 'ID1']]
178
- else:
179
- compound_df['X1'] = [compound_path]
180
- compound_df['ID1'] = [compound_path.stem]
181
- else:
182
- compound_df['X1'] = [args.X1]
183
- compound_df['ID1'] = 'compound_0'
184
- compound_df.dropna(subset=['X1'], inplace=True)
185
- compound_df.loc[compound_df['ID1'].isna(), 'ID1'] = compound_df.loc[compound_df['ID1'].isna(), 'X1'].apply(
186
- lambda x: Chem.MolToInchiKey(Chem.MolFromSmiles(x))
187
- )
188
-
189
- protein_df = pd.DataFrame(columns=['X2', 'ID2'])
190
- if Path(args.X2).is_file():
191
- protein_path = Path(args.X2)
192
- if protein_path.suffix in ['.csv', '.fasta']:
193
- protein_df[['X2', 'ID2']] = read_protein_library(protein_path)[['X2', 'ID2']]
194
- else:
195
- protein_df['protein_path'] = [protein_path]
196
- protein_df['ID2'] = [protein_path.stem]
197
- else:
198
- protein_df['X2'] = [args.X2]
199
- protein_df['ID2'] = 'protein_0'
200
- protein_df.dropna(subset=['X2'], inplace=True)
201
- protein_df.loc[protein_df['ID2'].isna(), 'ID2'] = [
202
- f"protein_{i}" for i in range(protein_df['ID2'].isna().sum())
203
- ]
204
-
205
- compound_df = process_fragment_library(compound_df)
206
- df = compound_df.merge(protein_df, how='cross')
207
-
208
- # Identify duplicates based on 'X1' and 'X2'
209
- duplicates = df[df.duplicated(subset=['X1', 'X2'], keep=False)]
210
- if not duplicates.empty:
211
- print("Duplicate rows based on columns 'X1' and 'X2':\n", duplicates[['ID1', 'X1', 'ID2', 'X2']])
212
- print("Keeping the first occurrence of each duplicate.")
213
- df = df.drop_duplicates(subset=['X1', 'X2'])
214
- df['name'] = df['ID2'] + '-' + df['ID1']
215
-
216
- df = df.replace({pd.NA: None})
217
- # Check unique mappings between IDn and Xn
218
- assert check_one_to_one(df, 'ID1', 'X1'), "ID1-X1 mapping is not one-to-one."
219
- assert check_one_to_one(df, 'ID2', 'X2'), "ID2-X2 mapping is not one-to-one."
220
-
221
- """
222
- Docking phase
223
- """
224
-
225
- # preprocessing of complexes into geometric graphs
226
- test_dataset = InferenceDataset(
227
- df=df, out_dir=args.out_dir,
228
- lm_embeddings=True,
229
- receptor_radius=score_model_args.receptor_radius,
230
- remove_hs=True, # score_model_args.remove_hs,
231
- c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors,
232
- all_atoms=score_model_args.all_atoms, atom_radius=score_model_args.atom_radius,
233
- atom_max_neighbors=score_model_args.atom_max_neighbors,
234
- knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph')
235
- else not score_model_args.not_knn_only_graph
236
- )
237
- test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
238
-
239
- if args.confidence_ckpt is not None and not confidence_args.use_original_model_cache:
240
- log.info('Confidence model uses different type of graphs than the score model. '
241
- 'Loading (or creating if not existing) the data for the confidence model now.')
242
- confidence_test_dataset = InferenceDataset(
243
- df=df, out_dir=args.out_dir,
244
- lm_embeddings=True,
245
- receptor_radius=confidence_args.receptor_radius,
246
- remove_hs=True, # confidence_args.remove_hs,
247
- c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors,
248
- all_atoms=confidence_args.all_atoms,
249
- atom_radius=confidence_args.atom_radius,
250
- atom_max_neighbors=confidence_args.atom_max_neighbors,
251
- precomputed_lm_embeddings=test_dataset.lm_embeddings,
252
- knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph')
253
- else not score_model_args.not_knn_only_graph
254
- )
255
- else:
256
- confidence_test_dataset = None
257
-
258
- t_to_sigma = partial(t_to_sigma_compl, args=score_model_args)
259
-
260
- model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True, old=args.old_score_model)
261
- state_dict = torch.load(Path(args.score_ckpt), map_location='cpu', weights_only=True)
262
- model.load_state_dict(state_dict, strict=True)
263
- model = model.to(device)
264
- model.eval()
265
-
266
- if args.confidence_ckpt is not None:
267
- confidence_model = get_model(confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True,
268
- confidence_mode=True, old=args.old_confidence_model)
269
- state_dict = torch.load(Path(args.confidence_ckpt), map_location='cpu', weights_only=True)
270
- confidence_model.load_state_dict(state_dict, strict=True)
271
- confidence_model = confidence_model.to(device)
272
- confidence_model.eval()
273
- else:
274
- confidence_model = None
275
- confidence_args = None
276
-
277
- tr_schedule = get_t_schedule(inference_steps=args.inference_steps, sigma_schedule='expbeta')
278
-
279
- failures, skipped = 0, 0
280
- samples_per_complex = args.samples_per_complex
281
- test_ds_size = len(test_dataset)
282
- df = test_loader.dataset.df
283
- docking_dfs = []
284
- log.info(f'Size of fragment dataset: {test_ds_size}')
285
- for idx, orig_complex_graph in tqdm(enumerate(test_loader), total=test_ds_size):
286
- if not orig_complex_graph.success[0]:
287
- skipped += 1
288
- log.warning(
289
- f"The test dataset did not contain {df['name'].iloc[idx]}"
290
- f" for {df['X1'].iloc[idx]} and {df['X2'].iloc[idx]}. We are skipping this complex.")
291
- continue
292
- try:
293
- if confidence_test_dataset is not None:
294
- confidence_complex_graph = confidence_test_dataset[idx]
295
- if not confidence_complex_graph.success:
296
- skipped += 1
297
- log.warning(
298
- f"The confidence dataset did not contain {orig_complex_graph.name}. We are skipping this complex.")
299
- continue
300
- confidence_data_list = [copy.deepcopy(confidence_complex_graph) for _ in range(samples_per_complex)]
301
- else:
302
- confidence_data_list = None
303
- data_list = [copy.deepcopy(orig_complex_graph) for _ in range(samples_per_complex)]
304
- randomize_position(data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max,
305
- initial_noise_std_proportion=args.initial_noise_std_proportion,
306
- choose_residue=args.choose_residue)
307
-
308
- lig = orig_complex_graph.mol[0]
309
-
310
- # initialize visualisation
311
- if args.save_visualisation:
312
- visualization_list = []
313
- for graph in data_list:
314
- pdb = PDBFile(lig)
315
- pdb.add(lig, 0, 0)
316
- pdb.add((orig_complex_graph['ligand'].pos + orig_complex_graph.original_center).detach().cpu(), 1,
317
- 0)
318
- pdb.add((graph['ligand'].pos + graph.original_center).detach().cpu(), part=1, order=1)
319
- visualization_list.append(pdb)
320
- else:
321
- visualization_list = None
322
-
323
- # run reverse diffusion
324
- data_list, confidence = sampling(data_list=data_list, model=model,
325
- inference_steps=args.actual_steps if args.actual_steps is not None
326
- else args.inference_steps,
327
- tr_schedule=tr_schedule, rot_schedule=tr_schedule,
328
- tor_schedule=tr_schedule,
329
- device=device, t_to_sigma=t_to_sigma, model_args=score_model_args,
330
- visualization_list=visualization_list, confidence_model=confidence_model,
331
- confidence_data_list=confidence_data_list,
332
- confidence_model_args=confidence_args,
333
- batch_size=args.n_poses, no_final_step_noise=args.no_final_step_noise,
334
- temp_sampling=[args.temp_sampling_tr, args.temp_sampling_rot,
335
- args.temp_sampling_tor],
336
- temp_psi=[args.temp_psi_tr, args.temp_psi_rot, args.temp_psi_tor],
337
- temp_sigma_data=[args.temp_sigma_data_tr, args.temp_sigma_data_rot,
338
- args.temp_sigma_data_tor])
339
-
340
- ligand_pos = np.asarray(
341
- [complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy() for
342
- complex_graph in data_list]
343
- )
344
-
345
- # save predictions
346
- n_samples = len(confidence)
347
- sample_df = pd.DataFrame([df.iloc[idx]] * n_samples)
348
- confidence = confidence[:, 0].cpu().numpy()
349
- sample_df['confidence'] = confidence
350
- if args.save_docking:
351
- sample_df['path'] = [
352
- Path(
353
- docking_out_dir, f"{df['name'].iloc[idx]}-confidence{confidence[i]:.2f}.sdf"
354
- ) for i in range(n_samples)
355
- ]
356
- sample_df['ligand_mol']= [
357
- create_mol_with_coords(
358
- mol=RemoveAllHs(copy.deepcopy(lig)),
359
- new_coords=pos,
360
- path=sample_df['path'].iloc[i] if args.save_docking else None
361
- ) for i, pos in enumerate(ligand_pos)
362
- ]
363
- # sample_df['ligand_pos'] = list(ligand_pos)
364
- docking_dfs.append(sample_df)
365
-
366
- # write_dir = f"{args.out_dir}/{df['name'].iloc[idx]}"
367
- # for rank, pos in enumerate(ligand_pos):
368
- # mol_pred = copy.deepcopy(lig)
369
- # if score_model_args.remove_hs: mol_pred = RemoveAllHs(mol_pred)
370
- # if rank == 0: write_mol_with_coords(mol_pred, pos, Path(write_dir, f'rank{rank + 1}.sdf'))
371
- # write_mol_with_coords(mol_pred, pos,
372
- # Path(write_dir, f'rank{rank + 1}_confidence{confidence[rank]:.2f}.sdf'))
373
- # save visualisation frames
374
- # if args.save_visualisation:
375
- # if confidence is not None:
376
- # for rank, batch_idx in enumerate(re_order):
377
- # visualization_list[batch_idx].write(
378
- # Path(write_dir, f'rank{rank + 1}_reverseprocess.pdb'))
379
- # else:
380
- # for rank, batch_idx in enumerate(ligand_pos):
381
- # visualization_list[batch_idx].write(
382
- # Path(write_dir, f'rank{rank + 1}_reverseprocess.pdb'))
383
-
384
- except Exception as e:
385
- log.warning("Failed on", orig_complex_graph["name"], e)
386
- failures += 1
387
-
388
- # Tear down DiffDock models and datasets
389
- model.cpu()
390
- del model
391
- if confidence_model is not None:
392
- confidence_model.cpu()
393
- del confidence_model
394
- del test_dataset
395
- if confidence_test_dataset is not None:
396
- del confidence_test_dataset
397
- del test_loader
398
-
399
- docking_df = pd.concat(docking_dfs, ignore_index=True)
400
- result_msg = f"""
401
- Failed for {failures} / {test_ds_size} complexes.
402
- Skipped {skipped} / {test_ds_size} complexes.
403
- """
404
- if failures or skipped:
405
- log.warning(result_msg)
406
- else:
407
- log.info(result_msg)
408
- log.info(f"Results saved in {docking_out_dir}")
409
-
410
- return docking_df
411
-
412
-
413
- def calculate_mol_atomic_distances(mol1, mol2, distance_type='min'):
414
- mol1_coords = [
415
- mol1.GetConformer().GetAtomPosition(i) for i in range(mol1.GetNumAtoms())
416
- ]
417
- mol2_coords = [
418
- mol2.GetConformer().GetAtomPosition(i) for i in range(mol2.GetNumAtoms())
419
- ]
420
- # Ensure numpy arrays
421
- mol1_coords = np.array(mol1_coords)
422
- mol2_coords = np.array(mol2_coords)
423
-
424
- # Compute pairwise distances between carbon atoms
425
- atom_pairwise_distances = np.linalg.norm(mol1_coords[:, None, :] - mol2_coords[None, :, :], axis=-1)
426
- # if np.any(np.isnan(atom_pairwise_distances)):
427
- # import pdb
428
- # pdb.set_trace() # Trigger a breakpoint if NaN is found
429
- if distance_type == 'min':
430
- return atom_pairwise_distances.min()
431
- elif distance_type == 'mean':
432
- return atom_pairwise_distances.mean()
433
- elif distance_type is None:
434
- return atom_pairwise_distances
435
- else:
436
- raise ValueError(f"Unsupported distance_type: {distance_type}")
437
-
438
-
439
- def process_docking_results(
440
- df,
441
- eps=5, # Distance threshold for DBSCAN clustering
442
- min_samples=5, # Minimum number of samples for a cluster (enrichment)
443
- frag_dist_range=(2, 5), # Distance range for fragment linking
444
- distance_type='min', # Type of distance to compute between fragments
445
- ):
446
- assert len(frag_dist_range) == 2, 'Distance range must be a tuple of two values in Angstroms (Å).'
447
- frag_dist_range = sorted(frag_dist_range)
448
-
449
- # The mols in df should have been processed to have no explicit hydrogens, except heavy hydrogen isotopes.
450
- docking_summaries = [] # For saving intermediate docking results
451
- fragment_combos = [] # Fragment pairs for the linking step
452
- # 1. Cluster docking poses
453
- # Compute pairwise distances of molecules defined by the closest non-heavy atoms
454
- for protein, protein_df in df.groupby('X2'):
455
- protein_id = protein_df['ID2'].iloc[0]
456
- protein_path = protein_df['protein_path'].iloc[0]
457
- protein_df['index'] = protein_df.index
458
-
459
- log.info(f'Processing docking results for {protein_id}...')
460
- protein_fragment_combos = []
461
- dist_matrix = np.stack(
462
- protein_df['ligand_mol'].parallel_apply(
463
- lambda mol1: [
464
- calculate_mol_atomic_distances(mol1, mol2, distance_type=distance_type)
465
- for mol2 in protein_df['ligand_mol']
466
- ]
467
- )
468
- )
469
-
470
- # Perform DBSCAN clustering
471
- dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='precomputed')
472
- protein_df['cluster'] = dbscan.fit_predict(dist_matrix)
473
- protein_df = protein_df.sort_values(
474
- by=['X1', 'cluster', 'confidence'], ascending=[True, True, False]
475
- )
476
- # Add conformer number to ID1
477
- protein_df.groupby('ID1').cumcount().astype(str).radd(protein_df['ID1'] + '_')
478
- if args.save_docking:
479
- docking_summaries.append(
480
- protein_df[['name', 'ID2', 'X2', 'ID1', 'X1', 'cluster', 'confidence', 'path']]
481
- )
482
- # Filter out outlier poses
483
- protein_df = protein_df[protein_df['cluster'] != -1]
484
- # Keep only the highest confidence pose per protein per ligand per cluster
485
- protein_df = protein_df.groupby(['X1', 'cluster']).first().reset_index()
486
-
487
- # 2. Create fragment-linking pairs
488
- for cluster, cluster_df in protein_df.groupby('cluster'):
489
- if len(cluster_df) > 1: # Skip clusters with only one pose
490
- pairs = list(itertools.combinations(cluster_df['index'], 2))
491
- for i, j in pairs:
492
- row1 = cluster_df[cluster_df['index'] == i].iloc[0]
493
- row2 = cluster_df[cluster_df['index'] == j].iloc[0]
494
- dist = dist_matrix[i, j]
495
- # Check if intermolecular distance is within the range
496
- if frag_dist_range[0] < dist < frag_dist_range[1]:
497
- combined_smiles = f"{row1['X1']}.{row2['X1']}"
498
- combined_mol = Chem.CombineMols(row1['ligand_mol'], row2['ligand_mol'])
499
- complex_name = f"{protein_id}-{row1['ID1']}-{row2['ID1']}"
500
- ligand_path = f"{row1['path']},{row2['path']}"
501
- protein_fragment_combos.append(
502
- (complex_name, protein, protein_path, combined_smiles, ligand_path, combined_mol, dist)
503
- )
504
- log.info(f'Number of fragment pairs for {protein_id}: {len(protein_fragment_combos)}.')
505
- fragment_combos.extend(protein_fragment_combos)
506
-
507
- # Save intermediate docking results
508
- if args.save_docking:
509
- docking_summary_df = pd.concat(docking_summaries, ignore_index=True)
510
- docking_summary_df.to_csv(Path(args.out_dir, 'docking_summary.csv'), index=False)
511
- log.info(f'Saved intermediate docking results to {args.out_dir}')
512
-
513
- # Convert fragment pair results to DataFrame
514
- if fragment_combos:
515
- linking_df = pd.DataFrame(
516
- fragment_combos, columns=['name', 'X2', 'protein_path', 'X1', 'ligand_path', 'ligand_mol', 'distance']
517
- )
518
- linking_df[
519
- ['name', 'X2', 'protein_path', 'X1', 'ligand_path', 'distance']
520
- ].to_csv(Path(args.out_dir, 'linking_summary.csv'), index=False)
521
- return linking_df
522
- else:
523
- raise ValueError('No eligible fragment pairs found for linking.')
524
-
525
- def get_pocket(mol, pdb_path, backbone_atoms_only=False):
526
- struct = PDBParser().get_structure('', pdb_path)
527
- residue_ids = []
528
- atom_coords = []
529
-
530
- for residue in struct.get_residues():
531
- resid = residue.get_id()[1]
532
- for atom in residue.get_atoms():
533
- atom_coords.append(atom.get_coord())
534
- residue_ids.append(resid)
535
-
536
- residue_ids = np.array(residue_ids)
537
- atom_coords = np.array(atom_coords)
538
- mol_atom_coords = mol.GetConformer().GetPositions()
539
-
540
- distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1)
541
- contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]])
542
-
543
- pocket_coords = []
544
- pocket_types = []
545
-
546
- for residue in struct.get_residues():
547
- resid = residue.get_id()[1]
548
- if resid not in contact_residues:
549
- continue
550
-
551
- for atom in residue.get_atoms():
552
- atom_name = atom.get_name()
553
- atom_type = atom.element.upper()
554
- atom_coord = atom.get_coord()
555
-
556
- if backbone_atoms_only and atom_name not in {'N', 'CA', 'C', 'O'}:
557
- continue
558
-
559
- pocket_coords.append(atom_coord.tolist())
560
- pocket_types.append(atom_type)
561
-
562
- pocket_pos = []
563
- pocket_one_hot = []
564
- pocket_charges = []
565
-
566
- for coord, atom_type in zip(pocket_coords, pocket_types):
567
- if atom_type not in const.GEOM_ATOM2IDX.keys():
568
- continue
569
-
570
- pocket_pos.append(coord)
571
- pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
572
- pocket_charges.append(const.GEOM_CHARGES[atom_type])
573
-
574
- pocket_pos = np.array(pocket_pos)
575
- pocket_one_hot = np.array(pocket_one_hot)
576
- pocket_charges = np.array(pocket_charges)
577
-
578
- return pocket_pos, pocket_one_hot, pocket_charges
579
-
580
-
581
- def generate_linker(
582
- df, backbone_atoms_only, model,
583
- output_dir, n_samples, n_steps, linker_size, anchors, max_batch_size, random_seed
584
- ):
585
- # Setup
586
- if random_seed is not None:
587
- set_deterministic(random_seed)
588
- output_dir = Path(output_dir, 'linking')
589
- output_dir.mkdir(exist_ok=True, parents=True)
590
-
591
- if linker_size.isdigit():
592
- print(f'Will generate linkers with {linker_size} atoms')
593
- linker_size = int(linker_size)
594
-
595
- def sample_fn(_data):
596
- return torch.ones(_data['positions'].shape[0], device=device, dtype=const.TORCH_INT) * linker_size
597
-
598
- else:
599
- boundaries = [x.strip() for x in linker_size.split(',')]
600
- if len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit():
601
- left = int(boundaries[0])
602
- right = int(boundaries[1])
603
- print(f'Will generate linkers with numbers of atoms sampled from U({left}, {right})')
604
-
605
- def sample_fn(_data):
606
- shape = len(_data['positions']),
607
- return torch.randint(left, right + 1, shape, device=device, dtype=const.TORCH_INT)
608
-
609
- else:
610
- print(f'Will generate linkers with sampled numbers of atoms')
611
- size_classifier = SizeClassifier.load_from_checkpoint(linker_size, map_location=device).eval().to(device)
612
-
613
- def sample_fn(_data):
614
- out, _ = size_classifier.forward(_data, return_loss=False, with_pocket=True, adjust_shape=True)
615
- probabilities = torch.softmax(out, dim=1)
616
- distribution = torch.distributions.Categorical(probs=probabilities)
617
- samples = distribution.sample()
618
- sizes = []
619
- for label in samples.detach().cpu().numpy():
620
- sizes.append(size_classifier.linker_id2size[label])
621
- sizes = torch.tensor(sizes, device=samples.device, dtype=const.TORCH_INT)
622
- return sizes
623
-
624
- if n_steps is not None:
625
- model.edm.T = n_steps
626
-
627
- if model.center_of_mass == 'anchors' and anchors is None:
628
- print(
629
- 'Please pass anchor atoms indices '
630
- 'or use another DiffLinker model that does not require information about anchors'
631
- )
632
- return
633
-
634
- cached_parse_molecule = cache(parse_molecule)
635
- dataset = []
636
- for i, row in df.iterrows():
637
- mol = row['ligand_mol'] # Hs already removed
638
- # Parsing fragments data
639
- frag_pos, frag_one_hot, frag_charges = cached_parse_molecule(mol, is_geom=ddpm.is_geom)
640
- # Parsing pocket data
641
- pocket_pos, pocket_one_hot, pocket_charges = get_pocket(mol, row['protein_path'], backbone_atoms_only)
642
-
643
- positions = np.concatenate([frag_pos, pocket_pos], axis=0)
644
- one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0)
645
- charges = np.concatenate([frag_charges, pocket_charges], axis=0)
646
- anchor_flags = np.zeros_like(charges)
647
- if anchors is not None:
648
- for anchor in anchors.split(','):
649
- anchor_flags[int(anchor.strip()) - 1] = 1
650
-
651
- fragment_only_mask = np.concatenate([
652
- np.ones_like(frag_charges),
653
- np.zeros_like(pocket_charges),
654
- ])
655
- pocket_mask = np.concatenate([
656
- np.zeros_like(frag_charges),
657
- np.ones_like(pocket_charges),
658
- ])
659
- linker_mask = np.concatenate([
660
- np.zeros_like(frag_charges),
661
- np.zeros_like(pocket_charges),
662
- ])
663
- fragment_mask = np.concatenate([
664
- np.ones_like(frag_charges),
665
- np.ones_like(pocket_charges),
666
- ])
667
-
668
- dataset.extend([{
669
- 'name': row['name'],
670
- 'X1': row['X1'],
671
- 'X2': row['X2'],
672
- 'protein_path': row['protein_path'],
673
- 'ligand_path': row['ligand_path'],
674
- 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
675
- 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
676
- 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
677
- 'anchors': torch.tensor(anchor_flags, dtype=const.TORCH_FLOAT, device=device),
678
- 'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
679
- 'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
680
- 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
681
- 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
682
- 'num_atoms': len(positions)
683
- }] * n_samples)
684
-
685
- dataset = ProteinConditionedDataset(data=dataset)
686
- ddpm.val_dataset = dataset
687
-
688
- global_batch_size = min(n_samples, max_batch_size)
689
- dataloader = get_dataloader(
690
- dataset, batch_size=global_batch_size, collate_fn=collate_with_fragment_without_pocket_edges
691
- )
692
-
693
- # df.drop(columns=['ligand_mol', 'protein_path'], inplace=True)
694
- linking_dfs = []
695
- # Sampling
696
- print('Sampling...')
697
- # TODO: update linking_summary.csv per batch
698
- for batch_i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
699
- effective_batch_size = len(data['positions'])
700
- complex_name = data['name'][0]
701
- batch_df = pd.DataFrame({
702
- 'name': data['name'],
703
- 'X1': data['X1'],
704
- 'X2': data['X2'],
705
- 'protein_path': data['protein_path'],
706
- 'ligand_path': data['ligand_path'],
707
- })
708
- chain = None
709
- node_mask = None
710
- for i in range(5):
711
- try:
712
- chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
713
- break
714
- except FoundNaNException:
715
- continue
716
- if chain is None:
717
- log.warning(f'Could not generate linker for {complex_name} in 5 attempts')
718
- continue
719
-
720
- x = chain[0][:, :, :ddpm.n_dims]
721
- h = chain[0][:, :, ddpm.n_dims:]
722
-
723
- # Put the molecule back to the initial orientation
724
- com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
725
- pos_masked = data['positions'] * com_mask
726
- N = com_mask.sum(1, keepdims=True)
727
- mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
728
- x = x + mean * node_mask
729
- node_mask[torch.where(data['pocket_mask'])] = 0
730
-
731
- batch_df['out_path'] = [Path(output_dir, f'{complex_name}_{i}.sdf') for i in range(effective_batch_size)]
732
- batch_df['one_hot'] = list(h.cpu())
733
- batch_df['positions'] = list(x.cpu())
734
- batch_df['node_mask'] = list(node_mask.cpu())
735
- batch_df['X1^'] = batch_df.parallel_apply(
736
- lambda row: save_sdf(
737
- row['out_path'], row['one_hot'], row['positions'], row['node_mask'], is_geom=ddpm.is_geom
738
- ), axis=1
739
- )
740
- linking_dfs.append(batch_df[['name', 'protein_path', 'X2', 'ligand_path', 'X1', 'X1^', 'out_path']])
741
- # for i in range(effective_batch_size):
742
- # # # Save XYZ file and generate SMILES
743
- # # out_xyz = Path(output_dir, f'{name}_{offset_idx+i}.xyz')
744
- # # smiles = save_xyz_files(out_xyz, h[i], x[i], node_mask[i], is_geom=ddpm.is_geom)
745
- # # # Convert XYZ to SDF
746
- # # out_sdf = Path(output_dir, name, f'output_{offset_idx+i}.sdf')
747
- # # with open(os.devnull, 'w') as devnull:
748
- # # subprocess.run(f'obabel {out_xyz} -O {out_sdf} -q', shell=True, stdout=devnull)
749
- # # Save SDF file and generate SMILES
750
- # out_sdf = Path(output_dir, f'{data["name"][i]}.sdf')
751
- # smiles = save_sdf(out_sdf, h[i], x[i], node_mask[i], is_geom=ddpm.is_geom)
752
- #
753
- # # Add experiment summary info
754
- # batch_df['X1^'] = smiles
755
- # batch_df['out_path'] = str(out_sdf)
756
- # linking_dfs.append(batch_df)
757
-
758
- if linking_dfs:
759
- linking_summary_df = pd.concat(linking_dfs, ignore_index=True)
760
- linking_summary_df.to_csv(Path(output_dir.parent, 'linking_summary.csv'), index=False)
761
- print(f'Saved experiment summary and generated molecules to {output_dir}')
762
- else:
763
- raise ValueError('No linkers generated.')
764
-
765
-
766
- if __name__ == "__main__":
767
- parser = ArgumentParser()
768
- # Fragment docking settings
769
- parser.add_argument('--config', type=FileType(mode='r'), default='default_inference_args.yaml')
770
- parser.add_argument('--protein_ligand_csv', type=str, default=None,
771
- help='Path to a .csv file specifying the input as described in the README. '
772
- 'If this is not None, it will be used instead of the `X1` and `X2` parameters')
773
- parser.add_argument('-n', '--name', type=str, default=None,
774
- help='Name that the experiment will be saved with')
775
- parser.add_argument('--X1', type=str,
776
- help='Either a SMILES string or the path of a molecule file that rdkit can read')
777
- parser.add_argument('--X2', type=str,
778
- help='Either a FASTA sequence or the path of a protein for ESMFold')
779
-
780
- parser.add_argument('-l', '--log', '--loglevel', type=str, default='INFO', dest="loglevel",
781
- help='Log level. Default %(default)s')
782
-
783
- parser.add_argument('--out_dir', type=str, default='results/',
784
- help='Directory where the outputs will be written to')
785
- parser.add_argument('--save_docking', action='store_true', default=True,
786
- help='Save the intermediate docking results including SDF files and a summary CSV.')
787
- parser.add_argument('--save_visualisation', action='store_true', default=False,
788
- help='Save a pdb file with all of the steps of the reverse diffusion')
789
- parser.add_argument('--samples_per_complex', type=int, default=10,
790
- help='Number of samples to generate')
791
-
792
- # parser.add_argument('--model_dir', type=str, default=None,
793
- # help='Path to folder with trained score model and hyperparameters')
794
- parser.add_argument('--score_ckpt', type=str, default='best_ema_inference_epoch_model.pt',
795
- help='Checkpoint to use for the score model')
796
- # parser.add_argument('--confidence_model_dir', type=str, default=None,
797
- # help='Path to folder with trained confidence model and hyperparameters')
798
- parser.add_argument('--confidence_ckpt', type=str, default='best_model.pt',
799
- help='Checkpoint to use for the confidence model')
800
-
801
- parser.add_argument('--n_poses', type=int, default=10, help='')
802
- parser.add_argument('--no_final_step_noise', action='store_true', default=True,
803
- help='Use no noise in the final step of the reverse diffusion')
804
- parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps')
805
- parser.add_argument('--actual_steps', type=int, default=None,
806
- help='Number of denoising steps that are actually performed')
807
-
808
- parser.add_argument('--old_score_model', action='store_true', default=False, help='')
809
- parser.add_argument('--old_confidence_model', action='store_true', default=True, help='')
810
- parser.add_argument('--initial_noise_std_proportion', type=float, default=-1.0,
811
- help='Initial noise std proportion')
812
- parser.add_argument('--choose_residue', action='store_true', default=False, help='')
813
-
814
- parser.add_argument('--temp_sampling_tr', type=float, default=1.0)
815
- parser.add_argument('--temp_psi_tr', type=float, default=0.0)
816
- parser.add_argument('--temp_sigma_data_tr', type=float, default=0.5)
817
- parser.add_argument('--temp_sampling_rot', type=float, default=1.0)
818
- parser.add_argument('--temp_psi_rot', type=float, default=0.0)
819
- parser.add_argument('--temp_sigma_data_rot', type=float, default=0.5)
820
- parser.add_argument('--temp_sampling_tor', type=float, default=1.0)
821
- parser.add_argument('--temp_psi_tor', type=float, default=0.0)
822
- parser.add_argument('--temp_sigma_data_tor', type=float, default=0.5)
823
-
824
- parser.add_argument('--gnina_minimize', action='store_true', default=False, help='')
825
- parser.add_argument('--gnina_path', type=str, default='gnina', help='')
826
- parser.add_argument('--gnina_log_file', type=str, default='gnina_log.txt',
827
- help='') # To redirect gnina subprocesses stdouts from the terminal window
828
- parser.add_argument('--gnina_full_dock', action='store_true', default=False, help='')
829
- parser.add_argument('--gnina_autobox_add', type=float, default=4.0)
830
- parser.add_argument('--gnina_poses_to_optimize', type=int, default=1)
831
-
832
- # Linker generation settings
833
- # parser.add_argument('--fragments', action='store', type=str, required=True,
834
- # help='Path to the file with input fragments'
835
- # )
836
- # parser.add_argument(
837
- # '--protein', action='store', type=str, required=True,
838
- # help='Path to the file with the target protein'
839
- # )
840
- parser.add_argument(
841
- '--backbone_atoms_only', action='store_true', required=False, default=False,
842
- help='Flag if to use only protein backbone atoms'
843
- )
844
- parser.add_argument(
845
- '--linker_ckpt', action='store', type=str,
846
- help='Path to the DiffLinker model'
847
- )
848
- parser.add_argument(
849
- '--linker_size', action='store', type=str,
850
- help='Linker size (int) or allowed size boundaries (comma-separated) or path to the size prediction model'
851
- )
852
- parser.add_argument(
853
- '--n_linkers', action='store', type=int, required=False, default=5,
854
- help='Number of linkers to generate'
855
- )
856
- parser.add_argument(
857
- '--n_steps', action='store', type=int, required=False, default=1000,
858
- help='Number of denoising steps'
859
- )
860
- parser.add_argument(
861
- '--anchors', action='store', type=str, required=False, default=None,
862
- help='Comma-separated indices of anchor atoms '
863
- '(according to the order of atoms in the input fragments file, enumeration starts with 1)'
864
- )
865
- parser.add_argument(
866
- '--max_batch_size', action='store', type=int, required=False, default=16,
867
- help='Max batch size'
868
- )
869
- parser.add_argument(
870
- '--random_seed', action='store', type=int, required=False, default=None,
871
- help='Random seed'
872
- )
873
- parser.add_argument(
874
- '--robust', action='store_true', required=False, default=False,
875
- help='Robust sampling modification'
876
- )
877
- parser.add_argument(
878
- '--dock', action='store_true', default=False,
879
- help='Fragment docking with DiffDock'
880
- )
881
- parser.add_argument(
882
- '--link', action='store_true', default=False,
883
- help='Linker generation with DiffLinker'
884
- )
885
-
886
- args = parser.parse_args()
887
- if args.config:
888
- config_dict = yaml.load(args.config, Loader=yaml.FullLoader)
889
- arg_dict = args.__dict__
890
- for key, value in config_dict.items():
891
- if isinstance(value, list):
892
- for v in value:
893
- arg_dict[key].append(v)
894
- else:
895
- arg_dict[key] = value
896
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
897
- configure_logger(args.loglevel)
898
- log = get_logger()
899
-
900
- date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
901
- experiment_name = f"{date_time}_{args.name}"
902
- args.out_dir = Path(args.out_dir, experiment_name)
903
-
904
- if args.dock:
905
- docking_df = dock_fragments(args)
906
- linking_df = process_docking_results(
907
- docking_df,
908
- eps=args.eps, min_samples=args.min_samples,
909
- frag_dist_range=args.frag_dist_range, distance_type=args.distance_type
910
- )
911
- if args.link:
912
- ddpm = DDPM.load_from_checkpoint(args.linker_ckpt, map_location=device, robust=args.robust).eval().to(device)
913
- generate_linker(
914
- linking_df,
915
- backbone_atoms_only=args.backbone_atoms_only,
916
- model=ddpm,
917
- output_dir=args.out_dir,
918
- n_samples=args.n_linkers,
919
- n_steps=args.n_steps,
920
- linker_size=args.linker_size,
921
- anchors=args.anchors,
922
- max_batch_size=args.max_batch_size,
923
- random_seed=args.random_seed,
924
- )
925
-
926
- if args.link:
927
- linking_df = pd.read_csv(args.protein_ligand_csv)
928
- ddpm = DDPM.load_from_checkpoint(args.linker_ckpt, map_location=device, robust=args.robust).eval().to(device)
929
- generate_linker(
930
- linking_df,
931
- backbone_atoms_only=args.backbone_atoms_only,
932
- model=ddpm,
933
- output_dir=args.out_dir,
934
- n_samples=args.n_linkers,
935
- n_steps=args.n_steps,
936
- linker_size=args.linker_size,
937
- anchors=args.anchors,
938
- max_batch_size=args.max_batch_size,
939
- random_seed=args.random_seed,
940
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mean_cluster.yaml DELETED
@@ -1,38 +0,0 @@
1
- actual_steps: 19
2
- score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
3
- confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
4
- linker_ckpt: resources/checkpoints/DiffLinker/pockets_difflinker_full.ckpt
5
- linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
6
- #confidence_model_dir: ./workdir/v1.1/confidence_model
7
- different_schedules: false
8
- inf_sched_alpha: 1
9
- inf_sched_beta: 1
10
- inference_steps: 20
11
- initial_noise_std_proportion: 1.4601642460337794
12
- limit_failures: 5
13
- #model_dir: ./workdir/v1.1/score_model
14
- #comment
15
- no_final_step_noise: true
16
- no_model: false
17
- no_random: false
18
- no_random_pocket: false
19
- ode: false
20
- old_filtering_model: true
21
- old_score_model: false
22
- resample_rdkit: false
23
- samples_per_complex: 10
24
- sigma_schedule: expbeta
25
- temp_psi_rot: 0.9022615585677628
26
- temp_psi_tor: 0.5946212391366862
27
- temp_psi_tr: 0.727287304570729
28
- temp_sampling_rot: 2.06391612594481
29
- temp_sampling_tor: 7.044261621607846
30
- temp_sampling_tr: 1.170050527854316
31
- temp_sigma_data_rot: 0.7464326999906034
32
- temp_sigma_data_tor: 0.6943254174849822
33
- temp_sigma_data_tr: 0.9299802531572672
34
-
35
- eps: 5
36
- min_samples: 10
37
- frag_dist_range: (2, 5)
38
- distance_type: min
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
min_cluster.yaml DELETED
@@ -1,38 +0,0 @@
1
- actual_steps: 19
2
- score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
3
- confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
4
- linker_ckpt: resources/checkpoints/DiffLinker/pockets_difflinker_full.ckpt
5
- linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
6
- #confidence_model_dir: ./workdir/v1.1/confidence_model
7
- different_schedules: false
8
- inf_sched_alpha: 1
9
- inf_sched_beta: 1
10
- inference_steps: 20
11
- initial_noise_std_proportion: 1.4601642460337794
12
- limit_failures: 5
13
- #model_dir: ./workdir/v1.1/score_model
14
- #comment
15
- no_final_step_noise: true
16
- no_model: false
17
- no_random: false
18
- no_random_pocket: false
19
- ode: false
20
- old_filtering_model: true
21
- old_score_model: false
22
- resample_rdkit: false
23
- samples_per_complex: 10
24
- sigma_schedule: expbeta
25
- temp_psi_rot: 0.9022615585677628
26
- temp_psi_tor: 0.5946212391366862
27
- temp_psi_tr: 0.727287304570729
28
- temp_sampling_rot: 2.06391612594481
29
- temp_sampling_tor: 7.044261621607846
30
- temp_sampling_tr: 1.170050527854316
31
- temp_sigma_data_rot: 0.7464326999906034
32
- temp_sigma_data_tor: 0.6943254174849822
33
- temp_sigma_data_tr: 0.9299802531572672
34
-
35
- eps: 1
36
- min_samples: 10
37
- frag_dist_range: (5, 8)
38
- distance_type: mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pocket_backbone.yaml DELETED
@@ -1,47 +0,0 @@
1
- actual_steps: 19
2
- score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
3
- confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
4
- #confidence_model_dir: ./workdir/v1.1/confidence_model
5
- different_schedules: false
6
- inf_sched_alpha: 1
7
- inf_sched_beta: 1
8
- inference_steps: 20
9
- initial_noise_std_proportion: 1.4601642460337794
10
- limit_failures: 5
11
- #model_dir: ./workdir/v1.1/score_model
12
- #comment
13
- no_final_step_noise: true
14
- no_model: false
15
- no_random: false
16
- no_random_pocket: false
17
- ode: false
18
- old_filtering_model: true
19
- old_score_model: false
20
- resample_rdkit: false
21
- samples_per_complex: 10
22
- sigma_schedule: expbeta
23
- temp_psi_rot: 0.9022615585677628
24
- temp_psi_tor: 0.5946212391366862
25
- temp_psi_tr: 0.727287304570729
26
- temp_sampling_rot: 2.06391612594481
27
- temp_sampling_tor: 7.044261621607846
28
- temp_sampling_tr: 1.170050527854316
29
- temp_sigma_data_rot: 0.7464326999906034
30
- temp_sigma_data_tor: 0.6943254174849822
31
- temp_sigma_data_tr: 0.9299802531572672
32
-
33
- eps: 5
34
- min_samples: 10
35
- frag_dist_range:
36
- - 2
37
- - 5
38
- distance_type: min
39
-
40
- linker_ckpt:
41
- pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
42
- pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
43
- # resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
44
- geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
45
- linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
46
- linker_condition: 'protein' # pocket
47
- linker_batch_size: 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pocket_full_no_anchors_4a.yaml DELETED
@@ -1,44 +0,0 @@
1
- actual_steps: 19
2
- score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
3
- confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
4
- #confidence_model_dir: ./workdir/v1.1/confidence_model
5
- different_schedules: false
6
- inf_sched_alpha: 1
7
- inf_sched_beta: 1
8
- inference_steps: 20
9
- initial_noise_std_proportion: 1.4601642460337794
10
- limit_failures: 5
11
- #model_dir: ./workdir/v1.1/score_model
12
- #comment
13
- no_final_step_noise: true
14
- no_model: false
15
- no_random: false
16
- no_random_pocket: false
17
- ode: false
18
- old_filtering_model: true
19
- old_score_model: false
20
- resample_rdkit: false
21
- samples_per_complex: 10
22
- sigma_schedule: expbeta
23
- temp_psi_rot: 0.9022615585677628
24
- temp_psi_tor: 0.5946212391366862
25
- temp_psi_tr: 0.727287304570729
26
- temp_sampling_rot: 2.06391612594481
27
- temp_sampling_tor: 7.044261621607846
28
- temp_sampling_tr: 1.170050527854316
29
- temp_sigma_data_rot: 0.7464326999906034
30
- temp_sigma_data_tor: 0.6943254174849822
31
- temp_sigma_data_tr: 0.9299802531572672
32
-
33
- eps: 5
34
- min_samples: 10
35
- frag_dist_range: (2, 5)
36
- distance_type: min
37
-
38
- linker_ckpt:
39
- # pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
40
- pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors.ckpt
41
- pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
42
- geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
43
- linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
44
- linker_condition: 'protein' # pocket
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pocket_full_no_anchors_fc_pdb_excluded.yaml DELETED
@@ -1,47 +0,0 @@
1
- actual_steps: 19
2
- score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
3
- confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
4
- #confidence_model_dir: ./workdir/v1.1/confidence_model
5
- different_schedules: false
6
- inf_sched_alpha: 1
7
- inf_sched_beta: 1
8
- inference_steps: 20
9
- initial_noise_std_proportion: 1.4601642460337794
10
- limit_failures: 5
11
- #model_dir: ./workdir/v1.1/score_model
12
- #comment
13
- no_final_step_noise: true
14
- no_model: false
15
- no_random: false
16
- no_random_pocket: false
17
- ode: false
18
- old_filtering_model: true
19
- old_score_model: false
20
- resample_rdkit: false
21
- samples_per_complex: 10
22
- sigma_schedule: expbeta
23
- temp_psi_rot: 0.9022615585677628
24
- temp_psi_tor: 0.5946212391366862
25
- temp_psi_tr: 0.727287304570729
26
- temp_sampling_rot: 2.06391612594481
27
- temp_sampling_tor: 7.044261621607846
28
- temp_sampling_tr: 1.170050527854316
29
- temp_sigma_data_rot: 0.7464326999906034
30
- temp_sigma_data_tor: 0.6943254174849822
31
- temp_sigma_data_tr: 0.9299802531572672
32
-
33
- eps: 5
34
- min_samples: 10
35
- frag_dist_range:
36
- - 2
37
- - 5
38
- distance_type: min
39
-
40
- linker_ckpt:
41
- pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
42
- pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
43
- # resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
44
- geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
45
- linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
46
- linker_condition: 'protein' # pocket
47
- linker_batch_size: 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protein_backbone.yaml DELETED
@@ -1,47 +0,0 @@
1
- actual_steps: 19
2
- score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
3
- confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
4
- #confidence_model_dir: ./workdir/v1.1/confidence_model
5
- different_schedules: false
6
- inf_sched_alpha: 1
7
- inf_sched_beta: 1
8
- inference_steps: 20
9
- initial_noise_std_proportion: 1.4601642460337794
10
- limit_failures: 5
11
- #model_dir: ./workdir/v1.1/score_model
12
- #comment
13
- no_final_step_noise: true
14
- no_model: false
15
- no_random: false
16
- no_random_pocket: false
17
- ode: false
18
- old_filtering_model: true
19
- old_score_model: false
20
- resample_rdkit: false
21
- samples_per_complex: 10
22
- sigma_schedule: expbeta
23
- temp_psi_rot: 0.9022615585677628
24
- temp_psi_tor: 0.5946212391366862
25
- temp_psi_tr: 0.727287304570729
26
- temp_sampling_rot: 2.06391612594481
27
- temp_sampling_tor: 7.044261621607846
28
- temp_sampling_tr: 1.170050527854316
29
- temp_sigma_data_rot: 0.7464326999906034
30
- temp_sigma_data_tor: 0.6943254174849822
31
- temp_sigma_data_tr: 0.9299802531572672
32
-
33
- eps: 5
34
- min_samples: 10
35
- frag_dist_range:
36
- - 2
37
- - 5
38
- distance_type: min
39
-
40
- linker_ckpt:
41
- pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
42
- pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
43
- geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
44
- linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
45
- linker_condition: 'protein' # pocket
46
- linker_batch_size: 20
47
- backbone_atoms_only: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protein_full_no_anchors_fc_pdb_excluded.yaml DELETED
@@ -1,47 +0,0 @@
1
- actual_steps: 19
2
- score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
3
- confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
4
- #confidence_model_dir: ./workdir/v1.1/confidence_model
5
- different_schedules: false
6
- inf_sched_alpha: 1
7
- inf_sched_beta: 1
8
- inference_steps: 20
9
- initial_noise_std_proportion: 1.4601642460337794
10
- limit_failures: 5
11
- #model_dir: ./workdir/v1.1/score_model
12
- #comment
13
- no_final_step_noise: true
14
- no_model: false
15
- no_random: false
16
- no_random_pocket: false
17
- ode: false
18
- old_filtering_model: true
19
- old_score_model: false
20
- resample_rdkit: false
21
- samples_per_complex: 10
22
- sigma_schedule: expbeta
23
- temp_psi_rot: 0.9022615585677628
24
- temp_psi_tor: 0.5946212391366862
25
- temp_psi_tr: 0.727287304570729
26
- temp_sampling_rot: 2.06391612594481
27
- temp_sampling_tor: 7.044261621607846
28
- temp_sampling_tr: 1.170050527854316
29
- temp_sigma_data_rot: 0.7464326999906034
30
- temp_sigma_data_tor: 0.6943254174849822
31
- temp_sigma_data_tr: 0.9299802531572672
32
-
33
- eps: 5
34
- min_samples: 10
35
- frag_dist_range:
36
- - 2
37
- - 5
38
- distance_type: min
39
-
40
- linker_ckpt:
41
- pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
42
- pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
43
- # resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
44
- geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
45
- linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
46
- linker_condition: 'protein' # pocket
47
- linker_batch_size: 20