Soumic commited on
Commit
31eb488
1 Parent(s): f70ddaf

:hammer_and_pick: Move old code to app_v2.py, and rewrite app.py just like hyenadna finetune

Browse files
Files changed (2) hide show
  1. app.py +184 -351
  2. app_v2.py +504 -0
app.py CHANGED
@@ -1,24 +1,20 @@
1
- import logging
2
  import os
3
  import random
4
- from typing import Any
5
 
 
6
  import numpy as np
7
- import pandas as pd
8
- from pytorch_lightning import Trainer, LightningModule, LightningDataModule
9
- from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
10
- from torch.nn.utils.rnn import pad_sequence
11
- from torch.utils.data import DataLoader, Dataset
12
- from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
13
- from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
14
- from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
- import torch
16
- from torch import nn
17
- from datasets import load_dataset, IterableDataset
18
- from huggingface_hub import PyTorchModelHubMixin
19
-
20
  from dotenv import load_dotenv
21
- from huggingface_hub import login
 
 
 
 
 
 
 
 
 
22
 
23
  timber = logging.getLogger()
24
  # logging.basicConfig(level=logging.DEBUG)
@@ -38,121 +34,7 @@ BACKWARD = "BACKWARD_INPUT"
38
 
39
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
 
41
-
42
- def login_inside_huggingface_virtualmachine():
43
- # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
44
- try:
45
- load_dotenv() # Only useful on your laptop if .env exists
46
- print(".env file loaded successfully.")
47
- except Exception as e:
48
- print(f"Warning: Could not load .env file. Exception: {e}")
49
-
50
- # Try to get the token from environment variables
51
- try:
52
- token = os.getenv("HF_TOKEN")
53
-
54
- if not token:
55
- raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
56
-
57
- # Log in to Hugging Face Hub
58
- login(token)
59
- print("Logged in to Hugging Face Hub successfully.")
60
-
61
- except Exception as e:
62
- print(f"Error during Hugging Face login: {e}")
63
- # Handle the error appropriately (e.g., exit or retry)
64
-
65
-
66
- def one_hot_e(dna_seq: str) -> np.ndarray:
67
- mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
68
- 'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
69
- 'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
70
- 'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
71
- 'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
72
- 'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}
73
-
74
- size_of_a_seq: int = len(dna_seq)
75
-
76
- # forward = np.zeros(shape=(size_of_a_seq, 4))
77
-
78
- forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
79
- encoded = np.asarray(forward_list)
80
- encoded_transposed = encoded.transpose() # todo: Needs review
81
- return encoded_transposed
82
-
83
-
84
- def one_hot_e_column(column: pd.Series) -> np.ndarray:
85
- tmp_list: list = [one_hot_e(seq) for seq in column]
86
- encoded_column = np.asarray(tmp_list).astype(np.float32)
87
- return encoded_column
88
-
89
-
90
- def reverse_dna_seq(dna_seq: str) -> str:
91
- # m_reversed = ""
92
- # for i in range(0, len(dna_seq)):
93
- # m_reversed = dna_seq[i] + m_reversed
94
- # return m_reversed
95
- return dna_seq[::-1]
96
-
97
-
98
- def complement_dna_seq(dna_seq: str) -> str:
99
- comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
100
- "a": "t", "c": "g", "t": "a", "g": "c",
101
- "N": "N", "H": "H", "-": "-",
102
- "n": "n", "h": "h"
103
- }
104
-
105
- comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
106
- comp_dna_seq: str = "".join(comp_dna_seq_list)
107
- return comp_dna_seq
108
-
109
-
110
- def reverse_complement_dna_seq(dna_seq: str) -> str:
111
- return reverse_dna_seq(complement_dna_seq(dna_seq))
112
-
113
-
114
- def reverse_complement_column(column: pd.Series) -> np.ndarray:
115
- rc_column: list = [reverse_complement_dna_seq(seq) for seq in column]
116
- return rc_column
117
-
118
-
119
- class TorchMetrics:
120
- def __init__(self, device=DEVICE):
121
- self.binary_accuracy = BinaryAccuracy().to(device)
122
- self.binary_auc = BinaryAUROC().to(device)
123
- self.binary_f1_score = BinaryF1Score().to(device)
124
- self.binary_precision = BinaryPrecision().to(device)
125
- self.binary_recall = BinaryRecall().to(device)
126
- pass
127
-
128
- def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
129
- self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels)
130
- self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels)
131
- self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels)
132
- self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels)
133
- self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels)
134
- pass
135
-
136
- def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
137
- b_accuracy = self.binary_accuracy.compute()
138
- b_auc = self.binary_auc.compute()
139
- b_f1_score = self.binary_f1_score.compute()
140
- b_precision = self.binary_precision.compute()
141
- b_recall = self.binary_recall.compute()
142
- timber.info(
143
- log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
144
- log(f"{log_prefix}_accuracy", b_accuracy)
145
- log(f"{log_prefix}_auc", b_auc)
146
- log(f"{log_prefix}_f1_score", b_f1_score)
147
- log(f"{log_prefix}_precision", b_precision)
148
- log(f"{log_prefix}_recall", b_recall)
149
-
150
- self.binary_accuracy.reset()
151
- self.binary_auc.reset()
152
- self.binary_f1_score.reset()
153
- self.binary_precision.reset()
154
- self.binary_recall.reset()
155
- pass
156
 
157
 
158
  def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
@@ -194,16 +76,12 @@ class PagingMQTLDataset(IterableDataset):
194
  label = row['label'] # Fetch the 'label' column (or whatever target you use)
195
  if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
196
  sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
197
- # Tokenize the sequence
198
- encoded_sequence: BatchEncoding = self.bert_tokenizer(
199
- sequence,
200
- truncation=True,
201
- padding='max_length',
202
- max_length=self.max_length,
203
- return_tensors='pt'
204
- )
205
- encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
206
- return encoded_sequence_squeezed, label
207
 
208
 
209
  class MqtlDataModule(LightningDataModule):
@@ -244,173 +122,17 @@ class MqtlDataModule(LightningDataModule):
244
  return self.test_loader
245
 
246
 
247
- class MQtlBertClassifierLightningModule(LightningModule):
248
- def __init__(self,
249
- classifier: nn.Module,
250
- criterion=None, # nn.BCEWithLogitsLoss(),
251
- regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
252
- l1_lambda=0.001,
253
- l2_wright_decay=0.001,
254
- *args: Any,
255
- **kwargs: Any):
256
- super().__init__(*args, **kwargs)
257
- self.classifier = classifier
258
- self.criterion = criterion
259
- self.train_metrics = TorchMetrics()
260
- self.validate_metrics = TorchMetrics()
261
- self.test_metrics = TorchMetrics()
262
-
263
- self.regularization = regularization
264
- self.l1_lambda = l1_lambda
265
- self.l2_weight_decay = l2_wright_decay
266
- pass
267
-
268
- def forward(self, x, *args: Any, **kwargs: Any) -> Any:
269
- input_ids: torch.tensor = x["input_ids"]
270
- attention_mask: torch.tensor = x["attention_mask"]
271
- token_type_ids: torch.tensor = x["token_type_ids"]
272
- # print(f"\n{ type(input_ids) = }, {input_ids = }")
273
- # print(f"{ type(attention_mask) = }, { attention_mask = }")
274
- # print(f"{ type(token_type_ids) = }, { token_type_ids = }")
275
-
276
- return self.classifier.forward(input_ids, attention_mask, token_type_ids)
277
-
278
- def configure_optimizers(self) -> OptimizerLRScheduler:
279
- # Here we add weight decay (L2 regularization) to the optimizer
280
- weight_decay = 0.0
281
- if self.regularization == 2 or self.regularization == 3:
282
- weight_decay = self.l2_weight_decay
283
- return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
284
-
285
- def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
286
- # Accuracy on training batch data
287
- x, y = batch
288
- preds = self.forward(x)
289
- loss = self.criterion(preds, y)
290
-
291
- if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
292
- l1_norm = sum(p.abs().sum() for p in self.parameters())
293
- loss += self.l1_lambda * l1_norm
294
-
295
- self.log("train_loss", loss)
296
- # calculate the scores start
297
- self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
298
- # calculate the scores end
299
- return loss
300
-
301
- def on_train_epoch_end(self) -> None:
302
- self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
303
- pass
304
-
305
- def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
306
- # Accuracy on validation batch data
307
- # print(f"debug { batch = }")
308
- x, y = batch
309
- preds = self.forward(x)
310
- loss = self.criterion(preds, y)
311
- self.log("valid_loss", loss)
312
- # calculate the scores start
313
- self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
314
- # calculate the scores end
315
- return loss
316
-
317
- def on_validation_epoch_end(self) -> None:
318
- self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
319
- return None
320
-
321
- def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
322
- # Accuracy on validation batch data
323
- x, y = batch
324
- preds = self.forward(x)
325
- loss = self.criterion(preds, y)
326
- self.log("test_loss", loss) # do we need this?
327
- # calculate the scores start
328
- self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
329
- # calculate the scores end
330
- return loss
331
-
332
- def on_test_epoch_end(self) -> None:
333
- self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
334
- return None
335
-
336
- pass
337
-
338
-
339
- DNA_BERT_6 = "zhihan1996/DNA_bert_6"
340
-
341
-
342
- class CommonAttentionLayer(nn.Module):
343
- def __init__(self, hidden_size, *args, **kwargs):
344
- super().__init__(*args, **kwargs)
345
- self.attention_linear = nn.Linear(hidden_size, 1)
346
- pass
347
-
348
- def forward(self, hidden_states):
349
- # Apply linear layer
350
- attn_weights = self.attention_linear(hidden_states)
351
- # Apply softmax to get attention scores
352
- attn_weights = torch.softmax(attn_weights, dim=1)
353
- # Apply attention weights to hidden states
354
- context_vector = torch.sum(attn_weights * hidden_states, dim=1)
355
- return context_vector, attn_weights
356
-
357
-
358
- class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
359
- def forward(self, input, target):
360
- return super().forward(input.squeeze(), target.float())
361
-
362
-
363
- class DnaBert6MQTLClassifier(nn.Module, PyTorchModelHubMixin):
364
- def __init__(self,
365
- seq_len: int, model_repository_name: str,
366
- bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
367
- hidden_size=768,
368
- num_classes=1,
369
- *args,
370
- **kwargs
371
- ):
372
- super().__init__(*args, **kwargs)
373
- self.seq_len = seq_len
374
- self.model_repository_name = model_repository_name
375
-
376
- self.model_name = "MQtlDnaBERT6Classifier"
377
-
378
- self.bert_model = bert_model
379
- self.attention = CommonAttentionLayer(hidden_size)
380
- self.classifier = nn.Linear(hidden_size, num_classes)
381
- pass
382
-
383
- def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids):
384
- """
385
- # torch.Size([128, 1, 512]) --> [128, 512]
386
- input_ids = input_ids.squeeze(dim=1).to(DEVICE)
387
- # torch.Size([16, 1, 512]) --> [16, 512]
388
- attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
389
- token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
390
- """
391
- bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(
392
- input_ids=input_ids,
393
- attention_mask=attention_mask,
394
- token_type_ids=token_type_ids
395
- )
396
-
397
- last_hidden_state = bert_output.last_hidden_state
398
- context_vector, ignore_attention_weight = self.attention(last_hidden_state)
399
- y = self.classifier(context_vector)
400
- return y
401
-
402
-
403
- def start_bert(classifier_model, criterion, m_optimizer=torch.optim.Adam, WINDOW=200,
404
- is_binned=True, is_debug=False, max_epochs=10, batch_size=8):
405
- file_suffix = ""
406
- if is_binned:
407
- file_suffix = "_binned"
408
-
409
  data_files = {
410
  # small samples
411
  "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
412
  "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
413
  "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
 
 
 
 
 
414
  # large samples
415
  "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
416
  "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
@@ -418,14 +140,12 @@ def start_bert(classifier_model, criterion, m_optimizer=torch.optim.Adam, WINDOW
418
  }
419
 
420
  dataset_map = None
421
- is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv")
422
  if is_my_laptop:
423
  dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
424
  else:
425
  dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
426
 
427
- tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6)
428
-
429
  train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
