"""Standalone smoke test for the FlexAttention hybrid path. Two checks: (A) eager path vs hybrid path, forward output agreement (should match exactly because hard path is identical and STE.detach() makes forward = O_hard regardless of soft path) (B) gradient direction agreement (should be high cos_sim — backward goes through soft path in both cases, just computed differently) Runs in fp32 throughout (no autocast) to eliminate precision noise. Run on VM: /venv/main/bin/python _test_flex_attn.py """ import torch import torch.nn.functional as F from torch.nn.attention.flex_attention import flex_attention, create_block_mask flex_attention_c = torch.compile(flex_attention, dynamic=False) def gumbel_hard_attention_eager(scores, mask, tau, g): s = scores.masked_fill(mask, -1e9) y_soft = F.softmax((s + g) / tau, dim=-1) y_hard = torch.zeros_like(y_soft) y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0) return y_soft + (y_hard - y_soft).detach() def attn_eager(Q, K, V, alibi_bias, causal_mask, tau, g): scores = torch.matmul(Q, K.transpose(-2, -1)) scores = scores - alibi_bias A = gumbel_hard_attention_eager(scores, causal_mask, tau, g) return torch.matmul(A, V) def attn_flex_hybrid(Q, K, V, slopes_f, tau, g, block_mask): B, H, T, Dh = Q.shape def score_mod(s, b, h, q, kv): bias = slopes_f[h] * (q - kv).abs().float() return (s + g[b, h, q, kv] - bias) / tau O_soft = flex_attention_c(Q, K, V, score_mod=score_mod, block_mask=block_mask, scale=1.0) # Chunked fp32 argmax (no_grad — argmax has no gradient anyway). BLOCK_Q = 128 argmax = torch.empty(B, H, T, dtype=torch.int64, device=Q.device) K_T = K.transpose(-2, -1) pos_kv = torch.arange(T, device=Q.device).view(1, 1, 1, T) with torch.no_grad(): for q_start in range(0, T, BLOCK_Q): q_end = min(q_start + BLOCK_Q, T) Q_chunk = Q[:, :, q_start:q_end, :].float() scores_chunk = torch.matmul(Q_chunk, K_T.float()) pos_q = torch.arange(q_start, q_end, device=Q.device).view(1, 1, -1, 1) bias = slopes_f.view(1, H, 1, 1) * (pos_q - pos_kv).abs().float() scores_chunk = scores_chunk - bias + g[:, :, q_start:q_end, :].float() causal_chunk = pos_kv > pos_q scores_chunk.masked_fill_(causal_chunk, -1e9) argmax[:, :, q_start:q_end] = scores_chunk.argmax(-1) idx = argmax.unsqueeze(-1).expand(-1, -1, -1, Dh) O_hard = V.gather(2, idx) return O_soft + (O_hard - O_soft).detach() def main(): torch.manual_seed(0) B, H, T, Dh = 2, 4, 256, 32 device = 'cuda' # fp32 throughout — eliminates bf16 ALiBi precision noise so eager and # hybrid see the same scores. dtype = torch.float32 Q = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) K = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) V = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) # Modest slopes so bf16 cast inside any internal operation isn't lossy. slopes_f = torch.tensor([1.0, 2.0, 4.0, 8.0], device=device) pos = torch.arange(T, device=device) dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs().float() alibi_bias = (slopes_f.view(H, 1, 1) * dist.view(1, T, T)).unsqueeze(0) causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1) tau = torch.tensor(0.1, device=device) g = -torch.log(-torch.log(torch.rand(B, H, T, T, device=device, dtype=dtype).clamp_(min=1e-9)) + 1e-9) def causal(b, h, q, kv): return q >= kv block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device) # === Forward agreement === O_eager = attn_eager(Q, K, V, alibi_bias, causal_mask, tau, g) O_flex = attn_flex_hybrid(Q, K, V, slopes_f, tau, g, block_mask) diff = (O_eager - O_flex).abs() print(f'(A) forward max abs diff: {diff.max().item():.2e} ' f'mean abs diff: {diff.mean().item():.2e}') print(f' eager mean: {O_eager.abs().mean().item():.4f} ' f'flex mean: {O_flex.abs().mean().item():.4f}') # === Gradient agreement === target = torch.randn_like(O_eager) L_eager = ((O_eager - target) ** 2).mean() grads_eager = torch.autograd.grad(L_eager, [Q, K, V], retain_graph=False) Q.grad = K.grad = V.grad = None L_flex = ((O_flex - target) ** 2).mean() grads_flex = torch.autograd.grad(L_flex, [Q, K, V], retain_graph=False) print(f'(B) gradient agreement:') for name, ge, gf in zip(['Q', 'K', 'V'], grads_eager, grads_flex): cos = F.cosine_similarity(ge.flatten(), gf.flatten(), dim=0).item() rel_err = (ge - gf).norm().item() / (ge.norm().item() + 1e-12) print(f' d{name}: cos_sim={cos:.6f} rel_err={rel_err:.4e} ' f'eager_norm={ge.norm().item():.4f} flex_norm={gf.norm().item():.4f}') if __name__ == '__main__': main()