| """Diagnostic tests on trained binary and FP32 models. |
| |
| Outputs structured JSON that analyze_report.py compiles into a readable report. |
| Each test tries to reveal *mechanism*, not just measure BPC. |
| """ |
| import argparse, json, math, os, time |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from model_v18 import BitLMv18 |
| from model_fp32 import FP32LM |
| from model_v16 import set_gumbel_tau |
|
|
|
|
| def load_binary_ckpt(path, device='cuda'): |
| ck = torch.load(path, map_location=device, weights_only=False) |
| cfg = ck['args'] |
| m = BitLMv18( |
| vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], |
| n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'], |
| ).to(device) |
| m.load_state_dict(ck['model']) |
| m.eval() |
| return m, ck |
|
|
|
|
| def load_fp32_ckpt(path, device='cuda'): |
| ck = torch.load(path, map_location=device, weights_only=False) |
| cfg = ck['args'] |
| m = FP32LM( |
| vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], |
| n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'], |
| ).to(device) |
| m.load_state_dict(ck['model']) |
| m.eval() |
| return m, ck |
|
|
|
|
| def sample_eval_batch(data, batch_size, seq_len, device='cuda'): |
| ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,)) |
| x = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in ix]).to(device) |
| y = torch.stack([torch.from_numpy(data[i+1:i+1+seq_len].astype(np.int64)) for i in ix]).to(device) |
| return x, y |
|
|
|
|
| |
| def layer_ablation_bpc(m, val_data, n_batches=20, bs=32, seq_len=256, device='cuda'): |
| """Zero out each layer's contribution (residual-only), measure BPC delta.""" |
| |
| m.eval() |
| base_losses = [] |
| with torch.no_grad(): |
| for _ in range(n_batches): |
| x, y = sample_eval_batch(val_data, bs, seq_len, device) |
| _, loss = m(x, y) |
| base_losses.append(loss.item()) |
| base_bpc = float(np.mean(base_losses)) / math.log(2) |
|
|
| |
| results = [] |
| for li in range(len(m.blocks)): |
| original = m.blocks[li].forward |
| |
| m.blocks[li].forward = lambda x: x |
| with torch.no_grad(): |
| abl_losses = [] |
| for _ in range(n_batches): |
| x, y = sample_eval_batch(val_data, bs, seq_len, device) |
| _, loss = m(x, y) |
| abl_losses.append(loss.item()) |
| m.blocks[li].forward = original |
| abl_bpc = float(np.mean(abl_losses)) / math.log(2) |
| results.append({'layer': li, 'baseline_bpc': base_bpc, 'ablated_bpc': abl_bpc, |
| 'delta_bpc': abl_bpc - base_bpc}) |
| return {'baseline_bpc': base_bpc, 'per_layer': results} |
|
|
|
|
| |
| def weight_saturation(m): |
| """For each 2D weight tensor, compute the distribution of |latent|. |
| High |latent| = 'locked sign' (won't flip easily). Near zero = 'flippable'. |
| Returns per-parameter distribution stats. |
| """ |
| stats = [] |
| for name, p in m.named_parameters(): |
| if p.dim() < 2: continue |
| with torch.no_grad(): |
| abs_vals = p.abs().flatten() |
| stats.append({ |
| 'name': name, 'shape': list(p.shape), 'n': abs_vals.numel(), |
| 'mean': abs_vals.mean().item(), |
| 'median': abs_vals.median().item(), |
| 'q10': abs_vals.quantile(0.1).item(), |
| 'q90': abs_vals.quantile(0.9).item(), |
| 'q99': abs_vals.quantile(0.99).item(), |
| 'frac_below_0.01': (abs_vals < 0.01).float().mean().item(), |
| 'frac_below_0.05': (abs_vals < 0.05).float().mean().item(), |
| 'frac_above_0.5': (abs_vals > 0.5).float().mean().item(), |
| 'max': abs_vals.max().item(), |
| }) |
| return stats |
|
|
|
|
| |
| def attention_entropy(m, val_data, n_batches=5, bs=8, seq_len=256, device='cuda'): |
| """For each layer and head, compute the entropy of attention-weight distribution |
| averaged over queries. Entropy should be log(T) for uniform, 0 for argmax. |
| For our Gumbel hard-attention, score distribution is what matters. We compute |
| entropy of the *softmax* of raw integer scores (sharpness proxy).""" |
| from model_v16 import _get_tau |
| results = [] |
| with torch.no_grad(): |
| for li, blk in enumerate(m.blocks): |
| attn = blk.attn |
| per_head_entropies = [] |
| per_head_max_score = [] |
| for _ in range(n_batches): |
| x, _ = sample_eval_batch(val_data, bs, seq_len, device) |
| |
| xe = m.embed(x) |
| for k in range(li): |
| xe = m.blocks[k](xe) |
| B, T, D = xe.shape |
| H, Dh = attn.n_heads, attn.head_dim |
| Q = attn.q_proj(xe).view(B, T, H, Dh).transpose(1, 2) |
| K = attn.k_proj(xe).view(B, T, H, Dh).transpose(1, 2) |
| scores = torch.matmul(Q, K.transpose(-2, -1)) |
| pos = torch.arange(T, device=device).float() |
| dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() |
| alibi = attn.alibi_slopes_int.view(1, H, 1, 1).float() * dist.view(1, 1, T, T) |
| scores = scores - alibi |
| mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) |
| scores = scores.masked_fill(mask, -1e9) |
| |
| probs = F.softmax(scores, dim=-1) |
| |
| max_p = probs.max(dim=-1).values |
| entropies = -(probs * probs.clamp(min=1e-9).log()).sum(dim=-1) |
| per_head_entropies.append(entropies.mean(dim=(0, 2)).cpu().numpy()) |
| per_head_max_score.append(max_p.mean(dim=(0, 2)).cpu().numpy()) |
| ph_ent = np.stack(per_head_entropies).mean(axis=0) |
| ph_maxp = np.stack(per_head_max_score).mean(axis=0) |
| results.append({'layer': li, |
| 'entropy_per_head': ph_ent.tolist(), |
| 'max_prob_per_head': ph_maxp.tolist(), |
| 'mean_entropy': float(ph_ent.mean()), |
| 'mean_max_prob': float(ph_maxp.mean())}) |
| return results |
|
|
|
|
| |
| def student_teacher_similarity(student_m, teacher_m, val_data, n_batches=5, bs=16, seq_len=256, device='cuda'): |
| """Per-layer: how well does the student's ±1 hidden state match sign(teacher hidden)?""" |
| student_m.eval(); teacher_m.eval() |
| n_layers_s = len(student_m.blocks) |
| n_layers_t = len(teacher_m.blocks) |
| |
| sims = [[] for _ in range(min(n_layers_s, n_layers_t))] |
| with torch.no_grad(): |
| for _ in range(n_batches): |
| x, _ = sample_eval_batch(val_data, bs, seq_len, device) |
| |
| s = student_m.embed(x) |
| s_hiddens = [] |
| for blk in student_m.blocks: |
| s = blk(s) |
| s_hiddens.append(s.clone()) |
| |
| T_ids = x.shape[1] |
| t_pos = torch.arange(T_ids, device=device) |
| t = teacher_m.embed(x) + teacher_m.pos(t_pos) |
| t_hiddens = [] |
| for blk in teacher_m.blocks: |
| t = blk(t) |
| t_hiddens.append(t.clone()) |
| |
| for i in range(min(n_layers_s, n_layers_t)): |
| tg = torch.sign(t_hiddens[i]) |
| tg[tg == 0] = 1 |
| s_flat = s_hiddens[i].reshape(-1, s_hiddens[i].shape[-1]) |
| t_flat = tg.reshape(-1, tg.shape[-1]) |
| |
| |
| agree = (s_flat == t_flat).float().mean().item() |
| sims[i].append(agree) |
| per_layer = [{'layer': i, 'sign_agreement': float(np.mean(sims[i]))} |
| for i in range(len(sims))] |
| return per_layer |
|
|
|
|
| |
| def error_breakdown(m, val_data, n_batches=20, bs=32, seq_len=256, device='cuda'): |
| """Classify errors by character class.""" |
| m.eval() |
| per_char_correct = np.zeros(128) |
| per_char_total = np.zeros(128) |
| class_groups = { |
| 'space': {32}, |
| 'newline': {10}, |
| 'lowercase': set(range(97, 123)), |
| 'uppercase': set(range(65, 91)), |
| 'digit': set(range(48, 58)), |
| 'punct': {46, 44, 33, 63, 39, 34, 58, 59, 40, 41, 45}, |
| } |
| with torch.no_grad(): |
| for _ in range(n_batches): |
| x, y = sample_eval_batch(val_data, bs, seq_len, device) |
| logits, _ = m(x, y) |
| pred = logits.argmax(dim=-1) |
| for i in range(y.numel()): |
| t = y.flatten()[i].item() |
| p = pred.flatten()[i].item() |
| if t < 128: |
| per_char_total[t] += 1 |
| if p == t: per_char_correct[t] += 1 |
| per_class_acc = {} |
| for name, chars in class_groups.items(): |
| tot = sum(per_char_total[c] for c in chars) |
| cor = sum(per_char_correct[c] for c in chars) |
| per_class_acc[name] = {'accuracy': cor / max(tot, 1), 'n': int(tot)} |
| overall_tot = per_char_total.sum() |
| overall_cor = per_char_correct.sum() |
| return {'overall_accuracy': float(overall_cor / max(overall_tot, 1)), |
| 'per_class': per_class_acc} |
|
|
|
|
| |
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--student-ckpt', required=True) |
| ap.add_argument('--teacher-ckpt', default=None) |
| ap.add_argument('--data', default='/root/bitnet1/data/validation.bin') |
| ap.add_argument('--out', required=True) |
| ap.add_argument('--tau-eval', type=float, default=0.1, |
| help='Gumbel tau used for eval-mode forwards.') |
| args = ap.parse_args() |
|
|
| set_gumbel_tau(args.tau_eval) |
| val = np.memmap(args.data, dtype=np.uint8, mode='r') |
|
|
| print(f"Loading student {args.student_ckpt}") |
| student, s_ck = load_binary_ckpt(args.student_ckpt) |
| s_cfg = s_ck['args'] |
|
|
| out = { |
| 'student_ckpt': args.student_ckpt, |
| 'student_config': s_cfg, |
| 'student_step': s_ck.get('step'), |
| 'student_val_bpc': s_ck.get('val_bpc'), |
| 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), |
| } |
|
|
| print("A. Layer ablation BPC...") |
| out['layer_ablation'] = layer_ablation_bpc(student, val) |
| print(f" baseline {out['layer_ablation']['baseline_bpc']:.4f}, {len(out['layer_ablation']['per_layer'])} layers") |
|
|
| print("B. Weight saturation...") |
| out['weight_saturation'] = weight_saturation(student) |
| print(f" {len(out['weight_saturation'])} weight tensors analyzed") |
|
|
| print("C. Attention entropy...") |
| out['attention_entropy'] = attention_entropy(student, val) |
| print(f" {len(out['attention_entropy'])} layers analyzed") |
|
|
| print("E. Error breakdown...") |
| out['error_breakdown'] = error_breakdown(student, val) |
| print(f" overall acc {out['error_breakdown']['overall_accuracy']:.4f}") |
|
|
| if args.teacher_ckpt: |
| print(f"Loading teacher {args.teacher_ckpt}") |
| teacher, t_ck = load_fp32_ckpt(args.teacher_ckpt) |
| out['teacher_ckpt'] = args.teacher_ckpt |
| out['teacher_val_bpc'] = t_ck.get('val_bpc') |
| print("D. Student-teacher similarity...") |
| out['student_teacher_similarity'] = student_teacher_similarity(student, teacher, val) |
| print(f" done") |
|
|
| with open(args.out, 'w') as f: |
| json.dump(out, f, indent=2, default=str) |
| print(f"Wrote {args.out}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|