Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
fb18dc2
initial commit of matrix normal base API , regression, and MNRSA
Mar 2, 2018
94dfcbb
add tensorflow to requirements
Mar 5, 2018
7264b3f
make the linter happy
Mar 12, 2018
aa54772
Merge pull request #1 from brainiak/master
narayanan2004 Mar 19, 2018
6e9fd68
Merge branch 'matnormal-regression-rsa' of https://github.com/mshvart…
Mar 19, 2018
34087a7
Fix style issues
Apr 2, 2018
4eb39f6
Merge branch 'master' into matnormal-regression-rsa
mshvartsman Apr 3, 2018
6ab49df
Merge pull request #1 from narayanan2004/matnormal-regression-rsa
mshvartsman Apr 3, 2018
347cdc8
Merge branch 'matnormal-regression-rsa' of github.com:mshvartsman/bra…
Apr 3, 2018
a75d866
more linter fixes
Apr 3, 2018
3dcf915
fix ambiguous varname
Apr 3, 2018
e9da333
linter fixes in tests
Apr 8, 2018
5ae24f1
broke this to make the linter happy, fixing
Apr 8, 2018
01aed20
more cleanup from hacky copypaste-squash
Apr 9, 2018
9ca2cef
More linter checks (for some reason run-checks.sh ignores /tests on m…
Apr 9, 2018
9ee9734
fixing sphinx complaints
Apr 9, 2018
94ff22c
original-style SRM
Apr 11, 2018
39fbc3b
Merge remote-tracking branch 'upstream/master' into matnormal-regress…
Apr 16, 2018
dfdc7d8
Merge branch 'matnormal-regression-rsa' into dpmnsrm
Apr 19, 2018
114721d
WIP dual probabilistic MN-SRM
Apr 20, 2018
e72998d
sync with upstream
mshvartsman Jul 18, 2019
56cc6c8
move kronecker solvers to their own file, utils.py was getting unwieldy
mshvartsman Jul 18, 2019
e5a2e94
initial refactor, all tests pass
mshvartsman Jan 1, 2020
bb57f74
remove CovTFWrap, use constant cholesky cov with passed Sigma instead
mshvartsman Jan 1, 2020
1ee89b3
linter and formatting fixes
mshvartsman Jan 1, 2020
835558a
merge upstream
mshvartsman Jan 2, 2020
231391b
add metaclass reference
mshvartsman Jan 8, 2020
d2bb285
further cleanup post refactor, doc changes, addressing minor comments
mshvartsman Jan 8, 2020
af2e0ac
Merge remote-tracking branch 'upstream/master' into matnormal-regress…
mshvartsman Mar 2, 2020
79092ee
linter and deprecation fixes, rename Matnorm to Matnormal in one spot
mshvartsman Mar 2, 2020
1fedd0c
fix missing comma
mshvartsman Mar 3, 2020
87eb10d
strict linter fixes
mshvartsman Mar 3, 2020
bc87a00
need old TF for things to work
mshvartsman Mar 3, 2020
05258b3
docstring formatting fix
mshvartsman Mar 3, 2020
40e7137
sync with master
mshvartsman Mar 4, 2020
2a075cc
merge in changes to base matnormal commit
mshvartsman Mar 4, 2020
1a88cbd
fix bad merge
mshvartsman Mar 4, 2020
06d68d4
wip cleanup
mshvartsman Jun 29, 2020
c362426
run tf1 -> tf2 conversion script
mshvartsman Jun 30, 2020
16bf442
run v1 -> v2 conversion on tests
mshvartsman Jun 30, 2020
e4f28d6
tf 1 -> 2 on solvers
mshvartsman Jun 30, 2020
dea4d10
test_cov passes
mshvartsman Jun 30, 2020
a5b2c7e
logp tests pass
mshvartsman Jun 30, 2020
d82a1b6
cov tests pass eager
mshvartsman Jul 27, 2020
b0eb0c6
matnorm test passes in eager
mshvartsman Jul 27, 2020
0ba7509
rest of likelihoods pass in eager
mshvartsman Jul 27, 2020
1ec4425
pack/unpack for fitting using scipy
mshvartsman Aug 5, 2020
d032b1a
do cholesky covs better now
mshvartsman Aug 5, 2020
ed63f6d
test for new packing/unpacking utils
mshvartsman Aug 5, 2020
1db0080
tests now pass eager mode
mshvartsman Aug 5, 2020
b134a05
pull out repeated code as a util
mshvartsman Aug 5, 2020
7ffbe91
Follow more standard sklearn API where est params come out as trailin…
mshvartsman Aug 5, 2020
9a6b4b7
mnrsa now works with eager, tests pass
mshvartsman Aug 5, 2020
b041460
final removal of tf print and session stuff
mshvartsman Aug 5, 2020
cee8c81
typo fixes
mshvartsman Aug 6, 2020
6e95082
likelihood docstrings
mshvartsman Aug 6, 2020
12eff1f
simplify make_val_and_grad
mshvartsman Aug 6, 2020
41eda2b
fix old tf1 stuff that would break in tf2
mshvartsman Aug 6, 2020
3f6dde5
fix for maintaining the graph correctly
mshvartsman Aug 6, 2020
429e329
More stringent tests by doing bad initialization
mshvartsman Aug 6, 2020
ad6ce7d
Update example with new way of doing things
mshvartsman Aug 6, 2020
a1fa69e
final typo fixes
mshvartsman Aug 6, 2020
dca4b51
Merge branch 'mnorm-eager' into matnormal-regression-rsa
mshvartsman Aug 6, 2020
ee57283
autoformat
mshvartsman Aug 6, 2020
0cf0a77
remove nb_black dependency
mshvartsman Aug 6, 2020
d3c0a5a
minor docstring cleanup
mshvartsman Aug 14, 2020
73d2945
fix the kron covs to work correctly with the new optimizer wrapper
mshvartsman Aug 14, 2020
98607ea
autoformat
mshvartsman Aug 14, 2020
4fcf557
correctly pass optimizer args
mshvartsman Aug 14, 2020
a1fe394
Make test linter happy
mshvartsman Aug 14, 2020
39646f7
Merge branch 'master' into matnormal-regression-rsa
mshvartsman Aug 14, 2020
495c656
maybe this will make travis use a recent TF?
mshvartsman Aug 14, 2020
c9a5688
workaround to be able to use pymanopt (for theano) in the presence of…
mshvartsman Aug 15, 2020
c599f92
Merge branch 'master' into matnormal-regression-rsa (pull in tf2.3 fi…
mshvartsman Aug 18, 2020
570f4a2
doc build fixes
mshvartsman Aug 22, 2020
5f04a82
doc cleanup and removal of unused functions
mshvartsman Aug 22, 2020
8a41b9d
fix linter issues introduced by fixing docbuild issues
mshvartsman Aug 22, 2020
b77ddf3
remove hard tf2.3 requirement (tensorflow_probability deps should res…
mshvartsman Aug 22, 2020
c0af0e6
add reproducible rng fixture, improve test coverage, fixup linear dec…
mshvartsman Aug 22, 2020
6250fc8
don't print debug on tests, improve cov
mshvartsman Aug 22, 2020
3dfec38
improve coverage
mshvartsman Aug 22, 2020
549bf2a
docstring cleanups
mshvartsman Aug 22, 2020
d0da34c
notation consistency fix for the example too
mshvartsman Aug 22, 2020
0022cc7
addressing @mihaic's comments
mshvartsman Aug 25, 2020
98d3220
Merge branch 'matnormal-regression-rsa' into mnsrm-tf2
mshvartsman Aug 28, 2020
32aabef
tf1->2 conversion script
mshvartsman Aug 28, 2020
3bbd540
wip pymanopt stuff
mshvartsman Sep 1, 2020
5aaf075
mn-srm ported to tf2
mshvartsman Dec 24, 2020
fee2c8d
consistent naming
mshvartsman Dec 24, 2020
2320c03
works now
mshvartsman Dec 26, 2020
3f40de1
notation consistency cleanups
mshvartsman Jan 1, 2021
9a62832
MNSRM-MargS with all options (some combos don't work and/or don't mak…
mshvartsman Jan 1, 2021
fb552a7
remove logging calls that would indirectly lead to explicit inverses
mshvartsman Jan 1, 2021
dcf026e
full complement of tests (incl nonidentifiable ones)
mshvartsman Jan 1, 2021
6e12707
remove srm-margs for now due to pymanopt dependency issue
mshvartsman Jan 1, 2021
1eb5e7d
rename
mshvartsman Jan 1, 2021
3a5d13a
remove ortho-s for initial version
mshvartsman Jan 28, 2021
dc476be
remove ortho-s and marg-s for now
mshvartsman Jan 28, 2021
ceb3b9f
remove prints
mshvartsman Jan 28, 2021
58a902d
correct size and noise
mshvartsman Jan 28, 2021
51860ac
Merge branch 'master' into feature-matnormal-dpsrm
mshvartsman Jan 28, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions brainiak/matnormal/covs.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,33 @@ def solve(self, X):
z = tf_solve_lower_triangular_masked_kron(self.L, X, self.mask)
x = tf_solve_upper_triangular_masked_kron(self.L, z, self.mask)
return x


class CovScaleMixin:
""" wraps a Cov, adds a scaler (e.g. for subject-specific variances)
"""
def __init__(self, base_cov, scale=1.0):
self._baseCov = base_cov
self._scale = scale

@property
def logdet(self):
""" log|Sigma|
"""
return self._baseCov.logdet + tf.math.log(self._scale) * self._baseCov.size

def solve(self, X):
"""Given this Sigma and some X, compute :math:`Sigma^{-1} * x`
"""
return self._baseCov.solve(X) / self._scale

def _cov(self):
"""return Sigma
"""
return self._baseCov._cov * self._scale

def _prec(self):
""" Sigma^{-1}. Override me with more efficient
implementation in subclasses
"""
return self._baseCov.Sigma_inv / self._scale
353 changes: 353 additions & 0 deletions brainiak/matnormal/dpsrm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
import tensorflow as tf
from sklearn.base import BaseEstimator
from brainiak.matnormal.covs import (CovIdentity,
CovScaleMixin,
CovUnconstrainedCholesky)
import numpy as np
from brainiak.matnormal.matnormal_likelihoods import (
matnorm_logp_marginal_col)
from brainiak.matnormal.utils import pack_trainable_vars, make_val_and_grad
import logging
from scipy.optimize import minimize

logger = logging.getLogger(__name__)


def assert_monotonicity(fun, rtol=1e-3):
"""
Check that the loss is monotonically decreasing
after called function.
tol > 0 allows for some slop due to numerics
"""
def wrapper(classref, *args, **kwargs):
loss_before = classref.lossfn(None)
res = fun(classref, *args, **kwargs)
loss_after = classref.lossfn(None)
assert loss_after-loss_before <= abs(loss_before*rtol), f"loss increased on {fun}"
return res
return wrapper


class DPMNSRM(BaseEstimator):
"""Probabilistic SRM, aka SRM with marginalization over W (and optionally,
orthonormal S). In contrast to SRM (Chen et al. 2015), this estimates
far fewer parameters due to the W integral, and includes support for
arbitrary kronecker-structured residual covariance. Inference is
performed by ECM algorithm.
"""

def __init__(self, n_features=5, time_noise_cov=CovIdentity,
space_noise_cov=CovIdentity,
optMethod="L-BFGS-B", optCtrl={},
improvement_tol=1e-5, algorithm="ECME"):

self.k = n_features
# self.s_constraint = s_constraint
self.improvement_tol = improvement_tol
self.algorithm = algorithm
self.marg_cov_class = CovIdentity

if algorithm not in ["ECM", "ECME"]:
raise RuntimeError(
f"Unknown algorithm! Expected 'ECM' or 'ECME', got {algorithm}!")

self.time_noise_cov_class = time_noise_cov
self.space_noise_cov_class = space_noise_cov

self.optCtrl, self.optMethod = optCtrl, optMethod

def logp(self, X, S=None):
""" MatnormSRM marginal log-likelihood, integrating over W"""

if S is None:
S = self.S

subj_space_covs = [CovScaleMixin(base_cov=self.space_cov,
scale=1/self.rhoprec[j]) for j in range(self.n)]
return tf.reduce_sum(
input_tensor=[matnorm_logp_marginal_col(X[j],
row_cov=subj_space_covs[j],
col_cov=self.time_cov,
marg=S,
marg_cov=CovIdentity(size=self.k))
for j in range(self.n)], name="lik_logp")

def Q_fun(self, X, S=None):

if S is None:
S = self.S

# shorthands for readability
kpt = self.k + self.t
nv = self.n * self.v

mean = X - self.b - tf.matmul(self.w_prime,
tf.tile(tf.expand_dims(S, 0),
[self.n, 1, 1]))

# covs don't support batch ops (yet!) (TODO):
x_quad_form = -tf.linalg.trace(tf.reduce_sum(
input_tensor=[tf.matmul(self.time_cov.solve(
tf.transpose(a=mean[j])),
self.space_cov.solve(mean[j])) *
self.rhoprec[j]
for j in range(self.n)], axis=0))

w_quad_form = -tf.linalg.trace(tf.reduce_sum(
input_tensor=[tf.matmul(self.marg_cov.solve(
tf.transpose(a=self.w_prime[j])),
self.space_cov.solve(self.w_prime[j])) *
self.rhoprec[j]
for j in range(self.n)], axis=0))

s_quad_form = - \
tf.linalg.trace(tf.matmul(self.time_cov.solve(
tf.transpose(a=S)), S))
det_terms = -(nv+self.k) * self.time_cov.logdet -\
kpt*self.n*self.space_cov.logdet +\
kpt*self.v*tf.reduce_sum(input_tensor=tf.math.log(self.rhoprec)) -\
nv*self.marg_cov.logdet

trace_prod = -tf.reduce_sum(input_tensor=self.rhoprec / self.rhoprec_prime) *\
tf.linalg.trace(self.space_cov.solve(self.vcov_prime)) *\
(tf.linalg.trace(tf.matmul(self.wcov_prime, self.marg_cov._prec +
tf.matmul(S, self.time_cov.solve(
tf.transpose(a=S))))))

return 0.5 * (det_terms +
x_quad_form +
w_quad_form +
trace_prod +
s_quad_form)

@assert_monotonicity
def estep_margw(self, X):

wchol = tf.linalg.cholesky(tf.eye(self.k, dtype=tf.float64) +
tf.matmul(self.S, self.time_cov.solve(
tf.transpose(a=self.S))))

wcov_prime = tf.linalg.cholesky_solve(wchol, tf.eye(self.k, dtype=tf.float64))

stacked_rhs = tf.tile(tf.expand_dims(self.time_cov.solve(
tf.transpose(a=tf.linalg.cholesky_solve(wchol, self.S))), 0),
[self.n, 1, 1])

w_prime = tf.matmul(self.X-self.b, stacked_rhs)

# rhoprec doesn't change
# vcov doesn't change
self.w_prime.assign(w_prime, read_value=False)
self.wcov_prime.assign(wcov_prime, read_value=False)

@assert_monotonicity
def mstep_b_margw(self, X):
resids_transpose = [tf.transpose(X[j] - self.w_prime[j] @ self.S) for j in range(self.n)]
numerator = [tf.reduce_sum(tf.transpose(self.time_cov.solve(r)), axis=1) for r in resids_transpose]
denominator = tf.reduce_sum(self.time_cov._prec)

self.b.assign(tf.stack([n/denominator for n in numerator])[...,None], read_value=False)

@assert_monotonicity
def mstep_S(self, X):
wtw = tf.reduce_sum(
input_tensor=[tf.matmul(self.w_prime[j],
self.space_cov.solve(
self.w_prime[j]),
transpose_a=True) *
self.rhoprec[j] for j in range(self.n)], axis=0)

wtx = tf.reduce_sum(
input_tensor=[tf.matmul(self.w_prime[j],
self.space_cov.solve(
X[j]-self.b[j]),
transpose_a=True) *
self.rhoprec[j] for j in range(self.n)], axis=0)

self.S.assign(tf.linalg.solve(wtw + tf.reduce_sum(input_tensor=self.rhoprec_prime / self.rhoprec) *
tf.linalg.trace(self.space_cov.solve(self.vcov_prime)) *
self.wcov_prime + tf.eye(self.k, dtype=tf.float64), wtx), read_value=False)

@assert_monotonicity
def mstep_rhoprec_margw(self, X):

mean = X - self.b -\
tf.matmul(self.w_prime,
tf.tile(tf.expand_dims(self.S, 0),
[self.n, 1, 1]))

mean_trace = tf.stack(
[tf.linalg.trace(tf.matmul(self.time_cov.solve(
tf.transpose(a=mean[j])),
self.space_cov.solve(mean[j]))) for j in range(self.n)])

w_trace = tf.stack(
[tf.linalg.trace(tf.matmul(self.marg_cov.solve(
tf.transpose(a=self.w_prime[j])),
self.space_cov.solve(self.w_prime[j])))
for j in range(self.n)])

shared_term = (1/self.rhoprec_prime) *\
tf.linalg.trace(self.space_cov.solve(self.vcov_prime)) *\
(tf.linalg.trace(self.marg_cov.solve(self.wcov_prime)) +
tf.linalg.trace(self.S @ self.time_cov.solve(tf.transpose(self.S))))

rho_hat_unscaled = mean_trace + w_trace + shared_term

self.rhoprec.assign((self.v*(self.k+self.t)) / rho_hat_unscaled, read_value=False)

@assert_monotonicity
def mstep_covs(self):
for cov in [self.space_cov, self.time_cov, self.marg_cov]:
if len(cov.get_optimize_vars()) > 0:
val_and_grad = make_val_and_grad(
self.lossfn, cov.get_optimize_vars())

x0 = pack_trainable_vars(cov.get_optimize_vars())

opt_results = minimize(
fun=val_and_grad, x0=x0, jac=True, method=self.optMethod,
**self.optCtrl
)
assert opt_results.success, f"L-BFGS for covariances failed with message: {opt_results.message}"

def mstep_margw(self, X):
# closed form parts
self.mstep_b_margw(X)
self.mstep_rhoprec_margw(X)
self.mstep_S(X)

# L-BFGS for residual covs
self.mstep_covs()

def _init_vars(self, X, svd_init=False):
self.n = len(X)

self.v, self.t = X[0].shape

self.X = tf.constant(X, name="X")

if svd_init:
xinit = [np.linalg.svd(x) for x in X]
else:
xinit = [np.linalg.svd(np.random.normal(
size=(self.v, self.t))) for i in range(self.n)]

# parameters
self.b = tf.Variable(np.random.normal(size=(self.n, self.v, 1)),
name="b")
self.rhoprec = tf.Variable(np.ones(self.n), name="rhoprec")
self.space_cov = self.space_noise_cov_class(size=self.v)
self.time_cov = self.time_noise_cov_class(size=self.t)
self.marg_cov = self.marg_cov_class(size=self.k)
self.S = tf.Variable(np.average([s[2][:self.k, :] for s in xinit],0),
dtype=tf.float64, name="S")

# sufficient statistics
self.w_prime = tf.Variable(np.array([s[0][:, :self.k] for s in xinit]),
name="w_prime")
self.rhoprec_prime = tf.Variable(np.ones(self.n), name="rhoprec_prime")
self.wcov_prime = tf.Variable(np.eye(self.k), name="wcov_prime")
self.vcov_prime = tf.Variable(np.eye(self.v), name="vcov_prime")

def fit(self, X, max_iter=10, y=None, svd_init=False, rtol=1e-3, gtol=1e-7):
"""
find S marginalizing W

Parameters
----------
X: 2d array
Brain data matrix (voxels by TRs). Y in the math
n_iter: int, default=10
Max iterations to run
"""

# in case we get a list, and/or int16s or float32s
X = np.array(X).astype(np.float64)
self._init_vars(X, svd_init=svd_init)

if self.algorithm == "ECME":
self.lossfn = lambda theta: -self.logp(X)
loss_name = "-Marginal Lik"
elif self.algorithm == "ECM":
self.lossfn = lambda theta: -self.Q_fun(X)
loss_name = "-ELPD (Q)"


prevloss = self.lossfn(None)
converged = False
for em_iter in range(max_iter):

logger.info(f"Iter {em_iter}, {loss_name} at start {prevloss}")
# print(f"Iter {em_iter}, {loss_name} at start {q_start}")

# ESTEP
self.estep_margw(X)
currloss = self.lossfn(None)
logger.info(f"Iter {em_iter}, {loss_name} at estep end {currloss}")
assert currloss - prevloss <= 0.1 , f"{loss_name} increased in E-step!"
prevloss = currloss
# MSTEP
self.mstep_margw(X)

currloss = self.lossfn(None)
logger.info(f"Iter {em_iter}, {loss_name} at mstep end {currloss}")
currloss = self.lossfn(None)
assert currloss - prevloss <= 0.1, f"{loss_name} increased in M-step!"

if prevloss - currloss < abs(rtol * prevloss):
break
converged = True
converged_reason = "rtol"
elif self._loss_gradnorm() < gtol:
break
converged = True
converged_reason = "gtol"

if converged:
logger.info(f"Converged in {em_iter} iterations with by metric {converged_reason}")
else:
logger.warn("Not converged to tolerance!\
Results may not be reliable")
self.w_ = self.w_prime.numpy()
self.s_ = self.S.numpy()
self.rho_ = 1/self.rhoprec.numpy()

self.final_loss_ = self.lossfn(None)
self.logp_ = self.logp(X)

def _loss_gradnorm(self):

params = [self.S, self.rhoprec] +\
self.space_cov.get_optimize_vars() +\
self.time_cov.get_optimize_vars() +\
self.marg_cov.get_optimize_vars()
if self.algorithm == "ECM":
# if ECME, marginal likelihood is independent
# of W sufficient statistic
params.append(self.w_prime)

val_and_grad = make_val_and_grad(self.lossfn, params)
packed_params = pack_trainable_vars(params)
_, grad = val_and_grad(packed_params)
return np.linalg.norm(grad, np.inf)

def _condition(self, x):
s = np.linalg.svd(x, compute_uv=False)
return np.max(s)/np.min(s)

def transform(self, X, ortho_w=False):
if ortho_w:
w_local = [w @ np.linalg.svd(
w.T @ w)[0] / np.sqrt(np.linalg.svd(w.T @ w)[1]) for w in self.w_]
else:
w_local = self.w_

vprec_w = [self.space_cov.solve(w).numpy(
) / r for (w, r) in zip(w_local, self.rhoprec_)]
vprec_x = [self.space_cov.solve(x).numpy(
) / r for (x, r) in zip(X, self.rhoprec_)]
conditions = [self._condition(w.T @ vw)
for (w, vw) in zip(w_local, self.vprec_w)]
logger.info(["Condition #s for transformation"] + conditions)
return [np.linalg.solve(w.T @ vw, w.T @ vx) for (w, vw, vx) in zip(w_local, vprec_w, vprec_x)]
2 changes: 0 additions & 2 deletions brainiak/matnormal/matnormal_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def solve_det_marginal(x, sigma, A, Q):
logging.DEBUG,
f"lemma_factor condition={lemma_cond}",
)
logging.log(logging.DEBUG, f"Q condition={_condition(Q._cov)}")
logging.log(logging.DEBUG, f"sigma condition={_condition(sigma._cov)}")
logging.log(
logging.DEBUG,
f"sigma max={tf.reduce_max(input_tensor=A)}," +
Expand Down
Loading