Adding "count_include_pad" argument to flax.linen.pooling.avg_pool#2451
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2451 +/- ##
==========================================
- Coverage 79.66% 78.83% -0.84%
==========================================
Files 49 49
Lines 4982 5070 +88
==========================================
+ Hits 3969 3997 +28
- Misses 1013 1073 +60 ☔ View full report in Codecov by Sentry. |
|
Hey @dslisleedh, thanks for creating this PR! |
|
To @cgarciae, Sorry, I tested my self but didn't share result in PR. Here's result. This is first time for me to create PR, so tell me anything if I missed something. Thank you. |
|
Just noticed there aren't any tests for |
|
I found PoolTest class from ./tests/linen/test_linen.py When I ran this code with the edits from this PR there were no problems. |
|
@dslisleedh can you create one of more tests under |
…iv_shape for it raises error when there's no batch dimension
|
To @cgarciae I add some codes to TestPool and here's code I tested. When I tested with the previous code, an error occurred in the non-batch avg_pool, so I corrected the PR. Thanks for telling me to test it for sure. Below is the test result with the modified code. Thank you. |
|
Awesome @dslisleedh! Can you commit changes to the tests? |
…iv_shape for it raises error when there's no batch dimension
|
To @cgarciae Sure :) |
|
@dslisleedh can you add this test? None of the other tests used @parameterized.parameters(
{'count_include_pad': True},
{'count_include_pad': False})
def test_avg_pool_padding_same(self, count_include_pad):
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
pool = lambda x: nn.avg_pool(x, (2, 2), padding="SAME", count_include_pad=count_include_pad)
y = pool(x)
if count_include_pad:
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1))
else:
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1))
np.testing.assert_allclose(y, expected_y) |
|
@cgarciae Oh, I forgot that. Thank you. and here is result of your code. |
cgarciae
left a comment
There was a problem hiding this comment.
This is great @dslisleedh, thanks for going through with this!
What does this PR do?
Now version's flax.linen.pooling.avg_pool average window_sum result include padded tokens. I add argument whether to include padded tokens or not
Checklist
issues
documentation guidelines.
(No quality testing = no merge!)