Soumic commited on
Commit
78e1dd9
1 Parent(s): b056aeb

:rocket: Add dockerfile, app.py, and requirements.txt

Browse files
Files changed (3) hide show
  1. Dockerfile +16 -0
  2. app.py +337 -0
  3. requirements.txt +31 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official PyTorch Docker image as a base (includes CUDA and PyTorch)
2
+ FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
3
+
4
+ # Set a working directory in the container
5
+ WORKDIR /workspace
6
+
7
+ # Install Python dependencies
8
+ COPY requirements.txt .
9
+ RUN pip install --upgrade pip
10
+ RUN pip install -r requirements.txt
11
+
12
+ # Copy the training script
13
+ COPY app.py .
14
+
15
+ # Run the training script
16
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from pytorch_lightning import Trainer, LightningModule, LightningDataModule
4
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
7
+ from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
8
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
9
+ import torch
10
+ from torch import nn
11
+ from datasets import load_dataset
12
+
13
+ black = "\u001b[30m"
14
+ red = "\u001b[31m"
15
+ green = "\u001b[32m"
16
+ yellow = "\u001b[33m"
17
+ blue = "\u001b[34m"
18
+ magenta = "\u001b[35m"
19
+ cyan = "\u001b[36m"
20
+ white = "\u001b[37m"
21
+
22
+ FORWARD = "FORWARD_INPUT"
23
+ BACKWARD = "BACKWARD_INPUT"
24
+
25
+ DNA_BERT_6 = "zhihan1996/DNA_bert_6"
26
+
27
+
28
+ class CommonAttentionLayer(nn.Module):
29
+ def __init__(self, hidden_size, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+ self.attention_linear = nn.Linear(hidden_size, 1)
32
+ pass
33
+
34
+ def forward(self, hidden_states):
35
+ # Apply linear layer
36
+ attn_weights = self.attention_linear(hidden_states)
37
+ # Apply softmax to get attention scores
38
+ attn_weights = torch.softmax(attn_weights, dim=1)
39
+ # Apply attention weights to hidden states
40
+ context_vector = torch.sum(attn_weights * hidden_states, dim=1)
41
+ return context_vector, attn_weights
42
+
43
+
44
+ class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
45
+ def forward(self, input, target):
46
+ return super().forward(input.squeeze(), target.float())
47
+
48
+
49
+ class MQtlDnaBERT6Classifier(nn.Module):
50
+ def __init__(self,
51
+ bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
52
+ hidden_size=768,
53
+ num_classes=1,
54
+ *args,
55
+ **kwargs
56
+ ):
57
+ super().__init__(*args, **kwargs)
58
+
59
+ self.model_name = "MQtlDnaBERT6Classifier"
60
+
61
+ self.bert_model = bert_model
62
+ self.attention = CommonAttentionLayer(hidden_size)
63
+ self.classifier = nn.Linear(hidden_size, num_classes)
64
+ pass
65
+
66
+ def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids):
67
+ """
68
+ # torch.Size([128, 1, 512]) --> [128, 512]
69
+ input_ids = input_ids.squeeze(dim=1).to(DEVICE)
70
+ # torch.Size([16, 1, 512]) --> [16, 512]
71
+ attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
72
+ token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
73
+ """
74
+ bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(
75
+ input_ids=input_ids,
76
+ attention_mask=attention_mask,
77
+ token_type_ids=token_type_ids
78
+ )
79
+
80
+ last_hidden_state = bert_output.last_hidden_state
81
+ context_vector, ignore_attention_weight = self.attention(last_hidden_state)
82
+ y = self.classifier(context_vector)
83
+ return y
84
+
85
+
86
+ class TorchMetrics:
87
+ def __init__(self):
88
+ self.binary_accuracy = BinaryAccuracy() #.to(device)
89
+ self.binary_auc = BinaryAUROC() # .to(device)
90
+ self.binary_f1_score = BinaryF1Score() # .to(device)
91
+ self.binary_precision = BinaryPrecision() # .to(device)
92
+ self.binary_recall = BinaryRecall() # .to(device)
93
+ pass
94
+
95
+ def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
96
+ # it looks like the library maintainers changed preds to input, ie, before: preds, now: input
97
+ self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels)
98
+ self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels)
99
+ self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels)
100
+ self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels)
101
+ self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
102
+ pass
103
+
104
+ def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
105
+ b_accuracy = self.binary_accuracy.compute()
106
+ b_auc = self.binary_auc.compute()
107
+ b_f1_score = self.binary_f1_score.compute()
108
+ b_precision = self.binary_precision.compute()
109
+ b_recall = self.binary_recall.compute()
110
+ # timber.info( log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
111
+ log(f"{log_prefix}_accuracy", b_accuracy)
112
+ log(f"{log_prefix}_auc", b_auc)
113
+ log(f"{log_prefix}_f1_score", b_f1_score)
114
+ log(f"{log_prefix}_precision", b_precision)
115
+ log(f"{log_prefix}_recall", b_recall)
116
+
117
+ self.binary_accuracy.reset()
118
+ self.binary_auc.reset()
119
+ self.binary_f1_score.reset()
120
+ self.binary_precision.reset()
121
+ self.binary_recall.reset()
122
+ pass
123
+
124
+
125
+ class MQtlBertClassifierLightningModule(LightningModule):
126
+ def __init__(self,
127
+ classifier: nn.Module,
128
+ criterion=None, # nn.BCEWithLogitsLoss(),
129
+ regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
130
+ l1_lambda=0.001,
131
+ l2_wright_decay=0.001,
132
+ *args: Any,
133
+ **kwargs: Any):
134
+ super().__init__(*args, **kwargs)
135
+ self.classifier = classifier
136
+ self.criterion = criterion
137
+ self.train_metrics = TorchMetrics()
138
+ self.validate_metrics = TorchMetrics()
139
+ self.test_metrics = TorchMetrics()
140
+
141
+ self.regularization = regularization
142
+ self.l1_lambda = l1_lambda
143
+ self.l2_weight_decay = l2_wright_decay
144
+ pass
145
+
146
+ def forward(self, x, *args: Any, **kwargs: Any) -> Any:
147
+ input_ids: torch.tensor = x["input_ids"]
148
+ attention_mask: torch.tensor = x["attention_mask"]
149
+ token_type_ids: torch.tensor = x["token_type_ids"]
150
+ # print(f"\n{ type(input_ids) = }, {input_ids = }")
151
+ # print(f"{ type(attention_mask) = }, { attention_mask = }")
152
+ # print(f"{ type(token_type_ids) = }, { token_type_ids = }")
153
+
154
+ return self.classifier.forward(input_ids, attention_mask, token_type_ids)
155
+
156
+ def configure_optimizers(self) -> OptimizerLRScheduler:
157
+ # Here we add weight decay (L2 regularization) to the optimizer
158
+ weight_decay = 0.0
159
+ if self.regularization == 2 or self.regularization == 3:
160
+ weight_decay = self.l2_weight_decay
161
+ return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
162
+
163
+ def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
164
+ # Accuracy on training batch data
165
+ x, y = batch
166
+ preds = self.forward(x)
167
+ loss = self.criterion(preds, y)
168
+
169
+ if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
170
+ l1_norm = sum(p.abs().sum() for p in self.parameters())
171
+ loss += self.l1_lambda * l1_norm
172
+
173
+ self.log("train_loss", loss)
174
+ # calculate the scores start
175
+ self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
176
+ # calculate the scores end
177
+ return loss
178
+
179
+ def on_train_epoch_end(self) -> None:
180
+ self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
181
+ pass
182
+
183
+ def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
184
+ # Accuracy on validation batch data
185
+ # print(f"debug { batch = }")
186
+ x, y = batch
187
+ preds = self.forward(x)
188
+ loss = 0 # self.criterion(preds, y)
189
+ self.log("valid_loss", loss)
190
+ # calculate the scores start
191
+ self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
192
+ # calculate the scores end
193
+ return loss
194
+
195
+ def on_validation_epoch_end(self) -> None:
196
+ self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
197
+ return None
198
+
199
+ def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
200
+ # Accuracy on validation batch data
201
+ x, y = batch
202
+ preds = self.forward(x)
203
+ loss = self.criterion(preds, y)
204
+ self.log("test_loss", loss) # do we need this?
205
+ # calculate the scores start
206
+ self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
207
+ # calculate the scores end
208
+ return loss
209
+
210
+ def on_test_epoch_end(self) -> None:
211
+ self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
212
+ return None
213
+
214
+ pass
215
+
216
+
217
+ class DNABERTDataset(Dataset):
218
+ def __init__(self, dataset, tokenizer, max_length=512):
219
+ self.dataset = dataset
220
+ self.bert_tokenizer = tokenizer
221
+ self.max_length = max_length
222
+
223
+ def __len__(self):
224
+ return len(self.dataset)
225
+
226
+ def __getitem__(self, idx):
227
+ sequence = self.dataset[idx]['sequence'] # Fetch the 'sequence' column
228
+ label = self.dataset[idx]['label'] # Fetch the 'label' column (or whatever target you use)
229
+
230
+ # Tokenize the sequence
231
+ encoded_sequence: BatchEncoding = self.bert_tokenizer(
232
+ sequence,
233
+ truncation=True,
234
+ padding='max_length',
235
+ max_length=self.max_length,
236
+ return_tensors='pt'
237
+ )
238
+
239
+ encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
240
+ return encoded_sequence_squeezed, label
241
+
242
+
243
+ class DNABERTDataModule(LightningDataModule):
244
+ def __init__(self, model_name=DNA_BERT_6, batch_size=8):
245
+ super().__init__()
246
+ self.tokenized_dataset = None
247
+ self.dataset = None
248
+ self.train_dataset: DNABERTDataset = None
249
+ self.validate_dataset: DNABERTDataset = None
250
+ self.test_dataset: DNABERTDataset = None
251
+ self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6)
252
+ self.batch_size = batch_size
253
+
254
+ def prepare_data(self):
255
+ # Download and prepare dataset
256
+ self.dataset = load_dataset("fahimfarhan/mqtl-classification-dataset-binned-200")
257
+
258
+ def setup(self, stage=None):
259
+ self.train_dataset = DNABERTDataset(self.dataset['train'], self.tokenizer)
260
+ self.validate_dataset = DNABERTDataset(self.dataset['validate'], self.tokenizer)
261
+ self.test_dataset = DNABERTDataset(self.dataset['test'], self.tokenizer)
262
+
263
+ def train_dataloader(self):
264
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=15)
265
+
266
+ def val_dataloader(self):
267
+ return DataLoader(self.validate_dataset, batch_size=self.batch_size, num_workers=15)
268
+
269
+ def test_dataloader(self) -> EVAL_DATALOADERS:
270
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=15)
271
+
272
+
273
+ # Initialize DataModule
274
+ model_name = "zhihan1996/DNABERT-6"
275
+ data_module = DNABERTDataModule(model_name=model_name, batch_size=8)
276
+
277
+
278
+ def start_bert(classifier_model, model_save_path, criterion, WINDOW=200, batch_size=4,
279
+ dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
280
+ file_suffix = ""
281
+ if is_binned:
282
+ file_suffix = "_binned"
283
+
284
+ data_module = DNABERTDataModule(batch_size=batch_size)
285
+
286
+ # classifier_model = classifier_model.to(DEVICE)
287
+
288
+ classifier_module = MQtlBertClassifierLightningModule(
289
+ classifier=classifier_model,
290
+ regularization=2, criterion=criterion)
291
+
292
+ # if os.path.exists(model_save_path):
293
+ # classifier_module.load_state_dict(torch.load(model_save_path))
294
+
295
+ classifier_module = classifier_module # .double()
296
+
297
+ # Set up training arguments
298
+ training_args = TrainingArguments(
299
+ output_dir='./results',
300
+ evaluation_strategy="epoch",
301
+ per_device_train_batch_size=batch_size,
302
+ per_device_eval_batch_size=batch_size,
303
+ num_train_epochs=max_epochs,
304
+ logging_dir='./logs',
305
+ report_to="none", # Disable reporting to WandB, etc.
306
+ )
307
+
308
+ # Prepare data using the DataModule
309
+ data_module.prepare_data()
310
+ data_module.setup()
311
+
312
+ # Initialize Trainer
313
+ # trainer = Trainer(
314
+ # model=classifier_module,
315
+ # args=training_args,
316
+ # train_dataset=data_module.tokenized_dataset["train"],
317
+ # eval_dataset=data_module.tokenized_dataset["test"],
318
+ # )
319
+
320
+ trainer = Trainer(max_epochs=max_epochs, precision="32")
321
+
322
+ # Train the model
323
+ trainer.fit(model=classifier_module, datamodule=data_module)
324
+ trainer.test(model=classifier_module, datamodule=data_module)
325
+ torch.save(classifier_module.state_dict(), model_save_path)
326
+
327
+ classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model")
328
+ pass
329
+
330
+
331
+ if __name__ == "__main__":
332
+ dataset_folder_prefix = "inputdata/"
333
+ pytorch_model = MQtlDnaBERT6Classifier()
334
+ start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
335
+ criterion=ReshapedBCEWithLogitsLoss(), WINDOW=200, batch_size=4,
336
+ dataset_folder_prefix=dataset_folder_prefix, max_epochs=2)
337
+ pass
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate # required by HayenaDNA
2
+ datasets
3
+ pandas
4
+ polars
5
+ numpy
6
+ matplotlib
7
+ scipy
8
+ shap
9
+ scikit-learn
10
+ skorch==1.0.0
11
+ six
12
+ hyperopt
13
+ requests
14
+ pyyaml
15
+ Bio
16
+ plotly
17
+ Levenshtein
18
+ # pytorch
19
+ captum
20
+ torch==2.4.0
21
+ torchvision
22
+ torchaudio
23
+ torchsummary
24
+ torcheval
25
+ pydot
26
+ pydotplus
27
+ PySide2 # matplotlib dependency on ubuntu. you may need sth else for other os/env setup
28
+ torchviz
29
+ gReLU @ git+https://github.com/Genentech/gReLU # @623fee8023aabcef89f0afeedbeafff4b71453af
30
+ # lightning[extra] # cz I got a stupid warning in the console logs
31
+ torchmetrics