-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow pickle autoguide #1169
Allow pickle autoguide #1169
Conversation
@@ -106,6 +106,16 @@ def inverse_shape(self, shape): | |||
""" | |||
return shape | |||
|
|||
# Allow for pickle serialization of transforms. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This follows Pyro patch for PyTorch transforms.
@lumip This supposes to fix your pickling issue. Could you double-check? Thanks! |
@@ -104,6 +104,11 @@ def _create_plates(self, *args, **kwargs): | |||
) | |||
return self.plates | |||
|
|||
def __getstate__(self): | |||
state = self.__dict__.copy() | |||
state.pop("plates", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plates
is generated during execution, so we don't need to cache. Actually, we typically call the guide under jit
, so plates
will hold abstract values which cannot be pickled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test looks good 👍 It would be nice to be able to unpickle and use a guide without the model (via guide()
), but I guess this is a separate issue.
Thanks for reviewing, @fritzo! You can use
or
We need it because we need Using
cc @lumip |
return_sites=["param", "x"], | ||
) | ||
samples = predictive(random.PRNGKey(1), None, 1) | ||
assert set(samples.keys()) == {"param", "x"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better to assert that pickled_guide
returns the same samples as guide
instead of just verifying that all sites are present?
This seems to work well in some simple tests I ran. Unfortunately I can't right now verify that in our more elaborate settings but I may get around to that tomorrow. Unfortunately I'm not all too familiar with the code that was changed, so I cannot review those changes in detail. Edit: Have now tried it in a more complicated environment and there I ran into the issue that we have some additional decorators around our model which prevent me from pickling the guard (I get a |
Thanks for reviewing, @lumip! I think we can just remove model from a pickled guide (if the guide already has prototype trace). Let me address that. For now, you can just set guide.model = None or delete that attribute before pickling. |
@lumip Have been thinking about this for a while, it seems that it's better to let users decide to pickle guide with model or guide without model. If we remove |
Resolves #1160
TODO: