Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions brainiak/funcalign/srm.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,87 @@ def _srm(self, data):
sigma_s = self.comm.bcast(sigma_s)
return sigma_s, w, mu, rho2, shared_response

def score(self, data):
"""Calculate the log-likelihood of test-subject's data for the model selection

Parameters
----------

data : list of 2D arrays, element i has shape=[voxels_i, samples]
Each element in the list contains the fMRI data of one subject.


Returns
-------

ll_score : Log-likelihood of test-subject's data
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better to have an input validation here to make sure all d in data have the same number of samples?

I noticed that later on line 616, you are skipping any d in data that is None. I was wondering if this can be done earlier, such a removing None from the very beginning, or warn the user that there is None in the data list? Feel free to decide what's more user-friendly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah and this function probably want to assert len(data) >= 2

local_min = min([d.shape[1] for d in data if d is not None],
default=sys.maxsize)
samples = self.comm.allreduce(local_min, op=MPI.MIN)
subjects = len(data)

random_states = [
np.random.RandomState(self.random_state_.randint(2 ** 32))
for i in range(len(data))]

# Transform from new subjects
_, voxels = _init_w_transforms(data, self.features, random_states,
self.comm)

# Compute the transform matricies for test-subjects
w = []
for subject in range(subjects):
_w = self._update_transform_subject(data[subject], self.s_)
w.append(_w)
Comment on lines +584 to +587
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary but I think the following is slightly faster

w = [self._update_transform_subject(data[subject], self.s_) for subject in range(subjects)]


x, mu, rho2, trace_xtx = self._init_structures(data, subjects)
sigma_s = self.sigma_s_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary but how about just use self.sigma_s_ on line 597?


# Sum the inverted the rho2 elements for computing W^T * Psi^-1 * W
rho0 = (1 / rho2).sum()

# Invert Sigma_s using Cholesky factorization
(chol_sigma_s, lower_sigma_s) = scipy.linalg.cho_factor(
sigma_s, check_finite=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

referred in a comment above

inv_sigma_s = scipy.linalg.cho_solve(
(chol_sigma_s, lower_sigma_s), np.identity(self.features),
check_finite=False)

# Invert (Sigma_s + rho_0 * I) using Cholesky factorization
sigma_s_rhos = inv_sigma_s + np.identity(self.features) * rho0
chol_sigma_s_rhos, lower_sigma_s_rhos = \
scipy.linalg.cho_factor(sigma_s_rhos,
check_finite=False)
inv_sigma_s_rhos = scipy.linalg.cho_solve(
(chol_sigma_s_rhos, lower_sigma_s_rhos),
np.identity(self.features), check_finite=False)

# Compute the sum of W_i^T * rho_i^-2 * X_i, and the sum of traces
# of X_i^T * rho_i^-2 * X_i
wt_invpsi_x = np.zeros((self.features, samples))
trace_xt_invsigma2_x = 0.0
for subject in range(subjects):
if data[subject] is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

referred in a comment above

wt_invpsi_x += (w[subject].T.dot(x[subject])) \
/ rho2[subject]
trace_xt_invsigma2_x += trace_xtx[subject] / rho2[subject]

wt_invpsi_x = self.comm.reduce(wt_invpsi_x, op=MPI.SUM)
trace_xt_invsigma2_x = self.comm.reduce(trace_xt_invsigma2_x,
op=MPI.SUM)
log_det_psi = np.sum(np.log(rho2) * voxels)

# Compute the log-likelihood
ll_score = self._likelihood(
chol_sigma_s_rhos, log_det_psi, chol_sigma_s,
trace_xt_invsigma2_x, inv_sigma_s_rhos, wt_invpsi_x,
samples)

# Add the constant term
ll_score -= 0.5*samples*np.sum(voxels)*np.log(2*np.pi)
return ll_score


class DetSRM(BaseEstimator, TransformerMixin):
"""Deterministic Shared Response Model (DetSRM)
Expand Down