bitnet-1bitllm / _test_flex_attn.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Standalone smoke test for the FlexAttention hybrid path.
Two checks:
(A) eager path vs hybrid path, forward output agreement (should match
exactly because hard path is identical and STE.detach() makes forward
= O_hard regardless of soft path)
(B) gradient direction agreement (should be high cos_sim — backward goes
through soft path in both cases, just computed differently)
Runs in fp32 throughout (no autocast) to eliminate precision noise.
Run on VM:
/venv/main/bin/python _test_flex_attn.py
"""
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_attention_c = torch.compile(flex_attention, dynamic=False)
def gumbel_hard_attention_eager(scores, mask, tau, g):
s = scores.masked_fill(mask, -1e9)
y_soft = F.softmax((s + g) / tau, dim=-1)
y_hard = torch.zeros_like(y_soft)
y_hard.scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
return y_soft + (y_hard - y_soft).detach()
def attn_eager(Q, K, V, alibi_bias, causal_mask, tau, g):
scores = torch.matmul(Q, K.transpose(-2, -1))
scores = scores - alibi_bias
A = gumbel_hard_attention_eager(scores, causal_mask, tau, g)
return torch.matmul(A, V)
def attn_flex_hybrid(Q, K, V, slopes_f, tau, g, block_mask):
B, H, T, Dh = Q.shape
def score_mod(s, b, h, q, kv):
bias = slopes_f[h] * (q - kv).abs().float()
return (s + g[b, h, q, kv] - bias) / tau
O_soft = flex_attention_c(Q, K, V, score_mod=score_mod,
block_mask=block_mask, scale=1.0)
# Chunked fp32 argmax (no_grad — argmax has no gradient anyway).
BLOCK_Q = 128
argmax = torch.empty(B, H, T, dtype=torch.int64, device=Q.device)
K_T = K.transpose(-2, -1)
pos_kv = torch.arange(T, device=Q.device).view(1, 1, 1, T)
with torch.no_grad():
for q_start in range(0, T, BLOCK_Q):
q_end = min(q_start + BLOCK_Q, T)
Q_chunk = Q[:, :, q_start:q_end, :].float()
scores_chunk = torch.matmul(Q_chunk, K_T.float())
pos_q = torch.arange(q_start, q_end, device=Q.device).view(1, 1, -1, 1)
bias = slopes_f.view(1, H, 1, 1) * (pos_q - pos_kv).abs().float()
scores_chunk = scores_chunk - bias + g[:, :, q_start:q_end, :].float()
causal_chunk = pos_kv > pos_q
scores_chunk.masked_fill_(causal_chunk, -1e9)
argmax[:, :, q_start:q_end] = scores_chunk.argmax(-1)
idx = argmax.unsqueeze(-1).expand(-1, -1, -1, Dh)
O_hard = V.gather(2, idx)
return O_soft + (O_hard - O_soft).detach()
def main():
torch.manual_seed(0)
B, H, T, Dh = 2, 4, 256, 32
device = 'cuda'
# fp32 throughout — eliminates bf16 ALiBi precision noise so eager and
# hybrid see the same scores.
dtype = torch.float32
Q = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True)
K = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True)
V = torch.randn(B, H, T, Dh, device=device, dtype=dtype, requires_grad=True)
# Modest slopes so bf16 cast inside any internal operation isn't lossy.
slopes_f = torch.tensor([1.0, 2.0, 4.0, 8.0], device=device)
pos = torch.arange(T, device=device)
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs().float()
alibi_bias = (slopes_f.view(H, 1, 1) * dist.view(1, T, T)).unsqueeze(0)
causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1)
tau = torch.tensor(0.1, device=device)
g = -torch.log(-torch.log(torch.rand(B, H, T, T, device=device,
dtype=dtype).clamp_(min=1e-9)) + 1e-9)
def causal(b, h, q, kv): return q >= kv
block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device)
# === Forward agreement ===
O_eager = attn_eager(Q, K, V, alibi_bias, causal_mask, tau, g)
O_flex = attn_flex_hybrid(Q, K, V, slopes_f, tau, g, block_mask)
diff = (O_eager - O_flex).abs()
print(f'(A) forward max abs diff: {diff.max().item():.2e} '
f'mean abs diff: {diff.mean().item():.2e}')
print(f' eager mean: {O_eager.abs().mean().item():.4f} '
f'flex mean: {O_flex.abs().mean().item():.4f}')
# === Gradient agreement ===
target = torch.randn_like(O_eager)
L_eager = ((O_eager - target) ** 2).mean()
grads_eager = torch.autograd.grad(L_eager, [Q, K, V], retain_graph=False)
Q.grad = K.grad = V.grad = None
L_flex = ((O_flex - target) ** 2).mean()
grads_flex = torch.autograd.grad(L_flex, [Q, K, V], retain_graph=False)
print(f'(B) gradient agreement:')
for name, ge, gf in zip(['Q', 'K', 'V'], grads_eager, grads_flex):
cos = F.cosine_similarity(ge.flatten(), gf.flatten(), dim=0).item()
rel_err = (ge - gf).norm().item() / (ge.norm().item() + 1e-12)
print(f' d{name}: cos_sim={cos:.6f} rel_err={rel_err:.4e} '
f'eager_norm={ge.norm().item():.4f} flex_norm={gf.norm().item():.4f}')
if __name__ == '__main__':
main()