Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make ProvenanceTensor behave more like a Tensor (closes #3218) #3220

Merged
merged 8 commits into from
May 28, 2023
Merged

make ProvenanceTensor behave more like a Tensor (closes #3218) #3220

merged 8 commits into from
May 28, 2023

Conversation

ilia-kats
Copy link
Contributor

@ilia-kats ilia-kats commented May 25, 2023

The data is now stored in the Tensor itself instead of an attribute. This fixes torch.to_tensor returning empty tensors when called with a ProvenanceTensor and and a device as arguments.

This is super hacky, but I couldn't come up with a cleaner way. Note that this is the only way to use pyro.infer.inspect.get_dependencies when training on GPUs (I'm using it in a custom Messenger guide), since the log_prob function of some distributions (for example Gamma) calls torch.to_tensor.

ilia-kats added 4 commits May 25, 2023 11:56
the data is now stored in the Tensor itself instead of an attribute.
This fixes torch.to_tensor returning empty tensors when called with a
ProvenanceTensor and and a device as arguments
this is important when using Tensors as keys in a dict, e.g.
the Pyro param store
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this subtle fix! This looks good to me (after one minor comment), but I'm unsure how this will interact with other subclasses of tensor.

@ordabayevy could you also take a look as you've thought about this before?

@@ -46,15 +46,21 @@ def __new__(cls, data: torch.Tensor, provenance=frozenset(), **kwargs):
assert not isinstance(data, ProvenanceTensor)
if not provenance:
return data
return super().__new__(cls)
ret = data.view(data.shape)
ret._t = data.view(data.shape) # this makes sure that detach_provenance always
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this line be simplified to

ret._t = data

or would that break something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, thanks. Took me about four tries to get all the tests to pass, this was still a remnant of an earlier attempt.

@fritzo
Copy link
Member

fritzo commented May 27, 2023

Would you be able to add a regression test and decorate it with @requires_cuda? It won't run on CI, but it might help future maintainers of ProvenanceTensor preserve your intent in the face of future changes.

fritzo
fritzo previously approved these changes May 27, 2023
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding a test.

I'll leave this up a couple days before merging in case @ordabayevy has any comments.

@fritzo fritzo mentioned this pull request May 27, 2023
@ordabayevy
Copy link
Member

Thanks for holding it up. I'll have a look at this later tonight.

@ordabayevy
Copy link
Member

@ilia-kats thanks for fixing this!

What about trying to use torch.Tensor._make_subclass as is used by torch.nn.Parameter:

class ProvenanceTensor(torch.Tensor):
         assert not isinstance(data, ProvenanceTensor)
         if not provenance:
             return data
-        return super().__new__(cls)
+        return torch.Tensor._make_subclass(cls, data)

And I believe we can remove instance check from the __init__ method since data being a ProvenanceTensor is prohibited in the __new__ method anyways? WDYT @fritzo @ilia-kats ?

     def __init__(self, data, provenance=frozenset()):
         assert isinstance(provenance, frozenset)
-        if isinstance(data, ProvenanceTensor):
-            provenance |= data._provenance
-            data = data._t
         self._t = data
         self._provenance = provenance

also remove unnecessary check in __init__
@ilia-kats
Copy link
Contributor Author

@ordabayevy Thanks for the comment. I actually Tensor.as_subclass in the public API, which also seems to work.

@fritzo
Copy link
Member

fritzo commented May 28, 2023

@ordabayevy ready to merge? I'll release today or tomorrow and will include this PR in the release

@ordabayevy
Copy link
Member

Yeah, lgtm.

@ordabayevy ordabayevy merged commit 831c463 into pyro-ppl:dev May 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants