Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use jax.tree_map in place of jax.tree_multimap
The latter is deprecated in jax and will be removed in jax-ml/jax#11382
- Loading branch information