Skip to content

PeppaKing8/minit2i-jax

Repository files navigation

MiniT2I logo

MiniT2I: A Minimalist Baseline for Text-to-Image Generation

MiniT2I blog post   Hugging Face JAX checkpoints   PyTorch code   License

Official JAX/Flax training implementation of MiniT2I.

MiniT2I is a simple direct-RGB text-to-image generator that trains a pixel-space MM-JiT denoiser with flow matching, conditioned on frozen FLAN-T5-Large text tokens. The recipe is intentionally plain: avoiding image tokenizers, cascaded generation, RL stages, and any auxiliary losses. Data used in training MiniT2I is fully public and easy to implement. For more details, please refer to our blog post.

This repository contains the original JAX/Flax diffusion training and evaluation code used for MiniT2I.

  • For the JAX Mean Flow distillation code used by the four-step MiniT2I-B/16-MF checkpoint, use the mean_flow_distill branch.
  • For a PyTorch/Diffusers implementation with inference and LoRA adaptation, see Hope7Happiness/minit2i-torch.

Model Zoo

Model Params Patch GenEval DPG-Bench Hugging Face
MiniT2I-B/16 258M + 341M text encoder 16 0.873 84.2 MiniT2I-B-16
MiniT2I-L/16 912M + 341M text encoder 16 0.883 85.9 MiniT2I-L-16

The repository also includes our default baseline B/32 for ablation and reproduction studies.

Repository Layout

.
|-- main.py                     # JAX distributed train/eval entry point
|-- train.py                    # training loop, checkpointing, sampling, online eval
|-- diffusion.py                # flow-matching objective and samplers
|-- configs/                    # defaults plus b32/b16/l16 YAML recipes
|-- settings.py                 # local paths, checkpoints, eval assets, logging
|-- models/                     # MM-JiT, T5 encoder, Flax/Torch-compatible layers
|-- utils/                      # input pipeline, pjit sharding, checkpoints, logging
|-- evaluators/                 # FID, GenEval, and DPG-Bench dispatch
|-- external/                   # JAX evaluator/model ports used by benchmarks
`-- scripts/                    # install/train/eval launch helpers

Installation

This codebase is TPU-oriented that uses JAX distributed initialization.

Create a Python environment, install a TPU-compatible JAX build, then install the remaining dependencies:

python -m pip install "jax[tpu]" \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python -m pip install -r requirements.txt

The same commands are wrapped by:

bash scripts/install.sh

Authenticate only with services needed by your run:

hf auth login
wandb login

Project-level paths, checkpoint roots, benchmark asset roots, and W&B defaults live in settings.py. Experiment-dependent fields such as load_from belong in run YAML configs or command-line overrides.

Dataset Preparation

Training uses WebDataset tar shards. Each sample must contain an image under jpg or png and a caption under txt. The input pipeline resizes and center-crops images, normalizes them to [-1, 1], and tokenizes captions with the configured frozen text encoder.

See scripts/datasets/README.md for more details.

Training

We provide training config files for B/32, B/16, and L/16, respectively:

Set CC12M_ROOT in settings.py, then keep the YAML focused on the training recipe:

eval_only: False

dataset:
  use_cc12m: True
  use_blip3_ft60k: False
  use_dalle3: False
  use_sharegpt4o: False

eval:
  on_training: False

Launch:

bash scripts/train.sh pretrain \
  --config configs/load_config.py:pretrain_b16 \
  --workdir /path/to/runs/minit2i-pretrain

For the B/32 ablation recipe:

bash scripts/train.sh pretrain_b32 \
  --config configs/load_config.py:pretrain_b32 \
  --workdir /path/to/runs/minit2i-b32-pretrain

For fine-tuning, pass the pretrained checkpoint with --load_from or set load_from in the run YAML, then enable the 120K mix sources:

eval_only: False
load_from: /path/to/pretrained/checkpoint_or_run_dir

dataset:
  use_cc12m: False
  use_blip3_ft60k: True
  use_dalle3: True
  use_sharegpt4o: True

Launch:

bash scripts/train.sh finetune \
  --config configs/load_config.py:finetune_b16 \
  --workdir /path/to/runs/minit2i-finetune \
  --load_from /path/to/pretrained/checkpoint_or_run_dir

Checkpoints are saved through Flax checkpointing directly under workdir. Local absolute paths and gs:// paths (GCS bucket) are both supported.

Evaluation

Evaluation is driven by the same entry point with eval_only: True. The checkpoint is passed with --load_from when using scripts/eval.sh, or with the experiment-level load_from field in a YAML config. It is restored in train.py.

load_from may point either to a Flax checkpoint directory (starting with checkpoint_), or to a parent directory containing Flax checkpoints. If a parent directory is given, the latest checkpoint_* under that directory is restored.

Download Checkpoints

Install the Hugging Face CLI if needed:

python -m pip install -U "huggingface_hub[cli]"

Download a JAX checkpoint:

hf download MiniT2I/MiniT2I-B-16-jax \
  --local-dir /path/to/checkpoints/MiniT2I-B-16-jax

hf download MiniT2I/MiniT2I-L-16-jax \
  --local-dir /path/to/checkpoints/MiniT2I-L-16-jax

The model architecture in the YAML config must match the checkpoint for successful parameter loading:

Checkpoint Config
MiniT2I-B/16 (MiniT2I/MiniT2I-B-16-jax) configs/load_config.py:eval_b16
MiniT2I-L/16 (MiniT2I/MiniT2I-L-16-jax) configs/load_config.py:eval_l16

Run Evaluation

For B/16:

bash scripts/eval.sh eval \
  --config configs/load_config.py:eval_b16 \
  --workdir /path/to/runs/minit2i-eval \
  --load_from /path/to/checkpoints/MiniT2I-B-16-jax

For L/16:

bash scripts/eval.sh eval_l \
  --config configs/load_config.py:eval_l16 \
  --workdir /path/to/runs/minit2i-l-eval \
  --load_from /path/to/checkpoints/MiniT2I-L-16-jax

The combined evaluator currently wires:

  • FID on MSCOCO-30K when enabled.
  • GenEval with the JAX Mask2Former detector and optional CLIP color classifier (This JAX version is our reproduction).
  • DPG-Bench with the JAX mPLUG VQA evaluator (This JAX version is our reproduction).

Evaluation Assets

We evaluate MSCOCO-FID-30K, GenEval, and DPGBench. For each benchmark, set the asset paths in settings.py, then enable the benchmark via setting enable: True inside the config yaml file. You can specify the CFG scale for each benchmark. For instance, if using GenEval with guidance scale 6.0:

eval:
  geneval:
    enable: True
    cfg_scale: 6.0

If a benchmark-specific value is missing, the evaluator falls back to the top-level eval.cfg_scale.

Sampling

During training or eval-only runs, setting eval_show_sample: True writes samples from the built-in visualization prompts. If W&B or TensorBoard logging is disabled, images are saved under:

<workdir>/writed_images/

Acknowledgments

This codebase builds on a number of open-source efforts:

We also thank GCP for providing the computational resources.

Citation

@misc{minit2i2026,
  title  = {MiniT2I: A Minimalist Baseline for Text-to-Image Generation},
  author = {Wang, Xianbang and Zhao, Hanhong and Lu, Yiyang and Zhou, Kangyang and Ma, Linrui and He, Kaiming},
  year   = {2026},
  url    = {https://peppaking8.github.io/#/post/minit2i}
}

About

Official JAX code of MiniT2I.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors