| """Standalone induction-pattern detector for v75 pointer attention. |
| |
| Lightweight wrapper around analyze_pointer_health's IPMR computation — |
| runs faster (no head-stats overhead), prints per-(layer, head) induction |
| match rates, and identifies the top-K heads that look most induction-like. |
| |
| Usage: |
| python analyze_induction.py --ckpt path/to/ckpt.pt --data validation.bin |
| """ |
| import argparse |
| import json |
| import os |
| import sys |
| from collections import defaultdict |
|
|
| import numpy as np |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from analyze_pointer_health import ( |
| load_model, collect_pointers, sample_random_data, |
| compute_M3_IPMR, compute_M3_IPMR_per_head, |
| ) |
| import torch |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--ckpt', required=True) |
| ap.add_argument('--data', default=None) |
| ap.add_argument('--n-batches', type=int, default=4) |
| ap.add_argument('--batch-size', type=int, default=4) |
| ap.add_argument('--seq-len', type=int, default=512) |
| ap.add_argument('--device', default='cpu') |
| ap.add_argument('--top-k', type=int, default=15, |
| help='Show top-K (layer, head) pairs by IPMR') |
| ap.add_argument('--kshift-layers', default=None, |
| help='Override saved k_shift_layers (comma-separated indices, e.g. "0"). ' |
| 'Needed for ckpts saved before k_shift_layers was persisted in args.') |
| ap.add_argument('--qshift-layers', default=None, |
| help='Override saved q_shift_layers (comma-separated)') |
| args = ap.parse_args() |
|
|
| print(f'[induction] loading {args.ckpt}', flush=True) |
| m, kw, val_bpc, step = load_model(args.ckpt, device=args.device, |
| kshift_layers_override=args.kshift_layers, |
| qshift_layers_override=args.qshift_layers) |
| print(f'[induction] val_bpc={val_bpc} step={step}', flush=True) |
|
|
| if args.data is None: |
| candidates = ['/root/bitnet1/data_fineweb_edu/validation.bin', |
| './data_fineweb_edu/validation.bin'] |
| for c in candidates: |
| if os.path.exists(c): |
| args.data = c |
| break |
|
|
| if args.data and os.path.exists(args.data): |
| ids_np = sample_random_data(args.data, args.n_batches, args.batch_size, args.seq_len) |
| print(f'[induction] sampled from {args.data}', flush=True) |
| else: |
| print('[induction] no validation data; using synthetic with planned repeats') |
| |
| rng = np.random.RandomState(0xface) |
| ids_np = rng.randint(0, kw['vocab_size'], |
| size=(args.n_batches, args.batch_size, args.seq_len), |
| dtype=np.int64) |
| |
| for c in range(args.n_batches): |
| for b in range(args.batch_size): |
| ids_np[c, b, 200:205] = ids_np[c, b, 50:55] |
| ids_np[c, b, 350:355] = ids_np[c, b, 100:105] |
|
|
| |
| layer_pointers_all = [[] for _ in range(kw['n_layers'])] |
| for c in range(args.n_batches): |
| x = torch.from_numpy(ids_np[c]).to(args.device) |
| per_layer = collect_pointers(m, x) |
| for li in range(kw['n_layers']): |
| layer_pointers_all[li].append(per_layer[li]) |
| print(f'[induction] forward {c+1}/{args.n_batches}', flush=True) |
|
|
| layer_pointers = [np.stack(layer_pointers_all[li], axis=0) for li in range(kw['n_layers'])] |
|
|
| ipmr_global, opps = compute_M3_IPMR(layer_pointers, ids_np) |
| ipmr_per_lh, _ = compute_M3_IPMR_per_head(layer_pointers, ids_np) |
|
|
| print(f'\n=== INDUCTION-PATTERN DETECTION ===') |
| print(f'val_bpc={val_bpc} step={step}') |
| print(f'Induction opportunities: {opps}') |
| print(f'Global IPMR (any-head match rate): {ipmr_global*100:.3f}%') |
|
|
| arr = np.array(ipmr_per_lh) |
| if arr.size and arr.max() > 0: |
| flat_idx = np.argsort(arr.flatten())[::-1][:args.top_k] |
| print(f'\nTop-{args.top_k} (layer, head) by IPMR:') |
| for fi in flat_idx: |
| l, h = divmod(int(fi), arr.shape[1]) |
| v = arr[l, h] |
| if v > 0: |
| print(f' L{l:2d} H{h:2d}: {v*100:.3f}%') |
| else: |
| print('\nNo (layer, head) pair has any induction-pattern matches.') |
|
|
| |
| out = args.ckpt.replace('.pt', '_induction.json') |
| with open(out, 'w') as f: |
| json.dump({ |
| 'ckpt': args.ckpt, 'val_bpc': val_bpc, 'step': step, |
| 'global_ipmr': ipmr_global, 'opportunities': opps, |
| 'per_layer_head_ipmr': ipmr_per_lh, |
| }, f, indent=2) |
| print(f'\nWrote {out}') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|