Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions brainiak/eventseg/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,23 +268,23 @@ def _forward_backward(self, logprob):
p_end[0, -2] = 1

# Forward pass
for t in range(t):
if t == 0:
for i in range(t):
if i == 0:
log_alpha[0, :] = self._log(p_start) + logprob[0, :]
else:
log_alpha[t, :] = self._log(np.exp(log_alpha[t - 1, :])
.dot(P)) + logprob[t, :]
log_alpha[i, :] = self._log(np.exp(log_alpha[i - 1, :])
.dot(P)) + logprob[i, :]

log_scale[t] = np.logaddexp.reduce(log_alpha[t, :])
log_alpha[t] -= log_scale[t]
log_scale[i] = np.logaddexp.reduce(log_alpha[i, :])
log_alpha[i] -= log_scale[i]

# Backward pass
log_beta[-1, :] = self._log(p_end) - log_scale[-1]
for t in reversed(range(t - 1)):
obs_weighted = log_beta[t + 1, :] + logprob[t + 1, :]
for i in reversed(range(t - 1)):
obs_weighted = log_beta[i + 1, :] + logprob[i + 1, :]
offset = np.max(obs_weighted)
log_beta[t, :] = offset + self._log(
np.exp(obs_weighted - offset).dot(P.T)) - log_scale[t]
log_beta[i, :] = offset + self._log(
np.exp(obs_weighted - offset).dot(P.T)) - log_scale[i]

# Combine and normalize
log_gamma = log_alpha + log_beta
Expand Down
1 change: 1 addition & 0 deletions docs/newsfragments/339.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed bug in eventseg that was causing fits to be asymmetric
2 changes: 1 addition & 1 deletion tests/.flake8
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ ignore =
W503,
W504,
# Print restriction only applies to libraries
T001
T001,
T003
14 changes: 14 additions & 0 deletions tests/eventseg/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def test_weighted_var():
assert np.allclose(
es.calc_weighted_event_var(D, weights, mean_pat), true_var),\
"Failed to compute variance with fractional weights"


def test_sym():
es = brainiak.eventseg.event.EventSegment(4)

evpat = np.repeat(np.arange(10).reshape(-1, 1), 4, axis=1)
es.set_event_patterns(evpat)

D = np.repeat(np.arange(10).reshape(1, -1), 20, axis=0)
ev = es.find_events(D, var=1)[0]

# Check that events 1-4 and 2-3 are symmetric
assert np.all(np.isclose(ev[:, :2], np.fliplr(np.flipud(ev[:, 2:])))),\
"Fit with constant data is not symmetric"