From e431b6ea172ffbed41acc45b247bf1ef0a6eaba4 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Wed, 10 Jul 2024 18:42:10 +0100 Subject: [PATCH] Allow uneven interval in spline --- .../bijections/rational_quadratic_spline.py | 25 +++++++++------- flowjax/flows.py | 4 +-- flowjax/wrappers.py | 2 +- pyproject.toml | 2 +- .../test_rational_quadratic_spline.py | 30 ++++++++++++++----- 5 files changed, 40 insertions(+), 23 deletions(-) diff --git a/flowjax/bijections/rational_quadratic_spline.py b/flowjax/bijections/rational_quadratic_spline.py index 000a5a23..f7d71423 100644 --- a/flowjax/bijections/rational_quadratic_spline.py +++ b/flowjax/bijections/rational_quadratic_spline.py @@ -13,7 +13,7 @@ def _real_to_increasing_on_interval( arr: Float[Array, " dim"], - interval: float | int = 1, + interval: tuple[int | float, int | float], softmax_adjust: float = 1e-2, *, pad_with_ends: bool = True, @@ -35,10 +35,11 @@ def _real_to_increasing_on_interval( widths = jax.nn.softmax(arr) widths = (widths + softmax_adjust / widths.size) / (1 + softmax_adjust) widths = widths.at[0].set(widths[0] / 2) - pos = 2 * interval * jnp.cumsum(widths) - interval + scale = interval[1] - interval[0] + pos = interval[0] + scale * jnp.cumsum(widths) if pad_with_ends: - pos = jnp.pad(pos, pad_width=1, constant_values=(-interval, interval)) + pos = jnp.pad(pos, pad_width=1, constant_values=interval) return pos @@ -48,7 +49,8 @@ class RationalQuadraticSpline(AbstractBijection): Args: knots: Number of knots. - interval: interval to transform, [-interval, interval]. + interval: Interval to transform, if a scalar value, uses [-interval, interval], + if a tuple, uses [interval[0], interval[1]] min_derivative: Minimum dervivative. Defaults to 1e-3. softmax_adjust: Controls minimum bin width and height by rescaling softmax output, e.g. 0=no adjustment, 1=average softmax output with evenly spaced @@ -57,7 +59,7 @@ class RationalQuadraticSpline(AbstractBijection): """ knots: int - interval: float | int + interval: tuple[int | float, int | float] softmax_adjust: float | int min_derivative: float x_pos: Array | wrappers.AbstractUnwrappable[Array] @@ -70,11 +72,12 @@ def __init__( self, *, knots: int, - interval: float | int, + interval: float | int | tuple[int | float, int | float], min_derivative: float = 1e-3, softmax_adjust: float | int = 1e-2, ): self.knots = knots + interval = interval if isinstance(interval, tuple) else (-interval, interval) self.interval = interval self.softmax_adjust = softmax_adjust self.min_derivative = min_derivative @@ -96,7 +99,7 @@ def __init__( def transform(self, x, condition=None): # Following notation from the paper x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives - in_bounds = jnp.logical_and(x >= -self.interval, x <= self.interval) + in_bounds = jnp.logical_and(x >= self.interval[0], x <= self.interval[1]) x_robust = jnp.where(in_bounds, x, 0) # To avoid nans k = jnp.searchsorted(x_pos, x_robust) - 1 # k is bin number xi = (x_robust - x_pos[k]) / (x_pos[k + 1] - x_pos[k]) @@ -107,7 +110,7 @@ def transform(self, x, condition=None): y = yk + num / den # eq. 4 # avoid numerical precision issues transforming from in -> out of bounds - y = jnp.clip(y, -self.interval, self.interval) + y = jnp.clip(y, self.interval[0], self.interval[1]) return jnp.where(in_bounds, y, x) def transform_and_log_det(self, x, condition=None): @@ -118,7 +121,7 @@ def transform_and_log_det(self, x, condition=None): def inverse(self, y, condition=None): # Following notation from the paper x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives - in_bounds = jnp.logical_and(y >= -self.interval, y <= self.interval) + in_bounds = jnp.logical_and(y >= self.interval[0], y <= self.interval[1]) y_robust = jnp.where(in_bounds, y, 0) # To avoid nans k = jnp.searchsorted(y_pos, y_robust) - 1 xk, xk1, yk, yk1 = x_pos[k], x_pos[k + 1], y_pos[k], y_pos[k + 1] @@ -134,7 +137,7 @@ def inverse(self, y, condition=None): x = xi * (xk1 - xk) + xk # avoid numerical precision issues transforming from in -> out of bounds - x = jnp.clip(x, -self.interval, self.interval) + x = jnp.clip(x, self.interval[0], self.interval[1]) return jnp.where(in_bounds, x, y) def inverse_and_log_det(self, y, condition=None): @@ -146,7 +149,7 @@ def derivative(self, x) -> Array: """The derivative dy/dx of the forward transformation.""" # Following notation from the paper (eq. 5) x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives - in_bounds = jnp.logical_and(x >= -self.interval, x <= self.interval) + in_bounds = jnp.logical_and(x >= self.interval[0], x <= self.interval[1]) x_robust = jnp.where(in_bounds, x, 0) # To avoid nans k = jnp.searchsorted(x_pos, x_robust) - 1 xi = (x_robust - x_pos[k]) / (x_pos[k + 1] - x_pos[k]) diff --git a/flowjax/flows.py b/flowjax/flows.py index 1af3e70b..615b4af0 100644 --- a/flowjax/flows.py +++ b/flowjax/flows.py @@ -38,11 +38,11 @@ Vmap, ) from flowjax.distributions import AbstractDistribution, Transformed -from flowjax.wrappers import BijectionReparam, NonTrainable, WeightNormalization +from flowjax.wrappers import BijectionReparam, WeightNormalization, non_trainable def _affine_with_min_scale(min_scale: float = 1e-2) -> Affine: - scale_reparam = Chain([SoftPlus(), NonTrainable(Loc(min_scale))]) + scale_reparam = Chain([SoftPlus(), non_trainable(Loc(min_scale))]) return eqx.tree_at( where=lambda aff: aff.scale, pytree=Affine(), diff --git a/flowjax/wrappers.py b/flowjax/wrappers.py index f03b1275..ca1716e3 100644 --- a/flowjax/wrappers.py +++ b/flowjax/wrappers.py @@ -140,7 +140,7 @@ def _apply_inverse_and_check_valid(bijection, arr): jnp.logical_and(jnp.isfinite(arr), ~jnp.isfinite(param_inv)), "Non-finite value(s) introduced when reparameterizing. This suggests " "the parameter vector passed to BijectionReparam was incompatible with " - f"the bijection used for reparmeterizing ({type(bijection).__name__}).", + f"the bijection used for reparameterizing ({type(bijection).__name__}).", ) diff --git a/pyproject.toml b/pyproject.toml index 0fd0e4af..0830ac50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ license = { file = "LICENSE" } name = "flowjax" readme = "README.md" requires-python = ">=3.10" -version = "12.4.0" +version = "12.5.0" [project.urls] repository = "https://github.com/danielward27/flowjax" diff --git a/tests/test_bijections/test_rational_quadratic_spline.py b/tests/test_bijections/test_rational_quadratic_spline.py index 33e19ddc..905db3b2 100644 --- a/tests/test_bijections/test_rational_quadratic_spline.py +++ b/tests/test_bijections/test_rational_quadratic_spline.py @@ -6,12 +6,13 @@ from jax.tree_util import tree_map from flowjax.bijections import RationalQuadraticSpline +from flowjax.bijections.rational_quadratic_spline import _real_to_increasing_on_interval -def test_RationalQuadraticSpline_tails(): +@pytest.mark.parametrize("interval", [3, (-4, 5)]) +def test_RationalQuadraticSpline_tails(interval): key = jr.PRNGKey(0) - x = jnp.array([-20, 0.1, 2, 20]) - spline = RationalQuadraticSpline(knots=10, interval=3) + spline = RationalQuadraticSpline(knots=10, interval=interval) # Change to random initialisation, rather than identity. spline = tree_map( @@ -19,14 +20,27 @@ def test_RationalQuadraticSpline_tails(): spline, ) + x = jr.uniform(key, (5,), minval=spline.interval[0], maxval=spline.interval[1]) y = vmap(spline.transform)(x) - expected_changed = jnp.array([True, False, False, True]) # identity padding - assert ((jnp.abs(y - x) <= 1e-5) == expected_changed).all() + assert pytest.approx(x, abs=1e-5) != y + + # Outside interval, default to identity + x = jnp.array([spline.interval[0] - 1, spline.interval[1] + 1]) + y = vmap(spline.transform)(x) + assert pytest.approx(x, abs=1e-5) == y -def test_RationalQuadraticSpline_init(): +@pytest.mark.parametrize("interval", [3, (-4, 5)]) +def test_RationalQuadraticSpline_init(interval): # Test it is initialized at the identity - x = jnp.array([-1, 0.1, 2, 1]) - spline = RationalQuadraticSpline(knots=10, interval=3) + x = jnp.array([-7, 0.1, 2, 1]) + spline = RationalQuadraticSpline(knots=10, interval=interval) y = vmap(spline.transform)(x) assert pytest.approx(x, abs=1e-6) == y + + +def test_real_to_increasing_on_interval(): + y = _real_to_increasing_on_interval(jnp.array([-3.0, -4, 5, 0, 0]), (-3, 7)) + assert y.max() == 7 + assert y.min() == -3 + assert jnp.all(jnp.diff(y)) > 0