diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index a3f910f7c..afce44f7f 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -66,6 +66,13 @@ def __call__(self, x): def codomain(self): return _get_codomain(self.bijector) + def tree_flatten(self): + return self.bijector, () + + @classmethod + def tree_unflatten(cls, _, bijector): + return cls(bijector) + class BijectorTransform(Transform): """ @@ -106,6 +113,13 @@ def inverse_shape(self, shape): batch_shape = shape[: len(shape) - len(out_event_shape)] return batch_shape + in_shape + def tree_flatten(self): + return self.bijector, () + + @classmethod + def tree_unflatten(cls, _, bijector): + return cls(bijector) + @biject_to.register(BijectorConstraint) def _transform_to_bijector_constraint(constraint): diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 9cc1ca1e0..d442a3150 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -62,6 +62,8 @@ import numpy as np import jax.numpy +import jax.numpy as jnp +from jax.tree_util import register_pytree_node class Constraint(object): @@ -75,6 +77,10 @@ class Constraint(object): is_discrete = False event_dim = 0 + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) + def __call__(self, x): raise NotImplementedError @@ -94,8 +100,24 @@ def feasible_like(self, prototype): """ raise NotImplementedError + @classmethod + def tree_unflatten(cls, aux_data, params): + params_keys, aux_data = aux_data + self = cls.__new__(cls) + for k, v in zip(params_keys, params): + setattr(self, k, v) + + for k, v in aux_data.items(): + setattr(self, k, v) + return self + + +class ParameterFreeConstraint(Constraint): + def tree_flatten(self): + return (), ((), dict()) + -class _SingletonConstraint(Constraint): +class _SingletonConstraint(ParameterFreeConstraint): """ A constraint type which has only one canonical instance, like constraints.real, and unlike constraints.interval. @@ -202,8 +224,23 @@ def __call__(self, x=None, *, is_discrete=NotImplemented, event_dim=NotImplement event_dim = self._event_dim return _Dependent(is_discrete=is_discrete, event_dim=event_dim) + def __eq__(self, other): + return ( + type(self) is type(other) + and self._is_discrete == other._is_discrete + and self._event_dim == other._event_dim + ) + + def tree_flatten(self): + return (), ( + (), + dict(_is_discrete=self._is_discrete, _event_dim=self._event_dim), + ) + class dependent_property(property, _Dependent): + # XXX: this should not need to be pytree-able since it simply wraps a method + # and thus is automatically present once the method's object is created def __init__( self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented ): @@ -243,8 +280,16 @@ def __repr__(self): def feasible_like(self, prototype): return jax.numpy.broadcast_to(self.lower_bound + 1, jax.numpy.shape(prototype)) + def tree_flatten(self): + return (self.lower_bound,), (("lower_bound",), dict()) + + def __eq__(self, other): + if not isinstance(other, _GreaterThan): + return False + return jnp.array_equal(self.lower_bound, other.lower_bound) -class _Positive(_GreaterThan, _SingletonConstraint): + +class _Positive(_SingletonConstraint, _GreaterThan): def __init__(self): super().__init__(0.0) @@ -301,6 +346,20 @@ def __repr__(self): def feasible_like(self, prototype): return self.base_constraint.feasible_like(prototype) + def tree_flatten(self): + return (self.base_constraint,), ( + ("base_constraint",), + {"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims}, + ) + + def __eq__(self, other): + if not isinstance(other, _IndependentConstraint): + return False + + return (self.base_constraint == other.base_constraint) & ( + self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims + ) + class _RealVector(_IndependentConstraint, _SingletonConstraint): def __init__(self): @@ -327,6 +386,14 @@ def __repr__(self): def feasible_like(self, prototype): return jax.numpy.broadcast_to(self.upper_bound - 1, jax.numpy.shape(prototype)) + def tree_flatten(self): + return (self.upper_bound,), (("upper_bound",), dict()) + + def __eq__(self, other): + if not isinstance(other, _LessThan): + return False + return jnp.array_equal(self.upper_bound, other.upper_bound) + class _IntegerInterval(Constraint): is_discrete = True @@ -348,6 +415,20 @@ def __repr__(self): def feasible_like(self, prototype): return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) + def tree_flatten(self): + return (self.lower_bound, self.upper_bound), ( + ("lower_bound", "upper_bound"), + dict(), + ) + + def __eq__(self, other): + if not isinstance(other, _IntegerInterval): + return False + + return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( + self.upper_bound, other.upper_bound + ) + class _IntegerGreaterThan(Constraint): is_discrete = True @@ -366,13 +447,21 @@ def __repr__(self): def feasible_like(self, prototype): return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) + def tree_flatten(self): + return (self.lower_bound,), (("lower_bound",), dict()) -class _IntegerPositive(_IntegerGreaterThan, _SingletonConstraint): + def __eq__(self, other): + if not isinstance(other, _IntegerGreaterThan): + return False + return jnp.array_equal(self.lower_bound, other.lower_bound) + + +class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan): def __init__(self): super().__init__(1) -class _IntegerNonnegative(_IntegerGreaterThan, _SingletonConstraint): +class _IntegerNonnegative(_SingletonConstraint, _IntegerGreaterThan): def __init__(self): super().__init__(0) @@ -398,19 +487,25 @@ def feasible_like(self, prototype): ) def __eq__(self, other): - return ( - isinstance(other, _Interval) - and self.lower_bound == other.lower_bound - and self.upper_bound == other.upper_bound + if not isinstance(other, _Interval): + return False + return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( + self.upper_bound, other.upper_bound ) + def tree_flatten(self): + return (self.lower_bound, self.upper_bound), ( + ("lower_bound", "upper_bound"), + dict(), + ) -class _Circular(_Interval, _SingletonConstraint): + +class _Circular(_SingletonConstraint, _Interval): def __init__(self): super().__init__(-math.pi, math.pi) -class _UnitInterval(_Interval, _SingletonConstraint): +class _UnitInterval(_SingletonConstraint, _Interval): def __init__(self): super().__init__(0.0, 1.0) @@ -462,6 +557,14 @@ def feasible_like(self, prototype): value = jax.numpy.pad(jax.numpy.expand_dims(self.upper_bound, -1), pad_width) return jax.numpy.broadcast_to(value, prototype.shape) + def tree_flatten(self): + return (self.upper_bound,), (("upper_bound",), dict()) + + def __eq__(self, other): + if not isinstance(other, _Multinomial): + return False + return jnp.array_equal(self.upper_bound, other.upper_bound) + class _L1Ball(_SingletonConstraint): """ @@ -546,7 +649,7 @@ def feasible_like(self, prototype): return jax.numpy.full_like(prototype, 1 / prototype.shape[-1]) -class _SoftplusPositive(_GreaterThan, _SingletonConstraint): +class _SoftplusPositive(_SingletonConstraint, _GreaterThan): def __init__(self): super().__init__(lower_bound=0.0) diff --git a/numpyro/distributions/flows.py b/numpyro/distributions/flows.py index 781003649..cd9b21c35 100644 --- a/numpyro/distributions/flows.py +++ b/numpyro/distributions/flows.py @@ -93,6 +93,21 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): log_scale = intermediates return log_scale.sum(-1) + def tree_flatten(self): + return (self.log_scale_min_clip, self.log_scale_max_clip), ( + ("log_scale_min_clip", "log_scale_max_clip"), + {"arn": self.arn}, + ) + + def __eq__(self, other): + if not isinstance(other, InverseAutoregressiveTransform): + return False + return ( + (self.arn is other.arn) + & jnp.array_equal(self.log_scale_min_clip, other.log_scale_min_clip) + & jnp.array_equal(self.log_scale_max_clip, other.log_scale_max_clip) + ) + class BlockNeuralAutoregressiveTransform(Transform): """ @@ -139,3 +154,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): else: logdet = intermediates return logdet.sum(-1) + + def tree_flatten(self): + return (), ((), {"bn_arn": self.bn_arn}) + + def __eq__(self, other): + return ( + isinstance(other, BlockNeuralAutoregressiveTransform) + and self.bn_arn is other.bn_arn + ) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index c656c2cc4..6bd86f14a 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -13,7 +13,7 @@ import jax.numpy as jnp from jax.scipy.linalg import solve_triangular from jax.scipy.special import expit, logit -from jax.tree_util import tree_flatten, tree_map +from jax.tree_util import register_pytree_node, tree_flatten, tree_map from numpyro.distributions import constraints from numpyro.distributions.util import ( @@ -60,6 +60,10 @@ class Transform(object): codomain = constraints.real _inv = None + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) + @property def inv(self): inv = None @@ -106,6 +110,25 @@ def __getstate__(self): attrs[k] = v return attrs + @classmethod + def tree_unflatten(cls, aux_data, params): + params_keys, aux_data = aux_data + self = cls.__new__(cls) + for k, v in zip(params_keys, params): + setattr(self, k, v) + + for k, v in aux_data.items(): + setattr(self, k, v) + return self + + +class ParameterFreeTransform(Transform): + def tree_flatten(self): + return (), ((), dict()) + + def __eq__(self, other): + return isinstance(other, type(self)) + class _InverseTransform(Transform): def __init__(self, transform): @@ -137,8 +160,15 @@ def forward_shape(self, shape): def inverse_shape(self, shape): return self._inv.forward_shape(shape) + def tree_flatten(self): + return (self._inv,), (("_inv",), dict()) + + @classmethod + def tree_unflatten(cls, aux_data, params): + return cls(params) -class AbsTransform(Transform): + +class AbsTransform(ParameterFreeTransform): domain = constraints.real codomain = constraints.positive @@ -215,6 +245,18 @@ def inverse_shape(self, shape): shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) ) + def tree_flatten(self): + return (self.loc, self.scale, self.domain), (("loc", "scale", "domain"), dict()) + + def __eq__(self, other): + if not isinstance(other, AffineTransform): + return False + return ( + jnp.array_equal(self.loc, other.loc) + & jnp.array_equal(self.scale, other.scale) + & (self.domain == other.domain) + ) + def _get_compose_transform_input_event_dim(parts): input_event_dim = parts[-1].domain.event_dim @@ -318,6 +360,14 @@ def inverse_shape(self, shape): shape = part.inverse_shape(shape) return shape + def tree_flatten(self): + return (self.parts,), (("parts",), {}) + + def __eq__(self, other): + if not isinstance(other, ComposeTransform): + return False + return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) + def _matrix_forward_shape(shape, offset=0): # Reshape from (..., N) to (..., D, D). @@ -342,7 +392,7 @@ def _matrix_inverse_shape(shape, offset=0): return shape[:-2] + (N,) -class CholeskyTransform(Transform): +class CholeskyTransform(ParameterFreeTransform): r""" Transform via the mapping :math:`y = cholesky(x)`, where `x` is a positive definite matrix. @@ -365,7 +415,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): ) -class CorrCholeskyTransform(Transform): +class CorrCholeskyTransform(ParameterFreeTransform): r""" Transforms a uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower @@ -482,8 +532,16 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y, intermediates=None): return x + def tree_flatten(self): + return (self.domain,), (("domain",), dict()) + + def __eq__(self, other): + if not isinstance(other, ExpTransform): + return False + return self.domain == other.domain -class IdentityTransform(Transform): + +class IdentityTransform(ParameterFreeTransform): def __call__(self, x): return x @@ -545,8 +603,21 @@ def forward_shape(self, shape): def inverse_shape(self, shape): return self.base_transform.inverse_shape(shape) + def tree_flatten(self): + return (self.base_transform, self.reinterpreted_batch_ndims), ( + ("base_transform", "reinterpreted_batch_ndims"), + dict(), + ) + + def __eq__(self, other): + if not isinstance(other, IndependentTransform): + return False + return (self.base_transform == other.base_transform) & ( + self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims + ) -class L1BallTransform(Transform): + +class L1BallTransform(ParameterFreeTransform): r""" Transforms a uncontrained real vector :math:`x` into the unit L1 ball. """ @@ -654,8 +725,18 @@ def inverse_shape(self, shape): raise ValueError("Too few dimensions on input") return lax.broadcast_shapes(shape, self.loc.shape, self.scale_tril.shape[:-1]) + def tree_flatten(self): + return (self.loc, self.scale_tril), (("loc", "scale_tril"), dict()) + + def __eq__(self, other): + if not isinstance(other, LowerCholeskyAffine): + return False + return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( + self.scale_tril, other.scale_tril + ) + -class LowerCholeskyTransform(Transform): +class LowerCholeskyTransform(ParameterFreeTransform): """ Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is @@ -723,7 +804,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): return (jnp.log(diag_softplus) * jnp.arange(n) - softplus(-diag)).sum(-1) -class OrderedTransform(Transform): +class OrderedTransform(ParameterFreeTransform): """ Transform a real vector to an ordered vector. @@ -781,6 +862,14 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.full(jnp.shape(x)[:-1], 0.0) + def tree_flatten(self): + return (self.permutation,), (("permutation",), dict()) + + def __eq__(self, other): + if not isinstance(other, PermuteTransform): + return False + return jnp.array_equal(self.permutation, other.permutation) + class PowerTransform(Transform): domain = constraints.positive @@ -804,8 +893,16 @@ def forward_shape(self, shape): def inverse_shape(self, shape): return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + def tree_flatten(self): + return (self.exponent,), (("exponent",), dict()) + + def __eq__(self, other): + if not isinstance(other, PowerTransform): + return False + return jnp.array_equal(self.exponent, other.exponent) -class SigmoidTransform(Transform): + +class SigmoidTransform(ParameterFreeTransform): codomain = constraints.unit_interval def __call__(self, x): @@ -871,6 +968,14 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): J_logdet = (softplus(y) + softplus(-y)).sum(-1) return J_logdet + def tree_flatten(self): + return (self.anchor_point,), (("anchor_point",), dict()) + + def __eq__(self, other): + if not isinstance(other, SimplexToOrderedTransform): + return False + return jnp.array_equal(self.anchor_point, other.anchor_point) + def forward_shape(self, shape): return shape[:-1] + (shape[-1] - 1,) @@ -882,7 +987,7 @@ def _softplus_inv(y): return jnp.log(-jnp.expm1(-y)) + y -class SoftplusTransform(Transform): +class SoftplusTransform(ParameterFreeTransform): r""" Transform from unconstrained space to positive domain via softplus :math:`y = \log(1 + \exp(x))`. The inverse is computed as :math:`x = \log(\exp(y) - 1)`. @@ -900,7 +1005,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): return -softplus(-x) -class SoftplusLowerCholeskyTransform(Transform): +class SoftplusLowerCholeskyTransform(ParameterFreeTransform): """ Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive @@ -934,7 +1039,7 @@ def inverse_shape(self, shape): return _matrix_inverse_shape(shape) -class StickBreakingTransform(Transform): +class StickBreakingTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.simplex @@ -1027,6 +1132,13 @@ def forward_shape(self, shape): def inverse_shape(self, shape): raise NotImplementedError + def tree_flatten(self): + # XXX: what if unpack_fn is a parametrized callable pytree? + return (), ((), {"unpack_fn": self.unpack_fn}) + + def __eq__(self, other): + return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn + ########################################################## # CONSTRAINT_REGISTRY diff --git a/test/test_constraints.py b/test/test_constraints.py new file mode 100644 index 000000000..0b9d23f13 --- /dev/null +++ b/test/test_constraints.py @@ -0,0 +1,183 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +import pytest + +from jax import jit, tree_map, vmap +import jax.numpy as jnp + +from numpyro.distributions import constraints + +SINGLETON_CONSTRAINTS = { + "boolean": constraints.boolean, + "circular": constraints.circular, + "corr_cholesky": constraints.corr_cholesky, + "corr_matrix": constraints.corr_matrix, + "l1_ball": constraints.l1_ball, + "lower_cholesky": constraints.lower_cholesky, + "scaled_unit_lower_cholesky": constraints.scaled_unit_lower_cholesky, + "nonnegative_integer": constraints.nonnegative_integer, + "ordered_vector": constraints.ordered_vector, + "positive": constraints.positive, + "positive_definite": constraints.positive_definite, + "positive_integer": constraints.positive_integer, + "positive_ordered_vector": constraints.positive_ordered_vector, + "real": constraints.real, + "real_vector": constraints.real_vector, + "real_matrix": constraints.real_matrix, + "simplex": constraints.simplex, + "softplus_lower_cholesky": constraints.softplus_lower_cholesky, + "softplus_positive": constraints.softplus_positive, + "sphere": constraints.sphere, + "unit_interval": constraints.unit_interval, +} + +_a = jnp.asarray + + +class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])): + pass + + +PARAMETRIZED_CONSTRAINTS = { + "dependent": T( + type(constraints.dependent), (), dict(is_discrete=True, event_dim=2) + ), + "greater_than": T(constraints.greater_than, (_a(0.0),), dict()), + "less_than": T(constraints.less_than, (_a(-1.0),), dict()), + "independent": T( + constraints.independent, + (constraints.greater_than(jnp.zeros((2,))),), + dict(reinterpreted_batch_ndims=1), + ), + "integer_interval": T(constraints.integer_interval, (_a(-1), _a(1)), dict()), + "integer_greater_than": T(constraints.integer_greater_than, (_a(1),), dict()), + "interval": T(constraints.interval, (_a(-1.0), _a(1.0)), dict()), + "multinomial": T( + constraints.multinomial, + (_a(1.0),), + dict(), + ), + "open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0)), dict()), +} + +# TODO: BijectorConstraint + + +@pytest.mark.parametrize( + "constraint", SINGLETON_CONSTRAINTS.values(), ids=SINGLETON_CONSTRAINTS.keys() +) +def test_singleton_constraint_pytree(constraint): + # test that singleton constraints objects can be used as pytrees + def in_cst(constraint, x): + return x**2 + + def out_cst(constraint, x): + return constraint + + jitted_in_cst = jit(in_cst) + jitted_out_cst = jit(out_cst) + + assert jitted_in_cst(constraint, 1.0) == 1.0 + assert jitted_out_cst(constraint, 1.0) == constraint + + assert jnp.allclose( + vmap(in_cst, in_axes=(None, 0), out_axes=0)(constraint, jnp.ones(3)), + jnp.ones(3), + ) + + assert ( + vmap(out_cst, in_axes=(None, 0), out_axes=None)(constraint, jnp.ones(3)) + is constraint + ) + + +@pytest.mark.parametrize( + "cls, cst_args, cst_kwargs", + PARAMETRIZED_CONSTRAINTS.values(), + ids=PARAMETRIZED_CONSTRAINTS.keys(), +) +def test_parametrized_constraint_pytree(cls, cst_args, cst_kwargs): + constraint = cls(*cst_args, **cst_kwargs) + + # test that singleton constraints objects can be used as pytrees + def in_cst(constraint, x): + return x**2 + + def out_cst(constraint, x): + return constraint + + jitted_in_cst = jit(in_cst) + jitted_out_cst = jit(out_cst) + + assert jitted_in_cst(constraint, 1.0) == 1.0 + assert jitted_out_cst(constraint, 1.0) == constraint + + assert jnp.allclose( + vmap(in_cst, in_axes=(None, 0), out_axes=0)(constraint, jnp.ones(3)), + jnp.ones(3), + ) + + assert ( + vmap(out_cst, in_axes=(None, 0), out_axes=None)(constraint, jnp.ones(3)) + == constraint + ) + + if len(cst_args) > 0: + # test creating and manipulating vmapped constraints + vmapped_cst_args = tree_map(lambda x: x[None], cst_args) + + vmapped_csts = jit(vmap(lambda args: cls(*args, **cst_kwargs), in_axes=(0,)))( + vmapped_cst_args + ) + assert vmap(lambda x: x == constraint, in_axes=0)(vmapped_csts).all() + + twice_vmapped_cst_args = tree_map(lambda x: x[None], vmapped_cst_args) + + vmapped_csts = jit( + vmap( + vmap(lambda args: cls(*args, **cst_kwargs), in_axes=(0,)), + in_axes=(0,), + ), + )(twice_vmapped_cst_args) + assert vmap(vmap(lambda x: x == constraint, in_axes=0), in_axes=0)( + vmapped_csts + ).all() + + +@pytest.mark.parametrize( + "cls, cst_args, cst_kwargs", + PARAMETRIZED_CONSTRAINTS.values(), + ids=PARAMETRIZED_CONSTRAINTS.keys(), +) +def test_parametrized_constraint_eq(cls, cst_args, cst_kwargs): + constraint = cls(*cst_args, **cst_kwargs) + constraint2 = cls(*cst_args, **cst_kwargs) + assert constraint == constraint2 + assert constraint != 1 + + # check that equality checks are robust to constraints parametrized + # by abstract values + @jit + def check_constraints(c1, c2): + return c1 == c2 + + assert check_constraints(constraint, constraint2) + + +@pytest.mark.parametrize( + "constraint", SINGLETON_CONSTRAINTS.values(), ids=SINGLETON_CONSTRAINTS.keys() +) +def test_singleton_constraint_eq(constraint): + assert constraint == constraint + assert constraint != 1 + + # check that equality checks are robust to constraints parametrized + # by abstract values + @jit + def check_constraints(c1, c2): + return c1 == c2 + + assert check_constraints(constraint, constraint) diff --git a/test/test_transforms.py b/test/test_transforms.py new file mode 100644 index 000000000..54316dd88 --- /dev/null +++ b/test/test_transforms.py @@ -0,0 +1,197 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple +from functools import partial + +import pytest + +from jax import jit, tree_map, vmap +import jax.numpy as jnp + +from numpyro.distributions.flows import ( + BlockNeuralAutoregressiveTransform, + InverseAutoregressiveTransform, +) +from numpyro.distributions.transforms import ( + AbsTransform, + AffineTransform, + CholeskyTransform, + ComposeTransform, + CorrCholeskyTransform, + CorrMatrixCholeskyTransform, + ExpTransform, + IdentityTransform, + IndependentTransform, + L1BallTransform, + LowerCholeskyAffine, + LowerCholeskyTransform, + OrderedTransform, + PermuteTransform, + PowerTransform, + ScaledUnitLowerCholeskyTransform, + SigmoidTransform, + SimplexToOrderedTransform, + SoftplusLowerCholeskyTransform, + SoftplusTransform, + StickBreakingTransform, + UnpackTransform, +) + + +def _unpack(x): + return (x,) + + +_a = jnp.asarray + + +def _smoke_neural_network(): + return None, None + + +class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): + pass + + +TRANSFORMS = { + "affine": T( + AffineTransform, (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), dict() + ), + "compose": T( + ComposeTransform, + ( + [ + AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), + ExpTransform(), + ], + ), + dict(), + ), + "independent": T( + IndependentTransform, + (AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])),), + dict(reinterpreted_batch_ndims=1), + ), + "lower_cholesky_affine": T( + LowerCholeskyAffine, (jnp.array([1.0, 2.0]), jnp.eye(2)), dict() + ), + "permute": T(PermuteTransform, (jnp.array([1, 0]),), dict()), + "power": T( + PowerTransform, + (_a(2.0),), + dict(), + ), + "simplex_to_ordered": T( + SimplexToOrderedTransform, + (_a(1.0),), + dict(), + ), + "unpack": T(UnpackTransform, (), dict(unpack_fn=_unpack)), + # unparametrized transforms + "abs": T(AbsTransform, (), dict()), + "cholesky": T(CholeskyTransform, (), dict()), + "corr_chol": T(CorrCholeskyTransform, (), dict()), + "corr_matrix_chol": T(CorrMatrixCholeskyTransform, (), dict()), + "exp": T(ExpTransform, (), dict()), + "identity": T(IdentityTransform, (), dict()), + "l1_ball": T(L1BallTransform, (), dict()), + "lower_cholesky": T(LowerCholeskyTransform, (), dict()), + "ordered": T(OrderedTransform, (), dict()), + "scaled_unit_lower_cholesky": T(ScaledUnitLowerCholeskyTransform, (), dict()), + "sigmoid": T(SigmoidTransform, (), dict()), + "softplus": T(SoftplusTransform, (), dict()), + "softplus_lower_cholesky": T(SoftplusLowerCholeskyTransform, (), dict()), + "stick_breaking": T(StickBreakingTransform, (), dict()), + # neural transforms + "iaf": T( + # autoregressive_nn is a non-jittable arg, which does not fit well with + # the current test pipeline, which assumes jittable args, and non-jittable kwargs + partial(InverseAutoregressiveTransform, _smoke_neural_network), + (_a(-1.0), _a(1.0)), + dict(), + ), + "bna": T( + partial(BlockNeuralAutoregressiveTransform, _smoke_neural_network), + (), + dict(), + ), +} + + +@pytest.mark.parametrize( + "cls, transform_args, transform_kwargs", + TRANSFORMS.values(), + ids=TRANSFORMS.keys(), +) +def test_parametrized_transform_pytree(cls, transform_args, transform_kwargs): + transform = cls(*transform_args, **transform_kwargs) + + # test that singleton transforms objects can be used as pytrees + def in_t(transform, x): + return x**2 + + def out_t(transform, x): + return transform + + jitted_in_t = jit(in_t) + jitted_out_t = jit(out_t) + + assert jitted_in_t(transform, 1.0) == 1.0 + assert jitted_out_t(transform, 1.0) == transform + + assert jnp.allclose( + vmap(in_t, in_axes=(None, 0), out_axes=0)(transform, jnp.ones(3)), + jnp.ones(3), + ) + + assert ( + vmap(out_t, in_axes=(None, 0), out_axes=None)(transform, jnp.ones(3)) + == transform + ) + + if len(transform_args) > 0: + # test creating and manipulating vmapped constraints + # this test assumes jittable args, and non-jittable kwargs, which is + # not suited for all transforms, see InverseAutoregressiveTransform. + # TODO: split among jittable and non-jittable args/kwargs instead. + vmapped_transform_args = tree_map(lambda x: x[None], transform_args) + + vmapped_transform = jit( + vmap(lambda args: cls(*args, **transform_kwargs), in_axes=(0,)) + )(vmapped_transform_args) + assert vmap(lambda x: x == transform, in_axes=0)(vmapped_transform).all() + + twice_vmapped_transform_args = tree_map( + lambda x: x[None], vmapped_transform_args + ) + + vmapped_transform = jit( + vmap( + vmap(lambda args: cls(*args, **transform_kwargs), in_axes=(0,)), + in_axes=(0,), + ) + )(twice_vmapped_transform_args) + assert vmap(vmap(lambda x: x == transform, in_axes=0), in_axes=0)( + vmapped_transform + ).all() + + +@pytest.mark.parametrize( + "cls, transform_args, transform_kwargs", + TRANSFORMS.values(), + ids=TRANSFORMS.keys(), +) +def test_parametrized_transform_eq(cls, transform_args, transform_kwargs): + transform = cls(*transform_args, **transform_kwargs) + transform2 = cls(*transform_args, **transform_kwargs) + assert transform == transform2 + assert transform != 1.0 + + # check that equality checks are robust to transforms parametrized + # by abstract values + @jit + def check_transforms(t1, t2): + return t1 == t2 + + assert check_transforms(transform, transform2)