Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support data subsampling #734

Merged
merged 12 commits into from
Sep 18, 2020
Merged

Support data subsampling #734

merged 12 commits into from
Sep 18, 2020

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Sep 11, 2020

Addresses #556. This will also be helpful for #724.

@fritzo The implementation here is a bit simpler than Pyro. But I am not sure if I miss any point. Could you help me review this PR?

@fehiepsi fehiepsi added the WIP label Sep 11, 2020
@eb8680
Copy link
Member

eb8680 commented Sep 11, 2020

Looks mostly correct to me, but you might want to add tests that exercise replay and nested/reused plates, like the model in this Pyro test.

The implementation here is a bit simpler than Pyro. But I am not sure if I miss any point.

Part of the extra complexity in Pyro comes from supporting sequential plates, which I guess isn't a goal here, and part of it comes from treating subsample sites as sample sites, which was just a poor early design choice in Pyro. Also, are you planning to add a numpyro.subsample primitive?

@fehiepsi
Copy link
Member Author

fehiepsi commented Sep 11, 2020

add tests that exercise replay and nested/reused plates

Thanks @eb8680 ! I didn't even think about this. Thanks for pointing it out. :)

planning to add a numpyro.subsample primitive?

I didn't have that plan. I already took a look at its implementation but didn't understand how it works yet. Let me see if it is easy to address.

@eb8680
Copy link
Member

eb8680 commented Sep 11, 2020

Let me see if it is easy to address.

The implementation in Pyro is here in SubsampleMessenger._postprocess_message: https://github.com/pyro-ppl/pyro/blob/dev/pyro/poutine/subsample_messenger.py#L134

Here's a simplified version that does not include parameter subsampling:

    def postprocess_message(self, msg):  # new method for plate
        if msg["type"] == "subsample" and self.dim is not None:
            event_dim = msg["kwargs"].get("event_dim")
            if event_dim is not None:
                assert event_dim >= 0
                dim = self.dim - event_dim
                shape = msg["value"].shape
                if len(shape) >= -dim and shape[dim] != 1:
                    if is_validation_enabled() and shape[dim] != self.size:
                        statement = "pyro.subsample(..., event_dim={})".format(event_dim)
                        raise ValueError(
                            "Inside pyro.plate({}, {}, dim={}) invalid shape of {}: {}"
                            .format(self.name, self.size, self.dim, statement, shape))
                    if self.subsample_size < self.size:
                        value = msg["value"]
                        new_value = value.index_select(dim, self._indices)
                        msg["value"] = new_value

@fehiepsi fehiepsi added invalid This doesn't seem right and removed awaiting review labels Sep 11, 2020
@fehiepsi
Copy link
Member Author

Thanks for the snippet, Eli! My intention is just making a quick PR for this but it is worth to extend this to support subsample and param primitives. Also Pyro allows subsample_size is None in model and not None in guide. I remember @fritzo said that he will use NumPyro if we have this feature implemented. That's a good motivation. :))

@fehiepsi fehiepsi added WIP and removed invalid This doesn't seem right labels Sep 11, 2020
@fehiepsi
Copy link
Member Author

@eb8680 I have added subsample primitive and support subsample for it and param. Also made a small change to allow subsample_size=None in model but not None in guide.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note sure if this first subsample PR is the right place for thorough tests of SVI, but I found Pyro's test_subsample_gradient() to be invaluable in detecting subsampling issues with very little compute cost. In particular it helped us correctly scale the different parts of the ScoreParts tuple.

Does this PR aim to support subsampling in SVI? If so do you already have good tests for poutine.scale()?

numpyro/handlers.py Show resolved Hide resolved
numpyro/primitives.py Outdated Show resolved Hide resolved
@fehiepsi
Copy link
Member Author

@fritzo Thanks for reviewing! I just add subsample_gradient test. We have some tests for handlers.scale in test_mask and test_scale, but we haven't supported ScoreParts stuff in NumPyro yet.

Does this PR aim to support subsampling in SVI?

My main motivation is to support #724, where subsample is required (see this comment). But SVI should work I believe.

setup.py Show resolved Hide resolved
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for implementing this!

numpyro/primitives.py Show resolved Hide resolved
test/test_handlers.py Outdated Show resolved Hide resolved
fritzo
fritzo previously approved these changes Sep 17, 2020
@fritzo
Copy link
Member

fritzo commented Sep 17, 2020

LGTM, @eb8680 did you have any other comments?

eb8680
eb8680 previously approved these changes Sep 17, 2020
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for adding the extra tests!

@fehiepsi
Copy link
Member Author

fehiepsi commented Sep 18, 2020

Thanks for reviewing, @eb8680 and @fritzo! I just observed something wrong in Predictive for models with multiple subsample statements Edit turns out to be my mistake, only replay handler needs to be fixed. I'll try to fix it and add the failing test in this PR.

@fehiepsi fehiepsi dismissed stale reviews from eb8680 and fritzo via b95bf9b September 18, 2020 05:54

with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace):
with numpyro.plate("a", len(data)):
subsample_data = numpyro.subsample(data, event_dim=0)
Copy link
Member Author

@fehiepsi fehiepsi Sep 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test exercises subsample primitive works under replay handler.

@@ -192,7 +192,7 @@ def __init__(self, fn=None, guide_trace=None):
super(replay, self).__init__(fn)

def process_message(self, msg):
if msg['name'] in self.guide_trace and msg['type'] in ('sample', 'plate'):
if msg['type'] in ('sample', 'plate') and msg['name'] in self.guide_trace:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switch the order of two sides of and because some primitives such as subsample do not have name field.

@fehiepsi fehiepsi merged commit cce00d1 into pyro-ppl:master Sep 18, 2020
@OlaRonning OlaRonning mentioned this pull request Sep 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants