Skip to content

itsanderz/nanoWM

Repository files navigation

nanoWorldModel

A minimal, from-scratch world model in PyTorch. The "nanoGPT" of world models.

8 files. ~1800 lines. One pip install.

Built on the DINO-WM architecture: frozen DINOv2 encoder + ViT predictor + CEM planner.

Architecture

Image → [Frozen DINOv2] → Patch Tokens (256 × 384)
                              ↓
Actions + Proprio → [Small MLPs] → Action/Proprio Tokens
                              ↓
[ViT Predictor] → Predicted Next-Frame Tokens
                              ↓
MSE Loss vs Actual Next-Frame Tokens

What's trainable: Only the ViT predictor (~10M params) and action/proprio encoders (~100 params). The DINOv2 encoder (22M params) stays frozen.

Files

File Lines What it does
encoder.py 74 Frozen DINOv2 wrapper
predictor.py 194 Custom ViT transformer (causal attention)
world_model.py 280 Full system: encoder + predictor + loss
data.py 191 Trajectory dataset with sliding windows
train.py 280 Training loop with mixed precision
eval.py 148 Prediction quality metrics
plan.py 201 CEM planner for goal-reaching
cli.py 215 Command-line interface
share.py 470 Live training dashboard (shareable URL)
sync_checkpoints.py 105 Auto-download checkpoints from remote

Quick Start

# Install dependencies
pip install torch torchvision
pip install atlas-logger  # optional, for experiment tracking

# Train
python train.py --data_path /path/to/point_maze --epochs 100

# Train with Temporal Straightening (arXiv:2603.12231)
python train.py --data_path /path/to/point_maze --curv_weight 0.5

# Evaluate
python cli.py eval --data_path /path/to/point_maze --checkpoint checkpoints/best.pth

# Plan
python cli.py plan --checkpoint checkpoints/best.pth --start_img start.png --goal_img goal.png

Multi-GPU Training

Run parallel experiments on separate GPUs (no DDP needed):

CUDA_VISIBLE_DEVICES=0 python train.py --data_path ./data --curv_weight 0.0 --save_dir checkpoints/cw0.0 &
CUDA_VISIBLE_DEVICES=1 python train.py --data_path ./data --curv_weight 0.1 --save_dir checkpoints/cw0.1 &
CUDA_VISIBLE_DEVICES=2 python train.py --data_path ./data --curv_weight 0.5 --save_dir checkpoints/cw0.5 &
CUDA_VISIBLE_DEVICES=3 python train.py --data_path ./data --curv_weight 1.0 --save_dir checkpoints/cw1.0 &

Live Dashboard

Share training progress with anyone via a URL:

# Start dashboard (serves a self-contained HTML page with live charts)
python share.py --run runs/<run_name> --port 8080

# Auto-sync checkpoints to your local machine
python sync_checkpoints.py --url http://<remote-ip>:8080 --dest ./checkpoints

Novel Features

  • Temporal Straightening (arXiv:2603.12231, LeCun co-author): --curv_weight adds curvature regularization to latent trajectories, making gradient-based planning competitive with CEM.
  • Atlas Chronicle integration: Local-first experiment tracking with per-iteration metrics, system monitoring, and hyperparameter logging.
  • Live shareable dashboard: Zero-dependency HTTP server for remote training monitoring.

Data Format

Uses the DINO-WM trajectory format:

data/point_maze/
├── actions.pth        (num_trajs, max_T, action_dim)
├── states.pth         (num_trajs, max_T, state_dim)
├── seq_lengths.pth    (num_trajs,)
└── obses/
    ├── episode_000.pth  (T, 224, 224, 3) uint8
    ├── episode_001.pth
    └── ...

Dataset available from OSF.

References

  • DINO-WM — Zhou et al., 2024. The architecture this is based on.
  • DINOv2 — Oquab et al., 2024. The frozen visual encoder.
  • Temporal Straightening — Curvature regularization for world models.

License

MIT

About

A minimal, from-scratch world model in PyTorch. The nanoGPT of world models.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors