This repository contains the implementation of Probabilistic SAM. Our model learns a latent variable space that captures uncertainty and annotator variability in medical images. At inference, the model samples from this latent space, producing diverse masks that reflect the inherent ambiguity in medical image segmentation.
Recent advances in promptable segmentation, such as the Segment Anything Model (SAM), have enabled flexible, high-quality mask generation across a wide range of visual domains. However, SAM and similar models remain fundamentally deterministic, producing a single segmentation per object per prompt, and fail to capture the inherent ambiguity present in many real-world tasks. This limitation is particularly troublesome in medical imaging, where multiple plausible segmentations may exist due to annotation uncertainty or inter-expert variability. In this paper, we introduce Probabilistic SAM, a probabilistic extension of SAM that models a distribution over segmentations conditioned on both the input image and prompt. By incorporating a latent variable space and training with a variational objective, our model learns to generate diverse and plausible segmentation masks reflecting the variability in human annotations. The architecture integrates a prior and posterior network into the SAM framework, allowing latent codes to modulate the prompt embeddings during inference. The latent space allows for efficient sampling during inference, enabling uncertainty-aware outputs with minimal overhead. We evaluate Probabilistic SAM on the LIDC-IDRI lung nodule dataset and demonstrate its ability to produce diverse outputs that align with expert disagreement, outperforming existing probabilistic baselines on uncertainty-aware metrics.
Given a CT slice and a bounding box prompt
A prior network maps image embeddings to a Gaussian latent space, from which latent vectors
A brief summary of our results are shown below. Our Probabilistic SAM is compared to various baselines. In the table, the best scores are bolded and the second-best scores are italicized.
| Model | GED (↓) | DSC (↑) | IoU (↑) |
|---|---|---|---|
| Dropout U-Net | 0.5156 | 0.5591 | 0.3880 |
| Dropout SAM | 0.5025 | 0.6799 | 0.5150 |
| Probabilistic U-Net | 0.3349 | 0.5818 | 0.5557 |
| Probabilistic SAM | 0.2910 | 0.8255 | 0.7849 |
We evaluate Probabilistic SAM on the task of lung nodule segmentation using the LIDC-IDRI dataset. This dataset contains thoracic CT scans along with ground truth annotations from four expert radiologists.
The code has been written in Python using the Pytorch framework. Training requries a GPU. To train your own Probabilistic SAM, simply clone this repository and run main.py.
Thanks to Stefan Knegt for open-sourcing his Pytorch implementation of Probabilistic U-Net, which served as a helpful guide in the development of Probabilistic SAM, and for providing a link to pre-processed LIDC-IDRI data.