Soumic
:rocket: Finetuned for 18hrs on my laptop
a01b289
import os
import random
import huggingface_hub
import numpy as np
from datasets import load_dataset, Dataset
from dotenv import load_dotenv
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
from torch.utils.data import DataLoader, IterableDataset
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
# from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel
from transformers import TrainingArguments, Trainer
import torch
import logging
import wandb
timber = logging.getLogger()
# logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
black = "\u001b[30m"
red = "\u001b[31m"
green = "\u001b[32m"
yellow = "\u001b[33m"
blue = "\u001b[34m"
magenta = "\u001b[35m"
cyan = "\u001b[36m"
white = "\u001b[37m"
FORWARD = "FORWARD_INPUT"
BACKWARD = "BACKWARD_INPUT"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PRETRAINED_MODEL_NAME: str = "LongSafari/hyenadna-small-32k-seqlen-hf"
def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
start = 0
end = len(seq)
rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
random_end = rand_pos + len(DEBUG_MOTIF)
output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
assert len(seq) == len(output)
return output
class PagingMQTLDataset(IterableDataset):
def __init__(self,
m_dataset,
seq_len,
tokenizer,
max_length=512,
check_if_pipeline_is_ok_by_inserting_debug_motif=False):
self.dataset = m_dataset
self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
self.debug_motif = "ATCGCCTA"
self.seq_len = seq_len
self.bert_tokenizer = tokenizer
self.max_length = max_length
pass
def __iter__(self):
for row in self.dataset:
processed = self.preprocess(row)
if processed is not None:
yield processed
def preprocess(self, row):
sequence = row['sequence'] # Fetch the 'sequence' column
if len(sequence) != self.seq_len:
return None # skip problematic row!
label = row['label'] # Fetch the 'label' column (or whatever target you use)
if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
input_ids = self.bert_tokenizer(sequence)["input_ids"]
tokenized_tensor = torch.tensor(input_ids)
label_tensor = torch.tensor(label)
output_dict = {"input_ids": tokenized_tensor, "labels": label_tensor} # so this is now you do it?
return output_dict # tokenized_tensor, label_tensor
class MqtlDataModule(LightningDataModule):
def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
super().__init__()
self.batch_size = batch_size
self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
# collate_fn=collate_fn,
num_workers=1,
# persistent_workers=True
)
self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
# collate_fn=collate_fn,
num_workers=1,
# persistent_workers=True
)
self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
# collate_fn=collate_fn,
num_workers=1,
# persistent_workers=True
)
pass
def prepare_data(self):
pass
def setup(self, stage: str) -> None:
timber.info(f"inside setup: {stage = }")
pass
def train_dataloader(self) -> TRAIN_DATALOADERS:
return self.train_loader
def val_dataloader(self) -> EVAL_DATALOADERS:
return self.validate_loader
def test_dataloader(self) -> EVAL_DATALOADERS:
return self.test_loader
def create_paging_train_val_test_datasets(tokenizer, WINDOW, is_debug, batch_size=1000):
data_files = {
# small samples
"train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
"validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
"test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
# medium samples
"train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
"validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
"test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",
# large samples
"train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
"validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
"test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
}
dataset_map = None
is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
if is_my_laptop:
dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
else:
dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
tokenizer=tokenizer,
seq_len=WINDOW
)
val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
tokenizer=tokenizer,
seq_len=WINDOW)
test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
tokenizer=tokenizer,
seq_len=WINDOW)
# data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
return train_dataset, val_dataset, test_dataset
def login_inside_huggingface_virtualmachine():
# Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
try:
load_dotenv() # Only useful on your laptop if .env exists
print(".env file loaded successfully.")
except Exception as e:
print(f"Warning: Could not load .env file. Exception: {e}")
# Try to get the token from environment variables
try:
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
# Log in to Hugging Face Hub
huggingface_hub.login(token)
print("Logged in to Hugging Face Hub successfully.")
except Exception as e:
print(f"Error during Hugging Face login: {e}")
# Handle the error appropriately (e.g., exit or retry)
# wand db login
try:
api_key = os.getenv("WAND_DB_API_KEY")
timber.info(f"{api_key = }")
if not api_key:
raise ValueError("WAND_DB_API_KEY not found. Make sure to set it in the environment variables or .env file.")
# Log in to Hugging Face Hub
wandb.login(key=api_key)
print("Logged in to wand db successfully.")
except Exception as e:
print(f"Error during wand db Face login: {e}")
pass
# use sklearn cz torchmetrics.classification gave array index out of bound exception :/ (whatever it is called in python)
def compute_metrics_using_sklearn(p):
try:
pred, labels = p
# Get predicted class labels
pred_labels = np.argmax(pred, axis=1)
# Get predicted probabilities for the positive class
pred_probs = pred[:, 1] # Assuming binary classification and 2 output classes
accuracy = accuracy_score(y_true=labels, y_pred=pred_labels)
recall = recall_score(y_true=labels, y_pred=pred_labels)
precision = precision_score(y_true=labels, y_pred=pred_labels)
f1 = f1_score(y_true=labels, y_pred=pred_labels)
roc_auc = roc_auc_score(y_true=labels, y_score=pred_probs)
return {"accuracy": accuracy, "roc_auc": roc_auc, "precision": precision, "recall": recall, "f1": f1}
except Exception as x:
print(f"compute_metrics_using_sklearn failed with exception: {x}")
return {"accuracy": 0, "roc_auc": 0, "precision": 0, "recall": 0, "f1": 0}
def start():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
login_inside_huggingface_virtualmachine()
WINDOW = 4000
batch_size = 100
model_local_directory = f"my-awesome-model-{WINDOW}"
model_remote_repository = f"fahimfarhan/hyenadna-sm-32k-mqtl-classifier-seq-len-{WINDOW}"
is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, trust_remote_code=True)
classifier_model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True)
args = {
"output_dir": "output_hyena_dna-mqtl_classification",
"num_train_epochs": 1, # I don't get it, it ain't working :/ just set it to 1
"max_steps": 2_0000, # my small dataset -> train 18k + val 2k = 20k. will take 17h to train :'(
# Set the number of steps you expect to train, originally 1000, takes too much time. So I set it to 10 to run faster and check my code/pipeline
"run_name": "laptop_run_hyena_dna-mqtl_classification", # Override run_name here
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 32,
"gradient_checkpointing": True,
"learning_rate": 1e-3,
"save_safetensors": False, # I added it. this solves the runtime error!
# not sure if it is a good idea. sklearn may slow down training, causing time loss... if so, disable these 2 lines below
"evaluation_strategy": "epoch", # To calculate metrics per epoch
"logging_strategy": "epoch" # Extra: to log training data stats for loss
}
# """
# got this error at the end!
# raise RuntimeError(
# RuntimeError: The weights trying to be saved contained shared tensors [{'hyena.backbone.layers.0.mixer.filter_fn.implicit_filter.3.freq', 'hyena.backbone.layers.0.mixer.filter_fn.implicit_filter.1.freq', 'hyena.backbone.layers.0.mixer.filter_fn.implicit_filter.5.freq'}] that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.
# """
training_args = TrainingArguments(**args)
# train_dataset, eval_dataset, test_dataset = create_data_module(tokenizer=tokenizer, WINDOW=WINDOW,
# batch_size=batch_size,
# is_debug=False)
""" # example code
max_length = 32_000
sequence = 'ACTG' * int(max_length / 4)
# sequence = 'ACTG' * int(1000) # seq_len = 4000 it works!
sequence = [sequence] * 8 # Create 8 identical samples
tokenized = tokenizer(sequence)["input_ids"]
labels = [0, 1] * 4
# Create a dataset for training
run_the_code_ds = Dataset.from_dict({"input_ids": tokenized, "labels": labels})
run_the_code_ds.set_format("pt")
"""
train_ds, val_ds, test_ds = create_paging_train_val_test_datasets(tokenizer, WINDOW=WINDOW, is_debug=False)
# train_ds, val_ds, test_ds = run_the_code_ds, run_the_code_ds, run_the_code_ds
# train_ds.set_format("pt") # doesn't work!
trainer = Trainer(
model=classifier_model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics_using_sklearn # torch_metrics.compute_metrics
)
# train, and validate
result = trainer.train()
try:
print(f"{result = }")
except Exception as x:
print(f"{x = }")
# testing
try:
# with torch.no_grad(): # didn't work :/
test_results = trainer.evaluate(eval_dataset=test_ds)
print(f"{test_results = }")
except Exception as oome:
print(f"{oome = }")
finally:
# save the model
model_name = "HyenaDnaMQtlClassifier"
classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False)
# push to the hub
commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
if is_my_laptop:
commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
classifier_model.push_to_hub(
repo_id=model_remote_repository,
# subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
commit_message=commit_message, # f":tada: Push model for window size {WINDOW}"
safe_serialization=False
)
pass
def interprete_demo():
is_my_laptop = True
WINDOW = 4000
batch_size = 100
model_local_directory = f"my-awesome-model-{WINDOW}"
model_remote_repository = f"fahimfarhan/hyenadna-sm-32k-mqtl-classifier-seq-len-{WINDOW}"
try:
classifier_model = AutoModel.from_pretrained(model_remote_repository)
# todo: use captum / gentech-grelu to interpret the model
except Exception as x:
print(x)
if __name__ == '__main__':
start()
pass