| """Isolate the gradient bug: test FlexAttention vs eager softmax attention |
| WITHOUT any score_mod, no Gumbel, no STE. Pure standard attention. |
| |
| If gradients disagree here, FlexAttention's backward is broken. |
| If gradients agree here but disagree with our hybrid, the bug is in score_mod |
| capture or STE pattern. |
| """ |
| import torch |
| import torch.nn.functional as F |
| from torch.nn.attention.flex_attention import flex_attention, create_block_mask |
|
|
| flex_c = torch.compile(flex_attention, dynamic=False) |
|
|
|
|
| def main(): |
| torch.manual_seed(0) |
| B, H, T, Dh = 2, 4, 256, 32 |
| device = 'cuda' |
| dtype = torch.float32 |
|
|
| Q = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) |
| K = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) |
| V = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) |
|
|
| causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1) |
| def causal(b, h, q, kv): return q >= kv |
| block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device) |
|
|
| |
| scores = torch.matmul(Q, K.transpose(-2, -1)) |
| scores = scores.masked_fill(causal_mask, -1e9) |
| A = F.softmax(scores, dim=-1) |
| O_eager = torch.matmul(A, V) |
|
|
| |
| def score_mod(s, b, h, q, kv): return s |
| O_flex = flex_c(Q, K, V, score_mod=score_mod, block_mask=block_mask) |
|
|
| diff = (O_eager - O_flex).abs() |
| print(f'(A) forward agreement: max={diff.max().item():.2e} mean={diff.mean().item():.2e}') |
|
|
| target = torch.randn_like(O_eager) |
| L_e = ((O_eager - target) ** 2).mean() |
| ge = torch.autograd.grad(L_e, [Q, K, V], retain_graph=False) |
|
|
| L_f = ((O_flex - target) ** 2).mean() |
| gf = torch.autograd.grad(L_f, [Q, K, V], retain_graph=False) |
|
|
| print(f'(B) gradient agreement (no score_mod):') |
| for name, e, f in zip(['Q', 'K', 'V'], ge, gf): |
| cos = F.cosine_similarity(e.flatten(), f.flatten(), dim=0).item() |
| rel = (e - f).norm().item() / e.norm().item() |
| print(f' d{name}: cos_sim={cos:.6f} rel_err={rel:.4e} ' |
| f'enorm={e.norm().item():.4f} fnorm={f.norm().item():.4f}') |
|
|
| |
| |
| |
| |
| print('\n--- redo with both using 1/sqrt(Dh) scale ---') |
| scale = 1.0 / (Dh ** 0.5) |
| Q.grad = K.grad = V.grad = None |
| scores2 = torch.matmul(Q, K.transpose(-2, -1)) * scale |
| scores2 = scores2.masked_fill(causal_mask, -1e9) |
| A2 = F.softmax(scores2, dim=-1) |
| O_eager2 = torch.matmul(A2, V) |
|
|
| diff2 = (O_eager2 - O_flex).abs() |
| print(f'(A) forward agreement: max={diff2.max().item():.2e} mean={diff2.mean().item():.2e}') |
|
|
| target2 = torch.randn_like(O_eager2) |
| L_e2 = ((O_eager2 - target2) ** 2).mean() |
| ge2 = torch.autograd.grad(L_e2, [Q, K, V], retain_graph=False) |
|
|
| L_f2 = ((O_flex - target2) ** 2).mean() |
| gf2 = torch.autograd.grad(L_f2, [Q, K, V], retain_graph=False) |
|
|
| print(f'(B) gradient agreement (with scale):') |
| for name, e, f in zip(['Q', 'K', 'V'], ge2, gf2): |
| cos = F.cosine_similarity(e.flatten(), f.flatten(), dim=0).item() |
| rel = (e - f).norm().item() / e.norm().item() |
| print(f' d{name}: cos_sim={cos:.6f} rel_err={rel:.4e} ' |
| f'enorm={e.norm().item():.4f} fnorm={f.norm().item():.4f}') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|