S-Lab, Nanyang Technological University1 SenseTime Research 2
✉Corresponding Author.
✉Corresponding Author.
🔥 More coming soon!
conda create -n uae python=3.10 -y
conda activate uae
pip install uv
uv pip install torch==2.2.0 torchvision==0.17.0 torchaudio --index-url https://download.pytorch.org/whl/cu121
uv pip install timm==0.9.16 accelerate==0.23.0 torchdiffeq==0.2.5 wandb
uv pip install "numpy<2" transformers einops omegaconf
uv pip install torchmetricspython eval_uae.py \
--config unified_ae/configs/stage1_infer.yaml \
--checkpoint PATH_TO_WEIGHTS \
--imagenet-path PATH_TO_IMAGENET \
--coco-path PATH_TO_COCO \
--batch-size 16 \
--num-workers 8 \
--image-size 256 \
--freq-ratio 1.0 \
--log-file logs/uae_eval_metrics.txtExpected Results:
ImageNet: PSNR=29.588 dB | SSIM=0.8789 | rFID=0.193
MS-COCO: PSNR=29.484 dB | SSIM=0.8846 | rFID=0.157There are four sub-stages to train our UAE model.
Follow the scripts to step-by-step reproduce our results. The per-stage checkpoints and training logs are provided here: link.
# sub-stage 1
export WANDB_API_KEY=YOUR_KEY
export WANDB_ENTITY=YOUR_ID
export WANDB_PROJECT=PROJECT_NAME
DATA_ROOT=PATH_TO_TRAIN_OF_IMGNET
VAL_ROOT=PATH_TO_VAL_OF_IMGNET
accelerate launch train_uae.py \
--config unified_ae/configs/stage1_train.yaml \
--stage-key sub_stage1 \
--data-path "$DATA_ROOT" \
--val-path "$VAL_ROOT" \
--results-dir results/sub_stage1 \
--mixed-precision YOUR_PRECISION(bf16 or no) \
--wandb --wandb-name uae_1After this you will get model with FID=103.870, PSNR=18.025 dB
# sub-stage 2
export WANDB_API_KEY=YOUR_KEY
export WANDB_ENTITY=YOUR_ID
export WANDB_PROJECT=PROJECT_NAME
DATA_ROOT=PATH_TO_TRAIN_OF_IMGNET
VAL_ROOT=PATH_TO_VAL_OF_IMGNET
accelerate launch train_uae.py \
--config unified_ae/configs/stage1_train.yaml \
--stage-key sub_stage2 \
--data-path "$DATA_ROOT" \
--val-path "$VAL_ROOT" \
--results-dir results/sub_stage2 \
--mixed-precision YOUR_PRECISION(bf16 or no) \
--wandb --wandb-name uae_2After this you will get model with FID=0.968, PSNR=27.356 dB
# sub-stage 3
export WANDB_API_KEY=YOUR_KEY
export WANDB_ENTITY=YOUR_ID
export WANDB_PROJECT=PROJECT_NAME
DATA_ROOT=PATH_TO_TRAIN_OF_IMGNET
VAL_ROOT=PATH_TO_VAL_OF_IMGNET
accelerate launch train_uae.py \
--config unified_ae/configs/stage1_train.yaml \
--stage-key sub_stage3 \
--data-path "$DATA_ROOT" \
--val-path "$VAL_ROOT" \
--results-dir results/sub_stage3 \
--mixed-precision YOUR_PRECISION(bf16 or no) \
--wandb --wandb-name uae_3After this you will get model with FID=0.530, PSNR=30.110 dB
# sub-stage 4
export WANDB_API_KEY=YOUR_KEY
export WANDB_ENTITY=YOUR_ID
export WANDB_PROJECT=PROJECT_NAME
DATA_ROOT=PATH_TO_TRAIN_OF_IMGNET
VAL_ROOT=PATH_TO_VAL_OF_IMGNET
accelerate launch train_uae.py \
--config unified_ae/configs/stage1_train.yaml \
--stage-key sub_stage4 \
--data-path "$DATA_ROOT" \
--val-path "$VAL_ROOT" \
--results-dir results/sub_stage4 \
--mixed-precision YOUR_PRECISION(bf16 or no) \
--wandb --wandb-name uae_4After this you will get model with FID=0.166, PSNR=29.499 dB
@misc{fan2025uae,
title={The Prism Hypothesis: Harmonizing Semantic and Pixel Representations via Unified Autoencoding},
author={Weichen Fan and Haiwen Diao and Quan Wang and Dahua Lin and Ziwei Liu},
year={2025},
eprint={2512.19693},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2512.19693},
}