oweller2 commited on
Commit
f66abc1
1 Parent(s): 6d1817e
Files changed (2) hide show
  1. attention.py +1 -1
  2. modeling_flexbert.py +19 -11
attention.py CHANGED
@@ -863,7 +863,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
863
  qkv = self.Wqkv(hidden_states)
864
 
865
  # only needed for inference when we have KV cache
866
- seqlen_offset = 0
867
 
868
  # (total_seqlen, 3, nheads, headdim)
869
  qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
 
863
  qkv = self.Wqkv(hidden_states)
864
 
865
  # only needed for inference when we have KV cache
866
+ seqlen_offset = max_seqlen * (len(cu_seqlens) - 2) if len(cu_seqlens) > 1 else 0
867
 
868
  # (total_seqlen, 3, nheads, headdim)
869
  qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
modeling_flexbert.py CHANGED
@@ -1715,20 +1715,28 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1715
  def prepare_inputs_for_generation(
1716
  self,
1717
  input_ids: torch.Tensor,
1718
- past_key_values: Optional[torch.FloatTensor] = None,
1719
  attention_mask: Optional[torch.Tensor] = None,
 
1720
  **kwargs
1721
  ) -> dict:
1722
- # only last token for inputs if past is defined
1723
- if past_key_values is not None:
1724
- input_ids = input_ids[:, -1].unsqueeze(-1)
1725
-
1726
- return {
1727
- "input_ids": input_ids,
1728
- "past_key_values": past_key_values,
1729
- "use_cache": kwargs.get("use_cache", True),
1730
- "attention_mask": None,
1731
- }
 
 
 
 
 
 
 
 
1732
 
1733
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1734
  """Returns the number of parameters in the model.
 
1715
  def prepare_inputs_for_generation(
1716
  self,
1717
  input_ids: torch.Tensor,
 
1718
  attention_mask: Optional[torch.Tensor] = None,
1719
+ position_ids: Optional[torch.Tensor] = None,
1720
  **kwargs
1721
  ) -> dict:
1722
+ if attention_mask is None:
1723
+ attention_mask = torch.ones_like(input_ids)
1724
+
1725
+ batch_size, seq_len = input_ids.shape[:2]
1726
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
1727
+ input_ids, attention_mask, position_ids, None
1728
+ )
1729
+ breakpoint()
1730
+ return {
1731
+ "input_ids": input_ids,
1732
+ "attention_mask": attention_mask,
1733
+ "position_ids": position_ids,
1734
+ "indices": indices,
1735
+ "cu_seqlens": cu_seqlens,
1736
+ "max_seqlen": max_seqlen,
1737
+ "batch_size": batch_size,
1738
+ "seq_len": seq_len
1739
+ }
1740
 
1741
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1742
  """Returns the number of parameters in the model.