Spaces:
Runtime error
Runtime error
import os | |
import random | |
import huggingface_hub | |
import numpy as np | |
from datasets import load_dataset, Dataset | |
from dotenv import load_dotenv | |
from pytorch_lightning import LightningDataModule | |
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS | |
from torch.utils.data import DataLoader, IterableDataset | |
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score | |
# from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, BertModel | |
from transformers import TrainingArguments, Trainer | |
import torch | |
import logging | |
import wandb | |
timber = logging.getLogger() | |
# logging.basicConfig(level=logging.DEBUG) | |
logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs... | |
black = "\u001b[30m" | |
red = "\u001b[31m" | |
green = "\u001b[32m" | |
yellow = "\u001b[33m" | |
blue = "\u001b[34m" | |
magenta = "\u001b[35m" | |
cyan = "\u001b[36m" | |
white = "\u001b[37m" | |
FORWARD = "FORWARD_INPUT" | |
BACKWARD = "BACKWARD_INPUT" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
PRETRAINED_MODEL_NAME: str = "zhihan1996/DNA_bert_6" | |
def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF): | |
start = 0 | |
end = len(seq) | |
rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF))) | |
random_end = rand_pos + len(DEBUG_MOTIF) | |
output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end] | |
assert len(seq) == len(output) | |
return output | |
class PagingMQTLDataset(IterableDataset): | |
def __init__(self, | |
m_dataset, | |
seq_len, | |
tokenizer, | |
max_length=512, | |
check_if_pipeline_is_ok_by_inserting_debug_motif=False): | |
self.dataset = m_dataset | |
self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif | |
self.debug_motif = "ATCGCCTA" | |
self.seq_len = seq_len | |
self.bert_tokenizer = tokenizer | |
self.max_length = max_length | |
pass | |
def __iter__(self): | |
for row in self.dataset: | |
processed = self.preprocess(row) | |
if processed is not None: | |
yield processed | |
def preprocess(self, row): | |
sequence = row['sequence'] # Fetch the 'sequence' column | |
if len(sequence) != self.seq_len: | |
return None # skip problematic row! | |
label = row['label'] # Fetch the 'label' column (or whatever target you use) | |
if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif: | |
sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif) | |
input_ids = self.bert_tokenizer(sequence)["input_ids"] | |
tokenized_tensor = torch.tensor(input_ids) | |
label_tensor = torch.tensor(label) | |
output_dict = {"input_ids": tokenized_tensor, "labels": label_tensor} # so this is now you do it? | |
return output_dict # tokenized_tensor, label_tensor | |
class MqtlDataModule(LightningDataModule): | |
def __init__(self, train_ds, val_ds, test_ds, batch_size=16): | |
super().__init__() | |
self.batch_size = batch_size | |
self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False, | |
# collate_fn=collate_fn, | |
num_workers=1, | |
# persistent_workers=True | |
) | |
self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False, | |
# collate_fn=collate_fn, | |
num_workers=1, | |
# persistent_workers=True | |
) | |
self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False, | |
# collate_fn=collate_fn, | |
num_workers=1, | |
# persistent_workers=True | |
) | |
pass | |
def prepare_data(self): | |
pass | |
def setup(self, stage: str) -> None: | |
timber.info(f"inside setup: {stage = }") | |
pass | |
def train_dataloader(self) -> TRAIN_DATALOADERS: | |
return self.train_loader | |
def val_dataloader(self) -> EVAL_DATALOADERS: | |
return self.validate_loader | |
def test_dataloader(self) -> EVAL_DATALOADERS: | |
return self.test_loader | |
def create_paging_train_val_test_datasets(tokenizer, WINDOW, is_debug, batch_size=1000): | |
data_files = { | |
# small samples | |
"train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv", | |
"validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv", | |
"test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv", | |
# medium samples | |
"train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv", | |
"validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv", | |
"test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv", | |
# large samples | |
"train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv", | |
"validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv", | |
"test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv", | |
} | |
dataset_map = None | |
is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv") | |
if is_my_laptop: | |
dataset_map = load_dataset("csv", data_files=data_files, streaming=True) | |
else: | |
dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True) | |
train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"], | |
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug, | |
tokenizer=tokenizer, | |
seq_len=WINDOW | |
) | |
val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"], | |
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug, | |
tokenizer=tokenizer, | |
seq_len=WINDOW) | |
test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"], | |
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug, | |
tokenizer=tokenizer, | |
seq_len=WINDOW) | |
# data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size) | |
return train_dataset, val_dataset, test_dataset | |
def login_inside_huggingface_virtualmachine(): | |
# Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space) | |
try: | |
load_dotenv() # Only useful on your laptop if .env exists | |
print(".env file loaded successfully.") | |
except Exception as e: | |
print(f"Warning: Could not load .env file. Exception: {e}") | |
# Try to get the token from environment variables | |
try: | |
token = os.getenv("HF_TOKEN") | |
if not token: | |
raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.") | |
# Log in to Hugging Face Hub | |
huggingface_hub.login(token) | |
print("Logged in to Hugging Face Hub successfully.") | |
except Exception as e: | |
print(f"Error during Hugging Face login: {e}") | |
# Handle the error appropriately (e.g., exit or retry) | |
# wand db login | |
try: | |
api_key = os.getenv("WAND_DB_API_KEY") | |
timber.info(f"{api_key = }") | |
if not api_key: | |
raise ValueError("WAND_DB_API_KEY not found. Make sure to set it in the environment variables or .env file.") | |
# Log in to Hugging Face Hub | |
wandb.login(key=api_key) | |
print("Logged in to wand db successfully.") | |
except Exception as e: | |
print(f"Error during wand db Face login: {e}") | |
pass | |
# use sklearn cz torchmetrics.classification gave array index out of bound exception :/ (whatever it is called in python) | |
def compute_metrics_using_sklearn(p): | |
try: | |
pred, labels = p | |
# Get predicted class labels | |
pred_labels = np.argmax(pred, axis=1) | |
# Get predicted probabilities for the positive class | |
pred_probs = pred[:, 1] # Assuming binary classification and 2 output classes | |
accuracy = accuracy_score(y_true=labels, y_pred=pred_labels) | |
recall = recall_score(y_true=labels, y_pred=pred_labels) | |
precision = precision_score(y_true=labels, y_pred=pred_labels) | |
f1 = f1_score(y_true=labels, y_pred=pred_labels) | |
roc_auc = roc_auc_score(y_true=labels, y_score=pred_probs) | |
return {"accuracy": accuracy, "roc_auc": roc_auc, "precision": precision, "recall": recall, "f1": f1} | |
except Exception as x: | |
print(f"compute_metrics_using_sklearn failed with exception: {x}") | |
return {"accuracy": 0, "roc_auc": 0, "precision": 0, "recall": 0, "f1": 0} | |
def start(): | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
login_inside_huggingface_virtualmachine() | |
WINDOW = 4000 | |
batch_size = 100 | |
model_local_directory = f"my-awesome-model-{WINDOW}" | |
model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}" | |
is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv") | |
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, trust_remote_code=True) | |
classifier_model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME, num_labels=2) | |
args = { | |
"output_dir": "output_dnabert-6-mqtl_classification", | |
"num_train_epochs": 1, | |
"max_steps": 20_000, # train 36k + val 4k = 40k | |
# Set the number of steps you expect to train, originally 1000, takes too much time. So I set it to 10 to run faster and check my code/pipeline | |
"run_name": "laptop_run_dna-bert-6-mqtl_classification", # Override run_name here | |
"per_device_train_batch_size": 1, | |
"gradient_accumulation_steps": 32, | |
"gradient_checkpointing": True, | |
"learning_rate": 1e-3, | |
"save_safetensors": False, # I added it. this solves the runtime error! | |
# not sure if it is a good idea. sklearn may slow down training, causing time loss... if so, disable these 2 lines below | |
"evaluation_strategy": "epoch", # To calculate metrics per epoch | |
"logging_strategy": "epoch" # Extra: to log training data stats for loss | |
} | |
training_args = TrainingArguments(**args) | |
# train_dataset, eval_dataset, test_dataset = create_data_module(tokenizer=tokenizer, WINDOW=WINDOW, | |
# batch_size=batch_size, | |
# is_debug=False) | |
""" # example code | |
max_length = 32_000 | |
sequence = 'ACTG' * int(max_length / 4) | |
# sequence = 'ACTG' * int(1000) # seq_len = 4000 it works! | |
sequence = [sequence] * 8 # Create 8 identical samples | |
tokenized = tokenizer(sequence)["input_ids"] | |
labels = [0, 1] * 4 | |
# Create a dataset for training | |
run_the_code_ds = Dataset.from_dict({"input_ids": tokenized, "labels": labels}) | |
run_the_code_ds.set_format("pt") | |
""" | |
train_ds, val_ds, test_ds = create_paging_train_val_test_datasets(tokenizer, WINDOW=WINDOW, is_debug=False) | |
# train_ds, val_ds, test_ds = run_the_code_ds, run_the_code_ds, run_the_code_ds | |
# train_ds.set_format("pt") # doesn't work! | |
trainer = Trainer( | |
model=classifier_model, | |
args=training_args, | |
train_dataset=train_ds, | |
eval_dataset=val_ds, | |
compute_metrics=compute_metrics_using_sklearn # torch_metrics.compute_metrics | |
) | |
# train, and validate | |
result = trainer.train() | |
try: | |
print(f"{result = }") | |
except Exception as x: | |
print(f"{x = }") | |
# testing | |
try: | |
# with torch.no_grad(): # didn't work :/ | |
test_results = trainer.evaluate(eval_dataset=test_ds) | |
print(f"{test_results = }") | |
except Exception as oome: | |
print(f"{oome = }") | |
finally: | |
# save the model | |
model_name = "DnaBert6MQtlClassifier" | |
classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False) | |
# push to the hub | |
commit_message = f":tada: Push model for window size {WINDOW} from huggingface space" | |
if is_my_laptop: | |
commit_message = f":tada: Push model for window size {WINDOW} from zephyrus" | |
classifier_model.push_to_hub( | |
repo_id=model_remote_repository, | |
# subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/ | |
commit_message=commit_message, # f":tada: Push model for window size {WINDOW}" | |
safe_serialization=False | |
) | |
pass | |
def interprete_demo(): | |
is_my_laptop = True | |
WINDOW = 4000 | |
batch_size = 100 | |
model_local_directory = f"my-awesome-model-{WINDOW}" | |
model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}" | |
try: | |
classifier_model = AutoModel.from_pretrained(model_remote_repository) | |
# todo: use captum / gentech-grelu to interpret the model | |
except Exception as x: | |
print(x) | |
if __name__ == '__main__': | |
start() | |
pass | |