oweller2 commited on
Commit
bf29787
1 Parent(s): 7bf8a6c
Files changed (1) hide show
  1. modeling_flexbert.py +6 -4
modeling_flexbert.py CHANGED
@@ -1657,6 +1657,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1657
  cu_seqlens=cu_seqlens,
1658
  max_seqlen=max_seqlen,
1659
  )
 
1660
 
1661
  if self.compile_model:
1662
  logits = self.compiled_lm_head(hidden_states)
@@ -1703,8 +1704,8 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1703
  # print(f"Padding logits: {logits.shape}")
1704
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
1705
  # print(new_logits.shape)
1706
- if new_logits.dim() == 1:
1707
- new_logits = new_logits.unsqueeze(0)
1708
  return CausalLMOutput(
1709
  loss=loss,
1710
  logits=new_logits,
@@ -1713,8 +1714,9 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1713
  )
1714
  else:
1715
  print(f"Non-padding logits: {logits.shape}")
1716
- if logits.dim() == 1:
1717
- logits = logits.unsqueeze(0)
 
1718
  return CausalLMOutput(
1719
  loss=loss,
1720
  logits=logits,
 
1657
  cu_seqlens=cu_seqlens,
1658
  max_seqlen=max_seqlen,
1659
  )
1660
+ print(hidden_states.shape)
1661
 
1662
  if self.compile_model:
1663
  logits = self.compiled_lm_head(hidden_states)
 
1704
  # print(f"Padding logits: {logits.shape}")
1705
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
1706
  # print(new_logits.shape)
1707
+ # if new_logits.dim() == 2:
1708
+ # new_logits = new_logits.unsqueeze(0)
1709
  return CausalLMOutput(
1710
  loss=loss,
1711
  logits=new_logits,
 
1714
  )
1715
  else:
1716
  print(f"Non-padding logits: {logits.shape}")
1717
+ logits = logits.view(-1, logits.size(-1))
1718
+ # if logits.dim() == 2:
1719
+ # logits = logits.unsqueeze(0)
1720
  return CausalLMOutput(
1721
  loss=loss,
1722
  logits=logits,