Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jittable transforms #1575

Merged
merged 15 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def __call__(self, x):
def codomain(self):
return _get_codomain(self.bijector)

def tree_flatten(self):
return self.bijector, ()

@classmethod
def tree_unflatten(cls, _, bijector):
return cls(bijector)


class BijectorTransform(Transform):
"""
Expand Down Expand Up @@ -106,6 +113,13 @@ def inverse_shape(self, shape):
batch_shape = shape[: len(shape) - len(out_event_shape)]
return batch_shape + in_shape

def tree_flatten(self):
return self.bijector, ()

@classmethod
def tree_unflatten(cls, _, bijector):
return cls(bijector)


@biject_to.register(BijectorConstraint)
def _transform_to_bijector_constraint(constraint):
Expand Down
125 changes: 114 additions & 11 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
import numpy as np

import jax.numpy
import jax.numpy as jnp
from jax.tree_util import register_pytree_node


class Constraint(object):
Expand All @@ -75,6 +77,10 @@ class Constraint(object):
is_discrete = False
event_dim = 0

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten)

def __call__(self, x):
raise NotImplementedError

Expand All @@ -94,8 +100,24 @@ def feasible_like(self, prototype):
"""
raise NotImplementedError

@classmethod
def tree_unflatten(cls, aux_data, params):
params_keys, aux_data = aux_data
self = cls.__new__(cls)
for k, v in zip(params_keys, params):
setattr(self, k, v)

for k, v in aux_data.items():
setattr(self, k, v)
return self


class ParameterFreeConstraint(Constraint):
def tree_flatten(self):
return (), ((), dict())


