Skip to content

Commit

Permalink
Fix flax test for flax v0.9.0 (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Sep 11, 2024
1 parent a9f2473 commit 8ed4def
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax_ml_stack/tests/test_nnx_with_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def loss(model, x=x, y=y):
return jnp.mean((model(x) - y) ** 2)

initial_loss = loss(model)
grads = nnx.grad(loss, wrt=nnx.Param)(state.model)
grads = nnx.grad(loss)(state.model)
state.update(grads)
final_loss = loss(model)

Expand Down

0 comments on commit 8ed4def

Please sign in to comment.