Update geneformer/perturber_utils.py
Browse files
geneformer/perturber_utils.py
CHANGED
|
@@ -155,9 +155,11 @@ def quant_layers(model):
|
|
| 155 |
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
|
| 156 |
return int(max(layer_nums)) + 1
|
| 157 |
|
| 158 |
-
def
|
| 159 |
-
return
|
| 160 |
|
|
|
|
|
|
|
| 161 |
|
| 162 |
def get_model_input_size(model):
|
| 163 |
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
|
|
|
|
| 155 |
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
|
| 156 |
return int(max(layer_nums)) + 1
|
| 157 |
|
| 158 |
+
def get_model_emb_dims(model):
|
| 159 |
+
return model.config.hidden_size
|
| 160 |
|
| 161 |
+
def get_model_input_size(model):
|
| 162 |
+
return model.config.max_position_embeddings
|
| 163 |
|
| 164 |
def get_model_input_size(model):
|
| 165 |
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
|