Skip to content

Commit

Permalink
Add type hints to the primitives module (pyro-ppl#1940)
Browse files Browse the repository at this point in the history
* initial type hints

* more types

* all basic hints

* fix hints

* use Union

* use as array

* distribution like type

* fix import

* add tests for the distributions type

* feedback

* fix __call__ types

* typo  #shame

* fix test

* space

* message type

* rename message type

* remove unecessary casting

* mprove import

* remover asarray
  • Loading branch information
juanitorduz authored Dec 23, 2024
1 parent 8e1d9b2 commit e71aa62
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 64 deletions.
16 changes: 6 additions & 10 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,22 @@
import numpyro.distributions as dist


def _non_centered_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
) -> Array:
def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)


def _centered_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
) -> Array:
def _centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

return phi @ beta


def linear_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int], non_centered: bool = True
phi: Array, spd: Array, m: int, non_centered: bool = True
) -> Array:
"""
Linear approximation formula of the Hilbert space Gaussian process.
Expand All @@ -52,10 +48,10 @@ def linear_approximation(
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param ArrayLike phi: laplacian eigenfunctions
:param ArrayLike spd: square root of the diagonal of the spectral density evaluated at square
:param Array phi: laplacian eigenfunctions
:param Array spd: square root of the diagonal of the spectral density evaluated at square
root of the first `m` eigenvalues.
:param int | list[int] m: number of eigenfunctions in the approximation
:param int m: number of eigenfunctions in the approximation
:param bool non_centered: whether to use a non-centered parameterization
:return: The low-rank approximation linear model
:rtype: Array
Expand Down
19 changes: 10 additions & 9 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Union

import numpy as np
from numpy.typing import NDArray

import jax
from jax import device_get
Expand All @@ -25,7 +26,7 @@
]


def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def _compute_chain_variance_stats(x: NDArray) -> tuple[NDArray, NDArray]:
# compute within-chain variance and variance estimator
# input has shape C x N x sample_shape
C, N = x.shape[:2]
Expand All @@ -41,7 +42,7 @@ def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray
return var_within, var_estimator


def gelman_rubin(x: np.ndarray) -> np.ndarray:
def gelman_rubin(x: NDArray) -> NDArray:
"""
Computes R-hat over chains of samples ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand All @@ -60,7 +61,7 @@ def gelman_rubin(x: np.ndarray) -> np.ndarray:
return rhat


def split_gelman_rubin(x: np.ndarray) -> np.ndarray:
def split_gelman_rubin(x: NDArray) -> NDArray:
"""
Computes split R-hat over chains of samples ``x``, where the first dimension
of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -97,7 +98,7 @@ def _fft_next_fast_len(target: int) -> int:
target += 1


def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
def autocorrelation(x: NDArray, axis: int = 0, bias: bool = True) -> NDArray:
"""
Computes the autocorrelation of samples at dimension ``axis``.
Expand Down Expand Up @@ -137,11 +138,11 @@ def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarr
autocorr = autocorr / np.arange(N, 0.0, -1)

with np.errstate(invalid="ignore", divide="ignore"):
autocorr = autocorr / autocorr[..., :1]
autocorr = (autocorr / autocorr[..., :1]).astype(np.float64)
return np.swapaxes(autocorr, axis, -1)


def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
def autocovariance(x: NDArray, axis: int = 0, bias: bool = True) -> NDArray:
"""
Computes the autocovariance of samples at dimension ``axis``.
Expand All @@ -154,7 +155,7 @@ def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarra
return autocorrelation(x, axis, bias) * x.var(axis=axis, keepdims=True)


def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
def effective_sample_size(x: NDArray, bias: bool = True) -> NDArray:
"""
Computes effective sample size of input ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -202,7 +203,7 @@ def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
return n_eff


def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray:
def hpdi(x: NDArray, prob: float = 0.90, axis: int = 0) -> NDArray:
"""
Computes "highest posterior density interval" (HPDI) which is the narrowest
interval with probability mass ``prob``.
Expand Down Expand Up @@ -285,7 +286,7 @@ def summary(


def print_summary(
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, NDArray], prob: float = 0.90, group_by_chain: bool = True
) -> None:
"""
Prints a summary table displaying diagnostics of ``samples`` from the
Expand Down
62 changes: 50 additions & 12 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from collections import OrderedDict
from contextlib import contextmanager
import functools
import inspect
from typing import Any, Protocol, runtime_checkable
import warnings

