"""Validate the custom_op wrapper: (1) eager forward + backward correctness vs the autograd.Function path (2) under torch.compile (catches fake-tensor / opaque-op issues) """ import torch import torch.nn.functional as F from triton_gumbel_hard_attn_v2 import gumbel_hard_attn_triton_v2 from triton_gumbel_hard_attn_v2_op import gumbel_hard_attn_co def main(): torch.manual_seed(0) B, H, T, D = 2, 4, 256, 48 device = 'cuda' Q = torch.randn(B, H, T, D, device=device, dtype=torch.bfloat16, requires_grad=True) K = torch.randn(B, H, T, D, device=device, dtype=torch.bfloat16, requires_grad=True) V = torch.randn(B, H, T, D, device=device, dtype=torch.bfloat16, requires_grad=True) slopes = torch.tensor([1.0, 2.0, 4.0, 8.0], device=device) tau = 0.1 seed = torch.tensor(12345, dtype=torch.int64, device=device) # === Eager: custom_op forward output is one-of-V-rows + finite gradients === O_co = gumbel_hard_attn_co(Q, K, V, slopes, tau) print(f'(A) eager custom_op fwd: shape={tuple(O_co.shape)} ' f'finite={torch.isfinite(O_co).all().item()} ' f'mean={O_co.float().mean().item():.4f}') # Each output row should equal some V row exactly diffs = (O_co.float().unsqueeze(-2) - V.float().unsqueeze(-3)).pow(2).sum(-1).sqrt() min_d = diffs.min(-1).values print(f' min-dist to a V row: max={min_d.max().item():.2e}') assert min_d.max().item() < 1e-2 target = torch.randn_like(O_co) L = ((O_co - target) ** 2).mean() g_co = torch.autograd.grad(L, [Q, K, V]) print('(B) eager gradient (finite + nontrivial):') for name, g in zip(['Q', 'K', 'V'], g_co): print(f' d{name}: norm={g.norm().item():.4f} finite={torch.isfinite(g).all().item()}') # === Under torch.compile === Q.grad = K.grad = V.grad = None def f(Q, K, V, slopes, tau): return gumbel_hard_attn_co(Q, K, V, slopes, tau) cf = torch.compile(f, mode='default', dynamic=False) print('(C) torch.compile fwd+bwd...') O_cc = cf(Q, K, V, slopes, tau) L = ((O_cc - target) ** 2).mean() L.backward() print(f' out shape={tuple(O_cc.shape)} finite={torch.isfinite(O_cc).all().item()}') print(f' dQ norm={Q.grad.norm().item():.4f} finite={torch.isfinite(Q.grad).all().item()}') print(f' dK norm={K.grad.norm().item():.4f} finite={torch.isfinite(K.grad).all().item()}') print(f' dV norm={V.grad.norm().item():.4f} finite={torch.isfinite(V.grad).all().item()}') print('PASS') if __name__ == '__main__': main()