Skip to content

GerardCB/mech-jepa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

53 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

MechJEPA 🧩

A world model with persistent mechanism memory for physical reasoning and out-of-distribution adaptation.

MechJEPA extends I-JEPA with:

  1. Mechanism Codebook β€” a learned library of stable physical interaction patterns (push, collide, support)
  2. Action-Conditioned Dynamics β€” AdaLN-based action conditioning via LeWorldModel
  3. System M β€” surprise-triggered online adaptation for OOD robustness

This is the first demonstration of an A-B-M (Anticipate–Behave–Modulate) agent loop on the Push-T environment.

Demo

Fully closed-loop evaluation using real Push-T environment renders (like LeWorldModel).

At every step: env.render() β†’ VideoSAUR DINOv2 ViT β†’ slot corrector β†’ 4 slots β†’ CEM planner β†’ action β†’ env.step()

Three conditions: Expert (left), Frozen MechJEPA (centre), A-B-M with System M (right).

Expert | Frozen | A-B-M


πŸ† Results

Phase 1: World Model Training (Push-T)

Trained on 18,685 expert episodes (β‰ˆ1.98M samples) from the Push-T dataset with action conditioning:

Metric Value
Val Loss 0.0018
Batch Size 4096 (H100)
GPU Utilization 94–96%
Training Epochs 50

Phase 2: CEM Latent Planning (System B)

CEM planner optimises action sequences directly in the latent slot space:

Mean Latent Error Median Max
CEM Planner 0.0034 0.0032 0.0058
Random Actions 0.0434 0.0472 0.0869
Improvement 12.8Γ—

Phase 3: System M β€” OOD Adaptation (A-B-M Loop)

Distribution shift: slot observations scaled by Ξ±=1.4 (simulating a 40% heavier/larger block)

Frozen Model A-B-M Agent
Mean Surprise 0.0340 0.0090 (3.8×↓)
Mean Plan Error 0.0483 0.0169 (2.9×↓)
Total Adaptations 0 29

Summary Results

Episode Frozen A-B-M Adaptations
ep0 0.0452 0.0208 12
ep10 0.0470 0.0168 1
ep11 0.0510 0.0150 6
ep12 0.0491 0.0165 7
ep13 0.0490 0.0153 3

πŸ—οΈ Architecture

Pixel Observation
      β”‚
      β–Ό
VideoSAUREncoder          ← Perceive: slots = {s₁, sβ‚‚, s₃, sβ‚„}
      β”‚
      β–Ό
MechanismCodebook         ← Retrieve: m_ij = close mechanism for each slot pair
      β”‚
      β–Ό
MechSlotPredictor (JEPA)  ← Predict: αΊ‘_{t+1} (action-conditioned via AdaLN)
      β”‚
      β”œβ”€β”€β–Ά CEMSolver (SWM)  ← System B: optimise action sequence toward goal
      β”‚
      └──▢ ABMPolicy         ← System M: if surprise(αΊ‘, z) > Ο„ β†’ online adaptation

Key Components

Module File Description
MechJEPA mechjepa/model.py Top-level model
MechanismCodebook mechjepa/codebook.py VQ-based mechanism memory
MechSlotPredictor mechjepa/dynamics.py Transformer predictor with AdaLN action conditioning
ActionAdaLN mechjepa/dynamics.py Per-layer action modulation
MechJEPACostModel mechjepa/cost_model.py SWM-compatible cost model (get_cost)
ABMPolicy mechjepa/abm_policy.py System M policy (extends SWM WorldModelPolicy)
SystemM mechjepa/system_m.py Surprise monitor & adaptation trigger
VideoSAUREncoder mechjepa/encoder.py Pixel β†’ slot encoder (VideoSAUR/C-JEPA)

πŸš€ Quick Start

Installation

git clone https://github.com/GerardCB/mech-jepa.git
cd mech-jepa
pip install -e .
pip install stable-worldmodel loguru einops

Download Checkpoint

# Best 50-epoch Push-T checkpoint (35MB)
# Place in checkpoints/mechjepa_pusht_act_best.ckpt

Reproduce Phase 2 (Planning Benchmark)

python scripts/plan_pusht.py \
    --ckpt checkpoints/mechjepa_pusht_act_best.ckpt \
    --data data/pusht_slots_actions.pkl \
    --ep 0 --horizon 10

