"""Standalone induction-pattern detector for v75 pointer attention. Lightweight wrapper around analyze_pointer_health's IPMR computation — runs faster (no head-stats overhead), prints per-(layer, head) induction match rates, and identifies the top-K heads that look most induction-like. Usage: python analyze_induction.py --ckpt path/to/ckpt.pt --data validation.bin """ import argparse import json import os import sys from collections import defaultdict import numpy as np # Reuse the heavy lifting from the health probe sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from analyze_pointer_health import ( load_model, collect_pointers, sample_random_data, compute_M3_IPMR, compute_M3_IPMR_per_head, ) import torch def main(): ap = argparse.ArgumentParser() ap.add_argument('--ckpt', required=True) ap.add_argument('--data', default=None) ap.add_argument('--n-batches', type=int, default=4) ap.add_argument('--batch-size', type=int, default=4) ap.add_argument('--seq-len', type=int, default=512) ap.add_argument('--device', default='cpu') ap.add_argument('--top-k', type=int, default=15, help='Show top-K (layer, head) pairs by IPMR') ap.add_argument('--kshift-layers', default=None, help='Override saved k_shift_layers (comma-separated indices, e.g. "0"). ' 'Needed for ckpts saved before k_shift_layers was persisted in args.') ap.add_argument('--qshift-layers', default=None, help='Override saved q_shift_layers (comma-separated)') args = ap.parse_args() print(f'[induction] loading {args.ckpt}', flush=True) m, kw, val_bpc, step = load_model(args.ckpt, device=args.device, kshift_layers_override=args.kshift_layers, qshift_layers_override=args.qshift_layers) print(f'[induction] val_bpc={val_bpc} step={step}', flush=True) if args.data is None: candidates = ['/root/bitnet1/data_fineweb_edu/validation.bin', './data_fineweb_edu/validation.bin'] for c in candidates: if os.path.exists(c): args.data = c break if args.data and os.path.exists(args.data): ids_np = sample_random_data(args.data, args.n_batches, args.batch_size, args.seq_len) print(f'[induction] sampled from {args.data}', flush=True) else: print('[induction] no validation data; using synthetic with planned repeats') # Build sequences with deliberate token-pair repeats to give induction a chance rng = np.random.RandomState(0xface) ids_np = rng.randint(0, kw['vocab_size'], size=(args.n_batches, args.batch_size, args.seq_len), dtype=np.int64) # Inject some repeats: copy a 5-token bigram from positions [50:55] to [200:205] for c in range(args.n_batches): for b in range(args.batch_size): ids_np[c, b, 200:205] = ids_np[c, b, 50:55] ids_np[c, b, 350:355] = ids_np[c, b, 100:105] # Forward pass per batch layer_pointers_all = [[] for _ in range(kw['n_layers'])] for c in range(args.n_batches): x = torch.from_numpy(ids_np[c]).to(args.device) per_layer = collect_pointers(m, x) for li in range(kw['n_layers']): layer_pointers_all[li].append(per_layer[li]) print(f'[induction] forward {c+1}/{args.n_batches}', flush=True) layer_pointers = [np.stack(layer_pointers_all[li], axis=0) for li in range(kw['n_layers'])] ipmr_global, opps = compute_M3_IPMR(layer_pointers, ids_np) ipmr_per_lh, _ = compute_M3_IPMR_per_head(layer_pointers, ids_np) print(f'\n=== INDUCTION-PATTERN DETECTION ===') print(f'val_bpc={val_bpc} step={step}') print(f'Induction opportunities: {opps}') print(f'Global IPMR (any-head match rate): {ipmr_global*100:.3f}%') arr = np.array(ipmr_per_lh) if arr.size and arr.max() > 0: flat_idx = np.argsort(arr.flatten())[::-1][:args.top_k] print(f'\nTop-{args.top_k} (layer, head) by IPMR:') for fi in flat_idx: l, h = divmod(int(fi), arr.shape[1]) v = arr[l, h] if v > 0: print(f' L{l:2d} H{h:2d}: {v*100:.3f}%') else: print('\nNo (layer, head) pair has any induction-pattern matches.') # Save JSON out = args.ckpt.replace('.pt', '_induction.json') with open(out, 'w') as f: json.dump({ 'ckpt': args.ckpt, 'val_bpc': val_bpc, 'step': step, 'global_ipmr': ipmr_global, 'opportunities': opps, 'per_layer_head_ipmr': ipmr_per_lh, }, f, indent=2) print(f'\nWrote {out}') if __name__ == '__main__': main()