oweller2 commited on
Commit
b9219f0
1 Parent(s): 322b01b

try unpadding in inferece

Browse files
Files changed (1) hide show
  1. modeling_flexbert.py +26 -6
modeling_flexbert.py CHANGED
@@ -1724,12 +1724,32 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1724
  if attention_mask is not None:
1725
  attention_mask = attention_mask[:, -1:]
1726
 
1727
- return {
1728
- "input_ids": input_ids,
1729
- "past_key_values": past_key_values,
1730
- "use_cache": kwargs.get("use_cache", True),
1731
- "attention_mask": attention_mask,
1732
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1733
 
1734
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1735
  """Returns the number of parameters in the model.
 
1724
  if attention_mask is not None:
1725
  attention_mask = attention_mask[:, -1:]
1726
 
1727
+ # Handle unpadding for the last token if needed
1728
+ if self.unpad_embeddings:
1729
+ batch_size, seq_len = input_ids.shape[:2]
1730
+ if attention_mask is None:
1731
+ # create all ones, except for padding (TODO?)
1732
+ attention_mask = torch.ones_like(input_ids)
1733
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
1734
+ input_ids, attention_mask, None, None
1735
+ )
1736
+ return {
1737
+ "input_ids": input_ids,
1738
+ "past_key_values": past_key_values,
1739
+ "use_cache": kwargs.get("use_cache", True),
1740
+ "attention_mask": None, # FA handles this
1741
+ "indices": indices,
1742
+ "cu_seqlens": cu_seqlens,
1743
+ "max_seqlen": max_seqlen,
1744
+ "position_ids": position_ids,
1745
+ }
1746
+ else:
1747
+ return {
1748
+ "input_ids": input_ids,
1749
+ "past_key_values": past_key_values,
1750
+ "use_cache": kwargs.get("use_cache", True),
1751
+ "attention_mask": attention_mask,
1752
+ }
1753
 
1754
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1755
  """Returns the number of parameters in the model.