AmelieSchreiber commited on
Commit
d8bde09
1 Parent(s): a081f7f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +148 -0
README.md CHANGED
@@ -87,4 +87,152 @@ ranked_binders = sorted(mlm_losses, key=mlm_losses.get)
87
  print("Ranking of Potential Binders:")
88
  for idx, binder in enumerate(ranked_binders, 1):
89
  print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  ```
 
87
  print("Ranking of Potential Binders:")
88
  for idx, binder in enumerate(ranked_binders, 1):
89
  print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")
90
+ ```
91
+
92
+ ## PPI Networks
93
+
94
+ To construct a protein-protein interaction network, try running the following code:
95
+
96
+ ```python
97
+ import networkx as nx
98
+ import numpy as np
99
+ import torch
100
+ from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmForMaskedLM
101
+ import plotly.graph_objects as go
102
+ from ipywidgets import interact
103
+ from ipywidgets import widgets
104
+
105
+ # Check if CUDA is available and set the default device accordingly
106
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
+
108
+ # Load the pretrained (or fine-tuned) ESM-2 model and tokenizer
109
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
110
+ model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_mlmppi_ph50")
111
+
112
+ # Send the model to the device (GPU or CPU)
113
+ model.to(device)
114
+
115
+ # Ensure the model is in evaluation mode
116
+ model.eval()
117
+
118
+ # Define Protein Sequences (Replace with your list)
119
+ all_proteins = [
120
+ "MFLSILVALCLWLHLALGVRGAPCEAVRIPMCRHMPWNITRMPNHLHHSTQENAILAIEQYEELVDVNCSAVLRFFLCAMYAPICTLEFLHDPIKPCKSVCQRARDDCEPLMKMYNHSWPESLACDELPVYDRGVCISPEAIVTDLPEDVKWIDITPDMMVQERPLDVDCKRLSPDRCKCKKVKPTLATYLSKNYSYVIHAKIKAVQRSGCNEVTTVVDVKEIFKSSSPIPRTQVPLITNSSCQCPHILPHQDVLIMCYEWRSRMMLLENCLVEKWRDQLSKRSIQWEERLQEQRRTVQDKKKTAGRTSRSNPPKPKGKPPAPKPASPKKNIKTRSAQKRTNPKRV",
121
+ "MDAVEPGGRGWASMLACRLWKAISRALFAEFLATGLYVFFGVGSVMRWPTALPSVLQIAITFNLVTAMAVQVTWKASGAHANPAVTLAFLVGSHISLPRAVAYVAAQLVGATVGAALLYGVMPGDIRETLGINVVRNSVSTGQAVAVELLLTLQLVLCVFASTDSRQTSGSPATMIGISVALGHLIGIHFTGCSMNPARSFGPAIIIGKFTVHWVFWVGPLMGALLASLIYNFVLFPDTKTLAQRLAILTGTVEVGTGAGAGAEPLKKESQPGSGAVEMESV",
122
+ "MKFLLDILLLLPLLIVCSLESFVKLFIPKRRKSVTGEIVLITGAGHGIGRLTAYEFAKLKSKLVLWDINKHGLEETAAKCKGLGAKVHTFVVDCSNREDIYSSAKKVKAEIGDVSILVNNAGVVYTSDLFATQDPQIEKTFEVNVLAHFWTTKAFLPAMTKNNHGHIVTVASAAGHVSVPFLLAYCSSKFAAVGFHKTLTDELAALQITGVKTTCLCPNFVNTGFIKNPSTSLGPTLEPEEVVNRLMHGILTEQKMIFIPSSIAFLTTLERILPERFLAVLKRKISVKFDAVIGYKMKAQ",
123
+
124
+ "MAAAVPRRPTQQGTVTFEDVAVNFSQEEWCLLSEAQRCLYRDVMLENLALISSLGCWCGSKDEEAPCKQRISVQRESQSRTPRAGVSPKKAHPCEMCGLILEDVFHFADHQETHHKQKLNRSGACGKNLDDTAYLHQHQKQHIGEKFYRKSVREASFVKKRKLRVSQEPFVFREFGKDVLPSSGLCQEEAAVEKTDSETMHGPPFQEGKTNYSCGKRTKAFSTKHSVIPHQKLFTRDGCYVCSDCGKSFSRYVSFSNHQRDHTAKGPYDCGECGKSYSRKSSLIQHQRVHTGQTAYPCEECGKSFSQKGSLISHQLVHTGEGPYECRECGKSFGQKGNLIQHQQGHTGERAYHCGECGKSFRQKFCFINHQRVHTGERPYKCGECGKSFGQKGNLVHHQRGHTGERPYECKECGKSFRYRSHLTEHQRLHTGERPYNCRECGKLFNRKYHLLVHERVHTGERPYACEVCGKLFGNKHSVTIHQRIHTGERPYECSECGKSFLSSSALHVHKRVHSGQKPYKCSECGKSFSECSSLIKHRRIHTGERPYECTKCGKTFQRSSTLLHHQSSHRRKAL",
125
+ "MGQPWAAGSTDGAPAQLPLVLTALWAAAVGLELAYVLVLGPGPPPLGPLARALQLALAAFQLLNLLGNVGLFLRSDPSIRGVMLAGRGLGQGWAYCYQCQSQVPPRSGHCSACRVCILRRDHHCRLLGRCVGFGNYRPFLCLLLHAAGVLLHVSVLLGPALSALLRAHTPLHMAALLLLPWLMLLTGRVSLAQFALAFVTDTCVAGALLCGAGLLFHGMLLLRGQTTWEWARGQHSYDLGPCHNLQAALGPRWALVWLWPFLASPLPGDGITFQTTADVGHTAS",
126
+ "MGLRIHFVVDPHGWCCMGLIVFVWLYNIVLIPKIVLFPHYEEGHIPGILIIIFYGISIFCLVALVRASITDPGRLPENPKIPHGEREFWELCNKCNLMRPKRSHHCSRCGHCVRRMDHHCPWINNCVGEDNHWLFLQLCFYTELLTCYALMFSFCHYYYFLPLKKRNLDLFVFRHELAIMRLAAFMGITMLVGITGLFYTQLIGIITDTTSIEKMSNCCEDISRPRKPWQQTFSEVFGTRWKILWFIPFRQRQPLRVPYHFANHV",
127
+
128
+ "MLLLGAVLLLLALPGHDQETTTQGPGVLLPLPKGACTGWMAGIPGHPGHNGAPGRDGRDGTPGEKGEKGDPGLIGPKGDIGETGVPGAEGPRGFPGIQGRKGEPGEGAYVYRSAFSVGLETYVTIPNMPIRFTKIFYNQQNHYDGSTGKFHCNIPGLYYFAYHITVYMKDVKVSLFKKDKAMLFTYDQYQENNVDQASGSVLLHLEVGDQVWLQVYGEGERNGLYADNDNDSTFTGFLLYHDTN",
129
+ "MGLLAFLKTQFVLHLLVGFVFVVSGLVINFVQLCTLALWPVSKQLYRRLNCRLAYSLWSQLVMLLEWWSCTECTLFTDQATVERFGKEHAVIILNHNFEIDFLCGWTMCERFGVLGSSKVLAKKELLYVPLIGWTWYFLEIVFCKRKWEEDRDTVVEGLRRLSDYPEYMWFLLYCEGTRFTETKHRVSMEVAAAKGLPVLKYHLLPRTKGFTTAVKCLRGTVAAVYDVTLNFRGNKNPSLLGILYGKKYEADMCVRRFPLEDIPLDEKEAAQWLHKLYQEKDALQEIYNQKGMFPGEQFKPARRPWTLLNFLSWATILLSPLFSFVLGVFASGSPLLILTFLGFVGAASFGVRRLIGVTEIEKGSSYGNQEFKKKE",
130
+ "MDLAGLLKSQFLCHLVFCYVFIASGLIINTIQLFTLLLWPINKQLFRKINCRLSYCISSQLVMLLEWWSGTECTIFTDPRAYLKYGKENAIVVLNHKFEIDFLCGWSLSERFGLLGGSKVLAKKELAYVPIIGWMWYFTEMVFCSRKWEQDRKTVATSLQHLRDYPEKYFFLIHCEGTRFTEKKHEISMQVARAKGLPRLKHHLLPRTKGFAITVRSLRNVVSAVYDCTLNFRNNENPTLLGVLNGKKYHADLYVRRIPLEDIPEDDDECSAWLHKLYQEKDAFQEEYYRTGTFPETPMVPPRRPWTLVNWLFWASLVLYPFFQFLVSMIRSGSSLTLASFILVFFVASVGVRWMIGVTEIDKGSAYGNSDSKQKLND",
131
+
132
+ "MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLSPEDQGPLDIEWLISPADNQKVDQVIILYSGDKIYDDYYPDLKGRVHFTSNDLKSGDASINVTNLQLSDIGTYQCKVKKAPGVANKKIHLVVLVKPSGARCYVDGSEEIGSDFKIKCEPKEGSLPLQYEWQKLSDSQKMPTSWLAEMTSSVISVKNASSEYSGTYSCTVRNRVGSDQCLLRLNVVPPSNKAGLIAGAIIGTLLALALIGLIIFCCRKKRREEKYEKEVHHDIREDVPPPKSRTSTARSYIGSNHSSLGSMSPSNMEGYSKTQYNQVPSEDFERTPQSPTLPPAKVAAPNLSRMGAIPVMIPAQSKDGSIV",
133
+ "MSYVFVNDSSQTNVPLLQACIDGDFNYSKRLLESGFDPNIRDSRGRTGLHLAAARGNVDICQLLHKFGADLLATDYQGNTALHLCGHVDTIQFLVSNGLKIDICNHQGATPLVLAKRRGVNKDVIRLLESLEEQEVKGFNRGTHSKLETMQTAESESAMESHSLLNPNLQQGEGVLSSFRTTWQEFVEDLGFWRVLLLIFVIALLSLGIAYYVSGVLPFVENQPELVH",
134
+ "MRVAGAAKLVVAVAVFLLTFYVISQVFEIKMDASLGNLFARSALDTAARSTKPPRYKCGISKACPEKHFAFKMASGAANVVGPKICLEDNVLMSGVKNNVGRGINVALANGKTGEVLDTKYFDMWGGDVAPFIEFLKAIQDGTIVLMGTYDDGATKLNDEARRLIADLGSTSITNLGFRDNWVFCGGKGIKTKSPFEQHIKNNKDTNKYEGWPEVVEMEGCIPQKQD",
135
+
136
+ "MAPAAATGGSTLPSGFSVFTTLPDLLFIFEFIFGGLVWILVASSLVPWPLVQGWVMFVSVFCFVATTTLIILYIIGAHGGETSWVTLDAAYHCTAALFYLSASVLEALATITMQDGFTYRHYHENIAAVVFSYIATLLYVVHAVFSLIRWKSS",
137
+ "MRLQGAIFVLLPHLGPILVWLFTRDHMSGWCEGPRMLSWCPFYKVLLLVQTAIYSVVGYASYLVWKDLGGGLGWPLALPLGLYAVQLTISWTVLVLFFTVHNPGLALLHLLLLYGLVVSTALIWHPINKLAALLLLPYLAWLTVTSALTYHLWRDSLCPVHQPQPTEKSD",
138
+ "MEESVVRPSVFVVDGQTDIPFTRLGRSHRRQSCSVARVGLGLLLLLMGAGLAVQGWFLLQLHWRLGEMVTRLPDGPAGSWEQLIQERRSHEVNPAAHLTGANSSLTGSGGPLLWETQLGLAFLRGLSYHDGALVVTKAGYYYIYSKVQLGGVGCPLGLASTITHGLYKRTPRYPEELELLVSQQSPCGRATSSSRVWWDSSFLGGVVHLEAGEKVVVRVLDERLVRLRDGTRSYFGAFMV"
139
+ ]
140
+
141
+ def compute_average_mlm_loss(protein1, protein2, iterations=10):
142
+ total_loss = 0.0
143
+ connector = "G" * 25 # Connector sequence of G's
144
+ for _ in range(iterations):
145
+ concatenated_sequence = protein1 + connector + protein2
146
+ inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024)
147
+
148
+ mask_prob = 0.35
149
+ mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob
150
+
151
+ # Locate the positions of the connector 'G's and set their mask indices to False
152
+ connector_indices = tokenizer.encode(connector, add_special_tokens=False)
153
+ connector_length = len(connector_indices)
154
+ start_connector = len(tokenizer.encode(protein1, add_special_tokens=False))
155
+ end_connector = start_connector + connector_length
156
+
157
+ # Avoid masking the connector 'G's
158
+ mask_indices[0, start_connector:end_connector] = False
159
+
160
+ # Apply the mask to the input IDs
161
+ inputs["input_ids"][mask_indices] = tokenizer.mask_token_id
162
+ inputs = {k: v.to(device) for k, v in inputs.items()} # Send inputs to the device
163
+
164
+ with torch.no_grad():
165
+ outputs = model(**inputs, labels=inputs["input_ids"])
166
+
167
+ loss = outputs.loss
168
+ total_loss += loss.item()
169
+
170
+ return total_loss / iterations
171
+
172
+ # Compute all average losses to determine the maximum threshold for the slider
173
+ all_losses = []
174
+ for i, protein1 in enumerate(all_proteins):
175
+ for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
176
+ avg_loss = compute_average_mlm_loss(protein1, protein2)
177
+ all_losses.append(avg_loss)
178
+
179
+ # Set the maximum threshold to the maximum loss computed
180
+ max_threshold = max(all_losses)
181
+ print(f"Maximum loss (maximum threshold for slider): {max_threshold}")
182
+
183
+ def plot_graph(threshold):
184
+ G = nx.Graph()
185
+
186
+ # Add all protein nodes to the graph
187
+ for i, protein in enumerate(all_proteins):
188
+ G.add_node(f"protein {i+1}")
189
+
190
+ # Loop through all pairs of proteins and calculate average MLM loss
191
+ loss_idx = 0 # Index to keep track of the position in the all_losses list
192
+ for i, protein1 in enumerate(all_proteins):
193
+ for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
194
+ avg_loss = all_losses[loss_idx]
195
+ loss_idx += 1
196
+
197
+ # Add an edge if the loss is below the threshold
198
+ if avg_loss < threshold:
199
+ G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3))
200
+
201
+ # 3D Network Plot
202
+ # Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value.
203
+ k_value = 2 # Lower value will bring nodes closer together
204
+ pos = nx.spring_layout(G, dim=3, seed=42, k=k_value)
205
+
206
+ edge_x = []
207
+ edge_y = []
208
+ edge_z = []
209
+ for edge in G.edges():
210
+ x0, y0, z0 = pos[edge[0]]
211
+ x1, y1, z1 = pos[edge[1]]
212
+ edge_x.extend([x0, x1, None])
213
+ edge_y.extend([y0, y1, None])
214
+ edge_z.extend([z0, z1, None])
215
+
216
+ edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey'))
217
+
218
+ node_x = []
219
+ node_y = []
220
+ node_z = []
221
+ node_text = []
222
+ for node in G.nodes():
223
+ x, y, z = pos[node]
224
+ node_x.append(x)
225
+ node_y.append(y)
226
+ node_z.append(z)
227
+ node_text.append(node)
228
+
229
+ node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text)
230
+
231
+ layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False)))
232
+
233
+ fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
234
+ fig.show()
235
+
236
+ # Create an interactive slider for the threshold value with a default of 8.50
237
+ interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25))
238
  ```