Skip to content

Fix SpatialSoftmax input shape#150

Merged
alexander-soare merged 2 commits into
huggingface:mainfrom
alexander-soare:fix_spatial_softmax_input_shape
May 8, 2024
Merged

Fix SpatialSoftmax input shape#150
alexander-soare merged 2 commits into
huggingface:mainfrom
alexander-soare:fix_spatial_softmax_input_shape

Conversation

@alexander-soare

Copy link
Copy Markdown
Contributor

What this does

The input shape for SpatialSoftmax should be the crop shape, not the policy's input shape. This was working before by a lucky coincidence (downsampling ended up producing a feature map that was the same size for the cropped vs non-cropped version, but this wouldn't work for larger input images).

@Cadene Cadene left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment that you could address before merging.

Comment on lines 319 to 323
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
self.backbone(
torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape))
).shape[1:]
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just some details on this line. It's difficult to understand.

Could it be better to break it into a few lines?

  • create Input
  • run forward -> get output
  • get shape

Also to explain why we get first dimension [0]

config.input_shapes["observation.image"][0]

And why do we use size=? It's quite unusual.

Finally, I am wonder why we dont create our torch.zeros tensor with the device argument.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I broke it down. To answer your questions:

And why do we use size=? It's quite unusual.

I just like being explicit with this arg because in numpy you sometimes don't have it first. My habit.

Finally, I am wonder why we dont create our torch.zeros tensor with the device argument.

The device is not specified at this point so we default to CPU.

@alexander-soare alexander-soare merged commit f5de57b into huggingface:main May 8, 2024
@alexander-soare alexander-soare deleted the fix_spatial_softmax_input_shape branch May 8, 2024 13:57
menhguin pushed a commit to menhguin/lerobot that referenced this pull request Feb 9, 2025
Kalcy-U referenced this pull request in Kalcy-U/lerobot May 13, 2025
ZoreAnuj pushed a commit to luckyrobots/lerobot that referenced this pull request Jul 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants