Physics-informed neural networks and score-based inference for diffusion MRI microstructure.
Validated on WAND (Welsh Advanced Neuroimaging Database) — Siemens Connectom 300 mT/m scanner.
The Van Gelderen restricted diffusion model recovers compartment geometry from multi-delta AxCaliber data. This is a proper PINN: the physics of restricted diffusion inside cylinders constrains the inverse problem. Stejskal-Tanner (Gaussian diffusion) cannot recover axon radius.
| Parameter | Recovered | Expected WM |
|---|---|---|
| Axon radius R | 3.15 um | 2-5 um |
| Intra-cellular fraction | 0.46 | 0.4-0.7 |
| D intra-cellular | 4.6e-10 m^2/s | 1-2e-9 |
| Fiber direction | [0.98, 0.14, 0.14] | — |
Data: sub-00395, 4 AxCaliber acquisitions (delta = 18/30/42/55 ms, b up to 15,500 s/mm^2). Training: 168 seconds on DGX Spark Grace CPU.
Recovers spatially-varying D(x) from CHARMED data (7 shells, b=0-6000 s/mm^2) using the Stejskal-Tanner equation with direction-dependent signal prediction. Not a PINN — no PDE residual, just physics-based signal model.
| Metric | Recovered | Expected WM |
|---|---|---|
| MD | 7.4e-10 m^2/s | ~0.7e-9 |
| FA | 0.42 | 0.4-0.7 |
Key insight: log-space loss (not MSE) is critical for correct MD.
| Test | Result |
|---|---|
| Microstructure.jl compartments (Cylinder, Zeppelin, Iso, Sphere) | PASS at 1e-13 |
| KomaMRI signal properties (1000 random configurations) | 1000/1000 PASS |
| Van Gelderen restricted diffusion (112 physics tests) | ALL PASS |
| Capability | Method | Status |
|---|---|---|
| Composable signal models | G1Ball, C1Stick, G2Zeppelin, S1Dot + multi-compartment | Tested, dmipy-compatible |
| Orientation dispersion | Watson distribution on Fibonacci sphere grid | Tested |
| Multi-compartment fitting | NLLS via Optimisers.jl | Working |
| Axon radius estimation | AxCaliber PINN (Van Gelderen) | Validated on real data |
| Diffusion tensor field | Neural field + Stejskal-Tanner | FA=0.42, MD correct |
| Score-based posteriors | Denoising score matching + DDPM | 12.8 deg orientation error |
| Forward model surrogate | Supervised MLP regression | 0.96% error, spec passed |
| Mixture Density Networks | Gaussian mixture posteriors (Lux.jl) | Tested |
| Simulation-Based Calibration | Rank histogram diagnostics (Talts et al.) | Tested |
| Conformal prediction | Distribution-free coverage guarantees | Tested |
| OOD detection | Reconstruction error + Mahalanobis + entropy | Tested |
| Posterior predictive checks | Model adequacy diagnostics | Tested |
| Native SDE/ODE samplers | DifferentialEquations.jl | EM 21k samples/s |
| Phase processing | ROMEO unwrapping, B0 mapping, T2*/R2* (MriResearchTools) | Validated on WAND 7T MEGRE |
| Optimal Experimental Design | Fisher information, D/A/E-optimality, CRLB | Phase 3a complete |
| FEM Bloch-Torrey | SpinDoctor.jl for validation + future differentiable inversion | Integrated |
| Compartment Models | Composition | Inference |
|---|---|---|
G1Ball (isotropic)C1Stick (intra-axonal)G2Zeppelin (extra-cellular)S1Dot (stationary water)
|
MultiCompartmentModelConstrainedModelWatsonDistributionDistributedModel
|
fit_mcm (NLLS)ScoreNetwork (FiLM)MixtureDensityNetworksample_posterior
|
| PINNs | Tensor Field Recovery | Validation |
AxCaliberDatabuild_axcaliber_pinntrain_axcaliber_pinn!BlochTorreyResidual
|
DiffusionFieldProblemsolve_diffusion_field_v2extract_maps (FA, MD)
|
SBC · Conformal · OOD · PPC KomaMRI oracle SpinDoctor.jl FEM oracle Microstructure.jl compat |
| Phase Processing | Experimental Design | Simulation Backends |
process_phaseROMEO unwrapping B0/T2*/R2* mapping Brain masking + bias correction |
fisher_informationcrlb (Cramer-Rao)D/A/E-optimality compare_protocols
|
MCMRSimulator (Monte Carlo) KomaMRI (Bloch sequences) SpinDoctor (FEM Bloch-Torrey) |
using Pkg
Pkg.develop(url="https://github.com/m9h/dmijl")GPU support auto-detects via LuxCUDA:
using DMI
dev = select_device() # auto-detects GPU or falls back to CPUusing DMI
# Build a Ball+Stick model (like dmipy)
mcm = MultiCompartmentModel([C1Stick(), G1Ball()])
# Add constraints
cm = ConstrainedModel(mcm)
set_fixed_parameter(cm, "G1Ball_lambda_iso", 3.0e-9) # fix CSF diffusivity
# Fit to observed signal
acq = load_acquisition("data.bval", "data.bvec")
result = fit_mcm(cm, acq, signal; n_restarts=5)watson = WatsonDistribution(; n_grid=300)
dm = DistributedModel(C1Stick(), watson)
mcm = MultiCompartmentModel([dm, G1Ball()])
# Fits: lambda_par, mu, kappa (dispersion), lambda_iso, volume fractionsusing DMI, Lux, Random
# Load multi-delta AxCaliber data
data = AxCaliberData(
signals, # 4 signal vectors (one per acquisition)
bvalues, # 4 b-value vectors
bvecs, # 4 gradient direction matrices
deltas, # [11e-3, 11e-3, 11e-3, 11e-3] (small delta)
Deltas, # [18e-3, 30e-3, 42e-3, 55e-3] (big delta)
)
# Build and train PINN
model = build_axcaliber_pinn(; signal_dim=264, hidden_dim=128, depth=5)
ps, st = Lux.setup(MersenneTwister(42), model)
ps, st, geom, losses = train_axcaliber_pinn!(model, ps, st, data;
n_steps=5000, lambda_physics=1.0)
# geom.R — axon radius (meters)
# geom.D_intra — intra-cellular diffusivity
# geom.f_intra — intra-cellular fraction
# geom.mu — fiber orientation (unit vector)using DMI
problem = DiffusionFieldProblem(signal, bvalues, bvecs, delta, Delta, T2, voxel_size)
result = solve_diffusion_field_v2(problem;
output_type = :diagonal,
n_steps = 5000,
)
maps = extract_maps(result; grid_resolution=8)
# maps.FA, maps.MDSBI4DWI (Python/JAX) — normalizing flow NPE and score-based posteriors for the same microstructure models. Achieves 2.8 deg median orientation error on Ball+2Stick with neural spline flows (300k steps), meeting the Nottingham paper target.
The two projects share:
- Same forward models (Ball+Stick, NODDI, DTI)
- Same WAND Connectom validation data
- Cross-validated against Microstructure.jl (Ting Gong, MGH/Martinos)
DMI.jl integrates three complementary physics simulation backends, each providing independent ground truth for different aspects of the diffusion signal:
| Backend | Method | Strengths | Use in DMI.jl |
|---|---|---|---|
| MCMRSimulator.jl | Monte Carlo random walk | Fast geometry sweeps, packed substrates | Training data generation |
| KomaMRI.jl | Bloch equation simulation | Arbitrary pulse sequences, GPU | Sequence-level signal validation |
| SpinDoctor.jl | FEM Bloch-Torrey PDE | Per-compartment D, membrane permeability, neuron morphologies | Ground truth for restricted diffusion |
SpinDoctor solves the Bloch-Torrey partial differential equation on tetrahedral FEM meshes, giving deterministic (noise-free) diffusion signals for arbitrary 3D tissue geometries. We maintain a modernized fork with Julia 1.12 compatibility, Makie as an optional extension, and split ODE packages.
This is the same core physics as ReMiDi (Khole et al. 2025, BioMedAI-UCSC) and its successor Spinverse (2026), which implement a differentiable PyTorch version of SpinDoctor to recover 3D axonal geometries (bending, beading, fanning fibers) directly from dMRI signals via backpropagation through the physics. Jing-Rebecca Li (INRIA), the original SpinDoctor author, is a co-author on both.
The Julia ecosystem is well-positioned to replicate this capability: Julia's AD tools (Enzyme.jl, Zygote.jl) can differentiate through the FEM assembly and ODE solve natively, potentially with less engineering effort than the PyTorch re-implementation. Making SpinDoctor.jl differentiable end-to-end is a natural direction for DMI.jl, connecting simulation-based inference (our score-based posteriors and MDNs) with mesh-level microstructure reconstruction.
| Package | Role |
|---|---|
| Lux.jl | Neural network layers (pure functional) |
| Zygote.jl | Automatic differentiation |
| DifferentialEquations.jl | SDE/ODE solvers for reverse diffusion |
| KomaMRI.jl | Bloch simulation validation oracle |
| MCMRSimulator.jl | Monte Carlo forward simulation (Cottaar/Jbabdi/Miller, FMRIB) |
| SpinDoctor.jl | FEM Bloch-Torrey PDE (Li et al., NeuroImage 2019) |
| ROMEO.jl | Phase unwrapping (Dymerska et al., MRM 2020) |
| MriResearchTools.jl | Masking, bias correction, T2*, coil combination |
| Microstructure.jl | Cross-validation reference (Ting Gong, MGH/Martinos) |
See results/LEADERBOARD.md for all results across
both DMI.jl and SBI4DWI, including autoresearch sweeps.
MIT
Built with Julia's SciML ecosystem. Validated on WAND Connectom data.