430
  check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
431
  tokenizer=tokenizer,
@@ -439,66 +159,179 @@ def start_bert(classifier_model, criterion, m_optimizer=torch.optim.Adam, WINDOW
439
  check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
440
  tokenizer=tokenizer,
441
  seq_len=WINDOW)
 
 
442
 
443
- data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
444
 
445
- classifier_model = classifier_model #.to(DEVICE)
 
446
  try:
447
- classifier_model = classifier_model.from_pretrained(classifier_model.model_repository_name)
448
- except Exception as x:
449
- print(x)
450
-
451
- classifier_module = MQtlBertClassifierLightningModule(
452
- classifier=classifier_model,
453
- regularization=2, criterion=criterion)
454
 
455
- # if os.path.exists(model_save_path):
456
- # classifier_module.load_state_dict(torch.load(model_save_path))
 
457
 
458
- classifier_module = classifier_module # .double()
 
459
 
460
- trainer = Trainer(max_epochs=max_epochs, precision="32")
461
- trainer.fit(model=classifier_module, datamodule=data_module)
462
- timber.info("\n\n")
463
- trainer.test(model=classifier_module, datamodule=data_module)
464
- timber.info("\n\n")
465
- # torch.save(classifier_module.state_dict(), model_save_path) # deprecated, use classifier_model.save_pretrained(model_subdirectory) instead
466
 
