From f712984c031dd513fa1f8b492e3c485f2938fb05 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Tue, 11 Apr 2023 22:52:33 +0100 Subject: [PATCH 01/14] [WIP] jittable transforms --- numpyro/distributions/constraints.py | 74 +++++++++++++++++-- numpyro/distributions/transforms.py | 31 +++++++- test/test_constraints.py | 104 +++++++++++++++++++++++++++ 3 files changed, 200 insertions(+), 9 deletions(-) create mode 100644 test/test_constraints.py diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 9cc1ca1e0..6dee0abae 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -62,6 +62,7 @@ import numpy as np import jax.numpy +from jax.tree_util import register_pytree_node class Constraint(object): @@ -75,6 +76,10 @@ class Constraint(object): is_discrete = False event_dim = 0 + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) + def __call__(self, x): raise NotImplementedError @@ -94,8 +99,24 @@ def feasible_like(self, prototype): """ raise NotImplementedError + @classmethod + def tree_unflatten(cls, aux_data, params): + params_keys, aux_data = aux_data + self = cls.__new__(cls) + for k, v in zip(params_keys, params): + setattr(self, k, v) + + for k, v in aux_data.items(): + setattr(self, k, v) + return self + + +class ParameterFreeConstraint(Constraint): + def tree_flatten(self): + return (), ((), dict()) + -class _SingletonConstraint(Constraint): +class _SingletonConstraint(ParameterFreeConstraint): """ A constraint type which has only one canonical instance, like constraints.real, and unlike constraints.interval. @@ -202,8 +223,17 @@ def __call__(self, x=None, *, is_discrete=NotImplemented, event_dim=NotImplement event_dim = self._event_dim return _Dependent(is_discrete=is_discrete, event_dim=event_dim) + def tree_flatten(self): + return (), ((), dict()) + + @classmethod + def tree_unflatten(cls, aux_data, params): + return cls() + class dependent_property(property, _Dependent): + # XXX: this should not need to be pytree-able since it simply wraps a method + # and thus is automatically present once the method's object is created def __init__( self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented ): @@ -243,8 +273,11 @@ def __repr__(self): def feasible_like(self, prototype): return jax.numpy.broadcast_to(self.lower_bound + 1, jax.numpy.shape(prototype)) + def tree_flatten(self): + return (self.lower_bound,), (("lower_bound",), dict()) -class _Positive(_GreaterThan, _SingletonConstraint): + +class _Positive(_SingletonConstraint, _GreaterThan): def __init__(self): super().__init__(0.0) @@ -301,6 +334,12 @@ def __repr__(self): def feasible_like(self, prototype): return self.base_constraint.feasible_like(prototype) + def tree_flatten(self): + return (self.base_constraint,), ( + ("base_constraint",), + {"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims}, + ) + class _RealVector(_IndependentConstraint, _SingletonConstraint): def __init__(self): @@ -327,6 +366,9 @@ def __repr__(self): def feasible_like(self, prototype): return jax.numpy.broadcast_to(self.upper_bound - 1, jax.numpy.shape(prototype)) + def tree_flatten(self): + return (self.upper_bound,), (("upper_bound",), dict()) + class _IntegerInterval(Constraint): is_discrete = True @@ -348,6 +390,12 @@ def __repr__(self): def feasible_like(self, prototype): return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) + def tree_flatten(self): + return (self.lower_bound, self.upper_bound), ( + ("lower_bound", "upper_bound"), + dict(), + ) + class _IntegerGreaterThan(Constraint): is_discrete = True @@ -366,13 +414,16 @@ def __repr__(self): def feasible_like(self, prototype): return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) + def tree_flatten(self): + return (self.lower_bound,), (("lower_bound",), dict()) -class _IntegerPositive(_IntegerGreaterThan, _SingletonConstraint): + +class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan): def __init__(self): super().__init__(1) -class _IntegerNonnegative(_IntegerGreaterThan, _SingletonConstraint): +class _IntegerNonnegative(_SingletonConstraint, _IntegerGreaterThan): def __init__(self): super().__init__(0) @@ -404,13 +455,19 @@ def __eq__(self, other): and self.upper_bound == other.upper_bound ) + def tree_flatten(self): + return (self.lower_bound, self.upper_bound), ( + ("lower_bound", "upper_bound"), + dict(), + ) + -class _Circular(_Interval, _SingletonConstraint): +class _Circular(_SingletonConstraint, _Interval): def __init__(self): super().__init__(-math.pi, math.pi) -class _UnitInterval(_Interval, _SingletonConstraint): +class _UnitInterval(_SingletonConstraint, _Interval): def __init__(self): super().__init__(0.0, 1.0) @@ -462,6 +519,9 @@ def feasible_like(self, prototype): value = jax.numpy.pad(jax.numpy.expand_dims(self.upper_bound, -1), pad_width) return jax.numpy.broadcast_to(value, prototype.shape) + def tree_flatten(self): + return (self.upper_bound,), (("upper_bound",), dict()) + class _L1Ball(_SingletonConstraint): """ @@ -546,7 +606,7 @@ def feasible_like(self, prototype): return jax.numpy.full_like(prototype, 1 / prototype.shape[-1]) -class _SoftplusPositive(_GreaterThan, _SingletonConstraint): +class _SoftplusPositive(_SingletonConstraint, _GreaterThan): def __init__(self): super().__init__(lower_bound=0.0) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index c8017eb57..76d666b22 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -13,7 +13,7 @@ import jax.numpy as jnp from jax.scipy.linalg import solve_triangular from jax.scipy.special import expit, logit -from jax.tree_util import tree_flatten, tree_map +from jax.tree_util import register_pytree_node, tree_flatten, tree_map from numpyro.distributions import constraints from numpyro.distributions.util import ( @@ -107,6 +107,19 @@ def __getstate__(self): return attrs +class ParameterFreeTransform(Transform): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) + + def tree_flatten(self): + return (), None + + @classmethod + def tree_unflatten(cls, aux_data, params): + return cls() + + class _InverseTransform(Transform): def __init__(self, transform): super().__init__() @@ -137,8 +150,15 @@ def forward_shape(self, shape): def inverse_shape(self, shape): return self._inv.forward_shape(shape) + def tree_flatten(self): + return self._inv, None -class AbsTransform(Transform): + @classmethod + def tree_unflatten(cls, aux_data, params): + return cls(params) + + +class AbsTransform(ParameterFreeTransform): domain = constraints.real codomain = constraints.positive @@ -215,6 +235,13 @@ def inverse_shape(self, shape): shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) ) + def tree_flatten(self): + return (self.loc, self.scale), None + + @classmethod + def tree_unflatten(cls, aux_data, params): + return cls(*params) + def _get_compose_transform_input_event_dim(parts): input_event_dim = parts[-1].domain.event_dim diff --git a/test/test_constraints.py b/test/test_constraints.py new file mode 100644 index 000000000..5a903db0b --- /dev/null +++ b/test/test_constraints.py @@ -0,0 +1,104 @@ +from collections import namedtuple + +import pytest + +from jax import jit, vmap +import jax.numpy as jnp + +from numpyro.distributions.constraints import ( + boolean, + circular, + corr_cholesky, + corr_matrix, + greater_than, + independent, + integer_greater_than, + integer_interval, + interval, + l1_ball, + less_than, + lower_cholesky, + multinomial, + nonnegative_integer, + open_interval, + ordered_vector, + positive, + positive_definite, + positive_integer, + positive_ordered_vector, + real, + real_matrix, + real_vector, + scaled_unit_lower_cholesky, + simplex, + softplus_lower_cholesky, + softplus_positive, + sphere, + unit_interval, +) + + +class T(namedtuple("TestCase", ["constraint_cls", "params"])): + pass + + +SINGLETON_CONSTRAINTS = ( + boolean, + circular, + corr_cholesky, + corr_matrix, + l1_ball, + lower_cholesky, + scaled_unit_lower_cholesky, + nonnegative_integer, + ordered_vector, + positive, + positive_definite, + positive_integer, + positive_ordered_vector, + real, + real_vector, + real_matrix, + simplex, + softplus_lower_cholesky, + softplus_positive, + sphere, + unit_interval, +) + +PARAMETRIZED_CONSTRAINTS = ( + greater_than, + less_than, + independent, + integer_interval, + integer_greater_than, + interval, + multinomial, + open_interval, +) + + +@pytest.mark.parametrize("constraint", SINGLETON_CONSTRAINTS) +def test_singleton_constrains_pytree(constraint): + # test that singleton constraints objects can be used as pytrees + def in_cst(constraint, x): + return x**2 + + def out_cst(constraint, x): + return constraint + + jitted_in_cst = jit(in_cst) + jitted_out_cst = jit(out_cst) + + assert jitted_in_cst(constraint, 1.0) == 1.0 + assert jitted_out_cst(constraint, 1.0) == constraint + + assert jnp.allclose( + vmap(in_cst, in_axes=(None, 0), out_axes=0)(constraint, jnp.ones(3)), + jnp.ones(3), + ) + + assert ( + vmap(out_cst, in_axes=(None, 0), out_axes=None)(constraint, jnp.ones(3)) + is constraint + ) From 0ca2586f79b2848a17d472a48ba89233710ad32d Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Tue, 11 Apr 2023 22:59:58 +0100 Subject: [PATCH 02/14] add licence to new test file --- test/test_constraints.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_constraints.py b/test/test_constraints.py index 5a903db0b..31c776e90 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -1,3 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + from collections import namedtuple import pytest From d857ca4eadbdaeb9e2137860cf3f8b8c52e19dcc Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sun, 16 Apr 2023 17:51:07 +0100 Subject: [PATCH 03/14] turn BijectorConstraint into pytree --- numpyro/contrib/tfp/distributions.py | 7 +++++++ test/test_constraints.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index a3f910f7c..e061d8194 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -66,6 +66,13 @@ def __call__(self, x): def codomain(self): return _get_codomain(self.bijector) + def tree_flatten(self): + return self.bijector, () + + @classmethod + def tree_unflatten(cls, _, bijector): + return cls(bijector) + class BijectorTransform(Transform): """ diff --git a/test/test_constraints.py b/test/test_constraints.py index 31c776e90..3f1585a05 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -80,6 +80,8 @@ class T(namedtuple("TestCase", ["constraint_cls", "params"])): open_interval, ) +# TODO: BijectorConstraint + @pytest.mark.parametrize("constraint", SINGLETON_CONSTRAINTS) def test_singleton_constrains_pytree(constraint): From 34e5daca2745b35d7f8c351ebeed52a0d9798430 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sun, 16 Apr 2023 18:15:47 +0100 Subject: [PATCH 04/14] test flattening/unflattening of parametrized constraints --- numpyro/distributions/constraints.py | 35 +++++++++ test/test_constraints.py | 107 ++++++++++++++++++--------- 2 files changed, 107 insertions(+), 35 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 6dee0abae..44f78205a 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -62,6 +62,7 @@ import numpy as np import jax.numpy +import jax.numpy as jnp from jax.tree_util import register_pytree_node @@ -276,6 +277,11 @@ def feasible_like(self, prototype): def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) + def __eq__(self, other): + return isinstance(other, _GreaterThan) and jnp.array_equal( + self.lower_bound, other.lower_bound + ) + class _Positive(_SingletonConstraint, _GreaterThan): def __init__(self): @@ -340,6 +346,13 @@ def tree_flatten(self): {"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims}, ) + def __eq__(self, other): + return ( + isinstance(other, _IndependentConstraint) + and self.base_constraint == other.base_constraint + and self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims + ) + class _RealVector(_IndependentConstraint, _SingletonConstraint): def __init__(self): @@ -369,6 +382,11 @@ def feasible_like(self, prototype): def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) + def __eq__(self, other): + return isinstance(other, _LessThan) and jnp.array_equal( + self.upper_bound, other.upper_bound + ) + class _IntegerInterval(Constraint): is_discrete = True @@ -396,6 +414,13 @@ def tree_flatten(self): dict(), ) + def __eq__(self, other): + return ( + isinstance(other, _IntegerInterval) + and jnp.array_equal(self.lower_bound, other.lower_bound) + and jnp.array_equal(self.upper_bound, other.upper_bound) + ) + class _IntegerGreaterThan(Constraint): is_discrete = True @@ -417,6 +442,11 @@ def feasible_like(self, prototype): def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) + def __eq__(self, other): + return isinstance(other, _IntegerGreaterThan) and jnp.array_equal( + self.lower_bound, other.lower_bound + ) + class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan): def __init__(self): @@ -522,6 +552,11 @@ def feasible_like(self, prototype): def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) + def __eq__(self, other): + return isinstance(other, _Multinomial) and jnp.array_equal( + self.upper_bound, other.upper_bound + ) + class _L1Ball(_SingletonConstraint): """ diff --git a/test/test_constraints.py b/test/test_constraints.py index 3f1585a05..fa32feb71 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -41,49 +41,54 @@ ) +SINGLETON_CONSTRAINTS = { + "boolean": boolean, + "circular": circular, + "corr_cholesky": corr_cholesky, + "corr_matrix": corr_matrix, + "l1_ball": l1_ball, + "lower_cholesky": lower_cholesky, + "scaled_unit_lower_cholesky": scaled_unit_lower_cholesky, + "nonnegative_integer": nonnegative_integer, + "ordered_vector": ordered_vector, + "positive": positive, + "positive_definite": positive_definite, + "positive_integer": positive_integer, + "positive_ordered_vector": positive_ordered_vector, + "real": real, + "real_vector": real_vector, + "real_matrix": real_matrix, + "simplex": simplex, + "softplus_lower_cholesky": softplus_lower_cholesky, + "softplus_positive": softplus_positive, + "sphere": sphere, + "unit_interval": unit_interval, +} + +_a = jnp.asarray + + class T(namedtuple("TestCase", ["constraint_cls", "params"])): pass -SINGLETON_CONSTRAINTS = ( - boolean, - circular, - corr_cholesky, - corr_matrix, - l1_ball, - lower_cholesky, - scaled_unit_lower_cholesky, - nonnegative_integer, - ordered_vector, - positive, - positive_definite, - positive_integer, - positive_ordered_vector, - real, - real_vector, - real_matrix, - simplex, - softplus_lower_cholesky, - softplus_positive, - sphere, - unit_interval, -) - -PARAMETRIZED_CONSTRAINTS = ( - greater_than, - less_than, - independent, - integer_interval, - integer_greater_than, - interval, - multinomial, - open_interval, -) +PARAMETRIZED_CONSTRAINTS = { + "greater_than": T(greater_than, (_a(0.0),)), + "less_than": T(less_than, (_a(-1.0),)), + "independent": T(independent, (greater_than(jnp.zeros((2,))), 1)), + "integer_interval": T(integer_interval, (_a(-1), _a(1))), + "integer_greater_than": T(integer_greater_than, (_a(1),)), + "interval": T(interval, (_a(-1.0), _a(1.0))), + "multinomial": T(multinomial, (_a(1.0),),), + "open_interval": T(open_interval, (_a(-1.0), _a(1.0))), +} # TODO: BijectorConstraint -@pytest.mark.parametrize("constraint", SINGLETON_CONSTRAINTS) +@pytest.mark.parametrize( + "constraint", SINGLETON_CONSTRAINTS.values(), ids=SINGLETON_CONSTRAINTS.keys() +) def test_singleton_constrains_pytree(constraint): # test that singleton constraints objects can be used as pytrees def in_cst(constraint, x): @@ -107,3 +112,35 @@ def out_cst(constraint, x): vmap(out_cst, in_axes=(None, 0), out_axes=None)(constraint, jnp.ones(3)) is constraint ) + + +@pytest.mark.parametrize( + "cls, params", + PARAMETRIZED_CONSTRAINTS.values(), + ids=PARAMETRIZED_CONSTRAINTS.keys(), +) +def test_parametrized_constrains_pytree(cls, params): + constraint = cls(*params) + + # test that singleton constraints objects can be used as pytrees + def in_cst(constraint, x): + return x**2 + + def out_cst(constraint, x): + return constraint + + jitted_in_cst = jit(in_cst) + jitted_out_cst = jit(out_cst) + + assert jitted_in_cst(constraint, 1.0) == 1.0 + assert jitted_out_cst(constraint, 1.0) == constraint + + assert jnp.allclose( + vmap(in_cst, in_axes=(None, 0), out_axes=0)(constraint, jnp.ones(3)), + jnp.ones(3), + ) + + assert ( + vmap(out_cst, in_axes=(None, 0), out_axes=None)(constraint, jnp.ones(3)) + == constraint + ) From bc2d0973c398cb918e0007727231869558989e85 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sun, 16 Apr 2023 18:20:00 +0100 Subject: [PATCH 05/14] cosmetic edits --- test/test_constraints.py | 96 +++++++++++++++------------------------- 1 file changed, 35 insertions(+), 61 deletions(-) diff --git a/test/test_constraints.py b/test/test_constraints.py index fa32feb71..d070bffba 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -8,61 +8,30 @@ from jax import jit, vmap import jax.numpy as jnp -from numpyro.distributions.constraints import ( - boolean, - circular, - corr_cholesky, - corr_matrix, - greater_than, - independent, - integer_greater_than, - integer_interval, - interval, - l1_ball, - less_than, - lower_cholesky, - multinomial, - nonnegative_integer, - open_interval, - ordered_vector, - positive, - positive_definite, - positive_integer, - positive_ordered_vector, - real, - real_matrix, - real_vector, - scaled_unit_lower_cholesky, - simplex, - softplus_lower_cholesky, - softplus_positive, - sphere, - unit_interval, -) - +from numpyro.distributions import constraints SINGLETON_CONSTRAINTS = { - "boolean": boolean, - "circular": circular, - "corr_cholesky": corr_cholesky, - "corr_matrix": corr_matrix, - "l1_ball": l1_ball, - "lower_cholesky": lower_cholesky, - "scaled_unit_lower_cholesky": scaled_unit_lower_cholesky, - "nonnegative_integer": nonnegative_integer, - "ordered_vector": ordered_vector, - "positive": positive, - "positive_definite": positive_definite, - "positive_integer": positive_integer, - "positive_ordered_vector": positive_ordered_vector, - "real": real, - "real_vector": real_vector, - "real_matrix": real_matrix, - "simplex": simplex, - "softplus_lower_cholesky": softplus_lower_cholesky, - "softplus_positive": softplus_positive, - "sphere": sphere, - "unit_interval": unit_interval, + "boolean": constraints.boolean, + "circular": constraints.circular, + "corr_cholesky": constraints.corr_cholesky, + "corr_matrix": constraints.corr_matrix, + "l1_ball": constraints.l1_ball, + "lower_cholesky": constraints.lower_cholesky, + "scaled_unit_lower_cholesky": constraints.scaled_unit_lower_cholesky, + "nonnegative_integer": constraints.nonnegative_integer, + "ordered_vector": constraints.ordered_vector, + "positive": constraints.positive, + "positive_definite": constraints.positive_definite, + "positive_integer": constraints.positive_integer, + "positive_ordered_vector": constraints.positive_ordered_vector, + "real": constraints.real, + "real_vector": constraints.real_vector, + "real_matrix": constraints.real_matrix, + "simplex": constraints.simplex, + "softplus_lower_cholesky": constraints.softplus_lower_cholesky, + "softplus_positive": constraints.softplus_positive, + "sphere": constraints.sphere, + "unit_interval": constraints.unit_interval, } _a = jnp.asarray @@ -73,14 +42,19 @@ class T(namedtuple("TestCase", ["constraint_cls", "params"])): PARAMETRIZED_CONSTRAINTS = { - "greater_than": T(greater_than, (_a(0.0),)), - "less_than": T(less_than, (_a(-1.0),)), - "independent": T(independent, (greater_than(jnp.zeros((2,))), 1)), - "integer_interval": T(integer_interval, (_a(-1), _a(1))), - "integer_greater_than": T(integer_greater_than, (_a(1),)), - "interval": T(interval, (_a(-1.0), _a(1.0))), - "multinomial": T(multinomial, (_a(1.0),),), - "open_interval": T(open_interval, (_a(-1.0), _a(1.0))), + "greater_than": T(constraints.greater_than, (_a(0.0),)), + "less_than": T(constraints.less_than, (_a(-1.0),)), + "independent": T( + constraints.independent, (constraints.greater_than(jnp.zeros((2,))), 1) + ), + "integer_interval": T(constraints.integer_interval, (_a(-1), _a(1))), + "integer_greater_than": T(constraints.integer_greater_than, (_a(1),)), + "interval": T(constraints.interval, (_a(-1.0), _a(1.0))), + "multinomial": T( + constraints.multinomial, + (_a(1.0),), + ), + "open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0))), } # TODO: BijectorConstraint From e9e6cb984790ebf6230df76fa4b5e3dd005b7f30 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sun, 16 Apr 2023 18:47:24 +0100 Subject: [PATCH 06/14] fix typo --- test/test_constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_constraints.py b/test/test_constraints.py index d070bffba..048e58324 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -63,7 +63,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params"])): @pytest.mark.parametrize( "constraint", SINGLETON_CONSTRAINTS.values(), ids=SINGLETON_CONSTRAINTS.keys() ) -def test_singleton_constrains_pytree(constraint): +def test_singleton_constraint_pytree(constraint): # test that singleton constraints objects can be used as pytrees def in_cst(constraint, x): return x**2 @@ -93,7 +93,7 @@ def out_cst(constraint, x): PARAMETRIZED_CONSTRAINTS.values(), ids=PARAMETRIZED_CONSTRAINTS.keys(), ) -def test_parametrized_constrains_pytree(cls, params): +def test_parametrized_constraint_pytree(cls, params): constraint = cls(*params) # test that singleton constraints objects can be used as pytrees From ef486a0b6036ae65d8595c081c1b8c3b62f5e6ac Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sun, 16 Apr 2023 19:14:45 +0100 Subject: [PATCH 07/14] implement tree_flatten/unflatten for transforms --- numpyro/distributions/flows.py | 23 +++++ numpyro/distributions/transforms.py | 127 ++++++++++++++++++++++------ test/test_transforms.py | 126 +++++++++++++++++++++++++++ 3 files changed, 252 insertions(+), 24 deletions(-) create mode 100644 test/test_transforms.py diff --git a/numpyro/distributions/flows.py b/numpyro/distributions/flows.py index 781003649..41c8c1b37 100644 --- a/numpyro/distributions/flows.py +++ b/numpyro/distributions/flows.py @@ -93,6 +93,20 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): log_scale = intermediates return log_scale.sum(-1) + def tree_flatten(self): + return (self.log_scale_min_clip, self.log_scale_max_clip), ( + ("log_scale_min_clip", "log_scale_max_clip"), + {"arn": self.arn}, + ) + + def __eq__(self, other): + return ( + isinstance(other, InverseAutoregressiveTransform) + and self.arn is other.arn + and self.log_scale_min_clip == other.log_scale_min_clip + and self.log_scale_max_clip == other.log_scale_max_clip + ) + class BlockNeuralAutoregressiveTransform(Transform): """ @@ -139,3 +153,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): else: logdet = intermediates return logdet.sum(-1) + + def tree_flatten(self): + return (), ((), {"bn_arn": self.bn_arn}) + + def __eq__(self, other): + return ( + isinstance(other, BlockNeuralAutoregressiveTransform) + and self.bn_arn is other.bn_arn + ) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 76d666b22..9b9f817e1 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -60,6 +60,10 @@ class Transform(object): codomain = constraints.real _inv = None + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) + @property def inv(self): inv = None @@ -106,18 +110,24 @@ def __getstate__(self): attrs[k] = v return attrs + @classmethod + def tree_unflatten(cls, aux_data, params): + params_keys, aux_data = aux_data + self = cls.__new__(cls) + for k, v in zip(params_keys, params): + setattr(self, k, v) -class ParameterFreeTransform(Transform): - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) + for k, v in aux_data.items(): + setattr(self, k, v) + return self + +class ParameterFreeTransform(Transform): def tree_flatten(self): - return (), None + return (), ((), dict()) - @classmethod - def tree_unflatten(cls, aux_data, params): - return cls() + def __eq__(self, other): + return isinstance(other, type(self)) class _InverseTransform(Transform): @@ -151,7 +161,7 @@ def inverse_shape(self, shape): return self._inv.forward_shape(shape) def tree_flatten(self): - return self._inv, None + return (self._inv,), (("_inv",), dict()) @classmethod def tree_unflatten(cls, aux_data, params): @@ -236,11 +246,15 @@ def inverse_shape(self, shape): ) def tree_flatten(self): - return (self.loc, self.scale), None + return (self.loc, self.scale, self.domain), (("loc", "scale", "domain"), dict()) - @classmethod - def tree_unflatten(cls, aux_data, params): - return cls(*params) + def __eq__(self, other): + return ( + isinstance(other, AffineTransform) + and jnp.array_equal(self.loc, other.loc) + and jnp.array_equal(self.scale, other.scale) + and self.domain == other.domain + ) def _get_compose_transform_input_event_dim(parts): @@ -345,6 +359,12 @@ def inverse_shape(self, shape): shape = part.inverse_shape(shape) return shape + def tree_flatten(self): + return (self.parts,), (("parts",), {}) + + def __eq__(self, other): + return isinstance(other, ComposeTransform) and self.parts == other.parts + def _matrix_forward_shape(shape, offset=0): # Reshape from (..., N) to (..., D, D). @@ -369,7 +389,7 @@ def _matrix_inverse_shape(shape, offset=0): return shape[:-2] + (N,) -class CholeskyTransform(Transform): +class CholeskyTransform(ParameterFreeTransform): r""" Transform via the mapping :math:`y = cholesky(x)`, where `x` is a positive definite matrix. @@ -392,7 +412,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): ) -class CorrCholeskyTransform(Transform): +class CorrCholeskyTransform(ParameterFreeTransform): r""" Transforms a uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower @@ -462,7 +482,7 @@ def inverse_shape(self, shape): return _matrix_inverse_shape(shape, offset=-1) -class CorrMatrixCholeskyTransform(CholeskyTransform): +class CorrMatrixCholeskyTransform(ParameterFreeTransform): r""" Transform via the mapping :math:`y = cholesky(x)`, where `x` is a correlation matrix. @@ -509,8 +529,14 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y, intermediates=None): return x + def tree_flatten(self): + return (self.domain,), (("domain",), dict()) + + def __eq__(self, other): + return isinstance(other, ExpTransform) and self.domain == other.domain -class IdentityTransform(Transform): + +class IdentityTransform(ParameterFreeTransform): def __call__(self, x): return x @@ -572,8 +598,21 @@ def forward_shape(self, shape): def inverse_shape(self, shape): return self.base_transform.inverse_shape(shape) + def tree_flatten(self): + return (self.base_transform, self.reinterpreted_batch_ndims), ( + ("base_transform", "reinterpreted_batch_ndims"), + dict(), + ) -class L1BallTransform(Transform): + def __eq__(self, other): + return ( + isinstance(other, IndependentTransform) + and self.base_transform == other.base_transform + and self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims + ) + + +class L1BallTransform(ParameterFreeTransform): r""" Transforms a uncontrained real vector :math:`x` into the unit L1 ball. """ @@ -681,8 +720,17 @@ def inverse_shape(self, shape): raise ValueError("Too few dimensions on input") return lax.broadcast_shapes(shape, self.loc.shape, self.scale_tril.shape[:-1]) + def tree_flatten(self): + return (self.loc, self.scale_tril), (("loc", "scale_tril"), dict()) + + def __eq__(self, other): + return isinstance(other, LowerCholeskyAffine) and ( + jnp.array_equal(self.loc, other.loc) + and jnp.array_equal(self.scale_tril, other.scale_tril) + ) + -class LowerCholeskyTransform(Transform): +class LowerCholeskyTransform(ParameterFreeTransform): """ Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is @@ -750,7 +798,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): return (jnp.log(diag_softplus) * jnp.arange(n) - softplus(-diag)).sum(-1) -class OrderedTransform(Transform): +class OrderedTransform(ParameterFreeTransform): """ Transform a real vector to an ordered vector. @@ -808,6 +856,14 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.full(jnp.shape(x)[:-1], 0.0) + def tree_flatten(self): + return (self.permutation,), (("permutation",), dict()) + + def __eq__(self, other): + return isinstance(other, PermuteTransform) and jnp.array_equal( + self.permutation, other.permutation + ) + class PowerTransform(Transform): domain = constraints.positive @@ -831,8 +887,16 @@ def forward_shape(self, shape): def inverse_shape(self, shape): return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + def tree_flatten(self): + return (self.exponent,), (("exponent",), dict()) + + def __eq__(self, other): + return isinstance(other, PowerTransform) and jnp.array_equal( + self.exponent, other.exponent + ) -class SigmoidTransform(Transform): + +class SigmoidTransform(ParameterFreeTransform): codomain = constraints.unit_interval def __call__(self, x): @@ -898,12 +962,20 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): J_logdet = (softplus(y) + softplus(-y)).sum(-1) return J_logdet + def tree_flatten(self): + return (self.anchor_point,), (("anchor_point",), dict()) + + def __eq__(self, other): + return isinstance(other, SimplexToOrderedTransform) and jnp.array_equal( + self.anchor_point, other.anchor_point + ) + def _softplus_inv(y): return jnp.log(-jnp.expm1(-y)) + y -class SoftplusTransform(Transform): +class SoftplusTransform(ParameterFreeTransform): r""" Transform from unconstrained space to positive domain via softplus :math:`y = \log(1 + \exp(x))`. The inverse is computed as :math:`x = \log(\exp(y) - 1)`. @@ -921,7 +993,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): return -softplus(-x) -class SoftplusLowerCholeskyTransform(Transform): +class SoftplusLowerCholeskyTransform(ParameterFreeTransform): """ Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive @@ -955,7 +1027,7 @@ def inverse_shape(self, shape): return _matrix_inverse_shape(shape) -class StickBreakingTransform(Transform): +class StickBreakingTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.simplex @@ -1048,6 +1120,13 @@ def forward_shape(self, shape): def inverse_shape(self, shape): raise NotImplementedError + def tree_flatten(self): + # XXX: what if unpack_fn is a parametrized callable pytree? + return (), ((), {"unpack_fn": self.unpack_fn}) + + def __eq__(self, other): + return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn + ########################################################## # CONSTRAINT_REGISTRY diff --git a/test/test_transforms.py b/test/test_transforms.py new file mode 100644 index 000000000..d6bd57672 --- /dev/null +++ b/test/test_transforms.py @@ -0,0 +1,126 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +import pytest + +from jax import jit, vmap +import jax.numpy as jnp + +from numpyro.distributions.flows import ( + BlockNeuralAutoregressiveTransform, + InverseAutoregressiveTransform, +) +from numpyro.distributions.transforms import ( + AbsTransform, + AffineTransform, + CholeskyTransform, + ComposeTransform, + CorrCholeskyTransform, + CorrMatrixCholeskyTransform, + ExpTransform, + IdentityTransform, + IndependentTransform, + L1BallTransform, + LowerCholeskyAffine, + LowerCholeskyTransform, + OrderedTransform, + PermuteTransform, + PowerTransform, + ScaledUnitLowerCholeskyTransform, + SigmoidTransform, + SimplexToOrderedTransform, + SoftplusLowerCholeskyTransform, + SoftplusTransform, + StickBreakingTransform, + UnpackTransform, +) + + +def _unpack(x): + return (x,) + + +_a = jnp.asarray + + +def _smoke_neural_network(): + return None, None + + +class T(namedtuple("TestCase", ["transform_cls", "params"])): + pass + + +TRANSFORMS = { + "affine": T(AffineTransform, (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))), + "compose": T(ComposeTransform, ([ExpTransform(), ExpTransform()],)), + "independent": T( + IndependentTransform, + (AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), 1), + ), + "lower_cholesky_affine": T( + LowerCholeskyAffine, (jnp.array([1.0, 2.0]), jnp.eye(2)) + ), + "permute": T(PermuteTransform, (jnp.array([1, 0]),)), + "power": T(PowerTransform, (_a(2.0),),), # fmt: skip + "simplex_to_ordered": T(SimplexToOrderedTransform, (_a(1.0),),), # fmt: skip + "unpack": T(UnpackTransform, (_unpack,)), + # unparametrized transforms + "abs": T(AbsTransform, ()), + "cholesky": T(CholeskyTransform, ()), + "corr_chol": T(CorrCholeskyTransform, ()), + "corr_matrix_chol": T(CorrMatrixCholeskyTransform, ()), + "exp": T(ExpTransform, ()), + "identity": T(IdentityTransform, ()), + "l1_ball": T(L1BallTransform, ()), + "lower_cholesky": T(LowerCholeskyTransform, ()), + "ordered": T(OrderedTransform, ()), + "scaled_unit_lower_cholesky": T(ScaledUnitLowerCholeskyTransform, ()), + "sigmoid": T(SigmoidTransform, ()), + "softplus": T(SoftplusTransform, ()), + "softplus_lower_cholesky": T(SoftplusLowerCholeskyTransform, ()), + "stick_breaking": T(StickBreakingTransform, ()), + # neural transforms + "iaf": T( + InverseAutoregressiveTransform, + (_smoke_neural_network, -1.0, 1.0), + ), + "bna": T( + BlockNeuralAutoregressiveTransform, + (_smoke_neural_network,), + ), +} + + +@pytest.mark.parametrize( + "cls, params", + TRANSFORMS.values(), + ids=TRANSFORMS.keys(), +) +def test_parametrized_transform_pytree(cls, params): + transform = cls(*params) + + # test that singleton transforms objects can be used as pytrees + def in_t(transform, x): + return x**2 + + def out_t(transform, x): + return transform + + jitted_in_t = jit(in_t) + jitted_out_t = jit(out_t) + + assert jitted_in_t(transform, 1.0) == 1.0 + assert jitted_out_t(transform, 1.0) == transform + + assert jnp.allclose( + vmap(in_t, in_axes=(None, 0), out_axes=0)(transform, jnp.ones(3)), + jnp.ones(3), + ) + + assert ( + vmap(out_t, in_axes=(None, 0), out_axes=None)(transform, jnp.ones(3)) + == transform + ) From d2007f9b4ff54a0396bf2058a7716754031aace7 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sun, 16 Apr 2023 19:20:45 +0100 Subject: [PATCH 08/14] attempt to avoid confusing black --- test/test_transforms.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index d6bd57672..53b9cf21d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -64,8 +64,14 @@ class T(namedtuple("TestCase", ["transform_cls", "params"])): LowerCholeskyAffine, (jnp.array([1.0, 2.0]), jnp.eye(2)) ), "permute": T(PermuteTransform, (jnp.array([1, 0]),)), - "power": T(PowerTransform, (_a(2.0),),), # fmt: skip - "simplex_to_ordered": T(SimplexToOrderedTransform, (_a(1.0),),), # fmt: skip + "power": T( + PowerTransform, + (_a(2.0),), + ), + "simplex_to_ordered": T( + SimplexToOrderedTransform, + (_a(1.0),), + ), "unpack": T(UnpackTransform, (_unpack,)), # unparametrized transforms "abs": T(AbsTransform, ()), From e0a2b3094ae3d24ab5340bcb83686d59544bc94e Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sun, 16 Apr 2023 19:50:00 +0100 Subject: [PATCH 09/14] add (un)flattening meths for BijectorTransform --- numpyro/contrib/tfp/distributions.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index e061d8194..afce44f7f 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -113,6 +113,13 @@ def inverse_shape(self, shape): batch_shape = shape[: len(shape) - len(out_event_shape)] return batch_shape + in_shape + def tree_flatten(self): + return self.bijector, () + + @classmethod + def tree_unflatten(cls, _, bijector): + return cls(bijector) + @biject_to.register(BijectorConstraint) def _transform_to_bijector_constraint(constraint): From 48f3b2e5d6d61245ad2d9b57672f4f8a02581869 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Mon, 17 Apr 2023 10:34:29 +0100 Subject: [PATCH 10/14] fixup! implement tree_flatten/unflatten for transforms --- numpyro/distributions/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 9b9f817e1..134050386 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -482,7 +482,7 @@ def inverse_shape(self, shape): return _matrix_inverse_shape(shape, offset=-1) -class CorrMatrixCholeskyTransform(ParameterFreeTransform): +class CorrMatrixCholeskyTransform(CholeskyTransform): r""" Transform via the mapping :math:`y = cholesky(x)`, where `x` is a correlation matrix. From c468f0482befd0f04ac86c745e42589cb8f5bb07 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Thu, 25 May 2023 23:35:56 +0100 Subject: [PATCH 11/14] test vmapping over transforms/constraints --- numpyro/distributions/constraints.py | 36 ++++++---- numpyro/distributions/flows.py | 6 +- numpyro/distributions/transforms.py | 26 +++---- test/test_constraints.py | 51 +++++++++---- test/test_transforms.py | 103 +++++++++++++++++++-------- 5 files changed, 151 insertions(+), 71 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 44f78205a..f89be66a3 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -224,12 +224,18 @@ def __call__(self, x=None, *, is_discrete=NotImplemented, event_dim=NotImplement event_dim = self._event_dim return _Dependent(is_discrete=is_discrete, event_dim=event_dim) - def tree_flatten(self): - return (), ((), dict()) + def __eq__(self, other): + return ( + type(self) is type(other) + and self._is_discrete == other._is_discrete + and self._event_dim == other._event_dim + ) - @classmethod - def tree_unflatten(cls, aux_data, params): - return cls() + def tree_flatten(self): + return (), ( + (), + dict(_is_discrete=self._is_discrete, _event_dim=self._event_dim), + ) class dependent_property(property, _Dependent): @@ -278,7 +284,7 @@ def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) def __eq__(self, other): - return isinstance(other, _GreaterThan) and jnp.array_equal( + return isinstance(other, _GreaterThan) & jnp.array_equal( self.lower_bound, other.lower_bound ) @@ -349,8 +355,8 @@ def tree_flatten(self): def __eq__(self, other): return ( isinstance(other, _IndependentConstraint) - and self.base_constraint == other.base_constraint - and self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims + & (self.base_constraint == other.base_constraint) + & (self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims) ) @@ -383,7 +389,7 @@ def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) def __eq__(self, other): - return isinstance(other, _LessThan) and jnp.array_equal( + return isinstance(other, _LessThan) & jnp.array_equal( self.upper_bound, other.upper_bound ) @@ -417,8 +423,8 @@ def tree_flatten(self): def __eq__(self, other): return ( isinstance(other, _IntegerInterval) - and jnp.array_equal(self.lower_bound, other.lower_bound) - and jnp.array_equal(self.upper_bound, other.upper_bound) + & jnp.array_equal(self.lower_bound, other.lower_bound) + & jnp.array_equal(self.upper_bound, other.upper_bound) ) @@ -443,7 +449,7 @@ def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) def __eq__(self, other): - return isinstance(other, _IntegerGreaterThan) and jnp.array_equal( + return isinstance(other, _IntegerGreaterThan) & jnp.array_equal( self.lower_bound, other.lower_bound ) @@ -481,8 +487,8 @@ def feasible_like(self, prototype): def __eq__(self, other): return ( isinstance(other, _Interval) - and self.lower_bound == other.lower_bound - and self.upper_bound == other.upper_bound + & jnp.array_equal(self.lower_bound, other.lower_bound) + & jnp.array_equal(self.upper_bound, other.upper_bound) ) def tree_flatten(self): @@ -553,7 +559,7 @@ def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) def __eq__(self, other): - return isinstance(other, _Multinomial) and jnp.array_equal( + return isinstance(other, _Multinomial) & jnp.array_equal( self.upper_bound, other.upper_bound ) diff --git a/numpyro/distributions/flows.py b/numpyro/distributions/flows.py index 41c8c1b37..d6f5ba7b9 100644 --- a/numpyro/distributions/flows.py +++ b/numpyro/distributions/flows.py @@ -102,9 +102,9 @@ def tree_flatten(self): def __eq__(self, other): return ( isinstance(other, InverseAutoregressiveTransform) - and self.arn is other.arn - and self.log_scale_min_clip == other.log_scale_min_clip - and self.log_scale_max_clip == other.log_scale_max_clip + & (self.arn is other.arn) + & jnp.array_equal(self.log_scale_min_clip, other.log_scale_min_clip) + & jnp.array_equal(self.log_scale_max_clip, other.log_scale_max_clip) ) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 134050386..82dc465a2 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -251,9 +251,9 @@ def tree_flatten(self): def __eq__(self, other): return ( isinstance(other, AffineTransform) - and jnp.array_equal(self.loc, other.loc) - and jnp.array_equal(self.scale, other.scale) - and self.domain == other.domain + & jnp.array_equal(self.loc, other.loc) + & jnp.array_equal(self.scale, other.scale) + & (self.domain == other.domain) ) @@ -363,7 +363,9 @@ def tree_flatten(self): return (self.parts,), (("parts",), {}) def __eq__(self, other): - return isinstance(other, ComposeTransform) and self.parts == other.parts + return isinstance(other, ComposeTransform) & jnp.logical_and( + *(p1 == p2 for p1, p2 in zip(self.parts, other.parts)) + ) def _matrix_forward_shape(shape, offset=0): @@ -533,7 +535,7 @@ def tree_flatten(self): return (self.domain,), (("domain",), dict()) def __eq__(self, other): - return isinstance(other, ExpTransform) and self.domain == other.domain + return isinstance(other, ExpTransform) & (self.domain == other.domain) class IdentityTransform(ParameterFreeTransform): @@ -607,8 +609,8 @@ def tree_flatten(self): def __eq__(self, other): return ( isinstance(other, IndependentTransform) - and self.base_transform == other.base_transform - and self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims + & (self.base_transform == other.base_transform) + & (self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims) ) @@ -724,9 +726,9 @@ def tree_flatten(self): return (self.loc, self.scale_tril), (("loc", "scale_tril"), dict()) def __eq__(self, other): - return isinstance(other, LowerCholeskyAffine) and ( + return isinstance(other, LowerCholeskyAffine) & ( jnp.array_equal(self.loc, other.loc) - and jnp.array_equal(self.scale_tril, other.scale_tril) + & jnp.array_equal(self.scale_tril, other.scale_tril) ) @@ -860,7 +862,7 @@ def tree_flatten(self): return (self.permutation,), (("permutation",), dict()) def __eq__(self, other): - return isinstance(other, PermuteTransform) and jnp.array_equal( + return isinstance(other, PermuteTransform) & jnp.array_equal( self.permutation, other.permutation ) @@ -891,7 +893,7 @@ def tree_flatten(self): return (self.exponent,), (("exponent",), dict()) def __eq__(self, other): - return isinstance(other, PowerTransform) and jnp.array_equal( + return isinstance(other, PowerTransform) & jnp.array_equal( self.exponent, other.exponent ) @@ -966,7 +968,7 @@ def tree_flatten(self): return (self.anchor_point,), (("anchor_point",), dict()) def __eq__(self, other): - return isinstance(other, SimplexToOrderedTransform) and jnp.array_equal( + return isinstance(other, SimplexToOrderedTransform) & jnp.array_equal( self.anchor_point, other.anchor_point ) diff --git a/test/test_constraints.py b/test/test_constraints.py index 048e58324..3b828ac77 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -5,7 +5,7 @@ import pytest -from jax import jit, vmap +from jax import jit, tree_map, vmap import jax.numpy as jnp from numpyro.distributions import constraints @@ -37,24 +37,30 @@ _a = jnp.asarray -class T(namedtuple("TestCase", ["constraint_cls", "params"])): +class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])): pass PARAMETRIZED_CONSTRAINTS = { - "greater_than": T(constraints.greater_than, (_a(0.0),)), - "less_than": T(constraints.less_than, (_a(-1.0),)), + "dependent": T( + type(constraints.dependent), (), dict(is_discrete=True, event_dim=2) + ), + "greater_than": T(constraints.greater_than, (_a(0.0),), dict()), + "less_than": T(constraints.less_than, (_a(-1.0),), dict()), "independent": T( - constraints.independent, (constraints.greater_than(jnp.zeros((2,))), 1) + constraints.independent, + (constraints.greater_than(jnp.zeros((2,))),), + dict(reinterpreted_batch_ndims=1), ), - "integer_interval": T(constraints.integer_interval, (_a(-1), _a(1))), - "integer_greater_than": T(constraints.integer_greater_than, (_a(1),)), - "interval": T(constraints.interval, (_a(-1.0), _a(1.0))), + "integer_interval": T(constraints.integer_interval, (_a(-1), _a(1)), dict()), + "integer_greater_than": T(constraints.integer_greater_than, (_a(1),), dict()), + "interval": T(constraints.interval, (_a(-1.0), _a(1.0)), dict()), "multinomial": T( constraints.multinomial, (_a(1.0),), + dict(), ), - "open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0))), + "open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0)), dict()), } # TODO: BijectorConstraint @@ -89,12 +95,12 @@ def out_cst(constraint, x): @pytest.mark.parametrize( - "cls, params", + "cls, cst_args, cst_kwargs", PARAMETRIZED_CONSTRAINTS.values(), ids=PARAMETRIZED_CONSTRAINTS.keys(), ) -def test_parametrized_constraint_pytree(cls, params): - constraint = cls(*params) +def test_parametrized_constraint_pytree(cls, cst_args, cst_kwargs): + constraint = cls(*cst_args, **cst_kwargs) # test that singleton constraints objects can be used as pytrees def in_cst(constraint, x): @@ -118,3 +124,24 @@ def out_cst(constraint, x): vmap(out_cst, in_axes=(None, 0), out_axes=None)(constraint, jnp.ones(3)) == constraint ) + + if len(cst_args) > 0: + # test creating and manipulating vmapped constraints + vmapped_cst_args = tree_map(lambda x: x[None], cst_args) + + vmapped_csts = jit(vmap(lambda args: cls(*args, **cst_kwargs), in_axes=(0,)))( + vmapped_cst_args + ) + assert vmap(lambda x: x == constraint, in_axes=0)(vmapped_csts).all() + + twice_vmapped_cst_args = tree_map(lambda x: x[None], vmapped_cst_args) + + vmapped_csts = jit( + vmap( + vmap(lambda args: cls(*args, **cst_kwargs), in_axes=(0,)), + in_axes=(0,), + ), + )(twice_vmapped_cst_args) + assert vmap(vmap(lambda x: x == constraint, in_axes=0), in_axes=0)( + vmapped_csts + ).all() diff --git a/test/test_transforms.py b/test/test_transforms.py index 53b9cf21d..c3e39d320 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from collections import namedtuple +from functools import partial import pytest -from jax import jit, vmap +from jax import jit, tree_map, vmap import jax.numpy as jnp from numpyro.distributions.flows import ( @@ -49,64 +50,82 @@ def _smoke_neural_network(): return None, None -class T(namedtuple("TestCase", ["transform_cls", "params"])): +class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): pass TRANSFORMS = { - "affine": T(AffineTransform, (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))), - "compose": T(ComposeTransform, ([ExpTransform(), ExpTransform()],)), + "affine": T( + AffineTransform, (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), dict() + ), + "compose": T( + ComposeTransform, + ( + [ + AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), + ExpTransform(), + ], + ), + dict(), + ), "independent": T( IndependentTransform, - (AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), 1), + (AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])),), + dict(reinterpreted_batch_ndims=1), ), "lower_cholesky_affine": T( - LowerCholeskyAffine, (jnp.array([1.0, 2.0]), jnp.eye(2)) + LowerCholeskyAffine, (jnp.array([1.0, 2.0]), jnp.eye(2)), dict() ), - "permute": T(PermuteTransform, (jnp.array([1, 0]),)), + "permute": T(PermuteTransform, (jnp.array([1, 0]),), dict()), "power": T( PowerTransform, (_a(2.0),), + dict(), ), "simplex_to_ordered": T( SimplexToOrderedTransform, (_a(1.0),), + dict(), ), - "unpack": T(UnpackTransform, (_unpack,)), + "unpack": T(UnpackTransform, (), dict(unpack_fn=_unpack)), # unparametrized transforms - "abs": T(AbsTransform, ()), - "cholesky": T(CholeskyTransform, ()), - "corr_chol": T(CorrCholeskyTransform, ()), - "corr_matrix_chol": T(CorrMatrixCholeskyTransform, ()), - "exp": T(ExpTransform, ()), - "identity": T(IdentityTransform, ()), - "l1_ball": T(L1BallTransform, ()), - "lower_cholesky": T(LowerCholeskyTransform, ()), - "ordered": T(OrderedTransform, ()), - "scaled_unit_lower_cholesky": T(ScaledUnitLowerCholeskyTransform, ()), - "sigmoid": T(SigmoidTransform, ()), - "softplus": T(SoftplusTransform, ()), - "softplus_lower_cholesky": T(SoftplusLowerCholeskyTransform, ()), - "stick_breaking": T(StickBreakingTransform, ()), + "abs": T(AbsTransform, (), dict()), + "cholesky": T(CholeskyTransform, (), dict()), + "corr_chol": T(CorrCholeskyTransform, (), dict()), + "corr_matrix_chol": T(CorrMatrixCholeskyTransform, (), dict()), + "exp": T(ExpTransform, (), dict()), + "identity": T(IdentityTransform, (), dict()), + "l1_ball": T(L1BallTransform, (), dict()), + "lower_cholesky": T(LowerCholeskyTransform, (), dict()), + "ordered": T(OrderedTransform, (), dict()), + "scaled_unit_lower_cholesky": T(ScaledUnitLowerCholeskyTransform, (), dict()), + "sigmoid": T(SigmoidTransform, (), dict()), + "softplus": T(SoftplusTransform, (), dict()), + "softplus_lower_cholesky": T(SoftplusLowerCholeskyTransform, (), dict()), + "stick_breaking": T(StickBreakingTransform, (), dict()), # neural transforms "iaf": T( - InverseAutoregressiveTransform, - (_smoke_neural_network, -1.0, 1.0), + # autoregressive_nn is a non-jittable arg, which does not fit well with + # the current test pipeline, which assumes jittable args, and non-jittable kwargs + partial(InverseAutoregressiveTransform, _smoke_neural_network), + (_a(-1.0), _a(1.0)), + dict(), ), "bna": T( - BlockNeuralAutoregressiveTransform, - (_smoke_neural_network,), + partial(BlockNeuralAutoregressiveTransform, _smoke_neural_network), + (), + dict(), ), } @pytest.mark.parametrize( - "cls, params", + "cls, transform_args, transform_kwargs", TRANSFORMS.values(), ids=TRANSFORMS.keys(), ) -def test_parametrized_transform_pytree(cls, params): - transform = cls(*params) +def test_parametrized_transform_pytree(cls, transform_args, transform_kwargs): + transform = cls(*transform_args, **transform_kwargs) # test that singleton transforms objects can be used as pytrees def in_t(transform, x): @@ -130,3 +149,29 @@ def out_t(transform, x): vmap(out_t, in_axes=(None, 0), out_axes=None)(transform, jnp.ones(3)) == transform ) + + if len(transform_args) > 0: + # test creating and manipulating vmapped constraints + # this test assumes jittable args, and non-jittable kwargs, which is + # not suited for all transforms, see InverseAutoregressiveTransform. + # TODO: split among jittable and non-jittable args/kwargs instead. + vmapped_transform_args = tree_map(lambda x: x[None], transform_args) + + vmapped_transform = jit( + vmap(lambda args: cls(*args, **transform_kwargs), in_axes=(0,)) + )(vmapped_transform_args) + assert vmap(lambda x: x == transform, in_axes=0)(vmapped_transform).all() + + twice_vmapped_transform_args = tree_map( + lambda x: x[None], vmapped_transform_args + ) + + vmapped_transform = jit( + vmap( + vmap(lambda args: cls(*args, **transform_kwargs), in_axes=(0,)), + in_axes=(0,), + ) + )(twice_vmapped_transform_args) + assert vmap(vmap(lambda x: x == transform, in_axes=0), in_axes=0)( + vmapped_transform + ).all() From 49b051cfd15d19abc96cbe7ec6d6566170ea1890 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Tue, 30 May 2023 23:11:10 +0100 Subject: [PATCH 12/14] Make constraints `__eq__` checks robust to arbitrary inputs --- numpyro/distributions/constraints.py | 50 +++++++++++++++------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index f89be66a3..d442a3150 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -284,9 +284,9 @@ def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) def __eq__(self, other): - return isinstance(other, _GreaterThan) & jnp.array_equal( - self.lower_bound, other.lower_bound - ) + if not isinstance(other, _GreaterThan): + return False + return jnp.array_equal(self.lower_bound, other.lower_bound) class _Positive(_SingletonConstraint, _GreaterThan): @@ -353,10 +353,11 @@ def tree_flatten(self): ) def __eq__(self, other): - return ( - isinstance(other, _IndependentConstraint) - & (self.base_constraint == other.base_constraint) - & (self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims) + if not isinstance(other, _IndependentConstraint): + return False + + return (self.base_constraint == other.base_constraint) & ( + self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims ) @@ -389,9 +390,9 @@ def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) def __eq__(self, other): - return isinstance(other, _LessThan) & jnp.array_equal( - self.upper_bound, other.upper_bound - ) + if not isinstance(other, _LessThan): + return False + return jnp.array_equal(self.upper_bound, other.upper_bound) class _IntegerInterval(Constraint): @@ -421,10 +422,11 @@ def tree_flatten(self): ) def __eq__(self, other): - return ( - isinstance(other, _IntegerInterval) - & jnp.array_equal(self.lower_bound, other.lower_bound) - & jnp.array_equal(self.upper_bound, other.upper_bound) + if not isinstance(other, _IntegerInterval): + return False + + return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( + self.upper_bound, other.upper_bound ) @@ -449,9 +451,9 @@ def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) def __eq__(self, other): - return isinstance(other, _IntegerGreaterThan) & jnp.array_equal( - self.lower_bound, other.lower_bound - ) + if not isinstance(other, _IntegerGreaterThan): + return False + return jnp.array_equal(self.lower_bound, other.lower_bound) class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan): @@ -485,10 +487,10 @@ def feasible_like(self, prototype): ) def __eq__(self, other): - return ( - isinstance(other, _Interval) - & jnp.array_equal(self.lower_bound, other.lower_bound) - & jnp.array_equal(self.upper_bound, other.upper_bound) + if not isinstance(other, _Interval): + return False + return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( + self.upper_bound, other.upper_bound ) def tree_flatten(self): @@ -559,9 +561,9 @@ def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) def __eq__(self, other): - return isinstance(other, _Multinomial) & jnp.array_equal( - self.upper_bound, other.upper_bound - ) + if not isinstance(other, _Multinomial): + return False + return jnp.array_equal(self.upper_bound, other.upper_bound) class _L1Ball(_SingletonConstraint): From d1c4d7133aba726933f746947345707216f5bc6e Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Tue, 30 May 2023 23:25:02 +0100 Subject: [PATCH 13/14] make transforms equality check robust to arbitrary inputs --- numpyro/distributions/flows.py | 5 +-- numpyro/distributions/transforms.py | 48 ++++++++++++++++------------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/numpyro/distributions/flows.py b/numpyro/distributions/flows.py index d6f5ba7b9..cd9b21c35 100644 --- a/numpyro/distributions/flows.py +++ b/numpyro/distributions/flows.py @@ -100,9 +100,10 @@ def tree_flatten(self): ) def __eq__(self, other): + if not isinstance(other, InverseAutoregressiveTransform): + return False return ( - isinstance(other, InverseAutoregressiveTransform) - & (self.arn is other.arn) + (self.arn is other.arn) & jnp.array_equal(self.log_scale_min_clip, other.log_scale_min_clip) & jnp.array_equal(self.log_scale_max_clip, other.log_scale_max_clip) ) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 5173a2449..6bd86f14a 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -249,9 +249,10 @@ def tree_flatten(self): return (self.loc, self.scale, self.domain), (("loc", "scale", "domain"), dict()) def __eq__(self, other): + if not isinstance(other, AffineTransform): + return False return ( - isinstance(other, AffineTransform) - & jnp.array_equal(self.loc, other.loc) + jnp.array_equal(self.loc, other.loc) & jnp.array_equal(self.scale, other.scale) & (self.domain == other.domain) ) @@ -363,9 +364,9 @@ def tree_flatten(self): return (self.parts,), (("parts",), {}) def __eq__(self, other): - return isinstance(other, ComposeTransform) & jnp.logical_and( - *(p1 == p2 for p1, p2 in zip(self.parts, other.parts)) - ) + if not isinstance(other, ComposeTransform): + return False + return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) def _matrix_forward_shape(shape, offset=0): @@ -535,7 +536,9 @@ def tree_flatten(self): return (self.domain,), (("domain",), dict()) def __eq__(self, other): - return isinstance(other, ExpTransform) & (self.domain == other.domain) + if not isinstance(other, ExpTransform): + return False + return self.domain == other.domain class IdentityTransform(ParameterFreeTransform): @@ -607,10 +610,10 @@ def tree_flatten(self): ) def __eq__(self, other): - return ( - isinstance(other, IndependentTransform) - & (self.base_transform == other.base_transform) - & (self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims) + if not isinstance(other, IndependentTransform): + return False + return (self.base_transform == other.base_transform) & ( + self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims ) @@ -726,9 +729,10 @@ def tree_flatten(self): return (self.loc, self.scale_tril), (("loc", "scale_tril"), dict()) def __eq__(self, other): - return isinstance(other, LowerCholeskyAffine) & ( - jnp.array_equal(self.loc, other.loc) - & jnp.array_equal(self.scale_tril, other.scale_tril) + if not isinstance(other, LowerCholeskyAffine): + return False + return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( + self.scale_tril, other.scale_tril ) @@ -862,9 +866,9 @@ def tree_flatten(self): return (self.permutation,), (("permutation",), dict()) def __eq__(self, other): - return isinstance(other, PermuteTransform) & jnp.array_equal( - self.permutation, other.permutation - ) + if not isinstance(other, PermuteTransform): + return False + return jnp.array_equal(self.permutation, other.permutation) class PowerTransform(Transform): @@ -893,9 +897,9 @@ def tree_flatten(self): return (self.exponent,), (("exponent",), dict()) def __eq__(self, other): - return isinstance(other, PowerTransform) & jnp.array_equal( - self.exponent, other.exponent - ) + if not isinstance(other, PowerTransform): + return False + return jnp.array_equal(self.exponent, other.exponent) class SigmoidTransform(ParameterFreeTransform): @@ -968,9 +972,9 @@ def tree_flatten(self): return (self.anchor_point,), (("anchor_point",), dict()) def __eq__(self, other): - return isinstance(other, SimplexToOrderedTransform) & jnp.array_equal( - self.anchor_point, other.anchor_point - ) + if not isinstance(other, SimplexToOrderedTransform): + return False + return jnp.array_equal(self.anchor_point, other.anchor_point) def forward_shape(self, shape): return shape[:-1] + (shape[-1] - 1,) From 9584a4a273ec44d72e3e99e2641424bca241503c Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Tue, 30 May 2023 23:25:33 +0100 Subject: [PATCH 14/14] test constraints and transforms equality checks --- test/test_constraints.py | 36 ++++++++++++++++++++++++++++++++++++ test/test_transforms.py | 20 ++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/test/test_constraints.py b/test/test_constraints.py index 3b828ac77..0b9d23f13 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -145,3 +145,39 @@ def out_cst(constraint, x): assert vmap(vmap(lambda x: x == constraint, in_axes=0), in_axes=0)( vmapped_csts ).all() + + +@pytest.mark.parametrize( + "cls, cst_args, cst_kwargs", + PARAMETRIZED_CONSTRAINTS.values(), + ids=PARAMETRIZED_CONSTRAINTS.keys(), +) +def test_parametrized_constraint_eq(cls, cst_args, cst_kwargs): + constraint = cls(*cst_args, **cst_kwargs) + constraint2 = cls(*cst_args, **cst_kwargs) + assert constraint == constraint2 + assert constraint != 1 + + # check that equality checks are robust to constraints parametrized + # by abstract values + @jit + def check_constraints(c1, c2): + return c1 == c2 + + assert check_constraints(constraint, constraint2) + + +@pytest.mark.parametrize( + "constraint", SINGLETON_CONSTRAINTS.values(), ids=SINGLETON_CONSTRAINTS.keys() +) +def test_singleton_constraint_eq(constraint): + assert constraint == constraint + assert constraint != 1 + + # check that equality checks are robust to constraints parametrized + # by abstract values + @jit + def check_constraints(c1, c2): + return c1 == c2 + + assert check_constraints(constraint, constraint) diff --git a/test/test_transforms.py b/test/test_transforms.py index c3e39d320..54316dd88 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -175,3 +175,23 @@ def out_t(transform, x): assert vmap(vmap(lambda x: x == transform, in_axes=0), in_axes=0)( vmapped_transform ).all() + + +@pytest.mark.parametrize( + "cls, transform_args, transform_kwargs", + TRANSFORMS.values(), + ids=TRANSFORMS.keys(), +) +def test_parametrized_transform_eq(cls, transform_args, transform_kwargs): + transform = cls(*transform_args, **transform_kwargs) + transform2 = cls(*transform_args, **transform_kwargs) + assert transform == transform2 + assert transform != 1.0 + + # check that equality checks are robust to transforms parametrized + # by abstract values + @jit + def check_transforms(t1, t2): + return t1 == t2 + + assert check_transforms(transform, transform2)