make ProvenanceTensor behave more like a Tensor (closes #3218)#3220
Conversation
the data is now stored in the Tensor itself instead of an attribute. This fixes torch.to_tensor returning empty tensors when called with a ProvenanceTensor and and a device as arguments
this is important when using Tensors as keys in a dict, e.g. the Pyro param store
fritzo
left a comment
There was a problem hiding this comment.
Thanks for this subtle fix! This looks good to me (after one minor comment), but I'm unsure how this will interact with other subclasses of tensor.
@ordabayevy could you also take a look as you've thought about this before?
| return data | ||
| return super().__new__(cls) | ||
| ret = data.view(data.shape) | ||
| ret._t = data.view(data.shape) # this makes sure that detach_provenance always |
There was a problem hiding this comment.
Could this line be simplified to
ret._t = dataor would that break something?
There was a problem hiding this comment.
Right, thanks. Took me about four tries to get all the tests to pass, this was still a remnant of an earlier attempt.
|
Would you be able to add a regression test and decorate it with @requires_cuda? It won't run on CI, but it might help future maintainers of |
fritzo
left a comment
There was a problem hiding this comment.
LGTM, thanks for adding a test.
I'll leave this up a couple days before merging in case @ordabayevy has any comments.
|
Thanks for holding it up. I'll have a look at this later tonight. |
|
@ilia-kats thanks for fixing this! What about trying to use class ProvenanceTensor(torch.Tensor):
assert not isinstance(data, ProvenanceTensor)
if not provenance:
return data
- return super().__new__(cls)
+ return torch.Tensor._make_subclass(cls, data)And I believe we can remove instance check from the def __init__(self, data, provenance=frozenset()):
assert isinstance(provenance, frozenset)
- if isinstance(data, ProvenanceTensor):
- provenance |= data._provenance
- data = data._t
self._t = data
self._provenance = provenance |
also remove unnecessary check in __init__
|
@ordabayevy Thanks for the comment. I actually |
|
@ordabayevy ready to merge? I'll release today or tomorrow and will include this PR in the release |
|
Yeah, lgtm. |
The data is now stored in the Tensor itself instead of an attribute. This fixes torch.to_tensor returning empty tensors when called with a ProvenanceTensor and and a device as arguments.
This is super hacky, but I couldn't come up with a cleaner way. Note that this is the only way to use
pyro.infer.inspect.get_dependencieswhen training on GPUs (I'm using it in a custom Messenger guide), since thelog_probfunction of some distributions (for example Gamma) callstorch.to_tensor.