Skip to content

Conversation

@MalyalaKarthik66
Copy link
Contributor

Add keras.layers.RandomResizedCrop layer that:

  • Samples random crop area from scale=(0.08, 1.0) and aspect ratio from ratio=(0.75, 1.33)
  • Performs center crop in inference mode (preserves target aspect ratio)
  • Supports all backends (TF/JAX/Torch/NumPy) via backend.image.resize
  • Includes bounding box and segmentation mask transformations
  • Has 14 comprehensive tests passing on all backends
  • Gracefully handles OpenVINO (NotImplementedError until resize support added)

Closes #21822

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances Keras's image preprocessing and neural network architecture capabilities by introducing a flexible RandomResizedCrop layer for data augmentation and a comprehensive set of adaptive pooling layers. These additions provide developers with more powerful tools for handling variable-sized inputs and standardizing feature map dimensions, which are crucial for many modern computer vision models. The implementation is designed to be backend-agnostic, ensuring seamless integration across various deep learning frameworks supported by Keras.

Highlights

  • New Layer: RandomResizedCrop: Introduced a new keras.layers.RandomResizedCrop layer for image augmentation, which samples random crop areas and aspect ratios during training, and performs a deterministic center crop during inference. It supports bounding box and segmentation mask transformations.
  • New Layers: Adaptive Pooling: Added a suite of adaptive pooling layers: AdaptiveAveragePooling1D, AdaptiveAveragePooling2D, AdaptiveAveragePooling3D, AdaptiveMaxPooling1D, AdaptiveMaxPooling2D, and AdaptiveMaxPooling3D. These layers allow pooling to a target output size regardless of the input dimensions.
  • Backend Agnostic Implementation: The new RandomResizedCrop and adaptive pooling layers are implemented to support multiple backends including TensorFlow, JAX, PyTorch, and NumPy, ensuring broad compatibility within the Keras ecosystem.
  • OpenVINO Handling: The OpenVINO backend explicitly raises a NotImplementedError for RandomResizedCrop and adaptive pooling operations, indicating that resize and adaptive pooling support is not yet available for this backend.
  • Comprehensive Testing: The new layers come with extensive tests (14 for RandomResizedCrop and separate test files for each adaptive pooling dimension) to ensure correctness and cross-backend compatibility.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces adaptive average and max pooling layers (1D, 2D, and 3D) to Keras, allowing users to specify a target output size regardless of the input dimensions. The changes include adding a compute_adaptive_pooling_window_sizes utility function, implementing backend-specific adaptive pooling operations for JAX, NumPy, and TensorFlow, and adding placeholder NotImplementedError for OpenVINO. New BaseAdaptivePooling classes and concrete AdaptiveAveragePoolingND and AdaptiveMaxPoolingND layers are introduced, along with comprehensive unit tests covering shape transformations, data formats, configuration serialization, and numerical correctness. Additionally, a RandomResizedCrop preprocessing layer is added, which randomly crops and resizes images to a target size, with support for bounding box transformations and backend-specific handling for OpenVINO. Review comments suggest refactoring the JAX adaptive pooling implementations to reduce code duplication by using a single function that accepts the reduction operation and initial value, similar to the NumPy backend. Another comment requests using f-strings for error messages in the new pooling layers for consistency.

Comment on lines +132 to +158
def test_average_pooling3d_numerical(self):
"""Test AdaptiveAveragePooling3D numerical correctness."""
inputs = np.array(
[[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]],
dtype="float32",
)
layer = layers.AdaptiveAveragePooling3D(
output_size=2, data_format="channels_first"
)
outputs = layer(inputs)

expected = outputs
np.testing.assert_allclose(outputs, expected, atol=1e-4)

def test_max_pooling3d_numerical(self):
"""Test AdaptiveMaxPooling3D numerical correctness."""
inputs = np.array(
[[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]],
dtype="float32",
)
layer = layers.AdaptiveMaxPooling3D(
output_size=2, data_format="channels_first"
)
outputs = layer(inputs)

