zhihan1996 commited on
Commit
782b462
1 Parent(s): c42a4fe

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +2 -2
bert_layers.py CHANGED
@@ -18,7 +18,7 @@ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
  from transformers.activations import ACT2FN
19
  from transformers.modeling_outputs import (MaskedLMOutput,
20
  SequenceClassifierOutput)
21
- from transformers.models.bert.modeling_bert import BertPreTrainedModel
22
 
23
  from .bert_padding import (index_first_axis,
24
  index_put_first_axis, pad_input,
@@ -521,7 +521,7 @@ class BertPredictionHeadTransform(nn.Module):
521
  return hidden_states
522
 
523
 
524
- class BertModel(BertPreTrainedModel):
525
  """Overall BERT model.
526
 
527
  Args:
 
18
  from transformers.activations import ACT2FN
19
  from transformers.modeling_outputs import (MaskedLMOutput,
20
  SequenceClassifierOutput)
21
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, PreTrainedModel
22
 
23
  from .bert_padding import (index_first_axis,
24
  index_put_first_axis, pad_input,
 
521
  return hidden_states
522
 
523
 
524
+ class BertModel(PreTrainedModel):
525
  """Overall BERT model.
526
 
527
  Args: