diff --git a/.conda/bin/build b/.conda/bin/build index 4171700db..1967ddb7f 100755 --- a/.conda/bin/build +++ b/.conda/bin/build @@ -5,7 +5,7 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" python_version=$1 if [ -z "$python_version" ] then - python_version=3.6 + python_version=3.7 fi export BRAINIAK_HOME=$DIR/../../ diff --git a/.conda/build.sh b/.conda/build.sh index c2a562cbf..68ae389f1 100755 --- a/.conda/build.sh +++ b/.conda/build.sh @@ -1,6 +1,8 @@ #!/bin/bash -# Install pymanopt via pip because there isn't a conda package +# Install from PyPI because there is no current conda package for the +# following. Explicitly install dependencies with no conda package as well +# because otherwise conda-build does not include them in the output package. PIP_NO_INDEX=False $PYTHON -m pip install pymanopt # NOTE: This is the recommended way to install packages diff --git a/.conda/meta.yaml b/.conda/meta.yaml index bbd2c4305..971ee237f 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -42,10 +42,6 @@ requirements: if req not in conda_package_nonexistent -%} - {{req}} {% endfor %} - {% for req in data.get('install_requires', []) - if req not in conda_package_nonexistent -%} - - {{req}} - {% endfor %} run: - python @@ -59,6 +55,7 @@ requirements: test: commands: - find $BRAINIAK_HOME/tests | grep pycache | xargs rm -rf + - pip install tensorflow tensorflow-probability - mpiexec -n 2 pytest $BRAINIAK_HOME # Known issue: https://github.com/travis-ci/travis-ci/issues/4704#issuecomment-348435959 diff --git a/README.rst b/README.rst index e4338eaae..0efdb9c61 100644 --- a/README.rst +++ b/README.rst @@ -33,7 +33,14 @@ If you have `Conda `_:: Otherwise, or if you want to compile from source, install the requirements (see `docs/installation`) and then install from PyPI:: - python3 -m pip install --no-use-pep517 brainiak + python3 -m pip install brainiak + +Note that to use the ``brainiak.matnormal`` package, you need to install +additional dependencies. As of October 2020, the required versions are not +available as Conda packages, so you should install from PyPI, even when using +Conda:: + + python3 -m pip install -U tensorflow tensorflow-probability Note that we do not support Windows. diff --git a/brainiak/funcalign/sssrm.py b/brainiak/funcalign/sssrm.py index 1fe3bd073..b43e94288 100644 --- a/brainiak/funcalign/sssrm.py +++ b/brainiak/funcalign/sssrm.py @@ -40,6 +40,8 @@ from pymanopt.solvers import ConjugateGradient from pymanopt import Problem from pymanopt.manifolds import Stiefel +import pymanopt + import gc from brainiak.utils import utils @@ -57,6 +59,10 @@ # https://github.com/pymc-devs/pymc3/pull/3767 theano.config.gcc.cxxflags = "-Wno-c++11-narrowing" +# FIXME workaround for pymanopt only working with tensorflow 1. +# We don't use pymanopt+TF so we just let pymanopt pretend TF doesn't exist. +pymanopt.tools.autodiff._tensorflow.tf = None + class SSSRM(BaseEstimator, ClassifierMixin, TransformerMixin): """Semi-Supervised Shared Response Model (SS-SRM) diff --git a/brainiak/matnormal/__init__.py b/brainiak/matnormal/__init__.py new file mode 100644 index 000000000..458caf473 --- /dev/null +++ b/brainiak/matnormal/__init__.py @@ -0,0 +1,252 @@ +""" +Some properties of the matrix-variate normal distribution +--------------------------------------------------------- + +.. math:: + \\DeclareMathOperator{\\Tr}{Tr} + \\newcommand{\\trp}{^{T}} % transpose + \\newcommand{\\trace}{\\text{Trace}} % trace + \\newcommand{\\inv}{^{-1}} + \\newcommand{\\mb}{\\mathbf{b}} + \\newcommand{\\M}{\\mathbf{M}} + \\newcommand{\\C}{\\mathbf{C}} + \\newcommand{\\G}{\\mathbf{G}} + \\newcommand{\\A}{\\mathbf{A}} + \\newcommand{\\R}{\\mathbf{R}} + \\renewcommand{\\S}{\\mathbf{S}} + \\newcommand{\\B}{\\mathbf{B}} + \\newcommand{\\Q}{\\mathbf{Q}} + \\newcommand{\\mH}{\\mathbf{H}} + \\newcommand{\\U}{\\mathbf{U}} + \\newcommand{\\mL}{\\mathbf{L}} + \\newcommand{\\diag}{\\mathrm{diag}} + \\newcommand{\\etr}{\\mathrm{etr}} + \\renewcommand{\\H}{\\mathbf{H}} + \\newcommand{\\vecop}{\\mathrm{vec}} + \\newcommand{\\I}{\\mathbf{I}} + \\newcommand{\\X}{\\mathbf{X}} + \\newcommand{\\Y}{\\mathbf{Y}} + \\newcommand{\\Z}{\\mathbf{Z}} + \\renewcommand{\\L}{\\mathbf{L}} + + +The matrix-variate normal distribution is a generalization to matrices of the +normal distribution. Another name for it is the multivariate normal +distribution with kronecker separable covariance. +The distributional intuition is as follows: if +:math:`X \\sim \\mathcal{MN}(M,R,C)` then +:math:`\\mathrm{vec}(X)\\sim\\mathcal{N}(\\mathrm{vec}(M), C \\otimes R)`, +where :math:`\\mathrm{vec}(\\cdot)` is the vectorization operator and +:math:`\\otimes` is the Kronecker product. If we think of X as a matrix of TRs +by voxels in the fMRI setting, then this model assumes that each voxel has the +same TR-by-TR covariance structure (represented by the matrix R), +and each volume has the same spatial covariance (represented by the matrix C). +This assumption allows us to model both covariances separately. +We can assume that the spatial covariance itself is kronecker-structured, +which implies that the spatial covariance of voxels is the same in the X, +Y and Z dimensions. + +The log-likelihood for the matrix-normal density is: + +.. math:: + \\log p(X\\mid \\M,\\R, \\C) = -2\\log mn - m \\log|\\C| - n \\log|\\R| - + \\Tr\\left[\\C\\inv(\\X-\\M)\\trp\\R\\inv(\\X-\\M)\\right] + +Here :math:`X` and :math:`M` are both :math:`m\\times n` matrices, :math:`\\R` +is :math:`m\\times m` and :math:`\\C` is :math:`n\\times n`. + +The `brainiak.matnormal` package provides structure to infer models that +can be stated in the matrix-normal notation that are useful for fMRI analysis. +It provides a few interfaces. `MatnormModelBase` is intended as a +base class for matrix-variate models. It provides a wrapper for the tensorflow +optimizer that provides convergence checks based on thresholds on the function +value and gradient, and simple verbose outputs. It also provides an interface +for noise covariances (`CovBase`). Any class that follows the interface +can be used as a noise covariance in any of the matrix normal models. The +package includes a variety of noise covariances to work with. + +Matrix normal marginals +------------------------- + +Here we extend the multivariate gaussian marginalization identity to matrix +normals. This is used in a number of the models in the package. Below, we +use lowercase subscripts for sizes to make dimensionalities easier to track. +Uppercase subscripts for covariances help keep track where they come from. + +.. math:: + \\mathbf{X}_{ij} &\\sim \\mathcal{MN}(\\mathbf{A}_{ij}, + \\Sigma_{\\mathbf{X}i},\\Sigma_{\\mathbf{X}j})\\\\ + \\mathbf{Y}_{jk} &\\sim \\mathcal{MN}(\\mathbf{B}_{jk}, + \\Sigma_{\\mathbf{Y}j},\\Sigma_{\\mathbf{Y}k})\\\\ + \\mathbf{Z}_{ik}\\mid\\mathbf{X}_{ij},\\mathbf{Y}_{jk} &\\sim + \\mathcal{MN}(\\mathbf{X}_{ij}\\mathbf{Y}_{jk} + \\mathbf{C}_{ik}, + \\Sigma_{\\mathbf{Z}_i}, \\Sigma_{\\mathbf{Z}_k})\\\\ + + +We vectorize, and convert to a form we recognize as +:math:`y \\sim \\mathcal{N}(Mx+b, \\Sigma)`. + +.. math:: + \\vecop(\\mathbf{Z}_{ik})\\mid\\mathbf{X}_{ij},\\mathbf{Y}_{jk} &\\sim + \\mathcal{N}(\\vecop(\\X_{ij}\\mathbf{Y}_{jk}+\\mathbf{C}_{ik}), + \\Sigma_{\\mathbf{Z}_k}\\otimes\\Sigma_{\\mathbf{Z}_i})\\\\ + \\vecop(\\mathbf{Z}_{ik})\\mid\\mathbf{X}_{ij},\\mathbf{Y}_{jk} + &\\sim \\mathcal{N}((\\I_k\\otimes\\X_{ij})\\vecop(\\mathbf{Y}_{jk}) + + \\vecop(\\mathbf{C}_{ik}), + \\Sigma_{\\mathbf{Z}_k}\\otimes\\Sigma_{\\mathbf{Z}_i}) + + +Now we can use our standard gaussian marginalization identity: + +.. math:: + \\vecop(\\mathbf{Z}_{ik})\\mid\\mathbf{X}_{ij} \\sim + \\mathcal{N}((\\I_k\\otimes\\X_{ij})\\vecop(\\mathbf{B}_{jk}) + + \\vecop(\\mathbf{C}_{ik}), + \\Sigma_{\\mathbf{Z}_k}\\otimes\\Sigma_{\\mathbf{Z}_i} + + (\\I_k\\otimes\\X_{ij})(\\Sigma_{\\mathbf{Y}_k}\\otimes + \\Sigma_{\\mathbf{Y}_j})(\\I_k\\otimes\\X_{ij})\\trp ) + + +Collect terms using the mixed-product property of kronecker products: + +.. math:: + \\vecop(\\mathbf{Z}_{ik})\\mid\\mathbf{X}_{ij} \\sim + \\mathcal{N}(\\vecop(\\X_{ij}\\mathbf{B}_{jk}) + + \\vecop(\\mathbf{C}_{ik}), \\Sigma_{\\mathbf{Z}_k}\\otimes + \\Sigma_{\\mathbf{Z}_i} + \\Sigma_{\\mathbf{Y}_k}\\otimes + \\X_{ij}\\Sigma_{\\mathbf{Y}_j}\\X_{ij}\\trp) + + +Now, we can see that the marginal density is a matrix-variate normal only if +:math:`\\Sigma_{\\mathbf{Z}_k}= \\Sigma_{\\mathbf{Y}_k}` -- that is, the +variable we're marginalizing over has the same covariance in the dimension +we're *not* marginalizing over as the marginal density. Otherwise the densit +is well-defined but the covariance retains its kronecker structure. So we let +:math:`\\Sigma_k:=\\Sigma_{\\mathbf{Z}_k}= \\Sigma_{\\mathbf{Y}_k}`, factor, +and transform it back into a matrix normal: + +.. math:: + \\vecop(\\mathbf{Z}_{ik})\\mid\\mathbf{X}_{ij} &\\sim + \\mathcal{N}(\\vecop(\\X\\mathbf{B}_{jk}) + \\vecop(\\mathbf{C}_{ik}), + \\Sigma_{k}\\otimes\\Sigma_{\\mathbf{Z}_i} + \\Sigma_{_k}\\otimes + \\X\\Sigma_{\\mathbf{Y}_j}\\X\\trp)\\\\ + \\vecop(\\mathbf{Z}_{ik})\\mid\\mathbf{X}_{ij} &\\sim + \\mathcal{N}(\\vecop(\\X\\mathbf{B}_{jk}) + \\vecop(\\mathbf{C}_{ik}), + \\Sigma_{k}\\otimes(\\Sigma_{\\mathbf{Z}_i} + +\\X\\Sigma_{\\mathbf{Y}_j}\\X\\trp))\\\\ + \\mathbf{Z}_{ik}\\mid\\mathbf{X}_{ij} &\\sim + \\mathcal{MN}(\\X\\mathbf{B}_{jk} + \\mathbf{C}_{ik}, + \\Sigma_{\\mathbf{Z}_i} +\\X\\Sigma_{\\mathbf{Y}_j}\\X\\trp,\\Sigma_{k}) + + +We can do it in the other direction as well, because if +:math:`\\X \\sim \\mathcal{MN}(M, U, V)` then :math:`\\X\\trp \\sim +\\mathcal{MN}(M\\trp, V, U)`: + +.. math:: + \\mathbf{Z\\trp}_{ik}\\mid\\mathbf{X}_{ij},\\mathbf{Y}_{jk} &\\sim + \\mathcal{MN}(\\mathbf{Y}_{jk}\\trp\\mathbf{X}_{ij}\\trp + + \\mathbf{C}\\trp_{ik}, \\Sigma_{\\mathbf{Z}_k},\\Sigma_{\\mathbf{Z}_i})\\\\ + \\mbox{let } \\Sigma_i := + \\Sigma_{\\mathbf{Z}_i}=\\Sigma_{\\mathbf{X}_i} \\\\ + \\cdots\\\\ + \\mathbf{Z\\trp}_{ik}\\mid\\mathbf{Y}_{jk} &\\sim + \\mathcal{MN}(\\mathbf{A}_{jk}\\trp\\mathbf{X}_{ij}\\trp + + \\mathbf{C}\\trp_{ik}, \\Sigma_{\\mathbf{Z}_k} + + \\Y\\trp\\Sigma_{\\mathbf{Y}_j}\\Y,\\Sigma_{\\mathbf{Z}_i})\\\\ + \\mathbf{Z}_{ik}\\mid\\mathbf{Y}_{jk} &\\sim + \\mathcal{MN}(\\mathbf{X}_{ij}\\mathbf{A}_{jk}+ + \\mathbf{C}_{ik},\\Sigma_{\\mathbf{Z}_i},\\Sigma_{\\mathbf{Z}_k} + + \\Y\\trp\\Sigma_{\\mathbf{Y}_j}\\Y) + +These marginal likelihoods are implemented relatively efficiently in +`MatnormModelBase.matnorm_logp_marginal_row` and +`MatnormModelBase.matnorm_logp_marginal_col`. + +Partitioned matrix normal conditionals +-------------------------------------- + +Here we extend the multivariate gaussian conditional identity to matrix +normals. This is used for prediction in some models. Below, we +use lowercase subscripts for sizes to make dimensionalities easier to track. +Uppercase subscripts for covariances help keep track where they come from. + + +Next, we do the same for the partitioned gaussian identity. First two +vectorized matrix-normals that form our partition: + +.. math:: + \\mathbf{X}_{ij} &\\sim \\mathcal{MN}(\\mathbf{A}_{ij}, \\Sigma_{i}, + \\Sigma_{j}) \\rightarrow \\vecop[\\mathbf{X}_{ij}] \\sim + \\mathcal{N}(\\vecop[\\mathbf{A}_{ij}], \\Sigma_{j}\\otimes\\Sigma_{i})\\\\ + \\mathbf{Y}_{ik} &\\sim \\mathcal{MN}(\\mathbf{B}_{ik}, \\Sigma_{i}, + \\Sigma_{k}) \\rightarrow \\vecop[\\mathbf{Y}_{ik}] \\sim + \\mathcal{N}(\\vecop[\\mathbf{B}_{ik}], \\Sigma_{k}\\otimes\\Sigma_{i})\\\\ + \\begin{bmatrix}\\vecop[\\mathbf{X}_{ij}] \\\\ \\vecop[\\mathbf{Y}_{ik}] + \\end{bmatrix} + & \\sim \\mathcal{N}\\left(\\vecop\\begin{bmatrix}\\mathbf{A}_{ij} + \\\\ \\mathbf{B}_{ik} + \\end{bmatrix} + , \\begin{bmatrix} \\Sigma_{j}\\otimes \\Sigma_i & + \\Sigma_{jk} \\otimes \\Sigma_i \\\\ + \\Sigma_{kj}\\otimes \\Sigma_i & \\Sigma_{k} \\otimes + \\Sigma_i\\end{bmatrix}\\right) + +We apply the standard partitioned Gaussian identity and simplify using the +properties of the :math:`\\vecop` operator and the mixed product property +of kronecker products: + +.. math:: + \\vecop[\\X_{ij}] \\mid \\vecop[\\Y_{ik}]\\sim + \\mathcal{N}(&\\vecop[\\A_{ij}] + (\\Sigma_{jk}\\otimes\\Sigma_i) + (\\Sigma_k\\inv\\otimes\\Sigma_i\\inv)(\\vecop[\\Y_{ik}]-\\vecop[\\B_{ik}]),\\\\ + & \\Sigma_j\\otimes\\Sigma_i - (\\Sigma_{jk}\\otimes\\Sigma_i) + (\\Sigma_k\\inv\\otimes\\Sigma_i\\inv) (\\Sigma_{kj}\\otimes\\Sigma_i))\\\\ + =\\mathcal{N}(&\\vecop[\\A_{ij}] + + (\\Sigma_{jk}\\Sigma_k\\inv\\otimes\\Sigma_i\\Sigma_i\\inv) + (\\vecop[\\Y_{ik}]-\\vecop[\\B_{ik}]), \\\\ + & \\Sigma_j\\otimes\\Sigma_i - + (\\Sigma_{jk}\\Sigma_k\\inv\\Sigma_{kj}\\otimes + \\Sigma_i\\Sigma_i\\inv \\Sigma_i))\\\\ + =\\mathcal{N}(&\\vecop[\\A_{ij}] + (\\Sigma_{jk}\\Sigma_k\\inv\\otimes\\I) + (\\vecop[\\Y_{ik}]-\\vecop[\\B_{ik}]), \\\\ + & \\Sigma_j\\otimes\\Sigma_i - + (\\Sigma_{jk}\\Sigma_k\\inv\\Sigma_{kj}\\otimes\\Sigma_i)\\\\ + =\\mathcal{N}(&\\vecop[\\A_{ij}] + + \\vecop[\\Y_{ik}-\\B_{ik}\\Sigma_k\\inv\\Sigma_{kj}], + (\\Sigma_j-\\Sigma_{jk}\\Sigma_k\\inv\\Sigma_{kj})\\otimes\\Sigma_i) + + +Next, we recognize that this multivariate gaussian is equivalent to the +following matrix variate gaussian: + +.. math:: + \\X_{ij} \\mid \\Y_{ik}\\sim \\mathcal{MN}(\\A_{ij} + + (\\Y_{ik}-\\B_{ik})\\Sigma_k\\inv\\Sigma_{kj}, \\Sigma_i, + \\Sigma_j-\\Sigma_{jk}\\Sigma_k\\inv\\Sigma_{kj}) + +The conditional in the other direction can be written by working through the +same algebra: + +.. math:: + \\Y_{ik} \\mid \\X_{ij}\\sim \\mathcal{MN}(\\B_{ik} +(\\X_{ij}- + \\A_{ij})\\Sigma_j\\inv\\Sigma_{jk}, \\Sigma_i, + \\Sigma_k-\\Sigma_{kj}\\Sigma_j\\inv\\Sigma_{jk}) + +Finally, vertical rather than horizontal concatenation (yielding a partitioned +row rather than column covariance) can be written by recognizing the behavior +of the matrix normal under transposition: + +.. math:: + \\X\\trp_{ji} \\mid \\Y\\trp_{ki}\\sim \\mathcal{MN}(&\\A\\trp_{ji} + + \\Sigma_{jk}\\Sigma_k\\inv(\\Y\\trp_{ki}-\\B\\trp_{ki}), + \\Sigma_j-\\Sigma_{jk}\\Sigma_k\\inv\\Sigma_{kj}, \\Sigma_i)\\\\ + \\Y\\trp_{ki} \\mid \\X\\trp_{ji}\\sim \\mathcal{MN}(&\\B\\trp_{ki} + + \\Sigma_{kj}\\Sigma_j\\inv(\\X\\trp_{ji}-\\A\\trp_{ji}), + \\Sigma_k-\\Sigma_{kj}\\Sigma_j\\inv\\Sigma_{jk}, \\Sigma_i) + +These conditional likelihoods are implemented relatively efficiently +in `MatnormModelBase.matnorm_logp_conditional_row` and +`MatnormModelBase.matnorm_logp_conditional_col`. + +""" diff --git a/brainiak/matnormal/covs.py b/brainiak/matnormal/covs.py new file mode 100644 index 000000000..84b386ce0 --- /dev/null +++ b/brainiak/matnormal/covs.py @@ -0,0 +1,622 @@ +import tensorflow as tf +import numpy as np +import abc +import scipy.linalg +import scipy.sparse +import tensorflow_probability as tfp + +from brainiak.matnormal.utils import ( + x_tx, + xx_t, + unflatten_cholesky_unique, + flatten_cholesky_unique, +) +from brainiak.utils.kronecker_solvers import ( + tf_solve_lower_triangular_kron, + tf_solve_upper_triangular_kron, + tf_solve_lower_triangular_masked_kron, + tf_solve_upper_triangular_masked_kron, +) + +__all__ = [ + "CovBase", + "CovIdentity", + "CovAR1", + "CovIsotropic", + "CovDiagonal", + "CovDiagonalGammaPrior", + "CovUnconstrainedCholesky", + "CovUnconstrainedCholeskyWishartReg", + "CovUnconstrainedInvCholesky", + "CovKroneckerFactored", +] + + +class CovBase(abc.ABC): + """Base metaclass for residual covariances. + For more on abstract classes, see + https://docs.python.org/3/library/abc.html + + Parameters + ---------- + + size: int + The size of the covariance matrix. + + """ + + def __init__(self, size): + self.size = size + + # Log-likelihood of this covariance (useful for regularization) + self.logp = tf.constant(0, dtype=tf.float64) + + @abc.abstractmethod + def get_optimize_vars(self): + """ Returns a list of tf variables that need to get optimized to fit + this covariance + """ + pass + + @property + def logdet(self): + """ log determinant of this covariance + """ + pass + + @abc.abstractmethod + def solve(self, X): + """Given this covariance :math:`\\Sigma` and some input :math:`X`, + compute :math:`\\Sigma^{-1}x` + """ + pass + + @property + def _prec(self): + """Expose the precision explicitly (mostly for testing / + visualization, materializing large covariances may be intractable) + """ + return self.solve(tf.eye(self.size, dtype=tf.float64)) + + @property + def _cov(self): + """Expose the covariance explicitly (mostly for testing / + visualization, materializing large covariances may be intractable) + """ + return tf.linalg.inv(self._prec) + + +class CovIdentity(CovBase): + """Identity noise covariance. + """ + + def __init__(self, size): + super(CovIdentity, self).__init__(size) + + @property + def logdet(self): + return tf.constant(0.0, "float64") + + def get_optimize_vars(self): + """Returns a list of tf variables that need to get optimized to + fit this covariance + """ + return [] + + def solve(self, X): + """Given this covariance :math:`\\Sigma` and some input :math:`X`, + compute :math:`\\Sigma^{-1}x` + """ + return X + + @property + def _prec(self): + """Expose the precision explicitly (mostly for testing / + visualization, materializing large covariances may be intractable) + """ + return tf.eye(self.size, dtype=tf.float64) + + @property + def _cov(self): + """Expose the covariance explicitly (mostly for testing / + visualization, materializing large covariances may be intractable) + """ + return tf.eye(self.size, dtype=tf.float64) + + +class CovAR1(CovBase): + """AR(1) covariance parameterized by autoregressive parameter rho + and new noise sigma. + + Parameters + ---------- + size: int + size of covariance matrix + rho: float or None + initial value of autoregressive parameter (if None, initialize + randomly) + sigma: float or None + initial value of new noise parameter (if None, initialize randomly) + + """ + + def __init__(self, size, rho=None, sigma=None, scan_onsets=None): + + super(CovAR1, self).__init__(size) + + # Similar to BRSA trick I think + if scan_onsets is None: + self.run_sizes = [size] + self.offdiag_template = tf.constant( + scipy.linalg.toeplitz(np.r_[0, 1, np.zeros(size - 2)]), + dtype=tf.float64 + ) + self.diag_template = tf.constant( + np.diag(np.r_[0, np.ones(size - 2), 0])) + else: + self.run_sizes = np.ediff1d(np.r_[scan_onsets, size]) + sub_offdiags = [ + scipy.linalg.toeplitz(np.r_[0, 1, np.zeros(r - 2)]) + for r in self.run_sizes + ] + self.offdiag_template = tf.constant( + scipy.sparse.block_diag(sub_offdiags).toarray() + ) + subdiags = [np.diag(np.r_[0, np.ones(r - 2), 0]) + for r in self.run_sizes] + self.diag_template = tf.constant( + scipy.sparse.block_diag(subdiags).toarray() + ) + + self._identity_mat = tf.constant(np.eye(size)) + + if sigma is None: + self.log_sigma = tf.Variable( + tf.random.normal([1], dtype=tf.float64), name="log_sigma" + ) + else: + self.log_sigma = tf.Variable(np.log(sigma), name="log_sigma") + + if rho is None: + self.rho_unc = tf.Variable( + tf.random.normal([1], dtype=tf.float64), name="rho_unc" + ) + else: + self.rho_unc = tf.Variable( + scipy.special.logit(rho / 2 + 0.5), name="rho_unc" + ) + + @property + def logdet(self): + """ log-determinant of this covariance + """ + # first, unconstrain rho and sigma + rho = 2 * tf.sigmoid(self.rho_unc) - 1 + # now compute logdet + return tf.reduce_sum( + input_tensor=2 + * tf.constant(self.run_sizes, dtype=tf.float64) + * self.log_sigma + - tf.math.log(1 - tf.square(rho)) + ) + + @property + def _prec(self): + """Precision matrix corresponding to this AR(1) covariance. + We assume stationarity within block so no special case + for first/last element of a block. This makes constructing this + matrix easier. + reprsimil.BRSA says (I - rho1 * D + rho1**2 * F) / sigma**2 and we + use the same trick + """ + rho = 2 * tf.sigmoid(self.rho_unc) - 1 + sigma = tf.exp(self.log_sigma) + + return ( + self._identity_mat + - rho * self.offdiag_template + + rho ** 2 * self.diag_template + ) / tf.square(sigma) + + def get_optimize_vars(self): + """ Returns a list of tf variables that need to get optimized to + fit this covariance + """ + return [self.rho_unc, self.log_sigma] + + def solve(self, X): + """Given this covariance :math:`\\Sigma` and some input :math:`X`, + compute :math:`\\Sigma^{-1}x` + """ + return tf.matmul(self._prec, X) + + +class CovIsotropic(CovBase): + """Scaled identity (isotropic) noise covariance. + + Parameters + ---------- + size: int + size of covariance matrix + var: float or None + initial value of new variance parameter (if None, initialize randomly) + + """ + + def __init__(self, size, var=None): + super(CovIsotropic, self).__init__(size) + if var is None: + self.log_var = tf.Variable( + tf.random.normal([1], dtype=tf.float64), name="sigma" + ) + else: + self.log_var = tf.Variable(np.log(var), name="sigma") + self.var = tf.exp(self.log_var) + + @property + def logdet(self): + return self.size * self.log_var + + def get_optimize_vars(self): + """ Returns a list of tf variables that need to get optimized to fit + this covariance + """ + return [self.log_var] + + def solve(self, X): + """Given this covariance :math:`\\Sigma` and some input :math:`X`, + compute :math:`\\Sigma^{-1}x` + + Parameters + ---------- + X: tf.Tensor + Tensor to multiply by inverse of this covariance + + """ + return X / self.var + + +class CovDiagonal(CovBase): + """Uncorrelated (diagonal) noise covariance + + Parameters + ---------- + size: int + size of covariance matrix + diag_var: float or None + initial value of (diagonal) variance vector (if None, initialize + randomly) + + """ + + def __init__(self, size, diag_var=None): + super(CovDiagonal, self).__init__(size) + if diag_var is None: + self.logprec = tf.Variable( + tf.random.normal([size], dtype=tf.float64), name="precisions" + ) + else: + self.logprec = tf.Variable( + np.log(1 / diag_var), name="log-precisions") + + @property + def logdet(self): + return -tf.reduce_sum(input_tensor=self.logprec) + + def get_optimize_vars(self): + """ Returns a list of tf variables that need to get optimized to fit + this covariance + """ + return [self.logprec] + + def solve(self, X): + """Given this covariance :math:`\\Sigma` and some input :math:`X`, + compute :math:`\\Sigma^{-1}x` + + Parameters + ---------- + X: tf.Tensor + Tensor to multiply by inverse of this covariance + + """ + prec = tf.exp(self.logprec) + prec_dimaugmented = tf.expand_dims(prec, -1) + return tf.multiply(prec_dimaugmented, X) + + +class CovDiagonalGammaPrior(CovDiagonal): + """Uncorrelated (diagonal) noise covariance + """ + + def __init__(self, size, sigma=None, alpha=1.5, beta=1e-10): + super(CovDiagonalGammaPrior, self).__init__(size, sigma) + + self.ig = tfp.distributions.InverseGamma( + concentration=tf.constant(alpha, dtype=tf.float64), + scale=tf.constant(beta, dtype=tf.float64), + ) + + self.logp = tf.reduce_sum( + input_tensor=self.ig.log_prob(tf.exp(self.logprec))) + + +class CovUnconstrainedCholesky(CovBase): + """Unconstrained noise covariance parameterized in terms of its cholesky + """ + + def __init__(self, size=None, Sigma=None): + + if size is None and Sigma is None: + raise RuntimeError("Must pass either Sigma or size but not both") + + if size is not None and Sigma is not None: + raise RuntimeError("Must pass either Sigma or size but not both") + + if Sigma is not None: + size = Sigma.shape[0] + + super(CovUnconstrainedCholesky, self).__init__(size) + + # number of parameters in the triangular mat + npar = (size * (size + 1)) // 2 + + if Sigma is None: + self.L_flat = tf.Variable( + tf.random.normal([npar], dtype=tf.float64), name="L_flat" + ) + + else: + L = np.linalg.cholesky(Sigma) + self.L_flat = tf.Variable( + flatten_cholesky_unique(L), name="L_flat") + + self.optimize_vars = [self.L_flat] + + @property + def L(self): + """ + Cholesky factor of this covariance + """ + return unflatten_cholesky_unique(self.L_flat) + + @property + def logdet(self): + return 2 * tf.reduce_sum(input_tensor=tf.math.log( + tf.linalg.diag_part(self.L))) + + def get_optimize_vars(self): + """ Returns a list of tf variables that need to get optimized to fit + this covariance + """ + return [self.L_flat] + + def solve(self, X): + """Given this covariance :math:`\\Sigma` and some input :math:`X`, + compute :math:`\\Sigma^{-1}x` (using cholesky solve) + + Parameters + ---------- + X: tf.Tensor + Tensor to multiply by inverse of this covariance + + """ + return tf.linalg.cholesky_solve(self.L, X) + + +class CovUnconstrainedCholeskyWishartReg(CovUnconstrainedCholesky): + """Unconstrained noise covariance parameterized in terms of its + cholesky factor. Regularized using the trick from + Chung et al. 2015 such that as the covariance approaches + singularity, the likelihood goes to 0. + + References + ---------- + Chung, Y., Gelman, A., Rabe-Hesketh, S., Liu, J., & Dorie, V. (2015). + Weakly Informative Prior for Point Estimation of Covariance Matrices + in Hierarchical Models. Journal of Educational and Behavioral Statistics, + 40(2), 136–157. https://doi.org/10.3102/1076998615570945 + """ + + def __init__(self, size, Sigma=None): + super(CovUnconstrainedCholeskyWishartReg, self).__init__(size) + self.wishartReg = tfp.distributions.WishartTriL( + df=tf.constant(size + 2, dtype=tf.float64), + scale_tril=tf.constant(1e5 * np.eye(size), dtype=tf.float64), + ) + + Sigma = xx_t(self.L) + self.logp = self.wishartReg.log_prob(Sigma) + + +class CovUnconstrainedInvCholesky(CovBase): + """Unconstrained noise covariance parameterized + in terms of its precision cholesky. Use this over the + regular cholesky unless you have a good reason not to, since + this saves a cholesky solve on every step of optimization + """ + + def __init__(self, size=None, invSigma=None): + + if size is None and invSigma is None: + raise RuntimeError( + "Must pass either invSigma or size but not both") + + if size is not None and invSigma is not None: + raise RuntimeError( + "Must pass either invSigma or size but not both") + + if invSigma is not None: + size = invSigma.shape[0] + + super(CovUnconstrainedInvCholesky, self).__init__(size) + + # number of parameters in the triangular mat + npar = (size * (size + 1)) // 2 + + if invSigma is None: + self.Linv_flat = tf.Variable( + tf.random.normal([npar], dtype=tf.float64), name="Linv_flat" + ) + + else: + Linv = np.linalg.cholesky(invSigma) + self.Linv_flat = tf.Variable( + flatten_cholesky_unique(Linv), name="Linv_flat" + ) + + @property + def Linv(self): + """ + Inverse of Cholesky factor of this covariance + """ + return unflatten_cholesky_unique(self.Linv_flat) + + @property + def logdet(self): + return -2 * tf.reduce_sum( + input_tensor=tf.math.log(tf.linalg.diag_part(self.Linv)) + ) + + def get_optimize_vars(self): + """ Returns a list of tf variables that need to get optimized to fit + this covariance + """ + return [self.Linv_flat] + + def solve(self, X): + """Given this covariance :math:`\\Sigma` and some input :math:`X`, + compute :math:`\\Sigma^{-1}x` (using cholesky solve) + + Parameters + ---------- + X: tf.Tensor + Tensor to multiply by inverse of this covariance + + """ + return tf.matmul(x_tx(self.Linv), X) + + +class CovKroneckerFactored(CovBase): + """ Kronecker product noise covariance parameterized in terms + of its component cholesky factors + """ + + def __init__(self, sizes, Sigmas=None, mask=None): + """Initialize the kronecker factored covariance object. + + Arguments + --------- + sizes : list + List of dimensions (int) of the factors + E.g. ``sizes = [2, 3]`` will create two factors of + sizes 2x2 and 3x3 giving us a 6x6 dimensional covariance + Sigmas : list (default : None) + Initial guess for the covariances. List of positive definite + covariance matrices the same sizes as sizes. + mask : int array (default : None) + 1-D tensor with length equal to product of sizes with 1 for + valid elements and 0 for don't care + + Returns + ------- + None + + Raises + ------ + TypeError + If sizes is not a list + """ + if not isinstance(sizes, list): + raise TypeError("sizes is not a list") + + self.sizes = sizes + self.nfactors = len(sizes) + self.size = np.prod(np.array(sizes), dtype=np.int32) + + npar = [(size * (size + 1)) // 2 for size in self.sizes] + if Sigmas is None: + self.Lflat = [ + tf.Variable( + tf.random.normal([npar[i]], dtype=tf.float64), + name="L" + str(i) + "_flat", + ) + for i in range(self.nfactors) + ] + else: + self.Lflat = [ + tf.Variable( + flatten_cholesky_unique(np.linalg.cholesky(Sigmas[i])), + name="L" + str(i) + "_flat", + ) + for i in range(self.nfactors) + ] + self.mask = mask + + @property + def L(self): + return [unflatten_cholesky_unique(mat) for mat in self.Lflat] + + def get_optimize_vars(self): + """ Returns a list of tf variables that need to get optimized + to fit this covariance + """ + return self.Lflat + + @property + def logdet(self): + """ log|Sigma| using the diagonals of the cholesky factors. + """ + if self.mask is None: + n_list = tf.stack( + [tf.cast(tf.shape(input=mat)[0], dtype=tf.float64) + for mat in self.L] + ) + n_prod = tf.reduce_prod(input_tensor=n_list) + logdet = tf.stack( + [ + tf.reduce_sum( + input_tensor=tf.math.log( + tf.linalg.tensor_diag_part(mat)) + ) + for mat in self.L + ] + ) + logdetfinal = tf.reduce_sum( + input_tensor=(logdet * n_prod) / n_list) + else: + n_list = [tf.shape(input=mat)[0] for mat in self.L] + mask_reshaped = tf.reshape(self.mask, n_list) + logdet = 0.0 + for i in range(self.nfactors): + indices = list(range(self.nfactors)) + indices.remove(i) + logdet += (tf.math.log(tf.linalg.tensor_diag_part(self.L[i])) * + tf.cast( + tf.reduce_sum( + input_tensor=mask_reshaped, axis=indices), + dtype=tf.float64, + )) + logdetfinal = tf.reduce_sum(input_tensor=logdet) + return 2.0 * logdetfinal + + def solve(self, X): + """ Given this covariance :math:`\\Sigma` and some input :math:`X`, + compute :math:`\\Sigma^{-1}x` using traingular solves with the cholesky + factors. + + Specifically, we solve :math:`L L^T x = y` by solving + :math:`L z = y` and :math:`L^T x = z`. + + Parameters + ---------- + X: tf.Tensor + Tensor to multiply by inverse of this covariance + + """ + if self.mask is None: + z = tf_solve_lower_triangular_kron(self.L, X) + x = tf_solve_upper_triangular_kron(self.L, z) + else: + z = tf_solve_lower_triangular_masked_kron(self.L, X, self.mask) + x = tf_solve_upper_triangular_masked_kron(self.L, z, self.mask) + return x diff --git a/brainiak/matnormal/matnormal_likelihoods.py b/brainiak/matnormal/matnormal_likelihoods.py new file mode 100644 index 000000000..55d3f3dff --- /dev/null +++ b/brainiak/matnormal/matnormal_likelihoods.py @@ -0,0 +1,429 @@ +import tensorflow as tf +from tensorflow import linalg as tlinalg +from .utils import scaled_I +import logging + +logger = logging.getLogger(__name__) + + +def _condition(X): + """ + Condition number (https://en.wikipedia.org/wiki/Condition_number) + used for diagnostics. + + NOTE: this formulation is only defined for symmetric positive definite + matrices (which covariances should be, and what we're using this for) + + Parameters + ---------- + X: tf.Tensor + Symmetric tensor to compute condition number of + + """ + s = tf.linalg.svd(X, compute_uv=False) + return tf.reduce_max(input_tensor=s) / tf.reduce_min(input_tensor=s) + + +def solve_det_marginal(x, sigma, A, Q): + """ + Use matrix inversion lemma for the solve: + + .. math:: + (\\Sigma + AQA^T)^{-1} X =\\ + (\\Sigma^{-1} - \\Sigma^{-1} A (Q^{-1} + + A^T \\Sigma^{-1} A)^{-1} A^T \\Sigma^{-1}) X + + Use matrix determinant lemma for determinant: + + .. math:: + \\log|(\\Sigma + AQA^T)| = \\log|Q^{-1} + A^T \\Sigma^{-1} A| + + \\log|Q| + \\log|\\Sigma| + + Parameters + ---------- + x: tf.Tensor + Tensor to multiply the solve by + sigma: brainiak.matnormal.CovBase + Covariance object implementing solve and logdet + A: tf.Tensor + Factor multiplying the variable we marginalized out + Q: brainiak.matnormal.CovBase + Covariance object of marginalized variable, + implementing solve and logdet + """ + + # For diagnostics, we want to check condition numbers + # of things we invert. This includes Q and Sigma, as well + # as the "lemma factor" for lack of a better definition + logging.log(logging.DEBUG, "Printing diagnostics for solve_det_marginal") + lemma_cond = _condition( + Q._prec + tf.matmul(A, sigma.solve(A), transpose_a=True)) + logging.log( + logging.DEBUG, + f"lemma_factor condition={lemma_cond}", + ) + logging.log(logging.DEBUG, f"Q condition={_condition(Q._cov)}") + logging.log(logging.DEBUG, f"sigma condition={_condition(sigma._cov)}") + logging.log( + logging.DEBUG, + f"sigma max={tf.reduce_max(input_tensor=A)}," + + f"sigma min={tf.reduce_min(input_tensor=A)}", + ) + + # cholesky of (Qinv + A^T Sigma^{-1} A), which looks sort of like + # a schur complement but isn't, so we call it the "lemma factor" + # since we use it in woodbury and matrix determinant lemmas + lemma_factor = tlinalg.cholesky( + Q._prec + tf.matmul(A, sigma.solve(A), transpose_a=True) + ) + + logdet = ( + Q.logdet + + sigma.logdet + + 2 * + tf.reduce_sum(input_tensor=tf.math.log( + tlinalg.diag_part(lemma_factor))) + ) + + logging.log(logging.DEBUG, f"Log-determinant of Q={Q.logdet}") + logging.log(logging.DEBUG, f"sigma logdet={sigma.logdet}") + lemma_logdet = 2 * \ + tf.reduce_sum(input_tensor=tf.math.log( + tlinalg.diag_part(lemma_factor))) + logging.log( + logging.DEBUG, + f"lemma factor logdet={lemma_logdet}", + ) + + # A^T Sigma^{-1} + Atrp_Sinv = tf.matmul(A, sigma._prec, transpose_a=True) + # (Qinv + A^T Sigma^{-1} A)^{-1} A^T Sigma^{-1} + prod_term = tlinalg.cholesky_solve(lemma_factor, Atrp_Sinv) + + solve = tf.matmul( + sigma.solve(scaled_I(1.0, sigma.size) - tf.matmul(A, prod_term)), x + ) + + return solve, logdet + + +def solve_det_conditional(x, sigma, A, Q): + """ + Use matrix inversion lemma for the solve: + + .. math:: + (\\Sigma - AQ^{-1}A^T)^{-1} X =\\ + (\\Sigma^{-1} + \\Sigma^{-1} A (Q - + A^T \\Sigma^{-1} A)^{-1} A^T \\Sigma^{-1}) X + + Use matrix determinant lemma for determinant: + + .. math:: + \\log|(\\Sigma - AQ^{-1}A^T)| = + \\log|Q - A^T \\Sigma^{-1} A| - \\log|Q| + \\log|\\Sigma| + + Parameters + ---------- + x: tf.Tensor + Tensor to multiply the solve by + sigma: brainiak.matnormal.CovBase + Covariance object implementing solve and logdet + A: tf.Tensor + Factor multiplying the variable we conditioned on + Q: brainiak.matnormal.CovBase + Covariance object of conditioning variable, + implementing solve and logdet + + """ + + # (Q - A^T Sigma^{-1} A) + lemma_factor = tlinalg.cholesky( + Q._cov - tf.matmul(A, sigma.solve(A), transpose_a=True) + ) + + logdet = ( + -Q.logdet + + sigma.logdet + + 2 * + tf.reduce_sum(input_tensor=tf.math.log( + tlinalg.diag_part(lemma_factor))) + ) + + # A^T Sigma^{-1} + Atrp_Sinv = tf.matmul(A, sigma._prec, transpose_a=True) + # (Q - A^T Sigma^{-1} A)^{-1} A^T Sigma^{-1} + prod_term = tlinalg.cholesky_solve(lemma_factor, Atrp_Sinv) + + solve = tf.matmul( + sigma.solve(scaled_I(1.0, sigma.size) + tf.matmul(A, prod_term)), x + ) + + return solve, logdet + + +def _mnorm_logp_internal( + colsize, rowsize, logdet_row, logdet_col, solve_row, solve_col +): + """Construct logp from the solves and determinants. + + Parameters + ---------------- + colsize: int + Column dimnesion of observation tensor + rowsize: int + Row dimension of observation tensor + logdet_row: tf.Tensor (scalar) + log-determinant of row covariance + logdet_col: tf.Tensor (scalar) + log-determinant of column covariance + solve_row: tf.Tensor + Inverse row covariance multiplying the observation tensor + solve_col + Inverse column covariance multiplying the transpose of + the observation tensor + """ + log2pi = 1.8378770664093453 + + logging.log(logging.DEBUG, + f"column precision trace ={tlinalg.trace(solve_col)}") + logging.log(logging.DEBUG, + f"row precision trace ={tlinalg.trace(solve_row)}") + logging.log(logging.DEBUG, f"row cov logdet ={logdet_row}") + logging.log(logging.DEBUG, f"col cov logdet ={logdet_col}") + + denominator = ( + -rowsize * colsize * log2pi - colsize * logdet_row - + rowsize * logdet_col + ) + numerator = -tlinalg.trace(tf.matmul(solve_col, solve_row)) + return 0.5 * (numerator + denominator) + + +def matnorm_logp(x, row_cov, col_cov): + """Log likelihood for centered matrix-variate normal density. + Assumes that row_cov and col_cov follow the API defined in CovBase. + + Parameters + ---------------- + x: tf.Tensor + Observation tensor + row_cov: CovBase + Row covariance implementing the CovBase API + col_cov: CovBase + Column Covariance implementing the CovBase API + + """ + + rowsize = tf.cast(tf.shape(input=x)[0], "float64") + colsize = tf.cast(tf.shape(input=x)[1], "float64") + + # precompute sigma_col^{-1} * x' + solve_col = col_cov.solve(tf.transpose(a=x)) + logdet_col = col_cov.logdet + + # precompute sigma_row^{-1} * x + solve_row = row_cov.solve(x) + logdet_row = row_cov.logdet + + return _mnorm_logp_internal( + colsize, rowsize, logdet_row, logdet_col, solve_row, solve_col + ) + + +def matnorm_logp_marginal_row(x, row_cov, col_cov, marg, marg_cov): + """ + Log likelihood for marginal centered matrix-variate normal density. + + .. math:: + X &\\sim \\mathcal{MN}(0, Q, C)\\ + + Y \\mid \\X &\\sim \\mathcal{MN}(AX, R, C),\\ + + Y &\\sim \\mathcal{MN}(0, R + AQA^T, C) + + This function efficiently computes the marginals by unpacking some + info in the covariance classes and then dispatching to + `solve_det_marginal`. + + Parameters + --------------- + x: tf.Tensor + Observation tensor + row_cov: CovBase + Row covariance implementing the CovBase API (:math:`R` above). + col_cov: CovBase + Column Covariance implementing the CovBase API (:math:`C` above). + marg: tf.Tensor + Marginal factor (:math:`A` above). + marg_cov: CovBase + Prior covariance implementing the CovBase API (:math:`Q` above). + """ + rowsize = tf.cast(tf.shape(input=x)[0], "float64") + colsize = tf.cast(tf.shape(input=x)[1], "float64") + + solve_col = col_cov.solve(tf.transpose(a=x)) + logdet_col = col_cov.logdet + + solve_row, logdet_row = solve_det_marginal(x, row_cov, marg, marg_cov) + + return _mnorm_logp_internal( + colsize, rowsize, logdet_row, logdet_col, solve_row, solve_col + ) + + +def matnorm_logp_marginal_col(x, row_cov, col_cov, marg, marg_cov): + """ + Log likelihood for centered marginal matrix-variate normal density. + + .. math:: + X &\\sim \\mathcal{MN}(0, R, Q)\\ + + Y \\mid \\X &\\sim \\mathcal{MN}(XA, R, C),\\ + + Y &\\sim \\mathcal{MN}(0, R, C + A^TQA) + + This function efficiently computes the marginals by unpacking some + info in the covariance classes and then dispatching to + `solve_det_marginal`. + + Parameters + --------------- + x: tf.Tensor + Observation tensor + row_cov: CovBase + Row covariance implementing the CovBase API (:math:`R` above). + col_cov: CovBase + Column Covariance implementing the CovBase API (:math:`C` above). + marg: tf.Tensor + Marginal factor (:math:`A` above). + marg_cov: CovBase + Prior covariance implementing the CovBase API (:math:`Q` above). + + """ + rowsize = tf.cast(tf.shape(input=x)[0], "float64") + colsize = tf.cast(tf.shape(input=x)[1], "float64") + + solve_row = row_cov.solve(x) + logdet_row = row_cov.logdet + + solve_col, logdet_col = solve_det_marginal( + tf.transpose(a=x), col_cov, tf.transpose(a=marg), marg_cov + ) + + return _mnorm_logp_internal( + colsize, rowsize, logdet_row, logdet_col, solve_row, solve_col + ) + + +def matnorm_logp_conditional_row(x, row_cov, col_cov, cond, cond_cov): + """ + Log likelihood for centered conditional matrix-variate normal density. + + Consider the following partitioned matrix-normal density: + + .. math:: + \\begin{bmatrix} + \\operatorname{vec}\\left[\\mathbf{X}_{i j}\\right] \\\\ + \\operatorname{vec}\\left[\\mathbf{Y}_{i k}\\right]\\end{bmatrix} + \\sim \\mathcal{N}\\left(0,\\begin{bmatrix} \\Sigma_{j} \\otimes + \\Sigma_{i} & \\Sigma_{j k} \\otimes \\Sigma_{i}\\\\ + \\Sigma_{k j} \\otimes \\Sigma_{i} & \\Sigma_{k} \\otimes \\Sigma_{i} + \\end{bmatrix}\\right) + + Then we can write the conditional: + + .. math:: + \\mathbf{X}^T j i \\mid \\mathbf{Y}_{k i}^T + \\sim \\mathcal{M}\\ + \\mathcal{N}\\left(0, \\Sigma_{j}-\\Sigma_{j k} \\Sigma_{k}^{-1} + \\Sigma_{k j},\\ + \\Sigma_{i}\\right) + + This function efficiently computes the conditionals by unpacking some + info in the covariance classes and then dispatching to + `solve_det_conditional`. + + Parameters + --------------- + x: tf.Tensor + Observation tensor + row_cov: CovBase + Row covariance (:math:`\\Sigma_{i}` in the notation above). + col_cov: CovBase + Column covariance (:math:`\\Sigma_{j}` in the notation above). + cond: tf.Tensor + Off-diagonal block of the partitioned covariance (:math:`\\Sigma_{jk}` + in the notation above). + cond_cov: CovBase + Covariance of conditioning variable (:math:`\\Sigma_{k}` in the + notation above). + + """ + rowsize = tf.cast(tf.shape(input=x)[0], "float64") + colsize = tf.cast(tf.shape(input=x)[1], "float64") + + solve_col = col_cov.solve(tf.transpose(a=x)) + logdet_col = col_cov.logdet + + solve_row, logdet_row = solve_det_conditional(x, row_cov, cond, cond_cov) + + return _mnorm_logp_internal( + colsize, rowsize, logdet_row, logdet_col, solve_row, solve_col + ) + + +def matnorm_logp_conditional_col(x, row_cov, col_cov, cond, cond_cov): + """ + Log likelihood for centered conditional matrix-variate normal density. + + Consider the following partitioned matrix-normal density: + + .. math:: + \\begin{bmatrix} + \\operatorname{vec}\\left[\\mathbf{X}_{i j}\\right] \\\\ + \\operatorname{vec}\\left[\\mathbf{Y}_{i k}\\right]\\end{bmatrix} + \\sim \\mathcal{N}\\left(0,\\begin{bmatrix} \\Sigma_{j} \\otimes + \\Sigma_{i} & \\Sigma_{j k} \\otimes \\Sigma_{i}\\\\ + \\Sigma_{k j} \\otimes \\Sigma_{i} & \\Sigma_{k} \\otimes \\Sigma_{i} + \\end{bmatrix}\\right) + + Then we can write the conditional: + + .. math:: + \\mathbf{X}_{i j} \\mid \\mathbf{Y}_{i k} \\sim \\mathcal{M}\\ + \\mathcal{N}\\left(0, \\Sigma_{i}, \\Sigma_{j}-\\Sigma_{j k}\\ + \\Sigma_{k}^{-1} \\Sigma_{k j}\\right) + + This function efficiently computes the conditionals by unpacking some + info in the covariance classes and then dispatching to + `solve_det_conditional`. + + Parameters + --------------- + x: tf.Tensor + Observation tensor + row_cov: CovBase + Row covariance (:math:`\\Sigma_{i}` in the notation above). + col_cov: CovBase + Column covariance (:math:`\\Sigma_{j}` in the notation above). + cond: tf.Tensor + Off-diagonal block of the partitioned covariance (:math:`\\Sigma_{jk}` + in the notation above). + cond_cov: CovBase + Covariance of conditioning variable (:math:`\\Sigma_{k}` in the + notation above). + + """ + rowsize = tf.cast(tf.shape(input=x)[0], "float64") + colsize = tf.cast(tf.shape(input=x)[1], "float64") + + solve_row = row_cov.solve(x) + logdet_row = row_cov.logdet + + solve_col, logdet_col = solve_det_conditional( + tf.transpose(a=x), col_cov, tf.transpose(a=cond), cond_cov + ) + + return _mnorm_logp_internal( + colsize, rowsize, logdet_row, logdet_col, solve_row, solve_col + ) diff --git a/brainiak/matnormal/mnrsa.py b/brainiak/matnormal/mnrsa.py new file mode 100644 index 000000000..0b175bf45 --- /dev/null +++ b/brainiak/matnormal/mnrsa.py @@ -0,0 +1,175 @@ +import tensorflow as tf +from sklearn.base import BaseEstimator +from sklearn.linear_model import LinearRegression +from .covs import CovIdentity +from brainiak.utils.utils import cov2corr +import numpy as np +from brainiak.matnormal.matnormal_likelihoods import matnorm_logp_marginal_row +from brainiak.matnormal.utils import ( + pack_trainable_vars, + unpack_trainable_vars, + make_val_and_grad, + unflatten_cholesky_unique, + flatten_cholesky_unique, +) + +from scipy.optimize import minimize + +__all__ = ["MNRSA"] + + +class MNRSA(BaseEstimator): + """ Matrix normal version of RSA. + + The goal of this analysis is to find the covariance of the mapping from + some design matrix X to the fMRI signal Y. It does so by marginalizing over + the actual mapping (i.e. averaging over the uncertainty in it), which + happens to correct a bias imposed by structure in the design matrix on the + RSA estimate (see Cai et al., NIPS 2016). + + This implementation makes different choices about residual covariance + relative to `brainiak.reprsimil.BRSA`: Here, the noise covariance is + assumed to be kronecker-separable. Informally, this means that all voxels + have the same temporal covariance, and all time points have the same + spatial covariance. This is in contrast to BRSA, which allows different + temporal covariance for each voxel. On the other hand, computational + efficiencies enabled by this choice allow MNRSA to support a richer class + of space and time covariances (anything in `brainiak.matnormal.covs`). + + For users: in general, if you are worried about voxels each having + different temporal noise structure,you should use + `brainiak.reprsimil.BRSA`. If you are worried about between-voxel + correlations or temporal covaraince structures that BRSA does not + support, you should use MNRSA. + + .. math:: + Y &\\sim \\mathcal{MN}(0, \\Sigma_t + XLL^TX^T+ + X_0X_0^T, \\Sigma_s)\\ + + U &= LL^T + + Parameters + ---------- + time_cov : subclass of CovBase + Temporal noise covariance class following CovBase interface. + space_cov : subclass of CovBase + Spatial noise covariance class following CovBase interface. + optimizer : string, Default :'L-BFGS' + Name of scipy optimizer to use. + optCtrl : dict, default: None + Additional arguments to pass to scipy.optimize.minimize. + + """ + + def __init__( + self, time_cov, space_cov, n_nureg=5, optimizer="L-BFGS-B", + optCtrl=None + ): + + self.n_T = time_cov.size + self.n_V = space_cov.size + self.n_nureg = n_nureg + + self.optMethod = optimizer + if optCtrl is None: + self.optCtrl = {} + + self.X_0 = tf.Variable( + tf.random.normal([self.n_T, n_nureg], dtype=tf.float64), name="X_0" + ) + + self.train_variables = [self.X_0] + + self.time_cov = time_cov + self.space_cov = space_cov + + self.train_variables.extend(self.time_cov.get_optimize_vars()) + self.train_variables.extend(self.space_cov.get_optimize_vars()) + + @property + def L(self): + """ + Cholesky factor of the RSA matrix. + """ + return unflatten_cholesky_unique(self.L_flat) + + def fit(self, X, y, naive_init=True): + """ Estimate dimension reduction and cognitive model parameters + + Parameters + ---------- + X: 2d array + Brain data matrix (TRs by voxels). Y in the math + y: 2d array or vector + Behavior data matrix (TRs by behavioral obsevations). X in the math + max_iter: int, default=1000 + Maximum number of iterations to run + step: int, default=100 + Number of steps between optimizer output + restart: bool, default=True + If this is true, optimizer is restarted (e.g. for a new dataset). + Otherwise optimizer will continue from where it is now (for example + for running more iterations if the initial number was not enough). + + """ + + # In the method signature we follow sklearn discriminative API + # where brain is X and behavior is y. Internally we are + # generative so we flip this here + X, Y = y, X + + self.n_c = X.shape[1] + + if naive_init: + # initialize from naive RSA + m = LinearRegression(fit_intercept=False) + m.fit(X=X, y=Y) + self.naive_U_ = np.cov(m.coef_.T) + naiveRSA_L = np.linalg.cholesky(self.naive_U_) + self.L_flat = tf.Variable( + flatten_cholesky_unique(naiveRSA_L), name="L_flat", + dtype="float64" + ) + else: + chol_flat_size = (self.n_c * (self.n_c + 1)) // 2 + self.L_flat = tf.Variable( + tf.random.normal([chol_flat_size], dtype="float64"), + name="L_flat", + dtype="float64", + ) + + self.train_variables.extend([self.L_flat]) + + def lossfn(theta): return -self.logp(X, Y) + val_and_grad = make_val_and_grad(lossfn, self.train_variables) + + x0 = pack_trainable_vars(self.train_variables) + + opt_results = minimize(fun=val_and_grad, x0=x0, + jac=True, method=self.optMethod, **self.optCtrl) + + unpacked_theta = unpack_trainable_vars( + opt_results.x, self.train_variables) + for var, val in zip(self.train_variables, unpacked_theta): + var.assign(val) + + self.U_ = self.L.numpy().dot(self.L.numpy().T) + self.C_ = cov2corr(self.U_) + + def logp(self, X, Y): + """ MNRSA Log-likelihood""" + + rsa_cov = CovIdentity(size=self.n_c + self.n_nureg) + x_stack = tf.concat([tf.matmul(X, self.L), self.X_0], 1) + return ( + self.time_cov.logp + + self.space_cov.logp + + rsa_cov.logp + + matnorm_logp_marginal_row( + Y, + row_cov=self.time_cov, + col_cov=self.space_cov, + marg=x_stack, + marg_cov=rsa_cov, + ) + ) diff --git a/brainiak/matnormal/regression.py b/brainiak/matnormal/regression.py new file mode 100644 index 000000000..816f22a7a --- /dev/null +++ b/brainiak/matnormal/regression.py @@ -0,0 +1,146 @@ +import tensorflow as tf +import numpy as np +from sklearn.base import BaseEstimator +from brainiak.matnormal.matnormal_likelihoods import matnorm_logp +from brainiak.matnormal.utils import ( + pack_trainable_vars, + unpack_trainable_vars, + make_val_and_grad, +) +from scipy.optimize import minimize + +__all__ = ["MatnormalRegression"] + + +class MatnormalRegression(BaseEstimator): + """ This analysis allows maximum likelihood estimation of regression models + in the presence of both spatial and temporal covariance. + + ..math:: + Y \\sim \\mathcal{MN}(X\beta, time_cov, space_cov) + + Parameters + ---------- + time_cov : subclass of CovBase + TR noise covariance class following CovBase interface. + space_cov : subclass of CovBase + Voxel noise covariance class following CovBase interface. + optimizer : string, default="L-BFGS-B" + Scipy optimizer to use. For other options, see "method" argument + of scipy.optimize.minimize + optCtrl: dict, default=None + Additional arguments to pass to scipy.optimize.minimize. + + """ + + def __init__(self, time_cov, space_cov, optimizer="L-BFGS-B", + optCtrl=None): + + self.optMethod = optimizer + if optCtrl is None: + self.optCtrl = {} + + self.time_cov = time_cov + self.space_cov = space_cov + + self.n_t = time_cov.size + self.n_v = space_cov.size + + def logp(self, X, Y): + """ Log likelihood of model (internal) + """ + y_hat = tf.matmul(X, self.beta) + resid = Y - y_hat + return matnorm_logp(resid, self.time_cov, self.space_cov) + + def fit(self, X, y, naive_init=True): + """ Compute the regression fit. + + Parameters + ---------- + X : np.array, TRs by conditions. + Design matrix + y : np.array, TRs by voxels. + fMRI data + """ + + self.n_c = X.shape[1] + + if naive_init: + # initialize to the least squares solution (basically all + # we need now is the cov) + sigma_inv_x = self.time_cov.solve(X) + sigma_inv_y = self.time_cov.solve(y) + + beta_init = np.linalg.solve( + (X.T).dot(sigma_inv_x), (X.T).dot(sigma_inv_y)) + + else: + beta_init = np.random.randn(self.n_c, self.n_v) + + self.beta = tf.Variable(beta_init, name="beta") + + self.train_variables = [self.beta] + self.train_variables.extend(self.time_cov.get_optimize_vars()) + self.train_variables.extend(self.space_cov.get_optimize_vars()) + + def lossfn(theta): + return -self.logp(X, y) + + val_and_grad = make_val_and_grad(lossfn, self.train_variables) + x0 = pack_trainable_vars(self.train_variables) + + opt_results = minimize( + fun=val_and_grad, x0=x0, jac=True, method=self.optMethod, + **self.optCtrl + ) + + unpacked_theta = unpack_trainable_vars( + opt_results.x, self.train_variables) + + for var, val in zip(self.train_variables, unpacked_theta): + var.assign(val) + + self.beta_ = self.beta.numpy() + + def predict(self, X): + """ Predict fMRI signal from design matrix. + + Parameters + ---------- + X : np.array, TRs by conditions. + Design matrix + + """ + + return X.dot(self.beta_) + + def calibrate(self, Y): + """ Decode design matrix from fMRI dataset, based on a previously + trained mapping. This method just does naive MLE: + + .. math:: + X = Y \\Sigma_s^{-1}B^T(B \\Sigma_s^{-1} B^T)^{-1} + + Parameters + ---------- + Y : np.array, TRs by voxels. + fMRI dataset + """ + + if Y.shape[1] <= self.n_c: + raise RuntimeError( + "More conditions than voxels! System is singular,\ + cannot decode." + ) + + # Sigma_s^{-1} B' + Sigma_s_btrp = self.space_cov.solve(tf.transpose(a=self.beta)) + # Y Sigma_s^{-1} B' + Y_Sigma_Btrp = tf.matmul(Y, Sigma_s_btrp).numpy() + # (B Sigma_s^{-1} B')^{-1} + B_Sigma_Btrp = tf.matmul(self.beta, Sigma_s_btrp).numpy() + + X_test = np.linalg.solve(B_Sigma_Btrp.T, Y_Sigma_Btrp.T).T + + return X_test diff --git a/brainiak/matnormal/utils.py b/brainiak/matnormal/utils.py new file mode 100644 index 000000000..ca429dd79 --- /dev/null +++ b/brainiak/matnormal/utils.py @@ -0,0 +1,124 @@ +import tensorflow as tf +import tensorflow_probability as tfp +from scipy.stats import norm +from numpy.linalg import cholesky +import numpy as np + + +def rmn(rowcov, colcov): + """ + Generate random draws from a zero-mean matrix-normal distribution. + + Parameters + ----------- + rowcov : np.ndarray + Row covariance (assumed to be positive definite) + colcov : np.ndarray + Column covariance (assumed to be positive definite) + """ + + Z = norm.rvs(size=(rowcov.shape[0], colcov.shape[0])) + return cholesky(rowcov).dot(Z).dot(cholesky(colcov)) + + +def xx_t(x): + """ + Outer product + :math:`xx^T` + + Parameters + ----------- + x : tf.Tensor + + """ + return tf.matmul(x, x, transpose_b=True) + + +def x_tx(x): + """Inner product + :math:`x^T x` + + Parameters + ----------- + x : tf.Tensor + + """ + return tf.matmul(x, x, transpose_a=True) + + +def scaled_I(x, size): + """Scaled identity matrix + :math:`x I_{size}` + + Parameters + ------------ + x: float or coercable to float + Scale to multiply the identity matrix by + size: int or otherwise coercable to a size + Dimension of the scaled identity matrix to return + """ + return tf.linalg.tensor_diag(tf.ones([size], dtype=tf.float64) * x) + + +def flatten_cholesky_unique(L): + """ + Flattens nonzero-elements Cholesky (triangular) factor + into a vector, and logs diagonal to make parameterization + unique. Inverse of unflatten_cholesky_unique. + """ + L_tf = tf.linalg.set_diag(L, tf.math.log(tf.linalg.diag_part(L))) + L_flat = tfp.math.fill_triangular_inverse(L_tf) + return L_flat + + +def unflatten_cholesky_unique(L_flat): + """ + Converts a vector of elements into a triangular matrix + (Cholesky factor). Exponentiates diagonal to make + parameterization unique. Inverse of flatten_cholesky_unique. + """ + L = tfp.math.fill_triangular(L_flat) + # exp diag for unique parameterization + L = tf.linalg.set_diag(L, tf.exp(tf.linalg.diag_part(L))) + return L + + +def pack_trainable_vars(trainable_vars): + """ + Pack trainable vars in a model into a single + vector that can be passed to scipy.optimize + """ + return tf.concat([tf.reshape(tv, (-1,)) for tv in trainable_vars], axis=0) + + +def unpack_trainable_vars(x, trainable_vars): + """ + Unpack trainable vars from a single vector as + used/returned by scipy.optimize + """ + + sizes = [tv.shape for tv in trainable_vars] + idxs = [np.prod(sz) for sz in sizes] + flatvars = tf.split(x, idxs) + return [tf.reshape(fv, tv.shape) for fv, tv in zip(flatvars, + trainable_vars)] + + +def make_val_and_grad(lossfn, train_vars): + """ + Makes a function that ouptuts the loss and gradient in a format compatible + with scipy.optimize.minimize + """ + + def val_and_grad(theta): + with tf.GradientTape() as tape: + tape.watch(train_vars) + unpacked_theta = unpack_trainable_vars(theta, train_vars) + for var, val in zip(train_vars, unpacked_theta): + var.assign(val) + loss = lossfn(theta) + grad = tape.gradient(loss, train_vars) + packed_grad = pack_trainable_vars(grad) + return loss.numpy(), packed_grad.numpy() + + return val_and_grad diff --git a/brainiak/utils/kronecker_solvers.py b/brainiak/utils/kronecker_solvers.py new file mode 100644 index 000000000..7ed9a7c9d --- /dev/null +++ b/brainiak/utils/kronecker_solvers.py @@ -0,0 +1,330 @@ +import tensorflow as tf + +__all__ = ["tf_kron_mult", "tf_masked_triangular_solve"] + + +def tf_solve_lower_triangular_kron(L, y): + """ Tensorflow function to solve L x = y + where L = kron(L[0], L[1] .. L[n-1]) + and L[i] are the lower triangular matrices + + Arguments + --------- + L : list of 2-D tensors + Each element of the list must be a tensorflow tensor and + must be a lower triangular matrix of dimension n_i x n_i + + y : 1-D or 2-D tensor + Dimension (n_0*n_1*..n_(m-1)) x p + + Returns + ------- + x : 1-D or 2-D tensor + Dimension (n_0*n_1*..n_(m-1)) x p + + """ + n = len(L) + if n == 1: + return tf.linalg.triangular_solve(L[0], y) + else: + x = y + na = L[0].get_shape().as_list()[0] + n_list = tf.stack( + [tf.cast(tf.shape(input=mat)[0], dtype=tf.float64) for mat in L] + ) + n_prod = tf.cast(tf.reduce_prod(input_tensor=n_list), dtype=tf.int32) + nb = tf.cast(n_prod / na, dtype=tf.int32) + col = tf.shape(input=x)[1] + + for i in range(na): + xt, xinb, xina = tf.split(x, [i * nb, nb, (na - i - 1) * nb], 0) + t = xinb / L[0][i, i] + xinb = tf_solve_lower_triangular_kron(L[1:], t) + xina = xina - tf.reshape( + tf.tile(tf.slice(L[0], [i + 1, i], + [na - i - 1, 1]), [1, nb * col]), + [(na - i - 1) * nb, col], + ) * tf.reshape( + tf.tile(tf.reshape(t, [-1, 1]), [na - i - 1, 1]), + [(na - i - 1) * nb, col], + ) + x = tf.concat(axis=0, values=[xt, xinb, xina]) + + return x + + +def tf_solve_upper_triangular_kron(L, y): + """ Tensorflow function to solve L^T x = y + where L = kron(L[0], L[1] .. L[n-1]) + and L[i] are the lower triangular matrices + + Arguments + --------- + L : list of 2-D tensors + Each element of the list must be a tensorflow tensor and + must be a lower triangular matrix of dimension n_i x n_i + + y : 1-D or 2-D tensor + Dimension (n_0*n_1*..n_(m-1)) x p + + Returns + ------- + x : 1-D or 2-D tensor + Dimension (n_0*n_1*..n_(m-1)) x p + + """ + n = len(L) + if n == 1: + return tf.linalg.triangular_solve(L[0], y, adjoint=True) + else: + x = y + na = L[0].get_shape().as_list()[0] + n_list = tf.stack( + [tf.cast(tf.shape(input=mat)[0], dtype=tf.float64) for mat in L] + ) + n_prod = tf.cast(tf.reduce_prod(input_tensor=n_list), dtype=tf.int32) + nb = tf.cast(n_prod / na, dtype=tf.int32) + col = tf.shape(input=x)[1] + + for i in range(na - 1, -1, -1): + xt, xinb, xina = tf.split(x, [i * nb, nb, (na - i - 1) * nb], 0) + t = xinb / L[0][i, i] + xinb = tf_solve_upper_triangular_kron(L[1:], t) + xt = xt - tf.reshape( + tf.tile(tf.transpose(a=tf.slice( + L[0], [i, 0], [1, i])), [1, nb * col]), + [i * nb, col], + ) * tf.reshape(tf.tile(tf.reshape(t, [-1, 1]), [i, 1]), + [i * nb, col]) + x = tf.concat(axis=0, values=[xt, xinb, xina]) + + return x + + +def tf_kron_mult(L, x): + """ Tensorflow multiply with kronecker product matrix + Returns kron(L[0], L[1] ...) * x + + Arguments + --------- + L : list of 2-D tensors + Each element of the list must be a tensorflow tensor and + must be a square matrix of dimension n_i x n_i + + x : 1-D or 2-D tensor + Dimension (n_0*n_1*..n_(m-1)) x p + + Returns + ------- + y : 1-D or 2-D tensor + Dimension (n_0*n_1*..n_(m-1)) x p + """ + n = len(L) + if n == 1: + return tf.matmul(L[0], x) + else: + na = L[0].get_shape().as_list()[0] + n_list = tf.stack( + [tf.cast(tf.shape(input=mat)[0], dtype=tf.float64) for mat in L] + ) + n_prod = tf.cast(tf.reduce_prod(input_tensor=n_list), dtype=tf.int32) + nb = tf.cast(n_prod / na, dtype=tf.int32) + col = tf.shape(input=x)[1] + xt = tf_kron_mult( + L[1:], tf.transpose(a=tf.reshape(tf.transpose(a=x), [-1, nb])) + ) + y = tf.zeros_like(x) + for i in range(na): + ya, yb, yc = tf.split(y, [i * nb, nb, (na - i - 1) * nb], 0) + yb = tf.reshape( + tf.matmul( + tf.reshape(xt, [nb * col, na]), + tf.transpose(a=tf.slice(L[0], [i, 0], [1, na])), + ), + [nb, col], + ) + y = tf.concat(axis=0, values=[ya, yb, yc]) + return y + + +def tf_masked_triangular_solve(L, y, mask, lower=True, adjoint=False): + """ Tensorflow function to solve L x = y + where L is a lower triangular matrix with a mask + + Arguments + --------- + L : 2-D tensor + Must be a tensorflow tensor and + must be a triangular matrix of dimension n x n + + y : 1-D or 2-D tensor + Dimension n x p + + mask : 1-D tensor + Dimension n x 1, should be 1 if element is valid, 0 if invalid + + lower : boolean (default : True) + True if L is lower triangular, False if upper triangular + + adjoint : boolean (default : False) + True if solving for L^T x = y, False if solving for Lx = y + + Returns + ------- + x : 1-D or 2-D tensor + Dimension n x p, values at rows for which mask == 0 are set to zero + + """ + + zero = tf.constant(0, dtype=tf.int32) + mask_mat = tf.compat.v1.where( + tf.not_equal( + tf.matmul(tf.reshape(mask, [-1, 1]), + tf.reshape(mask, [1, -1])), zero + ) + ) + q = tf.cast( + tf.sqrt(tf.cast(tf.shape(input=mask_mat)[0], dtype=tf.float64)), + dtype=tf.int32 + ) + L_masked = tf.reshape(tf.gather_nd(L, mask_mat), [q, q]) + + maskindex = tf.compat.v1.where(tf.not_equal(mask, zero)) + y_masked = tf.gather_nd(y, maskindex) + + x_s1 = tf.linalg.triangular_solve( + L_masked, y_masked, lower=lower, adjoint=adjoint) + x = tf.scatter_nd(maskindex, x_s1, tf.cast( + tf.shape(input=y), dtype=tf.int64)) + return x + + +def tf_solve_lower_triangular_masked_kron(L, y, mask): + """ Tensorflow function to solve L x = y + where L = kron(L[0], L[1] .. L[n-1]), + L[i] are the lower triangular matrices, + and mask is a binary elementwise mask on the full L + + Arguments + --------- + L : list of 2-D tensors + Each element of the list must be a tensorflow tensor and + must be a lower triangular matrix of dimension n_i x n_i + + y : 1-D or 2-D tensor + Dimension [n_0*n_1*..n_(m-1)), p] + + mask: 1-D tensor + Dimension [n_0*n_1*...n_(m-1)] with 1 for valid rows and 0 + for don't care + + Returns + ------- + x : 1-D or 2-D tensor + Dimension (n_0*n_1*..n_(m-1)) x p, values at rows + for which mask == 0 are set to zero + + """ + n = len(L) + if n == 1: + return tf_masked_triangular_solve(L[0], y, mask, lower=True, + adjoint=False) + else: + x = y + na = L[0].get_shape().as_list()[0] + n_list = tf.stack( + [tf.cast(tf.shape(input=mat)[0], dtype=tf.float64) for mat in L] + ) + n_prod = tf.cast(tf.reduce_prod(input_tensor=n_list), dtype=tf.int32) + nb = tf.cast(n_prod / na, dtype=tf.int32) + col = tf.shape(input=x)[1] + + for i in range(na): + mask_b = tf.slice(mask, [i * nb], [nb]) + xt, xinb, xina = tf.split(x, [i * nb, nb, (na - i - 1) * nb], 0) + t = xinb / L[0][i, i] + + if tf.reduce_sum(input_tensor=mask_b) != nb: + xinb = tf_solve_lower_triangular_masked_kron(L[1:], t, mask_b) + t_masked = tf_kron_mult(L[1:], xinb) + + else: + # all valid - same as no mask + xinb = tf_solve_lower_triangular_kron(L[1:], t) + t_masked = t + xina = xina - tf.reshape( + tf.tile(tf.slice(L[0], [i + 1, i], + [na - i - 1, 1]), [1, nb * col]), + [(na - i - 1) * nb, col], + ) * tf.reshape( + tf.tile(tf.reshape(t_masked, [-1, 1]), [na - i - 1, 1]), + [(na - i - 1) * nb, col], + ) + + x = tf.concat(axis=0, values=[xt, xinb, xina]) + + return x + + +def tf_solve_upper_triangular_masked_kron(L, y, mask): + """ Tensorflow function to solve L^T x = y + where L = kron(L[0], L[1] .. L[n-1]) + and L[i] are the lower triangular matrices + + Arguments + --------- + L : list of 2-D tensors + Each element of the list must be a tensorflow tensor and + must be a lower triangular matrix of dimension n_i x n_i + + y : 1-D or 2-D tensor + Dimension [n_0*n_1*..n_(m-1)), p] + + mask: 1-D tensor + Dimension [n_0*n_1*...n_(m-1)] with 1 for valid rows + and 0 for don't care + + Returns + ------- + x : 1-D or 2-D tensor + Dimension (n_0*n_1*..n_(m-1)) x p, values at rows + for which mask == 0 are set to zero + + """ + n = len(L) + if n == 1: + return tf_masked_triangular_solve(L[0], y, mask, lower=True, + adjoint=True) + else: + x = y + na = L[0].get_shape().as_list()[0] + n_list = tf.stack( + [tf.cast(tf.shape(input=mat)[0], dtype=tf.float64) for mat in L] + ) + n_prod = tf.cast(tf.reduce_prod(input_tensor=n_list), dtype=tf.int32) + nb = tf.cast(n_prod / na, dtype=tf.int32) + col = tf.shape(input=x)[1] + L1_end_tr = [tf.transpose(a=x) for x in L[1:]] + + for i in range(na - 1, -1, -1): + mask_b = tf.slice(mask, [i * nb], [nb]) + xt, xinb, xina = tf.split(x, [i * nb, nb, (na - i - 1) * nb], 0) + t = xinb / L[0][i, i] + + if tf.reduce_sum(input_tensor=mask_b) != nb: + xinb = tf_solve_upper_triangular_masked_kron(L[1:], t, mask_b) + t_masked = tf_kron_mult(L1_end_tr, xinb) + else: + xinb = tf_solve_upper_triangular_kron(L[1:], t) + t_masked = t + + xt = xt - tf.reshape( + tf.tile(tf.transpose(a=tf.slice( + L[0], [i, 0], [1, i])), [1, nb * col]), + [i * nb, col], + ) * tf.reshape( + tf.tile(tf.reshape(t_masked, [-1, 1]), [i, 1]), [i * nb, col] + ) + x = tf.concat(axis=0, values=[xt, xinb, xina]) + + return x diff --git a/examples/matnormal/MN-RSA.ipynb b/examples/matnormal/MN-RSA.ipynb new file mode 100644 index 000000000..abb6178c4 --- /dev/null +++ b/examples/matnormal/MN-RSA.ipynb @@ -0,0 +1,386 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MN-RSA derivation and example\n", + "\n", + "$$\n", + "\\DeclareMathOperator{\\Tr}{Tr}\n", + "\\newcommand{\\trp}{{^\\T}} % transpose\n", + "\\newcommand{\\trace}{\\text{Trace}} % trace\n", + "\\newcommand{\\inv}{^{-1}}\n", + "\\newcommand{\\mb}{\\mathbf{b}}\n", + "\\newcommand{\\M}{\\mathbf{M}}\n", + "\\newcommand{\\G}{\\mathbf{G}}\n", + "\\newcommand{\\A}{\\mathbf{A}}\n", + "\\newcommand{\\R}{\\mathbf{R}}\n", + "\\renewcommand{\\S}{\\mathbf{S}}\n", + "\\newcommand{\\B}{\\mathbf{B}}\n", + "\\newcommand{\\Q}{\\mathbf{Q}}\n", + "\\newcommand{\\mH}{\\mathbf{H}}\n", + "\\newcommand{\\U}{\\mathbf{U}}\n", + "\\newcommand{\\mL}{\\mathbf{L}}\n", + "\\newcommand{\\diag}{\\mathrm{diag}}\n", + "\\newcommand{\\etr}{\\mathrm{etr}}\n", + "\\renewcommand{\\H}{\\mathbf{H}}\n", + "\\newcommand{\\vecop}{\\mathrm{vec}}\n", + "\\newcommand{\\I}{\\mathbf{I}}\n", + "\\newcommand{\\X}{\\mathbf{X}}\n", + "\\newcommand{\\Y}{\\mathbf{Y}}\n", + "\\newcommand{\\Z}{\\mathbf{Z}}\n", + "\\renewcommand{\\L}{\\mathbf{L}}\n", + "$$\n", + "\n", + "We write the generative model for beta-series RSA. Note that for indicator-coded design matrix $\\X$ this is exactly equivalent to reshaping your data and directly computing the correlation, but allows for other features like convolving $\\X$ with an HRF. Here is the model: \n", + "\n", + "$$\n", + "\\Y = \\X\\beta + \\epsilon\n", + "$$\n", + "\n", + "where $\\Y$ is a TRs-by-voxels matrix of fMRI data, $\\X$ is a timepoint-by-feature design matrix that usually identifies conditions in the experiment, $\\beta$ is a feature-by-voxel matrix, $\\epsilon$ is a matrix of random perturbations (i.e. the noise). In conventional correlation-based RSA $\\epsilon \\sim \\mathcal{N}(0, \\sigma^2 \\I)$, i.e. the distribution of residulas is i.i.d. In Cai et al's BRSA $\\epsilon$ has temporal AR(1) noise structure and voxel-specific noise variance. Of research interest is the covariance of $\\beta$ in its row dimension, so we want to estimate as little as possible of anything else. We additionally import from Cai et al.'s BRSA the use of $\\X_0$, an unmodeled latent timecourse projected onto voxels by $\\beta_0$ as a way of capturing additional residual structure. \n", + "\n", + "The above model can be written as follows: \n", + "\n", + "$$\n", + "\\Y\\mid\\beta,\\X_0,\\beta_0,\\Sigma_t,\\sigma_s \\sim\\mathcal{MN}(\\X\\beta+\\X_0\\beta_0, \\Sigma_t, \\sigma_s\\trp\\mathbf{I}),\n", + "$$\n", + "\n", + "where $\\Sigma_t$ is a covariance matrix for the AR(1) covariance ($\\A\\inv$ in the BRSA paper), and $\\sigma_s$ is a spatial noise scaler that allows each voxel to have its own noise. This is not as general as voxel-specific AR coefficients, but has far fewer parameters and will allow us to tractably handle more complex temporal covariances. This tradeoff will behave differently in different datasets. \n", + "\n", + "Now we add a matrix-normal prior on $\\beta$, allowing us to marginalize. We parameterize the covariance in terms of its cholesky factor $\\L$. \n", + "\n", + "$$\n", + "\\beta\\sim\\mathcal{MN}(0,\\L\\L\\trp, \\sigma_s\\trp\\I)\\\\\n", + "\\Y\\mid\\X_0,\\beta_0,\\Sigma_t,\\sigma_s \\sim\\mathcal{MN}(\\X_0\\beta_0, \\Sigma_t + \\X\\L\\L\\trp\\X\\trp , \\vec{\\sigma_s}\\trp\\mathbf{I})\\\\\n", + "$$\n", + "\n", + "Using the same identity, we can marginalize over $\\beta_0$. \n", + "\n", + "$$\n", + "\\beta_0\\sim\\mathcal{MN}(0,\\I, \\sigma_s\\I)\\\\\n", + "\\Y\\mid\\X_0,\\beta_0,\\Sigma_t,\\sigma_s \\sim\\mathcal{MN}(0, \\Sigma_t + \\X\\L\\L\\trp\\X\\trp + \\X_0\\X_0\\trp , \\sigma_s\\mathbf{I})\n", + "$$\n", + "\n", + "Now, the temporal covariance is the sum of an autoregressive term, a low-rank noise term, and our term of interest. \n", + "\n", + "Next, we apply some computational tricks. Consider the matrix normal (log) density: \n", + "\n", + "$$\n", + "P(X; M, U, V) = \\frac{\\exp\\left(-\\frac12\\Tr\\left[V\\inv(X-M)\\trp U\\inv(X-M)\\right]\\right)}{(2\\pi)^{np/2}|U|^{p/2}|V|^{n/2}}\\\\\n", + "2 \\log P(X; M, U, V) = -\\Tr\\left[V\\inv(X-M)\\trp U\\inv(X-M)\\right]-np\\log 2\\pi-p\\log|U|-n\\log|V|\n", + "$$\n", + "\n", + "Here $n$ and $p$ are the row and column dimension of $M$. Note that both the determinant and the inverse are $O(n^3)$ and $O(p^3)$ for the two covariances. Furthermore, computing the determinant and logging it will be unstable. So instead we can take the cholesky decomposition of both covariances, at which point the log-determinant is just 2 times the sum of the diagonal elements. Then, we recognize that the term inside of the trace can be computed by our favorite triangular matrix solver using the cholesky we already paid for. Let $A = V, B = (X-M)\\trp$. Then a solver for X in $AX=B$ will give us exactly $V^{-1}(X-M)\\trp$. We play the same exact trick for $A=U, B=(X-M)$ (though of course we center the brain first so $M=0$. \n", + "\n", + "Cai et al. additionally apply the matrix inversion lemma twice so that they invert something feature-by-feature instead of time-by-time. Doing this naively will not help us in this version because we're still stuck with doing the determinant (which is cubic in time). Here is the expression: \n", + "\n", + "$$\n", + "\\begin{aligned}\n", + "\\Sigma_Y :=& \\Sigma_t + \\X\\L\\L\\trp\\X\\trp + \\X_0\\X_0\\trp \\\\\n", + "\\mbox{let } \\Z :=& \\Sigma_t + \\X\\L\\L\\trp\\X\\trp\\\\\n", + "\\Sigma_Y\\inv =& (\\Z + \\X_0\\X_0\\trp)\\inv \\\\\n", + "=& \\Z\\inv - \\Z\\inv\\X_0(\\I + \\X_0\\trp\\Z\\inv\\X_0)\\inv\\X_0\\trp\\Z\\inv\\\\\n", + "\\Z\\inv =& \\Sigma_t\\inv - \\Sigma_t\\inv \\X\\L(\\I+\\L\\trp\\X\\trp\\Sigma_t\\inv\\L\\X)\\inv\\L\\trp\\X\\trp \\Sigma_t\\inv\\\\\n", + "\\end{aligned}\n", + "$$\n", + "\n", + "That said, if the inverse and determinant of $\\Sigma_t\\inv$ is trivial (as in the case of AR1 but not generally), we can apply the matrix determinant lemma:\n", + "\n", + "$$\n", + "\\begin{aligned}\n", + "\\Sigma_Y :=& \\Sigma_t + \\X\\L\\L\\trp\\X\\trp + \\X_0\\X_0\\trp \\\\\n", + "\\mbox{let } \\Z :=& \\Sigma_t + \\X\\L\\L\\trp\\X\\trp\\\\\n", + "|\\Sigma_Y| =& |\\Z + \\X_0\\X_0\\trp| \\\\\n", + "=& |\\Z|\\times|\\I + \\X_0\\trp\\Z\\inv\\X_0|\\\\\n", + "\\Z\\inv =& \\Sigma_t\\inv - \\Sigma_t\\inv \\X\\L(\\I+\\L\\trp\\X\\trp\\Sigma_t\\inv\\X\\L)\\inv\\L\\trp\\X\\trp \\Sigma_t\\inv\\\\\n", + "|\\Z| =& |\\Sigma_t| \\times|\\I+\\L\\trp\\X\\trp\\Sigma_t\\inv\\X\\L|\\\\\n", + "=& \\frac{|\\I+\\L\\trp\\X\\trp\\Sigma_t\\inv\\X\\L|}{|\\Sigma_t\\inv|} \\\\\n", + "\\end{aligned}\n", + "$$\n", + "\n", + "Now we notice that we can still apply our cholesky-inverse-solve trick because the term in the inverse and determinant with the lemmas applied is identical. As long as the inverse and determinant of the temporal noise covariance is computable in better than cubic time, this is useful to do. Currently the lemma trick is not being done in the code (but the cholesky trick is). \n", + "\n", + "Now here is an example: " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAECCAYAAAD+eGJTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAMsElEQVR4nO3dfcyddX3H8ffHFsEisTAeVEpWNISFMAekWVAXt4hzDAn1j/2BjgWmSf/Yg2hICA/LzJJlMdH4kGzREEDJRPwDUQnxga5qzLLZUGp5LBMGDFqKZTFTUh3Q8N0f52pSm7tre851Xefc/b1fyZ37nHOdc77fc/e+P/1d17l+55eqQlK7XjPvBiTNlyEgNc4QkBpnCEiNMwSkxhkCUuPmHgJJLk7yH0meSHLdwLXOSPL9JI8meSTJ1UPW26/uiiQ/TnLPCLVWJ7kzyWNJtid5+8D1Ptb9LB9OckeS43p+/luT7E7y8H63nZRkY5LHu+8nDlzvk93P88EkX0+yesh6+227JkklObmvekuZawgkWQH8E/DHwDnAB5KcM2DJvcA1VXUOcCHwlwPX2+dqYPsIdQA+B3ynqn4L+J0h6yY5HfgIsK6qzgVWAJf3XOZLwMUH3HYdsKmqzgI2ddeHrLcROLeq3gb8BLh+4HokOQN4L/BMj7WWNO+RwO8CT1TVk1X1MvBVYP1QxapqV1Vt7S6/yOQP5PSh6gEkWQO8D7h5yDpdrTcA7wJuAaiql6vqfwYuuxJ4XZKVwCrguT6fvKp+CPzsgJvXA7d1l28D3j9kvaq6t6r2dld/BKwZsl7nM8C1wOBn8807BE4Hnt3v+g4G/qPcJ8la4Hxg88ClPsvkH/PVgesAnAm8AHyx2/24OcnxQxWrqp3Ap5j8b7UL+HlV3TtUvf2cVlW7usvPA6eNUHOfDwHfHrJAkvXAzqp6YMg6+8w7BOYiyeuBrwEfrapfDFjnUmB3Vd0/VI0DrAQuAD5fVecDe+h3qPxrun3x9UzC583A8UmuGKreUmpy3vso574nuZHJLuXtA9ZYBdwA/O1QNQ407xDYCZyx3/U13W2DSXIMkwC4varuGrIW8E7gsiRPM9nVeXeSLw9Ybwewo6r2jW7uZBIKQ3kP8FRVvVBVrwB3Ae8YsN4+P03yJoDu++6hCya5CrgU+NMadsLNW5mE6gPd780aYGuSNw5VcN4hcB9wVpIzk7yWyUGlu4cqliRM9pe3V9Wnh6qzT1VdX1Vrqmotk9f2vaoa7H/KqnoeeDbJ2d1NFwGPDlWPyW7AhUlWdT/bixjnAOjdwJXd5SuBbw5ZLMnFTHbpLquqXw5Zq6oeqqpTq2pt93uzA7ig+7cdrOhcv4BLmBxx/U/gxoFr/R6ToeODwLbu65KRXucfAPeMUOc8YEv3Gr8BnDhwvb8DHgMeBv4ZOLbn57+DyfGGV7o/iA8Dv8HkXYHHgX8BThq43hNMjl3t+535wpD1Dtj+NHDykP+G6QpJatS8dwckzZkhIDXOEJAaZwhIjTMEpMYtTAgk2WA96y1arRbqLUwIAKO+cOst63pH82sbvd4ihYCkORj1ZKEVJxxfK09Z+vMYXn1xD685YekJb8c+9avee3mFlziGY3t/XusdXbWOlnr/yx5erpey1LaVvVY6hJWnrGbNP/zFET/uLR/cNkA3Ujs216aDbnN3QGrcTCEw5ucDShrG1CEwh88HlDSAWUYCo34+oKRhzBICc/t8QEn9GfzAYJINSbYk2fLqi3uGLifpCM0SAof1+YBVdVNVrauqdQc7D0DS/MwSAqN+PqCkYUx9slBV7U3yV8B3maw8c2tVPdJbZ5JGMdMZg1X1LeBbPfUiaQ48Y1Bq3KhzB4596ldTzQN48ivnTVXPOQfSoTkSkBpnCEiNMwSkxhkCUuMMAalxhoDUOENAapwhIDXOEJAaZwhIjTMEpMYZAlLjDAGpcaPOIpzWtLMBnX0oHZojAalxhoDUOENAatwsy5CdkeT7SR5N8kiSq/tsTNI4ZjkwuBe4pqq2JjkBuD/Jxqp6tKfeJI1g6pFAVe2qqq3d5ReB7bgMmbTs9HJMIMla4Hxgcx/PJ2k8M58nkOT1wNeAj1bVL5bYvgHYAHAcq2YtJ6lnM40EkhzDJABur6q7lrrP/msRHsOxs5STNIBZ3h0IcAuwvao+3V9LksY0y0jgncCfAe9Osq37uqSnviSNZJYFSf8VSI+9SJoDzxiUGrcsZhFOy9mH0qE5EpAaZwhIjTMEpMYZAlLjDAGpcYaA1DhDQGqcISA1zhCQGmcISI0zBKTGGQJS4wwBqXFH9SzCaTn7UC1xJCA1zhCQGmcISI2bOQSSrEjy4yT39NGQpHH1MRK4mskSZJKWoVkXH1kDvA+4uZ92JI1t1pHAZ4FrgVd76EXSHMyyAtGlwO6quv8Q99uQZEuSLa/w0rTlJA1k1hWILkvyNPBVJisRffnAO7kWobTYpg6Bqrq+qtZU1VrgcuB7VXVFb51JGoXnCUiN62XuQFX9APhBH88laVyOBKTGOYuwR84+1HLkSEBqnCEgNc4QkBpnCEiNMwSkxhkCUuMMAalxhoDUOENAapwhIDXOEJAaZwhIjTMEpMY5i3ABOPtQ8+RIQGqcISA1zhCQGjfrCkSrk9yZ5LEk25O8va/GJI1j1gODnwO+U1V/kuS1wKoeepI0oqlDIMkbgHcBVwFU1cvAy/20JWkss+wOnAm8AHyxW5r85iTH99SXpJHMEgIrgQuAz1fV+cAe4LoD7+RahNJimyUEdgA7qmpzd/1OJqHwa1yLUFpss6xF+DzwbJKzu5suAh7tpStJo5n13YG/Bm7v3hl4Evjz2VuSNKaZQqCqtgHreupF0hx4xqDUOGcRLmPOPlQfHAlIjTMEpMYZAlLjDAGpcYaA1DhDQGqcISA1zhCQGmcISI0zBKTGGQJS4wwBqXGGgNQ4ZxE2aNrZgN99brrH/c3u357qcWO777wV825hLhwJSI0zBKTGGQJS42Zdi/BjSR5J8nCSO5Ic11djksYxdQgkOR34CLCuqs4FVgCX99WYpHHMujuwEnhdkpVMFiN9bvaWJI1plsVHdgKfAp4BdgE/r6p7+2pM0jhm2R04EVjPZGHSNwPHJ7liifu5FqG0wGbZHXgP8FRVvVBVrwB3Ae848E6uRSgttllC4BngwiSrkoTJWoTb+2lL0lhmOSawmclKxFuBh7rnuqmnviSNZNa1CD8OfLynXiTNgWcMSo1zFqEO27SzAf/+1IdGracj40hAapwhIDXOEJAaZwhIjTMEpMYZAlLjDAGpcYaA1DhDQGqcISA1zhCQGmcISI0zBKTGOYtQg3P24WJzJCA1zhCQGmcISI07ZAgkuTXJ7iQP73fbSUk2Jnm8+37isG1KGsrhjAS+BFx8wG3XAZuq6ixgU3dd0jJ0yBCoqh8CPzvg5vXAbd3l24D399yXpJFMe0zgtKra1V1+Hjitp34kjWzmA4NVVUAdbLtrEUqLbdoQ+GmSNwF033cf7I6uRSgttmlD4G7gyu7ylcA3+2lH0tgO5y3CO4B/B85OsiPJh4FPAH+Y5HEmqxN/Ytg2JQ3lkHMHquoDB9l0Uc+9SJoDzxiUGucsQi2ssWcf/hHnTfW45c6RgNQ4Q0BqnCEgNc4QkBpnCEiNMwSkxhkCUuMMAalxhoDUOENAapwhIDXOEJAaZwhIjXMWoQ7bfeetmHcLh2Xa2YBPfmW6x73lg9umetyicCQgNc4QkBpnCEiNm3Ytwk8meSzJg0m+nmT1sG1KGsq0axFuBM6tqrcBPwGu77kvSSOZai3Cqrq3qvZ2V38ErBmgN0kj6OOYwIeAbx9so8uQSYttphBIciOwF7j9YPdxGTJpsU19slCSq4BLgYu6RUklLUNThUCSi4Frgd+vql/225KkMU27FuE/AicAG5NsS/KFgfuUNJBp1yK8ZYBeJM2BZwxKjXMWodSZdjbgcp996EhAapwhIDXOEJAaZwhIjTMEpMYZAlLjDAGpcYaA1DhDQGqcISA1zhCQGmcISI0zBKTGOYtQmtFyn33oSEBqnCEgNW6qZcj223ZNkkpy8jDtSRratMuQkeQM4L3AMz33JGlEUy1D1vkMk48dd80BaRmb6phAkvXAzqp6oOd+JI3siN8iTLIKuIHJrsDh3H8DsAHgOFYdaTlJA5tmJPBW4EzggSRPM1mReGuSNy51Z9cilBbbEY8Equoh4NR917sgWFdV/91jX5JGMu0yZJKOEtMuQ7b/9rW9dSNpdJ4xKDXOEJAa5yxCaU7GnH340g3/dtBtjgSkxhkCUuMMAalxhoDUOENAapwhIDXOEJAaZwhIjTMEpMYZAlLjDAGpcYaA1DhDQGpcqsb7xPAkLwD/dZDNJwNjfkSZ9ZZvvaP5tQ1V7zer6pSlNowaAv+fJFuqap31rLdItVqo5+6A1DhDQGrcIoXATdaz3gLWOurrLcwxAUnzsUgjAUlzYAhIjTMEpMYZAlLjDAGpcf8Hu2X5EIe6s0cAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import scipy\n", + "from scipy.stats import norm\n", + "from scipy.special import expit as inv_logit\n", + "import numpy as np\n", + "from numpy.linalg import cholesky\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def rmn(rowcov, colcov):\n", + " # generate random draws from a zero-mean matrix-normal distribution\n", + " Z = norm.rvs(norm.rvs(size=(rowcov.shape[0], colcov.shape[0])))\n", + " return cholesky(rowcov).dot(Z).dot(cholesky(colcov))\n", + "\n", + "\n", + "def make_ar1_with_lowrank_covmat(size, rank):\n", + " \"\"\" Generate a random covariance that is AR1 with added low rank structure\n", + " \"\"\"\n", + " sigma = np.abs(norm.rvs())\n", + " rho = np.random.uniform(-1, 0)\n", + " offdiag_template = scipy.linalg.toeplitz(np.r_[0, 1, np.zeros(size - 2)])\n", + " diag_template = np.diag(np.r_[0, np.ones(size - 2), 0])\n", + " I = np.eye(size)\n", + "\n", + " prec_matrix = (I - rho * offdiag_template + rho ** 2 * diag_template) / (sigma ** 2)\n", + " lowrank_matrix = norm.rvs(size=(size, rank))\n", + " return np.linalg.inv(prec_matrix) + lowrank_matrix.dot(lowrank_matrix.T)\n", + "\n", + "\n", + "def gen_data(n_T, n_V, space_cov, time_cov):\n", + "\n", + " n_C = 16\n", + " U = np.zeros([n_C, n_C])\n", + " U = np.eye(n_C) * 0.6\n", + " U[8:12, 8:12] = 0.8\n", + " for cond in range(8, 12):\n", + " U[cond, cond] = 1\n", + "\n", + " beta = rmn(U, space_cov)\n", + "\n", + " X = rmn(np.eye(n_T), np.eye(n_C))\n", + "\n", + " Y_hat = X.dot(beta)\n", + "\n", + " Y = Y_hat + rmn(time_cov, space_cov)\n", + "\n", + " return beta, X, Y, U\n", + "\n", + "\n", + "n_T = 100\n", + "n_V = 80\n", + "n_C = 16\n", + "\n", + "spacecov_true = np.diag(np.abs(norm.rvs(size=(n_V))))\n", + "timecov_true = make_ar1_with_lowrank_covmat(n_T, rank=7)\n", + "\n", + "true_beta, true_X, true_Y, true_U = gen_data(n_T, n_V, spacecov_true, timecov_true)\n", + "\n", + "%matplotlib inline\n", + "plt.matshow(true_U)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That is the target matrix. Now we noisify it using a simple synthetic brain data generator, and recover it with MN-RSA. We intentionally code up MN-RSA here from the building blocks the toolkit provides so we can illustrate how easy it is to build new models: " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQoAAADxCAYAAAAz6fmnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAbKElEQVR4nO3deZBd5Xnn8e9P3dqXloSEJJCwZCITYwYblYbBZuKFJRE2hVw1GRc4C8SM8UwFhzh4XBCmIMXUTJE4Y5uZMBiBiYmhwAyxx6pYbMYwVFKA2ReBARkT0WLRggxBe3c/88c5jS+93ff0Offq3tu/T9Up7r398J5X3dLT73nPe55XEYGZ2VgmHewOmFnrc6Iws7qcKMysLicKM6vLicLM6nKiMLO6nCjM2oik6yVtlfTMKF+XpP8paZOkpyStquK8ThRm7eW7wJoxvn4asDI/zgOuruKkThRmbSQi7gfeHCNkLfB3kXkQmCtpSdnzOlGYdZbDgVdq3vfmn5XSXbYBMxvb73xqZux4sz8p9tGn9m0E9tZ8tC4i1jWkYwU4UZg12PY3+3nozqVJsZOX/GJvRKwucbotwLKa90vzz0rxpYdZwwX9MZB0VGA98If53Y8TgLci4rWyjXpEYdZgAQxQzVPakm4GPgkskNQLXAZMBoiIbwMbgE8Dm4DdwB9VcV4nCrMGC4IDkTZHUbetiLPqfD2AP67kZDWcKMyaoKoRxcFyUOcoJK2R9Hy+iuyiitpcJuleSc9K2ijpgirarWm/S9Ljkv6h4nbnSrpN0s8lPSfpoxW1+5X8+/CMpJslTSvR1rBVgZLmS7pb0ov5f+dV1O7X8+/FU5J+KGluVX2u+dqFkkLSgvG0nSqAfiLpaFUHLVFI6gKuIltJdjRwlqSjK2i6D7gwIo4GTgD+uKJ2B10APFdhe4OuBO6IiN8EPlzFOSQdDvwJsDoijgG6gDNLNPldhq8KvAi4JyJWAvfk76to927gmIg4FngBuHgc7Y7WNpKWAb8NbB5nu4UMEElHqzqYI4rjgU0R8VJE7AduIVtVVkpEvBYRj+Wv/4XsH1zpBScAkpYCnwGuq6K9mnZ7gI8D3wGIiP0R8auKmu8GpkvqBmYAr463oVFWBa4Fbshf3wB8top2I+KuiOjL3z5IdpuvsDFWMn4T+Bo0/l9nAP0RSUerOpiJoiEryGpJWg4cBzxUUZPfIvvLVcl9rBorgG3A3+aXNddJmlm20YjYAvw12W/N18huld1Vtt0hFtXcfnsdWFRx+wBfAG6vqjFJa4EtEfFkVW3WM5B4tKqOXUchaRbw98CfRsTbFbR3OrA1Ih4t3bnhuoFVwNURcRywi/EN4d8jny9YS5aIDgNmSvr9su2OJp9xr/TXoqRLyC4nb6qovRnAnwOXVtFeikicn/AcxcgasoIMQNJksiRxU0T8oIo2gROBMyS9THaZdJKkGytquxfojYjBkc9tZImjrFOAX0bEtog4APwA+FgF7dZ6Y/Cho/y/W6tqWNI5wOnA70V15eKPJEucT+Y/y6XAY5IWV9T+MBFwIPFoVQczUTwMrJS0QtIUskm29WUblSSya/3nIuIbZdsbFBEXR8TSiFhO1tefRkQlv50j4nXgFUlH5R+dDDxbQdObgRMkzci/LydT/UTseuDs/PXZwI+qaFTSGrLLvDMiYncVbQJExNMRcWhELM9/lr3Aqvxn0CCiP/FoVQctUeQTVecDd5L95b01IjZW0PSJwB+Q/cZ/Ij8+XUG7jfZl4CZJTwEfAf572QbzEcptwGPA02Q/73E/YJSvCnwAOEpSr6RzgSuAUyW9SDaCuaKidv8GmA3cnf8Mv11hn5sqgIFIO1qVvAGQWWMdc+yUuPXHC5NiP3TEq4+WfCisIbwy06zBsgVXrXtZkcKJwqwJBsKJwszG4BGFmdUViAPRdbC7UUpLLLiSdJ7bbVy7jWy73dptdNsjGRxR+PZoeY36wbndxrfdbu02uu0RiP6YlHS0Kl96mDVYVuGqdZNAiqYmigXzu2L5ssnDPj/i8G5Wf3jaexZ0PP9yeomAge6Rh2xTZs5j5iHLhi0U6d6TXm1of8/wa8vJc+Yxfcnwdrv2DgsdVdeevmGfTeueQ8+0JcPaHZiafn07af/If7Zpk+fQM32EticXuHYe4ds8ddpcZvcsHb4Yp8DynBjh5zdlxjxmzR/+PS5qpJsNo/29KGL3m73bIyJtcQSezCxk+bLJ/OzOZfUDgU994YvJ7e45pNhE0byn05/g7l0zPzl27qb0BDT7me3JsbuPTO/DjM3Fnn/bt3hWcuxoCXkk6k//d7hv/vBfHlW0C9A3rTH/QB++8av/nBoboZa+rEhRqveNqFBl1okGUNLRqsY9oqipUHUq2YM1D0taHxFVPMxk1jECsT/aezqwzIiiIRWqzDrN4GRmytGqyqS5kSpU/Zty3THrTP1ewj22fHHLeZDd3TCbaALR38KjhRRlep9UoSoi1kXE6ohYvbDg3QmzTjEQk5KOVlXmV/y7FarIEsSZwOcr6ZVZB8mWcLduEkgx7t43sEKVWUcZfCgs5ain3pIESUfkG2A9nm+eVEl1t1KTBhGxgWxT1CTPv7wgeSHVvddfm9yPT3/41ORYgD2r3pccu/iBPcmxk7e/kxz7ytr0qvaLf5behx2r0hdnAexZmD7JNvP19ILyB6antzt7y/BVqqOZtL9YUfsdH5+aHDt1pN0/KhBBJQuuEpck/BeyX9pX5xtfbQCWlz13e4+HzNpC2mKrhAVXKUsSApiTv+6hxIZPtXwbwqzBsp3Ckn8nL5D0SM37dRExWBA5ZUnCXwB3SfoyMJOs4HFpThRmTVBgMnN7yeK6ZwHfjYj/kW90/T1Jx0REqY3InCjMGixQVTUzU5YknEu+KXNEPJDvXr+AkhszeY7CrAn6mZR01JGyadZmso2ekPRBYBrZvraleERh1mBV1cyMiD5Jg0sSuoDrI2KjpMuBRyJiPXAhcK2kr5BNj5xTxXaMThRmDZbtFFbN4H2kJQkRcWnN62fJdsurlBOFWRO4wpWZjSlCLf0cRwonCrMmaPdSeE1NFAPdSq5vWWRZ9oYn7y7Uj9NOOys5duDJ55Jj+084Njl22TXPJMe+ceaHkmNnbC92u3zKrvTY3QvT/7L3/GJ/cuz0jcMeOh7V3g8enhwL8P4b0tse6JmZHPt0gT5khWt86WFmY2r/4rpOFGYNFjBxtxSUtCx/nPVZSRslXVBlx8w6xeDKzJSjVZUZUfQBF0bEY5JmA49KuttVuM2Ga+XCuSnGnSgi4jXgtfz1v0h6juzpNicKsxpZPYrWHS2kqGSOQtJy4DjgoSraM+s0rXxZkaJ0opA0C/h74E8jYth+drVVuKfMnFf2dGZtJ5ujmKCXHgCSJpMliZsi4gcjxeRFN9YBpTeGNWtXE3YJtyQB3wGei4hvVNcls84SiL6BCXp7lOwJtT8ATpL0RH5UUvHXrNNM2E2KI+IfoYX/ZGYtwnc9ip5sTz/znv5VUmyRkvpFnt0AuP32m5NjP/nFtO0FAKbsTH++oe+YFcmxXenNsvMDBYe4BWaNpm9LD371t9LL5PcsSf9ezNm8NzkW4I2TDkuOVZEZtMcLdWNiT2aaWX0V1sw8aJwozJqglecfUjhRmDVYVgrPicLMxhLtf3vUicKswVy4xsyS+NLDzMbkOQozS+JEYWZj8joKM6svoM8rM9Pt7+mid838pNjFD+xJbrdISX0otiz7vmuvTY79xJfOS44dmJz+G+bArPTYWb3FyvXPeyptST3AO0f2JMd2703v89bjk0NRpC8Nz/6H9NCFD+wo1naiKucoJK0BriTbe/S6iLhihJjPAX+Rn/rJiPh82fN6RGHWBFUkCkldwFXAqUAv8LCk9bV1aiWtBC4GToyInZIOLX1iyj1mbmYJKqzCfTywKSJeioj9wC3A2iExXwSuioidABGxtYo/Q+lEIalL0uOS/qGKDpl1ogglHcACSY/UHLXXs4cDr9S8780/q/UB4AOS/knSg/mlSmlVXHpcADwHzKmgLbOOVGBl5vaIWF3iVN3ASuCTwFLgfkn/KiLSJ6NGUGpEIWkp8BngujLtmHWyCKq69NgCLKt5vzT/rFYvsD4iDkTEL4EXyBJHKWUvPb4FfA0Ydapd0nmDw6j+3QV2xDXrGKJ/YFLSUcfDwEpJKyRNAc4E1g+J+b9kowkkLSC7FHmp7J+gzJaCpwNbI+LRseIiYl1ErI6I1V0z0neLNuskBeYoxmgj+oDzgTvJLvdvjYiNki6XdEYediewQ9KzwL3Af46I0vd9y8xRnAickRfUnQbMkXRjRPx+2U6ZdZIq11FExAZgw5DPLq15HcCf5Udlxj2iiIiLI2JpRCwnGwL91EnCbASRzVOkHK3KC67MmsD1KICIuA+4r15c116Yu6k/qc3J299JPn//Cccmx0KxatlFlmX/v2vWJcd+ZtXvJMf2nZRepbroCHfS1p3JsdNnpS+fjknpHdm1ZEZybPeetL8/g/YsTB807z90VnrDG9NDA+rOP7Q6jyjMGs5Pj5pZgoEBJwozG0M2UelEYWZ1+NLDzOpq5VufKZwozJrAlx5mNqag/vLsVudEYdYEbX7l4URh1nAB4dujZlaPLz0K6NrTx+xntifFvrJ2UXK7y655plA/+o5JXxJdpFp2kWXZP37szuTYNWekP2u3d9H05FiAXauOSI6dtm1vcuzuw9L7sXdh+sB86R3FCjXN3pS+hHvfosaVQfBdDzMbk5/1MLP6guJP67WYsjUz50q6TdLPJT0n6aNVdcysk0z0ehRXAndExO/mNfzSnxc2m0haOAmkGHeikNQDfBw4ByDfkCS90IPZhKG2vz1a5tJjBbAN+Nt8A6DrJA2bNq6twr2/f3eJ05m1qaimuO7BVCZRdAOrgKsj4jhgF3DR0KDaKtxTunxlYhNUJB4tqkyi6AV6I+Kh/P1tZInDzIZR4tGaylThfh14RdJR+UcnA8+O8b+YTVxtPqIoe9fjy8BN+R2Pl4A/Kt8lsw7UwkkgRalEERFPAMkbqg5M7WL3kfOTYhf/bE9yP94480PJsQBdBe7NHJiVPhwsUi27yLLsO9bfmBx72soTk2MBDvzro+oH5fbNT6/CPW3bvuTYnufT5642n35IcizAYfenb2PZ/c6BQm0n80NhZpakzUcUZTcpNrMUobSjDklrJD0vaZOkYXcZa+L+naSQlDziH4sThVkTKNKOMduQuoCrgNOAo4GzJB09Qtxs4ALgoaFfGy8nCrNGS73jUf/y5HhgU0S8lK+EvgVYO0LcfwX+EkivC1CHE4VZwyVedmSXHgsGVzLnR+2elocDr9S8780/+/WZpFXAsoj4cZV/Ak9mmjVD+mTm9ogY17yCpEnAN8ifv6qSE4VZMwxU0soWYFnN+6X5Z4NmA8cA90kCWAysl3RGRDxS5sROFGaNVl3hmoeBlZJWkCWIM4HPv3uaiLeABYPvJd0HfLVskgDPUZg1RRV3PSKiDzgfuBN4Drg1IjZKulzSGY3sv0cUZs1Q0YKriNgAbBjy2aWjxH6ymrM2OVFM2t/PjM1vJ8XuWJW21BtgxvZiF4A7P9CVHDurN73tIqPLItWyiyzLvv3Ff0rvBHDPngeTYx/bszw59qipryXHXnblOcmxS3/yVnIswNsrZyfHTt3ZV6jticQjCrMmqHdZ0eqcKMyaoYWrV6UoW4X7K5I2SnpG0s2SplXVMbOOEWS3R1OOFjXuRCHpcOBPgNURcQzQRXa7xsyGqOKux8FU9tKjG5gu6QBZqf5Xy3fJrAO1cBJIUaYU3hbgr4HNwGvAWxFxV1UdM+sobV4Kr8ylxzyyJ9dWAIcBMyUNK9vkcv020aVedrTypUeZycxTgF9GxLaIOAD8APjY0CCX6zejssI1B0uZRLEZOEHSDGVPoJxMtqzUzIZq80uPcU9mRsRDkm4DHgP6gMeBdVV1zKyTqIVvfaYoW4X7MuCyivpi1plafP4hRVNXZg5M7mLf4llJsXsWpl+vTUmvyJ4p8EOb99SvkmMnbd2ZHLtr1RHJsUVK6hd5dgPg5On9ybEv7kuvrNZV4FfozNfT+9A/Y0pyLMC0N9Of39g3t4H/HJwozKwuJwozq6fdLz1cuMbM6vKIwqwZ2nxE4URh1mgxwW+PmlkijyjMbCyi/ScznSjMmsGJwszG5JWZZpbEiaIAwUB32tLsma+nTxPvXlhsOcj0bek/tXeO7Elvd9bU5Nhp29KXQ++bn95ukZL6UGxZ9n+cu6V+UO5LvR9Njn37iPTtE2a/sCc5FmDf/DnJsT0b05fgF+W7HmZWn0cUZjamFq81kaLumF3S9ZK2Snqm5rP5ku6W9GL+33mN7aZZe6uqFJ6kNZKel7RJ0kUjfP3PJD0r6SlJ90h6XxX9T7m4/y6wZshnFwH3RMRK4J78vZmNpoIKV5K6gKuA04CjgbMkHT0k7HGyLTSOBW4D/qqK7tdNFBFxP/DmkI/XAjfkr28APltFZ8w6VUUjiuOBTRHxUkTsB24h+7f4roi4NyIGq1g/CCytov/jfXp0UUQM7kL7OrBotMDaKtwH9hetMGPWIdJHFAsG/73kx3k1rRwOvFLzvjf/bDTnArdX0f3Sk5kREdLouTAi1pHX0pzds7TNp3TMiitYin97RKwufc5s64zVwCfKtgXjH1G8IWlJ3qElwNYqOmPWsaqpwr0FWFbzfmn+2XtIOgW4BDgjIvaV7Dkw/kSxHjg7f3028KMqOmPWqSqao3gYWClphaQpZHv9rn/PeaTjgGvIkkRlv8BTbo/eDDwAHCWpV9K5wBXAqZJeJNsI6IqqOmTWkSoYUUREH3A+cCfZHjq3RsRGSZdLOiMP+zowC/g/kp6QtH6U5gqpO0cREWeN8qWTC58tQP1pF2sHpqdX4e75xf5C3Xj1t9KXRHfvTe9HTEqP3X3Y9OTYadvSR49HTX2tflCNItWyiyzLvmbpA8mxHx44Ljl236Ezk2MBpu5I/7ux6/3py/V5pn7Ie1Q0OxcRG4ANQz67tOb1KdWc6b28MtOs0fz0qJklcaIws3r89KiZ1eVLDzMbWwc8PepEYdYMThRmNhZX4TazNE4UZlaPor0zhROFWaN5S8Fiolvsmz85KXb2lr7kdqdvTK8ODdCzZEVy7Nbj09vdtWRGcuzehem/YXqeT2/3sivPSY4FmPl6f3JskWrZRZZlP/m1/50c+6kvfDE5FuCtI6ckxx7y5FuF2i6kvQcUHlGYNYMnM82svjZPFOOtwv11ST/PK/3+UNLcxnbTrI0l1qJo5VHHeKtw3w0ck1f6fQG4uOJ+mXWWaipcHTTjqsIdEXflRTSgwkq/Zp1ocMFVO48oqpij+ALw/dG+mFcRPg9gygzvE2QTkwZaOAskGG/NTAAkXQL0ATeNFhMR6yJidUSsnjy1WHUis46QetnRwrlk3CMKSecApwMnR7T5sjOzBpuQC64krQG+BnyiZlciMxtNm/8qHW8V7r8BZgN355V+v93gfpq1tY6fzBylCvd3xnvC1Crck/anj9X2fnCsXdWGm7N5b3KsokDF7j3py6GX3vGr5NjNpx+S3u5Pii1D7p+RvsR59gt7kmOLVMsusiz73uuvTY4F+Ldf/lJy7Fu/OSe94ccLdCKANr8698pMsyaYkHMUZpbOhWvMrL4IX3qYWX3tPqIoteDKzBJVtOBK0hpJz0vaJOmiEb4+VdL3868/JGl5Fd13ojBrgipuj0rqAq4CTgOOBs6SdPSQsHOBnRHxG8A3gb+sov9OFGaNFsBApB1jOx7YFBEvRcR+4BZg7ZCYtcAN+evbgJMlpe+ePQonCrMm0EDaASyQ9EjNcV5NM4cDr9S8780/Y6SY/Anvt4D0hTij8GSmWTOk3/XYHhGrG9mV8fCIwqwJKlrCvQVYVvN+af7ZiDGSuoEeYEfZ/jtRmDVadY+ZPwyslLRC0hTgTGD9kJj1wNn5698FflrF093NLdcv6JuWNq+y4+Ppz1i8/4Zi5frfOOmw9OAC00B7Fqbn3dmb0mMPu39XcuzbK2cnxwJMezN9W4R989OfhZi6Y39ybJGS+kWe3QD4x/91TXLsb9z0n9Ibvjk9NFuZWX4hRUT0STofuBPoAq6PiI2SLgceiYj1ZM9hfU/SJrLKdGeWPjGeozBrjoqe9YiIDcCGIZ9dWvN6L/Dvqznbr42rCnfN1y6UFJIWVN0xs06iiKSjVY23CjeSlgG/DWyuuE9mnSUS11C0cF3NcVXhzn2TrMpV6/7pzFpExxeuGYmktcCWiHiygkVfZp2vhS8rUhROFJJmAH9OdtmREv/rcv0zXa7fJqBo/8I141lHcSSwAnhS0stkiz4ek7R4pODacv3dLtdvE9VgTYp6R4sqPKKIiKeBQwff58lidURsr7BfZp2ldXNAkvFW4TazAtr99uh4q3DXfn15Zb0x60QBJFafb1UtuzJz6kg3ZEcx0FNs7qPIbaiFD6Q/T7P/0FnJsfsWpfe5+50DybFTd6YvyQbYNzf9r0DPxp3Jsbve35Mce8iT6VsMFCqpT7Fl2Zt+7+rk2K6vpvdBtPZoIUXLJgqzjuJEYWZ1OVGY2ZiCyh4KO1icKMyawHMUZlafE4WZjSkCBtr72sOJwqwZ2jtPOFGYNYPnKMysPicKMxvT4E5hbUwVVPJOP5m0DfjnEb60AGjE06dut/Ftt1u7VbX9vohYmBLYM21xfOyIs+sHAne8+FePtuIGQM0t1z/KN1bSI4345rjdxrfdbu02uu1R+dLDzMYUQH973/ZwojBruIBwoqjCOrfb0HYb2Xa7tdvotkfW5pceTZ3MNJuIeqYsio8tHrP+07vueOXKcU9mSpoPfB9YDrwMfC4idg6J+QhwNTAH6Af+W0R8v17b3qTYrBmaU1z3IuCeiFgJ3JO/H2o38IcR8SGyjb2+JWluvYadKMyaoTmJYi1wQ/76BuCzw7sRL0TEi/nrV4GtQN3bvK0yR2HWuSKgv78ZZ1oUEa/lr18HFo0VLOl4YArwi3oNO1GYNUP6aGGBpEdq3q+LiHcnXyX9BBhpD51L3nu6CGn06rCSlgDfA86OqH9LxonCrBnSE8X2sSYzI+KU0b4m6Q1JSyLitTwRbB0lbg7wY+CSiHgwpVOeozBruKbtZr4eGFwrfjbwo6EBkqYAPwT+LiJuS23YicKs0QIiBpKOkq4ATpX0InBK/h5JqyVdl8d8Dvg4cI6kJ/LjI/Ua9qWHWTM04enRiNgBnDzC548A/yF/fSNwY9G2nSjMmqHNFzY6UZg1WvNujzaME4VZE4SL65rZ2CpZdXlQOVGYNVoHlMJzojBrBtejMLOxBBAeUZjZmMIVrswsQbT57VFXuDJrMEl3kG0RkGJ7RKxpZH/Gw4nCzOryQ2FmVpcThZnV5URhZnU5UZhZXU4UZlbX/wf37nUQme9vRgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import tensorflow as tf\n", + "from brainiak.matnormal.covs import CovDiagonal, CovAR1, CovUnconstrainedCholesky\n", + "from brainiak.utils.utils import cov2corr\n", + "from brainiak.matnormal.utils import (\n", + " make_val_and_grad,\n", + " pack_trainable_vars,\n", + " unpack_trainable_vars,\n", + " unflatten_cholesky_unique,\n", + ")\n", + "from brainiak.matnormal.matnormal_likelihoods import matnorm_logp_marginal_row\n", + "from scipy.optimize import minimize\n", + "\n", + "space_cov = CovDiagonal(size=n_V)\n", + "time_cov = CovAR1(size=n_T)\n", + "\n", + "rsa_cov = CovUnconstrainedCholesky(size=n_C)\n", + "\n", + "params = (\n", + " rsa_cov.get_optimize_vars()\n", + " + time_cov.get_optimize_vars()\n", + " + space_cov.get_optimize_vars()\n", + ")\n", + "\n", + "# construct loss (marginal likelihood constructed automatically)\n", + "# note that params are ignored by this function but implicitly\n", + "# tracked by tf.GradientTape, and the remaining inputs are\n", + "# embedded via the closure mechanism\n", + "def loss(params):\n", + " return -(\n", + " time_cov.logp\n", + " + space_cov.logp\n", + " + rsa_cov.logp\n", + " + matnorm_logp_marginal_row(\n", + " true_Y, row_cov=time_cov, col_cov=space_cov, marg=true_X, marg_cov=rsa_cov\n", + " )\n", + " )\n", + "\n", + "\n", + "val_and_grad = make_val_and_grad(lossfn=loss, train_vars=params)\n", + "\n", + "x0 = pack_trainable_vars(params)\n", + "\n", + "opt_results = minimize(fun=val_and_grad, x0=x0, jac=True, method=\"L-BFGS-B\")\n", + "\n", + "fit_params = unpack_trainable_vars(opt_results.x, params)\n", + "\n", + "for var, val in zip(params, fit_params):\n", + " var.assign(val)\n", + "\n", + "U = rsa_cov._cov.numpy()\n", + "C = cov2corr(U)\n", + "plt.matshow(C)\n", + "plt.colorbar()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In practice, MN-RSA is already implemented in brainiak.matnormal, including the nuisance regressor estimation of Cai et al." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAECCAYAAAD+eGJTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAASW0lEQVR4nO3df2yc9X0H8Pfbd/6ZOLGDyU8bHApj/BA0LGJAUNc1HaT8Cms3iapMYUVimrZBKyTED2nd/pkqtWph2gRDhCYaWZAGKaVAQ9IU1HUMWBISEhIIIQmJk4BNSBzXjs/2+bM/7snmeHcO97nnec7h+35JUWyfP/587Tu//dzd870PzQwiEq6aai9ARKpLISASOIWASOAUAiKBUwiIBE4hIBK4qocAySUk3yO5m+T9CffqIPkKyR0k3yF5T5L9xvTNkHyL5Asp9Goh+QzJd0nuJHl1wv2+G/0st5NcTbIh5q//JMluktvHfGwGyfUk34/+b0243w+in+fbJH9KsiXJfmMuu5ekkWyLq18xVQ0BkhkA/wzgawAuBvBNkhcn2HIEwL1mdjGAqwD8VcL9TroHwM4U+gDAIwDWmtnvArg8yb4k5wG4G8BCM7sUQAbAbTG3WQFgybiP3Q9gg5ldAGBD9H6S/dYDuNTMLgOwC8ADCfcDyQ4A1wHYH2Ovoqp9JHAlgN1mtsfMhgA8DWBpUs3M7LCZbY7e7kPhF2ReUv0AgGQ7gBsBPJFkn6jXdABfArAcAMxsyMyOJdw2C6CRZBZAE4BDcX5xM/s1gE/HfXgpgJXR2ysB3JpkPzNbZ2Yj0buvA2hPsl/kxwDuA5D42XzVDoF5AA6Meb8LCf9SnkSyE8ACAG8k3OphFK7M0YT7AMB8AD0AfhLd/XiC5JSkmpnZQQA/ROGv1WEAvWa2Lql+Y8wys8PR2x8BmJVCz5O+DeAXSTYguRTAQTPbmmSfk6odAlVBciqAZwF8x8yOJ9jnJgDdZrYpqR7jZAFcAeBRM1sAoB/xHiqfIrovvhSF8JkLYArJ25PqV4wVzntP5dx3kg+hcJdyVYI9mgA8COBvk+oxXrVD4CCAjjHvt0cfSwzJWhQCYJWZrUmyF4BFAG4huQ+FuzpfIflUgv26AHSZ2cmjm2dQCIWkfBXAXjPrMbNhAGsAXJNgv5M+JjkHAKL/u5NuSPIOADcB+JYlu+HmCyiE6tbodtMOYDPJ2Uk1rHYI/DeAC0jOJ1mHwoNKzyfVjCRRuL+808x+lFSfk8zsATNrN7NOFL63X5lZYn8pzewjAAdIXhh9aDGAHUn1Q+FuwFUkm6Kf7WKk8wDo8wCWRW8vA/CzJJuRXILCXbpbzGwgyV5mts3MZppZZ3S76QJwRXTdJta0qv8A3IDCI64fAHgo4V7XonDo+DaALdG/G1L6Pr8M4IUU+nwRwMboe3wOQGvC/f4ewLsAtgP4VwD1MX/91Sg83jAc/ULcCeAsFJ4VeB/ALwHMSLjfbhQeuzp5m3ksyX7jLt8HoC3J65BRIxEJVLXvDohIlSkERAKnEBAJnEJAJHAKAZHATZoQIHmX+qnfZOsVQr9JEwIAUv3G1e+M7vd5/t5S7zeZQkBEqiDVk4XaZmSss6O26GU9R/I4+6xM0cve2+t7TQXLsuRlw7l+1NYX32CXOTFS9OOnM9SSLXnZyEA/sk0l+uVc7ZDpHyq9lvwJ1GUai1422lj8Ojidmly+dL+RAdRlm4peZhnn35qa4tff0FA/6uom2BzpvE2P1hVf5/BgP2obSvfjiPN3qMTNc3ioH7UTfH9W4ucykVz/pxjO9RctLH2rTUBnRy3efLnj9J84zuLb73T1G2zz3dinb/3EVffh12e66lp2l/7lmrDuTd/W/f5LfHtRmj446qrLtxYPh9PWNfhunsz7fin759W76uqP+v5oWKb8X2YAGJ5a/I/lRLate7jkZbo7IBK4ikIgzdcHFJFkuEOgCq8PKCIJqORIINXXBxSRZFQSAlV7fUARiU/iDwySvIvkRpIbe474HgUXkeRUEgKf6fUBzexxM1toZgtLnQcgItVTSQik+vqAIpIM98lCZjZC8q8BvIzC5Jknzeyd2FYmIqmo6IxBM3sJwEsxrUVEqkBnDIoELtW9A+/tbXPtA9jw1HJXvyXnLHTV2cXnu+rOfa7HVcfcsKvuwDfK34cBAGdv9e1Y6r3ct5ErM+SbwJYd8NXl631/26buP+Gqy3b7hlgdvn6Oq66ur/y9ETbBj0RHAiKBUwiIBE4hIBI4hYBI4BQCIoFTCIgETiEgEjiFgEjgFAIigVMIiAROISASOIWASOAUAiKBS3UXoWXpmgrk3Q24dv9GV90NXz7XVce+flfd6MxWV13H6j2uut5Fvu/PuxtwcLrvZeWG5/punlO6fa9lWbvfN3nq+JXtrrq5a3zXn7U0l12T7S89JUlHAiKBUwiIBE4hIBK4SsaQdZB8heQOku+QvCfOhYlIOip5YHAEwL1mtplkM4BNJNeb2Y6Y1iYiKXAfCZjZYTPbHL3dB2AnNIZM5IwTy2MCJDsBLADwRhxfT0TSU3EIkJwK4FkA3zGz//eyq2NnEQ7nfM+ji0hyKgoBkrUoBMAqM1tT7HPGziKsrZ9SSTsRSUAlzw4QwHIAO83sR/EtSUTSVMmRwCIAfwbgKyS3RP9uiGldIpKSSgaS/gYAY1yLiFSBzhgUCRzNyp9r5jW9cY5d3XlH2XVW7ztg4aBvxt9Lrz7rqlty87dcdVbr3GU3tfwdmQBw7Pw6X79pvgO/hk98t7HsCV/d4Fm+v22ZnK9f25bfuuqGm33XQ76x/O/vrf/4R/Qd6yp6BepIQCRwCgGRwCkERAKnEBAJnEJAJHAKAZHAKQREAqcQEAmcQkAkcAoBkcApBEQCpxAQCZxCQCRwqc4iHGrJ4sOvzyy77tznelz9vLMBvbsB1/58latu0T1/4aqbcijnqmvu8s3q46hvl13Tvl5XXf95La662hO+mYl9ztmHQ631vrppvt2jTYcGy67hSOnrTkcCIoFTCIgETiEgErg45g5kSL5F8oU4FiQi6YrjSOAeFEaQicgZqNLhI+0AbgTwRDzLEZG0VXok8DCA+wD4npMRkaqrZALRTQC6zWzTaT7vf2cRjgxoFqHIZFPpBKJbSO4D8DQKk4ieGv9JY2cRZps0i1BksnGHgJk9YGbtZtYJ4DYAvzKz22NbmYikQucJiAQulr0DZvYqgFfj+Foiki4dCYgELtVdhJkc0LK7/B1szPlmCo7ObHXVeWcDencD/ucj/+Kqu37uF111o3/8+6467xDq/LQGV139Ed8uycFZvl19mWHfLslj5/lmQuZafT/P0Wxj+TU7S/+915GASOAUAiKBUwiIBE4hIBI4hYBI4BQCIoFTCIgETiEgEjiFgEjgFAIigVMIiAROISASOIWASODS3UXYP4SWNw+VXXfgGx2ufh2r97jqcpfMc9V5ZwN6dwO+fGiLq+66P/X1yzf4dlfmG3w3s7rDx1119Rnf7rwjFzW56jqfPuiqs6PHXHUjl8wvu6ZmqPRrAetIQCRwCgGRwCkERAJX6QSiFpLPkHyX5E6SV8e1MBFJR6UPDD4CYK2Z/QnJOgC+R1ZEpGrcIUByOoAvAbgDAMxsCMBQPMsSkbRUcndgPoAeAD+JRpM/QVIjhkTOMJWEQBbAFQAeNbMFAPoB3D/+k8bOIhzKn6ignYgkoZIQ6ALQZWZvRO8/g0IonGLsLMK6TPkvlSwiyapkFuFHAA6QvDD60GIAO2JZlYikptJnB/4GwKromYE9AP688iWJSJoqCgEz2wJgYUxrEZEq0BmDIoGjmW/+mkdzS7stuPbususyudI7oCaSa/Ud6Ay0+bKxuav8OYsAMFrr2/XW2O3btbju31e46m78vSWuOu9MyHxTnasuM+A7XeXTy1pcdaO+UYSY+Ur5O2oBYLS5/AfYX9+1HL0Dh4re0HQkIBI4hYBI4BQCIoFTCIgETiEgEjiFgEjgFAIigVMIiAROISASOIWASOAUAiKBUwiIBE4hIBK4VGcR1uTyaPrgaNl1vZe3ufplJpi/NpHhab5dfRz17sj09fPOBvTuBnxx01pX3eO9c111R0d8r1s7NTPoqntsxc2uunkbel11RxbNcdU1Hyh/96hNMJ9RRwIigVMIiAROISASuEpnEX6X5Dskt5NcTbIhroWJSDrcIUByHoC7ASw0s0sBZADcFtfCRCQdld4dyAJoJJlFYRip70XTRKRqKhk+chDADwHsB3AYQK+ZrYtrYSKSjkruDrQCWIrCYNK5AKaQvL3I5/3fLMKRAf9KRSQRldwd+CqAvWbWY2bDANYAuGb8J50yizDbVEE7EUlCJSGwH8BVJJtIEoVZhDvjWZaIpKWSxwTeQGES8WYA26Kv9XhM6xKRlFQ6i/B7AL4X01pEpAp0xqBI4FLdRWiZGuRby39w0LsbcHC6b5ddwye+3YBN+3y7yfLTfCda5ht8V593NqB3N+Bd032nj3j7zcwed9XNfs337NVove96mLbXt9uxZnC4/KIJdrjqSEAkcAoBkcApBEQCpxAQCZxCQCRwCgGRwCkERAKnEBAJnEJAJHAKAZHAKQREAqcQEAmcQkAkcKnuIkQNXTvfsgPOmYJzfd9ew1Ffv/7zWlx19UfKny0HAHWHfbvlRmb4Zvx5ZwOmvfvwwY8vc9X1nt/oqpuxzXc95Np8/Rp7+squYV67CEWkBIWASOAUAiKBO20IkHySZDfJ7WM+NoPkepLvR//7XqpGRKrusxwJrACwZNzH7gewwcwuALAhel9EzkCnDQEz+zWAT8d9eCmAldHbKwHcGvO6RCQl3scEZpnZ4ejtjwDMimk9IpKyih8YNDMDUPJJyFNmEQ71V9pORGLmDYGPSc4BgOj/7lKfeMoswjrfySYikhxvCDwPYFn09jIAP4tnOSKSts/yFOFqAP8F4EKSXSTvBPB9AH9E8n0UphN/P9llikhSTntyvZl9s8RFi2Nei4hUgc4YFAlcursIzSbczVRKvt6XVVO68666/tm+GYa1J5wzE2fVu+rqM3TVZft8uxanZnyz87yzAb27Af9h1tuuupfrFrnqPLdpAKjtc8wUBJB37AK1rtK/QzoSEAmcQkAkcAoBkcApBEQCpxAQCZxCQCRwCgGRwCkERAKnEBAJnEJAJHAKAZHAKQREAqcQEAlcqrsIR+tq0D+v/B1zU/efcPWr3f+Jq27wa+e46vqcsw8zw75daEcuanLVTe3yzcB7bMXNrrrZrw246ryzAb27ATf93aOuuhuvucVVN9A+21Xn2bVoE+w41ZGASOAUAiKBUwiIBM47i/AHJN8l+TbJn5JsSXaZIpIU7yzC9QAuNbPLAOwC8EDM6xKRlLhmEZrZOjMbid59HUB7AmsTkRTE8ZjAtwH8otSFY8eQDQ9qDJnIZFNRCJB8CMAIgFWlPmfsGLLaBo0hE5ls3CcLkbwDwE0AFkdDSUXkDOQKAZJLANwH4A/MzHc6mIhMCt5ZhP8EoBnAepJbSD6W8DpFJCHeWYTLE1iLiFSBzhgUCVyquwg5Yqg/OnL6Txwn2+2bZXf8St/pC21bfuuqG2r1zRQ8dl6tq67z6YOuuu4/nOuqm7eh11U3Wu+7mc3Y5rvevbMBvbsBX3zteV+/a2911eXOmVF2DSf4kehIQCRwCgGRwCkERAKnEBAJnEJAJHAKAZHAKQREAqcQEAmcQkAkcAoBkcApBEQCpxAQCZxCQCRwqe4iBCeeiVbK4evnuNrNXbPHVTd40TxX3dC0jKsu11r+zwQA7OgxV93MV1xlOLLIdz1M2zvoqsu1+WYR1vYNu+q8swG9uwFf/M1zrrrfWfmXZdcM7dIsQhEpQSEgEjjXGLIxl91L0ki2JbM8EUmadwwZSHYAuA7A/pjXJCIpco0hi/wYhZcd18wBkTOY6zEBkksBHDSzrTGvR0RSVvZThCSbADyIwl2Bz/L5dwG4CwDqGzXBXGSy8RwJfAHAfABbSe5DYSLxZpJFn2Q9ZRZhnWYRikw2ZR8JmNk2ADNPvh8FwUIz+yTGdYlISrxjyETkc8I7hmzs5Z2xrUZEUqczBkUCpxAQCVyquwithhieWv5Ou7o+3/lI1tLsqss3+rKx6ZBvt9xo1rdbbuSS+a66TJ9vnc0Hcq66mkHfrr7Gnj5XXX6G71ko7wxDz2xAwLcbEAB2LXu07Jor/62n5GU6EhAJnEJAJHAKAZHAKQREAqcQEAmcQkAkcAoBkcApBEQCpxAQCZxCQCRwCgGRwCkERAKnEBAJHM3Se8Vwkj0APixxcRuANF+iTP3O3H6f5+8tqX7nmtnZxS5INQQmQnKjmS1UP/WbTL1C6Ke7AyKBUwiIBG4yhcDj6qd+k7DX577fpHlMQESqYzIdCYhIFSgERAKnEBAJnEJAJHAKAZHA/Q+xnHtpDs71SgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAECCAYAAAD+eGJTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAASFElEQVR4nO3de2zd9XnH8c8T23Fw7s6NmoQEEIERxDWiUBBDhVKWRKR/jAk0ptAiMZWtpRUShaKt3V/r1Kql2gYd4qo2A22UDkQpJNBWCDbSJWkC5AJmIQshTuyGJA1OHN+e/XFONsfzcTjP+f1+x+H7fklRbJ/z+Pmeiz/+nePzPY+5uwCka1y9FwCgvggBIHGEAJA4QgBIHCEAJI4QABJX9xAws+vM7G0ze9fM7s651zwz+5WZbTazTWZ2R579hvRtMLPfmtlzBfSaZmZPmdlWM9tiZpfl3O/r5evyLTN7wswmZPz9HzGzTjN7a8jXWs1stZm1l/+fnnO/75avzzfM7GdmNi3PfkNOu9PM3MxmZtVvJHUNATNrkPSPkv5I0jmSbjKzc3Js2S/pTnc/R9Klkv4i535H3SFpSwF9JOmHkl5w97MlnZ9nXzM7RdJXJS1293MlNUi6MeM2j0m6btjX7pb0srufKenl8ud59lst6Vx3P0/SO5LuybmfzGyepGsl7ciw14jqfSRwiaR33X2bu/dKelLS8ryauXuHu68vf3xQpR+QU/LqJ0lmNlfSUkkP5dmn3GuqpCslPSxJ7t7r7vtzbtso6SQza5TUImlXlt/c3V+R9OGwLy+X9Hj548clfSHPfu6+yt37y5++Lmlunv3KfiDpLkm5v5qv3iFwiqT3h3y+Uzn/UB5lZgskXShpTc6t7lPpxhzMuY8knSapS9Kj5YcfD5nZxLyaufsHkr6n0m+rDkkH3H1VXv2GmOPuHeWPd0uaU0DPo74k6Rd5NjCz5ZI+cPeNefY5qt4hUBdmNknSTyV9zd1/n2OfZZI63X1dXj2GaZR0kaQH3P1CSd3K9lD5GOXH4stVCp82SRPN7Oa8+o3ES697L+S172Z2r0oPKVfm2KNF0jcl/XVePYardwh8IGnekM/nlr+WGzNrUikAVrr703n2knS5pOvNbLtKD3U+a2Y/ybHfTkk73f3o0c1TKoVCXq6R9J67d7l7n6SnJX0mx35H7TGzT0lS+f/OvBua2S2Slkn6U893w80ZKoXqxvL9Zq6k9WZ2cl4N6x0C/ynpTDM7zczGq/Sk0rN5NTMzU+nx8hZ3/35efY5y93vcfa67L1Dpsv3S3XP7TenuuyW9b2Znlb90taTNefVT6WHApWbWUr5ur1YxT4A+K2lF+eMVkp7Js5mZXafSQ7rr3f1Qnr3c/U13n+3uC8r3m52SLirftrk1res/SUtUesb1vyTdm3OvK1Q6dHxD0obyvyUFXc6rJD1XQJ8LJK0tX8Z/kzQ9535/I2mrpLck/VhSc8bf/wmVnm/oK/9A3Cpphkp/FWiX9JKk1pz7vavSc1dH7zM/yrPfsNO3S5qZ521o5UYAElXvhwMA6owQABJHCACJIwSAxBECQOLGTAiY2W30o99Y65VCvzETApIKveD0O6H7fZIvW+H9xlIIAKiDQl8sNLO1wRfMaxrxtK69A5o1o2HE097ZNiPUb6C5csb193SrcUK2G+zcKp82Wj9vjPVr7K582/X1daupaeR+g+NHWegoGg5X3gjZ19+tpsaR+w1MiP2usQoXr+/IR2pqnlSxbjB4fTb0jNxwtOtSkmwgtkG0b/LI9/eBQ91qaBnlvhm4+foOfKj+Q90jVgavrpgF85r0mxfnHf+Mw3zupi+G+h04LfgmN8FcHGiO1fXMjP1QnrzmSKju4NzxobrWzR+F6vYvjIVtQ2/shjg0e+QfruNp3Rq7Ppv294TqOq6cGqqL3M/ee7TyVhkeDgCJqykEinx/QAD5CIdAHd4fEEAOajkSKPT9AQHko5YQqNv7AwLITu5PDJrZbWa21szWdu0dyLsdgCrVEgIf6/0B3f1Bd1/s7osrvQ4AQP3UEgKFvj8ggHyEXyzk7v1m9peSXlRp8swj7r4ps5UBKERNrxh09+clPZ/RWgDUAa8YBBJX6N6Bd7bNCO0DWP3Eo6F+Sy/+f3MeP5bui6rf3yBJPdNiT3zOeTL2KGrPTYtCdS2dsQ0vHZdPDtUNBLdwNByO1bV0xS5f1wWxzR/T347d7t1tsb0RJ+0J7DUZpRVHAkDiCAEgcYQAkDhCAEgcIQAkjhAAEkcIAIkjBIDEEQJA4ggBIHGEAJA4QgBIHCEAJK7QXYQDzeNCU4GiuwF/vu6FUN0VX/nzUN3Ula+H6g7c8OlQ3bT23lDd3kWx3XKzNsQm7XReHNtGOGNTbCLQYFPsd9v0Ve2hur5Fp4bq2l4LlWnyul1V1+zcV/m+wpEAkDhCAEgcIQAkrpYxZPPM7FdmttnMNpnZHVkuDEAxanlisF/Sne6+3swmS1pnZqvdfXNGawNQgPCRgLt3uPv68scHJW0RY8iAE04mzwmY2QJJF0pak8X3A1CcmkPAzCZJ+qmkr7n770c4/X9nEfb3dNfaDkDGagoBM2tSKQBWuvvTI51n6CzCxgkTa2kHIAe1/HXAJD0saYu7fz+7JQEoUi1HApdL+jNJnzWzDeV/SzJaF4CC1DKQ9FVJgVEoAMYSXjEIJK7QXYSSRp2JVkl0NmB0N+Crf/9PobqLZn85VBedndffEpuBd6Q1VKa958R2A07dNhCq624bH6qzgdiMv8NLFobq9p0dKlPz/tiB9IH51e9a7P3nytclRwJA4ggBIHGEAJA4QgBIHCEAJI4QABJHCACJIwSAxBECQOIIASBxhACQOEIASBwhACSu0F2EbtJAYAxez7TYbrnobMDobsD1f/VAqO78v7s9VDcutjlPU96L7bLrmRHb9TYYvJf1tMb69U2K1fWcGZu1OGVtbHdl3+RQmVp2V3/7jXZf4UgASBwhACSOEAASl8XcgQYz+62ZPZfFggAUK4sjgTtUGkEG4ARU6/CRuZKWSnoom+UAKFqtRwL3SbpLUuydMgHUXS0TiJZJ6nT3dcc5H7MIgTGs1glE15vZdklPqjSJ6CfDz8QsQmBsC4eAu9/j7nPdfYGkGyX90t1vzmxlAArB6wSAxGWyd8Ddfy3p11l8LwDF4kgASFyxuwgbpZ6Z1e/wmvPkplC/Azd8OlQXnQ0Y3Q248Rv3h+quveGWUN3B+bFdb+P6orsIY3Uz3zwSqtu+tClUt/C+WD+Ni9XtPW9KqM4GA7tARynhSABIHCEAJI4QABJHCACJIwSAxBECQOIIASBxhACQOEIASBwhACSOEAASRwgAiSMEgMQVuouwsdt18prqd1ztuWlRqN+09t5QXX9LbPZhdDZgdDfgqn99LFR3yT2xWYuj7UQbzaHZwd81FtsN2LI71q99RWw44Pzn+0N1LV2xO0zzvurv1w29lXfGciQAJI4QABJHCACJq3UC0TQze8rMtprZFjO7LKuFAShGrU8M/lDSC+7+x2Y2XlJLBmsCUKBwCJjZVElXSrpFkty9V1Ls6XgAdVPLw4HTJHVJerQ8mvwhM2PEEHCCqSUEGiVdJOkBd79QUreku4efaegswr4+ZhECY00tIbBT0k53X1P+/CmVQuEYQ2cRNjVxoACMNbXMItwt6X0zO6v8paslbc5kVQAKU+tfB74iaWX5LwPbJH2x9iUBKFJNIeDuGyQtzmgtAOqAVwwCiSt0F+HgeNPBueOrrmvpjM0G3LuoOVR3pDVUpinvxbbZRWcDRncD/uZvHwjVLb18eaju8BkzQ3UWu9nV1B3bBdq8L1YXFZopKOngqdXfXwY2V/59z5EAkDhCAEgcIQAkjhAAEkcIAIkjBIDEEQJA4ggBIHGEAJA4QgBIHCEAJI4QABJHCACJK3QXYcPhQbVu/qjquo7LYzPiZm3oCdXtPSe2q69nhoXqxvXF6qKzAaO7AX/+2jOhult3XBGqe+N3baG6q9raQ3XPvHhpqG5CcPfhoVmx38G9U6q/vwyOsnmXIwEgcYQAkDhCAEhcrbMIv25mm8zsLTN7wsxiD6YB1E04BMzsFElflbTY3c+V1CDpxqwWBqAYtT4caJR0kpk1qjSMdFftSwJQpFqGj3wg6XuSdkjqkHTA3VdltTAAxajl4cB0SctVGkzaJmmimd08wvn+bxZhP7MIgbGmlocD10h6z9273L1P0tOSPjP8TMfMImxkFiEw1tQSAjskXWpmLWZmKs0i3JLNsgAUpZbnBNaoNIl4vaQ3y9/rwYzWBaAgtc4i/Jakb2W0FgB1wCsGgcQVuotwYMI47V9Y/ZODA8HXIXZeHCucum0gVDcYvDYHG2O7CA/NjmV4dDZgdDfgw6e+Gqr7fPeyUF33QGwG5ZRtoTJN2h77q9eB06eE6lrf7q+6ZkdP5S2nHAkAiSMEgMQRAkDiCAEgcYQAkDhCAEgcIQAkjhAAEkcIAIkjBIDEEQJA4ggBIHGEAJC4QncRmksNvdUP0Gs4HOs3Y9ORUF132yiD20bR0xrbDTjzzdg6ZU2xssFYu+hswOhuwBf/4LlQ3emrbg3V+YWx3aPWH5uVOXtd7HbvnRr5sa183+RIAEgcIQAkjhAAEnfcEDCzR8ys08zeGvK1VjNbbWbt5f+n57tMAHn5OEcCj0m6btjX7pb0srufKenl8ucATkDHDQF3f0XSh8O+vFzS4+WPH5f0hYzXBaAg0ecE5rh7R/nj3ZLmZLQeAAWr+YlBd3dJFf/4f8wswiMf1doOQMaiIbDHzD4lSeX/Oyud8ZhZhM2Tgu0A5CUaAs9KWlH+eIWkZ7JZDoCifZw/ET4h6T8knWVmO83sVknfkfQ5M2tXaTrxd/JdJoC8HPdFyO5+U4WTrs54LQDqgFcMAokrdBfhYKN0aHZD1XUtXbFtb4NNsYyzgep3OkpS36TYLsLtS2O7AVt2xy5fU3f1t4EkXdXWHqqLzgaM7gbcdu3DobqLv/3lUN3sl94P1XUsmxeqi+xW7f/3yqdxJAAkjhAAEkcIAIkjBIDEEQJA4ggBIHGEAJA4QgBIHCEAJI4QABJHCACJIwSAxBECQOIK3UXY0ONq3Vr9/LWuC2K70Kaviu16O7xkYaiu58yeUN3C+2Iz6dpXxGbgNe+L7SJ85sVLQ3VTtoXKwrMBo7sB1337gVDd+SfdHqqLzOWUpIm7AvM8eyufxpEAkDhCAEgcIQAkLjqL8LtmttXM3jCzn5nZtHyXCSAv0VmEqyWd6+7nSXpH0j0ZrwtAQUKzCN19lbv3lz99XdLcHNYGoABZPCfwJUm/qHTiMWPI+rozaAcgSzWFgJndK6lf0spK5zlmDFnTxFraAchB+MVCZnaLpGWSri4PJQVwAgqFgJldJ+kuSX/o7oeyXRKAIkVnEf6DpMmSVpvZBjP7Uc7rBJCT6CzC2IgXAGMOrxgEElfoLkIbGFTT/up32k1/O7brrW/RqaG6fWeHyjRl7YRY4bjYLsL5z/cf/0wZmhDcfThpe+xPw9Yf2yUZnQ0Y3Q248Rv3h+qWXPMnobrOy2ZUXeOjjC/kSABIHCEAJI4QABJHCACJIwSAxBECQOIIASBxhACQOEIASBwhACSOEAASRwgAiSMEgMQVuouwb3KDOq6cWnVdd1vs3cvaXguVqXn/KFuuRtEX2/SmvedNCdW1dMVm9dlg7Po8NCv2O+PA6bHLN3tdbHdlx7J5obrobMDobsDnX/qXUN2Vt99Wdc24/sqXjSMBIHGEAJC40BiyIafdaWZuZjPzWR6AvEXHkMnM5km6VtKOjNcEoEChMWRlP1DpbceZOQCcwELPCZjZckkfuPvGjNcDoGBV/4nQzFokfVOlhwIf5/y3SbpNkpomT6+2HYCcRY4EzpB0mqSNZrZdpYnE683s5JHOPHQWYUMLswiBsabqIwF3f1PS7KOfl4Ngsbv/LsN1AShIdAwZgE+I6BiyoacvyGw1AArHKwaBxBECQOIK3UUokwaaqy87aU9sV9/kdbtCdQfmx2YYtuyOvW4ququveV9vqO7gqbGZib1TYrdD69uxmYm9U2N3z57W2Don7ordDpHZgFJsN6AkvXL/g1XXXPL5ys/bcyQAJI4QABJHCACJIwSAxBECQOIIASBxhACQOEIASBwhACSOEAASRwgAiSMEgMQRAkDizL24dww3sy5J/13h5JmSinyLMvqduP0+yZctr37z3X3WSCcUGgKjMbO17r6YfvQbS71S6MfDASBxhACQuLEUAtW/XQr9Uu33Sb5shfcbM88JAKiPsXQkAKAOCAEgcYQAkDhCAEgcIQAk7n8AHHtTv1Gh3LcAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "from brainiak.matnormal.mnrsa import MNRSA\n", + "from brainiak.matnormal.covs import CovIdentity\n", + "from sklearn.linear_model import LinearRegression\n", + "\n", + "# beta_series RSA\n", + "model_linreg = LinearRegression(fit_intercept=False)\n", + "model_linreg.fit(true_X, true_Y)\n", + "beta_series = model_linreg.coef_\n", + "naive_RSA = np.corrcoef(beta_series.T)\n", + "\n", + "# MN-RSA\n", + "space_cov = CovDiagonal(size=n_V)\n", + "time_cov = CovAR1(size=n_T)\n", + "\n", + "model_matnorm = MNRSA(time_cov=time_cov, space_cov=space_cov, n_nureg=3)\n", + "\n", + "model_matnorm.fit(true_Y, true_X)\n", + "\n", + "# very similar on this toy data but as we show in the paper can be very different\n", + "# in other examples\n", + "plt.matshow(model_matnorm.C_)\n", + "plt.matshow(naive_RSA)" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python [conda env:brainiak] *", + "language": "python", + "name": "conda-env-brainiak-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/requirements-dev.txt b/requirements-dev.txt index bf2baa3eb..168c864c4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ --e . # Installing BrainIAK here ensures no requirement conflict. +-e .[matnormal] # Installing BrainIAK here ensures no requirement conflict. coverage flake8 flake8-print diff --git a/setup.py b/setup.py index 8627370a4..362bf61ee 100644 --- a/setup.py +++ b/setup.py @@ -142,6 +142,12 @@ def finalize_options(self): 'wheel', # See https://github.com/astropy/astropy-helpers/issues/501 'pydicom', ], + extras_require={ + 'matnormal': [ + 'tensorflow', + 'tensorflow_probability', + ], + }, author='Princeton Neuroscience Institute and Intel Corporation', author_email='mihai.capota@intel.com', url='http://brainiak.org', diff --git a/tests/conftest.py b/tests/conftest.py index 8b9888bad..de5ec6b56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,16 @@ from mpi4py import MPI +import pytest +import numpy +import random +import tensorflow def pytest_configure(config): config.option.xmlpath = "junit-{}.xml".format(MPI.COMM_WORLD.Get_rank()) + + +@pytest.fixture +def seeded_rng(): + random.seed(0) + numpy.random.seed(0) + tensorflow.random.set_seed(0) diff --git a/tests/matnormal/test_cov.py b/tests/matnormal/test_cov.py new file mode 100644 index 000000000..fd1c11ca0 --- /dev/null +++ b/tests/matnormal/test_cov.py @@ -0,0 +1,310 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose +from scipy.stats import norm, wishart, invgamma, invwishart + +import tensorflow as tf + +from brainiak.matnormal.covs import ( + CovIdentity, + CovAR1, + CovIsotropic, + CovDiagonal, + CovDiagonalGammaPrior, + CovUnconstrainedCholesky, + CovUnconstrainedCholeskyWishartReg, + CovUnconstrainedInvCholesky, + CovKroneckerFactored, +) + +# X is m x n, so A sould be m x p + +m = 8 +n = 4 +p = 3 + +rtol = 1e-7 +atol = 1e-7 + + +def logdet_sinv_np(X, sigma): + # logdet + sign, logdet = np.linalg.slogdet(sigma) + logdet_np = sign * logdet + # sigma-inv + sinv_np = np.linalg.inv(sigma) + # solve + sinvx_np = np.linalg.solve(sigma, X) + return logdet_np, sinv_np, sinvx_np + + +def logdet_sinv_np_mask(X, sigma, mask): + mask_indices = np.nonzero(mask)[0] + # logdet + _, logdet_np = np.linalg.slogdet(sigma[np.ix_(mask_indices, mask_indices)]) + # sigma-inv + sinv_np_ = np.linalg.inv(sigma[np.ix_(mask_indices, mask_indices)]) + # sigma-inverse * + sinvx_np_ = sinv_np_.dot(X[mask_indices, :]) + + sinv_np = np.zeros_like(sigma) + sinv_np[np.ix_(mask_indices, mask_indices)] = sinv_np_ + sinvx_np = np.zeros_like(X) + sinvx_np[mask_indices, :] = sinvx_np_ + + return logdet_np, sinv_np, sinvx_np + + +X = norm.rvs(size=(m, n)) +X_tf = tf.constant(X) +A = norm.rvs(size=(m, p)) +A_tf = tf.constant(A) +eye = tf.eye(m, dtype=tf.float64) + + +def test_CovConstant(seeded_rng): + + cov_np = wishart.rvs(df=m + 2, scale=np.eye(m)) + cov = CovUnconstrainedCholesky(Sigma=cov_np) + + # verify what we pass is what we get + cov_tf = cov._cov + assert_allclose(cov_tf, cov_np) + + # compute the naive version + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + +def test_CovIdentity(seeded_rng): + + cov = CovIdentity(size=m) + + # compute the naive version + cov_np = np.eye(m) + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + +def test_CovIsotropic(seeded_rng): + + cov = CovIsotropic(size=m) + + # compute the naive version + cov_np = cov._cov * np.eye(cov.size) + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + # test initialization + cov = CovIsotropic(var=0.123, size=3) + assert_allclose(np.exp(cov.log_var.numpy()), 0.123) + + +def test_CovDiagonal(seeded_rng): + + cov = CovDiagonal(size=m) + + # compute the naive version + cov_np = cov._cov + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + +def test_CovDiagonal_initialized(seeded_rng): + + cov_np = np.diag(np.exp(np.random.normal(size=m))) + cov = CovDiagonal(size=m, diag_var=np.diag(cov_np)) + + # compute the naive version + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + +def test_CovDiagonalGammaPrior(seeded_rng): + + cov_np = np.diag(np.exp(np.random.normal(size=m))) + cov = CovDiagonalGammaPrior(size=m, sigma=np.diag(cov_np), alpha=1.5, + beta=1e-10) + + ig = invgamma(1.5, scale=1e-10) + + # compute the naive version + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + penalty_np = np.sum(ig.logpdf(1 / np.diag(cov_np))) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + assert_allclose(penalty_np, cov.logp, rtol=rtol) + + +def test_CovUnconstrainedCholesky(seeded_rng): + + cov = CovUnconstrainedCholesky(size=m) + + L = cov.L.numpy() + cov_np = L @ L.T + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + +def test_CovUnconstrainedCholeskyWishartReg(seeded_rng): + + cov = CovUnconstrainedCholeskyWishartReg(size=m) + + L = cov.L.numpy() + cov_np = L @ L.T + + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + # now compute the regularizer + reg = wishart.logpdf(cov_np, df=m + 2, scale=1e10 * np.eye(m)) + assert_allclose(reg, cov.logp, rtol=rtol) + + +def test_CovUnconstrainedInvCholesky(seeded_rng): + + init = invwishart.rvs(scale=np.eye(m), df=m + 2) + cov = CovUnconstrainedInvCholesky(invSigma=init) + + Linv = cov.Linv + L = np.linalg.inv(Linv) + cov_np = L @ L.T + + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + +def test_Cov2FactorKron(seeded_rng): + assert m % 2 == 0 + dim1 = int(m / 2) + dim2 = 2 + + with pytest.raises(TypeError) as excinfo: + cov = CovKroneckerFactored(sizes=dim1) + assert "sizes is not a list" in str(excinfo.value) + + cov = CovKroneckerFactored(sizes=[dim1, dim2]) + + L1 = (cov.L[0]).numpy() + L2 = (cov.L[1]).numpy() + cov_np = np.kron(np.dot(L1, L1.transpose()), np.dot(L2, L2.transpose())) + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + +def test_Cov3FactorKron(seeded_rng): + assert m % 4 == 0 + dim1 = int(m / 4) + dim2 = 2 + dim3 = 2 + cov = CovKroneckerFactored(sizes=[dim1, dim2, dim3]) + + L1 = (cov.L[0]).numpy() + L2 = (cov.L[1]).numpy() + L3 = (cov.L[2]).numpy() + cov_np = np.kron( + np.kron(np.dot(L1, L1.transpose()), np.dot(L2, L2.transpose())), + np.dot(L3, L3.transpose()), + ) + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinv_np, cov.solve(eye), rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + +def test_Cov3FactorMaskedKron(seeded_rng): + assert m % 4 == 0 + dim1 = int(m / 4) + dim2 = 2 + dim3 = 2 + + mask = np.random.binomial(1, 0.5, m).astype(np.int32) + + if sum(mask == 0): + mask[0] = 1 + mask_indices = np.nonzero(mask)[0] + + cov = CovKroneckerFactored(sizes=[dim1, dim2, dim3], mask=mask) + + L1 = (cov.L[0]).numpy() + L2 = (cov.L[1]).numpy() + L3 = (cov.L[2]).numpy() + cov_np_factor = np.kron(L1, np.kron(L2, L3))[ + np.ix_(mask_indices, mask_indices)] + cov_np = np.dot(cov_np_factor, cov_np_factor.transpose()) + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X[mask_indices, :], cov_np) + + assert_allclose(logdet_np, cov.logdet, rtol=rtol, atol=atol) + assert_allclose( + sinv_np, + cov.solve(eye).numpy()[np.ix_(mask_indices, mask_indices)], + rtol=rtol, + atol=atol, + ) + assert_allclose( + sinvx_np, cov.solve(X_tf).numpy()[ + mask_indices, :], rtol=rtol, atol=atol + ) + + +def test_CovAR1(seeded_rng): + + cov = CovAR1(size=m) + + cov_np = np.linalg.inv(cov.solve(eye)) + + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + # test initialization + cov = CovAR1(rho=0.3, sigma=1.3, size=3) + assert_allclose(np.exp(cov.log_sigma.numpy()), 1.3) + assert_allclose((2 * tf.sigmoid(cov.rho_unc) - 1).numpy(), 0.3) + + +def test_CovAR1_scan_onsets(seeded_rng): + + cov = CovAR1(size=m, scan_onsets=[0, m // 2]) + + # compute the naive version + cov_np = np.linalg.inv(cov.solve(eye)) + + logdet_np, sinv_np, sinvx_np = logdet_sinv_np(X, cov_np) + assert_allclose(logdet_np, cov.logdet, rtol=rtol) + assert_allclose(sinvx_np, cov.solve(X_tf), rtol=rtol) + + +def test_raises(seeded_rng): + + with pytest.raises(RuntimeError): + CovUnconstrainedCholesky(Sigma=np.eye(3), size=4) + + with pytest.raises(RuntimeError): + CovUnconstrainedCholesky() + + with pytest.raises(RuntimeError): + CovUnconstrainedInvCholesky(invSigma=np.eye(3), size=4) + + with pytest.raises(RuntimeError): + CovUnconstrainedInvCholesky() diff --git a/tests/matnormal/test_matnormal_logp.py b/tests/matnormal/test_matnormal_logp.py new file mode 100644 index 000000000..55c474a42 --- /dev/null +++ b/tests/matnormal/test_matnormal_logp.py @@ -0,0 +1,46 @@ +import numpy as np +from numpy.testing import assert_allclose +from scipy.stats import multivariate_normal +import tensorflow as tf + +from brainiak.matnormal.utils import rmn +from brainiak.matnormal.matnormal_likelihoods import matnorm_logp +from brainiak.matnormal.covs import CovIdentity, CovUnconstrainedCholesky + +# X is m x n, so A sould be m x p + +m = 5 +n = 4 +p = 3 + +rtol = 1e-7 + + +def test_against_scipy_mvn_row(seeded_rng): + + rowcov = CovUnconstrainedCholesky(size=m) + colcov = CovIdentity(size=n) + X = rmn(np.eye(m), np.eye(n)) + X_tf = tf.constant(X, "float64") + + rowcov_np = rowcov._cov + + scipy_answer = np.sum(multivariate_normal.logpdf( + X.T, np.zeros([m]), rowcov_np)) + tf_answer = matnorm_logp(X_tf, rowcov, colcov) + assert_allclose(scipy_answer, tf_answer, rtol=rtol) + + +def test_against_scipy_mvn_col(seeded_rng): + + rowcov = CovIdentity(size=m) + colcov = CovUnconstrainedCholesky(size=n) + X = rmn(np.eye(m), np.eye(n)) + X_tf = tf.constant(X, "float64") + + colcov_np = colcov._cov + + scipy_answer = np.sum(multivariate_normal.logpdf( + X, np.zeros([n]), colcov_np)) + tf_answer = matnorm_logp(X_tf, rowcov, colcov) + assert_allclose(scipy_answer, tf_answer, rtol=rtol) diff --git a/tests/matnormal/test_matnormal_logp_conditional.py b/tests/matnormal/test_matnormal_logp_conditional.py new file mode 100644 index 000000000..e85f34a0b --- /dev/null +++ b/tests/matnormal/test_matnormal_logp_conditional.py @@ -0,0 +1,78 @@ +import numpy as np +from numpy.testing import assert_allclose +from scipy.stats import wishart, multivariate_normal +import tensorflow as tf + +from brainiak.matnormal.utils import rmn +from brainiak.matnormal.matnormal_likelihoods import ( + matnorm_logp_conditional_col, + matnorm_logp_conditional_row, +) +from brainiak.matnormal.covs import CovIdentity, CovUnconstrainedCholesky + +# X is m x n, so A sould be m x p + +m = 5 +n = 4 +p = 3 + +rtol = 1e-7 + + +def test_against_scipy_mvn_row_conditional(seeded_rng): + + # have to be careful for constructing everything as a submatrix of a big + # PSD matrix, else no guarantee that anything's invertible. + cov_np = wishart.rvs(df=m + p + 2, scale=np.eye(m + p)) + + # rowcov = CovConstant(cov_np[0:m, 0:m]) + rowcov = CovUnconstrainedCholesky(Sigma=cov_np[0:m, 0:m]) + A = cov_np[0:m, m:] + + colcov = CovIdentity(size=n) + + Q = CovUnconstrainedCholesky(Sigma=cov_np[m:, m:]) + + X = rmn(np.eye(m), np.eye(n)) + + A_tf = tf.constant(A, "float64") + X_tf = tf.constant(X, "float64") + + Q_np = Q._cov + + rowcov_np = rowcov._cov - A.dot(np.linalg.inv(Q_np)).dot((A.T)) + + scipy_answer = np.sum(multivariate_normal.logpdf( + X.T, np.zeros([m]), rowcov_np)) + + tf_answer = matnorm_logp_conditional_row(X_tf, rowcov, colcov, A_tf, Q) + assert_allclose(scipy_answer, tf_answer, rtol=rtol) + + +def test_against_scipy_mvn_col_conditional(seeded_rng): + + # have to be careful for constructing everything as a submatrix of a big + # PSD matrix, else no guarantee that anything's invertible. + cov_np = wishart.rvs(df=m + p + 2, scale=np.eye(m + p)) + + rowcov = CovIdentity(size=m) + colcov = CovUnconstrainedCholesky(Sigma=cov_np[0:n, 0:n]) + A = cov_np[n:, 0:n] + + Q = CovUnconstrainedCholesky(Sigma=cov_np[n:, n:]) + + X = rmn(np.eye(m), np.eye(n)) + + A_tf = tf.constant(A, "float64") + X_tf = tf.constant(X, "float64") + + Q_np = Q._cov + + colcov_np = colcov._cov - A.T.dot(np.linalg.inv(Q_np)).dot((A)) + + scipy_answer = np.sum(multivariate_normal.logpdf( + X, np.zeros([n]), colcov_np)) + + tf_answer = matnorm_logp_conditional_col(X_tf, rowcov, colcov, A_tf, Q) + + assert_allclose(scipy_answer, tf_answer, rtol=rtol) diff --git a/tests/matnormal/test_matnormal_logp_marginal.py b/tests/matnormal/test_matnormal_logp_marginal.py new file mode 100644 index 000000000..53ca2b67e --- /dev/null +++ b/tests/matnormal/test_matnormal_logp_marginal.py @@ -0,0 +1,66 @@ +import numpy as np +from numpy.testing import assert_allclose +from scipy.stats import multivariate_normal +import tensorflow as tf + +from brainiak.matnormal.utils import rmn +from brainiak.matnormal.matnormal_likelihoods import ( + matnorm_logp_marginal_col, + matnorm_logp_marginal_row, +) + +from brainiak.matnormal.covs import CovIdentity, CovUnconstrainedCholesky + +# X is m x n, so A sould be m x p + +m = 5 +n = 4 +p = 3 + +rtol = 1e-7 + + +def test_against_scipy_mvn_row_marginal(seeded_rng): + + rowcov = CovUnconstrainedCholesky(size=m) + colcov = CovIdentity(size=n) + Q = CovUnconstrainedCholesky(size=p) + + X = rmn(np.eye(m), np.eye(n)) + A = rmn(np.eye(m), np.eye(p)) + + A_tf = tf.constant(A, "float64") + X_tf = tf.constant(X, "float64") + + Q_np = Q._cov + + rowcov_np = rowcov._cov + A.dot(Q_np).dot(A.T) + + scipy_answer = np.sum(multivariate_normal.logpdf( + X.T, np.zeros([m]), rowcov_np)) + + tf_answer = matnorm_logp_marginal_row(X_tf, rowcov, colcov, A_tf, Q) + assert_allclose(scipy_answer, tf_answer, rtol=rtol) + + +def test_against_scipy_mvn_col_marginal(seeded_rng): + + rowcov = CovIdentity(size=m) + colcov = CovUnconstrainedCholesky(size=n) + Q = CovUnconstrainedCholesky(size=p) + + X = rmn(np.eye(m), np.eye(n)) + A = rmn(np.eye(p), np.eye(n)) + + A_tf = tf.constant(A, "float64") + X_tf = tf.constant(X, "float64") + + Q_np = Q._cov + + colcov_np = colcov._cov + A.T.dot(Q_np).dot(A) + + scipy_answer = np.sum(multivariate_normal.logpdf( + X, np.zeros([n]), colcov_np)) + + tf_answer = matnorm_logp_marginal_col(X_tf, rowcov, colcov, A_tf, Q) + assert_allclose(scipy_answer, tf_answer, rtol=rtol) diff --git a/tests/matnormal/test_matnormal_regression.py b/tests/matnormal/test_matnormal_regression.py new file mode 100644 index 000000000..556ee81e8 --- /dev/null +++ b/tests/matnormal/test_matnormal_regression.py @@ -0,0 +1,159 @@ +import pytest +import numpy as np +from scipy.stats import norm, wishart, pearsonr + +from brainiak.matnormal.covs import ( + CovIdentity, + CovUnconstrainedCholesky, + CovUnconstrainedInvCholesky, + CovDiagonal, +) +from brainiak.matnormal.regression import MatnormalRegression +from brainiak.matnormal.utils import rmn + +m = 100 +n = 4 +p = 5 + +corrtol = 0.8 # at least this much correlation between true and est to pass + + +def test_matnorm_regression_unconstrained(seeded_rng): + + # Y = XB + eps + # Y is m x p, B is n x p, eps is m x p + X = norm.rvs(size=(m, n)) + B = norm.rvs(size=(n, p)) + Y_hat = X.dot(B) + rowcov_true = np.eye(m) + colcov_true = wishart.rvs(p + 2, np.eye(p)) + + Y = Y_hat + rmn(rowcov_true, colcov_true) + + row_cov = CovIdentity(size=m) + col_cov = CovUnconstrainedCholesky(size=p) + + model = MatnormalRegression(time_cov=row_cov, space_cov=col_cov) + + model.fit(X, Y, naive_init=False) + + assert pearsonr(B.flatten(), model.beta_.flatten())[0] >= corrtol + + pred_y = model.predict(X) + assert pearsonr(pred_y.flatten(), Y_hat.flatten())[0] >= corrtol + + model = MatnormalRegression(time_cov=row_cov, space_cov=col_cov) + + model.fit(X, Y, naive_init=True) + + assert pearsonr(B.flatten(), model.beta_.flatten())[0] >= corrtol + + pred_y = model.predict(X) + assert pearsonr(pred_y.flatten(), Y_hat.flatten())[0] >= corrtol + + +def test_matnorm_regression_unconstrainedprec(seeded_rng): + + # Y = XB + eps + # Y is m x n, B is n x p, eps is m x p + X = norm.rvs(size=(m, n)) + B = norm.rvs(size=(n, p)) + Y_hat = X.dot(B) + rowcov_true = np.eye(m) + colcov_true = wishart.rvs(p + 2, np.eye(p)) + + Y = Y_hat + rmn(rowcov_true, colcov_true) + + row_cov = CovIdentity(size=m) + col_cov = CovUnconstrainedInvCholesky(size=p) + + model = MatnormalRegression(time_cov=row_cov, space_cov=col_cov) + + model.fit(X, Y, naive_init=False) + + assert pearsonr(B.flatten(), model.beta_.flatten())[0] >= corrtol + + pred_y = model.predict(X) + assert pearsonr(pred_y.flatten(), Y_hat.flatten())[0] >= corrtol + + +def test_matnorm_regression_optimizerChoice(seeded_rng): + + # Y = XB + eps + # Y is m x n, B is n x p, eps is m x p + X = norm.rvs(size=(m, n)) + B = norm.rvs(size=(n, p)) + Y_hat = X.dot(B) + rowcov_true = np.eye(m) + colcov_true = wishart.rvs(p + 2, np.eye(p)) + + Y = Y_hat + rmn(rowcov_true, colcov_true) + + row_cov = CovIdentity(size=m) + col_cov = CovUnconstrainedInvCholesky(size=p) + + model = MatnormalRegression(time_cov=row_cov, space_cov=col_cov, + optimizer="CG") + + model.fit(X, Y, naive_init=False) + + assert pearsonr(B.flatten(), model.beta_.flatten())[0] >= corrtol + + pred_y = model.predict(X) + assert pearsonr(pred_y.flatten(), Y_hat.flatten())[0] >= corrtol + + +def test_matnorm_regression_scaledDiag(seeded_rng): + + # Y = XB + eps + # Y is m x n, B is n x p, eps is m x p + X = norm.rvs(size=(m, n)) + B = norm.rvs(size=(n, p)) + Y_hat = X.dot(B) + + rowcov_true = np.eye(m) + colcov_true = np.diag(np.abs(norm.rvs(size=p))) + + Y = Y_hat + rmn(rowcov_true, colcov_true) + + row_cov = CovIdentity(size=m) + col_cov = CovDiagonal(size=p) + + model = MatnormalRegression(time_cov=row_cov, space_cov=col_cov) + + model.fit(X, Y, naive_init=False) + + assert pearsonr(B.flatten(), model.beta_.flatten())[0] >= corrtol + + pred_y = model.predict(X) + assert pearsonr(pred_y.flatten(), Y_hat.flatten())[0] >= corrtol + + # we only do calibration test on the scaled diag + # model because to hit corrtol on unconstrainedCov + # we'd need a lot more data, which would make the test slow + X_hat = model.calibrate(Y) + assert pearsonr(X_hat.flatten(), X.flatten())[0] >= corrtol + + +def test_matnorm_calibration_raises(seeded_rng): + + # Y = XB + eps + # Y is m x n, B is n x p, eps is m x p + X = norm.rvs(size=(2, 5)) + B = norm.rvs(size=(5, 3)) + Y_hat = X.dot(B) + + rowcov_true = np.eye(2) + colcov_true = np.diag(np.abs(norm.rvs(size=3))) + + Y = Y_hat + rmn(rowcov_true, colcov_true) + + row_cov = CovIdentity(size=2) + col_cov = CovDiagonal(size=3) + + model = MatnormalRegression(time_cov=row_cov, space_cov=col_cov) + + model.fit(X, Y, naive_init=False) + + with pytest.raises(RuntimeError): + model.calibrate(Y) diff --git a/tests/matnormal/test_matnormal_rsa.py b/tests/matnormal/test_matnormal_rsa.py new file mode 100644 index 000000000..f897c3c9e --- /dev/null +++ b/tests/matnormal/test_matnormal_rsa.py @@ -0,0 +1,76 @@ +from brainiak.matnormal.mnrsa import MNRSA +from brainiak.utils.utils import cov2corr +from brainiak.matnormal.covs import CovIdentity, CovDiagonal +from scipy.stats import norm +from brainiak.matnormal.utils import rmn +import numpy as np + + +def gen_U_nips2016_example(): + + n_C = 16 + U = np.zeros([n_C, n_C]) + U = np.eye(n_C) * 0.6 + U[8:12, 8:12] = 0.8 + for cond in range(8, 12): + U[cond, cond] = 1 + + return U + + +def gen_brsa_data_matnorm_model(U, n_T, n_V, space_cov, time_cov, n_nureg): + + n_C = U.shape[0] + beta = rmn(U, space_cov) + X = rmn(np.eye(n_T), np.eye(n_C)) + beta_0 = rmn(np.eye(n_nureg), space_cov) + X_0 = rmn(np.eye(n_T), np.eye(n_nureg)) + Y_hat = X.dot(beta) + X_0.dot(beta_0) + Y = Y_hat + rmn(time_cov, space_cov) + sizes = {"n_C": n_C, "n_T": n_T, "n_V": n_V} + train = {"beta": beta, "X": X, "Y": Y, "U": U, "X_0": X_0} + + return train, sizes + + +def test_brsa_rudimentary(seeded_rng): + """this test is super loose""" + + # this is Mingbo's synth example from the paper + U = gen_U_nips2016_example() + + n_T = 150 + n_V = 250 + n_nureg = 5 + + spacecov_true = np.eye(n_V) + + timecov_true = np.diag(np.abs(norm.rvs(size=(n_T)))) + + tr, sz = gen_brsa_data_matnorm_model( + U, + n_T=n_T, + n_V=n_V, + n_nureg=n_nureg, + space_cov=spacecov_true, + time_cov=timecov_true, + ) + + spacecov_model = CovIdentity(size=n_V) + timecov_model = CovDiagonal(size=n_T) + + model_matnorm = MNRSA(time_cov=timecov_model, space_cov=spacecov_model) + + model_matnorm.fit(tr["Y"], tr["X"], naive_init=False) + + RMSE = np.mean((model_matnorm.C_ - cov2corr(tr["U"])) ** 2) ** 0.5 + + assert RMSE < 0.1 + + model_matnorm = MNRSA(time_cov=timecov_model, space_cov=spacecov_model) + + model_matnorm.fit(tr["Y"], tr["X"], naive_init=True) + + RMSE = np.mean((model_matnorm.C_ - cov2corr(tr["U"])) ** 2) ** 0.5 + + assert RMSE < 0.1 diff --git a/tests/matnormal/test_matnormal_utils.py b/tests/matnormal/test_matnormal_utils.py new file mode 100644 index 000000000..ed1269410 --- /dev/null +++ b/tests/matnormal/test_matnormal_utils.py @@ -0,0 +1,28 @@ +from brainiak.matnormal.utils import (pack_trainable_vars, + unpack_trainable_vars, + flatten_cholesky_unique, + unflatten_cholesky_unique) +import tensorflow as tf +import numpy as np +import numpy.testing as npt + + +def test_pack_unpack(seeded_rng): + + shapes = [[2, 3], [3], [3, 4, 2], [1, 5]] + mats = [tf.random.stateless_normal( + shape=shape, seed=[0, 0]) for shape in shapes] + flatmats = pack_trainable_vars(mats) + unflatmats = unpack_trainable_vars(flatmats, mats) + for mat_in, mat_out in zip(mats, unflatmats): + assert tf.math.reduce_all(tf.equal(mat_in, mat_out)) + + +def test_cholesky_uncholesky(seeded_rng): + size = 3 + flat_chol_length = (size*(size+1))//2 + flatchol = np.random.normal(size=[flat_chol_length]) + unflatchol = unflatten_cholesky_unique(flatchol) + npt.assert_equal(unflatchol.shape, [3, 3]) + reflatchol = flatten_cholesky_unique(unflatchol) + npt.assert_allclose(flatchol, reflatchol)