From 771fdb40e8fff32da3f39cc77bf906105b3c65e4 Mon Sep 17 00:00:00 2001 From: Hejia Zhang Date: Tue, 26 Feb 2019 00:14:35 -0500 Subject: [PATCH 1/7] Add MDMS method and examples --- brainiak/funcalign/mdms.py | 2674 +++++++++++++++++ docs/newsfragments/mdms.jinja | 34 + examples/funcalign/download-data.sh | 4 +- .../mdms_time_segment_matching_distributed.py | 200 ++ .../mdms_time_segment_matching_example.ipynb | 505 ++++ 5 files changed, 3416 insertions(+), 1 deletion(-) create mode 100644 brainiak/funcalign/mdms.py create mode 100644 docs/newsfragments/mdms.jinja create mode 100644 examples/funcalign/mdms_time_segment_matching_distributed.py create mode 100644 examples/funcalign/mdms_time_segment_matching_example.ipynb diff --git a/brainiak/funcalign/mdms.py b/brainiak/funcalign/mdms.py new file mode 100644 index 000000000..db0466a1e --- /dev/null +++ b/brainiak/funcalign/mdms.py @@ -0,0 +1,2674 @@ +# Copyright 2016 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""multi-dataset multi-subject (MDMS) SRM analysis + +The implementations are based on the following publications: +.. [Zhang2018] "Transfer learning on fMRI datasets", + H. Zhang, P.-H. Chen, P. Ramadge + The 21st International Conference on Artificial Intelligence and Statistics (AISTATS), 2018. + http://proceedings.mlr.press/v84/zhang18b/zhang18b.pdf +""" + +# Authors: Hejia Zhang (Princeton Neuroscience Institute), 2018 + +import logging + +import numpy as np +import scipy +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils import assert_all_finite +from sklearn.exceptions import NotFittedError +from mpi4py import MPI +import sys, json, os, glob +from scipy import sparse as sp +import matplotlib.pyplot as plt +import networkx as nx +import pickle as pkl + + +__all__ = [ + "DetMDMS", + "MDMS", + "Dataset" +] + +logging.basicConfig(filename='mdms.log', + filemode='a', + format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', + datefmt='%H:%M:%S', + level=logging.DEBUG) +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +def _init_w_transforms(voxels, features, random_states, datasets): + """Initialize the mappings (W_s) for the MDMS with random orthogonal matrices. + + Parameters + ---------- + + voxels : dict of int, voxels[s] is number of voxels where s is the name + of the subject. + A dict with the number of voxels for each subject. + + features : int + The number of features in the model. + + random_states : dict of `RandomState`s + One `RandomState` instance per subject. + + datasets : a Dataset object + The Dataset object containing datasets structures. + + comm : mpi4py.MPI.Intracomm + The MPI communicator containing the data + + Returns + ------- + + w : dict of array, w[s] has shape=[voxels[s], features] where s is the name + of the subject. + The initialized orthogonal transforms (mappings) :math:`W_s` for each + subject. + + + Note + ---- + + This function assumes that the numpy random number generator was + initialized. + + Not thread safe. + """ + w = {} + subjects = datasets.get_subjects_list() + + # Set Wi to a random orthogonal voxels by features matrix + for subject in subjects: + rnd_matrix = random_states[subject].random_sample(( + voxels[subject], features)) + q, r = np.linalg.qr(rnd_matrix) + w[subject] = q + return w + + +def _sanity_check(X, datasets, comm): + """Check if the input data and datasets information have valid shape/configuration. + + Parameters + ---------- + + X : dict of list of 2D arrays or dict of dict of 2D arrays + 1) When it is a dict of list of 2D arrays: + X[d] is a list of data of dataset d, where d is the name of the dataset. + Element i in the list has shape=[voxels_i, samples_d] + which is the fMRI data of the i'th subject in d. + 2) When it is a dict of dict of 2D arrays: + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of + subject s in dataset d, where s is the name of the subject and + d is the name of the dataset. + + + datasets : a Dataset object + The Dataset object containing datasets structures. + + comm : mpi4py.MPI.Intracomm + The MPI communicator containing the data. + + + Returns + ------- + + voxels_ : dict of int, voxels_[s] is number of voxels where s is the name + of the subject. + A dict with the number of voxels for each subject. + + samples_ : dict of int, samples_[d] is number of samples where d is the name + of the dataset. + A dict with the number of samples for each dataset. + + + """ + # Check the number of subjects and all ranks have all datasets in the Dataset object + ds_list = datasets.get_datasets_list() + for (ds, ns) in datasets.num_subj_dataset.items(): + if ns < 1: + raise ValueError("Dataset {} should have positive num_subj_dataset".format(ds)) + if ds not in X: + raise ValueError("Dataset {} not in all ranks".format(ds)) + if X[ds] is not None and len(X[ds]) < ns: + raise ValueError("Dataset {} does not have enough subjects: Need equal to or more " + "than {0:d} subjects but got {0:d} to train the model.".format(ds, ns, len(X[ds]))) + + # Collect size information + shape0, shape1, data_exist = {}, {}, {} + for ds in ds_list: + shape0[ds] = np.zeros((datasets.num_subj,), dtype=np.int) + shape1[ds] = np.zeros((datasets.num_subj,), dtype=np.int) + data_exist[ds] = np.zeros((datasets.num_subj,), dtype=np.int) + for ds in ds_list: + ds_idx = datasets.dataset_to_idx[ds] + if X[ds] is not None: + for subj in range(datasets.num_subj): + if datasets.dok_matrix[subj,ds_idx] != 0: + if datasets.built_from_data: + idx = datasets.idx_to_subject[subj] + if not idx in X[ds]: + raise Exception('Subject {} in dataset {} is missing.'.format(idx, ds)) + else: + idx = datasets.dok_matrix[subj,ds_idx] - 1 + if len(X[ds]) <= idx: + raise ValueError("Dataset {} does not have enough subjects: Need more " + "than {0:d} subjects but got {0:d} to train the model.".format(ds, idx, len(X[ds]))) + if X[ds][idx] is not None: + assert_all_finite(X[ds][idx]) + shape0[ds][subj] = X[ds][idx].shape[0] + shape1[ds][subj] = X[ds][idx].shape[1] + data_exist[ds][subj] = 1 + for ds in ds_list: + shape0[ds] = comm.allreduce(shape0[ds], op=MPI.SUM) + shape1[ds] = comm.allreduce(shape1[ds], op=MPI.SUM) + data_exist[ds] = comm.allreduce(data_exist[ds], op=MPI.SUM) + + # Check if all required data appears once and only once + # Also remove size information of data that is not in 'datasets' + for ds in ds_list: + ds_idx = datasets.dataset_to_idx[ds] + for subj in range(datasets.num_subj): + if datasets.dok_matrix[subj,ds_idx] != 0: + if data_exist[ds][subj] == 0: + raise ValueError("Data of subject {} in dataset {} is missing.".format(datasets.dok_matrix[subj,ds_idx]-1, ds)) + elif data_exist[ds][subj] > 1: + raise ValueError("Data of subject {} in dataset {} appears more than once.".format(datasets.dok_matrix[subj,ds_idx]-1, ds)) + else: + shape0[ds][subj] = 0 + shape1[ds][subj] = 0 + + # Check if each subject has same number of voxels across different datasets + voxels_ = {} + for subj in range(datasets.num_subj): + all_vxs_tmp = [v[subj] for v in shape0.values() if v[subj] != 0] + subj_name = datasets.idx_to_subject[subj] + voxels_[subj_name] = np.min(all_vxs_tmp) + if any([v != voxels_[subj_name] for v in all_vxs_tmp]): + raise ValueError("Subject {} has different number of voxels across" + "datasets.".format(subj_name)) + + # Check if all subjects have same number of TRs within the same dataset + samples_ = {} + for ds in ds_list: + all_trs_tmp = [t for t in shape1[ds] if t != 0] + samples_[ds] = np.min(all_trs_tmp) + if any([t != samples_[ds] for t in all_trs_tmp]): + raise ValueError("Different number of samples between subjects" + "in dataset {}.".format(ds)) + + return voxels_, samples_ + + +class MDMS(BaseEstimator, TransformerMixin): + """multi-dataset multi-subject (MDMS) SRM analysis + + Given multi-dataset multi-subject data, factorize it as a shared response S among all + subjects per dataset and an orthogonal transform W across all datasets per subject: + + .. math:: X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\dots M + + Parameters + ---------- + + n_iter : int, default: 10 + Number of iterations to run the algorithm. + + features : int, default: 50 + Number of features to compute. + + rand_seed : int, default: 0 + Seed for initializing the random number generator. + + comm : mpi4py.MPI.Intracomm + The MPI communicator containing the data + + Attributes + ---------- + + w_ : dict of array, w_[s] has shape=[voxels_[s], features], where + s is the name of the subject. + The orthogonal transforms (mappings) for each subject. + + s_ : dict of array, s_[d] has shape=[features, samples_[d]], where + d is the name of the dataset. + The shared response for each dataset. + + voxels_ : dict of int, voxels_[s] is number of voxels where s is the name + of the subject. + A dict with the number of voxels for each subject. + + samples_ : dict of int, samples_[d] is number of samples where d is the name + of the dataset. + A dict with the number of samples for each dataset. + + sigma_s_ : dict of array, sigma_s_[d] has shape=[features, features] + The covariance of the shared response Normal distribution for each dataset. + + mu_ : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name + of the subject. + The voxel means over the samples in all datasets for each subject. + + rho2_ : dict of dict of float, rho2_[d][s] is a float, where d is the name + of the dataset and s is the name of the subject. + The estimated noise variance :math:`\\rho_{di}^2` for each subject in each dataset. + + comm : mpi4py.MPI.Intracomm + The MPI communicator containing the data + + random_state_: `RandomState` + Random number generator initialized using rand_seed + + Note + ---- + + The number of voxels may be different between subjects within a dataset + and number of samples may be different between datasets. However, the + number of samples must be the same across subjects within a dataset and + number of voxels must be the same across datasets for the same subject. + + The probabilistic multi-dataset multi-subject model is approximated using the + Expectation Maximization (EM) algorithm proposed in [Zhang2018]_. The + implementation follows the optimizations published in [Anderson2016]_. + + The run-time complexity is :math:`O(I (V T K + V K^2 + K^3))` and the + memory complexity is :math:`O(V T)` with I - the number of iterations, + V - the sum of voxels from all subjects, T - the sum of samples from + all datasets, and K - the number of features (typically, :math:`V \\gg T \\gg K`). + """ + + def __init__(self, n_iter=10, features=50, rand_seed=0, + comm=MPI.COMM_SELF): + self.n_iter = n_iter + self.features = features + self.rand_seed = rand_seed + self.comm = comm + self.logger = logger + return + + + def fit(self, X, datasets=None, y=None): + """Compute the probabilistic multi-dataset multi-subject (MDMS) SRM analysis + + Parameters + ---------- + X : dict of list of 2D arrays or dict of dict of 2D arrays + 1) When it is a dict of list of 2D arrays: + 'datasets' must be defined in this case. + X[d] is a list of data of dataset d, where d is the name of the dataset. + Element i in the list has shape=[voxels_i, samples_d] + which is the fMRI data of the i'th subject in d. + 2) When it is a dict of dict of 2D arrays: + 'datasets' can be omitted in this case. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of + subject s in dataset d, where s is the name of the subject and + d is the name of the dataset. + + datasets : (optional) a Dataset object + The Dataset object containing datasets structure. + If not defined, the structure will be inferred from X. + + y : not used + """ + if self.comm.Get_rank() == 0: + self.logger.info('Starting Probabilistic MDMS') + + # Check if datasets is initialized + if datasets is not None and datasets.matrix is None: + raise NotFittedError('Dataset object is not initialized.') + + # Check X format + if type(X) != dict: + raise Exception('X should be a dict.') + format_X = type(next(iter(X.values()))) + if format_X != dict and format_X != list: + raise Exception('X should be a dict of dict of arrays or dict of list of arrays.') + if format_X == list and (datasets is None or datasets.built_from_data is None or datasets.built_from_data): + raise Exception("Argument 'datasets' must be defined and built from json " + "files when X is a dict of list of 2D arrays. ") + if format_X == dict and datasets is not None: + datasets.built_from_data = True + for v in X.values(): + if type(v) != format_X: + raise Exception('X should be a dict of dict of arrays or dict of list of arrays.') + + # Infer datasets structure from data + if datasets is None: + datasets = Dataset() + datasets.build_from_data(X) + + self.voxels_, self.samples_ = _sanity_check(X, datasets, self.comm) + + # Run MDMS + self.sigma_s_, self.w_, self.mu_, self.rho2_, self.s_ = self._mdms(X, datasets) + + return self + + + def transform(self, X, subjects, centered=True, y=None): + """Use the model to transform new data to Shared Response space + + Parameters + ---------- + X : list of 2D arrays, element i has shape=[voxels_i, samples_i] + Each element in the list contains the new fMRI data of one subject + + subjects : list of string, element i is the name of subject of X[i] + + centered : bool, if the data in X is already centered. + If centered = False, the voxel means computed during mode fitting + will be subtracted before transformation. + + y : not used (as it is unsupervised learning) + + + Returns + ------- + s : list of 2D arrays, element i has shape=[features_i, samples_i] + Shared responses from input data (X) + """ + + # Check if X and subjects have the same length + if len(X) != len(subjects): + raise ValueError("X and subjects must have the same length.") + + # Check if the model exist + if not hasattr(self, 'w_'): + raise NotFittedError("The model fit has not been run yet.") + + # Check if the subject exist in the fitted model and has the right number of voxels + for idx in range(len(X)): + if not subjects[idx] in self.w_: + raise NotFittedError("The model has not been fitted to subject {}.".format(subjects[idx])) + if X[idx] is not None and self.w_[subjects[idx]].shape[0] != X[idx].shape[0]: + raise ValueError("{}-th element of data has inconsistent number of" + "voxels with fitted model. Model has {} voxels while data has {}" + ".".format(idx, self.w_[subjects[idx]].shape[0], X[idx].shape[0])) + + s = [None] * len(X) + for idx in range(len(X)): + if X[idx] is not None: + if centered: + s[idx] = self.w_[subjects[idx]].T.dot(X[idx]) + else: + s[idx] = self.w_[subjects[idx]].T.dot(X[idx]-self.mu_[subjects[idx]][:, None]) + + return s + + + def _init_structures(self, data, datasets): + """Initializes data structures for MDMS and preprocess the data. + + + Parameters + ---------- + data : dict of list of 2D arrays or dict of dict of 2D arrays + 1) When it is a dict of list of 2D arrays: + 'datasets' must be defined in this case. + X[d] is a list of data of dataset d, where d is the name of the dataset. + Element i in the list has shape=[voxels_i, samples_d] + which is the fMRI data of the i'th subject in d. + 2) When it is a dict of dict of 2D arrays: + 'datasets' can be omitted in this case. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of + subject s in dataset d, where s is the name of the subject and + d is the name of the dataset. + + datasets : a Dataset object + The Dataset object containing datasets structures. + + + Returns + ------- + x : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] + where d is the name of the dataset and s is the name of the subject + Demeaned data for each subject. + + mu : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name + of the subject. + The voxel means over the samples in all datasets for each subject. + + rho2 : dict of dict of float, rho2_[d][s] is a float, where d is the name + of the dataset and s is the name of the subject. + The estimated noise variance :math:`\\rho_{di}^2` for each subject in each dataset. + + trace_xtx : dict of dict of float, trace_xtx[d][s] is a float, where + d is the name of the dataset and s is the name of the subject. + The squared Frobenius norm of the demeaned data in `x`. + """ + x = {} + mu = {} + rho2 = {} + trace_xtx = {} + + # re-arrange data to x + for ds_idx, ds in datasets.idx_to_dataset.items(): + x[ds] = {} + for subj in range(datasets.num_subj): + if datasets.dok_matrix[subj,ds_idx] != 0: + if datasets.built_from_data: + x[ds][datasets.idx_to_subject[subj]] = data[ds][datasets.idx_to_subject[subj]] + else: + x[ds][datasets.idx_to_subject[subj]] = data[ds][datasets.dok_matrix[subj,ds_idx]-1] + del data + + # compute mean + # collect mean from each MPI worker + weights = {} + mu_tmp = {} + for subj in datasets.subject_to_idx.keys(): + weights[subj], mu_tmp[subj] = {}, {} + for ds in x.keys(): + if subj in x[ds]: + if x[ds][subj] is not None: + mu_tmp[subj][ds] = np.mean(x[ds][subj], 1) + weights[subj][ds] = x[ds][subj].shape[1] + else: + mu_tmp[subj][ds] = np.zeros((self.voxels_[subj],)) + weights[subj][ds] = 0 + # collect mean from all MPI workers + for subj in datasets.subject_to_idx.keys(): + for ds in mu_tmp[subj].keys(): + mu_tmp[subj][ds] = self.comm.allreduce(mu_tmp[subj][ds], op=MPI.SUM) + weights[subj][ds] = self.comm.allreduce(weights[subj][ds], op=MPI.SUM) + # compute final mean + for subj in datasets.subject_to_idx.keys(): + mu[subj] = np.zeros((self.voxels_[subj],)) + nsample = np.sum(list(weights[subj].values())) + for ds in mu_tmp[subj].keys(): + mu[subj] += weights[subj][ds] * mu_tmp[subj][ds] / nsample + del weights, mu_tmp + + # subtract mean from x and compute trace_xtx, initialize rho2 + for ds in x.keys(): + rho2[ds], trace_xtx[ds] = {}, {} + for subj in x[ds].keys(): + rho2[ds][subj] = 1 + if x[ds][subj] is not None: + x[ds][subj] -= mu[subj][:,None] + trace_xtx[ds][subj] = np.sum(x[ds][subj] ** 2) + else: + trace_xtx[ds][subj] = 0 + + return x, mu, rho2, trace_xtx + + + def _likelihood(self, chol_sigma_s_rhos, log_det_psi, chol_sigma_s, + trace_xt_invsigma2_x, inv_sigma_s_rhos, wt_invpsi_x, + samples): + """Calculate the log-likelihood function of one dataset + + + Parameters + ---------- + + chol_sigma_s_rhos : array, shape=[features, features] + Cholesky factorization of the matrix (Sigma_S + sum_i(1/rho_i^2) + * I) + + log_det_psi : float + Determinant of diagonal matrix Psi (containing the rho_i^2 value + voxels_i times). + + chol_sigma_s : array, shape=[features, features] + Cholesky factorization of the matrix Sigma_S + + trace_xt_invsigma2_x : float + Trace of :math:`\\sum_i (||X_i||_F^2/\\rho_i^2)` + + inv_sigma_s_rhos : array, shape=[features, features] + Inverse of :math:`(\\Sigma_S + \\sum_i(1/\\rho_i^2) * I)` + + wt_invpsi_x : array, shape=[features, samples] + + samples : int + The total number of samples in the data. + + + Returns + ------- + + loglikehood : float + The log-likelihood value. + """ + log_det = (np.log(np.diag(chol_sigma_s_rhos) ** 2).sum() + log_det_psi + + np.log(np.diag(chol_sigma_s) ** 2).sum()) + loglikehood = -0.5 * samples * log_det - 0.5 * trace_xt_invsigma2_x + loglikehood += 0.5 * np.trace( + wt_invpsi_x.T.dot(inv_sigma_s_rhos).dot(wt_invpsi_x)) + + # + const --> -0.5*nTR*sum(voxel[subjects])*math.log(2*math.pi) + + return loglikehood + + @staticmethod + def _update_transform_subject(Xi, S): + """Updates the mappings `W_i` for one subject. + + Parameters + ---------- + + Xi : array, shape=[voxels, timepoints] + The fMRI data :math:`X_i` for aligning the subject. + + S : array, shape=[features, timepoints] + The shared response. + + Returns + ------- + + Wi : array, shape=[voxels, features] + The orthogonal transform (mapping) :math:`W_i` for the subject. + """ + A = Xi.dot(S.T) + # Solve the Procrustes problem + U, _, V = np.linalg.svd(A, full_matrices=False) + return U.dot(V) + + + def transform_subject(self, X, dataset): + """Transform a new subject using the existing model. + The subject is assumed to have received equivalent stimulation + of some dataset in the fitted model. + + Parameters + ---------- + + X : 2D array, shape=[voxels, timepoints] + The fMRI data of the new subject. + + dataset : string, name of the dataset in the fitted model that + has the same stimulation as the new subject + + Returns + ------- + + w : 2D array, shape=[voxels, features] + Orthogonal mapping `W_{new}` for new subject + + """ + # Check if the model exist + if not hasattr(self, 'w_'): + raise NotFittedError("The model fit has not been run yet.") + + # Check if the dataset is in the model + if not dataset in self.s_: + raise NotFittedError("Dataset {} is not in the model yet.".format(dataset)) + + # Check the number of TRs in the subject + if X.shape[1] != self.s_[dataset].shape[1]: + raise ValueError("The number of timepoints(TRs) does not match the" + "one in the model.") + + w = self._update_transform_subject(X, self.s_[dataset]) + + return w + + + def _mdms(self, data, datasets): + """Expectation-Maximization algorithm for fitting the probabilistic MDMS. + + Parameters + ---------- + + data : dict of list of 2D arrays or dict of dict of 2D arrays + 1) When it is a dict of list of 2D arrays: + 'datasets' must be defined in this case. + X[d] is a list of data of dataset d, where d is the name of the dataset. + Element i in the list has shape=[voxels_i, samples_d] + which is the fMRI data of the i'th subject in d. + 2) When it is a dict of dict of 2D arrays: + 'datasets' can be omitted in this case. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of + subject s in dataset d, where s is the name of the subject and + d is the name of the dataset. + + datasets : a Dataset object + The Dataset object containing datasets structures. + + + Returns + ------- + + sigma_s : dict of array, sigma_s[d] has shape=[features, features] where + d is the name of dataset. + The covariance :math:`\\Sigma_s` of the shared response Normal + distribution for each dataset. + + w : dict of array, w[s] has shape=[voxels_[s], features] where s is the name + of the subject. + The orthogonal transforms (mappings) :math:`W_s` for each subject. + + mu : dict of array, mu[s] has shape=[voxels_[s]] where s is the name + of the subject. + The voxel means :math:`\\mu_i` over the samples in all datasets + for each subject. + + rho2 : dict of dict of float, rho2[d][s] is a float, where d is the name + of the dataset and s is the name of the subject. + The estimated noise variance :math:`\\rho2{di}^2` for each subject + in each dataset. + + s : dict of array, s[d] has shape=[features, samples_[d]] where d is the + name of the dataset. + The shared response for each dataset. + """ + # get information from datasets structures + ds_list, subj_list = datasets.get_datasets_list(), datasets.get_subjects_list() + subj_ds_list = datasets.subjects_in_dataset_all() + ds_subj_list = datasets.datasets_with_subject_all() + + # initialize random states + self.random_state_ = np.random.RandomState(self.rand_seed) + random_states = { + subj_list[i] : np.random.RandomState(self.random_state_.randint(2 ** 32)) + for i in range(datasets.num_subj)} + + # assign ds to different ranks for parallel computing + rank = self.comm.Get_rank() + size = self.comm.Get_size() + + ds_rank = set() + if datasets.num_dataset <= size: + if rank < datasets.num_dataset: + ds_rank.add(ds_list[rank]) + else: + ds_rank_len = datasets.num_dataset // size + if rank != size - 1: + ds_rank.update(set(ds_list[ds_rank_len*rank:ds_rank_len*(rank+1)])) + else: + ds_rank.update(set(ds_list[ds_rank_len*rank:])) + + + # Initialization step: initialize the outputs with initial values + # and trace_xtx with the ||X_i||_F^2 of each subject in each dataset. + w = _init_w_transforms(self.voxels_, self.features, random_states, + datasets) + x, mu, rho2, trace_xtx = self._init_structures(data, datasets) + del data + # broadcast values in trace_xtx to all ranks + for subj in subj_list: + for ds in ds_subj_list[subj]: + trace_xtx[ds][subj] = self.comm.allreduce(trace_xtx[ds][subj], op=MPI.SUM) + + shared_response, sigma_s, rho0 = {}, {}, {} + for ds in ds_list: + shared_response[ds] = np.zeros((self.features, self.samples_[ds])) + if ds in ds_rank: + sigma_s[ds] = np.identity(self.features) + rho0[ds] = 0.0 + else: + sigma_s[ds] = np.zeros((self.features, self.features)) + + + # Main loop of the algorithm (run) + for iteration in range(self.n_iter): + if rank == 0: + self.logger.info('Iteration %d' % (iteration + 1)) + + # E-step and some M-step: update shared_response and sigma_s of each dataset + loglike = 0. + + # for multi-thread computation + chol_sigma_s, chol_sigma_s_rhos, inv_sigma_s_rhos = {}, {}, {} + wt_invpsi_x, trace_xt_invsigma2_x, trace_sigma_s = {}, {}, {} + for ds in ds_list: + chol_sigma_s[ds] = np.zeros((self.features, self.features)) + chol_sigma_s_rhos[ds] = np.zeros((self.features, self.features)) + inv_sigma_s_rhos[ds] = np.zeros((self.features, self.features)) + wt_invpsi_x[ds] = np.zeros((self.features, self.samples_[ds])) + trace_xt_invsigma2_x[ds] = 0.0 + trace_sigma_s[ds] = 0 + + # iterate through all ds in this rank + for ds in ds_rank: + # Sum the inverted the rho2 elements for computing W^T * Psi^-1 * W + rho0[ds] = np.sum([1/v for v in rho2[ds].values()]) + + # Invert Sigma_s[ds] using Cholesky factorization + (chol_sigma_s[ds], lower_sigma_s) = scipy.linalg.cho_factor( + sigma_s[ds], check_finite=False) + inv_sigma_s = scipy.linalg.cho_solve( + (chol_sigma_s[ds], lower_sigma_s), np.identity(self.features), + check_finite=False) + + # Invert (Sigma_s[ds] + rho_0 * I) using Cholesky factorization + sigma_s_rhos = inv_sigma_s + np.identity(self.features) * rho0[ds] + (chol_sigma_s_rhos[ds], lower_sigma_s_rhos) = \ + scipy.linalg.cho_factor(sigma_s_rhos, check_finite=False) + inv_sigma_s_rhos[ds] = scipy.linalg.cho_solve( + (chol_sigma_s_rhos[ds], lower_sigma_s_rhos), + np.identity(self.features), check_finite=False) + + # collect info from all ranks + for ds in ds_list: + chol_sigma_s[ds] = self.comm.allreduce(chol_sigma_s[ds], op=MPI.SUM) + chol_sigma_s_rhos[ds] = self.comm.allreduce(chol_sigma_s_rhos[ds], op=MPI.SUM) + inv_sigma_s_rhos[ds] = self.comm.allreduce(inv_sigma_s_rhos[ds], op=MPI.SUM) + + # Compute the sum of W_i^T * rho_i^-2 * X_i, and the sum of traces + # of X_i^T * rho_i^-2 * X_i + for ds in ds_list: + for subj in subj_ds_list[ds]: + if x[ds][subj] is not None: + wt_invpsi_x[ds] += (w[subj].T.dot(x[ds][subj])) / rho2[ds][subj] + trace_xt_invsigma2_x[ds] += trace_xtx[ds][subj] / rho2[ds][subj] + + # collect data from all ranks + for ds in ds_list: + wt_invpsi_x[ds] = self.comm.allreduce(wt_invpsi_x[ds], op=MPI.SUM) + trace_xt_invsigma2_x[ds] = self.comm.allreduce(trace_xt_invsigma2_x[ds], + op=MPI.SUM) + + # compute shared response and Sigma_s of ds in this rank + for ds in ds_list: + if ds in ds_rank: + log_det_psi = np.sum([np.log(rho2[ds][subj])*self.voxels_[subj] for subj in rho2[ds]]) + + # Update the shared response + shared_response[ds] = sigma_s[ds].dot( + np.identity(self.features) - rho0[ds] * inv_sigma_s_rhos[ds]).dot( + wt_invpsi_x[ds]) + + # Update Sigma_s and compute its trace + sigma_s[ds] = (inv_sigma_s_rhos[ds] + + shared_response[ds].dot(shared_response[ds].T) / self.samples_[ds]) + trace_sigma_s[ds] = self.samples_[ds] * np.trace(sigma_s[ds]) + + # calculate log likelihood to check convergence + loglike += self._likelihood( + chol_sigma_s_rhos[ds], log_det_psi, chol_sigma_s[ds], + trace_xt_invsigma2_x[ds], inv_sigma_s_rhos[ds], wt_invpsi_x[ds], + self.samples_[ds]) + + else: + shared_response[ds] = np.zeros((self.features, self.samples_[ds])) + sigma_s[ds] = np.zeros((self.features, self.features)) + trace_sigma_s[ds] = 0 + + # collect parameters from all ranks + for ds in ds_list: + shared_response[ds] = self.comm.allreduce(shared_response[ds], op=MPI.SUM) + trace_sigma_s[ds] = self.comm.allreduce(trace_sigma_s[ds], op=MPI.SUM) + sigma_s[ds] = self.comm.allreduce(sigma_s[ds], op=MPI.SUM) + + + # The rest of M-step: update w and rho2 + # Update each subject's mapping transform W_i and error variance + # rho_di^2 + for subj in subj_list: + # update w + a_subject = np.zeros((self.voxels_[subj], self.features)) + # use x data from all ranks + for ds in ds_subj_list[subj]: + if x[ds][subj] is not None: + a_subject += x[ds][subj].dot(shared_response[ds].T) + # collect a_subject from all ranks + a_subject = self.comm.allreduce(a_subject, op=MPI.SUM) + # compute w in one rank and broadcast + if rank == 0: + perturbation = np.zeros(a_subject.shape) + np.fill_diagonal(perturbation, 0.0001) + u_subject, _, v_subject = np.linalg.svd( + a_subject + perturbation, full_matrices=False) + w[subj] = u_subject.dot(v_subject) + else: + w[subj] = None + w[subj] = self.comm.bcast(w[subj], root=0) + # update rho2 + # compute trace_xtws_tmp of data in this rank + trace_xtws_tmp = {} + for ds in ds_subj_list[subj]: + if x[ds][subj] is not None: + trace_xtws_tmp[ds] = np.trace(x[ds][subj].T.dot(w[subj]).dot(shared_response[ds])) + else: + trace_xtws_tmp[ds] = 0.0 + # collect trace_xtws_tmp in all ranks + for ds in ds_subj_list[subj]: + trace_xtws_tmp[ds] = self.comm.allreduce(trace_xtws_tmp[ds], op=MPI.SUM) + # compute rho2 + if rank == 0: + for ds in ds_subj_list[subj]: + rho2[ds][subj] = trace_xtx[ds][subj] + rho2[ds][subj] += -2 * trace_xtws_tmp[ds] + rho2[ds][subj] += trace_sigma_s[ds] + rho2[ds][subj] /= self.samples_[ds] * self.voxels_[subj] + # broadcast to all ranks + for ds in ds_subj_list[subj]: + rho2[ds][subj] = self.comm.bcast(rho2[ds][subj], root=0) + + + # collect loglikelihood + loglike = self.comm.allreduce(loglike, op=MPI.SUM) + if rank == 0: + if self.logger.isEnabledFor(logging.INFO): + self.logger.info('Objective function %f' % loglike) + + return sigma_s, w, mu, rho2, shared_response + + + def save(self, file): + """Save the MDMS object to a file (as pickle) + + Parameters + ---------- + + file : The name (including full path) of the file that the object + will be saved to. + + Returns + ------- + + None + + Note + ---- + + The MPI communicator cannot be saved, so it will not be saved. When + restored, self.comm will be initialized to MPI.COMM_SELF + + """ + # get attributes from object + variables = self.__dict__.keys() + data = {k:getattr(self, k) for k in variables} + # remove attributes that cannot be pickled + del data['comm'] + del data['logger'] + if 'random_state_' in data: + del data['random_state_'] + # save attributes to file + with open(file, 'wb') as f: + pkl.dump(data, f, pkl.HIGHEST_PROTOCOL) + print ('MDMS object saved to {}.'.format(file)) + return + + + def restore(self, file): + """Restore the MDMS object from a (pickle) file + + Parameters + ---------- + + file : The name (including full path) of the file that the object + will be restored from. + + Returns + ------- + + None + + Note + ---- + + The MPI communicator cannot be saved, so self.comm is initialized to + MPI.COMM_SELF + + """ + # get attributes from file + with open(file, 'rb') as f: + data = pkl.load(f) + # set attributes to object + for (k, v) in data.items(): + setattr(self, k, v) + # set attributes that were not pickled + self.comm = MPI.COMM_SELF + self.random_state_ = np.random.RandomState(self.rand_seed) + self.logger = logger + print ('MDMS object restored from {}.'.format(file)) + return + + +class DetMDMS(BaseEstimator, TransformerMixin): + """Deterministic multi-dataset multi-subject (MDMS) SRM analysis (DetMDMS) + + Given multi-dataset multi-subject data, factorize it as a shared response S among all + subjects per dataset and an orthogonal transform W across all datasets per subject: + + .. math:: X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\dots M + + Parameters + ---------- + + n_iter : int, default: 10 + Number of iterations to run the algorithm. + + features : int, default: 50 + Number of features to compute. + + rand_seed : int, default: 0 + Seed for initializing the random number generator. + + comm : mpi4py.MPI.Intracomm + The MPI communicator containing the data + + Attributes + ---------- + + w_ : dict of array, w_[s] has shape=[voxels_[s], features], where + s is the name of the subject. + The orthogonal transforms (mappings) for each subject. + + s_ : dict of array, s_[d] has shape=[features, samples_[d]], where + d is the name of the dataset. + The shared response for each dataset. + + voxels_ : dict of int, voxels_[s] is number of voxels where s is the name + of the subject. + A dict with the number of voxels for each subject. + + samples_ : dict of int, samples_[d] is number of samples where d is the name + of the dataset. + A dict with the number of samples for each dataset. + + mu_ : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name + of the subject. + The voxel means over the samples in all datasets for each subject. + + random_state_: `RandomState` + Random number generator initialized using rand_seed + + comm : mpi4py.MPI.Intracomm + The MPI communicator containing the data + + Note + ---- + + The number of voxels may be different between subjects within a dataset + and number of samples may be different between datasets. However, the + number of samples must be the same across subjects within a dataset and + number of voxels must be the same across datasets for the same subject. + + The probabilistic multi-dataset multi-subject model is approximated using the + Block Coordinate Descent (BCD) algorithm proposed in [Zhang2018]_. + + The run-time complexity is :math:`O(I (V T K + V K^2))` and the memory + complexity is :math:`O(V T)` with I - the number of iterations, V - the + sum of number of voxels from all subjects, T - the sum of number of + samples from all datasets, K - the number of features (typically, + :math:`V \\gg T \\gg K`), and N - the number of subjects. + """ + + def __init__(self, n_iter=10, features=50, rand_seed=0, + comm=MPI.COMM_SELF): + self.n_iter = n_iter + self.features = features + self.rand_seed = rand_seed + self.comm = comm + self.logger = logger + return + + + def fit(self, X, datasets=None, demean=True, y=None): + """Compute the Deterministic Shared Response Model + + Parameters + ---------- + + X : dict of list of 2D arrays or dict of dict of 2D arrays + 1) When it is a dict of list of 2D arrays: + 'datasets' must be defined in this case. + X[d] is a list of data of dataset d, where d is the name of the dataset. + Element i in the list has shape=[voxels_i, samples_d] + which is the fMRI data of the i'th subject in d. + 2) When it is a dict of dict of 2D arrays: + 'datasets' can be omitted in this case. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of + subject s in dataset d, where s is the name of the subject and + d is the name of the dataset. + + datasets : (optional) a Dataset object + The Dataset object containing datasets structure. + If not defined, the structure will be inferred from X. + + demean : (optional) If True, compute voxel means for each subject + and subtract from data. If False, voxel means are set to zero + and data values are not changed. + + y : not used + """ + if self.comm.Get_rank() == 0: + self.logger.info('Starting Deterministic SRM') + + # Check if datasets is initialized + if datasets is not None and datasets.matrix is None: + raise NotFittedError('Dataset object is not initialized.') + + # Check X format + if type(X) != dict: + raise Exception('X should be a dict.') + format_X = type(next(iter(X.values()))) + if format_X != dict and format_X != list: + raise Exception('X should be a dict of dict of arrays or dict of list of arrays.') + if format_X == list and (datasets is None or datasets.built_from_data is None or datasets.built_from_data): + raise Exception("Argument 'datasets' must be defined and built from json " + "files when X is a dict of list of 2D arrays. ") + if format_X == dict and datasets is not None: + datasets.built_from_data = True + for v in X.values(): + if type(v) != format_X: + raise Exception('X should be a dict of dict of arrays or dict of list of arrays.') + + # Infer datasets structure from data + if datasets is None: + datasets = Dataset() + datasets.build_from_data(X) + + self.voxels_, self.samples_ = _sanity_check(X, datasets, self.comm) + + # Run MDMS + self.w_, self.s_, self.mu_ = self._mdms(X, datasets, demean) + + return self + + + def transform(self, X, subjects, centered=True, y=None): + """Use the model to transform new data to Shared Response space + + Parameters + ---------- + + X : list of 2D arrays, element i has shape=[voxels_i, samples_i] + Each element in the list contains the new fMRI data of one subject + + subjects : list of string, element i is the name of subject of X[i] + + centered : (optional) bool, if the data in X is already centered. + If centered = False, the voxel means computed during mode fitting + will be subtracted before transformation. + + y : not used (as it is unsupervised learning) + + + Returns + ------- + + s : list of 2D arrays, element i has shape=[features_i, samples_i] + Shared responses from input data (X) + """ + + # Check if X and subjects have the same length + if len(X) != len(subjects): + raise ValueError("X and subjects must have the same length.") + + # Check if the model exist + if not hasattr(self, 'w_'): + raise NotFittedError("The model fit has not been run yet.") + + # Check if the subject exist in the fitted model and has the right number of voxels + for idx in range(len(X)): + if not subjects[idx] in self.w_: + raise NotFittedError("The model has not been fitted to subject {}.".format(subjects[idx])) + if X[idx] is not None and self.w_[subjects[idx]].shape[0] != X[idx].shape[0]: + raise ValueError("{}-th element of data has inconsistent number of" + "voxels with fitted model. Model has {} voxels while data has {}" + ".".format(idx, self.w_[subjects[idx]].shape[0], X[idx].shape[0])) + + if not centered and self.mu_ is None: + raise Exception('Mean values are not computed during model fitting. ' + 'Please center the data to be transformed beforehand.') + + + s = [None] * len(X) + for idx in range(len(X)): + if X[idx] is not None: + if centered: + s[idx] = self.w_[subjects[idx]].T.dot(X[idx]) + else: + s[idx] = self.w_[subjects[idx]].T.dot(X[idx]-self.mu_[subjects[idx]][:, None]) + + return s + + + def _preprocess_data(self, data, datasets, demean): + """Preprocess and demean the data. + + + Parameters + ---------- + data : dict of list of 2D arrays or dict of dict of 2D arrays + 1) When it is a dict of list of 2D arrays: + 'datasets' must be defined in this case. + X[d] is a list of data of dataset d, where d is the name of the dataset. + Element i in the list has shape=[voxels_i, samples_d] + which is the fMRI data of the i'th subject in d. + 2) When it is a dict of dict of 2D arrays: + 'datasets' can be omitted in this case. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of + subject s in dataset d, where s is the name of the subject and + d is the name of the dataset. + + datasets : a Dataset object + The Dataset object containing datasets structures. + + demean : If True, compute voxel means for each subject + and subtract from data. If False, voxel means are set to zero + and data values are not changed. + + Returns + ------- + x : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] + where d is the name of the dataset and s is the name of the subject + Demeaned data for each subject. + + mu : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name + of the subject. + The voxel means over the samples in all datasets for each subject. + """ + x = {} + mu = {} + + # re-arrange data to x + for ds_idx, ds in datasets.idx_to_dataset.items(): + x[ds] = {} + for subj in range(datasets.num_subj): + if datasets.dok_matrix[subj,ds_idx] != 0: + if datasets.built_from_data: + x[ds][datasets.idx_to_subject[subj]] = data[ds][datasets.idx_to_subject[subj]] + else: + x[ds][datasets.idx_to_subject[subj]] = data[ds][datasets.dok_matrix[subj,ds_idx]-1] + del data + + # compute mean + if demean: + # collect mean from each MPI worker + weights = {} + mu_tmp = {} + for subj in datasets.subject_to_idx.keys(): + weights[subj], mu_tmp[subj] = {}, {} + for ds in x.keys(): + if subj in x[ds]: + if x[ds][subj] is not None: + mu_tmp[subj][ds] = np.mean(x[ds][subj], 1) + weights[subj][ds] = x[ds][subj].shape[1] + else: + mu_tmp[subj][ds] = np.zeros((self.voxels_[subj],)) + weights[subj][ds] = 0 + # collect mean from all MPI workers + for subj in datasets.subject_to_idx.keys(): + for ds in mu_tmp[subj].keys(): + mu_tmp[subj][ds] = self.comm.allreduce(mu_tmp[subj][ds], op=MPI.SUM) + weights[subj][ds] = self.comm.allreduce(weights[subj][ds], op=MPI.SUM) + # compute final mean + for subj in datasets.subject_to_idx.keys(): + mu[subj] = np.zeros((self.voxels_[subj],)) + nsample = np.sum(list(weights[subj].values())) + for ds in mu_tmp[subj].keys(): + mu[subj] += weights[subj][ds] * mu_tmp[subj][ds] / nsample + del weights, mu_tmp + + # subtract mean from x + for ds in x.keys(): + for subj in x[ds].keys(): + if x[ds][subj] is not None: + x[ds][subj] -= mu[subj][:,None] + + else: + mu = None + + return x, mu + + + def _objective_function(self, data, subj_ds_list, w, s, num_sample): + """Calculate the objective function + + Parameters + ---------- + + data : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] + where d is the name of the dataset and s is the name of the subject + Demeaned data for each subject. + + subj_ds_list : dict of list of string, subj_ds_list[d] is a list + of names of subjects in dataset d, where d is the name + of the subject. + + w : dict of array, w[s] has shape=[voxels_[s], features], where + s is the name of the subject. + The orthogonal transforms (mappings) for each subject. + + s : dict of array, s[d] has shape=[features, samples_[d]], where + d is the name of the dataset. + The shared response for each dataset. + + num_sample : int, total number of samples across all datasets and datasets + + Returns + ------- + + objective : float + The objective function value. + + Note + ---- + + In the multi nodes mode where data is scattered in different nodes, + objective needs to be reduced (summed) afterwards. + """ + objective = 0.0 + for ds in subj_ds_list.keys(): + for subj in subj_ds_list[ds]: + if data[ds][subj] is not None: + objective += \ + np.linalg.norm(data[ds][subj] - w[subj].dot(s[ds]), 'fro') ** 2 + + return 0.5 * objective / num_sample + + + def _compute_shared_response(self, data, subj_ds_list, w): + """ Compute the shared response S of all datasets + + Parameters + ---------- + + data : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] + where d is the name of the dataset and s is the name of the subject + Demeaned data for each subject. + + subj_ds_list : dict of list of string, subj_ds_list[d] is a list + of names of subjects in dataset d, where d is the name + of the subject. + + w : dict of array, w[s] has shape=[voxels_[s], features] where + s is the name of the subject. + The orthogonal transforms (mappings) for each subject. + + Returns + ------- + + s : dict of array, s[d] has shape=[features, samples_[d]] where + d is the name of the dataset. + The shared response for each dataset. + + Note + ---- + + In the multi nodes mode where data is scattered in different nodes, + s needs to be gathered afterwards. + + To get the final s, the returned s[d] needs to be devided by number + of subjects in dataset d. + """ + s = {} + for ds in subj_ds_list.keys(): + s[ds] = np.zeros((self.features, self.samples_[ds])) + for subj in subj_ds_list[ds]: + if data[ds][subj] is not None: + s[ds] += w[subj].T.dot(data[ds][subj]) + return s + + + @staticmethod + def _update_transform_subject(Xi, S): + """Updates the mappings `W_i` for one subject. + + Parameters + ---------- + + Xi : array, shape=[voxels, timepoints] + The fMRI data :math:`X_i` for aligning the subject. + + S : array, shape=[features, timepoints] + The shared response. + + Returns + ------- + + Wi : array, shape=[voxels, features] + The orthogonal transform (mapping) :math:`W_i` for the subject. + """ + A = Xi.dot(S.T) + # Solve the Procrustes problem + U, _, V = np.linalg.svd(A, full_matrices=False) + return U.dot(V) + + + def transform_subject(self, X, dataset): + """Transform a new subject using the existing model. + The subject is assumed to have received equivalent stimulation + of some dataset in the fitted model. + + Parameters + ---------- + + X : 2D array, shape=[voxels, timepoints] + The fMRI data of the new subject. + + dataset : string, name of the dataset in the fitted model that + has the same stimulation as the new subject + + Returns + ------- + + w : 2D array, shape=[voxels, features] + Orthogonal mapping `W_{new}` for new subject + + """ + # Check if the model exist + if not hasattr(self, 'w_'): + raise NotFittedError("The model fit has not been run yet.") + + # Check if the dataset is in the model + if not dataset in self.s_: + raise NotFittedError("Dataset {} is not in the model yet.".format(dataset)) + + # Check the number of TRs in the subject + if X.shape[1] != self.s_[dataset].shape[1]: + raise ValueError("The number of timepoints(TRs) does not match the" + "one in the model.") + + w = self._update_transform_subject(X, self.s_[dataset]) + + return w + + + def _mdms(self, data, datasets, demean): + """Block Coordinate Descent algorithm for fitting the deterministic MDMS. + + Parameters + ---------- + + data : dict of list of 2D arrays or dict of dict of 2D arrays + 1) When it is a dict of list of 2D arrays: + data[d] is a list of data of dataset d, where d is the name of the dataset. + Element i in the list has shape=[voxels_i, samples_[d]] + which is the fMRI data of the i'th subject in dataset d. + 2) When it is a dict of dict of 2D arrays: + data[d][s] has shape=[voxels_[s], samples_[d]], which is the fMRI data of + subject s in dataset d, where s is the name of the subject and + d is the name of the dataset. + + datasets : a Dataset object + The Dataset object containing datasets structure. + + demean : If True, compute voxel means for each subject + and subtract from data. If False, voxel means are set to zero + and data values are not changed. + + Returns + ------- + + w : dict of array, w[s] has shape=[voxels_[s], features], where + s is the name of the subject. + The orthogonal transforms (mappings) for each subject. + + s : dict of array, s[d] has shape=[features, samples_[d]], where + d is the name of the dataset. + The shared response for each dataset. + """ + + # get information from datasets structure + ds_list, subj_list = datasets.get_datasets_list(), datasets.get_subjects_list() + subj_ds_list = datasets.subjects_in_dataset_all() + ds_subj_list = datasets.datasets_with_subject_all() + num_sample = np.sum([datasets.num_subj_dataset[ds]*self.samples_[ds] for ds in ds_list]) + + # initialize random states + self.random_state_ = np.random.RandomState(self.rand_seed) + random_states = { + subj_list[i] : np.random.RandomState(self.random_state_.randint(2 ** 32)) + for i in range(datasets.num_subj)} + + rank = self.comm.Get_rank() + size = self.comm.Get_size() + + # Initialization step: + # 1) preprocess data + # 2) initialize the outputs with initial values + + w = _init_w_transforms(self.voxels_, self.features, random_states, + datasets) + x, mu = self._preprocess_data(data, datasets, demean) + del data + # compute shared_response from data in this rank + shared_response = self._compute_shared_response(x, subj_ds_list, w) + # collect shared_response data from all ranks + for ds in ds_list: + shared_response[ds] = self.comm.allreduce(shared_response[ds], op=MPI.SUM) + shared_response[ds] /= datasets.num_subj_dataset[ds] + + if self.logger.isEnabledFor(logging.INFO): + # Calculate the current objective function value + objective = self._objective_function(x, subj_ds_list, w, shared_response, num_sample) + objective = self.comm.allreduce(objective, op=MPI.SUM) + if rank == 0: + self.logger.info('Objective function %f' % objective) + + # Main loop of the algorithm + for iteration in range(self.n_iter): + if rank == 0: + self.logger.info('Iteration %d' % (iteration + 1)) + + # Update each subject's mapping transform W_s: + for subj in subj_list: + a_subject = np.zeros((self.voxels_[subj], self.features)) + # use x data from all ranks + for ds in ds_subj_list[subj]: + if x[ds][subj] is not None: + a_subject += x[ds][subj].dot(shared_response[ds].T) + # collect a_subject from all ranks + a_subject = self.comm.allreduce(a_subject, op=MPI.SUM) + # compute w in one rank and broadcast + if rank == 0: + perturbation = np.zeros(a_subject.shape) + np.fill_diagonal(perturbation, 0.0001) + u_subject, _, v_subject = np.linalg.svd( + a_subject + perturbation, full_matrices=False) + w[subj] = u_subject.dot(v_subject) + else: + w[subj] = None + w[subj] = self.comm.bcast(w[subj], root=0) + + # Update the each dataset's shared response S_d: + # compute shared_response from data in this rank + shared_response = self._compute_shared_response(x, subj_ds_list, w) + # collect shared_response data from all ranks + for ds in ds_list: + shared_response[ds] = self.comm.allreduce(shared_response[ds], op=MPI.SUM) + shared_response[ds] /= datasets.num_subj_dataset[ds] + + if self.logger.isEnabledFor(logging.INFO): + # Calculate the current objective function value + objective = self._objective_function(x, subj_ds_list, w, shared_response, num_sample) + objective = self.comm.allreduce(objective, op=MPI.SUM) + if rank == 0: + self.logger.info('Objective function %f' % objective) + + return w, shared_response, mu + + + def save(self, file): + """Save the DetMDMS object to a file (as pickle) + + Parameters + ---------- + + file : The name (including full path) of the file that the object + will be saved to. + + Returns + ------- + + None + + Note + ---- + + The MPI communicator cannot be saved, so it will not be saved. When + restored, self.comm will be initialized to MPI.COMM_SELF + + """ + # get attributes from object + variables = self.__dict__.keys() + data = {k:getattr(self, k) for k in variables} + # remove attributes that cannot be pickled + del data['comm'] + del data['logger'] + if 'random_state_' in data: + del data['random_state_'] + # save attributes to file + with open(file, 'wb') as f: + pkl.dump(data, f, pkl.HIGHEST_PROTOCOL) + print ('DetMDMS object saved to {}.'.format(file)) + return + + + def restore(self, file): + """Restore the DetMDMS object from a (pickle) file + + Parameters + ---------- + + file : The name (including full path) of the file that the object + will be restored from. + + Returns + ------- + + None + + Note + ---- + + The MPI communicator cannot be saved, so self.comm is initialized to + MPI.COMM_SELF + + """ + # get attributes from file + with open(file, 'rb') as f: + data = pkl.load(f) + # set attributes to object + for (k, v) in data.items(): + setattr(self, k, v) + # set attributes that were not pickled + self.comm = MPI.COMM_SELF + self.random_state_ = np.random.RandomState(self.rand_seed) + self.logger = logger + print ('DetMDMS object restored from {}.'.format(file)) + return + + +class Dataset(object): + """Datasets structure organizer + + Given multi-dataset multi-subject data or JSON files with subject names + in each dataset, infer datasets structure in different formats, such as + a graph where each dataset is a node and each edge is number of shared + subjects between the two datasets. + + This organizer is used in the MDMS or DetMDMS [Zhang2018]_ and can also + be used as a standalone datasets organizer. + + + Parameters + ---------- + + file : (optional) string, default: None + JSON file name (including full path) or folder name with JSON files. + + Each JSON file should contain a dict or a list of dict where each dict + has information of one dataset. Each dict must have 'dataset', + 'num_of_subj', and 'subjects' where 'dataset' is the name of the dataset, + 'num_of_subj' is the number of subjects in the dataset, and 'subjects' + is a list of strings with names of subjects in the dataset in the same + order as in the dataset. All datasets in all JSON files will be added + to the organizer. + + Example of a JSON file: + [{'dataset':'MyData','num_of_subj':3,'subjects':['Adam','Bob','Carol']}, + {'dataset':'MyData2','num_of_subj':2,'subjects':['Tom','Bob']}] + + + data : (optional) dict of dict of 2D array, default: None + Multi-dataset multi-subject data used to build the organizer. + + data[d][s] has shape=[voxels[s], samples[d]], where d is the name of + the dataset and s is the name of the subject. + + + Attributes + ---------- + + num_subj : int, + Total number of subjects + + num_dataset : int, + Total number of datasets + + dataset_to_idx : dict of int, dataset_to_idx[d] is the column index of dataset d + in self.matrix, where d is the name of the dataset. + Dataset name to column index of matrix, 0-indexed + + idx_to_dataset : dict of string, idx_to_dataset[i] is name of the dataset mapped + to the i'th column in self.matrix. + Column index of metrix to dataset name, 0-indexed + + subject_to_idx : dict of int, subject_to_idx[s] is the row index of subject s + in self.matrix, where s is the name of the subject. + Subject name to row index of matrix, 0-indexed + + idx_to_subject : dict of string, idx_to_subject[i] is name of the subject mapped + to the i'th row in self.matrix. + Row index to subject name, 0-indexed + + connected : list of list of string, each element is a list of name of connected + datasets (datasets can be connected through shared subjects). + + num_graph : int, + Number of connected dataset graphs + If 1, then all datasets are connected. + + adj_matrix : 2D csc sparse matrix of shape [num_dataset, num_dataset], + Weighted adjacency matrix of all datasets, where each node is a dataset and + weights on edges are number of shared subjects between the two datasets. + Mapping between dataset name and dataset index is in self.dataset_to_idx. + + num_subj_dataset : dict of int, num_subj_dataset[d] is an int where d is the name + of a dataset. + Number of subjects of each dataset + + subj_in_dataset : dict of list of string, subj_in_dataset[d] is a list of name + of subjects in dataset d in the same order as in d, where d is the name + of a dataset. If any subject is removed from the organizer, the name will + be replaced with None as a placeholder. + Name of subjects in each dataset + + matrix : 2D coo sparse matrix of shape [num_subj, num_dataset], + Dataset-subject membership matrix. + If built from JSON files, subject self.idx_to_subject[i] is the + self.matrix[i,j]'th subject in self.idx_to_dataset[j], 1-indexed + If built from multi-dataset multi-subject data, self.matrix[i,j] = 1 if + subject self.idx_to_subject[i] is in dataset self.idx_to_dataset[j]. + + dok_matrix : 2D dok sparse matrix of shape [num_subj, num_dataset], + Dataset-subject membership matrix. + It has the same content as self.matrix, but in Dictionary Of Keys format + for fast access of individual access. + + built_from_data : bool, + If the object is built from multi-dataset multi-subject data + If True, the object is built from data; if False, it is built from + JSON files. + + + Note + ---- + + Example usage can be found in BrainIAK MDMS example jupyter notebook. + + """ + + def __init__(self, file=None, data=None): + self.num_subj = 0 + self.num_dataset = 0 + self.dataset_to_idx = {} + self.idx_to_dataset = {} + self.subject_to_idx = {} + self.idx_to_subject = {} + self.connected = [] + self.num_graph = 0 + self.adj_matrix = None + self.num_subj_dataset = {} + self.subj_in_dataset = {} + self.matrix = None + self.dok_matrix = None + self.built_from_data = None + + if file is not None and data is not None: + raise Exception('Dataset object can only be built from data OR json files.') + + if file is not None: + self.add(file) + + if data is not None: + self.build_from_data(data) + return + + + def add(self, file): + """Add JSON file(s) to the organizer + + Parameters + ---------- + + file : string, default: None + JSON file name (including full path) or folder name with JSON files. + + Each JSON file should contain a dict or a list of dict where each dict + has information of one dataset. Each dict must have 'dataset', + 'num_of_subj', and 'subjects' where 'dataset' is the name of the dataset, + 'num_of_subj' is the number of subjects in the dataset, and 'subjects' + is a list of strings with names of subjects in the dataset in the same + order as in the dataset. All datasets in all JSON files will be added + to the organizer. If some datasets are already in the organizer, the + information of those datasets will be replaced with this new version. + + Example of a JSON file: + [{'dataset':'MyData','num_of_subj':3,'subjects':['Adam','Bob','Carol']}, + {'dataset':'MyData2','num_of_subj':2,'subjects':['Tom','Bob']}] + + Returns + ------- + + None + """ + # sanity check + if self.built_from_data is not None and self.built_from_data: + raise Exception('This Dataset object was already initialized with fMRI datasets.') + + # file can be json file name or folder name + # parse json filenames + if os.path.isfile(file): + # file + files = [file] + elif os.path.isdir(file): + # path + files = glob.glob(os.path.join(file, '*.json')) + if not files: + raise Exception('The path must contain JSON files.') + else: + raise Exception('Argument must be a filename or a path.') + + mem = [] # collect info of all datasets + for f in files: + tmp = json.load(open(f,'r')) + if type(tmp) == list: + # multiple datasets + mem.extend(tmp) + elif type(tmp) == dict: + # one dataset + mem.append(tmp) + else: + raise Exception('JSON file must be in list or dict format.') + + # separate datasets into new datasets and datasets to update + new_ds, new_sub, replace_ds, ds_dict = set(), set(), set(), {} + for m in mem: + # sanity check + if m['num_of_subj'] <= 0: + raise Exception('Number of subjects in dataset ' + m['dataset'] + ' must be positive.') + if m['num_of_subj'] != len(m['subjects']): + raise Exception('Number of subjects in dataset ' + m['dataset'] + ' does not agree.') + if m['dataset'] in new_ds or m['dataset'] in replace_ds: + raise Exception('Dataset ' + m['dataset'] + ' appears more than once.') + if len(m['subjects']) != len(set(m['subjects'])): + raise Exception('Dataset ' + m['dataset'] + ' has duplicate subjects.') + + # if the dataset is already in the matrix + if m['dataset'] in self.dataset_to_idx: + replace_ds.add(m['dataset']) + else: + new_ds.add(m['dataset']) + + # save subjects info into a dict + ds_dict[m['dataset']] = m['subjects'] + + # add new subjects in this dataset + for subj in m['subjects']: + if not subj in self.subject_to_idx: + new_sub.add(subj) + + # add number of subjects info if mem passes all the sanity check + for m in mem: + self.num_subj_dataset[m['dataset']] = m['num_of_subj'] + + del mem + + # construct or update the matrix + if self.matrix is None: + # construct a new matrix + self._construct_matrix(new_ds, new_sub, ds_dict) + else: + # add new datasets + self._add_new_dataset(new_ds, new_sub, ds_dict) + if replace_ds: + # replace some old datasets + self._replace_dataset(replace_ds, ds_dict) + self._compute_connected() + + self.built_from_data = False + + return + + + def build_from_data(self, data): + """Use multi-dataset multi-subject data to initialize the organizer + + Parameters + ---------- + + data : dict of dict of 2D array + Multi-dataset multi-subject data used to build the organizer. + data[d][s] has shape=[voxels[s], samples[d]], where d is the name of + the dataset and s is the name of the subject. + + Returns + ------- + + None + """ + # sanity check + if self.built_from_data is not None and not self.built_from_data: + raise Exception('This Dataset object was already initialized with JSON files.') + + # find out which datasets and subjects are in the data + if not type(data) == dict: + raise Exception('To build Dataset object from data, data must be a dict of dict ' + 'where data[d][s] is the fMRI data of dataset d and subject s.') + datasets = set(data.keys()) + subjects = set() + for ds in data: + if not type(data[ds]) == dict: + raise Exception('To build Dataset object from data, data must be a dict of dict ' + 'where data[d][s] is the fMRI data of dataset d and subject s.') + subjects.update(set(data[ds].keys())) + + # set attributes + self.num_dataset = len(datasets) + self.num_subj = len(subjects) + + for idx, subj in enumerate(subjects): + self.subject_to_idx[subj] = idx + self.idx_to_subject[idx] = subj + for idx, ds in enumerate(datasets): + self.dataset_to_idx[ds] = idx + self.idx_to_dataset[idx] = ds + + for ds in datasets: + self.num_subj_dataset[ds] = len(data[ds]) + self.subj_in_dataset[ds] = list(data[ds].keys()) + + # fill in sparse matrix + coo_data, row, col = [], [], [] + for ds in datasets: + col_idx = self.dataset_to_idx[ds] + for subj in data[ds].keys(): + coo_data.append(1) + col.append(col_idx) + row.append(self.subject_to_idx[subj]) + self.matrix = sp.coo_matrix((coo_data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.dok_matrix = self.matrix.todok(copy=True) + # compute connectivity + self._compute_connected() + + self.built_from_data = True + return + + + def remove_dataset(self, datasets): + """Remove some datasets from the organizer + + Parameters + ---------- + + datasets : set or list of string, each element is name of a dataset + Name of datasets to be removed + + Returns + ------- + + removed_subjects : list of string, each element is name of a subject + Name of subjects removed because of the removal of datasets. + """ + # sanity check + for ds in datasets: + if not ds in self.dataset_to_idx: + raise Exception('Dataset '+ ds + ' does not exist.') + + # extract data from the sparse matrix + data, row, col = self.matrix.data.tolist(), self.matrix.row.tolist(), self.matrix.col.tolist() + + # remove datasets from data + data, row, col, subj_to_check = self._remove_datasets_from_data(datasets, data, row, col) + + # if all datasets are removed + if not data: + removed_subjects = list(self.subject_to_idx.keys()) + self.reset() + return removed_subjects + + # find subjects not in any dataset + removed_subjects = [] + for subj in subj_to_check: + if not self.subject_to_idx[subj] in row: + removed_subjects.append(subj) + + # re-arrange subject indices + row = self._remove_subjects_by_re_indexing(removed_subjects, row) + + # re-arrange dataset indices + col = self._remove_datasets_by_re_indexing(datasets, col) + + # re-construct the matrix + self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.dok_matrix = self.matrix.todok(copy=True) + + # compute connectivity + self._compute_connected() + + return removed_subjects + + + def remove_subject(self, subjects): + """Remove some subjects from the organizer + + Parameters + ---------- + + subjects : set or list of string, each element is name of a subject + Name of subjects to be removed + + Returns + ------- + + removed_datasets : list of string, each element is name of a dataset + Name of datasets removed because of the removal of subjects. + """ + # sanity check + for subj in subjects: + if not subj in self.subject_to_idx: + raise Exception('Subject ' + subj + ' does not exist.') + + # extract data from the sparse matrix + data, row, col = self.matrix.data.tolist(), self.matrix.row.tolist(), self.matrix.col.tolist() + + # remove subjects from data + data, row, col = self._remove_subjects_from_data(subjects, data, row, col) + + # if all subjects are removed + if not data: + removed_datasets = list(self.dataset_to_idx.keys()) + self.reset() + return removed_datasets + + # find datasets without any subject + removed_datasets = [] + for (k,v) in self.num_subj_dataset.items(): + if not v: + removed_datasets.append(k) + for k in removed_datasets: + del self.num_subj_dataset[k] # remove from num_subj_dataset + del self.subj_in_dataset[k] + + # re-arrange subject indices + row = self._remove_subjects_by_re_indexing(subjects, row) + + # re-arrange dataset indices + col = self._remove_datasets_by_re_indexing(removed_datasets, col) + + # re-construct the matrix + self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.dok_matrix = self.matrix.todok(copy=True) + + # compute connectivity + self._compute_connected() + + return removed_datasets + + + def num_shared_subjects_between_datasets(self, ds1, ds2): + """Get number of shared subjects (subjects in both ds1 and ds2) + between two datasets (ds1 and ds2) + + Parameters + ---------- + + ds1, ds2 : string, + Name of two datasets + + Returns + ------- + + num_shared : int, + Number of shared subjects between ds1 and ds2 + """ + # sanity check + for ds in [ds1, ds2]: + if not ds in self.dataset_to_idx: + raise Exception('Dataset ' + ds + 'does not exist.') + # find number of shared subjects + idx1, idx2 = self.dataset_to_idx[ds1], self.dataset_to_idx[ds2] + return self.adj_matrix[idx1, idx2] + + + def shared_subjects_between_datasets(self, ds1, ds2): + """Get name of shared subjects (subjects in both ds1 and ds2) + between two datasets (ds1 and ds2) + + Parameters + ---------- + + ds1, ds2 : string, + Name of two datasets + + Returns + ------- + + shared : list of string, + Name of subjects shared between ds1 and ds2 + """ + # sanity check + for ds in [ds1, ds2]: + if not ds in self.dataset_to_idx: + raise Exception('Dataset ' + ds + 'does not exist.') + if self.matrix is None: + raise Exception('Dataset object not initialized.') + # find shared subjects + matrix_csc = self.matrix.tocsc(copy=True) + subj1 = set(matrix_csc[:,self.dataset_to_idx[ds1]].indices) # indices of subjects in ds1 + subj2 = set(matrix_csc[:,self.dataset_to_idx[ds2]].indices) # indices of subjects in ds2 + return [self.idx_to_subject[subj] for subj in subj1.intersection(subj2)] + + + def datasets_with_subject(self, subj): + """Get all datasets with some subject ('subj') + + Parameters + ---------- + + subj : string, + Name of the subject + + Returns + ------- + + datasets : list of string, + Name of datasets with subject 'subj' + """ + # sanity check + if not subj in self.subject_to_idx: + raise Exception('Subject ' + subj + 'does not exist.') + if self.matrix is None: + raise Exception('Dataset object not initialized.') + # find datasets with subject + matrix_csr = self.matrix.tocsr(copy=True) + indices = matrix_csr[self.subject_to_idx[subj],:].indices + return [self.idx_to_dataset[ds] for ds in indices] + + + def datasets_with_subject_all(self): + """For each subject, get a list of datasets with that subject + + Parameters + ---------- + + None + + Returns + ------- + + ds_subj_list : dict of list of string, ds_subj_list[s] is a list where s + is the name of a subject + List of datasets with subject s for each subject s + """ + if self.matrix is None: + raise Exception('Dataset object not initialized.') + ds_subj_list = {} + matrix_csr = self.matrix.tocsr(copy=True) + for subj in range(self.num_subj): + subj_name = self.idx_to_subject[subj] + indices = matrix_csr[subj,:].indices + ds_subj_list[subj_name] = [self.idx_to_dataset[ds] for ds in indices] + return ds_subj_list + + + def subjects_in_dataset(self, dataset): + """Get all subjects in some dataset ('dataset') + + Parameters + ---------- + + dataset : string, + Name of the dataset + + Returns + ------- + + subjects : list of string, + Name of subjects in dataset 'dataset' + """ + #sanity check + if not dataset in self.dataset_to_idx: + raise Exception('Dataset ' + dataset + 'does not exist.') + if self.matrix is None: + raise Exception('Dataset object not initialized.') + # find subjects in dataset + matrix_csc = self.matrix.tocsc(copy=True) + indices = matrix_csc[:,self.dataset_to_idx[dataset]].indices + return [self.idx_to_subject[subj] for subj in indices] + + + def subjects_in_dataset_all(self): + """For each dataset, get a list of subjects in that dataset + + Parameters + ---------- + + None + + Returns + ------- + + subj_ds_list : dict of list of string, subj_ds_list[d] is a list where d + is the name of a dataset + List of subjects in dataset d for each dataset d + """ + if self.matrix is None: + raise Exception('Dataset object not initialized.') + subj_ds_list = {} + matrix_csc = self.matrix.tocsc(copy=True) + for ds in range(self.num_dataset): + ds_name = self.idx_to_dataset[ds] + indices = matrix_csc[:,ds].indices + subj_ds_list[ds_name] = [self.idx_to_subject[subj] for subj in indices] + return subj_ds_list + + + def get_subjects_list(self): + """Get a list of all subjects in the organizer + + Parameters + ---------- + + None + + Returns + ------- + + subj_list : list of string, + Name of all subjects in the organizer + """ + return list(self.subject_to_idx.keys()) + + + def get_datasets_list(self): + """Get a list of all datasets in the organizer + + Parameters + ---------- + + None + + Returns + ------- + + ds_list : list of string, + Name of all datasets in the organizer + """ + return list(self.dataset_to_idx.keys()) + + + def visualize_graph(self, font_size=14): + """Visualize the organizer as a graph where each node is a dataset + and the edge is number of shared subjects between the two datasets + + Parameters + ---------- + + font_size : (optional) float, default = 14 + Font size of labels in the graph + + Returns + ------- + + None + """ + if self.adj_matrix is None: + raise Exception('Dataset object not initialized.') + # build graph from adjacency matrix + G = nx.from_numpy_matrix(self.adj_matrix.toarray()) + # assign edge labels + edge_labels=dict([((u,v,),self.adj_matrix[u,v]) for u,v in G.edges]) + pos=nx.spring_layout(G) + nx.draw(G, pos=pos,with_labels=False) + labels=nx.draw_networkx_labels(G,labels = self.idx_to_dataset, pos=pos, font_size=font_size) + edge_labels=nx.draw_networkx_edge_labels(G,edge_labels=edge_labels, pos=pos, font_size=font_size) + plt.show() + return + + + def reset(self): + """Reset all attributes in the organizer + + Parameters + ---------- + + None + + Returns + ------- + + None + """ + self.num_subj = 0 + self.num_dataset = 0 + self.dataset_to_idx = {} + self.idx_to_dataset = {} + self.subject_to_idx = {} + self.idx_to_subject = {} + self.connected = [] + self.num_graph = 0 + self.adj_matrix = None + self.num_subj_dataset = {} + self.subj_in_dataset = {} + self.matrix = None + self.adj_matrix = None + self.built_from_data = None + return + + + def save(self, file): + """Save the Dataset object to a file (as pickle) + + Parameters + ---------- + + file : The name (including full path) of the file that the object + will be saved to. + + Returns + ------- + + None + """ + # get attributes from object + variables = self.__dict__.keys() + data = {k:getattr(self, k) for k in variables} + # save attributes to file + with open(file, 'wb') as f: + pkl.dump(data, f, pkl.HIGHEST_PROTOCOL) + print ('Dataset object saved to {}.'.format(file)) + return + + + def restore(self, file): + """Restore the Dataset object from a (pickle) file + + Parameters + ---------- + + file : The name (including full path) of the file that the object + will be restored from. + + Returns + ------- + + None + """ + # get attributes from file + with open(file, 'rb') as f: + data = pkl.load(f) + # set attributes to object + for (k, v) in data.items(): + setattr(self, k, v) + print ('Dataset object restored from {}.'.format(file)) + return + + + def _compute_connected(self): + """Compute the weighted adjacency matrix and connectivity + + Parameters + ---------- + + None + + Returns + ------- + + None + """ + # build the weighted adjacency matrix (how many shared subjects between each pair of datasets) + matrix_csc = self.matrix.tocsc(copy=True) + row, col, data = [], [], [] + for i in range(self.num_dataset): + for j in range(i+1, self.num_dataset): + tmp = matrix_csc[:,i].multiply(matrix_csc[:,j]).nnz + if tmp != 0: + row.extend([i,j]) + col.extend([j,i]) + data.extend([tmp, tmp]) + self.adj_matrix = sp.csc_matrix((data, (row, col)),shape=(self.num_dataset, self.num_dataset)) + + # find out which datasets are connected + not_connected = set(range(self.num_dataset)) + connected = [] + dq = set() + for idx in range(self.num_dataset): + if idx in not_connected: + tmp = [] + dq.add(idx) + while dq: + n = dq.pop() + not_connected.remove(n) + tmp.append(n) + for neighbor in self.adj_matrix[:,n].indices: + if neighbor in not_connected: + dq.add(neighbor) + if not dq: + connected.append(tmp) + + # convert connected datasets from idx to dataset names + self.connected = [] + for idx, graph in enumerate(connected): + self.connected.append([]) + for node in graph: + self.connected[idx].append(self.idx_to_dataset[node]) + + # count number of connected graphs + self.num_graph = len(self.connected) + + return + + + def _construct_matrix(self, new_ds, new_sub, ds_dict): + """Initialize the organizer with some datasets and subjects + + Parameters + ---------- + + new_ds : set or list of string, + Name of all new datasets to add + + new_sub : set or list of string, + Name of all new subjects to add + + ds_dict : dict of list of string, ds_dict[d] is a list of subject names + in dataset d in the same order as in the dataset, where d is the name of + the dataset + + Returns + ------- + + None + """ + # fill in datasets and subjects info + self.num_subj = len(new_sub) + self.num_dataset = len(new_ds) + for idx, subj in enumerate(new_sub): + self.subject_to_idx[subj] = idx + self.idx_to_subject[idx] = subj + for idx, ds in enumerate(new_ds): + self.dataset_to_idx[ds] = idx + self.idx_to_dataset[idx] = ds + + # fill in sparse matrix + data, row, col = [], [], [] + for ds in new_ds: + self.subj_in_dataset[ds] = ds_dict[ds] + col_idx = self.dataset_to_idx[ds] + for idx, subj in enumerate(ds_dict[ds]): + data.append(idx+1) + col.append(col_idx) + row.append(self.subject_to_idx[subj]) + self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.dok_matrix = self.matrix.todok(copy=True) + + # compute connectivity + self._compute_connected() + return + + + def _add_new_dataset(self, new_ds, new_sub, ds_dict): + """Add some new datasets into the organizer when the organizer was + already initialized and the new datasets are not in it yet. + + Parameters + ---------- + + new_ds : set or list of string, + Name of all new datasets to add + + new_sub : set or list of string, + Name of all new subjects to add + + ds_dict : dict of list of string, ds_dict[d] is a list of subject names + in dataset d in the same order as in the dataset, where d is the name of + the dataset + + Returns + ------- + + None + """ + # fill in new datasets and subjects info + for idx, subj in enumerate(new_sub): + self.subject_to_idx[subj] = self.num_subj + idx + self.idx_to_subject[self.num_subj + idx] = subj + for idx, ds in enumerate(new_ds): + self.dataset_to_idx[ds] = self.num_dataset + idx + self.idx_to_dataset[self.num_dataset + idx] = ds + self.num_subj += len(new_sub) + self.num_dataset += len(new_ds) + + # fill in sparse matrix + data, row, col = self.matrix.data.tolist(), self.matrix.row.tolist(), self.matrix.col.tolist() + for ds in new_ds: + self.subj_in_dataset[ds] = ds_dict[ds] + col_idx = self.dataset_to_idx[ds] + for idx, subj in enumerate(ds_dict[ds]): + data.append(idx+1) + col.append(col_idx) + row.append(self.subject_to_idx[subj]) + self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.dok_matrix = self.matrix.todok(copy=True) + return + + + def _replace_dataset(self, replace_ds, ds_dict): + """Replace information of some datasets with information in ds_dict assuming + those datasets are already in the organizer + + Parameters + ---------- + + replace_ds : set or list of string, + Name of all datasets to replace + + ds_dict : dict of list of string, ds_dict[d] is a list of subject names + in dataset d in the same order as in the dataset, where d is the name of + the dataset + + Returns + ------- + + None + """ + # extract data from the sparse matrix + data, row, col = self.matrix.data.tolist(), self.matrix.row.tolist(), self.matrix.col.tolist() + + # remove data of datasets to be replaced from the coo sparse matrix + data, row, col, subj_to_check = self._remove_datasets_from_data(replace_ds, data, row, col) + + # add data of datasets to replace + for ds in replace_ds: + self.subj_in_dataset[ds] = ds_dict[ds] + col_idx = self.dataset_to_idx[ds] + for idx, subj in enumerate(ds_dict[ds]): + data.append(idx+1) + col.append(col_idx) + row.append(self.subject_to_idx[subj]) + subj_to_check.discard(subj) + + # finalize subj to remove (subjects not in any datasets) + subj_to_remove = [] + for subj in subj_to_check: + if not self.subject_to_idx[subj] in row: + subj_to_remove.append(subj) + + # remove those subjects and re-arrange subject indices + row = self._remove_subjects_by_re_indexing(subj_to_remove, row) + + # re-construct the matrix + self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.dok_matrix = self.matrix.todok(copy=True) + return + + + def _remove_subjects_from_data(self, subjects, data, row, col): + """Remove some subjects by deleting their data + + Parameters + ---------- + + subjects : set or list of string, + Name of subjects to be removed + + data, row, col : list of int, + Data extracted from sparse matrix self.matrix + + Returns + ------- + + data, row, col : list of int, + Data can be used to construct a sparse matrix after removal of those + subjects + + Note + ---- + + Subjects are not re-indexed. Need to call _remove_subjects_by_re_indexing() + afterwards to re-index. + """ + len_data = len(data) + row_to_remove = set() # subject indices (row indices) to remove + subjects = set(subjects) + for subj in subjects: + row_to_remove.add(self.subject_to_idx[subj]) + idx_to_remove = [] # data indices to remove from data, row, col lists + for idx, row_idx in enumerate(row): + if row_idx in row_to_remove: + idx_to_remove.append(idx) + self.num_subj_dataset[self.idx_to_dataset[col[idx]]] -= 1 + for ds in self.subj_in_dataset.keys(): + for idx in range(len(self.subj_in_dataset[ds])): + if self.subj_in_dataset[ds][idx] in subjects: + self.subj_in_dataset[ds][idx] = None + + # remove data + data = [data[i] for i in range(len_data) if not i in idx_to_remove] + row = [row[i] for i in range(len_data) if not i in idx_to_remove] + col = [col[i] for i in range(len_data) if not i in idx_to_remove] + + return data, row, col + + + def _remove_datasets_from_data(self, datasets, data, row, col): + """Remove some datasets by deleting their data + + Parameters + ---------- + + datasets : set or list of string, + Name of datasets to be removed + + data, row, col : list of int, + Data extracted from sparse matrix self.matrix + + Returns + ------- + + data, row, col : list of int, + Data can be used to construct a sparse matrix after removal of those + datasets + + subj_to_check : set of string, + Name of subjects that are possibly not in any datasets (and thus need + to be removed) after removal of those datasets + + Note + ---- + + Datasets are not re-indexed. Need to call _remove_datasets_by_re_indexing() + afterwards to re-index. + """ + len_data = len(data) + col_to_remove = set() # dataset indices (column indices) to remove + for ds in datasets: + col_to_remove.add(self.dataset_to_idx[ds]) + idx_to_remove = [] # data indices to remove from data, row, col lists + subj_to_check = set() # possible subject indices to remove after removing datasets + for idx, col_idx in enumerate(col): + if col_idx in col_to_remove: + idx_to_remove.append(idx) + subj_to_check.add(self.idx_to_subject[row[idx]]) + # remove info in dict + for ds in datasets: + del self.subj_in_dataset[ds] + del self.num_subj_dataset[ds] + # remove data + data = [data[i] for i in range(len_data) if not i in idx_to_remove] + row = [row[i] for i in range(len_data) if not i in idx_to_remove] + col = [col[i] for i in range(len_data) if not i in idx_to_remove] + + return data, row, col, subj_to_check + + + def _remove_subjects_by_re_indexing(self, subjects, row): + """Re-index all subjects after removal of data of some subjects + so that the subject indexing are still contiguous. + + Parameters + ---------- + + subjects : set or list of string, + Name of subjects where their data in self.matrix are removed + already and need to be removed from indexing + + row : list of int, row indices as in a sparse matrix + Row (subject) indices before re-indexing of subjects + + Returns + ------- + + row : list of int, row indices as in a sparse matrix + Row (subject) indices after re-indexing of subjects + + Note + ---- + + Data of subjects 'subjects' must be removed already. If not, + need to call _remove_subjects_from_data() beforehand + """ + # remaining subjects after moving 'subjects' + remained = set(self.subject_to_idx.keys()) - set(subjects) + # re-indexing + new_subject_to_idx, new_idx_to_subject = {}, {} + for idx, subj in enumerate(remained): + new_idx_to_subject[idx] = subj + new_subject_to_idx[subj] = idx + # map indices based on new indexing + for idx, r in enumerate(row): + subj = self.idx_to_subject[r] + new_r = new_subject_to_idx[subj] + row[idx] = new_r + # update mapping + self.subject_to_idx, self.idx_to_subject = new_subject_to_idx, new_idx_to_subject + # update total number of subjects + self.num_subj -= len(subjects) + return row + + + def _remove_datasets_by_re_indexing(self, datasets, col): + """Re-index all datasets after removal of data of some datasets + so that the dataset indexing are still contiguous. + + Parameters + ---------- + + datasets : set or list of string, + Name of datasets where their data in self.matrix are removed + already and need to be removed from indexing + + col : list of int, col indices as in a sparse matrix + Col (dataset) indices before re-indexing of datasets + + Returns + ------- + + col : list of int, col indices as in a sparse matrix + Col (dataset) indices after re-indexing of datasets + + Note + ---- + + Data of datasets 'datasets' must be removed already. If not, + need to call _remove_datasets_from_data() beforehand + """ + # remaining datasets after moving 'datasets' + remained = set(self.dataset_to_idx.keys()) - set(datasets) + # re-indexing + new_dataset_to_idx, new_idx_to_dataset = {}, {} + for idx, ds in enumerate(remained): + new_idx_to_dataset[idx] = ds + new_dataset_to_idx[ds] = idx + # map indices based on new indexing + for idx, c in enumerate(col): + ds = self.idx_to_dataset[c] + new_c = new_dataset_to_idx[ds] + col[idx] = new_c + # update mapping + self.dataset_to_idx, self.idx_to_dataset = new_dataset_to_idx, new_idx_to_dataset + # update total number of datasets + self.num_dataset -= len(datasets) + return col + + + + + diff --git a/docs/newsfragments/mdms.jinja b/docs/newsfragments/mdms.jinja new file mode 100644 index 000000000..72dc6150c --- /dev/null +++ b/docs/newsfragments/mdms.jinja @@ -0,0 +1,34 @@ +{% for section, _ in sections|dictsort(by='key') %} +{% set underline = "-" %}{% if section %}{{section}} +{{ underline * section|length }}{% set underline = "~" %} + +{% endif %} + +{% if sections[section] %} +{% for category in definitions if category in sections[section]%} +{{ definitions[category]['name'] }} +{{ underline * definitions[category]['name']|length }} + +{% if definitions[category]['showcontent'] %} +{% for text, values in sections[section][category]|dictsort(by='value') %} +{% set issue_joiner = joiner(', ') %} +- {{ text }} ({% for value in values|sort %}{{ issue_joiner() }}`{{ value }} `_{% endfor %}) +{% endfor %} + +{% else %} +- {{ sections[section][category]['']|sort|join(', ') }} + +{% endif %} +{% if sections[section][category]|length == 0 %} +srm: Add multi-dataset multi-subject (MDMS) SRM analysis method and examples. + +{% else %} +{% endif %} + +{% endfor %} +{% else %} +srm: Add multi-dataset multi-subject (MDMS) SRM analysis method and examples. + + +{% endif %} +{% endfor %} diff --git a/examples/funcalign/download-data.sh b/examples/funcalign/download-data.sh index 7ea1a030c..070208930 100755 --- a/examples/funcalign/download-data.sh +++ b/examples/funcalign/download-data.sh @@ -2,4 +2,6 @@ curl --location --create-dirs -o data/movie_data.mat https://www.dropbox.com/s/2 curl --location --create-dirs -o data/label.mat https://www.dropbox.com/s/ogd26q6fro4l2d2/label.mat?dl=0 curl --location --create-dirs -o data/image_data.mat https://www.dropbox.com/s/l818vr6o8huatxj/image_data.mat?dl=0 curl --location --create-dirs -o data/sl_movie_data.mat https://www.dropbox.com/s/2ahgqgu5ljmqbw5/movie_data.mat?dl=0 -curl --location --create-dirs -o data/MNI152_T1_3mm_brain_mask.nii https://www.dropbox.com/s/3zk78ok9wtd3v3o/MNI152_T1_3mm_brain_mask.nii?dl=0 \ No newline at end of file +curl --location --create-dirs -o data/MNI152_T1_3mm_brain_mask.nii https://www.dropbox.com/s/3zk78ok9wtd3v3o/MNI152_T1_3mm_brain_mask.nii?dl=0 +curl --location --create-dirs -o data/multi_dataset.json https://www.dropbox.com/s/9lzqt4fybngqynm/multi_dataset.json?dl=0 +curl --location --create-dirs -o data/multi_dataset.pickle https://www.dropbox.com/s/8l9qgazux8t44cu/multi_dataset.pickle?dl=0 \ No newline at end of file diff --git a/examples/funcalign/mdms_time_segment_matching_distributed.py b/examples/funcalign/mdms_time_segment_matching_distributed.py new file mode 100644 index 000000000..61453cef5 --- /dev/null +++ b/examples/funcalign/mdms_time_segment_matching_distributed.py @@ -0,0 +1,200 @@ +# Copyright 2016 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributed Multi-dataset multi-subject (MDMS) SRM analysis Example. + +This example runs MDMS on time segment matching experiment. +To get a better understanding of the code, please look at +mdms_time_segment_matching_example.ipynb first. + +Example Usage +------- +If run 4 ranks: + $ mpirun -n 4 python3 mdms_time_segment_matching_distributed.py + +Author +------- +Hejia Zhang (Princeton University ELE Department) + +Notes +------- +It's an implementation of: +.. [Zhang2018] "Transfer learning on fMRI datasets", + H. Zhang, P.-H. Chen, P. Ramadge + The 21st International Conference on Artificial Intelligence and Statistics (AISTATS), 2018. + http://proceedings.mlr.press/v84/zhang18b/zhang18b.pdf +""" + +import numpy as np +from mpi4py import MPI +from scipy.stats import stats +import pickle as pkl +from brainiak.fcma.util import compute_correlation +from brainiak.funcalign.mdms import MDMS, Dataset + + +# parameters +features = 75 # number of features, k +n_iter = 30 # number of iterations of EM +test_ds = 'milky' + +# MPI parameters, do not need to change +comm = MPI.COMM_WORLD +rank = comm.rank +size = comm.size +if rank == 0: + print ('comm size : {}'.format(size)) + +# load and preprocess data in rank 0 +if rank == 0: + # load data + with open('data/multi_dataset.pickle','rb') as f: + all_data = pkl.load(f) + + # load dataset structure + ds_struct = Dataset('data/multi_dataset.json') + + # separate train and test data + # save info of test data to rank 0, and the testing will run at rank 0 + test_subj_list = ds_struct.subj_in_dataset[test_ds] + test_data = all_data[test_ds] + + # remove test dataset from the dataset structure without changing the data and MDMS will handle it automatically + _ = ds_struct.remove_dataset([test_ds]) + + # remove subjects in test_ds that are not in any training dataset + train_subj = set(ds_struct.get_subjects_list()) # all subjects in training set + test_subj_idx_to_keep = [] # index of subjects to keep + for idx, subj in enumerate(test_subj_list): + if subj in train_subj: + test_subj_idx_to_keep.append(idx) + test_subj_list = [test_subj_list[idx] for idx in test_subj_idx_to_keep] + test_data = [test_data[idx] for idx in test_subj_idx_to_keep] + + # compute voxels mean and std of each subject from training data and use them to standardize training and testing data + mean, std = {}, {} # mean and std of each subject + matrix_csr = ds_struct.matrix.tocsr(copy=True) + for subj in range(ds_struct.num_subj): # iterate through all subjects + subj_name = ds_struct.idx_to_subject[subj] + indices = matrix_csr[subj,:].indices # indices of datasets with this subject + # aggregate all data from this subject + for idx, ds_idx in enumerate(indices): + if idx == 0: + mtx_tmp = all_data[ds_struct.idx_to_dataset[ds_idx]][ds_struct.dok_matrix[subj,ds_idx]-1] + else: + mtx_tmp = np.concatenate((mtx_tmp, all_data[ds_struct.idx_to_dataset[ds_idx]][ds_struct.dok_matrix[subj,ds_idx]-1]),axis=1) + # compute mean and std + mean[subj_name] = np.mean(mtx_tmp, axis=1) + std[subj_name] = np.std(mtx_tmp, axis=1) + # standardize training data + for ds_idx in indices: + ds_name, idx_in_ds = ds_struct.idx_to_dataset[ds_idx], ds_struct.dok_matrix[subj,ds_idx]-1 + all_data[ds_name][idx_in_ds] = np.nan_to_num((all_data[ds_name][idx_in_ds]-mean[subj_name][:,None])/std[subj_name][:,None]) + + # use the mean and std computed from training data to standardize testing data + for idx, subj in enumerate(test_subj_list): + test_data[idx] = np.nan_to_num((test_data[idx]-mean[subj][:,None])/std[subj][:,None]) + + # delete testing data from 'all_data' to save space + del all_data[test_ds] + + # get the membership and compute the tag for MPI communication for every data point in 'all_data' + data_mem = {} + tag_s = 0 # tag start from 0 + for ds in all_data: + length = len(all_data[ds]) + mem = np.random.randint(low=0,high=size,size=length) # which rank it belongs to + tag = list(range(tag_s, tag_s+length)) + tag_s += length + data_mem[ds] = [mem, tag] + +else: + ds_struct = None + data_mem = None + +# broadcast data_mem and ds_struct to all ranks and initialize data in each rank +data_mem = comm.bcast(data_mem, root=0) +ds_struct = comm.bcast(ds_struct, root=0) + +data = {} +for ds in data_mem: + data[ds] = [None]*len(data_mem[ds][0]) + +# distribute data +if rank == 0: + for ds in data: + for idx, (mem, tag) in enumerate(zip(data_mem[ds][0], data_mem[ds][1])): + if mem != 0: + comm.send(all_data[ds][idx], dest=mem, tag=tag) + else: + data[ds][idx] = all_data[ds][idx] + del all_data +else: + for ds in data: + for idx, (mem, tag) in enumerate(zip(data_mem[ds][0], data_mem[ds][1])): + if mem == rank: + data[ds][idx] = comm.recv(source=0, tag=tag) + +# Fit MDMS model +model = MDMS(features=features, n_iter=n_iter, comm=comm) +model.fit(data, ds_struct) + +# run the testing in rank 0 +if rank == 0: + # define time segment matching experiment + def time_segment_matching(data, win_size=6): + nsubjs = len(data) + (ndim, nsample) = data[0].shape + accu = np.zeros(shape=nsubjs) + nseg = nsample - win_size + # mysseg prediction prediction + trn_data = np.zeros((ndim*win_size, nseg),order='f') + # the trn data also include the tst data, but will be subtracted when + # calculating A + for m in range(nsubjs): + for w in range(win_size): + trn_data[w*ndim:(w+1)*ndim,:] += data[m][:,w:(w+nseg)] + for tst_subj in range(nsubjs): + tst_data = np.zeros((ndim*win_size, nseg),order='f') + for w in range(win_size): + tst_data[w*ndim:(w+1)*ndim,:] = data[tst_subj][:,w:(w+nseg)] + + A = np.nan_to_num(stats.zscore((trn_data - tst_data),axis=0, ddof=1)) + B = np.nan_to_num(stats.zscore(tst_data,axis=0, ddof=1)) + + # compute correlation matrix + corr_mtx = compute_correlation(B.T,A.T) + + for i in range(nseg): + for j in range(nseg): + if abs(i-j) Date: Tue, 26 Feb 2019 16:25:31 -0500 Subject: [PATCH 2/7] Adjust MDMS coding style --- brainiak/funcalign/mdms.py | 2110 ++++++++++------- .../mdms_time_segment_matching_distributed.py | 1 - 2 files changed, 1289 insertions(+), 822 deletions(-) diff --git a/brainiak/funcalign/mdms.py b/brainiak/funcalign/mdms.py index db0466a1e..194996fc4 100644 --- a/brainiak/funcalign/mdms.py +++ b/brainiak/funcalign/mdms.py @@ -16,50 +16,54 @@ The implementations are based on the following publications: .. [Zhang2018] "Transfer learning on fMRI datasets", H. Zhang, P.-H. Chen, P. Ramadge - The 21st International Conference on Artificial Intelligence and Statistics (AISTATS), 2018. + The 21st International Conference on Artificial Intelligence and + Statistics (AISTATS), 2018. http://proceedings.mlr.press/v84/zhang18b/zhang18b.pdf """ # Authors: Hejia Zhang (Princeton Neuroscience Institute), 2018 import logging - import numpy as np import scipy from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils import assert_all_finite from sklearn.exceptions import NotFittedError from mpi4py import MPI -import sys, json, os, glob +import sys +import json +import os +import glob from scipy import sparse as sp import matplotlib.pyplot as plt import networkx as nx import pickle as pkl - __all__ = [ "DetMDMS", - "MDMS", - "Dataset" + "MDMS" ] logging.basicConfig(filename='mdms.log', - filemode='a', - format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', - datefmt='%H:%M:%S', - level=logging.DEBUG) + filemode='a', + format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(' + 'message)s', + datefmt='%H:%M:%S', + level=logging.DEBUG) logger = logging.getLogger(__name__) logger.addHandler(logging.StreamHandler(sys.stdout)) + def _init_w_transforms(voxels, features, random_states, datasets): - """Initialize the mappings (W_s) for the MDMS with random orthogonal matrices. + """Initialize the mappings (W_s) for the MDMS with random orthogonal + matrices. Parameters ---------- voxels : dict of int, voxels[s] is number of voxels where s is the name of the subject. - A dict with the number of voxels for each subject. + A dict with the number of voxels for each subject. features : int The number of features in the model. @@ -76,8 +80,9 @@ def _init_w_transforms(voxels, features, random_states, datasets): Returns ------- - w : dict of array, w[s] has shape=[voxels[s], features] where s is the name - of the subject. + w : dict of array, w[s] has shape=[voxels[s], features] where s is the + name + of the subject. The initialized orthogonal transforms (mappings) :math:`W_s` for each subject. @@ -103,24 +108,25 @@ def _init_w_transforms(voxels, features, random_states, datasets): def _sanity_check(X, datasets, comm): - """Check if the input data and datasets information have valid shape/configuration. + """Check if the input data and datasets information have valid shape/ + configuration. Parameters ---------- - X : dict of list of 2D arrays or dict of dict of 2D arrays - 1) When it is a dict of list of 2D arrays: - X[d] is a list of data of dataset d, where d is the name of the dataset. - Element i in the list has shape=[voxels_i, samples_d] - which is the fMRI data of the i'th subject in d. - 2) When it is a dict of dict of 2D arrays: - X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of - subject s in dataset d, where s is the name of the subject and - d is the name of the dataset. - + X : dict of list of 2D arrays or dict of dict of 2D arrays + 1) When it is a dict of list of 2D arrays: + X[d] is a list of data of dataset d, where d is the name of the + dataset. + Element i in the list has shape=[voxels_i, samples_d] + which is the fMRI data of the i'th subject in d. + 2) When it is a dict of dict of 2D arrays: + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data + of subject s in dataset d, where s is the name of the subject and + d is the name of the dataset. - datasets : a Dataset object - The Dataset object containing datasets structures. + datasets : a Dataset object + The Dataset object containing datasets structures. comm : mpi4py.MPI.Intracomm The MPI communicator containing the data. @@ -131,98 +137,190 @@ def _sanity_check(X, datasets, comm): voxels_ : dict of int, voxels_[s] is number of voxels where s is the name of the subject. - A dict with the number of voxels for each subject. - - samples_ : dict of int, samples_[d] is number of samples where d is the name - of the dataset. - A dict with the number of samples for each dataset. - + A dict with the number of voxels for each subject. + samples_ : dict of int, samples_[d] is number of samples where d is the + name of the dataset. + A dict with the number of samples for each dataset. """ - # Check the number of subjects and all ranks have all datasets in the Dataset object + # Check the number of subjects and all ranks have all datasets in the + # Dataset object ds_list = datasets.get_datasets_list() - for (ds, ns) in datasets.num_subj_dataset.items(): + for (ds, ns) in datasets.num_subj_dataset.items(): if ns < 1: - raise ValueError("Dataset {} should have positive num_subj_dataset".format(ds)) + raise ValueError("Dataset {} should have positive " + "num_subj_dataset".format(ds)) if ds not in X: raise ValueError("Dataset {} not in all ranks".format(ds)) if X[ds] is not None and len(X[ds]) < ns: - raise ValueError("Dataset {} does not have enough subjects: Need equal to or more " - "than {0:d} subjects but got {0:d} to train the model.".format(ds, ns, len(X[ds]))) + raise ValueError("Dataset {} does not have enough subjects: Need" + " equal to or more than {0:d} subjects but " + "got {0:d} to train the model." + .format(ds, ns, len(X[ds]))) # Collect size information + shape0, shape1, data_exist = _collect_size_information(X, datasets, comm) + + # Check if all required data appears once and only once + # Also remove size information of data that is not in 'datasets' + shape0, shape1 = _check_missing_data(datasets, shape0, + shape1, data_exist) + + # Check if each subject has same number of voxels across different + # datasets + voxels_ = {} + for subj in range(datasets.num_subj): + all_vxs_tmp = [v[subj] for v in shape0.values() if v[subj] != 0] + subj_name = datasets.idx_to_subject[subj] + voxels_[subj_name] = np.min(all_vxs_tmp) + if any([v != voxels_[subj_name] for v in all_vxs_tmp]): + raise ValueError("Subject {} has different number of voxels " + "across datasets.".format(subj_name)) + + # Check if all subjects have same number of TRs within the same dataset + samples_ = {} + for ds in ds_list: + all_trs_tmp = [t for t in shape1[ds] if t != 0] + samples_[ds] = np.min(all_trs_tmp) + if any([t != samples_[ds] for t in all_trs_tmp]): + raise ValueError("Different number of samples between subjects" + "in dataset {}.".format(ds)) + + return voxels_, samples_ + + +def _collect_size_information(X, datasets, comm): + """Collect the shape of datasets and check if all data required are in X. + + Parameters + ---------- + + X : dict of list of 2D arrays or dict of dict of 2D arrays + 1) When it is a dict of list of 2D arrays: + X[d] is a list of data of dataset d, where d is the name of the + dataset. + Element i in the list has shape=[voxels_i, samples_d] + which is the fMRI data of the i'th subject in d. + 2) When it is a dict of dict of 2D arrays: + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data + of subject s in dataset d, where s is the name of the subject and + d is the name of the dataset. + + datasets : a Dataset object + The Dataset object containing datasets structures. + + comm : mpi4py.MPI.Intracomm + The MPI communicator containing the data. + + + Returns + ------- + + shape0 : dict of list, shape0[d] has shape [num_subj] + Size of the 1st dimension of each 2D data array. + + shape1 : dict of list, shape1[d] has shape [num_subj] + Size of the 2nd dimension of each 2D data array. + + data_exist : dict of list, data_exist[d] has shape [num_subj] + How many times the same 2D data array appears in the dataset. + """ shape0, shape1, data_exist = {}, {}, {} + ds_list = datasets.get_datasets_list() for ds in ds_list: + # initialization shape0[ds] = np.zeros((datasets.num_subj,), dtype=np.int) shape1[ds] = np.zeros((datasets.num_subj,), dtype=np.int) data_exist[ds] = np.zeros((datasets.num_subj,), dtype=np.int) - for ds in ds_list: ds_idx = datasets.dataset_to_idx[ds] + # collect size information of each dataset if X[ds] is not None: for subj in range(datasets.num_subj): - if datasets.dok_matrix[subj,ds_idx] != 0: + if datasets.dok_matrix[subj, ds_idx] != 0: if datasets.built_from_data: idx = datasets.idx_to_subject[subj] - if not idx in X[ds]: - raise Exception('Subject {} in dataset {} is missing.'.format(idx, ds)) + if idx not in X[ds]: + raise Exception('Subject {} in dataset {} is ' + 'missing.'.format(idx, ds)) else: - idx = datasets.dok_matrix[subj,ds_idx] - 1 + idx = datasets.dok_matrix[subj, ds_idx] - 1 if len(X[ds]) <= idx: - raise ValueError("Dataset {} does not have enough subjects: Need more " - "than {0:d} subjects but got {0:d} to train the model.".format(ds, idx, len(X[ds]))) + raise ValueError("Dataset {} does not have " + "enough subjects: Need more " + "than {0:d} subjects but got " + "{0:d} to train the model.". + format(ds, idx, len(X[ds]))) if X[ds][idx] is not None: - assert_all_finite(X[ds][idx]) + assert_all_finite(X[ds][idx]) shape0[ds][subj] = X[ds][idx].shape[0] shape1[ds][subj] = X[ds][idx].shape[1] data_exist[ds][subj] = 1 - for ds in ds_list: + # reduce from all ranks shape0[ds] = comm.allreduce(shape0[ds], op=MPI.SUM) shape1[ds] = comm.allreduce(shape1[ds], op=MPI.SUM) data_exist[ds] = comm.allreduce(data_exist[ds], op=MPI.SUM) - # Check if all required data appears once and only once - # Also remove size information of data that is not in 'datasets' + return shape0, shape1, data_exist + + +def _check_missing_data(datasets, shape0, shape1, data_exist): + """Check if all required data appears once and only once. + Also remove size information of data that is not in 'datasets' + + Parameters + ---------- + + datasets : a Dataset object + The Dataset object containing datasets structures. + + shape0 : dict of list, shape0[d] has shape [num_subj] + Size of the 1st dimension of each 2D data array. + + shape1 : dict of list, shape1[d] has shape [num_subj] + Size of the 2nd dimension of each 2D data array. + + data_exist : dict of list, data_exist[d] has shape [num_subj] + How many times the same 2D data array appears in the dataset. + + + Returns + ------- + + shape0 : dict of list, shape0[d] has shape [num_subj] + Size of the 1st dimension of each 2D data array. + + shape1 : dict of list, shape1[d] has shape [num_subj] + Size of the 2nd dimension of each 2D data array. + """ + ds_list = datasets.get_datasets_list() for ds in ds_list: ds_idx = datasets.dataset_to_idx[ds] for subj in range(datasets.num_subj): - if datasets.dok_matrix[subj,ds_idx] != 0: + if datasets.dok_matrix[subj, ds_idx] != 0: if data_exist[ds][subj] == 0: - raise ValueError("Data of subject {} in dataset {} is missing.".format(datasets.dok_matrix[subj,ds_idx]-1, ds)) + raise ValueError("Data of subject {} in dataset {} is " + "missing.".format(datasets.dok_matrix[ + subj, ds_idx]-1, ds)) elif data_exist[ds][subj] > 1: - raise ValueError("Data of subject {} in dataset {} appears more than once.".format(datasets.dok_matrix[subj,ds_idx]-1, ds)) + raise ValueError("Data of subject {} in dataset {} " + "appears more than once." + .format(datasets.dok_matrix[ + subj, ds_idx]-1, ds)) else: shape0[ds][subj] = 0 shape1[ds][subj] = 0 - - # Check if each subject has same number of voxels across different datasets - voxels_ = {} - for subj in range(datasets.num_subj): - all_vxs_tmp = [v[subj] for v in shape0.values() if v[subj] != 0] - subj_name = datasets.idx_to_subject[subj] - voxels_[subj_name] = np.min(all_vxs_tmp) - if any([v != voxels_[subj_name] for v in all_vxs_tmp]): - raise ValueError("Subject {} has different number of voxels across" - "datasets.".format(subj_name)) - - # Check if all subjects have same number of TRs within the same dataset - samples_ = {} - for ds in ds_list: - all_trs_tmp = [t for t in shape1[ds] if t != 0] - samples_[ds] = np.min(all_trs_tmp) - if any([t != samples_[ds] for t in all_trs_tmp]): - raise ValueError("Different number of samples between subjects" - "in dataset {}.".format(ds)) - - return voxels_, samples_ + return shape0, shape1 class MDMS(BaseEstimator, TransformerMixin): """multi-dataset multi-subject (MDMS) SRM analysis - Given multi-dataset multi-subject data, factorize it as a shared response S among all - subjects per dataset and an orthogonal transform W across all datasets per subject: + Given multi-dataset multi-subject data, factorize it as a shared + response S among all subjects per dataset and an orthogonal transform W + across all datasets per subject: - .. math:: X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\dots M + .. math:: X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\ + dots M Parameters ---------- @@ -252,22 +350,24 @@ class MDMS(BaseEstimator, TransformerMixin): voxels_ : dict of int, voxels_[s] is number of voxels where s is the name of the subject. - A dict with the number of voxels for each subject. + A dict with the number of voxels for each subject. - samples_ : dict of int, samples_[d] is number of samples where d is the name - of the dataset. + samples_ : dict of int, samples_[d] is number of samples where d is the + name of the dataset. A dict with the number of samples for each dataset. sigma_s_ : dict of array, sigma_s_[d] has shape=[features, features] - The covariance of the shared response Normal distribution for each dataset. + The covariance of the shared response Normal distribution for each + dataset. - mu_ : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name + mu_ : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name of the subject. The voxel means over the samples in all datasets for each subject. - rho2_ : dict of dict of float, rho2_[d][s] is a float, where d is the name - of the dataset and s is the name of the subject. - The estimated noise variance :math:`\\rho_{di}^2` for each subject in each dataset. + rho2_ : dict of dict of float, rho2_[d][s] is a float, where d is the + name of the dataset and s is the name of the subject. + The estimated noise variance :math:`\\rho_{di}^2` for each subject + in each dataset. comm : mpi4py.MPI.Intracomm The MPI communicator containing the data @@ -278,19 +378,22 @@ class MDMS(BaseEstimator, TransformerMixin): Note ---- - The number of voxels may be different between subjects within a dataset - and number of samples may be different between datasets. However, the - number of samples must be the same across subjects within a dataset and - number of voxels must be the same across datasets for the same subject. + The number of voxels may be different between subjects within a + dataset and number of samples may be different between datasets. + However, the number of samples must be the same across subjects + within a dataset and number of voxels must be the same across + datasets for the same subject. - The probabilistic multi-dataset multi-subject model is approximated using the - Expectation Maximization (EM) algorithm proposed in [Zhang2018]_. The - implementation follows the optimizations published in [Anderson2016]_. + The probabilistic multi-dataset multi-subject model is approximated + using the Expectation Maximization (EM) algorithm proposed in + [Zhang2018]_. The implementation follows the optimizations published + in [Anderson2016]_. The run-time complexity is :math:`O(I (V T K + V K^2 + K^3))` and the memory complexity is :math:`O(V T)` with I - the number of iterations, - V - the sum of voxels from all subjects, T - the sum of samples from - all datasets, and K - the number of features (typically, :math:`V \\gg T \\gg K`). + V - the sum of voxels from all subjects, T - the sum of samples from + all datasets, and K - the number of features (typically, :math:`V \\ + gg T \\gg K`). """ def __init__(self, n_iter=10, features=50, rand_seed=0, @@ -302,23 +405,24 @@ def __init__(self, n_iter=10, features=50, rand_seed=0, self.logger = logger return - def fit(self, X, datasets=None, y=None): - """Compute the probabilistic multi-dataset multi-subject (MDMS) SRM analysis + """Compute the probabilistic multi-dataset multi-subject (MDMS) SRM + analysis Parameters ---------- X : dict of list of 2D arrays or dict of dict of 2D arrays 1) When it is a dict of list of 2D arrays: 'datasets' must be defined in this case. - X[d] is a list of data of dataset d, where d is the name of the dataset. - Element i in the list has shape=[voxels_i, samples_d] + X[d] is a list of data of dataset d, where d is the name of + the dataset. + Element i in the list has shape=[voxels_i, samples_d] which is the fMRI data of the i'th subject in d. 2) When it is a dict of dict of 2D arrays: 'datasets' can be omitted in this case. - X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of - subject s in dataset d, where s is the name of the subject and - d is the name of the dataset. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI + data of subject s in dataset d, where s is the name of the + subject and d is the name of the dataset. datasets : (optional) a Dataset object The Dataset object containing datasets structure. @@ -338,15 +442,20 @@ def fit(self, X, datasets=None, y=None): raise Exception('X should be a dict.') format_X = type(next(iter(X.values()))) if format_X != dict and format_X != list: - raise Exception('X should be a dict of dict of arrays or dict of list of arrays.') - if format_X == list and (datasets is None or datasets.built_from_data is None or datasets.built_from_data): - raise Exception("Argument 'datasets' must be defined and built from json " - "files when X is a dict of list of 2D arrays. ") + raise Exception('X should be a dict of dict of arrays or dict of' + ' list of arrays.') + if format_X == list and (datasets is None or + datasets.built_from_data is None or + datasets.built_from_data): + raise Exception("Argument 'datasets' must be defined and built " + "from json " + "files when X is a dict of list of 2D arrays. ") if format_X == dict and datasets is not None: datasets.built_from_data = True for v in X.values(): if type(v) != format_X: - raise Exception('X should be a dict of dict of arrays or dict of list of arrays.') + raise Exception('X should be a dict of dict of arrays or ' + 'dict of list of arrays.') # Infer datasets structure from data if datasets is None: @@ -354,25 +463,26 @@ def fit(self, X, datasets=None, y=None): datasets.build_from_data(X) self.voxels_, self.samples_ = _sanity_check(X, datasets, self.comm) - + # Run MDMS - self.sigma_s_, self.w_, self.mu_, self.rho2_, self.s_ = self._mdms(X, datasets) + self.sigma_s_, self.w_, self.mu_, \ + self.rho2_, self.s_ = self._mdms(X, datasets) return self - def transform(self, X, subjects, centered=True, y=None): """Use the model to transform new data to Shared Response space Parameters ---------- X : list of 2D arrays, element i has shape=[voxels_i, samples_i] - Each element in the list contains the new fMRI data of one subject + Each element in the list contains the new fMRI data of one + subject subjects : list of string, element i is the name of subject of X[i] centered : bool, if the data in X is already centered. - If centered = False, the voxel means computed during mode fitting + If centered = False, the voxel means computed during mode fitting will be subtracted before transformation. y : not used (as it is unsupervised learning) @@ -392,68 +502,128 @@ def transform(self, X, subjects, centered=True, y=None): if not hasattr(self, 'w_'): raise NotFittedError("The model fit has not been run yet.") - # Check if the subject exist in the fitted model and has the right number of voxels + # Check if the subject exist in the fitted model and has the right + # number of voxels for idx in range(len(X)): if not subjects[idx] in self.w_: - raise NotFittedError("The model has not been fitted to subject {}.".format(subjects[idx])) - if X[idx] is not None and self.w_[subjects[idx]].shape[0] != X[idx].shape[0]: - raise ValueError("{}-th element of data has inconsistent number of" - "voxels with fitted model. Model has {} voxels while data has {}" - ".".format(idx, self.w_[subjects[idx]].shape[0], X[idx].shape[0])) + raise NotFittedError("The model has not been fitted to " + "subject {}.".format(subjects[idx])) + if X[idx] is not None and (self.w_[subjects[idx]]. + shape[0] != X[idx].shape[0]): + raise ValueError("{}-th element of data has inconsistent " + "number of voxels with fitted model. Model" + " has {} voxels while data has {}.". + format(idx, self.w_[subjects[idx]].shape[0], + X[idx].shape[0])) s = [None] * len(X) for idx in range(len(X)): if X[idx] is not None: if centered: s[idx] = self.w_[subjects[idx]].T.dot(X[idx]) - else: - s[idx] = self.w_[subjects[idx]].T.dot(X[idx]-self.mu_[subjects[idx]][:, None]) + else: + s[idx] = self.w_[subjects[idx]].T.\ + dot(X[idx] - self.mu_[subjects[idx]][:, None]) return s + def _compute_mean(self, x, datasets): + """Compute the mean of data. - def _init_structures(self, data, datasets): - """Initializes data structures for MDMS and preprocess the data. + Parameters + ---------- + x : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] + where d is the name of the dataset and s is the name of the + subject. + Demeaned data for each subject. + datasets : a Dataset object + The Dataset object containing datasets structures. + + Returns + ------- + + mu : dict of array, mu_[s] has shape=[voxels_[s]] where s is the + name of the subject. + The voxel means over the samples in all datasets for each + subject. + """ + # collect mean from each MPI worker + weights = {} + mu_tmp = {} + for subj in datasets.subject_to_idx.keys(): + weights[subj], mu_tmp[subj] = {}, {} + for ds in x.keys(): + if subj in x[ds]: + if x[ds][subj] is not None: + mu_tmp[subj][ds] = np.mean(x[ds][subj], 1) + weights[subj][ds] = x[ds][subj].shape[1] + else: + mu_tmp[subj][ds] = np.zeros((self.voxels_[subj],)) + weights[subj][ds] = 0 + # collect mean from all MPI workers + for subj in datasets.subject_to_idx.keys(): + for ds in mu_tmp[subj].keys(): + mu_tmp[subj][ds] = self.comm.allreduce(mu_tmp[subj][ds], + op=MPI.SUM) + weights[subj][ds] = self.comm.allreduce(weights[subj][ds], + op=MPI.SUM) + # compute final mean + mu = {} + for subj in datasets.subject_to_idx.keys(): + mu[subj] = np.zeros((self.voxels_[subj],)) + nsample = np.sum(list(weights[subj].values())) + for ds in mu_tmp[subj].keys(): + mu[subj] += weights[subj][ds] * mu_tmp[subj][ds] / nsample + return mu + + def _init_structures(self, data, datasets, ds_subj_list): + """Initializes data structures for MDMS and preprocess the data. Parameters ---------- data : dict of list of 2D arrays or dict of dict of 2D arrays 1) When it is a dict of list of 2D arrays: 'datasets' must be defined in this case. - X[d] is a list of data of dataset d, where d is the name of the dataset. - Element i in the list has shape=[voxels_i, samples_d] + X[d] is a list of data of dataset d, where d is the name of + the dataset. + Element i in the list has shape=[voxels_i, samples_d] which is the fMRI data of the i'th subject in d. 2) When it is a dict of dict of 2D arrays: 'datasets' can be omitted in this case. - X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of - subject s in dataset d, where s is the name of the subject and - d is the name of the dataset. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI + data of subject s in dataset d, where s is the name of the + subject and d is the name of the dataset. datasets : a Dataset object The Dataset object containing datasets structures. + ds_subj_list : dict of list of string, ds_subj_list[s] is a list + of names of datasets with subject s, where s is the name + of the subject. Returns ------- x : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] - where d is the name of the dataset and s is the name of the subject + where d is the name of the dataset and s is the name of the + subject. Demeaned data for each subject. - mu : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name + mu : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name of the subject. - The voxel means over the samples in all datasets for each subject. + The voxel means over the samples in all datasets for each + subject. - rho2 : dict of dict of float, rho2_[d][s] is a float, where d is the name - of the dataset and s is the name of the subject. - The estimated noise variance :math:`\\rho_{di}^2` for each subject in each dataset. + rho2 : dict of dict of float, rho2_[d][s] is a float, where d is the + name of the dataset and s is the name of the subject. + The estimated noise variance :math:`\\rho_{di}^2` for each + subject in each dataset. - trace_xtx : dict of dict of float, trace_xtx[d][s] is a float, where + trace_xtx : dict of dict of float, trace_xtx[d][s] is a float, where d is the name of the dataset and s is the name of the subject. The squared Frobenius norm of the demeaned data in `x`. """ x = {} - mu = {} rho2 = {} trace_xtx = {} @@ -461,39 +631,19 @@ def _init_structures(self, data, datasets): for ds_idx, ds in datasets.idx_to_dataset.items(): x[ds] = {} for subj in range(datasets.num_subj): - if datasets.dok_matrix[subj,ds_idx] != 0: + if datasets.dok_matrix[subj, ds_idx] != 0: if datasets.built_from_data: - x[ds][datasets.idx_to_subject[subj]] = data[ds][datasets.idx_to_subject[subj]] + x[ds][datasets. + idx_to_subject[subj]] =\ + data[ds][datasets.idx_to_subject[subj]] else: - x[ds][datasets.idx_to_subject[subj]] = data[ds][datasets.dok_matrix[subj,ds_idx]-1] + x[ds][datasets. + idx_to_subject[subj]] =\ + data[ds][datasets.dok_matrix[subj, ds_idx]-1] del data # compute mean - # collect mean from each MPI worker - weights = {} - mu_tmp = {} - for subj in datasets.subject_to_idx.keys(): - weights[subj], mu_tmp[subj] = {}, {} - for ds in x.keys(): - if subj in x[ds]: - if x[ds][subj] is not None: - mu_tmp[subj][ds] = np.mean(x[ds][subj], 1) - weights[subj][ds] = x[ds][subj].shape[1] - else: - mu_tmp[subj][ds] = np.zeros((self.voxels_[subj],)) - weights[subj][ds] = 0 - # collect mean from all MPI workers - for subj in datasets.subject_to_idx.keys(): - for ds in mu_tmp[subj].keys(): - mu_tmp[subj][ds] = self.comm.allreduce(mu_tmp[subj][ds], op=MPI.SUM) - weights[subj][ds] = self.comm.allreduce(weights[subj][ds], op=MPI.SUM) - # compute final mean - for subj in datasets.subject_to_idx.keys(): - mu[subj] = np.zeros((self.voxels_[subj],)) - nsample = np.sum(list(weights[subj].values())) - for ds in mu_tmp[subj].keys(): - mu[subj] += weights[subj][ds] * mu_tmp[subj][ds] / nsample - del weights, mu_tmp + mu = self._compute_mean(x, datasets) # subtract mean from x and compute trace_xtx, initialize rho2 for ds in x.keys(): @@ -501,13 +651,18 @@ def _init_structures(self, data, datasets): for subj in x[ds].keys(): rho2[ds][subj] = 1 if x[ds][subj] is not None: - x[ds][subj] -= mu[subj][:,None] - trace_xtx[ds][subj] = np.sum(x[ds][subj] ** 2) + x[ds][subj] -= mu[subj][:, None] + trace_xtx[ds][subj] = np.sum(x[ds][subj] ** 2) else: trace_xtx[ds][subj] = 0 - return x, mu, rho2, trace_xtx + # broadcast values in trace_xtx to all ranks + for subj in ds_subj_list.keys(): + for ds in ds_subj_list[subj]: + trace_xtx[ds][subj] = self.comm.allreduce( + trace_xtx[ds][subj], op=MPI.SUM) + return x, mu, rho2, trace_xtx def _likelihood(self, chol_sigma_s_rhos, log_det_psi, chol_sigma_s, trace_xt_invsigma2_x, inv_sigma_s_rhos, wt_invpsi_x, @@ -547,8 +702,8 @@ def _likelihood(self, chol_sigma_s_rhos, log_det_psi, chol_sigma_s, loglikehood : float The log-likelihood value. """ - log_det = (np.log(np.diag(chol_sigma_s_rhos) ** 2).sum() + log_det_psi - + np.log(np.diag(chol_sigma_s) ** 2).sum()) + log_det = (np.log(np.diag(chol_sigma_s_rhos) ** 2).sum() + + log_det_psi + np.log(np.diag(chol_sigma_s) ** 2).sum()) loglikehood = -0.5 * samples * log_det - 0.5 * trace_xt_invsigma2_x loglikehood += 0.5 * np.trace( wt_invpsi_x.T.dot(inv_sigma_s_rhos).dot(wt_invpsi_x)) @@ -581,7 +736,6 @@ def _update_transform_subject(Xi, S): U, _, V = np.linalg.svd(A, full_matrices=False) return U.dot(V) - def transform_subject(self, X, dataset): """Transform a new subject using the existing model. The subject is assumed to have received equivalent stimulation @@ -593,7 +747,7 @@ def transform_subject(self, X, dataset): X : 2D array, shape=[voxels, timepoints] The fMRI data of the new subject. - dataset : string, name of the dataset in the fitted model that + dataset : string, name of the dataset in the fitted model that has the same stimulation as the new subject Returns @@ -608,21 +762,325 @@ def transform_subject(self, X, dataset): raise NotFittedError("The model fit has not been run yet.") # Check if the dataset is in the model - if not dataset in self.s_: - raise NotFittedError("Dataset {} is not in the model yet.".format(dataset)) + if dataset not in self.s_: + raise NotFittedError("Dataset {} is not in the model yet." + .format(dataset)) # Check the number of TRs in the subject if X.shape[1] != self.s_[dataset].shape[1]: - raise ValueError("The number of timepoints(TRs) does not match the" - "one in the model.") + raise ValueError("The number of timepoints(TRs) does not match" + " the one in the model.") w = self._update_transform_subject(X, self.s_[dataset]) return w + def _compute_shared_response(self, x, w, shared_response, sigma_s, + rho2, trace_xtx, ds_list, subj_ds_list, + ds_rank, rank): + """Part of E step in MDMS. Update shared response and sigma_s for + each dataset. + + Parameters + ---------- + + x : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] + where d is the name of the dataset and s is the name of the + subject. + Demeaned data for each subject. + + w : dict of array, w[s] has shape=[voxels_[s], features] where s is + the name of the subject. + The orthogonal transforms (mappings) :math:`W_s` for each + subject. + + shared_response : dict of array, shared_response[d] has + shape=[features, samples_[d]] where d is the name of the dataset. + The shared response for each dataset. + + sigma_s : dict of array, sigma_s[d] has shape=[features, features] + where d is the name of dataset. + The covariance :math:`\\Sigma_s` of the shared response Normal + distribution for each dataset. + + rho2 : dict of dict of float, rho2[d][s] is a float, where d is the + name of the dataset and s is the name of the subject. + The estimated noise variance :math:`\\rho2{di}^2` for each + subject in each dataset. + + trace_xtx : dict of dict of float, trace_xtx[d][s] is a float, where + d is the name of the dataset and s is the name of the subject. + The squared Frobenius norm of the demeaned data in `x`. + + ds_list : list of string, names of all datasets + + subj_ds_list : dict of list of string, subj_ds_list[d] is a list + of names of subjects in dataset d, where d is the name + of the dataset. + + ds_rank : set of string, name of datasets assigned to be processed + by this rank. + + rank : int, the current MPI rank + + Returns + ------- + + shared_response : dict of array, shared_response[d] has + shape=[features, samples_[d]] where d is the name of the dataset. + The shared response for each dataset. + + trace_sigma_s : dict of float, trace of sigma_s for each dataset. + + sigma_s : dict of array, sigma_s[d] has shape=[features, features] + where d is the name of dataset. + The covariance :math:`\\Sigma_s` of the shared response Normal + distribution for each dataset. + """ + loglike = 0. + other_ds = set(ds_list) - ds_rank + + # for multi-thread computation + chol_sigma_s = {ds: np.zeros((self.features, self.features)) for ds + in other_ds} + chol_sigma_s_rhos = {ds: np.zeros((self.features, self.features)) + for ds in other_ds} + inv_sigma_s_rhos = {ds: np.zeros((self.features, self.features)) + for ds in other_ds} + rho0 = {ds: 0.0 for ds in other_ds} + wt_invpsi_x = {ds: np.zeros((self.features, self.samples_[ds])) + for ds in ds_list} + trace_xt_invsigma2_x = {ds: 0.0 for ds in ds_list} + trace_sigma_s = {ds: 0 for ds in ds_list} + + # iterate through all ds in this rank + for ds in ds_rank: + # Sum the inverted the rho2 elements for computing W^T * + # Psi^-1 * W + rho0[ds] = np.sum([1/v for v in rho2[ds].values()]) + + # Invert Sigma_s[ds] using Cholesky factorization + (chol_sigma_s[ds], lower_sigma_s) = scipy.linalg.cho_factor( + sigma_s[ds], check_finite=False) + inv_sigma_s = scipy.linalg.cho_solve( + (chol_sigma_s[ds], lower_sigma_s), + np.identity(self.features), + check_finite=False) + + # Invert (Sigma_s[ds] + rho_0 * I) using Cholesky + # factorization + sigma_s_rhos = inv_sigma_s + np.identity(self.features) *\ + rho0[ds] + (chol_sigma_s_rhos[ds], lower_sigma_s_rhos) = \ + scipy.linalg.cho_factor(sigma_s_rhos, check_finite=False) + inv_sigma_s_rhos[ds] = scipy.linalg.cho_solve( + (chol_sigma_s_rhos[ds], lower_sigma_s_rhos), + np.identity(self.features), check_finite=False) + + # collect info from all ranks + chol_sigma_s = {ds: self.comm. + allreduce(chol_sigma_s[ds], op=MPI.SUM) + for ds in ds_list} + chol_sigma_s_rhos = {ds: self.comm. + allreduce(chol_sigma_s_rhos[ds], op=MPI.SUM) + for ds in ds_list} + inv_sigma_s_rhos = {ds: self.comm. + allreduce(inv_sigma_s_rhos[ds], op=MPI.SUM) + for ds in ds_list} + + # Compute the sum of W_i^T * rho_i^-2 * X_i, and the sum of + # traces of X_i^T * rho_i^-2 * X_i + for ds in ds_list: + for subj in subj_ds_list[ds]: + if x[ds][subj] is not None: + wt_invpsi_x[ds] += (w[subj].T.dot(x[ds][subj])) /\ + rho2[ds][subj] + trace_xt_invsigma2_x[ds] += trace_xtx[ds][subj] /\ + rho2[ds][subj] + + # collect data from all ranks + for ds in ds_list: + wt_invpsi_x[ds] = self.comm.allreduce(wt_invpsi_x[ds], + op=MPI.SUM) + trace_xt_invsigma2_x[ds] = self.comm.allreduce( + trace_xt_invsigma2_x[ds], op=MPI.SUM) + + # compute shared response and Sigma_s of ds in this rank + for ds in ds_rank: + log_det_psi = np.sum([np.log(rho2[ds][subj]) * self. + voxels_[subj] for subj + in rho2[ds]]) + + # Update the shared response + shared_response[ds] = sigma_s[ds].dot( + np.identity(self.features) - rho0[ds] * + inv_sigma_s_rhos[ds]).dot( + wt_invpsi_x[ds]) + + # Update Sigma_s and compute its trace + sigma_s[ds] = (inv_sigma_s_rhos[ds] + + shared_response[ds].dot( + shared_response[ds].T) / + self.samples_[ds]) + trace_sigma_s[ds] = self.samples_[ds] *\ + np.trace(sigma_s[ds]) + + # calculate log likelihood to check convergence + loglike += self._likelihood( + chol_sigma_s_rhos[ds], log_det_psi, chol_sigma_s[ds], + trace_xt_invsigma2_x[ds], inv_sigma_s_rhos[ds], + wt_invpsi_x[ds], self.samples_[ds]) + + for ds in other_ds: + shared_response[ds] = np.zeros((self.features, + self.samples_[ds])) + sigma_s[ds] = np.zeros((self.features, self.features)) + trace_sigma_s[ds] = 0 + + # collect parameters from all ranks + for ds in ds_list: + shared_response[ds] = self.comm.allreduce( + shared_response[ds], op=MPI.SUM) + trace_sigma_s[ds] = self.comm.allreduce( + trace_sigma_s[ds], op=MPI.SUM) + sigma_s[ds] = self.comm.allreduce(sigma_s[ds], op=MPI.SUM) + + # collect loglikelihood + loglike = self.comm.allreduce(loglike, op=MPI.SUM) + if rank == 0 and self.logger.isEnabledFor(logging.INFO): + self.logger.info('Objective function %f' % loglike) + + return shared_response, trace_sigma_s, sigma_s + + def _compute_w(self, x, shared_response, ds_subj_list, rank): + """Compute transformation matrix W for each subject. + + Parameters + ---------- + + x : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] + where d is the name of the dataset and s is the name of the + subject. + Demeaned data for each subject. + + shared_response : dict of array, shared_response[d] has + shape=[features, samples_[d]] where d is the name of the dataset. + The shared response for each dataset. + + ds_subj_list : dict of list of string, ds_subj_list[s] is a list + of names of datasets with subject s, where s is the name + of the subject. + + rank : int, the current MPI rank + + Returns + ------- + + w : dict of array, w[s] has shape=[voxels_[s], features] where s is + the name of the subject. + The orthogonal transforms (mappings) :math:`W_s` for each + subject. + """ + w = {} + for subj in ds_subj_list.keys(): + # update w + a_subject = np.zeros((self.voxels_[subj], self.features)) + # use x data from all ranks + for ds in ds_subj_list[subj]: + if x[ds][subj] is not None: + a_subject += x[ds][subj].dot(shared_response[ds].T) + # collect a_subject from all ranks + a_subject = self.comm.allreduce(a_subject, op=MPI.SUM) + # compute w in one rank and broadcast + if rank == 0: + perturbation = np.zeros(a_subject.shape) + np.fill_diagonal(perturbation, 0.0001) + u_subject, _, v_subject = np.linalg.svd( + a_subject + perturbation, full_matrices=False) + w[subj] = u_subject.dot(v_subject) + else: + w[subj] = None + w[subj] = self.comm.bcast(w[subj], root=0) + return w + + def _compute_rho2(self, x, shared_response, w, ds_subj_list, ds_list, + trace_xtx, trace_sigma_s, rank): + """Compute the estimated noise variance for each subject in each + dataset. + + Parameters + ---------- + + x : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] + where d is the name of the dataset and s is the name of the + subject. + Demeaned data for each subject. + + shared_response : dict of array, shared_response[d] has + shape=[features, samples_[d]] where d is the name of the dataset. + The shared response for each dataset. + + w : dict of array, w[s] has shape=[voxels_[s], features] where s is + the name of the subject. + The orthogonal transforms (mappings) :math:`W_s` for each + subject. + + ds_subj_list : dict of list of string, ds_subj_list[s] is a list + of names of datasets with subject s, where s is the name + of the subject. + + ds_list : list of string, names of all datasets + + trace_xtx : dict of dict of float, trace_xtx[d][s] is a float, where + d is the name of the dataset and s is the name of the subject. + The squared Frobenius norm of the demeaned data in `x`. + + trace_sigma_s : dict of int, trace of sigma_s for each dataset + + rank : int, the current MPI rank + + Returns + ------- + + rho2 : dict of dict of float, rho2_[d][s] is a float, where d is the + name of the dataset and s is the name of the subject. + The estimated noise variance :math:`\\rho_{di}^2` for each + subject in each dataset. + """ + # update rho2 + rho2 = {d: {} for d in ds_list} + for subj in ds_subj_list.keys(): + # compute trace_xtws_tmp of data in this rank + trace_xtws_tmp = {} + for ds in ds_subj_list[subj]: + if x[ds][subj] is not None: + trace_xtws_tmp[ds] = np.trace(x[ds][subj].T.dot( + w[subj]).dot(shared_response[ds])) + else: + trace_xtws_tmp[ds] = 0.0 + # collect trace_xtws_tmp in all ranks + for ds in ds_subj_list[subj]: + trace_xtws_tmp[ds] = self.comm.allreduce( + trace_xtws_tmp[ds], op=MPI.SUM) + # compute rho2 + for ds in ds_subj_list[subj]: + if rank == 0: + rho2[ds][subj] = trace_xtx[ds][subj] + rho2[ds][subj] += -2 * trace_xtws_tmp[ds] + rho2[ds][subj] += trace_sigma_s[ds] + rho2[ds][subj] /= self.samples_[ds] *\ + self.voxels_[subj] + else: + rho2[ds][subj] = None + # broadcast to all ranks + for ds in ds_subj_list[subj]: + rho2[ds][subj] = self.comm.bcast(rho2[ds][subj], root=0) + + return rho2 def _mdms(self, data, datasets): - """Expectation-Maximization algorithm for fitting the probabilistic MDMS. + """Expectation-Maximization algorithm for fitting the probabilistic + MDMS. Parameters ---------- @@ -630,55 +1088,57 @@ def _mdms(self, data, datasets): data : dict of list of 2D arrays or dict of dict of 2D arrays 1) When it is a dict of list of 2D arrays: 'datasets' must be defined in this case. - X[d] is a list of data of dataset d, where d is the name of the dataset. - Element i in the list has shape=[voxels_i, samples_d] + X[d] is a list of data of dataset d, where d is the name of + the dataset. + Element i in the list has shape=[voxels_i, samples_d] which is the fMRI data of the i'th subject in d. 2) When it is a dict of dict of 2D arrays: 'datasets' can be omitted in this case. - X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of - subject s in dataset d, where s is the name of the subject and - d is the name of the dataset. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI + data of subject s in dataset d, where s is the name of the + subject and d is the name of the dataset. datasets : a Dataset object The Dataset object containing datasets structures. - Returns ------- - sigma_s : dict of array, sigma_s[d] has shape=[features, features] where - d is the name of dataset. + sigma_s : dict of array, sigma_s[d] has shape=[features, features] + where d is the name of dataset. The covariance :math:`\\Sigma_s` of the shared response Normal distribution for each dataset. - w : dict of array, w[s] has shape=[voxels_[s], features] where s is the name - of the subject. - The orthogonal transforms (mappings) :math:`W_s` for each subject. + w : dict of array, w[s] has shape=[voxels_[s], features] where s is + the name of the subject. + The orthogonal transforms (mappings) :math:`W_s` for each + subject. - mu : dict of array, mu[s] has shape=[voxels_[s]] where s is the name + mu : dict of array, mu[s] has shape=[voxels_[s]] where s is the name of the subject. - The voxel means :math:`\\mu_i` over the samples in all datasets + The voxel means :math:`\\mu_i` over the samples in all datasets for each subject. - rho2 : dict of dict of float, rho2[d][s] is a float, where d is the name - of the dataset and s is the name of the subject. - The estimated noise variance :math:`\\rho2{di}^2` for each subject - in each dataset. + rho2 : dict of dict of float, rho2[d][s] is a float, where d is the + name of the dataset and s is the name of the subject. + The estimated noise variance :math:`\\rho2{di}^2` for each + subject in each dataset. - s : dict of array, s[d] has shape=[features, samples_[d]] where d is the - name of the dataset. + s : dict of array, s[d] has shape=[features, samples_[d]] where d is + the name of the dataset. The shared response for each dataset. """ # get information from datasets structures - ds_list, subj_list = datasets.get_datasets_list(), datasets.get_subjects_list() + ds_list, subj_list = datasets.get_datasets_list(),\ + datasets.get_subjects_list() subj_ds_list = datasets.subjects_in_dataset_all() ds_subj_list = datasets.datasets_with_subject_all() # initialize random states self.random_state_ = np.random.RandomState(self.rand_seed) - random_states = { - subj_list[i] : np.random.RandomState(self.random_state_.randint(2 ** 32)) - for i in range(datasets.num_subj)} + random_states = {subj_list[i]: np.random.RandomState( + self.random_state_.randint(2 ** 32)) + for i in range(datasets.num_subj)} # assign ds to different ranks for parallel computing rank = self.comm.Get_rank() @@ -691,177 +1151,49 @@ def _mdms(self, data, datasets): else: ds_rank_len = datasets.num_dataset // size if rank != size - 1: - ds_rank.update(set(ds_list[ds_rank_len*rank:ds_rank_len*(rank+1)])) + ds_rank.update(set(ds_list[ds_rank_len*rank: + ds_rank_len*(rank+1)])) else: ds_rank.update(set(ds_list[ds_rank_len*rank:])) - - # Initialization step: initialize the outputs with initial values + # Initialization step: initialize the outputs with initial values # and trace_xtx with the ||X_i||_F^2 of each subject in each dataset. w = _init_w_transforms(self.voxels_, self.features, random_states, - datasets) - x, mu, rho2, trace_xtx = self._init_structures(data, datasets) + datasets) + x, mu, rho2, trace_xtx = self._init_structures(data, datasets, + ds_subj_list) del data - # broadcast values in trace_xtx to all ranks - for subj in subj_list: - for ds in ds_subj_list[subj]: - trace_xtx[ds][subj] = self.comm.allreduce(trace_xtx[ds][subj], op=MPI.SUM) - shared_response, sigma_s, rho0 = {}, {}, {} + shared_response, sigma_s = {}, {} for ds in ds_list: - shared_response[ds] = np.zeros((self.features, self.samples_[ds])) + shared_response[ds] = np.zeros((self.features, + self.samples_[ds])) if ds in ds_rank: sigma_s[ds] = np.identity(self.features) - rho0[ds] = 0.0 else: sigma_s[ds] = np.zeros((self.features, self.features)) + # Main loop of the algorithm (run) + for iteration in range(self.n_iter): + if rank == 0: + self.logger.info('Iteration %d' % (iteration + 1)) - # Main loop of the algorithm (run) - for iteration in range(self.n_iter): - if rank == 0: - self.logger.info('Iteration %d' % (iteration + 1)) - - # E-step and some M-step: update shared_response and sigma_s of each dataset - loglike = 0. - - # for multi-thread computation - chol_sigma_s, chol_sigma_s_rhos, inv_sigma_s_rhos = {}, {}, {} - wt_invpsi_x, trace_xt_invsigma2_x, trace_sigma_s = {}, {}, {} - for ds in ds_list: - chol_sigma_s[ds] = np.zeros((self.features, self.features)) - chol_sigma_s_rhos[ds] = np.zeros((self.features, self.features)) - inv_sigma_s_rhos[ds] = np.zeros((self.features, self.features)) - wt_invpsi_x[ds] = np.zeros((self.features, self.samples_[ds])) - trace_xt_invsigma2_x[ds] = 0.0 - trace_sigma_s[ds] = 0 - - # iterate through all ds in this rank - for ds in ds_rank: - # Sum the inverted the rho2 elements for computing W^T * Psi^-1 * W - rho0[ds] = np.sum([1/v for v in rho2[ds].values()]) - - # Invert Sigma_s[ds] using Cholesky factorization - (chol_sigma_s[ds], lower_sigma_s) = scipy.linalg.cho_factor( - sigma_s[ds], check_finite=False) - inv_sigma_s = scipy.linalg.cho_solve( - (chol_sigma_s[ds], lower_sigma_s), np.identity(self.features), - check_finite=False) - - # Invert (Sigma_s[ds] + rho_0 * I) using Cholesky factorization - sigma_s_rhos = inv_sigma_s + np.identity(self.features) * rho0[ds] - (chol_sigma_s_rhos[ds], lower_sigma_s_rhos) = \ - scipy.linalg.cho_factor(sigma_s_rhos, check_finite=False) - inv_sigma_s_rhos[ds] = scipy.linalg.cho_solve( - (chol_sigma_s_rhos[ds], lower_sigma_s_rhos), - np.identity(self.features), check_finite=False) - - # collect info from all ranks - for ds in ds_list: - chol_sigma_s[ds] = self.comm.allreduce(chol_sigma_s[ds], op=MPI.SUM) - chol_sigma_s_rhos[ds] = self.comm.allreduce(chol_sigma_s_rhos[ds], op=MPI.SUM) - inv_sigma_s_rhos[ds] = self.comm.allreduce(inv_sigma_s_rhos[ds], op=MPI.SUM) - - # Compute the sum of W_i^T * rho_i^-2 * X_i, and the sum of traces - # of X_i^T * rho_i^-2 * X_i - for ds in ds_list: - for subj in subj_ds_list[ds]: - if x[ds][subj] is not None: - wt_invpsi_x[ds] += (w[subj].T.dot(x[ds][subj])) / rho2[ds][subj] - trace_xt_invsigma2_x[ds] += trace_xtx[ds][subj] / rho2[ds][subj] - - # collect data from all ranks - for ds in ds_list: - wt_invpsi_x[ds] = self.comm.allreduce(wt_invpsi_x[ds], op=MPI.SUM) - trace_xt_invsigma2_x[ds] = self.comm.allreduce(trace_xt_invsigma2_x[ds], - op=MPI.SUM) - - # compute shared response and Sigma_s of ds in this rank - for ds in ds_list: - if ds in ds_rank: - log_det_psi = np.sum([np.log(rho2[ds][subj])*self.voxels_[subj] for subj in rho2[ds]]) - - # Update the shared response - shared_response[ds] = sigma_s[ds].dot( - np.identity(self.features) - rho0[ds] * inv_sigma_s_rhos[ds]).dot( - wt_invpsi_x[ds]) - - # Update Sigma_s and compute its trace - sigma_s[ds] = (inv_sigma_s_rhos[ds] - + shared_response[ds].dot(shared_response[ds].T) / self.samples_[ds]) - trace_sigma_s[ds] = self.samples_[ds] * np.trace(sigma_s[ds]) - - # calculate log likelihood to check convergence - loglike += self._likelihood( - chol_sigma_s_rhos[ds], log_det_psi, chol_sigma_s[ds], - trace_xt_invsigma2_x[ds], inv_sigma_s_rhos[ds], wt_invpsi_x[ds], - self.samples_[ds]) - - else: - shared_response[ds] = np.zeros((self.features, self.samples_[ds])) - sigma_s[ds] = np.zeros((self.features, self.features)) - trace_sigma_s[ds] = 0 - - # collect parameters from all ranks - for ds in ds_list: - shared_response[ds] = self.comm.allreduce(shared_response[ds], op=MPI.SUM) - trace_sigma_s[ds] = self.comm.allreduce(trace_sigma_s[ds], op=MPI.SUM) - sigma_s[ds] = self.comm.allreduce(sigma_s[ds], op=MPI.SUM) - + # E-step and some M-step: update shared_response and sigma_s of + # each dataset + shared_response, trace_sigma_s, sigma_s = self.\ + _compute_shared_response(x, w, shared_response, sigma_s, + rho2, trace_xtx, ds_list, + subj_ds_list, ds_rank, rank) # The rest of M-step: update w and rho2 # Update each subject's mapping transform W_i and error variance # rho_di^2 - for subj in subj_list: - # update w - a_subject = np.zeros((self.voxels_[subj], self.features)) - # use x data from all ranks - for ds in ds_subj_list[subj]: - if x[ds][subj] is not None: - a_subject += x[ds][subj].dot(shared_response[ds].T) - # collect a_subject from all ranks - a_subject = self.comm.allreduce(a_subject, op=MPI.SUM) - # compute w in one rank and broadcast - if rank == 0: - perturbation = np.zeros(a_subject.shape) - np.fill_diagonal(perturbation, 0.0001) - u_subject, _, v_subject = np.linalg.svd( - a_subject + perturbation, full_matrices=False) - w[subj] = u_subject.dot(v_subject) - else: - w[subj] = None - w[subj] = self.comm.bcast(w[subj], root=0) - # update rho2 - # compute trace_xtws_tmp of data in this rank - trace_xtws_tmp = {} - for ds in ds_subj_list[subj]: - if x[ds][subj] is not None: - trace_xtws_tmp[ds] = np.trace(x[ds][subj].T.dot(w[subj]).dot(shared_response[ds])) - else: - trace_xtws_tmp[ds] = 0.0 - # collect trace_xtws_tmp in all ranks - for ds in ds_subj_list[subj]: - trace_xtws_tmp[ds] = self.comm.allreduce(trace_xtws_tmp[ds], op=MPI.SUM) - # compute rho2 - if rank == 0: - for ds in ds_subj_list[subj]: - rho2[ds][subj] = trace_xtx[ds][subj] - rho2[ds][subj] += -2 * trace_xtws_tmp[ds] - rho2[ds][subj] += trace_sigma_s[ds] - rho2[ds][subj] /= self.samples_[ds] * self.voxels_[subj] - # broadcast to all ranks - for ds in ds_subj_list[subj]: - rho2[ds][subj] = self.comm.bcast(rho2[ds][subj], root=0) - - - # collect loglikelihood - loglike = self.comm.allreduce(loglike, op=MPI.SUM) - if rank == 0: - if self.logger.isEnabledFor(logging.INFO): - self.logger.info('Objective function %f' % loglike) - - return sigma_s, w, mu, rho2, shared_response + w = self._compute_w(x, shared_response, ds_subj_list, rank) + rho2 = self._compute_rho2(x, shared_response, w, ds_subj_list, + ds_list, trace_xtx, trace_sigma_s, + rank) + return sigma_s, w, mu, rho2, shared_response def save(self, file): """Save the MDMS object to a file (as pickle) @@ -869,7 +1201,7 @@ def save(self, file): Parameters ---------- - file : The name (including full path) of the file that the object + file : The name (including full path) of the file that the object will be saved to. Returns @@ -886,7 +1218,7 @@ def save(self, file): """ # get attributes from object variables = self.__dict__.keys() - data = {k:getattr(self, k) for k in variables} + data = {k: getattr(self, k) for k in variables} # remove attributes that cannot be pickled del data['comm'] del data['logger'] @@ -895,9 +1227,8 @@ def save(self, file): # save attributes to file with open(file, 'wb') as f: pkl.dump(data, f, pkl.HIGHEST_PROTOCOL) - print ('MDMS object saved to {}.'.format(file)) - return - + self.logger.info('MDMS object saved to {}.'.format(file)) + return def restore(self, file): """Restore the MDMS object from a (pickle) file @@ -905,7 +1236,7 @@ def restore(self, file): Parameters ---------- - file : The name (including full path) of the file that the object + file : The name (including full path) of the file that the object will be restored from. Returns @@ -916,7 +1247,7 @@ def restore(self, file): Note ---- - The MPI communicator cannot be saved, so self.comm is initialized to + The MPI communicator cannot be saved, so self.comm is initialized to MPI.COMM_SELF """ @@ -930,17 +1261,20 @@ def restore(self, file): self.comm = MPI.COMM_SELF self.random_state_ = np.random.RandomState(self.rand_seed) self.logger = logger - print ('MDMS object restored from {}.'.format(file)) + self.logger.info('MDMS object restored from {}.'.format(file)) return class DetMDMS(BaseEstimator, TransformerMixin): - """Deterministic multi-dataset multi-subject (MDMS) SRM analysis (DetMDMS) + """Deterministic multi-dataset multi-subject (MDMS) SRM analysis + (DetMDMS) - Given multi-dataset multi-subject data, factorize it as a shared response S among all - subjects per dataset and an orthogonal transform W across all datasets per subject: + Given multi-dataset multi-subject data, factorize it as a shared + response S among all subjects per dataset and an orthogonal transform W + across all datasets per subject: - .. math:: X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\dots M + .. math:: X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\ + dots M Parameters ---------- @@ -970,13 +1304,13 @@ class DetMDMS(BaseEstimator, TransformerMixin): voxels_ : dict of int, voxels_[s] is number of voxels where s is the name of the subject. - A dict with the number of voxels for each subject. + A dict with the number of voxels for each subject. - samples_ : dict of int, samples_[d] is number of samples where d is the name - of the dataset. + samples_ : dict of int, samples_[d] is number of samples where d is the + name of the dataset. A dict with the number of samples for each dataset. - mu_ : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name + mu_ : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name of the subject. The voxel means over the samples in all datasets for each subject. @@ -989,23 +1323,26 @@ class DetMDMS(BaseEstimator, TransformerMixin): Note ---- - The number of voxels may be different between subjects within a dataset - and number of samples may be different between datasets. However, the - number of samples must be the same across subjects within a dataset and - number of voxels must be the same across datasets for the same subject. - - The probabilistic multi-dataset multi-subject model is approximated using the - Block Coordinate Descent (BCD) algorithm proposed in [Zhang2018]_. - - The run-time complexity is :math:`O(I (V T K + V K^2))` and the memory - complexity is :math:`O(V T)` with I - the number of iterations, V - the - sum of number of voxels from all subjects, T - the sum of number of - samples from all datasets, K - the number of features (typically, + The number of voxels may be different between subjects within a + dataset and number of samples may be different between datasets. + However, the number of samples must be the same across subjects + within a dataset and number of voxels must be the same across + datasets for the same subject. + + The probabilistic multi-dataset multi-subject model is approximated + using the Block Coordinate Descent (BCD) algorithm proposed in + [Zhang2018]_. + + The run-time complexity is :math:`O(I (V T K + V K^2))` and the + memory complexity is :math:`O(V T)` with I - the number of + iterations, V - the sum of number of voxels from all subjects, T - + the sum of number of samples from all datasets, K - the number of + features (typically, :math:`V \\gg T \\gg K`), and N - the number of subjects. """ def __init__(self, n_iter=10, features=50, rand_seed=0, - comm=MPI.COMM_SELF): + comm=MPI.COMM_SELF): self.n_iter = n_iter self.features = features self.rand_seed = rand_seed @@ -1013,7 +1350,6 @@ def __init__(self, n_iter=10, features=50, rand_seed=0, self.logger = logger return - def fit(self, X, datasets=None, demean=True, y=None): """Compute the Deterministic Shared Response Model @@ -1023,14 +1359,15 @@ def fit(self, X, datasets=None, demean=True, y=None): X : dict of list of 2D arrays or dict of dict of 2D arrays 1) When it is a dict of list of 2D arrays: 'datasets' must be defined in this case. - X[d] is a list of data of dataset d, where d is the name of the dataset. - Element i in the list has shape=[voxels_i, samples_d] + X[d] is a list of data of dataset d, where d is the name of + the dataset. + Element i in the list has shape=[voxels_i, samples_d] which is the fMRI data of the i'th subject in d. 2) When it is a dict of dict of 2D arrays: 'datasets' can be omitted in this case. - X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of - subject s in dataset d, where s is the name of the subject and - d is the name of the dataset. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI + data of subject s in dataset d, where s is the name of the + subject and d is the name of the dataset. datasets : (optional) a Dataset object The Dataset object containing datasets structure. @@ -1054,15 +1391,20 @@ def fit(self, X, datasets=None, demean=True, y=None): raise Exception('X should be a dict.') format_X = type(next(iter(X.values()))) if format_X != dict and format_X != list: - raise Exception('X should be a dict of dict of arrays or dict of list of arrays.') - if format_X == list and (datasets is None or datasets.built_from_data is None or datasets.built_from_data): - raise Exception("Argument 'datasets' must be defined and built from json " - "files when X is a dict of list of 2D arrays. ") + raise Exception('X should be a dict of dict of arrays or dict of' + ' list of arrays.') + if format_X == list and (datasets is None or + datasets.built_from_data is None or + datasets.built_from_data): + raise Exception("Argument 'datasets' must be defined and built " + "from json files when X is a dict of list of 2D " + "arrays. ") if format_X == dict and datasets is not None: datasets.built_from_data = True for v in X.values(): if type(v) != format_X: - raise Exception('X should be a dict of dict of arrays or dict of list of arrays.') + raise Exception('X should be a dict of dict of arrays or ' + 'dict of list of arrays.') # Infer datasets structure from data if datasets is None: @@ -1076,7 +1418,6 @@ def fit(self, X, datasets=None, demean=True, y=None): return self - def transform(self, X, subjects, centered=True, y=None): """Use the model to transform new data to Shared Response space @@ -1084,17 +1425,17 @@ def transform(self, X, subjects, centered=True, y=None): ---------- X : list of 2D arrays, element i has shape=[voxels_i, samples_i] - Each element in the list contains the new fMRI data of one subject + Each element in the list contains the new fMRI data of one + subject. subjects : list of string, element i is the name of subject of X[i] centered : (optional) bool, if the data in X is already centered. - If centered = False, the voxel means computed during mode fitting - will be subtracted before transformation. + If centered = False, the voxel means computed during mode + fitting will be subtracted before transformation. y : not used (as it is unsupervised learning) - Returns ------- @@ -1110,48 +1451,108 @@ def transform(self, X, subjects, centered=True, y=None): if not hasattr(self, 'w_'): raise NotFittedError("The model fit has not been run yet.") - # Check if the subject exist in the fitted model and has the right number of voxels + # Check if the subject exist in the fitted model and has the right + # number of voxels for idx in range(len(X)): if not subjects[idx] in self.w_: - raise NotFittedError("The model has not been fitted to subject {}.".format(subjects[idx])) - if X[idx] is not None and self.w_[subjects[idx]].shape[0] != X[idx].shape[0]: - raise ValueError("{}-th element of data has inconsistent number of" - "voxels with fitted model. Model has {} voxels while data has {}" - ".".format(idx, self.w_[subjects[idx]].shape[0], X[idx].shape[0])) + raise NotFittedError("The model has not been fitted to " + "subject {}.".format(subjects[idx])) + if X[idx] is not None and (self.w_[subjects[idx]].shape[0] != + X[idx].shape[0]): + raise ValueError("{}-th element of data has inconsistent " + "number of voxels with fitted model. Model " + "has {} voxels while data has {}." + .format(idx, self.w_[subjects[idx]]. + shape[0], X[idx].shape[0])) if not centered and self.mu_ is None: - raise Exception('Mean values are not computed during model fitting. ' - 'Please center the data to be transformed beforehand.') - + raise Exception('Mean values are not computed during model ' + 'fitting. Please center the data to be ' + 'transformed beforehand.') s = [None] * len(X) for idx in range(len(X)): if X[idx] is not None: if centered: s[idx] = self.w_[subjects[idx]].T.dot(X[idx]) - else: - s[idx] = self.w_[subjects[idx]].T.dot(X[idx]-self.mu_[subjects[idx]][:, None]) + else: + s[idx] = self.w_[subjects[idx]].T.\ + dot(X[idx] - self.mu_ + [subjects[idx]][:, None]) return s + def _compute_mean(self, x, datasets): + """Compute the mean of data. + + Parameters + ---------- + x : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] + where d is the name of the dataset and s is the name of the + subject. + Demeaned data for each subject. + + datasets : a Dataset object + The Dataset object containing datasets structures. + + Returns + ------- + + mu : dict of array, mu_[s] has shape=[voxels_[s]] where s is the + name of the subject. + The voxel means over the samples in all datasets for each + subject. + """ + # collect mean from each MPI worker + weights = {} + mu_tmp = {} + for subj in datasets.subject_to_idx.keys(): + weights[subj], mu_tmp[subj] = {}, {} + for ds in x.keys(): + if subj in x[ds]: + if x[ds][subj] is not None: + mu_tmp[subj][ds] = np.mean(x[ds][subj], 1) + weights[subj][ds] = x[ds][subj].shape[1] + else: + mu_tmp[subj][ds] = np.zeros(( + self.voxels_[subj],)) + weights[subj][ds] = 0 + + # collect mean from all MPI workers + for subj in datasets.subject_to_idx.keys(): + for ds in mu_tmp[subj].keys(): + mu_tmp[subj][ds] = self.comm.allreduce( + mu_tmp[subj][ds], op=MPI.SUM) + weights[subj][ds] = self.comm.allreduce( + weights[subj][ds], op=MPI.SUM) + + # compute final mean + mu = {} + for subj in datasets.subject_to_idx.keys(): + mu[subj] = np.zeros((self.voxels_[subj],)) + nsample = np.sum(list(weights[subj].values())) + for ds in mu_tmp[subj].keys(): + mu[subj] += weights[subj][ds] * mu_tmp[subj][ds] /\ + nsample + return mu def _preprocess_data(self, data, datasets, demean): """Preprocess and demean the data. - Parameters ---------- data : dict of list of 2D arrays or dict of dict of 2D arrays 1) When it is a dict of list of 2D arrays: 'datasets' must be defined in this case. - X[d] is a list of data of dataset d, where d is the name of the dataset. - Element i in the list has shape=[voxels_i, samples_d] + X[d] is a list of data of dataset d, where d is the name of + the dataset. + Element i in the list has shape=[voxels_i, samples_d] which is the fMRI data of the i'th subject in d. 2) When it is a dict of dict of 2D arrays: 'datasets' can be omitted in this case. - X[d][s] has shape=[voxels_s, samples_d], which is the fMRI data of - subject s in dataset d, where s is the name of the subject and - d is the name of the dataset. + X[d][s] has shape=[voxels_s, samples_d], which is the fMRI + data of subject s in dataset d, where s is the name of the + subject and d is the name of the dataset. datasets : a Dataset object The Dataset object containing datasets structures. @@ -1163,78 +1564,55 @@ def _preprocess_data(self, data, datasets, demean): Returns ------- x : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] - where d is the name of the dataset and s is the name of the subject + where d is the name of the dataset and s is the name of the + subject. Demeaned data for each subject. - mu : dict of array, mu_[s] has shape=[voxels_[s]] where s is the name - of the subject. - The voxel means over the samples in all datasets for each subject. + mu : dict of array, mu_[s] has shape=[voxels_[s]] where s is the + name of the subject. + The voxel means over the samples in all datasets for each + subject. """ x = {} - mu = {} # re-arrange data to x for ds_idx, ds in datasets.idx_to_dataset.items(): x[ds] = {} for subj in range(datasets.num_subj): - if datasets.dok_matrix[subj,ds_idx] != 0: + if datasets.dok_matrix[subj, ds_idx] != 0: if datasets.built_from_data: - x[ds][datasets.idx_to_subject[subj]] = data[ds][datasets.idx_to_subject[subj]] + x[ds][datasets.idx_to_subject[subj]] = \ + data[ds][datasets.idx_to_subject[subj]] else: - x[ds][datasets.idx_to_subject[subj]] = data[ds][datasets.dok_matrix[subj,ds_idx]-1] + x[ds][datasets.idx_to_subject[subj]] = \ + data[ds][datasets.dok_matrix[subj, ds_idx]-1] del data # compute mean if demean: - # collect mean from each MPI worker - weights = {} - mu_tmp = {} - for subj in datasets.subject_to_idx.keys(): - weights[subj], mu_tmp[subj] = {}, {} - for ds in x.keys(): - if subj in x[ds]: - if x[ds][subj] is not None: - mu_tmp[subj][ds] = np.mean(x[ds][subj], 1) - weights[subj][ds] = x[ds][subj].shape[1] - else: - mu_tmp[subj][ds] = np.zeros((self.voxels_[subj],)) - weights[subj][ds] = 0 - # collect mean from all MPI workers - for subj in datasets.subject_to_idx.keys(): - for ds in mu_tmp[subj].keys(): - mu_tmp[subj][ds] = self.comm.allreduce(mu_tmp[subj][ds], op=MPI.SUM) - weights[subj][ds] = self.comm.allreduce(weights[subj][ds], op=MPI.SUM) - # compute final mean - for subj in datasets.subject_to_idx.keys(): - mu[subj] = np.zeros((self.voxels_[subj],)) - nsample = np.sum(list(weights[subj].values())) - for ds in mu_tmp[subj].keys(): - mu[subj] += weights[subj][ds] * mu_tmp[subj][ds] / nsample - del weights, mu_tmp - + mu = self._compute_mean(x, datasets) # subtract mean from x for ds in x.keys(): for subj in x[ds].keys(): if x[ds][subj] is not None: - x[ds][subj] -= mu[subj][:,None] - + x[ds][subj] -= mu[subj][:, None] else: mu = None return x, mu - def _objective_function(self, data, subj_ds_list, w, s, num_sample): """Calculate the objective function Parameters ---------- - data : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] - where d is the name of the dataset and s is the name of the subject + data : dict of dict of array, x[d][s] has shape=[voxels[s], + samples[d]] where d is the name of the dataset and s is the name + of the subject. Demeaned data for each subject. - subj_ds_list : dict of list of string, subj_ds_list[d] is a list + subj_ds_list : dict of list of string, subj_ds_list[d] is a list of names of subjects in dataset d, where d is the name of the subject. @@ -1246,7 +1624,8 @@ def _objective_function(self, data, subj_ds_list, w, s, num_sample): d is the name of the dataset. The shared response for each dataset. - num_sample : int, total number of samples across all datasets and datasets + num_sample : int, total number of samples across all datasets and + datasets Returns ------- @@ -1257,18 +1636,18 @@ def _objective_function(self, data, subj_ds_list, w, s, num_sample): Note ---- - In the multi nodes mode where data is scattered in different nodes, + In the multi nodes mode where data is scattered in different nodes, objective needs to be reduced (summed) afterwards. """ objective = 0.0 for ds in subj_ds_list.keys(): for subj in subj_ds_list[ds]: if data[ds][subj] is not None: - objective += \ - np.linalg.norm(data[ds][subj] - w[subj].dot(s[ds]), 'fro') ** 2 + objective += np.linalg.norm(data[ds][subj] - + w[subj].dot(s[ds]), + 'fro') ** 2 return 0.5 * objective / num_sample - def _compute_shared_response(self, data, subj_ds_list, w): """ Compute the shared response S of all datasets @@ -1276,13 +1655,14 @@ def _compute_shared_response(self, data, subj_ds_list, w): Parameters ---------- - data : dict of dict of array, x[d][s] has shape=[voxels[s], samples[d]] - where d is the name of the dataset and s is the name of the subject + data : dict of dict of array, data[d][s] has shape=[voxels[s], + samples[d]] where d is the name of the dataset and s is the name + of the subject Demeaned data for each subject. - subj_ds_list : dict of list of string, subj_ds_list[d] is a list + subj_ds_list : dict of list of string, subj_ds_list[d] is a list of names of subjects in dataset d, where d is the name - of the subject. + of the dataset. w : dict of array, w[s] has shape=[voxels_[s], features] where s is the name of the subject. @@ -1298,21 +1678,20 @@ def _compute_shared_response(self, data, subj_ds_list, w): Note ---- - In the multi nodes mode where data is scattered in different nodes, + In the multi nodes mode where data is scattered in different nodes, s needs to be gathered afterwards. - To get the final s, the returned s[d] needs to be devided by number + To get the final s, the returned s[d] needs to be devided by number of subjects in dataset d. """ s = {} - for ds in subj_ds_list.keys(): + for ds in subj_ds_list.keys(): s[ds] = np.zeros((self.features, self.samples_[ds])) for subj in subj_ds_list[ds]: if data[ds][subj] is not None: s[ds] += w[subj].T.dot(data[ds][subj]) return s - @staticmethod def _update_transform_subject(Xi, S): """Updates the mappings `W_i` for one subject. @@ -1337,7 +1716,6 @@ def _update_transform_subject(Xi, S): U, _, V = np.linalg.svd(A, full_matrices=False) return U.dot(V) - def transform_subject(self, X, dataset): """Transform a new subject using the existing model. The subject is assumed to have received equivalent stimulation @@ -1349,7 +1727,7 @@ def transform_subject(self, X, dataset): X : 2D array, shape=[voxels, timepoints] The fMRI data of the new subject. - dataset : string, name of the dataset in the fitted model that + dataset : string, name of the dataset in the fitted model that has the same stimulation as the new subject Returns @@ -1364,34 +1742,86 @@ def transform_subject(self, X, dataset): raise NotFittedError("The model fit has not been run yet.") # Check if the dataset is in the model - if not dataset in self.s_: - raise NotFittedError("Dataset {} is not in the model yet.".format(dataset)) + if dataset not in self.s_: + raise NotFittedError("Dataset {} is not in the model yet." + .format(dataset)) # Check the number of TRs in the subject if X.shape[1] != self.s_[dataset].shape[1]: - raise ValueError("The number of timepoints(TRs) does not match the" - "one in the model.") + raise ValueError("The number of timepoints(TRs) does not match " + "the one in the model.") w = self._update_transform_subject(X, self.s_[dataset]) return w + def _compute_w_subj(self, x, ds_subj_list, shared_response, rank): + """ Compute the transformation matrix W of all subjects + + Parameters + ---------- + + x : dict of dict of array, x[d][s] has shape=[voxels[s], + samples[d]] where d is the name of the dataset and s is the name + of the subject + Demeaned data for each subject. + + ds_subj_list : dict of list of string, ds_subj_list[s] is a list + of names of datasets with subject s, where s is the name + of the subject. + + shared_response : dict of array, shared_response[d] has + shape=[features, samples_[d]] where d is the name of the dataset. + The shared response for each dataset. + + rank: int, current MPI rank + + Returns + ------- + + w : dict of array, w[d] has shape=[voxels_[s], features] where + s is the name of the subject. + The transformation matrix for each subject. + + """ + w = {} + for subj in ds_subj_list.keys(): + a_subject = np.zeros((self.voxels_[subj], self.features)) + # use x data from all ranks + for ds in ds_subj_list[subj]: + if x[ds][subj] is not None: + a_subject += x[ds][subj].dot(shared_response[ds].T) + # collect a_subject from all ranks + a_subject = self.comm.allreduce(a_subject, op=MPI.SUM) + # compute w in one rank and broadcast + if rank == 0: + perturbation = np.zeros(a_subject.shape) + np.fill_diagonal(perturbation, 0.0001) + u_subject, _, v_subject = np.linalg.svd( + a_subject + perturbation, full_matrices=False) + w[subj] = u_subject.dot(v_subject) + else: + w[subj] = None + w[subj] = self.comm.bcast(w[subj], root=0) + return w def _mdms(self, data, datasets, demean): - """Block Coordinate Descent algorithm for fitting the deterministic MDMS. + """Block Coordinate Descent algorithm for fitting the deterministic + MDMS. Parameters ---------- data : dict of list of 2D arrays or dict of dict of 2D arrays 1) When it is a dict of list of 2D arrays: - data[d] is a list of data of dataset d, where d is the name of the dataset. - Element i in the list has shape=[voxels_i, samples_[d]] + data[d] is a list of data of dataset d, where d is the name + of the dataset. + Element i in the list has shape=[voxels_i, samples_[d]] which is the fMRI data of the i'th subject in dataset d. 2) When it is a dict of dict of 2D arrays: - data[d][s] has shape=[voxels_[s], samples_[d]], which is the fMRI data of - subject s in dataset d, where s is the name of the subject and - d is the name of the dataset. + data[d][s] has shape=[voxels_[s], samples_[d]], which is the + fMRI data of subject s in dataset d, where s is the name of + the subject and d is the name of the dataset. datasets : a Dataset object The Dataset object containing datasets structure. @@ -1413,38 +1843,41 @@ def _mdms(self, data, datasets, demean): """ # get information from datasets structure - ds_list, subj_list = datasets.get_datasets_list(), datasets.get_subjects_list() + ds_list, subj_list = datasets.get_datasets_list(),\ + datasets.get_subjects_list() subj_ds_list = datasets.subjects_in_dataset_all() ds_subj_list = datasets.datasets_with_subject_all() - num_sample = np.sum([datasets.num_subj_dataset[ds]*self.samples_[ds] for ds in ds_list]) + num_sample = np.sum([datasets.num_subj_dataset[ds] * + self.samples_[ds] for ds in ds_list]) # initialize random states self.random_state_ = np.random.RandomState(self.rand_seed) - random_states = { - subj_list[i] : np.random.RandomState(self.random_state_.randint(2 ** 32)) - for i in range(datasets.num_subj)} + random_states = {subj_list[i]: np.random.RandomState( + self.random_state_.randint(2 ** 32)) + for i in range(datasets.num_subj)} rank = self.comm.Get_rank() - size = self.comm.Get_size() - # Initialization step: + # Initialization step: # 1) preprocess data # 2) initialize the outputs with initial values w = _init_w_transforms(self.voxels_, self.features, random_states, - datasets) + datasets) x, mu = self._preprocess_data(data, datasets, demean) del data # compute shared_response from data in this rank shared_response = self._compute_shared_response(x, subj_ds_list, w) # collect shared_response data from all ranks for ds in ds_list: - shared_response[ds] = self.comm.allreduce(shared_response[ds], op=MPI.SUM) + shared_response[ds] = self.comm.allreduce(shared_response[ds], + op=MPI.SUM) shared_response[ds] /= datasets.num_subj_dataset[ds] if self.logger.isEnabledFor(logging.INFO): # Calculate the current objective function value - objective = self._objective_function(x, subj_ds_list, w, shared_response, num_sample) + objective = self._objective_function(x, subj_ds_list, w, + shared_response, num_sample) objective = self.comm.allreduce(objective, op=MPI.SUM) if rank == 0: self.logger.info('Objective function %f' % objective) @@ -1455,50 +1888,36 @@ def _mdms(self, data, datasets, demean): self.logger.info('Iteration %d' % (iteration + 1)) # Update each subject's mapping transform W_s: - for subj in subj_list: - a_subject = np.zeros((self.voxels_[subj], self.features)) - # use x data from all ranks - for ds in ds_subj_list[subj]: - if x[ds][subj] is not None: - a_subject += x[ds][subj].dot(shared_response[ds].T) - # collect a_subject from all ranks - a_subject = self.comm.allreduce(a_subject, op=MPI.SUM) - # compute w in one rank and broadcast - if rank == 0: - perturbation = np.zeros(a_subject.shape) - np.fill_diagonal(perturbation, 0.0001) - u_subject, _, v_subject = np.linalg.svd( - a_subject + perturbation, full_matrices=False) - w[subj] = u_subject.dot(v_subject) - else: - w[subj] = None - w[subj] = self.comm.bcast(w[subj], root=0) + w = self._compute_w_subj(x, ds_subj_list, shared_response, rank) # Update the each dataset's shared response S_d: # compute shared_response from data in this rank - shared_response = self._compute_shared_response(x, subj_ds_list, w) + shared_response = self._compute_shared_response( + x, subj_ds_list, w) # collect shared_response data from all ranks for ds in ds_list: - shared_response[ds] = self.comm.allreduce(shared_response[ds], op=MPI.SUM) + shared_response[ds] = self.comm.allreduce( + shared_response[ds], op=MPI.SUM) shared_response[ds] /= datasets.num_subj_dataset[ds] if self.logger.isEnabledFor(logging.INFO): # Calculate the current objective function value - objective = self._objective_function(x, subj_ds_list, w, shared_response, num_sample) + objective = self._objective_function(x, subj_ds_list, w, + shared_response, + num_sample) objective = self.comm.allreduce(objective, op=MPI.SUM) if rank == 0: self.logger.info('Objective function %f' % objective) return w, shared_response, mu - def save(self, file): """Save the DetMDMS object to a file (as pickle) Parameters ---------- - file : The name (including full path) of the file that the object + file : The name (including full path) of the file that the object will be saved to. Returns @@ -1515,7 +1934,7 @@ def save(self, file): """ # get attributes from object variables = self.__dict__.keys() - data = {k:getattr(self, k) for k in variables} + data = {k: getattr(self, k) for k in variables} # remove attributes that cannot be pickled del data['comm'] del data['logger'] @@ -1524,9 +1943,8 @@ def save(self, file): # save attributes to file with open(file, 'wb') as f: pkl.dump(data, f, pkl.HIGHEST_PROTOCOL) - print ('DetMDMS object saved to {}.'.format(file)) - return - + self.logger.info('DetMDMS object saved to {}.'.format(file)) + return def restore(self, file): """Restore the DetMDMS object from a (pickle) file @@ -1534,7 +1952,7 @@ def restore(self, file): Parameters ---------- - file : The name (including full path) of the file that the object + file : The name (including full path) of the file that the object will be restored from. Returns @@ -1545,7 +1963,7 @@ def restore(self, file): Note ---- - The MPI communicator cannot be saved, so self.comm is initialized to + The MPI communicator cannot be saved, so self.comm is initialized to MPI.COMM_SELF """ @@ -1559,7 +1977,7 @@ def restore(self, file): self.comm = MPI.COMM_SELF self.random_state_ = np.random.RandomState(self.rand_seed) self.logger = logger - print ('DetMDMS object restored from {}.'.format(file)) + self.logger.info('DetMDMS object restored from {}.'.format(file)) return @@ -1567,11 +1985,11 @@ class Dataset(object): """Datasets structure organizer Given multi-dataset multi-subject data or JSON files with subject names - in each dataset, infer datasets structure in different formats, such as - a graph where each dataset is a node and each edge is number of shared - subjects between the two datasets. + in each dataset, infer datasets structure in different formats, such as + a graph where each dataset is a node and each edge is number of shared + subjects between the two datasets. - This organizer is used in the MDMS or DetMDMS [Zhang2018]_ and can also + This organizer is used in the MDMS or DetMDMS [Zhang2018]_ and can also be used as a standalone datasets organizer. @@ -1581,91 +1999,94 @@ class Dataset(object): file : (optional) string, default: None JSON file name (including full path) or folder name with JSON files. - Each JSON file should contain a dict or a list of dict where each dict - has information of one dataset. Each dict must have 'dataset', - 'num_of_subj', and 'subjects' where 'dataset' is the name of the dataset, - 'num_of_subj' is the number of subjects in the dataset, and 'subjects' - is a list of strings with names of subjects in the dataset in the same - order as in the dataset. All datasets in all JSON files will be added - to the organizer. - - Example of a JSON file: - [{'dataset':'MyData','num_of_subj':3,'subjects':['Adam','Bob','Carol']}, + Each JSON file should contain a dict or a list of dict where each + dict has information of one dataset. Each dict must have 'dataset', + 'num_of_subj', and 'subjects' where 'dataset' is the name of the + dataset, 'num_of_subj' is the number of subjects in the dataset, and + 'subjects' is a list of strings with names of subjects in the + dataset in the same order as in the dataset. All datasets in all + JSON files will be added to the organizer. + + Example of a JSON file: + [{'dataset':'MyData','num_of_subj':3,'subjects': + ['Adam','Bob','Carol']}, {'dataset':'MyData2','num_of_subj':2,'subjects':['Tom','Bob']}] - data : (optional) dict of dict of 2D array, default: None Multi-dataset multi-subject data used to build the organizer. - data[d][s] has shape=[voxels[s], samples[d]], where d is the name of - the dataset and s is the name of the subject. + data[d][s] has shape=[voxels[s], samples[d]], where d is the name of + the dataset and s is the name of the subject. Attributes ---------- - num_subj : int, - Total number of subjects + num_subj : int, + Total number of subjects - num_dataset : int, + num_dataset : int, Total number of datasets - dataset_to_idx : dict of int, dataset_to_idx[d] is the column index of dataset d - in self.matrix, where d is the name of the dataset. + dataset_to_idx : dict of int, dataset_to_idx[d] is the column index of + dataset d in self.matrix, where d is the name of the dataset. Dataset name to column index of matrix, 0-indexed - idx_to_dataset : dict of string, idx_to_dataset[i] is name of the dataset mapped - to the i'th column in self.matrix. + idx_to_dataset : dict of string, idx_to_dataset[i] is name of the + dataset mapped to the i'th column in self.matrix. Column index of metrix to dataset name, 0-indexed - subject_to_idx : dict of int, subject_to_idx[s] is the row index of subject s - in self.matrix, where s is the name of the subject. + subject_to_idx : dict of int, subject_to_idx[s] is the row index of + subject s in self.matrix, where s is the name of the subject. Subject name to row index of matrix, 0-indexed - idx_to_subject : dict of string, idx_to_subject[i] is name of the subject mapped - to the i'th row in self.matrix. + idx_to_subject : dict of string, idx_to_subject[i] is name of the + subject mapped to the i'th row in self.matrix. Row index to subject name, 0-indexed - connected : list of list of string, each element is a list of name of connected - datasets (datasets can be connected through shared subjects). - - num_graph : int, - Number of connected dataset graphs - If 1, then all datasets are connected. - - adj_matrix : 2D csc sparse matrix of shape [num_dataset, num_dataset], - Weighted adjacency matrix of all datasets, where each node is a dataset and - weights on edges are number of shared subjects between the two datasets. - Mapping between dataset name and dataset index is in self.dataset_to_idx. - - num_subj_dataset : dict of int, num_subj_dataset[d] is an int where d is the name - of a dataset. + connected : list of list of string, each element is a list of name of + connected datasets (datasets can be connected through shared + subjects). + + num_graph : int, + Number of connected dataset graphs + If 1, then all datasets are connected. + + adj_matrix : 2D csc sparse matrix of shape [num_dataset, num_dataset], + Weighted adjacency matrix of all datasets, where each node is a + dataset and weights on edges are number of shared subjects between + the two datasets. + Mapping between dataset name and dataset index is in + self.dataset_to_idx. + + num_subj_dataset : dict of int, num_subj_dataset[d] is an int where d is + the name of a dataset. Number of subjects of each dataset - subj_in_dataset : dict of list of string, subj_in_dataset[d] is a list of name - of subjects in dataset d in the same order as in d, where d is the name - of a dataset. If any subject is removed from the organizer, the name will - be replaced with None as a placeholder. + subj_in_dataset : dict of list of string, subj_in_dataset[d] is a list + of name of subjects in dataset d in the same order as in d, where d + is the name of a dataset. If any subject is removed from the + organizer, the name will be replaced with None as a placeholder. Name of subjects in each dataset - matrix : 2D coo sparse matrix of shape [num_subj, num_dataset], + matrix : 2D coo sparse matrix of shape [num_subj, num_dataset], Dataset-subject membership matrix. - If built from JSON files, subject self.idx_to_subject[i] is the + If built from JSON files, subject self.idx_to_subject[i] is the self.matrix[i,j]'th subject in self.idx_to_dataset[j], 1-indexed - If built from multi-dataset multi-subject data, self.matrix[i,j] = 1 if - subject self.idx_to_subject[i] is in dataset self.idx_to_dataset[j]. + If built from multi-dataset multi-subject data, self.matrix[i,j] = 1 + if subject self.idx_to_subject[i] is in dataset self.idx_to_dataset + [j]. - dok_matrix : 2D dok sparse matrix of shape [num_subj, num_dataset], + dok_matrix : 2D dok sparse matrix of shape [num_subj, num_dataset], Dataset-subject membership matrix. - It has the same content as self.matrix, but in Dictionary Of Keys format - for fast access of individual access. - + It has the same content as self.matrix, but in Dictionary Of Keys + format for fast access of individual access. + built_from_data : bool, If the object is built from multi-dataset multi-subject data - If True, the object is built from data; if False, it is built from + If True, the object is built from data; if False, it is built from JSON files. - Note ---- @@ -1674,23 +2095,24 @@ class Dataset(object): """ def __init__(self, file=None, data=None): - self.num_subj = 0 - self.num_dataset = 0 - self.dataset_to_idx = {} - self.idx_to_dataset = {} - self.subject_to_idx = {} - self.idx_to_subject = {} - self.connected = [] - self.num_graph = 0 - self.adj_matrix = None - self.num_subj_dataset = {} + self.num_subj = 0 + self.num_dataset = 0 + self.dataset_to_idx = {} + self.idx_to_dataset = {} + self.subject_to_idx = {} + self.idx_to_subject = {} + self.connected = [] + self.num_graph = 0 + self.adj_matrix = None + self.num_subj_dataset = {} self.subj_in_dataset = {} - self.matrix = None + self.matrix = None self.dok_matrix = None self.built_from_data = None if file is not None and data is not None: - raise Exception('Dataset object can only be built from data OR json files.') + raise Exception('Dataset object can only be built from data OR ' + 'JSON files.') if file is not None: self.add(file) @@ -1699,7 +2121,6 @@ def __init__(self, file=None, data=None): self.build_from_data(data) return - def add(self, file): """Add JSON file(s) to the organizer @@ -1709,17 +2130,19 @@ def add(self, file): file : string, default: None JSON file name (including full path) or folder name with JSON files. - Each JSON file should contain a dict or a list of dict where each dict - has information of one dataset. Each dict must have 'dataset', - 'num_of_subj', and 'subjects' where 'dataset' is the name of the dataset, - 'num_of_subj' is the number of subjects in the dataset, and 'subjects' - is a list of strings with names of subjects in the dataset in the same - order as in the dataset. All datasets in all JSON files will be added - to the organizer. If some datasets are already in the organizer, the - information of those datasets will be replaced with this new version. - - Example of a JSON file: - [{'dataset':'MyData','num_of_subj':3,'subjects':['Adam','Bob','Carol']}, + Each JSON file should contain a dict or a list of dict where each + dict has information of one dataset. Each dict must have 'dataset', + 'num_of_subj', and 'subjects' where 'dataset' is the name of the + dataset, 'num_of_subj' is the number of subjects in the dataset, and + 'subjects' is a list of strings with names of subjects in the + dataset in the same order as in the dataset. All datasets in all + JSON files will be added to the organizer. If some datasets are + already in the organizer, the information of those datasets will be + replaced with this new version. + + Example of a JSON file: + [{'dataset':'MyData','num_of_subj':3,'subjects': + ['Adam','Bob','Carol']}, {'dataset':'MyData2','num_of_subj':2,'subjects':['Tom','Bob']}] Returns @@ -1729,7 +2152,8 @@ def add(self, file): """ # sanity check if self.built_from_data is not None and self.built_from_data: - raise Exception('This Dataset object was already initialized with fMRI datasets.') + raise Exception('This Dataset object was already initialized ' + 'with fMRI datasets.') # file can be json file name or folder name # parse json filenames @@ -1744,9 +2168,9 @@ def add(self, file): else: raise Exception('Argument must be a filename or a path.') - mem = [] # collect info of all datasets + mem = [] # collect info of all datasets for f in files: - tmp = json.load(open(f,'r')) + tmp = json.load(open(f, 'r')) if type(tmp) == list: # multiple datasets mem.extend(tmp) @@ -1756,56 +2180,9 @@ def add(self, file): else: raise Exception('JSON file must be in list or dict format.') - # separate datasets into new datasets and datasets to update - new_ds, new_sub, replace_ds, ds_dict = set(), set(), set(), {} - for m in mem: - # sanity check - if m['num_of_subj'] <= 0: - raise Exception('Number of subjects in dataset ' + m['dataset'] + ' must be positive.') - if m['num_of_subj'] != len(m['subjects']): - raise Exception('Number of subjects in dataset ' + m['dataset'] + ' does not agree.') - if m['dataset'] in new_ds or m['dataset'] in replace_ds: - raise Exception('Dataset ' + m['dataset'] + ' appears more than once.') - if len(m['subjects']) != len(set(m['subjects'])): - raise Exception('Dataset ' + m['dataset'] + ' has duplicate subjects.') - - # if the dataset is already in the matrix - if m['dataset'] in self.dataset_to_idx: - replace_ds.add(m['dataset']) - else: - new_ds.add(m['dataset']) - - # save subjects info into a dict - ds_dict[m['dataset']] = m['subjects'] - - # add new subjects in this dataset - for subj in m['subjects']: - if not subj in self.subject_to_idx: - new_sub.add(subj) - - # add number of subjects info if mem passes all the sanity check - for m in mem: - self.num_subj_dataset[m['dataset']] = m['num_of_subj'] - - del mem - - # construct or update the matrix - if self.matrix is None: - # construct a new matrix - self._construct_matrix(new_ds, new_sub, ds_dict) - else: - # add new datasets - self._add_new_dataset(new_ds, new_sub, ds_dict) - if replace_ds: - # replace some old datasets - self._replace_dataset(replace_ds, ds_dict) - self._compute_connected() - - self.built_from_data = False - + self._add_mem(mem) # add the information read from JSON files return - def build_from_data(self, data): """Use multi-dataset multi-subject data to initialize the organizer @@ -1814,8 +2191,8 @@ def build_from_data(self, data): data : dict of dict of 2D array Multi-dataset multi-subject data used to build the organizer. - data[d][s] has shape=[voxels[s], samples[d]], where d is the name of - the dataset and s is the name of the subject. + data[d][s] has shape=[voxels[s], samples[d]], where d is the name of + the dataset and s is the name of the subject. Returns ------- @@ -1824,18 +2201,21 @@ def build_from_data(self, data): """ # sanity check if self.built_from_data is not None and not self.built_from_data: - raise Exception('This Dataset object was already initialized with JSON files.') + raise Exception('This Dataset object was already initialized ' + 'with JSON files.') # find out which datasets and subjects are in the data if not type(data) == dict: - raise Exception('To build Dataset object from data, data must be a dict of dict ' - 'where data[d][s] is the fMRI data of dataset d and subject s.') + raise Exception('To build Dataset object from data, data must be' + ' a dict of dict where data[d][s] is the fMRI ' + 'data of dataset d and subject s.') datasets = set(data.keys()) subjects = set() for ds in data: if not type(data[ds]) == dict: - raise Exception('To build Dataset object from data, data must be a dict of dict ' - 'where data[d][s] is the fMRI data of dataset d and subject s.') + raise Exception('To build Dataset object from data, data ' + 'must be a dict of dict where data[d][s] is ' + 'the fMRI data of dataset d and subject s.') subjects.update(set(data[ds].keys())) # set attributes @@ -1861,14 +2241,14 @@ def build_from_data(self, data): coo_data.append(1) col.append(col_idx) row.append(self.subject_to_idx[subj]) - self.matrix = sp.coo_matrix((coo_data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.matrix = sp.coo_matrix((coo_data, (row, col)), + shape=(self.num_subj, self.num_dataset)) self.dok_matrix = self.matrix.todok(copy=True) # compute connectivity self._compute_connected() self.built_from_data = True - return - + return def remove_dataset(self, datasets): """Remove some datasets from the organizer @@ -1887,14 +2267,16 @@ def remove_dataset(self, datasets): """ # sanity check for ds in datasets: - if not ds in self.dataset_to_idx: - raise Exception('Dataset '+ ds + ' does not exist.') + if ds not in self.dataset_to_idx: + raise Exception('Dataset ' + ds + ' does not exist.') # extract data from the sparse matrix - data, row, col = self.matrix.data.tolist(), self.matrix.row.tolist(), self.matrix.col.tolist() - + data, row, col = self.matrix.data.tolist(), self.matrix.row.\ + tolist(), self.matrix.col.tolist() + # remove datasets from data - data, row, col, subj_to_check = self._remove_datasets_from_data(datasets, data, row, col) + data, row, col, subj_to_check = self._remove_datasets_from_data( + datasets, data, row, col) # if all datasets are removed if not data: @@ -1909,13 +2291,14 @@ def remove_dataset(self, datasets): removed_subjects.append(subj) # re-arrange subject indices - row = self._remove_subjects_by_re_indexing(removed_subjects, row) + row = self._remove_subjects_by_re_indexing(removed_subjects, row) # re-arrange dataset indices col = self._remove_datasets_by_re_indexing(datasets, col) # re-construct the matrix - self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.matrix = sp.coo_matrix((data, (row, col)), + shape=(self.num_subj, self.num_dataset)) self.dok_matrix = self.matrix.todok(copy=True) # compute connectivity @@ -1923,7 +2306,6 @@ def remove_dataset(self, datasets): return removed_subjects - def remove_subject(self, subjects): """Remove some subjects from the organizer @@ -1941,38 +2323,41 @@ def remove_subject(self, subjects): """ # sanity check for subj in subjects: - if not subj in self.subject_to_idx: + if subj not in self.subject_to_idx: raise Exception('Subject ' + subj + ' does not exist.') # extract data from the sparse matrix - data, row, col = self.matrix.data.tolist(), self.matrix.row.tolist(), self.matrix.col.tolist() - + data, row, col = self.matrix.data.tolist(), self.matrix.row.\ + tolist(), self.matrix.col.tolist() + # remove subjects from data - data, row, col = self._remove_subjects_from_data(subjects, data, row, col) + data, row, col = self._remove_subjects_from_data( + subjects, data, row, col) # if all subjects are removed if not data: removed_datasets = list(self.dataset_to_idx.keys()) self.reset() return removed_datasets - + # find datasets without any subject removed_datasets = [] - for (k,v) in self.num_subj_dataset.items(): + for (k, v) in self.num_subj_dataset.items(): if not v: removed_datasets.append(k) for k in removed_datasets: - del self.num_subj_dataset[k] # remove from num_subj_dataset + del self.num_subj_dataset[k] # remove from num_subj_dataset del self.subj_in_dataset[k] # re-arrange subject indices - row = self._remove_subjects_by_re_indexing(subjects, row) + row = self._remove_subjects_by_re_indexing(subjects, row) # re-arrange dataset indices col = self._remove_datasets_by_re_indexing(removed_datasets, col) # re-construct the matrix - self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.matrix = sp.coo_matrix((data, (row, col)), + shape=(self.num_subj, self.num_dataset)) self.dok_matrix = self.matrix.todok(copy=True) # compute connectivity @@ -1980,9 +2365,8 @@ def remove_subject(self, subjects): return removed_datasets - def num_shared_subjects_between_datasets(self, ds1, ds2): - """Get number of shared subjects (subjects in both ds1 and ds2) + """Get number of shared subjects (subjects in both ds1 and ds2) between two datasets (ds1 and ds2) Parameters @@ -1994,46 +2378,47 @@ def num_shared_subjects_between_datasets(self, ds1, ds2): Returns ------- - num_shared : int, + num_shared : int, Number of shared subjects between ds1 and ds2 """ # sanity check for ds in [ds1, ds2]: - if not ds in self.dataset_to_idx: + if ds not in self.dataset_to_idx: raise Exception('Dataset ' + ds + 'does not exist.') # find number of shared subjects idx1, idx2 = self.dataset_to_idx[ds1], self.dataset_to_idx[ds2] return self.adj_matrix[idx1, idx2] - def shared_subjects_between_datasets(self, ds1, ds2): - """Get name of shared subjects (subjects in both ds1 and ds2) + """Get name of shared subjects (subjects in both ds1 and ds2) between two datasets (ds1 and ds2) Parameters ---------- - ds1, ds2 : string, + ds1, ds2 : string, Name of two datasets Returns ------- - shared : list of string, + shared : list of string, Name of subjects shared between ds1 and ds2 """ # sanity check for ds in [ds1, ds2]: - if not ds in self.dataset_to_idx: - raise Exception('Dataset ' + ds + 'does not exist.') + if ds not in self.dataset_to_idx: + raise Exception('Dataset ' + ds + 'does not exist.') if self.matrix is None: - raise Exception('Dataset object not initialized.') - # find shared subjects + raise Exception('Dataset object not initialized.') + # find shared subjects matrix_csc = self.matrix.tocsc(copy=True) - subj1 = set(matrix_csc[:,self.dataset_to_idx[ds1]].indices) # indices of subjects in ds1 - subj2 = set(matrix_csc[:,self.dataset_to_idx[ds2]].indices) # indices of subjects in ds2 - return [self.idx_to_subject[subj] for subj in subj1.intersection(subj2)] - + # indices of subjects in ds1 + subj1 = set(matrix_csc[:, self.dataset_to_idx[ds1]].indices) + # indices of subjects in ds2 + subj2 = set(matrix_csc[:, self.dataset_to_idx[ds2]].indices) + return [self.idx_to_subject[subj] for subj in + subj1.intersection(subj2)] def datasets_with_subject(self, subj): """Get all datasets with some subject ('subj') @@ -2041,26 +2426,25 @@ def datasets_with_subject(self, subj): Parameters ---------- - subj : string, + subj : string, Name of the subject Returns ------- - datasets : list of string, + datasets : list of string, Name of datasets with subject 'subj' """ # sanity check - if not subj in self.subject_to_idx: - raise Exception('Subject ' + subj + 'does not exist.') + if subj not in self.subject_to_idx: + raise Exception('Subject ' + subj + 'does not exist.') if self.matrix is None: raise Exception('Dataset object not initialized.') # find datasets with subject matrix_csr = self.matrix.tocsr(copy=True) - indices = matrix_csr[self.subject_to_idx[subj],:].indices + indices = matrix_csr[self.subject_to_idx[subj], :].indices return [self.idx_to_dataset[ds] for ds in indices] - def datasets_with_subject_all(self): """For each subject, get a list of datasets with that subject @@ -2072,8 +2456,8 @@ def datasets_with_subject_all(self): Returns ------- - ds_subj_list : dict of list of string, ds_subj_list[s] is a list where s - is the name of a subject + ds_subj_list : dict of list of string, ds_subj_list[s] is a list + where s is the name of a subject. List of datasets with subject s for each subject s """ if self.matrix is None: @@ -2082,10 +2466,10 @@ def datasets_with_subject_all(self): matrix_csr = self.matrix.tocsr(copy=True) for subj in range(self.num_subj): subj_name = self.idx_to_subject[subj] - indices = matrix_csr[subj,:].indices - ds_subj_list[subj_name] = [self.idx_to_dataset[ds] for ds in indices] - return ds_subj_list - + indices = matrix_csr[subj, :].indices + ds_subj_list[subj_name] = [self.idx_to_dataset[ds] + for ds in indices] + return ds_subj_list def subjects_in_dataset(self, dataset): """Get all subjects in some dataset ('dataset') @@ -2093,26 +2477,25 @@ def subjects_in_dataset(self, dataset): Parameters ---------- - dataset : string, + dataset : string, Name of the dataset Returns ------- - subjects : list of string, + subjects : list of string, Name of subjects in dataset 'dataset' """ - #sanity check - if not dataset in self.dataset_to_idx: - raise Exception('Dataset ' + dataset + 'does not exist.') + # sanity check + if dataset not in self.dataset_to_idx: + raise Exception('Dataset ' + dataset + 'does not exist.') if self.matrix is None: raise Exception('Dataset object not initialized.') # find subjects in dataset matrix_csc = self.matrix.tocsc(copy=True) - indices = matrix_csc[:,self.dataset_to_idx[dataset]].indices + indices = matrix_csc[:, self.dataset_to_idx[dataset]].indices return [self.idx_to_subject[subj] for subj in indices] - def subjects_in_dataset_all(self): """For each dataset, get a list of subjects in that dataset @@ -2124,8 +2507,8 @@ def subjects_in_dataset_all(self): Returns ------- - subj_ds_list : dict of list of string, subj_ds_list[d] is a list where d - is the name of a dataset + subj_ds_list : dict of list of string, subj_ds_list[d] is a list + where d is the name of a dataset. List of subjects in dataset d for each dataset d """ if self.matrix is None: @@ -2134,10 +2517,10 @@ def subjects_in_dataset_all(self): matrix_csc = self.matrix.tocsc(copy=True) for ds in range(self.num_dataset): ds_name = self.idx_to_dataset[ds] - indices = matrix_csc[:,ds].indices - subj_ds_list[ds_name] = [self.idx_to_subject[subj] for subj in indices] - return subj_ds_list - + indices = matrix_csc[:, ds].indices + subj_ds_list[ds_name] = [self.idx_to_subject[subj] + for subj in indices] + return subj_ds_list def get_subjects_list(self): """Get a list of all subjects in the organizer @@ -2150,12 +2533,11 @@ def get_subjects_list(self): Returns ------- - subj_list : list of string, + subj_list : list of string, Name of all subjects in the organizer """ return list(self.subject_to_idx.keys()) - def get_datasets_list(self): """Get a list of all datasets in the organizer @@ -2167,14 +2549,13 @@ def get_datasets_list(self): Returns ------- - ds_list : list of string, + ds_list : list of string, Name of all datasets in the organizer """ return list(self.dataset_to_idx.keys()) - def visualize_graph(self, font_size=14): - """Visualize the organizer as a graph where each node is a dataset + """Visualize the organizer as a graph where each node is a dataset and the edge is number of shared subjects between the two datasets Parameters @@ -2193,15 +2574,17 @@ def visualize_graph(self, font_size=14): # build graph from adjacency matrix G = nx.from_numpy_matrix(self.adj_matrix.toarray()) # assign edge labels - edge_labels=dict([((u,v,),self.adj_matrix[u,v]) for u,v in G.edges]) - pos=nx.spring_layout(G) - nx.draw(G, pos=pos,with_labels=False) - labels=nx.draw_networkx_labels(G,labels = self.idx_to_dataset, pos=pos, font_size=font_size) - edge_labels=nx.draw_networkx_edge_labels(G,edge_labels=edge_labels, pos=pos, font_size=font_size) - plt.show() + edge_labels = dict([((u, v), self.adj_matrix[u, v]) + for u, v in G.edges]) + pos = nx.spring_layout(G) + nx.draw(G, pos=pos, with_labels=False) + _ = nx.draw_networkx_labels(G, labels=self.idx_to_dataset, + pos=pos, font_size=font_size) + _ = nx.draw_networkx_edge_labels(G, edge_labels=edge_labels, + pos=pos, font_size=font_size) + plt.show() return - def reset(self): """Reset all attributes in the organizer @@ -2215,22 +2598,21 @@ def reset(self): None """ - self.num_subj = 0 - self.num_dataset = 0 - self.dataset_to_idx = {} - self.idx_to_dataset = {} - self.subject_to_idx = {} - self.idx_to_subject = {} - self.connected = [] - self.num_graph = 0 - self.adj_matrix = None - self.num_subj_dataset = {} + self.num_subj = 0 + self.num_dataset = 0 + self.dataset_to_idx = {} + self.idx_to_dataset = {} + self.subject_to_idx = {} + self.idx_to_subject = {} + self.connected = [] + self.num_graph = 0 + self.adj_matrix = None + self.num_subj_dataset = {} self.subj_in_dataset = {} - self.matrix = None + self.matrix = None self.adj_matrix = None - self.built_from_data = None - return - + self.built_from_data = None + return def save(self, file): """Save the Dataset object to a file (as pickle) @@ -2238,7 +2620,7 @@ def save(self, file): Parameters ---------- - file : The name (including full path) of the file that the object + file : The name (including full path) of the file that the object will be saved to. Returns @@ -2248,13 +2630,12 @@ def save(self, file): """ # get attributes from object variables = self.__dict__.keys() - data = {k:getattr(self, k) for k in variables} + data = {k: getattr(self, k) for k in variables} # save attributes to file with open(file, 'wb') as f: pkl.dump(data, f, pkl.HIGHEST_PROTOCOL) - print ('Dataset object saved to {}.'.format(file)) - return - + self.logger.info('Dataset object saved to {}.'.format(file)) + return def restore(self, file): """Restore the Dataset object from a (pickle) file @@ -2262,7 +2643,7 @@ def restore(self, file): Parameters ---------- - file : The name (including full path) of the file that the object + file : The name (including full path) of the file that the object will be restored from. Returns @@ -2276,10 +2657,9 @@ def restore(self, file): # set attributes to object for (k, v) in data.items(): setattr(self, k, v) - print ('Dataset object restored from {}.'.format(file)) + self.logger.info('Dataset object restored from {}.'.format(file)) return - def _compute_connected(self): """Compute the weighted adjacency matrix and connectivity @@ -2293,31 +2673,50 @@ def _compute_connected(self): None """ - # build the weighted adjacency matrix (how many shared subjects between each pair of datasets) + # build the weighted adjacency matrix (how many shared subjects + # between each pair of datasets) matrix_csc = self.matrix.tocsc(copy=True) row, col, data = [], [], [] for i in range(self.num_dataset): for j in range(i+1, self.num_dataset): - tmp = matrix_csc[:,i].multiply(matrix_csc[:,j]).nnz + tmp = matrix_csc[:, i].multiply(matrix_csc[:, j]).nnz if tmp != 0: - row.extend([i,j]) - col.extend([j,i]) + row.extend([i, j]) + col.extend([j, i]) data.extend([tmp, tmp]) - self.adj_matrix = sp.csc_matrix((data, (row, col)),shape=(self.num_dataset, self.num_dataset)) + self.adj_matrix = sp.csc_matrix((data, (row, col)), + shape=(self.num_dataset, + self.num_dataset)) + + self._compute_num_connect_graph() + return + + def _compute_num_connect_graph(self): + """Compute which datasets are connected + + Parameters + ---------- + + None + + Returns + ------- + None + """ # find out which datasets are connected not_connected = set(range(self.num_dataset)) connected = [] dq = set() for idx in range(self.num_dataset): if idx in not_connected: - tmp = [] - dq.add(idx) - while dq: + tmp = [] + dq.add(idx) + while dq: n = dq.pop() not_connected.remove(n) tmp.append(n) - for neighbor in self.adj_matrix[:,n].indices: + for neighbor in self.adj_matrix[:, n].indices: if neighbor in not_connected: dq.add(neighbor) if not dq: @@ -2332,9 +2731,77 @@ def _compute_connected(self): # count number of connected graphs self.num_graph = len(self.connected) - return + def _add_mem(self, mem): + """Add information from JSON files to the organizer + + Parameters + ---------- + + mem : list of dict, information from JSON files + + Returns + ------- + + None + """ + # separate datasets into new datasets and datasets to update + new_ds, new_sub, replace_ds, ds_dict = set(), set(), set(), {} + for m in mem: + # sanity check + err_case = [m['num_of_subj'] <= 0, + m['num_of_subj'] != len(m['subjects']), + m['dataset'] in new_ds or m['dataset'] in replace_ds, + len(m['subjects']) != len(set(m['subjects']))] + err_msg = ['Number of subjects in dataset {} must be positive.'. + format(m['dataset']), + 'Number of subjects in dataset {} does not agree.'. + format(m['dataset']), + 'Dataset {} appears more than once.'. + format(m['dataset']), + 'Dataset {} has duplicate subjects.'. + format(m['dataset'])] + + for err, msg in zip(err_case, err_msg): + if err: + raise Exception(msg) + + # if the dataset is already in the matrix + if m['dataset'] in self.dataset_to_idx: + replace_ds.add(m['dataset']) + else: + new_ds.add(m['dataset']) + + # save subjects info into a dict + ds_dict[m['dataset']] = m['subjects'] + + # add new subjects in this dataset + for subj in m['subjects']: + if subj not in self.subject_to_idx: + new_sub.add(subj) + + # add number of subjects info if mem passes all the sanity check + for m in mem: + self.num_subj_dataset[m['dataset']] = m['num_of_subj'] + + del mem + + # construct or update the matrix + if self.matrix is None: + # construct a new matrix + self._construct_matrix(new_ds, new_sub, ds_dict) + else: + # add new datasets + self._add_new_dataset(new_ds, new_sub, ds_dict) + if replace_ds: + # replace some old datasets + self._replace_dataset(replace_ds, ds_dict) + self._compute_connected() + + self.built_from_data = False + + return def _construct_matrix(self, new_ds, new_sub, ds_dict): """Initialize the organizer with some datasets and subjects @@ -2342,15 +2809,15 @@ def _construct_matrix(self, new_ds, new_sub, ds_dict): Parameters ---------- - new_ds : set or list of string, + new_ds : set or list of string, Name of all new datasets to add new_sub : set or list of string, Name of all new subjects to add - ds_dict : dict of list of string, ds_dict[d] is a list of subject names - in dataset d in the same order as in the dataset, where d is the name of - the dataset + ds_dict : dict of list of string, ds_dict[d] is a list of subject + names in dataset d in the same order as in the dataset, where d is + the name of the dataset. Returns ------- @@ -2376,14 +2843,14 @@ def _construct_matrix(self, new_ds, new_sub, ds_dict): data.append(idx+1) col.append(col_idx) row.append(self.subject_to_idx[subj]) - self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.matrix = sp.coo_matrix((data, (row, col)), + shape=(self.num_subj, self.num_dataset)) self.dok_matrix = self.matrix.todok(copy=True) # compute connectivity self._compute_connected() return - def _add_new_dataset(self, new_ds, new_sub, ds_dict): """Add some new datasets into the organizer when the organizer was already initialized and the new datasets are not in it yet. @@ -2391,15 +2858,15 @@ def _add_new_dataset(self, new_ds, new_sub, ds_dict): Parameters ---------- - new_ds : set or list of string, + new_ds : set or list of string, Name of all new datasets to add new_sub : set or list of string, Name of all new subjects to add - ds_dict : dict of list of string, ds_dict[d] is a list of subject names - in dataset d in the same order as in the dataset, where d is the name of - the dataset + ds_dict : dict of list of string, ds_dict[d] is a list of subject + names in dataset d in the same order as in the dataset, where d is + the name of the dataset. Returns ------- @@ -2412,12 +2879,13 @@ def _add_new_dataset(self, new_ds, new_sub, ds_dict): self.idx_to_subject[self.num_subj + idx] = subj for idx, ds in enumerate(new_ds): self.dataset_to_idx[ds] = self.num_dataset + idx - self.idx_to_dataset[self.num_dataset + idx] = ds + self.idx_to_dataset[self.num_dataset + idx] = ds self.num_subj += len(new_sub) - self.num_dataset += len(new_ds) + self.num_dataset += len(new_ds) - # fill in sparse matrix - data, row, col = self.matrix.data.tolist(), self.matrix.row.tolist(), self.matrix.col.tolist() + # fill in sparse matrix + data, row, col = self.matrix.data.tolist(), self.matrix.row.\ + tolist(), self.matrix.col.tolist() for ds in new_ds: self.subj_in_dataset[ds] = ds_dict[ds] col_idx = self.dataset_to_idx[ds] @@ -2425,24 +2893,24 @@ def _add_new_dataset(self, new_ds, new_sub, ds_dict): data.append(idx+1) col.append(col_idx) row.append(self.subject_to_idx[subj]) - self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.matrix = sp.coo_matrix((data, (row, col)), + shape=(self.num_subj, self.num_dataset)) self.dok_matrix = self.matrix.todok(copy=True) return - def _replace_dataset(self, replace_ds, ds_dict): - """Replace information of some datasets with information in ds_dict assuming - those datasets are already in the organizer + """Replace information of some datasets with information in ds_dict + assuming those datasets are already in the organizer Parameters ---------- - replace_ds : set or list of string, + replace_ds : set or list of string, Name of all datasets to replace - ds_dict : dict of list of string, ds_dict[d] is a list of subject names - in dataset d in the same order as in the dataset, where d is the name of - the dataset + ds_dict : dict of list of string, ds_dict[d] is a list of subject + names in dataset d in the same order as in the dataset, where d is + the name of the dataset. Returns ------- @@ -2450,10 +2918,12 @@ def _replace_dataset(self, replace_ds, ds_dict): None """ # extract data from the sparse matrix - data, row, col = self.matrix.data.tolist(), self.matrix.row.tolist(), self.matrix.col.tolist() + data, row, col = self.matrix.data.tolist(), self.matrix.row.\ + tolist(), self.matrix.col.tolist() # remove data of datasets to be replaced from the coo sparse matrix - data, row, col, subj_to_check = self._remove_datasets_from_data(replace_ds, data, row, col) + data, row, col, subj_to_check = self._remove_datasets_from_data( + replace_ds, data, row, col) # add data of datasets to replace for ds in replace_ds: @@ -2475,11 +2945,11 @@ def _replace_dataset(self, replace_ds, ds_dict): row = self._remove_subjects_by_re_indexing(subj_to_remove, row) # re-construct the matrix - self.matrix = sp.coo_matrix((data, (row, col)), shape=(self.num_subj, self.num_dataset)) + self.matrix = sp.coo_matrix((data, (row, col)), + shape=(self.num_subj, self.num_dataset)) self.dok_matrix = self.matrix.todok(copy=True) return - def _remove_subjects_from_data(self, subjects, data, row, col): """Remove some subjects by deleting their data @@ -2495,22 +2965,24 @@ def _remove_subjects_from_data(self, subjects, data, row, col): Returns ------- - data, row, col : list of int, - Data can be used to construct a sparse matrix after removal of those + data, row, col : list of int, + Data can be used to construct a sparse matrix after removal of those subjects Note ---- - Subjects are not re-indexed. Need to call _remove_subjects_by_re_indexing() - afterwards to re-index. + Subjects are not re-indexed. Need to call + _remove_subjects_by_re_indexing() afterwards to re-index. """ len_data = len(data) - row_to_remove = set() # subject indices (row indices) to remove + # subject indices (row indices) to remove + row_to_remove = set() subjects = set(subjects) for subj in subjects: row_to_remove.add(self.subject_to_idx[subj]) - idx_to_remove = [] # data indices to remove from data, row, col lists + # data indices to remove from data, row, col lists + idx_to_remove = [] for idx, row_idx in enumerate(row): if row_idx in row_to_remove: idx_to_remove.append(idx) @@ -2521,13 +2993,12 @@ def _remove_subjects_from_data(self, subjects, data, row, col): self.subj_in_dataset[ds][idx] = None # remove data - data = [data[i] for i in range(len_data) if not i in idx_to_remove] - row = [row[i] for i in range(len_data) if not i in idx_to_remove] - col = [col[i] for i in range(len_data) if not i in idx_to_remove] + data = [data[i] for i in range(len_data) if i not in idx_to_remove] + row = [row[i] for i in range(len_data) if i not in idx_to_remove] + col = [col[i] for i in range(len_data) if i not in idx_to_remove] return data, row, col - def _remove_datasets_from_data(self, datasets, data, row, col): """Remove some datasets by deleting their data @@ -2543,26 +3014,28 @@ def _remove_datasets_from_data(self, datasets, data, row, col): Returns ------- - data, row, col : list of int, - Data can be used to construct a sparse matrix after removal of those + data, row, col : list of int, + Data can be used to construct a sparse matrix after removal of those datasets subj_to_check : set of string, - Name of subjects that are possibly not in any datasets (and thus need - to be removed) after removal of those datasets + Name of subjects that are possibly not in any datasets (and thus + need to be removed) after removal of those datasets. Note ---- - Datasets are not re-indexed. Need to call _remove_datasets_by_re_indexing() - afterwards to re-index. + Datasets are not re-indexed. Need to call + _remove_datasets_by_re_indexing() afterwards to re-index. """ len_data = len(data) - col_to_remove = set() # dataset indices (column indices) to remove + col_to_remove = set() # dataset indices (column indices) to remove for ds in datasets: col_to_remove.add(self.dataset_to_idx[ds]) - idx_to_remove = [] # data indices to remove from data, row, col lists - subj_to_check = set() # possible subject indices to remove after removing datasets + # data indices to remove from data, row, col lists + idx_to_remove = [] + # possible subject indices to remove after removing datasets + subj_to_check = set() for idx, col_idx in enumerate(col): if col_idx in col_to_remove: idx_to_remove.append(idx) @@ -2572,22 +3045,21 @@ def _remove_datasets_from_data(self, datasets, data, row, col): del self.subj_in_dataset[ds] del self.num_subj_dataset[ds] # remove data - data = [data[i] for i in range(len_data) if not i in idx_to_remove] - row = [row[i] for i in range(len_data) if not i in idx_to_remove] - col = [col[i] for i in range(len_data) if not i in idx_to_remove] + data = [data[i] for i in range(len_data) if i not in idx_to_remove] + row = [row[i] for i in range(len_data) if i not in idx_to_remove] + col = [col[i] for i in range(len_data) if i not in idx_to_remove] return data, row, col, subj_to_check - def _remove_subjects_by_re_indexing(self, subjects, row): - """Re-index all subjects after removal of data of some subjects + """Re-index all subjects after removal of data of some subjects so that the subject indexing are still contiguous. Parameters ---------- subjects : set or list of string, - Name of subjects where their data in self.matrix are removed + Name of subjects where their data in self.matrix are removed already and need to be removed from indexing row : list of int, row indices as in a sparse matrix @@ -2602,7 +3074,7 @@ def _remove_subjects_by_re_indexing(self, subjects, row): Note ---- - Data of subjects 'subjects' must be removed already. If not, + Data of subjects 'subjects' must be removed already. If not, need to call _remove_subjects_from_data() beforehand """ # remaining subjects after moving 'subjects' @@ -2618,21 +3090,21 @@ def _remove_subjects_by_re_indexing(self, subjects, row): new_r = new_subject_to_idx[subj] row[idx] = new_r # update mapping - self.subject_to_idx, self.idx_to_subject = new_subject_to_idx, new_idx_to_subject + self.subject_to_idx, self.idx_to_subject = new_subject_to_idx, \ + new_idx_to_subject # update total number of subjects self.num_subj -= len(subjects) return row - def _remove_datasets_by_re_indexing(self, datasets, col): - """Re-index all datasets after removal of data of some datasets + """Re-index all datasets after removal of data of some datasets so that the dataset indexing are still contiguous. Parameters ---------- datasets : set or list of string, - Name of datasets where their data in self.matrix are removed + Name of datasets where their data in self.matrix are removed already and need to be removed from indexing col : list of int, col indices as in a sparse matrix @@ -2647,7 +3119,7 @@ def _remove_datasets_by_re_indexing(self, datasets, col): Note ---- - Data of datasets 'datasets' must be removed already. If not, + Data of datasets 'datasets' must be removed already. If not, need to call _remove_datasets_from_data() beforehand """ # remaining datasets after moving 'datasets' @@ -2663,12 +3135,8 @@ def _remove_datasets_by_re_indexing(self, datasets, col): new_c = new_dataset_to_idx[ds] col[idx] = new_c # update mapping - self.dataset_to_idx, self.idx_to_dataset = new_dataset_to_idx, new_idx_to_dataset + self.dataset_to_idx, self.idx_to_dataset = new_dataset_to_idx, \ + new_idx_to_dataset # update total number of datasets self.num_dataset -= len(datasets) return col - - - - - diff --git a/examples/funcalign/mdms_time_segment_matching_distributed.py b/examples/funcalign/mdms_time_segment_matching_distributed.py index 61453cef5..bacd913fc 100644 --- a/examples/funcalign/mdms_time_segment_matching_distributed.py +++ b/examples/funcalign/mdms_time_segment_matching_distributed.py @@ -43,7 +43,6 @@ from brainiak.fcma.util import compute_correlation from brainiak.funcalign.mdms import MDMS, Dataset - # parameters features = 75 # number of features, k n_iter = 30 # number of iterations of EM From 12842a66b03b33a039699e91fbdada00077f8e56 Mon Sep 17 00:00:00 2001 From: Hejia Zhang Date: Tue, 26 Feb 2019 22:01:36 -0500 Subject: [PATCH 3/7] add test for MDMS --- brainiak/funcalign/mdms.py | 77 ++----- .../mdms_time_segment_matching_example.ipynb | 62 ++++- examples/funcalign/requirements.txt | 1 + tests/funcalign/test_mdms_distributed.py | 211 ++++++++++++++++++ 4 files changed, 285 insertions(+), 66 deletions(-) create mode 100644 tests/funcalign/test_mdms_distributed.py diff --git a/brainiak/funcalign/mdms.py b/brainiak/funcalign/mdms.py index 194996fc4..839144242 100644 --- a/brainiak/funcalign/mdms.py +++ b/brainiak/funcalign/mdms.py @@ -35,8 +35,6 @@ import os import glob from scipy import sparse as sp -import matplotlib.pyplot as plt -import networkx as nx import pickle as pkl __all__ = [ @@ -405,7 +403,7 @@ def __init__(self, n_iter=10, features=50, rand_seed=0, self.logger = logger return - def fit(self, X, datasets=None, y=None): + def fit(self, X, datasets, y=None): """Compute the probabilistic multi-dataset multi-subject (MDMS) SRM analysis @@ -424,9 +422,10 @@ def fit(self, X, datasets=None, y=None): data of subject s in dataset d, where s is the name of the subject and d is the name of the dataset. - datasets : (optional) a Dataset object + datasets : (optional) a Dataset object The Dataset object containing datasets structure. - If not defined, the structure will be inferred from X. + If you only have X, call datasets.build_from_data(X) with full + data to infer datasets. y : not used """ @@ -434,7 +433,7 @@ def fit(self, X, datasets=None, y=None): self.logger.info('Starting Probabilistic MDMS') # Check if datasets is initialized - if datasets is not None and datasets.matrix is None: + if datasets is None or datasets.matrix is None: raise NotFittedError('Dataset object is not initialized.') # Check X format @@ -444,24 +443,18 @@ def fit(self, X, datasets=None, y=None): if format_X != dict and format_X != list: raise Exception('X should be a dict of dict of arrays or dict of' ' list of arrays.') - if format_X == list and (datasets is None or - datasets.built_from_data is None or + if format_X == list and (datasets.built_from_data is None or datasets.built_from_data): raise Exception("Argument 'datasets' must be defined and built " - "from json " - "files when X is a dict of list of 2D arrays. ") - if format_X == dict and datasets is not None: + "from JSON files when X is a dict of list of 2D " + "arrays. ") + if format_X == dict: datasets.built_from_data = True for v in X.values(): if type(v) != format_X: raise Exception('X should be a dict of dict of arrays or ' 'dict of list of arrays.') - # Infer datasets structure from data - if datasets is None: - datasets = Dataset() - datasets.build_from_data(X) - self.voxels_, self.samples_ = _sanity_check(X, datasets, self.comm) # Run MDMS @@ -505,7 +498,7 @@ def transform(self, X, subjects, centered=True, y=None): # Check if the subject exist in the fitted model and has the right # number of voxels for idx in range(len(X)): - if not subjects[idx] in self.w_: + if subjects[idx] not in self.w_: raise NotFittedError("The model has not been fitted to " "subject {}.".format(subjects[idx])) if X[idx] is not None and (self.w_[subjects[idx]]. @@ -1350,7 +1343,7 @@ def __init__(self, n_iter=10, features=50, rand_seed=0, self.logger = logger return - def fit(self, X, datasets=None, demean=True, y=None): + def fit(self, X, datasets, demean=True, y=None): """Compute the Deterministic Shared Response Model Parameters @@ -1371,7 +1364,8 @@ def fit(self, X, datasets=None, demean=True, y=None): datasets : (optional) a Dataset object The Dataset object containing datasets structure. - If not defined, the structure will be inferred from X. + If you only have X, call datasets.build_from_data(X) with full + data to infer datasets. demean : (optional) If True, compute voxel means for each subject and subtract from data. If False, voxel means are set to zero @@ -1383,7 +1377,7 @@ def fit(self, X, datasets=None, demean=True, y=None): self.logger.info('Starting Deterministic SRM') # Check if datasets is initialized - if datasets is not None and datasets.matrix is None: + if datasets is None or datasets.matrix is None: raise NotFittedError('Dataset object is not initialized.') # Check X format @@ -1393,24 +1387,18 @@ def fit(self, X, datasets=None, demean=True, y=None): if format_X != dict and format_X != list: raise Exception('X should be a dict of dict of arrays or dict of' ' list of arrays.') - if format_X == list and (datasets is None or - datasets.built_from_data is None or + if format_X == list and (datasets.built_from_data is None or datasets.built_from_data): raise Exception("Argument 'datasets' must be defined and built " "from json files when X is a dict of list of 2D " "arrays. ") - if format_X == dict and datasets is not None: + if format_X == dict: datasets.built_from_data = True for v in X.values(): if type(v) != format_X: raise Exception('X should be a dict of dict of arrays or ' 'dict of list of arrays.') - # Infer datasets structure from data - if datasets is None: - datasets = Dataset() - datasets.build_from_data(X) - self.voxels_, self.samples_ = _sanity_check(X, datasets, self.comm) # Run MDMS @@ -1454,7 +1442,7 @@ def transform(self, X, subjects, centered=True, y=None): # Check if the subject exist in the fitted model and has the right # number of voxels for idx in range(len(X)): - if not subjects[idx] in self.w_: + if subjects[idx] not in self.w_: raise NotFittedError("The model has not been fitted to " "subject {}.".format(subjects[idx])) if X[idx] is not None and (self.w_[subjects[idx]].shape[0] != @@ -2554,37 +2542,6 @@ def get_datasets_list(self): """ return list(self.dataset_to_idx.keys()) - def visualize_graph(self, font_size=14): - """Visualize the organizer as a graph where each node is a dataset - and the edge is number of shared subjects between the two datasets - - Parameters - ---------- - - font_size : (optional) float, default = 14 - Font size of labels in the graph - - Returns - ------- - - None - """ - if self.adj_matrix is None: - raise Exception('Dataset object not initialized.') - # build graph from adjacency matrix - G = nx.from_numpy_matrix(self.adj_matrix.toarray()) - # assign edge labels - edge_labels = dict([((u, v), self.adj_matrix[u, v]) - for u, v in G.edges]) - pos = nx.spring_layout(G) - nx.draw(G, pos=pos, with_labels=False) - _ = nx.draw_networkx_labels(G, labels=self.idx_to_dataset, - pos=pos, font_size=font_size) - _ = nx.draw_networkx_edge_labels(G, edge_labels=edge_labels, - pos=pos, font_size=font_size) - plt.show() - return - def reset(self): """Reset all attributes in the organizer diff --git a/examples/funcalign/mdms_time_segment_matching_example.ipynb b/examples/funcalign/mdms_time_segment_matching_example.ipynb index 07747e90a..ae7d90ca6 100644 --- a/examples/funcalign/mdms_time_segment_matching_example.ipynb +++ b/examples/funcalign/mdms_time_segment_matching_example.ipynb @@ -34,11 +34,14 @@ "metadata": {}, "outputs": [], "source": [ - "%matplotlib inline\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline \n", "import numpy as np\n", "from scipy.stats import stats\n", "from brainiak.fcma.util import compute_correlation\n", "import pickle as pkl\n", + "import networkx as nx\n", "from brainiak.funcalign.mdms import MDMS, Dataset" ] }, @@ -153,9 +156,51 @@ "print ('------ Number of datasets ------')\n", "print (ds_struct.num_dataset)\n", "print ('------ Number of subjects ------')\n", - "print (ds_struct.num_subj)\n", - "print ('------ Visualize connectivity between datasets: datasets as nodes, number of shared subjects as edges------')\n", - "ds_struct.visualize_graph()" + "print (ds_struct.num_subj)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize connectivity between datasets: datasets as nodes, number of shared subjects as edges\n", + "def visualize_graph(ds_struct, font_size=14):\n", + " \"\"\"Visualize the Dataset object as a graph where each node is a dataset\n", + " and the edge is number of shared subjects between the two datasets\n", + "\n", + " Parameters\n", + " ----------\n", + "\n", + " ds_struct : a Dataset object\n", + " \n", + " font_size : (optional) float, default = 14\n", + " Font size of labels in the graph\n", + "\n", + " Returns\n", + " -------\n", + "\n", + " None\n", + " \"\"\"\n", + " if ds_struct.adj_matrix is None:\n", + " raise Exception('Dataset object not initialized.')\n", + " # build graph from adjacency matrix\n", + " G = nx.from_numpy_matrix(ds_struct.adj_matrix.toarray())\n", + " # assign edge labels\n", + " edge_labels = dict([((u, v), ds_struct.adj_matrix[u, v])\n", + " for u, v in G.edges])\n", + " pos = nx.spring_layout(G)\n", + " nx.draw(G, pos=pos, with_labels=False)\n", + " _ = nx.draw_networkx_labels(G, labels=ds_struct.idx_to_dataset,\n", + " pos=pos, font_size=font_size)\n", + " _ = nx.draw_networkx_edge_labels(G, edge_labels=edge_labels,\n", + " pos=pos, font_size=font_size)\n", + " plt.show()\n", + " return\n", + "\n", + "# draw the graph\n", + "visualize_graph(ds_struct)" ] }, { @@ -251,7 +296,12 @@ "\n", "# 2) When data is a dict of dict of 2D arrays, you don't need the ds_struct, but you need to remove all data not meant \n", "# to be used during the training phase.\n", - "# model.fit(data) # uncomment this line if you have this kind of dataset structure" + "# Uncomment the following lines if you have this kind of dataset structure. Note that ds_struct should only be built\n", + "# on MPI ranks with full data.\n", + "\n", + "# ds_struct = Dataset()\n", + "# ds_struct.build_from_data(data)\n", + "# model.fit(data) " ] }, { @@ -497,7 +547,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.6.5" } }, "nbformat": 4, diff --git a/examples/funcalign/requirements.txt b/examples/funcalign/requirements.txt index bf4f33004..dee6184e3 100644 --- a/examples/funcalign/requirements.txt +++ b/examples/funcalign/requirements.txt @@ -1,3 +1,4 @@ matplotlib nilearn notebook +networkx \ No newline at end of file diff --git a/tests/funcalign/test_mdms_distributed.py b/tests/funcalign/test_mdms_distributed.py new file mode 100644 index 000000000..d64755a13 --- /dev/null +++ b/tests/funcalign/test_mdms_distributed.py @@ -0,0 +1,211 @@ +# Copyright 2016 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from sklearn.exceptions import NotFittedError +import pytest +from mpi4py import MPI +from sklearn.datasets import make_spd_matrix + + +def test_distributed_mdms(): # noqa: C901 + import brainiak.funcalign.mdms + s = brainiak.funcalign.mdms.MDMS() + assert s, "Invalid MDMS instance!" + + import numpy as np + np.random.seed(0) + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + nrank = comm.Get_size() + + # set parameters + subj_ds_list = {'D1': ['Adam', 'Bob', 'Carol'], 'D2': ['Tom', 'Bob']} + voxels = {'Adam': 80, 'Bob': 100, 'Carol': 120, 'Tom': 90} + samples = {'D1': 100, 'D2': 50} + features = 3 + noise_level = 0.1 + + s = brainiak.funcalign.mdms.MDMS(n_iter=5, features=features, comm=comm) + assert s, "Invalid MDMS instance!" + + # generate data on rank 0 + if rank == 0: + # generate S + S = {} + for ds in samples: + mean = np.zeros((samples[ds],)) + cov = make_spd_matrix(samples[ds]) + S[ds] = np.random.multivariate_normal(mean, cov, size=features) + # generate W + W = {} + for subj in voxels: + rnd_matrix = np.random.rand(voxels[subj], features) + W[subj], _ = np.linalg.qr(rnd_matrix) + # compute X with noise + X = {} + for ds in samples: + X[ds] = {} + for subj in subj_ds_list[ds]: + noise = np.random.normal(loc=0, scale=noise_level * + abs(np.random.randn()), + size=(voxels[subj], samples[ds])) + X[ds][subj] = W[subj].dot(S[ds]) + noise + # compute data structure + ds_struct = brainiak.funcalign.mdms.Dataset() + ds_struct.build_from_data(X) + assert ds_struct, "Invalid Dataset instance!" + + else: + X = {} + for ds in samples: + X[ds] = {} + for subj in subj_ds_list[ds]: + X[ds][subj] = None + ds_struct = None + + # MDMS: broadcast ds_struct + ds_struct = comm.bcast(ds_struct) + + # Check that transform does NOT run before fitting the model + with pytest.raises(NotFittedError): + s.transform([X['D1']['Adam']], ['Adam']) + if rank == 0: + print("Test: transforming before fitting the model") + + # Check that it does NOT run with wrong X structure + with pytest.raises(Exception): + s.fit({'D1': X['D1'], 'D2': [X['D2']['Bob']]}, ds_struct) + if rank == 0: + print("Test: running MDMS with wrong X data structure") + + # random distribution of data, otherwise None + if rank == 0: + data_mem = {} + tag = 0 # tag start from 0 + for ds in X: + data_mem[ds] = {} + for subj in X[ds]: + data_mem[ds][subj] = [np.random.randint(low=0, high=nrank), + tag] + tag += 1 + else: + data_mem = None + data_mem = comm.bcast(data_mem) + if rank == 0: + X_new = {} + for ds in X: + X_new[ds] = {} + for subj in X[ds]: + mem, tag = data_mem[ds][subj] + if mem != 0: + X_new[ds][subj] = None + comm.send(X[ds][subj], dest=mem, tag=tag) + else: + X_new[ds][subj] = X[ds][subj] + X = X_new + else: + for ds in X: + for subj in X[ds]: + mem, tag = data_mem[ds][subj] + if mem == rank: + X[ds][subj] = comm.recv(source=0, tag=tag) + + # Check that runs with 4 subject + s.fit(X, ds_struct) + assert len(s.s_) == len(samples), ( + "Invalid computation of MDMS! (wrong # datasets in S)") + + assert len(s.w_) == len(voxels), ( + "Invalid computation of MDMS! (wrong # subjects in W)") + + # Check W + for subj in voxels: + assert s.w_[subj].shape[0] == voxels[subj], ( + "Invalid computation of MDMS! (wrong # voxels in W)") + assert s.w_[subj].shape[1] == features, ( + "Invalid computation of MDMS! (wrong # features in W)") + ortho = np.linalg.norm(s.w_[subj].T.dot(s.w_[subj]) + - np.eye(s.w_[subj].shape[1]), + 'fro') + assert ortho < 1e-7, "A Wi mapping is not orthonormal in MDMS." + + # Check S + for ds in samples: + assert s.s_[ds].shape[0] == features, ( + "Invalid computation of MDMS! (wrong # features in S)") + assert s.s_[ds].shape[1] == samples[ds], ( + "Invalid computation of MDMS! (wrong # samples in S)") + + # Check X reconstruction + for ds in X: + for subj in X[ds]: + if X[ds][subj] is not None: + difference = np.linalg.norm(X[ds][subj] - + s.w_[subj].dot(s.s_[ds]), + 'fro') + datanorm = np.linalg.norm(X[ds][subj], 'fro') + assert difference/datanorm < 2.0, ( + "Model seems incorrectly computed.") + + # Check that it does run to compute the shared response of each + # dataset after the model computation + for ds in samples: + data, subjects = [], [] + for subj in X[ds]: + data.append(X[ds][subj]) + subjects.append(subj) + new_s = s.transform(data, subjects) + + assert len(new_s) == len(data), ( + "Invalid computation of MDMS! (wrong #" + " subjects after transform)") + for subj in range(len(new_s)): + if new_s[subj] is not None: + assert new_s[subj].shape[0] == features, ( + "Invalid computation of MDMS! (wrong # features after " + "transform)") + assert new_s[subj].shape[1] == samples[ds], ( + "Invalid computation of MDMS! (wrong # samples after " + "transform)") + + # Check that it does NOT run with non-matching number of subjects + with pytest.raises(ValueError): + s.transform(data, subjects+['new']) + if rank == 0: + print("Test: transforming with non-matching number of subjects") + + # Check that it does not run with different number of voxels for the same + # subjects across datasets + # Only subject 'Bob' is in two datasets, so we change his data + if X['D1']['Bob'] is not None: + tmp = X['D1']['Bob'] + X['D1']['Bob'] = X['D1']['Bob'][: -2, :] + else: + tmp = None + with pytest.raises(ValueError): + s.fit(X, ds_struct) + if rank == 0: + print("Test: different number of voxels for the same subject") + + # Check that it does not run with different number of samples (TRs) + # within the same dataset + X['D1']['Bob'] = tmp # put back the data + if X['D2']['Tom'] is not None: + X['D2']['Tom'] = X['D2']['Tom'][:, : -2] + with pytest.raises(ValueError): + s.fit(X, ds_struct) + if rank == 0: + print("Test: different number of samples within dataset") + + +test_distributed_mdms() From 6df3c1b2a12489e03b989e73d9f07a80b2694ca3 Mon Sep 17 00:00:00 2001 From: Hejia Zhang Date: Tue, 26 Feb 2019 23:13:14 -0500 Subject: [PATCH 4/7] update MDMS test --- brainiak/funcalign/mdms.py | 9 +-- .../mdms_time_segment_matching_example.ipynb | 2 +- tests/funcalign/test_mdms_distributed.py | 70 ++++++++----------- 3 files changed, 36 insertions(+), 45 deletions(-) diff --git a/brainiak/funcalign/mdms.py b/brainiak/funcalign/mdms.py index 839144242..baa18009e 100644 --- a/brainiak/funcalign/mdms.py +++ b/brainiak/funcalign/mdms.py @@ -297,13 +297,14 @@ def _check_missing_data(datasets, shape0, shape1, data_exist): if datasets.dok_matrix[subj, ds_idx] != 0: if data_exist[ds][subj] == 0: raise ValueError("Data of subject {} in dataset {} is " - "missing.".format(datasets.dok_matrix[ - subj, ds_idx]-1, ds)) + "missing." + .format(datasets.idx_to_subject[subj], + ds)) elif data_exist[ds][subj] > 1: raise ValueError("Data of subject {} in dataset {} " "appears more than once." - .format(datasets.dok_matrix[ - subj, ds_idx]-1, ds)) + .format(datasets.idx_to_subject[subj], + ds)) else: shape0[ds][subj] = 0 shape1[ds][subj] = 0 diff --git a/examples/funcalign/mdms_time_segment_matching_example.ipynb b/examples/funcalign/mdms_time_segment_matching_example.ipynb index ae7d90ca6..32e674a6c 100644 --- a/examples/funcalign/mdms_time_segment_matching_example.ipynb +++ b/examples/funcalign/mdms_time_segment_matching_example.ipynb @@ -547,7 +547,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.5" + "version": "3.6.3" } }, "nbformat": 4, diff --git a/tests/funcalign/test_mdms_distributed.py b/tests/funcalign/test_mdms_distributed.py index d64755a13..71555deee 100644 --- a/tests/funcalign/test_mdms_distributed.py +++ b/tests/funcalign/test_mdms_distributed.py @@ -52,73 +52,63 @@ def test_distributed_mdms(): # noqa: C901 rnd_matrix = np.random.rand(voxels[subj], features) W[subj], _ = np.linalg.qr(rnd_matrix) # compute X with noise - X = {} + all_X = {} for ds in samples: - X[ds] = {} + all_X[ds] = {} for subj in subj_ds_list[ds]: noise = np.random.normal(loc=0, scale=noise_level * abs(np.random.randn()), size=(voxels[subj], samples[ds])) - X[ds][subj] = W[subj].dot(S[ds]) + noise + all_X[ds][subj] = W[subj].dot(S[ds]) + noise # compute data structure ds_struct = brainiak.funcalign.mdms.Dataset() - ds_struct.build_from_data(X) + ds_struct.build_from_data(all_X) assert ds_struct, "Invalid Dataset instance!" - - else: - X = {} - for ds in samples: - X[ds] = {} - for subj in subj_ds_list[ds]: - X[ds][subj] = None - ds_struct = None - - # MDMS: broadcast ds_struct - ds_struct = comm.bcast(ds_struct) - - # Check that transform does NOT run before fitting the model - with pytest.raises(NotFittedError): - s.transform([X['D1']['Adam']], ['Adam']) - if rank == 0: - print("Test: transforming before fitting the model") - - # Check that it does NOT run with wrong X structure - with pytest.raises(Exception): - s.fit({'D1': X['D1'], 'D2': [X['D2']['Bob']]}, ds_struct) - if rank == 0: - print("Test: running MDMS with wrong X data structure") - - # random distribution of data, otherwise None - if rank == 0: + # To distribute data later data_mem = {} tag = 0 # tag start from 0 - for ds in X: + for ds in all_X: data_mem[ds] = {} - for subj in X[ds]: + for subj in all_X[ds]: data_mem[ds][subj] = [np.random.randint(low=0, high=nrank), tag] tag += 1 else: + ds_struct = None data_mem = None + + # broadcast ds_struct and data_mem + ds_struct = comm.bcast(ds_struct) data_mem = comm.bcast(data_mem) + + # random distribution of data, otherwise None + X = {} + for ds in samples: + X[ds] = {} + for subj in subj_ds_list[ds]: + X[ds][subj] = None + if rank == 0: - X_new = {} for ds in X: - X_new[ds] = {} for subj in X[ds]: mem, tag = data_mem[ds][subj] if mem != 0: - X_new[ds][subj] = None - comm.send(X[ds][subj], dest=mem, tag=tag) + comm.send(all_X[ds][subj], dest=mem, tag=tag) else: - X_new[ds][subj] = X[ds][subj] - X = X_new + X[ds][subj] = all_X[ds][subj] + del all_X else: for ds in X: for subj in X[ds]: mem, tag = data_mem[ds][subj] - if mem == rank: - X[ds][subj] = comm.recv(source=0, tag=tag) + if mem == rank: + X[ds][subj] = comm.recv(source=0, tag=tag) + + # Check that transform does NOT run before fitting the model + with pytest.raises(NotFittedError): + s.transform([X['D1']['Adam']], ['Adam']) + if rank == 0: + print("Test: transforming before fitting the model") # Check that runs with 4 subject s.fit(X, ds_struct) From 7f70c523580c2969d9475872ae2e92f6bf9ff4c2 Mon Sep 17 00:00:00 2001 From: Hejia Zhang Date: Wed, 27 Feb 2019 02:41:57 -0500 Subject: [PATCH 5/7] edit docs --- brainiak/funcalign/mdms.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/brainiak/funcalign/mdms.py b/brainiak/funcalign/mdms.py index baa18009e..129bdf93b 100644 --- a/brainiak/funcalign/mdms.py +++ b/brainiak/funcalign/mdms.py @@ -14,6 +14,7 @@ """multi-dataset multi-subject (MDMS) SRM analysis The implementations are based on the following publications: + .. [Zhang2018] "Transfer learning on fMRI datasets", H. Zhang, P.-H. Chen, P. Ramadge The 21st International Conference on Artificial Intelligence and @@ -318,8 +319,8 @@ class MDMS(BaseEstimator, TransformerMixin): response S among all subjects per dataset and an orthogonal transform W across all datasets per subject: - .. math:: X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\ - dots M + .. math:: + X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\dots M\\ Parameters ---------- @@ -1260,15 +1261,11 @@ def restore(self, file): class DetMDMS(BaseEstimator, TransformerMixin): - """Deterministic multi-dataset multi-subject (MDMS) SRM analysis - (DetMDMS) + """Deterministic multi-dataset multi-subject (MDMS) Given multi-dataset multi-subject data, factorize it as a shared response S among all subjects per dataset and an orthogonal transform W - across all datasets per subject: - - .. math:: X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\ - dots M + across all datasets per subject. Parameters ---------- @@ -1331,8 +1328,8 @@ class DetMDMS(BaseEstimator, TransformerMixin): memory complexity is :math:`O(V T)` with I - the number of iterations, V - the sum of number of voxels from all subjects, T - the sum of number of samples from all datasets, K - the number of - features (typically, - :math:`V \\gg T \\gg K`), and N - the number of subjects. + features (typically, :math:`V \\gg T \\gg K`), and + N - the number of subjects. """ def __init__(self, n_iter=10, features=50, rand_seed=0, @@ -1978,6 +1975,9 @@ class Dataset(object): a graph where each dataset is a node and each edge is number of shared subjects between the two datasets. + .. math:: + X_{ds} \\approx W_s S_d, \\forall s=1 \\dots N, \\forall d=1 \\dots M\\ + This organizer is used in the MDMS or DetMDMS [Zhang2018]_ and can also be used as a standalone datasets organizer. From 6bdb01121e473057914160fc27fc0a7b0a6d78e5 Mon Sep 17 00:00:00 2001 From: Mingbo Cai Date: Mon, 14 Aug 2023 16:52:46 +0900 Subject: [PATCH 6/7] Update mdms.py change `np.int` to `np.int32` to conform to update of numpy. --- brainiak/funcalign/mdms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/brainiak/funcalign/mdms.py b/brainiak/funcalign/mdms.py index 129bdf93b..3a33cda8d 100644 --- a/brainiak/funcalign/mdms.py +++ b/brainiak/funcalign/mdms.py @@ -228,9 +228,9 @@ def _collect_size_information(X, datasets, comm): ds_list = datasets.get_datasets_list() for ds in ds_list: # initialization - shape0[ds] = np.zeros((datasets.num_subj,), dtype=np.int) - shape1[ds] = np.zeros((datasets.num_subj,), dtype=np.int) - data_exist[ds] = np.zeros((datasets.num_subj,), dtype=np.int) + shape0[ds] = np.zeros((datasets.num_subj,), dtype=np.int32) + shape1[ds] = np.zeros((datasets.num_subj,), dtype=np.int32) + data_exist[ds] = np.zeros((datasets.num_subj,), dtype=np.int32) ds_idx = datasets.dataset_to_idx[ds] # collect size information of each dataset if X[ds] is not None: From 365ca3712a82c10a2a82be67512a06166c9c43fe Mon Sep 17 00:00:00 2001 From: Mingbo Cai Date: Mon, 14 Aug 2023 16:54:17 +0900 Subject: [PATCH 7/7] Update mdms.py correct formatting --- brainiak/funcalign/mdms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/brainiak/funcalign/mdms.py b/brainiak/funcalign/mdms.py index 3a33cda8d..5aaa1a122 100644 --- a/brainiak/funcalign/mdms.py +++ b/brainiak/funcalign/mdms.py @@ -80,8 +80,7 @@ def _init_w_transforms(voxels, features, random_states, datasets): ------- w : dict of array, w[s] has shape=[voxels[s], features] where s is the - name - of the subject. + name of the subject. The initialized orthogonal transforms (mappings) :math:`W_s` for each subject.