bitnet-1bitllm / _bench_300m_1gpu.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Single-GPU 300M training-step benchmark.
No DeepSpeed, no accelerate — just torch + the model. Measures forward +
backward + optimizer step throughput on 1× RTX 5090 (32 GB).
Run:
cd /root/bitnet1/code
/venv/main/bin/python _bench_300m_1gpu.py [--max-steps 50] [--per-gpu-bs 1] [--grad-accum 4]
"""
import os, time, math, argparse
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
os.environ.setdefault('TORCHINDUCTOR_CACHE_DIR', '/root/bitnet1/inductor_cache')
import numpy as np
import torch
import torch.nn.functional as F
import torch._inductor.config as _ic
_ic.max_autotune_gemm_backends = "ATEN"
_ic.coordinate_descent_tuning = True
_ic.epilogue_fusion = True
from model_v47b import BitLMv47B
import model_v16 as v16
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--max-steps', type=int, default=50)
ap.add_argument('--per-gpu-bs', type=int, default=1)
ap.add_argument('--grad-accum', type=int, default=4)
ap.add_argument('--checkpoint', action='store_true', default=True)
ap.add_argument('--no-checkpoint', dest='checkpoint', action='store_false')
ap.add_argument('--eager', action='store_true')
args = ap.parse_args()
device = 'cuda:0'
torch.cuda.set_device(0)
# 300M architecture (same as production)
vocab_size = 16384
d_model = 1536; n_layers = 16; n_heads = 24; d_ff = 1536
seq_len = 2048
bs = args.per_gpu_bs
ga = args.grad_accum
eff_tok = bs * ga * seq_len
print(f'[bench] arch: d_model={d_model} n_layers={n_layers} n_heads={n_heads} '
f'd_ff={d_ff} seq_len={seq_len}', flush=True)
print(f'[bench] bs={bs} grad_accum={ga} eff_tok/step={eff_tok}', flush=True)
v16.set_gumbel_tau(0.1)
m = BitLMv47B(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers,
n_heads=n_heads, d_ff=d_ff, max_seq_len=seq_len,
slope_groups=8).to(device)
n_params = sum(p.numel() for p in m.parameters())
print(f'[bench] {n_params/1e6:.1f}M params', flush=True)
# Cast ALiBi to fp32
for blk in m.blocks:
if blk.attn.alibi_bias.dtype != torch.float32:
blk.attn.alibi_bias = blk.attn.alibi_bias.float()
if not args.eager:
m = torch.compile(m, mode='default', dynamic=False, fullgraph=False)
print('[bench] compiled', flush=True)
body = [p for n, p in m.named_parameters() if 'embed' not in n
and 'codebook' not in n and 'logit_scale' not in n]
small = [p for n, p in m.named_parameters() if 'embed' in n
or 'codebook' in n or 'logit_scale' in n]
opt = torch.optim.AdamW(
[{'params': body, 'weight_decay': 0.1},
{'params': small, 'weight_decay': 0.0}],
lr=4e-4, betas=(0.9, 0.95), eps=1e-8, fused=True,
)
# Random data — we're measuring throughput, not loss
rng = np.random.RandomState(42)
train_arr = np.memmap('/root/bitnet1/data_fineweb_edu/train.bin',
dtype=np.uint16, mode='r')
def get_batch():
ix = rng.randint(0, len(train_arr) - seq_len - 1, size=bs)
x = torch.empty(bs, seq_len, dtype=torch.int64, pin_memory=True)
y = torch.empty(bs, seq_len, dtype=torch.int64, pin_memory=True)
for i, s in enumerate(ix):
x[i].copy_(torch.from_numpy(train_arr[s:s+seq_len].astype(np.int64)))
y[i].copy_(torch.from_numpy(train_arr[s+1:s+1+seq_len].astype(np.int64)))
return x.to(device, non_blocking=True), y.to(device, non_blocking=True)
print(f'[bench] running {args.max_steps} steps...', flush=True)
t0 = time.time()
losses = []
train_started = None
for step in range(1, args.max_steps + 1):
opt.zero_grad(set_to_none=True)
accum = 0.0
for _ in range(ga):
x, y = get_batch()
with torch.autocast('cuda', dtype=torch.bfloat16):
logits, _ = m(x, None, use_checkpoint=args.checkpoint)
loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1)) / ga
loss.backward()
accum += loss.detach().float().item()
torch.nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g['params']], 1.0)
opt.step()
torch.cuda.synchronize()
if step == 1:
print(f'[bench] step 1 done (warmup+compile): {time.time()-t0:.1f}s, '
f'loss={accum:.3f}', flush=True)
train_started = time.time()
elif step == 5:
# Reset timer to skip remaining compile
train_started = time.time()
print(f'[bench] step 5 loss={accum:.3f}', flush=True)
elif step % 10 == 0:
elapsed = time.time() - train_started
steps_done = step - 5
tok = steps_done * eff_tok
print(f'[bench] step {step} loss={accum:.3f} '
f'{tok/max(1,elapsed):.0f} tok/s ({elapsed:.1f}s)', flush=True)
elapsed = time.time() - train_started
steps_done = args.max_steps - 5
tok = steps_done * eff_tok
print(f'\n[bench] FINAL: {tok/max(1,elapsed):.0f} tok/s steady-state '
f'({steps_done} steps, {elapsed:.1f}s, {tok:,} tokens, {n_params/1e6:.1f}M params)',
flush=True)
if __name__ == '__main__':
main()