oweller2
commited on
Commit
•
1400590
1
Parent(s):
0b90701
fix
Browse files- modeling_flexbert.py +6 -4
modeling_flexbert.py
CHANGED
@@ -1722,10 +1722,12 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1722 |
if attention_mask is None:
|
1723 |
attention_mask = torch.ones_like(input_ids)
|
1724 |
|
1725 |
-
# Calculate positions
|
1726 |
-
|
1727 |
-
|
1728 |
-
|
|
|
|
|
1729 |
|
1730 |
batch_size, seq_len = input_ids.shape[:2]
|
1731 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
|
|
|
1722 |
if attention_mask is None:
|
1723 |
attention_mask = torch.ones_like(input_ids)
|
1724 |
|
1725 |
+
# Calculate sequence-local positions
|
1726 |
+
seqlens = attention_mask.sum(dim=-1) # Get length of each sequence
|
1727 |
+
position_ids = torch.zeros_like(input_ids)
|
1728 |
+
for i in range(len(seqlens)):
|
1729 |
+
position_ids[i, :seqlens[i]] = torch.arange(seqlens[i], device=input_ids.device)
|
1730 |
+
|
1731 |
|
1732 |
batch_size, seq_len = input_ids.shape[:2]
|
1733 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
|