oweller2
commited on
Commit
•
f66abc1
1
Parent(s):
6d1817e
udpdate
Browse files- attention.py +1 -1
- modeling_flexbert.py +19 -11
attention.py
CHANGED
@@ -863,7 +863,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
863 |
qkv = self.Wqkv(hidden_states)
|
864 |
|
865 |
# only needed for inference when we have KV cache
|
866 |
-
seqlen_offset = 0
|
867 |
|
868 |
# (total_seqlen, 3, nheads, headdim)
|
869 |
qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
|
|
|
863 |
qkv = self.Wqkv(hidden_states)
|
864 |
|
865 |
# only needed for inference when we have KV cache
|
866 |
+
seqlen_offset = max_seqlen * (len(cu_seqlens) - 2) if len(cu_seqlens) > 1 else 0
|
867 |
|
868 |
# (total_seqlen, 3, nheads, headdim)
|
869 |
qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
|
modeling_flexbert.py
CHANGED
@@ -1715,20 +1715,28 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1715 |
def prepare_inputs_for_generation(
|
1716 |
self,
|
1717 |
input_ids: torch.Tensor,
|
1718 |
-
past_key_values: Optional[torch.FloatTensor] = None,
|
1719 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
1720 |
**kwargs
|
1721 |
) -> dict:
|
1722 |
-
|
1723 |
-
|
1724 |
-
|
1725 |
-
|
1726 |
-
|
1727 |
-
|
1728 |
-
|
1729 |
-
|
1730 |
-
|
1731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1732 |
|
1733 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
1734 |
"""Returns the number of parameters in the model.
|
|
|
1715 |
def prepare_inputs_for_generation(
|
1716 |
self,
|
1717 |
input_ids: torch.Tensor,
|
|
|
1718 |
attention_mask: Optional[torch.Tensor] = None,
|
1719 |
+
position_ids: Optional[torch.Tensor] = None,
|
1720 |
**kwargs
|
1721 |
) -> dict:
|
1722 |
+
if attention_mask is None:
|
1723 |
+
attention_mask = torch.ones_like(input_ids)
|
1724 |
+
|
1725 |
+
batch_size, seq_len = input_ids.shape[:2]
|
1726 |
+
input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
|
1727 |
+
input_ids, attention_mask, position_ids, None
|
1728 |
+
)
|
1729 |
+
breakpoint()
|
1730 |
+
return {
|
1731 |
+
"input_ids": input_ids,
|
1732 |
+
"attention_mask": attention_mask,
|
1733 |
+
"position_ids": position_ids,
|
1734 |
+
"indices": indices,
|
1735 |
+
"cu_seqlens": cu_seqlens,
|
1736 |
+
"max_seqlen": max_seqlen,
|
1737 |
+
"batch_size": batch_size,
|
1738 |
+
"seq_len": seq_len
|
1739 |
+
}
|
1740 |
|
1741 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
1742 |
"""Returns the number of parameters in the model.
|