467
- # save locally
468
- model_subdirectory = classifier_model.model_repository_name
469
- classifier_model.save_pretrained(model_subdirectory)
470
 
471
- # push to the hub
472
- commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
473
- if is_my_laptop:
474
- commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
475
 
476
- classifier_model.push_to_hub(
477
- repo_id=f"fahimfarhan/{classifier_model.model_repository_name}",
478
- # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
479
- commit_message=commit_message # f":tada: Push model for window size {WINDOW}"
480
- )
481
 
482
- # reload
483
- # classifier_model = classifier_model.from_pretrained(f"fahimfarhan/{classifier_model.model_repository_name}")
484
- # classifier_model = classifier_model.from_pretrained(model_subdirectory)
485
 
 
 
486
  pass
487
 
488
 
489
- if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  login_inside_huggingface_virtualmachine()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
- WINDOW = 1000
493
- some_model = DnaBert6MQTLClassifier(seq_len=WINDOW, model_repository_name="dnabert-6-mqtl-classifier")
494
- criterion = ReshapedBCEWithLogitsLoss()
495
-
496
- start_bert(
497
- classifier_model=some_model,
498
- criterion=criterion,
499
- WINDOW=WINDOW,
500
- is_debug=False,
501
- max_epochs=20,
502
- batch_size=16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  pass
 
 
1
  import os
2
  import random
 
3
 
4
+ import huggingface_hub
5
  import numpy as np
6
+ from datasets import load_dataset, Dataset
 
 
 
 
 
 
 
 
 
 
 
 
7
  from dotenv import load_dotenv
8
+ from pytorch_lightning import LightningDataModule
9
+ from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
10
+ from torch.utils.data import DataLoader, IterableDataset
11
+ from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
12
+ # from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
13
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, BertModel
14
+ from transformers import TrainingArguments, Trainer
15
+ import torch
16
+ import logging
17
+ import wandb
18
 
19
  timber = logging.getLogger()
20
  # logging.basicConfig(level=logging.DEBUG)
 
34
 
35
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
 
37
+ PRETRAINED_MODEL_NAME: str = "zhihan1996/DNA_bert_6"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
 
76
  label = row['label'] # Fetch the 'label' column (or whatever target you use)
77
  if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
78
  sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
79
+
80
+ input_ids = self.bert_tokenizer(sequence)["input_ids"]
81
+ tokenized_tensor = torch.tensor(input_ids)
82
+ label_tensor = torch.tensor(label)
83
+ output_dict = {"input_ids": tokenized_tensor, "labels": label_tensor} # so this is now you do it?
84
+ return output_dict # tokenized_tensor, label_tensor
 
 
 
 
85
 
86
 
87
  class MqtlDataModule(LightningDataModule):
 
122
  return self.test_loader
123
 
124
 
125
+ def create_paging_train_val_test_datasets(tokenizer, WINDOW, is_debug, batch_size=1000):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  data_files = {
127
  # small samples
128
  "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
129
  "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
130
  "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
131
+ # medium samples
132
+ "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
133
+ "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
134
+ "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",
135
+
136
  # large samples
137
  "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
138
  "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
 
140
  }
141
 
142
  dataset_map = None
143
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
144
  if is_my_laptop:
145
  dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
146
  else:
147
  dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
148
 
 
 
149
  train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
150
  check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
151
  tokenizer=tokenizer,
 
159
  check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
160
  tokenizer=tokenizer,
161
  seq_len=WINDOW)
