File size: 3,069 Bytes
d993853 4f58b6d 7e9a955 287b523 4f58b6d c726944 2d5f3aa c726944 4f58b6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
---
license: mit
---
# ESM-2 QLoRA for Binding Sites Prediction
In this model we added in more QLoRA adapter layers, modifying all of the weight matrices with QLoRA. The differences between the
train and test metrics, again, are smaller for this model than for the model with fewer adapter layers (only using query, key, and value
matrices). So, we see that adapting more of the weight matrices in this larger ESM-2 model decreases overfitting and serves as a better
regularizer. For comparison, see [this model](https://huggingface.co/AmelieSchreiber/esm2_t12_35M_qlora_binding_sites_v0) which only
has QLoRA adapters on the query, key, and value matrices. This model was trained on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/1111K_binding_sites).
Note, this dataset is too small for this model, so overfitting is expected, but overfitting is clearly reduced by including more adapter
layers in the QLoRA.
## Testing for Overfitting
```python
Train metrics:
{'eval_loss': 0.17861589789390564,
'eval_accuracy': 0.9336392007583741,
'eval_precision': 0.24007189695313816,
'eval_recall': 0.9234520216135872,
'eval_f1': 0.38107489676203077,
'eval_auc': 0.9286608447868842,
'eval_mcc': 0.4519203165484902}
Test metrics:
{'eval_loss': 0.2265990674495697,
'eval_accuracy': 0.913988661430497,
'eval_precision': 0.1725452162312655,
'eval_recall': 0.8272126203209694,
'eval_f1': 0.28553230637278637,
'eval_auc': 0.8715212375759034,
'eval_mcc': 0.3539008454498742
```
To use this model, run the following:
```
!pip install transformers -q
!pip install peft -q
```
Then run:
```python
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_qlora_binding_sites_v1"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)
# Ensure the model is in evaluation mode
loaded_model.eval()
# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
# Run the model
with torch.no_grad():
logits = loaded_model(**inputs).logits
# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)
# Define labels
id2label = {
0: "No binding site",
1: "Binding site"
}
# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
if token not in ['<pad>', '<cls>', '<eos>']:
print((token, id2label[prediction]))
``` |