diff --git a/blackjax/types.py b/blackjax/types.py index 5f02bc661..dc2181a03 100644 --- a/blackjax/types.py +++ b/blackjax/types.py @@ -1,11 +1,7 @@ from typing import Any, Iterable, Mapping, Union import jax -import jax.numpy as jnp -import numpy as np - -#: JAX or Numpy array -Array = Union[np.ndarray, jnp.ndarray] +from chex import Array #: JAX PyTrees PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]]