File size: 6,001 Bytes
d504f0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, TrainerCallback, EsmConfig
from torch.utils.data import Dataset
import pandas as pd
import torch
from torch.optim import AdamW
import random
import datetime
class ProteinDataset(Dataset):
def __init__(self, proteins, peptides, tokenizer, mask_percentage=0.30):
self.tokenizer = tokenizer
self.proteins = proteins
self.peptides = peptides
self.mask_percentage = mask_percentage
def __len__(self):
return len(self.proteins)
def mask_sequence(self, sequence):
mask_indices = random.sample(range(len(sequence)), int(len(sequence) * self.mask_percentage))
return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)])
def __getitem__(self, idx):
protein_seq = self.proteins[idx]
peptide_seq = self.peptides[idx]
masked_protein = self.mask_sequence(protein_seq)
masked_peptide = self.mask_sequence(peptide_seq)
complex_seq = masked_protein + masked_peptide
complex_input = self.tokenizer(
complex_seq,
return_tensors="pt",
padding="max_length",
max_length=1024,
truncation=True,
add_special_tokens=False
)
input_ids = complex_input["input_ids"].squeeze()
attention_mask = complex_input["attention_mask"].squeeze()
label_seq = protein_seq + peptide_seq
labels = self.tokenizer(
label_seq,
return_tensors="pt",
padding="max_length",
max_length=1024,
truncation=True,
add_special_tokens=False
)["input_ids"].squeeze()
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
# Callback to update mask percentage after each epoch
class DynamicMaskingCallback(TrainerCallback):
def __init__(self, dataset, increment=0.10):
self.dataset = dataset
self.increment = increment
def on_epoch_end(self, args, state, control, **kwargs):
self.dataset.mask_percentage = min(self.dataset.mask_percentage + self.increment, 1.0)
print(f"Updated mask percentage to: {self.dataset.mask_percentage * 100}%")
# Loading the dataset
file_path = "clustered_protein_pair_landscapes_l2_distances.tsv"
data = pd.read_csv(file_path, delimiter='\t')
# Splitting the data based on clusters, starting with cluster 0
test_clusters = [0] # Start with cluster 0
remaining_clusters = data[data['Cluster'] != 0]['Cluster'].unique()
random.shuffle(remaining_clusters) # Shuffle the remaining clusters
# Determine the size of cluster 0 in the dataset
cluster_0_size = (data['Cluster'] == 0).mean()
# Add more clusters until reaching approximately 20% of the dataset
test_size = cluster_0_size
for cluster in remaining_clusters:
cluster_size = (data['Cluster'] == cluster).mean()
if test_size + cluster_size > 0.20:
break
test_clusters.append(cluster)
test_size += cluster_size
# Creating test and train data based on the selected clusters
test_data = data[data['Cluster'].isin(test_clusters)]
train_data = data[~data['Cluster'].isin(test_clusters)]
proteins_train = train_data["Protein1"].tolist()
peptides_train = train_data["Protein2"].tolist()
proteins_test = test_data["Protein1"].tolist()
peptides_test = test_data["Protein2"].tolist()
# Load tokenizer and model
model_name = "esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name)
# Load model configuration and modify dropout rates
config = EsmConfig.from_pretrained("facebook/" + model_name)
# config.hidden_dropout_prob = 0.1 # Adjust hidden layer dropout
# config.attention_probs_dropout_prob = 0.1 # Adjust attention dropout
model = AutoModelForMaskedLM.from_pretrained("facebook/" + model_name, config=config)
# Generate a timestamp for the output directory
current_time = datetime.datetime.now()
timestamp = current_time.strftime("%Y%m%d_%H%M%S")
output_dir = f'./interact_output_{timestamp}/'
# Calculate the total number of training steps
num_train_epochs = 4
per_device_train_batch_size = 8
gradient_accumulation_steps = 4
total_steps = (len(proteins_train) // (per_device_train_batch_size * gradient_accumulation_steps)) * num_train_epochs
# Training arguments with cosine learning rate scheduler and gradient clipping
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=8,
warmup_steps=10,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch",
load_best_model_at_end=True,
save_strategy='epoch',
metric_for_best_model='eval_loss',
save_total_limit=3,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type='cosine',
max_steps=total_steps, # Corrected: Added comma here
gradient_checkpointing=True, # Enable gradient checkpointing for memory optimization
max_grad_norm=1.0 # Gradient clipping
)
# Optimizer with added weight decay for regularization
optimizer = AdamW(model.parameters(), lr=0.0007984276816171436, weight_decay=0.03)
# Instantiate the ProteinDataset for training and testing
train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer)
test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer)
# Initialize DynamicMaskingCallback
dynamic_masking_callback = DynamicMaskingCallback(train_dataset)
# Trainer with callbacks for dynamic masking and gradient clipping
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
optimizers=(optimizer, None),
callbacks=[dynamic_masking_callback]
)
# Start training
trainer.train() |