Fix SpatialSoftmax input shape#150
Conversation
Cadene
left a comment
There was a problem hiding this comment.
Left a comment that you could address before merging.
| feat_map_shape = tuple( | ||
| self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] | ||
| self.backbone( | ||
| torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) | ||
| ).shape[1:] | ||
| ) |
There was a problem hiding this comment.
Thanks! Just some details on this line. It's difficult to understand.
Could it be better to break it into a few lines?
- create Input
- run forward -> get output
- get shape
Also to explain why we get first dimension [0]
config.input_shapes["observation.image"][0]
And why do we use size=? It's quite unusual.
Finally, I am wonder why we dont create our torch.zeros tensor with the device argument.
There was a problem hiding this comment.
I broke it down. To answer your questions:
And why do we use size=? It's quite unusual.
I just like being explicit with this arg because in numpy you sometimes don't have it first. My habit.
Finally, I am wonder why we dont create our torch.zeros tensor with the device argument.
The device is not specified at this point so we default to CPU.
What this does
The input shape for SpatialSoftmax should be the crop shape, not the policy's input shape. This was working before by a lucky coincidence (downsampling ended up producing a feature map that was the same size for the cropped vs non-cropped version, but this wouldn't work for larger input images).