| """Single-GPU 300M training-step benchmark. |
| |
| No DeepSpeed, no accelerate — just torch + the model. Measures forward + |
| backward + optimizer step throughput on 1× RTX 5090 (32 GB). |
| |
| Run: |
| cd /root/bitnet1/code |
| /venv/main/bin/python _bench_300m_1gpu.py [--max-steps 50] [--per-gpu-bs 1] [--grad-accum 4] |
| """ |
| import os, time, math, argparse |
| os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True') |
| os.environ.setdefault('TORCHINDUCTOR_CACHE_DIR', '/root/bitnet1/inductor_cache') |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torch._inductor.config as _ic |
| _ic.max_autotune_gemm_backends = "ATEN" |
| _ic.coordinate_descent_tuning = True |
| _ic.epilogue_fusion = True |
|
|
| from model_v47b import BitLMv47B |
| import model_v16 as v16 |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--max-steps', type=int, default=50) |
| ap.add_argument('--per-gpu-bs', type=int, default=1) |
| ap.add_argument('--grad-accum', type=int, default=4) |
| ap.add_argument('--checkpoint', action='store_true', default=True) |
| ap.add_argument('--no-checkpoint', dest='checkpoint', action='store_false') |
| ap.add_argument('--eager', action='store_true') |
| args = ap.parse_args() |
|
|
| device = 'cuda:0' |
| torch.cuda.set_device(0) |
|
|
| |
| vocab_size = 16384 |
| d_model = 1536; n_layers = 16; n_heads = 24; d_ff = 1536 |
| seq_len = 2048 |
| bs = args.per_gpu_bs |
| ga = args.grad_accum |
| eff_tok = bs * ga * seq_len |
|
|
| print(f'[bench] arch: d_model={d_model} n_layers={n_layers} n_heads={n_heads} ' |
| f'd_ff={d_ff} seq_len={seq_len}', flush=True) |
| print(f'[bench] bs={bs} grad_accum={ga} eff_tok/step={eff_tok}', flush=True) |
|
|
| v16.set_gumbel_tau(0.1) |
| m = BitLMv47B(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, |
| n_heads=n_heads, d_ff=d_ff, max_seq_len=seq_len, |
| slope_groups=8).to(device) |
| n_params = sum(p.numel() for p in m.parameters()) |
| print(f'[bench] {n_params/1e6:.1f}M params', flush=True) |
|
|
| |
| for blk in m.blocks: |
| if blk.attn.alibi_bias.dtype != torch.float32: |
| blk.attn.alibi_bias = blk.attn.alibi_bias.float() |
|
|
| if not args.eager: |
| m = torch.compile(m, mode='default', dynamic=False, fullgraph=False) |
| print('[bench] compiled', flush=True) |
|
|
| body = [p for n, p in m.named_parameters() if 'embed' not in n |
| and 'codebook' not in n and 'logit_scale' not in n] |
| small = [p for n, p in m.named_parameters() if 'embed' in n |
| or 'codebook' in n or 'logit_scale' in n] |
| opt = torch.optim.AdamW( |
| [{'params': body, 'weight_decay': 0.1}, |
| {'params': small, 'weight_decay': 0.0}], |
| lr=4e-4, betas=(0.9, 0.95), eps=1e-8, fused=True, |
| ) |
|
|
| |
| rng = np.random.RandomState(42) |
| train_arr = np.memmap('/root/bitnet1/data_fineweb_edu/train.bin', |
| dtype=np.uint16, mode='r') |
|
|
| def get_batch(): |
| ix = rng.randint(0, len(train_arr) - seq_len - 1, size=bs) |
| x = torch.empty(bs, seq_len, dtype=torch.int64, pin_memory=True) |
| y = torch.empty(bs, seq_len, dtype=torch.int64, pin_memory=True) |
| for i, s in enumerate(ix): |
| x[i].copy_(torch.from_numpy(train_arr[s:s+seq_len].astype(np.int64))) |
| y[i].copy_(torch.from_numpy(train_arr[s+1:s+1+seq_len].astype(np.int64))) |
| return x.to(device, non_blocking=True), y.to(device, non_blocking=True) |
|
|
| print(f'[bench] running {args.max_steps} steps...', flush=True) |
| t0 = time.time() |
| losses = [] |
| train_started = None |
| for step in range(1, args.max_steps + 1): |
| opt.zero_grad(set_to_none=True) |
| accum = 0.0 |
| for _ in range(ga): |
| x, y = get_batch() |
| with torch.autocast('cuda', dtype=torch.bfloat16): |
| logits, _ = m(x, None, use_checkpoint=args.checkpoint) |
| loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1)) / ga |
| loss.backward() |
| accum += loss.detach().float().item() |
| torch.nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g['params']], 1.0) |
| opt.step() |
| torch.cuda.synchronize() |
|
|
| if step == 1: |
| print(f'[bench] step 1 done (warmup+compile): {time.time()-t0:.1f}s, ' |
| f'loss={accum:.3f}', flush=True) |
| train_started = time.time() |
| elif step == 5: |
| |
| train_started = time.time() |
| print(f'[bench] step 5 loss={accum:.3f}', flush=True) |
| elif step % 10 == 0: |
| elapsed = time.time() - train_started |
| steps_done = step - 5 |
| tok = steps_done * eff_tok |
| print(f'[bench] step {step} loss={accum:.3f} ' |
| f'{tok/max(1,elapsed):.0f} tok/s ({elapsed:.1f}s)', flush=True) |
|
|
| elapsed = time.time() - train_started |
| steps_done = args.max_steps - 5 |
| tok = steps_done * eff_tok |
| print(f'\n[bench] FINAL: {tok/max(1,elapsed):.0f} tok/s steady-state ' |
| f'({steps_done} steps, {elapsed:.1f}s, {tok:,} tokens, {n_params/1e6:.1f}M params)', |
| flush=True) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|