Markus28 commited on
Commit
59c0808
·
1 Parent(s): e93b0fd

feat: added return_dict

Browse files
Files changed (1) hide show
  1. modeling_bert.py +4 -0
modeling_bert.py CHANGED
@@ -379,6 +379,7 @@ class BertModel(BertPreTrainedModel):
379
  task_type_ids=None,
380
  attention_mask=None,
381
  masked_tokens_mask=None,
 
382
  ):
383
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
384
  we only want the output for the masked tokens. This means that we only compute the last
@@ -429,6 +430,9 @@ class BertModel(BertPreTrainedModel):
429
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
430
  pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
431
 
 
 
 
432
  return BaseModelOutputWithPoolingAndCrossAttentions(
433
  last_hidden_state=sequence_output,
434
  pooler_output=pooled_output,
 
379
  task_type_ids=None,
380
  attention_mask=None,
381
  masked_tokens_mask=None,
382
+ return_dict=True,
383
  ):
384
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
385
  we only want the output for the masked tokens. This means that we only compute the last
 
430
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
431
  pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
432
 
433
+ if not return_dict:
434
+ return (sequence_output, pooled_output)
435
+
436
  return BaseModelOutputWithPoolingAndCrossAttentions(
437
  last_hidden_state=sequence_output,
438
  pooler_output=pooled_output,