Bug: FlashAttention forward only supports head dimension
#3
by
Xidong
- opened
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention forward only supports head dimension at most 256
This isn't a bug. You can't use Flash Attention on our HF implementation because of our concat before the shared attn layer. We got around this by adding split-head support in our inference stack, which we're working on upstreaming to https://github.com/Zyphra/Zamba-torch
In the meantime, we're going to add an assertion to disable FA2 here until we get the HF port figured out for our FA2 changes. Can you try with non-flash attention?
Ok, get it, Thanks
Xidong
changed discussion status to
closed