Reproduce Phase 3 (Closed-Loop A-B-M Demo)

# In-distribution: 3 episodes with real env rendering
python scripts/eval_live_pusht.py \
    --ckpt checkpoints/mechjepa_pusht_act_best.ckpt \
    --encoder data/pusht_videosaur_model.ckpt \
    --episodes 3

# OOD: 40% bigger T-block (real physics change)
python scripts/eval_live_pusht.py \
    --ckpt checkpoints/mechjepa_pusht_act_best.ckpt \
    --encoder data/pusht_videosaur_model.ckpt \
    --episodes 3 --ood_block_scale 42

Latent-space A-B-M (fast, no encoder needed)

python scripts/abm_pusht.py \
    --ckpt checkpoints/mechjepa_pusht_act_best.ckpt \
    --data data/pusht_slots_actions.pkl \
    --shift 1.4 --threshold 0.015 --episodes 5

Train from Scratch (RunPod H100)

# Setup environment
bash scripts/runpod_pusht_setup.sh

# Train (batch_size=4096, mixed precision)
bash scripts/runpod_pusht_train.sh

πŸ“Š Figures

Figure Description
results/surprise_comparison.png Per-step prediction surprise: Frozen vs A-B-M
results/plan_err_comparison.png Per-step latent planning error
results/summary_bar.png Summary grouped bar chart
results/abm_demo.mp4 Side-by-side slot trajectory video

πŸ“ Repository Layout

mech-jepa/
β”œβ”€β”€ mechjepa/
β”‚   β”œβ”€β”€ model.py          # MechJEPA (top-level)
β”‚   β”œβ”€β”€ dynamics.py       # MechSlotPredictor + ActionAdaLN
β”‚   β”œβ”€β”€ codebook.py       # MechanismCodebook (VQ)
β”‚   β”œβ”€β”€ cost_model.py     # MechJEPACostModel (SWM get_cost interface)
β”‚   β”œβ”€β”€ abm_policy.py     # ABMPolicy (System M + SWM WorldModelPolicy)
β”‚   β”œβ”€β”€ system_m.py       # SystemM (surprise monitor)
β”‚   β”œβ”€β”€ encoder.py        # VideoSAUREncoder (pixel β†’ slots)
β”‚   └── data/
β”‚       └── clevrer_slots.py   # PushTSlotDataset + data loaders
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ train_pusht.py         # Full-scale training (H100, batch 4096)
β”‚   β”œβ”€β”€ plan_pusht.py          # Phase 2 CEM planning benchmark
β”‚   β”œβ”€β”€ eval_live_pusht.py     # Phase 3 closed-loop eval (real env renders)
β”‚   β”œβ”€β”€ abm_pusht.py           # Phase 3 latent-only A-B-M benchmark
β”‚   β”œβ”€β”€ visualize_abm.py       # Figure + video generation
β”‚   β”œβ”€β”€ runpod_pusht_setup.sh  # RunPod environment setup
β”‚   └── runpod_pusht_train.sh  # RunPod training launcher
β”œβ”€β”€ configs/
β”‚   └── pusht.yaml             # Training hyperparameters
└── tests/
    └── test_model.py          # Unit tests

πŸ“– Design Notes

Why MechJEPA over LeWorldModel?

LeWorldModel MechJEPA
Architecture JEPA + AdaLN JEPA + AdaLN + VQ Codebook
Bottleneck None Mechanism memory
OOD Adaptation βœ— System M (online)
Planning CEM CEM
OOD Recovery None 2.9Γ— better

System M Design

System M is not a neural network β€” it's a conditional branch on the per-slot prediction error signal, following LeCun's proposal in "A Path Towards Autonomous Machine Intelligence".

When the world model predicts incorrectly (surprise > Ο„), System M takes a few gradient steps on only the mechanism codebook and predictor, then returns to planning. This narrow update prevents catastrophic forgetting while enabling rapid local adaptation.

Action Conditioning

Actions are embedded and injected into every transformer layer via ActionAdaLN:

LN(x) β†’ scale(a) * LN(x) + shift(a)

Initialized to identity (scale=1, shift=0) to ensure backward compatibility with action-free pretraining.


πŸ”— Related Work

About

Causal World Models with Persistent Mechanism Memory. Mech-JEPA allows zero-shot causal transfer and surprise-triggered autonomous learning.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors