DiWA: Diffusion Policy Adaptation with World Models
Akshay L Chandra1*, Iman Nematollahi1*, Chenguang Huang2, Tim Welschehold1, Wolfram Burgard2, Abhinav Valada1 1Univeristy of Freiburg 2University of Technology Nuremberg |
DiWA is an algorithmic framework for fine-tuning diffusion-based policies entirely inside frozen world models (learned from large play data).
- To begin, clone this repository locally
git clone --recurse-submodules https://github.com/acl21/diwa.git
cd diwa
- Set environment variables for datasets and logging directory (default is
dataset/
andlogs/
), and set WandB entity (username or team name)
source scripts/set_path.sh
⚠️ If you've already cloned the repo without--recurse-submodules
, run:
git submodule update --init --recursive
-
Submodule:
lumos
This repository includes
https://github.com/nematoli/lumos/
as a submodule for all things related to world model training and featurizing, tracking itsdiwa
branch. If you want to inspect or update the submodule manually:cd DIWA_ROOT_DIR/lumos git checkout diwa git pull origin diwa
-
Submodule:
calvin_env
This repository inclues
https://github.com/mees/calvin_env/
as a submodule for simulation experiments, tracking itsmain
branch. -
(Optional) Submoudle:
LIBERO
One can optionally add
https://github.com/Lifelong-Robot-Learning/LIBERO
as a submodule for simulation experiments. Simply uncomment the relevant lines in.gitmodules
and relevant install commands ininstall.sh
.
- Create and activate the conda environment, then install the dependencies:
cd DIWA_ROOT_DIR
conda create -n diwa python=3.10
conda activate diwa
sh install.sh
To download and preprocess datasets for DiWA, please follow A.0 and A.1 here.
Note: You may skip world model training if you would like to use the default checkpoints (available for download here).
python scripts/train_wm.py trainer.devices=[<GPU-ID>]
python scripts/featurizer.py device=<GPU-ID>
Optionally, one can visually assess the quality of the learned world model with our test script.
python scripts/tests/visionwm.py
Note: Before pre-training, please extract the featurized expert data with A.2 here or you can download here.
All configs relevant for pre-training can be found under config/<env>/pretrain/<skill-name>
. To pretrain CALVIN's close_drawer
skill, run:
python scripts/run.py --config-name=pre_diffusion_mlp_feat_vision --config-dir=config/calvin/pretrain/close_drawer
Before training the reward classifier, please generate the class-balanced classification data with A.3 here or you can download pretrained reward classifiers here.
python scripts/rewcls/train_contrastive.py
All configs relevant for fine-tuning can be found under config/<env>/finetune/<skill-name>
. Set base_policy_path
to the relevant pretrained policy checkpoint. To fine-tune CALVIN's close_drawer
skill, run:
python scripts/run.py --config-name=ft_mb_ppo_diffusion_mlp_feat_vision --config-dir=cfg/calvin/finetune/close_drawer device="cuda:0" train.bc_loss_coeff=0.025 train.use_bc_loss=True seed=42
- To solve the
TypeError
you may face with line 72 incalvin_env/calvin_env/envs/play_table_env.py
, replace line 20 withfrom calvin_env import calvin_env
.
If you find DiWA useful in your work, please leave a ⭐ and consider citing our work with:
@article{chandra2025diwa,
title={DiWA: Diffusion Policy Adaptation with World Models},
author={Chandra, Akshay L and Nematollahi, Iman and Huang, Chenguang and Welschehold, Tim and Burgard, Wolfram and Valada, Abhinav},
journal={Conference on Robot Learning (CoRL)},
year={2025},
}
This repository is released under the GPL-3.0 license. See LICENSE.
- DPPO, Zen et al.: Code base on top of which DiWA was built. Specifically,
sequence.py
indiwa/dataset
, DDPM, DDIM, Gaussian, MLP/U-Net, ViT implementations indiwa/model/
, PPO implementation indiwa/agent/
are all borrowed. - LUMOS, Nematollahi et al.: World model training.
- CALVIN, Mees et al.: Simulation experiments.
- LIBERO, Liu et al.: Simulation experiments.