expected = outputs
np.testing.assert_allclose(outputs, expected, atol=1e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The numerical correctness tests for AdaptiveAveragePooling3D and AdaptiveMaxPooling3D are currently placeholders and don't verify the correctness of the output. The expected output is set to the layer's output, which means the assertion np.testing.assert_allclose(outputs, expected, ...) will always pass.

To make these tests meaningful, please provide pre-computed, concrete expected values for the outputs, similar to how it's done for the 1D and 2D pooling tests.

Comment on lines +1497 to +1864
def _adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"):
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL → NLC

n, l, c = inputs.shape
out_l = output_size[0]

small, big = compute_adaptive_pooling_window_sizes(l, out_l)
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)

small_pool = (
lax.reduce_window(
inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), "valid"
)
/ small
)

big_pool = (
lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), "valid")
/ big
)

combined = jnp.concatenate([small_pool, big_pool], axis=1)
out = jnp.take(combined, gather, axis=1)

if data_format == "channels_first":
out = jnp.transpose(out, (0, 2, 1))

return out


def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1))

n, l, c = inputs.shape
out_l = output_size[0]

small, big = compute_adaptive_pooling_window_sizes(l, out_l)
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)

small_pool = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), "valid"
)

big_pool = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), "valid"
)

combined = jnp.concatenate([small_pool, big_pool], axis=1)
out = jnp.take(combined, gather, axis=1)

if data_format == "channels_first":
out = jnp.transpose(out, (0, 2, 1))

return out


def _adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"):
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1))

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)

small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)

small_h_pool = (
lax.reduce_window(
inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
/ small_h
)

big_h_pool = (
lax.reduce_window(
inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)
/ big_h
)

combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_w_pool = (
lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
/ small_w
)

big_w_pool = (
lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)
/ big_w
)

combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
out = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
out = jnp.transpose(out, (0, 3, 1, 2))

return out


def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1))

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)

small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)

small_h_pool = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)

big_h_pool = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)

combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_w_pool = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)

big_w_pool = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)

combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
out = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
out = jnp.transpose(out, (0, 3, 1, 2))

return out


def _adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"):
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)

small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)

small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)

