From 902623c62f5ee0a385f16a4303693f3803c5050d Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sun, 30 Jul 2023 05:42:23 +0100 Subject: [PATCH] `vmap`-able `Distribution`s (#1529) * Make `Normal` objects compatible with `jax.vmap` * Sort imports * Make `MultivariateNormal` compatible with `vmap` * remove unused variables * Pass `in_vmap` to base class constructor of mvnormal dists * fix issue with docs formatting * test vmapping multivariate normal twice * temporarily disable failing test * lint * don't use `__init__` to restore state post-vmapping * adapt numpyro control flows to new batch_shape handling * fixup! adapt numpyro control flows to new batch_shape handling * Do not mutate shapes of ExpandedDistribution for map-free ops * improve post-scan batch_shape updating * re-enable disable test * [WIP] vmap tests for arbitrary distributions * WIP * vmappable continuous distributions * Fix incorrect unflattenning of inverse transforms * mark CAR's adj_matrix as auxiliary if sparse * vmap support for discrete distributions * add missing license header * batch shape promotion for Bernoulli/Categorical Probs * More distribution-specific logic batch-shape promotion * linting * fixup! linting * fixup! linting * implement `vmap` support remaining distribution * reference sparse warning using correct namespace * uniformize tree_flatten/unflatten method across dists * fixup! uniformize tree_flatten/unflatten method across dists * remove normal-specific vmap tests * decentralize batch_shape promotion * fixup! uniformize tree_flatten/unflatten method across dists * fixup! uniformize tree_flatten/unflatten method across dists * fixup! uniformize tree_flatten/unflatten method across dists * move `vmap_over` in a dedicated util module * minor cosmetic changes in `vmap_util` * [WIP] clarify test_vmap_dist * Finish clarifying test_vmap_dist * remove spurious deleted/added newlines * minor improvements in tests * use arg_constraints for batch shape promotion * move batch shape promotion to vmap_util.py * vmap_util.py -> batch_util.py * have pytree_data_fields default to arg_constraint.keys() * revert some unrelated modifications --- numpyro/contrib/control_flow/scan.py | 3 + numpyro/distributions/batch_util.py | 577 ++++++++++++++++++++++++++ numpyro/distributions/conjugate.py | 5 + numpyro/distributions/continuous.py | 134 ++---- numpyro/distributions/copula.py | 30 +- numpyro/distributions/directional.py | 1 + numpyro/distributions/discrete.py | 8 + numpyro/distributions/distribution.py | 183 ++++---- numpyro/distributions/mixtures.py | 56 +-- numpyro/distributions/truncated.py | 84 +--- test/test_distributions.py | 334 ++++++++++++--- 11 files changed, 1026 insertions(+), 389 deletions(-) create mode 100644 numpyro/distributions/batch_util.py diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index f0148d263..b04c51862 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -9,6 +9,7 @@ from jax.tree_util import tree_flatten, tree_map, tree_unflatten from numpyro import handlers +from numpyro.distributions.batch_util import promote_batch_shape from numpyro.ops.pytree import PytreeTrace from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack from numpyro.util import not_jax_tracer @@ -220,6 +221,7 @@ def body_fn(wrapped_carry, x, prefix=None): if first_var is None: first_var = name + site["fn"] = promote_batch_shape(site["fn"]) # we haven't promote shapes of values yet during `lax.scan`, so we do it here site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"]) @@ -308,6 +310,7 @@ def body_fn(wrapped_carry, x): for name, site in pytree_trace.trace.items(): if site["type"] != "sample": continue + site["fn"] = promote_batch_shape(site["fn"]) # we haven't promote shapes of values yet during `lax.scan`, so we do it here site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"]) return last_carry, (pytree_trace, ys) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py new file mode 100644 index 000000000..cc13dbf40 --- /dev/null +++ b/numpyro/distributions/batch_util.py @@ -0,0 +1,577 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import copy +from functools import singledispatch +from typing import Union + +from jax import tree_map +import jax.numpy as jnp + +from numpyro.distributions import constraints +from numpyro.distributions.conjugate import ( + BetaBinomial, + DirichletMultinomial, + GammaPoisson, + NegativeBinomial2, + NegativeBinomialLogits, + NegativeBinomialProbs, +) +from numpyro.distributions.constraints import Constraint +from numpyro.distributions.continuous import ( + CAR, + LKJ, + AsymmetricLaplaceQuantile, + Beta, + BetaProportion, + Chi2, + Gamma, + HalfCauchy, + HalfNormal, + InverseGamma, + Kumaraswamy, + LKJCholesky, + LogNormal, + LogUniform, + LowRankMultivariateNormal, + MultivariateStudentT, + Pareto, + RelaxedBernoulliLogits, + StudentT, + Uniform, +) +from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta +from numpyro.distributions.discrete import ( + CategoricalProbs, + DiscreteUniform, + OrderedLogistic, + ZeroInflatedPoisson, + ZeroInflatedProbs, +) +from numpyro.distributions.distribution import ( + Distribution, + ExpandedDistribution, + MaskedDistribution, + Unit, +) +from numpyro.distributions.transforms import ( + AffineTransform, + CorrCholeskyTransform, + PowerTransform, + Transform, +) +from numpyro.distributions.truncated import ( + LeftTruncatedDistribution, + RightTruncatedDistribution, + TwoSidedTruncatedDistribution, +) + + +@singledispatch +def vmap_over(d: Union[Distribution, Transform, Constraint], **kwargs): + raise NotImplementedError + + +@vmap_over.register +def _vmap_over_affine_transform( + dist: AffineTransform, loc=None, scale=None, domain=None +): + dist_axes = copy.copy(dist) + dist_axes.loc = loc + dist_axes.scale = scale + dist_axes.domain = domain + return dist_axes + + +@vmap_over.register +def _vmap_over_greater_than(dist: constraints._GreaterThan, lower_bound=None): + axes = copy.copy(dist) + axes.lower_bound = lower_bound + return axes + + +@vmap_over.register +def _vmap_over_less_than(dist: constraints._LessThan, upper_bound=None): + axes = copy.copy(dist) + axes.upper_bound = upper_bound + return axes + + +@vmap_over.register +def _vmap_over_interval( + dist: constraints._Interval, lower_bound=None, upper_bound=None +): + axes = copy.copy(dist) + axes.lower_bound = lower_bound + axes.upper_bound = upper_bound + return axes + + +@vmap_over.register +def _vmap_over_integer_interval( + dist: constraints._IntegerInterval, lower_bound=None, upper_bound=None +): + dist_axes = copy.copy(dist) + dist_axes.lower_bound = lower_bound + dist_axes.upper_bound = upper_bound + return dist_axes + + +@vmap_over.register +def _vmap_over_corr_cholesky_transform(dist: CorrCholeskyTransform): + return None + + +@vmap_over.register +def _vmap_over_power_transform(dist: PowerTransform, exponent=None): + axes = copy.copy(dist) + axes.exponent = exponent + return axes + + +@vmap_over.register +def _default_vmap_over(d: Distribution, **kwargs): + pytree_fields = type(d).gather_pytree_data_fields() + dist_axes = copy.copy(d) + + for f in pytree_fields: + setattr(dist_axes, f, kwargs.get(f, None)) + + return dist_axes + + +@vmap_over.register +def _(dist: AsymmetricLaplaceQuantile, loc=None, scale=None, quantile=None): + dist_axes = _default_vmap_over( + dist, + loc=loc, + scale=scale, + quantile=quantile, + _ald=vmap_over( + dist._ald, + loc=loc, + scale=scale if scale is not None else quantile, + asymmetry=quantile, + ), + ) + return dist_axes + + +@vmap_over.register +def _vmap_over_beta(dist: Beta, concentration1=None, concentration0=None): + dist_axes = _default_vmap_over( + dist, concentration1=concentration1, concentration0=concentration0 + ) + if concentration1 is not None or concentration0 is not None: + dist_axes._dirichlet = 0 + else: + dist_axes._dirichlet = None + return dist_axes + + +@vmap_over.register +def _vmap_over_beta_proportion(dist: BetaProportion, mean=None, concentration=None): + dist_axes = vmap_over.dispatch(Beta)( + dist, + concentration1=concentration if concentration is not None else mean, + concentration0=concentration if concentration is not None else mean, + ) + dist_axes.concentration = concentration + return dist_axes + + +@vmap_over.register +def _vmap_over_chi2(dist: Chi2, df=None): + dist_axes = vmap_over.dispatch(Gamma)(dist, rate=df, concentration=df) + dist_axes.df = df + return dist_axes + + +@vmap_over.register +def _vmap_over_gaussian_copula( + dist: GaussianCopula, + marginal_dist=None, + correlation_matrix=None, + correlation_cholesky=None, +): + dist_axes = _default_vmap_over( + dist, marginal_dist=marginal_dist, correlation_matrix=correlation_matrix + ) + dist_axes.base_dist = vmap_over( + dist.base_dist, + loc=correlation_matrix if correlation_matrix == 0 else correlation_cholesky, + scale_tril=correlation_matrix + if correlation_matrix == 0 + else correlation_cholesky, + covariance_matrix=correlation_matrix, + ) + return dist_axes + + +@vmap_over.register +def _vmap_over_gausian_copula_beta( + dist: GaussianCopulaBeta, + concentration1=None, + concentration0=None, + correlation_matrix=None, + correlation_cholesky=None, +): + d = vmap_over.dispatch(GaussianCopula)( + dist, + vmap_over( + dist.marginal_dist, + concentration1=concentration1, + concentration0=concentration0, + ), + correlation_matrix=correlation_matrix, + correlation_cholesky=correlation_cholesky, + ) + d.concentration1 = concentration1 + d.concentration0 = concentration0 + return d + + +@vmap_over.register +def _vmap_over_half_cauchy(dist: HalfCauchy, scale=None): + dist_axes = _default_vmap_over(dist, scale=scale) + dist_axes._cauchy = vmap_over(dist._cauchy, loc=scale, scale=scale) + return dist_axes + + +@vmap_over.register +def _vmap_over_inverse_gamma(dist: InverseGamma, concentration=None, rate=None): + dist_axes = _default_vmap_over(dist, concentration=concentration, rate=rate) + dist_axes.base_dist = vmap_over( + dist.base_dist, concentration=concentration, rate=rate + ) + dist_axes.transforms = None + return dist_axes + + +@vmap_over.register +def _vmap_over_uniform(dist: Uniform, low=None, high=None): + dist_axes = _default_vmap_over(dist, low=low, high=high) + dist_axes._support = vmap_over(dist._support, lower_bound=low, upper_bound=high) + return dist_axes + + +@vmap_over.register +def _vmap_over_kumaraswamy(dist: Kumaraswamy, concentration0=None, concentration1=None): + dist_axes = _default_vmap_over( + dist, concentration0=concentration0, concentration1=concentration1 + ) + if isinstance(dist.base_dist, Uniform): + dist_axes.base_dist = vmap_over(dist.base_dist, low=None, high=None) + else: + assert isinstance(dist.base_dist, ExpandedDistribution) + dist_axes.base_dist = vmap_over(dist.base_dist, base_dist=None) + + dist_axes.transforms = [ + vmap_over(dist.transforms[0], exponent=concentration0), + vmap_over(dist.transforms[1], loc=None, scale=None), + vmap_over(dist.transforms[2], exponent=concentration1), + ] + return dist_axes + + +@vmap_over.register +def _vmap_over_lkj(dist: LKJ, concentration=None): + dist_axes = _default_vmap_over(dist, concentration=concentration) + dist_axes.base_dist = vmap_over(dist.base_dist, concentration=concentration) + dist_axes.transforms = None + return dist_axes + + +@vmap_over.register +def _vmap_over_lkj_cholesky(dist: LKJCholesky, concentration): + dist_axes = _default_vmap_over(dist, concentration=concentration) + if dist_axes.sample_method == "onion": + dist_axes._beta = vmap_over( + dist._beta, concentration1=None, concentration0=concentration + ) + elif dist_axes.sample_method == "cvine": + dist_axes._beta = vmap_over( + dist._beta, concentration1=concentration, concentration0=concentration + ) + return dist_axes + + +@vmap_over.register +def _vmap_over_lognormal(dist: LogNormal, loc=None, scale=None): + dist_axes = _default_vmap_over(dist, loc=loc, scale=scale) + dist_axes.transforms = None + dist_axes.base_dist = vmap_over(dist.base_dist, loc=loc, scale=scale) + return dist_axes + + +@vmap_over.register +def _vmap_over_loguniform(dist: LogUniform, low=None, high=None): + dist_axes = _default_vmap_over(dist, low=low, high=high) + dist_axes.base_dist = vmap_over(dist.base_dist, low=low, high=high) + dist_axes._support = vmap_over(dist._support, lower_bound=low, upper_bound=high) + return dist_axes + + +@vmap_over.register +def _vmap_over_car( + dist: CAR, loc=None, correlation=None, conditional_precision=None, adj_matrix=None +): + dist_axes = _default_vmap_over( + dist, + loc=loc, + correlation=correlation, + conditional_precision=conditional_precision, + ) + if not dist.is_sparse: + dist_axes.adj_matrix = adj_matrix + dist_axes.precision_matrix = adj_matrix + else: + assert adj_matrix is None + return dist_axes + + +@vmap_over.register +def _vmap_over_multivariate_student_t( + dist: MultivariateStudentT, df=None, loc=None, scale_tril=None +): + dist_axes = _default_vmap_over(dist, df=df, loc=loc, scale_tril=scale_tril) + dist_axes._chi2 = vmap_over(dist._chi2, df=df) + return dist_axes + + +@vmap_over.register +def _vmap_over_low_rank_multivariate_normal( + dist: LowRankMultivariateNormal, loc=None, cov_factor=None, cov_diag=None +): + dist_axes = _default_vmap_over( + dist, loc=loc, cov_factor=cov_factor, cov_diag=cov_diag + ) + dist_axes._capacitance_tril = cov_diag if cov_diag is not None else cov_factor + return dist_axes + + +@vmap_over.register +def _vmap_over_pareto(dist: Pareto, scale=None, alpha=None): + dist_axes = _default_vmap_over(dist, scale=scale, alpha=alpha) + dist_axes.base_dist = vmap_over(dist.base_dist, rate=alpha) + dist_axes.transforms = [None, vmap_over(dist.transforms[1], loc=None, scale=scale)] + return dist_axes + + +@vmap_over.register +def _vmap_over_relaxed_bernoulli_logits( + dist: RelaxedBernoulliLogits, temperature=None, logits=None +): + dist_axes = _default_vmap_over(dist, temperature=temperature, logits=logits) + dist_axes.transforms = None + dist_axes.base_dist = vmap_over( + dist.base_dist, + loc=logits if logits is not None else temperature, + scale=temperature, + ) + return dist_axes + + +@vmap_over.register +def _vmap_over_student_t(dist: StudentT, df=None, loc=None, scale=None): + dist_axes = _default_vmap_over(dist, df=df, loc=loc, scale=scale) + dist_axes._chi2 = vmap_over(dist._chi2, df=df) + return dist_axes + + +@vmap_over.register +def _vmap_over_two_sided_truncated_distribution( + dist: TwoSidedTruncatedDistribution, low=None, high=None +): + dist_axes = _default_vmap_over(dist, low=low, high=high) + dist_axes.base_dist = None + dist_axes._support = vmap_over(dist._support, lower_bound=low, upper_bound=high) + return dist_axes + + +@vmap_over.register +def _vmap_over_left_truncated_distribution(dist: LeftTruncatedDistribution, low=None): + dist_axes = _default_vmap_over(dist, low=low) + dist_axes.base_dist = None + dist_axes._support = vmap_over(dist._support, lower_bound=low) + return dist_axes + + +@vmap_over.register +def _vmap_over_right_truncated_distribution( + dist: RightTruncatedDistribution, high=None +): + dist_axes = _default_vmap_over(dist, high=high) + dist_axes.base_dist = None + dist_axes._support = vmap_over(dist._support, upper_bound=high) + return dist_axes + + +@vmap_over.register +def _vmap_over_beta_binomial( + dist: BetaBinomial, concentration1=None, concentration0=None, total_count=None +): + dist_axes = _default_vmap_over( + dist, + concentration1=concentration1, + concentration0=concentration0, + total_count=total_count, + ) + dist_axes._beta = vmap_over( + dist._beta, concentration1=concentration1, concentration0=concentration0 + ) + return dist_axes + + +@vmap_over.register +def _vmap_over_dirichlet_multinomial(dist: DirichletMultinomial, concentration=None): + dist_axes = _default_vmap_over(dist, concentration=concentration) + dist_axes._dirichlet = vmap_over(dist._dirichlet, concentration=concentration) + return dist_axes + + +@vmap_over.register +def _vmap_over_gamma_poisson(dist: GammaPoisson, concentration=None, rate=None): + dist_axes = _default_vmap_over(dist, concentration=concentration, rate=rate) + dist_axes._gamma = vmap_over(dist._gamma, concentration=concentration, rate=rate) + return dist_axes + + +@vmap_over.register +def _vmap_over_negative_binomial_probs( + dist: NegativeBinomialProbs, total_count=None, probs=None +): + dist_axes = vmap_over.dispatch(GammaPoisson)( + dist, concentration=total_count, rate=probs + ) + dist_axes.total_count = total_count + dist_axes.probs = probs + return dist_axes + + +@vmap_over.register +def _vmap_over_negative_binomial_logits( + dist: NegativeBinomialLogits, total_count=None, logits=None +): + dist_axes = vmap_over.dispatch(GammaPoisson)( + dist, concentration=total_count, rate=logits + ) + dist_axes.total_count = total_count + dist_axes.logits = logits + return dist_axes + + +@vmap_over.register +def _vmap_over_negative_binomial_2( + dist: NegativeBinomial2, mean=None, concentration=None +): + return vmap_over.dispatch(GammaPoisson)( + dist, + concentration=concentration, + rate=concentration if concentration is not None else mean, + ) + + +@vmap_over.register +def _vmap_over_ordered_logistic(dist: OrderedLogistic, predictor=None, cutpoints=None): + dist_axes = vmap_over.dispatch(CategoricalProbs)( + dist, probs=predictor if predictor is not None else cutpoints + ) + dist_axes.predictor = predictor + dist_axes.cutpoints = cutpoints + return dist_axes + + +@vmap_over.register +def _vmap_over_discrete_uniform(dist: DiscreteUniform, low=None, high=None): + dist_axes = _default_vmap_over(dist, low=low, high=high) + dist_axes._support = vmap_over(dist._support, lower_bound=low, upper_bound=high) + return dist_axes + + +@vmap_over.register +def _vmap_over_zero_inflated_poisson(dist: ZeroInflatedPoisson, gate=None, rate=None): + dist_axes = vmap_over.dispatch(ZeroInflatedProbs)( + dist, base_dist=vmap_over(dist.base_dist, rate=rate), gate=gate + ) + dist_axes.rate = rate + return dist_axes + + +@vmap_over.register +def _vmap_over_half_normal(dist: HalfNormal, scale=None): + dist_axes = _default_vmap_over(dist, scale=scale) + dist_axes._normal = vmap_over(dist._normal, loc=scale, scale=scale) + return dist_axes + + +@singledispatch +def promote_batch_shape(d: Distribution): + raise NotImplementedError + + +@promote_batch_shape.register +def _default_promote_batch_shape(d: Distribution): + attr_name = list(d.arg_constraints.keys())[0] + attr_event_dim = d.arg_constraints[attr_name].event_dim + attr = getattr(d, attr_name) + resolved_batch_shape = attr.shape[ + : max(0, attr.ndim - d.event_dim - attr_event_dim) + ] + new_self = copy.deepcopy(d) + new_self._batch_shape = resolved_batch_shape + return new_self + + +@promote_batch_shape.register +def _promote_batch_shape_expanded(d: ExpandedDistribution): + orig_delta_batch_shape = d.batch_shape[ + : len(d.batch_shape) - len(d.base_dist.batch_shape) + ] + + new_self = copy.deepcopy(d) + + # new dimensions coming from a vmap or numpyro scan/enum operation + promoted_base_dist = promote_batch_shape(new_self.base_dist) + new_shapes_elems = promoted_base_dist.batch_shape[ + : len(promoted_base_dist.batch_shape) - len(d.base_dist.batch_shape) + ] + + # The new dimensions are appended in front of the previous ExpandedDistribution + # batch dimensions. However, these batch dimensions are now present in + # the base distribution. Thus the dimensions present in the original + # ExpandedDistribution batch_shape, but not in the original base distribution + # batch_shape are now intermediate dimensions: to maintain broadcastability, + # the attribute of the batch distribution are expanded with such intermediate + # dimensions. + new_self._batch_shape = (*new_shapes_elems, *d.batch_shape) + + new_self.base_dist._batch_shape = ( + *new_shapes_elems, + *tuple(1 for _ in orig_delta_batch_shape), + *d.base_dist.batch_shape, + ) + new_axes_locs = range( + len(new_shapes_elems), + len(new_shapes_elems) + len(orig_delta_batch_shape), + ) + new_base_dist = tree_map( + lambda x: jnp.expand_dims(x, axis=new_axes_locs), new_self.base_dist + ) + + new_self.base_dist = new_base_dist + return new_self + + +@promote_batch_shape.register +def _promote_batch_shape_masked(d: MaskedDistribution): + new_self = copy.copy(d) + new_base_dist = promote_batch_shape(d.base_dist) + new_self._batch_shape = new_base_dist.batch_shape + new_self.base_dist = new_base_dist + return new_self + + +@promote_batch_shape.register +def _promote_batch_shape_unit(d: Unit): + return d diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index d9eacac36..5b4233fb0 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -42,6 +42,7 @@ class BetaBinomial(Distribution): } has_enumerate_support = True enumerate_support = BinomialProbs.enumerate_support + pytree_data_fields = ("concentration1", "concentration0", "total_count", "_beta") def __init__( self, concentration1, concentration0, total_count=1, *, validate_args=None @@ -109,6 +110,8 @@ class DirichletMultinomial(Distribution): "concentration": constraints.independent(constraints.positive, 1), "total_count": constraints.nonnegative_integer, } + pytree_data_fields = ("concentration", "_dirichlet") + pytree_aux_fields = ("total_count",) def __init__(self, concentration, total_count=1, *, validate_args=None): if jnp.ndim(concentration) < 1: @@ -183,6 +186,7 @@ class GammaPoisson(Distribution): "rate": constraints.positive, } support = constraints.nonnegative_integer + pytree_data_fields = ("concentration", "rate", "_gamma") def __init__(self, concentration, rate=1.0, *, validate_args=None): self.concentration, self.rate = promote_shapes(concentration, rate) @@ -275,6 +279,7 @@ class NegativeBinomial2(GammaPoisson): "concentration": constraints.positive, } support = constraints.nonnegative_integer + pytree_data_fields = ("concentration",) def __init__(self, mean, concentration, *, validate_args=None): rate = concentration / mean diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index c7e20c97a..a0ea4349f 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -157,6 +157,7 @@ class Beta(Distribution): } reparametrized_params = ["concentration1", "concentration0"] support = constraints.unit_interval + pytree_data_fields = ("concentration0", "concentration1", "_dirichlet") def __init__(self, concentration1, concentration0, *, validate_args=None): self.concentration1, self.concentration0 = promote_shapes( @@ -308,6 +309,8 @@ class EulerMaruyama(Distribution): """ arg_constraints = {"t": constraints.ordered_vector} + pytree_data_fields = ("t", "init_dist") + pytree_aux_fields = ("sde_fn",) def __init__(self, t, sde_fn, init_dist, *, validate_args=None): self.t = t @@ -515,6 +518,7 @@ class GaussianRandomWalk(Distribution): arg_constraints = {"scale": constraints.positive} support = constraints.real_vector reparametrized_params = ["scale"] + pytree_aux_fields = ("num_steps",) def __init__(self, scale=1.0, num_steps=1, *, validate_args=None): assert ( @@ -551,18 +555,12 @@ def variance(self): self.batch_shape + self.event_shape, ) - def tree_flatten(self): - return (self.scale,), self.num_steps - - @classmethod - def tree_unflatten(cls, aux_data, params): - return cls(*params, num_steps=aux_data) - class HalfCauchy(Distribution): reparametrized_params = ["scale"] support = constraints.positive arg_constraints = {"scale": constraints.positive} + pytree_data_fields = ("_cauchy", "scale") def __init__(self, scale=1.0, *, validate_args=None): self._cauchy = Cauchy(0.0, scale) @@ -598,6 +596,7 @@ class HalfNormal(Distribution): reparametrized_params = ["scale"] support = constraints.positive arg_constraints = {"scale": constraints.positive} + pytree_data_fields = ("_normal", "scale") def __init__(self, scale=1.0, *, validate_args=None): self._normal = Normal(0.0, scale) @@ -663,9 +662,6 @@ def variance(self): a = (self.rate / (self.concentration - 1)) ** 2 / (self.concentration - 2) return jnp.where(self.concentration <= 2, jnp.inf, a) - def tree_flatten(self): - return super(TransformedDistribution, self).tree_flatten() - def cdf(self, x): return 1 - self.base_dist.cdf(1 / x) @@ -825,9 +821,6 @@ def variance(self): log_beta = betaln(1 + 2 / self.concentration1, self.concentration0) return self.concentration0 * jnp.exp(log_beta) - jnp.square(self.mean) - def tree_flatten(self): - return super(TransformedDistribution, self).tree_flatten() - class Laplace(Distribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} @@ -920,6 +913,7 @@ def model(y): # y has dimension N x d arg_constraints = {"concentration": constraints.positive} reparametrized_params = ["concentration"] support = constraints.corr_matrix + pytree_aux_fields = ("dimension", "sample_method") def __init__( self, dimension, concentration=1.0, sample_method="onion", *, validate_args=None @@ -941,14 +935,6 @@ def mean(self): self.batch_shape + (self.dimension, self.dimension), ) - def tree_flatten(self): - return (self.concentration,), (self.dimension, self.sample_method) - - @classmethod - def tree_unflatten(cls, aux_data, params): - dimension, sample_method = aux_data - return cls(dimension, *params, sample_method=sample_method) - class LKJCholesky(Distribution): r""" @@ -1003,6 +989,8 @@ def model(y): # y has dimension N x d arg_constraints = {"concentration": constraints.positive} reparametrized_params = ["concentration"] support = constraints.corr_cholesky + pytree_data_fields = ("_beta", "concentration") + pytree_aux_fields = ("dimension", "sample_method") def __init__( self, dimension, concentration=1.0, sample_method="onion", *, validate_args=None @@ -1152,14 +1140,6 @@ def log_prob(self, value): normalize_term = pi_constant + numerator - denominator return unnormalized - normalize_term - def tree_flatten(self): - return (self.concentration,), (self.dimension, self.sample_method) - - @classmethod - def tree_unflatten(cls, aux_data, params): - dimension, sample_method = aux_data - return cls(dimension, *params, sample_method=sample_method) - class LogNormal(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} @@ -1181,9 +1161,6 @@ def mean(self): def variance(self): return (jnp.exp(self.scale**2) - 1) * jnp.exp(2 * self.loc + self.scale**2) - def tree_flatten(self): - return super(TransformedDistribution, self).tree_flatten() - def cdf(self, x): return self.base_dist.cdf(jnp.log(x)) @@ -1231,6 +1208,7 @@ def icdf(self, q): class LogUniform(TransformedDistribution): arg_constraints = {"low": constraints.positive, "high": constraints.positive} reparametrized_params = ["low", "high"] + pytree_data_fields = ("low", "high", "_support") def __init__(self, low, high, *, validate_args=None): base_dist = Uniform(jnp.log(low), jnp.log(high)) @@ -1255,9 +1233,6 @@ def variance(self): - self.mean**2 ) - def tree_flatten(self): - return super(TransformedDistribution, self).tree_flatten() - def cdf(self, x): return self.base_dist.cdf(jnp.log(x)) @@ -1530,14 +1505,6 @@ def variance(self): jnp.sum(self.scale_tril**2, axis=-1), self.batch_shape + self.event_shape ) - def tree_flatten(self): - return (self.loc, self.scale_tril), None - - @classmethod - def tree_unflatten(cls, aux_data, params): - loc, scale_tril = params - return cls(loc, scale_tril=scale_tril) - @staticmethod def infer_shapes( loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None @@ -1595,6 +1562,7 @@ class CAR(Distribution): "conditional_precision", "adj_matrix", ] + pytree_aux_fields = ("is_sparse", "adj_matrix") def __init__( self, @@ -1734,32 +1702,6 @@ def precision_matrix(self): correlation = jnp.expand_dims(self.correlation, (-2, -1)) return conditional_precision * (D - correlation * adj_matrix) - def tree_flatten(self): - if self.is_sparse: - return (self.loc, self.correlation, self.conditional_precision), ( - self.is_sparse, - self.adj_matrix, - ) - else: - return ( - self.loc, - self.correlation, - self.conditional_precision, - self.adj_matrix, - ), (self.is_sparse,) - - @classmethod - def tree_unflatten(cls, aux_data, params): - is_sparse = aux_data[0] - if is_sparse: - loc, correlation, conditional_precision = params - adj_matrix = aux_data[1] - else: - loc, correlation, conditional_precision, adj_matrix = params - return cls( - loc, correlation, conditional_precision, adj_matrix, is_sparse=is_sparse - ) - @staticmethod def infer_shapes(loc, correlation, conditional_precision, adj_matrix): event_shape = adj_matrix[-1:] @@ -1768,6 +1710,32 @@ def infer_shapes(loc, correlation, conditional_precision, adj_matrix): ) return batch_shape, event_shape + def tree_flatten(self): + data, aux = super().tree_flatten() + adj_matrix_data_idx = type(self).gather_pytree_data_fields().index("adj_matrix") + adj_matrix_aux_idx = type(self).gather_pytree_aux_fields().index("adj_matrix") + + if not self.is_sparse: + aux = list(aux) + aux[adj_matrix_aux_idx] = None + aux = tuple(aux) + else: + data = list(data) + data[adj_matrix_data_idx] = None + data = tuple(data) + return data, aux + + @classmethod + def tree_unflatten(cls, aux_data, params): + d = super().tree_unflatten(aux_data, params) + if not d.is_sparse: + adj_matrix_data_idx = cls.gather_pytree_data_fields().index("adj_matrix") + setattr(d, "adj_matrix", params[adj_matrix_data_idx]) + else: + adj_matrix_aux_idx = cls.gather_pytree_aux_fields().index("adj_matrix") + setattr(d, "adj_matrix", aux_data[adj_matrix_aux_idx]) + return d + class MultivariateStudentT(Distribution): arg_constraints = { @@ -1777,6 +1745,7 @@ class MultivariateStudentT(Distribution): } support = constraints.real_vector reparametrized_params = ["df", "loc", "scale_tril"] + pytree_data_fields = ("df", "loc", "scale_tril", "_chi2") def __init__( self, @@ -1922,6 +1891,7 @@ class LowRankMultivariateNormal(Distribution): } support = constraints.real_vector reparametrized_params = ["loc", "cov_factor", "cov_diag"] + pytree_data_fields = ("loc", "cov_factor", "cov_diag", "_capacitance_tril") def __init__(self, loc, cov_factor, cov_diag, *, validate_args=None): if jnp.ndim(loc) < 1: @@ -2132,9 +2102,6 @@ def cdf(self, value): def icdf(self, q): return self.scale / jnp.power(1 - q, 1 / self.alpha) - def tree_flatten(self): - return super(TransformedDistribution, self).tree_flatten() - class RelaxedBernoulliLogits(TransformedDistribution): arg_constraints = {"temperature": constraints.positive, "logits": constraints.real} @@ -2146,9 +2113,6 @@ def __init__(self, temperature, logits, *, validate_args=None): transforms = [SigmoidTransform()] super().__init__(base_dist, transforms, validate_args=validate_args) - def tree_flatten(self): - return super(TransformedDistribution, self).tree_flatten() - def RelaxedBernoulli(temperature, probs=None, logits=None, *, validate_args=None): if probs is None and logits is None: @@ -2223,6 +2187,7 @@ class StudentT(Distribution): } support = constraints.real reparametrized_params = ["df", "loc", "scale"] + pytree_data_fields = ("df", "loc", "scale", "_chi2") def __init__(self, df, loc=0.0, scale=1.0, *, validate_args=None): batch_shape = lax.broadcast_shapes( @@ -2295,6 +2260,7 @@ def icdf(self, q): class Uniform(Distribution): arg_constraints = {"low": constraints.dependent, "high": constraints.dependent} reparametrized_params = ["low", "high"] + pytree_data_fields = ("low", "high", "_support") def __init__(self, low=0.0, high=1.0, *, validate_args=None): self.low, self.high = promote_shapes(low, high) @@ -2330,22 +2296,6 @@ def mean(self): def variance(self): return (self.high - self.low) ** 2 / 12.0 - def tree_flatten(self): - if isinstance(self._support.lower_bound, (int, float)) and isinstance( - self._support.upper_bound, (int, float) - ): - aux_data = (self._support.lower_bound, self._support.upper_bound) - else: - aux_data = None - return (self.low, self.high), aux_data - - @classmethod - def tree_unflatten(cls, aux_data, params): - d = cls(*params) - if aux_data is not None: - d._support = constraints.interval(*aux_data) - return d - @staticmethod def infer_shapes(low=(), high=()): batch_shape = lax.broadcast_shapes(low, high) @@ -2415,6 +2365,7 @@ class BetaProportion(Beta): } reparametrized_params = ["mean", "concentration"] support = constraints.unit_interval + pytree_data_fields = ("concentration",) def __init__(self, mean, concentration, *, validate_args=None): self.concentration = jnp.broadcast_to( @@ -2450,6 +2401,7 @@ class AsymmetricLaplaceQuantile(Distribution): } reparametrized_params = ["loc", "scale", "quantile"] support = constraints.real + pytree_data_fields = ("loc", "scale", "quantile", "_ald") def __init__(self, loc=0.0, scale=1.0, quantile=0.5, *, validate_args=None): batch_shape = lax.broadcast_shapes( diff --git a/numpyro/distributions/copula.py b/numpyro/distributions/copula.py index a5052f0a7..98fec536a 100644 --- a/numpyro/distributions/copula.py +++ b/numpyro/distributions/copula.py @@ -33,6 +33,8 @@ class GaussianCopula(Distribution): "correlation_cholesky", ] + pytree_data_fields = ("marginal_dist", "base_dist") + def __init__( self, marginal_dist, @@ -105,20 +107,6 @@ def correlation_matrix(self): def correlation_cholesky(self): return self.base_dist.scale_tril - def tree_flatten(self): - marginal_flatten, marginal_aux = self.marginal_dist.tree_flatten() - return (marginal_flatten, self.base_dist.scale_tril), ( - type(self.marginal_dist), - marginal_aux, - ) - - @classmethod - def tree_unflatten(cls, aux_data, params): - marginal_flatten, correlation_cholesky = params - marginal_cls, marginal_aux = aux_data - marginal_dist = marginal_cls.tree_unflatten(marginal_aux, marginal_flatten) - return cls(marginal_dist, correlation_cholesky=correlation_cholesky) - class GaussianCopulaBeta(GaussianCopula): arg_constraints = { @@ -128,6 +116,7 @@ class GaussianCopulaBeta(GaussianCopula): "correlation_cholesky": constraints.corr_cholesky, } support = constraints.independent(constraints.unit_interval, 1) + pytree_data_fields = ("concentration1", "concentration0") def __init__( self, @@ -147,16 +136,3 @@ def __init__( correlation_cholesky, validate_args=validate_args, ) - - def tree_flatten(self): - return ( - (self.concentration1, self.concentration0), - self.base_dist.scale_tril, - ), None - - @classmethod - def tree_unflatten(cls, aux_data, params): - (concentration1, concentration0), correlation_cholesky = params - return cls( - concentration1, concentration0, correlation_cholesky=correlation_cholesky - ) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 2996808e8..b3e1dcd1a 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -214,6 +214,7 @@ def model(obs): """ arg_constraints = {"skewness": constraints.l1_ball} + pytree_data_fields = ("base_dist", "skewness") support = constraints.independent(constraints.circular, 1) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index abf8f70b1..0952c53da 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -407,6 +407,7 @@ def Categorical(probs=None, logits=None, *, validate_args=None): class DiscreteUniform(Distribution): arg_constraints = {"low": constraints.dependent, "high": constraints.dependent} has_enumerate_support = True + pytree_data_fields = ("low", "high", "_support") def __init__(self, low=0, high=1, *, validate_args=None): self.low, self.high = promote_shapes(low, high) @@ -504,6 +505,8 @@ class MultinomialProbs(Distribution): "probs": constraints.simplex, "total_count": constraints.nonnegative_integer, } + pytree_data_fields = ("probs",) + pytree_aux_fields = ("total_count", "total_count_max") def __init__( self, probs, total_count=1, *, total_count_max=None, validate_args=None @@ -568,6 +571,8 @@ class MultinomialLogits(Distribution): "logits": constraints.real_vector, "total_count": constraints.nonnegative_integer, } + pytree_data_fields = ("logits",) + pytree_aux_fields = ("total_count", "total_count_max") def __init__( self, logits, total_count=1, *, total_count_max=None, validate_args=None @@ -677,6 +682,7 @@ class Poisson(Distribution): """ arg_constraints = {"rate": constraints.positive} support = constraints.nonnegative_integer + pytree_aux_fields = ("is_sparse",) def __init__(self, rate, *, is_sparse=False, validate_args=None): self.rate = rate @@ -727,6 +733,7 @@ def cdf(self, value): class ZeroInflatedProbs(Distribution): arg_constraints = {"gate": constraints.unit_interval} + pytree_data_fields = ("base_dist", "gate") def __init__(self, base_dist, gate, *, validate_args=None): batch_shape = lax.broadcast_shapes(jnp.shape(gate), base_dist.batch_shape) @@ -828,6 +835,7 @@ class ZeroInflatedPoisson(ZeroInflatedProbs): arg_constraints = {"gate": constraints.unit_interval, "rate": constraints.positive} support = constraints.nonnegative_integer + pytree_data_fields = ("rate",) # TODO: resolve inconsistent parameter order w.r.t. Pyro # and support `gate_logits` argument diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 1544407d0..4f27f0022 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -133,6 +133,8 @@ class Distribution(metaclass=DistributionMeta): has_enumerate_support = False reparametrized_params = [] _validate_args = False + pytree_data_fields = () + pytree_aux_fields = ("_batch_shape", "_event_shape") # register Distribution as a pytree # ref: https://github.com/google/jax/issues/2916 @@ -140,15 +142,73 @@ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) tree_util.register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) + @classmethod + def gather_pytree_data_fields(cls): + bases = inspect.getmro(cls) + + all_pytree_data_fields = () + for base in bases: + if issubclass(base, Distribution): + all_pytree_data_fields += base.__dict__.get( + "pytree_data_fields", + tuple(base.__dict__.get("arg_constraints", {}).keys()), + ) + # remove duplicates + all_pytree_data_fields = tuple(set(all_pytree_data_fields)) + return all_pytree_data_fields + + @classmethod + def gather_pytree_aux_fields(cls): + bases = inspect.getmro(cls) + + all_pytree_aux_fields = () + for base in bases: + if issubclass(base, Distribution): + all_pytree_aux_fields += base.__dict__.get("pytree_aux_fields", ()) + # remove duplicates + all_pytree_aux_fields = tuple(set(all_pytree_aux_fields)) + return all_pytree_aux_fields + def tree_flatten(self): + all_pytree_data_fields_names = type(self).gather_pytree_data_fields() + all_pytree_data_fields_vals = tuple( + # getattr(self, attr_name) for attr_name in all_pytree_data_fields_names + self.__dict__.get(attr_name) + for attr_name in all_pytree_data_fields_names + ) + all_pytree_aux_fields_names = type(self).gather_pytree_aux_fields() + all_pytree_aux_fields_vals = tuple( + # getattr(self, attr_name) for attr_name in all_pytree_aux_fields_names + self.__dict__.get(attr_name) + for attr_name in all_pytree_aux_fields_names + ) return ( - tuple(getattr(self, param) for param in self.arg_constraints.keys()), - None, + all_pytree_data_fields_vals, + all_pytree_aux_fields_vals, ) @classmethod def tree_unflatten(cls, aux_data, params): - return cls(**dict(zip(cls.arg_constraints.keys(), params))) + pytree_data_fields = cls.gather_pytree_data_fields() + pytree_aux_fields = cls.gather_pytree_aux_fields() + + pytree_data_fields_dict = dict(zip(pytree_data_fields, params)) + pytree_aux_fields_dict = dict(zip(pytree_aux_fields, aux_data)) + + d = cls.__new__(cls) + + for k, v in pytree_data_fields_dict.items(): + setattr(d, k, v) + + for k, v in pytree_aux_fields_dict.items(): + setattr(d, k, v) + + Distribution.__init__( + d, + pytree_aux_fields_dict["_batch_shape"], + pytree_aux_fields_dict["_event_shape"], + ) + return d @staticmethod def set_default_validate_args(value): @@ -468,6 +528,11 @@ def is_discrete(self): class ExpandedDistribution(Distribution): arg_constraints = {} + pytree_data_fields = ("base_dist",) + pytree_aux_fields = ( + "_expanded_sizes", + "_interstitial_sizes", + ) def __init__(self, base_dist, batch_shape=()): if isinstance(base_dist, ExpandedDistribution): @@ -606,36 +671,6 @@ def variance(self): self.base_dist.variance, self.batch_shape + self.event_shape ) - def tree_flatten(self): - prepend_ndim = len(self.batch_shape) - len(self.base_dist.batch_shape) - base_dist = tree_util.tree_map( - lambda x: promote_shapes(x, shape=(1,) * prepend_ndim + jnp.shape(x))[0], - self.base_dist, - ) - base_flatten, base_aux = base_dist.tree_flatten() - return base_flatten, ( - type(self.base_dist), - base_aux, - self.batch_shape, - prepend_ndim, - ) - - @classmethod - def tree_unflatten(cls, aux_data, params): - base_cls, base_aux, batch_shape, prepend_ndim = aux_data - base_dist = base_cls.tree_unflatten(base_aux, params) - prepend_shape = base_dist.batch_shape[ - : len(base_dist.batch_shape) - len(batch_shape) - ] - if len(prepend_shape) == 0: - # in that case, no additional dimension was added - # to the flattened distribution, and the batch_shape - # manipulation happening during the flattening can be - # reverted - base_dist._batch_shape = base_dist.batch_shape[prepend_ndim:] - return cls(base_dist, batch_shape=batch_shape) - return cls(base_dist, batch_shape=prepend_shape + batch_shape) - class ImproperUniform(Distribution): """ @@ -690,6 +725,7 @@ class ImproperUniform(Distribution): arg_constraints = {} support = constraints.dependent + pytree_data_fields = ("support",) def __init__(self, support, batch_shape, event_shape, *, validate_args=None): self.support = constraints.independent( @@ -710,15 +746,6 @@ def _validate_sample(self, value): mask = jnp.all(jnp.reshape(mask, jnp.shape(mask)[:batch_dim] + (-1,)), -1) return mask - def tree_flatten(self): - raise NotImplementedError( - "Cannot flattening ImproperPrior distribution for general supports. " - "Please raising a feature request for your specific `support`. " - "Alternatively, you can use '.mask(False)' pattern. " - "For example, to define an improper prior over positive domain, " - "we can use the distribution `dist.LogNormal(0, 1).mask(False)`." - ) - class Independent(Distribution): """ @@ -744,6 +771,8 @@ class Independent(Distribution): """ arg_constraints = {} + pytree_data_fields = ("base_dist",) + pytree_aux_fields = ("reinterpreted_batch_ndims",) def __init__(self, base_dist, reinterpreted_batch_ndims, *, validate_args=None): if reinterpreted_batch_ndims > len(base_dist.batch_shape): @@ -807,20 +836,6 @@ def expand(self, batch_shape): self.reinterpreted_batch_ndims ) - def tree_flatten(self): - base_flatten, base_aux = self.base_dist.tree_flatten() - return base_flatten, ( - type(self.base_dist), - base_aux, - self.reinterpreted_batch_ndims, - ) - - @classmethod - def tree_unflatten(cls, aux_data, params): - base_cls, base_aux, reinterpreted_batch_ndims = aux_data - base_dist = base_cls.tree_unflatten(base_aux, params) - return cls(base_dist, reinterpreted_batch_ndims) - class MaskedDistribution(Distribution): """ @@ -834,6 +849,8 @@ class MaskedDistribution(Distribution): """ arg_constraints = {} + pytree_data_fields = ("base_dist", "_mask") + pytree_aux_fields = ("_mask",) def __init__(self, base_dist, mask): if isinstance(mask, bool): @@ -900,22 +917,30 @@ def variance(self): return self.base_dist.variance def tree_flatten(self): - base_flatten, base_aux = self.base_dist.tree_flatten() + data, aux = super().tree_flatten() + _mask_data_idx = type(self).gather_pytree_data_fields().index("_mask") + _mask_aux_idx = type(self).gather_pytree_aux_fields().index("_mask") + if isinstance(self._mask, bool): - return base_flatten, (type(self.base_dist), base_aux, self._mask) + data = list(data) + data[_mask_data_idx] = None + data = tuple(data) else: - return (base_flatten, self._mask), (type(self.base_dist), base_aux) + aux = list(aux) + aux[_mask_aux_idx] = None + aux = tuple(aux) + return data, aux @classmethod def tree_unflatten(cls, aux_data, params): - if len(aux_data) == 2: - base_flatten, mask = params - base_cls, base_aux = aux_data + d = super().tree_unflatten(aux_data, params) + _mask_data_idx = cls.gather_pytree_data_fields().index("_mask") + _mask_aux_idx = cls.gather_pytree_aux_fields().index("_mask") + if aux_data[_mask_aux_idx] is None: + setattr(d, "_mask", params[_mask_data_idx]) else: - base_flatten = params - base_cls, base_aux, mask = aux_data - base_dist = base_cls.tree_unflatten(base_aux, base_flatten) - return cls(base_dist, mask) + setattr(d, "_mask", aux_data[_mask_aux_idx]) + return d class TransformedDistribution(Distribution): @@ -932,6 +957,7 @@ class TransformedDistribution(Distribution): """ arg_constraints = {} + pytree_data_fields = ("base_dist", "transforms") def __init__(self, base_distribution, transforms, *, validate_args=None): if isinstance(transforms, Transform): @@ -1056,14 +1082,6 @@ def mean(self): def variance(self): raise NotImplementedError - def tree_flatten(self): - raise NotImplementedError( - "Flatenning TransformedDistribution is only supported for some specific cases." - " Consider using `TransformReparam` to convert this distribution to the base_dist," - " which is supported in most situtations. In addition, please reach out to us with" - " your usage cases." - ) - class FoldedDistribution(TransformedDistribution): """ @@ -1086,16 +1104,6 @@ def log_prob(self, value): plus_minus = jnp.array([1.0, -1.0]).reshape((2,) + (1,) * dim) return logsumexp(self.base_dist.log_prob(plus_minus * value), axis=0) - def tree_flatten(self): - base_flatten, base_aux = self.base_dist.tree_flatten() - return base_flatten, (type(self.base_dist), base_aux) - - @classmethod - def tree_unflatten(cls, aux_data, params): - base_cls, base_aux = aux_data - base_dist = base_cls.tree_unflatten(base_aux, params) - return cls(base_dist) - class Delta(Distribution): arg_constraints = { @@ -1143,13 +1151,6 @@ def mean(self): def variance(self): return jnp.zeros(self.batch_shape + self.event_shape) - def tree_flatten(self): - return (self.v, self.log_density), self.event_dim - - @classmethod - def tree_unflatten(cls, aux_data, params): - return cls(*params, event_dim=aux_data) - class Unit(Distribution): """ diff --git a/numpyro/distributions/mixtures.py b/numpyro/distributions/mixtures.py index 5ad71df6b..b670d074a 100644 --- a/numpyro/distributions/mixtures.py +++ b/numpyro/distributions/mixtures.py @@ -183,6 +183,9 @@ class MixtureSameFamily(_MixtureBase): () """ + pytree_data_fields = ("_mixing_distribution", "_component_distribution") + pytree_aux_fields = ("_mixture_size",) + def __init__( self, mixing_distribution, component_distribution, *, validate_args=None ): @@ -229,28 +232,6 @@ def support(self): def is_discrete(self): return self.component_distribution.is_discrete - def tree_flatten(self): - mixing_flat, mixing_aux = self.mixing_distribution.tree_flatten() - component_flat, component_aux = self.component_distribution.tree_flatten() - params = (mixing_flat, component_flat) - aux_data = ( - (type(self.mixing_distribution), type(self.component_distribution)), - (mixing_aux, component_aux), - ) - return params, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, params): - mixing_params, component_params = params - child_clss, child_aux = aux_data - mixing_cls, component_cls = child_clss - mixing_aux, component_aux = child_aux - mixing_dist = mixing_cls.tree_unflatten(mixing_aux, mixing_params) - component_dist = component_cls.tree_unflatten(component_aux, component_params) - return cls( - mixing_distribution=mixing_dist, component_distribution=component_dist - ) - @property def component_mean(self): return self.component_distribution.mean @@ -308,6 +289,9 @@ class MixtureGeneral(_MixtureBase): () """ + pytree_data_fields = ("_mixing_distribution", "_component_distributions") + pytree_aux_fields = ("_mixture_size",) + def __init__( self, mixing_distribution, component_distributions, *, validate_args=None ): @@ -376,34 +360,6 @@ def support(self): def is_discrete(self): return self.component_distributions[0].is_discrete - def tree_flatten(self): - mixing_flat, mixing_aux = self.mixing_distribution.tree_flatten() - dists_flat, dists_aux = zip( - *(d.tree_flatten() for d in self.component_distributions) - ) - params = (mixing_flat, dists_flat) - aux_data = ( - ( - type(self.mixing_distribution), - tuple(type(d) for d in self.component_distributions), - ), - (mixing_aux, dists_aux), - ) - return params, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, params): - params_mix, params_dists = params - (cls_mix, cls_dists), (mixing_aux, dists_aux) = aux_data - mixing_dist = cls_mix.tree_unflatten(mixing_aux, params_mix) - distributions = [ - c.tree_unflatten(a, p) - for c, a, p in zip(cls_dists, dists_aux, params_dists) - ] - return cls( - mixing_distribution=mixing_dist, component_distributions=distributions - ) - @property def component_mean(self): return jnp.stack( diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index fc78b2d22..eeb2de398 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -30,6 +30,7 @@ class LeftTruncatedDistribution(Distribution): arg_constraints = {"low": constraints.real} reparametrized_params = ["low"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) + pytree_data_fields = ("base_dist", "low", "_support") def __init__(self, base_dist, low=0.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) @@ -79,28 +80,6 @@ def log_prob(self, value): sign * (self._tail_prob_at_high - self._tail_prob_at_low) ) - def tree_flatten(self): - base_flatten, base_aux = self.base_dist.tree_flatten() - if isinstance(self._support.lower_bound, (int, float)): - return base_flatten, ( - type(self.base_dist), - base_aux, - self._support.lower_bound, - ) - else: - return (base_flatten, self.low), (type(self.base_dist), base_aux) - - @classmethod - def tree_unflatten(cls, aux_data, params): - if len(aux_data) == 2: - base_flatten, low = params - base_cls, base_aux = aux_data - else: - base_flatten = params - base_cls, base_aux, low = aux_data - base_dist = base_cls.tree_unflatten(base_aux, base_flatten) - return cls(base_dist, low=low) - @property def mean(self): if isinstance(self.base_dist, Normal): @@ -130,6 +109,7 @@ class RightTruncatedDistribution(Distribution): arg_constraints = {"high": constraints.real} reparametrized_params = ["high"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) + pytree_data_fields = ("base_dist", "high", "_support") def __init__(self, base_dist, high=0.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) @@ -164,28 +144,6 @@ def sample(self, key, sample_shape=()): def log_prob(self, value): return self.base_dist.log_prob(value) - jnp.log(self._cdf_at_high) - def tree_flatten(self): - base_flatten, base_aux = self.base_dist.tree_flatten() - if isinstance(self._support.upper_bound, (int, float)): - return base_flatten, ( - type(self.base_dist), - base_aux, - self._support.upper_bound, - ) - else: - return (base_flatten, self.high), (type(self.base_dist), base_aux) - - @classmethod - def tree_unflatten(cls, aux_data, params): - if len(aux_data) == 2: - base_flatten, high = params - base_cls, base_aux = aux_data - else: - base_flatten = params - base_cls, base_aux, high = aux_data - base_dist = base_cls.tree_unflatten(base_aux, base_flatten) - return cls(base_dist, high=high) - @property def mean(self): if isinstance(self.base_dist, Normal): @@ -212,9 +170,13 @@ def var(self): class TwoSidedTruncatedDistribution(Distribution): - arg_constraints = {"low": constraints.dependent, "high": constraints.dependent} + arg_constraints = { + "low": constraints.dependent, + "high": constraints.dependent, + } reparametrized_params = ["low", "high"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) + pytree_data_fields = ("base_dist", "low", "high", "_support") def __init__(self, base_dist, low=0.0, high=1.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) @@ -296,31 +258,6 @@ def log_prob(self, value): # cdf(high) - cdf(low) = cdf(2 * loc - low) - cdf(2 * loc - high) return self.base_dist.log_prob(value) - self._log_diff_tail_probs - def tree_flatten(self): - base_flatten, base_aux = self.base_dist.tree_flatten() - if isinstance(self._support.lower_bound, (int, float)) and isinstance( - self._support.upper_bound, (int, float) - ): - return base_flatten, ( - type(self.base_dist), - base_aux, - self._support.lower_bound, - self._support.upper_bound, - ) - else: - return (base_flatten, self.low, self.high), (type(self.base_dist), base_aux) - - @classmethod - def tree_unflatten(cls, aux_data, params): - if len(aux_data) == 2: - base_flatten, low, high = params - base_cls, base_aux = aux_data - else: - base_flatten = params - base_cls, base_aux, low, high = aux_data - base_dist = base_cls.tree_unflatten(base_aux, base_flatten) - return cls(base_dist, low=low, high=high) - @property def mean(self): if isinstance(self.base_dist, Normal): @@ -430,10 +367,3 @@ def log_prob(self, value): sum_even = jnp.exp(logsumexp(even_terms, axis=-1)) sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1)) return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi) - - def tree_flatten(self): - return (), self.batch_shape - - @classmethod - def tree_unflatten(cls, aux_data, params): - return cls(batch_shape=aux_data) diff --git a/test/test_distributions.py b/test/test_distributions.py index a773153eb..dbcb21218 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -12,6 +12,7 @@ from numpy.testing import assert_allclose, assert_array_equal import pytest import scipy +from scipy.sparse import csr_matrix import scipy.stats as osp import jax @@ -20,7 +21,6 @@ import jax.random as random from jax.scipy.special import expit, logsumexp from jax.scipy.stats import norm as jax_norm, truncnorm as jax_truncnorm -from jax.tree_util import tree_map import numpyro.distributions as dist from numpyro.distributions import ( @@ -29,6 +29,7 @@ kl_divergence, transforms, ) +from numpyro.distributions.batch_util import vmap_over from numpyro.distributions.discrete import _to_probs_bernoulli, _to_probs_multinom from numpyro.distributions.flows import InverseAutoregressiveTransform from numpyro.distributions.gof import InvalidTest, auto_goodness_of_fit @@ -149,6 +150,11 @@ def __init__(self, skewness, **kwargs): super().__init__(base_dist, skewness, **kwargs) +@vmap_over.register +def _vmap_over_sine_skewed_uniform(self: SineSkewedUniform, skewness=None): + return vmap_over.dispatch(dist.SineSkewed)(self, base_dist=None, skewness=skewness) + + class SineSkewedVonMises(dist.SineSkewed): def __init__(self, skewness, **kwargs): von_loc, von_conc = (np.array([0.0]), np.array([1.0])) @@ -156,6 +162,11 @@ def __init__(self, skewness, **kwargs): super().__init__(base_dist, skewness, **kwargs) +@vmap_over.register +def _vmap_over_sine_skewed_von_mises(self: SineSkewedVonMises, skewness=None): + return vmap_over.dispatch(dist.SineSkewed)(self, base_dist=None, skewness=skewness) + + class SineSkewedVonMisesBatched(dist.SineSkewed): def __init__(self, skewness, **kwargs): von_loc, von_conc = (np.array([0.0, -1.234]), np.array([1.0, 10.0])) @@ -163,65 +174,142 @@ def __init__(self, skewness, **kwargs): super().__init__(base_dist, skewness, **kwargs) -def _GaussianMixture(mixing_probs, loc, scale): - component_dist = dist.Normal(loc=loc, scale=scale) - mixing_distribution = dist.Categorical(probs=mixing_probs) - return dist.MixtureSameFamily( - mixing_distribution=mixing_distribution, - component_distribution=component_dist, - ) +@vmap_over.register +def _vmap_over_sine_skewed_von_mises_batched( + self: SineSkewedVonMisesBatched, skewness=None +): + return vmap_over.dispatch(dist.SineSkewed)(self, base_dist=None, skewness=skewness) + + +class _GaussianMixture(dist.MixtureSameFamily): + arg_constraints = {} + reparametrized_params = [] + + def __init__(self, mixing_probs, loc, scale): + component_dist = dist.Normal(loc=loc, scale=scale) + mixing_distribution = dist.Categorical(probs=mixing_probs) + super().__init__( + mixing_distribution=mixing_distribution, + component_distribution=component_dist, + ) + @property + def loc(self): + return self.component_distribution.loc -_GaussianMixture.arg_constraints = {} -_GaussianMixture.reparametrized_params = [] -_GaussianMixture.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ()) + @property + def scale(self): + return self.component_distribution.scale -def _Gaussian2DMixture(mixing_probs, loc, cov_matrix): - component_dist = dist.MultivariateNormal(loc=loc, covariance_matrix=cov_matrix) - mixing_distribution = dist.Categorical(probs=mixing_probs) - return dist.MixtureSameFamily( - mixing_distribution=mixing_distribution, - component_distribution=component_dist, +@vmap_over.register +def _vmap_over_gaussian_mixture(self: _GaussianMixture, loc=None, scale=None): + component_distribution = vmap_over( + self.component_distribution, loc=loc, scale=scale + ) + return vmap_over.dispatch(dist.MixtureSameFamily)( + self, _component_distribution=component_distribution ) -_Gaussian2DMixture.arg_constraints = {} -_Gaussian2DMixture.reparametrized_params = [] -_Gaussian2DMixture.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ()) +class _Gaussian2DMixture(dist.MixtureSameFamily): + arg_constraints = {} + reparametrized_params = [] + + def __init__(self, mixing_probs, loc, covariance_matrix): + component_dist = dist.MultivariateNormal( + loc=loc, covariance_matrix=covariance_matrix + ) + mixing_distribution = dist.Categorical(probs=mixing_probs) + super().__init__( + mixing_distribution=mixing_distribution, + component_distribution=component_dist, + ) + + @property + def loc(self): + return self.component_distribution.loc + + @property + def covariance_matrix(self): + return self.component_distribution.covariance_matrix -def _GeneralMixture(mixing_probs, locs, scales): - component_dists = [ - dist.Normal(loc=loc_, scale=scale_) for loc_, scale_ in zip(locs, scales) - ] - mixing_distribution = dist.Categorical(probs=mixing_probs) - return dist.MixtureGeneral( - mixing_distribution=mixing_distribution, - component_distributions=component_dists, +@vmap_over.register +def _vmap_over_gaussian_2d_mixture(self: _Gaussian2DMixture, loc=None): + component_distribution = vmap_over(self.component_distribution, loc=loc) + return vmap_over.dispatch(dist.MixtureSameFamily)( + self, _component_distribution=component_distribution ) -_GeneralMixture.arg_constraints = {} -_GeneralMixture.reparametrized_params = [] -_GeneralMixture.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ()) +class _GeneralMixture(dist.MixtureGeneral): + arg_constraints = {} + reparametrized_params = [] + + def __init__(self, mixing_probs, locs, scales): + component_dists = [ + dist.Normal(loc=loc_, scale=scale_) for loc_, scale_ in zip(locs, scales) + ] + mixing_distribution = dist.Categorical(probs=mixing_probs) + return super().__init__( + mixing_distribution=mixing_distribution, + component_distributions=component_dists, + ) + + @property + def locs(self): + # hotfix for vmapping tests, which cannot easily check non-array attributes + return self.component_distributions[0].loc + + @property + def scales(self): + return self.component_distributions[0].scale -def _General2DMixture(mixing_probs, locs, cov_matrices): - component_dists = [ - dist.MultivariateNormal(loc=loc_, covariance_matrix=cov_) - for loc_, cov_ in zip(locs, cov_matrices) +@vmap_over.register +def _vmap_over_general_mixture(self: _GeneralMixture, locs=None, scales=None): + component_distributions = [ + vmap_over(d, loc=locs, scale=scales) for d in self.component_distributions ] - mixing_distribution = dist.Categorical(probs=mixing_probs) - return dist.MixtureGeneral( - mixing_distribution=mixing_distribution, - component_distributions=component_dists, + return vmap_over.dispatch(dist.MixtureGeneral)( + self, _component_distributions=component_distributions ) -_General2DMixture.arg_constraints = {} -_General2DMixture.reparametrized_params = [] -_General2DMixture.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ()) +class _General2DMixture(dist.MixtureGeneral): + arg_constraints = {} + reparametrized_params = [] + + def __init__(self, mixing_probs, locs, covariance_matrices): + component_dists = [ + dist.MultivariateNormal(loc=loc_, covariance_matrix=covariance_matrix) + for loc_, covariance_matrix in zip(locs, covariance_matrices) + ] + mixing_distribution = dist.Categorical(probs=mixing_probs) + return super().__init__( + mixing_distribution=mixing_distribution, + component_distributions=component_dists, + ) + + @property + def locs(self): + # hotfix for vmapping tests, which cannot easily check non-array attributes + return self.component_distributions[0].loc + + @property + def covariance_matrices(self): + return self.component_distributions[0].covariance_matrix + + +@vmap_over.register +def _vmap_over_general_2d_mixture(self: _General2DMixture, locs=None): + component_distributions = [ + vmap_over(d, loc=locs) for d in self.component_distributions + ] + return vmap_over.dispatch(dist.MixtureGeneral)( + self, _component_distributions=component_distributions + ) class _ImproperWrapper(dist.ImproperUniform): @@ -236,12 +324,27 @@ def sample(self, key, sample_shape=()): class ZeroInflatedPoissonLogits(dist.discrete.ZeroInflatedLogits): arg_constraints = {"rate": constraints.positive, "gate_logits": constraints.real} + pytree_data_fields = ("rate",) def __init__(self, rate, gate_logits, *, validate_args=None): self.rate = rate super().__init__(dist.Poisson(rate), gate_logits, validate_args=validate_args) +@vmap_over.register +def _vmap_over_zero_inflated_poisson_logits( + self: ZeroInflatedPoissonLogits, rate=None, gate_logits=None +): + dist_axes = vmap_over.dispatch(dist.discrete.ZeroInflatedLogits)( + self, + base_dist=vmap_over(self.base_dist, rate=rate), + gate_logits=gate_logits, + gate=gate_logits, + ) + dist_axes.rate = rate + return dist_axes + + class SparsePoisson(dist.Poisson): def __init__(self, rate, *, validate_args=None): super().__init__(rate, is_sparse=True, validate_args=validate_args) @@ -255,9 +358,15 @@ def __init__(self, loc, scale, validate_args=None): self.scale = scale super().__init__(dist.Normal(loc, scale), validate_args=validate_args) - @classmethod - def tree_unflatten(cls, aux_data, params): - return dist.FoldedDistribution.tree_unflatten(aux_data, params) + +@vmap_over.register +def _vmap_over_folded_normal(self: "FoldedNormal", loc=None, scale=None): + d = vmap_over.dispatch(dist.FoldedDistribution)( + self, base_dist=vmap_over(self.base_dist, loc=loc, scale=scale) + ) + d.loc = loc + d.scale = scale + return d class _SparseCAR(dist.CAR): @@ -2658,14 +2767,6 @@ def f(x): lax.map(f, np.ones(3)) -def test_expand_pytree(): - def g(x): - return dist.Normal(x, 1).expand([10, 3]) - - assert lax.map(g, jnp.ones((5, 3))).batch_shape == (5, 10, 3) - assert tree_map(lambda x: x[None], g(0)).batch_shape == (1, 10, 3) - - def test_expand_no_unnecessary_batch_shape_expansion(): # ExpandedDistribution can mutate the `batch_shape` of # its base distribution in order to make ExpandedDistribution @@ -2793,6 +2894,133 @@ def sample_binomial_withp0(key): jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1)) +def _get_vmappable_dist_init_params(jax_dist): + if jax_dist.__name__ == ("_TruncatedCauchy"): + return [2, 3] + elif jax_dist.__name__ == ("_TruncatedNormal"): + return [2, 3] + elif issubclass(jax_dist, dist.Distribution): + init_parameters = list(inspect.signature(jax_dist.__init__).parameters.keys())[ + 1: + ] + vmap_over_parameters = list( + inspect.signature(vmap_over.dispatch(jax_dist)).parameters.keys() + )[1:] + return list( + [ + i + for i, name in enumerate(init_parameters) + if name in vmap_over_parameters + ] + ) + else: + raise ValueError + + +def _allclose_or_equal(a1, a2): + if isinstance(a1, np.ndarray): + return np.allclose(a2, a1) + elif isinstance(a1, jnp.ndarray): + return jnp.allclose(a2, a1) + elif isinstance(a1, csr_matrix): + return np.allclose(a2.todense(), a1.todense()) + else: + return a2 == a1 or a2 is a1 + + +def _tree_equal(t1, t2): + t = jax.tree_util.tree_map(_allclose_or_equal, t1, t2) + return jnp.all(jax.flatten_util.ravel_pytree(t)[0]) + + +@pytest.mark.parametrize( + "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL +) +def test_vmap_dist(jax_dist, sp_dist, params): + param_names = list(inspect.signature(jax_dist).parameters.keys()) + vmappable_param_idxs = _get_vmappable_dist_init_params(jax_dist) + vmappable_param_idxs = vmappable_param_idxs[: len(params)] + + if len(vmappable_param_idxs) == 0: + return + + def make_jax_dist(*params): + return jax_dist(*params) + + def sample(d: dist.Distribution): + return d.sample(random.PRNGKey(0)) + + d = make_jax_dist(*params) + + if isinstance(d, _SparseCAR) and d.is_sparse: + # In this case, since csr arrays are not jittable, + # _SparseCAR has a csr_matrix as part of its pytree + # definition (not as a pytree leaf). This causes pytree + # operations like tree_map to fail, since these functions + # compare the pytree def of each of the arguments using == + # which is ambiguous for array-like objects. + return + + in_out_axes_cases = [ + # vmap over all args + ( + tuple(0 if i in vmappable_param_idxs else None for i in range(len(params))), + 0, + ), + # vmap over a single arg, out over all attributes of a distribution + *( + ([0 if i == idx else None for i in range(len(params))], 0) + for idx in vmappable_param_idxs + if params[idx] is not None + ), + # vmap over a single arg, out over the associated attribute of the distribution + *( + ( + [0 if i == idx else None for i in range(len(params))], + vmap_over(d, **{param_names[idx]: 0}), + ) + for idx in vmappable_param_idxs + if params[idx] is not None + ), + # vmap over a single arg, axis=1, (out single attribute, axis=1) + *( + ( + [1 if i == idx else None for i in range(len(params))], + vmap_over(d, **{param_names[idx]: 1}), + ) + for idx in vmappable_param_idxs + if isinstance(params[idx], jnp.ndarray) and jnp.array(params[idx]).ndim > 0 + # skip this distribution because _GeneralMixture.__init__ turns + # 1d inputs into 0d attributes, thus breaks the expectations of + # the vmapping test case where in_axes=1, only done for rank>=1 tensors. + and jax_dist is not _GeneralMixture + ), + ] + + for in_axes, out_axes in in_out_axes_cases: + batched_params = [ + jax.tree_map(lambda x: jnp.expand_dims(x, ax), arg) + if isinstance(ax, int) + else arg + for arg, ax in zip(params, in_axes) + ] + # Recreate the jax_dist to avoid side effects coming from `d.sample` + # triggering lazy_property computations, which, in a few cases, break + # vmap_over's expectations regarding existing attributes to be vmapped. + d = make_jax_dist(*params) + batched_d = jax.vmap(make_jax_dist, in_axes=in_axes, out_axes=out_axes)( + *batched_params + ) + eq = vmap(lambda x, y: _tree_equal(x, y), in_axes=(out_axes, None))( + batched_d, d + ) + assert eq == jnp.array([True]) + + samples_dist = sample(d) + samples_batched_dist = jax.vmap(sample, in_axes=(out_axes,))(batched_d) + assert samples_batched_dist.shape == (1, *samples_dist.shape) + + def test_multinomial_abstract_total_count(): probs = jnp.array([0.2, 0.5, 0.3]) key = random.PRNGKey(0)