-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement pyro.ops.streaming module #2856
Conversation
@eb8680 I've added you as a reviewer because these streaming classes create a new semigroup abstraction and you're the resident algebra expert. |
@mtsokol I believe you can use something like the following statistics in #2843: from pyro.ops.streaming import CountMeanVariance, StatsOfDict
...
stats = StatsOfDict(default=CountMeanVariance)
for mcmc_sample in ...: # learning loop
stats.update({
name: transformed_sample for name, transformed_sample in mcmc_sample.items()
})
result = stats.get() Let me know if it looks like you'll need any changes to this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat API! I'm a little confused by some of the types. Did you try running mypy locally?
self.count += 1 | ||
|
||
def merge(self, other: "CountStats"): | ||
assert isinstance(other, type(self)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these assertions should no longer be necessary with type hints
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, good point more generally. However I'd like to argue that we should include both type hints and assertions until all common tools can leverage type hints. My reasoning is that I'd really like to catch errors as early as possible, e.g. when users (like me) are working in a jupyter notebook. I think until Jupyter dynamically checks types while editing we'll want extra guard rails especially for tricky interfaces like this.
pyro/ops/streaming.py
Outdated
""" | ||
def __init__( | ||
self, | ||
types: Dict[object, Type[StreamingStats]] = {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be strengthened to Dict[str, Type[StreamingStats]]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I strengthened to Hashable
, but I think we do want to support e.g. integer keys among chains.
@eb8680 thanks for reviewing!
Sorry, I didn't run mypy locally, and some of the types are stale after refactoring. I'll fix... UPDATE ...fixed and ran mypy locally. |
Current #2857 draft isn't chain-aware and I'm wondering about it. It can be either handled by class CountMeanStats(StreamingStats):
def __init__(self, num_chains=1):
self.counts = [0] * num_chains
...
def update(self, sample, chain_index=0):
...
def get(self, group_by_chain=True):
# we can sum across chains so the update in self._statistics.update({
name: transformed_sample for name, transformed_sample in z_acc.items()
}, chain_index) Otherwise it can be handled by |
@mtsokol I think it's best to keep chain logic in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after merge conflicts are resolved
Thanks for reviewing @eb8680! Looks like I'll be using this right away in my mutation models 😄 |
Addresses #2843
This implements a new module
pyro.ops.streaming
to streamingly track various statistics. The first intended use case is the plannedStreamingMCMC
class which will track statistics rather than store samples. There are other potential uses in high-dimensional inference, e.g. recording statistics of gradients during SVI and computing sample moments from predictive when the samples don't fit in memory.Design choices
The two basic operations are
.update()
and.get()
. The third operation.merge()
will be useful for multiple-chain MCMC and computing things like rhat.I have restricted to the data type to dictionaries of tensors, which is the basic datatype in pyro.infer.mcmc and in much of NumPyro. We could easily generalize this to pytrees by adding classes
StatsOfList
andStatsOfTuple
.Tested