File size: 4,861 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Standalone induction-pattern detector for v75 pointer attention.

Lightweight wrapper around analyze_pointer_health's IPMR computation —
runs faster (no head-stats overhead), prints per-(layer, head) induction
match rates, and identifies the top-K heads that look most induction-like.

Usage:
  python analyze_induction.py --ckpt path/to/ckpt.pt --data validation.bin
"""
import argparse
import json
import os
import sys
from collections import defaultdict

import numpy as np

# Reuse the heavy lifting from the health probe
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from analyze_pointer_health import (
    load_model, collect_pointers, sample_random_data,
    compute_M3_IPMR, compute_M3_IPMR_per_head,
)
import torch


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--ckpt', required=True)
    ap.add_argument('--data', default=None)
    ap.add_argument('--n-batches', type=int, default=4)
    ap.add_argument('--batch-size', type=int, default=4)
    ap.add_argument('--seq-len', type=int, default=512)
    ap.add_argument('--device', default='cpu')
    ap.add_argument('--top-k', type=int, default=15,
                    help='Show top-K (layer, head) pairs by IPMR')
    ap.add_argument('--kshift-layers', default=None,
                    help='Override saved k_shift_layers (comma-separated indices, e.g. "0"). '
                         'Needed for ckpts saved before k_shift_layers was persisted in args.')
    ap.add_argument('--qshift-layers', default=None,
                    help='Override saved q_shift_layers (comma-separated)')
    args = ap.parse_args()

    print(f'[induction] loading {args.ckpt}', flush=True)
    m, kw, val_bpc, step = load_model(args.ckpt, device=args.device,
                                       kshift_layers_override=args.kshift_layers,
                                       qshift_layers_override=args.qshift_layers)
    print(f'[induction] val_bpc={val_bpc} step={step}', flush=True)

    if args.data is None:
        candidates = ['/root/bitnet1/data_fineweb_edu/validation.bin',
                      './data_fineweb_edu/validation.bin']
        for c in candidates:
            if os.path.exists(c):
                args.data = c
                break

    if args.data and os.path.exists(args.data):
        ids_np = sample_random_data(args.data, args.n_batches, args.batch_size, args.seq_len)
        print(f'[induction] sampled from {args.data}', flush=True)
    else:
        print('[induction] no validation data; using synthetic with planned repeats')
        # Build sequences with deliberate token-pair repeats to give induction a chance
        rng = np.random.RandomState(0xface)
        ids_np = rng.randint(0, kw['vocab_size'],
                              size=(args.n_batches, args.batch_size, args.seq_len),
                              dtype=np.int64)
        # Inject some repeats: copy a 5-token bigram from positions [50:55] to [200:205]
        for c in range(args.n_batches):
            for b in range(args.batch_size):
                ids_np[c, b, 200:205] = ids_np[c, b, 50:55]
                ids_np[c, b, 350:355] = ids_np[c, b, 100:105]

    # Forward pass per batch
    layer_pointers_all = [[] for _ in range(kw['n_layers'])]
    for c in range(args.n_batches):
        x = torch.from_numpy(ids_np[c]).to(args.device)
        per_layer = collect_pointers(m, x)
        for li in range(kw['n_layers']):
            layer_pointers_all[li].append(per_layer[li])
        print(f'[induction] forward {c+1}/{args.n_batches}', flush=True)

    layer_pointers = [np.stack(layer_pointers_all[li], axis=0) for li in range(kw['n_layers'])]

    ipmr_global, opps = compute_M3_IPMR(layer_pointers, ids_np)
    ipmr_per_lh, _ = compute_M3_IPMR_per_head(layer_pointers, ids_np)

    print(f'\n=== INDUCTION-PATTERN DETECTION ===')
    print(f'val_bpc={val_bpc} step={step}')
    print(f'Induction opportunities: {opps}')
    print(f'Global IPMR (any-head match rate): {ipmr_global*100:.3f}%')

    arr = np.array(ipmr_per_lh)
    if arr.size and arr.max() > 0:
        flat_idx = np.argsort(arr.flatten())[::-1][:args.top_k]
        print(f'\nTop-{args.top_k} (layer, head) by IPMR:')
        for fi in flat_idx:
            l, h = divmod(int(fi), arr.shape[1])
            v = arr[l, h]
            if v > 0:
                print(f'  L{l:2d} H{h:2d}: {v*100:.3f}%')
    else:
        print('\nNo (layer, head) pair has any induction-pattern matches.')

    # Save JSON
    out = args.ckpt.replace('.pt', '_induction.json')
    with open(out, 'w') as f:
        json.dump({
            'ckpt': args.ckpt, 'val_bpc': val_bpc, 'step': step,
            'global_ipmr': ipmr_global, 'opportunities': opps,
            'per_layer_head_ipmr': ipmr_per_lh,
        }, f, indent=2)
    print(f'\nWrote {out}')


if __name__ == '__main__':
    main()