oweller2 commited on
Commit
8686e3f
1 Parent(s): cf03b9b
Files changed (1) hide show
  1. modeling_flexbert.py +9 -3
modeling_flexbert.py CHANGED
@@ -1536,7 +1536,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1536
  # Initialize weights and apply final processing
1537
  self._init_weights(reset_params=False)
1538
 
1539
- def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1540
  # Handle the XOR condition
1541
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1542
 
@@ -1556,7 +1556,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1556
 
1557
  if not self.config.tie_word_embeddings:
1558
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1559
-
1560
  @classmethod
1561
  def from_composer(
1562
  cls,
@@ -1702,13 +1702,19 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1702
  )
1703
 
1704
  if self.pad_logits:
 
 
 
 
1705
  return CausalLMOutput(
1706
  loss=loss,
1707
- logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1708
  hidden_states=None,
1709
  attentions=None,
1710
  )
1711
  else:
 
 
1712
  return CausalLMOutput(
1713
  loss=loss,
1714
  logits=logits,
 
1536
  # Initialize weights and apply final processing
1537
  self._init_weights(reset_params=False)
1538
 
1539
+ [] def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1540
  # Handle the XOR condition
1541
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1542
 
 
1556
 
1557
  if not self.config.tie_word_embeddings:
1558
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1559
+
1560
  @classmethod
1561
  def from_composer(
1562
  cls,
 
1702
  )
1703
 
1704
  if self.pad_logits:
1705
+ # Reshape logits to 3D if needed
1706
+ new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
1707
+ if len(new_logits.shape) == 2:
1708
+ new_logits = new_logits.unsqueeze(0)
1709
  return CausalLMOutput(
1710
  loss=loss,
1711
+ logits=new_logits,
1712
  hidden_states=None,
1713
  attentions=None,
1714
  )
1715
  else:
1716
+ if len(logits.shape) == 2:
1717
+ logits = logits.unsqueeze(0)
1718
  return CausalLMOutput(
1719
  loss=loss,
1720
  logits=logits,