DreamGenX
commited on
Respect sliding_window=None (#1214)
Browse files
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
CHANGED
@@ -94,7 +94,7 @@ def _prepare_decoder_attention_mask(
|
|
94 |
sliding_window,
|
95 |
): # pylint: disable=unused-argument
|
96 |
# [bsz, seq_len]
|
97 |
-
if attention_mask is None:
|
98 |
return attention_mask
|
99 |
|
100 |
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
@@ -151,7 +151,7 @@ def flashattn_forward(
|
|
151 |
)
|
152 |
|
153 |
use_sliding_windows = (
|
154 |
-
|
155 |
and kv_seq_len > self.config.sliding_window
|
156 |
)
|
157 |
|
|
|
94 |
sliding_window,
|
95 |
): # pylint: disable=unused-argument
|
96 |
# [bsz, seq_len]
|
97 |
+
if attention_mask is None or sliding_window is None:
|
98 |
return attention_mask
|
99 |
|
100 |
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
|
|
151 |
)
|
152 |
|
153 |
use_sliding_windows = (
|
154 |
+
getattr(self.config, "sliding_window") is not None
|
155 |
and kv_seq_len > self.config.sliding_window
|
156 |
)
|
157 |
|