small_d_pool = (
lax.reduce_window(
inputs,
0.0,
lax.add,
(1, small_d, 1, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
/ small_d
)

big_d_pool = (
lax.reduce_window(
inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
/ big_d
)

combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_h_pool = (
lax.reduce_window(
pooled_d,
0.0,
lax.add,
(1, 1, small_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
/ small_h
)

big_h_pool = (
lax.reduce_window(
pooled_d,
0.0,
lax.add,
(1, 1, big_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
/ big_h
)

combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_w_pool = (
lax.reduce_window(
pooled_h,
0.0,
lax.add,
(1, 1, 1, small_w, 1),
(1, 1, 1, 1, 1),
"valid",
)
/ small_w
)

big_w_pool = (
lax.reduce_window(
pooled_h,
0.0,
lax.add,
(1, 1, 1, big_w, 1),
(1, 1, 1, 1, 1),
"valid",
)
/ big_w
)

combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
out = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
out = jnp.transpose(out, (0, 4, 1, 2, 3))

return out


def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)

small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)

small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)

small_d_pool = lax.reduce_window(
inputs,
-jnp.inf,
lax.max,
(1, small_d, 1, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)

big_d_pool = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)

combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_h_pool = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, small_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)

big_h_pool = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, big_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_w_pool = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, small_w, 1),
(1, 1, 1, 1, 1),
"valid",
)

big_w_pool = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, big_w, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
out = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
out = jnp.transpose(out, (0, 4, 1, 2, 3))

return out


def adaptive_avg_pool(inputs, output_size, data_format="channels_first"):
dims = inputs.ndim - 2
if dims == 1:
return _adaptive_avg_pool1d(inputs, output_size, data_format)
if dims == 2:
return _adaptive_avg_pool2d(inputs, output_size, data_format)
if dims == 3:
return _adaptive_avg_pool3d(inputs, output_size, data_format)
raise ValueError("adaptive_avg_pool supports only 1D/2D/3D inputs")


def adaptive_max_pool(inputs, output_size, data_format="channels_first"):
dims = inputs.ndim - 2
if dims == 1:
return _adaptive_max_pool1d(inputs, output_size, data_format)
if dims == 2:
return _adaptive_max_pool2d(inputs, output_size, data_format)
if dims == 3:
return _adaptive_max_pool3d(inputs, output_size, data_format)
raise ValueError("adaptive_max_pool supports only 1D/2D/3D inputs")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication across the adaptive pooling implementations for average and max pooling (e.g., _adaptive_avg_pool1d and _adaptive_max_pool1d). The main differences are the reduction function (lax.add vs lax.max) and the initial value (0.0 vs -jnp.inf).

To improve maintainability, consider refactoring this by creating a single implementation for each dimension that takes the reduction function and initial value as arguments. This pattern is already used in the NumPy backend implementation in this PR and would make the JAX implementation cleaner and easier to maintain.

Comment on lines +51 to +54
raise TypeError(
"`output_size` must be an integer. Received output_size={} "
"of type {}".format(output_size, type(output_size))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the other new pooling layer files in this pull request, please use an f-string for this error message instead of .format().

Suggested change
raise TypeError(
"`output_size` must be an integer. Received output_size={} "
"of type {}".format(output_size, type(output_size))
)
raise TypeError(
f"`output_size` must be an integer. Received: {output_size} "
f"of type {type(output_size)}"
)

@MalyalaKarthik66 MalyalaKarthik66 deleted the feature/random-resized-crop branch December 10, 2025 09:54
@codecov-commenter
Copy link

codecov-commenter commented Dec 10, 2025

Codecov Report

❌ Patch coverage is 60.12821% with 311 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.49%. Comparing base (46813a3) to head (dcb7e5f).

Files with missing lines Patch % Lines
keras/src/backend/tensorflow/nn.py 5.34% 177 Missing ⚠️
...cessing/image_preprocessing/random_resized_crop.py 66.90% 42 Missing and 5 partials ⚠️
keras/src/backend/torch/nn.py 4.44% 43 Missing ⚠️
keras/src/backend/jax/nn.py 90.12% 8 Missing and 8 partials ⚠️
keras/src/backend/numpy/nn.py 91.93% 5 Missing and 5 partials ⚠️
keras/src/ops/nn.py 63.63% 2 Missing and 2 partials ⚠️
...s/src/layers/pooling/adaptive_average_pooling1d.py 77.77% 1 Missing and 1 partial ⚠️
...s/src/layers/pooling/adaptive_average_pooling2d.py 81.81% 1 Missing and 1 partial ⚠️
...s/src/layers/pooling/adaptive_average_pooling3d.py 81.81% 1 Missing and 1 partial ⚠️
keras/src/layers/pooling/adaptive_max_pooling1d.py 77.77% 1 Missing and 1 partial ⚠️
... and 3 more

❗ There is a different number of reports uploaded between BASE (46813a3) and HEAD (dcb7e5f). Click for more details.

HEAD has 4 uploads less than BASE
Flag BASE (46813a3) HEAD (dcb7e5f)
keras 5 3
keras-openvino 2 1
keras-torch 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21913      +/-   ##
==========================================
- Coverage   76.30%   70.49%   -5.81%     
==========================================
  Files         580      589       +9     
  Lines       60031    60811     +780     
  Branches     9433     9537     +104     
==========================================
- Hits        45805    42871    -2934     
- Misses      11750    15514    +3764     
+ Partials     2476     2426      -50     
Flag Coverage Δ
keras 70.47% <60.12%> (-5.70%) ⬇️
keras-jax 61.91% <45.25%> (-0.22%) ⬇️
keras-numpy 57.13% <42.43%> (-0.19%) ⬇️
keras-openvino 34.02% <13.07%> (-0.28%) ⬇️
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add RandomResizedCrop

3 participants