habdine commited on
Commit
d6cd0d2
1 Parent(s): 6376152

Delete pdb2graph.py

Browse files
Files changed (1) hide show
  1. pdb2graph.py +0 -178
pdb2graph.py DELETED
@@ -1,178 +0,0 @@
1
- import multiprocessing
2
- import os
3
- from tqdm import tqdm
4
- from sklearn.preprocessing import MultiLabelBinarizer
5
-
6
- try:
7
- from torch_geometric.data import Data
8
- except ImportError:
9
- raise Exception('You need to install torch geometric and its dependecies to use this model please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html')
10
- import torch
11
-
12
- import numpy as np
13
-
14
- from .conversion import convert_nx_to_pyg_data
15
-
16
- try:
17
- from graphein.protein.config import ProteinGraphConfig, DSSPConfig
18
- from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale, hydrogen_bond_acceptor, hydrogen_bond_donor
19
- from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure
20
- from graphein.protein.edges.distance import (add_peptide_bonds,
21
- add_hydrogen_bond_interactions,
22
- add_disulfide_interactions,
23
- add_ionic_interactions,
24
- add_delaunay_triangulation,
25
- add_distance_threshold,
26
- add_sequence_distance_edges,
27
- add_k_nn_edges)
28
- except ImportError:
29
- raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html')
30
-
31
- from functools import partial
32
- from .graphs import *
33
- from .utils_dataset import *
34
- import os
35
- import sys
36
- import subprocess
37
- import wget
38
-
39
-
40
- class PDB2Graph():
41
- def __init__(self, root, output_folder, config, n_processors=int(multiprocessing.cpu_count())):
42
- self.root = root
43
- self.output_folder = output_folder
44
- self.map_secondary_structure = {'-':0, 'H':1, 'B':2, 'E':3, 'G':4, 'I':5, 'T':6, 'S':7}
45
- self.init_ohe_edge_type()
46
- self.config = config
47
- self.features = ['phi', 'psi', 'rsa', 'asa', 'ss', 'expasy']
48
- self.n_processors = n_processors
49
- self.raw_dir = root
50
- self.processed_dir = self._processed_dir()
51
- self.raw_file_names = self._raw_file_names()
52
- self.processed_file_names = self._processed_file_names()
53
-
54
-
55
- def _processed_dir(self):
56
- #processed_dir = os.path.join(os.path.split(self.root)[0], "processed_new")
57
- if not os.path.exists(self.output_folder):
58
- os.makedirs(self.output_folder)
59
- return self.output_folder
60
-
61
- def _raw_file_names(self):
62
- return os.listdir(self.raw_dir)
63
-
64
- def _processed_file_names(self):
65
- return [self.pdb2pathdata(pdb_path.split(".")[0]) for pdb_path in self.raw_file_names]
66
-
67
- def create_nx_graph(self, path_to_structure):
68
- return construct_graph(self.config, pdb_path = path_to_structure)
69
-
70
- def create_pyg_graph(self, path_to_structure):
71
- pyg_graph = convert_nx_to_pyg_data(self.create_nx_graph(path_to_structure))
72
-
73
- graph = Data(edge_index = pyg_graph.edge_index,
74
- num_nodes = len(pyg_graph.node_id),
75
- node_id = pyg_graph.node_id,
76
- name = pyg_graph.name[0],
77
- sequence = getattr(pyg_graph, f"sequence_{pyg_graph.chain_id[0]}"),
78
- distance_matrix = pyg_graph.dist_mat,
79
- distance = pyg_graph.distance,
80
- coordinates = torch.FloatTensor(np.array(pyg_graph.coords[0])))
81
- #create the features
82
- x = np.array([np.argmax(pyg_graph.amino_acid_one_hot, axis=1)]).reshape(-1,1)
83
- for feat in self.features:
84
- if feat == "ss":
85
- feature = np.array([[self.map_secondary_structure.get(feat_node, 0)] \
86
- for feat_node in pyg_graph[feat]])
87
- else:
88
- feature = np.array(pyg_graph[feat])
89
- if len(feature.shape) == 1:
90
- feature = feature.reshape(-1,1)
91
- x = np.concatenate((x, feature), axis = 1)
92
- graph.edge_type = self.mlb.transform(pyg_graph.kind)
93
- graph.x = torch.FloatTensor(x)
94
- # y = self.annotations[graph.name.split("_")[0]]
95
- # if self.task == 'GeneOntology' :
96
- # graph.y_mf = torch.FloatTensor(y["mf"])
97
- # graph.y_cc = torch.FloatTensor(y["cc"])
98
- # graph.y_bp = torch.FloatTensor(y["bp"])
99
- # else:
100
- # graph.y_ec = torch.FloatTensor(y["ec"])
101
- return graph
102
-
103
- def init_ohe_edge_type(self):
104
- self.mlb = MultiLabelBinarizer(classes = ['peptide_bond', 'sequence_distance_2', 'sequence_distance_3'
105
- , 'distance_threshold', 'delaunay', 'hbond', 'k_nn'])
106
- self.mlb.fit([['peptide_bond', 'sequence_distance_2', 'sequence_distance_3'
107
- , 'distance_threshold', 'delaunay', 'hbond', 'k_nn']])
108
-
109
- def process(self):
110
- """Convert the PDB files into torch geometric graphs"""
111
- # self.pdb2graph = PDB2Graph(self.config)
112
- to_be_processed = self.get_files_to_process()
113
-
114
- # pool = multiprocessing.Pool(self.n_processors)
115
- # for _ in tqdm(pool.imap_unordered(self.graph_creation, to_be_processed), total=len(to_be_processed)):
116
- # continue
117
- # pool.close()
118
- # pool.join()
119
-
120
-
121
-
122
- processes = []
123
- for prot in tqdm(to_be_processed):
124
- p = multiprocessing.Process(target=self.graph_creation, args=(prot,))
125
- processes.append(p)
126
- p.start()
127
-
128
- for process in processes:
129
- process.join()
130
-
131
-
132
- def graph_creation(self, pdb):
133
- """Create a graph from the PDB file"""
134
-
135
- # Define the path_to_structure from the pdb name file
136
- path_to_structure = self.pdb2pathstructure(pdb)
137
-
138
- # Convert the structure into a graph
139
- g = self.create_pyg_graph(path_to_structure)
140
- # Save the graph
141
- torch.save(g, os.path.join(self.output_folder, self.pdb2pathdata(pdb)))
142
-
143
- return None
144
-
145
- def pdb2pathdata(self, pdb):
146
- return pdb+'.pt'
147
-
148
- def pdb2pathstructure(self, pdb):
149
- return os.path.join(self.raw_dir, pdb+'.pdb')
150
-
151
- def get_files_to_process(self):
152
- RAW_FILES = self.processed_file_names
153
- PROCESSED_FILES = os.listdir(self.processed_dir)
154
- to_be_processed = set(RAW_FILES).difference(set(PROCESSED_FILES))
155
- to_be_processed = [path.split('.')[0] for path in to_be_processed]
156
- return to_be_processed
157
-
158
- def download_alphafold_structure(
159
- uniprot_id: str,
160
- out_dir: str,
161
- version: int = 4
162
- ):
163
-
164
- BASE_URL = "https://alphafold.ebi.ac.uk/files/"
165
- uniprot_id = uniprot_id.upper()
166
-
167
- query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v{version}.pdb"
168
- structure_filename = os.path.join(out_dir, f"AF-{uniprot_id}-F1-model_v{version}.pdb")
169
- if os.path.exists(structure_filename):
170
- return structure_filename
171
- try:
172
- structure_filename = wget.download(query_url, out=out_dir)
173
- except:
174
- print('Error.. could not download: ', f"AF-{uniprot_id}-F1-model_v{version}.pdb")
175
- return None
176
- return structure_filename
177
-
178
-