Commit
·
71aae6d
1
Parent(s):
8771224
fix: handle window_size passed as list
Browse files
mha.py
CHANGED
@@ -514,6 +514,10 @@ class MHA(nn.Module):
|
|
514 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
515 |
else:
|
516 |
alibi_slopes = None
|
|
|
|
|
|
|
|
|
517 |
if window_size != (-1, -1):
|
518 |
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
519 |
|
|
|
514 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
515 |
else:
|
516 |
alibi_slopes = None
|
517 |
+
|
518 |
+
if isinstance(window_size, list):
|
519 |
+
window_size = tuple(window_size)
|
520 |
+
|
521 |
if window_size != (-1, -1):
|
522 |
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
523 |
|