bitnet-1bitllm / _test_v2_co.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Validate the custom_op wrapper:
(1) eager forward + backward correctness vs the autograd.Function path
(2) under torch.compile (catches fake-tensor / opaque-op issues)
"""
import torch
import torch.nn.functional as F
from triton_gumbel_hard_attn_v2 import gumbel_hard_attn_triton_v2
from triton_gumbel_hard_attn_v2_op import gumbel_hard_attn_co
def main():
torch.manual_seed(0)
B, H, T, D = 2, 4, 256, 48
device = 'cuda'
Q = torch.randn(B, H, T, D, device=device, dtype=torch.bfloat16, requires_grad=True)
K = torch.randn(B, H, T, D, device=device, dtype=torch.bfloat16, requires_grad=True)
V = torch.randn(B, H, T, D, device=device, dtype=torch.bfloat16, requires_grad=True)
slopes = torch.tensor([1.0, 2.0, 4.0, 8.0], device=device)
tau = 0.1
seed = torch.tensor(12345, dtype=torch.int64, device=device)
# === Eager: custom_op forward output is one-of-V-rows + finite gradients ===
O_co = gumbel_hard_attn_co(Q, K, V, slopes, tau)
print(f'(A) eager custom_op fwd: shape={tuple(O_co.shape)} '
f'finite={torch.isfinite(O_co).all().item()} '
f'mean={O_co.float().mean().item():.4f}')
# Each output row should equal some V row exactly
diffs = (O_co.float().unsqueeze(-2) - V.float().unsqueeze(-3)).pow(2).sum(-1).sqrt()
min_d = diffs.min(-1).values
print(f' min-dist to a V row: max={min_d.max().item():.2e}')
assert min_d.max().item() < 1e-2
target = torch.randn_like(O_co)
L = ((O_co - target) ** 2).mean()
g_co = torch.autograd.grad(L, [Q, K, V])
print('(B) eager gradient (finite + nontrivial):')
for name, g in zip(['Q', 'K', 'V'], g_co):
print(f' d{name}: norm={g.norm().item():.4f} finite={torch.isfinite(g).all().item()}')
# === Under torch.compile ===
Q.grad = K.grad = V.grad = None
def f(Q, K, V, slopes, tau):
return gumbel_hard_attn_co(Q, K, V, slopes, tau)
cf = torch.compile(f, mode='default', dynamic=False)
print('(C) torch.compile fwd+bwd...')
O_cc = cf(Q, K, V, slopes, tau)
L = ((O_cc - target) ** 2).mean()
L.backward()
print(f' out shape={tuple(O_cc.shape)} finite={torch.isfinite(O_cc).all().item()}')
print(f' dQ norm={Q.grad.norm().item():.4f} finite={torch.isfinite(Q.grad).all().item()}')
print(f' dK norm={K.grad.norm().item():.4f} finite={torch.isfinite(K.grad).all().item()}')
print(f' dV norm={V.grad.norm().item():.4f} finite={torch.isfinite(V.grad).all().item()}')
print('PASS')
if __name__ == '__main__':
main()