Allow device auto map
#8
by
Jackmin108
- opened
- modeling_bert.py +1 -0
modeling_bert.py
CHANGED
@@ -956,6 +956,7 @@ class JinaBertPreTrainedModel(PreTrainedModel):
|
|
956 |
load_tf_weights = load_tf_weights_in_bert
|
957 |
base_model_prefix = "bert"
|
958 |
supports_gradient_checkpointing = True
|
|
|
959 |
|
960 |
def _init_weights(self, module):
|
961 |
"""Initialize the weights"""
|
|
|
956 |
load_tf_weights = load_tf_weights_in_bert
|
957 |
base_model_prefix = "bert"
|
958 |
supports_gradient_checkpointing = True
|
959 |
+
_no_split_modules = ["JinaBertLayer"]
|
960 |
|
961 |
def _init_weights(self, module):
|
962 |
"""Initialize the weights"""
|