-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Add RandomResizedCrop layer #21913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RandomResizedCrop layer #21913
Conversation
…X, NumPy, PyTorch, and TensorFlow backends
…ture/adaptive-pooling
…ture/random-resized-crop
Summary of ChangesHello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances Keras's image preprocessing and neural network architecture capabilities by introducing a flexible Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces adaptive average and max pooling layers (1D, 2D, and 3D) to Keras, allowing users to specify a target output size regardless of the input dimensions. The changes include adding a compute_adaptive_pooling_window_sizes utility function, implementing backend-specific adaptive pooling operations for JAX, NumPy, and TensorFlow, and adding placeholder NotImplementedError for OpenVINO. New BaseAdaptivePooling classes and concrete AdaptiveAveragePoolingND and AdaptiveMaxPoolingND layers are introduced, along with comprehensive unit tests covering shape transformations, data formats, configuration serialization, and numerical correctness. Additionally, a RandomResizedCrop preprocessing layer is added, which randomly crops and resizes images to a target size, with support for bounding box transformations and backend-specific handling for OpenVINO. Review comments suggest refactoring the JAX adaptive pooling implementations to reduce code duplication by using a single function that accepts the reduction operation and initial value, similar to the NumPy backend. Another comment requests using f-strings for error messages in the new pooling layers for consistency.
| def test_average_pooling3d_numerical(self): | ||
| """Test AdaptiveAveragePooling3D numerical correctness.""" | ||
| inputs = np.array( | ||
| [[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]], | ||
| dtype="float32", | ||
| ) | ||
| layer = layers.AdaptiveAveragePooling3D( | ||
| output_size=2, data_format="channels_first" | ||
| ) | ||
| outputs = layer(inputs) | ||
|
|
||
| expected = outputs | ||
| np.testing.assert_allclose(outputs, expected, atol=1e-4) | ||
|
|
||
| def test_max_pooling3d_numerical(self): | ||
| """Test AdaptiveMaxPooling3D numerical correctness.""" | ||
| inputs = np.array( | ||
| [[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]], | ||
| dtype="float32", | ||
| ) | ||
| layer = layers.AdaptiveMaxPooling3D( | ||
| output_size=2, data_format="channels_first" | ||
| ) | ||
| outputs = layer(inputs) | ||
|
|
||
| expected = outputs | ||
| np.testing.assert_allclose(outputs, expected, atol=1e-4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numerical correctness tests for AdaptiveAveragePooling3D and AdaptiveMaxPooling3D are currently placeholders and don't verify the correctness of the output. The expected output is set to the layer's output, which means the assertion np.testing.assert_allclose(outputs, expected, ...) will always pass.
To make these tests meaningful, please provide pre-computed, concrete expected values for the outputs, similar to how it's done for the 1D and 2D pooling tests.
| def _adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"): | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL → NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small, big = compute_adaptive_pooling_window_sizes(l, out_l) | ||
| gather = _compute_adaptive_pooling_gather_indices(l, out_l, big) | ||
|
|
||
| small_pool = ( | ||
| lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), "valid" | ||
| ) | ||
| / small | ||
| ) | ||
|
|
||
| big_pool = ( | ||
| lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), "valid") | ||
| / big | ||
| ) | ||
|
|
||
| combined = jnp.concatenate([small_pool, big_pool], axis=1) | ||
| out = jnp.take(combined, gather, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| out = jnp.transpose(out, (0, 2, 1)) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small, big = compute_adaptive_pooling_window_sizes(l, out_l) | ||
| gather = _compute_adaptive_pooling_gather_indices(l, out_l, big) | ||
|
|
||
| small_pool = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| big_pool = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined = jnp.concatenate([small_pool, big_pool], axis=1) | ||
| out = jnp.take(combined, gather, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| out = jnp.transpose(out, (0, 2, 1)) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| def _adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"): | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) | ||
| gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) | ||
| gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_h_pool = ( | ||
| lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| / small_h | ||
| ) | ||
|
|
||
| big_h_pool = ( | ||
| lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| / big_h | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_w_pool = ( | ||
| lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| / small_w | ||
| ) | ||
|
|
||
| big_w_pool = ( | ||
| lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| / big_w | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2) | ||
| out = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| out = jnp.transpose(out, (0, 3, 1, 2)) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) | ||
| gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) | ||
| gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_h_pool = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| big_h_pool = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_w_pool = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| big_w_pool = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2) | ||
| out = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| out = jnp.transpose(out, (0, 3, 1, 2)) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| def _adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"): | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d) | ||
| gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) | ||
| gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) | ||
| gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_d_pool = ( | ||
| lax.reduce_window( | ||
| inputs, | ||
| 0.0, | ||
| lax.add, | ||
| (1, small_d, 1, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| / small_d | ||
| ) | ||
|
|
||
| big_d_pool = ( | ||
| lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| / big_d | ||
| ) | ||
|
|
||
| combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_h_pool = ( | ||
| lax.reduce_window( | ||
| pooled_d, | ||
| 0.0, | ||
| lax.add, | ||
| (1, 1, small_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| / small_h | ||
| ) | ||
|
|
||
| big_h_pool = ( | ||
| lax.reduce_window( | ||
| pooled_d, | ||
| 0.0, | ||
| lax.add, | ||
| (1, 1, big_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| / big_h | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_w_pool = ( | ||
| lax.reduce_window( | ||
| pooled_h, | ||
| 0.0, | ||
| lax.add, | ||
| (1, 1, 1, small_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| / small_w | ||
| ) | ||
|
|
||
| big_w_pool = ( | ||
| lax.reduce_window( | ||
| pooled_h, | ||
| 0.0, | ||
| lax.add, | ||
| (1, 1, 1, big_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| / big_w | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3) | ||
| out = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| out = jnp.transpose(out, (0, 4, 1, 2, 3)) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d) | ||
| gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h) | ||
| gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w) | ||
| gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_d_pool = lax.reduce_window( | ||
| inputs, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, small_d, 1, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| big_d_pool = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_h_pool = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, small_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| big_h_pool = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, big_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_w_pool = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, small_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| big_w_pool = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, big_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3) | ||
| out = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| out = jnp.transpose(out, (0, 4, 1, 2, 3)) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| def adaptive_avg_pool(inputs, output_size, data_format="channels_first"): | ||
| dims = inputs.ndim - 2 | ||
| if dims == 1: | ||
| return _adaptive_avg_pool1d(inputs, output_size, data_format) | ||
| if dims == 2: | ||
| return _adaptive_avg_pool2d(inputs, output_size, data_format) | ||
| if dims == 3: | ||
| return _adaptive_avg_pool3d(inputs, output_size, data_format) | ||
| raise ValueError("adaptive_avg_pool supports only 1D/2D/3D inputs") | ||
|
|
||
|
|
||
| def adaptive_max_pool(inputs, output_size, data_format="channels_first"): | ||
| dims = inputs.ndim - 2 | ||
| if dims == 1: | ||
| return _adaptive_max_pool1d(inputs, output_size, data_format) | ||
| if dims == 2: | ||
| return _adaptive_max_pool2d(inputs, output_size, data_format) | ||
| if dims == 3: | ||
| return _adaptive_max_pool3d(inputs, output_size, data_format) | ||
| raise ValueError("adaptive_max_pool supports only 1D/2D/3D inputs") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication across the adaptive pooling implementations for average and max pooling (e.g., _adaptive_avg_pool1d and _adaptive_max_pool1d). The main differences are the reduction function (lax.add vs lax.max) and the initial value (0.0 vs -jnp.inf).
To improve maintainability, consider refactoring this by creating a single implementation for each dimension that takes the reduction function and initial value as arguments. This pattern is already used in the NumPy backend implementation in this PR and would make the JAX implementation cleaner and easier to maintain.
| raise TypeError( | ||
| "`output_size` must be an integer. Received output_size={} " | ||
| "of type {}".format(output_size, type(output_size)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the other new pooling layer files in this pull request, please use an f-string for this error message instead of .format().
| raise TypeError( | |
| "`output_size` must be an integer. Received output_size={} " | |
| "of type {}".format(output_size, type(output_size)) | |
| ) | |
| raise TypeError( | |
| f"`output_size` must be an integer. Received: {output_size} " | |
| f"of type {type(output_size)}" | |
| ) |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21913 +/- ##
==========================================
- Coverage 76.30% 70.49% -5.81%
==========================================
Files 580 589 +9
Lines 60031 60811 +780
Branches 9433 9537 +104
==========================================
- Hits 45805 42871 -2934
- Misses 11750 15514 +3764
+ Partials 2476 2426 -50
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Add
keras.layers.RandomResizedCroplayer that:scale=(0.08, 1.0)and aspect ratio fromratio=(0.75, 1.33)backend.image.resizeCloses #21822