adds llama and mistral dropout support (#858)
Browse files* adds llama and mistral dropout support
* gracefully handle attention dropout if not available yet
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -321,6 +321,8 @@ def flashattn_forward(
|
|
321 |
# only on first autoregressive step q,k,v have same seqlen
|
322 |
is_causal = key_states.shape == query_states.shape
|
323 |
|
|
|
|
|
324 |
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
325 |
# special handling using sample packing
|
326 |
qkv = torch.stack(
|
@@ -330,7 +332,12 @@ def flashattn_forward(
|
|
330 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
331 |
|
332 |
output = flash_attn_varlen_qkvpacked_func(
|
333 |
-
qkv,
|
|
|
|
|
|
|
|
|
|
|
334 |
)
|
335 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
336 |
elif query_states.shape == key_states.shape:
|
@@ -353,7 +360,7 @@ def flashattn_forward(
|
|
353 |
qkv_unpad,
|
354 |
cu_seqlens_q,
|
355 |
max_seqlen_q,
|
356 |
-
|
357 |
softmax_scale=None,
|
358 |
causal=is_causal,
|
359 |
)
|
@@ -366,6 +373,7 @@ def flashattn_forward(
|
|
366 |
output = flash_attn_kvpacked_func(
|
367 |
query_states,
|
368 |
torch.stack([key_states, value_states], 2),
|
|
|
369 |
causal=is_causal,
|
370 |
)
|
371 |
else:
|
@@ -398,7 +406,7 @@ def flashattn_forward(
|
|
398 |
cu_seqlens_k,
|
399 |
max_seqlen_q,
|
400 |
max_seqlen_k,
|
401 |
-
|
402 |
softmax_scale=None,
|
403 |
causal=is_causal,
|
404 |
)
|
|
|
321 |
# only on first autoregressive step q,k,v have same seqlen
|
322 |
is_causal = key_states.shape == query_states.shape
|
323 |
|
324 |
+
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
325 |
+
|
326 |
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
327 |
# special handling using sample packing
|
328 |
qkv = torch.stack(
|
|
|
332 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
333 |
|
334 |
output = flash_attn_varlen_qkvpacked_func(
|
335 |
+
qkv,
|
336 |
+
cu_seqlens,
|
337 |
+
max_seqlen,
|
338 |
+
dropout_p=dropout_rate,
|
339 |
+
softmax_scale=None,
|
340 |
+
causal=True,
|
341 |
)
|
342 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
343 |
elif query_states.shape == key_states.shape:
|
|
|
360 |
qkv_unpad,
|
361 |
cu_seqlens_q,
|
362 |
max_seqlen_q,
|
363 |
+
dropout_p=dropout_rate,
|
364 |
softmax_scale=None,
|
365 |
causal=is_causal,
|
366 |
)
|
|
|
373 |
output = flash_attn_kvpacked_func(
|
374 |
query_states,
|
375 |
torch.stack([key_states, value_states], 2),
|
376 |
+
dropout_p=dropout_rate,
|
377 |
causal=is_causal,
|
378 |
)
|
379 |
else:
|
|
|
406 |
cu_seqlens_k,
|
407 |
max_seqlen_q,
|
408 |
max_seqlen_k,
|
409 |
+
dropout_p=dropout_rate,
|
410 |
softmax_scale=None,
|
411 |
causal=is_causal,
|
412 |
)
|
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
CHANGED
@@ -201,6 +201,8 @@ def flashattn_forward(
|
|
201 |
# only on first autoregressive step q,k,v have same seqlen
|
202 |
is_causal = key_states.shape == query_states.shape
|
203 |
|
|
|
|
|
204 |
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
205 |
# special handling using sample packing
|
206 |
qkv = torch.stack(
|
@@ -213,7 +215,7 @@ def flashattn_forward(
|
|
213 |
qkv,
|
214 |
cu_seqlens,
|
215 |
max_seqlen,
|
216 |
-
|
217 |
softmax_scale=None,
|
218 |
causal=True,
|
219 |
window_size=window_size,
|
@@ -239,7 +241,7 @@ def flashattn_forward(
|
|
239 |
qkv_unpad,
|
240 |
cu_seqlens_q,
|
241 |
max_seqlen_q,
|
242 |
-
|
243 |
softmax_scale=None,
|
244 |
causal=is_causal,
|
245 |
window_size=window_size,
|
@@ -253,6 +255,7 @@ def flashattn_forward(
|
|
253 |
output = flash_attn_kvpacked_func(
|
254 |
query_states,
|
255 |
torch.stack([key_states, value_states], 2),
|
|
|
256 |
causal=is_causal,
|
257 |
window_size=window_size,
|
258 |
)
|
@@ -286,7 +289,7 @@ def flashattn_forward(
|
|
286 |
cu_seqlens_k,
|
287 |
max_seqlen_q,
|
288 |
max_seqlen_k,
|
289 |
-
|
290 |
softmax_scale=None,
|
291 |
causal=is_causal,
|
292 |
window_size=window_size,
|
|
|
201 |
# only on first autoregressive step q,k,v have same seqlen
|
202 |
is_causal = key_states.shape == query_states.shape
|
203 |
|
204 |
+
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
205 |
+
|
206 |
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
207 |
# special handling using sample packing
|
208 |
qkv = torch.stack(
|
|
|
215 |
qkv,
|
216 |
cu_seqlens,
|
217 |
max_seqlen,
|
218 |
+
dropout_p=dropout_rate,
|
219 |
softmax_scale=None,
|
220 |
causal=True,
|
221 |
window_size=window_size,
|
|
|
241 |
qkv_unpad,
|
242 |
cu_seqlens_q,
|
243 |
max_seqlen_q,
|
244 |
+
dropout_p=dropout_rate,
|
245 |
softmax_scale=None,
|
246 |
causal=is_causal,
|
247 |
window_size=window_size,
|
|
|
255 |
output = flash_attn_kvpacked_func(
|
256 |
query_states,
|
257 |
torch.stack([key_states, value_states], 2),
|
258 |
+
dropout_p=dropout_rate,
|
259 |
causal=is_causal,
|
260 |
window_size=window_size,
|
261 |
)
|
|
|
289 |
cu_seqlens_k,
|
290 |
max_seqlen_q,
|
291 |
max_seqlen_k,
|
292 |
+
dropout_p=dropout_rate,
|
293 |
softmax_scale=None,
|
294 |
causal=is_causal,
|
295 |
window_size=window_size,
|