| """Validate the custom_op wrapper: |
| (1) eager forward + backward correctness vs the autograd.Function path |
| (2) under torch.compile (catches fake-tensor / opaque-op issues) |
| """ |
| import torch |
| import torch.nn.functional as F |
| from triton_gumbel_hard_attn_v2 import gumbel_hard_attn_triton_v2 |
| from triton_gumbel_hard_attn_v2_op import gumbel_hard_attn_co |
|
|
|
|
| def main(): |
| torch.manual_seed(0) |
| B, H, T, D = 2, 4, 256, 48 |
| device = 'cuda' |
|
|
| Q = torch.randn(B, H, T, D, device=device, dtype=torch.bfloat16, requires_grad=True) |
| K = torch.randn(B, H, T, D, device=device, dtype=torch.bfloat16, requires_grad=True) |
| V = torch.randn(B, H, T, D, device=device, dtype=torch.bfloat16, requires_grad=True) |
| slopes = torch.tensor([1.0, 2.0, 4.0, 8.0], device=device) |
| tau = 0.1 |
| seed = torch.tensor(12345, dtype=torch.int64, device=device) |
|
|
| |
| O_co = gumbel_hard_attn_co(Q, K, V, slopes, tau) |
| print(f'(A) eager custom_op fwd: shape={tuple(O_co.shape)} ' |
| f'finite={torch.isfinite(O_co).all().item()} ' |
| f'mean={O_co.float().mean().item():.4f}') |
| |
| diffs = (O_co.float().unsqueeze(-2) - V.float().unsqueeze(-3)).pow(2).sum(-1).sqrt() |
| min_d = diffs.min(-1).values |
| print(f' min-dist to a V row: max={min_d.max().item():.2e}') |
| assert min_d.max().item() < 1e-2 |
|
|
| target = torch.randn_like(O_co) |
| L = ((O_co - target) ** 2).mean() |
| g_co = torch.autograd.grad(L, [Q, K, V]) |
| print('(B) eager gradient (finite + nontrivial):') |
| for name, g in zip(['Q', 'K', 'V'], g_co): |
| print(f' d{name}: norm={g.norm().item():.4f} finite={torch.isfinite(g).all().item()}') |
|
|
| |
| Q.grad = K.grad = V.grad = None |
|
|
| def f(Q, K, V, slopes, tau): |
| return gumbel_hard_attn_co(Q, K, V, slopes, tau) |
| cf = torch.compile(f, mode='default', dynamic=False) |
|
|
| print('(C) torch.compile fwd+bwd...') |
| O_cc = cf(Q, K, V, slopes, tau) |
| L = ((O_cc - target) ** 2).mean() |
| L.backward() |
| print(f' out shape={tuple(O_cc.shape)} finite={torch.isfinite(O_cc).all().item()}') |
| print(f' dQ norm={Q.grad.norm().item():.4f} finite={torch.isfinite(Q.grad).all().item()}') |
| print(f' dK norm={K.grad.norm().item():.4f} finite={torch.isfinite(K.grad).all().item()}') |
| print(f' dV norm={V.grad.norm().item():.4f} finite={torch.isfinite(V.grad).all().item()}') |
|
|
| print('PASS') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|