Skip to content

Commit

Permalink
Replace uses of internal JAX NumPy utils with public API functions. (#…
Browse files Browse the repository at this point in the history
…487)

* Replace uses of internal JAX NumPy utils with public API functions.

* Bug fix to update to new jaxopt version
  • Loading branch information
junpenglao authored Feb 17, 2023
1 parent d7831b5 commit f16fd05
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions blackjax/optimizers/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def minimize_lbfgs(
state=LbfgsState(
iter_num=last_step_raveled.state.iter_num,
value=last_step_raveled.state.value,
grad=unravel_fn(last_step_raveled.state.grad),
stepsize=last_step_raveled.state.stepsize,
error=last_step_raveled.state.error,
s_history=unravel_fn_mapped(last_step_raveled.state.s_history),
Expand Down
5 changes: 3 additions & 2 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import jax.numpy as jnp
from jax import jit, lax
from jax._src.numpy.util import _promote_dtypes
from jax.flatten_util import ravel_pytree
from jax.random import normal
from jax.tree_util import tree_leaves
Expand Down Expand Up @@ -42,7 +41,9 @@ def linear_map(diag_or_dense_a, b, *, precision="highest"):
-------
The result vector of the matrix multiplication.
"""
diag_or_dense_a, b = _promote_dtypes(diag_or_dense_a, b)
dtype = jnp.result_type(diag_or_dense_a.dtype, b.dtype)
diag_or_dense_a = diag_or_dense_a.astype(dtype)
b = b.astype(dtype)
ndim = jnp.ndim(diag_or_dense_a)

if ndim <= 1:
Expand Down

0 comments on commit f16fd05

Please sign in to comment.