Spaces:
Running
on
Zero
Running
on
Zero
Feature(MInference): fix the func name
Browse files
minference/ops/block_sparse_flash_attention.py
CHANGED
@@ -444,7 +444,7 @@ def test_flash_attention(
|
|
444 |
print('========================================\n')
|
445 |
|
446 |
|
447 |
-
def
|
448 |
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
449 |
key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
450 |
value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
|
|
444 |
print('========================================\n')
|
445 |
|
446 |
|
447 |
+
def block_sparse_attention(
|
448 |
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
449 |
key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
450 |
value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
minference/ops/pit_sparse_flash_attention_v2.py
CHANGED
@@ -693,7 +693,7 @@ def test_flash_attention(
|
|
693 |
torch.testing.assert_close(output_flash, output_triton_sparse, atol=1e-2, rtol=0)
|
694 |
|
695 |
|
696 |
-
def
|
697 |
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
698 |
key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
699 |
value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
|
|
693 |
torch.testing.assert_close(output_flash, output_triton_sparse, atol=1e-2, rtol=0)
|
694 |
|
695 |
|
696 |
+
def vertical_slash_sparse_attention(
|
697 |
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
698 |
key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
699 |
value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|