| """Standalone smoke test for the FlexAttention hybrid path. |
| |
| Two checks: |
| (A) eager path vs hybrid path, forward output agreement (should match |
| exactly because hard path is identical and STE.detach() makes forward |
| = O_hard regardless of soft path) |
| (B) gradient direction agreement (should be high cos_sim — backward goes |
| through soft path in both cases, just computed differently) |
| |
| Runs in fp32 throughout (no autocast) to eliminate precision noise. |
| |
| Run on VM: |
| /venv/main/bin/python _test_flex_attn.py |
| """ |
| import torch |
| import torch.nn.functional as F |
| from torch.nn.attention.flex_attention import flex_attention, create_block_mask |
|
|
|
|
| flex_attention_c = torch.compile(flex_attention, dynamic=False) |
|
|
|
|
| def gumbel_hard_attention_eager(scores, mask, tau, g): |
| s = scores.masked_fill(mask, -1e9) |
| y_soft = F.softmax((s + g) / tau, dim=-1) |
| y_hard = torch.zeros_like(y_soft) |
| y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0) |
| return y_soft + (y_hard - y_soft).detach() |
|
|
|
|
| def attn_eager(Q, K, V, alibi_bias, causal_mask, tau, g): |
| scores = torch.matmul(Q, K.transpose(-2, -1)) |
| scores = scores - alibi_bias |
| A = gumbel_hard_attention_eager(scores, causal_mask, tau, g) |
| return torch.matmul(A, V) |
|
|
|
|
| def attn_flex_hybrid(Q, K, V, slopes_f, tau, g, block_mask): |
| B, H, T, Dh = Q.shape |
|
|
| def score_mod(s, b, h, q, kv): |
| bias = slopes_f[h] * (q - kv).abs().float() |
| return (s + g[b, h, q, kv] - bias) / tau |
|
|
| O_soft = flex_attention_c(Q, K, V, score_mod=score_mod, |
| block_mask=block_mask, scale=1.0) |
|
|
| |
| BLOCK_Q = 128 |
| argmax = torch.empty(B, H, T, dtype=torch.int64, device=Q.device) |
| K_T = K.transpose(-2, -1) |
| pos_kv = torch.arange(T, device=Q.device).view(1, 1, 1, T) |
| with torch.no_grad(): |
| for q_start in range(0, T, BLOCK_Q): |
| q_end = min(q_start + BLOCK_Q, T) |
| Q_chunk = Q[:, :, q_start:q_end, :].float() |
| scores_chunk = torch.matmul(Q_chunk, K_T.float()) |
| pos_q = torch.arange(q_start, q_end, device=Q.device).view(1, 1, -1, 1) |
| bias = slopes_f.view(1, H, 1, 1) * (pos_q - pos_kv).abs().float() |
| scores_chunk = scores_chunk - bias + g[:, :, q_start:q_end, :].float() |
| causal_chunk = pos_kv > pos_q |
| scores_chunk.masked_fill_(causal_chunk, -1e9) |
| argmax[:, :, q_start:q_end] = scores_chunk.argmax(-1) |
|
|
| idx = argmax.unsqueeze(-1).expand(-1, -1, -1, Dh) |
| O_hard = V.gather(2, idx) |
| return O_soft + (O_hard - O_soft).detach() |
|
|
|
|
| def main(): |
| torch.manual_seed(0) |
| B, H, T, Dh = 2, 4, 256, 32 |
| device = 'cuda' |
| |
| |
| dtype = torch.float32 |
|
|
| Q = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) |
| K = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) |
| V = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) |
|
|
| |
| slopes_f = torch.tensor([1.0, 2.0, 4.0, 8.0], device=device) |
| pos = torch.arange(T, device=device) |
| dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs().float() |
| alibi_bias = (slopes_f.view(H, 1, 1) * dist.view(1, T, T)).unsqueeze(0) |
| causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1) |
| tau = torch.tensor(0.1, device=device) |
| g = -torch.log(-torch.log(torch.rand(B, H, T, T, device=device, |
| dtype=dtype).clamp_(min=1e-9)) + 1e-9) |
|
|
| def causal(b, h, q, kv): return q >= kv |
| block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device) |
|
|
| |
| O_eager = attn_eager(Q, K, V, alibi_bias, causal_mask, tau, g) |
| O_flex = attn_flex_hybrid(Q, K, V, slopes_f, tau, g, block_mask) |
| diff = (O_eager - O_flex).abs() |
| print(f'(A) forward max abs diff: {diff.max().item():.2e} ' |
| f'mean abs diff: {diff.mean().item():.2e}') |
| print(f' eager mean: {O_eager.abs().mean().item():.4f} ' |
| f'flex mean: {O_flex.abs().mean().item():.4f}') |
|
|
| |
| target = torch.randn_like(O_eager) |
| L_eager = ((O_eager - target) ** 2).mean() |
| grads_eager = torch.autograd.grad(L_eager, [Q, K, V], retain_graph=False) |
|
|
| Q.grad = K.grad = V.grad = None |
| L_flex = ((O_flex - target) ** 2).mean() |
| grads_flex = torch.autograd.grad(L_flex, [Q, K, V], retain_graph=False) |
|
|
| print(f'(B) gradient agreement:') |
| for name, ge, gf in zip(['Q', 'K', 'V'], grads_eager, grads_flex): |
| cos = F.cosine_similarity(ge.flatten(), gf.flatten(), dim=0).item() |
| rel_err = (ge - gf).norm().item() / (ge.norm().item() + 1e-12) |
| print(f' d{name}: cos_sim={cos:.6f} rel_err={rel_err:.4e} ' |
| f'eager_norm={ge.norm().item():.4f} flex_norm={gf.norm().item():.4f}') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|