zhihan1996 commited on
Commit
69b2c8f
1 Parent(s): 634202c

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +4 -4
bert_layers.py CHANGED
@@ -813,8 +813,8 @@ class BertForMaskedLM(BertPreTrainedModel):
813
  return MaskedLMOutput(
814
  loss=loss,
815
  logits=prediction_scores,
816
- hidden_states=None,
817
- attentions=None,
818
  )
819
 
820
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
@@ -972,7 +972,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
972
  return SequenceClassifierOutput(
973
  loss=loss,
974
  logits=logits,
975
- hidden_states=None,
976
- attentions=None,
977
  )
978
 
 
813
  return MaskedLMOutput(
814
  loss=loss,
815
  logits=prediction_scores,
816
+ hidden_states=outputs.hidden_states,
817
+ attentions=outputs.attention,
818
  )
819
 
820
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
 
972
  return SequenceClassifierOutput(
973
  loss=loss,
974
  logits=logits,
975
+ hidden_states=outputs.hidden_states,
976
+ attentions=outputs.attention,
977
  )
978