Skip to content

Commit

Permalink
Merge pull request #196 from danielward27/rename_partial_and_rm_fit_t…
Browse files Browse the repository at this point in the history
…o_variational

Rename partial and rm fit to variational
  • Loading branch information
danielward27 authored Dec 4, 2024
2 parents d0b3139 + 4159b8b commit 92af68d
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 139 deletions.
4 changes: 2 additions & 2 deletions flowjax/bijections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
EmbedCondition,
Flip,
Identity,
Indexed,
Invert,
NumericalInverse,
Partial,
Permute,
Reshape,
)
Expand All @@ -42,7 +42,7 @@
"LeakyTanh",
"Loc",
"MaskedAutoregressive",
"Partial",
"Indexed",
"Permute",
"Power",
"Planar",
Expand Down
9 changes: 4 additions & 5 deletions flowjax/bijections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,13 @@ def inverse_and_log_det(self, y, condition=None):
return jnp.flip(y), jnp.array(0)


class Partial(AbstractBijection): # TODO rename to avoid confusion with functools
class Indexed(AbstractBijection):
"""Applies bijection to specific indices of an input.
Args:
bijection: Bijection that is compatible with the subset
of x indexed by idxs. idxs: Indices (Integer, a slice, or an ndarray
with integer/bool dtype) of the transformed portion.
idxs: The indexes to transform.
bijection: Bijection that is compatible with the subset of x indexed by idxs.
idxs: Indices (Integer, a slice, or an ndarray with integer/bool dtype) of the
transformed portion.
shape: Shape of the bijection. Defaults to None.
"""

Expand Down
3 changes: 1 addition & 2 deletions flowjax/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

from .loops import fit_to_data, fit_to_key_based_loss
from .train_utils import step
from .variational_fit import fit_to_variational_target

__all__ = ["fit_to_key_based_loss", "fit_to_data", "fit_to_variational_target", "step"]
__all__ = ["fit_to_key_based_loss", "fit_to_data", "step"]
81 changes: 0 additions & 81 deletions flowjax/train/variational_fit.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/test_bijections/test_bijection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from equinox import EquinoxRuntimeError

from flowjax.bijections import Affine, Partial, Permute
from flowjax.bijections import Affine, Indexed, Permute

test_cases = {
# name: idx, expected
Expand All @@ -24,7 +24,7 @@ def test_partial(idx, expected):
"Check values only change where we expect."
x = jnp.zeros(4)
shape = x[idx].shape
bijection = Partial(Affine(jnp.ones(shape)), idx, x.shape)
bijection = Indexed(Affine(jnp.ones(shape)), idx, x.shape)
y = bijection.transform(x)
assert jnp.all((x != y) == expected)

Expand Down
10 changes: 5 additions & 5 deletions tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
Exp,
Flip,
Identity,
Indexed,
LeakyTanh,
Loc,
MaskedAutoregressive,
NumericalInverse,
Partial,
Permute,
Planar,
Power,
Expand Down Expand Up @@ -54,14 +54,14 @@
"Permute (3D)": lambda: Permute(
jnp.reshape(jr.permutation(KEY, jnp.arange(2 * 3 * 4)), (2, 3, 4)),
),
"Partial (int)": lambda: Partial(Affine(jnp.array(2), jnp.array(2)), 0, (DIM,)),
"Partial (bool array)": lambda: Partial(
"Indexed (int)": lambda: Indexed(Affine(jnp.array(2), jnp.array(2)), 0, (DIM,)),
"Indexed (bool array)": lambda: Indexed(
Flip((2,)),
jnp.array([True, False, True]),
(DIM,),
),
"Partial (int array)": lambda: Partial(Flip((2,)), jnp.array([0, 2]), (DIM,)),
"Partial (slice)": lambda: Partial(Affine(jnp.zeros(2)), slice(0, 2), (DIM,)),
"Indexed (int array)": lambda: Indexed(Flip((2,)), jnp.array([0, 2]), (DIM,)),
"Indexed (slice)": lambda: Indexed(Affine(jnp.zeros(2)), slice(0, 2), (DIM,)),
"Affine": lambda: Affine(jnp.ones(DIM), jnp.full(DIM, 2)),
"Affine (pos and neg scales)": lambda: eqx.tree_at(
lambda aff: aff.scale, Affine(scale=jnp.ones(3)), jnp.array([-1, 1, -2])
Expand Down
42 changes: 0 additions & 42 deletions tests/test_train/test_variational_fit.py

This file was deleted.

0 comments on commit 92af68d

Please sign in to comment.