from functools import cache from pathlib import Path from esm import FastaBatchedDataset, pretrained from rdkit.Chem import AddHs from torch_geometric.data import Dataset, HeteroData import numpy as np import torch import prody as pr import esm import pandas as pd from datasets.process_mols import generate_conformer, read_molecule, get_lig_graph_with_matching, moad_extract_receptor_structure from datasets.parse_chi import aa_idx2aa_short, get_onehot_sequence def get_sequences_from_pdbfile(file_path): sequence = None # prodyb package requires str input pdb = pr.parsePDB(str(file_path)) seq = pdb.ca.getSequence() one_hot = get_onehot_sequence(seq) chain_ids = np.zeros(len(one_hot)) res_chain_ids = pdb.ca.getChids() res_seg_ids = pdb.ca.getSegnames() res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)]) ids = np.unique(res_chain_ids) for i, id in enumerate(ids): chain_ids[res_chain_ids == id] = i s_temp = np.argmax(one_hot[res_chain_ids == id], axis=1) s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s_temp]) if sequence is None: sequence = s else: sequence += (":" + s) return sequence @cache def process_protein(protein_string): input_path = Path(protein_string) # Check if the input is a path to a file if Path(protein_string).is_absolute() or len(Path(protein_string).parts) > 1: # Check if the input is a PDB file path if input_path.is_file() and input_path.suffix == '.pdb': # Extract sequence from PDB file return get_sequences_from_pdbfile(input_path), str(input_path) else: raise FileNotFoundError(f"File {protein_string} not found or not a PDB file") else: # Assume the input is already a FASTA sequence return protein_string, None def compute_esm_embeddings(model, alphabet, labels, sequences): # settings used toks_per_batch = 4096 repr_layers = [33] truncation_seq_length = 1022 # Deduplicate by label to avoid redundant ESM computation unique_labels = set() unique_sequences = [] for label, sequence in zip(labels, sequences): if label not in unique_labels: unique_labels.add(label) unique_sequences.append(sequence) unique_labels = list(unique_labels) # Create the dataset with unique labels and sequences dataset = FastaBatchedDataset(unique_labels, unique_sequences) batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches ) assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers) repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers] embeddings = {} with torch.no_grad(): for batch_idx, (labels, strs, toks) in enumerate(data_loader): print(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)") if torch.cuda.is_available(): toks = toks.to(device="cuda", non_blocking=True) out = model(toks, repr_layers=repr_layers, return_contacts=False) representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()} for i, label in enumerate(labels): truncate_len = min(truncation_seq_length, len(strs[i])) embeddings[label] = representations[33][i, 1: truncate_len + 1].clone() return embeddings def generate_esm_structure(model, filename, sequence): model.set_chunk_size(256) chunk_size = 256 output = None while output is None: try: with torch.no_grad(): output = model.infer_pdb(sequence) with open(filename, "w") as f: f.write(output) print("saved", filename) except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory on chunk_size', chunk_size) for p in model.parameters(): if p.grad is not None: del p.grad # free some memory torch.cuda.empty_cache() chunk_size = chunk_size // 2 if chunk_size > 2: model.set_chunk_size(chunk_size) else: print("Not enough memory for ESMFold") break else: raise e return output is not None class InferenceDataset(Dataset): def __init__(self, df, out_dir, lm_embeddings, receptor_radius=30, c_alpha_max_neighbors=None, precomputed_lm_embeddings=None, remove_hs=False, all_atoms=False, atom_radius=5, atom_max_neighbors=None, knn_only_graph=False): super(InferenceDataset, self).__init__() self.receptor_radius = receptor_radius self.c_alpha_max_neighbors = c_alpha_max_neighbors self.remove_hs = remove_hs self.all_atoms = all_atoms self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors self.knn_only_graph = knn_only_graph self.df = df # generate LM embeddings if lm_embeddings and (precomputed_lm_embeddings is None or precomputed_lm_embeddings[0] is None): print("Generating ESM language model embeddings") model_location = "esm2_t33_650M_UR50D" model, alphabet = pretrained.load_model_and_alphabet(model_location) model.eval() if torch.cuda.is_available(): model = model.cuda() df[['protein_sequence', 'protein_path']] = df['X2'].apply(process_protein).apply(pd.Series) labels, sequences = [], [] for i in range(len(df)): s = df['protein_sequence'].iloc[i].split(':') sequences.extend(s) labels.extend([df['name'].iloc[i] + '_chain_' + str(j) for j in range(len(s))]) # TODO improve efficiency for repeated X2 values lm_embeddings = compute_esm_embeddings(model, alphabet, labels, sequences) self.lm_embeddings = [] for i in range(len(df)): s = df['protein_sequence'].iloc[i].split(':') self.lm_embeddings.append( [lm_embeddings[f"{df['name'].iloc[i]}_chain_{j}"] for j in range(len(s))] ) elif not lm_embeddings: self.lm_embeddings = [None] * len(self.complex_names) else: self.lm_embeddings = precomputed_lm_embeddings # generate structures with ESMFold if None in df['protein_path'].values: print("generating missing structures with ESMFold") model = esm.pretrained.esmfold_v1() model = model.eval().cuda() for i in range(len(df)): # TODO improve efficiency for repeated X2 values protein_sequence = df['protein_sequence'].iloc[i] protein_file = df['protein_path'].iloc[i] complex_name = df['name'].iloc[i] if protein_file is None: protein_file = f"{out_dir}/{complex_name}/{complex_name}_esmfold.pdb" if not Path(protein_file).is_file(): print("generating", df['protein_path'].iloc[i]) generate_esm_structure(model, protein_file, protein_sequence) df['protein_sequence'].iloc[i] = protein_sequence def len(self): return len(self.df) def get(self, idx): name = self.df['name'].iloc[idx] protein_file = self.df['protein_path'].iloc[idx] ligand_description = self.df['X1'].iloc[idx] mol = self.df['mol'].iloc[idx] lm_embedding = self.lm_embeddings[idx] # build the pytorch geometric heterogeneous graph complex_graph = HeteroData() complex_graph['name'] = name if mol is not None: mol = AddHs(mol) generate_conformer(mol) else: print(f'Failed to read molecule {ligand_description}. Skipping...') complex_graph['success'] = False return complex_graph try: # parse the receptor from the pdb file get_lig_graph_with_matching(mol, complex_graph, popsize=None, maxiter=None, matching=False, keep_original=False, num_conformers=1, remove_hs=self.remove_hs) moad_extract_receptor_structure( path=protein_file, complex_graph=complex_graph, neighbor_cutoff=self.receptor_radius, max_neighbors=self.c_alpha_max_neighbors, lm_embeddings=lm_embedding, knn_only_graph=self.knn_only_graph, all_atoms=self.all_atoms, atom_cutoff=self.atom_radius, atom_max_neighbors=self.atom_max_neighbors) except Exception as e: print(f'Skipping {name} because of the error:') print(e) complex_graph['success'] = False return complex_graph protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True) complex_graph['receptor'].pos -= protein_center if self.all_atoms: complex_graph['atom'].pos -= protein_center ligand_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True) complex_graph['ligand'].pos -= ligand_center complex_graph.original_center = protein_center complex_graph.mol = mol complex_graph['success'] = True return complex_graph