From 181d30e8e3ac13a2211168cf569193aec61f490d Mon Sep 17 00:00:00 2001 From: Christopher Baldassano Date: Tue, 20 Feb 2018 09:19:32 -0500 Subject: [PATCH 1/4] Fixing bug in event segmentation code --- brainiak/eventseg/event.py | 20 ++++++++++---------- tests/.flake8 | 2 +- tests/eventseg/test_event.py | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/brainiak/eventseg/event.py b/brainiak/eventseg/event.py index e0c12abb5..4373addd9 100644 --- a/brainiak/eventseg/event.py +++ b/brainiak/eventseg/event.py @@ -268,23 +268,23 @@ def _forward_backward(self, logprob): p_end[0, -2] = 1 # Forward pass - for t in range(t): - if t == 0: + for i in range(t): + if i == 0: log_alpha[0, :] = self._log(p_start) + logprob[0, :] else: - log_alpha[t, :] = self._log(np.exp(log_alpha[t - 1, :]) - .dot(P)) + logprob[t, :] + log_alpha[i, :] = self._log(np.exp(log_alpha[i - 1, :]) + .dot(P)) + logprob[i, :] - log_scale[t] = np.logaddexp.reduce(log_alpha[t, :]) - log_alpha[t] -= log_scale[t] + log_scale[i] = np.logaddexp.reduce(log_alpha[i, :]) + log_alpha[i] -= log_scale[i] # Backward pass log_beta[-1, :] = self._log(p_end) - log_scale[-1] - for t in reversed(range(t - 1)): - obs_weighted = log_beta[t + 1, :] + logprob[t + 1, :] + for i in reversed(range(t - 1)): + obs_weighted = log_beta[i + 1, :] + logprob[i + 1, :] offset = np.max(obs_weighted) - log_beta[t, :] = offset + self._log( - np.exp(obs_weighted - offset).dot(P.T)) - log_scale[t] + log_beta[i, :] = offset + self._log( + np.exp(obs_weighted - offset).dot(P.T)) - log_scale[i] # Combine and normalize log_gamma = log_alpha + log_beta diff --git a/tests/.flake8 b/tests/.flake8 index 49611972d..3c0105d98 100644 --- a/tests/.flake8 +++ b/tests/.flake8 @@ -12,5 +12,5 @@ ignore = W503, W504, # Print restriction only applies to libraries - T001 + T001, T003 diff --git a/tests/eventseg/test_event.py b/tests/eventseg/test_event.py index be6ef345e..dd60134ea 100644 --- a/tests/eventseg/test_event.py +++ b/tests/eventseg/test_event.py @@ -74,3 +74,17 @@ def test_weighted_var(): assert np.allclose( es.calc_weighted_event_var(D, weights, mean_pat), true_var),\ "Failed to compute variance with fractional weights" + + +def test_sym(): + es = brainiak.eventseg.event.EventSegment(4) + + evpat = np.repeat(np.arange(10).reshape(-1, 1), 4, axis=1) + es.set_event_patterns(evpat) + + D = np.repeat(np.arange(10).reshape(1, -1), 20, axis=0) + ev = es.find_events(D, var=1)[0] + + # Check that events 1-4 and 2-3 are symmetric + assert np.all(np.isclose(ev[:, :2], np.fliplr(np.flipud(ev[:, 2:])))),\ + "Fit with constant data is not symmetric" From 36c03d643b5ecaf7fc7ff5c43a97c6c4b1cb59ee Mon Sep 17 00:00:00 2001 From: Christopher Baldassano Date: Tue, 20 Feb 2018 09:24:44 -0500 Subject: [PATCH 2/4] Adding news fragment --- docs/newsfragments/X.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/newsfragments/X.bugfix diff --git a/docs/newsfragments/X.bugfix b/docs/newsfragments/X.bugfix new file mode 100644 index 000000000..2ff2c43a5 --- /dev/null +++ b/docs/newsfragments/X.bugfix @@ -0,0 +1 @@ +Fixed bug in eventseg that was causing fits to be asymmetric From 5157fc1f2636681a0e235ba1a41d4bdd213cac1c Mon Sep 17 00:00:00 2001 From: Christopher Baldassano Date: Tue, 20 Feb 2018 09:27:31 -0500 Subject: [PATCH 3/4] Updating news fragment title --- docs/newsfragments/X.bugfix => 339.bugfix | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/newsfragments/X.bugfix => 339.bugfix (100%) diff --git a/docs/newsfragments/X.bugfix b/339.bugfix similarity index 100% rename from docs/newsfragments/X.bugfix rename to 339.bugfix From 4c2c85bcd6f2741517172f4a2003165337ae3eb2 Mon Sep 17 00:00:00 2001 From: Christopher Baldassano Date: Tue, 20 Feb 2018 09:27:52 -0500 Subject: [PATCH 4/4] Updating news fragment title --- 339.bugfix => docs/newsfragments/339.bugfix | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename 339.bugfix => docs/newsfragments/339.bugfix (100%) diff --git a/339.bugfix b/docs/newsfragments/339.bugfix similarity index 100% rename from 339.bugfix rename to docs/newsfragments/339.bugfix