import numpy as np
Expand All @@ -37,6 +37,7 @@
from jax import lax, tree_util
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform
from numpyro.distributions.util import (
Expand Down Expand Up @@ -270,7 +271,7 @@ def validate_args(self, strict: bool = True) -> None:
raise RuntimeError("Cannot validate arguments inside jitted code.")

@property
def batch_shape(self):
def batch_shape(self) -> tuple[int, ...]:
"""
Returns the shape over which the distribution parameters are batched.
Expand All @@ -280,7 +281,7 @@ def batch_shape(self):
return self._batch_shape

@property
def event_shape(self):
def event_shape(self) -> tuple[int, ...]:
"""
Returns the shape of a single sample from the distribution without
batching.
Expand All @@ -291,24 +292,24 @@ def event_shape(self):
return self._event_shape

@property
def event_dim(self):
def event_dim(self) -> int:
"""
:return: Number of dimensions of individual events.
:rtype: int
"""
return len(self.event_shape)

@property
def has_rsample(self):
def has_rsample(self) -> bool:
return set(self.reparametrized_params) == set(self.arg_constraints)

def rsample(self, key, sample_shape=()):
def rsample(self, key, sample_shape=()) -> ArrayLike:
if self.has_rsample:
return self.sample(key, sample_shape=sample_shape)

raise NotImplementedError

def shape(self, sample_shape=()):
def shape(self, sample_shape=()) -> tuple[int, ...]:
"""
The tensor shape of samples from this distribution.
Expand All @@ -323,7 +324,7 @@ def shape(self, sample_shape=()):
"""
return sample_shape + self.batch_shape + self.event_shape

def sample(self, key, sample_shape=()):
def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
"""
Returns a sample from the distribution having shape given by
`sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
Expand Down Expand Up @@ -361,14 +362,14 @@ def log_prob(self, value):
raise NotImplementedError

@property
def mean(self):
def mean(self) -> ArrayLike:
"""
Mean of the distribution.
"""
raise NotImplementedError

@property
def variance(self):
def variance(self) -> ArrayLike:
"""
Variance of the distribution.
"""
Expand Down Expand Up @@ -540,7 +541,7 @@ def infer_shapes(cls, *args, **kwargs):
event_shape = ()
return batch_shape, event_shape

def cdf(self, value):
def cdf(self, value: ArrayLike) -> ArrayLike:
"""
The cumulative distribution function of this distribution.
Expand All @@ -549,7 +550,7 @@ def cdf(self, value):
"""
raise NotImplementedError

def icdf(self, q):
def icdf(self, q: ArrayLike) -> ArrayLike:
"""
The inverse cumulative distribution function of this distribution.
Expand All @@ -563,6 +564,43 @@ def is_discrete(self):
return self.support.is_discrete


@runtime_checkable
class DistributionLike(Protocol):
"""A protocol for typing distributions.
Used to type object of type numpyro.distributions.Distribution, funsor.Funsor
or tensorflow_probability.distributions.Distribution.
"""

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return super().__call__(*args, **kwargs)

@property
def batch_shape(self) -> tuple[int, ...]: ...

@property
def event_shape(self) -> tuple[int, ...]: ...

@property
def event_dim(self) -> int: ...

def sample(
self, key: ArrayLike, sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...

def log_prob(self, value: ArrayLike) -> ArrayLike: ...

@property
def mean(self) -> ArrayLike: ...

@property
def variance(self) -> ArrayLike: ...

def cdf(self, value: ArrayLike) -> ArrayLike: ...

def icdf(self, q: ArrayLike) -> ArrayLike: ...


class ExpandedDistribution(Distribution):
arg_constraints = {}
pytree_data_fields = ("base_dist",)
Expand Down
Loading

0 comments on commit e71aa62

Please sign in to comment.