diff --git a/docs/source/ops.rst b/docs/source/ops.rst index cd8815f7c6..8a16e6c50f 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -93,6 +93,14 @@ Statistical Utilities :show-inheritance: :member-order: bysource +Streaming Statistics +-------------------- + +.. automodule:: pyro.ops.streaming + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource State Space Model and GP Utilities ---------------------------------- diff --git a/pyro/ops/streaming.py b/pyro/ops/streaming.py new file mode 100644 index 0000000000..1886b265a1 --- /dev/null +++ b/pyro/ops/streaming.py @@ -0,0 +1,277 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import copy +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Callable, Dict, Hashable, Union + +import torch + +from pyro.ops.welford import WelfordCovariance + + +class StreamingStats(ABC): + """ + Abstract base class for streamable statistics of trees of tensors. + + Derived classes must implelement :meth:`update`, :meth:`merge`, and + :meth:`get`. + """ + + @abstractmethod + def update(self, sample) -> None: + """ + Update state from a single sample. + + This mutates ``self`` and returns nothing. Updates should be + independent of order, i.e. samples should be exchangeable. + + :param sample: A sample value which is a nested dictionary of + :class:`torch.Tensor` leaves. This can have arbitrary nesting and + shape shape, but assumes shape is constant across calls to + ``.update()``. + """ + raise NotImplementedError + + @abstractmethod + def merge(self, other) -> "StreamingStats": + """ + Select two aggregate statistics, e.g. from different MCMC chains. + + This is a pure function: it returns a new :class:`StreamingStats` + object and does not modify either ``self`` or ``other``. + + :param other: Another streaming stats instance of the same type. + """ + assert isinstance(other, type(self)) + raise NotImplementedError + + @abstractmethod + def get(self) -> Any: + """ + Return the aggregate statistic. + """ + raise NotImplementedError + + +class CountStats(StreamingStats): + """ + Statistic tracking only the number of samples. + + For example:: + + >>> stats = CountStats() + >>> stats.update(torch.randn(3, 3)) + >>> stats.get() + {'count': 1} + """ + + def __init__(self): + self.count = 0 + super().__init__() + + def update(self, sample) -> None: + self.count += 1 + + def merge(self, other: "CountStats") -> "CountStats": + assert isinstance(other, type(self)) + result = CountStats() + result.count = self.count + other.count + return result + + def get(self) -> Dict[str, int]: + """ + :returns: A dictionary with keys ``count: int``. + :rtype: dict + """ + return {"count": self.count} + + +class StatsOfDict(StreamingStats): + """ + Statistics of samples that are dictionaries with constant set of keys. + + For example the following are equivalent:: + + # Version 1. Hand encode statistics. + >>> a_stats = CountStats() + >>> b_stats = CountMeanStats() + >>> a_stats.update(torch.tensor(0.)) + >>> b_stats.update(torch.tensor([1., 2.])) + >>> summary = {"a": a_stats.get(), "b": b_stats.get()} + + # Version 2. Collect samples into dictionaries. + >>> stats = StatsOfDict({"a": CountStats, "b": CountMeanStats}) + >>> stats.update({"a": torch.tensor(0.), "b": torch.tensor([1., 2.])}) + >>> summary = stats.get() + >>> summary + {'a': {'count': 1}, 'b': {'count': 1, 'mean': tensor([1., 2.])}} + + :param default: Default type of statistics of values of the dictionary. + Defaults to the inexpensive :class:`CountStats`. + :param dict types: Dictionary mapping key to type of statistic that should + be recorded for values corresponding to that key. + """ + + def __init__( + self, + types: Dict[Hashable, Callable[[], StreamingStats]] = {}, + default: Callable[[], StreamingStats] = CountStats, + ): + self.stats: Dict[Hashable, StreamingStats] = defaultdict(default) + self.stats.update({k: v() for k, v in types.items()}) + super().__init__() + + def update(self, sample: Dict[Hashable, Any]) -> None: + for k, v in sample.items(): + self.stats[k].update(v) + + def merge(self, other: "StatsOfDict") -> "StatsOfDict": + assert isinstance(other, type(self)) + result = copy.deepcopy(self) + for k in set(self.stats).union(other.stats): + if k not in self.stats: + result.stats[k] = copy.deepcopy(other.stats[k]) + elif k in other.stats: + result.stats[k] = self.stats[k].merge(other.stats[k]) + return result + + def get(self) -> Dict[Hashable, Any]: + """ + :returns: A dictionary of statistics. The keys of this dictionary are + the same as the keys of the samples from which this object is + updated. + :rtype: dict + """ + return {k: v.get() for k, v in self.stats.items()} + + +class StackStats(StreamingStats): + """ + Statistic collecting a stream of tensors into a single stacked tensor. + """ + + def __init__(self): + self.samples = [] + + def update(self, sample: torch.Tensor) -> None: + assert isinstance(sample, torch.Tensor) + self.samples.append(sample) + + def merge(self, other: "StackStats") -> "StackStats": + assert isinstance(other, type(self)) + result = StackStats() + result.samples = self.samples + other.samples + return result + + def get(self) -> Dict[str, Union[int, torch.Tensor]]: + """ + :returns: A dictionary with keys ``count: int`` and (if any samples + have been collected) ``samples: torch.Tensor``. + :rtype: dict + """ + if not self.samples: + return {"count": 0} + return {"count": len(self.samples), "samples": torch.stack(self.samples)} + + +class CountMeanStats(StreamingStats): + """ + Statistic tracking the count and mean of a single :class:`torch.Tensor`. + """ + + def __init__(self): + self.count = 0 + self.mean = 0 + super().__init__() + + def update(self, sample: torch.Tensor) -> None: + assert isinstance(sample, torch.Tensor) + self.count += 1 + self.mean += (sample.detach() - self.mean) / self.count + + def merge(self, other: "CountMeanStats") -> "CountMeanStats": + assert isinstance(other, type(self)) + result = CountMeanStats() + result.count = self.count + other.count + p = self.count / max(result.count, 1) + q = other.count / max(result.count, 1) + result.mean = p * self.mean + q * other.mean + return result + + def get(self) -> Dict[str, Union[int, torch.Tensor]]: + """ + :returns: A dictionary with keys ``count: int`` and (if any samples + have been collected) ``mean: torch.Tensor``. + :rtype: dict + """ + if self.count == 0: + return {"count": 0} + return {"count": self.count, "mean": self.mean} + + +class CountMeanVarianceStats(StreamingStats): + """ + Statistic tracking the count, mean, and (diagonal) variance of a single + :class:`torch.Tensor`. + """ + + def __init__(self): + self.shape = None + self.welford = WelfordCovariance(diagonal=True) + super().__init__() + + def update(self, sample: torch.Tensor) -> None: + assert isinstance(sample, torch.Tensor) + if self.shape is None: + self.shape = sample.shape + assert sample.shape == self.shape + self.welford.update(sample.detach().reshape(-1)) + + def merge(self, other: "CountMeanVarianceStats") -> "CountMeanVarianceStats": + assert isinstance(other, type(self)) + if self.shape is None: + return copy.deepcopy(other) + if other.shape is None: + return copy.deepcopy(self) + result = copy.deepcopy(self) + res = result.welford + lhs = self.welford + rhs = other.welford + res.n_samples = lhs.n_samples + rhs.n_samples + lhs_weight = lhs.n_samples / res.n_samples + rhs_weight = rhs.n_samples / res.n_samples + res._mean = lhs_weight * lhs._mean + rhs_weight * rhs._mean + res._m2 = ( + lhs._m2 + + rhs._m2 + + (lhs.n_samples * rhs.n_samples / res.n_samples) + * (lhs._mean - rhs._mean) ** 2 + ) + return result + + def get(self) -> Dict[str, Union[int, torch.Tensor]]: + """ + :returns: A dictionary with keys ``count: int`` and (if any samples + have been collected) ``mean: torch.Tensor`` and ``variance: + torch.Tensor``. + :rtype: dict + """ + if self.shape is None: + return {"count": 0} + count = self.welford.n_samples + mean = self.welford._mean.reshape(self.shape) + variance = self.welford.get_covariance(regularize=False).reshape(self.shape) + return {"count": count, "mean": mean, "variance": variance} + + +# Note this is ordered logically for sphinx rather than alphabetically. +__all__ = [ + "StreamingStats", + "StatsOfDict", + "StackStats", + "CountStats", + "CountMeanStats", + "CountMeanVarianceStats", +] diff --git a/setup.cfg b/setup.cfg index 9785fe5387..a47927d363 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,7 +62,15 @@ warn_unused_ignores = True ignore_errors = True warn_unused_ignores = True -[mypy-pyro.ops.*] +[mypy-pyro.ops.einsum] +ignore_errors = True +warn_unused_ignores = True + +[mypy-pyro.ops.contract] +ignore_errors = True +warn_unused_ignores = True + +[mypy-pyro.ops.tensor_utils] ignore_errors = True warn_unused_ignores = True diff --git a/tests/common.py b/tests/common.py index da40d8c230..481820c9db 100644 --- a/tests/common.py +++ b/tests/common.py @@ -225,7 +225,7 @@ def assert_close(actual, expected, atol=1e-7, rtol=0, msg=''): assert set(actual.keys()) == set(expected.keys()) for key, x_val in actual.items(): assert_close(x_val, expected[key], atol=atol, rtol=rtol, - msg='At key{}: {} vs {}'.format(key, x_val, expected[key])) + msg='At key {}: {} vs {}'.format(repr(key), x_val, expected[key])) elif isinstance(actual, str): assert actual == expected, msg elif is_iterable(actual) and is_iterable(expected): diff --git a/tests/ops/test_streaming.py b/tests/ops/test_streaming.py new file mode 100644 index 0000000000..500776e77a --- /dev/null +++ b/tests/ops/test_streaming.py @@ -0,0 +1,109 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import functools + +import pytest +import torch + +from pyro.ops.streaming import ( + CountMeanStats, + CountMeanVarianceStats, + CountStats, + StackStats, + StatsOfDict, +) +from tests.common import assert_close + + +def generate_data(num_samples): + shapes = {"aaa": (), "bbb": (4,), "ccc": (3, 2), "ddd": (5, 1)} + return [{k: torch.randn(v) for k, v in shapes.items()} for _ in range(num_samples)] + + +EXAMPLE_STATS = [ + CountStats, + functools.partial(StatsOfDict, default=CountMeanStats), + functools.partial(StatsOfDict, default=CountMeanVarianceStats), + functools.partial(StatsOfDict, default=StackStats), + StatsOfDict, + functools.partial( + StatsOfDict, {"aaa": CountMeanStats, "bbb": CountMeanVarianceStats} + ), +] +EXAMPLE_STATS_IDS = [ + "CountStats", + "CountMeanStats", + "CountMeanVarianceStats", + "StackStats", + "StatsOfDict1", + "StatsOfDict2", +] + + +def sort_samples_in_place(x): + for key, value in list(x.items()): + if isinstance(key, str) and key == "samples": + x[key] = value.sort(0).values + elif isinstance(value, dict): + sort_samples_in_place(value) + + +@pytest.mark.parametrize("size", [0, 10]) +@pytest.mark.parametrize("make_stats", EXAMPLE_STATS, ids=EXAMPLE_STATS_IDS) +def test_update_get(make_stats, size): + samples = generate_data(size) + + expected_stats = make_stats() + for sample in samples: + expected_stats.update(sample) + expected = expected_stats.get() + + actual_stats = make_stats() + for i in torch.randperm(len(samples)).tolist(): + actual_stats.update(samples[i]) + actual = actual_stats.get() + + # Sort samples in case of StackStats. + sort_samples_in_place(expected) + sort_samples_in_place(actual) + + assert_close(actual, expected) + + +@pytest.mark.parametrize("left_size, right_size", [(3, 5), (0, 8), (8, 0), (0, 0)]) +@pytest.mark.parametrize("make_stats", EXAMPLE_STATS, ids=EXAMPLE_STATS_IDS) +def test_update_merge_get(make_stats, left_size, right_size): + left_samples = generate_data(left_size) + right_samples = generate_data(right_size) + + expected_stats = make_stats() + for sample in left_samples + right_samples: + expected_stats.update(sample) + expected = expected_stats.get() + + left_stats = make_stats() + for sample in left_samples: + left_stats.update(sample) + right_stats = make_stats() + for sample in right_samples: + right_stats.update(sample) + actual_stats = left_stats.merge(right_stats) + assert isinstance(actual_stats, type(expected_stats)) + + actual = actual_stats.get() + assert_close(actual, expected) + + +def test_stats_of_dict(): + stats = StatsOfDict(types={"aaa": CountMeanStats}, default=CountStats) + stats.update({"aaa": torch.tensor(0.0)}) + stats.update({"aaa": torch.tensor(1.0), "bbb": torch.randn(3, 3)}) + stats.update({"aaa": torch.tensor(2.0), "bbb": torch.randn(3, 3)}) + actual = stats.get() + + expected = { + "aaa": {"count": 3, "mean": torch.tensor(1.0)}, + "bbb": {"count": 2}, + } + assert_close(actual, expected)