diff --git a/brainiak/funcalign/rsrm.py b/brainiak/funcalign/rsrm.py index 9e2d419a..929ace53 100644 --- a/brainiak/funcalign/rsrm.py +++ b/brainiak/funcalign/rsrm.py @@ -97,7 +97,8 @@ class RSRM(BaseEstimator, TransformerMixin): The number of voxels may be different between subjects. However, the number of timepoints for the alignment data must be the same across - subjects. + subjects. Note that unlike SRM, DetSRM does not handle vector shifts + (intercepts) across subjects. The Robust Shared Response Model is approximated using the Block-Coordinate Descent (BCD) algorithm proposed in [Turek2017]_. @@ -380,7 +381,7 @@ def _objective_function(X, W, R, S, gamma): func = .0 for i in range(subjs): func += 0.5 * np.sum((X[i] - W[i].dot(R) - S[i])**2) \ - + gamma * np.sum(np.abs(S[i])) + + gamma * np.sum(np.abs(S[i])) return func @staticmethod @@ -473,7 +474,7 @@ def _update_shared_response(X, S, W, features): # Project the subject data with the individual component removed into # the shared subspace and average over all subjects. for i in range(subjs): - R += W[i].T.dot(X[i]-S[i]) + R += W[i].T.dot(X[i] - S[i]) R /= subjs return R diff --git a/brainiak/funcalign/srm.py b/brainiak/funcalign/srm.py index 3ed56a1d..82ce1010 100644 --- a/brainiak/funcalign/srm.py +++ b/brainiak/funcalign/srm.py @@ -296,8 +296,8 @@ def transform(self, X, y=None): s = [None] * len(X) for subject in range(len(X)): if X[subject] is not None: - s[subject] = self.w_[subject].T.dot(X[subject]) - + s[subject] = self.w_[subject].T.dot( + X[subject] - self.mu_[subject][:, np.newaxis]) return s def _init_structures(self, data, subjects): @@ -419,7 +419,10 @@ def _update_transform_subject(Xi, S): def transform_subject(self, X): """Transform a new subject using the existing model. - The subject is assumed to have recieved equivalent stimulation + The subject is assumed to have recieved equivalent stimulation. In + particular, to transform the new subject X with w and mu, one can do + the following: + shared_X = w.T @ (X - mu[:, np.newaxis]) Parameters ---------- @@ -432,6 +435,8 @@ def transform_subject(self, X): w : 2D array, shape=[voxels, features] Orthogonal mapping `W_{new}` for new subject + mu : 1D array, shape=[voxels] + The voxel means for the new subject """ # Check if the model exist @@ -442,10 +447,10 @@ def transform_subject(self, X): if X.shape[1] != self.s_.shape[1]: raise ValueError("The number of timepoints(TRs) does not match the" "one in the model.") - - w = self._update_transform_subject(X, self.s_) - - return w + # get the intercept for mean centering, as procrustes doesn't handle it + mu = np.mean(X, axis=1) + w = self._update_transform_subject(X - mu[:, np.newaxis], self.s_) + return w, mu def save(self, file): """Save fitted SRM to .npz file. @@ -562,7 +567,7 @@ def _srm(self, data): for subject in range(subjects): if data[subject] is not None: wt_invpsi_x += (w[subject].T.dot(x[subject])) \ - / rho2[subject] + / rho2[subject] trace_xt_invsigma2_x += trace_xtx[subject] / rho2[subject] wt_invpsi_x = self.comm.reduce(wt_invpsi_x, op=MPI.SUM) @@ -650,6 +655,9 @@ class DetSRM(BaseEstimator, TransformerMixin): s_ : array, shape=[features, samples] The shared response. + mu_ : list of array, element i has shape=[voxels_i] + The voxel means over the samples for each subject. + random_state_: `RandomState` Random number generator initialized using rand_seed @@ -742,8 +750,8 @@ def transform(self, X, y=None): s = [None] * len(X) for subject in range(len(X)): - s[subject] = self.w_[subject].T.dot(X[subject]) - + s[subject] = self.w_[subject].T.dot( + X[subject] - self.mu_[subject][:, np.newaxis]) return s def _objective_function(self, data, w, s): @@ -818,11 +826,16 @@ def _update_transform_subject(Xi, S): Wi : array, shape=[voxels, features] The orthogonal transform (mapping) :math:`W_i` for the subject. + + mu : 1D array, shape=[voxels] + The voxel means for the new subject """ - A = Xi.dot(S.T) + # estimate the intercept and center the data + mu = np.mean(Xi, axis=1) + A = (Xi - mu[:, np.newaxis]).dot(S.T) # Solve the Procrustes problem U, _, V = np.linalg.svd(A, full_matrices=False) - return U.dot(V) + return U.dot(V), mu def transform_subject(self, X): """Transform a new subject using the existing model. @@ -880,13 +893,20 @@ def _srm(self, data): np.random.RandomState(self.random_state_.randint(2 ** 32)) for i in range(len(data))] + # compute subject specific intercept + self.mu_ = [np.mean(data[s], axis=1) for s in range(subjects)] + # center the data + data = [data[s] - self.mu_[s][:, np.newaxis] + for s in range(subjects)] + # Initialization step: initialize the outputs with initial values, # voxels with the number of voxels in each subject. w, _ = _init_w_transforms(data, self.features, random_states) shared_response = self._compute_shared_response(data, w) if logger.isEnabledFor(logging.INFO): # Calculate the current objective function value - objective = self._objective_function(data, w, shared_response) + objective = self._objective_function( + data, w, shared_response) logger.info('Objective function %f' % objective) # Main loop of the algorithm @@ -907,7 +927,8 @@ def _srm(self, data): if logger.isEnabledFor(logging.INFO): # Calculate the current objective function value - objective = self._objective_function(data, w, shared_response) + objective = self._objective_function( + data, w, shared_response) logger.info('Objective function %f' % objective) return w, shared_response diff --git a/brainiak/funcalign/sssrm.py b/brainiak/funcalign/sssrm.py index b43e9428..e0c22c16 100644 --- a/brainiak/funcalign/sssrm.py +++ b/brainiak/funcalign/sssrm.py @@ -126,6 +126,8 @@ class SSSRM(BaseEstimator, ClassifierMixin, TransformerMixin): The number of voxels may be different between subjects. However, the number of samples for the alignment data must be the same across subjects. The number of labeled samples per subject can be different. + Note that unlike SRM, DetSRM does not handle vector shifts (intercepts) + across subjects. The Semi-Supervised Shared Response Model is approximated using the Block-Coordinate Descent (BCD) algorithm proposed in [Turek2016]_. diff --git a/tests/funcalign/test_srm.py b/tests/funcalign/test_srm.py index b109a77b..10b91369 100644 --- a/tests/funcalign/test_srm.py +++ b/tests/funcalign/test_srm.py @@ -44,7 +44,7 @@ def test_can_instantiate(tmp_path): W = [] Q, R = np.linalg.qr(np.random.random((voxels, features))) W.append(Q) - X.append(Q.dot(S) + 0.1*np.random.random((voxels, samples))) + X.append(Q.dot(S) + 0.1 * np.random.random((voxels, samples))) # Check that transform does NOT run before fitting the model with pytest.raises(NotFittedError): @@ -59,7 +59,7 @@ def test_can_instantiate(tmp_path): for subject in range(1, subjects): Q, R = np.linalg.qr(np.random.random((voxels, features))) W.append(Q) - X.append(Q.dot(S) + 0.1*np.random.random((voxels, samples))) + X.append(Q.dot(S) + 0.1 * np.random.random((voxels, samples))) # Check that runs with 2 subject s.fit(X) @@ -81,7 +81,7 @@ def test_can_instantiate(tmp_path): difference = np.linalg.norm(X[subject] - s.w_[subject].dot(s.s_), 'fro') datanorm = np.linalg.norm(X[subject], 'fro') - assert difference/datanorm < 1.0, "Model seems incorrectly computed." + assert difference / datanorm < 1.0, "Model seems incorrectly computed." assert s.s_.shape[0] == features, ( "Invalid computation of SRM! (wrong # features in S)") assert s.s_.shape[1] == samples, ( @@ -106,7 +106,7 @@ def test_can_instantiate(tmp_path): # Check that it does not run without enough samples (TRs). with pytest.raises(ValueError): - s.set_params(features=(samples+1)) + s.set_params(features=(samples + 1)) s.fit(X) print("Test: not enough samples") @@ -162,12 +162,12 @@ def test_new_subject(): W = [] Q, R = np.linalg.qr(np.random.random((voxels, features))) W.append(Q) - X.append(Q.dot(S) + 0.1*np.random.random((voxels, samples))) + X.append(Q.dot(S) + 0.1 * np.random.random((voxels, samples))) for subject in range(1, subjects): Q, R = np.linalg.qr(np.random.random((voxels, features))) W.append(Q) - X.append(Q.dot(S) + 0.1*np.random.random((voxels, samples))) + X.append(Q.dot(S) + 0.1 * np.random.random((voxels, samples))) # Check that transform does NOT run before fitting the model with pytest.raises(NotFittedError): @@ -182,11 +182,15 @@ def test_new_subject(): s.transform_subject(X[0].T) # Check that it does run to compute a new subject - new_w = s.transform_subject(X[0]) + new_w, new_mu = s.transform_subject(X[0]) + np.shape(new_w) + np.shape(new_mu) assert new_w.shape[1] == features, ( - "Invalid computation of SRM! (wrong # features for new subject)") + "Invalid computation of SRM! (wrong # features for new subject)") assert new_w.shape[0] == voxels, ( - "Invalid computation of SRM! (wrong # voxels for new subject)") + "Invalid computation of SRM! (wrong # voxels for new subject)") + assert new_mu.shape[0] == voxels, ( + "Invalid computation of SRM! (wrong # voxels for new subject)") # Check that these analyses work with the deterministic SRM too ds = brainiak.funcalign.srm.DetSRM(n_iter=5, features=features) @@ -204,11 +208,13 @@ def test_new_subject(): ds.transform_subject(X[0].T) # Check that it does run to compute a new subject - new_w = ds.transform_subject(X[0]) + new_w, new_mu = ds.transform_subject(X[0]) assert new_w.shape[1] == features, ( - "Invalid computation of SRM! (wrong # features for new subject)") + "Invalid computation of SRM! (wrong # features for new subject)") assert new_w.shape[0] == voxels, ( - "Invalid computation of SRM! (wrong # voxels for new subject)") + "Invalid computation of SRM! (wrong # voxels for new subject)") + assert new_mu.shape[0] == voxels, ( + "Invalid computation of SRM! (wrong # voxels for new subject)") def test_det_srm(): @@ -239,7 +245,7 @@ def test_det_srm(): W = [] Q, R = np.linalg.qr(np.random.random((voxels, features))) W.append(Q) - X.append(Q.dot(S) + 0.1*np.random.random((voxels, samples))) + X.append(Q.dot(S) + 0.1 * np.random.random((voxels, samples))) # Check that transform does NOT run before fitting the model with pytest.raises(NotFittedError): @@ -254,7 +260,7 @@ def test_det_srm(): for subject in range(1, subjects): Q, R = np.linalg.qr(np.random.random((voxels, features))) W.append(Q) - X.append(Q.dot(S) + 0.1*np.random.random((voxels, samples))) + X.append(Q.dot(S) + 0.1 * np.random.random((voxels, samples))) # Check that runs with 2 subject model.fit(X) @@ -274,7 +280,7 @@ def test_det_srm(): - model.w_[subject].dot(model.s_), 'fro') datanorm = np.linalg.norm(X[subject], 'fro') - assert difference/datanorm < 1.0, "Model seems incorrectly computed." + assert difference / datanorm < 1.0, "Model seems incorrectly computed." assert model.s_.shape[0] == features, ( "Invalid computation of DetSRM! (wrong # features in S)") assert model.s_.shape[1] == samples, ( @@ -294,11 +300,13 @@ def test_det_srm(): "Invalid computation of DetSRM! (wrong # samples after transform)") # Check that it does run to compute a new subject - new_w = model.transform_subject(X[0]) + new_w, new_mu = model.transform_subject(X[0]) assert new_w.shape[1] == features, ( - "Invalid computation of SRM! (wrong # features for new subject)") + "Invalid computation of SRM! (wrong # features for new subject)") assert new_w.shape[0] == voxels, ( - "Invalid computation of SRM! (wrong # voxels for new subject)") + "Invalid computation of SRM! (wrong # voxels for new subject)") + assert new_mu.shape[0] == voxels, ( + "Invalid computation of SRM! (wrong # voxels for new subject)") # Check that it does NOT run with non-matching number of subjects with pytest.raises(ValueError): @@ -307,7 +315,7 @@ def test_det_srm(): # Check that it does not run without enough samples (TRs). with pytest.raises(ValueError): - model.set_params(features=(samples+1)) + model.set_params(features=(samples + 1)) model.fit(X) print("Test: not enough samples") @@ -317,3 +325,99 @@ def test_det_srm(): with pytest.raises(ValueError): model.fit(X) print("Test: different number of samples per subject") + + +def test_vector_shift_srm(): + import brainiak.funcalign.srm + import numpy as np + from scipy.linalg import qr + np.random.seed(0) + nvoxels = 2 + ntps = 4 + nsubjs = 3 + # initialize S + S = np.random.uniform(size=(nvoxels, ntps)) + # preallocate + W_truth = [None] * nsubjs + X = [None] * nsubjs + intercept = [None] * nsubjs + # make simulated data, such that each subject is + # X_i = W_i.T @ S + intercept + # ... where W_i is a random ortho matrix and intercept is a random vector + for i in range(nsubjs): + # sample an ortho matrix + H = np.random.randn(nvoxels, nvoxels) + W_truth[i], _ = qr(H) + # sample an intercept + intercept[i] = np.random.randn(nvoxels) + # simulate the i-th subject + X[i] = W_truth[i].T @ S + intercept[i][:, np.newaxis] + # make a new subject + # sample an ortho matrix + H = np.random.randn(nvoxels, nvoxels) + W_new, _ = qr(H) + # sample an intercept + intercept_new = np.random.randn(nvoxels) + # simulate a new subject + X_new = W_new.T @ S + intercept_new[:, np.newaxis] + # fit SRM on the training set, X + srm = brainiak.funcalign.srm.SRM(features=nvoxels) + X_shared = srm.fit_transform(X) + # map the new subject to the pre-trained SRM -- estimate W and intercept + W_hat, mu_hat = srm.transform_subject(X_new) + SX_new = W_hat.T @ (X_new - mu_hat[:, np.newaxis]) + # check all subjects in the training set are aligned (small frobenius diff) + for i in np.arange(nsubjs): + assert np.linalg.norm(X_shared[0] - X_shared[i], ord='fro') < 1e-10, ( + 'subject is misaligned with subject 0') + # check the new subject is also aligned (small frobenius diff) + assert np.linalg.norm(X_shared[0] - SX_new, ord='fro') < 1e-10, ( + 'the new subject is misaligned with subject 0') + + +def test_vector_shift_detsrm(): + import brainiak.funcalign.srm + import numpy as np + from scipy.linalg import qr + np.random.seed(0) + nvoxels = 2 + ntps = 4 + nsubjs = 3 + # initialize S + S = np.random.uniform(size=(nvoxels, ntps)) + # preallocate + W_truth = [None] * nsubjs + X = [None] * nsubjs + intercept = [None] * nsubjs + # make simulated data, such that each subject is + # X_i = W_i.T @ S + intercept + # ... where W_i is a random ortho matrix and intercept is a random vector + for i in range(nsubjs): + # sample an ortho matrix + H = np.random.randn(nvoxels, nvoxels) + W_truth[i], _ = qr(H) + # sample an intercept + intercept[i] = np.random.randn(nvoxels) + # simulate the i-th subject + X[i] = W_truth[i].T @ S + intercept[i][:, np.newaxis] + # make a new subject + # sample an ortho matrix + H = np.random.randn(nvoxels, nvoxels) + W_new, _ = qr(H) + # sample an intercept + intercept_new = np.random.randn(nvoxels) + # simulate a new subject + X_new = W_new.T @ S + intercept_new[:, np.newaxis] + # fit SRM on the training set, X + detsrm = brainiak.funcalign.srm.DetSRM(features=nvoxels) + X_shared = detsrm.fit_transform(X) + # map the new subject to the pre-trained SRM -- estimate W and intercept + W_hat, mu_hat = detsrm.transform_subject(X_new) + SX_new = W_hat.T @ (X_new - mu_hat[:, np.newaxis]) + # check all subjects in the training set are aligned (small frobenius diff) + for i in np.arange(nsubjs): + assert np.linalg.norm(X_shared[0] - X_shared[i], ord='fro') < 1e-10, ( + 'subject is misaligned with subject 0') + # check the new subject is also aligned (small frobenius diff) + assert np.linalg.norm(X_shared[0] - SX_new, ord='fro') < 1e-10, ( + 'the new subject is misaligned with subject 0')