oweller2 commited on
Commit
1400590
1 Parent(s): 0b90701
Files changed (1) hide show
  1. modeling_flexbert.py +6 -4
modeling_flexbert.py CHANGED
@@ -1722,10 +1722,12 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1722
  if attention_mask is None:
1723
  attention_mask = torch.ones_like(input_ids)
1724
 
1725
- # Calculate positions before unpadding
1726
- if position_ids is None:
1727
- position_ids = attention_mask.long().cumsum(-1) - 1
1728
- position_ids.masked_fill_(attention_mask == 0, 1)
 
 
1729
 
1730
  batch_size, seq_len = input_ids.shape[:2]
1731
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
 
1722
  if attention_mask is None:
1723
  attention_mask = torch.ones_like(input_ids)
1724
 
1725
+ # Calculate sequence-local positions
1726
+ seqlens = attention_mask.sum(dim=-1) # Get length of each sequence
1727
+ position_ids = torch.zeros_like(input_ids)
1728
+ for i in range(len(seqlens)):
1729
+ position_ids[i, :seqlens[i]] = torch.arange(seqlens[i], device=input_ids.device)
1730
+
1731
 
1732
  batch_size, seq_len = input_ids.shape[:2]
1733
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(