Christina Theodoris
commited on
Commit
•
2f25aea
1
Parent(s):
fd93ebf
Add functions for extracting gene embeddings, move state_embs_dict outside isp, fix bugs in isp
Browse files- examples/in_silico_perturbation.ipynb +50 -11
- geneformer/emb_extractor.py +477 -190
- geneformer/in_silico_perturber.py +631 -1035
- geneformer/in_silico_perturber_stats.py +631 -313
- geneformer/perturber_utils.py +698 -0
- setup.py +1 -1
examples/in_silico_perturbation.ipynb
CHANGED
@@ -8,21 +8,62 @@
|
|
8 |
"outputs": [],
|
9 |
"source": [
|
10 |
"from geneformer import InSilicoPerturber\n",
|
11 |
-
"from geneformer import InSilicoPerturberStats"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
]
|
13 |
},
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": null,
|
17 |
-
"id": "
|
18 |
"metadata": {
|
19 |
"tags": []
|
20 |
},
|
21 |
"outputs": [],
|
22 |
"source": [
|
23 |
-
"# in silico perturbation in deletion mode to determine genes whose \n",
|
24 |
-
"# deletion in the dilated cardiomyopathy (dcm) state significantly shifts\n",
|
25 |
-
"# the embedding towards non-failing (nf) state\n",
|
26 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
27 |
" perturb_rank_shift=None,\n",
|
28 |
" genes_to_perturb=\"all\",\n",
|
@@ -32,11 +73,9 @@
|
|
32 |
" num_classes=3,\n",
|
33 |
" emb_mode=\"cell\",\n",
|
34 |
" cell_emb_style=\"mean_pool\",\n",
|
35 |
-
" filter_data=
|
36 |
-
" cell_states_to_model=
|
37 |
-
"
|
38 |
-
" 'goal_state': 'nf', \n",
|
39 |
-
" 'alt_states': ['hcm']},\n",
|
40 |
" max_ncells=2000,\n",
|
41 |
" emb_layer=0,\n",
|
42 |
" forward_batch_size=400,\n",
|
@@ -68,7 +107,7 @@
|
|
68 |
" genes_perturbed=\"all\",\n",
|
69 |
" combos=0,\n",
|
70 |
" anchor_gene=None,\n",
|
71 |
-
" cell_states_to_model=
|
72 |
]
|
73 |
},
|
74 |
{
|
|
|
8 |
"outputs": [],
|
9 |
"source": [
|
10 |
"from geneformer import InSilicoPerturber\n",
|
11 |
+
"from geneformer import InSilicoPerturberStats\n",
|
12 |
+
"from geneformer import EmbExtractor"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "markdown",
|
17 |
+
"id": "cbd6851c-060e-4967-b816-e605ffe58b23",
|
18 |
+
"metadata": {
|
19 |
+
"tags": []
|
20 |
+
},
|
21 |
+
"source": [
|
22 |
+
"### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"id": "c53e98cd-c603-4878-82ba-db471181bb55",
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"# first obtain start, goal, and alt embedding positions\n",
|
33 |
+
"# this function was changed to be separate from perturb_data\n",
|
34 |
+
"# to avoid repeating calcuations when parallelizing perturb_data\n",
|
35 |
+
"cell_states_to_model={\"state_key\": \"disease\", \n",
|
36 |
+
" \"start_state\": \"dcm\", \n",
|
37 |
+
" \"goal_state\": \"nf\", \n",
|
38 |
+
" \"alt_states\": [\"hcm\"]}\n",
|
39 |
+
"\n",
|
40 |
+
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
41 |
+
"\n",
|
42 |
+
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
43 |
+
" num_classes=3,\n",
|
44 |
+
" filter_data=filter_data_dict,\n",
|
45 |
+
" max_ncells=1000,\n",
|
46 |
+
" emb_layer=0,\n",
|
47 |
+
" summary_stat=\"exact_mean\",\n",
|
48 |
+
" forward_batch_size=256,\n",
|
49 |
+
" nproc=16)\n",
|
50 |
+
"\n",
|
51 |
+
"state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
|
52 |
+
" \"path/to/model\",\n",
|
53 |
+
" \"path/to/input_data\",\n",
|
54 |
+
" \"path/to/output_directory\",\n",
|
55 |
+
" \"output_prefix\")"
|
56 |
]
|
57 |
},
|
58 |
{
|
59 |
"cell_type": "code",
|
60 |
"execution_count": null,
|
61 |
+
"id": "981e1190-62da-4543-b7d3-6e2a2d6a6d56",
|
62 |
"metadata": {
|
63 |
"tags": []
|
64 |
},
|
65 |
"outputs": [],
|
66 |
"source": [
|
|
|
|
|
|
|
67 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
68 |
" perturb_rank_shift=None,\n",
|
69 |
" genes_to_perturb=\"all\",\n",
|
|
|
73 |
" num_classes=3,\n",
|
74 |
" emb_mode=\"cell\",\n",
|
75 |
" cell_emb_style=\"mean_pool\",\n",
|
76 |
+
" filter_data=filter_data_dict,\n",
|
77 |
+
" cell_states_to_model=cell_states_to_model,\n",
|
78 |
+
" state_embs_dict=state_embs_dict,\n",
|
|
|
|
|
79 |
" max_ncells=2000,\n",
|
80 |
" emb_layer=0,\n",
|
81 |
" forward_batch_size=400,\n",
|
|
|
107 |
" genes_perturbed=\"all\",\n",
|
108 |
" combos=0,\n",
|
109 |
" anchor_gene=None,\n",
|
110 |
+
" cell_states_to_model=cell_states_to_model)"
|
111 |
]
|
112 |
},
|
113 |
{
|
geneformer/emb_extractor.py
CHANGED
@@ -7,66 +7,62 @@ Usage:
|
|
7 |
num_classes=3,
|
8 |
emb_mode="cell",
|
9 |
cell_emb_style="mean_pool",
|
|
|
10 |
filter_data={"cell_type":["cardiomyocyte"]},
|
11 |
max_ncells=1000,
|
12 |
max_ncells_to_plot=1000,
|
13 |
emb_layer=-1,
|
14 |
emb_label=["disease","cell_type"],
|
15 |
labels_to_plot=["disease","cell_type"],
|
16 |
-
forward_batch_size=100,
|
17 |
nproc=16,
|
18 |
summary_stat=None)
|
19 |
embs = embex.extract_embs("path/to/model",
|
20 |
"path/to/input_data",
|
21 |
"path/to/output_directory",
|
22 |
"output_prefix")
|
23 |
-
embex.plot_embs(embs=embs,
|
24 |
plot_style="heatmap",
|
25 |
output_directory="path/to/output_directory",
|
26 |
output_prefix="output_prefix")
|
27 |
-
|
28 |
"""
|
29 |
|
30 |
# imports
|
31 |
import logging
|
|
|
|
|
|
|
|
|
32 |
import anndata
|
33 |
import matplotlib.pyplot as plt
|
34 |
import numpy as np
|
35 |
import pandas as pd
|
36 |
-
import pickle
|
37 |
-
from tdigest import TDigest
|
38 |
import scanpy as sc
|
39 |
import seaborn as sns
|
40 |
import torch
|
41 |
-
from
|
42 |
-
from pathlib import Path
|
43 |
from tqdm.auto import trange
|
44 |
-
from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
|
45 |
|
|
|
46 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
47 |
|
48 |
-
from .in_silico_perturber import downsample_and_sort, \
|
49 |
-
gen_attention_mask, \
|
50 |
-
get_model_input_size, \
|
51 |
-
load_and_filter, \
|
52 |
-
load_model, \
|
53 |
-
mean_nonpadding_embs, \
|
54 |
-
pad_tensor_list, \
|
55 |
-
quant_layers
|
56 |
-
|
57 |
logger = logging.getLogger(__name__)
|
58 |
|
|
|
59 |
# extract embeddings
|
60 |
-
def get_embs(
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
total_batch_length = len(filtered_input_data)
|
69 |
-
|
70 |
if summary_stat is None:
|
71 |
embs_list = []
|
72 |
elif summary_stat is not None:
|
@@ -74,69 +70,173 @@ def get_embs(model,
|
|
74 |
example = filtered_input_data.select([i for i in range(1)])
|
75 |
example.set_format(type="torch")
|
76 |
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
for i in trange(0, total_batch_length, forward_batch_size):
|
81 |
-
max_range = min(i+forward_batch_size, total_batch_length)
|
82 |
|
83 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
84 |
-
|
85 |
-
|
|
|
86 |
minibatch.set_format(type="torch")
|
87 |
|
88 |
input_data_minibatch = minibatch["input_ids"]
|
89 |
-
input_data_minibatch = pad_tensor_list(
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
with torch.no_grad():
|
95 |
outputs = model(
|
96 |
-
input_ids
|
97 |
-
attention_mask
|
98 |
)
|
99 |
|
100 |
embs_i = outputs.hidden_states[layer_to_quant]
|
101 |
-
|
102 |
if emb_mode == "cell":
|
103 |
-
mean_embs = mean_nonpadding_embs(embs_i, original_lens)
|
104 |
if summary_stat is None:
|
105 |
-
embs_list
|
106 |
elif summary_stat is not None:
|
107 |
# update tdigests with current batch for each emb dim
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
del outputs
|
112 |
del minibatch
|
113 |
del input_data_minibatch
|
114 |
del embs_i
|
115 |
-
|
116 |
-
torch.cuda.empty_cache()
|
117 |
-
|
118 |
if summary_stat is None:
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
# calculate summary stat embs from approximated tdigests
|
121 |
elif summary_stat is not None:
|
122 |
-
if
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
return embs_stack
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
def test_emb(model, example, layer_to_quant):
|
131 |
with torch.no_grad():
|
132 |
-
outputs = model(
|
133 |
-
input_ids = example.to("cuda")
|
134 |
-
)
|
135 |
|
136 |
embs_test = outputs.hidden_states[layer_to_quant]
|
137 |
return embs_test.size()[2]
|
138 |
|
139 |
-
|
|
|
140 |
embs_df = pd.DataFrame(embs.cpu().numpy())
|
141 |
if emb_labels is not None:
|
142 |
for label in emb_labels:
|
@@ -144,94 +244,145 @@ def label_embs(embs, downsampled_data, emb_labels):
|
|
144 |
embs_df[label] = emb_label
|
145 |
return embs_df
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
148 |
-
only_embs_df = embs_df.iloc[
|
149 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
150 |
-
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
|
|
|
|
151 |
vars_dict = {"embs": only_embs_df.columns}
|
152 |
-
obs_dict = {"cell_id": list(only_embs_df.index),
|
153 |
-
f"{label}": list(embs_df[label])}
|
154 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
155 |
-
sc.tl.pca(adata, svd_solver=
|
156 |
sc.pp.neighbors(adata)
|
157 |
sc.tl.umap(adata)
|
158 |
-
sns.set(rc={
|
159 |
sns.set_style("white")
|
160 |
-
default_kwargs_dict = {"palette":"Set2", "size":200}
|
161 |
if kwargs_dict is not None:
|
162 |
default_kwargs_dict.update(kwargs_dict)
|
163 |
-
|
164 |
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
165 |
|
|
|
166 |
def gen_heatmap_class_colors(labels, df):
|
167 |
-
pal = sns.cubehelix_palette(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
169 |
colors = pd.Series(labels, index=df.index).map(lut)
|
170 |
return colors
|
171 |
-
|
|
|
172 |
def gen_heatmap_class_dict(classes, label_colors_series):
|
173 |
-
class_color_dict_df = pd.DataFrame(
|
|
|
|
|
174 |
class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
|
175 |
-
return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"]))
|
176 |
-
|
177 |
-
def make_colorbar(embs_df, label):
|
178 |
|
|
|
|
|
179 |
labels = list(embs_df[label])
|
180 |
-
|
181 |
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
182 |
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
183 |
|
184 |
-
for i,row in label_colors.iterrows():
|
185 |
-
colors=row[0]
|
186 |
-
if len(colors)!=3 or any(np.isnan(colors)):
|
187 |
-
print(i,colors)
|
188 |
|
189 |
label_colors.isna().sum()
|
190 |
-
|
191 |
# create dictionary for colors and classes
|
192 |
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
193 |
return label_colors, label_color_dict
|
194 |
-
|
|
|
195 |
def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
196 |
sns.set_style("white")
|
197 |
sns.set(font_scale=2)
|
198 |
plt.figure(figsize=(15, 15), dpi=150)
|
199 |
label_colors, label_color_dict = make_colorbar(embs_df, label)
|
200 |
-
|
201 |
-
default_kwargs_dict = {
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
212 |
if kwargs_dict is not None:
|
213 |
default_kwargs_dict.update(kwargs_dict)
|
214 |
-
g = sns.clustermap(
|
|
|
|
|
215 |
|
216 |
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
217 |
|
218 |
for label_color in list(label_color_dict.keys()):
|
219 |
-
g.ax_col_dendrogram.bar(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
-
|
222 |
-
loc="lower center",
|
223 |
-
ncol=4,
|
224 |
-
bbox_to_anchor=(0.5, 1),
|
225 |
-
facecolor="white")
|
226 |
|
227 |
-
plt.savefig(output_file, bbox_inches='tight')
|
228 |
|
229 |
class EmbExtractor:
|
230 |
valid_option_dict = {
|
231 |
-
"model_type": {"Pretrained","GeneClassifier","CellClassifier"},
|
232 |
"num_classes": {int},
|
233 |
-
"emb_mode": {"cell","gene"},
|
234 |
"cell_emb_style": {"mean_pool"},
|
|
|
235 |
"filter_data": {None, dict},
|
236 |
"max_ncells": {None, int},
|
237 |
"emb_layer": {-1, 0},
|
@@ -239,14 +390,16 @@ class EmbExtractor:
|
|
239 |
"labels_to_plot": {None, list},
|
240 |
"forward_batch_size": {int},
|
241 |
"nproc": {int},
|
242 |
-
"summary_stat": {None, "mean", "median"},
|
243 |
}
|
|
|
244 |
def __init__(
|
245 |
self,
|
246 |
model_type="Pretrained",
|
247 |
num_classes=0,
|
248 |
emb_mode="cell",
|
249 |
cell_emb_style="mean_pool",
|
|
|
250 |
filter_data=None,
|
251 |
max_ncells=1000,
|
252 |
emb_layer=-1,
|
@@ -272,6 +425,9 @@ class EmbExtractor:
|
|
272 |
cell_emb_style : "mean_pool"
|
273 |
Method for summarizing cell embeddings.
|
274 |
Currently only option is mean pooling of gene embeddings for given cell.
|
|
|
|
|
|
|
275 |
filter_data : None, dict
|
276 |
Default is to extract embeddings from all input data.
|
277 |
Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
@@ -296,10 +452,11 @@ class EmbExtractor:
|
|
296 |
Batch size for forward pass.
|
297 |
nproc : int
|
298 |
Number of CPU processes to use.
|
299 |
-
summary_stat : {None, "mean", "median"}
|
300 |
-
If
|
301 |
-
|
302 |
-
|
|
|
303 |
token_dictionary_file : Path
|
304 |
Path to pickle file containing token dictionary (Ensembl ID:token).
|
305 |
"""
|
@@ -308,6 +465,7 @@ class EmbExtractor:
|
|
308 |
self.num_classes = num_classes
|
309 |
self.emb_mode = emb_mode
|
310 |
self.cell_emb_style = cell_emb_style
|
|
|
311 |
self.filter_data = filter_data
|
312 |
self.max_ncells = max_ncells
|
313 |
self.emb_layer = emb_layer
|
@@ -315,7 +473,12 @@ class EmbExtractor:
|
|
315 |
self.labels_to_plot = labels_to_plot
|
316 |
self.forward_batch_size = forward_batch_size
|
317 |
self.nproc = nproc
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
self.validate_options()
|
321 |
|
@@ -323,51 +486,49 @@ class EmbExtractor:
|
|
323 |
with open(token_dictionary_file, "rb") as f:
|
324 |
self.gene_token_dict = pickle.load(f)
|
325 |
|
|
|
326 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
327 |
-
|
328 |
-
|
329 |
def validate_options(self):
|
330 |
-
# first disallow options under development
|
331 |
-
if self.emb_mode == "gene":
|
332 |
-
logger.error(
|
333 |
-
"Extraction and plotting of gene-level embeddings currently under development. " \
|
334 |
-
"Current valid option for 'emb_mode': 'cell'"
|
335 |
-
)
|
336 |
-
raise
|
337 |
-
|
338 |
# confirm arguments are within valid options and compatible with each other
|
339 |
-
for attr_name,valid_options in self.valid_option_dict.items():
|
340 |
attr_value = self.__dict__[attr_name]
|
341 |
-
if
|
342 |
if attr_value in valid_options:
|
343 |
continue
|
344 |
valid_type = False
|
345 |
for option in valid_options:
|
346 |
-
if (option in [int,list,dict]) and isinstance(
|
|
|
|
|
347 |
valid_type = True
|
348 |
break
|
349 |
if valid_type:
|
350 |
continue
|
351 |
logger.error(
|
352 |
-
f"Invalid option for {attr_name}. "
|
353 |
f"Valid options for {attr_name}: {valid_options}"
|
354 |
)
|
355 |
raise
|
356 |
-
|
357 |
if self.filter_data is not None:
|
358 |
-
for key,value in self.filter_data.items():
|
359 |
-
if
|
360 |
self.filter_data[key] = [value]
|
361 |
logger.warning(
|
362 |
-
"Values in filter_data dict must be lists. "
|
363 |
-
f"Changing {key} value to list ([{value}])."
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
371 |
"""
|
372 |
Extract embeddings from input data and save as results in output_directory.
|
373 |
|
@@ -384,42 +545,165 @@ class EmbExtractor:
|
|
384 |
output_torch_embs : bool
|
385 |
Whether or not to also output the embeddings as a tensor.
|
386 |
Note, if true, will output embeddings as both dataframe and tensor.
|
|
|
|
|
387 |
"""
|
388 |
|
389 |
-
filtered_input_data = load_and_filter(
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
|
406 |
# save embeddings to output_path
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
else:
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
"""
|
424 |
Plot embeddings, coloring by provided labels.
|
425 |
|
@@ -440,60 +724,63 @@ class EmbExtractor:
|
|
440 |
kwargs_dict : dict
|
441 |
Dictionary of kwargs to pass to plotting function.
|
442 |
"""
|
443 |
-
|
444 |
-
if plot_style not in ["heatmap","umap"]:
|
445 |
logger.error(
|
446 |
-
"Invalid option for 'plot_style'. "
|
447 |
-
"Valid options: {'heatmap','umap'}"
|
448 |
)
|
449 |
raise
|
450 |
-
|
451 |
if (plot_style == "umap") and (self.labels_to_plot is None):
|
452 |
-
logger.error(
|
453 |
-
"Plotting UMAP requires 'labels_to_plot'. "
|
454 |
-
)
|
455 |
raise
|
456 |
-
|
457 |
if max_ncells_to_plot > self.max_ncells:
|
458 |
max_ncells_to_plot = self.max_ncells
|
459 |
logger.warning(
|
460 |
-
"max_ncells_to_plot must be <= max_ncells. "
|
461 |
-
f"Changing max_ncells_to_plot to {self.max_ncells}."
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
embs = embs.sample(max_ncells_to_plot, axis=0)
|
466 |
-
|
467 |
if self.emb_label is None:
|
468 |
label_len = 0
|
469 |
else:
|
470 |
label_len = len(self.emb_label)
|
471 |
-
|
472 |
emb_dims = embs.shape[1] - label_len
|
473 |
-
|
474 |
if self.emb_label is None:
|
475 |
emb_labels = None
|
476 |
else:
|
477 |
emb_labels = embs.columns[emb_dims:]
|
478 |
-
|
479 |
if plot_style == "umap":
|
480 |
for label in self.labels_to_plot:
|
481 |
if label not in emb_labels:
|
482 |
logger.warning(
|
483 |
-
f"Label {label} from labels_to_plot "
|
484 |
-
f"not present in provided embeddings dataframe."
|
|
|
485 |
continue
|
486 |
output_prefix_label = "_" + output_prefix + f"_umap_{label}"
|
487 |
-
output_file = (
|
|
|
|
|
488 |
plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
|
489 |
-
|
490 |
if plot_style == "heatmap":
|
491 |
for label in self.labels_to_plot:
|
492 |
if label not in emb_labels:
|
493 |
logger.warning(
|
494 |
-
f"Label {label} from labels_to_plot "
|
495 |
-
f"not present in provided embeddings dataframe."
|
|
|
496 |
continue
|
497 |
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
498 |
-
output_file = (
|
499 |
-
|
|
|
|
|
|
7 |
num_classes=3,
|
8 |
emb_mode="cell",
|
9 |
cell_emb_style="mean_pool",
|
10 |
+
gene_emb_style="mean_pool",
|
11 |
filter_data={"cell_type":["cardiomyocyte"]},
|
12 |
max_ncells=1000,
|
13 |
max_ncells_to_plot=1000,
|
14 |
emb_layer=-1,
|
15 |
emb_label=["disease","cell_type"],
|
16 |
labels_to_plot=["disease","cell_type"],
|
|
|
17 |
nproc=16,
|
18 |
summary_stat=None)
|
19 |
embs = embex.extract_embs("path/to/model",
|
20 |
"path/to/input_data",
|
21 |
"path/to/output_directory",
|
22 |
"output_prefix")
|
23 |
+
embex.plot_embs(embs=embs,
|
24 |
plot_style="heatmap",
|
25 |
output_directory="path/to/output_directory",
|
26 |
output_prefix="output_prefix")
|
27 |
+
|
28 |
"""
|
29 |
|
30 |
# imports
|
31 |
import logging
|
32 |
+
import pickle
|
33 |
+
from collections import Counter
|
34 |
+
from pathlib import Path
|
35 |
+
|
36 |
import anndata
|
37 |
import matplotlib.pyplot as plt
|
38 |
import numpy as np
|
39 |
import pandas as pd
|
|
|
|
|
40 |
import scanpy as sc
|
41 |
import seaborn as sns
|
42 |
import torch
|
43 |
+
from tdigest import TDigest
|
|
|
44 |
from tqdm.auto import trange
|
|
|
45 |
|
46 |
+
from . import perturber_utils as pu
|
47 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
logger = logging.getLogger(__name__)
|
50 |
|
51 |
+
|
52 |
# extract embeddings
|
53 |
+
def get_embs(
|
54 |
+
model,
|
55 |
+
filtered_input_data,
|
56 |
+
emb_mode,
|
57 |
+
layer_to_quant,
|
58 |
+
pad_token_id,
|
59 |
+
forward_batch_size,
|
60 |
+
summary_stat=None,
|
61 |
+
silent=False,
|
62 |
+
):
|
63 |
+
model_input_size = pu.get_model_input_size(model)
|
64 |
total_batch_length = len(filtered_input_data)
|
65 |
+
|
66 |
if summary_stat is None:
|
67 |
embs_list = []
|
68 |
elif summary_stat is not None:
|
|
|
70 |
example = filtered_input_data.select([i for i in range(1)])
|
71 |
example.set_format(type="torch")
|
72 |
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
73 |
+
if emb_mode == "cell":
|
74 |
+
# initiate tdigests for # of emb dims
|
75 |
+
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
76 |
+
if emb_mode == "gene":
|
77 |
+
gene_set = list(
|
78 |
+
{
|
79 |
+
element
|
80 |
+
for sublist in filtered_input_data["input_ids"]
|
81 |
+
for element in sublist
|
82 |
+
}
|
83 |
+
)
|
84 |
+
# initiate dict with genes as keys and tdigests for # of emb dims as values
|
85 |
+
embs_tdigests_dict = {
|
86 |
+
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
87 |
+
}
|
88 |
+
|
89 |
+
overall_max_len = 0
|
90 |
|
91 |
+
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
92 |
+
max_range = min(i + forward_batch_size, total_batch_length)
|
93 |
|
94 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
95 |
+
|
96 |
+
max_len = int(max(minibatch["length"]))
|
97 |
+
original_lens = torch.tensor(minibatch["length"], device="cuda")
|
98 |
minibatch.set_format(type="torch")
|
99 |
|
100 |
input_data_minibatch = minibatch["input_ids"]
|
101 |
+
input_data_minibatch = pu.pad_tensor_list(
|
102 |
+
input_data_minibatch, max_len, pad_token_id, model_input_size
|
103 |
+
)
|
104 |
+
|
|
|
105 |
with torch.no_grad():
|
106 |
outputs = model(
|
107 |
+
input_ids=input_data_minibatch.to("cuda"),
|
108 |
+
attention_mask=pu.gen_attention_mask(minibatch),
|
109 |
)
|
110 |
|
111 |
embs_i = outputs.hidden_states[layer_to_quant]
|
112 |
+
|
113 |
if emb_mode == "cell":
|
114 |
+
mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
|
115 |
if summary_stat is None:
|
116 |
+
embs_list.append(mean_embs)
|
117 |
elif summary_stat is not None:
|
118 |
# update tdigests with current batch for each emb dim
|
119 |
+
accumulate_tdigests(embs_tdigests, mean_embs, emb_dims)
|
120 |
+
del mean_embs
|
121 |
+
elif emb_mode == "gene":
|
122 |
+
if summary_stat is None:
|
123 |
+
embs_list.append(embs_i)
|
124 |
+
elif summary_stat is not None:
|
125 |
+
for h in trange(len(minibatch)):
|
126 |
+
length_h = minibatch[h]["length"]
|
127 |
+
input_ids_h = minibatch[h]["input_ids"][0:length_h]
|
128 |
+
|
129 |
+
# double check dimensions before unsqueezing
|
130 |
+
embs_i_dim = embs_i.dim()
|
131 |
+
if embs_i_dim != 3:
|
132 |
+
logger.error(
|
133 |
+
f"Embedding tensor should have 3 dimensions, not {embs_i_dim}"
|
134 |
+
)
|
135 |
+
raise
|
136 |
+
|
137 |
+
embs_h = embs_i[h, :, :].unsqueeze(dim=1)
|
138 |
+
dict_h = dict(zip(input_ids_h, embs_h))
|
139 |
+
for k in dict_h.keys():
|
140 |
+
accumulate_tdigests(
|
141 |
+
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
142 |
+
)
|
143 |
+
|
144 |
+
overall_max_len = max(overall_max_len, max_len)
|
145 |
del outputs
|
146 |
del minibatch
|
147 |
del input_data_minibatch
|
148 |
del embs_i
|
149 |
+
|
150 |
+
torch.cuda.empty_cache()
|
151 |
+
|
152 |
if summary_stat is None:
|
153 |
+
if emb_mode == "cell":
|
154 |
+
embs_stack = torch.cat(embs_list, dim=0)
|
155 |
+
elif emb_mode == "gene":
|
156 |
+
embs_stack = pu.pad_tensor_list(
|
157 |
+
embs_list,
|
158 |
+
overall_max_len,
|
159 |
+
pad_token_id,
|
160 |
+
model_input_size,
|
161 |
+
1,
|
162 |
+
pu.pad_3d_tensor,
|
163 |
+
)
|
164 |
+
|
165 |
# calculate summary stat embs from approximated tdigests
|
166 |
elif summary_stat is not None:
|
167 |
+
if emb_mode == "cell":
|
168 |
+
if summary_stat == "mean":
|
169 |
+
summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
|
170 |
+
elif summary_stat == "median":
|
171 |
+
summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
|
172 |
+
embs_stack = torch.tensor(summary_emb_list)
|
173 |
+
elif emb_mode == "gene":
|
174 |
+
if summary_stat == "mean":
|
175 |
+
[
|
176 |
+
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
177 |
+
for gene in embs_tdigests_dict.keys()
|
178 |
+
]
|
179 |
+
elif summary_stat == "median":
|
180 |
+
[
|
181 |
+
update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims)
|
182 |
+
for gene in embs_tdigests_dict.keys()
|
183 |
+
]
|
184 |
+
return embs_tdigests_dict
|
185 |
|
186 |
return embs_stack
|
187 |
|
188 |
+
|
189 |
+
def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
|
190 |
+
# note: tdigest batch update known to be slow so updating serially
|
191 |
+
[
|
192 |
+
embs_tdigests[j].update(mean_embs[i, j].item())
|
193 |
+
for i in range(mean_embs.size(0))
|
194 |
+
for j in range(emb_dims)
|
195 |
+
]
|
196 |
+
|
197 |
+
|
198 |
+
def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
|
199 |
+
embs_tdigests_dict[gene] = accumulate_tdigests(
|
200 |
+
embs_tdigests_dict[gene], gene_embs, emb_dims
|
201 |
+
)
|
202 |
+
|
203 |
+
|
204 |
+
def update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims):
|
205 |
+
embs_tdigests_dict[gene] = tdigest_mean(embs_tdigests_dict[gene], emb_dims)
|
206 |
+
|
207 |
+
|
208 |
+
def update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims):
|
209 |
+
embs_tdigests_dict[gene] = tdigest_median(embs_tdigests_dict[gene], emb_dims)
|
210 |
+
|
211 |
+
|
212 |
+
def summarize_gene_embs(h, minibatch, embs_i, embs_tdigests_dict, emb_dims):
|
213 |
+
length_h = minibatch[h]["length"]
|
214 |
+
input_ids_h = minibatch[h]["input_ids"][0:length_h]
|
215 |
+
embs_h = embs_i[h, :, :].unsqueeze(dim=1)
|
216 |
+
dict_h = dict(zip(input_ids_h, embs_h))
|
217 |
+
[
|
218 |
+
update_tdigest_dict(embs_tdigests_dict, k, dict_h[k], emb_dims)
|
219 |
+
for k in dict_h.keys()
|
220 |
+
]
|
221 |
+
|
222 |
+
|
223 |
+
def tdigest_mean(embs_tdigests, emb_dims):
|
224 |
+
return [embs_tdigests[i].trimmed_mean(0, 100) for i in range(emb_dims)]
|
225 |
+
|
226 |
+
|
227 |
+
def tdigest_median(embs_tdigests, emb_dims):
|
228 |
+
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
229 |
+
|
230 |
+
|
231 |
def test_emb(model, example, layer_to_quant):
|
232 |
with torch.no_grad():
|
233 |
+
outputs = model(input_ids=example.to("cuda"))
|
|
|
|
|
234 |
|
235 |
embs_test = outputs.hidden_states[layer_to_quant]
|
236 |
return embs_test.size()[2]
|
237 |
|
238 |
+
|
239 |
+
def label_cell_embs(embs, downsampled_data, emb_labels):
|
240 |
embs_df = pd.DataFrame(embs.cpu().numpy())
|
241 |
if emb_labels is not None:
|
242 |
for label in emb_labels:
|
|
|
244 |
embs_df[label] = emb_label
|
245 |
return embs_df
|
246 |
|
247 |
+
|
248 |
+
def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
249 |
+
gene_set = {
|
250 |
+
element for sublist in downsampled_data["input_ids"] for element in sublist
|
251 |
+
}
|
252 |
+
gene_emb_dict = {k: [] for k in gene_set}
|
253 |
+
for i in range(embs.size()[0]):
|
254 |
+
length = downsampled_data[i]["length"]
|
255 |
+
dict_i = dict(
|
256 |
+
zip(
|
257 |
+
downsampled_data[i]["input_ids"][0:length],
|
258 |
+
embs[i, :, :].unsqueeze(dim=1),
|
259 |
+
)
|
260 |
+
)
|
261 |
+
for k in dict_i.keys():
|
262 |
+
gene_emb_dict[k].append(dict_i[k])
|
263 |
+
for k in gene_emb_dict.keys():
|
264 |
+
gene_emb_dict[k] = (
|
265 |
+
torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
|
266 |
+
.cpu()
|
267 |
+
.numpy()
|
268 |
+
)
|
269 |
+
embs_df = pd.DataFrame(gene_emb_dict).T
|
270 |
+
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
271 |
+
return embs_df
|
272 |
+
|
273 |
+
|
274 |
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
275 |
+
only_embs_df = embs_df.iloc[:, :emb_dims]
|
276 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
277 |
+
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
278 |
+
str
|
279 |
+
)
|
280 |
vars_dict = {"embs": only_embs_df.columns}
|
281 |
+
obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
|
|
|
282 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
283 |
+
sc.tl.pca(adata, svd_solver="arpack")
|
284 |
sc.pp.neighbors(adata)
|
285 |
sc.tl.umap(adata)
|
286 |
+
sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
|
287 |
sns.set_style("white")
|
288 |
+
default_kwargs_dict = {"palette": "Set2", "size": 200}
|
289 |
if kwargs_dict is not None:
|
290 |
default_kwargs_dict.update(kwargs_dict)
|
291 |
+
|
292 |
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
293 |
|
294 |
+
|
295 |
def gen_heatmap_class_colors(labels, df):
|
296 |
+
pal = sns.cubehelix_palette(
|
297 |
+
len(Counter(labels).keys()),
|
298 |
+
light=0.9,
|
299 |
+
dark=0.1,
|
300 |
+
hue=1,
|
301 |
+
reverse=True,
|
302 |
+
start=1,
|
303 |
+
rot=-2,
|
304 |
+
)
|
305 |
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
306 |
colors = pd.Series(labels, index=df.index).map(lut)
|
307 |
return colors
|
308 |
+
|
309 |
+
|
310 |
def gen_heatmap_class_dict(classes, label_colors_series):
|
311 |
+
class_color_dict_df = pd.DataFrame(
|
312 |
+
{"classes": classes, "color": label_colors_series}
|
313 |
+
)
|
314 |
class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
|
315 |
+
return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
|
|
|
|
|
316 |
|
317 |
+
|
318 |
+
def make_colorbar(embs_df, label):
|
319 |
labels = list(embs_df[label])
|
320 |
+
|
321 |
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
322 |
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
323 |
|
324 |
+
for i, row in label_colors.iterrows():
|
325 |
+
colors = row[0]
|
326 |
+
if len(colors) != 3 or any(np.isnan(colors)):
|
327 |
+
print(i, colors)
|
328 |
|
329 |
label_colors.isna().sum()
|
330 |
+
|
331 |
# create dictionary for colors and classes
|
332 |
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
333 |
return label_colors, label_color_dict
|
334 |
+
|
335 |
+
|
336 |
def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
337 |
sns.set_style("white")
|
338 |
sns.set(font_scale=2)
|
339 |
plt.figure(figsize=(15, 15), dpi=150)
|
340 |
label_colors, label_color_dict = make_colorbar(embs_df, label)
|
341 |
+
|
342 |
+
default_kwargs_dict = {
|
343 |
+
"row_cluster": True,
|
344 |
+
"col_cluster": True,
|
345 |
+
"row_colors": label_colors,
|
346 |
+
"standard_scale": 1,
|
347 |
+
"linewidths": 0,
|
348 |
+
"xticklabels": False,
|
349 |
+
"yticklabels": False,
|
350 |
+
"figsize": (15, 15),
|
351 |
+
"center": 0,
|
352 |
+
"cmap": "magma",
|
353 |
+
}
|
354 |
+
|
355 |
if kwargs_dict is not None:
|
356 |
default_kwargs_dict.update(kwargs_dict)
|
357 |
+
g = sns.clustermap(
|
358 |
+
embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
|
359 |
+
)
|
360 |
|
361 |
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
362 |
|
363 |
for label_color in list(label_color_dict.keys()):
|
364 |
+
g.ax_col_dendrogram.bar(
|
365 |
+
0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
|
366 |
+
)
|
367 |
+
|
368 |
+
g.ax_col_dendrogram.legend(
|
369 |
+
title=f"{label}",
|
370 |
+
loc="lower center",
|
371 |
+
ncol=4,
|
372 |
+
bbox_to_anchor=(0.5, 1),
|
373 |
+
facecolor="white",
|
374 |
+
)
|
375 |
|
376 |
+
plt.savefig(output_file, bbox_inches="tight")
|
|
|
|
|
|
|
|
|
377 |
|
|
|
378 |
|
379 |
class EmbExtractor:
|
380 |
valid_option_dict = {
|
381 |
+
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
382 |
"num_classes": {int},
|
383 |
+
"emb_mode": {"cell", "gene"},
|
384 |
"cell_emb_style": {"mean_pool"},
|
385 |
+
"gene_emb_style": {"mean_pool"},
|
386 |
"filter_data": {None, dict},
|
387 |
"max_ncells": {None, int},
|
388 |
"emb_layer": {-1, 0},
|
|
|
390 |
"labels_to_plot": {None, list},
|
391 |
"forward_batch_size": {int},
|
392 |
"nproc": {int},
|
393 |
+
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
394 |
}
|
395 |
+
|
396 |
def __init__(
|
397 |
self,
|
398 |
model_type="Pretrained",
|
399 |
num_classes=0,
|
400 |
emb_mode="cell",
|
401 |
cell_emb_style="mean_pool",
|
402 |
+
gene_emb_style="mean_pool",
|
403 |
filter_data=None,
|
404 |
max_ncells=1000,
|
405 |
emb_layer=-1,
|
|
|
425 |
cell_emb_style : "mean_pool"
|
426 |
Method for summarizing cell embeddings.
|
427 |
Currently only option is mean pooling of gene embeddings for given cell.
|
428 |
+
gene_emb_style : "mean_pool"
|
429 |
+
Method for summarizing gene embeddings.
|
430 |
+
Currently only option is mean pooling of contextual gene embeddings for given gene.
|
431 |
filter_data : None, dict
|
432 |
Default is to extract embeddings from all input data.
|
433 |
Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
|
|
452 |
Batch size for forward pass.
|
453 |
nproc : int
|
454 |
Number of CPU processes to use.
|
455 |
+
summary_stat : {None, "mean", "median", "exact_mean", "exact_median"}
|
456 |
+
If exact_mean or exact_median, outputs only exact mean or median embedding of input data.
|
457 |
+
If mean or median, outputs only approximated mean or median embedding of input data.
|
458 |
+
Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
459 |
+
Non-exact is slower but more memory-efficient.
|
460 |
token_dictionary_file : Path
|
461 |
Path to pickle file containing token dictionary (Ensembl ID:token).
|
462 |
"""
|
|
|
465 |
self.num_classes = num_classes
|
466 |
self.emb_mode = emb_mode
|
467 |
self.cell_emb_style = cell_emb_style
|
468 |
+
self.gene_emb_style = gene_emb_style
|
469 |
self.filter_data = filter_data
|
470 |
self.max_ncells = max_ncells
|
471 |
self.emb_layer = emb_layer
|
|
|
473 |
self.labels_to_plot = labels_to_plot
|
474 |
self.forward_batch_size = forward_batch_size
|
475 |
self.nproc = nproc
|
476 |
+
if (summary_stat is not None) and ("exact" in summary_stat):
|
477 |
+
self.summary_stat = None
|
478 |
+
self.exact_summary_stat = summary_stat
|
479 |
+
else:
|
480 |
+
self.summary_stat = summary_stat
|
481 |
+
self.exact_summary_stat = None
|
482 |
|
483 |
self.validate_options()
|
484 |
|
|
|
486 |
with open(token_dictionary_file, "rb") as f:
|
487 |
self.gene_token_dict = pickle.load(f)
|
488 |
|
489 |
+
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
490 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
491 |
+
|
|
|
492 |
def validate_options(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
# confirm arguments are within valid options and compatible with each other
|
494 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
495 |
attr_value = self.__dict__[attr_name]
|
496 |
+
if not isinstance(attr_value, (list, dict)):
|
497 |
if attr_value in valid_options:
|
498 |
continue
|
499 |
valid_type = False
|
500 |
for option in valid_options:
|
501 |
+
if (option in [int, list, dict, bool]) and isinstance(
|
502 |
+
attr_value, option
|
503 |
+
):
|
504 |
valid_type = True
|
505 |
break
|
506 |
if valid_type:
|
507 |
continue
|
508 |
logger.error(
|
509 |
+
f"Invalid option for {attr_name}. "
|
510 |
f"Valid options for {attr_name}: {valid_options}"
|
511 |
)
|
512 |
raise
|
513 |
+
|
514 |
if self.filter_data is not None:
|
515 |
+
for key, value in self.filter_data.items():
|
516 |
+
if not isinstance(value, list):
|
517 |
self.filter_data[key] = [value]
|
518 |
logger.warning(
|
519 |
+
"Values in filter_data dict must be lists. "
|
520 |
+
f"Changing {key} value to list ([{value}])."
|
521 |
+
)
|
522 |
+
|
523 |
+
def extract_embs(
|
524 |
+
self,
|
525 |
+
model_directory,
|
526 |
+
input_data_file,
|
527 |
+
output_directory,
|
528 |
+
output_prefix,
|
529 |
+
output_torch_embs=False,
|
530 |
+
cell_state=None,
|
531 |
+
):
|
532 |
"""
|
533 |
Extract embeddings from input data and save as results in output_directory.
|
534 |
|
|
|
545 |
output_torch_embs : bool
|
546 |
Whether or not to also output the embeddings as a tensor.
|
547 |
Note, if true, will output embeddings as both dataframe and tensor.
|
548 |
+
cell_state : dict
|
549 |
+
Cell state key and value for state embedding extraction.
|
550 |
"""
|
551 |
|
552 |
+
filtered_input_data = pu.load_and_filter(
|
553 |
+
self.filter_data, self.nproc, input_data_file
|
554 |
+
)
|
555 |
+
if cell_state is not None:
|
556 |
+
filtered_input_data = pu.filter_by_dict(
|
557 |
+
filtered_input_data, cell_state, self.nproc
|
558 |
+
)
|
559 |
+
downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
|
560 |
+
model = pu.load_model(self.model_type, self.num_classes, model_directory)
|
561 |
+
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
562 |
+
embs = get_embs(
|
563 |
+
model,
|
564 |
+
downsampled_data,
|
565 |
+
self.emb_mode,
|
566 |
+
layer_to_quant,
|
567 |
+
self.pad_token_id,
|
568 |
+
self.forward_batch_size,
|
569 |
+
self.summary_stat,
|
570 |
+
)
|
571 |
+
|
572 |
+
if self.emb_mode == "cell":
|
573 |
+
if self.summary_stat is None:
|
574 |
+
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
575 |
+
elif self.summary_stat is not None:
|
576 |
+
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
577 |
+
elif self.emb_mode == "gene":
|
578 |
+
if self.summary_stat is None:
|
579 |
+
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
|
580 |
+
elif self.summary_stat is not None:
|
581 |
+
embs_df = pd.DataFrame(embs).T
|
582 |
+
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
583 |
|
584 |
# save embeddings to output_path
|
585 |
+
if cell_state is None:
|
586 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
587 |
+
embs_df.to_csv(output_path)
|
588 |
+
|
589 |
+
if self.exact_summary_stat == "exact_mean":
|
590 |
+
embs = embs.mean(dim=0)
|
591 |
+
embs_df = pd.DataFrame(
|
592 |
+
embs_df[0:255].mean(axis="rows"), columns=[self.exact_summary_stat]
|
593 |
+
).T
|
594 |
+
elif self.exact_summary_stat == "exact_median":
|
595 |
+
embs = torch.median(embs, dim=0)[0]
|
596 |
+
embs_df = pd.DataFrame(
|
597 |
+
embs_df[0:255].median(axis="rows"), columns=[self.exact_summary_stat]
|
598 |
+
).T
|
599 |
+
|
600 |
+
if cell_state is not None:
|
601 |
+
return embs
|
602 |
else:
|
603 |
+
if output_torch_embs:
|
604 |
+
return embs_df, embs
|
605 |
+
else:
|
606 |
+
return embs_df
|
607 |
+
|
608 |
+
def get_state_embs(
|
609 |
+
self,
|
610 |
+
cell_states_to_model,
|
611 |
+
model_directory,
|
612 |
+
input_data_file,
|
613 |
+
output_directory,
|
614 |
+
output_prefix,
|
615 |
+
output_torch_embs=True,
|
616 |
+
):
|
617 |
+
"""
|
618 |
+
Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
|
619 |
+
|
620 |
+
Parameters
|
621 |
+
----------
|
622 |
+
cell_states_to_model : None, dict
|
623 |
+
Cell states to model if testing perturbations that achieve goal state change.
|
624 |
+
Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
625 |
+
state_key: key specifying name of column in .dataset that defines the start/goal states
|
626 |
+
start_state: value in the state_key column that specifies the start state
|
627 |
+
goal_state: value in the state_key column taht specifies the goal end state
|
628 |
+
alt_states: list of values in the state_key column that specify the alternate end states
|
629 |
+
For example: {"state_key": "disease",
|
630 |
+
"start_state": "dcm",
|
631 |
+
"goal_state": "nf",
|
632 |
+
"alt_states": ["hcm", "other1", "other2"]}
|
633 |
+
model_directory : Path
|
634 |
+
Path to directory containing model
|
635 |
+
input_data_file : Path
|
636 |
+
Path to directory containing .dataset inputs
|
637 |
+
output_directory : Path
|
638 |
+
Path to directory where embedding data will be saved as csv
|
639 |
+
output_prefix : str
|
640 |
+
Prefix for output file
|
641 |
+
output_torch_embs : bool
|
642 |
+
Whether or not to also output the embeddings as a tensor.
|
643 |
+
Note, if true, will output embeddings as both dataframe and tensor.
|
644 |
+
|
645 |
+
Outputs
|
646 |
+
----------
|
647 |
+
Outputs state_embs_dict for use with in silico perturber.
|
648 |
+
Format is dictionary of embedding positions of each cell state to model shifts from/towards.
|
649 |
+
Keys specify each possible cell state to model.
|
650 |
+
Values are target embedding positions as torch.tensor.
|
651 |
+
For example: {"nf": emb_nf,
|
652 |
+
"hcm": emb_hcm,
|
653 |
+
"dcm": emb_dcm,
|
654 |
+
"other1": emb_other1,
|
655 |
+
"other2": emb_other2}
|
656 |
+
"""
|
657 |
+
|
658 |
+
pu.validate_cell_states_to_model(cell_states_to_model)
|
659 |
+
valid_summary_stats = ["exact_mean", "exact_median"]
|
660 |
+
if self.exact_summary_stat not in valid_summary_stats:
|
661 |
+
logger.error(
|
662 |
+
"For extracting state embs, summary_stat in EmbExtractor "
|
663 |
+
f"must be set to option in {valid_summary_stats}"
|
664 |
+
)
|
665 |
+
raise
|
666 |
+
|
667 |
+
state_embs_dict = dict()
|
668 |
+
state_key = cell_states_to_model["state_key"]
|
669 |
+
for k, v in cell_states_to_model.items():
|
670 |
+
if k == "state_key":
|
671 |
+
continue
|
672 |
+
elif (k == "start_state") or (k == "goal_state"):
|
673 |
+
state_embs_dict[v] = self.extract_embs(
|
674 |
+
model_directory,
|
675 |
+
input_data_file,
|
676 |
+
output_directory,
|
677 |
+
output_prefix,
|
678 |
+
output_torch_embs,
|
679 |
+
cell_state={state_key: v},
|
680 |
+
)
|
681 |
+
else: # k == "alt_states"
|
682 |
+
for alt_state in v:
|
683 |
+
state_embs_dict[alt_state] = self.extract_embs(
|
684 |
+
model_directory,
|
685 |
+
input_data_file,
|
686 |
+
output_directory,
|
687 |
+
output_prefix,
|
688 |
+
output_torch_embs,
|
689 |
+
cell_state={state_key: alt_state},
|
690 |
+
)
|
691 |
+
|
692 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".pkl")
|
693 |
+
with open(output_path, "wb") as fp:
|
694 |
+
pickle.dump(state_embs_dict, fp)
|
695 |
+
|
696 |
+
return state_embs_dict
|
697 |
+
|
698 |
+
def plot_embs(
|
699 |
+
self,
|
700 |
+
embs,
|
701 |
+
plot_style,
|
702 |
+
output_directory,
|
703 |
+
output_prefix,
|
704 |
+
max_ncells_to_plot=1000,
|
705 |
+
kwargs_dict=None,
|
706 |
+
):
|
707 |
"""
|
708 |
Plot embeddings, coloring by provided labels.
|
709 |
|
|
|
724 |
kwargs_dict : dict
|
725 |
Dictionary of kwargs to pass to plotting function.
|
726 |
"""
|
727 |
+
|
728 |
+
if plot_style not in ["heatmap", "umap"]:
|
729 |
logger.error(
|
730 |
+
"Invalid option for 'plot_style'. " "Valid options: {'heatmap','umap'}"
|
|
|
731 |
)
|
732 |
raise
|
733 |
+
|
734 |
if (plot_style == "umap") and (self.labels_to_plot is None):
|
735 |
+
logger.error("Plotting UMAP requires 'labels_to_plot'. ")
|
|
|
|
|
736 |
raise
|
737 |
+
|
738 |
if max_ncells_to_plot > self.max_ncells:
|
739 |
max_ncells_to_plot = self.max_ncells
|
740 |
logger.warning(
|
741 |
+
"max_ncells_to_plot must be <= max_ncells. "
|
742 |
+
f"Changing max_ncells_to_plot to {self.max_ncells}."
|
743 |
+
)
|
744 |
+
|
745 |
+
if (max_ncells_to_plot is not None) and (max_ncells_to_plot < self.max_ncells):
|
746 |
embs = embs.sample(max_ncells_to_plot, axis=0)
|
747 |
+
|
748 |
if self.emb_label is None:
|
749 |
label_len = 0
|
750 |
else:
|
751 |
label_len = len(self.emb_label)
|
752 |
+
|
753 |
emb_dims = embs.shape[1] - label_len
|
754 |
+
|
755 |
if self.emb_label is None:
|
756 |
emb_labels = None
|
757 |
else:
|
758 |
emb_labels = embs.columns[emb_dims:]
|
759 |
+
|
760 |
if plot_style == "umap":
|
761 |
for label in self.labels_to_plot:
|
762 |
if label not in emb_labels:
|
763 |
logger.warning(
|
764 |
+
f"Label {label} from labels_to_plot "
|
765 |
+
f"not present in provided embeddings dataframe."
|
766 |
+
)
|
767 |
continue
|
768 |
output_prefix_label = "_" + output_prefix + f"_umap_{label}"
|
769 |
+
output_file = (
|
770 |
+
Path(output_directory) / output_prefix_label
|
771 |
+
).with_suffix(".pdf")
|
772 |
plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
|
773 |
+
|
774 |
if plot_style == "heatmap":
|
775 |
for label in self.labels_to_plot:
|
776 |
if label not in emb_labels:
|
777 |
logger.warning(
|
778 |
+
f"Label {label} from labels_to_plot "
|
779 |
+
f"not present in provided embeddings dataframe."
|
780 |
+
)
|
781 |
continue
|
782 |
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
783 |
+
output_file = (
|
784 |
+
Path(output_directory) / output_prefix_label
|
785 |
+
).with_suffix(".pdf")
|
786 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
geneformer/in_silico_perturber.py
CHANGED
@@ -8,614 +8,66 @@ Usage:
|
|
8 |
genes_to_perturb="all",
|
9 |
combos=0,
|
10 |
anchor_gene=None,
|
11 |
-
model_type="
|
12 |
num_classes=0,
|
13 |
emb_mode="cell",
|
14 |
cell_emb_style="mean_pool",
|
15 |
filter_data={"cell_type":["cardiomyocyte"]},
|
16 |
cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
|
|
|
17 |
max_ncells=None,
|
18 |
-
emb_layer
|
19 |
forward_batch_size=100,
|
20 |
-
nproc=
|
21 |
isp.perturb_data("path/to/model",
|
22 |
"path/to/input_data",
|
23 |
"path/to/output_directory",
|
24 |
"output_prefix")
|
25 |
"""
|
26 |
|
27 |
-
# imports
|
28 |
-
import itertools as it
|
29 |
import logging
|
30 |
-
|
|
|
|
|
31 |
import pickle
|
32 |
-
import re
|
33 |
-
import seaborn as sns; sns.set()
|
34 |
-
import torch
|
35 |
from collections import defaultdict
|
36 |
-
|
|
|
|
|
|
|
37 |
from tqdm.auto import trange
|
38 |
-
from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
|
39 |
|
|
|
|
|
40 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
# load data and filter by defined criteria
|
46 |
-
def load_and_filter(filter_data, nproc, input_data_file):
|
47 |
-
data = load_from_disk(input_data_file)
|
48 |
-
if filter_data is not None:
|
49 |
-
for key,value in filter_data.items():
|
50 |
-
def filter_data_by_criteria(example):
|
51 |
-
return example[key] in value
|
52 |
-
data = data.filter(filter_data_by_criteria, num_proc=nproc)
|
53 |
-
if len(data) == 0:
|
54 |
-
logger.error(
|
55 |
-
"No cells remain after filtering. Check filtering criteria.")
|
56 |
-
raise
|
57 |
-
data_shuffled = data.shuffle(seed=42)
|
58 |
-
return data_shuffled
|
59 |
-
|
60 |
-
# load model to GPU
|
61 |
-
def load_model(model_type, num_classes, model_directory):
|
62 |
-
if model_type == "Pretrained":
|
63 |
-
model = BertForMaskedLM.from_pretrained(model_directory,
|
64 |
-
output_hidden_states=True,
|
65 |
-
output_attentions=False)
|
66 |
-
elif model_type == "GeneClassifier":
|
67 |
-
model = BertForTokenClassification.from_pretrained(model_directory,
|
68 |
-
num_labels=num_classes,
|
69 |
-
output_hidden_states=True,
|
70 |
-
output_attentions=False)
|
71 |
-
elif model_type == "CellClassifier":
|
72 |
-
model = BertForSequenceClassification.from_pretrained(model_directory,
|
73 |
-
num_labels=num_classes,
|
74 |
-
output_hidden_states=True,
|
75 |
-
output_attentions=False)
|
76 |
-
# put the model in eval mode for fwd pass
|
77 |
-
model.eval()
|
78 |
-
model = model.to("cuda:0")
|
79 |
-
return model
|
80 |
-
|
81 |
-
def quant_layers(model):
|
82 |
-
layer_nums = []
|
83 |
-
for name, parameter in model.named_parameters():
|
84 |
-
if "layer" in name:
|
85 |
-
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
|
86 |
-
return int(max(layer_nums))+1
|
87 |
-
|
88 |
-
def get_model_input_size(model):
|
89 |
-
return int(re.split("\(|,",str(model.bert.embeddings.position_embeddings))[1])
|
90 |
-
|
91 |
-
def flatten_list(megalist):
|
92 |
-
return [item for sublist in megalist for item in sublist]
|
93 |
-
|
94 |
-
def measure_length(example):
|
95 |
-
example["length"] = len(example["input_ids"])
|
96 |
-
return example
|
97 |
-
|
98 |
-
def downsample_and_sort(data_shuffled, max_ncells):
|
99 |
-
num_cells = len(data_shuffled)
|
100 |
-
# if max number of cells is defined, then subsample to this max number
|
101 |
-
if max_ncells != None:
|
102 |
-
num_cells = min(max_ncells,num_cells)
|
103 |
-
data_subset = data_shuffled.select([i for i in range(num_cells)])
|
104 |
-
# sort dataset with largest cell first to encounter any memory errors earlier
|
105 |
-
data_sorted = data_subset.sort("length",reverse=True)
|
106 |
-
return data_sorted
|
107 |
-
|
108 |
-
def get_possible_states(cell_states_to_model):
|
109 |
-
possible_states = []
|
110 |
-
for key in ["start_state","goal_state"]:
|
111 |
-
possible_states += [cell_states_to_model[key]]
|
112 |
-
possible_states += cell_states_to_model.get("alt_states",[])
|
113 |
-
return possible_states
|
114 |
-
|
115 |
-
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
116 |
-
example_cell.set_format(type="torch")
|
117 |
-
input_data = example_cell["input_ids"]
|
118 |
-
with torch.no_grad():
|
119 |
-
outputs = model(
|
120 |
-
input_ids = input_data.to("cuda")
|
121 |
-
)
|
122 |
-
emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
|
123 |
-
del outputs
|
124 |
-
return emb
|
125 |
-
|
126 |
-
def perturb_emb_by_index(emb, indices):
|
127 |
-
mask = torch.ones(emb.numel(), dtype=torch.bool)
|
128 |
-
mask[indices] = False
|
129 |
-
return emb[mask]
|
130 |
-
|
131 |
-
def delete_indices(example):
|
132 |
-
indices = example["perturb_index"]
|
133 |
-
if any(isinstance(el, list) for el in indices):
|
134 |
-
indices = flatten_list(indices)
|
135 |
-
for index in sorted(indices, reverse=True):
|
136 |
-
del example["input_ids"][index]
|
137 |
-
return example
|
138 |
-
|
139 |
-
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
140 |
-
def overexpress_indices(example):
|
141 |
-
indices = example["perturb_index"]
|
142 |
-
if any(isinstance(el, list) for el in indices):
|
143 |
-
indices = flatten_list(indices)
|
144 |
-
for index in sorted(indices, reverse=True):
|
145 |
-
example["input_ids"].insert(0, example["input_ids"].pop(index))
|
146 |
-
return example
|
147 |
-
|
148 |
-
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
149 |
-
def overexpress_tokens(example):
|
150 |
-
# -100 indicates tokens to overexpress are not present in rank value encoding
|
151 |
-
if example["perturb_index"] != [-100]:
|
152 |
-
example = delete_indices(example)
|
153 |
-
[example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
|
154 |
-
|
155 |
-
return example
|
156 |
-
|
157 |
-
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
158 |
-
# indices_to_remove is list of indices to remove
|
159 |
-
indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
|
160 |
-
num_dims = emb.dim()
|
161 |
-
emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
|
162 |
-
sliced_emb = emb[emb_slice]
|
163 |
-
return sliced_emb
|
164 |
-
|
165 |
-
def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
|
166 |
-
output_batch = torch.stack([
|
167 |
-
remove_indices_from_emb(emb_batch[i, :, :], idxs, gene_dim-1) for
|
168 |
-
i, idxs in enumerate(list_of_indices_to_remove)
|
169 |
-
])
|
170 |
-
return output_batch
|
171 |
-
|
172 |
-
def make_perturbation_batch(example_cell,
|
173 |
-
perturb_type,
|
174 |
-
tokens_to_perturb,
|
175 |
-
anchor_token,
|
176 |
-
combo_lvl,
|
177 |
-
num_proc):
|
178 |
-
if tokens_to_perturb == "all":
|
179 |
-
if perturb_type in ["overexpress","activate"]:
|
180 |
-
range_start = 1
|
181 |
-
elif perturb_type in ["delete","inhibit"]:
|
182 |
-
range_start = 0
|
183 |
-
indices_to_perturb = [[i] for i in range(range_start, example_cell["length"][0])]
|
184 |
-
elif combo_lvl>0 and (anchor_token is not None):
|
185 |
-
example_input_ids = example_cell["input_ids "][0]
|
186 |
-
anchor_index = example_input_ids.index(anchor_token[0])
|
187 |
-
indices_to_perturb = [sorted([anchor_index,i]) if i!=anchor_index else None for i in range(example_cell["length"][0])]
|
188 |
-
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
189 |
-
else:
|
190 |
-
example_input_ids = example_cell["input_ids"][0]
|
191 |
-
indices_to_perturb = [[example_input_ids.index(token)] if token in example_input_ids else None for token in tokens_to_perturb]
|
192 |
-
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
193 |
-
|
194 |
-
# create all permutations of combo_lvl of modifiers from tokens_to_perturb
|
195 |
-
if combo_lvl>0 and (anchor_token is None):
|
196 |
-
if tokens_to_perturb != "all":
|
197 |
-
if len(tokens_to_perturb) == combo_lvl+1:
|
198 |
-
indices_to_perturb = [list(x) for x in it.combinations(indices_to_perturb, combo_lvl+1)]
|
199 |
-
else:
|
200 |
-
all_indices = [[i] for i in range(example_cell["length"][0])]
|
201 |
-
all_indices = [index for index in all_indices if index not in indices_to_perturb]
|
202 |
-
indices_to_perturb = [[[j for i in indices_to_perturb for j in i], x] for x in all_indices]
|
203 |
-
length = len(indices_to_perturb)
|
204 |
-
perturbation_dataset = Dataset.from_dict({"input_ids": example_cell["input_ids"]*length,
|
205 |
-
"perturb_index": indices_to_perturb})
|
206 |
-
if length<400:
|
207 |
-
num_proc_i = 1
|
208 |
-
else:
|
209 |
-
num_proc_i = num_proc
|
210 |
-
if perturb_type == "delete":
|
211 |
-
perturbation_dataset = perturbation_dataset.map(delete_indices, num_proc=num_proc_i)
|
212 |
-
elif perturb_type == "overexpress":
|
213 |
-
perturbation_dataset = perturbation_dataset.map(overexpress_indices, num_proc=num_proc_i)
|
214 |
-
return perturbation_dataset, indices_to_perturb
|
215 |
-
|
216 |
-
# perturbed cell emb removing the activated/overexpressed/inhibited gene emb
|
217 |
-
# so that only non-perturbed gene embeddings are compared to each other
|
218 |
-
# in original or perturbed context
|
219 |
-
def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
|
220 |
-
all_embs_list = []
|
221 |
-
|
222 |
-
# if making comparison batch for multiple perturbations in single cell
|
223 |
-
if perturb_group == False:
|
224 |
-
original_emb_list = [original_emb_batch]*len(indices_to_perturb)
|
225 |
-
# if making comparison batch for single perturbation in multiple cells
|
226 |
-
elif perturb_group == True:
|
227 |
-
original_emb_list = original_emb_batch
|
228 |
-
|
229 |
-
|
230 |
-
for i in range(len(original_emb_list)):
|
231 |
-
original_emb = original_emb_list[i]
|
232 |
-
indices = indices_to_perturb[i]
|
233 |
-
if indices == [-100]:
|
234 |
-
all_embs_list += [original_emb[:]]
|
235 |
-
continue
|
236 |
-
emb_list = []
|
237 |
-
start = 0
|
238 |
-
if any(isinstance(el, list) for el in indices):
|
239 |
-
indices = flatten_list(indices)
|
240 |
-
for i in sorted(indices):
|
241 |
-
emb_list += [original_emb[start:i]]
|
242 |
-
start = i+1
|
243 |
-
emb_list += [original_emb[start:]]
|
244 |
-
all_embs_list += [torch.cat(emb_list)]
|
245 |
-
len_set = set([emb.size()[0] for emb in all_embs_list])
|
246 |
-
if len(len_set) > 1:
|
247 |
-
max_len = max(len_set)
|
248 |
-
all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
|
249 |
-
return torch.stack(all_embs_list)
|
250 |
-
|
251 |
-
# average embedding position of goal cell states
|
252 |
-
def get_cell_state_avg_embs(model,
|
253 |
-
filtered_input_data,
|
254 |
-
cell_states_to_model,
|
255 |
-
layer_to_quant,
|
256 |
-
pad_token_id,
|
257 |
-
forward_batch_size,
|
258 |
-
num_proc):
|
259 |
-
|
260 |
-
model_input_size = get_model_input_size(model)
|
261 |
-
possible_states = get_possible_states(cell_states_to_model)
|
262 |
-
state_embs_dict = dict()
|
263 |
-
for possible_state in possible_states:
|
264 |
-
state_embs_list = []
|
265 |
-
original_lens = []
|
266 |
-
|
267 |
-
def filter_states(example):
|
268 |
-
state_key = cell_states_to_model["state_key"]
|
269 |
-
return example[state_key] in [possible_state]
|
270 |
-
filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
|
271 |
-
total_batch_length = len(filtered_input_data_state)
|
272 |
-
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
273 |
-
forward_batch_size = forward_batch_size-1
|
274 |
-
max_len = max(filtered_input_data_state["length"])
|
275 |
-
for i in range(0, total_batch_length, forward_batch_size):
|
276 |
-
max_range = min(i+forward_batch_size, total_batch_length)
|
277 |
-
|
278 |
-
state_minibatch = filtered_input_data_state.select([i for i in range(i, max_range)])
|
279 |
-
state_minibatch.set_format(type="torch")
|
280 |
-
|
281 |
-
input_data_minibatch = state_minibatch["input_ids"]
|
282 |
-
original_lens += state_minibatch["length"]
|
283 |
-
input_data_minibatch = pad_tensor_list(input_data_minibatch,
|
284 |
-
max_len,
|
285 |
-
pad_token_id,
|
286 |
-
model_input_size)
|
287 |
-
attention_mask = gen_attention_mask(state_minibatch, max_len)
|
288 |
-
|
289 |
-
with torch.no_grad():
|
290 |
-
outputs = model(
|
291 |
-
input_ids = input_data_minibatch.to("cuda"),
|
292 |
-
attention_mask = attention_mask
|
293 |
-
)
|
294 |
-
|
295 |
-
state_embs_i = outputs.hidden_states[layer_to_quant]
|
296 |
-
state_embs_list += [state_embs_i]
|
297 |
-
del outputs
|
298 |
-
del state_minibatch
|
299 |
-
del input_data_minibatch
|
300 |
-
del attention_mask
|
301 |
-
del state_embs_i
|
302 |
-
torch.cuda.empty_cache()
|
303 |
-
|
304 |
-
state_embs = torch.cat(state_embs_list)
|
305 |
-
avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
|
306 |
-
avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
|
307 |
-
state_embs_dict[possible_state] = avg_state_emb
|
308 |
-
return state_embs_dict
|
309 |
-
|
310 |
-
# quantify cosine similarity of perturbed vs original or alternate states
|
311 |
-
def quant_cos_sims(model,
|
312 |
-
perturb_type,
|
313 |
-
perturbation_batch,
|
314 |
-
forward_batch_size,
|
315 |
-
layer_to_quant,
|
316 |
-
original_emb,
|
317 |
-
tokens_to_perturb,
|
318 |
-
indices_to_perturb,
|
319 |
-
perturb_group,
|
320 |
-
cell_states_to_model,
|
321 |
-
state_embs_dict,
|
322 |
-
pad_token_id,
|
323 |
-
model_input_size,
|
324 |
-
nproc):
|
325 |
-
cos = torch.nn.CosineSimilarity(dim=2)
|
326 |
-
total_batch_length = len(perturbation_batch)
|
327 |
-
|
328 |
-
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
329 |
-
forward_batch_size = forward_batch_size-1
|
330 |
-
|
331 |
-
if perturb_group == False:
|
332 |
-
comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
|
333 |
-
|
334 |
-
if cell_states_to_model is None:
|
335 |
-
cos_sims = []
|
336 |
-
else:
|
337 |
-
possible_states = get_possible_states(cell_states_to_model)
|
338 |
-
cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for _ in range(len(possible_states))]))
|
339 |
-
|
340 |
-
# measure length of each element in perturbation_batch
|
341 |
-
perturbation_batch = perturbation_batch.map(
|
342 |
-
measure_length, num_proc=nproc
|
343 |
-
)
|
344 |
|
345 |
-
def compute_batch_embeddings(minibatch, _max_len = None):
|
346 |
-
minibatch_lengths = minibatch["length"]
|
347 |
-
minibatch_length_set = set(minibatch_lengths)
|
348 |
-
max_len = model_input_size
|
349 |
|
350 |
-
|
351 |
-
needs_pad_or_trunc = True
|
352 |
-
else:
|
353 |
-
needs_pad_or_trunc = False
|
354 |
-
max_len = max(minibatch_length_set)
|
355 |
-
|
356 |
-
|
357 |
-
if needs_pad_or_trunc == True:
|
358 |
-
if _max_len is None:
|
359 |
-
max_len = min(max(minibatch_length_set), max_len)
|
360 |
-
else:
|
361 |
-
max_len = _max_len
|
362 |
-
def pad_or_trunc_example(example):
|
363 |
-
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
|
364 |
-
pad_token_id,
|
365 |
-
max_len)
|
366 |
-
return example
|
367 |
-
minibatch = minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
368 |
-
|
369 |
-
minibatch.set_format(type="torch")
|
370 |
-
|
371 |
-
input_data_minibatch = minibatch["input_ids"]
|
372 |
-
attention_mask = gen_attention_mask(minibatch, max_len)
|
373 |
-
|
374 |
-
# extract embeddings for perturbation minibatch
|
375 |
-
with torch.no_grad():
|
376 |
-
outputs = model(
|
377 |
-
input_ids = input_data_minibatch.to("cuda"),
|
378 |
-
attention_mask = attention_mask
|
379 |
-
)
|
380 |
|
381 |
-
return outputs, max_len
|
382 |
-
|
383 |
-
for i in range(0, total_batch_length, forward_batch_size):
|
384 |
-
max_range = min(i+forward_batch_size, total_batch_length)
|
385 |
-
perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
|
386 |
-
outputs, mini_max_len = compute_batch_embeddings(perturbation_minibatch)
|
387 |
-
|
388 |
-
if len(indices_to_perturb)>1:
|
389 |
-
minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
|
390 |
-
else:
|
391 |
-
minibatch_emb = outputs.hidden_states[layer_to_quant]
|
392 |
-
|
393 |
-
if perturb_type == "overexpress":
|
394 |
-
# remove overexpressed genes to quantify effect on remaining genes
|
395 |
-
if perturb_group == False:
|
396 |
-
overexpressed_to_remove = 1
|
397 |
-
if perturb_group == True:
|
398 |
-
overexpressed_to_remove = len(tokens_to_perturb)
|
399 |
-
minibatch_emb = minibatch_emb[:, overexpressed_to_remove: ,:]
|
400 |
-
|
401 |
-
|
402 |
-
# if quantifying single perturbation in multiple different cells, pad original batch and extract embs
|
403 |
-
if perturb_group == True:
|
404 |
-
# pad minibatch of original batch to extract embeddings
|
405 |
-
# truncate to the (model input size - # tokens to overexpress) to ensure comparability
|
406 |
-
# since max input size of perturb batch will be reduced by # tokens to overexpress
|
407 |
-
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
408 |
-
original_outputs, orig_max_len = compute_batch_embeddings(original_minibatch, mini_max_len)
|
409 |
-
|
410 |
-
if len(indices_to_perturb)>1:
|
411 |
-
original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
|
412 |
-
else:
|
413 |
-
original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
|
414 |
-
|
415 |
-
# if we overexpress genes that aren't already expressed,
|
416 |
-
# we need to remove genes to make sure the embeddings are of a consistent size
|
417 |
-
# get rid of the bottom n genes/padding since those will get truncated anyways
|
418 |
-
# multiple perturbations is more complicated because if 1 out of n perturbed genes is expressed
|
419 |
-
# the idxs will still not be [-100]
|
420 |
-
if len(tokens_to_perturb) == 1:
|
421 |
-
indices_to_perturb_minibatch = [idx if idx != [-100] else [orig_max_len - 1]
|
422 |
-
for idx in indices_to_perturb[i:max_range]]
|
423 |
-
else:
|
424 |
-
num_perturbed = len(tokens_to_perturb)
|
425 |
-
indices_to_perturb_minibatch = []
|
426 |
-
end_range = [i for i in range(orig_max_len - tokens_to_perturb, orig_max_len)]
|
427 |
-
for idx in indices_to_perturb[i:i+max_range]:
|
428 |
-
if idx == [-100]:
|
429 |
-
indices_to_perturb_minibatch.append(end_range)
|
430 |
-
elif len(idx) < len(tokens_to_perturb):
|
431 |
-
indices_to_perturb_minibatch.append(idx + end_range[-num_perturbed:])
|
432 |
-
else:
|
433 |
-
indices_to_perturb_minibatch.append(idx)
|
434 |
-
|
435 |
-
original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
|
436 |
-
indices_to_perturb_minibatch,
|
437 |
-
gene_dim=1)
|
438 |
-
|
439 |
-
# cosine similarity between original emb and batch items
|
440 |
-
if cell_states_to_model is None:
|
441 |
-
if perturb_group == False:
|
442 |
-
minibatch_comparison = comparison_batch[i:max_range]
|
443 |
-
elif perturb_group == True:
|
444 |
-
minibatch_comparison = original_minibatch_emb
|
445 |
-
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
446 |
-
elif cell_states_to_model is not None:
|
447 |
-
if perturb_group == False:
|
448 |
-
original_emb = comparison_batch[i:max_range]
|
449 |
-
else:
|
450 |
-
original_minibatch_lengths = torch.tensor(original_minibatch["length"], device="cuda")
|
451 |
-
minibatch_lengths = torch.tensor(perturbation_minibatch["length"], device="cuda")
|
452 |
-
for state in possible_states:
|
453 |
-
if perturb_group == False:
|
454 |
-
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb,
|
455 |
-
minibatch_emb,
|
456 |
-
state_embs_dict[state],
|
457 |
-
perturb_group)
|
458 |
-
elif perturb_group == True:
|
459 |
-
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
|
460 |
-
minibatch_emb,
|
461 |
-
state_embs_dict[state],
|
462 |
-
perturb_group,
|
463 |
-
original_minibatch_lengths,
|
464 |
-
minibatch_lengths)
|
465 |
-
del outputs
|
466 |
-
del minibatch_emb
|
467 |
-
if cell_states_to_model is None:
|
468 |
-
del minibatch_comparison
|
469 |
-
if perturb_group == True:
|
470 |
-
del original_minibatch_emb
|
471 |
-
torch.cuda.empty_cache()
|
472 |
-
if cell_states_to_model is None:
|
473 |
-
cos_sims_stack = torch.cat(cos_sims)
|
474 |
-
return cos_sims_stack
|
475 |
-
else:
|
476 |
-
for state in possible_states:
|
477 |
-
cos_sims_vs_alt_dict[state] = torch.cat(cos_sims_vs_alt_dict[state])
|
478 |
-
return cos_sims_vs_alt_dict
|
479 |
-
|
480 |
-
|
481 |
-
# calculate cos sim shift of perturbation with respect to origin and alternative cell
|
482 |
-
def cos_sim_shift(original_emb,
|
483 |
-
minibatch_emb,
|
484 |
-
end_emb,
|
485 |
-
perturb_group,
|
486 |
-
original_minibatch_lengths = None,
|
487 |
-
minibatch_lengths = None):
|
488 |
-
cos = torch.nn.CosineSimilarity(dim=2)
|
489 |
-
if original_emb.size() != minibatch_emb.size():
|
490 |
-
logger.error(
|
491 |
-
f"Embeddings are not the same dimensions. " \
|
492 |
-
f"original_emb is {original_emb.size()}. " \
|
493 |
-
f"minibatch_emb is {minibatch_emb.size()}. "
|
494 |
-
)
|
495 |
-
raise
|
496 |
-
if not perturb_group:
|
497 |
-
original_emb = torch.mean(original_emb,dim=1,keepdim=True)
|
498 |
-
origin_v_end = torch.squeeze(cos(original_emb, end_emb))
|
499 |
-
else:
|
500 |
-
if original_minibatch_lengths is not None:
|
501 |
-
original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
|
502 |
-
# else:
|
503 |
-
# original_emb = torch.mean(original_emb,dim=1,keepdim=True)
|
504 |
-
|
505 |
-
end_emb = torch.unsqueeze(end_emb, 1)
|
506 |
-
origin_v_end = torch.squeeze(cos(original_emb, end_emb))
|
507 |
-
if minibatch_lengths is not None:
|
508 |
-
perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
|
509 |
-
else:
|
510 |
-
perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
|
511 |
-
perturb_v_end = cos(perturb_emb, end_emb)
|
512 |
-
perturb_v_end = torch.squeeze(perturb_v_end)
|
513 |
-
if (perturb_v_end-origin_v_end).numel() == 1:
|
514 |
-
return [([perturb_v_end-origin_v_end]).to("cpu")]
|
515 |
-
return [(perturb_v_end-origin_v_end).to("cpu")]
|
516 |
-
|
517 |
-
def pad_list(input_ids, pad_token_id, max_len):
|
518 |
-
input_ids = np.pad(input_ids,
|
519 |
-
(0, max_len-len(input_ids)),
|
520 |
-
mode='constant', constant_values=pad_token_id)
|
521 |
-
return input_ids
|
522 |
-
|
523 |
-
def pad_tensor(tensor, pad_token_id, max_len):
|
524 |
-
tensor = torch.nn.functional.pad(tensor, pad=(0,
|
525 |
-
max_len - tensor.numel()),
|
526 |
-
mode='constant',
|
527 |
-
value=pad_token_id)
|
528 |
-
return tensor
|
529 |
-
|
530 |
-
def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
|
531 |
-
if dim == 0:
|
532 |
-
pad = (0, 0, 0, max_len - tensor.size()[dim])
|
533 |
-
elif dim == 1:
|
534 |
-
pad = (0, max_len - tensor.size()[dim], 0, 0)
|
535 |
-
tensor = torch.nn.functional.pad(tensor, pad=pad,
|
536 |
-
mode='constant',
|
537 |
-
value=pad_token_id)
|
538 |
-
return tensor
|
539 |
-
|
540 |
-
def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
|
541 |
-
if isinstance(encoding, torch.Tensor):
|
542 |
-
encoding_len = tensor.size()[0]
|
543 |
-
elif isinstance(encoding, list):
|
544 |
-
encoding_len = len(encoding)
|
545 |
-
if encoding_len > max_len:
|
546 |
-
encoding = encoding[0:max_len]
|
547 |
-
elif encoding_len < max_len:
|
548 |
-
if isinstance(encoding, torch.Tensor):
|
549 |
-
encoding = pad_tensor(encoding, pad_token_id, max_len)
|
550 |
-
elif isinstance(encoding, list):
|
551 |
-
encoding = pad_list(encoding, pad_token_id, max_len)
|
552 |
-
return encoding
|
553 |
-
|
554 |
-
# pad list of tensors and convert to tensor
|
555 |
-
def pad_tensor_list(tensor_list, dynamic_or_constant, pad_token_id, model_input_size):
|
556 |
-
|
557 |
-
# Determine maximum tensor length
|
558 |
-
if dynamic_or_constant == "dynamic":
|
559 |
-
max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
|
560 |
-
elif type(dynamic_or_constant) == int:
|
561 |
-
max_len = dynamic_or_constant
|
562 |
-
else:
|
563 |
-
max_len = model_input_size
|
564 |
-
logger.warning(
|
565 |
-
"If padding style is constant, must provide integer value. " \
|
566 |
-
f"Setting padding to max input size {model_input_size}.")
|
567 |
-
|
568 |
-
# pad all tensors to maximum length
|
569 |
-
tensor_list = [pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list]
|
570 |
-
|
571 |
-
# return stacked tensors
|
572 |
-
return torch.stack(tensor_list)
|
573 |
-
|
574 |
-
def gen_attention_mask(minibatch_encoding, max_len = None):
|
575 |
-
if max_len == None:
|
576 |
-
max_len = max(minibatch_encoding["length"])
|
577 |
-
original_lens = minibatch_encoding["length"]
|
578 |
-
attention_mask = [[1]*original_len
|
579 |
-
+[0]*(max_len - original_len)
|
580 |
-
if original_len <= max_len
|
581 |
-
else [1]*max_len
|
582 |
-
for original_len in original_lens]
|
583 |
-
return torch.tensor(attention_mask).to("cuda")
|
584 |
-
|
585 |
-
# get cell embeddings excluding padding
|
586 |
-
def mean_nonpadding_embs(embs, original_lens):
|
587 |
-
# mask based on padding lengths
|
588 |
-
mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
|
589 |
-
|
590 |
-
# extend mask dimensions to match the embeddings tensor
|
591 |
-
mask = mask.unsqueeze(2).expand_as(embs)
|
592 |
-
|
593 |
-
# use the mask to zero out the embeddings in padded areas
|
594 |
-
masked_embs = embs * mask.float()
|
595 |
-
|
596 |
-
# sum and divide by the lengths to get the mean of non-padding embs
|
597 |
-
mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
|
598 |
-
return mean_embs
|
599 |
|
600 |
class InSilicoPerturber:
|
601 |
valid_option_dict = {
|
602 |
-
"perturb_type": {"delete","overexpress","inhibit","activate"},
|
603 |
"perturb_rank_shift": {None, 1, 2, 3},
|
604 |
"genes_to_perturb": {"all", list},
|
605 |
"combos": {0, 1},
|
606 |
"anchor_gene": {None, str},
|
607 |
-
"model_type": {"Pretrained","GeneClassifier","CellClassifier"},
|
608 |
"num_classes": {int},
|
609 |
-
"emb_mode": {"cell","cell_and_gene"},
|
610 |
"cell_emb_style": {"mean_pool"},
|
611 |
"filter_data": {None, dict},
|
612 |
"cell_states_to_model": {None, dict},
|
|
|
613 |
"max_ncells": {None, int},
|
614 |
"cell_inds_to_perturb": {"all", dict},
|
615 |
"emb_layer": {-1, 0},
|
616 |
"forward_batch_size": {int},
|
617 |
"nproc": {int},
|
618 |
}
|
|
|
619 |
def __init__(
|
620 |
self,
|
621 |
perturb_type="delete",
|
@@ -629,6 +81,7 @@ class InSilicoPerturber:
|
|
629 |
cell_emb_style="mean_pool",
|
630 |
filter_data=None,
|
631 |
cell_states_to_model=None,
|
|
|
632 |
max_ncells=None,
|
633 |
cell_inds_to_perturb="all",
|
634 |
emb_layer=-1,
|
@@ -676,13 +129,14 @@ class InSilicoPerturber:
|
|
676 |
For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
677 |
emb_mode : {"cell","cell_and_gene"}
|
678 |
Whether to output impact of perturbation on cell and/or gene embeddings.
|
|
|
679 |
cell_emb_style : "mean_pool"
|
680 |
Method for summarizing cell embeddings.
|
681 |
Currently only option is mean pooling of gene embeddings for given cell.
|
682 |
filter_data : None, dict
|
683 |
Default is to use all input data for in silico perturbation study.
|
684 |
Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
685 |
-
cell_states_to_model: None, dict
|
686 |
Cell states to model if testing perturbations that achieve goal state change.
|
687 |
Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
688 |
state_key: key specifying name of column in .dataset that defines the start/goal states
|
@@ -693,6 +147,15 @@ class InSilicoPerturber:
|
|
693 |
"start_state": "dcm",
|
694 |
"goal_state": "nf",
|
695 |
"alt_states": ["hcm", "other1", "other2"]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
696 |
max_ncells : None, int
|
697 |
Maximum number of cells to test.
|
698 |
If None, will test all cells.
|
@@ -705,8 +168,8 @@ class InSilicoPerturber:
|
|
705 |
Useful for splitting extremely large datasets across separate GPUs.
|
706 |
emb_layer : {-1, 0}
|
707 |
Embedding layer to use for quantification.
|
708 |
-
|
709 |
-
|
710 |
forward_batch_size : int
|
711 |
Batch size for forward pass.
|
712 |
nproc : int
|
@@ -721,23 +184,25 @@ class InSilicoPerturber:
|
|
721 |
self.combos = combos
|
722 |
self.anchor_gene = anchor_gene
|
723 |
if self.genes_to_perturb == "all":
|
724 |
-
self.perturb_group = False
|
725 |
else:
|
726 |
self.perturb_group = True
|
727 |
-
if (self.anchor_gene
|
728 |
self.anchor_gene = None
|
729 |
self.combos = 0
|
730 |
logger.warning(
|
731 |
-
"anchor_gene set to None and combos set to 0. "
|
732 |
-
"If providing list of genes to perturb, "
|
733 |
-
"list of genes_to_perturb will be perturbed together, "
|
734 |
-
"without anchor gene or combinations."
|
|
|
735 |
self.model_type = model_type
|
736 |
self.num_classes = num_classes
|
737 |
self.emb_mode = emb_mode
|
738 |
self.cell_emb_style = cell_emb_style
|
739 |
self.filter_data = filter_data
|
740 |
self.cell_states_to_model = cell_states_to_model
|
|
|
741 |
self.max_ncells = max_ncells
|
742 |
self.cell_inds_to_perturb = cell_inds_to_perturb
|
743 |
self.emb_layer = emb_layer
|
@@ -758,36 +223,47 @@ class InSilicoPerturber:
|
|
758 |
try:
|
759 |
self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
|
760 |
except KeyError:
|
761 |
-
logger.error(
|
762 |
-
f"Anchor gene {self.anchor_gene} not in token dictionary."
|
763 |
-
)
|
764 |
raise
|
765 |
|
766 |
if self.genes_to_perturb == "all":
|
767 |
self.tokens_to_perturb = "all"
|
768 |
else:
|
769 |
-
missing_genes = [
|
|
|
|
|
|
|
|
|
770 |
if len(missing_genes) == len(self.genes_to_perturb):
|
771 |
logger.error(
|
772 |
"None of the provided genes to perturb are in token dictionary."
|
773 |
)
|
774 |
raise
|
775 |
-
elif len(missing_genes)>0:
|
776 |
logger.warning(
|
777 |
-
f"Genes to perturb {missing_genes} are not in token dictionary."
|
778 |
-
|
|
|
|
|
|
|
779 |
|
780 |
def validate_options(self):
|
781 |
# first disallow options under development
|
782 |
if self.perturb_type in ["inhibit", "activate"]:
|
783 |
logger.error(
|
784 |
-
"In silico inhibition and activation currently under development. "
|
785 |
"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
|
786 |
)
|
787 |
raise
|
788 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
# confirm arguments are within valid options and compatible with each other
|
790 |
-
for attr_name,valid_options in self.valid_option_dict.items():
|
791 |
attr_value = self.__dict__[attr_name]
|
792 |
if type(attr_value) not in {list, dict}:
|
793 |
if attr_value in valid_options:
|
@@ -797,141 +273,120 @@ class InSilicoPerturber:
|
|
797 |
continue
|
798 |
valid_type = False
|
799 |
for option in valid_options:
|
800 |
-
if (option in [int,list,dict]) and isinstance(
|
|
|
|
|
801 |
valid_type = True
|
802 |
break
|
803 |
if valid_type:
|
804 |
continue
|
805 |
logger.error(
|
806 |
-
f"Invalid option for {attr_name}. "
|
807 |
f"Valid options for {attr_name}: {valid_options}"
|
808 |
)
|
809 |
raise
|
810 |
-
|
811 |
-
if self.perturb_type in ["delete","overexpress"]:
|
812 |
if self.perturb_rank_shift is not None:
|
813 |
if self.perturb_type == "delete":
|
814 |
logger.warning(
|
815 |
-
"perturb_rank_shift set to None. "
|
816 |
-
"If perturb type is delete then gene is deleted entirely "
|
817 |
-
"rather than shifted by quartile"
|
|
|
818 |
elif self.perturb_type == "overexpress":
|
819 |
logger.warning(
|
820 |
-
"perturb_rank_shift set to None. "
|
821 |
-
"If perturb type is overexpress then gene is moved to front "
|
822 |
-
"of rank value encoding rather than shifted by quartile"
|
|
|
823 |
self.perturb_rank_shift = None
|
824 |
-
|
825 |
if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
|
826 |
self.emb_mode = "cell"
|
827 |
logger.warning(
|
828 |
-
"emb_mode set to 'cell'. "
|
829 |
-
"Currently, analysis with anchor gene "
|
830 |
-
"only outputs effect on cell embeddings."
|
831 |
-
|
|
|
832 |
if self.cell_states_to_model is not None:
|
833 |
-
|
|
|
|
|
|
|
834 |
logger.warning(
|
835 |
-
"
|
836 |
-
"
|
837 |
-
"
|
838 |
-
"in the cell_states_to_model dictionary for future use. " \
|
839 |
-
"For example, cell_states_to_model={" \
|
840 |
-
"'state_key': 'disease', " \
|
841 |
-
"'start_state': 'dcm', " \
|
842 |
-
"'goal_state': 'nf', " \
|
843 |
-
"'alt_states': ['hcm', 'other1', 'other2']}"
|
844 |
)
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
"start_state": state_values[0][0],
|
857 |
-
"goal_state": state_values[1][0],
|
858 |
-
"alt_states": state_values[2:][0]
|
859 |
-
}
|
860 |
-
elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
|
861 |
-
if (self.cell_states_to_model["state_key"] is None) \
|
862 |
-
or (self.cell_states_to_model["start_state"] is None) \
|
863 |
-
or (self.cell_states_to_model["goal_state"] is None):
|
864 |
-
logger.error(
|
865 |
-
"Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
|
866 |
-
raise
|
867 |
-
|
868 |
-
if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
|
869 |
logger.error(
|
870 |
-
"
|
|
|
871 |
raise
|
872 |
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
)
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
else:
|
885 |
logger.error(
|
886 |
-
"
|
887 |
-
"
|
888 |
-
"
|
889 |
-
|
890 |
-
"'start_state': 'dcm', " \
|
891 |
-
"'goal_state': 'nf', " \
|
892 |
-
"'alt_states': ['hcm', 'other1', 'other2']}"
|
893 |
)
|
894 |
raise
|
895 |
|
896 |
-
|
897 |
-
self.anchor_gene = None
|
898 |
-
logger.warning(
|
899 |
-
"anchor_gene set to None. " \
|
900 |
-
"Currently, anchor gene not available " \
|
901 |
-
"when modeling multiple cell states.")
|
902 |
-
|
903 |
-
if self.perturb_type in ["inhibit","activate"]:
|
904 |
if self.perturb_rank_shift is None:
|
905 |
logger.error(
|
906 |
-
"If perturb_type is inhibit or activate then "
|
907 |
-
"quartile to shift by must be specified."
|
|
|
908 |
raise
|
909 |
-
|
910 |
if self.filter_data is not None:
|
911 |
-
for key,value in self.filter_data.items():
|
912 |
-
if
|
913 |
self.filter_data[key] = [value]
|
914 |
logger.warning(
|
915 |
-
"Values in filter_data dict must be lists. "
|
916 |
-
f"Changing {key} value to list ([{value}])."
|
917 |
-
|
|
|
918 |
if self.cell_inds_to_perturb != "all":
|
919 |
if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
|
920 |
logger.error(
|
921 |
"If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
|
922 |
)
|
923 |
raise
|
924 |
-
if
|
925 |
-
|
926 |
-
|
927 |
-
|
|
|
928 |
raise
|
929 |
|
930 |
-
def perturb_data(
|
931 |
-
|
932 |
-
|
933 |
-
output_directory,
|
934 |
-
output_prefix):
|
935 |
"""
|
936 |
Perturb genes in input data and save as results in output_directory.
|
937 |
|
@@ -947,365 +402,506 @@ class InSilicoPerturber:
|
|
947 |
Prefix for output files
|
948 |
"""
|
949 |
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
-
|
955 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
956 |
else:
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
|
961 |
-
|
962 |
-
|
963 |
-
|
964 |
-
|
965 |
-
|
966 |
-
|
967 |
-
|
968 |
-
|
969 |
-
self.cell_states_to_model,
|
970 |
-
layer_to_quant,
|
971 |
-
self.pad_token_id,
|
972 |
-
self.forward_batch_size,
|
973 |
-
self.nproc)
|
974 |
-
# filter for start state cells
|
975 |
-
start_state = self.cell_states_to_model["start_state"]
|
976 |
-
def filter_for_origin(example):
|
977 |
-
return example[state_name] in [start_state]
|
978 |
-
|
979 |
-
filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
|
980 |
-
|
981 |
-
self.in_silico_perturb(model,
|
982 |
-
filtered_input_data,
|
983 |
-
layer_to_quant,
|
984 |
-
state_embs_dict,
|
985 |
-
output_directory,
|
986 |
-
output_prefix)
|
987 |
-
|
988 |
-
# determine effect of perturbation on other genes
|
989 |
-
def in_silico_perturb(self,
|
990 |
-
model,
|
991 |
-
filtered_input_data,
|
992 |
-
layer_to_quant,
|
993 |
-
state_embs_dict,
|
994 |
-
output_directory,
|
995 |
-
output_prefix):
|
996 |
-
|
997 |
-
output_path_prefix = f"{output_directory}in_silico_{self.perturb_type}_{output_prefix}_dict_1Kbatch"
|
998 |
-
model_input_size = get_model_input_size(model)
|
999 |
-
|
1000 |
-
# filter dataset for cells that have tokens to be perturbed
|
1001 |
-
if self.anchor_token is not None:
|
1002 |
-
def if_has_tokens_to_perturb(example):
|
1003 |
-
return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
|
1004 |
-
filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
|
1005 |
-
if len(filtered_input_data) == 0:
|
1006 |
-
logger.error(
|
1007 |
-
"No cells in dataset contain anchor gene.")
|
1008 |
-
raise
|
1009 |
-
else:
|
1010 |
-
logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
|
1011 |
-
|
1012 |
if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
|
1013 |
-
#
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1027 |
if self.cell_inds_to_perturb != "all":
|
1028 |
-
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
|
1040 |
-
|
1041 |
-
|
1042 |
-
|
1043 |
-
|
1044 |
-
|
1045 |
-
|
1046 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1047 |
else:
|
1048 |
-
|
1049 |
-
|
1050 |
-
|
1051 |
-
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
|
1056 |
-
|
1057 |
-
|
1058 |
-
|
1059 |
-
|
1060 |
-
|
1061 |
-
|
1062 |
-
|
1063 |
-
|
1064 |
-
|
1065 |
-
|
1066 |
-
|
1067 |
-
|
1068 |
-
|
1069 |
-
|
1070 |
-
|
1071 |
-
|
1072 |
-
|
1073 |
-
|
1074 |
-
|
1075 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1076 |
if self.cell_states_to_model is None:
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
gene_list = filtered_input_data[j]["input_ids"]
|
1086 |
-
indices_removed = indices_to_perturb[j]
|
1087 |
-
padding_to_remove = max_padded_len - (original_length \
|
1088 |
-
- len(self.tokens_to_perturb) \
|
1089 |
-
- len(indices_removed))
|
1090 |
-
nonpadding_cos_sims_data = cos_sims_data[j][:-padding_to_remove]
|
1091 |
-
cell_cos_sim = torch.mean(nonpadding_cos_sims_data).item()
|
1092 |
-
cos_sims_dict[(perturbed_genes, "cell_emb")] += [cell_cos_sim]
|
1093 |
-
|
1094 |
-
if self.emb_mode == "cell_and_gene":
|
1095 |
-
for k in range(cos_sims_data.shape[1]):
|
1096 |
-
cos_sim_value = nonpadding_cos_sims_data[k]
|
1097 |
-
affected_gene = gene_list[k].item()
|
1098 |
-
cos_sims_dict[(perturbed_genes, affected_gene)] += [cos_sim_value.item()]
|
1099 |
else:
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
if self.anchor_token is None:
|
1127 |
-
for combo_lvl in range(self.combos+1):
|
1128 |
-
perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
|
1129 |
-
self.perturb_type,
|
1130 |
-
self.tokens_to_perturb,
|
1131 |
-
self.anchor_token,
|
1132 |
-
combo_lvl,
|
1133 |
-
self.nproc)
|
1134 |
-
cos_sims_data = quant_cos_sims(model,
|
1135 |
-
self.perturb_type,
|
1136 |
-
perturbation_batch,
|
1137 |
-
self.forward_batch_size,
|
1138 |
-
layer_to_quant,
|
1139 |
-
original_emb,
|
1140 |
-
self.tokens_to_perturb,
|
1141 |
-
indices_to_perturb,
|
1142 |
-
self.perturb_group,
|
1143 |
-
self.cell_states_to_model,
|
1144 |
-
state_embs_dict,
|
1145 |
-
self.pad_token_id,
|
1146 |
-
model_input_size,
|
1147 |
-
self.nproc)
|
1148 |
-
|
1149 |
-
if self.cell_states_to_model is None:
|
1150 |
-
# update cos sims dict
|
1151 |
-
# key is tuple of (perturbed_gene, affected_gene)
|
1152 |
-
# or (perturbed_gene, "cell_emb") for avg cell emb change
|
1153 |
-
cos_sims_data = cos_sims_data.to("cuda")
|
1154 |
-
for j in range(cos_sims_data.shape[0]):
|
1155 |
-
if self.tokens_to_perturb != "all":
|
1156 |
-
j_index = torch.tensor(indices_to_perturb[j])
|
1157 |
-
if j_index.shape[0]>1:
|
1158 |
-
j_index = torch.squeeze(j_index)
|
1159 |
-
else:
|
1160 |
-
j_index = torch.tensor([j])
|
1161 |
-
|
1162 |
-
if self.perturb_type in ("overexpress", "activate"):
|
1163 |
-
perturbed_gene = torch.index_select(gene_list, 0, j_index + 1)
|
1164 |
-
else:
|
1165 |
-
perturbed_gene = torch.index_select(gene_list, 0, j_index)
|
1166 |
-
|
1167 |
-
if perturbed_gene.shape[0]==1:
|
1168 |
-
perturbed_gene = perturbed_gene.item()
|
1169 |
-
elif perturbed_gene.shape[0]>1:
|
1170 |
-
perturbed_gene = tuple(perturbed_gene.tolist())
|
1171 |
-
|
1172 |
-
cell_cos_sim = torch.mean(cos_sims_data[j]).item()
|
1173 |
-
cos_sims_dict[(perturbed_gene, "cell_emb")] += [cell_cos_sim]
|
1174 |
-
|
1175 |
-
# not_j_index = list(set(i for i in range(gene_list.shape[0])).difference(j_index))
|
1176 |
-
# gene_list_j = torch.index_select(gene_list, 0, j_index)
|
1177 |
-
if self.emb_mode == "cell_and_gene":
|
1178 |
-
for k in range(cos_sims_data.shape[1]):
|
1179 |
-
cos_sim_value = cos_sims_data[j][k]
|
1180 |
-
affected_gene = gene_list[k].item()
|
1181 |
-
cos_sims_dict[(perturbed_gene, affected_gene)] += [cos_sim_value.item()]
|
1182 |
-
else:
|
1183 |
-
# update cos sims dict
|
1184 |
-
# key is tuple of (perturbed_gene, "cell_emb")
|
1185 |
-
# value is list of tuples of cos sims for cell_states_to_model
|
1186 |
-
origin_state_key = self.cell_states_to_model["start_state"]
|
1187 |
-
cos_sims_origin = cos_sims_data[origin_state_key]
|
1188 |
-
|
1189 |
-
for j in range(cos_sims_origin.shape[0]):
|
1190 |
-
if (self.tokens_to_perturb != "all") or (combo_lvl>0):
|
1191 |
-
j_index = torch.tensor(indices_to_perturb[j])
|
1192 |
-
if j_index.shape[0]>1:
|
1193 |
-
j_index = torch.squeeze(j_index)
|
1194 |
-
else:
|
1195 |
-
j_index = torch.tensor([j])
|
1196 |
-
|
1197 |
-
if self.perturb_type in ("overexpress", "activate"):
|
1198 |
-
perturbed_gene = torch.index_select(gene_list, 0, j_index + 1)
|
1199 |
-
else:
|
1200 |
-
perturbed_gene = torch.index_select(gene_list, 0, j_index)
|
1201 |
-
|
1202 |
-
if perturbed_gene.shape[0]==1:
|
1203 |
-
perturbed_gene = perturbed_gene.item()
|
1204 |
-
elif perturbed_gene.shape[0]>1:
|
1205 |
-
perturbed_gene = tuple(perturbed_gene.tolist())
|
1206 |
-
|
1207 |
-
data_list = []
|
1208 |
-
for data in list(cos_sims_data.values()):
|
1209 |
-
data_item = data.to("cuda")
|
1210 |
-
cell_data = torch.mean(data_item[j]).item()
|
1211 |
-
data_list += [cell_data]
|
1212 |
-
cos_sims_dict[(perturbed_gene, "cell_emb")] += [tuple(data_list)]
|
1213 |
-
|
1214 |
-
elif self.anchor_token is not None:
|
1215 |
-
perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
|
1216 |
-
self.perturb_type,
|
1217 |
-
self.tokens_to_perturb,
|
1218 |
-
None, # first run without anchor token to test individual gene perturbations
|
1219 |
-
0,
|
1220 |
-
self.nproc)
|
1221 |
-
cos_sims_data = quant_cos_sims(model,
|
1222 |
-
self.perturb_type,
|
1223 |
-
perturbation_batch,
|
1224 |
-
self.forward_batch_size,
|
1225 |
-
layer_to_quant,
|
1226 |
-
original_emb,
|
1227 |
-
self.tokens_to_perturb,
|
1228 |
-
indices_to_perturb,
|
1229 |
-
self.perturb_group,
|
1230 |
-
self.cell_states_to_model,
|
1231 |
-
state_embs_dict,
|
1232 |
-
self.pad_token_id,
|
1233 |
-
model_input_size,
|
1234 |
-
self.nproc)
|
1235 |
-
cos_sims_data = cos_sims_data.to("cuda")
|
1236 |
-
|
1237 |
-
combo_perturbation_batch, combo_indices_to_perturb = make_perturbation_batch(example_cell,
|
1238 |
-
self.perturb_type,
|
1239 |
-
self.tokens_to_perturb,
|
1240 |
-
self.anchor_token,
|
1241 |
-
1,
|
1242 |
-
self.nproc)
|
1243 |
-
combo_cos_sims_data = quant_cos_sims(model,
|
1244 |
-
self.perturb_type,
|
1245 |
-
combo_perturbation_batch,
|
1246 |
-
self.forward_batch_size,
|
1247 |
-
layer_to_quant,
|
1248 |
-
original_emb,
|
1249 |
-
self.tokens_to_perturb,
|
1250 |
-
combo_indices_to_perturb,
|
1251 |
-
self.perturb_group,
|
1252 |
-
self.cell_states_to_model,
|
1253 |
-
state_embs_dict,
|
1254 |
-
self.pad_token_id,
|
1255 |
-
model_input_size,
|
1256 |
-
self.nproc)
|
1257 |
-
combo_cos_sims_data = combo_cos_sims_data.to("cuda")
|
1258 |
-
|
1259 |
-
# update cos sims dict
|
1260 |
-
# key is tuple of (perturbed_gene, "cell_emb") for avg cell emb change
|
1261 |
-
anchor_index = example_cell["input_ids"][0].index(self.anchor_token[0])
|
1262 |
-
anchor_cell_cos_sim = torch.mean(cos_sims_data[anchor_index]).item()
|
1263 |
-
non_anchor_indices = [k for k in range(cos_sims_data.shape[0]) if k != anchor_index]
|
1264 |
-
cos_sims_data = cos_sims_data[non_anchor_indices,:]
|
1265 |
-
|
1266 |
-
for j in range(cos_sims_data.shape[0]):
|
1267 |
-
|
1268 |
-
if j<anchor_index:
|
1269 |
-
j_index = torch.tensor([j])
|
1270 |
-
else:
|
1271 |
-
j_index = torch.tensor([j+1])
|
1272 |
-
|
1273 |
-
perturbed_gene = torch.index_select(gene_list, 0, j_index)
|
1274 |
-
perturbed_gene = perturbed_gene.item()
|
1275 |
-
|
1276 |
-
cell_cos_sim = torch.mean(cos_sims_data[j]).item()
|
1277 |
-
combo_cos_sim = torch.mean(combo_cos_sims_data[j]).item()
|
1278 |
-
cos_sims_dict[(perturbed_gene, "cell_emb")] += [(anchor_cell_cos_sim, # cos sim anchor gene alone
|
1279 |
-
cell_cos_sim, # cos sim deleted gene alone
|
1280 |
-
combo_cos_sim)] # cos sim anchor gene + deleted gene
|
1281 |
-
|
1282 |
-
# save dict to disk every 100 cells
|
1283 |
-
if (i/100).is_integer():
|
1284 |
-
with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
|
1285 |
-
pickle.dump(cos_sims_dict, fp)
|
1286 |
-
# reset and clear memory every 1000 cells
|
1287 |
-
if (i/1000).is_integer():
|
1288 |
-
pickle_batch = pickle_batch+1
|
1289 |
-
# clear memory
|
1290 |
-
del perturbed_gene
|
1291 |
-
del cos_sims_data
|
1292 |
-
if self.cell_states_to_model is None:
|
1293 |
-
del cell_cos_sim
|
1294 |
-
if self.cell_states_to_model is not None:
|
1295 |
-
del cell_data
|
1296 |
-
del data_list
|
1297 |
-
elif self.anchor_token is None:
|
1298 |
-
if self.emb_mode == "cell_and_gene":
|
1299 |
-
del affected_gene
|
1300 |
-
del cos_sim_value
|
1301 |
-
else:
|
1302 |
-
del combo_cos_sim
|
1303 |
-
del combo_cos_sims_data
|
1304 |
-
# reset dict
|
1305 |
-
del cos_sims_dict
|
1306 |
cos_sims_dict = defaultdict(list)
|
1307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1308 |
|
1309 |
-
|
1310 |
-
with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
|
1311 |
-
pickle.dump(cos_sims_dict, fp)
|
|
|
8 |
genes_to_perturb="all",
|
9 |
combos=0,
|
10 |
anchor_gene=None,
|
11 |
+
model_type="CellClassifier",
|
12 |
num_classes=0,
|
13 |
emb_mode="cell",
|
14 |
cell_emb_style="mean_pool",
|
15 |
filter_data={"cell_type":["cardiomyocyte"]},
|
16 |
cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
|
17 |
+
state_embs_dict ={"nf": emb_nf, "hcm": emb_hcm, "dcm": emb_dcm, "other1": emb_other1, "other2": emb_other2},
|
18 |
max_ncells=None,
|
19 |
+
emb_layer=0,
|
20 |
forward_batch_size=100,
|
21 |
+
nproc=16)
|
22 |
isp.perturb_data("path/to/model",
|
23 |
"path/to/input_data",
|
24 |
"path/to/output_directory",
|
25 |
"output_prefix")
|
26 |
"""
|
27 |
|
|
|
|
|
28 |
import logging
|
29 |
+
|
30 |
+
# imports
|
31 |
+
import os
|
32 |
import pickle
|
|
|
|
|
|
|
33 |
from collections import defaultdict
|
34 |
+
|
35 |
+
import seaborn as sns
|
36 |
+
import torch
|
37 |
+
from datasets import Dataset
|
38 |
from tqdm.auto import trange
|
|
|
39 |
|
40 |
+
from . import perturber_utils as pu
|
41 |
+
from .emb_extractor import get_embs
|
42 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
43 |
|
44 |
+
sns.set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
class InSilicoPerturber:
|
51 |
valid_option_dict = {
|
52 |
+
"perturb_type": {"delete", "overexpress", "inhibit", "activate"},
|
53 |
"perturb_rank_shift": {None, 1, 2, 3},
|
54 |
"genes_to_perturb": {"all", list},
|
55 |
"combos": {0, 1},
|
56 |
"anchor_gene": {None, str},
|
57 |
+
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
58 |
"num_classes": {int},
|
59 |
+
"emb_mode": {"cell", "cell_and_gene"},
|
60 |
"cell_emb_style": {"mean_pool"},
|
61 |
"filter_data": {None, dict},
|
62 |
"cell_states_to_model": {None, dict},
|
63 |
+
"state_embs_dict": {None, dict},
|
64 |
"max_ncells": {None, int},
|
65 |
"cell_inds_to_perturb": {"all", dict},
|
66 |
"emb_layer": {-1, 0},
|
67 |
"forward_batch_size": {int},
|
68 |
"nproc": {int},
|
69 |
}
|
70 |
+
|
71 |
def __init__(
|
72 |
self,
|
73 |
perturb_type="delete",
|
|
|
81 |
cell_emb_style="mean_pool",
|
82 |
filter_data=None,
|
83 |
cell_states_to_model=None,
|
84 |
+
state_embs_dict=None,
|
85 |
max_ncells=None,
|
86 |
cell_inds_to_perturb="all",
|
87 |
emb_layer=-1,
|
|
|
129 |
For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
130 |
emb_mode : {"cell","cell_and_gene"}
|
131 |
Whether to output impact of perturbation on cell and/or gene embeddings.
|
132 |
+
Gene embedding shifts only available as compared to original cell, not comparing to goal state.
|
133 |
cell_emb_style : "mean_pool"
|
134 |
Method for summarizing cell embeddings.
|
135 |
Currently only option is mean pooling of gene embeddings for given cell.
|
136 |
filter_data : None, dict
|
137 |
Default is to use all input data for in silico perturbation study.
|
138 |
Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
139 |
+
cell_states_to_model : None, dict
|
140 |
Cell states to model if testing perturbations that achieve goal state change.
|
141 |
Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
142 |
state_key: key specifying name of column in .dataset that defines the start/goal states
|
|
|
147 |
"start_state": "dcm",
|
148 |
"goal_state": "nf",
|
149 |
"alt_states": ["hcm", "other1", "other2"]}
|
150 |
+
state_embs_dict : None, dict
|
151 |
+
Embedding positions of each cell state to model shifts from/towards (e.g. mean or median).
|
152 |
+
Dictionary with keys specifying each possible cell state to model.
|
153 |
+
Values are target embedding positions as torch.tensor.
|
154 |
+
For example: {"nf": emb_nf,
|
155 |
+
"hcm": emb_hcm,
|
156 |
+
"dcm": emb_dcm,
|
157 |
+
"other1": emb_other1,
|
158 |
+
"other2": emb_other2}
|
159 |
max_ncells : None, int
|
160 |
Maximum number of cells to test.
|
161 |
If None, will test all cells.
|
|
|
168 |
Useful for splitting extremely large datasets across separate GPUs.
|
169 |
emb_layer : {-1, 0}
|
170 |
Embedding layer to use for quantification.
|
171 |
+
0: last layer (recommended for questions closely tied to model's training objective)
|
172 |
+
-1: 2nd to last layer (recommended for questions requiring more general representations)
|
173 |
forward_batch_size : int
|
174 |
Batch size for forward pass.
|
175 |
nproc : int
|
|
|
184 |
self.combos = combos
|
185 |
self.anchor_gene = anchor_gene
|
186 |
if self.genes_to_perturb == "all":
|
187 |
+
self.perturb_group = False
|
188 |
else:
|
189 |
self.perturb_group = True
|
190 |
+
if (self.anchor_gene is not None) or (self.combos != 0):
|
191 |
self.anchor_gene = None
|
192 |
self.combos = 0
|
193 |
logger.warning(
|
194 |
+
"anchor_gene set to None and combos set to 0. "
|
195 |
+
"If providing list of genes to perturb, "
|
196 |
+
"list of genes_to_perturb will be perturbed together, "
|
197 |
+
"without anchor gene or combinations."
|
198 |
+
)
|
199 |
self.model_type = model_type
|
200 |
self.num_classes = num_classes
|
201 |
self.emb_mode = emb_mode
|
202 |
self.cell_emb_style = cell_emb_style
|
203 |
self.filter_data = filter_data
|
204 |
self.cell_states_to_model = cell_states_to_model
|
205 |
+
self.state_embs_dict = state_embs_dict
|
206 |
self.max_ncells = max_ncells
|
207 |
self.cell_inds_to_perturb = cell_inds_to_perturb
|
208 |
self.emb_layer = emb_layer
|
|
|
223 |
try:
|
224 |
self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
|
225 |
except KeyError:
|
226 |
+
logger.error(f"Anchor gene {self.anchor_gene} not in token dictionary.")
|
|
|
|
|
227 |
raise
|
228 |
|
229 |
if self.genes_to_perturb == "all":
|
230 |
self.tokens_to_perturb = "all"
|
231 |
else:
|
232 |
+
missing_genes = [
|
233 |
+
gene
|
234 |
+
for gene in self.genes_to_perturb
|
235 |
+
if gene not in self.gene_token_dict.keys()
|
236 |
+
]
|
237 |
if len(missing_genes) == len(self.genes_to_perturb):
|
238 |
logger.error(
|
239 |
"None of the provided genes to perturb are in token dictionary."
|
240 |
)
|
241 |
raise
|
242 |
+
elif len(missing_genes) > 0:
|
243 |
logger.warning(
|
244 |
+
f"Genes to perturb {missing_genes} are not in token dictionary."
|
245 |
+
)
|
246 |
+
self.tokens_to_perturb = [
|
247 |
+
self.gene_token_dict.get(gene) for gene in self.genes_to_perturb
|
248 |
+
]
|
249 |
|
250 |
def validate_options(self):
|
251 |
# first disallow options under development
|
252 |
if self.perturb_type in ["inhibit", "activate"]:
|
253 |
logger.error(
|
254 |
+
"In silico inhibition and activation currently under development. "
|
255 |
"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
|
256 |
)
|
257 |
raise
|
258 |
+
if (self.combos > 0) and (self.anchor_token is None):
|
259 |
+
logger.error(
|
260 |
+
"Combination perturbation without anchor gene is currently under development. "
|
261 |
+
"Currently, must provide anchor gene for combination perturbation."
|
262 |
+
)
|
263 |
+
raise
|
264 |
+
|
265 |
# confirm arguments are within valid options and compatible with each other
|
266 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
267 |
attr_value = self.__dict__[attr_name]
|
268 |
if type(attr_value) not in {list, dict}:
|
269 |
if attr_value in valid_options:
|
|
|
273 |
continue
|
274 |
valid_type = False
|
275 |
for option in valid_options:
|
276 |
+
if (option in [bool, int, list, dict]) and isinstance(
|
277 |
+
attr_value, option
|
278 |
+
):
|
279 |
valid_type = True
|
280 |
break
|
281 |
if valid_type:
|
282 |
continue
|
283 |
logger.error(
|
284 |
+
f"Invalid option for {attr_name}. "
|
285 |
f"Valid options for {attr_name}: {valid_options}"
|
286 |
)
|
287 |
raise
|
288 |
+
|
289 |
+
if self.perturb_type in ["delete", "overexpress"]:
|
290 |
if self.perturb_rank_shift is not None:
|
291 |
if self.perturb_type == "delete":
|
292 |
logger.warning(
|
293 |
+
"perturb_rank_shift set to None. "
|
294 |
+
"If perturb type is delete then gene is deleted entirely "
|
295 |
+
"rather than shifted by quartile"
|
296 |
+
)
|
297 |
elif self.perturb_type == "overexpress":
|
298 |
logger.warning(
|
299 |
+
"perturb_rank_shift set to None. "
|
300 |
+
"If perturb type is overexpress then gene is moved to front "
|
301 |
+
"of rank value encoding rather than shifted by quartile"
|
302 |
+
)
|
303 |
self.perturb_rank_shift = None
|
304 |
+
|
305 |
if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
|
306 |
self.emb_mode = "cell"
|
307 |
logger.warning(
|
308 |
+
"emb_mode set to 'cell'. "
|
309 |
+
"Currently, analysis with anchor gene "
|
310 |
+
"only outputs effect on cell embeddings."
|
311 |
+
)
|
312 |
+
|
313 |
if self.cell_states_to_model is not None:
|
314 |
+
pu.validate_cell_states_to_model(self.cell_states_to_model)
|
315 |
+
|
316 |
+
if self.anchor_gene is not None:
|
317 |
+
self.anchor_gene = None
|
318 |
logger.warning(
|
319 |
+
"anchor_gene set to None. "
|
320 |
+
"Currently, anchor gene not available "
|
321 |
+
"when modeling multiple cell states."
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
)
|
323 |
+
|
324 |
+
if self.state_embs_dict is None:
|
325 |
+
logger.error(
|
326 |
+
"state_embs_dict must be provided for mode with cell_states_to_model. "
|
327 |
+
"Format is dictionary with keys specifying each possible cell state to model. "
|
328 |
+
"Values are target embedding positions as torch.tensor."
|
329 |
+
)
|
330 |
+
raise
|
331 |
+
|
332 |
+
for state_emb in self.state_embs_dict.values():
|
333 |
+
if not torch.is_tensor(state_emb):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
logger.error(
|
335 |
+
"state_embs_dict must be dictionary with values being torch.tensor."
|
336 |
+
)
|
337 |
raise
|
338 |
|
339 |
+
keys_absent = []
|
340 |
+
for k, v in self.cell_states_to_model.items():
|
341 |
+
if (k == "start_state") or (k == "goal_state"):
|
342 |
+
if v not in self.state_embs_dict.keys():
|
343 |
+
keys_absent.append(v)
|
344 |
+
if k == "alt_states":
|
345 |
+
for state in v:
|
346 |
+
if state not in self.state_embs_dict.keys():
|
347 |
+
keys_absent.append(state)
|
348 |
+
if len(keys_absent) > 0:
|
|
|
|
|
349 |
logger.error(
|
350 |
+
"Each start_state, goal_state, and alt_states in cell_states_to_model "
|
351 |
+
"must be a key in state_embs_dict with the value being "
|
352 |
+
"the state's embedding position as torch.tensor. "
|
353 |
+
f"Missing keys: {keys_absent}"
|
|
|
|
|
|
|
354 |
)
|
355 |
raise
|
356 |
|
357 |
+
if self.perturb_type in ["inhibit", "activate"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
if self.perturb_rank_shift is None:
|
359 |
logger.error(
|
360 |
+
"If perturb_type is inhibit or activate then "
|
361 |
+
"quartile to shift by must be specified."
|
362 |
+
)
|
363 |
raise
|
364 |
+
|
365 |
if self.filter_data is not None:
|
366 |
+
for key, value in self.filter_data.items():
|
367 |
+
if not isinstance(value, list):
|
368 |
self.filter_data[key] = [value]
|
369 |
logger.warning(
|
370 |
+
"Values in filter_data dict must be lists. "
|
371 |
+
f"Changing {key} value to list ([{value}])."
|
372 |
+
)
|
373 |
+
|
374 |
if self.cell_inds_to_perturb != "all":
|
375 |
if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
|
376 |
logger.error(
|
377 |
"If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
|
378 |
)
|
379 |
raise
|
380 |
+
if (
|
381 |
+
self.cell_inds_to_perturb["start"] < 0
|
382 |
+
or self.cell_inds_to_perturb["end"] < 0
|
383 |
+
):
|
384 |
+
logger.error("cell_inds_to_perturb must be positive.")
|
385 |
raise
|
386 |
|
387 |
+
def perturb_data(
|
388 |
+
self, model_directory, input_data_file, output_directory, output_prefix
|
389 |
+
):
|
|
|
|
|
390 |
"""
|
391 |
Perturb genes in input data and save as results in output_directory.
|
392 |
|
|
|
402 |
Prefix for output files
|
403 |
"""
|
404 |
|
405 |
+
### format output path ###
|
406 |
+
output_path_prefix = os.path.join(
|
407 |
+
output_directory, f"in_silico_{self.perturb_type}_{output_prefix}"
|
408 |
+
)
|
409 |
+
|
410 |
+
### load model and define parameters ###
|
411 |
+
model = pu.load_model(self.model_type, self.num_classes, model_directory)
|
412 |
+
self.max_len = pu.get_model_input_size(model)
|
413 |
+
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
414 |
+
|
415 |
+
### filter input data ###
|
416 |
+
# general filtering of input data based on filter_data argument
|
417 |
+
filtered_input_data = pu.load_and_filter(
|
418 |
+
self.filter_data, self.nproc, input_data_file
|
419 |
+
)
|
420 |
+
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
421 |
+
|
422 |
+
if self.perturb_group is True:
|
423 |
+
self.isp_perturb_set(
|
424 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
425 |
+
)
|
426 |
else:
|
427 |
+
self.isp_perturb_all(
|
428 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
429 |
+
)
|
430 |
+
|
431 |
+
def apply_additional_filters(self, filtered_input_data):
|
432 |
+
# additional filtering of input data dependent on isp mode
|
433 |
+
if self.cell_states_to_model is not None:
|
434 |
+
# filter for cells with start_state and log result
|
435 |
+
filtered_input_data = pu.filter_data_by_start_state(
|
436 |
+
filtered_input_data, self.cell_states_to_model, self.nproc
|
437 |
+
)
|
438 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
|
440 |
+
# filter for cells with tokens_to_perturb and log result
|
441 |
+
filtered_input_data = pu.filter_data_by_tokens_and_log(
|
442 |
+
filtered_input_data,
|
443 |
+
self.tokens_to_perturb,
|
444 |
+
self.nproc,
|
445 |
+
"genes_to_perturb",
|
446 |
+
)
|
447 |
+
|
448 |
+
if self.anchor_token is not None:
|
449 |
+
# filter for cells with anchor gene and log result
|
450 |
+
filtered_input_data = pu.filter_data_by_tokens_and_log(
|
451 |
+
filtered_input_data, self.anchor_token, self.nproc, "anchor_gene"
|
452 |
+
)
|
453 |
+
|
454 |
+
# downsample and sort largest to smallest to encounter memory constraints earlier
|
455 |
+
filtered_input_data = pu.downsample_and_sort(
|
456 |
+
filtered_input_data, self.max_ncells
|
457 |
+
)
|
458 |
+
|
459 |
+
# slice dataset if cells_inds_to_perturb is not "all"
|
460 |
if self.cell_inds_to_perturb != "all":
|
461 |
+
filtered_input_data = pu.slice_by_inds_to_perturb(
|
462 |
+
filtered_input_data, self.cell_inds_to_perturb
|
463 |
+
)
|
464 |
+
|
465 |
+
return filtered_input_data
|
466 |
+
|
467 |
+
def isp_perturb_set(
|
468 |
+
self,
|
469 |
+
model,
|
470 |
+
filtered_input_data: Dataset,
|
471 |
+
layer_to_quant: int,
|
472 |
+
output_path_prefix: str,
|
473 |
+
):
|
474 |
+
def make_group_perturbation_batch(example):
|
475 |
+
example_input_ids = example["input_ids"]
|
476 |
+
example["tokens_to_perturb"] = self.tokens_to_perturb
|
477 |
+
indices_to_perturb = [
|
478 |
+
example_input_ids.index(token) if token in example_input_ids else None
|
479 |
+
for token in self.tokens_to_perturb
|
480 |
+
]
|
481 |
+
indices_to_perturb = [
|
482 |
+
item for item in indices_to_perturb if item is not None
|
483 |
+
]
|
484 |
+
if len(indices_to_perturb) > 0:
|
485 |
+
example["perturb_index"] = indices_to_perturb
|
486 |
+
else:
|
487 |
+
# -100 indicates tokens to overexpress are not present in rank value encoding
|
488 |
+
example["perturb_index"] = [-100]
|
489 |
+
if self.perturb_type == "delete":
|
490 |
+
example = pu.delete_indices(example)
|
491 |
+
elif self.perturb_type == "overexpress":
|
492 |
+
example = pu.overexpress_tokens(example, self.max_len)
|
493 |
+
example["n_overflow"] = pu.calc_n_overflow(
|
494 |
+
self.max_len,
|
495 |
+
example["length"],
|
496 |
+
self.tokens_to_perturb,
|
497 |
+
indices_to_perturb,
|
498 |
+
)
|
499 |
+
return example
|
500 |
+
|
501 |
+
total_batch_length = len(filtered_input_data)
|
502 |
+
if self.cell_states_to_model is None:
|
503 |
+
cos_sims_dict = defaultdict(list)
|
504 |
+
else:
|
505 |
+
cos_sims_dict = {
|
506 |
+
state: defaultdict(list)
|
507 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
508 |
+
}
|
509 |
+
|
510 |
+
perturbed_data = filtered_input_data.map(
|
511 |
+
make_group_perturbation_batch, num_proc=self.nproc
|
512 |
+
)
|
513 |
+
if self.perturb_type == "overexpress":
|
514 |
+
filtered_input_data = filtered_input_data.add_column(
|
515 |
+
"n_overflow", perturbed_data["n_overflow"]
|
516 |
+
)
|
517 |
+
# remove overflow genes from original data so that embeddings are comparable
|
518 |
+
# i.e. if original cell has genes 0:2047 and you want to overexpress new gene 2048,
|
519 |
+
# then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
|
520 |
+
# (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
|
521 |
+
# rather than only adding 2048)
|
522 |
+
filtered_input_data = filtered_input_data.map(
|
523 |
+
pu.truncate_by_n_overflow, num_proc=self.nproc
|
524 |
+
)
|
525 |
+
|
526 |
+
if self.emb_mode == "cell_and_gene":
|
527 |
+
stored_gene_embs_dict = defaultdict(list)
|
528 |
+
|
529 |
+
# iterate through batches
|
530 |
+
for i in trange(0, total_batch_length, self.forward_batch_size):
|
531 |
+
max_range = min(i + self.forward_batch_size, total_batch_length)
|
532 |
+
inds_select = [i for i in range(i, max_range)]
|
533 |
+
|
534 |
+
minibatch = filtered_input_data.select(inds_select)
|
535 |
+
perturbation_batch = perturbed_data.select(inds_select)
|
536 |
+
|
537 |
+
if self.cell_emb_style == "mean_pool":
|
538 |
+
full_original_emb = get_embs(
|
539 |
+
model,
|
540 |
+
minibatch,
|
541 |
+
"gene",
|
542 |
+
layer_to_quant,
|
543 |
+
self.pad_token_id,
|
544 |
+
self.forward_batch_size,
|
545 |
+
summary_stat=None,
|
546 |
+
silent=True,
|
547 |
+
)
|
548 |
+
indices_to_perturb = perturbation_batch["perturb_index"]
|
549 |
+
# remove indices that were perturbed
|
550 |
+
original_emb = pu.remove_perturbed_indices_set(
|
551 |
+
full_original_emb,
|
552 |
+
self.perturb_type,
|
553 |
+
indices_to_perturb,
|
554 |
+
self.tokens_to_perturb,
|
555 |
+
minibatch["length"],
|
556 |
+
)
|
557 |
+
full_perturbation_emb = get_embs(
|
558 |
+
model,
|
559 |
+
perturbation_batch,
|
560 |
+
"gene",
|
561 |
+
layer_to_quant,
|
562 |
+
self.pad_token_id,
|
563 |
+
self.forward_batch_size,
|
564 |
+
summary_stat=None,
|
565 |
+
silent=True,
|
566 |
+
)
|
567 |
+
|
568 |
+
# remove overexpressed genes
|
569 |
+
if self.perturb_type == "overexpress":
|
570 |
+
perturbation_emb = full_perturbation_emb[
|
571 |
+
:, len(self.tokens_to_perturb) :, :
|
572 |
+
]
|
573 |
+
|
574 |
+
elif self.perturb_type == "delete":
|
575 |
+
perturbation_emb = full_perturbation_emb[
|
576 |
+
:, : max(perturbation_batch["length"]), :
|
577 |
+
]
|
578 |
+
|
579 |
+
n_perturbation_genes = perturbation_emb.size()[1]
|
580 |
+
|
581 |
+
# if no goal states, the cosine similarties are the mean of gene cosine similarities
|
582 |
+
if (
|
583 |
+
self.cell_states_to_model is None
|
584 |
+
or self.emb_mode == "cell_and_gene"
|
585 |
+
):
|
586 |
+
gene_cos_sims = pu.quant_cos_sims(
|
587 |
+
perturbation_emb,
|
588 |
+
original_emb,
|
589 |
+
self.cell_states_to_model,
|
590 |
+
self.state_embs_dict,
|
591 |
+
emb_mode="gene",
|
592 |
+
)
|
593 |
+
|
594 |
+
# if there are goal states, the cosine similarities are the cell cosine similarities
|
595 |
+
if self.cell_states_to_model is not None:
|
596 |
+
original_cell_emb = pu.mean_nonpadding_embs(
|
597 |
+
full_original_emb,
|
598 |
+
torch.tensor(minibatch["length"], device="cuda"),
|
599 |
+
dim=1,
|
600 |
+
)
|
601 |
+
perturbation_cell_emb = pu.mean_nonpadding_embs(
|
602 |
+
full_perturbation_emb,
|
603 |
+
torch.tensor(perturbation_batch["length"], device="cuda"),
|
604 |
+
dim=1,
|
605 |
+
)
|
606 |
+
cell_cos_sims = pu.quant_cos_sims(
|
607 |
+
perturbation_cell_emb,
|
608 |
+
original_cell_emb,
|
609 |
+
self.cell_states_to_model,
|
610 |
+
self.state_embs_dict,
|
611 |
+
emb_mode="cell",
|
612 |
+
)
|
613 |
+
|
614 |
+
# get cosine similarities in gene embeddings
|
615 |
+
# if getting gene embeddings, need gene names
|
616 |
+
if self.emb_mode == "cell_and_gene":
|
617 |
+
gene_list = minibatch["input_ids"]
|
618 |
+
# need to truncate gene_list
|
619 |
+
gene_list = [
|
620 |
+
[g for g in genes if g not in self.tokens_to_perturb][
|
621 |
+
:n_perturbation_genes
|
622 |
+
]
|
623 |
+
for genes in gene_list
|
624 |
+
]
|
625 |
+
|
626 |
+
for cell_i, genes in enumerate(gene_list):
|
627 |
+
for gene_j, affected_gene in enumerate(genes):
|
628 |
+
if len(self.genes_to_perturb) > 1:
|
629 |
+
tokens_to_perturb = tuple(self.tokens_to_perturb)
|
630 |
+
else:
|
631 |
+
tokens_to_perturb = self.tokens_to_perturb
|
632 |
+
|
633 |
+
# fill in the gene cosine similarities
|
634 |
+
try:
|
635 |
+
stored_gene_embs_dict[
|
636 |
+
(tokens_to_perturb, affected_gene)
|
637 |
+
].append(gene_cos_sims[cell_i, gene_j].item())
|
638 |
+
except KeyError:
|
639 |
+
stored_gene_embs_dict[
|
640 |
+
(tokens_to_perturb, affected_gene)
|
641 |
+
] = gene_cos_sims[cell_i, gene_j].item()
|
642 |
else:
|
643 |
+
gene_list = None
|
644 |
+
|
645 |
+
if self.cell_states_to_model is None:
|
646 |
+
# calculate the mean of the gene cosine similarities for cell shift
|
647 |
+
# tensor of nonpadding lengths for each cell
|
648 |
+
if self.perturb_type == "overexpress":
|
649 |
+
# subtract number of genes that were overexpressed
|
650 |
+
# since they are removed before getting cos sims
|
651 |
+
n_overexpressed = len(self.tokens_to_perturb)
|
652 |
+
nonpadding_lens = [
|
653 |
+
x - n_overexpressed for x in perturbation_batch["length"]
|
654 |
+
]
|
655 |
+
else:
|
656 |
+
nonpadding_lens = perturbation_batch["length"]
|
657 |
+
cos_sims_data = pu.mean_nonpadding_embs(
|
658 |
+
gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
|
659 |
+
)
|
660 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
661 |
+
cos_sims_dict,
|
662 |
+
cos_sims_data,
|
663 |
+
filtered_input_data,
|
664 |
+
indices_to_perturb,
|
665 |
+
gene_list,
|
666 |
+
)
|
667 |
+
else:
|
668 |
+
cos_sims_data = cell_cos_sims
|
669 |
+
for state in cos_sims_dict.keys():
|
670 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
671 |
+
cos_sims_dict[state],
|
672 |
+
cos_sims_data[state],
|
673 |
+
filtered_input_data,
|
674 |
+
indices_to_perturb,
|
675 |
+
gene_list,
|
676 |
+
)
|
677 |
+
del minibatch
|
678 |
+
del perturbation_batch
|
679 |
+
del original_emb
|
680 |
+
del perturbation_emb
|
681 |
+
del cos_sims_data
|
682 |
+
|
683 |
+
torch.cuda.empty_cache()
|
684 |
+
|
685 |
+
pu.write_perturbation_dictionary(
|
686 |
+
cos_sims_dict,
|
687 |
+
f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
|
688 |
+
)
|
689 |
+
|
690 |
+
if self.emb_mode == "cell_and_gene":
|
691 |
+
pu.write_perturbation_dictionary(
|
692 |
+
stored_gene_embs_dict,
|
693 |
+
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
694 |
+
)
|
695 |
+
|
696 |
+
def isp_perturb_all(
|
697 |
+
self,
|
698 |
+
model,
|
699 |
+
filtered_input_data: Dataset,
|
700 |
+
layer_to_quant: int,
|
701 |
+
output_path_prefix: str,
|
702 |
+
):
|
703 |
+
pickle_batch = -1
|
704 |
+
if self.cell_states_to_model is None:
|
705 |
+
cos_sims_dict = defaultdict(list)
|
706 |
+
else:
|
707 |
+
cos_sims_dict = {
|
708 |
+
state: defaultdict(list)
|
709 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
710 |
+
}
|
711 |
+
|
712 |
+
if self.emb_mode == "cell_and_gene":
|
713 |
+
stored_gene_embs_dict = defaultdict(list)
|
714 |
+
for i in trange(len(filtered_input_data)):
|
715 |
+
example_cell = filtered_input_data.select([i])
|
716 |
+
full_original_emb = get_embs(
|
717 |
+
model,
|
718 |
+
example_cell,
|
719 |
+
"gene",
|
720 |
+
layer_to_quant,
|
721 |
+
self.pad_token_id,
|
722 |
+
self.forward_batch_size,
|
723 |
+
summary_stat=None,
|
724 |
+
silent=True,
|
725 |
+
)
|
726 |
+
|
727 |
+
# gene_list is used to assign cos sims back to genes
|
728 |
+
# need to remove the anchor gene
|
729 |
+
gene_list = example_cell["input_ids"][0][:]
|
730 |
+
if self.anchor_token is not None:
|
731 |
+
for token in self.anchor_token:
|
732 |
+
gene_list.remove(token)
|
733 |
+
|
734 |
+
perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
|
735 |
+
example_cell,
|
736 |
+
self.perturb_type,
|
737 |
+
self.tokens_to_perturb,
|
738 |
+
self.anchor_token,
|
739 |
+
self.combos,
|
740 |
+
self.nproc,
|
741 |
+
)
|
742 |
+
|
743 |
+
full_perturbation_emb = get_embs(
|
744 |
+
model,
|
745 |
+
perturbation_batch,
|
746 |
+
"gene",
|
747 |
+
layer_to_quant,
|
748 |
+
self.pad_token_id,
|
749 |
+
self.forward_batch_size,
|
750 |
+
summary_stat=None,
|
751 |
+
silent=True,
|
752 |
+
)
|
753 |
+
|
754 |
+
num_inds_perturbed = 1 + self.combos
|
755 |
+
# need to remove overexpressed gene to quantify cosine shifts
|
756 |
+
if self.perturb_type == "overexpress":
|
757 |
+
perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
|
758 |
+
gene_list = gene_list[
|
759 |
+
num_inds_perturbed:
|
760 |
+
] # index 0 is not overexpressed
|
761 |
+
|
762 |
+
elif self.perturb_type == "delete":
|
763 |
+
perturbation_emb = full_perturbation_emb
|
764 |
+
|
765 |
+
original_batch = pu.make_comparison_batch(
|
766 |
+
full_original_emb, indices_to_perturb, perturb_group=False
|
767 |
+
)
|
768 |
+
|
769 |
+
if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
|
770 |
+
gene_cos_sims = pu.quant_cos_sims(
|
771 |
+
perturbation_emb,
|
772 |
+
original_batch,
|
773 |
+
self.cell_states_to_model,
|
774 |
+
self.state_embs_dict,
|
775 |
+
emb_mode="gene",
|
776 |
+
)
|
777 |
+
if self.cell_states_to_model is not None:
|
778 |
+
original_cell_emb = pu.compute_nonpadded_cell_embedding(
|
779 |
+
full_original_emb, "mean_pool"
|
780 |
+
)
|
781 |
+
perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
|
782 |
+
full_perturbation_emb, "mean_pool"
|
783 |
+
)
|
784 |
+
|
785 |
+
cell_cos_sims = pu.quant_cos_sims(
|
786 |
+
perturbation_cell_emb,
|
787 |
+
original_cell_emb,
|
788 |
+
self.cell_states_to_model,
|
789 |
+
self.state_embs_dict,
|
790 |
+
emb_mode="cell",
|
791 |
+
)
|
792 |
+
|
793 |
+
if self.emb_mode == "cell_and_gene":
|
794 |
+
# remove perturbed index for gene list
|
795 |
+
perturbed_gene_dict = {
|
796 |
+
gene: gene_list[:i] + gene_list[i + 1 :]
|
797 |
+
for i, gene in enumerate(gene_list)
|
798 |
+
}
|
799 |
+
|
800 |
+
for perturbation_i, perturbed_gene in enumerate(gene_list):
|
801 |
+
for gene_j, affected_gene in enumerate(
|
802 |
+
perturbed_gene_dict[perturbed_gene]
|
803 |
+
):
|
804 |
+
try:
|
805 |
+
stored_gene_embs_dict[
|
806 |
+
(perturbed_gene, affected_gene)
|
807 |
+
].append(gene_cos_sims[perturbation_i, gene_j].item())
|
808 |
+
except KeyError:
|
809 |
+
stored_gene_embs_dict[
|
810 |
+
(perturbed_gene, affected_gene)
|
811 |
+
] = gene_cos_sims[perturbation_i, gene_j].item()
|
812 |
+
|
813 |
if self.cell_states_to_model is None:
|
814 |
+
cos_sims_data = torch.mean(gene_cos_sims, dim=1)
|
815 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
816 |
+
cos_sims_dict,
|
817 |
+
cos_sims_data,
|
818 |
+
filtered_input_data,
|
819 |
+
indices_to_perturb,
|
820 |
+
gene_list,
|
821 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
822 |
else:
|
823 |
+
cos_sims_data = cell_cos_sims
|
824 |
+
for state in cos_sims_dict.keys():
|
825 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
826 |
+
cos_sims_dict[state],
|
827 |
+
cos_sims_data[state],
|
828 |
+
filtered_input_data,
|
829 |
+
indices_to_perturb,
|
830 |
+
gene_list,
|
831 |
+
)
|
832 |
+
|
833 |
+
# save dict to disk every 100 cells
|
834 |
+
if i % 100 == 0:
|
835 |
+
pu.write_perturbation_dictionary(
|
836 |
+
cos_sims_dict,
|
837 |
+
f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
|
838 |
+
)
|
839 |
+
if self.emb_mode == "cell_and_gene":
|
840 |
+
pu.write_perturbation_dictionary(
|
841 |
+
stored_gene_embs_dict,
|
842 |
+
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|
843 |
+
)
|
844 |
+
|
845 |
+
# reset and clear memory every 1000 cells
|
846 |
+
if i % 1000 == 0:
|
847 |
+
pickle_batch += 1
|
848 |
+
if self.cell_states_to_model is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
849 |
cos_sims_dict = defaultdict(list)
|
850 |
+
else:
|
851 |
+
cos_sims_dict = {
|
852 |
+
state: defaultdict(list)
|
853 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
854 |
+
}
|
855 |
+
|
856 |
+
if self.emb_mode == "cell_and_gene":
|
857 |
+
stored_gene_embs_dict = defaultdict(list)
|
858 |
+
|
859 |
+
torch.cuda.empty_cache()
|
860 |
+
|
861 |
+
pu.write_perturbation_dictionary(
|
862 |
+
cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
|
863 |
+
)
|
864 |
+
|
865 |
+
if self.emb_mode == "cell_and_gene":
|
866 |
+
pu.write_perturbation_dictionary(
|
867 |
+
stored_gene_embs_dict,
|
868 |
+
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|
869 |
+
)
|
870 |
+
|
871 |
+
def update_perturbation_dictionary(
|
872 |
+
self,
|
873 |
+
cos_sims_dict: defaultdict,
|
874 |
+
cos_sims_data: torch.Tensor,
|
875 |
+
filtered_input_data: Dataset,
|
876 |
+
indices_to_perturb: list[list[int]],
|
877 |
+
gene_list=None,
|
878 |
+
):
|
879 |
+
if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
|
880 |
+
logger.error(
|
881 |
+
f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
|
882 |
+
cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
|
883 |
+
len(gene_list) = {len(gene_list)}."
|
884 |
+
)
|
885 |
+
raise
|
886 |
+
|
887 |
+
if self.perturb_group is True:
|
888 |
+
if len(self.tokens_to_perturb) > 1:
|
889 |
+
perturbed_genes = tuple(self.tokens_to_perturb)
|
890 |
+
else:
|
891 |
+
perturbed_genes = self.tokens_to_perturb[0]
|
892 |
+
|
893 |
+
# if cell embeddings, can just append
|
894 |
+
# shape will be (batch size, 1)
|
895 |
+
cos_sims_data = torch.squeeze(cos_sims_data).tolist()
|
896 |
+
|
897 |
+
# handle case of single cell left
|
898 |
+
if not isinstance(cos_sims_data, list):
|
899 |
+
cos_sims_data = [cos_sims_data]
|
900 |
+
|
901 |
+
cos_sims_dict[(perturbed_genes, "cell_emb")] += cos_sims_data
|
902 |
+
|
903 |
+
else:
|
904 |
+
for i, cos in enumerate(cos_sims_data.tolist()):
|
905 |
+
cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
|
906 |
|
907 |
+
return cos_sims_dict
|
|
|
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -6,9 +6,9 @@ Usage:
|
|
6 |
ispstats = InSilicoPerturberStats(mode="goal_state_shift",
|
7 |
combos=0,
|
8 |
anchor_gene=None,
|
9 |
-
cell_states_to_model={"state_key": "disease",
|
10 |
-
"start_state": "dcm",
|
11 |
-
"goal_state": "nf",
|
12 |
"alt_states": ["hcm", "other1", "other2"]})
|
13 |
ispstats.get_stats("path/to/input_data",
|
14 |
None,
|
@@ -17,88 +17,157 @@ Usage:
|
|
17 |
"""
|
18 |
|
19 |
|
20 |
-
import os
|
21 |
import logging
|
22 |
-
import
|
23 |
-
import pandas as pd
|
24 |
import pickle
|
25 |
import random
|
26 |
-
import statsmodels.stats.multitest as smt
|
27 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
28 |
from scipy.stats import ranksums
|
29 |
from sklearn.mixture import GaussianMixture
|
30 |
-
from tqdm.auto import
|
31 |
-
|
32 |
-
from .in_silico_perturber import flatten_list
|
33 |
|
|
|
34 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
35 |
|
36 |
GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
37 |
|
38 |
logger = logging.getLogger(__name__)
|
39 |
|
|
|
40 |
# invert dictionary keys/values
|
41 |
def invert_dict(dictionary):
|
42 |
return {v: k for k, v in dictionary.items()}
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# read raw dictionary files
|
45 |
-
def read_dictionaries(
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
file_path_list = []
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
for file in os.listdir(input_data_directory):
|
50 |
-
# process only _raw.pickle
|
51 |
-
if file.endswith(
|
52 |
-
file_found =
|
53 |
file_path_list += [f"{input_data_directory}/{file}"]
|
54 |
for file_path in tqdm(file_path_list):
|
55 |
with open(file_path, "rb") as fp:
|
56 |
cos_sims_dict = pickle.load(fp)
|
57 |
-
if
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
66 |
logger.error(
|
67 |
-
|
68 |
-
|
|
|
69 |
raise
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
|
72 |
# get complete gene list
|
73 |
-
def get_gene_list(dict_list,mode):
|
74 |
if mode == "cell":
|
75 |
position = 0
|
76 |
elif mode == "gene":
|
77 |
position = 1
|
78 |
gene_set = set()
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
gene_list = list(gene_set)
|
82 |
if mode == "gene":
|
83 |
gene_list.remove("cell_emb")
|
84 |
gene_list.sort()
|
85 |
return gene_list
|
86 |
|
|
|
87 |
def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
|
88 |
-
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def n_detections(token, dict_list, mode, anchor_token):
|
91 |
cos_sim_megalist = []
|
92 |
for dict_i in dict_list:
|
93 |
if mode == "cell":
|
94 |
-
cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
|
95 |
elif mode == "gene":
|
96 |
-
cos_sim_megalist += dict_i.get((anchor_token, token),[])
|
97 |
return len(cos_sim_megalist)
|
98 |
|
|
|
99 |
def get_fdr(pvalues):
|
100 |
return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
|
101 |
|
|
|
102 |
def get_impact_component(test_value, gaussian_mixture_model):
|
103 |
impact_border = gaussian_mixture_model.means_[0][0]
|
104 |
nonimpact_border = gaussian_mixture_model.means_[1][0]
|
@@ -114,236 +183,356 @@ def get_impact_component(test_value, gaussian_mixture_model):
|
|
114 |
impact_component = 1
|
115 |
return impact_component
|
116 |
|
|
|
117 |
# aggregate data for single perturbation in multiple cells
|
118 |
-
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
|
119 |
-
names=["Cosine_shift"]
|
120 |
cos_sims_full_df = pd.DataFrame(columns=names)
|
121 |
|
122 |
cos_shift_data = []
|
123 |
token = cos_sims_df["Gene"][0]
|
124 |
for dict_i in dict_list:
|
125 |
-
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
126 |
cos_sims_full_df["Cosine_shift"] = cos_shift_data
|
127 |
-
return cos_sims_full_df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
|
130 |
-
def isp_stats_to_goal_state(
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
135 |
alt_end_state_exists = False
|
136 |
-
elif (len(cell_states_to_model["alt_states"]) > 0) and (
|
|
|
|
|
137 |
alt_end_state_exists = True
|
138 |
-
|
139 |
# for single perturbation in multiple cells, there are no random perturbations to compare to
|
140 |
if genes_perturbed != "all":
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
names.remove("Shift_to_alt_end")
|
145 |
-
cos_sims_full_df = pd.DataFrame(columns=names)
|
146 |
-
|
147 |
-
cos_shift_data = []
|
148 |
token = cos_sims_df["Gene"][0]
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
if alt_end_state_exists
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
# sort by shift to desired state
|
158 |
-
cos_sims_full_df = cos_sims_full_df.sort_values(
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
162 |
elif genes_perturbed == "all":
|
163 |
-
|
|
|
|
|
|
|
|
|
164 |
for i in trange(cos_sims_df.shape[0]):
|
165 |
token = cos_sims_df["Gene"][i]
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
|
175 |
# downsample to improve speed of ranksums
|
176 |
if len(goal_end_random_megalist) > 100_000:
|
177 |
random.seed(42)
|
178 |
-
goal_end_random_megalist = random.sample(
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
cos_sims_full_df = pd.DataFrame(columns=names)
|
195 |
|
|
|
196 |
for i in trange(cos_sims_df.shape[0]):
|
197 |
token = cos_sims_df["Gene"][i]
|
198 |
name = cos_sims_df["Gene_name"][i]
|
199 |
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
name,
|
225 |
-
ensembl_id,
|
226 |
-
mean_goal_end,
|
227 |
-
mean_alt_end,
|
228 |
-
pval_goal_end,
|
229 |
-
pval_alt_end]
|
230 |
-
|
231 |
-
cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
|
232 |
-
cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
|
233 |
-
|
234 |
-
cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
|
235 |
-
if alt_end_state_exists == True:
|
236 |
-
cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
|
237 |
|
238 |
# quantify number of detections of each gene
|
239 |
-
cos_sims_full_df["N_Detections"] = [
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
248 |
return cos_sims_full_df
|
249 |
|
|
|
250 |
# stats comparing cos sim shifts of test perturbations vs null distribution
|
251 |
def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
|
252 |
cos_sims_full_df = cos_sims_df.copy()
|
253 |
|
254 |
cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
255 |
cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
256 |
-
cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(
|
|
|
|
|
257 |
cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
258 |
cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
259 |
-
cos_sims_full_df["N_Detections_test"] = np.zeros(
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
262 |
for i in trange(cos_sims_df.shape[0]):
|
263 |
token = cos_sims_df["Gene"][i]
|
264 |
test_shifts = []
|
265 |
null_shifts = []
|
266 |
-
|
267 |
for dict_i in dict_list:
|
268 |
-
test_shifts += dict_i.get((token, "cell_emb"),[])
|
269 |
|
270 |
for dict_i in null_dict_list:
|
271 |
-
null_shifts += dict_i.get((token, "cell_emb"),[])
|
272 |
-
|
273 |
cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
|
274 |
cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
|
275 |
-
cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(
|
276 |
-
|
277 |
-
|
278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
|
280 |
cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
|
281 |
|
282 |
-
cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
|
|
289 |
return cos_sims_full_df
|
290 |
|
|
|
291 |
# stats for identifying perturbations with largest effect within a given set of cells
|
292 |
# fits a mixture model to 2 components (impact vs. non-impact) and
|
293 |
# reports the most likely component for each test perturbation
|
294 |
# Note: because assumes given perturbation has a consistent effect in the cells tested,
|
295 |
# we recommend only using the mixture model strategy with uniform cell populations
|
296 |
def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
|
297 |
-
|
298 |
-
|
299 |
-
"Gene_name",
|
300 |
-
"Ensembl_ID"]
|
301 |
-
|
302 |
if combos == 0:
|
303 |
names += ["Test_avg_shift"]
|
304 |
elif combos == 1:
|
305 |
-
names += [
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
|
|
313 |
|
314 |
cos_sims_full_df = pd.DataFrame(columns=names)
|
315 |
avg_values = []
|
316 |
gene_names = []
|
317 |
-
|
318 |
for i in trange(cos_sims_df.shape[0]):
|
319 |
token = cos_sims_df["Gene"][i]
|
320 |
name = cos_sims_df["Gene_name"][i]
|
321 |
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
322 |
cos_shift_data = []
|
323 |
-
|
324 |
for dict_i in dict_list:
|
325 |
if (combos == 0) and (anchor_token is not None):
|
326 |
-
cos_shift_data += dict_i.get((anchor_token, token),[])
|
327 |
else:
|
328 |
-
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
329 |
-
|
330 |
# Extract values for current gene
|
331 |
if combos == 0:
|
332 |
test_values = cos_shift_data
|
333 |
elif combos == 1:
|
334 |
test_values = []
|
335 |
for tup in cos_shift_data:
|
336 |
-
test_values.append(tup[2])
|
337 |
-
|
338 |
if len(test_values) > 0:
|
339 |
avg_value = np.mean(test_values)
|
340 |
avg_values.append(avg_value)
|
341 |
gene_names.append(name)
|
342 |
-
|
343 |
# fit Gaussian mixture model to dataset of mean for each gene
|
344 |
avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
|
345 |
gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
|
346 |
-
|
347 |
for i in trange(cos_sims_df.shape[0]):
|
348 |
token = cos_sims_df["Gene"][i]
|
349 |
name = cos_sims_df["Gene_name"][i]
|
@@ -352,71 +541,95 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
|
|
352 |
|
353 |
for dict_i in dict_list:
|
354 |
if (combos == 0) and (anchor_token is not None):
|
355 |
-
cos_shift_data += dict_i.get((anchor_token, token),[])
|
356 |
else:
|
357 |
-
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
358 |
-
|
359 |
if combos == 0:
|
360 |
mean_test = np.mean(cos_shift_data)
|
361 |
-
impact_components = [
|
|
|
|
|
362 |
elif combos == 1:
|
363 |
-
anchor_cos_sim_megalist = [
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
mean_anchor = np.mean(anchor_cos_sim_megalist)
|
370 |
mean_token = np.mean(token_cos_sim_megalist)
|
371 |
mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
|
372 |
mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
|
373 |
mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
|
374 |
-
|
375 |
-
impact_components = [
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
|
|
383 |
if combos == 0:
|
384 |
data_i += [mean_test]
|
385 |
elif combos == 1:
|
386 |
-
data_i += [
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
397 |
# quantify number of detections of each gene
|
398 |
-
cos_sims_full_df["N_Detections"] = [
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
if combos == 0:
|
404 |
-
cos_sims_full_df = cos_sims_full_df.sort_values(
|
405 |
-
|
406 |
-
|
407 |
elif combos == 1:
|
408 |
-
cos_sims_full_df = cos_sims_full_df.sort_values(
|
409 |
-
|
410 |
-
|
411 |
return cos_sims_full_df
|
412 |
|
|
|
413 |
class InSilicoPerturberStats:
|
414 |
valid_option_dict = {
|
415 |
-
"mode": {
|
416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
"anchor_gene": {None, str},
|
418 |
"cell_states_to_model": {None, dict},
|
|
|
419 |
}
|
|
|
420 |
def __init__(
|
421 |
self,
|
422 |
mode="mixture_model",
|
@@ -424,6 +637,7 @@ class InSilicoPerturberStats:
|
|
424 |
combos=0,
|
425 |
anchor_gene=None,
|
426 |
cell_states_to_model=None,
|
|
|
427 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
428 |
gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
|
429 |
):
|
@@ -432,12 +646,13 @@ class InSilicoPerturberStats:
|
|
432 |
|
433 |
Parameters
|
434 |
----------
|
435 |
-
mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data"}
|
436 |
Type of stats.
|
437 |
"goal_state_shift": perturbation vs. random for desired cell state shift
|
438 |
"vs_null": perturbation vs. null from provided null distribution dataset
|
439 |
"mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
|
440 |
"aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
|
|
|
441 |
genes_perturbed : "all", list
|
442 |
Genes perturbed in isp experiment.
|
443 |
Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
|
@@ -472,13 +687,14 @@ class InSilicoPerturberStats:
|
|
472 |
self.combos = combos
|
473 |
self.anchor_gene = anchor_gene
|
474 |
self.cell_states_to_model = cell_states_to_model
|
475 |
-
|
|
|
476 |
self.validate_options()
|
477 |
|
478 |
# load token dictionary (Ensembl IDs:token)
|
479 |
with open(token_dictionary_file, "rb") as f:
|
480 |
self.gene_token_dict = pickle.load(f)
|
481 |
-
|
482 |
# load gene name dictionary (gene name:Ensembl ID)
|
483 |
with open(gene_name_id_dictionary_file, "rb") as f:
|
484 |
self.gene_name_id_dict = pickle.load(f)
|
@@ -489,7 +705,7 @@ class InSilicoPerturberStats:
|
|
489 |
self.anchor_token = self.gene_token_dict[self.anchor_gene]
|
490 |
|
491 |
def validate_options(self):
|
492 |
-
for attr_name,valid_options in self.valid_option_dict.items():
|
493 |
attr_value = self.__dict__[attr_name]
|
494 |
if type(attr_value) not in {list, dict}:
|
495 |
if attr_name in {"anchor_gene"}:
|
@@ -498,35 +714,40 @@ class InSilicoPerturberStats:
|
|
498 |
continue
|
499 |
valid_type = False
|
500 |
for option in valid_options:
|
501 |
-
if (option in [int,list,dict]) and isinstance(
|
|
|
|
|
502 |
valid_type = True
|
503 |
break
|
504 |
-
if valid_type:
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
if self.cell_states_to_model is not None:
|
513 |
if len(self.cell_states_to_model.items()) == 1:
|
514 |
logger.warning(
|
515 |
-
"The single value dictionary for cell_states_to_model will be "
|
516 |
-
"replaced with a dictionary with named keys for start, goal, and alternate states. "
|
517 |
-
"Please specify state_key, start_state, goal_state, and alt_states "
|
518 |
-
"in the cell_states_to_model dictionary for future use. "
|
519 |
-
"For example, cell_states_to_model={"
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
)
|
525 |
-
for key,value in self.cell_states_to_model.items():
|
526 |
if (len(value) == 3) and isinstance(value, tuple):
|
527 |
-
if
|
|
|
|
|
|
|
|
|
528 |
if len(value[0]) == 1 and len(value[1]) == 1:
|
529 |
-
all_values = value[0]+value[1]+value[2]
|
530 |
if len(all_values) == len(set(all_values)):
|
531 |
continue
|
532 |
# reformat to the new named key format
|
@@ -535,75 +756,93 @@ class InSilicoPerturberStats:
|
|
535 |
"state_key": list(self.cell_states_to_model.keys())[0],
|
536 |
"start_state": state_values[0][0],
|
537 |
"goal_state": state_values[1][0],
|
538 |
-
"alt_states": state_values[2:][0]
|
539 |
}
|
540 |
-
elif set(self.cell_states_to_model.keys()) == {
|
541 |
-
|
542 |
-
|
543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
logger.error(
|
545 |
-
"Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
|
|
|
546 |
raise
|
547 |
-
|
548 |
-
if
|
549 |
-
|
550 |
-
|
|
|
|
|
551 |
raise
|
552 |
|
553 |
if self.cell_states_to_model["alt_states"] is not None:
|
554 |
-
if
|
555 |
logger.error(
|
556 |
"self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
557 |
)
|
558 |
raise
|
559 |
-
if len(self.cell_states_to_model["alt_states"])!= len(
|
560 |
-
|
561 |
-
|
|
|
562 |
raise
|
563 |
|
564 |
else:
|
565 |
logger.error(
|
566 |
-
"cell_states_to_model must only have the following four keys: "
|
567 |
-
"'state_key', 'start_state', 'goal_state', 'alt_states'."
|
568 |
-
"For example, cell_states_to_model={"
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
)
|
574 |
raise
|
575 |
|
576 |
if self.anchor_gene is not None:
|
577 |
self.anchor_gene = None
|
578 |
logger.warning(
|
579 |
-
"anchor_gene set to None. "
|
580 |
-
"Currently, anchor gene not available "
|
581 |
-
"when modeling multiple cell states."
|
582 |
-
|
|
|
583 |
if self.combos > 0:
|
584 |
if self.anchor_gene is None:
|
585 |
logger.error(
|
586 |
-
"Currently, stats are only supported for combination "
|
587 |
-
"in silico perturbation run with anchor gene. Please add "
|
588 |
-
"anchor gene when using with combos > 0. "
|
|
|
589 |
raise
|
590 |
-
|
591 |
if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
|
592 |
logger.error(
|
593 |
-
|
594 |
-
|
|
|
595 |
raise
|
596 |
if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
|
597 |
logger.error(
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
|
|
|
|
|
|
|
|
607 |
"""
|
608 |
Get stats for in silico perturbation data and save as results in output_directory.
|
609 |
|
@@ -617,20 +856,22 @@ class InSilicoPerturberStats:
|
|
617 |
Path to directory where perturbation data will be saved as .csv
|
618 |
output_prefix : str
|
619 |
Prefix for output .csv
|
620 |
-
|
|
|
|
|
621 |
Outputs
|
622 |
----------
|
623 |
Definition of possible columns in .csv output file.
|
624 |
-
|
625 |
Of note, not all columns will be present in all output files.
|
626 |
Some columns are specific to particular perturbation modes.
|
627 |
-
|
628 |
"Gene": gene token
|
629 |
"Gene_name": gene name
|
630 |
"Ensembl_ID": gene Ensembl ID
|
631 |
"N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
|
632 |
"Sig": 1 if FDR<0.05, otherwise 0
|
633 |
-
|
634 |
"Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
|
635 |
"Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
|
636 |
"Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
|
@@ -639,7 +880,7 @@ class InSilicoPerturberStats:
|
|
639 |
pvalue compares shift caused by perturbing given gene compared to random genes
|
640 |
"Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
|
641 |
"Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
|
642 |
-
|
643 |
"Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
|
644 |
"Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
|
645 |
"Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
|
@@ -648,7 +889,7 @@ class InSilicoPerturberStats:
|
|
648 |
"Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
|
649 |
"N_Detections_test": "N_Detections" in cells from test distribution
|
650 |
"N_Detections_null": "N_Detections" in cells from null distribution
|
651 |
-
|
652 |
"Anchor_shift": cosine shift in response to given perturbation of anchor gene
|
653 |
"Test_token_shift": cosine shift in response to given perturbation of test gene
|
654 |
"Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
|
@@ -658,13 +899,27 @@ class InSilicoPerturberStats:
|
|
658 |
"Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
|
659 |
1: within impact component; 0: not within impact component
|
660 |
"Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
|
|
|
|
|
|
|
|
|
|
|
|
|
661 |
"""
|
662 |
|
663 |
-
if self.mode not in [
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
logger.error(
|
665 |
-
"Currently, only modes available are stats for goal_state_shift, "
|
666 |
-
"vs_null (comparing to null distribution),
|
667 |
-
"mixture_model (fitting mixture model for perturbations with or without impact
|
|
|
|
|
668 |
raise
|
669 |
|
670 |
self.gene_token_id_dict = invert_dict(self.gene_token_dict)
|
@@ -673,44 +928,107 @@ class InSilicoPerturberStats:
|
|
673 |
# obtain total gene list
|
674 |
if (self.combos == 0) and (self.anchor_token is not None):
|
675 |
# cos sim data for effect of gene perturbation on the embedding of each other gene
|
676 |
-
dict_list = read_dictionaries(
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
gene_list = get_gene_list(dict_list, "gene")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
else:
|
679 |
# cos sim data for effect of gene perturbation on the embedding of each cell
|
680 |
-
dict_list = read_dictionaries(
|
|
|
|
|
|
|
|
|
|
|
|
|
681 |
gene_list = get_gene_list(dict_list, "cell")
|
682 |
-
|
683 |
# initiate results dataframe
|
684 |
-
cos_sims_df_initial = pd.DataFrame(
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
|
|
|
|
|
|
|
|
|
|
694 |
|
695 |
if self.mode == "goal_state_shift":
|
696 |
-
cos_sims_df = isp_stats_to_goal_state(
|
697 |
-
|
|
|
|
|
|
|
|
|
|
|
698 |
elif self.mode == "vs_null":
|
699 |
-
null_dict_list
|
700 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
701 |
|
702 |
elif self.mode == "mixture_model":
|
703 |
-
cos_sims_df = isp_stats_mixture_model(
|
704 |
-
|
|
|
|
|
705 |
elif self.mode == "aggregate_data":
|
706 |
cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
|
707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
# save perturbation stats to output_path
|
709 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
710 |
cos_sims_df.to_csv(output_path)
|
711 |
|
712 |
def token_to_gene_name(self, item):
|
713 |
-
if isinstance(item,int):
|
714 |
-
return self.gene_id_name_dict.get(
|
715 |
-
|
716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
ispstats = InSilicoPerturberStats(mode="goal_state_shift",
|
7 |
combos=0,
|
8 |
anchor_gene=None,
|
9 |
+
cell_states_to_model={"state_key": "disease",
|
10 |
+
"start_state": "dcm",
|
11 |
+
"goal_state": "nf",
|
12 |
"alt_states": ["hcm", "other1", "other2"]})
|
13 |
ispstats.get_stats("path/to/input_data",
|
14 |
None,
|
|
|
17 |
"""
|
18 |
|
19 |
|
|
|
20 |
import logging
|
21 |
+
import os
|
|
|
22 |
import pickle
|
23 |
import random
|
|
|
24 |
from pathlib import Path
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import pandas as pd
|
28 |
+
import statsmodels.stats.multitest as smt
|
29 |
from scipy.stats import ranksums
|
30 |
from sklearn.mixture import GaussianMixture
|
31 |
+
from tqdm.auto import tqdm, trange
|
|
|
|
|
32 |
|
33 |
+
from .perturber_utils import flatten_list, validate_cell_states_to_model
|
34 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
35 |
|
36 |
GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
37 |
|
38 |
logger = logging.getLogger(__name__)
|
39 |
|
40 |
+
|
41 |
# invert dictionary keys/values
|
42 |
def invert_dict(dictionary):
|
43 |
return {v: k for k, v in dictionary.items()}
|
44 |
|
45 |
+
|
46 |
+
def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
|
47 |
+
if cell_or_gene_emb == "cell":
|
48 |
+
cell_emb_dict = {
|
49 |
+
k: v for k, v in cos_sims_dict.items() if v and "cell_emb" in k
|
50 |
+
}
|
51 |
+
return [cell_emb_dict]
|
52 |
+
elif cell_or_gene_emb == "gene":
|
53 |
+
if anchor_token is None:
|
54 |
+
gene_emb_dict = {k: v for k, v in cos_sims_dict.items() if v}
|
55 |
+
else:
|
56 |
+
gene_emb_dict = {
|
57 |
+
k: v for k, v in cos_sims_dict.items() if v and anchor_token == k[0]
|
58 |
+
}
|
59 |
+
return [gene_emb_dict]
|
60 |
+
|
61 |
+
|
62 |
# read raw dictionary files
|
63 |
+
def read_dictionaries(
|
64 |
+
input_data_directory,
|
65 |
+
cell_or_gene_emb,
|
66 |
+
anchor_token,
|
67 |
+
cell_states_to_model,
|
68 |
+
pickle_suffix,
|
69 |
+
):
|
70 |
+
file_found = False
|
71 |
file_path_list = []
|
72 |
+
if cell_states_to_model is None:
|
73 |
+
dict_list = []
|
74 |
+
else:
|
75 |
+
validate_cell_states_to_model(cell_states_to_model)
|
76 |
+
cell_states_to_model_valid = {
|
77 |
+
state: value
|
78 |
+
for state, value in cell_states_to_model.items()
|
79 |
+
if state != "state_key"
|
80 |
+
and cell_states_to_model[state] is not None
|
81 |
+
and cell_states_to_model[state] != []
|
82 |
+
}
|
83 |
+
cell_states_list = []
|
84 |
+
# flatten all state values into list
|
85 |
+
for state in cell_states_to_model_valid:
|
86 |
+
value = cell_states_to_model_valid[state]
|
87 |
+
if isinstance(value, list):
|
88 |
+
cell_states_list += value
|
89 |
+
else:
|
90 |
+
cell_states_list.append(value)
|
91 |
+
state_dict = {state_value: dict() for state_value in cell_states_list}
|
92 |
for file in os.listdir(input_data_directory):
|
93 |
+
# process only files with given suffix (e.g. "_raw.pickle")
|
94 |
+
if file.endswith(pickle_suffix):
|
95 |
+
file_found = True
|
96 |
file_path_list += [f"{input_data_directory}/{file}"]
|
97 |
for file_path in tqdm(file_path_list):
|
98 |
with open(file_path, "rb") as fp:
|
99 |
cos_sims_dict = pickle.load(fp)
|
100 |
+
if cell_states_to_model is None:
|
101 |
+
dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
|
102 |
+
else:
|
103 |
+
for state_value in cell_states_list:
|
104 |
+
new_dict = read_dict(
|
105 |
+
cos_sims_dict[state_value], cell_or_gene_emb, anchor_token
|
106 |
+
)[0]
|
107 |
+
for key in new_dict:
|
108 |
+
try:
|
109 |
+
state_dict[state_value][key] += new_dict[key]
|
110 |
+
except KeyError:
|
111 |
+
state_dict[state_value][key] = new_dict[key]
|
112 |
+
if not file_found:
|
113 |
logger.error(
|
114 |
+
"No raw data for processing found within provided directory. "
|
115 |
+
"Please ensure data files end with '{pickle_suffix}'."
|
116 |
+
)
|
117 |
raise
|
118 |
+
if cell_states_to_model is None:
|
119 |
+
return dict_list
|
120 |
+
else:
|
121 |
+
return state_dict
|
122 |
+
|
123 |
|
124 |
# get complete gene list
|
125 |
+
def get_gene_list(dict_list, mode):
|
126 |
if mode == "cell":
|
127 |
position = 0
|
128 |
elif mode == "gene":
|
129 |
position = 1
|
130 |
gene_set = set()
|
131 |
+
if isinstance(dict_list, list):
|
132 |
+
for dict_i in dict_list:
|
133 |
+
gene_set.update([k[position] for k, v in dict_i.items() if v])
|
134 |
+
elif isinstance(dict_list, dict):
|
135 |
+
for state, dict_i in dict_list.items():
|
136 |
+
gene_set.update([k[position] for k, v in dict_i.items() if v])
|
137 |
+
else:
|
138 |
+
logger.error(
|
139 |
+
"dict_list should be a list, or if modeling shift to goal states, a dict. "
|
140 |
+
f"{type(dict_list)} is not the correct format."
|
141 |
+
)
|
142 |
+
raise
|
143 |
gene_list = list(gene_set)
|
144 |
if mode == "gene":
|
145 |
gene_list.remove("cell_emb")
|
146 |
gene_list.sort()
|
147 |
return gene_list
|
148 |
|
149 |
+
|
150 |
def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
|
151 |
+
try:
|
152 |
+
return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
|
153 |
+
except TypeError:
|
154 |
+
return tuple(gene_token_id_dict.get(token_tuple, np.nan))
|
155 |
+
|
156 |
|
157 |
def n_detections(token, dict_list, mode, anchor_token):
|
158 |
cos_sim_megalist = []
|
159 |
for dict_i in dict_list:
|
160 |
if mode == "cell":
|
161 |
+
cos_sim_megalist += dict_i.get((token, "cell_emb"), [])
|
162 |
elif mode == "gene":
|
163 |
+
cos_sim_megalist += dict_i.get((anchor_token, token), [])
|
164 |
return len(cos_sim_megalist)
|
165 |
|
166 |
+
|
167 |
def get_fdr(pvalues):
|
168 |
return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
|
169 |
|
170 |
+
|
171 |
def get_impact_component(test_value, gaussian_mixture_model):
|
172 |
impact_border = gaussian_mixture_model.means_[0][0]
|
173 |
nonimpact_border = gaussian_mixture_model.means_[1][0]
|
|
|
183 |
impact_component = 1
|
184 |
return impact_component
|
185 |
|
186 |
+
|
187 |
# aggregate data for single perturbation in multiple cells
|
188 |
+
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
|
189 |
+
names = ["Cosine_shift"]
|
190 |
cos_sims_full_df = pd.DataFrame(columns=names)
|
191 |
|
192 |
cos_shift_data = []
|
193 |
token = cos_sims_df["Gene"][0]
|
194 |
for dict_i in dict_list:
|
195 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
196 |
cos_sims_full_df["Cosine_shift"] = cos_shift_data
|
197 |
+
return cos_sims_full_df
|
198 |
+
|
199 |
+
|
200 |
+
def find(variable, x):
|
201 |
+
try:
|
202 |
+
if x in variable: # Test if variable is iterable and contains x
|
203 |
+
return True
|
204 |
+
except TypeError:
|
205 |
+
return x == variable # Test if variable is x if non-iterable
|
206 |
+
|
207 |
+
|
208 |
+
def isp_aggregate_gene_shifts(
|
209 |
+
cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
|
210 |
+
):
|
211 |
+
cos_shift_data = dict()
|
212 |
+
for i in trange(cos_sims_df.shape[0]):
|
213 |
+
token = cos_sims_df["Gene"][i]
|
214 |
+
for dict_i in dict_list:
|
215 |
+
affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
|
216 |
+
for key in affected_pairs:
|
217 |
+
if key in cos_shift_data.keys():
|
218 |
+
cos_shift_data[key] += dict_i.get(key, [])
|
219 |
+
else:
|
220 |
+
cos_shift_data[key] = dict_i.get(key, [])
|
221 |
+
|
222 |
+
cos_data_mean = {
|
223 |
+
k: [np.mean(v), np.std(v), len(v)] for k, v in cos_shift_data.items()
|
224 |
+
}
|
225 |
+
cos_sims_full_df = pd.DataFrame()
|
226 |
+
cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
|
227 |
+
cos_sims_full_df["Gene_name"] = [
|
228 |
+
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0]
|
229 |
+
for k, v in cos_data_mean.items()
|
230 |
+
]
|
231 |
+
cos_sims_full_df["Ensembl_ID"] = [
|
232 |
+
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0]
|
233 |
+
for k, v in cos_data_mean.items()
|
234 |
+
]
|
235 |
+
cos_sims_full_df["Affected"] = [k[1] for k, v in cos_data_mean.items()]
|
236 |
+
cos_sims_full_df["Affected_Gene_name"] = [
|
237 |
+
gene_id_name_dict.get(gene_token_id_dict.get(token, np.nan), np.nan)
|
238 |
+
for token in cos_sims_full_df["Affected"]
|
239 |
+
]
|
240 |
+
cos_sims_full_df["Affected_Ensembl_ID"] = [
|
241 |
+
gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
|
242 |
+
]
|
243 |
+
cos_sims_full_df["Cosine_shift_mean"] = [v[0] for k, v in cos_data_mean.items()]
|
244 |
+
cos_sims_full_df["Cosine_shift_stdev"] = [v[1] for k, v in cos_data_mean.items()]
|
245 |
+
cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
|
246 |
+
|
247 |
+
specific_val = "cell_emb"
|
248 |
+
cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
|
249 |
+
# reorder so cell embs are at the top and all are subordered by magnitude of cosine shift
|
250 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
251 |
+
by=(["temp", "Cosine_shift_mean"]), ascending=[False, False]
|
252 |
+
).drop("temp", axis=1)
|
253 |
+
|
254 |
+
return cos_sims_full_df
|
255 |
+
|
256 |
|
257 |
# stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
|
258 |
+
def isp_stats_to_goal_state(
|
259 |
+
cos_sims_df, result_dict, cell_states_to_model, genes_perturbed
|
260 |
+
):
|
261 |
+
if (
|
262 |
+
("alt_states" not in cell_states_to_model.keys())
|
263 |
+
or (len(cell_states_to_model["alt_states"]) == 0)
|
264 |
+
or (cell_states_to_model["alt_states"] == [None])
|
265 |
+
):
|
266 |
alt_end_state_exists = False
|
267 |
+
elif (len(cell_states_to_model["alt_states"]) > 0) and (
|
268 |
+
cell_states_to_model["alt_states"] != [None]
|
269 |
+
):
|
270 |
alt_end_state_exists = True
|
271 |
+
|
272 |
# for single perturbation in multiple cells, there are no random perturbations to compare to
|
273 |
if genes_perturbed != "all":
|
274 |
+
cos_sims_full_df = pd.DataFrame()
|
275 |
+
|
276 |
+
cos_shift_data_end = []
|
|
|
|
|
|
|
|
|
277 |
token = cos_sims_df["Gene"][0]
|
278 |
+
cos_shift_data_end += result_dict[cell_states_to_model["goal_state"]].get(
|
279 |
+
(token, "cell_emb"), []
|
280 |
+
)
|
281 |
+
cos_sims_full_df["Shift_to_goal_end"] = [np.mean(cos_shift_data_end)]
|
282 |
+
if alt_end_state_exists is True:
|
283 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
284 |
+
cos_shift_data_alt_state = []
|
285 |
+
cos_shift_data_alt_state += result_dict.get(alt_state).get(
|
286 |
+
(token, "cell_emb"), []
|
287 |
+
)
|
288 |
+
cos_sims_full_df[f"Shift_to_alt_end_{alt_state}"] = [
|
289 |
+
np.mean(cos_shift_data_alt_state)
|
290 |
+
]
|
291 |
+
|
292 |
# sort by shift to desired state
|
293 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
294 |
+
by=["Shift_to_goal_end"], ascending=[False]
|
295 |
+
)
|
296 |
+
return cos_sims_full_df
|
297 |
+
|
298 |
elif genes_perturbed == "all":
|
299 |
+
goal_end_random_megalist = []
|
300 |
+
if alt_end_state_exists is True:
|
301 |
+
alt_end_state_random_dict = {
|
302 |
+
alt_state: [] for alt_state in cell_states_to_model["alt_states"]
|
303 |
+
}
|
304 |
for i in trange(cos_sims_df.shape[0]):
|
305 |
token = cos_sims_df["Gene"][i]
|
306 |
+
goal_end_random_megalist += result_dict[
|
307 |
+
cell_states_to_model["goal_state"]
|
308 |
+
].get((token, "cell_emb"), [])
|
309 |
+
if alt_end_state_exists is True:
|
310 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
311 |
+
alt_end_state_random_dict[alt_state] += result_dict[alt_state].get(
|
312 |
+
(token, "cell_emb"), []
|
313 |
+
)
|
314 |
|
315 |
# downsample to improve speed of ranksums
|
316 |
if len(goal_end_random_megalist) > 100_000:
|
317 |
random.seed(42)
|
318 |
+
goal_end_random_megalist = random.sample(
|
319 |
+
goal_end_random_megalist, k=100_000
|
320 |
+
)
|
321 |
+
if alt_end_state_exists is True:
|
322 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
323 |
+
if len(alt_end_state_random_dict[alt_state]) > 100_000:
|
324 |
+
random.seed(42)
|
325 |
+
alt_end_state_random_dict[alt_state] = random.sample(
|
326 |
+
alt_end_state_random_dict[alt_state], k=100_000
|
327 |
+
)
|
328 |
+
|
329 |
+
names = [
|
330 |
+
"Gene",
|
331 |
+
"Gene_name",
|
332 |
+
"Ensembl_ID",
|
333 |
+
"Shift_to_goal_end",
|
334 |
+
"Goal_end_vs_random_pval",
|
335 |
+
]
|
336 |
+
if alt_end_state_exists is True:
|
337 |
+
[
|
338 |
+
names.append(f"Shift_to_alt_end_{alt_state}")
|
339 |
+
for alt_state in cell_states_to_model["alt_states"]
|
340 |
+
]
|
341 |
+
names.append(names.pop(names.index("Goal_end_vs_random_pval")))
|
342 |
+
[
|
343 |
+
names.append(f"Alt_end_vs_random_pval_{alt_state}")
|
344 |
+
for alt_state in cell_states_to_model["alt_states"]
|
345 |
+
]
|
346 |
cos_sims_full_df = pd.DataFrame(columns=names)
|
347 |
|
348 |
+
n_detections_dict = dict()
|
349 |
for i in trange(cos_sims_df.shape[0]):
|
350 |
token = cos_sims_df["Gene"][i]
|
351 |
name = cos_sims_df["Gene_name"][i]
|
352 |
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
353 |
+
goal_end_cos_sim_megalist = result_dict[
|
354 |
+
cell_states_to_model["goal_state"]
|
355 |
+
].get((token, "cell_emb"), [])
|
356 |
+
n_detections_dict[token] = len(goal_end_cos_sim_megalist)
|
357 |
+
mean_goal_end = np.mean(goal_end_cos_sim_megalist)
|
358 |
+
pval_goal_end = ranksums(
|
359 |
+
goal_end_random_megalist, goal_end_cos_sim_megalist
|
360 |
+
).pvalue
|
361 |
|
362 |
+
if alt_end_state_exists is True:
|
363 |
+
alt_end_state_dict = {
|
364 |
+
alt_state: [] for alt_state in cell_states_to_model["alt_states"]
|
365 |
+
}
|
366 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
367 |
+
alt_end_state_dict[alt_state] = result_dict[alt_state].get(
|
368 |
+
(token, "cell_emb"), []
|
369 |
+
)
|
370 |
+
alt_end_state_dict[f"{alt_state}_mean"] = np.mean(
|
371 |
+
alt_end_state_dict[alt_state]
|
372 |
+
)
|
373 |
+
alt_end_state_dict[f"{alt_state}_pval"] = ranksums(
|
374 |
+
alt_end_state_random_dict[alt_state],
|
375 |
+
alt_end_state_dict[alt_state],
|
376 |
+
).pvalue
|
377 |
|
378 |
+
results_dict = dict()
|
379 |
+
results_dict["Gene"] = token
|
380 |
+
results_dict["Gene_name"] = name
|
381 |
+
results_dict["Ensembl_ID"] = ensembl_id
|
382 |
+
results_dict["Shift_to_goal_end"] = mean_goal_end
|
383 |
+
results_dict["Goal_end_vs_random_pval"] = pval_goal_end
|
384 |
+
if alt_end_state_exists is True:
|
385 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
386 |
+
results_dict[f"Shift_to_alt_end_{alt_state}"] = alt_end_state_dict[
|
387 |
+
f"{alt_state}_mean"
|
388 |
+
]
|
389 |
+
results_dict[
|
390 |
+
f"Alt_end_vs_random_pval_{alt_state}"
|
391 |
+
] = alt_end_state_dict[f"{alt_state}_pval"]
|
392 |
|
393 |
+
cos_sims_df_i = pd.DataFrame(results_dict, index=[i])
|
394 |
+
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
395 |
+
|
396 |
+
cos_sims_full_df["Goal_end_FDR"] = get_fdr(
|
397 |
+
list(cos_sims_full_df["Goal_end_vs_random_pval"])
|
398 |
+
)
|
399 |
+
if alt_end_state_exists is True:
|
400 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
401 |
+
cos_sims_full_df[f"Alt_end_FDR_{alt_state}"] = get_fdr(
|
402 |
+
list(cos_sims_full_df[f"Alt_end_vs_random_pval_{alt_state}"])
|
403 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
# quantify number of detections of each gene
|
406 |
+
cos_sims_full_df["N_Detections"] = [
|
407 |
+
n_detections_dict[token] for token in cos_sims_full_df["Gene"]
|
408 |
+
]
|
409 |
+
|
410 |
+
# sort by shift to desired state
|
411 |
+
cos_sims_full_df["Sig"] = [
|
412 |
+
1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]
|
413 |
+
]
|
414 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
415 |
+
by=["Sig", "Shift_to_goal_end", "Goal_end_FDR"],
|
416 |
+
ascending=[False, False, True],
|
417 |
+
)
|
418 |
+
|
419 |
return cos_sims_full_df
|
420 |
|
421 |
+
|
422 |
# stats comparing cos sim shifts of test perturbations vs null distribution
|
423 |
def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
|
424 |
cos_sims_full_df = cos_sims_df.copy()
|
425 |
|
426 |
cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
427 |
cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
428 |
+
cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(
|
429 |
+
cos_sims_df.shape[0], dtype=float
|
430 |
+
)
|
431 |
cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
432 |
cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
433 |
+
cos_sims_full_df["N_Detections_test"] = np.zeros(
|
434 |
+
cos_sims_df.shape[0], dtype="uint32"
|
435 |
+
)
|
436 |
+
cos_sims_full_df["N_Detections_null"] = np.zeros(
|
437 |
+
cos_sims_df.shape[0], dtype="uint32"
|
438 |
+
)
|
439 |
+
|
440 |
for i in trange(cos_sims_df.shape[0]):
|
441 |
token = cos_sims_df["Gene"][i]
|
442 |
test_shifts = []
|
443 |
null_shifts = []
|
444 |
+
|
445 |
for dict_i in dict_list:
|
446 |
+
test_shifts += dict_i.get((token, "cell_emb"), [])
|
447 |
|
448 |
for dict_i in null_dict_list:
|
449 |
+
null_shifts += dict_i.get((token, "cell_emb"), [])
|
450 |
+
|
451 |
cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
|
452 |
cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
|
453 |
+
cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(
|
454 |
+
test_shifts
|
455 |
+
) - np.mean(null_shifts)
|
456 |
+
cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(
|
457 |
+
test_shifts, null_shifts, nan_policy="omit"
|
458 |
+
).pvalue
|
459 |
+
# remove nan values
|
460 |
+
cos_sims_full_df.Test_vs_null_pval = np.where(
|
461 |
+
np.isnan(cos_sims_full_df.Test_vs_null_pval),
|
462 |
+
1,
|
463 |
+
cos_sims_full_df.Test_vs_null_pval,
|
464 |
+
)
|
465 |
cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
|
466 |
cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
|
467 |
|
468 |
+
cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(
|
469 |
+
cos_sims_full_df["Test_vs_null_pval"]
|
470 |
+
)
|
471 |
+
|
472 |
+
cos_sims_full_df["Sig"] = [
|
473 |
+
1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]
|
474 |
+
]
|
475 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
476 |
+
by=["Sig", "Test_vs_null_avg_shift", "Test_vs_null_FDR"],
|
477 |
+
ascending=[False, False, True],
|
478 |
+
)
|
479 |
return cos_sims_full_df
|
480 |
|
481 |
+
|
482 |
# stats for identifying perturbations with largest effect within a given set of cells
|
483 |
# fits a mixture model to 2 components (impact vs. non-impact) and
|
484 |
# reports the most likely component for each test perturbation
|
485 |
# Note: because assumes given perturbation has a consistent effect in the cells tested,
|
486 |
# we recommend only using the mixture model strategy with uniform cell populations
|
487 |
def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
|
488 |
+
names = ["Gene", "Gene_name", "Ensembl_ID"]
|
489 |
+
|
|
|
|
|
|
|
490 |
if combos == 0:
|
491 |
names += ["Test_avg_shift"]
|
492 |
elif combos == 1:
|
493 |
+
names += [
|
494 |
+
"Anchor_shift",
|
495 |
+
"Test_token_shift",
|
496 |
+
"Sum_of_indiv_shifts",
|
497 |
+
"Combo_shift",
|
498 |
+
"Combo_minus_sum_shift",
|
499 |
+
]
|
500 |
+
|
501 |
+
names += ["Impact_component", "Impact_component_percent"]
|
502 |
|
503 |
cos_sims_full_df = pd.DataFrame(columns=names)
|
504 |
avg_values = []
|
505 |
gene_names = []
|
506 |
+
|
507 |
for i in trange(cos_sims_df.shape[0]):
|
508 |
token = cos_sims_df["Gene"][i]
|
509 |
name = cos_sims_df["Gene_name"][i]
|
510 |
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
511 |
cos_shift_data = []
|
512 |
+
|
513 |
for dict_i in dict_list:
|
514 |
if (combos == 0) and (anchor_token is not None):
|
515 |
+
cos_shift_data += dict_i.get((anchor_token, token), [])
|
516 |
else:
|
517 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
518 |
+
|
519 |
# Extract values for current gene
|
520 |
if combos == 0:
|
521 |
test_values = cos_shift_data
|
522 |
elif combos == 1:
|
523 |
test_values = []
|
524 |
for tup in cos_shift_data:
|
525 |
+
test_values.append(tup[2])
|
526 |
+
|
527 |
if len(test_values) > 0:
|
528 |
avg_value = np.mean(test_values)
|
529 |
avg_values.append(avg_value)
|
530 |
gene_names.append(name)
|
531 |
+
|
532 |
# fit Gaussian mixture model to dataset of mean for each gene
|
533 |
avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
|
534 |
gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
|
535 |
+
|
536 |
for i in trange(cos_sims_df.shape[0]):
|
537 |
token = cos_sims_df["Gene"][i]
|
538 |
name = cos_sims_df["Gene_name"][i]
|
|
|
541 |
|
542 |
for dict_i in dict_list:
|
543 |
if (combos == 0) and (anchor_token is not None):
|
544 |
+
cos_shift_data += dict_i.get((anchor_token, token), [])
|
545 |
else:
|
546 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
547 |
+
|
548 |
if combos == 0:
|
549 |
mean_test = np.mean(cos_shift_data)
|
550 |
+
impact_components = [
|
551 |
+
get_impact_component(value, gm) for value in cos_shift_data
|
552 |
+
]
|
553 |
elif combos == 1:
|
554 |
+
anchor_cos_sim_megalist = [
|
555 |
+
anchor for anchor, token, combo in cos_shift_data
|
556 |
+
]
|
557 |
+
token_cos_sim_megalist = [token for anchor, token, combo in cos_shift_data]
|
558 |
+
anchor_plus_token_cos_sim_megalist = [
|
559 |
+
1 - ((1 - anchor) + (1 - token))
|
560 |
+
for anchor, token, combo in cos_shift_data
|
561 |
+
]
|
562 |
+
combo_anchor_token_cos_sim_megalist = [
|
563 |
+
combo for anchor, token, combo in cos_shift_data
|
564 |
+
]
|
565 |
+
combo_minus_sum_cos_sim_megalist = [
|
566 |
+
combo - (1 - ((1 - anchor) + (1 - token)))
|
567 |
+
for anchor, token, combo in cos_shift_data
|
568 |
+
]
|
569 |
|
570 |
mean_anchor = np.mean(anchor_cos_sim_megalist)
|
571 |
mean_token = np.mean(token_cos_sim_megalist)
|
572 |
mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
|
573 |
mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
|
574 |
mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
|
575 |
+
|
576 |
+
impact_components = [
|
577 |
+
get_impact_component(value, gm)
|
578 |
+
for value in combo_anchor_token_cos_sim_megalist
|
579 |
+
]
|
580 |
+
|
581 |
+
impact_component = get_impact_component(mean_test, gm)
|
582 |
+
impact_component_percent = np.mean(impact_components) * 100
|
583 |
+
|
584 |
+
data_i = [token, name, ensembl_id]
|
585 |
if combos == 0:
|
586 |
data_i += [mean_test]
|
587 |
elif combos == 1:
|
588 |
+
data_i += [
|
589 |
+
mean_anchor,
|
590 |
+
mean_token,
|
591 |
+
mean_sum,
|
592 |
+
mean_test,
|
593 |
+
mean_combo_minus_sum,
|
594 |
+
]
|
595 |
+
data_i += [impact_component, impact_component_percent]
|
596 |
+
|
597 |
+
cos_sims_df_i = pd.DataFrame(dict(zip(names, data_i)), index=[i])
|
598 |
+
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
599 |
+
|
600 |
# quantify number of detections of each gene
|
601 |
+
cos_sims_full_df["N_Detections"] = [
|
602 |
+
n_detections(i, dict_list, "gene", anchor_token)
|
603 |
+
for i in cos_sims_full_df["Gene"]
|
604 |
+
]
|
605 |
+
|
606 |
if combos == 0:
|
607 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
608 |
+
by=["Impact_component", "Test_avg_shift"], ascending=[False, True]
|
609 |
+
)
|
610 |
elif combos == 1:
|
611 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
612 |
+
by=["Impact_component", "Combo_minus_sum_shift"], ascending=[False, True]
|
613 |
+
)
|
614 |
return cos_sims_full_df
|
615 |
|
616 |
+
|
617 |
class InSilicoPerturberStats:
|
618 |
valid_option_dict = {
|
619 |
+
"mode": {
|
620 |
+
"goal_state_shift",
|
621 |
+
"vs_null",
|
622 |
+
"mixture_model",
|
623 |
+
"aggregate_data",
|
624 |
+
"aggregate_gene_shifts",
|
625 |
+
},
|
626 |
+
"genes_perturbed": {"all", list},
|
627 |
+
"combos": {0, 1},
|
628 |
"anchor_gene": {None, str},
|
629 |
"cell_states_to_model": {None, dict},
|
630 |
+
"pickle_suffix": {None, str},
|
631 |
}
|
632 |
+
|
633 |
def __init__(
|
634 |
self,
|
635 |
mode="mixture_model",
|
|
|
637 |
combos=0,
|
638 |
anchor_gene=None,
|
639 |
cell_states_to_model=None,
|
640 |
+
pickle_suffix="_raw.pickle",
|
641 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
642 |
gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
|
643 |
):
|
|
|
646 |
|
647 |
Parameters
|
648 |
----------
|
649 |
+
mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data","aggregate_gene_shifts"}
|
650 |
Type of stats.
|
651 |
"goal_state_shift": perturbation vs. random for desired cell state shift
|
652 |
"vs_null": perturbation vs. null from provided null distribution dataset
|
653 |
"mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
|
654 |
"aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
|
655 |
+
"aggregate_gene_shifts": aggregates cosine shifts of genes in response to perturbation(s)
|
656 |
genes_perturbed : "all", list
|
657 |
Genes perturbed in isp experiment.
|
658 |
Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
|
|
|
687 |
self.combos = combos
|
688 |
self.anchor_gene = anchor_gene
|
689 |
self.cell_states_to_model = cell_states_to_model
|
690 |
+
self.pickle_suffix = pickle_suffix
|
691 |
+
|
692 |
self.validate_options()
|
693 |
|
694 |
# load token dictionary (Ensembl IDs:token)
|
695 |
with open(token_dictionary_file, "rb") as f:
|
696 |
self.gene_token_dict = pickle.load(f)
|
697 |
+
|
698 |
# load gene name dictionary (gene name:Ensembl ID)
|
699 |
with open(gene_name_id_dictionary_file, "rb") as f:
|
700 |
self.gene_name_id_dict = pickle.load(f)
|
|
|
705 |
self.anchor_token = self.gene_token_dict[self.anchor_gene]
|
706 |
|
707 |
def validate_options(self):
|
708 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
709 |
attr_value = self.__dict__[attr_name]
|
710 |
if type(attr_value) not in {list, dict}:
|
711 |
if attr_name in {"anchor_gene"}:
|
|
|
714 |
continue
|
715 |
valid_type = False
|
716 |
for option in valid_options:
|
717 |
+
if (option in [str, int, list, dict]) and isinstance(
|
718 |
+
attr_value, option
|
719 |
+
):
|
720 |
valid_type = True
|
721 |
break
|
722 |
+
if not valid_type:
|
723 |
+
logger.error(
|
724 |
+
f"Invalid option for {attr_name}. "
|
725 |
+
f"Valid options for {attr_name}: {valid_options}"
|
726 |
+
)
|
727 |
+
raise
|
728 |
+
|
|
|
729 |
if self.cell_states_to_model is not None:
|
730 |
if len(self.cell_states_to_model.items()) == 1:
|
731 |
logger.warning(
|
732 |
+
"The single value dictionary for cell_states_to_model will be "
|
733 |
+
"replaced with a dictionary with named keys for start, goal, and alternate states. "
|
734 |
+
"Please specify state_key, start_state, goal_state, and alt_states "
|
735 |
+
"in the cell_states_to_model dictionary for future use. "
|
736 |
+
"For example, cell_states_to_model={"
|
737 |
+
"'state_key': 'disease', "
|
738 |
+
"'start_state': 'dcm', "
|
739 |
+
"'goal_state': 'nf', "
|
740 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
741 |
)
|
742 |
+
for key, value in self.cell_states_to_model.items():
|
743 |
if (len(value) == 3) and isinstance(value, tuple):
|
744 |
+
if (
|
745 |
+
isinstance(value[0], list)
|
746 |
+
and isinstance(value[1], list)
|
747 |
+
and isinstance(value[2], list)
|
748 |
+
):
|
749 |
if len(value[0]) == 1 and len(value[1]) == 1:
|
750 |
+
all_values = value[0] + value[1] + value[2]
|
751 |
if len(all_values) == len(set(all_values)):
|
752 |
continue
|
753 |
# reformat to the new named key format
|
|
|
756 |
"state_key": list(self.cell_states_to_model.keys())[0],
|
757 |
"start_state": state_values[0][0],
|
758 |
"goal_state": state_values[1][0],
|
759 |
+
"alt_states": state_values[2:][0],
|
760 |
}
|
761 |
+
elif set(self.cell_states_to_model.keys()) == {
|
762 |
+
"state_key",
|
763 |
+
"start_state",
|
764 |
+
"goal_state",
|
765 |
+
"alt_states",
|
766 |
+
}:
|
767 |
+
if (
|
768 |
+
(self.cell_states_to_model["state_key"] is None)
|
769 |
+
or (self.cell_states_to_model["start_state"] is None)
|
770 |
+
or (self.cell_states_to_model["goal_state"] is None)
|
771 |
+
):
|
772 |
logger.error(
|
773 |
+
"Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
|
774 |
+
)
|
775 |
raise
|
776 |
+
|
777 |
+
if (
|
778 |
+
self.cell_states_to_model["start_state"]
|
779 |
+
== self.cell_states_to_model["goal_state"]
|
780 |
+
):
|
781 |
+
logger.error("All states must be unique.")
|
782 |
raise
|
783 |
|
784 |
if self.cell_states_to_model["alt_states"] is not None:
|
785 |
+
if not isinstance(self.cell_states_to_model["alt_states"], list):
|
786 |
logger.error(
|
787 |
"self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
788 |
)
|
789 |
raise
|
790 |
+
if len(self.cell_states_to_model["alt_states"]) != len(
|
791 |
+
set(self.cell_states_to_model["alt_states"])
|
792 |
+
):
|
793 |
+
logger.error("All states must be unique.")
|
794 |
raise
|
795 |
|
796 |
else:
|
797 |
logger.error(
|
798 |
+
"cell_states_to_model must only have the following four keys: "
|
799 |
+
"'state_key', 'start_state', 'goal_state', 'alt_states'."
|
800 |
+
"For example, cell_states_to_model={"
|
801 |
+
"'state_key': 'disease', "
|
802 |
+
"'start_state': 'dcm', "
|
803 |
+
"'goal_state': 'nf', "
|
804 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
805 |
)
|
806 |
raise
|
807 |
|
808 |
if self.anchor_gene is not None:
|
809 |
self.anchor_gene = None
|
810 |
logger.warning(
|
811 |
+
"anchor_gene set to None. "
|
812 |
+
"Currently, anchor gene not available "
|
813 |
+
"when modeling multiple cell states."
|
814 |
+
)
|
815 |
+
|
816 |
if self.combos > 0:
|
817 |
if self.anchor_gene is None:
|
818 |
logger.error(
|
819 |
+
"Currently, stats are only supported for combination "
|
820 |
+
"in silico perturbation run with anchor gene. Please add "
|
821 |
+
"anchor gene when using with combos > 0. "
|
822 |
+
)
|
823 |
raise
|
824 |
+
|
825 |
if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
|
826 |
logger.error(
|
827 |
+
"Mixture model mode requires multiple gene perturbations to fit model "
|
828 |
+
"so is incompatible with a single grouped perturbation."
|
829 |
+
)
|
830 |
raise
|
831 |
if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
|
832 |
logger.error(
|
833 |
+
"Simple data aggregation mode is for single perturbation in multiple cells "
|
834 |
+
"so is incompatible with a genes_perturbed being 'all'."
|
835 |
+
)
|
836 |
+
raise
|
837 |
+
|
838 |
+
def get_stats(
|
839 |
+
self,
|
840 |
+
input_data_directory,
|
841 |
+
null_dist_data_directory,
|
842 |
+
output_directory,
|
843 |
+
output_prefix,
|
844 |
+
null_dict_list=None,
|
845 |
+
):
|
846 |
"""
|
847 |
Get stats for in silico perturbation data and save as results in output_directory.
|
848 |
|
|
|
856 |
Path to directory where perturbation data will be saved as .csv
|
857 |
output_prefix : str
|
858 |
Prefix for output .csv
|
859 |
+
null_dict_list: dict
|
860 |
+
List of loaded null distribtion dictionary if more than one comparison vs. the null is to be performed
|
861 |
+
|
862 |
Outputs
|
863 |
----------
|
864 |
Definition of possible columns in .csv output file.
|
865 |
+
|
866 |
Of note, not all columns will be present in all output files.
|
867 |
Some columns are specific to particular perturbation modes.
|
868 |
+
|
869 |
"Gene": gene token
|
870 |
"Gene_name": gene name
|
871 |
"Ensembl_ID": gene Ensembl ID
|
872 |
"N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
|
873 |
"Sig": 1 if FDR<0.05, otherwise 0
|
874 |
+
|
875 |
"Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
|
876 |
"Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
|
877 |
"Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
|
|
|
880 |
pvalue compares shift caused by perturbing given gene compared to random genes
|
881 |
"Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
|
882 |
"Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
|
883 |
+
|
884 |
"Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
|
885 |
"Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
|
886 |
"Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
|
|
|
889 |
"Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
|
890 |
"N_Detections_test": "N_Detections" in cells from test distribution
|
891 |
"N_Detections_null": "N_Detections" in cells from null distribution
|
892 |
+
|
893 |
"Anchor_shift": cosine shift in response to given perturbation of anchor gene
|
894 |
"Test_token_shift": cosine shift in response to given perturbation of test gene
|
895 |
"Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
|
|
|
899 |
"Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
|
900 |
1: within impact component; 0: not within impact component
|
901 |
"Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
|
902 |
+
|
903 |
+
In case of aggregating gene shifts:
|
904 |
+
"Perturbed": ID(s) of gene(s) being perturbed
|
905 |
+
"Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
|
906 |
+
"Cosine_shift_mean": mean of cosine shift of modeled perturbation on affected gene or cell
|
907 |
+
"Cosine_shift_stdev": standard deviation of cosine shift of modeled perturbation on affected gene or cell
|
908 |
"""
|
909 |
|
910 |
+
if self.mode not in [
|
911 |
+
"goal_state_shift",
|
912 |
+
"vs_null",
|
913 |
+
"mixture_model",
|
914 |
+
"aggregate_data",
|
915 |
+
"aggregate_gene_shifts",
|
916 |
+
]:
|
917 |
logger.error(
|
918 |
+
"Currently, only modes available are stats for goal_state_shift, "
|
919 |
+
"vs_null (comparing to null distribution), "
|
920 |
+
"mixture_model (fitting mixture model for perturbations with or without impact), "
|
921 |
+
"and aggregating data for single perturbations or for gene embedding shifts."
|
922 |
+
)
|
923 |
raise
|
924 |
|
925 |
self.gene_token_id_dict = invert_dict(self.gene_token_dict)
|
|
|
928 |
# obtain total gene list
|
929 |
if (self.combos == 0) and (self.anchor_token is not None):
|
930 |
# cos sim data for effect of gene perturbation on the embedding of each other gene
|
931 |
+
dict_list = read_dictionaries(
|
932 |
+
input_data_directory,
|
933 |
+
"gene",
|
934 |
+
self.anchor_token,
|
935 |
+
self.cell_states_to_model,
|
936 |
+
self.pickle_suffix,
|
937 |
+
)
|
938 |
gene_list = get_gene_list(dict_list, "gene")
|
939 |
+
elif (
|
940 |
+
(self.combos == 0)
|
941 |
+
and (self.anchor_token is None)
|
942 |
+
and (self.mode == "aggregate_gene_shifts")
|
943 |
+
):
|
944 |
+
dict_list = read_dictionaries(
|
945 |
+
input_data_directory,
|
946 |
+
"gene",
|
947 |
+
self.anchor_token,
|
948 |
+
self.cell_states_to_model,
|
949 |
+
self.pickle_suffix,
|
950 |
+
)
|
951 |
+
gene_list = get_gene_list(dict_list, "cell")
|
952 |
else:
|
953 |
# cos sim data for effect of gene perturbation on the embedding of each cell
|
954 |
+
dict_list = read_dictionaries(
|
955 |
+
input_data_directory,
|
956 |
+
"cell",
|
957 |
+
self.anchor_token,
|
958 |
+
self.cell_states_to_model,
|
959 |
+
self.pickle_suffix,
|
960 |
+
)
|
961 |
gene_list = get_gene_list(dict_list, "cell")
|
962 |
+
|
963 |
# initiate results dataframe
|
964 |
+
cos_sims_df_initial = pd.DataFrame(
|
965 |
+
{
|
966 |
+
"Gene": gene_list,
|
967 |
+
"Gene_name": [self.token_to_gene_name(item) for item in gene_list],
|
968 |
+
"Ensembl_ID": [
|
969 |
+
token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict)
|
970 |
+
if self.genes_perturbed != "all"
|
971 |
+
else self.gene_token_id_dict[genes[1]]
|
972 |
+
if isinstance(genes, tuple)
|
973 |
+
else self.gene_token_id_dict[genes]
|
974 |
+
for genes in gene_list
|
975 |
+
],
|
976 |
+
},
|
977 |
+
index=[i for i in range(len(gene_list))],
|
978 |
+
)
|
979 |
|
980 |
if self.mode == "goal_state_shift":
|
981 |
+
cos_sims_df = isp_stats_to_goal_state(
|
982 |
+
cos_sims_df_initial,
|
983 |
+
dict_list,
|
984 |
+
self.cell_states_to_model,
|
985 |
+
self.genes_perturbed,
|
986 |
+
)
|
987 |
+
|
988 |
elif self.mode == "vs_null":
|
989 |
+
if null_dict_list is None:
|
990 |
+
null_dict_list = read_dictionaries(
|
991 |
+
null_dist_data_directory,
|
992 |
+
"cell",
|
993 |
+
self.anchor_token,
|
994 |
+
self.cell_states_to_model,
|
995 |
+
self.pickle_suffix,
|
996 |
+
)
|
997 |
+
cos_sims_df = isp_stats_vs_null(
|
998 |
+
cos_sims_df_initial, dict_list, null_dict_list
|
999 |
+
)
|
1000 |
|
1001 |
elif self.mode == "mixture_model":
|
1002 |
+
cos_sims_df = isp_stats_mixture_model(
|
1003 |
+
cos_sims_df_initial, dict_list, self.combos, self.anchor_token
|
1004 |
+
)
|
1005 |
+
|
1006 |
elif self.mode == "aggregate_data":
|
1007 |
cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
|
1008 |
|
1009 |
+
elif self.mode == "aggregate_gene_shifts":
|
1010 |
+
cos_sims_df = isp_aggregate_gene_shifts(
|
1011 |
+
cos_sims_df_initial,
|
1012 |
+
dict_list,
|
1013 |
+
self.gene_token_id_dict,
|
1014 |
+
self.gene_id_name_dict,
|
1015 |
+
)
|
1016 |
+
|
1017 |
# save perturbation stats to output_path
|
1018 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
1019 |
cos_sims_df.to_csv(output_path)
|
1020 |
|
1021 |
def token_to_gene_name(self, item):
|
1022 |
+
if isinstance(item, int):
|
1023 |
+
return self.gene_id_name_dict.get(
|
1024 |
+
self.gene_token_id_dict.get(item, np.nan), np.nan
|
1025 |
+
)
|
1026 |
+
if isinstance(item, tuple):
|
1027 |
+
return tuple(
|
1028 |
+
[
|
1029 |
+
self.gene_id_name_dict.get(
|
1030 |
+
self.gene_token_id_dict.get(i, np.nan), np.nan
|
1031 |
+
)
|
1032 |
+
for i in item
|
1033 |
+
]
|
1034 |
+
)
|
geneformer/perturber_utils.py
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools as it
|
2 |
+
import logging
|
3 |
+
import pickle
|
4 |
+
import re
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import seaborn as sns
|
10 |
+
import torch
|
11 |
+
from datasets import Dataset, load_from_disk
|
12 |
+
from transformers import (
|
13 |
+
BertForMaskedLM,
|
14 |
+
BertForSequenceClassification,
|
15 |
+
BertForTokenClassification,
|
16 |
+
)
|
17 |
+
|
18 |
+
sns.set()
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
# load data and filter by defined criteria
|
24 |
+
def load_and_filter(filter_data, nproc, input_data_file):
|
25 |
+
data = load_from_disk(input_data_file)
|
26 |
+
if filter_data is not None:
|
27 |
+
data = filter_by_dict(data, filter_data, nproc)
|
28 |
+
return data
|
29 |
+
|
30 |
+
|
31 |
+
def filter_by_dict(data, filter_data, nproc):
|
32 |
+
for key, value in filter_data.items():
|
33 |
+
|
34 |
+
def filter_data_by_criteria(example):
|
35 |
+
return example[key] in value
|
36 |
+
|
37 |
+
data = data.filter(filter_data_by_criteria, num_proc=nproc)
|
38 |
+
if len(data) == 0:
|
39 |
+
logger.error("No cells remain after filtering. Check filtering criteria.")
|
40 |
+
raise
|
41 |
+
return data
|
42 |
+
|
43 |
+
|
44 |
+
def filter_data_by_tokens(filtered_input_data, tokens, nproc):
|
45 |
+
def if_has_tokens(example):
|
46 |
+
return len(set(example["input_ids"]).intersection(tokens)) == len(tokens)
|
47 |
+
|
48 |
+
filtered_input_data = filtered_input_data.filter(if_has_tokens, num_proc=nproc)
|
49 |
+
return filtered_input_data
|
50 |
+
|
51 |
+
|
52 |
+
def logging_filtered_data_len(filtered_input_data, filtered_tokens_categ):
|
53 |
+
if len(filtered_input_data) == 0:
|
54 |
+
logger.error(f"No cells in dataset contain {filtered_tokens_categ}.")
|
55 |
+
raise
|
56 |
+
else:
|
57 |
+
logger.info(f"# cells with {filtered_tokens_categ}: {len(filtered_input_data)}")
|
58 |
+
|
59 |
+
|
60 |
+
def filter_data_by_tokens_and_log(
|
61 |
+
filtered_input_data, tokens, nproc, filtered_tokens_categ
|
62 |
+
):
|
63 |
+
# filter for cells with anchor gene
|
64 |
+
filtered_input_data = filter_data_by_tokens(filtered_input_data, tokens, nproc)
|
65 |
+
# logging length of filtered data
|
66 |
+
logging_filtered_data_len(filtered_input_data, filtered_tokens_categ)
|
67 |
+
|
68 |
+
return filtered_input_data
|
69 |
+
|
70 |
+
|
71 |
+
def filter_data_by_start_state(filtered_input_data, cell_states_to_model, nproc):
|
72 |
+
# confirm that start state is valid to prevent futile filtering
|
73 |
+
state_key = cell_states_to_model["state_key"]
|
74 |
+
state_values = filtered_input_data[state_key]
|
75 |
+
start_state = cell_states_to_model["start_state"]
|
76 |
+
if start_state not in state_values:
|
77 |
+
logger.error(
|
78 |
+
f"Start state {start_state} is not present "
|
79 |
+
f"in the dataset's {state_key} attribute."
|
80 |
+
)
|
81 |
+
raise
|
82 |
+
|
83 |
+
# filter for start state cells
|
84 |
+
def filter_for_origin(example):
|
85 |
+
return example[state_key] in [start_state]
|
86 |
+
|
87 |
+
filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=nproc)
|
88 |
+
return filtered_input_data
|
89 |
+
|
90 |
+
|
91 |
+
def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
|
92 |
+
if cell_inds_to_perturb["start"] >= len(filtered_input_data):
|
93 |
+
logger.error(
|
94 |
+
"cell_inds_to_perturb['start'] is larger than the filtered dataset."
|
95 |
+
)
|
96 |
+
raise
|
97 |
+
if cell_inds_to_perturb["end"] > len(filtered_input_data):
|
98 |
+
logger.warning(
|
99 |
+
"cell_inds_to_perturb['end'] is larger than the filtered dataset. \
|
100 |
+
Setting to the end of the filtered dataset."
|
101 |
+
)
|
102 |
+
cell_inds_to_perturb["end"] = len(filtered_input_data)
|
103 |
+
filtered_input_data = filtered_input_data.select(
|
104 |
+
[i for i in range(cell_inds_to_perturb["start"], cell_inds_to_perturb["end"])]
|
105 |
+
)
|
106 |
+
return filtered_input_data
|
107 |
+
|
108 |
+
|
109 |
+
# load model to GPU
|
110 |
+
def load_model(model_type, num_classes, model_directory):
|
111 |
+
if model_type == "Pretrained":
|
112 |
+
model = BertForMaskedLM.from_pretrained(
|
113 |
+
model_directory, output_hidden_states=True, output_attentions=False
|
114 |
+
)
|
115 |
+
elif model_type == "GeneClassifier":
|
116 |
+
model = BertForTokenClassification.from_pretrained(
|
117 |
+
model_directory,
|
118 |
+
num_labels=num_classes,
|
119 |
+
output_hidden_states=True,
|
120 |
+
output_attentions=False,
|
121 |
+
)
|
122 |
+
elif model_type == "CellClassifier":
|
123 |
+
model = BertForSequenceClassification.from_pretrained(
|
124 |
+
model_directory,
|
125 |
+
num_labels=num_classes,
|
126 |
+
output_hidden_states=True,
|
127 |
+
output_attentions=False,
|
128 |
+
)
|
129 |
+
# put the model in eval mode for fwd pass
|
130 |
+
model.eval()
|
131 |
+
model = model.to("cuda:0")
|
132 |
+
return model
|
133 |
+
|
134 |
+
|
135 |
+
def quant_layers(model):
|
136 |
+
layer_nums = []
|
137 |
+
for name, parameter in model.named_parameters():
|
138 |
+
if "layer" in name:
|
139 |
+
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
|
140 |
+
return int(max(layer_nums)) + 1
|
141 |
+
|
142 |
+
|
143 |
+
def get_model_input_size(model):
|
144 |
+
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
|
145 |
+
|
146 |
+
|
147 |
+
def flatten_list(megalist):
|
148 |
+
return [item for sublist in megalist for item in sublist]
|
149 |
+
|
150 |
+
|
151 |
+
def measure_length(example):
|
152 |
+
example["length"] = len(example["input_ids"])
|
153 |
+
return example
|
154 |
+
|
155 |
+
|
156 |
+
def downsample_and_sort(data, max_ncells):
|
157 |
+
num_cells = len(data)
|
158 |
+
# if max number of cells is defined, then shuffle and subsample to this max number
|
159 |
+
if max_ncells is not None:
|
160 |
+
if num_cells > max_ncells:
|
161 |
+
data = data.shuffle(seed=42)
|
162 |
+
num_cells = max_ncells
|
163 |
+
data_subset = data.select([i for i in range(num_cells)])
|
164 |
+
# sort dataset with largest cell first to encounter any memory errors earlier
|
165 |
+
data_sorted = data_subset.sort("length", reverse=True)
|
166 |
+
return data_sorted
|
167 |
+
|
168 |
+
|
169 |
+
def get_possible_states(cell_states_to_model):
|
170 |
+
possible_states = []
|
171 |
+
for key in ["start_state", "goal_state"]:
|
172 |
+
possible_states += [cell_states_to_model[key]]
|
173 |
+
possible_states += cell_states_to_model.get("alt_states", [])
|
174 |
+
return possible_states
|
175 |
+
|
176 |
+
|
177 |
+
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
178 |
+
example_cell.set_format(type="torch")
|
179 |
+
input_data = example_cell["input_ids"]
|
180 |
+
with torch.no_grad():
|
181 |
+
outputs = model(input_ids=input_data.to("cuda"))
|
182 |
+
emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
|
183 |
+
del outputs
|
184 |
+
return emb
|
185 |
+
|
186 |
+
|
187 |
+
def perturb_emb_by_index(emb, indices):
|
188 |
+
mask = torch.ones(emb.numel(), dtype=torch.bool)
|
189 |
+
mask[indices] = False
|
190 |
+
return emb[mask]
|
191 |
+
|
192 |
+
|
193 |
+
def delete_indices(example):
|
194 |
+
indices = example["perturb_index"]
|
195 |
+
if any(isinstance(el, list) for el in indices):
|
196 |
+
indices = flatten_list(indices)
|
197 |
+
for index in sorted(indices, reverse=True):
|
198 |
+
del example["input_ids"][index]
|
199 |
+
|
200 |
+
example["length"] = len(example["input_ids"])
|
201 |
+
return example
|
202 |
+
|
203 |
+
|
204 |
+
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
205 |
+
def overexpress_indices(example):
|
206 |
+
indices = example["perturb_index"]
|
207 |
+
if any(isinstance(el, list) for el in indices):
|
208 |
+
indices = flatten_list(indices)
|
209 |
+
for index in sorted(indices, reverse=True):
|
210 |
+
example["input_ids"].insert(0, example["input_ids"].pop(index))
|
211 |
+
|
212 |
+
example["length"] = len(example["input_ids"])
|
213 |
+
return example
|
214 |
+
|
215 |
+
|
216 |
+
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
217 |
+
def overexpress_tokens(example, max_len):
|
218 |
+
# -100 indicates tokens to overexpress are not present in rank value encoding
|
219 |
+
if example["perturb_index"] != [-100]:
|
220 |
+
example = delete_indices(example)
|
221 |
+
[
|
222 |
+
example["input_ids"].insert(0, token)
|
223 |
+
for token in example["tokens_to_perturb"][::-1]
|
224 |
+
]
|
225 |
+
|
226 |
+
# truncate to max input size, must also truncate original emb to be comparable
|
227 |
+
if len(example["input_ids"]) > max_len:
|
228 |
+
example["input_ids"] = example["input_ids"][0:max_len]
|
229 |
+
|
230 |
+
example["length"] = len(example["input_ids"])
|
231 |
+
return example
|
232 |
+
|
233 |
+
|
234 |
+
def calc_n_overflow(max_len, example_len, tokens_to_perturb, indices_to_perturb):
|
235 |
+
n_to_add = len(tokens_to_perturb) - len(indices_to_perturb)
|
236 |
+
n_overflow = example_len + n_to_add - max_len
|
237 |
+
return n_overflow
|
238 |
+
|
239 |
+
|
240 |
+
def truncate_by_n_overflow(example):
|
241 |
+
new_max_len = example["length"] - example["n_overflow"]
|
242 |
+
example["input_ids"] = example["input_ids"][0:new_max_len]
|
243 |
+
example["length"] = len(example["input_ids"])
|
244 |
+
return example
|
245 |
+
|
246 |
+
|
247 |
+
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
248 |
+
# indices_to_remove is list of indices to remove
|
249 |
+
indices_to_keep = [
|
250 |
+
i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove
|
251 |
+
]
|
252 |
+
num_dims = emb.dim()
|
253 |
+
emb_slice = [
|
254 |
+
slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)
|
255 |
+
]
|
256 |
+
sliced_emb = emb[emb_slice]
|
257 |
+
return sliced_emb
|
258 |
+
|
259 |
+
|
260 |
+
def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
|
261 |
+
output_batch_list = [
|
262 |
+
remove_indices_from_emb(emb_batch[i, :, :], idxes, gene_dim - 1)
|
263 |
+
for i, idxes in enumerate(list_of_indices_to_remove)
|
264 |
+
]
|
265 |
+
# add padding given genes are sometimes added that are or are not in original cell
|
266 |
+
batch_max = max([emb.size()[gene_dim - 1] for emb in output_batch_list])
|
267 |
+
output_batch_list_padded = [
|
268 |
+
pad_xd_tensor(emb, 0.000, batch_max, gene_dim - 1) for emb in output_batch_list
|
269 |
+
]
|
270 |
+
return torch.stack(output_batch_list_padded)
|
271 |
+
|
272 |
+
|
273 |
+
# removes perturbed indices
|
274 |
+
# need to handle the various cases where a set of genes is overexpressed
|
275 |
+
def remove_perturbed_indices_set(
|
276 |
+
emb,
|
277 |
+
perturb_type: str,
|
278 |
+
indices_to_perturb: list[list],
|
279 |
+
tokens_to_perturb: list[list],
|
280 |
+
original_lengths: list[int],
|
281 |
+
input_ids=None,
|
282 |
+
):
|
283 |
+
if perturb_type == "overexpress":
|
284 |
+
num_perturbed = len(tokens_to_perturb)
|
285 |
+
if num_perturbed == 1:
|
286 |
+
indices_to_perturb_orig = [
|
287 |
+
idx if idx != [-100] else [None] for idx in indices_to_perturb
|
288 |
+
]
|
289 |
+
if all(v is [None] for v in indices_to_perturb_orig):
|
290 |
+
return emb
|
291 |
+
else:
|
292 |
+
indices_to_perturb_orig = []
|
293 |
+
|
294 |
+
for idx_list in indices_to_perturb:
|
295 |
+
indices_to_perturb_orig.append(
|
296 |
+
[idx if idx != [-100] else [None] for idx in idx_list]
|
297 |
+
)
|
298 |
+
|
299 |
+
else:
|
300 |
+
indices_to_perturb_orig = indices_to_perturb
|
301 |
+
|
302 |
+
emb = remove_indices_from_emb_batch(emb, indices_to_perturb_orig, gene_dim=1)
|
303 |
+
|
304 |
+
return emb
|
305 |
+
|
306 |
+
|
307 |
+
def make_perturbation_batch(
|
308 |
+
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
309 |
+
) -> tuple[Dataset, list[int]]:
|
310 |
+
if combo_lvl == 0 and tokens_to_perturb == "all":
|
311 |
+
if perturb_type in ["overexpress", "activate"]:
|
312 |
+
range_start = 1
|
313 |
+
elif perturb_type in ["delete", "inhibit"]:
|
314 |
+
range_start = 0
|
315 |
+
indices_to_perturb = [
|
316 |
+
[i] for i in range(range_start, example_cell["length"][0])
|
317 |
+
]
|
318 |
+
# elif combo_lvl > 0 and anchor_token is None:
|
319 |
+
## to implement
|
320 |
+
elif combo_lvl > 0 and (anchor_token is not None):
|
321 |
+
example_input_ids = example_cell["input_ids"][0]
|
322 |
+
anchor_index = example_input_ids.index(anchor_token[0])
|
323 |
+
indices_to_perturb = [
|
324 |
+
sorted([anchor_index, i]) if i != anchor_index else None
|
325 |
+
for i in range(example_cell["length"][0])
|
326 |
+
]
|
327 |
+
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
328 |
+
else:
|
329 |
+
example_input_ids = example_cell["input_ids"][0]
|
330 |
+
indices_to_perturb = [
|
331 |
+
[example_input_ids.index(token)] if token in example_input_ids else None
|
332 |
+
for token in tokens_to_perturb
|
333 |
+
]
|
334 |
+
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
335 |
+
|
336 |
+
# create all permutations of combo_lvl of modifiers from tokens_to_perturb
|
337 |
+
if combo_lvl > 0 and (anchor_token is None):
|
338 |
+
if tokens_to_perturb != "all":
|
339 |
+
if len(tokens_to_perturb) == combo_lvl + 1:
|
340 |
+
indices_to_perturb = [
|
341 |
+
list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
|
342 |
+
]
|
343 |
+
else:
|
344 |
+
all_indices = [[i] for i in range(example_cell["length"][0])]
|
345 |
+
all_indices = [
|
346 |
+
index for index in all_indices if index not in indices_to_perturb
|
347 |
+
]
|
348 |
+
indices_to_perturb = [
|
349 |
+
[[j for i in indices_to_perturb for j in i], x] for x in all_indices
|
350 |
+
]
|
351 |
+
|
352 |
+
length = len(indices_to_perturb)
|
353 |
+
perturbation_dataset = Dataset.from_dict(
|
354 |
+
{
|
355 |
+
"input_ids": example_cell["input_ids"] * length,
|
356 |
+
"perturb_index": indices_to_perturb,
|
357 |
+
}
|
358 |
+
)
|
359 |
+
|
360 |
+
if length < 400:
|
361 |
+
num_proc_i = 1
|
362 |
+
else:
|
363 |
+
num_proc_i = num_proc
|
364 |
+
|
365 |
+
if perturb_type == "delete":
|
366 |
+
perturbation_dataset = perturbation_dataset.map(
|
367 |
+
delete_indices, num_proc=num_proc_i
|
368 |
+
)
|
369 |
+
elif perturb_type == "overexpress":
|
370 |
+
perturbation_dataset = perturbation_dataset.map(
|
371 |
+
overexpress_indices, num_proc=num_proc_i
|
372 |
+
)
|
373 |
+
|
374 |
+
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
375 |
+
|
376 |
+
return perturbation_dataset, indices_to_perturb
|
377 |
+
|
378 |
+
|
379 |
+
# perturbed cell emb removing the activated/overexpressed/inhibited gene emb
|
380 |
+
# so that only non-perturbed gene embeddings are compared to each other
|
381 |
+
# in original or perturbed context
|
382 |
+
def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
|
383 |
+
all_embs_list = []
|
384 |
+
|
385 |
+
# if making comparison batch for multiple perturbations in single cell
|
386 |
+
if perturb_group is False:
|
387 |
+
# squeeze if single cell
|
388 |
+
if original_emb_batch.ndim == 3 and original_emb_batch.size()[0] == 1:
|
389 |
+
original_emb_batch = torch.squeeze(original_emb_batch)
|
390 |
+
original_emb_list = [original_emb_batch] * len(indices_to_perturb)
|
391 |
+
# if making comparison batch for single perturbation in multiple cells
|
392 |
+
elif perturb_group is True:
|
393 |
+
original_emb_list = original_emb_batch
|
394 |
+
|
395 |
+
for original_emb, indices in zip(original_emb_list, indices_to_perturb):
|
396 |
+
if indices == [-100]:
|
397 |
+
all_embs_list += [original_emb[:]]
|
398 |
+
continue
|
399 |
+
|
400 |
+
emb_list = []
|
401 |
+
start = 0
|
402 |
+
if any(isinstance(el, list) for el in indices):
|
403 |
+
indices = flatten_list(indices)
|
404 |
+
|
405 |
+
# removes indices that were perturbed from the original embedding
|
406 |
+
for i in sorted(indices):
|
407 |
+
emb_list += [original_emb[start:i]]
|
408 |
+
start = i + 1
|
409 |
+
|
410 |
+
emb_list += [original_emb[start:]]
|
411 |
+
all_embs_list += [torch.cat(emb_list)]
|
412 |
+
|
413 |
+
len_set = set([emb.size()[0] for emb in all_embs_list])
|
414 |
+
if len(len_set) > 1:
|
415 |
+
max_len = max(len_set)
|
416 |
+
all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
|
417 |
+
return torch.stack(all_embs_list)
|
418 |
+
|
419 |
+
|
420 |
+
def pad_list(input_ids, pad_token_id, max_len):
|
421 |
+
input_ids = np.pad(
|
422 |
+
input_ids,
|
423 |
+
(0, max_len - len(input_ids)),
|
424 |
+
mode="constant",
|
425 |
+
constant_values=pad_token_id,
|
426 |
+
)
|
427 |
+
return input_ids
|
428 |
+
|
429 |
+
|
430 |
+
def pad_xd_tensor(tensor, pad_token_id, max_len, dim):
|
431 |
+
padding_length = max_len - tensor.size()[dim]
|
432 |
+
# Construct a padding configuration where all padding values are 0, except for the padding dimension
|
433 |
+
# 2 * number of dimensions (padding before and after for every dimension)
|
434 |
+
pad_config = [0] * 2 * tensor.dim()
|
435 |
+
# Set the padding after the desired dimension to the calculated padding length
|
436 |
+
pad_config[-2 * dim - 1] = padding_length
|
437 |
+
return torch.nn.functional.pad(
|
438 |
+
tensor, pad=pad_config, mode="constant", value=pad_token_id
|
439 |
+
)
|
440 |
+
|
441 |
+
|
442 |
+
def pad_tensor(tensor, pad_token_id, max_len):
|
443 |
+
tensor = torch.nn.functional.pad(
|
444 |
+
tensor, pad=(0, max_len - tensor.numel()), mode="constant", value=pad_token_id
|
445 |
+
)
|
446 |
+
|
447 |
+
return tensor
|
448 |
+
|
449 |
+
|
450 |
+
def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
|
451 |
+
if dim == 0:
|
452 |
+
pad = (0, 0, 0, max_len - tensor.size()[dim])
|
453 |
+
elif dim == 1:
|
454 |
+
pad = (0, max_len - tensor.size()[dim], 0, 0)
|
455 |
+
tensor = torch.nn.functional.pad(
|
456 |
+
tensor, pad=pad, mode="constant", value=pad_token_id
|
457 |
+
)
|
458 |
+
return tensor
|
459 |
+
|
460 |
+
|
461 |
+
def pad_3d_tensor(tensor, pad_token_id, max_len, dim):
|
462 |
+
if dim == 0:
|
463 |
+
raise Exception("dim 0 usually does not need to be padded.")
|
464 |
+
if dim == 1:
|
465 |
+
pad = (0, 0, 0, max_len - tensor.size()[dim])
|
466 |
+
elif dim == 2:
|
467 |
+
pad = (0, max_len - tensor.size()[dim], 0, 0)
|
468 |
+
tensor = torch.nn.functional.pad(
|
469 |
+
tensor, pad=pad, mode="constant", value=pad_token_id
|
470 |
+
)
|
471 |
+
return tensor
|
472 |
+
|
473 |
+
|
474 |
+
def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
|
475 |
+
if isinstance(encoding, torch.Tensor):
|
476 |
+
encoding_len = encoding.size()[0]
|
477 |
+
elif isinstance(encoding, list):
|
478 |
+
encoding_len = len(encoding)
|
479 |
+
if encoding_len > max_len:
|
480 |
+
encoding = encoding[0:max_len]
|
481 |
+
elif encoding_len < max_len:
|
482 |
+
if isinstance(encoding, torch.Tensor):
|
483 |
+
encoding = pad_tensor(encoding, pad_token_id, max_len)
|
484 |
+
elif isinstance(encoding, list):
|
485 |
+
encoding = pad_list(encoding, pad_token_id, max_len)
|
486 |
+
return encoding
|
487 |
+
|
488 |
+
|
489 |
+
# pad list of tensors and convert to tensor
|
490 |
+
def pad_tensor_list(
|
491 |
+
tensor_list,
|
492 |
+
dynamic_or_constant,
|
493 |
+
pad_token_id,
|
494 |
+
model_input_size,
|
495 |
+
dim=None,
|
496 |
+
padding_func=None,
|
497 |
+
):
|
498 |
+
# determine maximum tensor length
|
499 |
+
if dynamic_or_constant == "dynamic":
|
500 |
+
max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
|
501 |
+
elif isinstance(dynamic_or_constant, int):
|
502 |
+
max_len = dynamic_or_constant
|
503 |
+
else:
|
504 |
+
max_len = model_input_size
|
505 |
+
logger.warning(
|
506 |
+
"If padding style is constant, must provide integer value. "
|
507 |
+
f"Setting padding to max input size {model_input_size}."
|
508 |
+
)
|
509 |
+
|
510 |
+
# pad all tensors to maximum length
|
511 |
+
if dim is None:
|
512 |
+
tensor_list = [
|
513 |
+
pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list
|
514 |
+
]
|
515 |
+
else:
|
516 |
+
tensor_list = [
|
517 |
+
padding_func(tensor, pad_token_id, max_len, dim) for tensor in tensor_list
|
518 |
+
]
|
519 |
+
# return stacked tensors
|
520 |
+
if padding_func != pad_3d_tensor:
|
521 |
+
return torch.stack(tensor_list)
|
522 |
+
else:
|
523 |
+
return torch.cat(tensor_list, 0)
|
524 |
+
|
525 |
+
|
526 |
+
def gen_attention_mask(minibatch_encoding, max_len=None):
|
527 |
+
if max_len is None:
|
528 |
+
max_len = max(minibatch_encoding["length"])
|
529 |
+
original_lens = minibatch_encoding["length"]
|
530 |
+
attention_mask = [
|
531 |
+
[1] * original_len + [0] * (max_len - original_len)
|
532 |
+
if original_len <= max_len
|
533 |
+
else [1] * max_len
|
534 |
+
for original_len in original_lens
|
535 |
+
]
|
536 |
+
return torch.tensor(attention_mask, device="cuda")
|
537 |
+
|
538 |
+
|
539 |
+
# get cell embeddings excluding padding
|
540 |
+
def mean_nonpadding_embs(embs, original_lens, dim=1):
|
541 |
+
# create a mask tensor based on padding lengths
|
542 |
+
mask = torch.arange(embs.size(dim), device=embs.device) < original_lens.unsqueeze(1)
|
543 |
+
if embs.dim() == 3:
|
544 |
+
# fill the masked positions in embs with zeros
|
545 |
+
masked_embs = embs.masked_fill(~mask.unsqueeze(2), 0.0)
|
546 |
+
|
547 |
+
# compute the mean across the non-padding dimensions
|
548 |
+
mean_embs = masked_embs.sum(dim) / original_lens.view(-1, 1).float()
|
549 |
+
|
550 |
+
elif embs.dim() == 2:
|
551 |
+
masked_embs = embs.masked_fill(~mask, 0.0)
|
552 |
+
mean_embs = masked_embs.sum(dim) / original_lens.float()
|
553 |
+
return mean_embs
|
554 |
+
|
555 |
+
|
556 |
+
# get cell embeddings when there is no padding
|
557 |
+
def compute_nonpadded_cell_embedding(embs, cell_emb_style):
|
558 |
+
if cell_emb_style == "mean_pool":
|
559 |
+
return torch.mean(embs, dim=embs.ndim - 2)
|
560 |
+
|
561 |
+
|
562 |
+
# quantify shifts for a set of genes
|
563 |
+
def quant_cos_sims(
|
564 |
+
perturbation_emb,
|
565 |
+
original_emb,
|
566 |
+
cell_states_to_model,
|
567 |
+
state_embs_dict,
|
568 |
+
emb_mode="gene",
|
569 |
+
):
|
570 |
+
if emb_mode == "gene":
|
571 |
+
cos = torch.nn.CosineSimilarity(dim=2)
|
572 |
+
elif emb_mode == "cell":
|
573 |
+
cos = torch.nn.CosineSimilarity(dim=1)
|
574 |
+
|
575 |
+
if cell_states_to_model is None:
|
576 |
+
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
577 |
+
else:
|
578 |
+
possible_states = get_possible_states(cell_states_to_model)
|
579 |
+
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|
580 |
+
for state in possible_states:
|
581 |
+
cos_sims[state] = cos_sim_shift(
|
582 |
+
original_emb,
|
583 |
+
perturbation_emb,
|
584 |
+
state_embs_dict[state].to("cuda"), # required to move to cuda here
|
585 |
+
cos,
|
586 |
+
)
|
587 |
+
|
588 |
+
return cos_sims
|
589 |
+
|
590 |
+
|
591 |
+
# calculate cos sim shift of perturbation with respect to origin and alternative cell
|
592 |
+
def cos_sim_shift(original_emb, perturbed_emb, end_emb, cos):
|
593 |
+
origin_v_end = cos(original_emb, end_emb)
|
594 |
+
perturb_v_end = cos(perturbed_emb, end_emb)
|
595 |
+
|
596 |
+
return perturb_v_end - origin_v_end
|
597 |
+
|
598 |
+
|
599 |
+
def concatenate_cos_sims(cos_sims):
|
600 |
+
if isinstance(cos_sims, list):
|
601 |
+
return torch.cat(cos_sims)
|
602 |
+
else:
|
603 |
+
for state in cos_sims.keys():
|
604 |
+
cos_sims[state] = torch.cat(cos_sims[state])
|
605 |
+
return cos_sims
|
606 |
+
|
607 |
+
|
608 |
+
def write_perturbation_dictionary(cos_sims_dict: defaultdict, output_path_prefix: str):
|
609 |
+
with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
|
610 |
+
pickle.dump(cos_sims_dict, fp)
|
611 |
+
|
612 |
+
|
613 |
+
def tensor_list_to_pd(tensor_list):
|
614 |
+
tensor = torch.cat(tensor_list).cpu().numpy()
|
615 |
+
df = pd.DataFrame(tensor)
|
616 |
+
return df
|
617 |
+
|
618 |
+
|
619 |
+
def validate_cell_states_to_model(cell_states_to_model):
|
620 |
+
if cell_states_to_model is not None:
|
621 |
+
if len(cell_states_to_model.items()) == 1:
|
622 |
+
logger.warning(
|
623 |
+
"The single value dictionary for cell_states_to_model will be "
|
624 |
+
"replaced with a dictionary with named keys for start, goal, and alternate states. "
|
625 |
+
"Please specify state_key, start_state, goal_state, and alt_states "
|
626 |
+
"in the cell_states_to_model dictionary for future use. "
|
627 |
+
"For example, cell_states_to_model={"
|
628 |
+
"'state_key': 'disease', "
|
629 |
+
"'start_state': 'dcm', "
|
630 |
+
"'goal_state': 'nf', "
|
631 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
632 |
+
)
|
633 |
+
for key, value in cell_states_to_model.items():
|
634 |
+
if (len(value) == 3) and isinstance(value, tuple):
|
635 |
+
if (
|
636 |
+
isinstance(value[0], list)
|
637 |
+
and isinstance(value[1], list)
|
638 |
+
and isinstance(value[2], list)
|
639 |
+
):
|
640 |
+
if len(value[0]) == 1 and len(value[1]) == 1:
|
641 |
+
all_values = value[0] + value[1] + value[2]
|
642 |
+
if len(all_values) == len(set(all_values)):
|
643 |
+
continue
|
644 |
+
# reformat to the new named key format
|
645 |
+
state_values = flatten_list(list(cell_states_to_model.values()))
|
646 |
+
|
647 |
+
cell_states_to_model = {
|
648 |
+
"state_key": list(cell_states_to_model.keys())[0],
|
649 |
+
"start_state": state_values[0][0],
|
650 |
+
"goal_state": state_values[1][0],
|
651 |
+
"alt_states": state_values[2:][0],
|
652 |
+
}
|
653 |
+
elif set(cell_states_to_model.keys()).issuperset(
|
654 |
+
{"state_key", "start_state", "goal_state"}
|
655 |
+
):
|
656 |
+
if (
|
657 |
+
(cell_states_to_model["state_key"] is None)
|
658 |
+
or (cell_states_to_model["start_state"] is None)
|
659 |
+
or (cell_states_to_model["goal_state"] is None)
|
660 |
+
):
|
661 |
+
logger.error(
|
662 |
+
"Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
|
663 |
+
)
|
664 |
+
raise
|
665 |
+
|
666 |
+
if (
|
667 |
+
cell_states_to_model["start_state"]
|
668 |
+
== cell_states_to_model["goal_state"]
|
669 |
+
):
|
670 |
+
logger.error("All states must be unique.")
|
671 |
+
raise
|
672 |
+
|
673 |
+
if "alt_states" in set(cell_states_to_model.keys()):
|
674 |
+
if cell_states_to_model["alt_states"] is not None:
|
675 |
+
if not isinstance(cell_states_to_model["alt_states"], list):
|
676 |
+
logger.error(
|
677 |
+
"cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
678 |
+
)
|
679 |
+
raise
|
680 |
+
if len(cell_states_to_model["alt_states"]) != len(
|
681 |
+
set(cell_states_to_model["alt_states"])
|
682 |
+
):
|
683 |
+
logger.error("All states must be unique.")
|
684 |
+
raise
|
685 |
+
else:
|
686 |
+
cell_states_to_model["alt_states"] = []
|
687 |
+
|
688 |
+
else:
|
689 |
+
logger.error(
|
690 |
+
"cell_states_to_model must only have the following four keys: "
|
691 |
+
"'state_key', 'start_state', 'goal_state', 'alt_states'."
|
692 |
+
"For example, cell_states_to_model={"
|
693 |
+
"'state_key': 'disease', "
|
694 |
+
"'start_state': 'dcm', "
|
695 |
+
"'goal_state': 'nf', "
|
696 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
697 |
+
)
|
698 |
+
raise
|
setup.py
CHANGED
@@ -2,7 +2,7 @@ from setuptools import setup
|
|
2 |
|
3 |
setup(
|
4 |
name="geneformer",
|
5 |
-
version="0.0
|
6 |
author="Christina Theodoris",
|
7 |
author_email="christina.theodoris@gladstone.ucsf.edu",
|
8 |
description="Geneformer is a transformer model pretrained \
|
|
|
2 |
|
3 |
setup(
|
4 |
name="geneformer",
|
5 |
+
version="0.1.0",
|
6 |
author="Christina Theodoris",
|
7 |
author_email="christina.theodoris@gladstone.ucsf.edu",
|
8 |
description="Geneformer is a transformer model pretrained \
|