zhihan1996
commited on
Commit
•
69b2c8f
1
Parent(s):
634202c
Update bert_layers.py
Browse files- bert_layers.py +4 -4
bert_layers.py
CHANGED
@@ -813,8 +813,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
813 |
return MaskedLMOutput(
|
814 |
loss=loss,
|
815 |
logits=prediction_scores,
|
816 |
-
hidden_states=
|
817 |
-
attentions=
|
818 |
)
|
819 |
|
820 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
@@ -972,7 +972,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
972 |
return SequenceClassifierOutput(
|
973 |
loss=loss,
|
974 |
logits=logits,
|
975 |
-
hidden_states=
|
976 |
-
attentions=
|
977 |
)
|
978 |
|
|
|
813 |
return MaskedLMOutput(
|
814 |
loss=loss,
|
815 |
logits=prediction_scores,
|
816 |
+
hidden_states=outputs.hidden_states,
|
817 |
+
attentions=outputs.attention,
|
818 |
)
|
819 |
|
820 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
|
|
972 |
return SequenceClassifierOutput(
|
973 |
loss=loss,
|
974 |
logits=logits,
|
975 |
+
hidden_states=outputs.hidden_states,
|
976 |
+
attentions=outputs.attention,
|
977 |
)
|
978 |
|