Spaces:
Sleeping
Sleeping
Fix conflicts
Browse files- configs/gen_fbdd_v1.yaml +0 -42
- gen_fbdd_v1.yaml +0 -44
- inference.py.bak +0 -940
- mean_cluster.yaml +0 -38
- min_cluster.yaml +0 -38
- pocket_backbone.yaml +0 -47
- pocket_full_no_anchors_4a.yaml +0 -44
- pocket_full_no_anchors_fc_pdb_excluded.yaml +0 -47
- protein_backbone.yaml +0 -47
- protein_full_no_anchors_fc_pdb_excluded.yaml +0 -47
configs/gen_fbdd_v1.yaml
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
|
2 |
-
confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
|
3 |
-
docking_batch_size: 20
|
4 |
-
#confidence_model_dir: ./workdir/v1.1/confidence_model
|
5 |
-
different_schedules: false
|
6 |
-
inf_sched_alpha: 1
|
7 |
-
inf_sched_beta: 1
|
8 |
-
inference_steps: 20
|
9 |
-
initial_noise_std_proportion: 1.4601642460337794
|
10 |
-
limit_failures: 5
|
11 |
-
#model_dir: ./workdir/v1.1/score_model
|
12 |
-
#comment
|
13 |
-
no_final_step_noise: true
|
14 |
-
no_model: false
|
15 |
-
no_random: false
|
16 |
-
no_random_pocket: false
|
17 |
-
ode: false
|
18 |
-
old_filtering_model: true
|
19 |
-
old_score_model: false
|
20 |
-
resample_rdkit: false
|
21 |
-
samples_per_complex: 10
|
22 |
-
sigma_schedule: expbeta
|
23 |
-
temp_psi_rot: 0.9022615585677628
|
24 |
-
temp_psi_tor: 0.5946212391366862
|
25 |
-
temp_psi_tr: 0.727287304570729
|
26 |
-
temp_sampling_rot: 2.06391612594481
|
27 |
-
temp_sampling_tor: 7.044261621607846
|
28 |
-
temp_sampling_tr: 1.170050527854316
|
29 |
-
temp_sigma_data_rot: 0.7464326999906034
|
30 |
-
temp_sigma_data_tor: 0.6943254174849822
|
31 |
-
temp_sigma_data_tr: 0.9299802531572672
|
32 |
-
|
33 |
-
rmsd_threshold: 1.5
|
34 |
-
|
35 |
-
linker_ckpt:
|
36 |
-
pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
37 |
-
pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
|
38 |
-
geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
|
39 |
-
size_ckpt: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
|
40 |
-
linker_condition: 'none' # pocket
|
41 |
-
linker_batch_size: 64
|
42 |
-
linker_steps: 1000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen_fbdd_v1.yaml
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
actual_steps: 19
|
2 |
-
score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
|
3 |
-
confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
|
4 |
-
#confidence_model_dir: ./workdir/v1.1/confidence_model
|
5 |
-
different_schedules: false
|
6 |
-
inf_sched_alpha: 1
|
7 |
-
inf_sched_beta: 1
|
8 |
-
inference_steps: 20
|
9 |
-
initial_noise_std_proportion: 1.4601642460337794
|
10 |
-
limit_failures: 5
|
11 |
-
#model_dir: ./workdir/v1.1/score_model
|
12 |
-
#comment
|
13 |
-
no_final_step_noise: true
|
14 |
-
no_model: false
|
15 |
-
no_random: false
|
16 |
-
no_random_pocket: false
|
17 |
-
ode: false
|
18 |
-
old_filtering_model: true
|
19 |
-
old_score_model: false
|
20 |
-
resample_rdkit: false
|
21 |
-
samples_per_complex: 10
|
22 |
-
sigma_schedule: expbeta
|
23 |
-
temp_psi_rot: 0.9022615585677628
|
24 |
-
temp_psi_tor: 0.5946212391366862
|
25 |
-
temp_psi_tr: 0.727287304570729
|
26 |
-
temp_sampling_rot: 2.06391612594481
|
27 |
-
temp_sampling_tor: 7.044261621607846
|
28 |
-
temp_sampling_tr: 1.170050527854316
|
29 |
-
temp_sigma_data_rot: 0.7464326999906034
|
30 |
-
temp_sigma_data_tor: 0.6943254174849822
|
31 |
-
temp_sigma_data_tr: 0.9299802531572672
|
32 |
-
|
33 |
-
eps: 5
|
34 |
-
min_samples: 10
|
35 |
-
frag_dist_range: [2, 5]
|
36 |
-
distance_type: min
|
37 |
-
|
38 |
-
linker_ckpt:
|
39 |
-
pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
40 |
-
pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
|
41 |
-
geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
|
42 |
-
linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
|
43 |
-
linker_condition: 'none' # pocket
|
44 |
-
linker_batch_size: 64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.py.bak
DELETED
@@ -1,940 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import subprocess
|
3 |
-
import sys
|
4 |
-
from argparse import ArgumentParser, Namespace, FileType
|
5 |
-
import copy
|
6 |
-
import itertools
|
7 |
-
import os
|
8 |
-
from datetime import datetime
|
9 |
-
from pathlib import Path
|
10 |
-
from functools import partial, cache
|
11 |
-
import warnings
|
12 |
-
import yaml
|
13 |
-
from Bio.PDB import PDBParser
|
14 |
-
from sklearn.cluster import DBSCAN
|
15 |
-
|
16 |
-
from src import const
|
17 |
-
from src.datasets import (
|
18 |
-
collate_with_fragment_without_pocket_edges, get_dataloader, get_one_hot, parse_molecule, ProteinConditionedDataset
|
19 |
-
)
|
20 |
-
from src.lightning import DDPM
|
21 |
-
from src.linker_size_lightning import SizeClassifier
|
22 |
-
from src.utils import set_deterministic, FoundNaNException
|
23 |
-
from src.visualizer import save_sdf
|
24 |
-
|
25 |
-
# Ignore pandas deprecation warning around pyarrow
|
26 |
-
warnings.filterwarnings("ignore", category=DeprecationWarning,
|
27 |
-
message="(?s).*Pyarrow will become a required dependency of pandas.*")
|
28 |
-
import numpy as np
|
29 |
-
import pandas as pd
|
30 |
-
from pandarallel import pandarallel
|
31 |
-
import torch
|
32 |
-
from torch_geometric.loader import DataLoader
|
33 |
-
|
34 |
-
from Bio import SeqIO
|
35 |
-
from rdkit import RDLogger, Chem
|
36 |
-
from rdkit.Chem import RemoveAllHs
|
37 |
-
|
38 |
-
# TODO imports are a little odd, utils seems to shadow things
|
39 |
-
from utils.logging_utils import configure_logger, get_logger
|
40 |
-
from datasets.process_mols import create_mol_with_coords, read_molecule
|
41 |
-
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule
|
42 |
-
from utils.inference_utils import InferenceDataset
|
43 |
-
from utils.sampling import randomize_position, sampling
|
44 |
-
from utils.utils import get_model
|
45 |
-
from utils.visualise import PDBFile
|
46 |
-
from tqdm import tqdm
|
47 |
-
|
48 |
-
RDLogger.DisableLog('rdApp.*')
|
49 |
-
warnings.filterwarnings("ignore", category=UserWarning,
|
50 |
-
message="The TorchScript type system doesn't support instance-level annotations on"
|
51 |
-
" empty non-base types in `__init__`")
|
52 |
-
|
53 |
-
# Prody logging is very verbose by default
|
54 |
-
prody_logger = logging.getLogger(".prody")
|
55 |
-
prody_logger.setLevel(logging.ERROR)
|
56 |
-
|
57 |
-
# Pandarallel initialization
|
58 |
-
nb_workers = os.cpu_count()
|
59 |
-
progress_bar = False
|
60 |
-
if hasattr(sys, 'gettrace') and sys.gettrace() is not None: # Debug mode
|
61 |
-
nb_workers = 1
|
62 |
-
progress_bar = True
|
63 |
-
pandarallel.initialize(nb_workers=nb_workers, progress_bar=progress_bar)
|
64 |
-
|
65 |
-
|
66 |
-
def read_compound_library(file_path):
|
67 |
-
df = None
|
68 |
-
if file_path.suffix == '.csv':
|
69 |
-
df = pd.read_csv(file_path)
|
70 |
-
elif file_path.suffix == '.sdf':
|
71 |
-
supplier = Chem.SDMolSupplier(file_path, sanitize=False, removeHs=False)
|
72 |
-
# Convert to a dataframe
|
73 |
-
df = pd.DataFrame([{'X1': Chem.MolToSmiles(mol), 'ID1': mol.GetProp('_Name')} for mol in supplier])
|
74 |
-
# Use InChiKey as ID1 if None
|
75 |
-
df.loc[df['ID1'].isna(), 'ID1'] = df.loc[
|
76 |
-
df['ID1'].isna(), 'X1'
|
77 |
-
].apply(Chem.MolFromSmiles).apply(Chem.MolToInchiKey)
|
78 |
-
|
79 |
-
return df
|
80 |
-
|
81 |
-
def read_protein_library(file_path):
|
82 |
-
df = None
|
83 |
-
if file_path.suffix == '.csv':
|
84 |
-
df = pd.read_csv(file_path)
|
85 |
-
elif file_path.suffix == '.fasta':
|
86 |
-
records = list(SeqIO.parse(file_path, 'fasta'))
|
87 |
-
df = pd.DataFrame([{'X2': str(record.seq), 'ID2': record.id} for record in records])
|
88 |
-
|
89 |
-
return df
|
90 |
-
|
91 |
-
def process_fragment_library(df):
|
92 |
-
"""
|
93 |
-
SMILES strings with separators (e.g., .) represent distinct molecular entities, such as ligands, ions, or
|
94 |
-
co-crystallized molecules. Splitting them ensures that each entity is treated individually, allowing focused
|
95 |
-
analysis of their roles in binding. Single atom fragments (e.g., counterions like [I-] or [Cl-] are irrelevant in
|
96 |
-
docking and are to be removed. This filtering focuses on structurally relevant fragments.
|
97 |
-
"""
|
98 |
-
# Get subset of rows with SMILES containing separators
|
99 |
-
fragmented_rows = df['X1'].str.contains('.', regex=False)
|
100 |
-
df_fragmented = df[fragmented_rows].copy()
|
101 |
-
|
102 |
-
# Split SMILES into lists and expand
|
103 |
-
df_fragmented['X1'] = df_fragmented['X1'].str.split('.')
|
104 |
-
df_fragmented = df_fragmented.explode('X1').reset_index(drop=True)
|
105 |
-
|
106 |
-
# Append fragment index as alphabet (A, B, C... AA, AB...) to ID1 for rows with fragmented SMILES
|
107 |
-
df_fragmented['ID1'] = df_fragmented.groupby('ID1').cumcount().apply(num_to_letter_code).radd(
|
108 |
-
df_fragmented['ID1'] + '_')
|
109 |
-
df = pd.concat([df[~fragmented_rows], df_fragmented]).sort_index().reset_index(drop=True)
|
110 |
-
df['mol'] = df['X1'].apply(read_molecule, remove_confs=True)
|
111 |
-
df = df.dropna(subset=['mol'])
|
112 |
-
|
113 |
-
# # Remove fragments with no carbon atoms
|
114 |
-
# df = df[df['mol'].swifter.apply(lambda mol: any(atom.GetSymbol() == 'C' for atom in mol.GetAtoms()))]
|
115 |
-
# Remove single-atom fragments
|
116 |
-
df = df[df['mol'].apply(lambda mol: mol.GetNumAtoms() > 1)]
|
117 |
-
# Canonicalize SMILES
|
118 |
-
df['X1'] = df['mol'].apply(lambda x: Chem.MolToSmiles(x))
|
119 |
-
|
120 |
-
return df
|
121 |
-
|
122 |
-
|
123 |
-
def check_one_to_one(df, ID_column, X_column):
|
124 |
-
# Check for multiple X values for the same ID
|
125 |
-
id_to_x_conflicts = df.groupby(ID_column)[X_column].nunique()
|
126 |
-
conflicting_ids = id_to_x_conflicts[id_to_x_conflicts > 1]
|
127 |
-
|
128 |
-
# Check for multiple ID values for the same X
|
129 |
-
x_to_id_conflicts = df.groupby(X_column)[ID_column].nunique()
|
130 |
-
conflicting_xs = x_to_id_conflicts[x_to_id_conflicts > 1]
|
131 |
-
|
132 |
-
# Print conflicting mappings
|
133 |
-
if not conflicting_ids.empty:
|
134 |
-
print(f"Conflicting {ID_column} -> multiple {X_column}:")
|
135 |
-
for idx in conflicting_ids.index:
|
136 |
-
print(f"{ID_column}: {idx}, {X_column} values: {df[df[ID_column] == idx][X_column].unique()}")
|
137 |
-
|
138 |
-
if not conflicting_xs.empty:
|
139 |
-
print(f"Conflicting {X_column} -> multiple {ID_column}:")
|
140 |
-
for x in conflicting_xs.index:
|
141 |
-
print(f"{X_column}: {x}, {ID_column} values: {df[df[X_column] == x][ID_column].unique()}")
|
142 |
-
|
143 |
-
# Return whether the mappings are one-to-one
|
144 |
-
return conflicting_ids.empty and conflicting_xs.empty
|
145 |
-
|
146 |
-
|
147 |
-
def num_to_letter_code(n):
|
148 |
-
result = ''
|
149 |
-
while n >= 0:
|
150 |
-
result = chr(65 + (n % 26)) + result
|
151 |
-
n = n // 26 - 1
|
152 |
-
return result
|
153 |
-
|
154 |
-
|
155 |
-
def dock_fragments(args):
|
156 |
-
with open(Path(args.score_ckpt).parent / 'model_parameters.yml') as f:
|
157 |
-
score_model_args = Namespace(**yaml.full_load(f))
|
158 |
-
if args.confidence_ckpt is not None:
|
159 |
-
with open(Path(args.confidence_ckpt).parent / 'model_parameters.yml') as f:
|
160 |
-
confidence_args = Namespace(**yaml.full_load(f))
|
161 |
-
log.info(f"DiffFragDock will run on {device}")
|
162 |
-
|
163 |
-
docking_out_dir = Path(args.out_dir, 'docking')
|
164 |
-
docking_out_dir.mkdir(parents=True, exist_ok=True)
|
165 |
-
if args.protein_ligand_csv is not None:
|
166 |
-
csv_path = Path(args.protein_ligand_csv)
|
167 |
-
assert csv_path.is_file(), f"File {args.protein_ligand_csv} does not exist"
|
168 |
-
df = pd.read_csv(csv_path)
|
169 |
-
df = process_fragment_library(df)
|
170 |
-
else:
|
171 |
-
assert args.X1 is not None and args.X2 is not None, "Either a .csv file or `X1` and `X2` must be provided."
|
172 |
-
|
173 |
-
compound_df = pd.DataFrame(columns=['X1', 'ID1'])
|
174 |
-
if Path(args.X1).is_file():
|
175 |
-
compound_path = Path(args.X1)
|
176 |
-
if compound_path.suffix in ['.csv', '.sdf']:
|
177 |
-
compound_df[['X1', 'ID1']] = read_compound_library(compound_path)[['X1', 'ID1']]
|
178 |
-
else:
|
179 |
-
compound_df['X1'] = [compound_path]
|
180 |
-
compound_df['ID1'] = [compound_path.stem]
|
181 |
-
else:
|
182 |
-
compound_df['X1'] = [args.X1]
|
183 |
-
compound_df['ID1'] = 'compound_0'
|
184 |
-
compound_df.dropna(subset=['X1'], inplace=True)
|
185 |
-
compound_df.loc[compound_df['ID1'].isna(), 'ID1'] = compound_df.loc[compound_df['ID1'].isna(), 'X1'].apply(
|
186 |
-
lambda x: Chem.MolToInchiKey(Chem.MolFromSmiles(x))
|
187 |
-
)
|
188 |
-
|
189 |
-
protein_df = pd.DataFrame(columns=['X2', 'ID2'])
|
190 |
-
if Path(args.X2).is_file():
|
191 |
-
protein_path = Path(args.X2)
|
192 |
-
if protein_path.suffix in ['.csv', '.fasta']:
|
193 |
-
protein_df[['X2', 'ID2']] = read_protein_library(protein_path)[['X2', 'ID2']]
|
194 |
-
else:
|
195 |
-
protein_df['protein_path'] = [protein_path]
|
196 |
-
protein_df['ID2'] = [protein_path.stem]
|
197 |
-
else:
|
198 |
-
protein_df['X2'] = [args.X2]
|
199 |
-
protein_df['ID2'] = 'protein_0'
|
200 |
-
protein_df.dropna(subset=['X2'], inplace=True)
|
201 |
-
protein_df.loc[protein_df['ID2'].isna(), 'ID2'] = [
|
202 |
-
f"protein_{i}" for i in range(protein_df['ID2'].isna().sum())
|
203 |
-
]
|
204 |
-
|
205 |
-
compound_df = process_fragment_library(compound_df)
|
206 |
-
df = compound_df.merge(protein_df, how='cross')
|
207 |
-
|
208 |
-
# Identify duplicates based on 'X1' and 'X2'
|
209 |
-
duplicates = df[df.duplicated(subset=['X1', 'X2'], keep=False)]
|
210 |
-
if not duplicates.empty:
|
211 |
-
print("Duplicate rows based on columns 'X1' and 'X2':\n", duplicates[['ID1', 'X1', 'ID2', 'X2']])
|
212 |
-
print("Keeping the first occurrence of each duplicate.")
|
213 |
-
df = df.drop_duplicates(subset=['X1', 'X2'])
|
214 |
-
df['name'] = df['ID2'] + '-' + df['ID1']
|
215 |
-
|
216 |
-
df = df.replace({pd.NA: None})
|
217 |
-
# Check unique mappings between IDn and Xn
|
218 |
-
assert check_one_to_one(df, 'ID1', 'X1'), "ID1-X1 mapping is not one-to-one."
|
219 |
-
assert check_one_to_one(df, 'ID2', 'X2'), "ID2-X2 mapping is not one-to-one."
|
220 |
-
|
221 |
-
"""
|
222 |
-
Docking phase
|
223 |
-
"""
|
224 |
-
|
225 |
-
# preprocessing of complexes into geometric graphs
|
226 |
-
test_dataset = InferenceDataset(
|
227 |
-
df=df, out_dir=args.out_dir,
|
228 |
-
lm_embeddings=True,
|
229 |
-
receptor_radius=score_model_args.receptor_radius,
|
230 |
-
remove_hs=True, # score_model_args.remove_hs,
|
231 |
-
c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors,
|
232 |
-
all_atoms=score_model_args.all_atoms, atom_radius=score_model_args.atom_radius,
|
233 |
-
atom_max_neighbors=score_model_args.atom_max_neighbors,
|
234 |
-
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph')
|
235 |
-
else not score_model_args.not_knn_only_graph
|
236 |
-
)
|
237 |
-
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
|
238 |
-
|
239 |
-
if args.confidence_ckpt is not None and not confidence_args.use_original_model_cache:
|
240 |
-
log.info('Confidence model uses different type of graphs than the score model. '
|
241 |
-
'Loading (or creating if not existing) the data for the confidence model now.')
|
242 |
-
confidence_test_dataset = InferenceDataset(
|
243 |
-
df=df, out_dir=args.out_dir,
|
244 |
-
lm_embeddings=True,
|
245 |
-
receptor_radius=confidence_args.receptor_radius,
|
246 |
-
remove_hs=True, # confidence_args.remove_hs,
|
247 |
-
c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors,
|
248 |
-
all_atoms=confidence_args.all_atoms,
|
249 |
-
atom_radius=confidence_args.atom_radius,
|
250 |
-
atom_max_neighbors=confidence_args.atom_max_neighbors,
|
251 |
-
precomputed_lm_embeddings=test_dataset.lm_embeddings,
|
252 |
-
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph')
|
253 |
-
else not score_model_args.not_knn_only_graph
|
254 |
-
)
|
255 |
-
else:
|
256 |
-
confidence_test_dataset = None
|
257 |
-
|
258 |
-
t_to_sigma = partial(t_to_sigma_compl, args=score_model_args)
|
259 |
-
|
260 |
-
model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True, old=args.old_score_model)
|
261 |
-
state_dict = torch.load(Path(args.score_ckpt), map_location='cpu', weights_only=True)
|
262 |
-
model.load_state_dict(state_dict, strict=True)
|
263 |
-
model = model.to(device)
|
264 |
-
model.eval()
|
265 |
-
|
266 |
-
if args.confidence_ckpt is not None:
|
267 |
-
confidence_model = get_model(confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True,
|
268 |
-
confidence_mode=True, old=args.old_confidence_model)
|
269 |
-
state_dict = torch.load(Path(args.confidence_ckpt), map_location='cpu', weights_only=True)
|
270 |
-
confidence_model.load_state_dict(state_dict, strict=True)
|
271 |
-
confidence_model = confidence_model.to(device)
|
272 |
-
confidence_model.eval()
|
273 |
-
else:
|
274 |
-
confidence_model = None
|
275 |
-
confidence_args = None
|
276 |
-
|
277 |
-
tr_schedule = get_t_schedule(inference_steps=args.inference_steps, sigma_schedule='expbeta')
|
278 |
-
|
279 |
-
failures, skipped = 0, 0
|
280 |
-
samples_per_complex = args.samples_per_complex
|
281 |
-
test_ds_size = len(test_dataset)
|
282 |
-
df = test_loader.dataset.df
|
283 |
-
docking_dfs = []
|
284 |
-
log.info(f'Size of fragment dataset: {test_ds_size}')
|
285 |
-
for idx, orig_complex_graph in tqdm(enumerate(test_loader), total=test_ds_size):
|
286 |
-
if not orig_complex_graph.success[0]:
|
287 |
-
skipped += 1
|
288 |
-
log.warning(
|
289 |
-
f"The test dataset did not contain {df['name'].iloc[idx]}"
|
290 |
-
f" for {df['X1'].iloc[idx]} and {df['X2'].iloc[idx]}. We are skipping this complex.")
|
291 |
-
continue
|
292 |
-
try:
|
293 |
-
if confidence_test_dataset is not None:
|
294 |
-
confidence_complex_graph = confidence_test_dataset[idx]
|
295 |
-
if not confidence_complex_graph.success:
|
296 |
-
skipped += 1
|
297 |
-
log.warning(
|
298 |
-
f"The confidence dataset did not contain {orig_complex_graph.name}. We are skipping this complex.")
|
299 |
-
continue
|
300 |
-
confidence_data_list = [copy.deepcopy(confidence_complex_graph) for _ in range(samples_per_complex)]
|
301 |
-
else:
|
302 |
-
confidence_data_list = None
|
303 |
-
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(samples_per_complex)]
|
304 |
-
randomize_position(data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max,
|
305 |
-
initial_noise_std_proportion=args.initial_noise_std_proportion,
|
306 |
-
choose_residue=args.choose_residue)
|
307 |
-
|
308 |
-
lig = orig_complex_graph.mol[0]
|
309 |
-
|
310 |
-
# initialize visualisation
|
311 |
-
if args.save_visualisation:
|
312 |
-
visualization_list = []
|
313 |
-
for graph in data_list:
|
314 |
-
pdb = PDBFile(lig)
|
315 |
-
pdb.add(lig, 0, 0)
|
316 |
-
pdb.add((orig_complex_graph['ligand'].pos + orig_complex_graph.original_center).detach().cpu(), 1,
|
317 |
-
0)
|
318 |
-
pdb.add((graph['ligand'].pos + graph.original_center).detach().cpu(), part=1, order=1)
|
319 |
-
visualization_list.append(pdb)
|
320 |
-
else:
|
321 |
-
visualization_list = None
|
322 |
-
|
323 |
-
# run reverse diffusion
|
324 |
-
data_list, confidence = sampling(data_list=data_list, model=model,
|
325 |
-
inference_steps=args.actual_steps if args.actual_steps is not None
|
326 |
-
else args.inference_steps,
|
327 |
-
tr_schedule=tr_schedule, rot_schedule=tr_schedule,
|
328 |
-
tor_schedule=tr_schedule,
|
329 |
-
device=device, t_to_sigma=t_to_sigma, model_args=score_model_args,
|
330 |
-
visualization_list=visualization_list, confidence_model=confidence_model,
|
331 |
-
confidence_data_list=confidence_data_list,
|
332 |
-
confidence_model_args=confidence_args,
|
333 |
-
batch_size=args.n_poses, no_final_step_noise=args.no_final_step_noise,
|
334 |
-
temp_sampling=[args.temp_sampling_tr, args.temp_sampling_rot,
|
335 |
-
args.temp_sampling_tor],
|
336 |
-
temp_psi=[args.temp_psi_tr, args.temp_psi_rot, args.temp_psi_tor],
|
337 |
-
temp_sigma_data=[args.temp_sigma_data_tr, args.temp_sigma_data_rot,
|
338 |
-
args.temp_sigma_data_tor])
|
339 |
-
|
340 |
-
ligand_pos = np.asarray(
|
341 |
-
[complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy() for
|
342 |
-
complex_graph in data_list]
|
343 |
-
)
|
344 |
-
|
345 |
-
# save predictions
|
346 |
-
n_samples = len(confidence)
|
347 |
-
sample_df = pd.DataFrame([df.iloc[idx]] * n_samples)
|
348 |
-
confidence = confidence[:, 0].cpu().numpy()
|
349 |
-
sample_df['confidence'] = confidence
|
350 |
-
if args.save_docking:
|
351 |
-
sample_df['path'] = [
|
352 |
-
Path(
|
353 |
-
docking_out_dir, f"{df['name'].iloc[idx]}-confidence{confidence[i]:.2f}.sdf"
|
354 |
-
) for i in range(n_samples)
|
355 |
-
]
|
356 |
-
sample_df['ligand_mol']= [
|
357 |
-
create_mol_with_coords(
|
358 |
-
mol=RemoveAllHs(copy.deepcopy(lig)),
|
359 |
-
new_coords=pos,
|
360 |
-
path=sample_df['path'].iloc[i] if args.save_docking else None
|
361 |
-
) for i, pos in enumerate(ligand_pos)
|
362 |
-
]
|
363 |
-
# sample_df['ligand_pos'] = list(ligand_pos)
|
364 |
-
docking_dfs.append(sample_df)
|
365 |
-
|
366 |
-
# write_dir = f"{args.out_dir}/{df['name'].iloc[idx]}"
|
367 |
-
# for rank, pos in enumerate(ligand_pos):
|
368 |
-
# mol_pred = copy.deepcopy(lig)
|
369 |
-
# if score_model_args.remove_hs: mol_pred = RemoveAllHs(mol_pred)
|
370 |
-
# if rank == 0: write_mol_with_coords(mol_pred, pos, Path(write_dir, f'rank{rank + 1}.sdf'))
|
371 |
-
# write_mol_with_coords(mol_pred, pos,
|
372 |
-
# Path(write_dir, f'rank{rank + 1}_confidence{confidence[rank]:.2f}.sdf'))
|
373 |
-
# save visualisation frames
|
374 |
-
# if args.save_visualisation:
|
375 |
-
# if confidence is not None:
|
376 |
-
# for rank, batch_idx in enumerate(re_order):
|
377 |
-
# visualization_list[batch_idx].write(
|
378 |
-
# Path(write_dir, f'rank{rank + 1}_reverseprocess.pdb'))
|
379 |
-
# else:
|
380 |
-
# for rank, batch_idx in enumerate(ligand_pos):
|
381 |
-
# visualization_list[batch_idx].write(
|
382 |
-
# Path(write_dir, f'rank{rank + 1}_reverseprocess.pdb'))
|
383 |
-
|
384 |
-
except Exception as e:
|
385 |
-
log.warning("Failed on", orig_complex_graph["name"], e)
|
386 |
-
failures += 1
|
387 |
-
|
388 |
-
# Tear down DiffDock models and datasets
|
389 |
-
model.cpu()
|
390 |
-
del model
|
391 |
-
if confidence_model is not None:
|
392 |
-
confidence_model.cpu()
|
393 |
-
del confidence_model
|
394 |
-
del test_dataset
|
395 |
-
if confidence_test_dataset is not None:
|
396 |
-
del confidence_test_dataset
|
397 |
-
del test_loader
|
398 |
-
|
399 |
-
docking_df = pd.concat(docking_dfs, ignore_index=True)
|
400 |
-
result_msg = f"""
|
401 |
-
Failed for {failures} / {test_ds_size} complexes.
|
402 |
-
Skipped {skipped} / {test_ds_size} complexes.
|
403 |
-
"""
|
404 |
-
if failures or skipped:
|
405 |
-
log.warning(result_msg)
|
406 |
-
else:
|
407 |
-
log.info(result_msg)
|
408 |
-
log.info(f"Results saved in {docking_out_dir}")
|
409 |
-
|
410 |
-
return docking_df
|
411 |
-
|
412 |
-
|
413 |
-
def calculate_mol_atomic_distances(mol1, mol2, distance_type='min'):
|
414 |
-
mol1_coords = [
|
415 |
-
mol1.GetConformer().GetAtomPosition(i) for i in range(mol1.GetNumAtoms())
|
416 |
-
]
|
417 |
-
mol2_coords = [
|
418 |
-
mol2.GetConformer().GetAtomPosition(i) for i in range(mol2.GetNumAtoms())
|
419 |
-
]
|
420 |
-
# Ensure numpy arrays
|
421 |
-
mol1_coords = np.array(mol1_coords)
|
422 |
-
mol2_coords = np.array(mol2_coords)
|
423 |
-
|
424 |
-
# Compute pairwise distances between carbon atoms
|
425 |
-
atom_pairwise_distances = np.linalg.norm(mol1_coords[:, None, :] - mol2_coords[None, :, :], axis=-1)
|
426 |
-
# if np.any(np.isnan(atom_pairwise_distances)):
|
427 |
-
# import pdb
|
428 |
-
# pdb.set_trace() # Trigger a breakpoint if NaN is found
|
429 |
-
if distance_type == 'min':
|
430 |
-
return atom_pairwise_distances.min()
|
431 |
-
elif distance_type == 'mean':
|
432 |
-
return atom_pairwise_distances.mean()
|
433 |
-
elif distance_type is None:
|
434 |
-
return atom_pairwise_distances
|
435 |
-
else:
|
436 |
-
raise ValueError(f"Unsupported distance_type: {distance_type}")
|
437 |
-
|
438 |
-
|
439 |
-
def process_docking_results(
|
440 |
-
df,
|
441 |
-
eps=5, # Distance threshold for DBSCAN clustering
|
442 |
-
min_samples=5, # Minimum number of samples for a cluster (enrichment)
|
443 |
-
frag_dist_range=(2, 5), # Distance range for fragment linking
|
444 |
-
distance_type='min', # Type of distance to compute between fragments
|
445 |
-
):
|
446 |
-
assert len(frag_dist_range) == 2, 'Distance range must be a tuple of two values in Angstroms (Å).'
|
447 |
-
frag_dist_range = sorted(frag_dist_range)
|
448 |
-
|
449 |
-
# The mols in df should have been processed to have no explicit hydrogens, except heavy hydrogen isotopes.
|
450 |
-
docking_summaries = [] # For saving intermediate docking results
|
451 |
-
fragment_combos = [] # Fragment pairs for the linking step
|
452 |
-
# 1. Cluster docking poses
|
453 |
-
# Compute pairwise distances of molecules defined by the closest non-heavy atoms
|
454 |
-
for protein, protein_df in df.groupby('X2'):
|
455 |
-
protein_id = protein_df['ID2'].iloc[0]
|
456 |
-
protein_path = protein_df['protein_path'].iloc[0]
|
457 |
-
protein_df['index'] = protein_df.index
|
458 |
-
|
459 |
-
log.info(f'Processing docking results for {protein_id}...')
|
460 |
-
protein_fragment_combos = []
|
461 |
-
dist_matrix = np.stack(
|
462 |
-
protein_df['ligand_mol'].parallel_apply(
|
463 |
-
lambda mol1: [
|
464 |
-
calculate_mol_atomic_distances(mol1, mol2, distance_type=distance_type)
|
465 |
-
for mol2 in protein_df['ligand_mol']
|
466 |
-
]
|
467 |
-
)
|
468 |
-
)
|
469 |
-
|
470 |
-
# Perform DBSCAN clustering
|
471 |
-
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='precomputed')
|
472 |
-
protein_df['cluster'] = dbscan.fit_predict(dist_matrix)
|
473 |
-
protein_df = protein_df.sort_values(
|
474 |
-
by=['X1', 'cluster', 'confidence'], ascending=[True, True, False]
|
475 |
-
)
|
476 |
-
# Add conformer number to ID1
|
477 |
-
protein_df.groupby('ID1').cumcount().astype(str).radd(protein_df['ID1'] + '_')
|
478 |
-
if args.save_docking:
|
479 |
-
docking_summaries.append(
|
480 |
-
protein_df[['name', 'ID2', 'X2', 'ID1', 'X1', 'cluster', 'confidence', 'path']]
|
481 |
-
)
|
482 |
-
# Filter out outlier poses
|
483 |
-
protein_df = protein_df[protein_df['cluster'] != -1]
|
484 |
-
# Keep only the highest confidence pose per protein per ligand per cluster
|
485 |
-
protein_df = protein_df.groupby(['X1', 'cluster']).first().reset_index()
|
486 |
-
|
487 |
-
# 2. Create fragment-linking pairs
|
488 |
-
for cluster, cluster_df in protein_df.groupby('cluster'):
|
489 |
-
if len(cluster_df) > 1: # Skip clusters with only one pose
|
490 |
-
pairs = list(itertools.combinations(cluster_df['index'], 2))
|
491 |
-
for i, j in pairs:
|
492 |
-
row1 = cluster_df[cluster_df['index'] == i].iloc[0]
|
493 |
-
row2 = cluster_df[cluster_df['index'] == j].iloc[0]
|
494 |
-
dist = dist_matrix[i, j]
|
495 |
-
# Check if intermolecular distance is within the range
|
496 |
-
if frag_dist_range[0] < dist < frag_dist_range[1]:
|
497 |
-
combined_smiles = f"{row1['X1']}.{row2['X1']}"
|
498 |
-
combined_mol = Chem.CombineMols(row1['ligand_mol'], row2['ligand_mol'])
|
499 |
-
complex_name = f"{protein_id}-{row1['ID1']}-{row2['ID1']}"
|
500 |
-
ligand_path = f"{row1['path']},{row2['path']}"
|
501 |
-
protein_fragment_combos.append(
|
502 |
-
(complex_name, protein, protein_path, combined_smiles, ligand_path, combined_mol, dist)
|
503 |
-
)
|
504 |
-
log.info(f'Number of fragment pairs for {protein_id}: {len(protein_fragment_combos)}.')
|
505 |
-
fragment_combos.extend(protein_fragment_combos)
|
506 |
-
|
507 |
-
# Save intermediate docking results
|
508 |
-
if args.save_docking:
|
509 |
-
docking_summary_df = pd.concat(docking_summaries, ignore_index=True)
|
510 |
-
docking_summary_df.to_csv(Path(args.out_dir, 'docking_summary.csv'), index=False)
|
511 |
-
log.info(f'Saved intermediate docking results to {args.out_dir}')
|
512 |
-
|
513 |
-
# Convert fragment pair results to DataFrame
|
514 |
-
if fragment_combos:
|
515 |
-
linking_df = pd.DataFrame(
|
516 |
-
fragment_combos, columns=['name', 'X2', 'protein_path', 'X1', 'ligand_path', 'ligand_mol', 'distance']
|
517 |
-
)
|
518 |
-
linking_df[
|
519 |
-
['name', 'X2', 'protein_path', 'X1', 'ligand_path', 'distance']
|
520 |
-
].to_csv(Path(args.out_dir, 'linking_summary.csv'), index=False)
|
521 |
-
return linking_df
|
522 |
-
else:
|
523 |
-
raise ValueError('No eligible fragment pairs found for linking.')
|
524 |
-
|
525 |
-
def get_pocket(mol, pdb_path, backbone_atoms_only=False):
|
526 |
-
struct = PDBParser().get_structure('', pdb_path)
|
527 |
-
residue_ids = []
|
528 |
-
atom_coords = []
|
529 |
-
|
530 |
-
for residue in struct.get_residues():
|
531 |
-
resid = residue.get_id()[1]
|
532 |
-
for atom in residue.get_atoms():
|
533 |
-
atom_coords.append(atom.get_coord())
|
534 |
-
residue_ids.append(resid)
|
535 |
-
|
536 |
-
residue_ids = np.array(residue_ids)
|
537 |
-
atom_coords = np.array(atom_coords)
|
538 |
-
mol_atom_coords = mol.GetConformer().GetPositions()
|
539 |
-
|
540 |
-
distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1)
|
541 |
-
contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]])
|
542 |
-
|
543 |
-
pocket_coords = []
|
544 |
-
pocket_types = []
|
545 |
-
|
546 |
-
for residue in struct.get_residues():
|
547 |
-
resid = residue.get_id()[1]
|
548 |
-
if resid not in contact_residues:
|
549 |
-
continue
|
550 |
-
|
551 |
-
for atom in residue.get_atoms():
|
552 |
-
atom_name = atom.get_name()
|
553 |
-
atom_type = atom.element.upper()
|
554 |
-
atom_coord = atom.get_coord()
|
555 |
-
|
556 |
-
if backbone_atoms_only and atom_name not in {'N', 'CA', 'C', 'O'}:
|
557 |
-
continue
|
558 |
-
|
559 |
-
pocket_coords.append(atom_coord.tolist())
|
560 |
-
pocket_types.append(atom_type)
|
561 |
-
|
562 |
-
pocket_pos = []
|
563 |
-
pocket_one_hot = []
|
564 |
-
pocket_charges = []
|
565 |
-
|
566 |
-
for coord, atom_type in zip(pocket_coords, pocket_types):
|
567 |
-
if atom_type not in const.GEOM_ATOM2IDX.keys():
|
568 |
-
continue
|
569 |
-
|
570 |
-
pocket_pos.append(coord)
|
571 |
-
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
|
572 |
-
pocket_charges.append(const.GEOM_CHARGES[atom_type])
|
573 |
-
|
574 |
-
pocket_pos = np.array(pocket_pos)
|
575 |
-
pocket_one_hot = np.array(pocket_one_hot)
|
576 |
-
pocket_charges = np.array(pocket_charges)
|
577 |
-
|
578 |
-
return pocket_pos, pocket_one_hot, pocket_charges
|
579 |
-
|
580 |
-
|
581 |
-
def generate_linker(
|
582 |
-
df, backbone_atoms_only, model,
|
583 |
-
output_dir, n_samples, n_steps, linker_size, anchors, max_batch_size, random_seed
|
584 |
-
):
|
585 |
-
# Setup
|
586 |
-
if random_seed is not None:
|
587 |
-
set_deterministic(random_seed)
|
588 |
-
output_dir = Path(output_dir, 'linking')
|
589 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
590 |
-
|
591 |
-
if linker_size.isdigit():
|
592 |
-
print(f'Will generate linkers with {linker_size} atoms')
|
593 |
-
linker_size = int(linker_size)
|
594 |
-
|
595 |
-
def sample_fn(_data):
|
596 |
-
return torch.ones(_data['positions'].shape[0], device=device, dtype=const.TORCH_INT) * linker_size
|
597 |
-
|
598 |
-
else:
|
599 |
-
boundaries = [x.strip() for x in linker_size.split(',')]
|
600 |
-
if len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit():
|
601 |
-
left = int(boundaries[0])
|
602 |
-
right = int(boundaries[1])
|
603 |
-
print(f'Will generate linkers with numbers of atoms sampled from U({left}, {right})')
|
604 |
-
|
605 |
-
def sample_fn(_data):
|
606 |
-
shape = len(_data['positions']),
|
607 |
-
return torch.randint(left, right + 1, shape, device=device, dtype=const.TORCH_INT)
|
608 |
-
|
609 |
-
else:
|
610 |
-
print(f'Will generate linkers with sampled numbers of atoms')
|
611 |
-
size_classifier = SizeClassifier.load_from_checkpoint(linker_size, map_location=device).eval().to(device)
|
612 |
-
|
613 |
-
def sample_fn(_data):
|
614 |
-
out, _ = size_classifier.forward(_data, return_loss=False, with_pocket=True, adjust_shape=True)
|
615 |
-
probabilities = torch.softmax(out, dim=1)
|
616 |
-
distribution = torch.distributions.Categorical(probs=probabilities)
|
617 |
-
samples = distribution.sample()
|
618 |
-
sizes = []
|
619 |
-
for label in samples.detach().cpu().numpy():
|
620 |
-
sizes.append(size_classifier.linker_id2size[label])
|
621 |
-
sizes = torch.tensor(sizes, device=samples.device, dtype=const.TORCH_INT)
|
622 |
-
return sizes
|
623 |
-
|
624 |
-
if n_steps is not None:
|
625 |
-
model.edm.T = n_steps
|
626 |
-
|
627 |
-
if model.center_of_mass == 'anchors' and anchors is None:
|
628 |
-
print(
|
629 |
-
'Please pass anchor atoms indices '
|
630 |
-
'or use another DiffLinker model that does not require information about anchors'
|
631 |
-
)
|
632 |
-
return
|
633 |
-
|
634 |
-
cached_parse_molecule = cache(parse_molecule)
|
635 |
-
dataset = []
|
636 |
-
for i, row in df.iterrows():
|
637 |
-
mol = row['ligand_mol'] # Hs already removed
|
638 |
-
# Parsing fragments data
|
639 |
-
frag_pos, frag_one_hot, frag_charges = cached_parse_molecule(mol, is_geom=ddpm.is_geom)
|
640 |
-
# Parsing pocket data
|
641 |
-
pocket_pos, pocket_one_hot, pocket_charges = get_pocket(mol, row['protein_path'], backbone_atoms_only)
|
642 |
-
|
643 |
-
positions = np.concatenate([frag_pos, pocket_pos], axis=0)
|
644 |
-
one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0)
|
645 |
-
charges = np.concatenate([frag_charges, pocket_charges], axis=0)
|
646 |
-
anchor_flags = np.zeros_like(charges)
|
647 |
-
if anchors is not None:
|
648 |
-
for anchor in anchors.split(','):
|
649 |
-
anchor_flags[int(anchor.strip()) - 1] = 1
|
650 |
-
|
651 |
-
fragment_only_mask = np.concatenate([
|
652 |
-
np.ones_like(frag_charges),
|
653 |
-
np.zeros_like(pocket_charges),
|
654 |
-
])
|
655 |
-
pocket_mask = np.concatenate([
|
656 |
-
np.zeros_like(frag_charges),
|
657 |
-
np.ones_like(pocket_charges),
|
658 |
-
])
|
659 |
-
linker_mask = np.concatenate([
|
660 |
-
np.zeros_like(frag_charges),
|
661 |
-
np.zeros_like(pocket_charges),
|
662 |
-
])
|
663 |
-
fragment_mask = np.concatenate([
|
664 |
-
np.ones_like(frag_charges),
|
665 |
-
np.ones_like(pocket_charges),
|
666 |
-
])
|
667 |
-
|
668 |
-
dataset.extend([{
|
669 |
-
'name': row['name'],
|
670 |
-
'X1': row['X1'],
|
671 |
-
'X2': row['X2'],
|
672 |
-
'protein_path': row['protein_path'],
|
673 |
-
'ligand_path': row['ligand_path'],
|
674 |
-
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
675 |
-
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
676 |
-
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
677 |
-
'anchors': torch.tensor(anchor_flags, dtype=const.TORCH_FLOAT, device=device),
|
678 |
-
'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
|
679 |
-
'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
|
680 |
-
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
681 |
-
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
682 |
-
'num_atoms': len(positions)
|
683 |
-
}] * n_samples)
|
684 |
-
|
685 |
-
dataset = ProteinConditionedDataset(data=dataset)
|
686 |
-
ddpm.val_dataset = dataset
|
687 |
-
|
688 |
-
global_batch_size = min(n_samples, max_batch_size)
|
689 |
-
dataloader = get_dataloader(
|
690 |
-
dataset, batch_size=global_batch_size, collate_fn=collate_with_fragment_without_pocket_edges
|
691 |
-
)
|
692 |
-
|
693 |
-
# df.drop(columns=['ligand_mol', 'protein_path'], inplace=True)
|
694 |
-
linking_dfs = []
|
695 |
-
# Sampling
|
696 |
-
print('Sampling...')
|
697 |
-
# TODO: update linking_summary.csv per batch
|
698 |
-
for batch_i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
|
699 |
-
effective_batch_size = len(data['positions'])
|
700 |
-
complex_name = data['name'][0]
|
701 |
-
batch_df = pd.DataFrame({
|
702 |
-
'name': data['name'],
|
703 |
-
'X1': data['X1'],
|
704 |
-
'X2': data['X2'],
|
705 |
-
'protein_path': data['protein_path'],
|
706 |
-
'ligand_path': data['ligand_path'],
|
707 |
-
})
|
708 |
-
chain = None
|
709 |
-
node_mask = None
|
710 |
-
for i in range(5):
|
711 |
-
try:
|
712 |
-
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
|
713 |
-
break
|
714 |
-
except FoundNaNException:
|
715 |
-
continue
|
716 |
-
if chain is None:
|
717 |
-
log.warning(f'Could not generate linker for {complex_name} in 5 attempts')
|
718 |
-
continue
|
719 |
-
|
720 |
-
x = chain[0][:, :, :ddpm.n_dims]
|
721 |
-
h = chain[0][:, :, ddpm.n_dims:]
|
722 |
-
|
723 |
-
# Put the molecule back to the initial orientation
|
724 |
-
com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
|
725 |
-
pos_masked = data['positions'] * com_mask
|
726 |
-
N = com_mask.sum(1, keepdims=True)
|
727 |
-
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
|
728 |
-
x = x + mean * node_mask
|
729 |
-
node_mask[torch.where(data['pocket_mask'])] = 0
|
730 |
-
|
731 |
-
batch_df['out_path'] = [Path(output_dir, f'{complex_name}_{i}.sdf') for i in range(effective_batch_size)]
|
732 |
-
batch_df['one_hot'] = list(h.cpu())
|
733 |
-
batch_df['positions'] = list(x.cpu())
|
734 |
-
batch_df['node_mask'] = list(node_mask.cpu())
|
735 |
-
batch_df['X1^'] = batch_df.parallel_apply(
|
736 |
-
lambda row: save_sdf(
|
737 |
-
row['out_path'], row['one_hot'], row['positions'], row['node_mask'], is_geom=ddpm.is_geom
|
738 |
-
), axis=1
|
739 |
-
)
|
740 |
-
linking_dfs.append(batch_df[['name', 'protein_path', 'X2', 'ligand_path', 'X1', 'X1^', 'out_path']])
|
741 |
-
# for i in range(effective_batch_size):
|
742 |
-
# # # Save XYZ file and generate SMILES
|
743 |
-
# # out_xyz = Path(output_dir, f'{name}_{offset_idx+i}.xyz')
|
744 |
-
# # smiles = save_xyz_files(out_xyz, h[i], x[i], node_mask[i], is_geom=ddpm.is_geom)
|
745 |
-
# # # Convert XYZ to SDF
|
746 |
-
# # out_sdf = Path(output_dir, name, f'output_{offset_idx+i}.sdf')
|
747 |
-
# # with open(os.devnull, 'w') as devnull:
|
748 |
-
# # subprocess.run(f'obabel {out_xyz} -O {out_sdf} -q', shell=True, stdout=devnull)
|
749 |
-
# # Save SDF file and generate SMILES
|
750 |
-
# out_sdf = Path(output_dir, f'{data["name"][i]}.sdf')
|
751 |
-
# smiles = save_sdf(out_sdf, h[i], x[i], node_mask[i], is_geom=ddpm.is_geom)
|
752 |
-
#
|
753 |
-
# # Add experiment summary info
|
754 |
-
# batch_df['X1^'] = smiles
|
755 |
-
# batch_df['out_path'] = str(out_sdf)
|
756 |
-
# linking_dfs.append(batch_df)
|
757 |
-
|
758 |
-
if linking_dfs:
|
759 |
-
linking_summary_df = pd.concat(linking_dfs, ignore_index=True)
|
760 |
-
linking_summary_df.to_csv(Path(output_dir.parent, 'linking_summary.csv'), index=False)
|
761 |
-
print(f'Saved experiment summary and generated molecules to {output_dir}')
|
762 |
-
else:
|
763 |
-
raise ValueError('No linkers generated.')
|
764 |
-
|
765 |
-
|
766 |
-
if __name__ == "__main__":
|
767 |
-
parser = ArgumentParser()
|
768 |
-
# Fragment docking settings
|
769 |
-
parser.add_argument('--config', type=FileType(mode='r'), default='default_inference_args.yaml')
|
770 |
-
parser.add_argument('--protein_ligand_csv', type=str, default=None,
|
771 |
-
help='Path to a .csv file specifying the input as described in the README. '
|
772 |
-
'If this is not None, it will be used instead of the `X1` and `X2` parameters')
|
773 |
-
parser.add_argument('-n', '--name', type=str, default=None,
|
774 |
-
help='Name that the experiment will be saved with')
|
775 |
-
parser.add_argument('--X1', type=str,
|
776 |
-
help='Either a SMILES string or the path of a molecule file that rdkit can read')
|
777 |
-
parser.add_argument('--X2', type=str,
|
778 |
-
help='Either a FASTA sequence or the path of a protein for ESMFold')
|
779 |
-
|
780 |
-
parser.add_argument('-l', '--log', '--loglevel', type=str, default='INFO', dest="loglevel",
|
781 |
-
help='Log level. Default %(default)s')
|
782 |
-
|
783 |
-
parser.add_argument('--out_dir', type=str, default='results/',
|
784 |
-
help='Directory where the outputs will be written to')
|
785 |
-
parser.add_argument('--save_docking', action='store_true', default=True,
|
786 |
-
help='Save the intermediate docking results including SDF files and a summary CSV.')
|
787 |
-
parser.add_argument('--save_visualisation', action='store_true', default=False,
|
788 |
-
help='Save a pdb file with all of the steps of the reverse diffusion')
|
789 |
-
parser.add_argument('--samples_per_complex', type=int, default=10,
|
790 |
-
help='Number of samples to generate')
|
791 |
-
|
792 |
-
# parser.add_argument('--model_dir', type=str, default=None,
|
793 |
-
# help='Path to folder with trained score model and hyperparameters')
|
794 |
-
parser.add_argument('--score_ckpt', type=str, default='best_ema_inference_epoch_model.pt',
|
795 |
-
help='Checkpoint to use for the score model')
|
796 |
-
# parser.add_argument('--confidence_model_dir', type=str, default=None,
|
797 |
-
# help='Path to folder with trained confidence model and hyperparameters')
|
798 |
-
parser.add_argument('--confidence_ckpt', type=str, default='best_model.pt',
|
799 |
-
help='Checkpoint to use for the confidence model')
|
800 |
-
|
801 |
-
parser.add_argument('--n_poses', type=int, default=10, help='')
|
802 |
-
parser.add_argument('--no_final_step_noise', action='store_true', default=True,
|
803 |
-
help='Use no noise in the final step of the reverse diffusion')
|
804 |
-
parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps')
|
805 |
-
parser.add_argument('--actual_steps', type=int, default=None,
|
806 |
-
help='Number of denoising steps that are actually performed')
|
807 |
-
|
808 |
-
parser.add_argument('--old_score_model', action='store_true', default=False, help='')
|
809 |
-
parser.add_argument('--old_confidence_model', action='store_true', default=True, help='')
|
810 |
-
parser.add_argument('--initial_noise_std_proportion', type=float, default=-1.0,
|
811 |
-
help='Initial noise std proportion')
|
812 |
-
parser.add_argument('--choose_residue', action='store_true', default=False, help='')
|
813 |
-
|
814 |
-
parser.add_argument('--temp_sampling_tr', type=float, default=1.0)
|
815 |
-
parser.add_argument('--temp_psi_tr', type=float, default=0.0)
|
816 |
-
parser.add_argument('--temp_sigma_data_tr', type=float, default=0.5)
|
817 |
-
parser.add_argument('--temp_sampling_rot', type=float, default=1.0)
|
818 |
-
parser.add_argument('--temp_psi_rot', type=float, default=0.0)
|
819 |
-
parser.add_argument('--temp_sigma_data_rot', type=float, default=0.5)
|
820 |
-
parser.add_argument('--temp_sampling_tor', type=float, default=1.0)
|
821 |
-
parser.add_argument('--temp_psi_tor', type=float, default=0.0)
|
822 |
-
parser.add_argument('--temp_sigma_data_tor', type=float, default=0.5)
|
823 |
-
|
824 |
-
parser.add_argument('--gnina_minimize', action='store_true', default=False, help='')
|
825 |
-
parser.add_argument('--gnina_path', type=str, default='gnina', help='')
|
826 |
-
parser.add_argument('--gnina_log_file', type=str, default='gnina_log.txt',
|
827 |
-
help='') # To redirect gnina subprocesses stdouts from the terminal window
|
828 |
-
parser.add_argument('--gnina_full_dock', action='store_true', default=False, help='')
|
829 |
-
parser.add_argument('--gnina_autobox_add', type=float, default=4.0)
|
830 |
-
parser.add_argument('--gnina_poses_to_optimize', type=int, default=1)
|
831 |
-
|
832 |
-
# Linker generation settings
|
833 |
-
# parser.add_argument('--fragments', action='store', type=str, required=True,
|
834 |
-
# help='Path to the file with input fragments'
|
835 |
-
# )
|
836 |
-
# parser.add_argument(
|
837 |
-
# '--protein', action='store', type=str, required=True,
|
838 |
-
# help='Path to the file with the target protein'
|
839 |
-
# )
|
840 |
-
parser.add_argument(
|
841 |
-
'--backbone_atoms_only', action='store_true', required=False, default=False,
|
842 |
-
help='Flag if to use only protein backbone atoms'
|
843 |
-
)
|
844 |
-
parser.add_argument(
|
845 |
-
'--linker_ckpt', action='store', type=str,
|
846 |
-
help='Path to the DiffLinker model'
|
847 |
-
)
|
848 |
-
parser.add_argument(
|
849 |
-
'--linker_size', action='store', type=str,
|
850 |
-
help='Linker size (int) or allowed size boundaries (comma-separated) or path to the size prediction model'
|
851 |
-
)
|
852 |
-
parser.add_argument(
|
853 |
-
'--n_linkers', action='store', type=int, required=False, default=5,
|
854 |
-
help='Number of linkers to generate'
|
855 |
-
)
|
856 |
-
parser.add_argument(
|
857 |
-
'--n_steps', action='store', type=int, required=False, default=1000,
|
858 |
-
help='Number of denoising steps'
|
859 |
-
)
|
860 |
-
parser.add_argument(
|
861 |
-
'--anchors', action='store', type=str, required=False, default=None,
|
862 |
-
help='Comma-separated indices of anchor atoms '
|
863 |
-
'(according to the order of atoms in the input fragments file, enumeration starts with 1)'
|
864 |
-
)
|
865 |
-
parser.add_argument(
|
866 |
-
'--max_batch_size', action='store', type=int, required=False, default=16,
|
867 |
-
help='Max batch size'
|
868 |
-
)
|
869 |
-
parser.add_argument(
|
870 |
-
'--random_seed', action='store', type=int, required=False, default=None,
|
871 |
-
help='Random seed'
|
872 |
-
)
|
873 |
-
parser.add_argument(
|
874 |
-
'--robust', action='store_true', required=False, default=False,
|
875 |
-
help='Robust sampling modification'
|
876 |
-
)
|
877 |
-
parser.add_argument(
|
878 |
-
'--dock', action='store_true', default=False,
|
879 |
-
help='Fragment docking with DiffDock'
|
880 |
-
)
|
881 |
-
parser.add_argument(
|
882 |
-
'--link', action='store_true', default=False,
|
883 |
-
help='Linker generation with DiffLinker'
|
884 |
-
)
|
885 |
-
|
886 |
-
args = parser.parse_args()
|
887 |
-
if args.config:
|
888 |
-
config_dict = yaml.load(args.config, Loader=yaml.FullLoader)
|
889 |
-
arg_dict = args.__dict__
|
890 |
-
for key, value in config_dict.items():
|
891 |
-
if isinstance(value, list):
|
892 |
-
for v in value:
|
893 |
-
arg_dict[key].append(v)
|
894 |
-
else:
|
895 |
-
arg_dict[key] = value
|
896 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
897 |
-
configure_logger(args.loglevel)
|
898 |
-
log = get_logger()
|
899 |
-
|
900 |
-
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
901 |
-
experiment_name = f"{date_time}_{args.name}"
|
902 |
-
args.out_dir = Path(args.out_dir, experiment_name)
|
903 |
-
|
904 |
-
if args.dock:
|
905 |
-
docking_df = dock_fragments(args)
|
906 |
-
linking_df = process_docking_results(
|
907 |
-
docking_df,
|
908 |
-
eps=args.eps, min_samples=args.min_samples,
|
909 |
-
frag_dist_range=args.frag_dist_range, distance_type=args.distance_type
|
910 |
-
)
|
911 |
-
if args.link:
|
912 |
-
ddpm = DDPM.load_from_checkpoint(args.linker_ckpt, map_location=device, robust=args.robust).eval().to(device)
|
913 |
-
generate_linker(
|
914 |
-
linking_df,
|
915 |
-
backbone_atoms_only=args.backbone_atoms_only,
|
916 |
-
model=ddpm,
|
917 |
-
output_dir=args.out_dir,
|
918 |
-
n_samples=args.n_linkers,
|
919 |
-
n_steps=args.n_steps,
|
920 |
-
linker_size=args.linker_size,
|
921 |
-
anchors=args.anchors,
|
922 |
-
max_batch_size=args.max_batch_size,
|
923 |
-
random_seed=args.random_seed,
|
924 |
-
)
|
925 |
-
|
926 |
-
if args.link:
|
927 |
-
linking_df = pd.read_csv(args.protein_ligand_csv)
|
928 |
-
ddpm = DDPM.load_from_checkpoint(args.linker_ckpt, map_location=device, robust=args.robust).eval().to(device)
|
929 |
-
generate_linker(
|
930 |
-
linking_df,
|
931 |
-
backbone_atoms_only=args.backbone_atoms_only,
|
932 |
-
model=ddpm,
|
933 |
-
output_dir=args.out_dir,
|
934 |
-
n_samples=args.n_linkers,
|
935 |
-
n_steps=args.n_steps,
|
936 |
-
linker_size=args.linker_size,
|
937 |
-
anchors=args.anchors,
|
938 |
-
max_batch_size=args.max_batch_size,
|
939 |
-
random_seed=args.random_seed,
|
940 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mean_cluster.yaml
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
actual_steps: 19
|
2 |
-
score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
|
3 |
-
confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
|
4 |
-
linker_ckpt: resources/checkpoints/DiffLinker/pockets_difflinker_full.ckpt
|
5 |
-
linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
|
6 |
-
#confidence_model_dir: ./workdir/v1.1/confidence_model
|
7 |
-
different_schedules: false
|
8 |
-
inf_sched_alpha: 1
|
9 |
-
inf_sched_beta: 1
|
10 |
-
inference_steps: 20
|
11 |
-
initial_noise_std_proportion: 1.4601642460337794
|
12 |
-
limit_failures: 5
|
13 |
-
#model_dir: ./workdir/v1.1/score_model
|
14 |
-
#comment
|
15 |
-
no_final_step_noise: true
|
16 |
-
no_model: false
|
17 |
-
no_random: false
|
18 |
-
no_random_pocket: false
|
19 |
-
ode: false
|
20 |
-
old_filtering_model: true
|
21 |
-
old_score_model: false
|
22 |
-
resample_rdkit: false
|
23 |
-
samples_per_complex: 10
|
24 |
-
sigma_schedule: expbeta
|
25 |
-
temp_psi_rot: 0.9022615585677628
|
26 |
-
temp_psi_tor: 0.5946212391366862
|
27 |
-
temp_psi_tr: 0.727287304570729
|
28 |
-
temp_sampling_rot: 2.06391612594481
|
29 |
-
temp_sampling_tor: 7.044261621607846
|
30 |
-
temp_sampling_tr: 1.170050527854316
|
31 |
-
temp_sigma_data_rot: 0.7464326999906034
|
32 |
-
temp_sigma_data_tor: 0.6943254174849822
|
33 |
-
temp_sigma_data_tr: 0.9299802531572672
|
34 |
-
|
35 |
-
eps: 5
|
36 |
-
min_samples: 10
|
37 |
-
frag_dist_range: (2, 5)
|
38 |
-
distance_type: min
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
min_cluster.yaml
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
actual_steps: 19
|
2 |
-
score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
|
3 |
-
confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
|
4 |
-
linker_ckpt: resources/checkpoints/DiffLinker/pockets_difflinker_full.ckpt
|
5 |
-
linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
|
6 |
-
#confidence_model_dir: ./workdir/v1.1/confidence_model
|
7 |
-
different_schedules: false
|
8 |
-
inf_sched_alpha: 1
|
9 |
-
inf_sched_beta: 1
|
10 |
-
inference_steps: 20
|
11 |
-
initial_noise_std_proportion: 1.4601642460337794
|
12 |
-
limit_failures: 5
|
13 |
-
#model_dir: ./workdir/v1.1/score_model
|
14 |
-
#comment
|
15 |
-
no_final_step_noise: true
|
16 |
-
no_model: false
|
17 |
-
no_random: false
|
18 |
-
no_random_pocket: false
|
19 |
-
ode: false
|
20 |
-
old_filtering_model: true
|
21 |
-
old_score_model: false
|
22 |
-
resample_rdkit: false
|
23 |
-
samples_per_complex: 10
|
24 |
-
sigma_schedule: expbeta
|
25 |
-
temp_psi_rot: 0.9022615585677628
|
26 |
-
temp_psi_tor: 0.5946212391366862
|
27 |
-
temp_psi_tr: 0.727287304570729
|
28 |
-
temp_sampling_rot: 2.06391612594481
|
29 |
-
temp_sampling_tor: 7.044261621607846
|
30 |
-
temp_sampling_tr: 1.170050527854316
|
31 |
-
temp_sigma_data_rot: 0.7464326999906034
|
32 |
-
temp_sigma_data_tor: 0.6943254174849822
|
33 |
-
temp_sigma_data_tr: 0.9299802531572672
|
34 |
-
|
35 |
-
eps: 1
|
36 |
-
min_samples: 10
|
37 |
-
frag_dist_range: (5, 8)
|
38 |
-
distance_type: mean
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pocket_backbone.yaml
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
actual_steps: 19
|
2 |
-
score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
|
3 |
-
confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
|
4 |
-
#confidence_model_dir: ./workdir/v1.1/confidence_model
|
5 |
-
different_schedules: false
|
6 |
-
inf_sched_alpha: 1
|
7 |
-
inf_sched_beta: 1
|
8 |
-
inference_steps: 20
|
9 |
-
initial_noise_std_proportion: 1.4601642460337794
|
10 |
-
limit_failures: 5
|
11 |
-
#model_dir: ./workdir/v1.1/score_model
|
12 |
-
#comment
|
13 |
-
no_final_step_noise: true
|
14 |
-
no_model: false
|
15 |
-
no_random: false
|
16 |
-
no_random_pocket: false
|
17 |
-
ode: false
|
18 |
-
old_filtering_model: true
|
19 |
-
old_score_model: false
|
20 |
-
resample_rdkit: false
|
21 |
-
samples_per_complex: 10
|
22 |
-
sigma_schedule: expbeta
|
23 |
-
temp_psi_rot: 0.9022615585677628
|
24 |
-
temp_psi_tor: 0.5946212391366862
|
25 |
-
temp_psi_tr: 0.727287304570729
|
26 |
-
temp_sampling_rot: 2.06391612594481
|
27 |
-
temp_sampling_tor: 7.044261621607846
|
28 |
-
temp_sampling_tr: 1.170050527854316
|
29 |
-
temp_sigma_data_rot: 0.7464326999906034
|
30 |
-
temp_sigma_data_tor: 0.6943254174849822
|
31 |
-
temp_sigma_data_tr: 0.9299802531572672
|
32 |
-
|
33 |
-
eps: 5
|
34 |
-
min_samples: 10
|
35 |
-
frag_dist_range:
|
36 |
-
- 2
|
37 |
-
- 5
|
38 |
-
distance_type: min
|
39 |
-
|
40 |
-
linker_ckpt:
|
41 |
-
pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
42 |
-
pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
43 |
-
# resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
|
44 |
-
geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
|
45 |
-
linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
|
46 |
-
linker_condition: 'protein' # pocket
|
47 |
-
linker_batch_size: 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pocket_full_no_anchors_4a.yaml
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
actual_steps: 19
|
2 |
-
score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
|
3 |
-
confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
|
4 |
-
#confidence_model_dir: ./workdir/v1.1/confidence_model
|
5 |
-
different_schedules: false
|
6 |
-
inf_sched_alpha: 1
|
7 |
-
inf_sched_beta: 1
|
8 |
-
inference_steps: 20
|
9 |
-
initial_noise_std_proportion: 1.4601642460337794
|
10 |
-
limit_failures: 5
|
11 |
-
#model_dir: ./workdir/v1.1/score_model
|
12 |
-
#comment
|
13 |
-
no_final_step_noise: true
|
14 |
-
no_model: false
|
15 |
-
no_random: false
|
16 |
-
no_random_pocket: false
|
17 |
-
ode: false
|
18 |
-
old_filtering_model: true
|
19 |
-
old_score_model: false
|
20 |
-
resample_rdkit: false
|
21 |
-
samples_per_complex: 10
|
22 |
-
sigma_schedule: expbeta
|
23 |
-
temp_psi_rot: 0.9022615585677628
|
24 |
-
temp_psi_tor: 0.5946212391366862
|
25 |
-
temp_psi_tr: 0.727287304570729
|
26 |
-
temp_sampling_rot: 2.06391612594481
|
27 |
-
temp_sampling_tor: 7.044261621607846
|
28 |
-
temp_sampling_tr: 1.170050527854316
|
29 |
-
temp_sigma_data_rot: 0.7464326999906034
|
30 |
-
temp_sigma_data_tor: 0.6943254174849822
|
31 |
-
temp_sigma_data_tr: 0.9299802531572672
|
32 |
-
|
33 |
-
eps: 5
|
34 |
-
min_samples: 10
|
35 |
-
frag_dist_range: (2, 5)
|
36 |
-
distance_type: min
|
37 |
-
|
38 |
-
linker_ckpt:
|
39 |
-
# pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
40 |
-
pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors.ckpt
|
41 |
-
pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
|
42 |
-
geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
|
43 |
-
linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
|
44 |
-
linker_condition: 'protein' # pocket
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pocket_full_no_anchors_fc_pdb_excluded.yaml
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
actual_steps: 19
|
2 |
-
score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
|
3 |
-
confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
|
4 |
-
#confidence_model_dir: ./workdir/v1.1/confidence_model
|
5 |
-
different_schedules: false
|
6 |
-
inf_sched_alpha: 1
|
7 |
-
inf_sched_beta: 1
|
8 |
-
inference_steps: 20
|
9 |
-
initial_noise_std_proportion: 1.4601642460337794
|
10 |
-
limit_failures: 5
|
11 |
-
#model_dir: ./workdir/v1.1/score_model
|
12 |
-
#comment
|
13 |
-
no_final_step_noise: true
|
14 |
-
no_model: false
|
15 |
-
no_random: false
|
16 |
-
no_random_pocket: false
|
17 |
-
ode: false
|
18 |
-
old_filtering_model: true
|
19 |
-
old_score_model: false
|
20 |
-
resample_rdkit: false
|
21 |
-
samples_per_complex: 10
|
22 |
-
sigma_schedule: expbeta
|
23 |
-
temp_psi_rot: 0.9022615585677628
|
24 |
-
temp_psi_tor: 0.5946212391366862
|
25 |
-
temp_psi_tr: 0.727287304570729
|
26 |
-
temp_sampling_rot: 2.06391612594481
|
27 |
-
temp_sampling_tor: 7.044261621607846
|
28 |
-
temp_sampling_tr: 1.170050527854316
|
29 |
-
temp_sigma_data_rot: 0.7464326999906034
|
30 |
-
temp_sigma_data_tor: 0.6943254174849822
|
31 |
-
temp_sigma_data_tr: 0.9299802531572672
|
32 |
-
|
33 |
-
eps: 5
|
34 |
-
min_samples: 10
|
35 |
-
frag_dist_range:
|
36 |
-
- 2
|
37 |
-
- 5
|
38 |
-
distance_type: min
|
39 |
-
|
40 |
-
linker_ckpt:
|
41 |
-
pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
42 |
-
pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
43 |
-
# resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
|
44 |
-
geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
|
45 |
-
linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
|
46 |
-
linker_condition: 'protein' # pocket
|
47 |
-
linker_batch_size: 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protein_backbone.yaml
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
actual_steps: 19
|
2 |
-
score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
|
3 |
-
confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
|
4 |
-
#confidence_model_dir: ./workdir/v1.1/confidence_model
|
5 |
-
different_schedules: false
|
6 |
-
inf_sched_alpha: 1
|
7 |
-
inf_sched_beta: 1
|
8 |
-
inference_steps: 20
|
9 |
-
initial_noise_std_proportion: 1.4601642460337794
|
10 |
-
limit_failures: 5
|
11 |
-
#model_dir: ./workdir/v1.1/score_model
|
12 |
-
#comment
|
13 |
-
no_final_step_noise: true
|
14 |
-
no_model: false
|
15 |
-
no_random: false
|
16 |
-
no_random_pocket: false
|
17 |
-
ode: false
|
18 |
-
old_filtering_model: true
|
19 |
-
old_score_model: false
|
20 |
-
resample_rdkit: false
|
21 |
-
samples_per_complex: 10
|
22 |
-
sigma_schedule: expbeta
|
23 |
-
temp_psi_rot: 0.9022615585677628
|
24 |
-
temp_psi_tor: 0.5946212391366862
|
25 |
-
temp_psi_tr: 0.727287304570729
|
26 |
-
temp_sampling_rot: 2.06391612594481
|
27 |
-
temp_sampling_tor: 7.044261621607846
|
28 |
-
temp_sampling_tr: 1.170050527854316
|
29 |
-
temp_sigma_data_rot: 0.7464326999906034
|
30 |
-
temp_sigma_data_tor: 0.6943254174849822
|
31 |
-
temp_sigma_data_tr: 0.9299802531572672
|
32 |
-
|
33 |
-
eps: 5
|
34 |
-
min_samples: 10
|
35 |
-
frag_dist_range:
|
36 |
-
- 2
|
37 |
-
- 5
|
38 |
-
distance_type: min
|
39 |
-
|
40 |
-
linker_ckpt:
|
41 |
-
pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
42 |
-
pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
|
43 |
-
geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
|
44 |
-
linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
|
45 |
-
linker_condition: 'protein' # pocket
|
46 |
-
linker_batch_size: 20
|
47 |
-
backbone_atoms_only: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protein_full_no_anchors_fc_pdb_excluded.yaml
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
actual_steps: 19
|
2 |
-
score_ckpt: resources/checkpoints/DiffDock/score_model/best_ema_inference_epoch_model.pt
|
3 |
-
confidence_ckpt: resources/checkpoints/DiffDock/confidence_model/best_model_epoch75.pt
|
4 |
-
#confidence_model_dir: ./workdir/v1.1/confidence_model
|
5 |
-
different_schedules: false
|
6 |
-
inf_sched_alpha: 1
|
7 |
-
inf_sched_beta: 1
|
8 |
-
inference_steps: 20
|
9 |
-
initial_noise_std_proportion: 1.4601642460337794
|
10 |
-
limit_failures: 5
|
11 |
-
#model_dir: ./workdir/v1.1/score_model
|
12 |
-
#comment
|
13 |
-
no_final_step_noise: true
|
14 |
-
no_model: false
|
15 |
-
no_random: false
|
16 |
-
no_random_pocket: false
|
17 |
-
ode: false
|
18 |
-
old_filtering_model: true
|
19 |
-
old_score_model: false
|
20 |
-
resample_rdkit: false
|
21 |
-
samples_per_complex: 10
|
22 |
-
sigma_schedule: expbeta
|
23 |
-
temp_psi_rot: 0.9022615585677628
|
24 |
-
temp_psi_tor: 0.5946212391366862
|
25 |
-
temp_psi_tr: 0.727287304570729
|
26 |
-
temp_sampling_rot: 2.06391612594481
|
27 |
-
temp_sampling_tor: 7.044261621607846
|
28 |
-
temp_sampling_tr: 1.170050527854316
|
29 |
-
temp_sigma_data_rot: 0.7464326999906034
|
30 |
-
temp_sigma_data_tor: 0.6943254174849822
|
31 |
-
temp_sigma_data_tr: 0.9299802531572672
|
32 |
-
|
33 |
-
eps: 5
|
34 |
-
min_samples: 10
|
35 |
-
frag_dist_range:
|
36 |
-
- 2
|
37 |
-
- 5
|
38 |
-
distance_type: min
|
39 |
-
|
40 |
-
linker_ckpt:
|
41 |
-
pocket_full: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
42 |
-
pocket_bb: resources/checkpoints/DiffLinker/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt
|
43 |
-
# resources/checkpoints/DiffLinker/pockets_difflinker_backbone.ckpt
|
44 |
-
geom: resources/checkpoints/DiffLinker/geom_difflinker.ckpt
|
45 |
-
linker_size: resources/checkpoints/DiffLinker/geom_size_gnn.ckpt
|
46 |
-
linker_condition: 'protein' # pocket
|
47 |
-
linker_batch_size: 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|