"""Isolate the gradient bug: test FlexAttention vs eager softmax attention WITHOUT any score_mod, no Gumbel, no STE. Pure standard attention. If gradients disagree here, FlexAttention's backward is broken. If gradients agree here but disagree with our hybrid, the bug is in score_mod capture or STE pattern. """ import torch import torch.nn.functional as F from torch.nn.attention.flex_attention import flex_attention, create_block_mask flex_c = torch.compile(flex_attention, dynamic=False) def main(): torch.manual_seed(0) B, H, T, Dh = 2, 4, 256, 32 device = 'cuda' dtype = torch.float32 Q = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) K = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) V = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True) causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1) def causal(b, h, q, kv): return q >= kv block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device) # eager softmax(Q K^T) @ V (causal) scores = torch.matmul(Q, K.transpose(-2, -1)) scores = scores.masked_fill(causal_mask, -1e9) A = F.softmax(scores, dim=-1) O_eager = torch.matmul(A, V) # FlexAttention with no score_mod (identity) def score_mod(s, b, h, q, kv): return s O_flex = flex_c(Q, K, V, score_mod=score_mod, block_mask=block_mask) diff = (O_eager - O_flex).abs() print(f'(A) forward agreement: max={diff.max().item():.2e} mean={diff.mean().item():.2e}') target = torch.randn_like(O_eager) L_e = ((O_eager - target) ** 2).mean() ge = torch.autograd.grad(L_e, [Q, K, V], retain_graph=False) L_f = ((O_flex - target) ** 2).mean() gf = torch.autograd.grad(L_f, [Q, K, V], retain_graph=False) print(f'(B) gradient agreement (no score_mod):') for name, e, f in zip(['Q', 'K', 'V'], ge, gf): cos = F.cosine_similarity(e.flatten(), f.flatten(), dim=0).item() rel = (e - f).norm().item() / e.norm().item() print(f' d{name}: cos_sim={cos:.6f} rel_err={rel:.4e} ' f'enorm={e.norm().item():.4f} fnorm={f.norm().item():.4f}') # FlexAttention WITHOUT softmax_scale=1 — by default it scales by 1/sqrt(Dh)! # That's why gradients differ. Standard attention is softmax(QK^T/sqrt(Dh))@V. # Eager is doing softmax(QK^T)@V (no scale). # Let me redo with eager including the scale. print('\n--- redo with both using 1/sqrt(Dh) scale ---') scale = 1.0 / (Dh ** 0.5) Q.grad = K.grad = V.grad = None scores2 = torch.matmul(Q, K.transpose(-2, -1)) * scale scores2 = scores2.masked_fill(causal_mask, -1e9) A2 = F.softmax(scores2, dim=-1) O_eager2 = torch.matmul(A2, V) diff2 = (O_eager2 - O_flex).abs() print(f'(A) forward agreement: max={diff2.max().item():.2e} mean={diff2.mean().item():.2e}') target2 = torch.randn_like(O_eager2) L_e2 = ((O_eager2 - target2) ** 2).mean() ge2 = torch.autograd.grad(L_e2, [Q, K, V], retain_graph=False) L_f2 = ((O_flex - target2) ** 2).mean() gf2 = torch.autograd.grad(L_f2, [Q, K, V], retain_graph=False) print(f'(B) gradient agreement (with scale):') for name, e, f in zip(['Q', 'K', 'V'], ge2, gf2): cos = F.cosine_similarity(e.flatten(), f.flatten(), dim=0).item() rel = (e - f).norm().item() / e.norm().item() print(f' d{name}: cos_sim={cos:.6f} rel_err={rel:.4e} ' f'enorm={e.norm().item():.4f} fnorm={f.norm().item():.4f}') if __name__ == '__main__': main()