oweller2 commited on
Commit
7bf8a6c
1 Parent(s): 6369df1
Files changed (1) hide show
  1. modeling_flexbert.py +4 -3
modeling_flexbert.py CHANGED
@@ -1702,7 +1702,8 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1702
  if self.pad_logits:
1703
  # print(f"Padding logits: {logits.shape}")
1704
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
1705
- if len(new_logits.shape) == 2:
 
1706
  new_logits = new_logits.unsqueeze(0)
1707
  return CausalLMOutput(
1708
  loss=loss,
@@ -1711,8 +1712,8 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1711
  attentions=None,
1712
  )
1713
  else:
1714
- # print(f"Non-padding logits: {logits.shape}")
1715
- if len(logits.shape) == 2:
1716
  logits = logits.unsqueeze(0)
1717
  return CausalLMOutput(
1718
  loss=loss,
 
1702
  if self.pad_logits:
1703
  # print(f"Padding logits: {logits.shape}")
1704
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
1705
+ # print(new_logits.shape)
1706
+ if new_logits.dim() == 1:
1707
  new_logits = new_logits.unsqueeze(0)
1708
  return CausalLMOutput(
1709
  loss=loss,
 
1712
  attentions=None,
1713
  )
1714
  else:
1715
+ print(f"Non-padding logits: {logits.shape}")
1716
+ if logits.dim() == 1:
1717
  logits = logits.unsqueeze(0)
1718
  return CausalLMOutput(
1719
  loss=loss,