A minimal, from-scratch world model in PyTorch. The "nanoGPT" of world models.
8 files. ~1800 lines. One pip install.
Built on the DINO-WM architecture: frozen DINOv2 encoder + ViT predictor + CEM planner.
Image → [Frozen DINOv2] → Patch Tokens (256 × 384)
↓
Actions + Proprio → [Small MLPs] → Action/Proprio Tokens
↓
[ViT Predictor] → Predicted Next-Frame Tokens
↓
MSE Loss vs Actual Next-Frame Tokens
What's trainable: Only the ViT predictor (~10M params) and action/proprio encoders (~100 params). The DINOv2 encoder (22M params) stays frozen.
| File | Lines | What it does |
|---|---|---|
encoder.py |
74 | Frozen DINOv2 wrapper |
predictor.py |
194 | Custom ViT transformer (causal attention) |
world_model.py |
280 | Full system: encoder + predictor + loss |
data.py |
191 | Trajectory dataset with sliding windows |
train.py |
280 | Training loop with mixed precision |
eval.py |
148 | Prediction quality metrics |
plan.py |
201 | CEM planner for goal-reaching |
cli.py |
215 | Command-line interface |
share.py |
470 | Live training dashboard (shareable URL) |
sync_checkpoints.py |
105 | Auto-download checkpoints from remote |
# Install dependencies
pip install torch torchvision
pip install atlas-logger # optional, for experiment tracking
# Train
python train.py --data_path /path/to/point_maze --epochs 100
# Train with Temporal Straightening (arXiv:2603.12231)
python train.py --data_path /path/to/point_maze --curv_weight 0.5
# Evaluate
python cli.py eval --data_path /path/to/point_maze --checkpoint checkpoints/best.pth
# Plan
python cli.py plan --checkpoint checkpoints/best.pth --start_img start.png --goal_img goal.pngRun parallel experiments on separate GPUs (no DDP needed):
CUDA_VISIBLE_DEVICES=0 python train.py --data_path ./data --curv_weight 0.0 --save_dir checkpoints/cw0.0 &
CUDA_VISIBLE_DEVICES=1 python train.py --data_path ./data --curv_weight 0.1 --save_dir checkpoints/cw0.1 &
CUDA_VISIBLE_DEVICES=2 python train.py --data_path ./data --curv_weight 0.5 --save_dir checkpoints/cw0.5 &
CUDA_VISIBLE_DEVICES=3 python train.py --data_path ./data --curv_weight 1.0 --save_dir checkpoints/cw1.0 &Share training progress with anyone via a URL:
# Start dashboard (serves a self-contained HTML page with live charts)
python share.py --run runs/<run_name> --port 8080
# Auto-sync checkpoints to your local machine
python sync_checkpoints.py --url http://<remote-ip>:8080 --dest ./checkpoints- Temporal Straightening (arXiv:2603.12231, LeCun co-author):
--curv_weightadds curvature regularization to latent trajectories, making gradient-based planning competitive with CEM. - Atlas Chronicle integration: Local-first experiment tracking with per-iteration metrics, system monitoring, and hyperparameter logging.
- Live shareable dashboard: Zero-dependency HTTP server for remote training monitoring.
Uses the DINO-WM trajectory format:
data/point_maze/
├── actions.pth (num_trajs, max_T, action_dim)
├── states.pth (num_trajs, max_T, state_dim)
├── seq_lengths.pth (num_trajs,)
└── obses/
├── episode_000.pth (T, 224, 224, 3) uint8
├── episode_001.pth
└── ...
Dataset available from OSF.
- DINO-WM — Zhou et al., 2024. The architecture this is based on.
- DINOv2 — Oquab et al., 2024. The frozen visual encoder.
- Temporal Straightening — Curvature regularization for world models.
MIT