A world model with persistent mechanism memory for physical reasoning and out-of-distribution adaptation.
MechJEPA extends I-JEPA with:
- Mechanism Codebook β a learned library of stable physical interaction patterns (push, collide, support)
- Action-Conditioned Dynamics β AdaLN-based action conditioning via LeWorldModel
- System M β surprise-triggered online adaptation for OOD robustness
This is the first demonstration of an A-B-M (AnticipateβBehaveβModulate) agent loop on the Push-T environment.
Fully closed-loop evaluation using real Push-T environment renders (like LeWorldModel).
At every step: env.render() β VideoSAUR DINOv2 ViT β slot corrector β 4 slots β CEM planner β action β env.step()
Three conditions: Expert (left), Frozen MechJEPA (centre), A-B-M with System M (right).
Trained on 18,685 expert episodes (β1.98M samples) from the Push-T dataset with action conditioning:
| Metric | Value |
|---|---|
| Val Loss | 0.0018 |
| Batch Size | 4096 (H100) |
| GPU Utilization | 94β96% |
| Training Epochs | 50 |
CEM planner optimises action sequences directly in the latent slot space:
| Mean Latent Error | Median | Max | |
|---|---|---|---|
| CEM Planner | 0.0034 | 0.0032 | 0.0058 |
| Random Actions | 0.0434 | 0.0472 | 0.0869 |
| Improvement | 12.8Γ |
Distribution shift: slot observations scaled by Ξ±=1.4 (simulating a 40% heavier/larger block)
| Frozen Model | A-B-M Agent | |
|---|---|---|
| Mean Surprise | 0.0340 | 0.0090 (3.8Γβ) |
| Mean Plan Error | 0.0483 | 0.0169 (2.9Γβ) |
| Total Adaptations | 0 | 29 |
| Episode | Frozen | A-B-M | Adaptations |
|---|---|---|---|
| ep0 | 0.0452 | 0.0208 | 12 |
| ep10 | 0.0470 | 0.0168 | 1 |
| ep11 | 0.0510 | 0.0150 | 6 |
| ep12 | 0.0491 | 0.0165 | 7 |
| ep13 | 0.0490 | 0.0153 | 3 |
Pixel Observation
β
βΌ
VideoSAUREncoder β Perceive: slots = {sβ, sβ, sβ, sβ}
β
βΌ
MechanismCodebook β Retrieve: m_ij = close mechanism for each slot pair
β
βΌ
MechSlotPredictor (JEPA) β Predict: αΊ_{t+1} (action-conditioned via AdaLN)
β
ββββΆ CEMSolver (SWM) β System B: optimise action sequence toward goal
β
ββββΆ ABMPolicy β System M: if surprise(αΊ, z) > Ο β online adaptation
| Module | File | Description |
|---|---|---|
MechJEPA |
mechjepa/model.py |
Top-level model |
MechanismCodebook |
mechjepa/codebook.py |
VQ-based mechanism memory |
MechSlotPredictor |
mechjepa/dynamics.py |
Transformer predictor with AdaLN action conditioning |
ActionAdaLN |
mechjepa/dynamics.py |
Per-layer action modulation |
MechJEPACostModel |
mechjepa/cost_model.py |
SWM-compatible cost model (get_cost) |
ABMPolicy |
mechjepa/abm_policy.py |
System M policy (extends SWM WorldModelPolicy) |
SystemM |
mechjepa/system_m.py |
Surprise monitor & adaptation trigger |
VideoSAUREncoder |
mechjepa/encoder.py |
Pixel β slot encoder (VideoSAUR/C-JEPA) |
git clone https://github.com/GerardCB/mech-jepa.git
cd mech-jepa
pip install -e .
pip install stable-worldmodel loguru einops# Best 50-epoch Push-T checkpoint (35MB)
# Place in checkpoints/mechjepa_pusht_act_best.ckptpython scripts/plan_pusht.py \
--ckpt checkpoints/mechjepa_pusht_act_best.ckpt \
--data data/pusht_slots_actions.pkl \
--ep 0 --horizon 10# In-distribution: 3 episodes with real env rendering
python scripts/eval_live_pusht.py \
--ckpt checkpoints/mechjepa_pusht_act_best.ckpt \
--encoder data/pusht_videosaur_model.ckpt \
--episodes 3
# OOD: 40% bigger T-block (real physics change)
python scripts/eval_live_pusht.py \
--ckpt checkpoints/mechjepa_pusht_act_best.ckpt \
--encoder data/pusht_videosaur_model.ckpt \
--episodes 3 --ood_block_scale 42python scripts/abm_pusht.py \
--ckpt checkpoints/mechjepa_pusht_act_best.ckpt \
--data data/pusht_slots_actions.pkl \
--shift 1.4 --threshold 0.015 --episodes 5# Setup environment
bash scripts/runpod_pusht_setup.sh
# Train (batch_size=4096, mixed precision)
bash scripts/runpod_pusht_train.sh| Figure | Description |
|---|---|
results/surprise_comparison.png |
Per-step prediction surprise: Frozen vs A-B-M |
results/plan_err_comparison.png |
Per-step latent planning error |
results/summary_bar.png |
Summary grouped bar chart |
results/abm_demo.mp4 |
Side-by-side slot trajectory video |
mech-jepa/
βββ mechjepa/
β βββ model.py # MechJEPA (top-level)
β βββ dynamics.py # MechSlotPredictor + ActionAdaLN
β βββ codebook.py # MechanismCodebook (VQ)
β βββ cost_model.py # MechJEPACostModel (SWM get_cost interface)
β βββ abm_policy.py # ABMPolicy (System M + SWM WorldModelPolicy)
β βββ system_m.py # SystemM (surprise monitor)
β βββ encoder.py # VideoSAUREncoder (pixel β slots)
β βββ data/
β βββ clevrer_slots.py # PushTSlotDataset + data loaders
βββ scripts/
β βββ train_pusht.py # Full-scale training (H100, batch 4096)
β βββ plan_pusht.py # Phase 2 CEM planning benchmark
β βββ eval_live_pusht.py # Phase 3 closed-loop eval (real env renders)
β βββ abm_pusht.py # Phase 3 latent-only A-B-M benchmark
β βββ visualize_abm.py # Figure + video generation
β βββ runpod_pusht_setup.sh # RunPod environment setup
β βββ runpod_pusht_train.sh # RunPod training launcher
βββ configs/
β βββ pusht.yaml # Training hyperparameters
βββ tests/
βββ test_model.py # Unit tests
| LeWorldModel | MechJEPA | |
|---|---|---|
| Architecture | JEPA + AdaLN | JEPA + AdaLN + VQ Codebook |
| Bottleneck | None | Mechanism memory |
| OOD Adaptation | β | System M (online) |
| Planning | CEM | CEM |
| OOD Recovery | None | 2.9Γ better |
System M is not a neural network β it's a conditional branch on the per-slot prediction error signal, following LeCun's proposal in "A Path Towards Autonomous Machine Intelligence".
When the world model predicts incorrectly (surprise > Ο), System M takes a few gradient steps on only the mechanism codebook and predictor, then returns to planning. This narrow update prevents catastrophic forgetting while enabling rapid local adaptation.
Actions are embedded and injected into every transformer layer via ActionAdaLN:
LN(x) β scale(a) * LN(x) + shift(a)
Initialized to identity (scale=1, shift=0) to ensure backward compatibility with action-free pretraining.
- I-JEPA: Image-based JEPA
- LeWorldModel: Action-conditioned JEPA for embodied agents
- C-JEPA: Causal JEPA with VideoSAUR
- Stable WorldModel: Push-T evaluation framework

