Skip to content

Commit

Permalink
Extra test
Browse files Browse the repository at this point in the history
  • Loading branch information
danielward27 committed Sep 2, 2024
1 parent 15238b9 commit e53c865
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def test_Parameterize():
assert pytest.approx(unwrap(unwrappable)) == jnp.zeros((3, 2))


def test_nested_Parameterized():
param = Parameterize(
jnp.square,
Parameterize(jnp.square, Parameterize(jnp.square, 2)),
)
assert unwrap(param) == jnp.square(jnp.square(jnp.square(2)))


def test_NonTrainable_and_non_trainable():
dist1 = eqx.tree_at(lambda dist: dist.bijection, Normal(), replace_fn=NonTrainable)
dist2 = non_trainable(Normal())
Expand Down

0 comments on commit e53c865

Please sign in to comment.