From 8ed4def265b5ad409633c3a6fe5ca26a10fe6021 Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Wed, 11 Sep 2024 15:53:33 -0700 Subject: [PATCH] Fix flax test for flax v0.9.0 (#11) --- jax_ml_stack/tests/test_nnx_with_optax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_ml_stack/tests/test_nnx_with_optax.py b/jax_ml_stack/tests/test_nnx_with_optax.py index d571e76..9c81aa5 100644 --- a/jax_ml_stack/tests/test_nnx_with_optax.py +++ b/jax_ml_stack/tests/test_nnx_with_optax.py @@ -46,7 +46,7 @@ def loss(model, x=x, y=y): return jnp.mean((model(x) - y) ** 2) initial_loss = loss(model) - grads = nnx.grad(loss, wrt=nnx.Param)(state.model) + grads = nnx.grad(loss)(state.model) state.update(grads) final_loss = loss(model)