Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pyro.ops.streaming module #2856

Merged
merged 12 commits into from
Jun 7, 2021
8 changes: 8 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ Statistical Utilities
:show-inheritance:
:member-order: bysource

Streaming Statistics
--------------------

.. automodule:: pyro.ops.streaming
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

State Space Model and GP Utilities
----------------------------------
Expand Down
277 changes: 277 additions & 0 deletions pyro/ops/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import copy
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, Hashable, Union

import torch

from pyro.ops.welford import WelfordCovariance


class StreamingStats(ABC):
"""
Abstract base class for streamable statistics of trees of tensors.

Derived classes must implelement :meth:`update`, :meth:`merge`, and
:meth:`get`.
"""

@abstractmethod
def update(self, sample) -> None:
"""
Update state from a single sample.

This mutates ``self`` and returns nothing. Updates should be
independent of order, i.e. samples should be exchangeable.

:param sample: A sample value which is a nested dictionary of
:class:`torch.Tensor` leaves. This can have arbitrary nesting and
shape shape, but assumes shape is constant across calls to
``.update()``.
"""
raise NotImplementedError

@abstractmethod
def merge(self, other) -> "StreamingStats":
"""
Select two aggregate statistics, e.g. from different MCMC chains.

This is a pure function: it returns a new :class:`StreamingStats`
object and does not modify either ``self`` or ``other``.

:param other: Another streaming stats instance of the same type.
"""
assert isinstance(other, type(self))
raise NotImplementedError

@abstractmethod
def get(self) -> Any:
"""
Return the aggregate statistic.
"""
raise NotImplementedError


class CountStats(StreamingStats):
"""
Statistic tracking only the number of samples.

For example::

>>> stats = CountStats()
>>> stats.update(torch.randn(3, 3))
>>> stats.get()
{'count': 1}
"""

def __init__(self):
self.count = 0
super().__init__()

def update(self, sample) -> None:
self.count += 1

def merge(self, other: "CountStats") -> "CountStats":
assert isinstance(other, type(self))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these assertions should no longer be necessary with type hints

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good point more generally. However I'd like to argue that we should include both type hints and assertions until all common tools can leverage type hints. My reasoning is that I'd really like to catch errors as early as possible, e.g. when users (like me) are working in a jupyter notebook. I think until Jupyter dynamically checks types while editing we'll want extra guard rails especially for tricky interfaces like this.

result = CountStats()
result.count = self.count + other.count
return result

def get(self) -> Dict[str, int]:
"""
:returns: A dictionary with keys ``count: int``.
:rtype: dict
"""
return {"count": self.count}


class StatsOfDict(StreamingStats):
fritzo marked this conversation as resolved.
Show resolved Hide resolved
"""
Statistics of samples that are dictionaries with constant set of keys.

For example the following are equivalent::

# Version 1. Hand encode statistics.
>>> a_stats = CountStats()
>>> b_stats = CountMeanStats()
>>> a_stats.update(torch.tensor(0.))
>>> b_stats.update(torch.tensor([1., 2.]))
>>> summary = {"a": a_stats.get(), "b": b_stats.get()}

# Version 2. Collect samples into dictionaries.
>>> stats = StatsOfDict({"a": CountStats, "b": CountMeanStats})
>>> stats.update({"a": torch.tensor(0.), "b": torch.tensor([1., 2.])})
>>> summary = stats.get()
>>> summary
{'a': {'count': 1}, 'b': {'count': 1, 'mean': tensor([1., 2.])}}

:param default: Default type of statistics of values of the dictionary.
Defaults to the inexpensive :class:`CountStats`.
:param dict types: Dictionary mapping key to type of statistic that should
be recorded for values corresponding to that key.
"""

def __init__(
self,
types: Dict[Hashable, Callable[[], StreamingStats]] = {},
default: Callable[[], StreamingStats] = CountStats,
):
self.stats: Dict[Hashable, StreamingStats] = defaultdict(default)
self.stats.update({k: v() for k, v in types.items()})
super().__init__()

def update(self, sample: Dict[Hashable, Any]) -> None:
for k, v in sample.items():
self.stats[k].update(v)

def merge(self, other: "StatsOfDict") -> "StatsOfDict":
assert isinstance(other, type(self))
result = copy.deepcopy(self)
fritzo marked this conversation as resolved.
Show resolved Hide resolved
for k in set(self.stats).union(other.stats):
if k not in self.stats:
result.stats[k] = copy.deepcopy(other.stats[k])
elif k in other.stats:
result.stats[k] = self.stats[k].merge(other.stats[k])
return result

def get(self) -> Dict[Hashable, Any]:
"""
:returns: A dictionary of statistics. The keys of this dictionary are
the same as the keys of the samples from which this object is
updated.
:rtype: dict
"""
return {k: v.get() for k, v in self.stats.items()}


class StackStats(StreamingStats):
"""
Statistic collecting a stream of tensors into a single stacked tensor.
"""

def __init__(self):
self.samples = []

def update(self, sample: torch.Tensor) -> None:
assert isinstance(sample, torch.Tensor)
self.samples.append(sample)

def merge(self, other: "StackStats") -> "StackStats":
assert isinstance(other, type(self))
result = StackStats()
result.samples = self.samples + other.samples
return result

