diff --git a/numpyro/ops/indexing.py b/numpyro/ops/indexing.py index 3dfc4efd8b..c69fb225c3 100644 --- a/numpyro/ops/indexing.py +++ b/numpyro/ops/indexing.py @@ -1,14 +1,18 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from typing import Any + +from jax import Array import jax.numpy as jnp +from jax.typing import ArrayLike def _is_batched(arg): return jnp.ndim(arg) > 0 -def vindex(tensor, args): +def vindex(tensor: ArrayLike, args: tuple[Any, ...]) -> Array: """ Vectorized advanced indexing with broadcasting semantics. @@ -72,10 +76,10 @@ def vindex(tensor, args): This implementation is similar to the proposed notation ``x.vindex[]`` except for slightly different handling of ``Ellipsis``. - :param jnp.ndarray tensor: A tensor to be indexed. - :param tuple args: An index, as args to ``__getitem__``. + :param ArrayLike tensor: A tensor to be indexed. + :param tuple[Any, ...] args: An index, as args to ``__getitem__``. :returns: A nonstandard interpretation of ``tensor[args]``. - :rtype: jnp.ndarray + :rtype: Array """ if not isinstance(args, tuple): return tensor[args] @@ -140,8 +144,8 @@ class Vindex: :return: An object with a special :meth:`__getitem__` method. """ - def __init__(self, tensor): + def __init__(self, tensor: ArrayLike) -> None: self._tensor = tensor - def __getitem__(self, args): + def __getitem__(self, args: tuple[Any, ...]) -> Array: return vindex(self._tensor, args)