-
Notifications
You must be signed in to change notification settings - Fork 3.5k
ENH: minor speedups in non-tree KernelExplainer #3983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: minor speedups in non-tree KernelExplainer #3983
Conversation
* A few minor simplifications/adjustments that seem to give about 5-10% better performance for non-tree `KernelExplainer` for the scenario described at shapgh-3943. * I don't think we have formal `asv`-style benchmarks in this project, but informal testing with 5 trials each (times in seconds) for the reproducer in the above issue (on an ARM Mac): - `master`: 52.803, 50.711, 51.503, 50.613, 51.348 - this branch: 47.557, 47.128, 48.197, 47.651, 47.216 * This is less impactful than then improvements at shapgh-3944, but also less intrusive, and the improvements should compound with each other. Co-authored-by: Aaron R Hall <arhall@lanl.gov>
| return 0 if np.allclose(i, j, equal_nan=True) else 1 | ||
| elif hasattr(i, "dtype") and hasattr(j, "dtype"): | ||
| if np.issubdtype(i.dtype, np.number) and np.issubdtype(j.dtype, np.number): | ||
| return 0 if np.allclose(i, j, equal_nan=True) else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part of the changes was needed because some "perfectly-normal" NumPy arrays in the testsuite were not caught by the condition above after removing the frompyfunc business (which, of course, doesn't actually make CPython code any faster).
Obviously, make sure you/we are comfortable that the current testsuite is robust for this codepath. It did seem to have many failures without this shim, so possibly coverage is "ok."
| return 0 if np.allclose(i, j, equal_nan=True) else 1 | ||
| elif hasattr(i, "dtype") and hasattr(j, "dtype"): | ||
| if np.issubdtype(i.dtype, np.number) and np.issubdtype(j.dtype, np.number): | ||
| return 0 if np.allclose(i, j, equal_nan=True) else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. I started a new conversation on these lines 'cause I somehow couldn't comment + suggest on the 4 lines on the previous one. Isn't there also a case where i and j are not a subtype of np.number? Is there a chance that we return None here? I can't think of one but just to be sure could we do:
| return 0 if np.allclose(i, j, equal_nan=True) else 1 | |
| elif hasattr(i, "dtype") and hasattr(j, "dtype"): | |
| if np.issubdtype(i.dtype, np.number) and np.issubdtype(j.dtype, np.number): | |
| return 0 if np.allclose(i, j, equal_nan=True) else 1 | |
| return 0 if np.allclose(i, j, equal_nan=True) else 1 | |
| elif hasattr(i, "dtype") and hasattr(j, "dtype"): | |
| if np.issubdtype(i.dtype, np.number) and np.issubdtype(j.dtype, np.number): | |
| return 0 if np.allclose(i, j, equal_nan=True) else 1 | |
| return 0 if i == j else 1 |
To have the general comparison as a fallback for all cases where if/else branches might be missing?
| x_group = x_group.todense() | ||
| num_mismatches = np.sum(np.frompyfunc(self.not_equal, 2, 1)(x_group, self.data.data[:, inds])) | ||
| varying[i] = num_mismatches > 0 | ||
| varying[i] = self.not_equal(x_group, self.data.data[:, inds]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice improvement
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #3983 +/- ##
==========================================
+ Coverage 64.67% 64.71% +0.04%
==========================================
Files 92 92
Lines 12862 12884 +22
==========================================
+ Hits 8318 8338 +20
- Misses 4544 4546 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
* Expand the handling of array/object types in `not_equal()` based on reviewer feedback. Add a related test that fails when using the naive `i == j` fallback.
for more information, see https://pre-commit.ci
| return 0 if np.allclose(i, j, equal_nan=True) else 1 | ||
| if np.issubdtype(i.dtype, np.bool_) and np.issubdtype(j.dtype, np.bool_): | ||
| return 0 if np.allclose(i, j, equal_nan=True) else 1 | ||
| return 0 if all(i == j) else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to address #3983 (comment), but found that I actually needed a slight adjustment to the fallback for array-like inputs with all(). I also added custom handling for bool_, which seems "ok" with np.allclose().
| np.testing.assert_allclose(sigm(shap_values.values.sum(1) + explainer.expected_value), pred, atol=1e-04) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dt", [np.bool_, np.object_]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is a bit weird in the sense that it doesn't really "fail before" and "pass after," but it does fail if the fallback I recently added uses return 0 if i == j else 1 instead of return 0 if all(i == j) else 1 because for the np.object_ dtype you'll get the typical ValueError: The truth value of an array with more than one element is ambiguous.
So it is more of a sniff test for things slipping through the cracks, but it isn't super picky beyond that, and in fact the original returning of None noted above in some cases still allows the test to pass. It may be possible to cook the test up in a way that is more stringent than that, but it is at least doing something.
* Relax the stringency of the numeric closeness check in `test_explainer_non_number_dtype`, which was failing in CI.
| rf.fit(X, y) | ||
| explainer = shap.KernelExplainer(model=rf.predict_proba, data=X, random_state=seed) | ||
| shap_values = explainer(X) | ||
| np.testing.assert_allclose(shap_values.values.max(), 0.26548, rtol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the rtol needed for CI to be happy is rather high; conversely, the assertion here isn't really needed for the purpose of the test, which is mostly to fail if there is fundamental logic issue in not_equal like using i == j when all(i == j) is needed.
CloseChoice
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGMT, very nice.
A few minor simplifications/adjustments that seem to give about 5-10% better performance for non-tree
KernelExplainerfor the scenario described at Query, ENH: faster KernelExplainer #3943.I don't think we have formal
asv-style benchmarks in this project, but informal testing with 5 trials each (times in seconds) for the reproducer in the above issue (on an ARM Mac):master: 52.803, 50.711, 51.503, 50.613, 51.348This is less impactful than then improvements at ENH: faster non-tree KernelExplainer #3944, but also less intrusive, and the improvements should compound with each other.
I added postdoc @arhall0 as a co-author on the commit, since they did the initial profiling work to find the ~10% bottleneck.
Since it appears that we only really need to identify the first point of difference between the arrays, it is probably possible to write an early-break algorithm that is more efficient, but that would probably increase complexity/need for compiled backend, so this is just a start to bump things a little bit.
Checklist