import torch from typing import Optional, Tuple from .triton_attention import ( fused_mha_with_paged_cache, fused_mha_with_cache ) dtype_int = torch.int32 def fused_mha_interface( query_states: torch.Tensor, # [batch, q_len, heads, head_dim] key_states: torch.Tensor, # [batch, kv_len, heads, head_dim] value_states: torch.Tensor, # [batch, kv_len, heads, head_dim] k_cache: torch.Tensor, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD] or [num_pages, page_size, n, d] for paged attn v_cache: torch.Tensor, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD] position_ids: torch.Tensor=None, page_table: torch.Tensor=None, # [b, max_num_pages_per_seq] # loc of the block page in the cache. max_seq_len = None, ) -> torch.Tensor: """ Replacement for _flash_attention_forward(...) that uses Triton’s fused_mha_with_paged_cache under the hood. Returns: [batch, q_len, heads*head_dim] """ # unpack shapes b, ql, n_heads, head_dim = query_states.shape _, kvl, n_kv_heads, _ = key_states.shape q = query_states.reshape(b, ql, n_heads * head_dim) k = key_states.reshape(b, kvl, n_kv_heads * head_dim) v = value_states.reshape(b, kvl, n_kv_heads * head_dim) if position_ids is not None: if ql == 1: # Generate phase - single token input_pos = position_ids[:, -1] # Use the last position for each sequence else: # Context phase - multiple tokens input_pos = position_ids[:, 0] # Use the starting position for each sequence else: # Fallback: assume starting from 0 for all sequences input_pos = torch.zeros(b, device=q.device, dtype=torch.int32) freqs_cis = None if page_table is None: y = torch.ops.attention.fused_mha_with_cache( q, k, v, input_pos, k_cache, v_cache, freqs_cis, ) else: batch_size = b # cache_loc: identity mapping [0, 1, ..., b-1] cache_loc = torch.arange(batch_size, device=q.device, dtype=dtype_int) # input_positions: assume pure context (all start from 0) input_positions = torch.zeros(batch_size, device=q.device, dtype=dtype_int) # seq_len: each sequence length is kvl seq_len = torch.full((batch_size,), kvl, device=q.device, dtype=dtype_int) # seq_start: flattened starting index for each sequence seq_start = (seq_len.cumsum(0) - seq_len).to(dtype=dtype_int) assert max_seq_len is not None, "max_seq_len must be provided when using paged attention." y = torch.ops.attention.fused_mha_with_paged_cache( q, k, v, input_positions, cache_loc, seq_len, seq_start, page_table, max_seq_len, k_cache, v_cache, freqs_cis, ) y = y.view(b, ql, n_heads, head_dim) return y def main(): #––– Test hyperparameters ––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––– batch_size = 1 q_len = 1 kv_len = 1 num_heads = 16 n_kv_heads = 16 head_dim = 128 max_batch_size = 1 max_seq_len = 1024 page_size = 256 device = "cuda" #––– Random query, key, value tensors ––––––––––––––––––––––––––––––––––––––––––––––––––– query_states = torch.randn(batch_size, q_len, num_heads, head_dim, device=device) key_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device) value_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device) k_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device) v_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device) attn_out = fused_mha_interface( query_states, key_states, value_states, k_cache=k_cache, v_cache=v_cache, ) expected_shape = (batch_size, q_len, num_heads, head_dim) print(f"[test] output shape: {attn_out.shape} (expected {expected_shape})") if attn_out.shape == expected_shape: print("[test] ✅ Success: output tensor has correct shape.") else: print("[test] ❌ Failure: shape mismatch.") if __name__ == "__main__": main()