Skip to content

Conversation

@tylerjereddy
Copy link
Contributor

  • A few minor simplifications/adjustments that seem to give about 5-10% better performance for non-tree KernelExplainer for the scenario described at Query, ENH: faster KernelExplainer #3943.

  • I don't think we have formal asv-style benchmarks in this project, but informal testing with 5 trials each (times in seconds) for the reproducer in the above issue (on an ARM Mac):

    • master: 52.803, 50.711, 51.503, 50.613, 51.348
    • this branch: 47.557, 47.128, 48.197, 47.651, 47.216
  • This is less impactful than then improvements at ENH: faster non-tree KernelExplainer #3944, but also less intrusive, and the improvements should compound with each other.

  • I added postdoc @arhall0 as a co-author on the commit, since they did the initial profiling work to find the ~10% bottleneck.

  • Since it appears that we only really need to identify the first point of difference between the arrays, it is probably possible to write an early-break algorithm that is more efficient, but that would probably increase complexity/need for compiled backend, so this is just a start to bump things a little bit.

Checklist

  • All pre-commit checks pass.
  • Unit tests added (if fixing a bug or adding a new feature)

* A few minor simplifications/adjustments that seem to give about 5-10%
better performance for non-tree `KernelExplainer` for the scenario
described at shapgh-3943.

* I don't think we have formal `asv`-style benchmarks in this project,
but informal testing with 5 trials each (times in seconds)
for the reproducer in the above issue (on an ARM Mac):
- `master`:     52.803, 50.711, 51.503, 50.613, 51.348
- this branch:  47.557, 47.128, 48.197, 47.651, 47.216

* This is less impactful than then improvements at shapgh-3944, but also
less intrusive, and the improvements should compound with each other.

Co-authored-by: Aaron R Hall <arhall@lanl.gov>
return 0 if np.allclose(i, j, equal_nan=True) else 1
elif hasattr(i, "dtype") and hasattr(j, "dtype"):
if np.issubdtype(i.dtype, np.number) and np.issubdtype(j.dtype, np.number):
return 0 if np.allclose(i, j, equal_nan=True) else 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the changes was needed because some "perfectly-normal" NumPy arrays in the testsuite were not caught by the condition above after removing the frompyfunc business (which, of course, doesn't actually make CPython code any faster).

Obviously, make sure you/we are comfortable that the current testsuite is robust for this codepath. It did seem to have many failures without this shim, so possibly coverage is "ok."

Comment on lines +498 to +501
return 0 if np.allclose(i, j, equal_nan=True) else 1
elif hasattr(i, "dtype") and hasattr(j, "dtype"):
if np.issubdtype(i.dtype, np.number) and np.issubdtype(j.dtype, np.number):
return 0 if np.allclose(i, j, equal_nan=True) else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. I started a new conversation on these lines 'cause I somehow couldn't comment + suggest on the 4 lines on the previous one. Isn't there also a case where i and j are not a subtype of np.number? Is there a chance that we return None here? I can't think of one but just to be sure could we do:

Suggested change
return 0 if np.allclose(i, j, equal_nan=True) else 1
elif hasattr(i, "dtype") and hasattr(j, "dtype"):
if np.issubdtype(i.dtype, np.number) and np.issubdtype(j.dtype, np.number):
return 0 if np.allclose(i, j, equal_nan=True) else 1
return 0 if np.allclose(i, j, equal_nan=True) else 1
elif hasattr(i, "dtype") and hasattr(j, "dtype"):
if np.issubdtype(i.dtype, np.number) and np.issubdtype(j.dtype, np.number):
return 0 if np.allclose(i, j, equal_nan=True) else 1
return 0 if i == j else 1

To have the general comparison as a fallback for all cases where if/else branches might be missing?

x_group = x_group.todense()
num_mismatches = np.sum(np.frompyfunc(self.not_equal, 2, 1)(x_group, self.data.data[:, inds]))
varying[i] = num_mismatches > 0
varying[i] = self.not_equal(x_group, self.data.data[:, inds])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice improvement

@codecov
Copy link

codecov bot commented Feb 8, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 64.71%. Comparing base (baace0f) to head (b3c8bb8).
Report is 39 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #3983      +/-   ##
==========================================
+ Coverage   64.67%   64.71%   +0.04%     
==========================================
  Files          92       92              
  Lines       12862    12884      +22     
==========================================
+ Hits         8318     8338      +20     
- Misses       4544     4546       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

tylerjereddy and others added 2 commits February 8, 2025 17:56
* Expand the handling of array/object types in `not_equal()`
based on reviewer feedback. Add a related test that fails
when using the naive `i == j` fallback.
return 0 if np.allclose(i, j, equal_nan=True) else 1
if np.issubdtype(i.dtype, np.bool_) and np.issubdtype(j.dtype, np.bool_):
return 0 if np.allclose(i, j, equal_nan=True) else 1
return 0 if all(i == j) else 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to address #3983 (comment), but found that I actually needed a slight adjustment to the fallback for array-like inputs with all(). I also added custom handling for bool_, which seems "ok" with np.allclose().

np.testing.assert_allclose(sigm(shap_values.values.sum(1) + explainer.expected_value), pred, atol=1e-04)


@pytest.mark.parametrize("dt", [np.bool_, np.object_])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is a bit weird in the sense that it doesn't really "fail before" and "pass after," but it does fail if the fallback I recently added uses return 0 if i == j else 1 instead of return 0 if all(i == j) else 1 because for the np.object_ dtype you'll get the typical ValueError: The truth value of an array with more than one element is ambiguous.

So it is more of a sniff test for things slipping through the cracks, but it isn't super picky beyond that, and in fact the original returning of None noted above in some cases still allows the test to pass. It may be possible to cook the test up in a way that is more stringent than that, but it is at least doing something.

* Relax the stringency of the numeric closeness check
in `test_explainer_non_number_dtype`, which was failing in CI.
rf.fit(X, y)
explainer = shap.KernelExplainer(model=rf.predict_proba, data=X, random_state=seed)
shap_values = explainer(X)
np.testing.assert_allclose(shap_values.values.max(), 0.26548, rtol=1e-2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the rtol needed for CI to be happy is rather high; conversely, the assertion here isn't really needed for the purpose of the test, which is mostly to fail if there is fundamental logic issue in not_equal like using i == j when all(i == j) is needed.

Copy link
Collaborator

@CloseChoice CloseChoice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT, very nice.

@CloseChoice CloseChoice merged commit f1808f5 into shap:master Feb 11, 2025
18 checks passed
@tylerjereddy tylerjereddy deleted the treddy_issue_3943_close_speedup branch February 11, 2025 17:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants