From 7d5039378bbe02fe7363413c167060672119e497 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Wed, 18 Sep 2024 01:56:17 +0500 Subject: [PATCH] gh-1806: Implementation of Doubly Truncated Power Law and Lower Truncated Power Law (#1807) * implementation of DoublyTruncatedPowerLaw * implementation of LowerTruncatedPowerLaw * chore: mathematical description in docstrings * chore: mathematical details of `LowerTruncatedPowerLaw` * chore: Fix bug in DoublyTruncatedPowerLaw cdf and icdf calculation * chore: Refactor mean and variance calculation by using kth-moment in DoublyTruncatedPowerLaw * chore: Refactor mean and variance calculation in LowerTruncatedPowerLaw * chore: masking in icdf of LowerTruncatedPowerLaw * chore: entropy of LowerTruncatedPowerLaw * chore: `lax.sqaure` replaced with `jnp.sqaure` * chore: moments and entropy were extra and removed * chore: unit tests * fix: nan gradients fixed, values still diverging * Updated UpperTruncatedPowerLaw with adequate derivations, including fixing the discontinuity point for alpha equals minus one and correct tangents for lower and upper bounds. * Changed constrains of alpha of LowerTruncatedPowerLaw to the smaller minus one and also changed equation to keep equational uniformity UpperTruncatedPowerLaw to and improved stability of the calculation by integrating formula parts inside log to prevent NaN values due to negative values of some parts that could balance out in the end. * chore: code and docstring formated * chore: equation refactor and simplified * chore: equation refactor and simplified * chore: use numpy arrays and numpy constants * chore: high precision computation enable for powerlaws * chore: `__name__` attribute calls removed * chore: powerlaws shifted with truncated distributions * chore: spelling mistakes fixed with code spell checker pre-commit hook * fix typo: perforance->perforamce->performance * chore: explicit enabling/disabling of 64bit floating point numbers * chore: disable everytime and enable x64 for power laws * chore: disable x64 for every test * chore: linked explanation in comments for disabling x64 for future reference for devs * chore: high precision test handeled efficiently for DoublyTruncatedPowerLaw * chore: high precision exception handled in test_log_prob_gradient --------- Co-authored-by: David Ziegler <25408738+InfinityMod@users.noreply.github.com> --- .github/workflows/ci.yml | 3 + .pre-commit-config.yaml | 8 + README.md | 75 ++-- docker/README.md | 2 +- docs/source/distributions.rst | 16 + examples/annotation.py | 2 +- examples/ar2.py | 2 +- examples/holt_winters.py | 2 +- examples/mortality.py | 2 +- numpyro/distributions/__init__.py | 26 +- numpyro/distributions/continuous.py | 12 +- numpyro/distributions/distribution.py | 4 +- numpyro/distributions/gof.py | 4 +- numpyro/distributions/transforms.py | 6 +- numpyro/distributions/truncated.py | 533 ++++++++++++++++++++++++++ numpyro/infer/barker.py | 2 +- numpyro/infer/elbo.py | 2 +- numpyro/infer/ensemble.py | 4 +- numpyro/infer/inspect.py | 4 +- numpyro/infer/mcmc.py | 2 +- numpyro/infer/mixed_hmc.py | 2 +- numpyro/optim.py | 2 +- test/contrib/test_infer_discrete.py | 2 +- test/test_distributions.py | 40 +- 24 files changed, 672 insertions(+), 85 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 47fabe4e7..996344220 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,6 +72,9 @@ jobs: - name: Test with pytest run: | CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/ + - name: Test x64 + run: | + JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k powerLaw test-inference: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bdbbe787a..b22262c06 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,11 @@ repos: - id: check-yaml - id: check-added-large-files exclude: notebooks/* + + - repo: https://github.com/codespell-project/codespell + rev: v2.3.0 + hooks: + - id: codespell + stages: [commit, commit-msg] + args: + [--ignore-words-list, "Teh,aas", --check-filenames, --skip, "*.ipynb"] diff --git a/README.md b/README.md index 83fcc91ad..62a7fc5cd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ [![Build Status](https://github.com/pyro-ppl/numpyro/workflows/CI/badge.svg)](https://github.com/pyro-ppl/numpyro/actions) [![Documentation Status](https://readthedocs.org/projects/numpyro/badge/?version=latest)](https://numpyro.readthedocs.io/en/latest/?badge=latest) [![Latest Version](https://badge.fury.io/py/numpyro.svg)](https://pypi.python.org/pypi/numpyro) + # NumPyro Probabilistic programming powered by [JAX](https://github.com/google/jax) for autograd and JIT compilation to GPU/TPU/CPU. @@ -15,10 +16,10 @@ NumPyro is a lightweight probabilistic programming library that provides a NumPy NumPyro is designed to be *lightweight* and focuses on providing a flexible substrate that users can build on: - - **Pyro Primitives:** NumPyro programs can contain regular Python and NumPy code, in addition to [Pyro primitives](https://pyro.ai/examples/intro_part_i.html) like `sample` and `param`. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See the [example](https://github.com/pyro-ppl/numpyro#a-simple-example---8-schools) below. - - **Inference algorithms:** NumPyro supports a number of inference algorithms, with a particular focus on MCMC algorithms like Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. Additional MCMC algorithms include [MixedHMC](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mixed_hmc.MixedHMC) (which can accommodate discrete latent variables) as well as [HMCECS](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCECS) (which only computes the likelihood for subsets of the data in each iteration). One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose `jit` and `grad` to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using [Iterative NUTS](https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS)). There is also a basic Variational Inference implementation together with many flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI). The variational inference implementation supports a number of features, including support for models with discrete latent variables (see [TraceGraph_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceGraph_ELBO) and [TraceEnum_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceEnum_ELBO)). - - **Distributions:** The [numpyro.distributions](https://numpyro.readthedocs.io/en/latest/distributions.html) module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's [functional pseudo-random number generator](https://github.com/google/jax#random-numbers-are-different). The design of the distributions module largely follows from [PyTorch](https://pytorch.org/docs/stable/distributions.html). A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in `torch.distributions`. In addition to distributions, `constraints` and `transforms` are very useful when operating on distribution classes with bounded support. Finally, distributions from TensorFlow Probability ([TFP](https://num.pyro.ai/en/latest/distributions.html?highlight=tfp#numpyro.contrib.tfp.distributions.TFPDistribution)) can directly be used in NumPyro models. - - **Effect handlers:** Like Pyro, primitives like `sample` and `param` can be provided nonstandard interpretations using effect-handlers from the [numpyro.handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) module, and these can be easily extended to implement custom inference algorithms and inference utilities. +- **Pyro Primitives:** NumPyro programs can contain regular Python and NumPy code, in addition to [Pyro primitives](https://pyro.ai/examples/intro_part_i.html) like `sample` and `param`. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See the [example](https://github.com/pyro-ppl/numpyro#a-simple-example---8-schools) below. +- **Inference algorithms:** NumPyro supports a number of inference algorithms, with a particular focus on MCMC algorithms like Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. Additional MCMC algorithms include [MixedHMC](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mixed_hmc.MixedHMC) (which can accommodate discrete latent variables) as well as [HMCECS](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCECS) (which only computes the likelihood for subsets of the data in each iteration). One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose `jit` and `grad` to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using [Iterative NUTS](https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS)). There is also a basic Variational Inference implementation together with many flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI). The variational inference implementation supports a number of features, including support for models with discrete latent variables (see [TraceGraph_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceGraph_ELBO) and [TraceEnum_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceEnum_ELBO)). +- **Distributions:** The [numpyro.distributions](https://numpyro.readthedocs.io/en/latest/distributions.html) module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's [functional pseudo-random number generator](https://github.com/google/jax#random-numbers-are-different). The design of the distributions module largely follows from [PyTorch](https://pytorch.org/docs/stable/distributions.html). A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in `torch.distributions`. In addition to distributions, `constraints` and `transforms` are very useful when operating on distribution classes with bounded support. Finally, distributions from TensorFlow Probability ([TFP](https://num.pyro.ai/en/latest/distributions.html?highlight=tfp#numpyro.contrib.tfp.distributions.TFPDistribution)) can directly be used in NumPyro models. +- **Effect handlers:** Like Pyro, primitives like `sample` and `param` can be provided nonstandard interpretations using effect-handlers from the [numpyro.handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) module, and these can be easily extended to implement custom inference algorithms and inference utilities. ## A Simple Example - 8 Schools @@ -34,6 +35,7 @@ The data is given by: >>> sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) ``` + , where `y` are the treatment effects and `sigma` the standard error. We build a hierarchical model for the study where we assume that the group-level parameters `theta` for each school are sampled from a Normal distribution with unknown mean `mu` and standard deviation `tau`, while the observed data are in turn generated from a Normal distribution with mean and standard deviation given by `theta` (true effect) and `sigma`, respectively. This allows us to estimate the population-level parameters `mu` and `tau` by pooling from all the observations, while still allowing for individual variation amongst the schools using the group-level `theta` parameters. ```python @@ -88,7 +90,7 @@ Expected log joint density: -54.55 ``` -The values above 1 for the split Gelman Rubin diagnostic (`r_hat`) indicates that the chain has not fully converged. The low value for the effective sample size (`n_eff`), particularly for `tau`, and the number of divergent transitions looks problematic. Fortunately, this is a common pathology that can be rectified by using a [non-centered paramaterization](https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html) for `tau` in our model. This is straightforward to do in NumPyro by using a [TransformedDistribution](https://num.pyro.ai/en/latest/distributions.html#transformeddistribution) instance together with a [reparameterization](https://num.pyro.ai/en/latest/handlers.html#reparam) effect handler. Let us rewrite the same model but instead of sampling `theta` from a `Normal(mu, tau)`, we will instead sample it from a base `Normal(0, 1)` distribution that is transformed using an [AffineTransform](https://num.pyro.ai/en/latest/distributions.html#affinetransform). Note that by doing so, NumPyro runs HMC by generating samples `theta_base` for the base `Normal(0, 1)` distribution instead. We see that the resulting chain does not suffer from the same pathology — the Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good! +The values above 1 for the split Gelman Rubin diagnostic (`r_hat`) indicates that the chain has not fully converged. The low value for the effective sample size (`n_eff`), particularly for `tau`, and the number of divergent transitions looks problematic. Fortunately, this is a common pathology that can be rectified by using a [non-centered parameterization](https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html) for `tau` in our model. This is straightforward to do in NumPyro by using a [TransformedDistribution](https://num.pyro.ai/en/latest/distributions.html#transformeddistribution) instance together with a [reparameterization](https://num.pyro.ai/en/latest/handlers.html#reparam) effect handler. Let us rewrite the same model but instead of sampling `theta` from a `Normal(mu, tau)`, we will instead sample it from a base `Normal(0, 1)` distribution that is transformed using an [AffineTransform](https://num.pyro.ai/en/latest/distributions.html#affinetransform). Note that by doing so, NumPyro runs HMC by generating samples `theta_base` for the base `Normal(0, 1)` distribution instead. We see that the resulting chain does not suffer from the same pathology — the Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good! ```python >>> from numpyro.infer.reparam import TransformReparam @@ -165,23 +167,21 @@ Now, let us assume that we have a new school for which we have not observed any ## More Examples - For some more examples on specifying models and doing inference in NumPyro: - - [Bayesian Regression in NumPyro](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/source/bayesian_regression.ipynb) - Start here to get acquainted with writing a simple model in NumPyro, MCMC inference API, effect handlers and writing custom inference utilities. - - [Time Series Forecasting](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/source/time_series_forecasting.ipynb) - Illustrates how to convert for loops in the model to JAX's `lax.scan` primitive for fast inference. - - [Annotation examples](https://num.pyro.ai/en/stable/examples/annotation.html) - Illustrates how to utilize the enumeration mechanism to perform inference for models with discrete latent variables. - - [Baseball example](https://github.com/pyro-ppl/numpyro/blob/master/examples/baseball.py) - Using NUTS for a simple hierarchical model. Compare this with the baseball example in [Pyro](https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py). - - [Hidden Markov Model](https://github.com/pyro-ppl/numpyro/blob/master/examples/hmm.py) in NumPyro as compared to [Stan](https://mc-stan.org/docs/2_19/stan-users-guide/hmms-section.html). - - [Variational Autoencoder](https://github.com/pyro-ppl/numpyro/blob/master/examples/vae.py) - As a simple example that uses Variational Inference with neural networks. [Pyro implementation](https://github.com/pyro-ppl/pyro/blob/dev/examples/vae/vae.py) for comparison. - - [Gaussian Process](https://github.com/pyro-ppl/numpyro/blob/master/examples/gp.py) - Provides a simple example to use NUTS to sample from the posterior over the hyper-parameters of a Gaussian Process. - - [Horseshoe Regression](https://github.com/pyro-ppl/numpyro/blob/master/examples/horseshoe_regression.py) - Shows how to implemement generalized linear models equipped with a Horseshoe prior for both binary-valued and real-valued outputs. - - [Statistical Rethinking with NumPyro](https://github.com/fehiepsi/rethinking-numpyro) - [Notebooks](https://nbviewer.jupyter.org/github/fehiepsi/rethinking-numpyro/tree/master/notebooks/) containing translation of the code in Richard McElreath's [Statistical Rethinking](https://xcelab.net/rm/statistical-rethinking/) book second version, to NumPyro. - - Other model examples can be found in the [examples](https://num.pyro.ai/en/stable/) site. +- [Bayesian Regression in NumPyro](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/source/bayesian_regression.ipynb) - Start here to get acquainted with writing a simple model in NumPyro, MCMC inference API, effect handlers and writing custom inference utilities. +- [Time Series Forecasting](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/source/time_series_forecasting.ipynb) - Illustrates how to convert for loops in the model to JAX's `lax.scan` primitive for fast inference. +- [Annotation examples](https://num.pyro.ai/en/stable/examples/annotation.html) - Illustrates how to utilize the enumeration mechanism to perform inference for models with discrete latent variables. +- [Baseball example](https://github.com/pyro-ppl/numpyro/blob/master/examples/baseball.py) - Using NUTS for a simple hierarchical model. Compare this with the baseball example in [Pyro](https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py). +- [Hidden Markov Model](https://github.com/pyro-ppl/numpyro/blob/master/examples/hmm.py) in NumPyro as compared to [Stan](https://mc-stan.org/docs/2_19/stan-users-guide/hmms-section.html). +- [Variational Autoencoder](https://github.com/pyro-ppl/numpyro/blob/master/examples/vae.py) - As a simple example that uses Variational Inference with neural networks. [Pyro implementation](https://github.com/pyro-ppl/pyro/blob/dev/examples/vae/vae.py) for comparison. +- [Gaussian Process](https://github.com/pyro-ppl/numpyro/blob/master/examples/gp.py) - Provides a simple example to use NUTS to sample from the posterior over the hyper-parameters of a Gaussian Process. +- [Horseshoe Regression](https://github.com/pyro-ppl/numpyro/blob/master/examples/horseshoe_regression.py) - Shows how to implement generalized linear models equipped with a Horseshoe prior for both binary-valued and real-valued outputs. +- [Statistical Rethinking with NumPyro](https://github.com/fehiepsi/rethinking-numpyro) - [Notebooks](https://nbviewer.jupyter.org/github/fehiepsi/rethinking-numpyro/tree/master/notebooks/) containing translation of the code in Richard McElreath's [Statistical Rethinking](https://xcelab.net/rm/statistical-rethinking/) book second version, to NumPyro. +- Other model examples can be found in the [examples](https://num.pyro.ai/en/stable/) site. Pyro users will note that the API for model specification and inference is largely the same as Pyro, including the distributions API, by design. However, there are some important core differences (reflected in the internals) that users should be aware of. e.g. in NumPyro, there is no global parameter store or random state, to make it possible for us to leverage JAX's JIT compilation. Also, users may need to write their models in a more *functional* style that works better with JAX. Refer to [FAQs](#frequently-asked-questions) for a list of differences. - ## Overview of inference algorithms We provide an overview of most of the inference algorithms supported by NumPyro and offer some guidelines about which inference algorithms may be appropriate for different classes of models. @@ -200,23 +200,25 @@ As discussed above, model [reparameterization](https://num.pyro.ai/en/latest/rep Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see [restrictions](https://pyro.ai/examples/enumeration.html#Restriction-1:-conditional-independence)). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the [annotation example](https://num.pyro.ai/en/stable/examples/annotation.html). ### Nested Sampling + - [NestedSampler](https://num.pyro.ai/en/latest/contrib.html#nested-sampling) offers a wrapper for [jaxns](https://github.com/Joshuaalbert/jaxns). See [JAXNS's readthedocs](https://jaxns.readthedocs.io/en/latest/) for examples and [Nested Sampling for Gaussian Shells](https://num.pyro.ai/en/stable/examples/gaussian_shells.html) example for how to apply the sampler on numpyro models. Can handle arbitrary models, including ones with discrete RVs, and non-invertible transformations. ### Stochastic variational inference + - Variational objectives - - [Trace_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.Trace_ELBO) is our basic ELBO implementation. - - [TraceMeanField_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceMeanField_ELBO) is like `Trace_ELBO` but computes part of the ELBO analytically if doing so is possible. - - [TraceGraph_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceGraph_ELBO) offers variance reduction strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables. - - [TraceEnum_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceEnum_ELBO) offers variable enumeration strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables when enumeration is possible. + - [Trace_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.Trace_ELBO) is our basic ELBO implementation. + - [TraceMeanField_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceMeanField_ELBO) is like `Trace_ELBO` but computes part of the ELBO analytically if doing so is possible. + - [TraceGraph_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceGraph_ELBO) offers variance reduction strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables. + - [TraceEnum_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceEnum_ELBO) offers variable enumeration strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables when enumeration is possible. - Automatic guides (appropriate for models with continuous latent variables) - - [AutoNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoNormal) and [AutoDiagonalNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDiagonalNormal) are our basic mean-field guides. If the latent space is non-euclidean (due to e.g. a positivity constraint on one of the sample sites) an appropriate bijective transformation is automatically used under the hood to map between the unconstrained space (where the Normal variational distribution is defined) to the corresponding constrained space (note this is true for all automatic guides). These guides are a great place to start when trying to get variational inference to work on a model you are developing. - - [AutoMultivariateNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoMultivariateNormal) and [AutoLowRankMultivariateNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLowRankMultivariateNormal) also construct Normal variational distributions but offer more flexibility, as they can capture correlations in the posterior. Note that these guides may be difficult to fit in the high-dimensional setting. - - [AutoDelta](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDelta) is used for computing point estimates via MAP (maximum a posteriori estimation). See [here](https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101) for example usage. - - [AutoBNAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal) and [AutoIAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoIAFNormal) offer flexible variational distributions parameterized by normalizing flows. - - [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDAIS) is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model. - - [AutoSurrogateLikelihoodDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoSurrogateLikelihoodDAIS) is a powerful variational inference algorithm that leverages HMC and that supports data subsampling. - - [AutoSemiDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoSemiDAIS) constructs a posterior approximation like [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDAIS) for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables. - - [AutoLaplaceApproximation](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation) can be used to compute a Laplace approximation. + - [AutoNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoNormal) and [AutoDiagonalNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDiagonalNormal) are our basic mean-field guides. If the latent space is non-euclidean (due to e.g. a positivity constraint on one of the sample sites) an appropriate bijective transformation is automatically used under the hood to map between the unconstrained space (where the Normal variational distribution is defined) to the corresponding constrained space (note this is true for all automatic guides). These guides are a great place to start when trying to get variational inference to work on a model you are developing. + - [AutoMultivariateNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoMultivariateNormal) and [AutoLowRankMultivariateNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLowRankMultivariateNormal) also construct Normal variational distributions but offer more flexibility, as they can capture correlations in the posterior. Note that these guides may be difficult to fit in the high-dimensional setting. + - [AutoDelta](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDelta) is used for computing point estimates via MAP (maximum a posteriori estimation). See [here](https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101) for example usage. + - [AutoBNAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal) and [AutoIAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoIAFNormal) offer flexible variational distributions parameterized by normalizing flows. + - [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDAIS) is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model. + - [AutoSurrogateLikelihoodDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoSurrogateLikelihoodDAIS) is a powerful variational inference algorithm that leverages HMC and that supports data subsampling. + - [AutoSemiDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoSemiDAIS) constructs a posterior approximation like [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDAIS) for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables. + - [AutoLaplaceApproximation](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation) can be used to compute a Laplace approximation. ### Stein Variational Inference @@ -240,9 +242,11 @@ pip install numpyro[cpu] ``` To use **NumPyro on the GPU**, you need to install CUDA first and then use the following pip command: + ``` pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` + If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda). To run **NumPyro on Cloud TPUs**, you can look at some [JAX on Cloud TPU examples](https://github.com/google/jax/tree/master/cloud_tpu_colabs). @@ -263,6 +267,7 @@ pip install -e .[dev] # contains additional dependencies for NumPyro developmen ``` You can also install NumPyro with conda: + ``` conda install -c conda-forge numpyro ``` @@ -308,21 +313,19 @@ conda install -c conda-forge numpyro For most small models, changes required to run inference in NumPyro should be minor. Additionally, we are working on [pyro-api](https://github.com/pyro-ppl/pyro-api) which allows you to write the same code and dispatch it to multiple backends, including NumPyro. This will necessarily be more restrictive, but has the advantage of being backend agnostic. See the [documentation](https://pyro-api.readthedocs.io/en/latest/dispatch.html#module-pyroapi.dispatch) for an example, and let us know your feedback. - 3. How can I contribute to the project? Thanks for your interest in the project! You can take a look at beginner friendly issues that are marked with the [good first issue](https://github.com/pyro-ppl/numpyro/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) tag on Github. Also, please feel to reach out to us on the [forum](https://forum.pyro.ai/). - ## Future / Ongoing Work In the near term, we plan to work on the following. Please open new issues for feature requests and enhancements: - - Improving robustness of inference on different models, profiling and performance tuning. - - Supporting more functionality as part of the [pyro-api](https://github.com/pyro-ppl/pyro-api) generic modeling interface. - - More inference algorithms, particularly those that require second order derivatives or use HMC. - - Integration with [Funsor](https://github.com/pyro-ppl/funsor) to support inference algorithms with delayed sampling. - - Other areas motivated by Pyro's research goals and application focus, and interest from the community. +- Improving robustness of inference on different models, profiling and performance tuning. +- Supporting more functionality as part of the [pyro-api](https://github.com/pyro-ppl/pyro-api) generic modeling interface. +- More inference algorithms, particularly those that require second order derivatives or use HMC. +- Integration with [Funsor](https://github.com/pyro-ppl/funsor) to support inference algorithms with delayed sampling. +- Other areas motivated by Pyro's research goals and application focus, and interest from the community. ## Citing NumPyro diff --git a/docker/README.md b/docker/README.md index da3cd7ff6..844b3c879 100644 --- a/docker/README.md +++ b/docker/README.md @@ -34,5 +34,5 @@ Design Choices: Future Work: - Right now the jax, jaxlib, and numpyro versions are manually specified, so they have to be updated every NumPyro release. There are two ways forward for this: - 1. If there is a CI/CD in place to build and push images to a repository like Dockerhub, then the jax, jaxlib, and numpyro versions can be passed in as environment variables (for example, if something like [Drone CI](http://plugins.drone.io/drone-plugins/drone-docker/) is used). If implemented this way, the jax/jaxlib/numpyro versions will be ephemereal (not stored in source code). + 1. If there is a CI/CD in place to build and push images to a repository like Dockerhub, then the jax, jaxlib, and numpyro versions can be passed in as environment variables (for example, if something like [Drone CI](http://plugins.drone.io/drone-plugins/drone-docker/) is used). If implemented this way, the jax/jaxlib/numpyro versions will be ephemeral (not stored in source code). 2. Alternative, one can create a Python script that will modify the Dockerfiles upon release accordingly (using a hook of some sort). diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index f568d458d..125c11ca4 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -662,6 +662,14 @@ VonMises Truncated Distributions ----------------------- +DoublyTruncatedPowerLaw +^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: numpyro.distributions.truncated.DoublyTruncatedPowerLaw + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + LeftTruncatedDistribution ^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.truncated.LeftTruncatedDistribution @@ -670,6 +678,14 @@ LeftTruncatedDistribution :show-inheritance: :member-order: bysource +LowerTruncatedPowerLaw +^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: numpyro.distributions.truncated.LowerTruncatedPowerLaw + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + RightTruncatedDistribution ^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.truncated.RightTruncatedDistribution diff --git a/examples/annotation.py b/examples/annotation.py index 119409d30..ceb6a81fd 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -9,7 +9,7 @@ All models have discrete latent variables. Under the hood, we enumerate over (marginalize out) those discrete latent sites in inference. Those models have different -complexity so they are great refererences for those who are new to Pyro/NumPyro +complexity so they are great references for those who are new to Pyro/NumPyro enumeration mechanism. We recommend readers compare the implementations with the corresponding plate diagrams in [1] to see how concise a Pyro/NumPyro program is. diff --git a/examples/ar2.py b/examples/ar2.py index a73844c56..83da74a89 100644 --- a/examples/ar2.py +++ b/examples/ar2.py @@ -103,7 +103,7 @@ def run_inference(model, args, rng_key, y): def main(args): - # generate artifical dataset + # generate artificial dataset num_data = args.num_data rng_key = jax.random.PRNGKey(0) t = jnp.arange(0, num_data) diff --git a/examples/holt_winters.py b/examples/holt_winters.py index 946aec23e..c1f88d3a1 100644 --- a/examples/holt_winters.py +++ b/examples/holt_winters.py @@ -147,7 +147,7 @@ def predict(model, args, samples, rng_key, y, n_seasons): def main(args): - # generate artifical dataset + # generate artificial dataset rng_key, _ = random.split(random.PRNGKey(0)) T = args.T t = jnp.linspace(0, T + args.future, (T + args.future) * N_POINTS_PER_UNIT) diff --git a/examples/mortality.py b/examples/mortality.py index 406ece20e..f1c8bf500 100644 --- a/examples/mortality.py +++ b/examples/mortality.py @@ -57,7 +57,7 @@ dimensions of the age, space and time variables. This allows us to efficiently broadcast arrays in the likelihood. -As written above, the model includes a lot of centred random effects. The NUTS alogrithm benefits +As written above, the model includes a lot of centred random effects. The NUTS algorithm benefits from a non-centred reparamatrisation to overcome difficult posterior geometries [2]. Rather than manually writing out the non-centred parametrisation, we make use of the NumPyro's automatic reparametrisation in :class:`~numpyro.infer.reparam.LocScaleReparam`. diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 412dd1976..cb076a5c0 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -95,7 +95,9 @@ from numpyro.distributions.mixtures import Mixture, MixtureGeneral, MixtureSameFamily from numpyro.distributions.transforms import biject_to from numpyro.distributions.truncated import ( + DoublyTruncatedPowerLaw, LeftTruncatedDistribution, + LowerTruncatedPowerLaw, RightTruncatedDistribution, TruncatedCauchy, TruncatedDistribution, @@ -122,6 +124,7 @@ "Binomial", "BinomialLogits", "BinomialProbs", + "CAR", "Categorical", "CategoricalLogits", "CategoricalProbs", @@ -132,9 +135,10 @@ "DirichletMultinomial", "DiscreteUniform", "Distribution", + "DoublyTruncatedPowerLaw", "EulerMaruyama", - "Exponential", "ExpandedDistribution", + "Exponential", "FoldedDistribution", "Gamma", "GammaPoisson", @@ -152,29 +156,29 @@ "Independent", "InverseGamma", "Kumaraswamy", - "LKJ", - "LKJCholesky", "Laplace", "LeftTruncatedDistribution", + "LKJ", + "LKJCholesky", "Logistic", "LogNormal", "LogUniform", - "MatrixNormal", + "LowerTruncatedPowerLaw", + "LowRankMultivariateNormal", "MaskedDistribution", + "MatrixNormal", "Mixture", - "MixtureSameFamily", "MixtureGeneral", + "MixtureSameFamily", "Multinomial", "MultinomialLogits", "MultinomialProbs", "MultivariateNormal", - "CAR", "MultivariateStudentT", - "LowRankMultivariateNormal", - "Normal", - "NegativeBinomialProbs", - "NegativeBinomialLogits", "NegativeBinomial2", + "NegativeBinomialLogits", + "NegativeBinomialProbs", + "Normal", "OrderedLogistic", "Pareto", "Poisson", @@ -199,7 +203,7 @@ "Wishart", "WishartCholesky", "ZeroInflatedDistribution", - "ZeroInflatedPoisson", "ZeroInflatedNegativeBinomial2", + "ZeroInflatedPoisson", "ZeroSumNormal", ] diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 129889434..4ff685174 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -983,7 +983,7 @@ class LKJCholesky(Distribution): r""" LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by ``concentration`` parameter :math:`\eta` to make the probability of the - correlation matrix :math:`M` generated from a Cholesky factor propotional to + correlation matrix :math:`M` generated from a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that, when ``concentration == 1``, we have a uniform distribution over Cholesky factors of correlation matrices. @@ -1048,7 +1048,7 @@ def __init__( # We construct base distributions to generate samples for each method. # The purpose of this base distribution is to generate a distribution for - # correlation matrices which is propotional to `det(M)^{\eta - 1}`. + # correlation matrices which is proportional to `det(M)^{\eta - 1}`. # (note that this is not a unique way to define base distribution) # Both of the following methods have marginal distribution of each off-diagonal # element of sampled correlation matrices is Beta(eta + (D-2) / 2, eta + (D-2) / 2) @@ -1150,12 +1150,12 @@ def log_prob(self, value): # Generally, for a D dimensional matrix, we have: # Jacobian = L22^(D-2) * L33^(D-3) * ... * Ldd^0 # - # From [1], we know that probability of a correlation matrix is propotional to + # From [1], we know that probability of a correlation matrix is proportional to # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1)) # On the other hand, Jabobian of the transformation from Cholesky factor to # correlation matrix is: # prod(L_ii ^ (D - i)) - # So the probability of a Cholesky factor is propotional to + # So the probability of a Cholesky factor is proportional to # prod(L_ii ^ (2 * concentration - 2 + D - i)) =: prod(L_ii ^ order_i) # with order_i = 2 * concentration - 2 + D - i, # i = 2..D (we omit the element i = 1 because L_11 = 1) @@ -1286,7 +1286,7 @@ def entropy(self): def _batch_solve_triangular(A, B): """ Extende solve_triangular for the case that B.ndim > A.ndim. - This is achived by first flattening the leading B.ndim - A.ndim dimensions of B and then + This is achieved by first flattening the leading B.ndim - A.ndim dimensions of B and then moving the first dimension to the end. @@ -1720,7 +1720,7 @@ def log_prob(self, value): D_rsqrt[..., None, :] * D_rsqrt[..., None] ) - # TODO: look into sparse eignvalue methods + # TODO: look into sparse eigenvalue methods if isinstance(adj_matrix_scaled, np.ndarray): lam = np.linalg.eigvalsh(adj_matrix_scaled) else: diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index cabf67c87..ae08eee91 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -519,10 +519,10 @@ def infer_shapes(cls, *args, **kwargs): def cdf(self, value): """ - The cummulative distribution function of this distribution. + The cumulative distribution function of this distribution. :param value: samples from this distribution. - :return: output of the cummulative distribution function evaluated at `value`. + :return: output of the cumulative distribution function evaluated at `value`. """ raise NotImplementedError diff --git a/numpyro/distributions/gof.py b/numpyro/distributions/gof.py index 7ef80b987..b1f917303 100644 --- a/numpyro/distributions/gof.py +++ b/numpyro/distributions/gof.py @@ -158,7 +158,7 @@ def unif01_goodness_of_fit(samples, *, plot=False): def exp_goodness_of_fit(samples, plot=False): """ - Transform exponentially distribued samples to Uniform(0,1) distribution and + Transform exponentially distributed samples to Uniform(0,1) distribution and assess goodness of fit via binned Pearson's chi^2 test. :param numpy.ndarray samples: A vector of real-valued samples from a @@ -353,7 +353,7 @@ def _chi2sf(x, s): F(x; s) = \frac{ \gamma( x/2, s/2 ) }{ \Gamma(s/2) }, with :math:`\gamma` is the incomplete gamma function defined above. - Therefore, the survival probability is givne by: + Therefore, the survival probability is given by: .. math:: 1 - \frac{ \gamma( x/2, s/2 ) }{ \Gamma(s/2) }. diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index ec20d7b53..290e504c1 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -405,7 +405,7 @@ def _matrix_forward_shape(shape, offset=0): N = shape[-1] D = round((0.25 + 2 * N) ** 0.5 - 0.5) if D * (D + 1) // 2 != N: - raise ValueError("Input is not a flattend lower-diagonal number") + raise ValueError("Input is not a flattened lower-diagonal number") D = D - offset return shape[:-1] + (D, D) @@ -447,7 +447,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): class CorrCholeskyTransform(ParameterFreeTransform): r""" - Transforms a uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the + Transforms a unconstrained real vector :math:`x` with length :math:`D*(D-1)/2` into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows: @@ -655,7 +655,7 @@ def __eq__(self, other): class L1BallTransform(ParameterFreeTransform): r""" - Transforms a uncontrained real vector :math:`x` into the unit L1 ball. + Transforms a unconstrained real vector :math:`x` into the unit L1 ball. """ domain = constraints.real_vector diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 236e2a000..fecdff2f7 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -365,3 +365,536 @@ def log_prob(self, value): sum_even = jnp.exp(logsumexp(even_terms, axis=-1)) sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1)) return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi) + + +class DoublyTruncatedPowerLaw(Distribution): + r"""Power law distribution with :math:`\alpha` index, and lower and upper bounds. + We can define the power law distribution as, + + .. math:: + f(x; \alpha, a, b) = \frac{x^{\alpha}}{Z(\alpha, a, b)}, + + where, :math:`a` and :math:`b` are the lower and upper bounds respectively, + and :math:`Z(\alpha, a, b)` is the normalization constant. It is defined as, + + .. math:: + Z(\alpha, a, b) = \begin{cases} + \log(b) - \log(a) & \text{if } \alpha = -1, \\ + \frac{b^{1 + \alpha} - a^{1 + \alpha}}{1 + \alpha} & \text{otherwise}. + \end{cases} + + :param alpha: index of the power law distribution + :param low: lower bound of the distribution + :param high: upper bound of the distribution + """ + + arg_constraints = { + "alpha": constraints.real, + "low": constraints.greater_than_eq(0), + "high": constraints.greater_than(0), + } + reparametrized_params = ["alpha", "low", "high"] + pytree_aux_fields = ("_support",) + pytree_data_fields = ("alpha", "low", "high") + + def __init__(self, alpha, low, high, *, validate_args=None): + self.alpha, self.low, self.high = promote_shapes(alpha, low, high) + self._support = constraints.interval(low, high) + batch_shape = lax.broadcast_shapes( + jnp.shape(alpha), jnp.shape(low), jnp.shape(high) + ) + super(DoublyTruncatedPowerLaw, self).__init__( + batch_shape=batch_shape, validate_args=validate_args + ) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return self._support + + @validate_sample + def log_prob(self, value): + r"""Logarithmic probability distribution: + Z inequal minus one: + .. math:: + (x^\alpha) (\alpha + 1)/(b^(\alpha + 1) - a^(\alpha + 1)) + + Z equal minus one: + .. math:: + (x^\alpha)/(log(b) - log(a)) + Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. + """ + + @jax.custom_jvp + def f(x, alpha, low, high): + neq_neg1_mask = jnp.not_equal(alpha, -1.0) + neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) + # eq_neg1_alpha = jnp.where(~neq_neg1_mask, alpha, -1.0) + + def neq_neg1_fn(): + one_more_alpha = 1.0 + neq_neg1_alpha + return jnp.log( + jnp.power(x, neq_neg1_alpha) + * (one_more_alpha) + / (jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha)) + ) + + def eq_neg1_fn(): + return -jnp.log(x) - jnp.log(jnp.log(high) - jnp.log(low)) + + return jnp.where(neq_neg1_mask, neq_neg1_fn(), eq_neg1_fn()) + + @f.defjvp + def f_jvp(primals, tangents): + x, alpha, low, high = primals + x_t, alpha_t, low_t, high_t = tangents + + log_low = jnp.log(low) + log_high = jnp.log(high) + log_x = jnp.log(x) + + # Mask and alpha values + delta_eq_neg1 = 10e-4 + neq_neg1_mask = jnp.not_equal(alpha, -1.0) + neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) + eq_neg1_alpha = jnp.where(jnp.not_equal(alpha, 0.0), alpha, -1.0) + + primal_out = f(*primals) + + # Alpha tangent with approximation + # Variable part for all values alpha unequal -1 + def alpha_tangent_variable(alpha): + one_more_alpha = 1.0 + alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + return jnp.reciprocal(one_more_alpha) + ( + low_pow_one_more_alpha * log_low + - high_pow_one_more_alpha * log_high + ) / (high_pow_one_more_alpha - low_pow_one_more_alpha) + + # Alpha tangent + alpha_tangent = jnp.where( + neq_neg1_mask, + log_x + alpha_tangent_variable(neq_neg1_alpha), + # Approximate derivate with right an lefthand approximation + log_x + + ( + alpha_tangent_variable(alpha - delta_eq_neg1) + + alpha_tangent_variable(alpha + delta_eq_neg1) + ) + * 0.5, + ) + + # High and low tangents for alpha unequal -1 + one_more_alpha = 1.0 + neq_neg1_alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + change_sq = jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) + low_tangent_neq_neg1_common = ( + jnp.square(one_more_alpha) * jnp.power(x, neq_neg1_alpha) / change_sq + ) + low_tangent_neq_neg1 = low_tangent_neq_neg1_common * jnp.power( + low, neq_neg1_alpha + ) + high_tangent_neq_neg1 = low_tangent_neq_neg1_common * jnp.power( + high, neq_neg1_alpha + ) + + # High and low tangents for alpha equal -1 + low_tangent_eq_neg1_common = jnp.power(x, eq_neg1_alpha) / jnp.square( + log_high - log_low + ) + low_tangent_eq_neg1 = low_tangent_eq_neg1_common / low + high_tangent_eq_neg1 = -low_tangent_eq_neg1_common / high + + # High and low tangents + low_tangent = jnp.where( + neq_neg1_mask, low_tangent_neq_neg1, low_tangent_eq_neg1 + ) + high_tangent = jnp.where( + neq_neg1_mask, high_tangent_neq_neg1, high_tangent_eq_neg1 + ) + + # Final tangents + tangent_out = ( + alpha / x * x_t + + alpha_tangent * alpha_t + + low_tangent * low_t + + high_tangent * high_t + ) + return primal_out, tangent_out + + return f(value, self.alpha, self.low, self.high) + + def cdf(self, value): + r"""Cumulated probability distribution: + Z inequal minus one: + + .. math:: + + \frac{x^{\alpha + 1} - a^{\alpha + 1}}{b^{\alpha + 1} - a^{\alpha + 1}} + + Z equal minus one: + + .. math:: + + \frac{\log(x) - \log(a)}{\log(b) - \log(a)} + + Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. + """ + + @jax.custom_jvp + def f(x, alpha, low, high): + neq_neg1_mask = jnp.not_equal(alpha, -1.0) + neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) + + def cdf_when_alpha_neq_neg1(): + one_more_alpha = 1.0 + neq_neg1_alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + return (jnp.power(x, one_more_alpha) - low_pow_one_more_alpha) / ( + jnp.power(high, one_more_alpha) - low_pow_one_more_alpha + ) + + def cdf_when_alpha_eq_neg1(): + return jnp.log(x / low) / jnp.log(high / low) + + cdf_val = jnp.where( + neq_neg1_mask, + cdf_when_alpha_neq_neg1(), + cdf_when_alpha_eq_neg1(), + ) + return jnp.clip(cdf_val, a_min=0.0, a_max=1.0) + + @f.defjvp + def f_jvp(primals, tangents): + x, alpha, low, high = primals + x_t, alpha_t, low_t, high_t = tangents + + log_low = jnp.log(low) + log_high = jnp.log(high) + log_x = jnp.log(x) + + delta_eq_neg1 = 10e-4 + neq_neg1_mask = jnp.not_equal(alpha, -1.0) + neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) + + # Calculate primal + primal_out = f(*primals) + + # Tangents for alpha not equals -1 + def x_neq_neg1(alpha): + one_more_alpha = 1.0 + alpha + return (one_more_alpha * jnp.power(x, alpha)) / ( + jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha) + ) + + def alpha_neq_neg1(alpha): + one_more_alpha = 1.0 + alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + x_pow_one_more_alpha = jnp.power(x, one_more_alpha) + term1 = ( + x_pow_one_more_alpha * log_x - low_pow_one_more_alpha * log_low + ) / (high_pow_one_more_alpha - low_pow_one_more_alpha) + term2 = ( + (x_pow_one_more_alpha - low_pow_one_more_alpha) + * ( + high_pow_one_more_alpha * log_high + - low_pow_one_more_alpha * log_low + ) + ) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) + return term1 - term2 + + def low_neq_neg1(alpha): + one_more_alpha = 1.0 + alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + x_pow_one_more_alpha = jnp.power(x, one_more_alpha) + change = high_pow_one_more_alpha - low_pow_one_more_alpha + term2 = one_more_alpha * jnp.power(low, alpha) / change + term1 = term2 * (x_pow_one_more_alpha - low_pow_one_more_alpha) / change + return term1 - term2 + + def high_neq_neg1(alpha): + one_more_alpha = 1.0 + alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + x_pow_one_more_alpha = jnp.power(x, one_more_alpha) + return -( + one_more_alpha + * jnp.power(high, alpha) + * (x_pow_one_more_alpha - low_pow_one_more_alpha) + ) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) + + # Tangents for alpha equals -1 + def x_eq_neg1(): + return jnp.reciprocal(x * (log_high - log_low)) + + def low_eq_neg1(): + return (log_x - log_low) / ( + jnp.square(log_high - log_low) * low + ) - jnp.reciprocal((log_high - log_low) * low) + + def high_eq_neg1(): + return (log_x - log_low) / (jnp.square(log_high - log_low) * high) + + # Including approximation for alpha = -1 + tangent_out = ( + jnp.where(neq_neg1_mask, x_neq_neg1(neq_neg1_alpha), x_eq_neg1()) * x_t + + jnp.where( + neq_neg1_mask, + alpha_neq_neg1(neq_neg1_alpha), + ( + alpha_neq_neg1(alpha - delta_eq_neg1) + + alpha_neq_neg1(alpha + delta_eq_neg1) + ) + * 0.5, + ) + * alpha_t + + jnp.where(neq_neg1_mask, low_neq_neg1(neq_neg1_alpha), low_eq_neg1()) + * low_t + + jnp.where( + neq_neg1_mask, high_neq_neg1(neq_neg1_alpha), high_eq_neg1() + ) + * high_t + ) + + return primal_out, tangent_out + + return f(value, self.alpha, self.low, self.high) + + def icdf(self, q): + r"""Inverse cumulated probability distribution: + Z inequal minus one: + + .. math:: + a \left(\frac{b}{a}\right)^{q} + + Z equal minus one: + + .. math:: + \left(a^{1 + \alpha} + q (b^{1 + \alpha} - a^{1 + \alpha})\right)^{\frac{1}{1 + \alpha}} + + Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. + """ + + @jax.custom_jvp + def f(q, alpha, low, high): + neq_neg1_mask = jnp.not_equal(alpha, -1.0) + neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) + + def icdf_alpha_neq_neg1(): + one_more_alpha = 1.0 + neq_neg1_alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + return jnp.power( + low_pow_one_more_alpha + + q * (high_pow_one_more_alpha - low_pow_one_more_alpha), + jnp.reciprocal(one_more_alpha), + ) + + def icdf_alpha_eq_neg1(): + return jnp.power(high / low, q) * low + + icdf_val = jnp.where( + neq_neg1_mask, + icdf_alpha_neq_neg1(), + icdf_alpha_eq_neg1(), + ) + return icdf_val + + @f.defjvp + def f_jvp(primals, tangents): + x, alpha, low, high = primals + x_t, alpha_t, low_t, high_t = tangents + + log_low = jnp.log(low) + log_high = jnp.log(high) + high_over_low = jnp.divide(high, low) + + delta_eq_neg1 = 10e-4 + neq_neg1_mask = jnp.not_equal(alpha, -1.0) + neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) + + primal_out = f(*primals) + + # Tangents for alpha not equal -1 + def x_neq_neg1(alpha): + one_more_alpha = 1.0 + alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + change = high_pow_one_more_alpha - low_pow_one_more_alpha + return ( + change + * jnp.power( + low_pow_one_more_alpha + x * change, + jnp.reciprocal(one_more_alpha) - 1, + ) + ) / one_more_alpha + + def alpha_neq_neg1(alpha): + one_more_alpha = 1.0 + alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + factor0 = low_pow_one_more_alpha + x * ( + high_pow_one_more_alpha - low_pow_one_more_alpha + ) + term1 = jnp.power(factor0, jnp.reciprocal(one_more_alpha)) + term2 = ( + low_pow_one_more_alpha * log_low + + x + * ( + high_pow_one_more_alpha * log_high + - low_pow_one_more_alpha * log_low + ) + ) / (one_more_alpha * factor0) + term3 = jnp.log(factor0) / jnp.square(one_more_alpha) + return term1 * (term2 - term3) + + def low_neq_neg1(alpha): + one_more_alpha = 1.0 + alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + return ( + (1.0 - x) + * jnp.power(low, alpha) + * jnp.power( + low_pow_one_more_alpha + + x * (high_pow_one_more_alpha - low_pow_one_more_alpha), + jnp.reciprocal(one_more_alpha) - 1, + ) + ) + + def high_neq_neg1(alpha): + one_more_alpha = 1.0 + alpha + low_pow_one_more_alpha = jnp.power(low, one_more_alpha) + high_pow_one_more_alpha = jnp.power(high, one_more_alpha) + return ( + x + * jnp.power(high, alpha) + * jnp.power( + low_pow_one_more_alpha + + x * (high_pow_one_more_alpha - low_pow_one_more_alpha), + jnp.reciprocal(one_more_alpha) - 1, + ) + ) + + # Tangents for alpha equals -1 + def dx_eq_neg1(): + return low * jnp.power(high_over_low, x) * (log_high - log_low) + + def low_eq_neg1(): + return ( + jnp.power(high_over_low, x) + - (high * x * jnp.power(high_over_low, x - 1)) / low + ) + + def high_eq_neg1(): + return x * jnp.power(high_over_low, x - 1) + + # Including approximation for alpha = -1 \ + tangent_out = ( + jnp.where(neq_neg1_mask, x_neq_neg1(neq_neg1_alpha), dx_eq_neg1()) * x_t + + jnp.where( + neq_neg1_mask, + alpha_neq_neg1(neq_neg1_alpha), + ( + alpha_neq_neg1(alpha - delta_eq_neg1) + + alpha_neq_neg1(alpha + delta_eq_neg1) + ) + * 0.5, + ) + * alpha_t + + jnp.where(neq_neg1_mask, low_neq_neg1(neq_neg1_alpha), low_eq_neg1()) + * low_t + + jnp.where( + neq_neg1_mask, high_neq_neg1(neq_neg1_alpha), high_eq_neg1() + ) + * high_t + ) + + return primal_out, tangent_out + + return f(q, self.alpha, self.low, self.high) + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + u = random.uniform(key, sample_shape + self.batch_shape) + samples = self.icdf(u) + return samples + + +class LowerTruncatedPowerLaw(Distribution): + r"""Lower truncated power law distribution with :math:`\alpha` index. + We can define the power law distribution as, + + .. math:: + f(x; \alpha, a) = (-\alpha-1)a^{-\alpha - 1}x^{-\alpha}, + \qquad x \geq a, \qquad \alpha < -1, + + where, :math:`a` is the lower bound. The cdf of the distribution is given by, + + .. math:: + F(x; \alpha, a) = 1 - \left(\frac{x}{a}\right)^{1+\alpha}. + + The k-th moment of the distribution is given by, + + .. math:: + E[X^k] = \begin{cases} + \frac{-\alpha-1}{-\alpha-1-k}a^k & \text{if } k < -\alpha-1, \\ + \infty & \text{otherwise}. + \end{cases} + + :param alpha: index of the power law distribution + :param low: lower bound of the distribution + """ + + arg_constraints = { + "alpha": constraints.less_than(-1.0), + "low": constraints.greater_than(0.0), + } + reparametrized_params = ["alpha", "low"] + pytree_aux_fields = ("_support",) + + def __init__(self, alpha, low, *, validate_args=None): + self.alpha, self.low = promote_shapes(alpha, low) + batch_shape = lax.broadcast_shapes(jnp.shape(alpha), jnp.shape(low)) + self._support = constraints.greater_than(low) + super(LowerTruncatedPowerLaw, self).__init__( + batch_shape=batch_shape, validate_args=validate_args + ) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return self._support + + @validate_sample + def log_prob(self, value): + one_more_alpha = 1.0 + self.alpha + return ( + self.alpha * jnp.log(value) + + jnp.log(-one_more_alpha) + - one_more_alpha * jnp.log(self.low) + ) + + def cdf(self, value): + cdf_val = jnp.where( + jnp.less_equal(value, self.low), + jnp.zeros_like(value), + 1.0 - jnp.power(value / self.low, 1.0 + self.alpha), + ) + return cdf_val + + def icdf(self, q): + nan_mask = jnp.logical_or(jnp.isnan(q), jnp.less(q, 0.0)) + nan_mask = jnp.logical_or(nan_mask, jnp.greater(q, 1.0)) + return jnp.where( + nan_mask, + jnp.nan, + self.low * jnp.power(1.0 - q, jnp.reciprocal(1.0 + self.alpha)), + ) + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + u = random.uniform(key, sample_shape + self.batch_shape) + samples = self.icdf(u) + return samples diff --git a/numpyro/infer/barker.py b/numpyro/infer/barker.py index 9d5fa0f2b..52babdf11 100644 --- a/numpyro/infer/barker.py +++ b/numpyro/infer/barker.py @@ -87,7 +87,7 @@ class BarkerMH(MCMCKernel): :param bool dense_mass: Whether to use a dense (i.e. full-rank) or diagonal mass matrix. (defaults to ``dense_mass=False``). :param float target_accept_prob: The target acceptance probability that is used to guide - step size adapation. Defaults to ``target_accept_prob=0.4``. + step size adaptation. Defaults to ``target_accept_prob=0.4``. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index fd387fb8d..c0aaa3118 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -93,7 +93,7 @@ def loss_with_mutable_state( the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). - :return: dictionay containing ELBO loss and the mutable state + :return: dictionary containing ELBO loss and the mutable state """ raise NotImplementedError("This ELBO objective does not support mutable state.") diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 033b2c7b0..64a2870ad 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -248,7 +248,7 @@ class AIES(EnsembleSampler): Defaults to False. :param moves: a dictionary mapping moves to their respective probabilities of being selected. Valid keys are `AIES.DEMove()` and `AIES.StretchMove()`. Both tend to work well in practice. - If the sum of probabilites exceeds 1, the probabilities will be normalized. Defaults to `{AIES.DEMove(): 1.0}`. + If the sum of probabilities exceeds 1, the probabilities will be normalized. Defaults to `{AIES.DEMove(): 1.0}`. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. @@ -455,7 +455,7 @@ class ESS(EnsembleSampler): :param bool randomize_split: whether or not to permute the chain order at each iteration. Defaults to True. :param moves: a dictionary mapping moves to their respective probabilities of being selected. - If the sum of probabilites exceeds 1, the probabilities will be normalized. Valid keys include: + If the sum of probabilities exceeds 1, the probabilities will be normalized. Valid keys include: `ESS.DifferentialMove()` -> default proposal, works well along a wide range of target distributions, `ESS.GaussianMove()` -> for approximately normally distributed targets, `ESS.KDEMove()` -> for multimodal posteriors - requires large `num_chains`, and they must be well initialized diff --git a/numpyro/infer/inspect.py b/numpyro/infer/inspect.py index 6b8d5d058..7b999ede5 100644 --- a/numpyro/infer/inspect.py +++ b/numpyro/infer/inspect.py @@ -229,7 +229,7 @@ def model_3(): for d, upstreams in prior_dependencies.items(): for u, p in upstreams.items(): if u not in observed: - # Note the folowing reverses: + # Note the following reverses: # u is henceforth downstream and d is henceforth upstream. posterior_dependencies[u][d] = p.copy() @@ -547,7 +547,7 @@ def render_graph(graph_specification, render_distributions=False): shape = "plain" rv_label = rv.replace( "$params", "" - ) # incase of neural network parameters + ) # in case of neural network parameters # use different symbol for Deterministic site node_style = ( diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index e18b35bc0..81efb1cbf 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -769,7 +769,7 @@ def print_summary(self, prob=0.9, exclude_deterministic=True): def transfer_states_to_host(self): """ - Reduce the memory footprint of collected samples by transfering them to the host device. + Reduce the memory footprint of collected samples by transferring them to the host device. """ self._states = device_get(self._states) self._states_flat = device_get(self._get_states_flat()) diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py index 3e3d2ae59..2960e0216 100644 --- a/numpyro/infer/mixed_hmc.py +++ b/numpyro/infer/mixed_hmc.py @@ -104,7 +104,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): find_reasonable_step_size=None, ) - # In HMC, when `hmc_state.r` is not None, we will skip drawing a random momemtum at the + # In HMC, when `hmc_state.r` is not None, we will skip drawing a random momentum at the # beginning of an HMC step. The reason is we need to maintain `r` between each sub-trajectories. r = momentum_generator( state.hmc_state.z, state.hmc_state.adapt_state.mass_matrix_sqrt, rng_r diff --git a/numpyro/optim.py b/numpyro/optim.py index 0abc90ee3..eab3f594a 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -222,7 +222,7 @@ def __init__(self, *args, **kwargs): # TODO: currently, jax.scipy.optimize.minimize only supports 1D input, # so we need to add the following mechanism to transform params to flat_params -# and pass `unravel_fn` arround. +# and pass `unravel_fn` around. # When arbitrary pytree is supported in JAX, we can just simply use # identity functions for `init_fn` and `get_params`. _MinimizeState = namedtuple("MinimizeState", ["flat_params", "unravel_fn"]) diff --git a/test/contrib/test_infer_discrete.py b/test/contrib/test_infer_discrete.py index 69aa3f773..c37f4efae 100644 --- a/test/contrib/test_infer_discrete.py +++ b/test/contrib/test_infer_discrete.py @@ -408,7 +408,7 @@ def test_mcmc_model_side_enumeration(model, temperature): k: v[0] for k, v in mcmc.get_samples().items() if k in ["loc", "scale"] } - # MAP estimate discretes, conditioned on posterior sampled continous latents. + # MAP estimate discrete, conditioned on posterior sampled continuous latents. model = handlers.seed(model, rng_seed=1) actual_trace = handlers.trace( infer_discrete( diff --git a/test/test_distributions.py b/test/test_distributions.py index f074eae0d..b0c379290 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -914,6 +914,11 @@ def get_sp_dist(jax_dist): ] ), # Covariance ), + T(dist.LowerTruncatedPowerLaw, -np.pi, np.array([2.0, 5.0, 10.0, 50.0])), + T(dist.DoublyTruncatedPowerLaw, -1.0, 1.0, 2.0), + T(dist.DoublyTruncatedPowerLaw, np.pi, 5.0, 50.0), + T(dist.DoublyTruncatedPowerLaw, -1.0, 5.0, 50.0), + T(dist.DoublyTruncatedPowerLaw, np.pi, 1.0, 2.0), ] DIRECTIONAL = [ @@ -1063,6 +1068,8 @@ def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): return random.bernoulli(key, shape=size) elif isinstance(constraint, constraints.greater_than): return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps + elif isinstance(constraint, constraints.less_than): + return constraint.upper_bound - jnp.exp(random.normal(key, size)) - eps elif isinstance(constraint, constraints.integer_interval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) upper_bound = jnp.broadcast_to(constraint.upper_bound, size) @@ -1129,6 +1136,8 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints.greater_than): return constraint.lower_bound - jnp.exp(random.normal(key, size)) + elif isinstance(constraint, constraints.less_than): + return constraint.upper_bound + jnp.exp(random.normal(key, size)) elif isinstance(constraint, constraints.integer_interval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) @@ -1324,6 +1333,12 @@ def test_sample_gradient(jax_dist, sp_dist, params): "StudentT": ["df"], }.get(jax_dist.__name__, []) + if ( + jax_dist in [dist.DoublyTruncatedPowerLaw] + and jnp.result_type(float) == jnp.float32 + ): + pytest.skip("DoublyTruncatedPowerLaw is tested with x64 only.") + dist_args = [ p for p in ( @@ -1739,14 +1754,7 @@ def test_zero_inflated_logits_probs_agree(): gate_probs = expit(gate_logits) zi_logits = dist.ZeroInflatedDistribution(d, gate_logits=gate_logits) zi_probs = dist.ZeroInflatedDistribution(d, gate=gate_probs) - sample = np.random.randint( - 0, - 20, - ( - 1000, - 100, - ), - ) + sample = np.random.randint(0, 20, (1000, 100)) assert_allclose(zi_probs.log_prob(sample), zi_logits.log_prob(sample)) @@ -1830,6 +1838,11 @@ def test_log_prob_gradient(jax_dist, sp_dist, params): pytest.skip("we have separated tests for LKJCholesky distribution") if jax_dist is _ImproperWrapper: pytest.skip("no param for ImproperUniform to test for log_prob gradient") + if ( + jax_dist in [dist.DoublyTruncatedPowerLaw] + and jnp.result_type(float) == jnp.float32 + ): + pytest.skip("DoublyTruncatedPowerLaw is tested with x64 only.") rng_key = random.PRNGKey(0) value = jax_dist(*params).sample(rng_key) @@ -1852,6 +1865,8 @@ def fn(*args): params[i], dist.Distribution ): # skip taking grad w.r.t. base_dist continue + if jax_dist is dist.DoublyTruncatedPowerLaw and i != 0: + continue if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64): continue actual_grad = jax.grad(fn, i)(*params) @@ -1893,6 +1908,10 @@ def test_mean_var(jax_dist, sp_dist, params): pytest.skip("Truncated distributions do not has mean/var implemented") if jax_dist is dist.ProjectedNormal: pytest.skip("Mean is defined in submanifold") + if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: + pytest.skip( + f"{jax_dist.__name__} distribution does not has mean/var implemented" + ) n = ( 20000 @@ -2053,6 +2072,7 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): _General2DMixture, ): pytest.skip(f"{jax_dist.__name__} is a function, not a class") + dist_args = [p for p in inspect.getfullargspec(jax_dist.__init__)[0][1:]] valid_params, oob_params = list(params), list(params) @@ -3288,7 +3308,7 @@ def sample(d: dist.Distribution): def test_vmap_validate_args(): - # Test for #1684: vmapping distributions whould work when `validate_args=True` + # Test for #1684: vmapping distributions would work when `validate_args=True` v_dist = jax.vmap( lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=True), in_axes=(0, 0), @@ -3371,7 +3391,7 @@ def _assert_not_jax_issue_19885( ) -> None: # jit-ing identity plus matrix multiplication leads to performance degradation as # discussed in https://github.com/google/jax/issues/19885. This assertion verifies - # that the issue does not affect perforance in numpyro. + # that the issue does not affect performance in numpyro. for jit in [True, False]: result = jax.jit(func)(*args, **kwargs) block_until_ready = getattr(result, "block_until_ready", None)