Fix AutoModel not loading model correctly due to config_class inconsistency

#26
Files changed (1) hide show
  1. bert_layers.py +3 -1
bert_layers.py CHANGED
@@ -24,6 +24,7 @@ from transformers.modeling_utils import PreTrainedModel
24
  from .bert_padding import (index_first_axis,
25
  index_put_first_axis, pad_input,
26
  unpad_input, unpad_input_only)
 
27
 
28
  try:
29
  from .flash_attn_triton import flash_attn_qkvpacked_func
@@ -564,7 +565,8 @@ class BertModel(BertPreTrainedModel):
564
  all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
565
  ```
566
  """
567
-
 
568
  def __init__(self, config, add_pooling_layer=True):
569
  super(BertModel, self).__init__(config)
570
  self.embeddings = BertEmbeddings(config)
 
24
  from .bert_padding import (index_first_axis,
25
  index_put_first_axis, pad_input,
26
  unpad_input, unpad_input_only)
27
+ from .configuration_bert import BertConfig
28
 
29
  try:
30
  from .flash_attn_triton import flash_attn_qkvpacked_func
 
565
  all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
566
  ```
567
  """
568
+ config_class = BertConfig
569
+
570
  def __init__(self, config, add_pooling_layer=True):
571
  super(BertModel, self).__init__(config)
572
  self.embeddings = BertEmbeddings(config)