Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions brainiak/searchlight/searchlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,18 +420,37 @@ def run_block_function(self, block_fn, extra_block_fn_params=None,
processes = usable_cpus
else:
processes = min(pool_size, usable_cpus)
with Pool(processes) as pool:

if processes > 1:
with Pool(processes) as pool:
for idx, block in enumerate(self.blocks):
result = pool.apply_async(
block_fn,
([subproblem[idx] for subproblem in self.subproblems],
self.submasks[idx],
self.sl_rad,
self.bcast_var,
extra_block_fn_params))
results.append((block[0], result))
local_outputs = [(result[0], result[1].get())
for result in results]
else:
# If we only are using one CPU core, no need to create a Pool,
# cause an underlying fork(), and send the data to that process.
# Just do it here in serial. This will save copying the memory
# and will stop a fork() which can cause problems in some MPI
# implementations.
for idx, block in enumerate(self.blocks):
result = pool.apply_async(
block_fn,
([subproblem[idx] for subproblem in self.subproblems],
self.submasks[idx],
self.sl_rad,
self.bcast_var,
extra_block_fn_params))
subprob_list = [subproblem[idx]
for subproblem in self.subproblems]
result = block_fn(
subprob_list,
self.submasks[idx],
self.sl_rad,
self.bcast_var,
extra_block_fn_params)
results.append((block[0], result))
local_outputs = [(result[0], result[1].get())
for result in results]
local_outputs = [(result[0], result[1]) for result in results]

# Collect results
global_outputs = self.comm.gather(local_outputs)
Expand Down
30 changes: 30 additions & 0 deletions tests/searchlight/test_searchlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,36 @@ def test_searchlight_with_cube():
assert global_outputs[i, j, k] is None


def test_searchlight_with_cube_poolsize_1():
sl = Searchlight(sl_rad=3)
comm = MPI.COMM_WORLD
rank = comm.rank
size = comm.size
dim0, dim1, dim2 = (50, 50, 50)
ntr = 30
nsubj = 3
mask = np.zeros((dim0, dim1, dim2), dtype=np.bool)
data = [np.empty((dim0, dim1, dim2, ntr), dtype=np.object)
if i % size == rank
else None
for i in range(0, nsubj)]

# Put a spot in the mask
mask[10:17, 10:17, 10:17] = True

sl.distribute(data, mask)
global_outputs = sl.run_searchlight(cube_sfn, pool_size=1)

if rank == 0:
assert global_outputs[13, 13, 13] == 1.0
global_outputs[13, 13, 13] = None

for i in range(global_outputs.shape[0]):
for j in range(global_outputs.shape[1]):
for k in range(global_outputs.shape[2]):
assert global_outputs[i, j, k] is None


def diamond_sfn(l, msk, myrad, bcast_var):
assert not np.any(msk[~Diamond(3).mask_])
if np.all(msk[Diamond(3).mask_]):
Expand Down