model_type on top
Browse files- configuration_bert.py +2 -2
configuration_bert.py
CHANGED
@@ -40,6 +40,8 @@ class BertConfig(TransformersBertConfig):
|
|
40 |
|
41 |
|
42 |
class FlexBertConfig(TransformersBertConfig):
|
|
|
|
|
43 |
def __init__(
|
44 |
self,
|
45 |
attention_layer: str = "base",
|
@@ -97,7 +99,6 @@ class FlexBertConfig(TransformersBertConfig):
|
|
97 |
pad_logits: bool = False,
|
98 |
compile_model: bool = False,
|
99 |
masked_prediction: bool = False,
|
100 |
-
model_type: str = "flex_bert",
|
101 |
**kwargs,
|
102 |
):
|
103 |
"""
|
@@ -214,7 +215,6 @@ class FlexBertConfig(TransformersBertConfig):
|
|
214 |
self.pad_logits = pad_logits
|
215 |
self.compile_model = compile_model
|
216 |
self.masked_prediction = masked_prediction
|
217 |
-
self.model_type = model_type
|
218 |
|
219 |
if loss_kwargs.get("return_z_loss", False):
|
220 |
if loss_function != "fa_cross_entropy":
|
|
|
40 |
|
41 |
|
42 |
class FlexBertConfig(TransformersBertConfig):
|
43 |
+
model_type = "flex_bert"
|
44 |
+
|
45 |
def __init__(
|
46 |
self,
|
47 |
attention_layer: str = "base",
|
|
|
99 |
pad_logits: bool = False,
|
100 |
compile_model: bool = False,
|
101 |
masked_prediction: bool = False,
|
|
|
102 |
**kwargs,
|
103 |
):
|
104 |
"""
|
|
|
215 |
self.pad_logits = pad_logits
|
216 |
self.compile_model = compile_model
|
217 |
self.masked_prediction = masked_prediction
|
|
|
218 |
|
219 |
if loss_kwargs.get("return_z_loss", False):
|
220 |
if loss_function != "fa_cross_entropy":
|