-
Notifications
You must be signed in to change notification settings - Fork 141
Add a score function for SRM #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -552,6 +552,87 @@ def _srm(self, data): | |
| sigma_s = self.comm.bcast(sigma_s) | ||
| return sigma_s, w, mu, rho2, shared_response | ||
|
|
||
| def score(self, data): | ||
| """Calculate the log-likelihood of test-subject's data for the model selection | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
||
| data : list of 2D arrays, element i has shape=[voxels_i, samples] | ||
| Each element in the list contains the fMRI data of one subject. | ||
|
|
||
|
|
||
| Returns | ||
| ------- | ||
|
|
||
| ll_score : Log-likelihood of test-subject's data | ||
| """ | ||
| local_min = min([d.shape[1] for d in data if d is not None], | ||
| default=sys.maxsize) | ||
| samples = self.comm.allreduce(local_min, op=MPI.MIN) | ||
| subjects = len(data) | ||
|
|
||
| random_states = [ | ||
| np.random.RandomState(self.random_state_.randint(2 ** 32)) | ||
| for i in range(len(data))] | ||
|
|
||
| # Transform from new subjects | ||
| _, voxels = _init_w_transforms(data, self.features, random_states, | ||
| self.comm) | ||
|
|
||
| # Compute the transform matricies for test-subjects | ||
| w = [] | ||
| for subject in range(subjects): | ||
| _w = self._update_transform_subject(data[subject], self.s_) | ||
| w.append(_w) | ||
|
Comment on lines
+584
to
+587
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not necessary but I think the following is slightly faster |
||
|
|
||
| x, mu, rho2, trace_xtx = self._init_structures(data, subjects) | ||
| sigma_s = self.sigma_s_ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not necessary but how about just use |
||
|
|
||
| # Sum the inverted the rho2 elements for computing W^T * Psi^-1 * W | ||
| rho0 = (1 / rho2).sum() | ||
|
|
||
| # Invert Sigma_s using Cholesky factorization | ||
| (chol_sigma_s, lower_sigma_s) = scipy.linalg.cho_factor( | ||
| sigma_s, check_finite=False) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. referred in a comment above |
||
| inv_sigma_s = scipy.linalg.cho_solve( | ||
| (chol_sigma_s, lower_sigma_s), np.identity(self.features), | ||
| check_finite=False) | ||
|
|
||
| # Invert (Sigma_s + rho_0 * I) using Cholesky factorization | ||
| sigma_s_rhos = inv_sigma_s + np.identity(self.features) * rho0 | ||
| chol_sigma_s_rhos, lower_sigma_s_rhos = \ | ||
| scipy.linalg.cho_factor(sigma_s_rhos, | ||
| check_finite=False) | ||
| inv_sigma_s_rhos = scipy.linalg.cho_solve( | ||
| (chol_sigma_s_rhos, lower_sigma_s_rhos), | ||
| np.identity(self.features), check_finite=False) | ||
|
|
||
| # Compute the sum of W_i^T * rho_i^-2 * X_i, and the sum of traces | ||
| # of X_i^T * rho_i^-2 * X_i | ||
| wt_invpsi_x = np.zeros((self.features, samples)) | ||
| trace_xt_invsigma2_x = 0.0 | ||
| for subject in range(subjects): | ||
| if data[subject] is not None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. referred in a comment above |
||
| wt_invpsi_x += (w[subject].T.dot(x[subject])) \ | ||
| / rho2[subject] | ||
| trace_xt_invsigma2_x += trace_xtx[subject] / rho2[subject] | ||
|
|
||
| wt_invpsi_x = self.comm.reduce(wt_invpsi_x, op=MPI.SUM) | ||
| trace_xt_invsigma2_x = self.comm.reduce(trace_xt_invsigma2_x, | ||
| op=MPI.SUM) | ||
| log_det_psi = np.sum(np.log(rho2) * voxels) | ||
|
|
||
| # Compute the log-likelihood | ||
| ll_score = self._likelihood( | ||
| chol_sigma_s_rhos, log_det_psi, chol_sigma_s, | ||
| trace_xt_invsigma2_x, inv_sigma_s_rhos, wt_invpsi_x, | ||
| samples) | ||
|
|
||
| # Add the constant term | ||
| ll_score -= 0.5*samples*np.sum(voxels)*np.log(2*np.pi) | ||
| return ll_score | ||
|
|
||
|
|
||
| class DetSRM(BaseEstimator, TransformerMixin): | ||
| """Deterministic Shared Response Model (DetSRM) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be better to have an input validation here to make sure all
dindatahave the same number of samples?I noticed that later on line 616, you are skipping any
dindatathat isNone. I was wondering if this can be done earlier, such a removingNonefrom the very beginning, or warn the user that there is None in thedatalist? Feel free to decide what's more user-friendly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah and this function probably want to assert
len(data) >= 2