|
""" |
|
Geneformer classifier. |
|
|
|
**Input data:** |
|
|
|
Cell state classifier: |
|
| Single-cell transcriptomes as Geneformer rank value encodings with cell state labels |
|
| in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py) |
|
|
|
Gene classifier: |
|
| Dictionary in format {Gene_label: list(genes)} for gene labels |
|
| and single-cell transcriptomes as Geneformer rank value encodings |
|
| in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py) |
|
|
|
**Usage:** |
|
|
|
.. code-block :: python |
|
|
|
>>> from geneformer import Classifier |
|
>>> cc = Classifier(classifier="cell", # example of cell state classifier |
|
... cell_state_dict={"state_key": "disease", "states": "all"}, |
|
... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}, |
|
... training_args=training_args, |
|
... freeze_layers = 2, |
|
... num_crossval_splits = 1, |
|
... forward_batch_size=200, |
|
... nproc=16) |
|
>>> cc.prepare_data(input_data_file="path/to/input_data", |
|
... output_directory="path/to/output_directory", |
|
... output_prefix="output_prefix") |
|
>>> all_metrics = cc.validate(model_directory="path/to/model", |
|
... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset", |
|
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl", |
|
... output_directory="path/to/output_directory", |
|
... output_prefix="output_prefix", |
|
... predict=True) |
|
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]}, |
|
... output_directory="path/to/output_directory", |
|
... output_prefix="output_prefix", |
|
... custom_class_order=["healthy","disease1","disease2"]) |
|
>>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl", |
|
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl", |
|
... title="disease", |
|
... output_directory="path/to/output_directory", |
|
... output_prefix="output_prefix", |
|
... custom_class_order=["healthy","disease1","disease2"]) |
|
""" |
|
|
|
import datetime |
|
import logging |
|
import os |
|
import pickle |
|
import subprocess |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
from sklearn.model_selection import StratifiedKFold |
|
from tqdm.auto import tqdm, trange |
|
from transformers import Trainer |
|
from transformers.training_args import TrainingArguments |
|
|
|
from . import DataCollatorForCellClassification, DataCollatorForGeneClassification |
|
from . import classifier_utils as cu |
|
from . import evaluation_utils as eu |
|
from . import perturber_utils as pu |
|
from .tokenizer import TOKEN_DICTIONARY_FILE |
|
|
|
sns.set() |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class Classifier: |
|
valid_option_dict = { |
|
"classifier": {"cell", "gene"}, |
|
"cell_state_dict": {None, dict}, |
|
"gene_class_dict": {None, dict}, |
|
"filter_data": {None, dict}, |
|
"rare_threshold": {int, float}, |
|
"max_ncells": {None, int}, |
|
"max_ncells_per_class": {None, int}, |
|
"training_args": {None, dict}, |
|
"freeze_layers": {int}, |
|
"num_crossval_splits": {0, 1, 5}, |
|
"eval_size": {int, float}, |
|
"no_eval": {bool}, |
|
"stratify_splits_col": {None, str}, |
|
"forward_batch_size": {int}, |
|
"nproc": {int}, |
|
} |
|
|
|
def __init__( |
|
self, |
|
classifier=None, |
|
cell_state_dict=None, |
|
gene_class_dict=None, |
|
filter_data=None, |
|
rare_threshold=0, |
|
max_ncells=None, |
|
max_ncells_per_class=None, |
|
training_args=None, |
|
freeze_layers=0, |
|
num_crossval_splits=1, |
|
eval_size=0.2, |
|
stratify_splits_col=None, |
|
no_eval=False, |
|
forward_batch_size=100, |
|
nproc=4, |
|
): |
|
""" |
|
Initialize Geneformer classifier. |
|
|
|
**Parameters:** |
|
|
|
classifier : {"cell", "gene"} |
|
| Whether to fine-tune a cell state or gene classifier. |
|
cell_state_dict : None, dict |
|
| Cell states to fine-tune model to distinguish. |
|
| Two-item dictionary with keys: state_key and states |
|
| state_key: key specifying name of column in .dataset that defines the states to model |
|
| states: list of values in the state_key column that specifies the states to model |
|
| Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data. |
|
| Of note, if using "all", states will be defined after data is filtered. |
|
| Must have at least 2 states to model. |
|
| For example: {"state_key": "disease", |
|
| "states": ["nf", "hcm", "dcm"]} |
|
| or |
|
| {"state_key": "disease", |
|
| "states": "all"} |
|
gene_class_dict : None, dict |
|
| Gene classes to fine-tune model to distinguish. |
|
| Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...), |
|
| Gene_label_B: list(geneB1, geneB2, ...)} |
|
| Gene values should be Ensembl IDs. |
|
filter_data : None, dict |
|
| Default is to fine-tune with all input data. |
|
| Otherwise, dictionary specifying .dataset column name and list of values to filter by. |
|
rare_threshold : float |
|
| Threshold below which rare cell states should be removed. |
|
| For example, setting to 0.05 will remove cell states representing |
|
| < 5% of the total cells from the cell state classifier's possible classes. |
|
max_ncells : None, int |
|
| Maximum number of cells to use for fine-tuning. |
|
| Default is to fine-tune with all input data. |
|
max_ncells_per_class : None, int |
|
| Maximum number of cells per cell class to use for fine-tuning. |
|
| Of note, will be applied after max_ncells above. |
|
| (Only valid for cell classification.) |
|
training_args : None, dict |
|
| Training arguments for fine-tuning. |
|
| If None, defaults will be inferred for 6 layer Geneformer. |
|
| Otherwise, will use the Hugging Face defaults: |
|
| https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments |
|
| Note: Hyperparameter tuning is highly recommended, rather than using defaults. |
|
freeze_layers : int |
|
| Number of layers to freeze from fine-tuning. |
|
| 0: no layers will be frozen; 2: first two layers will be frozen; etc. |
|
num_crossval_splits : {0, 1, 5} |
|
| 0: train on all data without splitting |
|
| 1: split data into train and eval sets by designated eval_size |
|
| 5: split data into 5 folds of train and eval sets by designated eval_size |
|
eval_size : None, float |
|
| Proportion of data to hold out for evaluation (e.g. 0.2 if intending 80:20 train/eval split) |
|
stratify_splits_col : None, str |
|
| Name of column in .dataset to be used for stratified splitting. |
|
| Proportion of each class in this column will be the same in the splits as in the original dataset. |
|
no_eval : bool |
|
| If True, will skip eval step and use all data for training. |
|
| Otherwise, will perform eval during training. |
|
forward_batch_size : int |
|
| Batch size for forward pass (for evaluation, not training). |
|
nproc : int |
|
| Number of CPU processes to use. |
|
|
|
""" |
|
|
|
self.classifier = classifier |
|
self.cell_state_dict = cell_state_dict |
|
self.gene_class_dict = gene_class_dict |
|
self.filter_data = filter_data |
|
self.rare_threshold = rare_threshold |
|
self.max_ncells = max_ncells |
|
self.max_ncells_per_class = max_ncells_per_class |
|
self.training_args = training_args |
|
self.freeze_layers = freeze_layers |
|
self.num_crossval_splits = num_crossval_splits |
|
self.eval_size = eval_size |
|
self.stratify_splits_col = stratify_splits_col |
|
self.no_eval = no_eval |
|
self.forward_batch_size = forward_batch_size |
|
self.nproc = nproc |
|
|
|
if self.training_args is None: |
|
logger.warning( |
|
"Hyperparameter tuning is highly recommended for optimal results. " |
|
"No training_args provided; using default hyperparameters." |
|
) |
|
|
|
self.validate_options() |
|
|
|
if self.filter_data is None: |
|
self.filter_data = dict() |
|
|
|
if self.classifier == "cell": |
|
if self.cell_state_dict["states"] != "all": |
|
self.filter_data[ |
|
self.cell_state_dict["state_key"] |
|
] = self.cell_state_dict["states"] |
|
|
|
|
|
with open(TOKEN_DICTIONARY_FILE, "rb") as f: |
|
self.gene_token_dict = pickle.load(f) |
|
|
|
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()} |
|
|
|
|
|
if self.classifier == "gene": |
|
all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values())) |
|
missing_genes = [ |
|
gene |
|
for gene in all_gene_class_values |
|
if gene not in self.gene_token_dict.keys() |
|
] |
|
if len(missing_genes) == len(all_gene_class_values): |
|
logger.error( |
|
"None of the provided genes to classify are in token dictionary." |
|
) |
|
raise |
|
elif len(missing_genes) > 0: |
|
logger.warning( |
|
f"Genes to classify {missing_genes} are not in token dictionary." |
|
) |
|
self.gene_class_dict = { |
|
k: set([self.gene_token_dict.get(gene) for gene in v]) |
|
for k, v in self.gene_class_dict.items() |
|
} |
|
empty_classes = [] |
|
for k, v in self.gene_class_dict.items(): |
|
if len(v) == 0: |
|
empty_classes += [k] |
|
if len(empty_classes) > 0: |
|
logger.error( |
|
f"Class(es) {empty_classes} did not contain any genes in the token dictionary." |
|
) |
|
raise |
|
|
|
def validate_options(self): |
|
|
|
for attr_name, valid_options in self.valid_option_dict.items(): |
|
attr_value = self.__dict__[attr_name] |
|
if not isinstance(attr_value, (list, dict)): |
|
if attr_value in valid_options: |
|
continue |
|
valid_type = False |
|
for option in valid_options: |
|
if (option in [int, float, list, dict, bool]) and isinstance( |
|
attr_value, option |
|
): |
|
valid_type = True |
|
break |
|
if valid_type: |
|
continue |
|
logger.error( |
|
f"Invalid option for {attr_name}. " |
|
f"Valid options for {attr_name}: {valid_options}" |
|
) |
|
raise |
|
|
|
if self.filter_data is not None: |
|
for key, value in self.filter_data.items(): |
|
if not isinstance(value, list): |
|
self.filter_data[key] = [value] |
|
logger.warning( |
|
"Values in filter_data dict must be lists. " |
|
f"Changing {key} value to list ([{value}])." |
|
) |
|
|
|
if self.classifier == "cell": |
|
if set(self.cell_state_dict.keys()) != set(["state_key", "states"]): |
|
logger.error( |
|
"Invalid keys for cell_state_dict. " |
|
"The cell_state_dict should have only 2 keys: state_key and states" |
|
) |
|
raise |
|
|
|
if self.cell_state_dict["states"] != "all": |
|
if not isinstance(self.cell_state_dict["states"], list): |
|
logger.error( |
|
"States in cell_state_dict should be list of states to model." |
|
) |
|
raise |
|
if len(self.cell_state_dict["states"]) < 2: |
|
logger.error( |
|
"States in cell_state_dict should contain at least 2 states to classify." |
|
) |
|
raise |
|
|
|
if self.classifier == "gene": |
|
if len(self.gene_class_dict.keys()) < 2: |
|
logger.error( |
|
"Gene_class_dict should contain at least 2 gene classes to classify." |
|
) |
|
raise |
|
|
|
def prepare_data( |
|
self, |
|
input_data_file, |
|
output_directory, |
|
output_prefix, |
|
split_id_dict=None, |
|
test_size=0, |
|
attr_to_split=None, |
|
attr_to_balance=None, |
|
max_trials=100, |
|
pval_threshold=0.1, |
|
): |
|
""" |
|
Prepare data for cell state or gene classification. |
|
|
|
**Parameters** |
|
|
|
input_data_file : Path |
|
| Path to directory containing .dataset input |
|
output_directory : Path |
|
| Path to directory where prepared data will be saved |
|
output_prefix : str |
|
| Prefix for output file |
|
split_id_dict : None, dict |
|
| Dictionary of IDs for train and test splits |
|
| Three-item dictionary with keys: attr_key, train, test |
|
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits |
|
| train: list of IDs in the attr_key column to include in the train split |
|
| test: list of IDs in the attr_key column to include in the test split |
|
| For example: {"attr_key": "individual", |
|
| "train": ["patient1", "patient2", "patient3", "patient4"], |
|
| "test": ["patient5", "patient6"]} |
|
test_size : None, float |
|
| Proportion of data to be saved separately and held out for test set |
|
| (e.g. 0.2 if intending hold out 20%) |
|
| The training set will be further split to train / validation in self.validate |
|
| Note: only available for CellClassifiers |
|
attr_to_split : None, str |
|
| Key for attribute on which to split data while balancing potential confounders |
|
| e.g. "patient_id" for splitting by patient while balancing other characteristics |
|
| Note: only available for CellClassifiers |
|
attr_to_balance : None, list |
|
| List of attribute keys on which to balance data while splitting on attr_to_split |
|
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient |
|
| Note: only available for CellClassifiers |
|
max_trials : None, int |
|
| Maximum number of trials of random splitting to try to achieve balanced other attributes |
|
| If no split is found without significant (p<0.05) differences in other attributes, will select best |
|
| Note: only available for CellClassifiers |
|
pval_threshold : None, float |
|
| P-value threshold to use for attribute balancing across splits |
|
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance |
|
""" |
|
|
|
|
|
data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file) |
|
|
|
if self.classifier == "cell": |
|
if "label" in data.features: |
|
logger.error( |
|
"Column name 'label' must be reserved for class IDs. Please rename column." |
|
) |
|
raise |
|
elif self.classifier == "gene": |
|
if "labels" in data.features: |
|
logger.error( |
|
"Column name 'labels' must be reserved for class IDs. Please rename column." |
|
) |
|
raise |
|
|
|
if self.classifier == "cell": |
|
|
|
data = cu.remove_rare( |
|
data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc |
|
) |
|
|
|
data = cu.downsample_and_shuffle( |
|
data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict |
|
) |
|
|
|
data = cu.rename_cols(data, self.cell_state_dict["state_key"]) |
|
|
|
|
|
|
|
|
|
|
|
data, id_class_dict = cu.label_classes( |
|
self.classifier, data, self.gene_class_dict, self.nproc |
|
) |
|
|
|
|
|
id_class_output_path = ( |
|
Path(output_directory) / f"{output_prefix}_id_class_dict" |
|
).with_suffix(".pkl") |
|
with open(id_class_output_path, "wb") as f: |
|
pickle.dump(id_class_dict, f) |
|
|
|
if split_id_dict is not None: |
|
data_dict = dict() |
|
data_dict["train"] = pu.filter_by_dict( |
|
data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc |
|
) |
|
data_dict["test"] = pu.filter_by_dict( |
|
data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc |
|
) |
|
train_data_output_path = ( |
|
Path(output_directory) / f"{output_prefix}_labeled_train" |
|
).with_suffix(".dataset") |
|
test_data_output_path = ( |
|
Path(output_directory) / f"{output_prefix}_labeled_test" |
|
).with_suffix(".dataset") |
|
data_dict["train"].save_to_disk(train_data_output_path) |
|
data_dict["test"].save_to_disk(test_data_output_path) |
|
elif (test_size is not None) and (self.classifier == "cell"): |
|
if 1 > test_size > 0: |
|
data_dict, balance_df = cu.balance_attr_splits( |
|
data, |
|
attr_to_split, |
|
attr_to_balance, |
|
test_size, |
|
max_trials, |
|
pval_threshold, |
|
self.cell_state_dict["state_key"], |
|
self.nproc, |
|
) |
|
balance_df.to_csv( |
|
f"{output_directory}/{output_prefix}_train_test_balance_df.csv" |
|
) |
|
train_data_output_path = ( |
|
Path(output_directory) / f"{output_prefix}_labeled_train" |
|
).with_suffix(".dataset") |
|
test_data_output_path = ( |
|
Path(output_directory) / f"{output_prefix}_labeled_test" |
|
).with_suffix(".dataset") |
|
data_dict["train"].save_to_disk(train_data_output_path) |
|
data_dict["test"].save_to_disk(test_data_output_path) |
|
else: |
|
data_output_path = ( |
|
Path(output_directory) / f"{output_prefix}_labeled" |
|
).with_suffix(".dataset") |
|
data.save_to_disk(data_output_path) |
|
|
|
def train_all_data( |
|
self, |
|
model_directory, |
|
prepared_input_data_file, |
|
id_class_dict_file, |
|
output_directory, |
|
output_prefix, |
|
save_eval_output=True, |
|
): |
|
""" |
|
Train cell state or gene classifier using all data. |
|
|
|
**Parameters** |
|
|
|
model_directory : Path |
|
| Path to directory containing model |
|
prepared_input_data_file : Path |
|
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data |
|
id_class_dict_file : Path |
|
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data |
|
| (dictionary of format: numerical IDs: class_labels) |
|
output_directory : Path |
|
| Path to directory where model and eval data will be saved |
|
output_prefix : str |
|
| Prefix for output files |
|
save_eval_output : bool |
|
| Whether to save cross-fold eval output |
|
| Saves as pickle file of dictionary of eval metrics |
|
|
|
**Output** |
|
|
|
Returns trainer after fine-tuning with all data. |
|
|
|
""" |
|
|
|
|
|
|
|
with open(id_class_dict_file, "rb") as f: |
|
id_class_dict = pickle.load(f) |
|
class_id_dict = {v: k for k, v in id_class_dict.items()} |
|
|
|
|
|
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file) |
|
data = data.shuffle(seed=42) |
|
|
|
|
|
current_date = datetime.datetime.now() |
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" |
|
if output_directory[-1:] != "/": |
|
output_directory = output_directory + "/" |
|
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/" |
|
subprocess.call(f"mkdir {output_dir}", shell=True) |
|
|
|
|
|
num_classes = cu.get_num_classes(id_class_dict) |
|
|
|
if self.classifier == "gene": |
|
targets = pu.flatten_list(self.gene_class_dict.values()) |
|
labels = pu.flatten_list( |
|
[ |
|
[class_id_dict[label]] * len(targets) |
|
for label, targets in self.gene_class_dict.items() |
|
] |
|
) |
|
assert len(targets) == len(labels) |
|
data = cu.prep_gene_classifier_all_data( |
|
data, targets, labels, self.max_ncells, self.nproc |
|
) |
|
|
|
trainer = self.train_classifier( |
|
model_directory, num_classes, data, None, output_dir |
|
) |
|
|
|
return trainer |
|
|
|
def validate( |
|
self, |
|
model_directory, |
|
prepared_input_data_file, |
|
id_class_dict_file, |
|
output_directory, |
|
output_prefix, |
|
split_id_dict=None, |
|
attr_to_split=None, |
|
attr_to_balance=None, |
|
max_trials=100, |
|
pval_threshold=0.1, |
|
save_eval_output=True, |
|
predict_eval=True, |
|
predict_trainer=False, |
|
): |
|
""" |
|
(Cross-)validate cell state or gene classifier. |
|
|
|
**Parameters** |
|
|
|
model_directory : Path |
|
| Path to directory containing model |
|
prepared_input_data_file : Path |
|
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data |
|
id_class_dict_file : Path |
|
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data |
|
| (dictionary of format: numerical IDs: class_labels) |
|
output_directory : Path |
|
| Path to directory where model and eval data will be saved |
|
output_prefix : str |
|
| Prefix for output files |
|
split_id_dict : None, dict |
|
| Dictionary of IDs for train and eval splits |
|
| Three-item dictionary with keys: attr_key, train, eval |
|
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits |
|
| train: list of IDs in the attr_key column to include in the train split |
|
| eval: list of IDs in the attr_key column to include in the eval split |
|
| For example: {"attr_key": "individual", |
|
| "train": ["patient1", "patient2", "patient3", "patient4"], |
|
| "eval": ["patient5", "patient6"]} |
|
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1) |
|
attr_to_split : None, str |
|
| Key for attribute on which to split data while balancing potential confounders |
|
| e.g. "patient_id" for splitting by patient while balancing other characteristics |
|
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1) |
|
attr_to_balance : None, list |
|
| List of attribute keys on which to balance data while splitting on attr_to_split |
|
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient |
|
max_trials : None, int |
|
| Maximum number of trials of random splitting to try to achieve balanced other attribute |
|
| If no split is found without significant (p < pval_threshold) differences in other attributes, will select best |
|
pval_threshold : None, float |
|
| P-value threshold to use for attribute balancing across splits |
|
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance |
|
save_eval_output : bool |
|
| Whether to save cross-fold eval output |
|
| Saves as pickle file of dictionary of eval metrics |
|
predict_eval : bool |
|
| Whether or not to save eval predictions |
|
| Saves as a pickle file of self.evaluate predictions |
|
predict_trainer : bool |
|
| Whether or not to save eval predictions from trainer |
|
| Saves as a pickle file of trainer predictions |
|
""" |
|
|
|
if self.num_crossval_splits == 0: |
|
logger.error("num_crossval_splits must be 1 or 5 to validate.") |
|
raise |
|
|
|
|
|
if self.classifier == "gene": |
|
insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5] |
|
if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0): |
|
logger.error( |
|
f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate." |
|
) |
|
raise |
|
|
|
|
|
|
|
with open(id_class_dict_file, "rb") as f: |
|
id_class_dict = pickle.load(f) |
|
class_id_dict = {v: k for k, v in id_class_dict.items()} |
|
|
|
|
|
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file) |
|
data = data.shuffle(seed=42) |
|
|
|
|
|
current_date = datetime.datetime.now() |
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" |
|
if output_directory[-1:] != "/": |
|
output_directory = output_directory + "/" |
|
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/" |
|
subprocess.call(f"mkdir {output_dir}", shell=True) |
|
|
|
|
|
num_classes = cu.get_num_classes(id_class_dict) |
|
|
|
|
|
results = [] |
|
all_conf_mat = np.zeros((num_classes, num_classes)) |
|
iteration_num = 1 |
|
if self.classifier == "cell": |
|
for i in trange(self.num_crossval_splits): |
|
print( |
|
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n" |
|
) |
|
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}") |
|
if self.num_crossval_splits == 1: |
|
|
|
if split_id_dict is not None: |
|
data_dict = dict() |
|
data_dict["train"] = pu.filter_by_dict( |
|
data, |
|
{split_id_dict["attr_key"]: split_id_dict["train"]}, |
|
self.nproc, |
|
) |
|
data_dict["test"] = pu.filter_by_dict( |
|
data, |
|
{split_id_dict["attr_key"]: split_id_dict["eval"]}, |
|
self.nproc, |
|
) |
|
elif attr_to_split is not None: |
|
data_dict, balance_df = cu.balance_attr_splits( |
|
data, |
|
attr_to_split, |
|
attr_to_balance, |
|
self.eval_size, |
|
max_trials, |
|
pval_threshold, |
|
self.cell_state_dict["state_key"], |
|
self.nproc, |
|
) |
|
|
|
balance_df.to_csv( |
|
f"{output_dir}/{output_prefix}_train_valid_balance_df.csv" |
|
) |
|
else: |
|
data_dict = data.train_test_split( |
|
test_size=self.eval_size, |
|
stratify_by_column=self.stratify_splits_col, |
|
seed=42, |
|
) |
|
train_data = data_dict["train"] |
|
eval_data = data_dict["test"] |
|
else: |
|
|
|
num_cells = len(data) |
|
fifth_cells = num_cells * 0.2 |
|
num_eval = min((self.eval_size * num_cells), fifth_cells) |
|
start = i * fifth_cells |
|
end = start + num_eval |
|
eval_indices = [j for j in range(start, end)] |
|
train_indices = [ |
|
j for j in range(num_cells) if j not in eval_indices |
|
] |
|
eval_data = data.select(eval_indices) |
|
train_data = data.select(train_indices) |
|
trainer = self.train_classifier( |
|
model_directory, |
|
num_classes, |
|
train_data, |
|
eval_data, |
|
ksplit_output_dir, |
|
predict_trainer, |
|
) |
|
result = self.evaluate_model( |
|
trainer.model, |
|
num_classes, |
|
id_class_dict, |
|
eval_data, |
|
predict_eval, |
|
ksplit_output_dir, |
|
output_prefix, |
|
) |
|
results += [result] |
|
all_conf_mat = all_conf_mat + result["conf_mat"] |
|
iteration_num = iteration_num + 1 |
|
|
|
elif self.classifier == "gene": |
|
|
|
targets = pu.flatten_list(self.gene_class_dict.values()) |
|
labels = pu.flatten_list( |
|
[ |
|
[class_id_dict[label]] * len(targets) |
|
for label, targets in self.gene_class_dict.items() |
|
] |
|
) |
|
assert len(targets) == len(labels) |
|
n_splits = int(1 / self.eval_size) |
|
skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True) |
|
|
|
for train_index, eval_index in tqdm(skf.split(targets, labels)): |
|
print( |
|
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n" |
|
) |
|
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}") |
|
|
|
|
|
train_data, eval_data = cu.prep_gene_classifier_split( |
|
data, |
|
targets, |
|
labels, |
|
train_index, |
|
eval_index, |
|
self.max_ncells, |
|
iteration_num, |
|
self.nproc, |
|
) |
|
|
|
trainer = self.train_classifier( |
|
model_directory, |
|
num_classes, |
|
train_data, |
|
eval_data, |
|
ksplit_output_dir, |
|
predict_trainer, |
|
) |
|
result = self.evaluate_model( |
|
trainer.model, |
|
num_classes, |
|
id_class_dict, |
|
eval_data, |
|
predict_eval, |
|
ksplit_output_dir, |
|
output_prefix, |
|
) |
|
results += [result] |
|
all_conf_mat = all_conf_mat + result["conf_mat"] |
|
|
|
if iteration_num == self.num_crossval_splits: |
|
break |
|
iteration_num = iteration_num + 1 |
|
|
|
all_conf_mat_df = pd.DataFrame( |
|
all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values() |
|
) |
|
all_metrics = { |
|
"conf_matrix": all_conf_mat_df, |
|
"macro_f1": [result["macro_f1"] for result in results], |
|
"acc": [result["acc"] for result in results], |
|
} |
|
all_roc_metrics = None |
|
if num_classes == 2: |
|
mean_fpr = np.linspace(0, 1, 100) |
|
all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results] |
|
all_roc_auc = [result["roc_metrics"]["auc"] for result in results] |
|
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results] |
|
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics( |
|
all_tpr, all_roc_auc, all_tpr_wt |
|
) |
|
all_roc_metrics = { |
|
"mean_tpr": mean_tpr, |
|
"mean_fpr": mean_fpr, |
|
"all_roc_auc": all_roc_auc, |
|
"roc_auc": roc_auc, |
|
"roc_auc_sd": roc_auc_sd, |
|
} |
|
all_metrics["all_roc_metrics"] = all_roc_metrics |
|
if save_eval_output is True: |
|
eval_metrics_output_path = ( |
|
Path(output_dir) / f"{output_prefix}_eval_metrics_dict" |
|
).with_suffix(".pkl") |
|
with open(eval_metrics_output_path, "wb") as f: |
|
pickle.dump(all_metrics, f) |
|
|
|
return all_metrics |
|
|
|
def train_classifier( |
|
self, |
|
model_directory, |
|
num_classes, |
|
train_data, |
|
eval_data, |
|
output_directory, |
|
predict=False, |
|
): |
|
""" |
|
Fine-tune model for cell state or gene classification. |
|
|
|
**Parameters** |
|
|
|
model_directory : Path |
|
| Path to directory containing model |
|
num_classes : int |
|
| Number of classes for classifier |
|
train_data : Dataset |
|
| Loaded training .dataset input |
|
| For cell classifier, labels in column "label". |
|
| For gene classifier, labels in column "labels". |
|
eval_data : None, Dataset |
|
| (Optional) Loaded evaluation .dataset input |
|
| For cell classifier, labels in column "label". |
|
| For gene classifier, labels in column "labels". |
|
output_directory : Path |
|
| Path to directory where fine-tuned model will be saved |
|
predict : bool |
|
| Whether or not to save eval predictions from trainer |
|
""" |
|
|
|
|
|
train_data, eval_data = cu.validate_and_clean_cols( |
|
train_data, eval_data, self.classifier |
|
) |
|
|
|
if (self.no_eval is True) and (eval_data is not None): |
|
logger.warning( |
|
"no_eval set to True; model will be trained without evaluation." |
|
) |
|
eval_data = None |
|
|
|
if (self.classifier == "gene") and (predict is True): |
|
logger.warning( |
|
"Predictions during training not currently available for gene classifiers; setting predict to False." |
|
) |
|
predict = False |
|
|
|
|
|
saved_model_test = os.path.join(output_directory, "pytorch_model.bin") |
|
if os.path.isfile(saved_model_test) is True: |
|
logger.error("Model already saved to this designated output directory.") |
|
raise |
|
|
|
subprocess.call(f"mkdir {output_directory}", shell=True) |
|
|
|
|
|
if self.classifier == "cell": |
|
model_type = "CellClassifier" |
|
elif self.classifier == "gene": |
|
model_type = "GeneClassifier" |
|
model = pu.load_model(model_type, num_classes, model_directory, "train") |
|
|
|
def_training_args, def_freeze_layers = cu.get_default_train_args( |
|
model, self.classifier, train_data, output_directory |
|
) |
|
|
|
if self.training_args is not None: |
|
def_training_args.update(self.training_args) |
|
logging_steps = round( |
|
len(train_data) / def_training_args["per_device_train_batch_size"] / 10 |
|
) |
|
def_training_args["logging_steps"] = logging_steps |
|
def_training_args["output_dir"] = output_directory |
|
if eval_data is None: |
|
def_training_args["evaluation_strategy"] = "no" |
|
def_training_args["load_best_model_at_end"] = False |
|
training_args_init = TrainingArguments(**def_training_args) |
|
|
|
if self.freeze_layers is not None: |
|
def_freeze_layers = self.freeze_layers |
|
|
|
if def_freeze_layers > 0: |
|
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers] |
|
for module in modules_to_freeze: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
|
|
if self.classifier == "cell": |
|
data_collator = DataCollatorForCellClassification() |
|
elif self.classifier == "gene": |
|
data_collator = DataCollatorForGeneClassification() |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args_init, |
|
data_collator=data_collator, |
|
train_dataset=train_data, |
|
eval_dataset=eval_data, |
|
compute_metrics=cu.compute_metrics, |
|
) |
|
|
|
|
|
trainer.train() |
|
trainer.save_model(output_directory) |
|
if predict is True: |
|
|
|
predictions = trainer.predict(eval_data) |
|
prediction_output_path = f"{output_directory}/predictions.pkl" |
|
with open(prediction_output_path, "wb") as f: |
|
pickle.dump(predictions, f) |
|
trainer.save_metrics("eval", predictions.metrics) |
|
return trainer |
|
|
|
def evaluate_model( |
|
self, |
|
model, |
|
num_classes, |
|
id_class_dict, |
|
eval_data, |
|
predict=False, |
|
output_directory=None, |
|
output_prefix=None, |
|
): |
|
""" |
|
Evaluate the fine-tuned model. |
|
|
|
**Parameters** |
|
|
|
model : nn.Module |
|
| Loaded fine-tuned model (e.g. trainer.model) |
|
num_classes : int |
|
| Number of classes for classifier |
|
id_class_dict : dict |
|
| Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data |
|
| (dictionary of format: numerical IDs: class_labels) |
|
eval_data : Dataset |
|
| Loaded evaluation .dataset input |
|
predict : bool |
|
| Whether or not to save eval predictions |
|
output_directory : Path |
|
| Path to directory where eval data will be saved |
|
output_prefix : str |
|
| Prefix for output files |
|
""" |
|
|
|
|
|
labels = id_class_dict.keys() |
|
y_pred, y_true, logits_list = eu.classifier_predict( |
|
model, self.classifier, eval_data, self.forward_batch_size |
|
) |
|
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics( |
|
y_pred, y_true, logits_list, num_classes, labels |
|
) |
|
if predict is True: |
|
pred_dict = { |
|
"pred_ids": y_pred, |
|
"label_ids": y_true, |
|
"predictions": logits_list, |
|
} |
|
pred_dict_output_path = ( |
|
Path(output_directory) / f"{output_prefix}_pred_dict" |
|
).with_suffix(".pkl") |
|
with open(pred_dict_output_path, "wb") as f: |
|
pickle.dump(pred_dict, f) |
|
return { |
|
"conf_mat": conf_mat, |
|
"macro_f1": macro_f1, |
|
"acc": acc, |
|
"roc_metrics": roc_metrics, |
|
} |
|
|
|
def evaluate_saved_model( |
|
self, |
|
model_directory, |
|
id_class_dict_file, |
|
test_data_file, |
|
output_directory, |
|
output_prefix, |
|
predict=True, |
|
): |
|
""" |
|
Evaluate the fine-tuned model. |
|
|
|
**Parameters** |
|
|
|
model_directory : Path |
|
| Path to directory containing model |
|
id_class_dict_file : Path |
|
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data |
|
| (dictionary of format: numerical IDs: class_labels) |
|
test_data_file : Path |
|
| Path to directory containing test .dataset |
|
output_directory : Path |
|
| Path to directory where eval data will be saved |
|
output_prefix : str |
|
| Prefix for output files |
|
predict : bool |
|
| Whether or not to save eval predictions |
|
""" |
|
|
|
|
|
with open(id_class_dict_file, "rb") as f: |
|
id_class_dict = pickle.load(f) |
|
|
|
|
|
num_classes = cu.get_num_classes(id_class_dict) |
|
|
|
|
|
test_data = pu.load_and_filter(None, self.nproc, test_data_file) |
|
|
|
|
|
if self.classifier == "cell": |
|
model_type = "CellClassifier" |
|
elif self.classifier == "gene": |
|
model_type = "GeneClassifier" |
|
model = pu.load_model(model_type, num_classes, model_directory, "eval") |
|
|
|
|
|
results = self.evaluate_model( |
|
model, |
|
num_classes, |
|
id_class_dict, |
|
test_data, |
|
predict=predict, |
|
output_directory=output_directory, |
|
output_prefix=output_prefix, |
|
) |
|
|
|
all_conf_mat_df = pd.DataFrame( |
|
results["conf_mat"], |
|
columns=id_class_dict.values(), |
|
index=id_class_dict.values(), |
|
) |
|
all_metrics = { |
|
"conf_matrix": all_conf_mat_df, |
|
"macro_f1": results["macro_f1"], |
|
"acc": results["acc"], |
|
} |
|
all_roc_metrics = None |
|
if num_classes == 2: |
|
mean_fpr = np.linspace(0, 1, 100) |
|
all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results] |
|
all_roc_auc = [result["roc_metrics"]["auc"] for result in results] |
|
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results] |
|
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics( |
|
all_tpr, all_roc_auc, all_tpr_wt |
|
) |
|
all_roc_metrics = { |
|
"mean_tpr": mean_tpr, |
|
"mean_fpr": mean_fpr, |
|
"all_roc_auc": all_roc_auc, |
|
} |
|
all_metrics["all_roc_metrics"] = all_roc_metrics |
|
test_metrics_output_path = ( |
|
Path(output_directory) / f"{output_prefix}_test_metrics_dict" |
|
).with_suffix(".pkl") |
|
with open(test_metrics_output_path, "wb") as f: |
|
pickle.dump(all_metrics, f) |
|
|
|
return all_metrics |
|
|
|
def plot_conf_mat( |
|
self, |
|
conf_mat_dict, |
|
output_directory, |
|
output_prefix, |
|
custom_class_order=None, |
|
): |
|
""" |
|
Plot confusion matrix results of evaluating the fine-tuned model. |
|
|
|
**Parameters** |
|
|
|
conf_mat_dict : dict |
|
| Dictionary of model_name : confusion_matrix_DataFrame |
|
| (all_metrics["conf_matrix"] from self.validate) |
|
output_directory : Path |
|
| Path to directory where plots will be saved |
|
output_prefix : str |
|
| Prefix for output file |
|
custom_class_order : None, list |
|
| List of classes in custom order for plots. |
|
| Same order will be used for all models. |
|
""" |
|
|
|
for model_name in conf_mat_dict.keys(): |
|
eu.plot_confusion_matrix( |
|
conf_mat_dict[model_name], |
|
model_name, |
|
output_directory, |
|
output_prefix, |
|
custom_class_order, |
|
) |
|
|
|
def plot_roc( |
|
self, |
|
roc_metric_dict, |
|
model_style_dict, |
|
title, |
|
output_directory, |
|
output_prefix, |
|
): |
|
""" |
|
Plot ROC curve results of evaluating the fine-tuned model. |
|
|
|
**Parameters** |
|
|
|
roc_metric_dict : dict |
|
| Dictionary of model_name : roc_metrics |
|
| (all_metrics["all_roc_metrics"] from self.validate) |
|
model_style_dict : dict[dict] |
|
| Dictionary of model_name : dictionary of style_attribute : style |
|
| where style includes color and linestyle |
|
| e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...} |
|
title : str |
|
| Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors') |
|
output_directory : Path |
|
| Path to directory where plots will be saved |
|
output_prefix : str |
|
| Prefix for output file |
|
""" |
|
|
|
eu.plot_ROC( |
|
roc_metric_dict, model_style_dict, title, output_directory, output_prefix |
|
) |
|
|
|
def plot_predictions( |
|
self, |
|
predictions_file, |
|
id_class_dict_file, |
|
title, |
|
output_directory, |
|
output_prefix, |
|
custom_class_order=None, |
|
kwargs_dict=None, |
|
): |
|
""" |
|
Plot prediction results of evaluating the fine-tuned model. |
|
|
|
**Parameters** |
|
|
|
predictions_file : path |
|
| Path of model predictions output to plot |
|
| (saved output from self.validate if predict=True) |
|
| (or saved output from self.evaluate_saved_model) |
|
id_class_dict_file : Path |
|
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data |
|
| (dictionary of format: numerical IDs: class_labels) |
|
title : str |
|
| Title for legend containing class labels. |
|
output_directory : Path |
|
| Path to directory where plots will be saved |
|
output_prefix : str |
|
| Prefix for output file |
|
custom_class_order : None, list |
|
| List of classes in custom order for plots. |
|
| Same order will be used for all models. |
|
kwargs_dict : None, dict |
|
| Dictionary of kwargs to pass to plotting function. |
|
""" |
|
|
|
with open(predictions_file, "rb") as f: |
|
predictions = pickle.load(f) |
|
|
|
|
|
with open(id_class_dict_file, "rb") as f: |
|
id_class_dict = pickle.load(f) |
|
|
|
if isinstance(predictions, dict): |
|
if all( |
|
[ |
|
key in predictions.keys() |
|
for key in ["pred_ids", "label_ids", "predictions"] |
|
] |
|
): |
|
|
|
predictions_logits = np.array(predictions["predictions"]) |
|
true_ids = predictions["label_ids"] |
|
else: |
|
|
|
predictions_logits = predictions.predictions |
|
true_ids = predictions.label_ids |
|
|
|
num_classes = len(id_class_dict.keys()) |
|
num_predict_classes = predictions_logits.shape[1] |
|
assert num_classes == num_predict_classes |
|
classes = id_class_dict.values() |
|
true_labels = [id_class_dict[idx] for idx in true_ids] |
|
predictions_df = pd.DataFrame(predictions_logits, columns=classes) |
|
if custom_class_order is not None: |
|
predictions_df = predictions_df.reindex(columns=custom_class_order) |
|
predictions_df["true"] = true_labels |
|
custom_dict = dict(zip(classes, [i for i in range(len(classes))])) |
|
if custom_class_order is not None: |
|
custom_dict = dict( |
|
zip(custom_class_order, [i for i in range(len(custom_class_order))]) |
|
) |
|
predictions_df = predictions_df.sort_values( |
|
by=["true"], key=lambda x: x.map(custom_dict) |
|
) |
|
|
|
eu.plot_predictions( |
|
predictions_df, title, output_directory, output_prefix, kwargs_dict |
|
) |
|
|