diff --git a/docs/api/bijections.rst b/docs/api/bijections.rst index 9d4859a1..6157d0ad 100644 --- a/docs/api/bijections.rst +++ b/docs/api/bijections.rst @@ -5,3 +5,4 @@ Bijections :members: :undoc-members: :show-inheritance: + :member-order: groupwise diff --git a/docs/api/losses.rst b/docs/api/losses.rst index 083b993f..f9b8965f 100644 --- a/docs/api/losses.rst +++ b/docs/api/losses.rst @@ -5,4 +5,3 @@ Loss functions from ``flowjax.train.losses``. .. automodule:: flowjax.train.losses :members: :undoc-members: - :show-inheritance: diff --git a/docs/api/training.rst b/docs/api/training.rst index 82f0abf8..88383c91 100644 --- a/docs/api/training.rst +++ b/docs/api/training.rst @@ -9,4 +9,4 @@ corresponding conditioning variables if appropriate), we can use ``fit_to_data`` Alternatively, we can use ``fit_to_variational_target`` to fit the flow to a function using variational inference. -.. autofunction:: flowjax.train.fit_to_variational_target \ No newline at end of file +.. autofunction:: flowjax.train.fit_to_variational_target diff --git a/docs/conf.py b/docs/conf.py index e380a354..8531079a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,7 +1,11 @@ """Configuration file for the Sphinx documentation builder.""" + +import builtins import sys from pathlib import Path +builtins.GENERATING_DOCUMENTATION = True # For processing ArrayLike + import jax # noqa Required to avoid circular import sys.path.insert(0, Path("..").resolve()) @@ -22,17 +26,23 @@ extensions = [ "sphinx.ext.viewcode", "sphinx.ext.autodoc", - "sphinx.ext.napoleon", "sphinx.ext.doctest", + "sphinx.ext.intersphinx", "nbsphinx", "sphinx_copybutton", + "sphinx.ext.napoleon", + "sphinx_autodoc_typehints", ] +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "jax": ("https://jax.readthedocs.io/en/latest/", None), +} + templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] -add_module_names = False -napoleon_include_init_with_doc = False +# napoleon_include_init_with_doc = False # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output @@ -50,8 +60,12 @@ } pygments_style = "xcode" -autodoc_typehints = "none" -autodoc_member_order = "bysource" copybutton_prompt_text = r">>> |\.\.\. " copybutton_prompt_is_regexp = True + +napolean_use_rtype = False +napoleon_attr_annotations = True + +autodoc_type_aliases = {"ArrayLike": "ArrayLike"} +add_module_names = False diff --git a/docs/faq.rst b/docs/faq.rst index f6eb3868..749ad64d 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -4,7 +4,7 @@ FAQ Freezing parameters ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Often it is useful to not train particular parameters. To achieve this we can provide a -``filter_spec`` to :py:func:`~flowjax.train.data_fit.fit_to_data`. For example, to avoid +``filter_spec`` to :func:`~flowjax.train.fit_to_data`. For example, to avoid training the base distribution, we could create a ``filter_spec`` as follows .. testsetup:: diff --git a/flowjax/__init__.py b/flowjax/__init__.py index 05b31752..b528404b 100644 --- a/flowjax/__init__.py +++ b/flowjax/__init__.py @@ -1,5 +1,4 @@ """flowjax - Basic flowjax implementation in jax.""" - from importlib.metadata import version __version__ = version("flowjax") diff --git a/flowjax/_custom_types.py b/flowjax/_custom_types.py new file mode 100644 index 00000000..44130548 --- /dev/null +++ b/flowjax/_custom_types.py @@ -0,0 +1,13 @@ +# We do this for now due to an incompatibility between equinox abstract class +# extensions and the documentation generator sphinx +# https://github.com/patrick-kidger/equinox/issues/591. This will likely be fixable with +# https://peps.python.org/pep-0649/ in python 3.13 +import builtins + +if getattr(builtins, "GENERATING_DOCUMENTATION", False): + + class ArrayLike: + pass + +else: + from jaxtyping import ArrayLike # noqa: F401 diff --git a/flowjax/bijections/affine.py b/flowjax/bijections/affine.py index 2559d906..47e2e288 100644 --- a/flowjax/bijections/affine.py +++ b/flowjax/bijections/affine.py @@ -1,4 +1,5 @@ """Affine bijections.""" +from __future__ import annotations from collections.abc import Callable from typing import ClassVar @@ -20,12 +21,11 @@ class Affine(AbstractBijection): ``loc`` and ``scale`` should broadcast to the desired shape of the bijection. Args: - loc (ArrayLike): Location parameter. Defaults to 0. - scale (ArrayLike): Scale parameter. Defaults to 1. - positivity_constraint (AbstractBijection | None): Bijection with shape - matching the Affine bijection, that maps the scale parameter from an - unbounded domain to the positive domain. Defaults to - :class:`~flowjax.bijections.SoftPlus`. + loc: Location parameter. Defaults to 0. + scale: Scale parameter. Defaults to 1. + positivity_constraint: Bijection with shape matching the Affine bijection, that + maps the scale parameter from an unbounded domain to the positive domain. + Defaults to :class:`~flowjax.bijections.SoftPlus`. """ shape: tuple[int, ...] @@ -77,17 +77,16 @@ class TriangularAffine(AbstractBijection): triangular matrix, and :math:`b` is the bias vector. Args: - loc (ArrayLike): Location parameter. If this is scalar, it is broadcast to the - dimension inferred from arr. - arr (ArrayLike): Triangular matrix. - lower (bool): Whether the mask should select the lower or upper - triangular matrix (other elements ignored). Defaults to True (lower). - weight_normalisation (bool): If true, carry out weight normalisation. - positivity_constraint (AbstractBijection): Bijection with shape matching the - dimension of the triangular affine bijection, that maps the diagonal - entries of the array from an unbounded domain to the positive domain. - Also used for weight normalisation parameters, if used. Defaults to - SoftPlus. + loc: Location parameter. If this is scalar, it is broadcast to the dimension + inferred from arr. + arr: Triangular matrix. + lower: Whether the mask should select the lower or upper + triangular matrix (other elements ignored). Defaults to True (lower). + weight_normalisation: If true, carry out weight normalisation. + positivity_constraint: Bijection with shape matching the dimension of the + triangular affine bijection, that maps the diagonal entries of the array + from an unbounded domain to the positive domain. Also used for weight + normalisation parameters, if used. Defaults to SoftPlus. """ shape: tuple[int, ...] cond_shape: ClassVar[None] = None @@ -181,11 +180,11 @@ class AdditiveCondition(AbstractBijection): module with trainable parameters. Args: - module (Callable[[ArrayLike], ArrayLike]): A callable (e.g. a function or - callable module) that maps array with shape cond_shape, to a shape - that is broadcastable with the shape of the bijection. - shape (tuple[int, ...]): The shape of the bijection. - cond_shape (tuple[int, ...]): The condition shape of the bijection. + module: A callable (e.g. a function or callable module) that maps array with + shape cond_shape, to a shape that is broadcastable with the shape of the + bijection. + shape: The shape of the bijection. + cond_shape: The condition shape of the bijection. Example: Conditioning using a linear transformation diff --git a/flowjax/bijections/bijection.py b/flowjax/bijections/bijection.py index 97842cfa..d1f7cc8e 100644 --- a/flowjax/bijections/bijection.py +++ b/flowjax/bijections/bijection.py @@ -6,14 +6,13 @@ bijection can be used to invert the orientation if a fast inverse is desired (e.g. maximum likelihood fitting of flows). """ - import functools from abc import abstractmethod import equinox as eqx from jax import Array -from jax.typing import ArrayLike +from flowjax._custom_types import ArrayLike from flowjax.utils import arraylike_to_array @@ -98,10 +97,10 @@ def transform(self, x: ArrayLike, condition: ArrayLike | None = None) -> Array: """Apply the forward transformation. Args: - x (ArrayLike): Input with shape matching bijections.shape. - condition (ArrayLike | None, optional): Condition, with shape matching - bijection.cond_shape, required for conditional bijections. Defaults to - None. + x: Input with shape matching ``bijections.shape``. + condition: Condition, with shape matching ``bijection.cond_shape``, required + for conditional bijections and ignored for unconditional bijections. + Defaults to None. """ @abstractmethod @@ -113,8 +112,8 @@ def transform_and_log_det( """Apply transformation and compute the log absolute Jacobian determinant. Args: - x (ArrayLike): Input with shape matching the bijections shape - condition (ArrayLike | None, optional): . Defaults to None. + x: Input with shape matching the bijections shape + condition: . Defaults to None. """ @abstractmethod @@ -122,10 +121,9 @@ def inverse(self, y: ArrayLike, condition: ArrayLike | None = None) -> Array: """Compute the inverse transformation. Args: - y (ArrayLike): Input array with shape matching bijection.shape - condition (ArrayLike | None, optional): Condition array with shape matching - bijection.cond_shape. Required for conditional bijections. Defaults to - None. + y: Input array with shape matching bijection.shape + condition: Condition array with shape matching bijection.cond_shape. + Required for conditional bijections. Defaults to None. """ @abstractmethod @@ -137,8 +135,7 @@ def inverse_and_log_det( """Inverse transformation and corresponding log absolute jacobian determinant. Args: - y (ArrayLike): Input array with shape matching bijection.shape. - condition (ArrayLike | None, optional): Condition array with shape matching - bijection.cond_shape. Required for conditional bijections. Defaults to - None. + y: Input array with shape matching bijection.shape. + condition: Condition array with shape matching bijection.cond_shape. + Required for conditional bijections. Defaults to None. """ diff --git a/flowjax/bijections/block_autoregressive_network.py b/flowjax/bijections/block_autoregressive_network.py index b6ff14d1..62130608 100644 --- a/flowjax/bijections/block_autoregressive_network.py +++ b/flowjax/bijections/block_autoregressive_network.py @@ -49,16 +49,15 @@ class BlockAutoregressiveNetwork(AbstractBijection): densities (see https://github.com/danielward27/flowjax/issues/102). Args: - key (KeyArray): Jax PRNGKey - dim (int): Dimension of the distribution. - cond_dim (tuple[int, ...] | None): Dimension of conditioning variables. - depth (int): Number of hidden layers in the network. - block_dim (int): Block dimension (hidden layer size is `dim*block_dim`). - activation: (Bijection | Callable | None). Activation function, either - a scalar bijection or a callable that computes the activation for a - scalar value. Note that the activation should be bijective - to ensure invertibility of the network and in general should map - real -> real to ensure that when transforming a distribution (either + key: Jax PRNGKey + dim: Dimension of the distribution. + cond_dim: Dimension of conditioning variables. + depth: Number of hidden layers in the network. + block_dim: Block dimension (hidden layer size is `dim*block_dim`). + activation: Activation function, either a scalar bijection or a callable that + computes the activation for a scalar value. Note that the activation should + be bijective to ensure invertibility of the network and in general should + map real -> real to ensure that when transforming a distribution (either with the forward or inverse), the map is defined across the support of the base distribution. Defaults to ``LeakyTanh(3)``. """ diff --git a/flowjax/bijections/chain.py b/flowjax/bijections/chain.py index a8fae7d2..dd2316c9 100644 --- a/flowjax/bijections/chain.py +++ b/flowjax/bijections/chain.py @@ -9,8 +9,8 @@ class Chain(AbstractBijection): """Chain together arbitrary bijections to form another bijection. Args: - bijections (Sequence[Bijection]): Sequence of bijections. The bijection - shapes must match, and any none None condition shapes must match. + bijections: Sequence of bijections. The bijection shapes must match, and any + none None condition shapes must match. """ shape: tuple[int, ...] diff --git a/flowjax/bijections/concatenate.py b/flowjax/bijections/concatenate.py index 05014271..2c7d3ad2 100644 --- a/flowjax/bijections/concatenate.py +++ b/flowjax/bijections/concatenate.py @@ -15,9 +15,8 @@ class Concatenate(AbstractBijection): See also :class:`Stack`. Args: - bijections (Sequence[Bijection]): Bijections, to stack into a single - bijection. - axis (int): Axis along which to stack. Defaults to 0. + bijections: Bijections, to stack into a single bijection. + axis: Axis along which to stack. Defaults to 0. """ shape: tuple[int, ...] @@ -94,8 +93,8 @@ class Stack(AbstractBijection): See also :class:`Concatenate`. Args: - bijections (list[Bijection]): Bijections. - axis (int): Axis along which to stack. Defaults to 0. + bijections: Bijections. + axis: Axis along which to stack. Defaults to 0. """ shape: tuple[int, ...] diff --git a/flowjax/bijections/coupling.py b/flowjax/bijections/coupling.py index fb90b4ac..3c0a5ff9 100644 --- a/flowjax/bijections/coupling.py +++ b/flowjax/bijections/coupling.py @@ -18,17 +18,15 @@ class Coupling(AbstractBijection): """Coupling layer implementation (https://arxiv.org/abs/1605.08803). Args: - key (KeyArray): Jax PRNGKey - transformer (AbstractBijection): Unconditional bijection with shape () - to be parameterised by the conditioner neural netork. - untransformed_dim (int): Number of untransformed conditioning variables - (e.g. dim // 2). - dim (int): Total dimension. - cond_dim (int | None): Dimension of additional conditioning variables. - nn_width (int): Neural network hidden layer width. - nn_depth (int): Neural network hidden layer size. - nn_activation (Callable): Neural network activation function. - Defaults to jnn.relu. + key: Jax PRNGKey + transformer: Unconditional bijection with shape () to be parameterised by the + conditioner neural netork. + untransformed_dim: Number of untransformed conditioning variables (e.g. dim//2). + dim: Total dimension. + cond_dim: Dimension of additional conditioning variables. + nn_width: Neural network hidden layer width. + nn_depth: Neural network hidden layer size. + nn_activation: Neural network activation function. Defaults to jnn.relu. """ shape: tuple[int, ...] diff --git a/flowjax/bijections/exp.py b/flowjax/bijections/exp.py index ba6a27e9..40e98f54 100644 --- a/flowjax/bijections/exp.py +++ b/flowjax/bijections/exp.py @@ -10,8 +10,7 @@ class Exp(AbstractBijection): """Elementwise exponential transform (forward) and log transform (inverse). Args: - shape (tuple[int, ...] | None): Shape of the bijection. - Defaults to None. + shape: Shape of the bijection. Defaults to (). """ shape: tuple[int, ...] = () diff --git a/flowjax/bijections/jax_transforms.py b/flowjax/bijections/jax_transforms.py index cb98a6ac..41a97b17 100644 --- a/flowjax/bijections/jax_transforms.py +++ b/flowjax/bijections/jax_transforms.py @@ -17,9 +17,9 @@ class Scan(AbstractBijection): to construct these using ``equinox.filter_vmap``. Args: - bijection (AbstractBijection): A bijection, in which the arrays leaves have - an additional leading axis to scan over. It is often can convenient to - create compatible bijections with ``equinox.filter_vmap``. + bijection: A bijection, in which the arrays leaves have an additional leading + axis to scan over. It is often can convenient to create compatible + bijections with ``equinox.filter_vmap``. Example: Below is equivilent to ``Chain([Affine(p) for p in params])``. @@ -92,16 +92,16 @@ class Vmap(AbstractBijection): """Applies vmap to bijection methods to add a batch dimension to the bijection. Args: - bijection (AbstractBijection): The bijection to vectorize. - in_axis (int | None | Callable): Specify which axes of the bijection - parameters to vectorise over. It should be a PyTree of ``None``, ``int`` - with the tree structure being a prefix of the bijection, or a callable - mapping ``Leaf -> Union[None, int]``. Defaults to None. - axis_size (int, optional): The size of the new axis. This should be left - unspecified if in_axis is provided, as the size can be inferred from the - bijection parameters. Defaults to None. - in_axis_condition (int | None, optional): Optionally define an axis of - the conditioning variable to vectorize over. Defaults to None. + bijection: The bijection to vectorize. + in_axis: Specify which axes of the bijection parameters to vectorise over. It + should be a PyTree of ``None``, ``int`` with the tree structure being a + prefix of the bijection, or a callable mapping ``Leaf -> Union[None, int]``. + Defaults to None. + axis_size: The size of the new axis. This should be left unspecified if in_axis + is provided, as the size can be inferred from the bijection parameters. + Defaults to None. + in_axis_condition: Optionally define an axis of the conditioning variable to + vectorize over. Defaults to None. Example: The two most common use cases, are shown below: diff --git a/flowjax/bijections/masked_autoregressive.py b/flowjax/bijections/masked_autoregressive.py index cf99f053..22a901f9 100644 --- a/flowjax/bijections/masked_autoregressive.py +++ b/flowjax/bijections/masked_autoregressive.py @@ -27,14 +27,14 @@ class MaskedAutoregressive(AbstractBijection): - https://arxiv.org/abs/1705.07057v4 Args: - key (KeyArray): Jax PRNGKey - transformer (AbstractBijection): Bijection with shape () to be parameterised - by the autoregressive network. - dim (int): Dimension. - cond_dim (int | None): Dimension of any conditioning variables. - nn_width (int): Neural network width. - nn_depth (int): Neural network depth. - nn_activation (Callable): Neural network activation. Defaults to jnn.relu. + key: Jax PRNGKey + transformer: Bijection with shape () to be parameterised by the autoregressive + network. + dim: Dimension. + cond_dim: Dimension of any conditioning variables. + nn_width: Neural network width. + nn_depth: Neural network depth. + nn_activation: Neural network activation. Defaults to jnn.relu. """ shape: tuple[int, ...] diff --git a/flowjax/bijections/planar.py b/flowjax/bijections/planar.py index de35faf1..d40afc27 100644 --- a/flowjax/bijections/planar.py +++ b/flowjax/bijections/planar.py @@ -24,9 +24,9 @@ class Planar(AbstractBijection): In the conditional case they are parameterised by an MLP. Args: - key (Array): Jax random seed. - dim (int): Dimension of the bijection. - cond_dim (int | None, optional): Dimension of extra conditioning variables. + key: Jax random seed. + dim: Dimension of the bijection. + cond_dim: Dimension of extra conditioning variables. Defaults to None. **mlp_kwargs: Key word arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None. diff --git a/flowjax/bijections/rational_quadratic_spline.py b/flowjax/bijections/rational_quadratic_spline.py index 47a78754..1a373b5e 100644 --- a/flowjax/bijections/rational_quadratic_spline.py +++ b/flowjax/bijections/rational_quadratic_spline.py @@ -15,13 +15,13 @@ class RationalQuadraticSpline(AbstractBijection): """Scalar RationalQuadraticSpline transformation (https://arxiv.org/abs/1906.04032). Args: - knots (int): Number of knots. - interval (float): interval to transform, [-interval, interval]. - min_derivative (float): Minimum dervivative. Defaults to 1e-3. - softmax_adjust (float): Controls minimum bin width and height by - rescaling softmax output, e.g. 0=no adjustment, 1=average softmax output - with evenly spaced widths, >1 promotes more evenly spaced widths. - See ``real_to_increasing_on_interval``. Defaults to 1e-2. + knots: Number of knots. + interval: interval to transform, [-interval, interval]. + min_derivative: Minimum dervivative. Defaults to 1e-3. + softmax_adjust: Controls minimum bin width and height by rescaling softmax + output, e.g. 0=no adjustment, 1=average softmax output with evenly spaced + widths, >1 promotes more evenly spaced widths. See + ``real_to_increasing_on_interval``. Defaults to 1e-2. """ shape: ClassVar[tuple] = () @@ -101,7 +101,6 @@ def transform_and_log_det(self, x, condition=None): def inverse(self, y, condition=None): # Following notation from the paper - # pylint: disable=C0103 x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives in_bounds = jnp.logical_and(y > -self.interval, y < self.interval) @@ -128,7 +127,6 @@ def inverse_and_log_det(self, y, condition=None): def derivative(self, x) -> Array: """The derivative dy/dx of the forward transformation.""" # Following notation from the paper (eq. 5) - # pylint: disable=C0103 x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives in_bounds = jnp.logical_and(x > -self.interval, x < self.interval) x_robust = jnp.where(in_bounds, x, 0) # To avoid nans diff --git a/flowjax/bijections/tanh.py b/flowjax/bijections/tanh.py index 79a0c3b6..4a975332 100644 --- a/flowjax/bijections/tanh.py +++ b/flowjax/bijections/tanh.py @@ -42,8 +42,8 @@ class LeakyTanh(AbstractBijection): so Tanh is not appropriate. Args: - max_val (float): Value above or below which the function becomes linear. - shape (tuple[int, ...] | None): The shape of the bijection. Defaults to (). + max_val: Value above or below which the function becomes linear. + shape: The shape of the bijection. Defaults to (). """ shape: tuple[int, ...] = () diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index 5dc9b2d4..58ed0ebb 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -1,4 +1,6 @@ """Utility bijections (embedding network, permutations, inversion etc.).""" +from __future__ import annotations + from collections.abc import Callable from typing import ClassVar @@ -8,6 +10,7 @@ from jax.typing import ArrayLike from flowjax.bijections.bijection import AbstractBijection +from flowjax.utils import arraylike_to_array class Invert(AbstractBijection): @@ -21,7 +24,7 @@ class Invert(AbstractBijection): achieve this aim. Args: - bijection (AbstractBijection): Bijection to invert. + bijection: Bijection to invert. """ bijection: AbstractBijection @@ -51,9 +54,9 @@ class Permute(AbstractBijection): """Permutation transformation. Args: - permutation (ArrayLike): An array with shape matching the array to - transform, with elements 0-(array.size-1) representing the new order - based on the flattened array (uses, C-like ordering). + permutation: An array with shape matching the array to transform, with elements + 0-(array.size-1) representing the new order based on the flattened array + (uses, C-like ordering). """ shape: tuple[int, ...] @@ -62,7 +65,7 @@ class Permute(AbstractBijection): inverse_permutation: tuple[Array, ...] def __init__(self, permutation: ArrayLike): - permutation = jnp.asarray(permutation) + permutation = arraylike_to_array(permutation) checkify.check( (permutation.ravel().sort() == jnp.arange(permutation.size)).all(), "Invalid permutation array provided.", @@ -97,8 +100,7 @@ class Flip(AbstractBijection): """Flip the input array. Condition argument is ignored. Args: - shape (tuple[int, ...]): The shape of the bijection. - Defaults to None. + shape: The shape of the bijection. Defaults to None. """ shape: tuple[int, ...] = () @@ -121,11 +123,11 @@ class Partial(AbstractBijection): """Applies bijection to specific indices of an input. Args: - bijection (AbstractBijection): Bijection that is compatible with the subset + bijection: Bijection that is compatible with the subset of x indexed by idxs. idxs: Indices (Integer, a slice, or an ndarray with integer/bool dtype) of the transformed portion. - idxs (int | slice | Array | tuple): The indexes to transform. - shape (tuple[int, ...] | None): Shape of the bijection. Defaults to None. + idxs: The indexes to transform. + shape: Shape of the bijection. Defaults to None. """ bijection: AbstractBijection @@ -166,7 +168,7 @@ class Identity(AbstractBijection): """The identity bijection. Args: - shape (tuple[int, ...]): The shape of the bijection. Defaults to (). + shape: The shape of the bijection. Defaults to (). """ shape: tuple[int, ...] = () @@ -192,12 +194,10 @@ class EmbedCondition(AbstractBijection): variable. The returned bijection has cond_dim equal to the raw condition size. Args: - bijection (AbstractBijection): Bijection with ``bijection.cond_dim`` equal - to the embedded size. - embedding_net (Callable): A callable (e.g. equinox module) that embeds a - conditioning variable to size ``bijection.cond_dim``. - raw_cond_shape (tuple[int, ...] | None): The dimension of the raw - conditioning variable. + bijection: Bijection with ``bijection.cond_dim`` equal to the embedded size. + embedding_net: A callable (e.g. equinox module) that embeds a conditioning + variable to size ``bijection.cond_dim``. + raw_cond_shape: The dimension of the raw conditioning variable. """ bijection: AbstractBijection diff --git a/flowjax/distributions.py b/flowjax/distributions.py index fc37a4c9..ed37367a 100644 --- a/flowjax/distributions.py +++ b/flowjax/distributions.py @@ -14,8 +14,8 @@ from jax.lax import stop_gradient from jax.numpy import linalg from jax.scipy import stats as jstats -from jax.typing import ArrayLike +from flowjax._custom_types import ArrayLike from flowjax.bijections import AbstractBijection, Affine, Chain, TriangularAffine from flowjax.utils import _get_ufunc_signature, arraylike_to_array, merge_cond_shapes @@ -30,16 +30,14 @@ class AbstractDistribution(eqx.Module): (1) Inherit from :class:`AbstractDistribution`. (2) Define the abstract attributes ``shape`` and ``cond_shape``. ``cond_shape`` should be ``None`` for unconditional distributions. - (3) Define the abstract methods :meth:`_sample` and :meth:`_log_prob`. + (3) Define the abstract methods `_sample` and `_log_prob`. See the source code for :class:`StandardNormal` for a simple concrete example. Attributes: - shape (AbstractVar[tuple[int, ...]]): Denotes the shape of a single sample from - the distribution. - cond_shape (AbstractVar[tuple[int, ...] | None]): The shape of an instance of - the conditioning variable. This should be None for unconditional - distributions. + shape: Tuple denoting the shape of a single sample from the distribution. + cond_shape: Tuple denoting the shape of an instance of the conditioning + variable. This should be None for unconditional distributions. """ @@ -74,8 +72,8 @@ def log_prob(self, x: ArrayLike, condition: ArrayLike | None = None) -> Array: Uses numpy-like broadcasting if additional leading dimensions are passed. Args: - x (ArrayLike): Points at which to evaluate density. - condition (ArrayLike | None): Conditioning variables. Defaults to None. + x: Points at which to evaluate density. + condition: Conditioning variables. Defaults to None. Returns: Array: Jax array of log probabilities. @@ -101,9 +99,9 @@ def sample( See the example for more information. Args: - key (Array): Jax random key. - condition (ArrayLike | None): Conditioning variables. Defaults to None. - sample_shape (tuple[int, ...]): Sample shape. Defaults to (). + key: Jax random key. + condition: Conditioning variables. Defaults to None. + sample_shape: Sample shape. Defaults to (). Example: The below example shows the behaviour of sampling, for an unconditional @@ -176,9 +174,9 @@ def sample_and_log_prob( more information. Args: - key (Array): Jax random key. - condition (ArrayLike | None): Conditioning variables. Defaults to None. - sample_shape (tuple[int, ...]): Sample shape. Defaults to (). + key: Jax random key. + condition: Conditioning variables. Defaults to None. + sample_shape: Sample shape. Defaults to (). """ if self.cond_shape is not None: condition = arraylike_to_array(condition, err_name="condition") @@ -253,6 +251,10 @@ class AbstractTransformed(AbstractDistribution): Concete implementations should subclass :class:`AbstractTransformed`, and define the abstract attributes ``base_dist`` and ``bijection``. See the source code for :class:`Normal` as a simple example. + + Attributes: + base_dist: The base distribution. + bijection: The transformation to apply. """ base_dist: AbstractVar[AbstractDistribution] @@ -332,8 +334,8 @@ class Transformed(AbstractTransformed): lead to to unexpected results. Args: - base_dist (AbstractDistribution): Base distribution. - bijection (AbstractBijection): Bijection to transform distribution. + base_dist: Base distribution. + bijection: Bijection to transform distribution. Example: .. doctest:: @@ -355,7 +357,7 @@ class StandardNormal(AbstractDistribution): Note unlike :class:`Normal`, this has no trainable parameters. Args: - shape (tuple[int, ...]): The shape of the distribution. Defaults to (). + shape: The shape of the distribution. Defaults to (). """ shape: tuple[int, ...] = () @@ -374,8 +376,8 @@ class Normal(AbstractTransformed): ``loc`` and ``scale`` should broadcast to the desired shape of the distribution. Args: - loc (ArrayLike): Means. Defaults to 0. - scale (ArrayLike): Standard deviations. Defaults to 1. + loc: Means. Defaults to 0. + scale: Standard deviations. Defaults to 1. """ base_dist: StandardNormal @@ -406,9 +408,9 @@ class MultivariateNormal(AbstractTransformed): matrix. Args: - loc (ArrayLike): The location/mean parameter vector. If this is scalar it is - broadcast to the dimension implied by the covariance matrix. - covariance (ArrayLike, optional): Covariance matrix. + loc: The location/mean parameter vector. If this is scalar it is broadcast to + the dimension implied by the covariance matrix. + covariance: Covariance matrix. """ base_dist: StandardNormal @@ -420,6 +422,7 @@ def __init__(self, loc: ArrayLike, covariance: ArrayLike): @property def loc(self): + """Location (mean) of the distribution.""" return self.bijection.loc @property @@ -446,8 +449,8 @@ class Uniform(AbstractTransformed): ``minval`` and ``maxval`` should broadcast to the desired distribution shape. Args: - minval (ArrayLike): Minimum values. - maxval (ArrayLike): Maximum values. + minval: Minimum values. + maxval: Maximum values. """ base_dist: _StandardUniform @@ -494,8 +497,8 @@ class Gumbel(AbstractTransformed): ``loc`` and ``scale`` should broadcast to the dimension of the distribution. Args: - loc (ArrayLike): Location paramter. - scale (ArrayLike): Scale parameter. Defaults to 1.0. + loc: Location paramter. + scale: Scale parameter. Defaults to 1.0. """ base_dist: _StandardGumbel @@ -540,8 +543,8 @@ class Cauchy(AbstractTransformed): ``loc`` and ``scale`` should broadcast to the dimension of the distribution. Args: - loc (ArrayLike): Location paramter. - scale (ArrayLike): Scale parameter. Defaults to 1.0. + loc: Location paramter. + scale: Scale parameter. Defaults to 1.0. """ base_dist: _StandardCauchy @@ -595,9 +598,9 @@ class StudentT(AbstractTransformed): ``df``, ``loc`` and ``scale`` broadcast to the dimension of the distribution. Args: - df (ArrayLike): The degrees of freedom. - loc (ArrayLike): Location parameter. Defaults to 0.0. - scale (ArrayLike): Scale parameter. Defaults to 1.0. + df: The degrees of freedom. + loc: Location parameter. Defaults to 0.0. + scale: Scale parameter. Defaults to 1.0. """ base_dist: _StandardStudentT @@ -632,11 +635,11 @@ class SpecializeCondition(AbstractDistribution): # TODO check tested of the class. Args: - dist (AbstractDistribution): Conditional distribution to specialize. - condition (ArrayLike, optional): Instance of conditioning variable with - shape matching ``dist.cond_shape``. Defaults to None. - stop_gradient (bool): Whether to use ``jax.lax.stop_gradient`` to prevent - training of the condition array. Defaults to True. + dist: Conditional distribution to specialize. + condition: Instance of conditioning variable with shape matching + ``dist.cond_shape``. Defaults to None. + stop_gradient: Whether to use ``jax.lax.stop_gradient`` to prevent training of + the condition array. Defaults to True. """ shape: tuple[int, ...] diff --git a/flowjax/experimental/numpyro.py b/flowjax/experimental/numpyro.py index 01b75895..380062dc 100644 --- a/flowjax/experimental/numpyro.py +++ b/flowjax/experimental/numpyro.py @@ -3,6 +3,7 @@ Note these utilities require `numpyro `_ to be installed. """ +from __future__ import annotations from collections.abc import Callable from typing import Any @@ -41,7 +42,7 @@ class _VectorizedBijection: """Wrap a flowjax bijection to support vectorization. Args: - bijection (AbstractBijection): flowjax bijection to be wrapped. + bijection: flowjax bijection to be wrapped. """ def __init__(self, bijection: AbstractBijection): @@ -78,15 +79,14 @@ def vectorize(self, func, *, log_det=False): class TransformedToNumpyro(numpyro.distributions.Distribution): - """Convert a :class:`Transformed` flowjax distribution to a numpyro distribution. + """Convert a flowjax transformed distribution to a numpyro distribution. We assume the support of the distribution is unbounded. Args: - dist (AbstractTransformed): The distribution. - condition (ArrayLike | None, optional): Conditioning variables. Any - leading batch dimensions will be converted to a batch dimension in - the numpyro distribution. Defaults to None. + dist: The flowjax distribution. + condition: Conditioning variables. Any leading batch dimensions will be + converted to batch dimensions in the numpyro distribution. Defaults to None. """ def __init__( @@ -150,13 +150,11 @@ def register_params( context to have an effect, e.g. within a numpyro model or guide function. Args: - name (str): Name for the parameter set. - model (PyTree): The pytree (e.g. an equinox module, flowjax distribution, - or a flowjax bijection). - filter_spec (Callable | PyTree): Equinox `filter_spec` for specifying trainable - parameters. Either a callable `leaf -> bool`, or a PyTree with prefix - structure matching `dist` with True/False values. Defaults to - `eqx.is_inexact_array`. + name: Name for the parameter set. + model: The pytree (e.g. an equinox module, flowjax distribution/bijection). + filter_spec: Equinox `filter_spec` for specifying trainable parameters. Either a + callable `leaf -> bool`, or a PyTree with prefix structure matching `dist` + with True/False values. Defaults to `eqx.is_inexact_array`. """ params, static = eqx.partition(model, filter_spec) diff --git a/flowjax/flows.py b/flowjax/flows.py index 5e536e56..22110f8c 100644 --- a/flowjax/flows.py +++ b/flowjax/flows.py @@ -53,19 +53,18 @@ def coupling_flow( """Create a coupling flow (https://arxiv.org/abs/1605.08803). Args: - key (Array): Jax random number generator key. - base_dist (AbstractDistribution): Base distribution, with ``base_dist.ndim==1``. - transformer (AbstractBijection): Bijection to be parameterised by - conditioner. Defaults to ``Affine()``. - cond_dim (int): Dimension of conditioning variables. Defaults to None. - flow_layers (int): Number of coupling layers. Defaults to 5. - nn_width (int): Conditioner hidden layer size. Defaults to 40. - nn_depth (int): Conditioner depth. Defaults to 2. - nn_activation (int): Conditioner activation function. Defaults to jnn.relu. - invert: (bool): Whether to invert the bijection. Broadly, True will - prioritise a faster `inverse` methods, leading to faster `log_prob`, - False will prioritise faster `transform` methods, leading to faster - `sample`. Defaults to True. + key: Jax random number generator key. + base_dist: Base distribution, with ``base_dist.ndim==1``. + transformer: Bijection to be parameterised by conditioner. Defaults to + ``Affine()``. + cond_dim: Dimension of conditioning variables. Defaults to None. + flow_layers: Number of coupling layers. Defaults to 8. + nn_width: Conditioner hidden layer size. Defaults to 50. + nn_depth: Conditioner depth. Defaults to 1. + nn_activation: Conditioner activation function. Defaults to jnn.relu. + invert: Whether to invert the bijection. Broadly, True will prioritise a faster + `inverse` methods, leading to faster `log_prob`, False will prioritise + faster `transform` methods, leading to faster `sample`. Defaults to True. """ transformer = Affine() if transformer is None else transformer dim = base_dist.shape[-1] @@ -108,18 +107,18 @@ def masked_autoregressive_flow( Refs: https://arxiv.org/abs/1606.04934; https://arxiv.org/abs/1705.07057v4. Args: - key (Array): Random seed. - base_dist (AbstractDistribution): Base distribution, with ``base_dist.ndim==1``. - transformer (AbstractBijection): Bijection parameterised by autoregressive - network. Defaults to ``Affine()``. - cond_dim (int): _description_. Defaults to 0. - flow_layers (int): Number of flow layers. Defaults to 5. - nn_width (int): Number of hidden layers in neural network. Defaults to 40. - nn_depth (int): Depth of neural network. Defaults to 2. - nn_activation (Callable): _description_. Defaults to jnn.relu. - invert (bool): Whether to invert the bijection. Broadly, True will - prioritise a faster inverse, leading to faster `log_prob`, False will - prioritise faster forward, leading to faster `sample`. Defaults to True. + key: Random seed. + base_dist: Base distribution, with ``base_dist.ndim==1``. + transformer: Bijection parameterised by autoregressive network. Defaults to + ``Affine()``. + cond_dim: Dimension of the conditioning variable. Defaults to None. + flow_layers: Number of flow layers. Defaults to 8. + nn_width: Number of hidden layers in neural network. Defaults to 50. + nn_depth: Depth of neural network. Defaults to 1. + nn_activation: _description_. Defaults to jnn.relu. + invert: Whether to invert the bijection. Broadly, True will prioritise a faster + inverse, leading to faster `log_prob`, False will prioritise faster forward, + leading to faster `sample`. Defaults to True. """ transformer = Affine() if transformer is None else transformer dim = base_dist.shape[-1] @@ -163,21 +162,18 @@ def block_neural_autoregressive_flow( controlled using the invert argument. Args: - key (Array): Jax PRNGKey. - base_dist (AbstractDistribution): Base distribution, with ``base_dist.ndim==1``. - cond_dim (int | None): Dimension of conditional variables. - nn_depth (int): Number of hidden layers within the networks. - Defaults to 1. - nn_block_dim (int): Block size. Hidden layer width is - dim*nn_block_dim. Defaults to 8. - flow_layers (int): Number of BNAF layers. Defaults to 1. - invert: (bool): Use `True` for access of ``log_prob`` only (e.g. - fitting by maximum likelihood), `False` for the forward direction - (``sample`` and ``sample_and_log_prob``) only (e.g. for fitting - variationally). - activation: (Bijection | Callable | None). Activation function used within - block neural autoregressive networks. Note this should be bijective and - in some use cases should map real -> real. For more information, see + key: Jax PRNGKey. + base_dist: Base distribution, with ``base_dist.ndim==1``. + cond_dim: Dimension of conditional variables. + nn_depth: Number of hidden layers within the networks. Defaults to 1. + nn_block_dim: Block size. Hidden layer width is dim*nn_block_dim. Defaults to 8. + flow_layers: Number of BNAF layers. Defaults to 1. + invert: Use `True` for access of ``log_prob`` only (e.g. fitting by maximum + likelihood), `False` for the forward direction (``sample`` and + ``sample_and_log_prob``) only (e.g. for fitting variationally). + activation: Activation function used within block neural autoregressive + networks. Note this should be bijective and in some use cases should map + real -> real. For more information, see :class:`~flowjax.bijections.block_autoregressive_network.BlockAutoregressiveNetwork`. Defaults to :class:`~flowjax.bijections.tanh.LeakyTanh`. """ @@ -215,14 +211,13 @@ def planar_flow( permutations. Note the definition here is inverted compared to the original paper. Args: - key (Array): Jax PRNGKey. - base_dist (AbstractDistribution): Base distribution, with ``base_dist.ndim==1``. - cond_dim (int): Dimension of conditioning variables. Defaults to None. - flow_layers (int): Number of flow layers. Defaults to 5. - invert: (bool): Whether to invert the bijection. Broadly, True will - prioritise a faster `inverse` methods, leading to faster `log_prob`, - False will prioritise faster `transform` methods, leading to faster - `sample`. Defaults to True + key: Jax PRNGKey. + base_dist: Base distribution, with ``base_dist.ndim==1``. + cond_dim: Dimension of conditioning variables. Defaults to None. + flow_layers: Number of flow layers. Defaults to 8. + invert: Whether to invert the bijection. Broadly, True will prioritise a faster + `inverse` methods, leading to faster `log_prob`, False will prioritise + faster `transform` methods, leading to faster `sample`. Defaults to True. **mlp_kwargs: Key word arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None. """ @@ -262,19 +257,17 @@ def triangular_spline_flow( transformations. Args: - key (Array): Jax random seed. - base_dist (AbstractDistribution): Base distribution, with ``base_dist.ndim==1``. - cond_dim (int | None): The number of conditioning features. - Defaults to None. - flow_layers (int): Number of flow layers. Defaults to 8. - knots (int): Number of knots in the splines. Defaults to 8. - tanh_max_val (float): Maximum absolute value beyond which we use linear - "tails" in the tanh function. Defaults to 3.0. - invert: (bool): Use `True` for access of `log_prob` only (e.g. - fitting by maximum likelihood), `False` for the forward direction - (sampling) only (e.g. for fitting variationally). - init (Callable | None): Initialisation method for the lower triangular - weights. Defaults to glorot_uniform(). + key: Jax random seed. + base_dist: Base distribution, with ``base_dist.ndim==1``. + cond_dim: The number of conditioning features. Defaults to None. + flow_layers: Number of flow layers. Defaults to 8. + knots: Number of knots in the splines. Defaults to 8. + tanh_max_val: Maximum absolute value beyond which we use linear "tails" in the + tanh function. Defaults to 3.0. + invert: Whether to invert the bijection before transforming the base + distribution. Defaults to True. + init: Initialisation method for the lower triangular weights. + Defaults to glorot_uniform(). """ init = init if init is not None else glorot_uniform() dim = base_dist.shape[-1] diff --git a/flowjax/masks.py b/flowjax/masks.py index a029e011..bb1a39e7 100644 --- a/flowjax/masks.py +++ b/flowjax/masks.py @@ -16,9 +16,9 @@ def rank_based_mask(in_ranks: Array, out_ranks: Array, *, eq: bool = False): """Forms mask matrix, with 1s where the out_ranks > or >= in_ranks. Args: - in_ranks (Array): Ranks of the inputs. - out_ranks (Array): Ranks of the outputs. - eq (bool): If true, compares with >= instead of >. Defaults to False. + in_ranks: Ranks of the inputs. + out_ranks: Ranks of the outputs. + eq: If true, compares with >= instead of >. Defaults to False. Returns: Array: Mask with shape `(len(out_ranks), len(in_ranks))` diff --git a/flowjax/nn/block_autoregressive.py b/flowjax/nn/block_autoregressive.py index 0193db50..d60d9786 100644 --- a/flowjax/nn/block_autoregressive.py +++ b/flowjax/nn/block_autoregressive.py @@ -17,13 +17,12 @@ class BlockAutoregressiveLinear(eqx.Module): conditioning variable) to the right of the block diagonal weight matrix. Args: - key (KeyArray): Random key - n_blocks (int): Number of diagonal blocks (dimension of original input). - block_shape (tuple): The shape of the (unconstrained) blocks. - cond_dim (int | None): Number of additional conditioning variables. - Defaults to None. - init (Callable | None): Default initialisation method for the weight - matrix. Defaults to ``glorot_uniform()``. + key: Random key + n_blocks: Number of diagonal blocks (dimension of original input). + block_shape: The shape of the (unconstrained) blocks. + cond_dim: Number of additional conditioning variables. Defaults to None. + init: Default initialisation method for the weight matrix. Defaults to + ``glorot_uniform()``. """ n_blocks: int diff --git a/flowjax/nn/masked_autoregressive.py b/flowjax/nn/masked_autoregressive.py index 4093e6c8..017543de 100644 --- a/flowjax/nn/masked_autoregressive.py +++ b/flowjax/nn/masked_autoregressive.py @@ -1,4 +1,6 @@ """Autoregressive linear layers and multilayer perceptron.""" +from __future__ import annotations + from collections.abc import Callable import jax.nn as jnn @@ -6,7 +8,6 @@ from equinox import Module from equinox.nn import Linear from jax import Array, random -from jax.random import KeyArray from jax.typing import ArrayLike from flowjax.masks import rank_based_mask @@ -20,15 +21,15 @@ class MaskedLinear(Module): """Masked linear neural network layer. Args: - mask (ArrayLike): Mask with shape (out_features, in_features). - key (KeyArray): Jax PRNGKey - use_bias (bool): Whether to include bias terms. Defaults to True. + mask: Mask with shape (out_features, in_features). + key: Jax random key. + use_bias: Whether to include bias terms. Defaults to True. """ linear: Linear mask: Array - def __init__(self, mask: ArrayLike, *, use_bias: bool = True, key: KeyArray): + def __init__(self, mask: ArrayLike, *, use_bias: bool = True, key: Array): mask = jnp.asarray(mask) self.linear = Linear(mask.shape[1], mask.shape[0], use_bias, key=key) self.mask = mask @@ -37,7 +38,7 @@ def __call__(self, x: ArrayLike): """Run the masked linear layer. Args: - x (ArrayLike): Array with shape ``(mask.shape[1], )`` + x: Array with shape ``(mask.shape[1], )`` """ x = jnp.asarray(x) x = self.linear.weight * self.mask @ x @@ -53,14 +54,13 @@ class AutoregressiveMLP(Module): nodes where in_ranks < out_ranks. Args: - in_ranks (ArrayLike): Ranks of the inputs. - hidden_ranks (ArrayLike): Ranks of the hidden layer(s). - out_ranks (ArrayLike): Ranks of the outputs. - depth (int): Number of hidden layers. - activation (Callable): Activation function. Defaults to jnn.relu. - final_activation (Callable): Final activation function. Defaults to - _identity. - key (KeyArray): Jax PRNGKey. + in_ranks: Ranks of the inputs. + hidden_ranks: Ranks of the hidden layer(s). + out_ranks: Ranks of the outputs. + depth: Number of hidden layers. + activation: Activation function. Defaults to jnn.relu. + final_activation: Final activation function. Defaults to _identity. + key: Jax PRNGKey. """ in_size: int diff --git a/flowjax/tasks.py b/flowjax/tasks.py index f2d2463b..4f92bb6f 100644 --- a/flowjax/tasks.py +++ b/flowjax/tasks.py @@ -1,4 +1,5 @@ """Example tasks.""" +from __future__ import annotations import equinox as eqx import jax @@ -8,6 +9,7 @@ from jax.typing import ArrayLike from flowjax.distributions import Uniform +from flowjax.utils import arraylike_to_array def two_moons(key: Array, n_samples, noise_std=0.2): @@ -40,7 +42,7 @@ def __init__(self, dim: int = 2, prior_bound: float = 10.0) -> None: @eqx.filter_jit def simulator(self, key: Array, theta: ArrayLike): """Carry out simulations.""" - theta = jnp.atleast_2d(jnp.asarray(theta)) + theta = jnp.atleast_2d(arraylike_to_array(theta)) key, subkey = jr.split(key) component = jr.bernoulli(subkey, shape=(theta.shape[0],)) scales = jnp.where(component, 0.1, 1) @@ -57,9 +59,9 @@ def sample_reference_posterior( """Sample the reference posterior given an observation. Uses the closed form solution with rejection sampling for samples outside prior - bound. + bound. """ - observation = jnp.asarray(observation) + observation = arraylike_to_array(observation) if observation.shape != (self.dim,): raise ValueError(f"Expected shape {(self.dim, )}, got {observation.shape}") diff --git a/flowjax/train/__init__.py b/flowjax/train/__init__.py index 2b133010..575ba878 100644 --- a/flowjax/train/__init__.py +++ b/flowjax/train/__init__.py @@ -1,5 +1,4 @@ """Utilities for training flows, fitting to samples or ysing variational inference.""" - from .data_fit import fit_to_data from .variational_fit import fit_to_variational_target diff --git a/flowjax/train/data_fit.py b/flowjax/train/data_fit.py index edef974d..dc7fdbc2 100644 --- a/flowjax/train/data_fit.py +++ b/flowjax/train/data_fit.py @@ -1,4 +1,6 @@ """Function to fit flows to samples from a distribution.""" +from __future__ import annotations + from collections.abc import Callable from typing import Any @@ -26,7 +28,7 @@ def fit_to_data( dist: PyTree, x: ArrayLike, *, - condition: ArrayLike = None, + condition: ArrayLike | None = None, loss_fn: Callable | None = None, max_epochs: int = 100, max_patience: int = 5, @@ -45,25 +47,26 @@ def fit_to_data( non-distribution pytrees as long as a compatible loss function is provided. Args: - key (KeyArray): Jax random seed. - dist (PyTree): The distribution to train. - x (ArrayLike): Samples from target distribution. - condition (ArrayLike | None): Conditioning variables. Defaults to None. - loss_fn (Callable | None): Loss function. Defaults to MaximumLikelihoodLoss. - max_epochs (int): Maximum number of epochs. Defaults to 100. - max_patience (int): Number of consecutive epochs with no validation - loss improvement after which training is terminated. Defaults to 5. - batch_size (int): Batch size. Defaults to 100. - val_prop (float): Proportion of data to use in validation set. Defaults to 0.1. - learning_rate (float): Adam learning rate. Defaults to 5e-4. - optimizer (optax.GradientTransformation): Optax optimizer. If provided, this - overrides the default Adam optimizer, and the learning_rate is ignored. - Defaults to None. - filter_spec (Callable | PyTree): Equinox `filter_spec` for specifying trainable - parameters. Either a callable `leaf -> bool`, or a PyTree with prefix - structure matching `dist` with True/False values. Defaults to - `eqx.is_inexact_array`. - show_progress (bool): Whether to show progress bar. Defaults to True. + key: Jax random seed. + dist: The distribution to train. + x: Samples from target distribution. + condition: Conditioning variables. Defaults to None. + loss_fn: Loss function. Defaults to MaximumLikelihoodLoss. + max_epochs: Maximum number of epochs. Defaults to 100. + max_patience: Number of consecutive epochs with no validation loss improvement + after which training is terminated. Defaults to 5. + batch_size: Batch size. Defaults to 100. + val_prop: Proportion of data to use in validation set. Defaults to 0.1. + learning_rate: Adam learning rate. Defaults to 5e-4. + optimizer: Optax optimizer. If provided, this overrides the default Adam + optimizer, and the learning_rate is ignored. Defaults to None. + filter_spec: Equinox `filter_spec` for specifying trainable parameters. Either a + callable `leaf -> bool`, or a PyTree with prefix structure matching `dist` + with True/False values. Defaults to `eqx.is_inexact_array`. + show_progress: Whether to show progress bar. Defaults to True. + + Returns: + A tuple containing the trained distribution and the losses. """ data = (x,) if condition is None else (x, condition) data = tuple(jnp.asarray(a) for a in data) diff --git a/flowjax/train/losses.py b/flowjax/train/losses.py index 3c93cb39..34e224ee 100644 --- a/flowjax/train/losses.py +++ b/flowjax/train/losses.py @@ -3,6 +3,8 @@ The loss functions are callables, with the first two arguments being the partitioned distribution (see equinox.partition). """ +from __future__ import annotations + from collections.abc import Callable import equinox as eqx @@ -46,15 +48,15 @@ class ContrastiveLoss: variable (the simulator parameters), and ``condition`` for the conditioning variable (the simulator output/oberved data). - References: - - https://arxiv.org/abs/1905.07488 - - https://arxiv.org/abs/2002.03712 - Args: - prior (AbstractDistribution): The prior distribution over x (the target + prior: The prior distribution over x (the target variable). - n_contrastive (int): The number of contrastive samples/atoms to use when + n_contrastive: The number of contrastive samples/atoms to use when computing the loss. + + References: + - https://arxiv.org/abs/1905.07488 + - https://arxiv.org/abs/2002.03712 """ def __init__(self, prior: AbstractDistribution, n_contrastive: int): @@ -96,17 +98,15 @@ class ElboLoss: """The negative evidence lower bound (ELBO), approximated using samples. Args: - num_samples (int): Number of samples to use in the ELBO approximation. - target (Callable[[ArrayLike], Array]): The target, i.e. log posterior - density up to an additive constant / the negative of the potential - function, evaluated for a single point. - stick_the_landing (bool): Whether to use the (often) lower variance ELBO - gradient estimator introduced in https://arxiv.org/pdf/1703.09194.pdf. - Note for flows this requires evaluating the flow in both directions - (running the forward and inverse transformation). For some flow - architectures, this may be computationally expensive due to assymetrical - computational complexity between the forward and inverse transformation. - Defaults to False. + num_samples: Number of samples to use in the ELBO approximation. + target: The target, i.e. log posterior density up to an additive constant / the + negative of the potential function, evaluated for a single point. + stick_the_landing: Whether to use the (often) lower variance ELBO gradient + estimator introduced in https://arxiv.org/pdf/1703.09194.pdf. Note for flows + this requires evaluating the flow in both directions (running the forward + and inverse transformation). For some flow architectures, this may be + computationally expensive due to assymetrical computational complexity + between the forward and inverse transformation. Defaults to False. """ target: Callable[[ArrayLike], Array] @@ -134,9 +134,9 @@ def __call__( """Compute the ELBO loss. Args: - params (AbstractDistribution): The trainable parameters of the model. - static (AbstractDistribution): The static components of the model. - key (Array): Jax random seed. + params: The trainable parameters of the model. + static: The static components of the model. + key: Jax random seed. """ dist = eqx.combine(params, static) diff --git a/flowjax/train/train_utils.py b/flowjax/train/train_utils.py index cb561b14..669e4ba5 100644 --- a/flowjax/train/train_utils.py +++ b/flowjax/train/train_utils.py @@ -24,13 +24,13 @@ def step( """Carry out a training step. Args: - params (PyTree): Parameters for the model - static (PyTree): Static components of the model. + params: Parameters for the model + static: Static components of the model. *args: Arguments passed to the loss function. - optimizer (optax.GradientTransformation): Optax optimizer. - opt_state (PyTree): Optimizer state. - loss_fn (Callable): The loss function. This should take params and static as - the first two arguments. + optimizer: Optax optimizer. + opt_state: Optimizer state. + loss_fn: The loss function. This should take params and static as the first two + arguments. Returns: tuple: (params, opt_state, loss_val) @@ -45,12 +45,12 @@ def train_val_split(key: Array, arrays: Sequence[Array], val_prop: float = 0.1): """Random train validation split for a sequence of arrays. Args: - key (KeyArray): Jax random key. - arrays (Sequence[Array]): Sequence of arrays, with matching size on axis 0. - val_prop (float): Proportion of data to use for validation. Defaults to 0.1. + key: Jax random key. + arrays: Sequence of arrays, with matching size on axis 0. + val_prop: Proportion of data to use for validation. Defaults to 0.1. Returns: - tuple[tuple]: (train_arrays, validation_arrays) + A tuple containing the train arrays and the validation arrays. """ if not 0 <= val_prop <= 1: raise ValueError("val_prop should be between 0 and 1.") @@ -80,8 +80,8 @@ def get_batches(arrays: Sequence[Array], batch_size: int): batch size to be equal to the data length. Args: - arrays (Sequence[Array]): Sequence of arrays, with shape matching on axis 0. - batch_size (int): The batch size. + arrays: Sequence of arrays, with shape matching on axis 0. + batch_size: The batch size. """ data_len = arrays[0].shape[0] if not all(arr.shape[0] == data_len for arr in arrays): @@ -101,7 +101,7 @@ def count_fruitless(losses: list[float]) -> int: """Count the number of epochs since the minimum loss in a list of losses. Args: - losses (list[float]): List of losses. + losses: List of losses. """ min_idx = jnp.argmin(jnp.array(losses)).item() diff --git a/flowjax/train/variational_fit.py b/flowjax/train/variational_fit.py index 3ec25246..cd38ee83 100644 --- a/flowjax/train/variational_fit.py +++ b/flowjax/train/variational_fit.py @@ -16,35 +16,33 @@ def fit_to_variational_target( key: Array, - *, dist: AbstractDistribution, loss_fn: Callable, + *, steps: int = 100, learning_rate: float = 5e-4, optimizer: optax.GradientTransformation | None = None, filter_spec: Callable | PyTree = eqx.is_inexact_array, show_progress: bool = True, -): +) -> tuple[AbstractDistribution, list]: """Train a distribution (e.g. a flow) by variational inference. Args: - key (Array): Jax PRNGKey. - dist (AbstractDistribution): Distribution object, trainable parameters are found - using equinox.is_inexact_array. - loss_fn (Callable | None): The loss function to optimize (e.g. the ElboLoss). - steps (int, optional): The number of training steps to run. Defaults to 100. - learning_rate (float, optional): Learning rate. Defaults to 5e-4. - optimizer (optax.GradientTransformation | None, optional): Optax optimizer. If - provided, this overrides the default Adam optimizer, and the learning_rate - is ignored. Defaults to None. - filter_spec (Callable | PyTree, optional): Equinox `filter_spec` for - specifying trainable parameters. Either a callable `leaf -> bool`, or a - PyTree with prefix structure matching `dist` with True/False values. - Defaults to eqx.is_inexact_array. - show_progress (bool, optional): Whether to show progress bar. Defaults to True. + key: Jax PRNGKey. + dist: Distribution object, trainable parameters are found using + equinox.is_inexact_array. + loss_fn: The loss function to optimize (e.g. the ElboLoss). + steps: The number of training steps to run. Defaults to 100. + learning_rate: Learning rate. Defaults to 5e-4. + optimizer: Optax optimizer. If provided, this overrides the default Adam + optimizer, and the learning_rate is ignored. Defaults to None. + filter_spec: Equinox `filter_spec` for specifying trainable parameters. Either + a callable `leaf -> bool`, or a PyTree with prefix structure matching `dist` + with True/False values. Defaults to eqx.is_inexact_array. + show_progress: Whether to show progress bar. Defaults to True. Returns: - tuple: (distribution, losses). + A tuple containing the trained distribution and the losses. """ if optimizer is None: optimizer = optax.adam(learning_rate) diff --git a/flowjax/utils.py b/flowjax/utils.py index 25ce260d..aa4f67be 100644 --- a/flowjax/utils.py +++ b/flowjax/utils.py @@ -1,4 +1,6 @@ """Utility functions.""" +from __future__ import annotations + from collections.abc import Sequence import equinox as eqx @@ -17,12 +19,12 @@ def real_to_increasing_on_interval( """Transform unconstrained vector to monotonically increasing positions on [-B, B]. Args: - arr (Array): Parameter vector. - B (float): Interval to transform output. Defaults to 1. - softmax_adjust (float): Rescales softmax output using (widths + - softmax_adjust/widths.size) / (1 + softmax_adjust). e.g. 0=no adjustment, - 1=average softmax output with evenly spaced widths, >1 promotes more evenly - spaced widths. + arr: Parameter vector. + B : Interval to transform output. Defaults to 1. + softmax_adjust : Rescales softmax output using + ``(widths + softmax_adjust/widths.size) / (1 + softmax_adjust)``. e.g. + 0=no adjustment, 1=average softmax output with evenly spaced widths, >1 + promotes more evenly spaced widths. """ if softmax_adjust < 0: raise ValueError("softmax_adjust should be >= 0.") @@ -92,12 +94,12 @@ def get_ravelled_bijection_constructor( construction of the bijection directly from the neural network output. Args: - bijection (AbstractBijection): Bijection to form constructor for. + bijection: Bijection to form constructor for. filter_spec: Filter function to specify parameters. Defaults to eqx.is_inexact_array. Returns: - tuple: The constructor, and the current parameter vector. + The constructor, and the current parameter vector. """ params, static = eqx.partition(bijection, filter_spec) current, unravel = ravel_pytree(params) @@ -109,7 +111,7 @@ def constructor(ravelled_params: Array): return constructor, current -def arraylike_to_array(arr, err_name: str = "input", **kwargs) -> Array: +def arraylike_to_array(arr: ArrayLike, err_name: str = "input", **kwargs) -> Array: """Check the input is arraylike and convert to a jax Array with ``jnp.asarray``. Combines ``jnp.asarray``, with an isinstance(arr, ArrayLike) check. This @@ -119,8 +121,7 @@ def arraylike_to_array(arr, err_name: str = "input", **kwargs) -> Array: Args: arr: Arraylike input to convert to a jax array. - err_name (str, optional): Name of the input in the error message. Defaults to - "input". + err_name: Name of the input in the error message. Defaults to "input". **kwargs: Key word arguments passed to jnp.asarray. """ if not isinstance(arr, ArrayLike): diff --git a/pyproject.toml b/pyproject.toml index a317bcea..2552ae8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ license = { file = "LICENSE" } name = "flowjax" readme = "README.md" requires-python = ">=3.10" -version = "10.1.0" +version = "11.0.0" [project.urls] repository = "https://github.com/danielward27/flowjax" @@ -29,6 +29,7 @@ dev = [ "sphinx", "sphinx-rtd-theme", "sphinx-copybutton", + "sphinx-autodoc-typehints", "nbsphinx", "ipython", "numpyro",