Skip to content

acl21/diwa

Repository files navigation

DiWA: Diffusion Policy Adaptation with World Models

arXiv Website License PyTorch

drawing

DiWA is an algorithmic framework for fine-tuning diffusion-based policies entirely inside frozen world models (learned from large play data).

🔎 Overview

drawing

💻 Installation

  1. To begin, clone this repository locally
git clone --recurse-submodules https://github.com/acl21/diwa.git
cd diwa
  1. Set environment variables for datasets and logging directory (default is dataset/ and logs/), and set WandB entity (username or team name)
source scripts/set_path.sh
  1. ⚠️ If you've already cloned the repo without --recurse-submodules, run:
git submodule update --init --recursive
  • Submodule: lumos

    This repository includes https://github.com/nematoli/lumos/ as a submodule for all things related to world model training and featurizing, tracking its diwa branch. If you want to inspect or update the submodule manually:

    cd DIWA_ROOT_DIR/lumos
    git checkout diwa
    git pull origin diwa
  • Submodule: calvin_env

    This repository inclues https://github.com/mees/calvin_env/ as a submodule for simulation experiments, tracking its main branch.

  • (Optional) Submoudle: LIBERO

    One can optionally add https://github.com/Lifelong-Robot-Learning/LIBERO as a submodule for simulation experiments. Simply uncomment the relevant lines in .gitmodules and relevant install commands in install.sh.

  1. Create and activate the conda environment, then install the dependencies:
cd DIWA_ROOT_DIR
conda create -n diwa python=3.10
conda activate diwa
sh install.sh 

🛠️ Usage

0. Dataset

To download and preprocess datasets for DiWA, please follow A.0 and A.1 here.

1. World Model

Note: You may skip world model training if you would like to use the default checkpoints (available for download here).

1.1 Training

python scripts/train_wm.py trainer.devices=[<GPU-ID>]

1.2 Featurizer

python scripts/featurizer.py device=<GPU-ID>

1.3 (Optional) Qualitative Test

Optionally, one can visually assess the quality of the learned world model with our test script.

python scripts/tests/visionwm.py

2. Diffusion Policy Training

Note: Before pre-training, please extract the featurized expert data with A.2 here or you can download here.

All configs relevant for pre-training can be found under config/<env>/pretrain/<skill-name>. To pretrain CALVIN's close_drawer skill, run:

python scripts/run.py --config-name=pre_diffusion_mlp_feat_vision --config-dir=config/calvin/pretrain/close_drawer

3. Reward Estimation

Before training the reward classifier, please generate the class-balanced classification data with A.3 here or you can download pretrained reward classifiers here.

python scripts/rewcls/train_contrastive.py

4. Fine-tuning inside World Model

All configs relevant for fine-tuning can be found under config/<env>/finetune/<skill-name>. Set base_policy_path to the relevant pretrained policy checkpoint. To fine-tune CALVIN's close_drawer skill, run:

python scripts/run.py --config-name=ft_mb_ppo_diffusion_mlp_feat_vision --config-dir=cfg/calvin/finetune/close_drawer device="cuda:0" train.bc_loss_coeff=0.025 train.use_bc_loss=True seed=42

⚠️ Known Issues

  1. To solve the TypeError you may face with line 72 in calvin_env/calvin_env/envs/play_table_env.py, replace line 20 with from calvin_env import calvin_env.

📝 Citation

If you find DiWA useful in your work, please leave a ⭐ and consider citing our work with:

@article{chandra2025diwa,
    title={DiWA: Diffusion Policy Adaptation with World Models},
    author={Chandra, Akshay L and Nematollahi, Iman and Huang, Chenguang and Welschehold, Tim and Burgard, Wolfram and Valada, Abhinav},
    journal={Conference on Robot Learning (CoRL)},
    year={2025},
}

🏷️ License

This repository is released under the GPL-3.0 license. See LICENSE.

✨ Acknowledgement

About

DiWA: Diffusion Policy Adaptation with World Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published