Upload in_silico_perturber_stats.py
Browse filesFix bug in selecting a gene with "aggregate_data" option
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -192,16 +192,27 @@ def get_impact_component(test_value, gaussian_mixture_model):
|
|
192 |
|
193 |
|
194 |
# aggregate data for single perturbation in multiple cells
|
195 |
-
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
|
196 |
-
names = ["Cosine_shift"]
|
197 |
-
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
|
207 |
def find(variable, x):
|
@@ -1017,8 +1028,8 @@ class InSilicoPerturberStats:
|
|
1017 |
cos_sims_df_initial, dict_list, self.combos, self.anchor_token
|
1018 |
)
|
1019 |
|
1020 |
-
elif self.mode == "aggregate_data":
|
1021 |
-
cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
|
1022 |
|
1023 |
elif self.mode == "aggregate_gene_shifts":
|
1024 |
cos_sims_df = isp_aggregate_gene_shifts(
|
|
|
192 |
|
193 |
|
194 |
# aggregate data for single perturbation in multiple cells
|
195 |
+
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
|
196 |
+
names = ["Cosine_shift", "Gene"]
|
197 |
+
cos_sims_full_dfs = []
|
198 |
|
199 |
+
|
200 |
+
gene_ids_df = cos_sims_df.loc[np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :]
|
201 |
+
tokens = gene_ids_df["Gene"]
|
202 |
+
symbols = gene_ids_df["Gene_name"]
|
203 |
+
|
204 |
+
for token, symbol in zip(tokens, symbols):
|
205 |
+
cos_shift_data = []
|
206 |
+
for dict_i in dict_list:
|
207 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
208 |
+
|
209 |
+
df = pd.DataFrame(columns=names)
|
210 |
+
df["Cosine_shift"] = cos_shift_data
|
211 |
+
df["Gene"] = symbol
|
212 |
+
cos_sims_full_dfs.append(df)
|
213 |
+
|
214 |
+
|
215 |
+
return pd.concat(cos_sims_full_dfs)
|
216 |
|
217 |
|
218 |
def find(variable, x):
|
|
|
1028 |
cos_sims_df_initial, dict_list, self.combos, self.anchor_token
|
1029 |
)
|
1030 |
|
1031 |
+
elif self.mode == "aggregate_data":
|
1032 |
+
cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed)
|
1033 |
|
1034 |
elif self.mode == "aggregate_gene_shifts":
|
1035 |
cos_sims_df = isp_aggregate_gene_shifts(
|