-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement pyro.ops.streaming module #2856
Merged
Merged
Changes from 11 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
dff7517
Implement first version of pyro.ops.streaming
fritzo 27ea8c7
Refactor to simplify vector stat implementation
fritzo fe7ffb8
Fix typo
fritzo 212a9e8
Fix types; add StackStats
fritzo 93d748e
Fix doctests
fritzo b0f1096
Refine type hints
fritzo 9cd53c8
Relax StatsOfDict key type
fritzo 9072002
Fix tests
fritzo 7a4b568
Relax type hints to StatsOfDict
fritzo c684f4d
Merge branch 'dev' into ops-streaming
fritzo 7e53700
Enable type checking for pyro.ops.streaming
fritzo 171e762
Merge branch 'dev' into ops-streaming
fritzo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import copy | ||
from abc import ABC, abstractmethod | ||
from collections import defaultdict | ||
from typing import Any, Callable, Dict, Hashable, Union | ||
|
||
import torch | ||
|
||
from pyro.ops.welford import WelfordCovariance | ||
|
||
|
||
class StreamingStats(ABC): | ||
""" | ||
Abstract base class for streamable statistics of trees of tensors. | ||
|
||
Derived classes must implelement :meth:`update`, :meth:`merge`, and | ||
:meth:`get`. | ||
""" | ||
|
||
@abstractmethod | ||
def update(self, sample) -> None: | ||
""" | ||
Update state from a single sample. | ||
|
||
This mutates ``self`` and returns nothing. Updates should be | ||
independent of order, i.e. samples should be exchangeable. | ||
|
||
:param sample: A sample value which is a nested dictionary of | ||
:class:`torch.Tensor` leaves. This can have arbitrary nesting and | ||
shape shape, but assumes shape is constant across calls to | ||
``.update()``. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def merge(self, other) -> "StreamingStats": | ||
""" | ||
Select two aggregate statistics, e.g. from different MCMC chains. | ||
|
||
This is a pure function: it returns a new :class:`StreamingStats` | ||
object and does not modify either ``self`` or ``other``. | ||
|
||
:param other: Another streaming stats instance of the same type. | ||
""" | ||
assert isinstance(other, type(self)) | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get(self) -> Any: | ||
""" | ||
Return the aggregate statistic. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class CountStats(StreamingStats): | ||
""" | ||
Statistic tracking only the number of samples. | ||
|
||
For example:: | ||
|
||
>>> stats = CountStats() | ||
>>> stats.update(torch.randn(3, 3)) | ||
>>> stats.get() | ||
{'count': 1} | ||
""" | ||
|
||
def __init__(self): | ||
self.count = 0 | ||
super().__init__() | ||
|
||
def update(self, sample) -> None: | ||
self.count += 1 | ||
|
||
def merge(self, other: "CountStats") -> "CountStats": | ||
assert isinstance(other, type(self)) | ||
result = CountStats() | ||
result.count = self.count + other.count | ||
return result | ||
|
||
def get(self) -> Dict[str, int]: | ||
""" | ||
:returns: A dictionary with keys ``count: int``. | ||
:rtype: dict | ||
""" | ||
return {"count": self.count} | ||
|
||
|
||
class StatsOfDict(StreamingStats): | ||
fritzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Statistics of samples that are dictionaries with constant set of keys. | ||
|
||
For example the following are equivalent:: | ||
|
||
# Version 1. Hand encode statistics. | ||
>>> a_stats = CountStats() | ||
>>> b_stats = CountMeanStats() | ||
>>> a_stats.update(torch.tensor(0.)) | ||
>>> b_stats.update(torch.tensor([1., 2.])) | ||
>>> summary = {"a": a_stats.get(), "b": b_stats.get()} | ||
|
||
# Version 2. Collect samples into dictionaries. | ||
>>> stats = StatsOfDict({"a": CountStats, "b": CountMeanStats}) | ||
>>> stats.update({"a": torch.tensor(0.), "b": torch.tensor([1., 2.])}) | ||
>>> summary = stats.get() | ||
>>> summary | ||
{'a': {'count': 1}, 'b': {'count': 1, 'mean': tensor([1., 2.])}} | ||
|
||
:param default: Default type of statistics of values of the dictionary. | ||
Defaults to the inexpensive :class:`CountStats`. | ||
:param dict types: Dictionary mapping key to type of statistic that should | ||
be recorded for values corresponding to that key. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
types: Dict[Hashable, Callable[[], StreamingStats]] = {}, | ||
default: Callable[[], StreamingStats] = CountStats, | ||
): | ||
self.stats: Dict[Hashable, StreamingStats] = defaultdict(default) | ||
self.stats.update({k: v() for k, v in types.items()}) | ||
super().__init__() | ||
|
||
def update(self, sample: Dict[Hashable, Any]) -> None: | ||
for k, v in sample.items(): | ||
self.stats[k].update(v) | ||
|
||
def merge(self, other: "StatsOfDict") -> "StatsOfDict": | ||
assert isinstance(other, type(self)) | ||
result = copy.deepcopy(self) | ||
fritzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for k in set(self.stats).union(other.stats): | ||
if k not in self.stats: | ||
result.stats[k] = copy.deepcopy(other.stats[k]) | ||
elif k in other.stats: | ||
result.stats[k] = self.stats[k].merge(other.stats[k]) | ||
return result | ||
|
||
def get(self) -> Dict[Hashable, Any]: | ||
""" | ||
:returns: A dictionary of statistics. The keys of this dictionary are | ||
the same as the keys of the samples from which this object is | ||
updated. | ||
:rtype: dict | ||
""" | ||
return {k: v.get() for k, v in self.stats.items()} | ||
|
||
|
||
class StackStats(StreamingStats): | ||
""" | ||
Statistic collecting a stream of tensors into a single stacked tensor. | ||
""" | ||
|
||
def __init__(self): | ||
self.samples = [] | ||
|
||
def update(self, sample: torch.Tensor) -> None: | ||
assert isinstance(sample, torch.Tensor) | ||
self.samples.append(sample) | ||
|
||
def merge(self, other: "StackStats") -> "StackStats": | ||
assert isinstance(other, type(self)) | ||
result = StackStats() | ||
result.samples = self.samples + other.samples | ||
return result | ||
|
||
def get(self) -> Dict[str, Union[int, torch.Tensor]]: | ||
""" | ||
:returns: A dictionary with keys ``count: int`` and (if any samples | ||
have been collected) ``samples: torch.Tensor``. | ||
:rtype: dict | ||
""" | ||
if not self.samples: | ||
return {"count": 0} | ||
return {"count": len(self.samples), "samples": torch.stack(self.samples)} | ||
|
||
|
||
class CountMeanStats(StreamingStats): | ||
""" | ||
Statistic tracking the count and mean of a single :class:`torch.Tensor`. | ||
""" | ||
|
||
def __init__(self): | ||
self.count = 0 | ||
self.mean = 0 | ||
super().__init__() | ||
|
||
def update(self, sample: torch.Tensor) -> None: | ||
assert isinstance(sample, torch.Tensor) | ||
self.count += 1 | ||
self.mean += (sample.detach() - self.mean) / self.count | ||
|
||
def merge(self, other: "CountMeanStats") -> "CountMeanStats": | ||
assert isinstance(other, type(self)) | ||
result = CountMeanStats() | ||
result.count = self.count + other.count | ||
p = self.count / max(result.count, 1) | ||
q = other.count / max(result.count, 1) | ||
result.mean = p * self.mean + q * other.mean | ||
return result | ||
|
||
def get(self) -> Dict[str, Union[int, torch.Tensor]]: | ||
""" | ||
:returns: A dictionary with keys ``count: int`` and (if any samples | ||
have been collected) ``mean: torch.Tensor``. | ||
:rtype: dict | ||
""" | ||
if self.count == 0: | ||
return {"count": 0} | ||
return {"count": self.count, "mean": self.mean} | ||
|
||
|
||
class CountMeanVarianceStats(StreamingStats): | ||
""" | ||
Statistic tracking the count, mean, and (diagonal) variance of a single | ||
:class:`torch.Tensor`. | ||
""" | ||
|
||
def __init__(self): | ||
self.shape = None | ||
self.welford = WelfordCovariance(diagonal=True) | ||
super().__init__() | ||
|
||
def update(self, sample: torch.Tensor) -> None: | ||
assert isinstance(sample, torch.Tensor) | ||
if self.shape is None: | ||
self.shape = sample.shape | ||
assert sample.shape == self.shape | ||
self.welford.update(sample.detach().reshape(-1)) | ||
|
||
def merge(self, other: "CountMeanVarianceStats") -> "CountMeanVarianceStats": | ||
assert isinstance(other, type(self)) | ||
if self.shape is None: | ||
return copy.deepcopy(other) | ||
if other.shape is None: | ||
return copy.deepcopy(self) | ||
result = copy.deepcopy(self) | ||
res = result.welford | ||
lhs = self.welford | ||
rhs = other.welford | ||
res.n_samples = lhs.n_samples + rhs.n_samples | ||
lhs_weight = lhs.n_samples / res.n_samples | ||
rhs_weight = rhs.n_samples / res.n_samples | ||
res._mean = lhs_weight * lhs._mean + rhs_weight * rhs._mean | ||
res._m2 = ( | ||
lhs._m2 | ||
+ rhs._m2 | ||
+ (lhs.n_samples * rhs.n_samples / res.n_samples) | ||
* (lhs._mean - rhs._mean) ** 2 | ||
) | ||
return result | ||
|
||
def get(self) -> Dict[str, Union[int, torch.Tensor]]: | ||
""" | ||
:returns: A dictionary with keys ``count: int`` and (if any samples | ||
have been collected) ``mean: torch.Tensor`` and ``variance: | ||
torch.Tensor``. | ||
:rtype: dict | ||
""" | ||
if self.shape is None: | ||
return {"count": 0} | ||
count = self.welford.n_samples | ||
mean = self.welford._mean.reshape(self.shape) | ||
variance = self.welford.get_covariance(regularize=False).reshape(self.shape) | ||
return {"count": count, "mean": mean, "variance": variance} | ||
|
||
|
||
# Note this is ordered logically for sphinx rather than alphabetically. | ||
__all__ = [ | ||
"StreamingStats", | ||
"StatsOfDict", | ||
"StackStats", | ||
"CountStats", | ||
"CountMeanStats", | ||
"CountMeanVarianceStats", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these assertions should no longer be necessary with type hints
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, good point more generally. However I'd like to argue that we should include both type hints and assertions until all common tools can leverage type hints. My reasoning is that I'd really like to catch errors as early as possible, e.g. when users (like me) are working in a jupyter notebook. I think until Jupyter dynamically checks types while editing we'll want extra guard rails especially for tricky interfaces like this.