bitnet-1bitllm / _test_flex_score_mod.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Isolate which operations in score_mod break Inductor lowering on this PT/triton."""
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_c = torch.compile(flex_attention, dynamic=False)
def try_score_mod(label, score_mod_fn, captures=None):
print(f'\n=== {label} ===')
B, H, T, Dh = 4, 16, 2048, 48
device = 'cuda'
Q = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True)
K = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True)
V = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True)
def causal(b, h, q, kv): return q >= kv
bm = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device)
try:
with torch.autocast('cuda', dtype=torch.bfloat16):
O = flex_c(Q, K, V, score_mod=score_mod_fn, block_mask=bm, scale=1.0)
O.sum().backward()
print(f' PASS shape={O.shape}')
except Exception as e:
msg = str(e).split('\n')[0]
# Find the actual assertion line if present
s = str(e)
for kw in ['AssertionError', 'wrong ndim', 'FlexibleLayout', 'NotImplementedError']:
i = s.find(kw)
if i >= 0:
msg = s[i:i+100]
break
print(f' FAIL: {msg}')
def main():
H = 16
slopes_f = torch.tensor([1 << i for i in range(H)], dtype=torch.float32, device='cuda')
tau_t = torch.tensor(0.1, device='cuda')
tau_1 = torch.tensor([0.1], device='cuda')
# Test 1: identity
try_score_mod('1. identity', lambda s, b, h, q, kv: s)
# Test 2: scalar divide (Python float)
try_score_mod('2. s / 0.1 (python float)', lambda s, b, h, q, kv: s / 0.1)
# Test 3: 0-dim tensor divide
try_score_mod('3. s / tau_t (0-dim tensor)', lambda s, b, h, q, kv: s / tau_t)
# Test 4: (1,) tensor squeezed
tau_sq = tau_1.squeeze()
try_score_mod('4. s / tau_1.squeeze()', lambda s, b, h, q, kv: s / tau_sq)
# Test 5: subtract slopes-indexed bias
try_score_mod('5. s - slopes_f[h]', lambda s, b, h, q, kv: s - slopes_f[h])
# Test 6: q - kv
try_score_mod('6. s + (q - kv)', lambda s, b, h, q, kv: s + (q - kv))
# Test 7: abs(q - kv)
try_score_mod('7. s + abs(q - kv)', lambda s, b, h, q, kv: s + (q - kv).abs())
# Test 8: full ALiBi (slopes * |q-kv|, no .float())
try_score_mod('8. s - slopes_f[h] * (q - kv).abs()',
lambda s, b, h, q, kv: s - slopes_f[h] * (q - kv).abs())
# Test 9: alibi + tau divide
try_score_mod('9. (s - slopes_f[h] * (q - kv).abs()) / tau_t',
lambda s, b, h, q, kv: (s - slopes_f[h] * (q - kv).abs()) / tau_t)
# Test 10: with captured g
g = torch.zeros(4, H, 2048, 2048, device='cuda', dtype=torch.float32)
try_score_mod('10. + g[b,h,q,kv]',
lambda s, b, h, q, kv:
(s + g[b, h, q, kv] - slopes_f[h] * (q - kv).abs()) / tau_t)
if __name__ == '__main__':
main()