AmelieSchreiber
commited on
Commit
•
7c24275
1
Parent(s):
fd834e4
Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,239 @@
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
4 |
+
|
5 |
+
# ESM-2 Finetuned for PPI Prediction
|
6 |
+
This is a finetuned version of `facebook/esm2_t33_650M_UR50D` using the masked language modeling objective.
|
7 |
+
The model was finetuned for four epochs on concatenated pairs of interacting proteins, clustered using persistent homology
|
8 |
+
landscapes as explained in [this post](https://huggingface.co/blog/AmelieSchreiber/faster-pha). The dataset consists of 10,000
|
9 |
+
protein pairs, which can be [found here](https://huggingface.co/datasets/AmelieSchreiber/pha_clustered_protein_complexes).
|
10 |
+
This is a very new method for clustering protein-protein complexes.
|
11 |
+
|
12 |
+
Using the MLM loss to predict pairs of interacting proteins was inspired by [this paper](https://arxiv.org/abs/2308.07136). However,
|
13 |
+
the authors do not finetune the models for this task. Thus we reasoned that improved performance on this method could be achieved
|
14 |
+
by finetuning the model on pairs of interacting proteins.
|
15 |
+
|
16 |
+
## Using the Model
|
17 |
+
To use the model, we follow [this blog post](https://huggingface.co/blog/AmelieSchreiber/protein-binding-partners-with-esm2).
|
18 |
+
Below we see how to use the model for ranking potential binders for a target protein of interest. The lower the MLM loss average,
|
19 |
+
the more likely the two proteins are to interact with one another.
|
20 |
+
|
21 |
+
```python
|
22 |
+
import numpy as np
|
23 |
+
from transformers import AutoTokenizer, EsmForMaskedLM
|
24 |
+
import torch
|
25 |
+
|
26 |
+
# Load the base model and tokenizer
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
28 |
+
model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_mlmppi_ph50_v3")
|
29 |
+
|
30 |
+
# Ensure the model is in evaluation mode
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
# Define the protein of interest and its potential binders
|
34 |
+
protein_of_interest = "MLTEVMEVWHGLVIAVVSLFLQACFLTAINYLLSRHMAHKSEQILKAASLQVPRPSPGHHHPPAVKEMKETQTERDIPMSDSLYRHDSDTPSDSLDSSCSSPPACQATEDVDYTQVVFSDPGELKNDSPLDYENIKEITDYVNVNPERHKPSFWYFVNPALSEPAEYDQVAM"
|
35 |
+
potential_binders = [
|
36 |
+
# Known to interact
|
37 |
+
"MASPGSGFWSFGSEDGSGDSENPGTARAWCQVAQKFTGGIGNKLCALLYGDAEKPAESGGSQPPRAAARKAACACDQKPCSCSKVDVNYAFLHATDLLPACDGERPTLAFLQDVMNILLQYVVKSFDRSTKVIDFHYPNELLQEYNWELADQPQNLEEILMHCQTTLKYAIKTGHPRYFNQLSTGLDMVGLAADWLTSTANTNMFTYEIAPVFVLLEYVTLKKMREIIGWPGGSGDGIFSPGGAISNMYAMMIARFKMFPEVKEKGMAALPRLIAFTSEHSHFSLKKGAAALGIGTDSVILIKCDERGKMIPSDLERRILEAKQKGFVPFLVSATAGTTVYGAFDPLLAVADICKKYKIWMHVDAAWGGGLLMSRKHKWKLSGVERANSVTWNPHKMMGVPLQCSALLVREEGLMQNCNQMHASYLFQQDKHYDLSYDTGDKALQCGRHVDVFKLWLMWRAKGTTGFEAHVDKCLELAEYLYNIIKNREGYEMVFDGKPQHTNVCFWYIPPSLRTLEDNEERMSRLSKVAPVIKARMMEYGTTMVSYQPLGDKVNFFRMVISNPAATHQDIDFLIEEIERLGQDL",
|
38 |
+
"MAAGVAGWGVEAEEFEDAPDVEPLEPTLSNIIEQRSLKWIFVGGKGGVGKTTCSCSLAVQLSKGRESVLIISTDPAHNISDAFDQKFSKVPTKVKGYDNLFAMEIDPSLGVAELPDEFFEEDNMLSMGKKMMQEAMSAFPGIDEAMSYAEVMRLVKGMNFSVVVFDTAPTGHTLRLLNFPTIVERGLGRLMQIKNQISPFISQMCNMLGLGDMNADQLASKLEETLPVIRSVSEQFKDPEQTTFICVCIAEFLSLYETERLIQELAKCKIDTHNIIVNQLVFPDPEKPCKMCEARHKIQAKYLDQMEDLYEDFHIVKLPLLPHEVRGADKVNTFSALLLEPYKPPSAQ",
|
39 |
+
"EKTGLSIRGAQEEDPPDPQLMRLDNMLLAEGVSGPEKGGGSAAAAAAAAASGGSSDNSIEHSDYRAKLTQIRQIYHTELEKYEQACNEFTTHVMNLLREQSRTRPISPKEIERMVGIIHRKFSSIQMQLKQSTCEAVMILRSRFLDARRKRRNFSKQATEILNEYFYSHLSNPYPSEEAKEELAKKCSITVSQSLVKDPKERGSKGSDIQPTSVVSNWFGNKRIRYKKNIGKFQEEANLYAAKTAVTAAHAVAAAVQNNQTNSPTTPNSGSSGSFNLPNSGDMFMNMQSLNGDSYQGSQVGANVQSQVDTLRHVINQTGGYSDGLGGNSLYSPHNLNANGGWQDATTPSSVTSPTEGPGSVHSDTSN",
|
40 |
+
# Not known to interact
|
41 |
+
"MRQRLLPSVTSLLLVALLFPGSSQARHVNHSATEALGELRERAPGQGTNGFQLLRHAVKRDLLPPRTPPYQVHISHREARGPSFRICVDFLGPRWARGCSTGN",
|
42 |
+
"MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS"
|
43 |
+
] # Add potential binding sequences here
|
44 |
+
|
45 |
+
def compute_mlm_loss(protein, binder, iterations=5):
|
46 |
+
total_loss = 0.0
|
47 |
+
|
48 |
+
for _ in range(iterations):
|
49 |
+
# Concatenate protein sequences with a separator
|
50 |
+
concatenated_sequence = protein + binder
|
51 |
+
|
52 |
+
# Mask a subset of amino acids in the concatenated sequence (excluding the separator)
|
53 |
+
tokens = list(concatenated_sequence)
|
54 |
+
mask_rate = 0.35 # For instance, masking 35% of the sequence
|
55 |
+
num_mask = int(len(tokens) * mask_rate)
|
56 |
+
|
57 |
+
# Exclude the separator from potential mask indices
|
58 |
+
available_indices = [i for i, token in enumerate(tokens) if token != ":"]
|
59 |
+
probs = torch.ones(len(available_indices))
|
60 |
+
mask_indices = torch.multinomial(probs, num_mask, replacement=False)
|
61 |
+
|
62 |
+
for idx in mask_indices:
|
63 |
+
tokens[available_indices[idx]] = tokenizer.mask_token
|
64 |
+
|
65 |
+
masked_sequence = "".join(tokens)
|
66 |
+
inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length', add_special_tokens=False)
|
67 |
+
|
68 |
+
# Compute the MLM loss
|
69 |
+
with torch.no_grad():
|
70 |
+
outputs = model(**inputs, labels=inputs["input_ids"])
|
71 |
+
loss = outputs.loss
|
72 |
+
|
73 |
+
total_loss += loss.item()
|
74 |
+
|
75 |
+
# Return the average loss
|
76 |
+
return total_loss / iterations
|
77 |
+
|
78 |
+
# Compute MLM loss for each potential binder
|
79 |
+
mlm_losses = {}
|
80 |
+
for binder in potential_binders:
|
81 |
+
loss = compute_mlm_loss(protein_of_interest, binder)
|
82 |
+
mlm_losses[binder] = loss
|
83 |
+
|
84 |
+
# Rank binders based on MLM loss
|
85 |
+
ranked_binders = sorted(mlm_losses, key=mlm_losses.get)
|
86 |
+
|
87 |
+
print("Ranking of Potential Binders:")
|
88 |
+
for idx, binder in enumerate(ranked_binders, 1):
|
89 |
+
print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")
|
90 |
+
```
|
91 |
+
|
92 |
+
## PPI Networks
|
93 |
+
|
94 |
+
To construct a protein-protein interaction network, try running the following code. Try adjusting the length of the connector
|
95 |
+
(0-25 for example). Try also adjusting the number of iterations and the masking percentage.
|
96 |
+
|
97 |
+
```python
|
98 |
+
import networkx as nx
|
99 |
+
import numpy as np
|
100 |
+
import torch
|
101 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmForMaskedLM
|
102 |
+
import plotly.graph_objects as go
|
103 |
+
from ipywidgets import interact
|
104 |
+
from ipywidgets import widgets
|
105 |
+
|
106 |
+
# Check if CUDA is available and set the default device accordingly
|
107 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
108 |
+
|
109 |
+
# Load the pretrained (or fine-tuned) ESM-2 model and tokenizer
|
110 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
111 |
+
model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_mlmppi_ph50_v3")
|
112 |
+
|
113 |
+
# Send the model to the device (GPU or CPU)
|
114 |
+
model.to(device)
|
115 |
+
|
116 |
+
# Ensure the model is in evaluation mode
|
117 |
+
model.eval()
|
118 |
+
|
119 |
+
# Define Protein Sequences (Replace with your list)
|
120 |
+
all_proteins = [
|
121 |
+
"MFLSILVALCLWLHLALGVRGAPCEAVRIPMCRHMPWNITRMPNHLHHSTQENAILAIEQYEELVDVNCSAVLRFFLCAMYAPICTLEFLHDPIKPCKSVCQRARDDCEPLMKMYNHSWPESLACDELPVYDRGVCISPEAIVTDLPEDVKWIDITPDMMVQERPLDVDCKRLSPDRCKCKKVKPTLATYLSKNYSYVIHAKIKAVQRSGCNEVTTVVDVKEIFKSSSPIPRTQVPLITNSSCQCPHILPHQDVLIMCYEWRSRMMLLENCLVEKWRDQLSKRSIQWEERLQEQRRTVQDKKKTAGRTSRSNPPKPKGKPPAPKPASPKKNIKTRSAQKRTNPKRV",
|
122 |
+
"MDAVEPGGRGWASMLACRLWKAISRALFAEFLATGLYVFFGVGSVMRWPTALPSVLQIAITFNLVTAMAVQVTWKASGAHANPAVTLAFLVGSHISLPRAVAYVAAQLVGATVGAALLYGVMPGDIRETLGINVVRNSVSTGQAVAVELLLTLQLVLCVFASTDSRQTSGSPATMIGISVALGHLIGIHFTGCSMNPARSFGPAIIIGKFTVHWVFWVGPLMGALLASLIYNFVLFPDTKTLAQRLAILTGTVEVGTGAGAGAEPLKKESQPGSGAVEMESV",
|
123 |
+
"MKFLLDILLLLPLLIVCSLESFVKLFIPKRRKSVTGEIVLITGAGHGIGRLTAYEFAKLKSKLVLWDINKHGLEETAAKCKGLGAKVHTFVVDCSNREDIYSSAKKVKAEIGDVSILVNNAGVVYTSDLFATQDPQIEKTFEVNVLAHFWTTKAFLPAMTKNNHGHIVTVASAAGHVSVPFLLAYCSSKFAAVGFHKTLTDELAALQITGVKTTCLCPNFVNTGFIKNPSTSLGPTLEPEEVVNRLMHGILTEQKMIFIPSSIAFLTTLERILPERFLAVLKRKISVKFDAVIGYKMKAQ",
|
124 |
+
|
125 |
+
"MAAAVPRRPTQQGTVTFEDVAVNFSQEEWCLLSEAQRCLYRDVMLENLALISSLGCWCGSKDEEAPCKQRISVQRESQSRTPRAGVSPKKAHPCEMCGLILEDVFHFADHQETHHKQKLNRSGACGKNLDDTAYLHQHQKQHIGEKFYRKSVREASFVKKRKLRVSQEPFVFREFGKDVLPSSGLCQEEAAVEKTDSETMHGPPFQEGKTNYSCGKRTKAFSTKHSVIPHQKLFTRDGCYVCSDCGKSFSRYVSFSNHQRDHTAKGPYDCGECGKSYSRKSSLIQHQRVHTGQTAYPCEECGKSFSQKGSLISHQLVHTGEGPYECRECGKSFGQKGNLIQHQQGHTGERAYHCGECGKSFRQKFCFINHQRVHTGERPYKCGECGKSFGQKGNLVHHQRGHTGERPYECKECGKSFRYRSHLTEHQRLHTGERPYNCRECGKLFNRKYHLLVHERVHTGERPYACEVCGKLFGNKHSVTIHQRIHTGERPYECSECGKSFLSSSALHVHKRVHSGQKPYKCSECGKSFSECSSLIKHRRIHTGERPYECTKCGKTFQRSSTLLHHQSSHRRKAL",
|
126 |
+
"MGQPWAAGSTDGAPAQLPLVLTALWAAAVGLELAYVLVLGPGPPPLGPLARALQLALAAFQLLNLLGNVGLFLRSDPSIRGVMLAGRGLGQGWAYCYQCQSQVPPRSGHCSACRVCILRRDHHCRLLGRCVGFGNYRPFLCLLLHAAGVLLHVSVLLGPALSALLRAHTPLHMAALLLLPWLMLLTGRVSLAQFALAFVTDTCVAGALLCGAGLLFHGMLLLRGQTTWEWARGQHSYDLGPCHNLQAALGPRWALVWLWPFLASPLPGDGITFQTTADVGHTAS",
|
127 |
+
"MGLRIHFVVDPHGWCCMGLIVFVWLYNIVLIPKIVLFPHYEEGHIPGILIIIFYGISIFCLVALVRASITDPGRLPENPKIPHGEREFWELCNKCNLMRPKRSHHCSRCGHCVRRMDHHCPWINNCVGEDNHWLFLQLCFYTELLTCYALMFSFCHYYYFLPLKKRNLDLFVFRHELAIMRLAAFMGITMLVGITGLFYTQLIGIITDTTSIEKMSNCCEDISRPRKPWQQTFSEVFGTRWKILWFIPFRQRQPLRVPYHFANHV",
|
128 |
+
|
129 |
+
"MLLLGAVLLLLALPGHDQETTTQGPGVLLPLPKGACTGWMAGIPGHPGHNGAPGRDGRDGTPGEKGEKGDPGLIGPKGDIGETGVPGAEGPRGFPGIQGRKGEPGEGAYVYRSAFSVGLETYVTIPNMPIRFTKIFYNQQNHYDGSTGKFHCNIPGLYYFAYHITVYMKDVKVSLFKKDKAMLFTYDQYQENNVDQASGSVLLHLEVGDQVWLQVYGEGERNGLYADNDNDSTFTGFLLYHDTN",
|
130 |
+
"MGLLAFLKTQFVLHLLVGFVFVVSGLVINFVQLCTLALWPVSKQLYRRLNCRLAYSLWSQLVMLLEWWSCTECTLFTDQATVERFGKEHAVIILNHNFEIDFLCGWTMCERFGVLGSSKVLAKKELLYVPLIGWTWYFLEIVFCKRKWEEDRDTVVEGLRRLSDYPEYMWFLLYCEGTRFTETKHRVSMEVAAAKGLPVLKYHLLPRTKGFTTAVKCLRGTVAAVYDVTLNFRGNKNPSLLGILYGKKYEADMCVRRFPLEDIPLDEKEAAQWLHKLYQEKDALQEIYNQKGMFPGEQFKPARRPWTLLNFLSWATILLSPLFSFVLGVFASGSPLLILTFLGFVGAASFGVRRLIGVTEIEKGSSYGNQEFKKKE",
|
131 |
+
"MDLAGLLKSQFLCHLVFCYVFIASGLIINTIQLFTLLLWPINKQLFRKINCRLSYCISSQLVMLLEWWSGTECTIFTDPRAYLKYGKENAIVVLNHKFEIDFLCGWSLSERFGLLGGSKVLAKKELAYVPIIGWMWYFTEMVFCSRKWEQDRKTVATSLQHLRDYPEKYFFLIHCEGTRFTEKKHEISMQVARAKGLPRLKHHLLPRTKGFAITVRSLRNVVSAVYDCTLNFRNNENPTLLGVLNGKKYHADLYVRRIPLEDIPEDDDECSAWLHKLYQEKDAFQEEYYRTGTFPETPMVPPRRPWTLVNWLFWASLVLYPFFQFLVSMIRSGSSLTLASFILVFFVASVGVRWMIGVTEIDKGSAYGNSDSKQKLND",
|
132 |
+
|
133 |
+
"MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLSPEDQGPLDIEWLISPADNQKVDQVIILYSGDKIYDDYYPDLKGRVHFTSNDLKSGDASINVTNLQLSDIGTYQCKVKKAPGVANKKIHLVVLVKPSGARCYVDGSEEIGSDFKIKCEPKEGSLPLQYEWQKLSDSQKMPTSWLAEMTSSVISVKNASSEYSGTYSCTVRNRVGSDQCLLRLNVVPPSNKAGLIAGAIIGTLLALALIGLIIFCCRKKRREEKYEKEVHHDIREDVPPPKSRTSTARSYIGSNHSSLGSMSPSNMEGYSKTQYNQVPSEDFERTPQSPTLPPAKVAAPNLSRMGAIPVMIPAQSKDGSIV",
|
134 |
+
"MSYVFVNDSSQTNVPLLQACIDGDFNYSKRLLESGFDPNIRDSRGRTGLHLAAARGNVDICQLLHKFGADLLATDYQGNTALHLCGHVDTIQFLVSNGLKIDICNHQGATPLVLAKRRGVNKDVIRLLESLEEQEVKGFNRGTHSKLETMQTAESESAMESHSLLNPNLQQGEGVLSSFRTTWQEFVEDLGFWRVLLLIFVIALLSLGIAYYVSGVLPFVENQPELVH",
|
135 |
+
"MRVAGAAKLVVAVAVFLLTFYVISQVFEIKMDASLGNLFARSALDTAARSTKPPRYKCGISKACPEKHFAFKMASGAANVVGPKICLEDNVLMSGVKNNVGRGINVALANGKTGEVLDTKYFDMWGGDVAPFIEFLKAIQDGTIVLMGTYDDGATKLNDEARRLIADLGSTSITNLGFRDNWVFCGGKGIKTKSPFEQHIKNNKDTNKYEGWPEVVEMEGCIPQKQD",
|
136 |
+
|
137 |
+
"MAPAAATGGSTLPSGFSVFTTLPDLLFIFEFIFGGLVWILVASSLVPWPLVQGWVMFVSVFCFVATTTLIILYIIGAHGGETSWVTLDAAYHCTAALFYLSASVLEALATITMQDGFTYRHYHENIAAVVFSYIATLLYVVHAVFSLIRWKSS",
|
138 |
+
"MRLQGAIFVLLPHLGPILVWLFTRDHMSGWCEGPRMLSWCPFYKVLLLVQTAIYSVVGYASYLVWKDLGGGLGWPLALPLGLYAVQLTISWTVLVLFFTVHNPGLALLHLLLLYGLVVSTALIWHPINKLAALLLLPYLAWLTVTSALTYHLWRDSLCPVHQPQPTEKSD",
|
139 |
+
"MEESVVRPSVFVVDGQTDIPFTRLGRSHRRQSCSVARVGLGLLLLLMGAGLAVQGWFLLQLHWRLGEMVTRLPDGPAGSWEQLIQERRSHEVNPAAHLTGANSSLTGSGGPLLWETQLGLAFLRGLSYHDGALVVTKAGYYYIYSKVQLGGVGCPLGLASTITHGLYKRTPRYPEELELLVSQQSPCGRATSSSRVWWDSSFLGGVVHLEAGEKVVVRVLDERLVRLRDGTRSYFGAFMV"
|
140 |
+
]
|
141 |
+
|
142 |
+
def compute_average_mlm_loss(protein1, protein2, iterations=10):
|
143 |
+
total_loss = 0.0
|
144 |
+
connector = "G" * 25 # Connector sequence of G's
|
145 |
+
for _ in range(iterations):
|
146 |
+
concatenated_sequence = protein1 + connector + protein2
|
147 |
+
inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
148 |
+
|
149 |
+
mask_prob = 0.35
|
150 |
+
mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob
|
151 |
+
|
152 |
+
# Locate the positions of the connector 'G's and set their mask indices to False
|
153 |
+
connector_indices = tokenizer.encode(connector, add_special_tokens=False)
|
154 |
+
connector_length = len(connector_indices)
|
155 |
+
start_connector = len(tokenizer.encode(protein1, add_special_tokens=False))
|
156 |
+
end_connector = start_connector + connector_length
|
157 |
+
|
158 |
+
# Avoid masking the connector 'G's
|
159 |
+
mask_indices[0, start_connector:end_connector] = False
|
160 |
+
|
161 |
+
# Apply the mask to the input IDs
|
162 |
+
inputs["input_ids"][mask_indices] = tokenizer.mask_token_id
|
163 |
+
inputs = {k: v.to(device) for k, v in inputs.items()} # Send inputs to the device
|
164 |
+
|
165 |
+
with torch.no_grad():
|
166 |
+
outputs = model(**inputs, labels=inputs["input_ids"])
|
167 |
+
|
168 |
+
loss = outputs.loss
|
169 |
+
total_loss += loss.item()
|
170 |
+
|
171 |
+
return total_loss / iterations
|
172 |
+
|
173 |
+
# Compute all average losses to determine the maximum threshold for the slider
|
174 |
+
all_losses = []
|
175 |
+
for i, protein1 in enumerate(all_proteins):
|
176 |
+
for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
|
177 |
+
avg_loss = compute_average_mlm_loss(protein1, protein2)
|
178 |
+
all_losses.append(avg_loss)
|
179 |
+
|
180 |
+
# Set the maximum threshold to the maximum loss computed
|
181 |
+
max_threshold = max(all_losses)
|
182 |
+
print(f"Maximum loss (maximum threshold for slider): {max_threshold}")
|
183 |
+
|
184 |
+
def plot_graph(threshold):
|
185 |
+
G = nx.Graph()
|
186 |
+
|
187 |
+
# Add all protein nodes to the graph
|
188 |
+
for i, protein in enumerate(all_proteins):
|
189 |
+
G.add_node(f"protein {i+1}")
|
190 |
+
|
191 |
+
# Loop through all pairs of proteins and calculate average MLM loss
|
192 |
+
loss_idx = 0 # Index to keep track of the position in the all_losses list
|
193 |
+
for i, protein1 in enumerate(all_proteins):
|
194 |
+
for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
|
195 |
+
avg_loss = all_losses[loss_idx]
|
196 |
+
loss_idx += 1
|
197 |
+
|
198 |
+
# Add an edge if the loss is below the threshold
|
199 |
+
if avg_loss < threshold:
|
200 |
+
G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3))
|
201 |
+
|
202 |
+
# 3D Network Plot
|
203 |
+
# Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value.
|
204 |
+
k_value = 2 # Lower value will bring nodes closer together
|
205 |
+
pos = nx.spring_layout(G, dim=3, seed=42, k=k_value)
|
206 |
+
|
207 |
+
edge_x = []
|
208 |
+
edge_y = []
|
209 |
+
edge_z = []
|
210 |
+
for edge in G.edges():
|
211 |
+
x0, y0, z0 = pos[edge[0]]
|
212 |
+
x1, y1, z1 = pos[edge[1]]
|
213 |
+
edge_x.extend([x0, x1, None])
|
214 |
+
edge_y.extend([y0, y1, None])
|
215 |
+
edge_z.extend([z0, z1, None])
|
216 |
+
|
217 |
+
edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey'))
|
218 |
+
|
219 |
+
node_x = []
|
220 |
+
node_y = []
|
221 |
+
node_z = []
|
222 |
+
node_text = []
|
223 |
+
for node in G.nodes():
|
224 |
+
x, y, z = pos[node]
|
225 |
+
node_x.append(x)
|
226 |
+
node_y.append(y)
|
227 |
+
node_z.append(z)
|
228 |
+
node_text.append(node)
|
229 |
+
|
230 |
+
node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text)
|
231 |
+
|
232 |
+
layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False)))
|
233 |
+
|
234 |
+
fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
|
235 |
+
fig.show()
|
236 |
+
|
237 |
+
# Create an interactive slider for the threshold value with a default of 8.50
|
238 |
+
interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25))
|
239 |
+
```
|