Christina Theodoris commited on
Commit
d154fee
·
1 Parent(s): 1786b44

Add function to extract and plot cell embeddings

Browse files
examples/extract_and_plot_cell_embeddings.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
geneformer/__init__.py CHANGED
@@ -7,5 +7,6 @@ from .tokenizer import TranscriptomeTokenizer
7
  from .pretrainer import GeneformerPretrainer
8
  from .collator_for_classification import DataCollatorForGeneClassification
9
  from .collator_for_classification import DataCollatorForCellClassification
 
10
  from .in_silico_perturber import InSilicoPerturber
11
  from .in_silico_perturber_stats import InSilicoPerturberStats
 
7
  from .pretrainer import GeneformerPretrainer
8
  from .collator_for_classification import DataCollatorForGeneClassification
9
  from .collator_for_classification import DataCollatorForCellClassification
10
+ from .emb_extractor import EmbExtractor
11
  from .in_silico_perturber import InSilicoPerturber
12
  from .in_silico_perturber_stats import InSilicoPerturberStats
geneformer/emb_extractor.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer embedding extractor.
3
+
4
+ Usage:
5
+ from geneformer import EmbExtractor
6
+ embex = EmbExtractor(model_type="CellClassifier",
7
+ num_classes=3,
8
+ emb_mode="cell",
9
+ cell_emb_style="mean_pool",
10
+ filter_data={"cell_type":["cardiomyocyte"]},
11
+ max_ncells=1000,
12
+ max_ncells_to_plot=1000,
13
+ emb_layer=-1,
14
+ emb_label=["disease","cell_type"],
15
+ labels_to_plot=["disease","cell_type"],
16
+ forward_batch_size=100,
17
+ nproc=16)
18
+ embs = embex.extract_embs("path/to/model",
19
+ "path/to/input_data",
20
+ "path/to/output_directory",
21
+ "output_prefix")
22
+ embex.plot_embs(embs=embs,
23
+ plot_style="heatmap",
24
+ output_directory="path/to/output_directory",
25
+ output_prefix="output_prefix")
26
+
27
+ """
28
+
29
+ # imports
30
+ import logging
31
+ import anndata
32
+ import matplotlib.pyplot as plt
33
+ import numpy as np
34
+ import pandas as pd
35
+ import pickle
36
+ import scanpy as sc
37
+ import seaborn as sns
38
+ import torch
39
+ from collections import Counter
40
+ from pathlib import Path
41
+ from tqdm.notebook import trange
42
+ from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
43
+
44
+ from .tokenizer import TOKEN_DICTIONARY_FILE
45
+
46
+ from .in_silico_perturber import load_and_filter, \
47
+ downsample_and_sort, \
48
+ load_model, \
49
+ quant_layers, \
50
+ downsample_and_sort, \
51
+ pad_tensor_list, \
52
+ get_model_input_size
53
+
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+ # get cell embeddings excluding padding
58
+ def mean_nonpadding_embs(embs, original_lens):
59
+ # mask based on padding lengths
60
+ mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
61
+
62
+ # extend mask dimensions to match the embeddings tensor
63
+ mask = mask.unsqueeze(2).expand_as(embs)
64
+
65
+ # use the mask to zero out the embeddings in padded areas
66
+ masked_embs = embs * mask.float()
67
+
68
+ # sum and divide by the lengths to get the mean of non-padding embs
69
+ mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
70
+ return mean_embs
71
+
72
+ # average embedding position of goal cell states
73
+ def get_embs(model,
74
+ filtered_input_data,
75
+ emb_mode,
76
+ layer_to_quant,
77
+ pad_token_id,
78
+ forward_batch_size):
79
+
80
+ model_input_size = get_model_input_size(model)
81
+ total_batch_length = len(filtered_input_data)
82
+ if ((total_batch_length-1)/forward_batch_size).is_integer():
83
+ forward_batch_size = forward_batch_size-1
84
+
85
+ embs_list = []
86
+ for i in trange(0, total_batch_length, forward_batch_size):
87
+ max_range = min(i+forward_batch_size, total_batch_length)
88
+
89
+ minibatch = filtered_input_data.select([i for i in range(i, max_range)])
90
+ max_len = max(minibatch["length"])
91
+ original_lens = torch.tensor(minibatch["length"]).to("cuda")
92
+ minibatch.set_format(type="torch")
93
+
94
+ input_data_minibatch = minibatch["input_ids"]
95
+ input_data_minibatch = pad_tensor_list(input_data_minibatch,
96
+ max_len,
97
+ pad_token_id,
98
+ model_input_size)
99
+
100
+ with torch.no_grad():
101
+ outputs = model(
102
+ input_ids = input_data_minibatch.to("cuda")
103
+ )
104
+
105
+ embs_i = outputs.hidden_states[layer_to_quant]
106
+
107
+ if emb_mode == "cell":
108
+ mean_embs = mean_nonpadding_embs(embs_i, original_lens)
109
+ embs_list += [mean_embs]
110
+
111
+ del outputs
112
+ del minibatch
113
+ del input_data_minibatch
114
+ del embs_i
115
+ del mean_embs
116
+ torch.cuda.empty_cache()
117
+
118
+ embs_stack = torch.cat(embs_list)
119
+ return embs_stack
120
+
121
+ def label_embs(embs, downsampled_data, emb_labels):
122
+ embs_df = pd.DataFrame(embs.cpu())
123
+ if emb_labels is not None:
124
+ for label in emb_labels:
125
+ emb_label = downsampled_data[label]
126
+ embs_df[label] = emb_label
127
+ return embs_df
128
+
129
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
130
+ only_embs_df = embs_df.iloc[:,:emb_dims]
131
+ only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
132
+ only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str)
133
+ vars_dict = {"embs": only_embs_df.columns}
134
+ obs_dict = {"cell_id": list(only_embs_df.index),
135
+ f"{label}": list(embs_df[label])}
136
+ adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
137
+ sc.tl.pca(adata, svd_solver='arpack')
138
+ sc.pp.neighbors(adata)
139
+ sc.tl.umap(adata)
140
+ sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3)
141
+ sns.set_style("white")
142
+ default_kwargs_dict = {"palette":"Set2", "size":200}
143
+ if kwargs_dict is not None:
144
+ default_kwargs_dict.update(kwargs_dict)
145
+
146
+ sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
147
+
148
+
149
+ def gen_heatmap_class_colors(labels, df):
150
+ pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
151
+ lut = dict(zip(map(str, Counter(labels).keys()), pal))
152
+ colors = pd.Series(labels, index=df.index).map(lut)
153
+ return colors
154
+
155
+ def gen_heatmap_class_dict(classes, label_colors_series):
156
+ class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series})
157
+ class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
158
+ return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"]))
159
+
160
+ def make_colorbar(embs_df, label):
161
+
162
+ labels = list(embs_df[label])
163
+
164
+ cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
165
+ label_colors = pd.DataFrame(cell_type_colors, columns=[label])
166
+
167
+ for i,row in label_colors.iterrows():
168
+ colors=row[0]
169
+ if len(colors)!=3 or any(np.isnan(colors)):
170
+ print(i,colors)
171
+
172
+ label_colors.isna().sum()
173
+
174
+ # create dictionary for colors and classes
175
+ label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
176
+ return label_colors, label_color_dict
177
+
178
+ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
179
+ sns.set_style("white")
180
+ sns.set(font_scale=2)
181
+ plt.figure(figsize=(15, 15), dpi=150)
182
+ label_colors, label_color_dict = make_colorbar(embs_df, label)
183
+
184
+ default_kwargs_dict = {"row_cluster": True,
185
+ "col_cluster": True,
186
+ "row_colors": label_colors,
187
+ "standard_scale": 1,
188
+ "linewidths": 0,
189
+ "xticklabels": False,
190
+ "yticklabels": False,
191
+ "figsize": (15,15),
192
+ "center": 0,
193
+ "cmap": "magma"}
194
+
195
+ if kwargs_dict is not None:
196
+ default_kwargs_dict.update(kwargs_dict)
197
+ g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict)
198
+
199
+ plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
200
+
201
+ for label in list(label_color_dict.keys()):
202
+ g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label], label=label, linewidth=0)
203
+
204
+ # g.ax_col_dendrogram.set_visible(False)
205
+ # g.ax_col_dendrogram.set_xlim([0,0])
206
+ l1 = g.ax_col_dendrogram.legend(title=f"{label}",
207
+ loc="lower center",
208
+ ncol=4,
209
+ bbox_to_anchor=(0.5, 1),
210
+ facecolor="white")
211
+
212
+ plt.savefig(output_file, bbox_inches='tight')
213
+
214
+ class EmbExtractor:
215
+ valid_option_dict = {
216
+ "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
217
+ "num_classes": {int},
218
+ "emb_mode": {"cell","gene"},
219
+ "cell_emb_style": {"mean_pool"},
220
+ "filter_data": {None, dict},
221
+ "max_ncells": {None, int},
222
+ "emb_layer": {-1, 0},
223
+ "emb_label": {None, list},
224
+ "labels_to_plot": {None, list},
225
+ "forward_batch_size": {int},
226
+ "nproc": {int},
227
+ }
228
+ def __init__(
229
+ self,
230
+ model_type="Pretrained",
231
+ num_classes=0,
232
+ emb_mode="cell",
233
+ cell_emb_style="mean_pool",
234
+ filter_data=None,
235
+ max_ncells=1000,
236
+ emb_layer=-1,
237
+ emb_label=None,
238
+ labels_to_plot=None,
239
+ forward_batch_size=100,
240
+ nproc=4,
241
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
242
+ ):
243
+ """
244
+ Initialize embedding extractor.
245
+
246
+ Parameters
247
+ ----------
248
+ model_type : {"Pretrained","GeneClassifier","CellClassifier"}
249
+ Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
250
+ num_classes : int
251
+ If model is a gene or cell classifier, specify number of classes it was trained to classify.
252
+ For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
253
+ emb_mode : {"cell","gene"}
254
+ Whether to output cell or gene embeddings.
255
+ cell_emb_style : "mean_pool"
256
+ Method for summarizing cell embeddings.
257
+ Currently only option is mean pooling of gene embeddings for given cell.
258
+ filter_data : None, dict
259
+ Default is to extract embeddings from all input data.
260
+ Otherwise, dictionary specifying .dataset column name and list of values to filter by.
261
+ max_ncells : None, int
262
+ Maximum number of cells to extract embeddings from.
263
+ Default is 1000 cells randomly sampled from input data.
264
+ If None, will extract embeddings from all cells.
265
+ emb_layer : {-1, 0}
266
+ Embedding layer to extract.
267
+ The last layer is most specifically weighted to optimize the given learning objective.
268
+ Generally, it is best to extract the 2nd to last layer to get a more general representation.
269
+ -1: 2nd to last layer
270
+ 0: last layer
271
+ emb_label : None, list
272
+ List of column name(s) in .dataset to add as labels to embedding output.
273
+ labels_to_plot : None, list
274
+ Cell labels to plot.
275
+ Shown as color bar in heatmap.
276
+ Shown as cell color in umap.
277
+ Plotting umap requires labels to plot.
278
+ forward_batch_size : int
279
+ Batch size for forward pass.
280
+ nproc : int
281
+ Number of CPU processes to use.
282
+ token_dictionary_file : Path
283
+ Path to pickle file containing token dictionary (Ensembl ID:token).
284
+ """
285
+
286
+ self.model_type = model_type
287
+ self.num_classes = num_classes
288
+ self.emb_mode = emb_mode
289
+ self.cell_emb_style = cell_emb_style
290
+ self.filter_data = filter_data
291
+ self.max_ncells = max_ncells
292
+ self.emb_layer = emb_layer
293
+ self.emb_label = emb_label
294
+ self.labels_to_plot = labels_to_plot
295
+ self.forward_batch_size = forward_batch_size
296
+ self.nproc = nproc
297
+
298
+ self.validate_options()
299
+
300
+ # load token dictionary (Ensembl IDs:token)
301
+ with open(token_dictionary_file, "rb") as f:
302
+ self.gene_token_dict = pickle.load(f)
303
+
304
+ self.pad_token_id = self.gene_token_dict.get("<pad>")
305
+
306
+
307
+ def validate_options(self):
308
+
309
+ # confirm arguments are within valid options and compatible with each other
310
+ for attr_name,valid_options in self.valid_option_dict.items():
311
+ attr_value = self.__dict__[attr_name]
312
+ if type(attr_value) not in {list, dict}:
313
+ if attr_value in valid_options:
314
+ continue
315
+ valid_type = False
316
+ for option in valid_options:
317
+ if (option in [int,list,dict]) and isinstance(attr_value, option):
318
+ valid_type = True
319
+ break
320
+ if valid_type:
321
+ continue
322
+ logger.error(
323
+ f"Invalid option for {attr_name}. " \
324
+ f"Valid options for {attr_name}: {valid_options}"
325
+ )
326
+ raise
327
+
328
+ if self.filter_data is not None:
329
+ for key,value in self.filter_data.items():
330
+ if type(value) != list:
331
+ self.filter_data[key] = [value]
332
+ logger.warning(
333
+ "Values in filter_data dict must be lists. " \
334
+ f"Changing {key} value to list ([{value}]).")
335
+
336
+ def extract_embs(self,
337
+ model_directory,
338
+ input_data_file,
339
+ output_directory,
340
+ output_prefix):
341
+ """
342
+ Extract embeddings from input data and save as results in output_directory.
343
+
344
+ Parameters
345
+ ----------
346
+ model_directory : Path
347
+ Path to directory containing model
348
+ input_data_file : Path
349
+ Path to directory containing .dataset inputs
350
+ output_directory : Path
351
+ Path to directory where embedding data will be saved as csv
352
+ output_prefix : str
353
+ Prefix for output file
354
+ """
355
+
356
+ filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
357
+ downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
358
+ model = load_model(self.model_type, self.num_classes, model_directory)
359
+ layer_to_quant = quant_layers(model)+self.emb_layer
360
+ embs = get_embs(model,
361
+ downsampled_data,
362
+ self.emb_mode,
363
+ layer_to_quant,
364
+ self.pad_token_id,
365
+ self.forward_batch_size)
366
+ embs_df = label_embs(embs, downsampled_data, self.emb_label)
367
+
368
+ # save embeddings to output_path
369
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
370
+ embs_df.to_csv(output_path)
371
+
372
+ return embs_df
373
+
374
+ def plot_embs(self,
375
+ embs,
376
+ plot_style,
377
+ output_directory,
378
+ output_prefix,
379
+ max_ncells_to_plot=1000,
380
+ kwargs_dict=None):
381
+
382
+ """
383
+ Plot embeddings, coloring by provided labels.
384
+
385
+ Parameters
386
+ ----------
387
+ embs : pandas.core.frame.DataFrame
388
+ Pandas dataframe containing embeddings output from extract_embs
389
+ plot_style : str
390
+ Style of plot: "heatmap" or "umap"
391
+ output_directory : Path
392
+ Path to directory where plots will be saved as pdf
393
+ output_prefix : str
394
+ Prefix for output file
395
+ max_ncells_to_plot : None, int
396
+ Maximum number of cells to plot.
397
+ Default is 1000 cells randomly sampled from embeddings.
398
+ If None, will plot embeddings from all cells.
399
+ kwargs_dict : dict
400
+ Dictionary of kwargs to pass to plotting function.
401
+ """
402
+
403
+ if plot_style not in ["heatmap","umap"]:
404
+ logger.error(
405
+ "Invalid option for 'plot_style'. " \
406
+ "Valid options: {'heatmap','umap'}"
407
+ )
408
+ raise
409
+
410
+ if (plot_style == "umap") and (self.labels_to_plot is None):
411
+ logger.error(
412
+ "Plotting UMAP requires 'labels_to_plot'. "
413
+ )
414
+ raise
415
+
416
+ if max_ncells_to_plot > self.max_ncells:
417
+ max_ncells_to_plot = self.max_ncells
418
+ logger.warning(
419
+ "max_ncells_to_plot must be <= max_ncells. " \
420
+ f"Changing max_ncells_to_plot to {self.max_ncells}.")
421
+
422
+ if (max_ncells_to_plot is not None) \
423
+ and (max_ncells_to_plot < self.max_ncells):
424
+ embs = embs.sample(max_ncells_to_plot, axis=0)
425
+
426
+ if self.emb_label is None:
427
+ label_len = 0
428
+ else:
429
+ label_len = len(self.emb_label)
430
+
431
+ emb_dims = embs.shape[1] - label_len
432
+
433
+ if self.emb_label is None:
434
+ emb_labels = None
435
+ else:
436
+ emb_labels = embs.columns[emb_dims:]
437
+
438
+ if plot_style == "umap":
439
+ for label in self.labels_to_plot:
440
+ if label not in emb_labels:
441
+ logger.warning(
442
+ f"Label {label} from labels_to_plot " \
443
+ f"not present in provided embeddings dataframe.")
444
+ continue
445
+ output_prefix_label = "_" + output_prefix + f"_umap_{label}"
446
+ output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
447
+ plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
448
+
449
+ if plot_style == "heatmap":
450
+ for label in self.labels_to_plot:
451
+ if label not in emb_labels:
452
+ logger.warning(
453
+ f"Label {label} from labels_to_plot " \
454
+ f"not present in provided embeddings dataframe.")
455
+ continue
456
+ output_prefix_label = output_prefix + f"_heatmap_{label}"
457
+ output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
458
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
459
+
geneformer/in_silico_perturber.py CHANGED
@@ -41,6 +41,43 @@ from .tokenizer import TOKEN_DICTIONARY_FILE
41
 
42
  logger = logging.getLogger(__name__)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def quant_layers(model):
45
  layer_nums = []
46
  for name, parameter in model.named_parameters():
@@ -726,8 +763,8 @@ class InSilicoPerturber:
726
  Prefix for output files
727
  """
