From e0408226b8ccab9613261692fd9c3ee706bfbe10 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Wed, 4 Dec 2024 12:48:27 +0000 Subject: [PATCH 1/2] Remove fit_to_variational_target and rename Partial --- flowjax/bijections/__init__.py | 4 +- flowjax/bijections/utils.py | 9 +-- flowjax/train/__init__.py | 3 +- flowjax/train/variational_fit.py | 81 ------------------- tests/test_bijections/test_bijection_utils.py | 4 +- tests/test_train/test_variational_fit.py | 42 ---------- 6 files changed, 9 insertions(+), 134 deletions(-) delete mode 100644 flowjax/train/variational_fit.py delete mode 100644 tests/test_train/test_variational_fit.py diff --git a/flowjax/bijections/__init__.py b/flowjax/bijections/__init__.py index fc9e7002..916e1219 100644 --- a/flowjax/bijections/__init__.py +++ b/flowjax/bijections/__init__.py @@ -19,9 +19,9 @@ EmbedCondition, Flip, Identity, + Indexed, Invert, NumericalInverse, - Partial, Permute, Reshape, ) @@ -42,7 +42,7 @@ "LeakyTanh", "Loc", "MaskedAutoregressive", - "Partial", + "Indexed", "Permute", "Power", "Planar", diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index 3daa806b..50f26a04 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -102,14 +102,13 @@ def inverse_and_log_det(self, y, condition=None): return jnp.flip(y), jnp.array(0) -class Partial(AbstractBijection): # TODO rename to avoid confusion with functools +class Indexed(AbstractBijection): """Applies bijection to specific indices of an input. Args: - bijection: Bijection that is compatible with the subset - of x indexed by idxs. idxs: Indices (Integer, a slice, or an ndarray - with integer/bool dtype) of the transformed portion. - idxs: The indexes to transform. + bijection: Bijection that is compatible with the subset of x indexed by idxs. + idxs: Indices (Integer, a slice, or an ndarray with integer/bool dtype) of the + transformed portion. shape: Shape of the bijection. Defaults to None. """ diff --git a/flowjax/train/__init__.py b/flowjax/train/__init__.py index 5bef3da1..700e29ca 100644 --- a/flowjax/train/__init__.py +++ b/flowjax/train/__init__.py @@ -2,6 +2,5 @@ from .loops import fit_to_data, fit_to_key_based_loss from .train_utils import step -from .variational_fit import fit_to_variational_target -__all__ = ["fit_to_key_based_loss", "fit_to_data", "fit_to_variational_target", "step"] +__all__ = ["fit_to_key_based_loss", "fit_to_data", "step"] diff --git a/flowjax/train/variational_fit.py b/flowjax/train/variational_fit.py deleted file mode 100644 index c6f99f81..00000000 --- a/flowjax/train/variational_fit.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Basic training script for fitting a flow using variational inference.""" - -import warnings -from collections.abc import Callable - -import equinox as eqx -import jax.random as jr -import optax -import paramax -from jaxtyping import PRNGKeyArray, PyTree -from tqdm import tqdm - -from flowjax.train.train_utils import step - - -def fit_to_variational_target( - key: PRNGKeyArray, - dist: PyTree, # Custom losses may support broader types than AbstractDistribution - loss_fn: Callable, - *, - steps: int = 100, - learning_rate: float = 5e-4, - optimizer: optax.GradientTransformation | None = None, - return_best: bool = True, - show_progress: bool = True, -) -> tuple[PyTree, list]: - """Train a distribution (e.g. a flow) by variational inference. - - Args: - key: Jax key. - dist: Distribution object, trainable parameters are found using - equinox.is_inexact_array. - loss_fn: The loss function to optimize (e.g. the ElboLoss). - steps: The number of training steps to run. Defaults to 100. - learning_rate: Learning rate. Defaults to 5e-4. - optimizer: Optax optimizer. If provided, this overrides the default Adam - optimizer, and the learning_rate is ignored. Defaults to None. - return_best: Whether the result should use the parameters where the minimum loss - was reached (when True), or the parameters after the last update (when - False). Defaults to True. - show_progress: Whether to show progress bar. Defaults to True. - - Returns: - A tuple containing the trained distribution and the losses. - """ - warnings.warn( - "This function will be deprecated in 17.0.0. Please switch to using " - "``flowjax.train.loops.fit_to_key_based_loss``.", - DeprecationWarning, - stacklevel=2, - ) # TODO deprecate - if optimizer is None: - optimizer = optax.adam(learning_rate) - - params, static = eqx.partition( - dist, - eqx.is_inexact_array, - is_leaf=lambda leaf: isinstance(leaf, paramax.NonTrainable), - ) - opt_state = optimizer.init(params) - - losses = [] - - best_params = params - keys = tqdm(jr.split(key, steps), disable=not show_progress) - - for key in keys: - params, opt_state, loss = step( - params, - static, - key=key, - optimizer=optimizer, - opt_state=opt_state, - loss_fn=loss_fn, - ) - losses.append(loss.item()) - keys.set_postfix({"loss": loss.item()}) - if loss.item() == min(losses): - best_params = params - params = best_params if return_best else params - return eqx.combine(params, static), losses diff --git a/tests/test_bijections/test_bijection_utils.py b/tests/test_bijections/test_bijection_utils.py index 5081c410..f47afd64 100644 --- a/tests/test_bijections/test_bijection_utils.py +++ b/tests/test_bijections/test_bijection_utils.py @@ -2,7 +2,7 @@ import pytest from equinox import EquinoxRuntimeError -from flowjax.bijections import Affine, Partial, Permute +from flowjax.bijections import Affine, Indexed, Permute test_cases = { # name: idx, expected @@ -24,7 +24,7 @@ def test_partial(idx, expected): "Check values only change where we expect." x = jnp.zeros(4) shape = x[idx].shape - bijection = Partial(Affine(jnp.ones(shape)), idx, x.shape) + bijection = Indexed(Affine(jnp.ones(shape)), idx, x.shape) y = bijection.transform(x) assert jnp.all((x != y) == expected) diff --git a/tests/test_train/test_variational_fit.py b/tests/test_train/test_variational_fit.py deleted file mode 100644 index b13531c8..00000000 --- a/tests/test_train/test_variational_fit.py +++ /dev/null @@ -1,42 +0,0 @@ -import equinox as eqx -import jax.numpy as jnp -import jax.random as jr -import pytest - -from flowjax.distributions import Normal, StandardNormal -from flowjax.train.losses import ElboLoss -from flowjax.train.variational_fit import fit_to_variational_target - -test_shapes = [(), (2,), (2, 3, 4)] - - -@pytest.mark.parametrize("shape", test_shapes) -def test_elbo_loss(shape): - "Check finite scaler loss." - target = StandardNormal(shape) - vi_dist = StandardNormal(shape) - loss = ElboLoss(target.log_prob, num_samples=100) - loss_val = loss(*eqx.partition(vi_dist, eqx.is_inexact_array), jr.key(0)) - assert loss_val.shape == () # expect scalar loss - assert jnp.isfinite(loss_val) # expect finite loss - - -@pytest.mark.parametrize("shape", test_shapes) -def test_fit_to_variational_target(shape): - "Check that loss decreases." - vi_dist = Normal(jnp.ones(shape)) - target_dist = StandardNormal(shape) - - loss = ElboLoss(target_dist.log_prob, 50) - - vi_dist, losses = fit_to_variational_target( - key=jr.key(0), - dist=vi_dist, - loss_fn=loss, - show_progress=False, - learning_rate=0.1, - ) - # We expect the loss to be decreasing - start, end = jnp.split(jnp.array(losses), 2) - assert jnp.mean(start) > jnp.mean(end) - assert isinstance(losses[0], float) From 4159b8b73c9b0ff3d3dbd13ff9b7f2ad1700d2cf Mon Sep 17 00:00:00 2001 From: danielward27 Date: Wed, 4 Dec 2024 12:53:33 +0000 Subject: [PATCH 2/2] Update tests --- tests/test_bijections/test_bijections.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_bijections/test_bijections.py b/tests/test_bijections/test_bijections.py index 258ca62f..24de4d96 100644 --- a/tests/test_bijections/test_bijections.py +++ b/tests/test_bijections/test_bijections.py @@ -20,11 +20,11 @@ Exp, Flip, Identity, + Indexed, LeakyTanh, Loc, MaskedAutoregressive, NumericalInverse, - Partial, Permute, Planar, Power, @@ -54,14 +54,14 @@ "Permute (3D)": lambda: Permute( jnp.reshape(jr.permutation(KEY, jnp.arange(2 * 3 * 4)), (2, 3, 4)), ), - "Partial (int)": lambda: Partial(Affine(jnp.array(2), jnp.array(2)), 0, (DIM,)), - "Partial (bool array)": lambda: Partial( + "Indexed (int)": lambda: Indexed(Affine(jnp.array(2), jnp.array(2)), 0, (DIM,)), + "Indexed (bool array)": lambda: Indexed( Flip((2,)), jnp.array([True, False, True]), (DIM,), ), - "Partial (int array)": lambda: Partial(Flip((2,)), jnp.array([0, 2]), (DIM,)), - "Partial (slice)": lambda: Partial(Affine(jnp.zeros(2)), slice(0, 2), (DIM,)), + "Indexed (int array)": lambda: Indexed(Flip((2,)), jnp.array([0, 2]), (DIM,)), + "Indexed (slice)": lambda: Indexed(Affine(jnp.zeros(2)), slice(0, 2), (DIM,)), "Affine": lambda: Affine(jnp.ones(DIM), jnp.full(DIM, 2)), "Affine (pos and neg scales)": lambda: eqx.tree_at( lambda aff: aff.scale, Affine(scale=jnp.ones(3)), jnp.array([-1, 1, -2])