| """Isolate which operations in score_mod break Inductor lowering on this PT/triton.""" |
| import torch |
| from torch.nn.attention.flex_attention import flex_attention, create_block_mask |
|
|
| flex_c = torch.compile(flex_attention, dynamic=False) |
|
|
|
|
| def try_score_mod(label, score_mod_fn, captures=None): |
| print(f'\n=== {label} ===') |
| B, H, T, Dh = 4, 16, 2048, 48 |
| device = 'cuda' |
| Q = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True) |
| K = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True) |
| V = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True) |
| def causal(b, h, q, kv): return q >= kv |
| bm = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device) |
| try: |
| with torch.autocast('cuda', dtype=torch.bfloat16): |
| O = flex_c(Q, K, V, score_mod=score_mod_fn, block_mask=bm, scale=1.0) |
| O.sum().backward() |
| print(f' PASS shape={O.shape}') |
| except Exception as e: |
| msg = str(e).split('\n')[0] |
| |
| s = str(e) |
| for kw in ['AssertionError', 'wrong ndim', 'FlexibleLayout', 'NotImplementedError']: |
| i = s.find(kw) |
| if i >= 0: |
| msg = s[i:i+100] |
| break |
| print(f' FAIL: {msg}') |
|
|
|
|
| def main(): |
| H = 16 |
| slopes_f = torch.tensor([1 << i for i in range(H)], dtype=torch.float32, device='cuda') |
| tau_t = torch.tensor(0.1, device='cuda') |
| tau_1 = torch.tensor([0.1], device='cuda') |
|
|
| |
| try_score_mod('1. identity', lambda s, b, h, q, kv: s) |
|
|
| |
| try_score_mod('2. s / 0.1 (python float)', lambda s, b, h, q, kv: s / 0.1) |
|
|
| |
| try_score_mod('3. s / tau_t (0-dim tensor)', lambda s, b, h, q, kv: s / tau_t) |
|
|
| |
| tau_sq = tau_1.squeeze() |
| try_score_mod('4. s / tau_1.squeeze()', lambda s, b, h, q, kv: s / tau_sq) |
|
|
| |
| try_score_mod('5. s - slopes_f[h]', lambda s, b, h, q, kv: s - slopes_f[h]) |
|
|
| |
| try_score_mod('6. s + (q - kv)', lambda s, b, h, q, kv: s + (q - kv)) |
|
|
| |
| try_score_mod('7. s + abs(q - kv)', lambda s, b, h, q, kv: s + (q - kv).abs()) |
|
|
| |
| try_score_mod('8. s - slopes_f[h] * (q - kv).abs()', |
| lambda s, b, h, q, kv: s - slopes_f[h] * (q - kv).abs()) |
|
|
| |
| try_score_mod('9. (s - slopes_f[h] * (q - kv).abs()) / tau_t', |
| lambda s, b, h, q, kv: (s - slopes_f[h] * (q - kv).abs()) / tau_t) |
|
|
| |
| g = torch.zeros(4, H, 2048, 2048, device='cuda', dtype=torch.float32) |
| try_score_mod('10. + g[b,h,q,kv]', |
| lambda s, b, h, q, kv: |
| (s + g[b, h, q, kv] - slopes_f[h] * (q - kv).abs()) / tau_t) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|