CausalGrok β€” Grokking-Favorable Training Artifact Archive

Complete artifact bundle for the project "Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training".

This Hugging Face repository is the full preservation archive: every model checkpoint, per-run training artifact, mechanistic-interpretability output, log, figure, and the paper source. It mirrors the heavy artifacts that do not fit in the source repository.


What this work studies

A small-sample study on Camelyon17 (WILDS hospital-shift OOD pathology classification). We train ResNet-18 under two regimes and ask whether grokking-favorable optimization (high weight decay, expanded init scale, Grokfast gradient-EMA filtering) changes the internal shortcut representation, even when out-of-distribution (OOD) accuracy does not improve.

  • Empirical result: across 14 runs, every model ungrokks β€” OOD accuracy peaks early and decays during training; no delayed-generalization (grokking) transition occurs.
  • Interventional result: activation steering along the dominant between-hospital direction v_s gives a monotonic OOD response in 4/5 grokking-favorable seeds vs 1/3 standard seeds (Mann-Whitney p=0.071, non-significant at this sample size). Read as shortcut concentration, not elimination under heavy regularization.

The grokking-favorable recipe (code)

Two training configurations, both pure cross-entropy + AdamW, 3000 epochs. They differ along three axes (a deliberate, disclosed confound).

# code/experiments/causalgrok_camelyon_v2.py β€” get_config()
def get_config(condition):
    base = dict(seed=42, n_train=300, batch_size=32, img_size=96,
                n_classes=2, log_every=50, device="cuda")
    if condition == "standard":
        base.update(condition="standard", lr=1e-3, weight_decay=1e-4,
                     n_epochs=3000, init_scale=1.0, use_grokfast=False)
    elif condition == "grokking":
        base.update(condition="grokking", lr=1e-3, weight_decay=5e-3,
                     n_epochs=3000, init_scale=4.0, use_grokfast=True,
                     grokfast_alpha=0.98, grokfast_lamb=2.0)
    return base

Init-scale rescaling (grokking-favorable only β€” every multi-dim weight tensor scaled 4Γ— at init):

if cfg["init_scale"] != 1.0:
    for name, p in model.named_parameters():
        if "weight" in name and p.dim() > 1:
            p.data *= cfg["init_scale"]

Grokfast EMA β€” amplifies the slow-varying gradient component (Lee et al. 2024, arXiv:2405.20233). Applied after loss.backward(), before optimizer.step():

# code/utils/grokfast.py β€” gradfilter_ema()
for name, p in model.named_parameters():
    if p.requires_grad and p.grad is not None:
        if name not in grads_ema:
            grads_ema[name] = p.grad.data.detach().clone()
        else:
            grads_ema[name] = grads_ema[name] * alpha + p.grad.data * (1 - alpha)  # alpha=0.98
        p.grad.data = p.grad.data + grads_ema[name] * lamb                          # lamb=2.0

Training loss is cross-entropy only. An IRM-style invariance penalty (Arjovsky et al. 2019) is computed every epoch across training-hospital environments as a diagnostic only β€” it is logged, never added to the loss:

criterion = nn.CrossEntropyLoss()
logits  = model(imgs)
loss    = criterion(logits, labels)        # pure CE; irm_weight = 0.0 for every reported run
loss.backward()
if cfg["use_grokfast"]:
    grads_ema = gradfilter_ema(model, grads_ema, alpha=0.98, lamb=2.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

Checkpoints are saved every 200 epochs (15 periodic + 1 final.pt per run).

Hyperparameter Standard Grokking-favorable
Optimizer AdamW AdamW
Learning rate 1e-3 1e-3
Weight decay 1e-4 5e-3 (50Γ—)
Init scale 1.0 4.0
Grokfast EMA off on (alpha 0.98, lamb 2.0)
Grad clip (max-norm) 1.0 1.0
Batch size 32 32
Epochs 3000 3000
IRM weight in loss 0.0 (diagnostic) 0.0 (diagnostic)

Model: ResNet-18 (timm, no ImageNet pretraining), 96Γ—96 input, 2-class head, 11,177,538 parameters.


Mechanistic-interpretability suite (code)

Four probes on avgpool features (D=512); run on the saved checkpoints.

  • M1 β€” layer-wise linear probing (code/experiments/mechinterp_m1.py): logistic-regression hospital and tumor probes at six ResNet stages.
  • M4 β€” subspace ablation (code/experiments/mechinterp_m4_ablation.py): project features orthogonal to a ~35-dim LDA-style hospital subspace, re-classify with the original head.
  • M5 β€” activation steering (code/experiments/mechinterp_m5_steering.py): steer h' = h + alpha Β· sigma Β· v_s along the dominant between-hospital direction, alpha ∈ [-3, +3].
  • M6 β€” targeted neuron ablation (code/experiments/mechinterp_m6_neuron_ablation.py): zero top-K hospital-discriminating channels vs random-K and morphology-K controls, K ∈ {0,4,8,16,32,64,128,256}.

Repository layout

runs/<run_id>/
β”œβ”€β”€ config.json                 # full hyperparameter config
β”œβ”€β”€ checkpoints/ep00200.pt … ep03000.pt, final.pt   # ~44 MB each
β”œβ”€β”€ results/history.json        # per-checkpoint metrics (61 rows)
β”œβ”€β”€ results/summary.json        # final-summary fields
β”œβ”€β”€ logs/train.log              # launch command + per-checkpoint log lines
β”œβ”€β”€ wandb/                      # offline wandb run metadata
└── mechinterp/                 # m1/m4/m5/m6 JSON + PNG outputs

figures/   # 7 paper figures (PNG + PDF) + m6_summary.csv (88-row results table)
paper/     # main.tex, example_paper.bib, compiled PDF, style files
code/      # training + mechanistic-interpretability source
logs/      # top-level training and M1/M4/M5/M6 driver logs
docs/      # TRAINING_DETAILS.md β€” exhaustive code/hyperparameter/metric/results reference

.pt checkpoints total ~10 GB across 240 files.


Run inventory (n=1000)

Cond. Seed Run ID Peak OOD Peak ep Final OOD
Grok 7 20260508-183413_grokking_n1000_s7 0.6876 50 0.5882
Grok 42 20260505-080445_grokking_n1000_s42 0.7336 350 0.6639
Grok 123 20260505-100720_grokking_n1000_s123 0.7270 350 0.6447
Grok 456 20260505-100720_grokking_n1000_s456 0.6722 1100 0.5224
Grok 2024 20260508-183413_grokking_n1000_s2024 0.7056 400 0.5506
Std 42 20260505-100720_standard_n1000_s42 0.7615 1 0.6482
Std 123 20260508-183413_standard_n1000_s123 0.8880* 1 0.6645
Std 456 20260508-183413_standard_n1000_s456 0.7450 1050 0.5783

* Std s123 peaks at epoch 1 on the random initialization (artifact). Additional runs at n=300 and n=500; full 14-run table and all metrics in docs/TRAINING_DETAILS.md.


Loading a checkpoint

import torch, timm

model = timm.create_model("resnet18", pretrained=False, num_classes=2)
sd = torch.load("runs/20260505-080445_grokking_n1000_s42/checkpoints/ep00400.pt",
                map_location="cpu")
model.load_state_dict(sd)
model.eval()

Checkpoints are plain state_dict files; the init-scale rescaling is baked into the trained weights.

from huggingface_hub import hf_hub_download, snapshot_download
# one file
p = hf_hub_download("nileshsarkar-ai/CausalGrok",
                    "runs/20260505-080445_grokking_n1000_s42/checkpoints/ep00400.pt")
# whole archive
snapshot_download("nileshsarkar-ai/CausalGrok", local_dir="CausalGrok")

Dataset

Experiments use Camelyon17 from the WILDS benchmark. The raw dataset (~10 GB) is not mirrored here (public benchmark); code/utils/camelyon_data.py::get_camelyon_subsets auto-downloads it via the wilds package.


Citation

@misc{causalgrok2026,
  title  = {Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training},
  author = {Sarkar, Nilesh},
  year   = {2026},
  url    = {https://github.com/nileshsarkar-ai/CausalGrok}
}

License

CC BY 4.0. The Camelyon17 dataset retains its own WILDS license; this archive does not redistribute it.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for nileshsarkar-ai/CausalGrok