diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index 64250fda3..d306b1d01 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import numpy as np + from jax.dtypes import canonicalize_dtype import jax.numpy as jnp from tensorflow_probability.substrates.jax import bijectors as tfb @@ -22,7 +24,7 @@ def _get_codomain(bijector): return constraints.positive elif bijector.__class__.__name__ == "GeneralizedPareto": loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration - if not_jax_tracer(concentration) and jnp.all(concentration < 0): + if not_jax_tracer(concentration) and np.all(concentration < 0): return constraints.interval(loc, loc + scale / jnp.abs(concentration)) # XXX: here we suppose concentration > 0 # which is not true in general, but should cover enough usage cases diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index aa16e57b7..54457b470 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -7,10 +7,9 @@ from numpyro.distributions import constraints from numpyro.distributions.continuous import Beta, Dirichlet, Gamma -from numpyro.distributions.discrete import Binomial, Multinomial, Poisson +from numpyro.distributions.discrete import BinomialProbs, MultinomialProbs, Poisson from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import promote_shapes, validate_sample -from numpyro.util import not_jax_tracer def _log_beta_1(alpha, value): @@ -35,20 +34,23 @@ class BetaBinomial(Distribution): 'total_count': constraints.nonnegative_integer} has_enumerate_support = True is_discrete = True + enumerate_support = BinomialProbs.enumerate_support def __init__(self, concentration1, concentration0, total_count=1, validate_args=None): + self.concentration1, self.concentration0, self.total_count = promote_shapes( + concentration1, concentration0, total_count + ) batch_shape = lax.broadcast_shapes(jnp.shape(concentration1), jnp.shape(concentration0), jnp.shape(total_count)) - self.concentration1 = jnp.broadcast_to(concentration1, batch_shape) - self.concentration0 = jnp.broadcast_to(concentration0, batch_shape) - self.total_count, = promote_shapes(total_count, shape=batch_shape) - self._beta = Beta(self.concentration1, self.concentration0) + concentration1 = jnp.broadcast_to(concentration1, batch_shape) + concentration0 = jnp.broadcast_to(concentration0, batch_shape) + self._beta = Beta(concentration1, concentration0) super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): key_beta, key_binom = random.split(key) probs = self._beta.sample(key_beta, sample_shape) - return Binomial(self.total_count, probs).sample(key_binom) + return BinomialProbs(total_count=self.total_count, probs=probs).sample(key_binom) @validate_sample def log_prob(self, value): @@ -68,18 +70,6 @@ def variance(self): def support(self): return constraints.integer_interval(0, self.total_count) - def enumerate_support(self, expand=True): - total_count = jnp.amax(self.total_count) - if not_jax_tracer(total_count): - # NB: the error can't be raised if inhomogeneous issue happens when tracing - if jnp.amin(self.total_count) != total_count: - raise NotImplementedError("Inhomogeneous total count not supported" - " by `enumerate_support`.") - values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape)) - if expand: - values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) - return values - class DirichletMultinomial(Distribution): r""" @@ -102,22 +92,18 @@ def __init__(self, concentration, total_count=1, validate_args=None): raise ValueError("`concentration` parameter must be at least one-dimensional.") batch_shape = lax.broadcast_shapes(jnp.shape(concentration)[:-1], jnp.shape(total_count)) - self.concentration = jnp.broadcast_to(concentration, batch_shape + jnp.shape(concentration)[-1:]) - self._dirichlet = Dirichlet(self.concentration) + concentration_shape = batch_shape + jnp.shape(concentration)[-1:] + self.concentration, = promote_shapes(concentration, shape=concentration_shape) self.total_count, = promote_shapes(total_count, shape=batch_shape) + concentration = jnp.broadcast_to(self.concentration, concentration_shape) + self._dirichlet = Dirichlet(concentration) super().__init__( self._dirichlet.batch_shape, self._dirichlet.event_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): key_dirichlet, key_multinom = random.split(key) probs = self._dirichlet.sample(key_dirichlet, sample_shape) - total_count = jnp.amax(self.total_count) - if not_jax_tracer(total_count): - # NB: the error can't be raised if inhomogeneous issue happens when tracing - if jnp.amin(self.total_count) != total_count: - raise NotImplementedError("Inhomogeneous total count not supported" - " by `sample`.") - return Multinomial(total_count, probs).sample(key_multinom) + return MultinomialProbs(total_count=self.total_count, probs=probs).sample(key_multinom) @validate_sample def log_prob(self, value): @@ -157,9 +143,8 @@ class GammaPoisson(Distribution): is_discrete = True def __init__(self, concentration, rate=1., validate_args=None): + self.concentration, self.rate = promote_shapes(concentration, rate) self._gamma = Gamma(concentration, rate) - self.concentration = self._gamma.concentration - self.rate = self._gamma.rate super(GammaPoisson, self).__init__(self._gamma.batch_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index db89b448e..35d50bfd1 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -50,7 +50,9 @@ 'Constraint', ] -import jax.numpy as jnp +import numpy as np + +import jax.numpy class Constraint(object): @@ -79,6 +81,7 @@ def __call__(self, x): class _CorrCholesky(Constraint): def __call__(self, x): + jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tril = jnp.tril(x) lower_triangular = jnp.all(jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1) positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) @@ -89,6 +92,7 @@ def __call__(self, x): class _CorrMatrix(Constraint): def __call__(self, x): + jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1) # check for the smallest eigenvalue is positive @@ -129,7 +133,7 @@ def __init__(self, lower_bound, upper_bound): self.upper_bound = upper_bound def __call__(self, x): - return (x >= self.lower_bound) & (x <= self.upper_bound) & (x == jnp.floor(x)) + return (x >= self.lower_bound) & (x <= self.upper_bound) & (x % 1 == 0) class _IntegerGreaterThan(Constraint): @@ -151,6 +155,7 @@ def __call__(self, x): class _LowerCholesky(Constraint): def __call__(self, x): + jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tril = jnp.tril(x) lower_triangular = jnp.all(jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1) positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) @@ -162,16 +167,17 @@ def __init__(self, upper_bound): self.upper_bound = upper_bound def __call__(self, x): - return jnp.all(x >= 0, axis=-1) & (jnp.sum(x, -1) == self.upper_bound) + return (x >= 0).all(axis=-1) & (x.sum(axis=-1) == self.upper_bound) class _OrderedVector(Constraint): def __call__(self, x): - return jnp.all(x[..., 1:] > x[..., :-1], axis=-1) + return (x[..., 1:] > x[..., :-1]).all(axis=-1) class _PositiveDefinite(Constraint): def __call__(self, x): + jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1) # check for the smallest eigenvalue is positive @@ -182,18 +188,18 @@ def __call__(self, x): class _Real(Constraint): def __call__(self, x): # XXX: consider to relax this condition to [-inf, inf] interval - return jnp.isfinite(x) + return (x == x) & (x != float('inf')) & (x != float('-inf')) class _RealVector(Constraint): def __call__(self, x): - return jnp.all(jnp.isfinite(x), axis=-1) + return ((x == x) & (x != float('inf')) & (x != float('-inf'))).all(axis=-1) class _Simplex(Constraint): def __call__(self, x): - x_sum = jnp.sum(x, axis=-1) - return jnp.all(x >= 0, axis=-1) & (x_sum < 1 + 1e-6) & (x_sum > 1 - 1e-6) + x_sum = x.sum(axis=-1) + return (x >= 0).all(axis=-1) & (x_sum < 1 + 1e-6) & (x_sum > 1 - 1e-6) # TODO: Make types consistent diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 0c7d7e61f..283b34d87 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -54,10 +54,11 @@ class Beta(Distribution): support = constraints.unit_interval def __init__(self, concentration1, concentration0, validate_args=None): + self.concentration1, self.concentration0 = promote_shapes(concentration1, concentration0) batch_shape = lax.broadcast_shapes(jnp.shape(concentration1), jnp.shape(concentration0)) - self.concentration1 = jnp.broadcast_to(concentration1, batch_shape) - self.concentration0 = jnp.broadcast_to(concentration0, batch_shape) - self._dirichlet = Dirichlet(jnp.stack([self.concentration1, self.concentration0], + concentration1 = jnp.broadcast_to(concentration1, batch_shape) + concentration0 = jnp.broadcast_to(concentration0, batch_shape) + self._dirichlet = Dirichlet(jnp.stack([concentration1, concentration0], axis=-1)) super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args) @@ -687,8 +688,8 @@ class MultivariateNormal(Distribution): def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None): - if jnp.isscalar(loc): - loc = jnp.expand_dims(loc, axis=-1) + if jnp.ndim(loc) == 0: + loc, = promote_shapes(loc, shape=(1,)) # temporary append a new axis to loc loc = loc[..., jnp.newaxis] if covariance_matrix is not None: @@ -704,7 +705,7 @@ def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_ ' must be specified.') batch_shape = lax.broadcast_shapes(jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2]) event_shape = jnp.shape(self.scale_tril)[-1:] - self.loc = jnp.broadcast_to(jnp.squeeze(loc, axis=-1), batch_shape + event_shape) + self.loc = loc[..., 0] super(MultivariateNormal, self).__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args) @@ -731,7 +732,7 @@ def precision_matrix(self): @property def mean(self): - return self.loc + return jnp.broadcast_to(self.loc, self.shape()) @property def variance(self): @@ -817,7 +818,7 @@ def __init__(self, loc, cov_factor, cov_diag, validate_args=None): loc, cov_factor, cov_diag = promote_shapes(loc[..., jnp.newaxis], cov_factor, cov_diag[..., jnp.newaxis]) batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(cov_factor), jnp.shape(cov_diag))[:-2] - self.loc = jnp.broadcast_to(loc[..., 0], batch_shape + event_shape) + self.loc = loc[..., 0] self.cov_factor = cov_factor cov_diag = cov_diag[..., 0] self.cov_diag = cov_diag @@ -937,22 +938,23 @@ class Pareto(TransformedDistribution): arg_constraints = {'scale': constraints.positive, 'alpha': constraints.positive} def __init__(self, scale, alpha, validate_args=None): + self.scale, self.alpha = promote_shapes(scale, alpha) batch_shape = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(alpha)) - self.scale, self.alpha = jnp.broadcast_to(scale, batch_shape), jnp.broadcast_to(alpha, batch_shape) - base_dist = Exponential(self.alpha) - transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] + scale, alpha = jnp.broadcast_to(scale, batch_shape), jnp.broadcast_to(alpha, batch_shape) + base_dist = Exponential(alpha) + transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)] super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args) @property def mean(self): # mean is inf for alpha <= 1 - a = lax.div(self.alpha * self.scale, (self.alpha - 1)) + a = jnp.divide(self.alpha * self.scale, (self.alpha - 1)) return jnp.where(self.alpha <= 1, jnp.inf, a) @property def variance(self): # var is inf for alpha <= 2 - a = lax.div((self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2)) + a = jnp.divide((self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2)) return jnp.where(self.alpha <= 2, jnp.inf, a) # override the default behaviour to save computations @@ -971,9 +973,9 @@ class StudentT(Distribution): def __init__(self, df, loc=0., scale=1., validate_args=None): batch_shape = lax.broadcast_shapes(jnp.shape(df), jnp.shape(loc), jnp.shape(scale)) - self.df = jnp.broadcast_to(df, batch_shape) - self.loc, self.scale = promote_shapes(loc, scale, shape=batch_shape) - self._chi2 = Chi2(self.df) + self.df, self.loc, self.scale = promote_shapes(df, loc, scale, shape=batch_shape) + df = jnp.broadcast_to(df, batch_shape) + self._chi2 = Chi2(df) super(StudentT, self).__init__(batch_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): @@ -997,7 +999,7 @@ def mean(self): @property def variance(self): - var = jnp.where(self.df > 2, self.scale ** 2 * self.df / (self.df - 2.0), jnp.inf) + var = jnp.where(self.df > 2, jnp.divide(self.scale ** 2 * self.df, self.df - 2.0), jnp.inf) var = jnp.where(self.df <= 1, jnp.nan, var) return jnp.broadcast_to(var, self.batch_shape) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 9e3517258..6a987e1fc 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -183,12 +183,14 @@ def support(self): return constraints.integer_interval(0, self.total_count) def enumerate_support(self, expand=True): - total_count = jnp.amax(self.total_count) - if not_jax_tracer(total_count): + if not_jax_tracer(self.total_count): + total_count = np.amax(self.total_count) # NB: the error can't be raised if inhomogeneous issue happens when tracing - if jnp.amin(self.total_count) != total_count: + if np.amin(self.total_count) != total_count: raise NotImplementedError("Inhomogeneous total count not supported" " by `enumerate_support`.") + else: + total_count = jnp.amax(self.total_count) values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) @@ -200,6 +202,7 @@ class BinomialLogits(Distribution): 'total_count': constraints.nonnegative_integer} has_enumerate_support = True is_discrete = True + enumerate_support = BinomialProbs.enumerate_support def __init__(self, logits, total_count=1, validate_args=None): self.logits, self.total_count = promote_shapes(logits, total_count) @@ -235,18 +238,6 @@ def variance(self): def support(self): return constraints.integer_interval(0, self.total_count) - def enumerate_support(self, expand=True): - total_count = jnp.amax(self.total_count) - if not_jax_tracer(total_count): - # NB: the error can't be raised if inhomogeneous issue happens when tracing - if jnp.amin(self.total_count) != total_count: - raise NotImplementedError("Inhomogeneous total count not supported" - " by `enumerate_support`.") - values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape)) - if expand: - values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) - return values - def Binomial(total_count=1, probs=None, logits=None, validate_args=None): if probs is not None: @@ -421,7 +412,11 @@ class OrderedLogistic(CategoricalProbs): 'cutpoints': constraints.ordered_vector} def __init__(self, predictor, cutpoints, validate_args=None): - predictor, self.cutpoints = promote_shapes(jnp.expand_dims(predictor, -1), cutpoints) + if jnp.ndim(predictor) == 0: + predictor, = promote_shapes(predictor, shape=(1,)) + else: + predictor = predictor[..., None] + predictor, self.cutpoints = promote_shapes(predictor, cutpoints) self.predictor = predictor[..., 0] cumulative_probs = expit(cutpoints - predictor) # add two boundary points 0 and 1 diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 9d6fcbbdf..0281bdac9 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -29,6 +29,8 @@ from contextlib import contextmanager import warnings +import numpy as np + from jax import lax, tree_util import jax.numpy as jnp @@ -150,9 +152,9 @@ def __init__(self, batch_shape=(), event_shape=(), validate_args=None): continue if is_dependent(constraint): continue # skip constraints that cannot be checked - is_valid = jnp.all(constraint(getattr(self, param))) + is_valid = constraint(getattr(self, param)) if not_jax_tracer(is_valid): - if not is_valid: + if not np.all(is_valid): raise ValueError("{} distribution got invalid {} parameter.".format( self.__class__.__name__, param)) super(Distribution, self).__init__() @@ -255,7 +257,7 @@ def variance(self): def _validate_sample(self, value): mask = self.support(value) if not_jax_tracer(mask): - if not jnp.all(mask): + if not np.all(mask): warnings.warn('Out-of-support values provided to log prob method. ' 'The value argument should be within the support.') return mask diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 4e7be1010..82facf7c9 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -4,6 +4,8 @@ import math import warnings +import numpy as np + from jax import ops, tree_flatten, tree_map, vmap from jax.dtypes import canonicalize_dtype from jax.flatten_util import ravel_pytree @@ -96,19 +98,19 @@ def codomain(self): elif self.domain is constraints.real_vector: return constraints.real_vector elif isinstance(self.domain, constraints.greater_than): - if not_jax_tracer(self.scale) and jnp.all(self.scale < 0): + if not_jax_tracer(self.scale) and np.all(self.scale < 0): return constraints.less_than(self(self.domain.lower_bound)) # we suppose scale > 0 for any tracer else: return constraints.greater_than(self(self.domain.lower_bound)) elif isinstance(self.domain, constraints.less_than): - if not_jax_tracer(self.scale) and jnp.all(self.scale < 0): + if not_jax_tracer(self.scale) and np.all(self.scale < 0): return constraints.greater_than(self(self.domain.upper_bound)) # we suppose scale > 0 for any tracer else: return constraints.less_than(self(self.domain.upper_bound)) elif isinstance(self.domain, constraints.interval): - if not_jax_tracer(self.scale) and jnp.all(self.scale < 0): + if not_jax_tracer(self.scale) and np.all(self.scale < 0): return constraints.interval(self(self.domain.upper_bound), self(self.domain.lower_bound)) else: diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 5a357428c..fca6f182a 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -4,6 +4,8 @@ from functools import update_wrapper import math +import numpy as np + from jax import jit, lax, random, vmap from jax.dtypes import canonicalize_dtype from jax.lib import xla_bridge @@ -232,6 +234,13 @@ def binary_cross_entropy_with_logits(x, y): return jnp.clip(x, 0) + jnp.log1p(jnp.exp(-jnp.abs(x))) - x * y +def _reshape(x, shape): + if isinstance(x, (int, float, np.ndarray, np.generic)): + return np.reshape(x, shape) + else: + return jnp.reshape(x, shape) + + def promote_shapes(*args, shape=()): # adapted from lax.lax_numpy if len(args) < 2 and not shape: @@ -239,7 +248,7 @@ def promote_shapes(*args, shape=()): else: shapes = [jnp.shape(arg) for arg in args] num_dims = len(lax.broadcast_shapes(shape, *shapes)) - return [lax.reshape(arg, (1,) * (num_dims - len(s)) + s) + return [_reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg for arg, s in zip(args, shapes)] diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index ed32558ce..fc2e442dc 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -6,6 +6,8 @@ from contextlib import ExitStack import warnings +import numpy as np + from jax import hessian, lax, random, tree_map from jax.experimental import stax from jax.flatten_util import ravel_pytree @@ -647,7 +649,7 @@ def loss_fn(z): precision = hessian(loss_fn)(loc) scale_tril = cholesky_of_inverse(precision) if not_jax_tracer(scale_tril): - if jnp.any(jnp.isnan(scale_tril)): + if np.any(np.isnan(scale_tril)): warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior" " samples from AutoLaplaceApproxmiation will be constant (equal to" " the MAP point).") diff --git a/test/test_distributions.py b/test/test_distributions.py index 9f070d692..0e0a51a8f 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -728,6 +728,16 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): with pytest.raises(ValueError): jax_dist(*oob_params, validate_args=True) + with pytest.raises(ValueError): + # test error raised under jit omnistaging + oob_params = jax.device_get(oob_params) + + def dist_gen_fn(): + d = jax_dist(*oob_params, validate_args=True) + return d + + jax.jit(dist_gen_fn)() + d = jax_dist(*valid_params, validate_args=True) # Test agreement of log density evaluation on randomly generated samples @@ -744,9 +754,36 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): # Out of support samples throw ValueError oob_samples = gen_values_outside_bounds(d.support, size=prepend_shape + d.batch_shape + d.event_shape) - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match="Out-of-support"): d.log_prob(oob_samples) + with pytest.warns(UserWarning, match="Out-of-support"): + # test warning work under jit omnistaging + oob_samples = jax.device_get(oob_samples) + valid_params = jax.device_get(valid_params) + + def log_prob_fn(): + d = jax_dist(*valid_params, validate_args=True) + return d.log_prob(oob_samples) + + jax.jit(log_prob_fn)() + + +def test_omnistaging_invalid_param(): + def f(x): + return dist.LogNormal(x, -np.ones(2), validate_args=True).log_prob(0) + + with pytest.raises(ValueError, match="got invalid"): + jax.jit(f)(0) + + +def test_omnistaging_invalid_sample(): + def f(x): + return dist.LogNormal(x, np.ones(2), validate_args=True).log_prob(-1) + + with pytest.warns(UserWarning, match="Out-of-support"): + jax.jit(f)(0) + def test_categorical_log_prob_grad(): data = jnp.repeat(jnp.arange(3), 10)