Skip to content

3D U-Net model for volumetric semantic segmentation written in pytorch

License

Notifications You must be signed in to change notification settings

wolny/pytorch-3dunet

Repository files navigation

alt text

DOI Build Status Anaconda-Server Badge Anaconda-Server Badge Anaconda-Server Badge Anaconda-Server Badge

pytorch-3dunet

PyTorch implementation of 3D U-Net and its variants:

The code allows for training the U-Net for both: semantic segmentation (binary and multi-class) and regression problems (e.g. de-noising, learning deconvolutions).

2D U-Net

2D U-Net is also supported, see 2DUnet_confocal or 2DUnet_dsb2018 for example configuration. Just make sure to keep the singleton z-dimension in your H5 dataset (i.e. (1, Y, X) instead of (Y, X)) , because data loading / data augmentation requires tensors of rank 3. The 2D U-Net itself uses the standard 2D convolutional layers instead of 3D convolutions with kernel size (1, 3, 3) for performance reasons.

Input Data Format

The input data should be stored in HDF5 files. The HDF5 files for training should contain two datasets: raw and label. The raw dataset contains the input data, while the label dataset contains the ground truth labels. The format of the raw and label datasets depends on whether the problem is 2D or 3D, as well as whether the data is single-channel or multi-channel. Please refer to the table below:

2D 3D
single-channel (1, Y, X) (Z, Y, X)
multi-channel (C, 1, Y, X) (C, Z, Y, X)

Prerequisites

  • Miniconda
  • Python 3.11+
  • NVIDIA GPU (optional but recommended for training/prediction speedup)

Running on Windows/OSX

pytorch-3dunet is a cross-platform package and runs on Windows and OS X as well.

Installation

The easiest way to install pytorch-3dunet package is via conda:

# Created new conda environment "3dunet" with the latest python version from the conda-forge channel
conda create -n 3dunet python -c conda-forge -y

# Activate the conda environment
conda activate 3dunet

# pytorch-3dunet does not include PyTorch dependencies, so that one can install the desired PyTorch version (with/without CUDA support) separately
pip install torch torchvision
# you may need to adjust the command above depending on your GPU and the CUDA version you want to use, e.g. for CUDA 12.6:
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
# or for CPU-only version:
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu

# Install the latest pytorch-3dunet package from conda-forge channel
conda install -c conda-forge pytorch-3dunet

After installation the following commands will be accessible within the conda environment: train3dunet for training the network and predict3dunet for prediction (see below).

One can also install directly from source, i.e. go to the checkout directory and run:

pip install -e .

Installation tips

PyTorch package comes with their own CUDA runtime libraries, so you don't need to install CUDA separately on your system. However, you must ensure that the PyTorch/CUDA version you choose is compatible with your GPU’s compute capability. See PyTorch installation guide for more details.

Train

Given that pytorch-3dunet package was installed via conda as described above, you can train the network by simply invoking:

train3dunet --config <CONFIG>

where CONFIG is the path to a YAML configuration file that specifies all aspects of the training process.

In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the config. Below are some example configs for segmentation and regression tasks:

One can monitor the training progress with Tensorboard tensorboard --logdir <checkpoint_dir>/logs/ (you need tensorflow installed in your conda env), where checkpoint_dir is the path to the checkpoint directory specified in the config.

Training tips

  1. When training with binary-based losses, i.e.: BCEWithLogitsLoss, DiceLoss, BCEDiceLoss, GeneralizedDiceLoss: The target data has to be 4D (one target binary mask per output channel of the network).
  2. When training with WeightedCrossEntropyLoss, CrossEntropyLoss the target dataset has to be 3D label image as expected by the loss (see PyTorch documentation for cross entropy loss: https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html)

Prediction

Given that pytorch-3dunet package was installed via conda as described above, one can run the prediction via:

predict3dunet --config <CONFIG>

In order to predict on your own data, just provide the path to your model as well as paths to HDF5 test files (see example test_config_segmentation.yaml).

