winglian commited on
Commit
343ac84
1 Parent(s): 0c96727

fix check for flash attn branching (#377)

Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -92,7 +92,7 @@ def forward(
92
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
93
  )
94
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
95
- elif position_ids.shape[0] == 1:
96
  # special handling using sample packing
97
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
98
  cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
 
92
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
93
  )
94
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
95
+ elif attention_mask.shape[0] == 1:
96
  # special handling using sample packing
97
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
98
  cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)