AmelieSchreiber
commited on
Commit
•
d8bde09
1
Parent(s):
a081f7f
Update README.md
Browse files
README.md
CHANGED
@@ -87,4 +87,152 @@ ranked_binders = sorted(mlm_losses, key=mlm_losses.get)
|
|
87 |
print("Ranking of Potential Binders:")
|
88 |
for idx, binder in enumerate(ranked_binders, 1):
|
89 |
print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
```
|
|
|
87 |
print("Ranking of Potential Binders:")
|
88 |
for idx, binder in enumerate(ranked_binders, 1):
|
89 |
print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")
|
90 |
+
```
|
91 |
+
|
92 |
+
## PPI Networks
|
93 |
+
|
94 |
+
To construct a protein-protein interaction network, try running the following code:
|
95 |
+
|
96 |
+
```python
|
97 |
+
import networkx as nx
|
98 |
+
import numpy as np
|
99 |
+
import torch
|
100 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmForMaskedLM
|
101 |
+
import plotly.graph_objects as go
|
102 |
+
from ipywidgets import interact
|
103 |
+
from ipywidgets import widgets
|
104 |
+
|
105 |
+
# Check if CUDA is available and set the default device accordingly
|
106 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
107 |
+
|
108 |
+
# Load the pretrained (or fine-tuned) ESM-2 model and tokenizer
|
109 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
110 |
+
model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_mlmppi_ph50")
|
111 |
+
|
112 |
+
# Send the model to the device (GPU or CPU)
|
113 |
+
model.to(device)
|
114 |
+
|
115 |
+
# Ensure the model is in evaluation mode
|
116 |
+
model.eval()
|
117 |
+
|
118 |
+
# Define Protein Sequences (Replace with your list)
|
119 |
+
all_proteins = [
|
120 |
+
"MFLSILVALCLWLHLALGVRGAPCEAVRIPMCRHMPWNITRMPNHLHHSTQENAILAIEQYEELVDVNCSAVLRFFLCAMYAPICTLEFLHDPIKPCKSVCQRARDDCEPLMKMYNHSWPESLACDELPVYDRGVCISPEAIVTDLPEDVKWIDITPDMMVQERPLDVDCKRLSPDRCKCKKVKPTLATYLSKNYSYVIHAKIKAVQRSGCNEVTTVVDVKEIFKSSSPIPRTQVPLITNSSCQCPHILPHQDVLIMCYEWRSRMMLLENCLVEKWRDQLSKRSIQWEERLQEQRRTVQDKKKTAGRTSRSNPPKPKGKPPAPKPASPKKNIKTRSAQKRTNPKRV",
|
121 |
+
"MDAVEPGGRGWASMLACRLWKAISRALFAEFLATGLYVFFGVGSVMRWPTALPSVLQIAITFNLVTAMAVQVTWKASGAHANPAVTLAFLVGSHISLPRAVAYVAAQLVGATVGAALLYGVMPGDIRETLGINVVRNSVSTGQAVAVELLLTLQLVLCVFASTDSRQTSGSPATMIGISVALGHLIGIHFTGCSMNPARSFGPAIIIGKFTVHWVFWVGPLMGALLASLIYNFVLFPDTKTLAQRLAILTGTVEVGTGAGAGAEPLKKESQPGSGAVEMESV",
|
122 |
+
"MKFLLDILLLLPLLIVCSLESFVKLFIPKRRKSVTGEIVLITGAGHGIGRLTAYEFAKLKSKLVLWDINKHGLEETAAKCKGLGAKVHTFVVDCSNREDIYSSAKKVKAEIGDVSILVNNAGVVYTSDLFATQDPQIEKTFEVNVLAHFWTTKAFLPAMTKNNHGHIVTVASAAGHVSVPFLLAYCSSKFAAVGFHKTLTDELAALQITGVKTTCLCPNFVNTGFIKNPSTSLGPTLEPEEVVNRLMHGILTEQKMIFIPSSIAFLTTLERILPERFLAVLKRKISVKFDAVIGYKMKAQ",
|
123 |
+
|
124 |
+
"MAAAVPRRPTQQGTVTFEDVAVNFSQEEWCLLSEAQRCLYRDVMLENLALISSLGCWCGSKDEEAPCKQRISVQRESQSRTPRAGVSPKKAHPCEMCGLILEDVFHFADHQETHHKQKLNRSGACGKNLDDTAYLHQHQKQHIGEKFYRKSVREASFVKKRKLRVSQEPFVFREFGKDVLPSSGLCQEEAAVEKTDSETMHGPPFQEGKTNYSCGKRTKAFSTKHSVIPHQKLFTRDGCYVCSDCGKSFSRYVSFSNHQRDHTAKGPYDCGECGKSYSRKSSLIQHQRVHTGQTAYPCEECGKSFSQKGSLISHQLVHTGEGPYECRECGKSFGQKGNLIQHQQGHTGERAYHCGECGKSFRQKFCFINHQRVHTGERPYKCGECGKSFGQKGNLVHHQRGHTGERPYECKECGKSFRYRSHLTEHQRLHTGERPYNCRECGKLFNRKYHLLVHERVHTGERPYACEVCGKLFGNKHSVTIHQRIHTGERPYECSECGKSFLSSSALHVHKRVHSGQKPYKCSECGKSFSECSSLIKHRRIHTGERPYECTKCGKTFQRSSTLLHHQSSHRRKAL",
|
125 |
+
"MGQPWAAGSTDGAPAQLPLVLTALWAAAVGLELAYVLVLGPGPPPLGPLARALQLALAAFQLLNLLGNVGLFLRSDPSIRGVMLAGRGLGQGWAYCYQCQSQVPPRSGHCSACRVCILRRDHHCRLLGRCVGFGNYRPFLCLLLHAAGVLLHVSVLLGPALSALLRAHTPLHMAALLLLPWLMLLTGRVSLAQFALAFVTDTCVAGALLCGAGLLFHGMLLLRGQTTWEWARGQHSYDLGPCHNLQAALGPRWALVWLWPFLASPLPGDGITFQTTADVGHTAS",
|
126 |
+
"MGLRIHFVVDPHGWCCMGLIVFVWLYNIVLIPKIVLFPHYEEGHIPGILIIIFYGISIFCLVALVRASITDPGRLPENPKIPHGEREFWELCNKCNLMRPKRSHHCSRCGHCVRRMDHHCPWINNCVGEDNHWLFLQLCFYTELLTCYALMFSFCHYYYFLPLKKRNLDLFVFRHELAIMRLAAFMGITMLVGITGLFYTQLIGIITDTTSIEKMSNCCEDISRPRKPWQQTFSEVFGTRWKILWFIPFRQRQPLRVPYHFANHV",
|
127 |
+
|
128 |
+
"MLLLGAVLLLLALPGHDQETTTQGPGVLLPLPKGACTGWMAGIPGHPGHNGAPGRDGRDGTPGEKGEKGDPGLIGPKGDIGETGVPGAEGPRGFPGIQGRKGEPGEGAYVYRSAFSVGLETYVTIPNMPIRFTKIFYNQQNHYDGSTGKFHCNIPGLYYFAYHITVYMKDVKVSLFKKDKAMLFTYDQYQENNVDQASGSVLLHLEVGDQVWLQVYGEGERNGLYADNDNDSTFTGFLLYHDTN",
|
129 |
+
"MGLLAFLKTQFVLHLLVGFVFVVSGLVINFVQLCTLALWPVSKQLYRRLNCRLAYSLWSQLVMLLEWWSCTECTLFTDQATVERFGKEHAVIILNHNFEIDFLCGWTMCERFGVLGSSKVLAKKELLYVPLIGWTWYFLEIVFCKRKWEEDRDTVVEGLRRLSDYPEYMWFLLYCEGTRFTETKHRVSMEVAAAKGLPVLKYHLLPRTKGFTTAVKCLRGTVAAVYDVTLNFRGNKNPSLLGILYGKKYEADMCVRRFPLEDIPLDEKEAAQWLHKLYQEKDALQEIYNQKGMFPGEQFKPARRPWTLLNFLSWATILLSPLFSFVLGVFASGSPLLILTFLGFVGAASFGVRRLIGVTEIEKGSSYGNQEFKKKE",
|
130 |
+
"MDLAGLLKSQFLCHLVFCYVFIASGLIINTIQLFTLLLWPINKQLFRKINCRLSYCISSQLVMLLEWWSGTECTIFTDPRAYLKYGKENAIVVLNHKFEIDFLCGWSLSERFGLLGGSKVLAKKELAYVPIIGWMWYFTEMVFCSRKWEQDRKTVATSLQHLRDYPEKYFFLIHCEGTRFTEKKHEISMQVARAKGLPRLKHHLLPRTKGFAITVRSLRNVVSAVYDCTLNFRNNENPTLLGVLNGKKYHADLYVRRIPLEDIPEDDDECSAWLHKLYQEKDAFQEEYYRTGTFPETPMVPPRRPWTLVNWLFWASLVLYPFFQFLVSMIRSGSSLTLASFILVFFVASVGVRWMIGVTEIDKGSAYGNSDSKQKLND",
|
131 |
+
|
132 |
+
"MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLSPEDQGPLDIEWLISPADNQKVDQVIILYSGDKIYDDYYPDLKGRVHFTSNDLKSGDASINVTNLQLSDIGTYQCKVKKAPGVANKKIHLVVLVKPSGARCYVDGSEEIGSDFKIKCEPKEGSLPLQYEWQKLSDSQKMPTSWLAEMTSSVISVKNASSEYSGTYSCTVRNRVGSDQCLLRLNVVPPSNKAGLIAGAIIGTLLALALIGLIIFCCRKKRREEKYEKEVHHDIREDVPPPKSRTSTARSYIGSNHSSLGSMSPSNMEGYSKTQYNQVPSEDFERTPQSPTLPPAKVAAPNLSRMGAIPVMIPAQSKDGSIV",
|
133 |
+
"MSYVFVNDSSQTNVPLLQACIDGDFNYSKRLLESGFDPNIRDSRGRTGLHLAAARGNVDICQLLHKFGADLLATDYQGNTALHLCGHVDTIQFLVSNGLKIDICNHQGATPLVLAKRRGVNKDVIRLLESLEEQEVKGFNRGTHSKLETMQTAESESAMESHSLLNPNLQQGEGVLSSFRTTWQEFVEDLGFWRVLLLIFVIALLSLGIAYYVSGVLPFVENQPELVH",
|
134 |
+
"MRVAGAAKLVVAVAVFLLTFYVISQVFEIKMDASLGNLFARSALDTAARSTKPPRYKCGISKACPEKHFAFKMASGAANVVGPKICLEDNVLMSGVKNNVGRGINVALANGKTGEVLDTKYFDMWGGDVAPFIEFLKAIQDGTIVLMGTYDDGATKLNDEARRLIADLGSTSITNLGFRDNWVFCGGKGIKTKSPFEQHIKNNKDTNKYEGWPEVVEMEGCIPQKQD",
|
135 |
+
|
136 |
+
"MAPAAATGGSTLPSGFSVFTTLPDLLFIFEFIFGGLVWILVASSLVPWPLVQGWVMFVSVFCFVATTTLIILYIIGAHGGETSWVTLDAAYHCTAALFYLSASVLEALATITMQDGFTYRHYHENIAAVVFSYIATLLYVVHAVFSLIRWKSS",
|
137 |
+
"MRLQGAIFVLLPHLGPILVWLFTRDHMSGWCEGPRMLSWCPFYKVLLLVQTAIYSVVGYASYLVWKDLGGGLGWPLALPLGLYAVQLTISWTVLVLFFTVHNPGLALLHLLLLYGLVVSTALIWHPINKLAALLLLPYLAWLTVTSALTYHLWRDSLCPVHQPQPTEKSD",
|
138 |
+
"MEESVVRPSVFVVDGQTDIPFTRLGRSHRRQSCSVARVGLGLLLLLMGAGLAVQGWFLLQLHWRLGEMVTRLPDGPAGSWEQLIQERRSHEVNPAAHLTGANSSLTGSGGPLLWETQLGLAFLRGLSYHDGALVVTKAGYYYIYSKVQLGGVGCPLGLASTITHGLYKRTPRYPEELELLVSQQSPCGRATSSSRVWWDSSFLGGVVHLEAGEKVVVRVLDERLVRLRDGTRSYFGAFMV"
|
139 |
+
]
|
140 |
+
|
141 |
+
def compute_average_mlm_loss(protein1, protein2, iterations=10):
|
142 |
+
total_loss = 0.0
|
143 |
+
connector = "G" * 25 # Connector sequence of G's
|
144 |
+
for _ in range(iterations):
|
145 |
+
concatenated_sequence = protein1 + connector + protein2
|
146 |
+
inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
147 |
+
|
148 |
+
mask_prob = 0.35
|
149 |
+
mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob
|
150 |
+
|
151 |
+
# Locate the positions of the connector 'G's and set their mask indices to False
|
152 |
+
connector_indices = tokenizer.encode(connector, add_special_tokens=False)
|
153 |
+
connector_length = len(connector_indices)
|
154 |
+
start_connector = len(tokenizer.encode(protein1, add_special_tokens=False))
|
155 |
+
end_connector = start_connector + connector_length
|
156 |
+
|
157 |
+
# Avoid masking the connector 'G's
|
158 |
+
mask_indices[0, start_connector:end_connector] = False
|
159 |
+
|
160 |
+
# Apply the mask to the input IDs
|
161 |
+
inputs["input_ids"][mask_indices] = tokenizer.mask_token_id
|
162 |
+
inputs = {k: v.to(device) for k, v in inputs.items()} # Send inputs to the device
|
163 |
+
|
164 |
+
with torch.no_grad():
|
165 |
+
outputs = model(**inputs, labels=inputs["input_ids"])
|
166 |
+
|
167 |
+
loss = outputs.loss
|
168 |
+
total_loss += loss.item()
|
169 |
+
|
170 |
+
return total_loss / iterations
|
171 |
+
|
172 |
+
# Compute all average losses to determine the maximum threshold for the slider
|
173 |
+
all_losses = []
|
174 |
+
for i, protein1 in enumerate(all_proteins):
|
175 |
+
for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
|
176 |
+
avg_loss = compute_average_mlm_loss(protein1, protein2)
|
177 |
+
all_losses.append(avg_loss)
|
178 |
+
|
179 |
+
# Set the maximum threshold to the maximum loss computed
|
180 |
+
max_threshold = max(all_losses)
|
181 |
+
print(f"Maximum loss (maximum threshold for slider): {max_threshold}")
|
182 |
+
|
183 |
+
def plot_graph(threshold):
|
184 |
+
G = nx.Graph()
|
185 |
+
|
186 |
+
# Add all protein nodes to the graph
|
187 |
+
for i, protein in enumerate(all_proteins):
|
188 |
+
G.add_node(f"protein {i+1}")
|
189 |
+
|
190 |
+
# Loop through all pairs of proteins and calculate average MLM loss
|
191 |
+
loss_idx = 0 # Index to keep track of the position in the all_losses list
|
192 |
+
for i, protein1 in enumerate(all_proteins):
|
193 |
+
for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
|
194 |
+
avg_loss = all_losses[loss_idx]
|
195 |
+
loss_idx += 1
|
196 |
+
|
197 |
+
# Add an edge if the loss is below the threshold
|
198 |
+
if avg_loss < threshold:
|
199 |
+
G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3))
|
200 |
+
|
201 |
+
# 3D Network Plot
|
202 |
+
# Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value.
|
203 |
+
k_value = 2 # Lower value will bring nodes closer together
|
204 |
+
pos = nx.spring_layout(G, dim=3, seed=42, k=k_value)
|
205 |
+
|
206 |
+
edge_x = []
|
207 |
+
edge_y = []
|
208 |
+
edge_z = []
|
209 |
+
for edge in G.edges():
|
210 |
+
x0, y0, z0 = pos[edge[0]]
|
211 |
+
x1, y1, z1 = pos[edge[1]]
|
212 |
+
edge_x.extend([x0, x1, None])
|
213 |
+
edge_y.extend([y0, y1, None])
|
214 |
+
edge_z.extend([z0, z1, None])
|
215 |
+
|
216 |
+
edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey'))
|
217 |
+
|
218 |
+
node_x = []
|
219 |
+
node_y = []
|
220 |
+
node_z = []
|
221 |
+
node_text = []
|
222 |
+
for node in G.nodes():
|
223 |
+
x, y, z = pos[node]
|
224 |
+
node_x.append(x)
|
225 |
+
node_y.append(y)
|
226 |
+
node_z.append(z)
|
227 |
+
node_text.append(node)
|
228 |
+
|
229 |
+
node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text)
|
230 |
+
|
231 |
+
layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False)))
|
232 |
+
|
233 |
+
fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
|
234 |
+
fig.show()
|
235 |
+
|
236 |
+
# Create an interactive slider for the threshold value with a default of 8.50
|
237 |
+
interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25))
|
238 |
```
|