162
+ # data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
163
+ return train_dataset, val_dataset, test_dataset
164
 
 
165
 
166
+ def login_inside_huggingface_virtualmachine():
167
+ # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
168
  try:
169
+ load_dotenv() # Only useful on your laptop if .env exists
170
+ print(".env file loaded successfully.")
171
+ except Exception as e:
172
+ print(f"Warning: Could not load .env file. Exception: {e}")
 
 
 
173
 
174
+ # Try to get the token from environment variables
175
+ try:
176
+ token = os.getenv("HF_TOKEN")
177
 
178
+ if not token:
179
+ raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
180
 
181
+ # Log in to Hugging Face Hub
182
+ huggingface_hub.login(token)
183
+ print("Logged in to Hugging Face Hub successfully.")
 
 
 
184
 
185
+ except Exception as e:
186
+ print(f"Error during Hugging Face login: {e}")
187
+ # Handle the error appropriately (e.g., exit or retry)
188
 
189
+ # wand db login
190
+ try:
191
+ api_key = os.getenv("WAND_DB_API_KEY")
192
+ timber.info(f"{api_key = }")
193
 
194
+ if not api_key:
195
+ raise ValueError("WAND_DB_API_KEY not found. Make sure to set it in the environment variables or .env file.")
 
 
 
196
 
197
+ # Log in to Hugging Face Hub
198
+ wandb.login(key=api_key)
199
+ print("Logged in to wand db successfully.")
200
 
201
+ except Exception as e:
202
+ print(f"Error during wand db Face login: {e}")
203
  pass
204
 
205
 
206
+ # use sklearn cz torchmetrics.classification gave array index out of bound exception :/ (whatever it is called in python)
207
+ def compute_metrics_using_sklearn(p):
208
+ try:
209
+ pred, labels = p
210
+
211
+ # Get predicted class labels
212
+ pred_labels = np.argmax(pred, axis=1)
213
+
214
+ # Get predicted probabilities for the positive class
215
+ pred_probs = pred[:, 1] # Assuming binary classification and 2 output classes
216
+
217
+ accuracy = accuracy_score(y_true=labels, y_pred=pred_labels)
218
+ recall = recall_score(y_true=labels, y_pred=pred_labels)
219
+ precision = precision_score(y_true=labels, y_pred=pred_labels)
220
+ f1 = f1_score(y_true=labels, y_pred=pred_labels)
221
+ roc_auc = roc_auc_score(y_true=labels, y_score=pred_probs)
222
+
223
+ return {"accuracy": accuracy, "roc_auc": roc_auc, "precision": precision, "recall": recall, "f1": f1}
224
+
225
+ except Exception as x:
226
+ print(f"compute_metrics_using_sklearn failed with exception: {x}")
227
+ return {"accuracy": 0, "roc_auc": 0, "precision": 0, "recall": 0, "f1": 0}
228
+
229
+
230
+ def start():
231
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
232
+
233
  login_inside_huggingface_virtualmachine()
234
+ WINDOW = 4000
235
+ batch_size = 100
236
+ model_local_directory = f"my-awesome-model-{WINDOW}"
237
+ model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
238
+
239
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
240
+
241
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, trust_remote_code=True)
242
+ classifier_model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME, num_labels=2)
243
+ args = {
244
+ "output_dir": "output_dnabert-6-mqtl_classification",
245
+ "num_train_epochs": 1,
246
+ "max_steps": 100, # train 36k + val 4k = 40k
247
+ # Set the number of steps you expect to train, originally 1000, takes too much time. So I set it to 10 to run faster and check my code/pipeline
248
+ "run_name": "laptop_run_dna-bert-6-mqtl_classification", # Override run_name here
249
+ "per_device_train_batch_size": 1,
250
+ "gradient_accumulation_steps": 32,
251
+ "gradient_checkpointing": True,
252
+ "learning_rate": 1e-3,
253
+ "save_safetensors": False, # I added it. this solves the runtime error!
254
+ # not sure if it is a good idea. sklearn may slow down training, causing time loss... if so, disable these 2 lines below
255
+ "evaluation_strategy": "epoch", # To calculate metrics per epoch
256
+ "logging_strategy": "epoch" # Extra: to log training data stats for loss
257
+ }
258
 
