Skip to content

Commit

Permalink
More type clean ups
Browse files Browse the repository at this point in the history
- also fix import
  • Loading branch information
junpenglao committed Feb 27, 2023
1 parent 56cc5f5 commit 3224185
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions blackjax/types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from typing import Any, Iterable, Mapping, Union

import jax
import jax.numpy as jnp
import numpy as np

#: JAX or Numpy array
Array = Union[np.ndarray, jnp.ndarray]

#: JAX PyTrees
PyTree = Union[Array, Iterable[Array], Mapping[Any, Array]]
# It is not currently tested but we also support recursive PyTrees.
# Once recursive typing is fully supported (https://github.com/python/mypy/issues/731), we can uncomment the line below.
# PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]]
PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]]

#: JAX PRNGKey
PRNGKey = jax.random.PRNGKeyArray

0 comments on commit 3224185

Please sign in to comment.