Soumic
:hammer_and_pick: Move old code to app_v2.py, and rewrite app.py just like hyenadna finetune
31eb488
import logging
import os
import random
from typing import Any
import numpy as np
import pandas as pd
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
import torch
from torch import nn
from datasets import load_dataset, IterableDataset
from huggingface_hub import PyTorchModelHubMixin
from dotenv import load_dotenv
from huggingface_hub import login
timber = logging.getLogger()
# logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
black = "\u001b[30m"
red = "\u001b[31m"
green = "\u001b[32m"
yellow = "\u001b[33m"
blue = "\u001b[34m"
magenta = "\u001b[35m"
cyan = "\u001b[36m"
white = "\u001b[37m"
FORWARD = "FORWARD_INPUT"
BACKWARD = "BACKWARD_INPUT"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def login_inside_huggingface_virtualmachine():
# Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
try:
load_dotenv() # Only useful on your laptop if .env exists
print(".env file loaded successfully.")
except Exception as e:
print(f"Warning: Could not load .env file. Exception: {e}")
# Try to get the token from environment variables
try:
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
# Log in to Hugging Face Hub
login(token)
print("Logged in to Hugging Face Hub successfully.")
except Exception as e:
print(f"Error during Hugging Face login: {e}")
# Handle the error appropriately (e.g., exit or retry)
def one_hot_e(dna_seq: str) -> np.ndarray:
mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}
size_of_a_seq: int = len(dna_seq)
# forward = np.zeros(shape=(size_of_a_seq, 4))
forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
encoded = np.asarray(forward_list)
encoded_transposed = encoded.transpose() # todo: Needs review
return encoded_transposed
def one_hot_e_column(column: pd.Series) -> np.ndarray:
tmp_list: list = [one_hot_e(seq) for seq in column]
encoded_column = np.asarray(tmp_list).astype(np.float32)
return encoded_column
def reverse_dna_seq(dna_seq: str) -> str:
# m_reversed = ""
# for i in range(0, len(dna_seq)):
# m_reversed = dna_seq[i] + m_reversed
# return m_reversed
return dna_seq[::-1]
def complement_dna_seq(dna_seq: str) -> str:
comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
"a": "t", "c": "g", "t": "a", "g": "c",
"N": "N", "H": "H", "-": "-",
"n": "n", "h": "h"
}
comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
comp_dna_seq: str = "".join(comp_dna_seq_list)
return comp_dna_seq
def reverse_complement_dna_seq(dna_seq: str) -> str:
return reverse_dna_seq(complement_dna_seq(dna_seq))
def reverse_complement_column(column: pd.Series) -> np.ndarray:
rc_column: list = [reverse_complement_dna_seq(seq) for seq in column]
return rc_column
class TorchMetrics:
def __init__(self, device=DEVICE):
self.binary_accuracy = BinaryAccuracy().to(device)
self.binary_auc = BinaryAUROC().to(device)
self.binary_f1_score = BinaryF1Score().to(device)
self.binary_precision = BinaryPrecision().to(device)
self.binary_recall = BinaryRecall().to(device)
pass
def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels)
self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels)
self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels)
self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels)
self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels)
pass
def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
b_accuracy = self.binary_accuracy.compute()
b_auc = self.binary_auc.compute()
b_f1_score = self.binary_f1_score.compute()
b_precision = self.binary_precision.compute()
b_recall = self.binary_recall.compute()
timber.info(
log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
log(f"{log_prefix}_accuracy", b_accuracy)
log(f"{log_prefix}_auc", b_auc)
log(f"{log_prefix}_f1_score", b_f1_score)
log(f"{log_prefix}_precision", b_precision)
log(f"{log_prefix}_recall", b_recall)
self.binary_accuracy.reset()
self.binary_auc.reset()
self.binary_f1_score.reset()
self.binary_precision.reset()
self.binary_recall.reset()
pass
def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
start = 0
end = len(seq)
rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
random_end = rand_pos + len(DEBUG_MOTIF)
output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
assert len(seq) == len(output)
return output
class PagingMQTLDataset(IterableDataset):
def __init__(self,
m_dataset,
seq_len,
tokenizer,
max_length=512,
check_if_pipeline_is_ok_by_inserting_debug_motif=False):
self.dataset = m_dataset
self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
self.debug_motif = "ATCGCCTA"
self.seq_len = seq_len
self.bert_tokenizer = tokenizer
self.max_length = max_length
pass
def __iter__(self):
for row in self.dataset:
processed = self.preprocess(row)
if processed is not None:
yield processed
def preprocess(self, row):
sequence = row['sequence'] # Fetch the 'sequence' column
if len(sequence) != self.seq_len:
return None # skip problematic row!
label = row['label'] # Fetch the 'label' column (or whatever target you use)
if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
# Tokenize the sequence
encoded_sequence: BatchEncoding = self.bert_tokenizer(
sequence,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
return encoded_sequence_squeezed, label
class MqtlDataModule(LightningDataModule):
def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
super().__init__()
self.batch_size = batch_size
self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
# collate_fn=collate_fn,
num_workers=1,
# persistent_workers=True
)
self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
# collate_fn=collate_fn,
num_workers=1,
# persistent_workers=True
)
self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
# collate_fn=collate_fn,
num_workers=1,
# persistent_workers=True
)
pass
def prepare_data(self):
pass
def setup(self, stage: str) -> None:
timber.info(f"inside setup: {stage = }")
pass
def train_dataloader(self) -> TRAIN_DATALOADERS:
return self.train_loader
def val_dataloader(self) -> EVAL_DATALOADERS:
return self.validate_loader
def test_dataloader(self) -> EVAL_DATALOADERS:
return self.test_loader
class MQtlBertClassifierLightningModule(LightningModule):
def __init__(self,
classifier: nn.Module,
criterion=None, # nn.BCEWithLogitsLoss(),
regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
l1_lambda=0.001,
l2_wright_decay=0.001,
*args: Any,
**kwargs: Any):
super().__init__(*args, **kwargs)
self.classifier = classifier
self.criterion = criterion
self.train_metrics = TorchMetrics()
self.validate_metrics = TorchMetrics()
self.test_metrics = TorchMetrics()
self.regularization = regularization
self.l1_lambda = l1_lambda
self.l2_weight_decay = l2_wright_decay
pass
def forward(self, x, *args: Any, **kwargs: Any) -> Any:
input_ids: torch.tensor = x["input_ids"]
attention_mask: torch.tensor = x["attention_mask"]
token_type_ids: torch.tensor = x["token_type_ids"]
# print(f"\n{ type(input_ids) = }, {input_ids = }")
# print(f"{ type(attention_mask) = }, { attention_mask = }")
# print(f"{ type(token_type_ids) = }, { token_type_ids = }")
return self.classifier.forward(input_ids, attention_mask, token_type_ids)
def configure_optimizers(self) -> OptimizerLRScheduler:
# Here we add weight decay (L2 regularization) to the optimizer
weight_decay = 0.0
if self.regularization == 2 or self.regularization == 3:
weight_decay = self.l2_weight_decay
return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# Accuracy on training batch data
x, y = batch
preds = self.forward(x)
loss = self.criterion(preds, y)
if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
l1_norm = sum(p.abs().sum() for p in self.parameters())
loss += self.l1_lambda * l1_norm
self.log("train_loss", loss)
# calculate the scores start
self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
# calculate the scores end
return loss
def on_train_epoch_end(self) -> None:
self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
pass
def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# Accuracy on validation batch data
# print(f"debug { batch = }")
x, y = batch
preds = self.forward(x)
loss = self.criterion(preds, y)
self.log("valid_loss", loss)
# calculate the scores start
self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
# calculate the scores end
return loss
def on_validation_epoch_end(self) -> None:
self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
return None
def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# Accuracy on validation batch data
x, y = batch
preds = self.forward(x)
loss = self.criterion(preds, y)
self.log("test_loss", loss) # do we need this?
# calculate the scores start
self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
# calculate the scores end
return loss
def on_test_epoch_end(self) -> None:
self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
return None
pass
DNA_BERT_6 = "zhihan1996/DNA_bert_6"
class CommonAttentionLayer(nn.Module):
def __init__(self, hidden_size, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attention_linear = nn.Linear(hidden_size, 1)
pass
def forward(self, hidden_states):
# Apply linear layer
attn_weights = self.attention_linear(hidden_states)
# Apply softmax to get attention scores
attn_weights = torch.softmax(attn_weights, dim=1)
# Apply attention weights to hidden states
context_vector = torch.sum(attn_weights * hidden_states, dim=1)
return context_vector, attn_weights
class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
def forward(self, input, target):
return super().forward(input.squeeze(), target.float())
class DnaBert6MQTLClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self,
seq_len: int, model_repository_name: str,
bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
hidden_size=768,
num_classes=1,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.seq_len = seq_len
self.model_repository_name = model_repository_name
self.model_name = "MQtlDnaBERT6Classifier"
self.bert_model = bert_model
self.attention = CommonAttentionLayer(hidden_size)
self.classifier = nn.Linear(hidden_size, num_classes)
pass
def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids):
"""
# torch.Size([128, 1, 512]) --> [128, 512]
input_ids = input_ids.squeeze(dim=1).to(DEVICE)
# torch.Size([16, 1, 512]) --> [16, 512]
attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
"""
bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
last_hidden_state = bert_output.last_hidden_state
context_vector, ignore_attention_weight = self.attention(last_hidden_state)
y = self.classifier(context_vector)
return y
def start_bert(classifier_model, criterion, m_optimizer=torch.optim.Adam, WINDOW=200,
is_binned=True, is_debug=False, max_epochs=10, batch_size=8):
file_suffix = ""
if is_binned:
file_suffix = "_binned"
data_files = {
# small samples
"train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
"validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
"test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
# large samples
"train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
"validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
"test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
}
dataset_map = None
is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv")
if is_my_laptop:
dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
else:
dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6)
train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
tokenizer=tokenizer,
seq_len=WINDOW
)
val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
tokenizer=tokenizer,
seq_len=WINDOW)
test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
tokenizer=tokenizer,
seq_len=WINDOW)
data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
classifier_model = classifier_model #.to(DEVICE)
try:
classifier_model = classifier_model.from_pretrained(classifier_model.model_repository_name)
except Exception as x:
print(x)
classifier_module = MQtlBertClassifierLightningModule(
classifier=classifier_model,
regularization=2, criterion=criterion)
# if os.path.exists(model_save_path):
# classifier_module.load_state_dict(torch.load(model_save_path))
classifier_module = classifier_module # .double()
trainer = Trainer(max_epochs=max_epochs, precision="32")
trainer.fit(model=classifier_module, datamodule=data_module)
timber.info("\n\n")
trainer.test(model=classifier_module, datamodule=data_module)
timber.info("\n\n")
# torch.save(classifier_module.state_dict(), model_save_path) # deprecated, use classifier_model.save_pretrained(model_subdirectory) instead
# save locally
model_subdirectory = classifier_model.model_repository_name
classifier_model.save_pretrained(model_subdirectory)
# push to the hub
commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
if is_my_laptop:
commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
classifier_model.push_to_hub(
repo_id=f"fahimfarhan/{classifier_model.model_repository_name}",
# subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
commit_message=commit_message # f":tada: Push model for window size {WINDOW}"
)
# reload
# classifier_model = classifier_model.from_pretrained(f"fahimfarhan/{classifier_model.model_repository_name}")
# classifier_model = classifier_model.from_pretrained(model_subdirectory)
pass
if __name__ == '__main__':
login_inside_huggingface_virtualmachine()
WINDOW = 1000
some_model = DnaBert6MQTLClassifier(seq_len=WINDOW, model_repository_name="dnabert-6-mqtl-classifier")
criterion = ReshapedBCEWithLogitsLoss()
start_bert(
classifier_model=some_model,
criterion=criterion,
WINDOW=WINDOW,
is_debug=False,
max_epochs=20,
batch_size=16
)
pass