"""Isolate which operations in score_mod break Inductor lowering on this PT/triton.""" import torch from torch.nn.attention.flex_attention import flex_attention, create_block_mask flex_c = torch.compile(flex_attention, dynamic=False) def try_score_mod(label, score_mod_fn, captures=None): print(f'\n=== {label} ===') B, H, T, Dh = 4, 16, 2048, 48 device = 'cuda' Q = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True) K = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True) V = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True) def causal(b, h, q, kv): return q >= kv bm = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device) try: with torch.autocast('cuda', dtype=torch.bfloat16): O = flex_c(Q, K, V, score_mod=score_mod_fn, block_mask=bm, scale=1.0) O.sum().backward() print(f' PASS shape={O.shape}') except Exception as e: msg = str(e).split('\n')[0] # Find the actual assertion line if present s = str(e) for kw in ['AssertionError', 'wrong ndim', 'FlexibleLayout', 'NotImplementedError']: i = s.find(kw) if i >= 0: msg = s[i:i+100] break print(f' FAIL: {msg}') def main(): H = 16 slopes_f = torch.tensor([1 << i for i in range(H)], dtype=torch.float32, device='cuda') tau_t = torch.tensor(0.1, device='cuda') tau_1 = torch.tensor([0.1], device='cuda') # Test 1: identity try_score_mod('1. identity', lambda s, b, h, q, kv: s) # Test 2: scalar divide (Python float) try_score_mod('2. s / 0.1 (python float)', lambda s, b, h, q, kv: s / 0.1) # Test 3: 0-dim tensor divide try_score_mod('3. s / tau_t (0-dim tensor)', lambda s, b, h, q, kv: s / tau_t) # Test 4: (1,) tensor squeezed tau_sq = tau_1.squeeze() try_score_mod('4. s / tau_1.squeeze()', lambda s, b, h, q, kv: s / tau_sq) # Test 5: subtract slopes-indexed bias try_score_mod('5. s - slopes_f[h]', lambda s, b, h, q, kv: s - slopes_f[h]) # Test 6: q - kv try_score_mod('6. s + (q - kv)', lambda s, b, h, q, kv: s + (q - kv)) # Test 7: abs(q - kv) try_score_mod('7. s + abs(q - kv)', lambda s, b, h, q, kv: s + (q - kv).abs()) # Test 8: full ALiBi (slopes * |q-kv|, no .float()) try_score_mod('8. s - slopes_f[h] * (q - kv).abs()', lambda s, b, h, q, kv: s - slopes_f[h] * (q - kv).abs()) # Test 9: alibi + tau divide try_score_mod('9. (s - slopes_f[h] * (q - kv).abs()) / tau_t', lambda s, b, h, q, kv: (s - slopes_f[h] * (q - kv).abs()) / tau_t) # Test 10: with captured g g = torch.zeros(4, H, 2048, 2048, device='cuda', dtype=torch.float32) try_score_mod('10. + g[b,h,q,kv]', lambda s, b, h, q, kv: (s + g[b, h, q, kv] - slopes_f[h] * (q - kv).abs()) / tau_t) if __name__ == '__main__': main()