diff --git a/brainiak/funcalign/srm.py b/brainiak/funcalign/srm.py index 023e71872..2cc452cf6 100644 --- a/brainiak/funcalign/srm.py +++ b/brainiak/funcalign/srm.py @@ -552,6 +552,87 @@ def _srm(self, data): sigma_s = self.comm.bcast(sigma_s) return sigma_s, w, mu, rho2, shared_response + def score(self, data): + """Calculate the log-likelihood of test-subject's data for the model selection + + Parameters + ---------- + + data : list of 2D arrays, element i has shape=[voxels_i, samples] + Each element in the list contains the fMRI data of one subject. + + + Returns + ------- + + ll_score : Log-likelihood of test-subject's data + """ + local_min = min([d.shape[1] for d in data if d is not None], + default=sys.maxsize) + samples = self.comm.allreduce(local_min, op=MPI.MIN) + subjects = len(data) + + random_states = [ + np.random.RandomState(self.random_state_.randint(2 ** 32)) + for i in range(len(data))] + + # Transform from new subjects + _, voxels = _init_w_transforms(data, self.features, random_states, + self.comm) + + # Compute the transform matricies for test-subjects + w = [] + for subject in range(subjects): + _w = self._update_transform_subject(data[subject], self.s_) + w.append(_w) + + x, mu, rho2, trace_xtx = self._init_structures(data, subjects) + sigma_s = self.sigma_s_ + + # Sum the inverted the rho2 elements for computing W^T * Psi^-1 * W + rho0 = (1 / rho2).sum() + + # Invert Sigma_s using Cholesky factorization + (chol_sigma_s, lower_sigma_s) = scipy.linalg.cho_factor( + sigma_s, check_finite=False) + inv_sigma_s = scipy.linalg.cho_solve( + (chol_sigma_s, lower_sigma_s), np.identity(self.features), + check_finite=False) + + # Invert (Sigma_s + rho_0 * I) using Cholesky factorization + sigma_s_rhos = inv_sigma_s + np.identity(self.features) * rho0 + chol_sigma_s_rhos, lower_sigma_s_rhos = \ + scipy.linalg.cho_factor(sigma_s_rhos, + check_finite=False) + inv_sigma_s_rhos = scipy.linalg.cho_solve( + (chol_sigma_s_rhos, lower_sigma_s_rhos), + np.identity(self.features), check_finite=False) + + # Compute the sum of W_i^T * rho_i^-2 * X_i, and the sum of traces + # of X_i^T * rho_i^-2 * X_i + wt_invpsi_x = np.zeros((self.features, samples)) + trace_xt_invsigma2_x = 0.0 + for subject in range(subjects): + if data[subject] is not None: + wt_invpsi_x += (w[subject].T.dot(x[subject])) \ + / rho2[subject] + trace_xt_invsigma2_x += trace_xtx[subject] / rho2[subject] + + wt_invpsi_x = self.comm.reduce(wt_invpsi_x, op=MPI.SUM) + trace_xt_invsigma2_x = self.comm.reduce(trace_xt_invsigma2_x, + op=MPI.SUM) + log_det_psi = np.sum(np.log(rho2) * voxels) + + # Compute the log-likelihood + ll_score = self._likelihood( + chol_sigma_s_rhos, log_det_psi, chol_sigma_s, + trace_xt_invsigma2_x, inv_sigma_s_rhos, wt_invpsi_x, + samples) + + # Add the constant term + ll_score -= 0.5*samples*np.sum(voxels)*np.log(2*np.pi) + return ll_score + class DetSRM(BaseEstimator, TransformerMixin): """Deterministic Shared Response Model (DetSRM)