Skip to content

andytimm/nutpieR

Repository files navigation

nutpieR

R bindings for the nuts-rs NUTS sampler, using BridgeStan for Stan model evaluation. The R equivalent of the official nutpie (Python).

R is not in the hot loop -- parallel chains run entirely in Rust.

Installation

Precompiled binaries are available from R-universe (no Rust toolchain needed):

install.packages("nutpieR",
  repos = c("https://andytimm.r-universe.dev", "https://cloud.r-project.org"))

(CRAN is included so the imports — digest, jsonlite, posterior — resolve; the andytimm.r-universe.dev repo only hosts nutpieR itself.)

Or install from source (requires a Rust toolchain):

remotes::install_github("andytimm/nutpieR")

System requirements

  • C++ toolchain (same as what Stan needs — Stan models are compiled to C++ at runtime via BridgeStan):
    • Windows: Rtools
    • macOS: Xcode Command Line Tools (xcode-select --install)
    • Linux: build-essential (Debian/Ubuntu) or equivalent
  • Rust: rustc >= 1.85 and cargo (install via rustup) — only needed when installing from source

Usage

library(nutpieR)

# Compile a Stan model (inline here for a runnable demo; .stan files work too)
model <- nutpie_compile_model(code = "
  data { int<lower=0> N; array[N] int<lower=0,upper=1> y; }
  parameters { real<lower=0,upper=1> theta; }
  model { theta ~ beta(1, 1); y ~ bernoulli(theta); }
")

# Sample
draws <- nutpie_sample(
  model,
  data = list(N = 10, y = c(0, 1, 0, 0, 0, 0, 0, 0, 0, 1)),
  num_draws = 1000,
  num_chains = 4,
  seed = 604
)

# Returns a posterior::draws_array -- works with all posterior/bayesplot tools
posterior::summarize_draws(draws)

After sampling:

dim(draws)                          # (num_draws, num_chains, n_params)
posterior::variables(draws)
attr(draws, "num_warmup")
nutpie_diagnostics(draws)           # divergences, treedepth, energy, ...

In practice, you'll usually compile from a .stan file:

model <- nutpie_compile_model(stan_file = "my_model.stan")

Diagnostics

# Sampler diagnostics (divergences, treedepth, energy, etc.)
nutpie_diagnostics(draws)

# Access warmup draws (if save_warmup = TRUE)
dat <- list(N = 10, y = c(0, 1, 0, 0, 0, 0, 0, 0, 0, 1))
draws <- nutpie_sample(model, data = dat, save_warmup = TRUE)
nutpie_warmup_draws(draws)
nutpie_warmup_diagnostics(draws)

Sampling parameters

draws <- nutpie_sample(
  model,
  data = dat,
  num_draws = 1000,       # post-warmup draws per chain
  num_warmup = 1000,      # warmup draws per chain
  num_chains = 4,         # number of chains
  cores = 4,              # parallel cores
  seed = 604,             # RNG seed
  max_treedepth = 10,     # maximum tree depth
  target_accept = 0.8,    # target acceptance rate
  refresh = 100,          # progress every N draws (0 = off)
  save_warmup = FALSE,    # save warmup draws?
  store_divergences = FALSE,  # store divergence details?
  store_mass_matrix = FALSE   # store mass matrix?
)

Low-rank mass matrix adaptation

For models with correlated parameters, nutpieR supports low-rank modified mass matrix adaptation from nuts-rs. This captures posterior correlations more effectively than the default diagonal mass matrix, and can significantly improve sampling efficiency on challenging geometries.

draws <- nutpie_sample(
  model,
  data = dat,
  adaptation = "low_rank",               # enable low-rank adaptation
  mass_matrix_gamma = 1e-5,              # regularisation (default)
  mass_matrix_eigval_cutoff = 2.0        # eigenvalue cutoff (default)
)

(low_rank_modified_mass_matrix = TRUE still works but is deprecated.) Mass-matrix and warmup defaults inherit from nuts-rs; pass num_warmup explicitly to override.

How it works

nutpieR compiles Stan models via the BridgeStan Rust crate and samples using the nuts-rs NUTS sampler. During sampling, Rust calls the compiled Stan shared library directly through BridgeStan's C ABI -- R is not involved in the sampling loop. Each chain runs on its own thread via rayon.

Results are transferred from Rust to R via Apache Arrow with a single copy into R-allocated memory (no extra intermediate buffer), and returned as a posterior::draws_array.

Compiled artifacts are cached, matching cmdstanr's convention -- repeat calls return in <1s. See ?nutpie_compile_model for cache controls and ?nutpie_clear_cache to invalidate.

License

MIT

About

No description, website, or topics provided.

Resources

License

Unknown, MIT licenses found

Licenses found

Unknown
LICENSE
MIT
LICENSE.md

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors