glmhmm: Generalized Linear Model Hidden Markov Models in JAX
glmhmm is a Python package for fitting GLM-HMMs using JAX. It allows latent state transitions and emission probabilities to be driven by time-varying external covariates (GLMs), making it ideal for analyzing complex time-series data such as neural recordings or behavioral choices.
pip install git+https://github.com/YSanchezAraujo/glmhmm.gitThe core data structure is GLMHMMData, which holds lists of arrays. This design supports fitting multiple independent sequences (e.g., different experimental sessions) simultaneously.
-
Lists of Arrays: All fields are
List[jnp.ndarray], where each element corresponds to one sequence (session). -
Matching Lengths: For a given sequence
$i$ , all active fields (transition inputs, emission inputs, observations) must have compatible time dimensions$T_i$ . -
Missing Data: If you don't have data for a specific modality (e.g., no Gaussian data), it will default to
Noneand will be ignored during fitting.
| Field | Shape per Sequence | Description |
|---|---|---|
x_trans |
Covariates driving state transitions. | |
x_bern |
Covariates for Bernoulli emissions. | |
y_bern |
Binary observations (0 or 1). | |
x_multinom |
Covariates for Multinomial emissions. | |
y_multinom |
Integer class labels |
|
x_negbin |
Covariates for Negative Binomial emissions. | |
y_negbin |
Integer counts |
|
x_gamma |
Covariates for Gamma emissions. | |
y_gamma |
Positive continuous values |
|
x_gauss |
Covariates for Gaussian emissions (e.g. aligned to neural samples). | |
y_gauss |
Continuous multivariate observations. |
Note: For Gaussian data, we support "variable rate" alignment where multiple Gaussian samples map to a single HMM state step. See
gauss_first_idx/gauss_last_idxintest_fit_compatibility.pyfor advanced usage. For standard 1-to-1 mapping, these indices are not needed.
Note: The model is written such that the transition inputs operate on the assumption of
$t \rightarrow t+1$ . This implies that your transition inputs are aligned to the transition step (i.e. they correctly reflect$t \rightarrow t+1$ and so they will be one less than the total number of observations).
Modeling an agent's binary choices (Left/Right) driven by Bias and Stimulus strength.
import jax
import jax.numpy as jnp
from glmhmm import fit_glmhmm
from glmhmm.structs import GLMHMMData, GLMHMMConfig
from glmhmm.utils import simulate_data_bernoulli_glmhmm
from glmhmm.m_step import compute_A_from_theta_and_inputs
from glmhmm.fit import e_step
key = jax.random.PRNGKey(42)
N_states = 3
T = 1000
N_bern_features = 2
N_trans_features = 2
# Generate synthetic data
true_params, sim_data = simulate_data_bernoulli_glmhmm(
key, N_states, T, N_bern_features, N_trans_features
)
# Unwrap simulated data for GLMHMMData
# x_inputs from utils is a list of arrays (one per step), so we concatenate
# to get a single array of shape (T-1, D) for the session.
x_trans_concat = jnp.concatenate(sim_data.x_inputs, axis=0)
data = GLMHMMData(
x_bern=[sim_data.x_bern],
y_bern=[sim_data.y_bern],
x_trans=[x_trans_concat] # Transitions driven by inputs (length T-1)
)
config = GLMHMMConfig(
num_states=N_states,
num_em_iters=100,
l2_reg_bern=1.0, # Regularization for emission weights
l2_reg_trans=0.5 , # Regularization for transition weights
tol = 1e-6
)
params, lml_history = fit_glmhmm(data, config)
print(f"Converged LML: {lml_history[-1]}")
print("Bernoulli Weights (D x K):")
print(params.W_bern)
# Get the posteriors
gammas, xis, final_lml = e_step(params, data)
# Compute the input-driven transition matrix
A_timevar = compute_A_from_theta_and_inputs(x_trans_concat, params.theta)Modeling choices between 3 options (e.g., Left, Right, No-Go).
# X shape: (T, D)
# y shape: (T,) with values {0, 1, 2}
data_multi = GLMHMMData(
x_multinom=[X_session1],
y_multinom=[y_session1]
)
config_multi = GLMHMMConfig(
num_states=2,
num_em_iters=50,
l2_reg_multinom=1.0
)
params, _ = fit_glmhmm(data_multi, config_multi)
# W_multinom shape: (D, C, K)
# The last category (C-1) is fixed to 0 for identifiability.
print(params.W_multinom.shape) Compare models with different number of states.
from glmhmm.model_selection_factorized_bq import compute_bic
for k in [2, 3, 4]:
cfg = GLMHMMConfig(num_states=k, num_em_iters=100)
params, h = fit_glmhmm(data, cfg)
# Compute BIC
bic = compute_bic(data, params, max_log_lik=h[-1])
print(f"K={k}, BIC={bic:.2f}")For rigorous model comparison, use compute_factorized_evidence to estimate the marginal likelihood
from glmhmm.fit import e_step
from glmhmm.model_selection_factorized_bq import compute_factorized_evidence
# 1. Get posteriors from fitted model
gammas, xis, _ = e_step(params, data)
# 2. Run BQ
log_ev, var = compute_factorized_evidence(
data, params, config, gammas, xis
)