728
 
729
- filtered_input_data = self.load_and_filter(input_data_file)
730
- model = self.load_model(model_directory)
731
  layer_to_quant = quant_layers(model)+self.emb_layer
732
 
733
  if self.cell_states_to_model is None:
@@ -755,42 +792,6 @@ class InSilicoPerturber:
755
  state_embs_dict,
756
  output_directory,
757
  output_prefix)
758
-
759
- # load data and filter by defined criteria
760
- def load_and_filter(self, input_data_file):
761
- data = load_from_disk(input_data_file)
762
- if self.filter_data is not None:
763
- for key,value in self.filter_data.items():
764
- def filter_data_by_criteria(example):
765
- return example[key] in value
766
- data = data.filter(filter_data_by_criteria, num_proc=self.nproc)
767
- if len(data) == 0:
768
- logger.error(
769
- "No cells remain after filtering. Check filtering criteria.")
770
- raise
771
- data_shuffled = data.shuffle(seed=42)
772
- return data_shuffled
773
-
774
- # load model to GPU
775
- def load_model(self, model_directory):
776
- if self.model_type == "Pretrained":
777
- model = BertForMaskedLM.from_pretrained(model_directory,
778
- output_hidden_states=True,
779
- output_attentions=False)
780
- elif self.model_type == "GeneClassifier":
781
- model = BertForTokenClassification.from_pretrained(model_directory,
782
- num_labels=self.num_classes,
783
- output_hidden_states=True,
784
- output_attentions=False)
785
- elif self.model_type == "CellClassifier":
786
- model = BertForSequenceClassification.from_pretrained(model_directory,
787
- num_labels=self.num_classes,
788
- output_hidden_states=True,
789
- output_attentions=False)
790
- # put the model in eval mode for fwd pass
791
- model.eval()
792
- model = model.to("cuda:0")
793
- return model
794
 
795
  # determine effect of perturbation on other genes
796
  def in_silico_perturb(self,
 
41
 
42
  logger = logging.getLogger(__name__)
43
 
44
+
45
+ # load data and filter by defined criteria
46
+ def load_and_filter(filter_data, nproc, input_data_file):
47
+ data = load_from_disk(input_data_file)
48
+ if filter_data is not None:
49
+ for key,value in filter_data.items():
50
+ def filter_data_by_criteria(example):
51
+ return example[key] in value
52
+ data = data.filter(filter_data_by_criteria, num_proc=nproc)
53
+ if len(data) == 0:
54
+ logger.error(
55
+ "No cells remain after filtering. Check filtering criteria.")
56
+ raise
57
+ data_shuffled = data.shuffle(seed=42)
58
+ return data_shuffled
59
+
60
+ # load model to GPU
61
+ def load_model(model_type, num_classes, model_directory):
62
+ if model_type == "Pretrained":
63
+ model = BertForMaskedLM.from_pretrained(model_directory,
64
+ output_hidden_states=True,
65
+ output_attentions=False)
66
+ elif model_type == "GeneClassifier":
67
+ model = BertForTokenClassification.from_pretrained(model_directory,
68
+ num_labels=num_classes,
69
+ output_hidden_states=True,
70
+ output_attentions=False)
71
+ elif model_type == "CellClassifier":
72
+ model = BertForSequenceClassification.from_pretrained(model_directory,
73
+ num_labels=num_classes,
74
+ output_hidden_states=True,
75
+ output_attentions=False)
76
+ # put the model in eval mode for fwd pass
77
+ model.eval()
78
+ model = model.to("cuda:0")
79
+ return model
80
+
81
  def quant_layers(model):
82
  layer_nums = []
83
  for name, parameter in model.named_parameters():
 
763
  Prefix for output files
764
  """
765
 
766
+ filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
767
+ model = load_model(self.model_type, self.num_classes, model_directory)
768
  layer_to_quant = quant_layers(model)+self.emb_layer
769
 
770
  if self.cell_states_to_model is None:
 
792
  state_embs_dict,
793
  output_directory,
794
  output_prefix)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
 
796
  # determine effect of perturbation on other genes
797
  def in_silico_perturb(self,