From 723e6bf61b6ae4942048afcc70f75ff5f995f84c Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Fri, 8 Nov 2024 22:45:37 -0500 Subject: [PATCH] Fix AdaBelief implementation. --- optax/_src/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 69b24bdc1..562763dd3 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -689,7 +689,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params mu = otu.tree_update_moment(updates, state.mu, b1, 1) - prediction_error = jax.tree.map(lambda g, m: g - m, updates, state.mu) + prediction_error = otu.tree_sub(updates, mu) nu = otu.tree_update_moment_per_elem_norm(prediction_error, state.nu, b2, 2) nu = jax.tree.map(lambda v: v + eps_root, nu) count_inc = numerics.safe_increment(state.count)