bitnet-1bitllm / _test_flex_attn2.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Isolate the gradient bug: test FlexAttention vs eager softmax attention
WITHOUT any score_mod, no Gumbel, no STE. Pure standard attention.
If gradients disagree here, FlexAttention's backward is broken.
If gradients agree here but disagree with our hybrid, the bug is in score_mod
capture or STE pattern.
"""
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_c = torch.compile(flex_attention, dynamic=False)
def main():
torch.manual_seed(0)
B, H, T, Dh = 2, 4, 256, 32
device = 'cuda'
dtype = torch.float32
Q = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True)
K = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True)
V = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True)
causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1)
def causal(b, h, q, kv): return q >= kv
block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device)
# eager softmax(Q K^T) @ V (causal)
scores = torch.matmul(Q, K.transpose(-2, -1))
scores = scores.masked_fill(causal_mask, -1e9)
A = F.softmax(scores, dim=-1)
O_eager = torch.matmul(A, V)
# FlexAttention with no score_mod (identity)
def score_mod(s, b, h, q, kv): return s
O_flex = flex_c(Q, K, V, score_mod=score_mod, block_mask=block_mask)
diff = (O_eager - O_flex).abs()
print(f'(A) forward agreement: max={diff.max().item():.2e} mean={diff.mean().item():.2e}')
target = torch.randn_like(O_eager)
L_e = ((O_eager - target) ** 2).mean()
ge = torch.autograd.grad(L_e, [Q, K, V], retain_graph=False)
L_f = ((O_flex - target) ** 2).mean()
gf = torch.autograd.grad(L_f, [Q, K, V], retain_graph=False)
print(f'(B) gradient agreement (no score_mod):')
for name, e, f in zip(['Q', 'K', 'V'], ge, gf):
cos = F.cosine_similarity(e.flatten(), f.flatten(), dim=0).item()
rel = (e - f).norm().item() / e.norm().item()
print(f' d{name}: cos_sim={cos:.6f} rel_err={rel:.4e} '
f'enorm={e.norm().item():.4f} fnorm={f.norm().item():.4f}')
# FlexAttention WITHOUT softmax_scale=1 — by default it scales by 1/sqrt(Dh)!
# That's why gradients differ. Standard attention is softmax(QK^T/sqrt(Dh))@V.
# Eager is doing softmax(QK^T)@V (no scale).
# Let me redo with eager including the scale.
print('\n--- redo with both using 1/sqrt(Dh) scale ---')
scale = 1.0 / (Dh ** 0.5)
Q.grad = K.grad = V.grad = None
scores2 = torch.matmul(Q, K.transpose(-2, -1)) * scale
scores2 = scores2.masked_fill(causal_mask, -1e9)
A2 = F.softmax(scores2, dim=-1)
O_eager2 = torch.matmul(A2, V)
diff2 = (O_eager2 - O_flex).abs()
print(f'(A) forward agreement: max={diff2.max().item():.2e} mean={diff2.mean().item():.2e}')
target2 = torch.randn_like(O_eager2)
L_e2 = ((O_eager2 - target2) ** 2).mean()
ge2 = torch.autograd.grad(L_e2, [Q, K, V], retain_graph=False)
L_f2 = ((O_flex - target2) ** 2).mean()
gf2 = torch.autograd.grad(L_f2, [Q, K, V], retain_graph=False)
print(f'(B) gradient agreement (with scale):')
for name, e, f in zip(['Q', 'K', 'V'], ge2, gf2):
cos = F.cosine_similarity(e.flatten(), f.flatten(), dim=0).item()
rel = (e - f).norm().item() / e.norm().item()
print(f' d{name}: cos_sim={cos:.6f} rel_err={rel:.4e} '
f'enorm={e.norm().item():.4f} fnorm={f.norm().item():.4f}')
if __name__ == '__main__':
main()