259
+ training_args = TrainingArguments(**args)
260
+ # train_dataset, eval_dataset, test_dataset = create_data_module(tokenizer=tokenizer, WINDOW=WINDOW,
261
+ # batch_size=batch_size,
262
+ # is_debug=False)
263
+ """ # example code
264
+ max_length = 32_000
265
+ sequence = 'ACTG' * int(max_length / 4)
266
+ # sequence = 'ACTG' * int(1000) # seq_len = 4000 it works!
267
+ sequence = [sequence] * 8 # Create 8 identical samples
268
+ tokenized = tokenizer(sequence)["input_ids"]
269
+ labels = [0, 1] * 4
270
+
271
+ # Create a dataset for training
272
+ run_the_code_ds = Dataset.from_dict({"input_ids": tokenized, "labels": labels})
273
+ run_the_code_ds.set_format("pt")
274
+ """
275
+
276
+ train_ds, val_ds, test_ds = create_paging_train_val_test_datasets(tokenizer, WINDOW=WINDOW, is_debug=False)
277
+ # train_ds, val_ds, test_ds = run_the_code_ds, run_the_code_ds, run_the_code_ds
278
+ # train_ds.set_format("pt") # doesn't work!
279
+
280
+ trainer = Trainer(
281
+ model=classifier_model,
282
+ args=training_args,
283
+ train_dataset=train_ds,
284
+ eval_dataset=val_ds,
285
+ compute_metrics=compute_metrics_using_sklearn # torch_metrics.compute_metrics
286
  )
287
+ # train, and validate
288
+ result = trainer.train()
289
+ try:
290
+ print(f"{result = }")
291
+ except Exception as x:
292
+ print(f"{x = }")
293
+
294
+ # testing
295
+ try:
296
+ # with torch.no_grad(): # didn't work :/
297
+ test_results = trainer.evaluate(eval_dataset=test_ds)
298
+ print(f"{test_results = }")
299
+ except Exception as oome:
300
+ print(f"{oome = }")
301
+ finally:
302
+ # save the model
303
+ model_name = "DnaBert6MQtlClassifier"
304
+
305
+ classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False)
306
+
307
+ # push to the hub
308
+ commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
309
+ if is_my_laptop:
310
+ commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
311
+
312
+ classifier_model.push_to_hub(
313
+ repo_id=model_remote_repository,
314
+ # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
315
+ commit_message=commit_message, # f":tada: Push model for window size {WINDOW}"
316
+ safe_serialization=False
317
+ )
318
+ pass
319
+
320
+
321
+ def interprete_demo():
322
+ is_my_laptop = True
323
+ WINDOW = 4000
324
+ batch_size = 100
325
+ model_local_directory = f"my-awesome-model-{WINDOW}"
326
+ model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
327
+
328
+ try:
329
+ classifier_model = AutoModel.from_pretrained(model_remote_repository)
330
+ # todo: use captum / gentech-grelu to interpret the model
331
+ except Exception as x:
332
+ print(x)
333
+
334
+
335
+ if __name__ == '__main__':
336
+ start()
337
  pass
