Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow pickle autoguide #1169

Merged
merged 3 commits into from
Oct 8, 2021
Merged

Allow pickle autoguide #1169

merged 3 commits into from
Oct 8, 2021

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Sep 25, 2021

Resolves #1160

TODO:

  • allow pickling unpack_latent
  • allow pickling postprocess_fn
  • add tests

@fehiepsi fehiepsi added the WIP label Sep 25, 2021
@@ -106,6 +106,16 @@ def inverse_shape(self, shape):
"""
return shape

# Allow for pickle serialization of transforms.
Copy link
Member Author

@fehiepsi fehiepsi Sep 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follows Pyro patch for PyTorch transforms.

@fehiepsi
Copy link
Member Author

@lumip This supposes to fix your pickling issue. Could you double-check? Thanks!

@@ -104,6 +104,11 @@ def _create_plates(self, *args, **kwargs):
)
return self.plates

def __getstate__(self):
state = self.__dict__.copy()
state.pop("plates", None)
Copy link
Member Author

@fehiepsi fehiepsi Sep 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plates is generated during execution, so we don't need to cache. Actually, we typically call the guide under jit, so plates will hold abstract values which cannot be pickled.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test looks good 👍 It would be nice to be able to unpickle and use a guide without the model (via guide()), but I guess this is a separate issue.

@fehiepsi
Copy link
Member Author

Thanks for reviewing, @fritzo! You can use guide to sample posteriors by using either

guide.sample_posterior(rng_key, svi_result.params, sample_shape=(num_samples,))

or

Predictive(guide, params=svi_result.params, num_samples=num_samples)(rng_key)

We need it because we need params, rng_key to be able to call guide().

Using Predictive(model, guide, params) is unnecessary. It is just a shortcut for

posterior_samples = guide.sample_posterior(...)
x = Predictive(model, posterior_samples, ...)(rng_key, ...)['x']

cc @lumip

return_sites=["param", "x"],
)
samples = predictive(random.PRNGKey(1), None, 1)
assert set(samples.keys()) == {"param", "x"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to assert that pickled_guide returns the same samples as guide instead of just verifying that all sites are present?

@lumip
Copy link
Contributor

lumip commented Sep 28, 2021

This seems to work well in some simple tests I ran. Unfortunately I can't right now verify that in our more elaborate settings but I may get around to that tomorrow.

Unfortunately I'm not all too familiar with the code that was changed, so I cannot review those changes in detail.

Edit: Have now tried it in a more complicated environment and there I ran into the issue that we have some additional decorators around our model which prevent me from pickling the guard (I get a AttributeError: Can't pickle local object 'guard_model.<locals>.wrapped_model' where the guard_model is our decorator). So I guess the pickling approach works, but only if the user provides a model that is pickle-able as well.

@fehiepsi
Copy link
Member Author

Thanks for reviewing, @lumip! I think we can just remove model from a pickled guide (if the guide already has prototype trace). Let me address that. For now, you can just set guide.model = None or delete that attribute before pickling.

@fehiepsi
Copy link
Member Author

fehiepsi commented Oct 8, 2021

@lumip Have been thinking about this for a while, it seems that it's better to let users decide to pickle guide with model or guide without model. If we remove model automatically, someone might be surprised that they can't access model attribute from the guide anymore. :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AutoGuides do not permit sampling using only learned parameters
3 participants