fix check for flash attn branching (#377)
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -92,7 +92,7 @@ def forward(
|
|
92 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
93 |
)
|
94 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
95 |
-
elif
|
96 |
# special handling using sample packing
|
97 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
98 |
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
|
|
92 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
93 |
)
|
94 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
95 |
+
elif attention_mask.shape[0] == 1:
|
96 |
# special handling using sample packing
|
97 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
98 |
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|