bitnet-1bitllm / analyze.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Diagnostic tests on trained binary and FP32 models.
Outputs structured JSON that analyze_report.py compiles into a readable report.
Each test tries to reveal *mechanism*, not just measure BPC.
"""
import argparse, json, math, os, time
import numpy as np
import torch
import torch.nn.functional as F
from model_v18 import BitLMv18
from model_fp32 import FP32LM
from model_v16 import set_gumbel_tau
def load_binary_ckpt(path, device='cuda'):
ck = torch.load(path, map_location=device, weights_only=False)
cfg = ck['args']
m = BitLMv18(
vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'],
).to(device)
m.load_state_dict(ck['model'])
m.eval()
return m, ck
def load_fp32_ckpt(path, device='cuda'):
ck = torch.load(path, map_location=device, weights_only=False)
cfg = ck['args']
m = FP32LM(
vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'],
).to(device)
m.load_state_dict(ck['model'])
m.eval()
return m, ck
def sample_eval_batch(data, batch_size, seq_len, device='cuda'):
ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
x = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in ix]).to(device)
y = torch.stack([torch.from_numpy(data[i+1:i+1+seq_len].astype(np.int64)) for i in ix]).to(device)
return x, y
# ---------------- Test A: Layer ablation ----------------
def layer_ablation_bpc(m, val_data, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""Zero out each layer's contribution (residual-only), measure BPC delta."""
# Baseline
m.eval()
base_losses = []
with torch.no_grad():
for _ in range(n_batches):
x, y = sample_eval_batch(val_data, bs, seq_len, device)
_, loss = m(x, y)
base_losses.append(loss.item())
base_bpc = float(np.mean(base_losses)) / math.log(2)
# For each layer, replace its forward with identity (skip connection only)
results = []
for li in range(len(m.blocks)):
original = m.blocks[li].forward
# Wrap forward to return x unchanged
m.blocks[li].forward = lambda x: x
with torch.no_grad():
abl_losses = []
for _ in range(n_batches):
x, y = sample_eval_batch(val_data, bs, seq_len, device)
_, loss = m(x, y)
abl_losses.append(loss.item())
m.blocks[li].forward = original
abl_bpc = float(np.mean(abl_losses)) / math.log(2)
results.append({'layer': li, 'baseline_bpc': base_bpc, 'ablated_bpc': abl_bpc,
'delta_bpc': abl_bpc - base_bpc})
return {'baseline_bpc': base_bpc, 'per_layer': results}
# ---------------- Test B: Weight saturation / flip-flop potential ----------------
def weight_saturation(m):
"""For each 2D weight tensor, compute the distribution of |latent|.
High |latent| = 'locked sign' (won't flip easily). Near zero = 'flippable'.
Returns per-parameter distribution stats.
"""
stats = []
for name, p in m.named_parameters():
if p.dim() < 2: continue
with torch.no_grad():
abs_vals = p.abs().flatten()
stats.append({
'name': name, 'shape': list(p.shape), 'n': abs_vals.numel(),
'mean': abs_vals.mean().item(),
'median': abs_vals.median().item(),
'q10': abs_vals.quantile(0.1).item(),
'q90': abs_vals.quantile(0.9).item(),
'q99': abs_vals.quantile(0.99).item(),
'frac_below_0.01': (abs_vals < 0.01).float().mean().item(),
'frac_below_0.05': (abs_vals < 0.05).float().mean().item(),
'frac_above_0.5': (abs_vals > 0.5).float().mean().item(),
'max': abs_vals.max().item(),
})
return stats
# ---------------- Test C: Attention entropy per head/layer ----------------
def attention_entropy(m, val_data, n_batches=5, bs=8, seq_len=256, device='cuda'):
"""For each layer and head, compute the entropy of attention-weight distribution
averaged over queries. Entropy should be log(T) for uniform, 0 for argmax.
For our Gumbel hard-attention, score distribution is what matters. We compute
entropy of the *softmax* of raw integer scores (sharpness proxy)."""
from model_v16 import _get_tau
results = []
with torch.no_grad():
for li, blk in enumerate(m.blocks):
attn = blk.attn
per_head_entropies = []
per_head_max_score = []
for _ in range(n_batches):
x, _ = sample_eval_batch(val_data, bs, seq_len, device)
# Mirror attention forward but capture scores
xe = m.embed(x)
for k in range(li):
xe = m.blocks[k](xe)
B, T, D = xe.shape
H, Dh = attn.n_heads, attn.head_dim
Q = attn.q_proj(xe).view(B, T, H, Dh).transpose(1, 2)
K = attn.k_proj(xe).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1))
pos = torch.arange(T, device=device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
alibi = attn.alibi_slopes_int.view(1, H, 1, 1).float() * dist.view(1, 1, T, T)
scores = scores - alibi
mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(mask, -1e9)
# Per head: take argmax concentration = max softmax prob averaged over queries
probs = F.softmax(scores, dim=-1) # (B, H, T, T)
# For each (q, h), compute max prob and entropy
max_p = probs.max(dim=-1).values # (B, H, T)
entropies = -(probs * probs.clamp(min=1e-9).log()).sum(dim=-1) # (B, H, T)
per_head_entropies.append(entropies.mean(dim=(0, 2)).cpu().numpy())
per_head_max_score.append(max_p.mean(dim=(0, 2)).cpu().numpy())
ph_ent = np.stack(per_head_entropies).mean(axis=0)
ph_maxp = np.stack(per_head_max_score).mean(axis=0)
results.append({'layer': li,
'entropy_per_head': ph_ent.tolist(),
'max_prob_per_head': ph_maxp.tolist(),
'mean_entropy': float(ph_ent.mean()),
'mean_max_prob': float(ph_maxp.mean())})
return results
# ---------------- Test D: Student-teacher representation similarity ----------------
def student_teacher_similarity(student_m, teacher_m, val_data, n_batches=5, bs=16, seq_len=256, device='cuda'):
"""Per-layer: how well does the student's ±1 hidden state match sign(teacher hidden)?"""
student_m.eval(); teacher_m.eval()
n_layers_s = len(student_m.blocks)
n_layers_t = len(teacher_m.blocks)
# We assume aligned architectures (student layers == teacher layers)
sims = [[] for _ in range(min(n_layers_s, n_layers_t))]
with torch.no_grad():
for _ in range(n_batches):
x, _ = sample_eval_batch(val_data, bs, seq_len, device)
# Student path with hidden snapshots
s = student_m.embed(x)
s_hiddens = []
for blk in student_m.blocks:
s = blk(s)
s_hiddens.append(s.clone()) # ±1 valued
# Teacher path
T_ids = x.shape[1]
t_pos = torch.arange(T_ids, device=device)
t = teacher_m.embed(x) + teacher_m.pos(t_pos)
t_hiddens = []
for blk in teacher_m.blocks:
t = blk(t)
t_hiddens.append(t.clone())
# Compare: student vs sign(teacher)
for i in range(min(n_layers_s, n_layers_t)):
tg = torch.sign(t_hiddens[i])
tg[tg == 0] = 1
s_flat = s_hiddens[i].reshape(-1, s_hiddens[i].shape[-1])
t_flat = tg.reshape(-1, tg.shape[-1])
# Cosine similarity: (a · b) / (|a| |b|); for ±1 it simplifies to
# agreement fraction × 2 - 1
agree = (s_flat == t_flat).float().mean().item()
sims[i].append(agree)
per_layer = [{'layer': i, 'sign_agreement': float(np.mean(sims[i]))}
for i in range(len(sims))]
return per_layer
# ---------------- Test E: Prediction error breakdown ----------------
def error_breakdown(m, val_data, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""Classify errors by character class."""
m.eval()
per_char_correct = np.zeros(128)
per_char_total = np.zeros(128)
class_groups = {
'space': {32},
'newline': {10},
'lowercase': set(range(97, 123)),
'uppercase': set(range(65, 91)),
'digit': set(range(48, 58)),
'punct': {46, 44, 33, 63, 39, 34, 58, 59, 40, 41, 45},
}
with torch.no_grad():
for _ in range(n_batches):
x, y = sample_eval_batch(val_data, bs, seq_len, device)
logits, _ = m(x, y)
pred = logits.argmax(dim=-1)
for i in range(y.numel()):
t = y.flatten()[i].item()
p = pred.flatten()[i].item()
if t < 128:
per_char_total[t] += 1
if p == t: per_char_correct[t] += 1
per_class_acc = {}
for name, chars in class_groups.items():
tot = sum(per_char_total[c] for c in chars)
cor = sum(per_char_correct[c] for c in chars)
per_class_acc[name] = {'accuracy': cor / max(tot, 1), 'n': int(tot)}
overall_tot = per_char_total.sum()
overall_cor = per_char_correct.sum()
return {'overall_accuracy': float(overall_cor / max(overall_tot, 1)),
'per_class': per_class_acc}
# ---------------- Main ----------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--student-ckpt', required=True)
ap.add_argument('--teacher-ckpt', default=None)
ap.add_argument('--data', default='/root/bitnet1/data/validation.bin')
ap.add_argument('--out', required=True)
ap.add_argument('--tau-eval', type=float, default=0.1,
help='Gumbel tau used for eval-mode forwards.')
args = ap.parse_args()
set_gumbel_tau(args.tau_eval)
val = np.memmap(args.data, dtype=np.uint8, mode='r')
print(f"Loading student {args.student_ckpt}")
student, s_ck = load_binary_ckpt(args.student_ckpt)
s_cfg = s_ck['args']
out = {
'student_ckpt': args.student_ckpt,
'student_config': s_cfg,
'student_step': s_ck.get('step'),
'student_val_bpc': s_ck.get('val_bpc'),
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
}
print("A. Layer ablation BPC...")
out['layer_ablation'] = layer_ablation_bpc(student, val)
print(f" baseline {out['layer_ablation']['baseline_bpc']:.4f}, {len(out['layer_ablation']['per_layer'])} layers")
print("B. Weight saturation...")
out['weight_saturation'] = weight_saturation(student)
print(f" {len(out['weight_saturation'])} weight tensors analyzed")
print("C. Attention entropy...")
out['attention_entropy'] = attention_entropy(student, val)
print(f" {len(out['attention_entropy'])} layers analyzed")
print("E. Error breakdown...")
out['error_breakdown'] = error_breakdown(student, val)
print(f" overall acc {out['error_breakdown']['overall_accuracy']:.4f}")
if args.teacher_ckpt:
print(f"Loading teacher {args.teacher_ckpt}")
teacher, t_ck = load_fp32_ckpt(args.teacher_ckpt)
out['teacher_ckpt'] = args.teacher_ckpt
out['teacher_val_bpc'] = t_ck.get('val_bpc')
print("D. Student-teacher similarity...")
out['student_teacher_similarity'] = student_teacher_similarity(student, teacher, val)
print(f" done")
with open(args.out, 'w') as f:
json.dump(out, f, indent=2, default=str)
print(f"Wrote {args.out}")
if __name__ == '__main__':
main()