Prediction tips

  1. If you're running prediction for a large dataset, consider using LazyHDF5Dataset and LazyPredictor in the config. This will save memory by loading data on the fly at the cost of slower prediction time. See test_config_lazy for an example config.
  2. If your model predicts multiple classes (see e.g. train_config_multiclass), consider saving only the final segmentation instead of the multi-channel probability maps, which can be time and space consuming. To do so, set save_segmentation: true in the predictor section of the config (see test_config_multiclass).
  3. If the model was trained with binary losses (BCEWithLogitsLoss, DiceLoss, BCEDiceLoss, GeneralizedDiceLoss) set final_sigmoid=True in the model part of the config so that the sigmoid is applied to the logits in inference mode.
  4. If the model was trained with multi-class losses (WeightedCrossEntropyLoss, CrossEntropyLoss) set final_sigmoid=False so that Softmax normalization is applied to the logits in inference mode.
  5. For fast prediction use the same patch_shape and stride_shape in the config. When doing so make sure to add a non-zero halo_shape around each patch in order to avoid checkerboard artifacts in the prediction (see e.g.: https://github.com/wolny/pytorch-3dunet/blob/master/resources/3DUnet_lightsheet_boundary/test_config.yml#L41)

Data Parallelism

By default, if multiple GPUs are available training/prediction will be run on all the GPUs using DataParallel. If training/prediction on all available GPUs is not desirable, restrict the number of GPUs using CUDA_VISIBLE_DEVICES, e.g.

CUDA_VISIBLE_DEVICES=0,1 train3dunet --config <CONFIG>

or

CUDA_VISIBLE_DEVICES=0,1 predict3dunet --config <CONFIG>

Supported Loss Functions

Semantic Segmentation

  • BCEWithLogitsLoss (binary cross-entropy)
  • DiceLoss (standard DiceLoss defined as 1 - DiceCoefficient used for binary semantic segmentation; when more than 2 classes are present in the ground truth, it computes the DiceLoss per channel and averages the values)
  • BCEDiceLoss (Linear combination of BCE and Dice losses, i.e. alpha * BCE + beta * Dice, alpha, beta can be specified in the loss section of the config)
  • CrossEntropyLoss (one can specify class weights via the weight: [w_1, ..., w_k] in the loss section of the config)
  • WeightedCrossEntropyLoss (see 'Weighted cross-entropy (WCE)' in the below paper for a detailed explanation)
  • GeneralizedDiceLoss (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation) Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise, use standard DiceLoss.

For a detailed explanation of some of the supported loss functions see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations.

Regression

  • MSELoss (mean squared error loss)
  • L1Loss (mean absolute error loss)
  • SmoothL1Loss (less sensitive to outliers than MSELoss)
  • WeightedSmoothL1Loss (extension of the SmoothL1Loss which allows to weight the voxel values above/below a given threshold differently)

Supported Evaluation Metrics

Semantic Segmentation

  • MeanIoU (mean intersection over union)
  • DiceCoefficient (computes per channel Dice Coefficient and returns the average) If a 3D U-Net was trained to predict cell boundaries, one can use the following semantic instance segmentation metrics (the metrics below are computed by running connected components on threshold boundary map and comparing the resulted instances to the ground truth instance segmentation):
  • BoundaryAveragePrecision (Average Precision applied to the boundary probability maps: thresholds the output from the network, runs connected components to get the segmentation and computes AP between the resulting segmentation and the ground truth)
  • AdaptedRandError (see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation)
  • AveragePrecision (see https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric)

If not specified MeanIoU will be used by default.

Regression

  • PSNR (peak signal to noise ratio)
  • MSE (mean squared error)

Examples

Cell boundary predictions for lightsheet images of Arabidopsis thaliana lateral root

Training/predictions configs can be found in 3DUnet_lightsheet_boundary. Pre-trained model weights available here. In order to use the pre-trained model on your own data:

  • download the best_checkpoint.pytorch from the above link
  • add the path to the downloaded model and the path to your data in test_config.yml
  • run predict3dunet --config test_config.yml
  • optionally fine-tune the pre-trained model with your own data, by setting the pre_trained attribute in the YAML config to point to the best_checkpoint.pytorch path

The data used for training can be downloaded from the following OSF project:

Sample z-slice predictions on the test set (top: raw input , bottom: boundary predictions):

Cell boundary predictions for confocal images of Arabidopsis thaliana ovules

Training/predictions configs can be found in 3DUnet_confocal_boundary. Pre-trained model weights available here. In order to use the pre-trained model on your own data:

  • download the best_checkpoint.pytorch from the above link
  • add the path to the downloaded model and the path to your data in test_config.yml
  • run predict3dunet --config test_config.yml
  • optionally fine-tune the pre-trained model with your own data, by setting the pre_trained attribute in the YAML config to point to the best_checkpoint.pytorch path

The data used for training can be downloaded from the following OSF project:

Sample z-slice predictions on the test set (top: raw input , bottom: boundary predictions):

Nuclei predictions for lightsheet images of Arabidopsis thaliana lateral root

Training/predictions configs can be found in 3DUnet_lightsheet_nuclei. Pre-trained model weights available here. In order to use the pre-trained model on your own data:

  • download the best_checkpoint.pytorch from the above link
  • add the path to the downloaded model and the path to your data in test_config.yml
  • run predict3dunet --config test_config.yml
  • optionally fine-tune the pre-trained model with your own data, by setting the pre_trained attribute in the YAML config to point to the best_checkpoint.pytorch path

The training and validation sets can be downloaded from the following OSF project: https://osf.io/thxzn/

Sample z-slice predictions on the test set (top: raw input, bottom: nuclei predictions):

2D nuclei predictions for Kaggle DSB2018

The data can be downloaded from: https://www.kaggle.com/c/data-science-bowl-2018/data

Training/predictions configs can be found in 2DUnet_dsb2018.

Sample predictions on the test image (top: raw input, bottom: nuclei predictions):

Contribute

If you want to contribute back, please make a pull request.

Cite

If you use this code for your research, please cite as:

@article {10.7554/eLife.57613,
article_type = {journal},
title = {Accurate and versatile 3D segmentation of plant tissues at cellular resolution},
author = {Wolny, Adrian and Cerrone, Lorenzo and Vijayan, Athul and Tofanelli, Rachele and Barro, Amaya Vilches and Louveaux, Marion and Wenzl, Christian and Strauss, Sören and Wilson-Sánchez, David and Lymbouridou, Rena and Steigleder, Susanne S and Pape, Constantin and Bailoni, Alberto and Duran-Nebreda, Salva and Bassel, George W and Lohmann, Jan U and Tsiantis, Miltos and Hamprecht, Fred A and Schneitz, Kay and Maizel, Alexis and Kreshuk, Anna},
editor = {Hardtke, Christian S and Bergmann, Dominique C and Bergmann, Dominique C and Graeff, Moritz},
volume = 9,
year = 2020,
month = {jul},
pub_date = {2020-07-29},
pages = {e57613},
citation = {eLife 2020;9:e57613},
doi = {10.7554/eLife.57613},
url = {https://doi.org/10.7554/eLife.57613},
keywords = {instance segmentation, cell segmentation, deep learning, image analysis},
journal = {eLife},
issn = {2050-084X},
publisher = {eLife Sciences Publications, Ltd},
}

Development

A development environment can be created via conda:

conda env create --file environment.yaml
conda activate 3dunet
pip install -e .

Tests can be run via pytest. The device the tests should be run on can be specified with the --device argument (cpu, mps, or cuda - default: cpu). Linting is done via ruff (see pyproject.toml for configuration).

Release new version on conda-forge channel

To release a new version of pytorch-3dunet on the conda-forge channel, follow these steps:

  1. In the main branch: run bumpversion patch (or major or minor) - this will bump the version in .bumpversion.cfg and __version__.py add create a new tag
  2. Run git push && git push --tags to push the changes to GitHub
  3. Make a new release on GitHub
  4. (Optional) Make sure that the new release version is in sync with the version in .bumpversion.cfg and __version__.py
  5. Generate the checksums for the new release using: curl -sL https://github.com/wolny/pytorch-3dunet/archive/refs/tags/VERSION.tar.gz | openssl sha256. Replace VERSION with the new release version
  6. Fork the conda-forge feedstock repository (https://github.com/conda-forge/pytorch-3dunet-feedstock)
  7. Clone the forked repository and create a new PR with the following changes:
    • Update the version in recipe/meta.yaml to the new release version
    • Update the sha256 in recipe/meta.yaml to the new checksum
  8. Wait for the checks to pass. Once the PR is merged, the new version will be available on the conda-forge channel