|
|
import torch |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
from .triton_attention import ( |
|
|
fused_mha_with_paged_cache, fused_mha_with_cache |
|
|
) |
|
|
|
|
|
dtype_int = torch.int32 |
|
|
|
|
|
def fused_mha_interface( |
|
|
query_states: torch.Tensor, |
|
|
key_states: torch.Tensor, |
|
|
value_states: torch.Tensor, |
|
|
k_cache: torch.Tensor, |
|
|
v_cache: torch.Tensor, |
|
|
position_ids: torch.Tensor=None, |
|
|
page_table: torch.Tensor=None, |
|
|
max_seq_len = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Replacement for _flash_attention_forward(...) that uses |
|
|
Tritonβs fused_mha_with_paged_cache under the hood. |
|
|
Returns: [batch, q_len, heads*head_dim] |
|
|
""" |
|
|
|
|
|
b, ql, n_heads, head_dim = query_states.shape |
|
|
_, kvl, n_kv_heads, _ = key_states.shape |
|
|
|
|
|
q = query_states.reshape(b, ql, n_heads * head_dim) |
|
|
k = key_states.reshape(b, kvl, n_kv_heads * head_dim) |
|
|
v = value_states.reshape(b, kvl, n_kv_heads * head_dim) |
|
|
|
|
|
if position_ids is not None: |
|
|
if ql == 1: |
|
|
input_pos = position_ids[:, -1] |
|
|
else: |
|
|
input_pos = position_ids[:, 0] |
|
|
else: |
|
|
|
|
|
input_pos = torch.zeros(b, device=q.device, dtype=torch.int32) |
|
|
|
|
|
freqs_cis = None |
|
|
|
|
|
if page_table is None: |
|
|
y = torch.ops.attention.fused_mha_with_cache( |
|
|
q, k, v, |
|
|
input_pos, |
|
|
k_cache, v_cache, |
|
|
freqs_cis, |
|
|
) |
|
|
|
|
|
|
|
|
else: |
|
|
batch_size = b |
|
|
|
|
|
|
|
|
cache_loc = torch.arange(batch_size, device=q.device, dtype=dtype_int) |
|
|
|
|
|
|
|
|
input_positions = torch.zeros(batch_size, device=q.device, dtype=dtype_int) |
|
|
|
|
|
|
|
|
seq_len = torch.full((batch_size,), kvl, device=q.device, dtype=dtype_int) |
|
|
|
|
|
|
|
|
seq_start = (seq_len.cumsum(0) - seq_len).to(dtype=dtype_int) |
|
|
|
|
|
assert max_seq_len is not None, "max_seq_len must be provided when using paged attention." |
|
|
|
|
|
y = torch.ops.attention.fused_mha_with_paged_cache( |
|
|
q, k, v, |
|
|
input_positions, cache_loc, |
|
|
seq_len, seq_start, |
|
|
page_table, max_seq_len, |
|
|
k_cache, v_cache, |
|
|
freqs_cis, |
|
|
) |
|
|
|
|
|
y = y.view(b, ql, n_heads, head_dim) |
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
batch_size = 1 |
|
|
q_len = 1 |
|
|
kv_len = 1 |
|
|
num_heads = 16 |
|
|
n_kv_heads = 16 |
|
|
head_dim = 128 |
|
|
|
|
|
max_batch_size = 1 |
|
|
max_seq_len = 1024 |
|
|
|
|
|
page_size = 256 |
|
|
|
|
|
device = "cuda" |
|
|
|
|
|
|
|
|
query_states = torch.randn(batch_size, q_len, num_heads, head_dim, device=device) |
|
|
key_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device) |
|
|
value_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device) |
|
|
|
|
|
k_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device) |
|
|
v_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device) |
|
|
|
|
|
attn_out = fused_mha_interface( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
k_cache=k_cache, |
|
|
v_cache=v_cache, |
|
|
) |
|
|
|
|
|
expected_shape = (batch_size, q_len, num_heads, head_dim) |
|
|
print(f"[test] output shape: {attn_out.shape} (expected {expected_shape})") |
|
|
|
|
|
if attn_out.shape == expected_shape: |
|
|
print("[test] β
Success: output tensor has correct shape.") |
|
|
else: |
|
|
print("[test] β Failure: shape mismatch.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |