bitnet-1bitllm / analyze_induction.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Standalone induction-pattern detector for v75 pointer attention.
Lightweight wrapper around analyze_pointer_health's IPMR computation —
runs faster (no head-stats overhead), prints per-(layer, head) induction
match rates, and identifies the top-K heads that look most induction-like.
Usage:
python analyze_induction.py --ckpt path/to/ckpt.pt --data validation.bin
"""
import argparse
import json
import os
import sys
from collections import defaultdict
import numpy as np
# Reuse the heavy lifting from the health probe
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from analyze_pointer_health import (
load_model, collect_pointers, sample_random_data,
compute_M3_IPMR, compute_M3_IPMR_per_head,
)
import torch
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--data', default=None)
ap.add_argument('--n-batches', type=int, default=4)
ap.add_argument('--batch-size', type=int, default=4)
ap.add_argument('--seq-len', type=int, default=512)
ap.add_argument('--device', default='cpu')
ap.add_argument('--top-k', type=int, default=15,
help='Show top-K (layer, head) pairs by IPMR')
ap.add_argument('--kshift-layers', default=None,
help='Override saved k_shift_layers (comma-separated indices, e.g. "0"). '
'Needed for ckpts saved before k_shift_layers was persisted in args.')
ap.add_argument('--qshift-layers', default=None,
help='Override saved q_shift_layers (comma-separated)')
args = ap.parse_args()
print(f'[induction] loading {args.ckpt}', flush=True)
m, kw, val_bpc, step = load_model(args.ckpt, device=args.device,
kshift_layers_override=args.kshift_layers,
qshift_layers_override=args.qshift_layers)
print(f'[induction] val_bpc={val_bpc} step={step}', flush=True)
if args.data is None:
candidates = ['/root/bitnet1/data_fineweb_edu/validation.bin',
'./data_fineweb_edu/validation.bin']
for c in candidates:
if os.path.exists(c):
args.data = c
break
if args.data and os.path.exists(args.data):
ids_np = sample_random_data(args.data, args.n_batches, args.batch_size, args.seq_len)
print(f'[induction] sampled from {args.data}', flush=True)
else:
print('[induction] no validation data; using synthetic with planned repeats')
# Build sequences with deliberate token-pair repeats to give induction a chance
rng = np.random.RandomState(0xface)
ids_np = rng.randint(0, kw['vocab_size'],
size=(args.n_batches, args.batch_size, args.seq_len),
dtype=np.int64)
# Inject some repeats: copy a 5-token bigram from positions [50:55] to [200:205]
for c in range(args.n_batches):
for b in range(args.batch_size):
ids_np[c, b, 200:205] = ids_np[c, b, 50:55]
ids_np[c, b, 350:355] = ids_np[c, b, 100:105]
# Forward pass per batch
layer_pointers_all = [[] for _ in range(kw['n_layers'])]
for c in range(args.n_batches):
x = torch.from_numpy(ids_np[c]).to(args.device)
per_layer = collect_pointers(m, x)
for li in range(kw['n_layers']):
layer_pointers_all[li].append(per_layer[li])
print(f'[induction] forward {c+1}/{args.n_batches}', flush=True)
layer_pointers = [np.stack(layer_pointers_all[li], axis=0) for li in range(kw['n_layers'])]
ipmr_global, opps = compute_M3_IPMR(layer_pointers, ids_np)
ipmr_per_lh, _ = compute_M3_IPMR_per_head(layer_pointers, ids_np)
print(f'\n=== INDUCTION-PATTERN DETECTION ===')
print(f'val_bpc={val_bpc} step={step}')
print(f'Induction opportunities: {opps}')
print(f'Global IPMR (any-head match rate): {ipmr_global*100:.3f}%')
arr = np.array(ipmr_per_lh)
if arr.size and arr.max() > 0:
flat_idx = np.argsort(arr.flatten())[::-1][:args.top_k]
print(f'\nTop-{args.top_k} (layer, head) by IPMR:')
for fi in flat_idx:
l, h = divmod(int(fi), arr.shape[1])
v = arr[l, h]
if v > 0:
print(f' L{l:2d} H{h:2d}: {v*100:.3f}%')
else:
print('\nNo (layer, head) pair has any induction-pattern matches.')
# Save JSON
out = args.ckpt.replace('.pt', '_induction.json')
with open(out, 'w') as f:
json.dump({
'ckpt': args.ckpt, 'val_bpc': val_bpc, 'step': step,
'global_ipmr': ipmr_global, 'opportunities': opps,
'per_layer_head_ipmr': ipmr_per_lh,
}, f, indent=2)
print(f'\nWrote {out}')
if __name__ == '__main__':
main()