Skip to content

Commit

Permalink
Use public JAX API for PRNGKeyArray (#498)
Browse files Browse the repository at this point in the history
* Use public JAX API for PRNGKeyArray

* More type clean ups
- also fix import
  • Loading branch information
junpenglao committed Mar 12, 2024
1 parent a287632 commit 2193c95
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions blackjax/types.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from typing import Any, Iterable, Mapping, Union

import jax._src.prng as prng
import jax
import jax.numpy as jnp
import numpy as np

#: JAX or Numpy array
Array = Union[np.ndarray, jnp.ndarray]

#: JAX PyTrees
PyTree = Union[Array, Iterable[Array], Mapping[Any, Array]]
# It is not currently tested but we also support recursive PyTrees.
# Once recursive typing is fully supported (https://github.com/python/mypy/issues/731), we can uncomment the line below.
# PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]]
PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]]

#: JAX PRNGKey
PRNGKey = prng.PRNGKeyArray
PRNGKey = jax.random.PRNGKeyArray

0 comments on commit 2193c95

Please sign in to comment.