oweller2 commited on
Commit
1f61dbc
1 Parent(s): 3cd62e1
Files changed (1) hide show
  1. modeling_flexbert.py +19 -22
modeling_flexbert.py CHANGED
@@ -1708,28 +1708,25 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1708
  attentions=None,
1709
  )
1710
 
1711
- def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
1712
- input_shape = input_ids.shape
1713
- effective_batch_size = input_shape[0]
1714
- breakpoint()
1715
-
1716
- # add a dummy token
1717
- if self.config.pad_token_id is None:
1718
- raise ValueError("The PAD token should be defined for generation")
1719
-
1720
- attention_mask = torch.cat(
1721
- [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
1722
- dim=-1,
1723
- )
1724
- dummy_token = torch.full(
1725
- (effective_batch_size, 1),
1726
- self.config.pad_token_id,
1727
- dtype=torch.long,
1728
- device=input_ids.device,
1729
- )
1730
- input_ids = torch.cat([input_ids, dummy_token], dim=1)
1731
-
1732
- return {"input_ids": input_ids, "attention_mask": attention_mask}
1733
 
1734
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1735
  """Returns the number of parameters in the model.
 
1708
  attentions=None,
1709
  )
1710
 
1711
+ def prepare_inputs_for_generation(
1712
+ self,
1713
+ input_ids: torch.Tensor,
1714
+ past_key_values: Optional[torch.FloatTensor] = None,
1715
+ attention_mask: Optional[torch.Tensor] = None,
1716
+ **kwargs
1717
+ ) -> dict:
1718
+ # only last token for inputs if past is defined
1719
+ if past_key_values is not None:
1720
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1721
+ if attention_mask is not None:
1722
+ attention_mask = attention_mask[:, -1:]
1723
+
1724
+ return {
1725
+ "input_ids": input_ids,
1726
+ "past_key_values": past_key_values,
1727
+ "use_cache": kwargs.get("use_cache", True),
1728
+ "attention_mask": attention_mask,
1729
+ }
 
 
 
1730
 
1731
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1732
  """Returns the number of parameters in the model.