oweller2 commited on
Commit
46797c8
1 Parent(s): 81b671b

no pad at inference

Browse files
Files changed (1) hide show
  1. modeling_flexbert.py +1 -3
modeling_flexbert.py CHANGED
@@ -1721,14 +1721,12 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1721
  # only last token for inputs if past is defined
1722
  if past_key_values is not None:
1723
  input_ids = input_ids[:, -1].unsqueeze(-1)
1724
- if attention_mask is not None:
1725
- attention_mask = attention_mask[:, -1:]
1726
 
1727
  return {
1728
  "input_ids": input_ids,
1729
  "past_key_values": past_key_values,
1730
  "use_cache": kwargs.get("use_cache", True),
1731
- "attention_mask": attention_mask,
1732
  }
1733
 
1734
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
 
1721
  # only last token for inputs if past is defined
1722
  if past_key_values is not None:
1723
  input_ids = input_ids[:, -1].unsqueeze(-1)
 
 
1724
 
1725
  return {
1726
  "input_ids": input_ids,
1727
  "past_key_values": past_key_values,
1728
  "use_cache": kwargs.get("use_cache", True),
1729
+ "attention_mask": None,
1730
  }
1731
 
1732
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int: