Fix AutoModel not loading model correctly due to config_class inconsistency
#26
by
liamclarkza
- opened
- bert_layers.py +3 -1
bert_layers.py
CHANGED
@@ -24,6 +24,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|
24 |
from .bert_padding import (index_first_axis,
|
25 |
index_put_first_axis, pad_input,
|
26 |
unpad_input, unpad_input_only)
|
|
|
27 |
|
28 |
try:
|
29 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
@@ -564,7 +565,8 @@ class BertModel(BertPreTrainedModel):
|
|
564 |
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
565 |
```
|
566 |
"""
|
567 |
-
|
|
|
568 |
def __init__(self, config, add_pooling_layer=True):
|
569 |
super(BertModel, self).__init__(config)
|
570 |
self.embeddings = BertEmbeddings(config)
|
|
|
24 |
from .bert_padding import (index_first_axis,
|
25 |
index_put_first_axis, pad_input,
|
26 |
unpad_input, unpad_input_only)
|
27 |
+
from .configuration_bert import BertConfig
|
28 |
|
29 |
try:
|
30 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
|
|
565 |
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
566 |
```
|
567 |
"""
|
568 |
+
config_class = BertConfig
|
569 |
+
|
570 |
def __init__(self, config, add_pooling_layer=True):
|
571 |
super(BertModel, self).__init__(config)
|
572 |
self.embeddings = BertEmbeddings(config)
|