class _SingletonConstraint(Constraint):
class _SingletonConstraint(ParameterFreeConstraint):
"""
A constraint type which has only one canonical instance, like constraints.real,
and unlike constraints.interval.
Expand Down Expand Up @@ -202,8 +224,23 @@ def __call__(self, x=None, *, is_discrete=NotImplemented, event_dim=NotImplement
event_dim = self._event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)

def __eq__(self, other):
return (
type(self) is type(other)
and self._is_discrete == other._is_discrete
and self._event_dim == other._event_dim
)

def tree_flatten(self):
return (), (
(),
dict(_is_discrete=self._is_discrete, _event_dim=self._event_dim),
)


class dependent_property(property, _Dependent):
# XXX: this should not need to be pytree-able since it simply wraps a method
# and thus is automatically present once the method's object is created
def __init__(
self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
):
Expand Down Expand Up @@ -243,8 +280,16 @@ def __repr__(self):
def feasible_like(self, prototype):
return jax.numpy.broadcast_to(self.lower_bound + 1, jax.numpy.shape(prototype))

def tree_flatten(self):
return (self.lower_bound,), (("lower_bound",), dict())

def __eq__(self, other):
if not isinstance(other, _GreaterThan):
return False
return jnp.array_equal(self.lower_bound, other.lower_bound)

class _Positive(_GreaterThan, _SingletonConstraint):

class _Positive(_SingletonConstraint, _GreaterThan):
def __init__(self):
super().__init__(0.0)
pierreglaser marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -301,6 +346,20 @@ def __repr__(self):
def feasible_like(self, prototype):
return self.base_constraint.feasible_like(prototype)

def tree_flatten(self):
return (self.base_constraint,), (
("base_constraint",),
{"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims},
)

def __eq__(self, other):
if not isinstance(other, _IndependentConstraint):
return False

return (self.base_constraint == other.base_constraint) & (
self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims
)


class _RealVector(_IndependentConstraint, _SingletonConstraint):
def __init__(self):
Expand All @@ -327,6 +386,14 @@ def __repr__(self):
def feasible_like(self, prototype):
return jax.numpy.broadcast_to(self.upper_bound - 1, jax.numpy.shape(prototype))

def tree_flatten(self):
return (self.upper_bound,), (("upper_bound",), dict())

def __eq__(self, other):
if not isinstance(other, _LessThan):
return False
return jnp.array_equal(self.upper_bound, other.upper_bound)


class _IntegerInterval(Constraint):
is_discrete = True
Expand All @@ -348,6 +415,20 @@ def __repr__(self):
def feasible_like(self, prototype):
return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype))

def tree_flatten(self):
return (self.lower_bound, self.upper_bound), (
("lower_bound", "upper_bound"),
dict(),
)

def __eq__(self, other):
if not isinstance(other, _IntegerInterval):
return False

return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal(
self.upper_bound, other.upper_bound
)


class _IntegerGreaterThan(Constraint):
is_discrete = True
Expand All @@ -366,13 +447,21 @@ def __repr__(self):
def feasible_like(self, prototype):
return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype))

def tree_flatten(self):
return (self.lower_bound,), (("lower_bound",), dict())

class _IntegerPositive(_IntegerGreaterThan, _SingletonConstraint):
def __eq__(self, other):
if not isinstance(other, _IntegerGreaterThan):
return False
return jnp.array_equal(self.lower_bound, other.lower_bound)


class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan):
def __init__(self):
super().__init__(1)


class _IntegerNonnegative(_IntegerGreaterThan, _SingletonConstraint):
class _IntegerNonnegative(_SingletonConstraint, _IntegerGreaterThan):
def __init__(self):
super().__init__(0)

Expand All @@ -398,19 +487,25 @@ def feasible_like(self, prototype):
)

def __eq__(self, other):
return (
isinstance(other, _Interval)
and self.lower_bound == other.lower_bound
and self.upper_bound == other.upper_bound
if not isinstance(other, _Interval):
return False
return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal(
self.upper_bound, other.upper_bound
)

def tree_flatten(self):
return (self.lower_bound, self.upper_bound), (
("lower_bound", "upper_bound"),
dict(),
)

class _Circular(_Interval, _SingletonConstraint):

class _Circular(_SingletonConstraint, _Interval):
def __init__(self):
super().__init__(-math.pi, math.pi)


class _UnitInterval(_Interval, _SingletonConstraint):
class _UnitInterval(_SingletonConstraint, _Interval):
def __init__(self):
super().__init__(0.0, 1.0)

Expand Down Expand Up @@ -462,6 +557,14 @@ def feasible_like(self, prototype):
value = jax.numpy.pad(jax.numpy.expand_dims(self.upper_bound, -1), pad_width)
return jax.numpy.broadcast_to(value, prototype.shape)

def tree_flatten(self):
return (self.upper_bound,), (("upper_bound",), dict())

def __eq__(self, other):
if not isinstance(other, _Multinomial):
return False
return jnp.array_equal(self.upper_bound, other.upper_bound)


class _L1Ball(_SingletonConstraint):
"""
Expand Down Expand Up @@ -546,7 +649,7 @@ def feasible_like(self, prototype):
return jax.numpy.full_like(prototype, 1 / prototype.shape[-1])


class _SoftplusPositive(_GreaterThan, _SingletonConstraint):
class _SoftplusPositive(_SingletonConstraint, _GreaterThan):
def __init__(self):
super().__init__(lower_bound=0.0)

Expand Down
24 changes: 24 additions & 0 deletions numpyro/distributions/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,21 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
log_scale = intermediates
return log_scale.sum(-1)

def tree_flatten(self):
return (self.log_scale_min_clip, self.log_scale_max_clip), (
("log_scale_min_clip", "log_scale_max_clip"),
{"arn": self.arn},
)

def __eq__(self, other):
if not isinstance(other, InverseAutoregressiveTransform):
return False
return (
(self.arn is other.arn)
& jnp.array_equal(self.log_scale_min_clip, other.log_scale_min_clip)
& jnp.array_equal(self.log_scale_max_clip, other.log_scale_max_clip)
)


class BlockNeuralAutoregressiveTransform(Transform):
"""
Expand Down Expand Up @@ -139,3 +154,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
else:
logdet = intermediates
return logdet.sum(-1)

def tree_flatten(self):
return (), ((), {"bn_arn": self.bn_arn})

def __eq__(self, other):
return (
isinstance(other, BlockNeuralAutoregressiveTransform)
and self.bn_arn is other.bn_arn
)
Loading