app_v2.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from pytorch_lightning import Trainer, LightningModule, LightningDataModule
9
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
10
+ from torch.nn.utils.rnn import pad_sequence
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
13
+ from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
14
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
+ import torch
16
+ from torch import nn
17
+ from datasets import load_dataset, IterableDataset
18
+ from huggingface_hub import PyTorchModelHubMixin
19
+
20
+ from dotenv import load_dotenv
21
+ from huggingface_hub import login
22
+
23
+ timber = logging.getLogger()
24
+ # logging.basicConfig(level=logging.DEBUG)
25
+ logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
26
+
27
+ black = "\u001b[30m"
28
+ red = "\u001b[31m"
29
+ green = "\u001b[32m"
30
+ yellow = "\u001b[33m"
31
+ blue = "\u001b[34m"
32
+ magenta = "\u001b[35m"
33
+ cyan = "\u001b[36m"
34
+ white = "\u001b[37m"
35
+
36
+ FORWARD = "FORWARD_INPUT"
37
+ BACKWARD = "BACKWARD_INPUT"
38
+
39
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+
42
+ def login_inside_huggingface_virtualmachine():
43
+ # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
44
+ try:
45
+ load_dotenv() # Only useful on your laptop if .env exists
46
+ print(".env file loaded successfully.")
47
+ except Exception as e:
48
+ print(f"Warning: Could not load .env file. Exception: {e}")
49
+
50
+ # Try to get the token from environment variables
51
+ try:
52
+ token = os.getenv("HF_TOKEN")
53
+
54
+ if not token:
55
+ raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
56
+
57
+ # Log in to Hugging Face Hub
58
+ login(token)
59
+ print("Logged in to Hugging Face Hub successfully.")
60
+
61
+ except Exception as e:
62
+ print(f"Error during Hugging Face login: {e}")
63
+ # Handle the error appropriately (e.g., exit or retry)
64
+
65
+
66
+ def one_hot_e(dna_seq: str) -> np.ndarray:
67
+ mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
68
+ 'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
69
+ 'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
70
+ 'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
71
+ 'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
72
+ 'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}
73
+
74
+ size_of_a_seq: int = len(dna_seq)
75
+
76
+ # forward = np.zeros(shape=(size_of_a_seq, 4))
77
+
78
+ forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
79
+ encoded = np.asarray(forward_list)
80
+ encoded_transposed = encoded.transpose() # todo: Needs review
81
+ return encoded_transposed
82
+
83
+
84
+ def one_hot_e_column(column: pd.Series) -> np.ndarray:
85
+ tmp_list: list = [one_hot_e(seq) for seq in column]
86
+ encoded_column = np.asarray(tmp_list).astype(np.float32)
87
+ return encoded_column
88
+
89
+
90
+ def reverse_dna_seq(dna_seq: str) -> str:
91
+ # m_reversed = ""
92
+ # for i in range(0, len(dna_seq)):
93
+ # m_reversed = dna_seq[i] + m_reversed
94
+ # return m_reversed
95
+ return dna_seq[::-1]
96
+
97
+
98
+ def complement_dna_seq(dna_seq: str) -> str:
99
+ comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
100
+ "a": "t", "c": "g", "t": "a", "g": "c",
101
+ "N": "N", "H": "H", "-": "-",
102
+ "n": "n", "h": "h"
103
+ }
104
+
105
+ comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
106
+ comp_dna_seq: str = "".join(comp_dna_seq_list)
107
+ return comp_dna_seq
108
+
109
+
110
+ def reverse_complement_dna_seq(dna_seq: str) -> str:
111
+ return reverse_dna_seq(complement_dna_seq(dna_seq))
112
+
113
+
114
+ def reverse_complement_column(column: pd.Series) -> np.ndarray:
115
+ rc_column: list = [reverse_complement_dna_seq(seq) for seq in column]
116
+ return rc_column
117
+
118
+
119
+ class TorchMetrics:
120
+ def __init__(self, device=DEVICE):
121
+ self.binary_accuracy = BinaryAccuracy().to(device)
122
+ self.binary_auc = BinaryAUROC().to(device)
123
+ self.binary_f1_score = BinaryF1Score().to(device)
124
+ self.binary_precision = BinaryPrecision().to(device)
125
+ self.binary_recall = BinaryRecall().to(device)
126
+ pass
127
+
128
+ def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
129
+ self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels)
130
+ self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels)
131
+ self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels)
132
+ self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels)
133
+ self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels)
134
+ pass
135
+
136
+ def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
137
+ b_accuracy = self.binary_accuracy.compute()
138
+ b_auc = self.binary_auc.compute()
139
+ b_f1_score = self.binary_f1_score.compute()
140
+ b_precision = self.binary_precision.compute()
141
+ b_recall = self.binary_recall.compute()
142
+ timber.info(
143
+ log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
144
+ log(f"{log_prefix}_accuracy", b_accuracy)
145
+ log(f"{log_prefix}_auc", b_auc)
146
+ log(f"{log_prefix}_f1_score", b_f1_score)
147
+ log(f"{log_prefix}_precision", b_precision)
148
+ log(f"{log_prefix}_recall", b_recall)
149
+
150
+ self.binary_accuracy.reset()
151
+ self.binary_auc.reset()
152
+ self.binary_f1_score.reset()
153
+ self.binary_precision.reset()
154
+ self.binary_recall.reset()
155
+ pass
156
+
157
+
158
+ def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
159
+ start = 0
160
+ end = len(seq)
161
+ rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
162
+ random_end = rand_pos + len(DEBUG_MOTIF)
163
+ output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
164
+ assert len(seq) == len(output)
165
+ return output
166
+
167
+
168
+ class PagingMQTLDataset(IterableDataset):
169
+ def __init__(self,
170
+ m_dataset,
171
+ seq_len,
172
+ tokenizer,
173
+ max_length=512,
174
+ check_if_pipeline_is_ok_by_inserting_debug_motif=False):
175
+ self.dataset = m_dataset
176
+ self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
177
+ self.debug_motif = "ATCGCCTA"
178
+ self.seq_len = seq_len
179
+
180
+ self.bert_tokenizer = tokenizer
181
+ self.max_length = max_length
182
+ pass
183
+
184
+ def __iter__(self):
185
+ for row in self.dataset:
186
+ processed = self.preprocess(row)
187
+ if processed is not None:
188
+ yield processed
189
+
190
+ def preprocess(self, row):
191
+ sequence = row['sequence'] # Fetch the 'sequence' column
192
+ if len(sequence) != self.seq_len:
193
+ return None # skip problematic row!
194
+ label = row['label'] # Fetch the 'label' column (or whatever target you use)
195
+ if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
196
+ sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
197
+ # Tokenize the sequence
198
+ encoded_sequence: BatchEncoding = self.bert_tokenizer(
199
+ sequence,
200
+ truncation=True,
201
+ padding='max_length',
202
+ max_length=self.max_length,
203
+ return_tensors='pt'
204
+ )
205
+ encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
206
+ return encoded_sequence_squeezed, label
207
+
208
+
209
+ class MqtlDataModule(LightningDataModule):
210
+ def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
211
+ super().__init__()
212
+ self.batch_size = batch_size
213
+ self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
214
+ # collate_fn=collate_fn,
215
+ num_workers=1,
216
+ # persistent_workers=True
217
+ )
218
+ self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
219
+ # collate_fn=collate_fn,
220
+ num_workers=1,
221
+ # persistent_workers=True
222
+ )
223
+ self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
224
+ # collate_fn=collate_fn,
225
+ num_workers=1,
226
+ # persistent_workers=True
227
+ )
228
+ pass
229
+
230
+ def prepare_data(self):
231
+ pass
232
+
233
+ def setup(self, stage: str) -> None:
234
+ timber.info(f"inside setup: {stage = }")
235
+ pass
236
+
237
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
238
+ return self.train_loader
239
+
240
+ def val_dataloader(self) -> EVAL_DATALOADERS:
241
+ return self.validate_loader
242
+
243
+ def test_dataloader(self) -> EVAL_DATALOADERS:
244
+ return self.test_loader
245
+
246
+
247
+ class MQtlBertClassifierLightningModule(LightningModule):
248
+ def __init__(self,
249
+ classifier: nn.Module,
250
+ criterion=None, # nn.BCEWithLogitsLoss(),
251
+ regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
252
+ l1_lambda=0.001,
253
+ l2_wright_decay=0.001,
254
+ *args: Any,
255
+ **kwargs: Any):
256
+ super().__init__(*args, **kwargs)
257
+ self.classifier = classifier
258
+ self.criterion = criterion
259
+ self.train_metrics = TorchMetrics()
260
+ self.validate_metrics = TorchMetrics()
261
+ self.test_metrics = TorchMetrics()
262
+
263
+ self.regularization = regularization
264
+ self.l1_lambda = l1_lambda
265
+ self.l2_weight_decay = l2_wright_decay
266
+ pass
267
+
268
+ def forward(self, x, *args: Any, **kwargs: Any) -> Any:
269
+ input_ids: torch.tensor = x["input_ids"]
270
+ attention_mask: torch.tensor = x["attention_mask"]
271
+ token_type_ids: torch.tensor = x["token_type_ids"]
272
+ # print(f"\n{ type(input_ids) = }, {input_ids = }")
273
+ # print(f"{ type(attention_mask) = }, { attention_mask = }")
274
+ # print(f"{ type(token_type_ids) = }, { token_type_ids = }")
275
+
276
+ return self.classifier.forward(input_ids, attention_mask, token_type_ids)
277
+
278
+ def configure_optimizers(self) -> OptimizerLRScheduler:
279
+ # Here we add weight decay (L2 regularization) to the optimizer
280
+ weight_decay = 0.0
281
+ if self.regularization == 2 or self.regularization == 3:
282
+ weight_decay = self.l2_weight_decay
283
+ return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
284
+
285
+ def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
286
+ # Accuracy on training batch data
287
+ x, y = batch
288
+ preds = self.forward(x)
289
+ loss = self.criterion(preds, y)
290
+
291
+ if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
292
+ l1_norm = sum(p.abs().sum() for p in self.parameters())
293
+ loss += self.l1_lambda * l1_norm
294
+
295
+ self.log("train_loss", loss)
296
+ # calculate the scores start
297
+ self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
298
+ # calculate the scores end
299
+ return loss
300
+
301
+ def on_train_epoch_end(self) -> None:
302
+ self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
303
+ pass
304
+
305
+ def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
306
+ # Accuracy on validation batch data
307
+ # print(f"debug { batch = }")
308
+ x, y = batch
309
+ preds = self.forward(x)
310
+ loss = self.criterion(preds, y)
311
+ self.log("valid_loss", loss)
312
+ # calculate the scores start
313
+ self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
314
+ # calculate the scores end
315
+ return loss
316
+
317
+ def on_validation_epoch_end(self) -> None:
318
+ self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
319
+ return None
320
+
321
+ def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
322
+ # Accuracy on validation batch data
323
+ x, y = batch
324
+ preds = self.forward(x)
325
+ loss = self.criterion(preds, y)
326
+ self.log("test_loss", loss) # do we need this?
327
+ # calculate the scores start
328
+ self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
329
+ # calculate the scores end
330
+ return loss
331
+
332
+ def on_test_epoch_end(self) -> None:
333
+ self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
334
+ return None
335
+
336
+ pass
337
+
338
+
339
+ DNA_BERT_6 = "zhihan1996/DNA_bert_6"
340
+
341
+
342
+ class CommonAttentionLayer(nn.Module):
343
+ def __init__(self, hidden_size, *args, **kwargs):
344
+ super().__init__(*args, **kwargs)
345
+ self.attention_linear = nn.Linear(hidden_size, 1)
346
+ pass
347
+
348
+ def forward(self, hidden_states):
349
+ # Apply linear layer
350
+ attn_weights = self.attention_linear(hidden_states)
351
+ # Apply softmax to get attention scores
352
+ attn_weights = torch.softmax(attn_weights, dim=1)
353
+ # Apply attention weights to hidden states
354
+ context_vector = torch.sum(attn_weights * hidden_states, dim=1)
355
+ return context_vector, attn_weights
356
+
357
+
358
+ class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
359
+ def forward(self, input, target):
360
+ return super().forward(input.squeeze(), target.float())
361
+
362
+
363
+ class DnaBert6MQTLClassifier(nn.Module, PyTorchModelHubMixin):
364
+ def __init__(self,
365
+ seq_len: int, model_repository_name: str,
366
+ bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
367
+ hidden_size=768,
368
+ num_classes=1,
369
+ *args,
370
+ **kwargs
371
+ ):
372
+ super().__init__(*args, **kwargs)
373
+ self.seq_len = seq_len
374
+ self.model_repository_name = model_repository_name
375
+
376
+ self.model_name = "MQtlDnaBERT6Classifier"
377
+
378
+ self.bert_model = bert_model
379
+ self.attention = CommonAttentionLayer(hidden_size)
380
+ self.classifier = nn.Linear(hidden_size, num_classes)
381
+ pass
382
+
383
+ def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids):
384
+ """
385
+ # torch.Size([128, 1, 512]) --> [128, 512]
386
+ input_ids = input_ids.squeeze(dim=1).to(DEVICE)
387
+ # torch.Size([16, 1, 512]) --> [16, 512]
388
+ attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
389
+ token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
390
+ """
391
+ bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(
392
+ input_ids=input_ids,
393
+ attention_mask=attention_mask,
394
+ token_type_ids=token_type_ids
395
+ )
396
+
397
+ last_hidden_state = bert_output.last_hidden_state
398
+ context_vector, ignore_attention_weight = self.attention(last_hidden_state)
399
+ y = self.classifier(context_vector)
400
+ return y
401
+
402
+
403
+ def start_bert(classifier_model, criterion, m_optimizer=torch.optim.Adam, WINDOW=200,
404
+ is_binned=True, is_debug=False, max_epochs=10, batch_size=8):
405
+ file_suffix = ""
406
+ if is_binned:
407
+ file_suffix = "_binned"
408
+
409
+ data_files = {
410
+ # small samples
411
+ "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
412
+ "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
413
+ "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
414
+ # large samples
415
+ "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
416
+ "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
417
+ "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
418
+ }
419
+
420
+ dataset_map = None
421
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv")
422
+ if is_my_laptop:
423
+ dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
424
+ else:
425
+ dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
426
+
427
+ tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6)
428
+
429
+ train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
430
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
431
+ tokenizer=tokenizer,
432
+ seq_len=WINDOW
433
+ )
434
+ val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
435
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
436
+ tokenizer=tokenizer,
437
+ seq_len=WINDOW)
438
+ test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
439
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
440
+ tokenizer=tokenizer,
441
+ seq_len=WINDOW)
442
+
443
+ data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
444
+
445
+ classifier_model = classifier_model #.to(DEVICE)
446
+ try:
447
+ classifier_model = classifier_model.from_pretrained(classifier_model.model_repository_name)
448
+ except Exception as x:
449
+ print(x)
450
+
451
+ classifier_module = MQtlBertClassifierLightningModule(
452
+ classifier=classifier_model,
453
+ regularization=2, criterion=criterion)
454
+
455
+ # if os.path.exists(model_save_path):
456
+ # classifier_module.load_state_dict(torch.load(model_save_path))
457
+
458
+ classifier_module = classifier_module # .double()
459
+
460
+ trainer = Trainer(max_epochs=max_epochs, precision="32")
461
+ trainer.fit(model=classifier_module, datamodule=data_module)
462
+ timber.info("\n\n")
463
+ trainer.test(model=classifier_module, datamodule=data_module)
464
+ timber.info("\n\n")
465
+ # torch.save(classifier_module.state_dict(), model_save_path) # deprecated, use classifier_model.save_pretrained(model_subdirectory) instead
466
+
467
+ # save locally
468
+ model_subdirectory = classifier_model.model_repository_name
469
+ classifier_model.save_pretrained(model_subdirectory)
470
+
471
+ # push to the hub
472
+ commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
473
+ if is_my_laptop:
474
+ commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
475
+
476
+ classifier_model.push_to_hub(
477
+ repo_id=f"fahimfarhan/{classifier_model.model_repository_name}",
478
+ # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
479
+ commit_message=commit_message # f":tada: Push model for window size {WINDOW}"
480
+ )
481
+
482
+ # reload
483
+ # classifier_model = classifier_model.from_pretrained(f"fahimfarhan/{classifier_model.model_repository_name}")
484
+ # classifier_model = classifier_model.from_pretrained(model_subdirectory)
485
+
486
+ pass
487
+
488
+
489
+ if __name__ == '__main__':
490
+ login_inside_huggingface_virtualmachine()
491
+
492
+ WINDOW = 1000
493
+ some_model = DnaBert6MQTLClassifier(seq_len=WINDOW, model_repository_name="dnabert-6-mqtl-classifier")
494
+ criterion = ReshapedBCEWithLogitsLoss()
495
+
496
+ start_bert(
497
+ classifier_model=some_model,
498
+ criterion=criterion,
499
+ WINDOW=WINDOW,
500
+ is_debug=False,
501
+ max_epochs=20,
502
+ batch_size=16
503
+ )
504
+ pass