-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support data subsampling #734
Conversation
Looks mostly correct to me, but you might want to add tests that exercise replay and nested/reused plates, like the model in this Pyro test.
Part of the extra complexity in Pyro comes from supporting sequential plates, which I guess isn't a goal here, and part of it comes from treating subsample sites as |
Thanks @eb8680 ! I didn't even think about this. Thanks for pointing it out. :)
I didn't have that plan. I already took a look at its implementation but didn't understand how it works yet. Let me see if it is easy to address. |
The implementation in Pyro is here in Here's a simplified version that does not include parameter subsampling: def postprocess_message(self, msg): # new method for plate
if msg["type"] == "subsample" and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
shape = msg["value"].shape
if len(shape) >= -dim and shape[dim] != 1:
if is_validation_enabled() and shape[dim] != self.size:
statement = "pyro.subsample(..., event_dim={})".format(event_dim)
raise ValueError(
"Inside pyro.plate({}, {}, dim={}) invalid shape of {}: {}"
.format(self.name, self.size, self.dim, statement, shape))
if self.subsample_size < self.size:
value = msg["value"]
new_value = value.index_select(dim, self._indices)
msg["value"] = new_value |
Thanks for the snippet, Eli! My intention is just making a quick PR for this but it is worth to extend this to support subsample and param primitives. Also Pyro allows subsample_size is None in model and not None in guide. I remember @fritzo said that he will use NumPyro if we have this feature implemented. That's a good motivation. :)) |
@eb8680 I have added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note sure if this first subsample PR is the right place for thorough tests of SVI, but I found Pyro's test_subsample_gradient() to be invaluable in detecting subsampling issues with very little compute cost. In particular it helped us correctly scale the different parts of the ScoreParts
tuple.
Does this PR aim to support subsampling in SVI? If so do you already have good tests for poutine.scale()
?
@fritzo Thanks for reviewing! I just add
My main motivation is to support #724, where subsample is required (see this comment). But SVI should work I believe. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for implementing this!
LGTM, @eb8680 did you have any other comments? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks for adding the extra tests!
|
||
with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace): | ||
with numpyro.plate("a", len(data)): | ||
subsample_data = numpyro.subsample(data, event_dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test exercises subsample
primitive works under replay
handler.
@@ -192,7 +192,7 @@ def __init__(self, fn=None, guide_trace=None): | |||
super(replay, self).__init__(fn) | |||
|
|||
def process_message(self, msg): | |||
if msg['name'] in self.guide_trace and msg['type'] in ('sample', 'plate'): | |||
if msg['type'] in ('sample', 'plate') and msg['name'] in self.guide_trace: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I switch the order of two sides of and
because some primitives such as subsample
do not have name
field.
Addresses #556. This will also be helpful for #724.
@fritzo The implementation here is a bit simpler than Pyro. But I am not sure if I miss any point. Could you help me review this PR?