CausalGrok β Grokking-Favorable Training Artifact Archive
Complete artifact bundle for the project "Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training".
This Hugging Face repository is the full preservation archive: every model checkpoint, per-run training artifact, mechanistic-interpretability output, log, figure, and the paper source. It mirrors the heavy artifacts that do not fit in the source repository.
- Code (browsable): https://github.com/nileshsarkar-ai/CausalGrok
- This archive (heavy artifacts): 240 checkpoints (~10 GB), run JSONs, logs, figures
What this work studies
A small-sample study on Camelyon17 (WILDS hospital-shift OOD pathology classification). We train ResNet-18 under two regimes and ask whether grokking-favorable optimization (high weight decay, expanded init scale, Grokfast gradient-EMA filtering) changes the internal shortcut representation, even when out-of-distribution (OOD) accuracy does not improve.
- Empirical result: across 14 runs, every model ungrokks β OOD accuracy peaks early and decays during training; no delayed-generalization (grokking) transition occurs.
- Interventional result: activation steering along the dominant between-hospital direction
v_sgives a monotonic OOD response in 4/5 grokking-favorable seeds vs 1/3 standard seeds (Mann-Whitney p=0.071, non-significant at this sample size). Read as shortcut concentration, not elimination under heavy regularization.
The grokking-favorable recipe (code)
Two training configurations, both pure cross-entropy + AdamW, 3000 epochs. They differ along three axes (a deliberate, disclosed confound).
# code/experiments/causalgrok_camelyon_v2.py β get_config()
def get_config(condition):
base = dict(seed=42, n_train=300, batch_size=32, img_size=96,
n_classes=2, log_every=50, device="cuda")
if condition == "standard":
base.update(condition="standard", lr=1e-3, weight_decay=1e-4,
n_epochs=3000, init_scale=1.0, use_grokfast=False)
elif condition == "grokking":
base.update(condition="grokking", lr=1e-3, weight_decay=5e-3,
n_epochs=3000, init_scale=4.0, use_grokfast=True,
grokfast_alpha=0.98, grokfast_lamb=2.0)
return base
Init-scale rescaling (grokking-favorable only β every multi-dim weight tensor scaled 4Γ at init):
if cfg["init_scale"] != 1.0:
for name, p in model.named_parameters():
if "weight" in name and p.dim() > 1:
p.data *= cfg["init_scale"]
Grokfast EMA β amplifies the slow-varying gradient component (Lee et al. 2024, arXiv:2405.20233). Applied after loss.backward(), before optimizer.step():
# code/utils/grokfast.py β gradfilter_ema()
for name, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
if name not in grads_ema:
grads_ema[name] = p.grad.data.detach().clone()
else:
grads_ema[name] = grads_ema[name] * alpha + p.grad.data * (1 - alpha) # alpha=0.98
p.grad.data = p.grad.data + grads_ema[name] * lamb # lamb=2.0
Training loss is cross-entropy only. An IRM-style invariance penalty (Arjovsky et al. 2019) is computed every epoch across training-hospital environments as a diagnostic only β it is logged, never added to the loss:
criterion = nn.CrossEntropyLoss()
logits = model(imgs)
loss = criterion(logits, labels) # pure CE; irm_weight = 0.0 for every reported run
loss.backward()
if cfg["use_grokfast"]:
grads_ema = gradfilter_ema(model, grads_ema, alpha=0.98, lamb=2.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Checkpoints are saved every 200 epochs (15 periodic + 1 final.pt per run).
| Hyperparameter | Standard | Grokking-favorable |
|---|---|---|
| Optimizer | AdamW | AdamW |
| Learning rate | 1e-3 | 1e-3 |
| Weight decay | 1e-4 | 5e-3 (50Γ) |
| Init scale | 1.0 | 4.0 |
| Grokfast EMA | off | on (alpha 0.98, lamb 2.0) |
| Grad clip (max-norm) | 1.0 | 1.0 |
| Batch size | 32 | 32 |
| Epochs | 3000 | 3000 |
| IRM weight in loss | 0.0 (diagnostic) | 0.0 (diagnostic) |
Model: ResNet-18 (timm, no ImageNet pretraining), 96Γ96 input, 2-class head, 11,177,538 parameters.
Mechanistic-interpretability suite (code)
Four probes on avgpool features (D=512); run on the saved checkpoints.
- M1 β layer-wise linear probing (
code/experiments/mechinterp_m1.py): logistic-regression hospital and tumor probes at six ResNet stages. - M4 β subspace ablation (
code/experiments/mechinterp_m4_ablation.py): project features orthogonal to a ~35-dim LDA-style hospital subspace, re-classify with the original head. - M5 β activation steering (
code/experiments/mechinterp_m5_steering.py): steerh' = h + alpha Β· sigma Β· v_salong the dominant between-hospital direction,alpha β [-3, +3]. - M6 β targeted neuron ablation (
code/experiments/mechinterp_m6_neuron_ablation.py): zero top-K hospital-discriminating channels vs random-K and morphology-K controls,K β {0,4,8,16,32,64,128,256}.
Repository layout
runs/<run_id>/
βββ config.json # full hyperparameter config
βββ checkpoints/ep00200.pt β¦ ep03000.pt, final.pt # ~44 MB each
βββ results/history.json # per-checkpoint metrics (61 rows)
βββ results/summary.json # final-summary fields
βββ logs/train.log # launch command + per-checkpoint log lines
βββ wandb/ # offline wandb run metadata
βββ mechinterp/ # m1/m4/m5/m6 JSON + PNG outputs
figures/ # 7 paper figures (PNG + PDF) + m6_summary.csv (88-row results table)
paper/ # main.tex, example_paper.bib, compiled PDF, style files
code/ # training + mechanistic-interpretability source
logs/ # top-level training and M1/M4/M5/M6 driver logs
docs/ # TRAINING_DETAILS.md β exhaustive code/hyperparameter/metric/results reference
.pt checkpoints total ~10 GB across 240 files.
Run inventory (n=1000)
| Cond. | Seed | Run ID | Peak OOD | Peak ep | Final OOD |
|---|---|---|---|---|---|
| Grok | 7 | 20260508-183413_grokking_n1000_s7 | 0.6876 | 50 | 0.5882 |
| Grok | 42 | 20260505-080445_grokking_n1000_s42 | 0.7336 | 350 | 0.6639 |
| Grok | 123 | 20260505-100720_grokking_n1000_s123 | 0.7270 | 350 | 0.6447 |
| Grok | 456 | 20260505-100720_grokking_n1000_s456 | 0.6722 | 1100 | 0.5224 |
| Grok | 2024 | 20260508-183413_grokking_n1000_s2024 | 0.7056 | 400 | 0.5506 |
| Std | 42 | 20260505-100720_standard_n1000_s42 | 0.7615 | 1 | 0.6482 |
| Std | 123 | 20260508-183413_standard_n1000_s123 | 0.8880* | 1 | 0.6645 |
| Std | 456 | 20260508-183413_standard_n1000_s456 | 0.7450 | 1050 | 0.5783 |
* Std s123 peaks at epoch 1 on the random initialization (artifact). Additional runs at n=300 and n=500; full 14-run table and all metrics in docs/TRAINING_DETAILS.md.
Loading a checkpoint
import torch, timm
model = timm.create_model("resnet18", pretrained=False, num_classes=2)
sd = torch.load("runs/20260505-080445_grokking_n1000_s42/checkpoints/ep00400.pt",
map_location="cpu")
model.load_state_dict(sd)
model.eval()
Checkpoints are plain state_dict files; the init-scale rescaling is baked into the trained weights.
from huggingface_hub import hf_hub_download, snapshot_download
# one file
p = hf_hub_download("nileshsarkar-ai/CausalGrok",
"runs/20260505-080445_grokking_n1000_s42/checkpoints/ep00400.pt")
# whole archive
snapshot_download("nileshsarkar-ai/CausalGrok", local_dir="CausalGrok")
Dataset
Experiments use Camelyon17 from the WILDS benchmark. The raw dataset (~10 GB) is not mirrored here (public benchmark); code/utils/camelyon_data.py::get_camelyon_subsets auto-downloads it via the wilds package.
Citation
@misc{causalgrok2026,
title = {Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training},
author = {Sarkar, Nilesh},
year = {2026},
url = {https://github.com/nileshsarkar-ai/CausalGrok}
}
License
CC BY 4.0. The Camelyon17 dataset retains its own WILDS license; this archive does not redistribute it.