def get(self) -> Dict[str, Union[int, torch.Tensor]]:
"""
:returns: A dictionary with keys ``count: int`` and (if any samples
have been collected) ``samples: torch.Tensor``.
:rtype: dict
"""
if not self.samples:
return {"count": 0}
return {"count": len(self.samples), "samples": torch.stack(self.samples)}


class CountMeanStats(StreamingStats):
"""
Statistic tracking the count and mean of a single :class:`torch.Tensor`.
"""

def __init__(self):
self.count = 0
self.mean = 0
super().__init__()

def update(self, sample: torch.Tensor) -> None:
assert isinstance(sample, torch.Tensor)
self.count += 1
self.mean += (sample.detach() - self.mean) / self.count

def merge(self, other: "CountMeanStats") -> "CountMeanStats":
assert isinstance(other, type(self))
result = CountMeanStats()
result.count = self.count + other.count
p = self.count / max(result.count, 1)
q = other.count / max(result.count, 1)
result.mean = p * self.mean + q * other.mean
return result

def get(self) -> Dict[str, Union[int, torch.Tensor]]:
"""
:returns: A dictionary with keys ``count: int`` and (if any samples
have been collected) ``mean: torch.Tensor``.
:rtype: dict
"""
if self.count == 0:
return {"count": 0}
return {"count": self.count, "mean": self.mean}


class CountMeanVarianceStats(StreamingStats):
"""
Statistic tracking the count, mean, and (diagonal) variance of a single
:class:`torch.Tensor`.
"""

def __init__(self):
self.shape = None
self.welford = WelfordCovariance(diagonal=True)
super().__init__()

def update(self, sample: torch.Tensor) -> None:
assert isinstance(sample, torch.Tensor)
if self.shape is None:
self.shape = sample.shape
assert sample.shape == self.shape
self.welford.update(sample.detach().reshape(-1))

def merge(self, other: "CountMeanVarianceStats") -> "CountMeanVarianceStats":
assert isinstance(other, type(self))
if self.shape is None:
return copy.deepcopy(other)
if other.shape is None:
return copy.deepcopy(self)
result = copy.deepcopy(self)
res = result.welford
lhs = self.welford
rhs = other.welford
res.n_samples = lhs.n_samples + rhs.n_samples
lhs_weight = lhs.n_samples / res.n_samples
rhs_weight = rhs.n_samples / res.n_samples
res._mean = lhs_weight * lhs._mean + rhs_weight * rhs._mean
res._m2 = (
lhs._m2
+ rhs._m2
+ (lhs.n_samples * rhs.n_samples / res.n_samples)
* (lhs._mean - rhs._mean) ** 2
)
return result

def get(self) -> Dict[str, Union[int, torch.Tensor]]:
"""
:returns: A dictionary with keys ``count: int`` and (if any samples
have been collected) ``mean: torch.Tensor`` and ``variance:
torch.Tensor``.
:rtype: dict
"""
if self.shape is None:
return {"count": 0}
count = self.welford.n_samples
mean = self.welford._mean.reshape(self.shape)
variance = self.welford.get_covariance(regularize=False).reshape(self.shape)
return {"count": count, "mean": mean, "variance": variance}


# Note this is ordered logically for sphinx rather than alphabetically.
__all__ = [
"StreamingStats",
"StatsOfDict",
"StackStats",
"CountStats",
"CountMeanStats",
"CountMeanVarianceStats",
]
19 changes: 10 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,45 @@ column_limit = 120
python_version = 3.6
warn_return_any = True
warn_unused_configs = True
warn_incomplete_stub = True

# Per-module options:

[mypy-pyro._version.*]
ignore_errors = True
warn_incomplete_stub = True

[mypy-pyro.contrib.*]
ignore_errors = True
warn_incomplete_stub = True

[mypy-pyro.distributions.*]
ignore_errors = True
warn_incomplete_stub = True
warn_unused_ignores = True

[mypy-pyro.generic.*]
ignore_errors = True
warn_incomplete_stub = True
warn_unused_ignores = True

[mypy-pyro.infer.*]
ignore_errors = True
warn_incomplete_stub = True
warn_unused_ignores = True

[mypy-pyro.nn.*]
ignore_errors = True
warn_incomplete_stub = True
warn_unused_ignores = True

[mypy-pyro.ops.*]
[mypy-pyro.ops.einsum]
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.ops.contract]
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.ops.tensor_utils]
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.optm.*]
warn_incomplete_stub = True
warn_unused_ignores = True

[mypy-pyro.params.*]
Expand All @@ -83,5 +85,4 @@ ignore_errors = True

[mypy-pyro.util.*]
ignore_errors = True
warn_incomplete_stub = True
warn_unused_ignores = True
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def assert_close(actual, expected, atol=1e-7, rtol=0, msg=''):
assert set(actual.keys()) == set(expected.keys())
for key, x_val in actual.items():
assert_close(x_val, expected[key], atol=atol, rtol=rtol,
msg='At key{}: {} vs {}'.format(key, x_val, expected[key]))
msg='At key {}: {} vs {}'.format(repr(key), x_val, expected[key]))
elif isinstance(actual, str):
assert actual == expected, msg
elif is_iterable(actual) and is_iterable(expected):
Expand Down
Loading