oweller2 commited on
Commit
c9fb228
·
1 Parent(s): bf29787
Files changed (1) hide show
  1. modeling_flexbert.py +4 -0
modeling_flexbert.py CHANGED
@@ -1700,6 +1700,10 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1700
  shift_labels.view(-1)
1701
  )
1702
 
 
 
 
 
1703
  if self.pad_logits:
1704
  # print(f"Padding logits: {logits.shape}")
1705
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
 
1700
  shift_labels.view(-1)
1701
  )
1702
 
1703
+ if self.unpad_embeddings:
1704
+ # reshape to batch size
1705
+ logits = logits.view(batch_size, -1, logits.size(-1))
1706
+
1707
  if self.pad_logits:
1708
  # print(f"Padding logits: {logits.shape}")
1709
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]