The official implementation of ICML 2025 paper (poster) Bayesian Basis Function Approximation for Scalable Gaussian Process Priors in Deep Generative Models.
Overview of DGBFGP. (a) A conditional generative model incorporates covariate information. (b) Additive GP priors are used but are not scalable. (c) We replace exact GP with a basis function approximation for scalability. (d) Samples are generated by the decoder.
High-dimensional time-series datasets are common in domains such as healthcare and economics. Variational autoencoder (VAE) models, where latent variables are modeled with a Gaussian process (GP) prior, have become a prominent model class to analyze such correlated datasets. However, their applications are challenged by the inherent cubic time complexity that requires specific GP approximation techniques, as well as the general challenge of modeling both shared and individual-specific correlations across time. Though inducing points enhance GP prior VAE scalability, optimizing them remains challenging, especially since discrete covariates resist gradient‑based methods. In this work, we propose a scalable basis function approximation technique for GP prior VAEs that mitigates these challenges and results in linear time complexity, with a global parametrization that eliminates the need for amortized variational inference and the associated amortization gap, making it well-suited for conditional generation tasks where accuracy and efficiency are crucial. Empirical evaluations on synthetic and real-world benchmark datasets demonstrate that our approach not only improves scalability and interpretability but also drastically enhances predictive performance.
To install requirements:
conda env create -f environment.yml
First, create a directory to store all datasets
mkdir ../data
Dowload MNIST digits and unzip to data folder.
Clone the original SPRITES repository and follow the data generation steps.
- Rotated MNIST
python Rotated_MNIST_generate.py
- Health MNIST
python Health_MNIST_generate.py
- Physionet
python Physionet_generate.py
- SPRITES
python SPRITES_generate.py
To train DGBFGP on different datasets, run these commands:
- Rotated MNIST
python main.py --f=./config/config_RotatedMNIST.txt
- Health MNIST
python main.py --f=./config/config_HealthMNIST.txt
- Physionet
python main.py --f=./config/config_Physionet.txt
- SPRITES
python main.py --f=./config/config_SPRITES.txt
Config files in config folder contains hyperparameter choices for the training and modeling choices.
| Argument | Explanation |
|---|---|
| train_data_source_path | Path to training data |
| val_data_source_path | Path to validation data |
| test_data_source_path | Path to test data |
| csv_file_data | File that contains data (with missing values if any) |
| csv_file_data_gt | File that contains ground truth data |
| csv_file_label | File that contains auxiliary covariates |
| mask_file | File that contains the mask for missing values (if any) |
| dataset_type | Name of the dataset |
| P | Number of subjects |
| T | Number of time points |
| y_num_dim | Data space dimensionality |
| x_num_dim | Auxiliary data space dimensionality |
| M | Number of basis functions |
| C | Number of categories for the categorical aux. variable in each interaction kernel, e.g., [2, 4] |
| latent_dim | Latent space dimensionality |
| se_idx | Indices of auxiliary dimensions to use for SE kernel |
| interactions | List of interactions between auxiliary dimensions e.g., [[0, 3], [1, 4]] |
| id_covariate | Index of the id covariate |
| lr | Learning rate |
| n_epoch | Number of epochs |
| batch_size | Batch size |
| output_dir | Output directory |
| model_file | File to save the trained model |
@inproceedings{
balk2025bayesian,
title={Bayesian Basis Function Approximation for Scalable Gaussian Process Priors in Deep Generative Models},
author={Mehmet Yi{\u{g}}it Bal{\i}k and Maksim Sinelnikov and Priscilla Ong and Harri L{\"a}hdesm{\"a}ki},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=dewZTXKwli}
}
This project is licensed under the MIT License.