Skip to content

Commit

Permalink
gh-1806: Implementation of Doubly Truncated Power Law and Lower Trunc…
Browse files Browse the repository at this point in the history
…ated Power Law (#1807)

* implementation of DoublyTruncatedPowerLaw

* implementation of LowerTruncatedPowerLaw

* chore: mathematical description in docstrings

* chore: mathematical details of `LowerTruncatedPowerLaw`

* chore: Fix bug in DoublyTruncatedPowerLaw cdf and icdf calculation

* chore: Refactor mean and variance calculation by using kth-moment in DoublyTruncatedPowerLaw

* chore: Refactor mean and variance calculation in LowerTruncatedPowerLaw

* chore: masking in icdf of LowerTruncatedPowerLaw

* chore: entropy of LowerTruncatedPowerLaw

* chore: `lax.sqaure` replaced with `jnp.sqaure`

* chore: moments and entropy were extra and removed

* chore: unit tests

* fix: nan gradients fixed, values still diverging

* Updated UpperTruncatedPowerLaw with adequate derivations, including fixing the discontinuity point for alpha equals minus one and correct tangents for lower and upper bounds.

* Changed constrains of alpha of LowerTruncatedPowerLaw to the smaller minus one and also changed equation to keep equational uniformity UpperTruncatedPowerLaw to and improved stability of the calculation by integrating formula parts inside log to prevent NaN values due to negative values of some parts that could balance out in the end.

* chore: code and docstring formated

* chore: equation refactor and simplified

* chore: equation refactor and simplified

* chore: use numpy arrays and numpy constants

* chore: high precision computation enable for powerlaws

* chore: `__name__` attribute calls removed

* chore: powerlaws shifted with truncated distributions

* chore: spelling mistakes fixed with code spell checker pre-commit hook

* fix typo: perforance->perforamce->performance

* chore: explicit enabling/disabling of 64bit floating point numbers

* chore: disable everytime and enable x64 for power laws

* chore: disable x64 for every test

* chore: linked explanation in comments for disabling x64 for future reference for devs

* chore: high precision test handeled efficiently for DoublyTruncatedPowerLaw

* chore: high precision exception handled in test_log_prob_gradient

---------

Co-authored-by: David Ziegler <[email protected]>
  • Loading branch information
Qazalbash and InfinityMod authored Sep 17, 2024
1 parent ca8fb39 commit 7d50393
Show file tree
Hide file tree
Showing 24 changed files with 672 additions and 85 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ jobs:
- name: Test with pytest
run: |
CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k powerLaw
test-inference:
Expand Down
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ repos:
- id: check-yaml
- id: check-added-large-files
exclude: notebooks/*

- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
hooks:
- id: codespell
stages: [commit, commit-msg]
args:
[--ignore-words-list, "Teh,aas", --check-filenames, --skip, "*.ipynb"]
75 changes: 39 additions & 36 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ Design Choices:
Future Work:

- Right now the jax, jaxlib, and numpyro versions are manually specified, so they have to be updated every NumPyro release. There are two ways forward for this:
1. If there is a CI/CD in place to build and push images to a repository like Dockerhub, then the jax, jaxlib, and numpyro versions can be passed in as environment variables (for example, if something like [Drone CI](http://plugins.drone.io/drone-plugins/drone-docker/) is used). If implemented this way, the jax/jaxlib/numpyro versions will be ephemereal (not stored in source code).
1. If there is a CI/CD in place to build and push images to a repository like Dockerhub, then the jax, jaxlib, and numpyro versions can be passed in as environment variables (for example, if something like [Drone CI](http://plugins.drone.io/drone-plugins/drone-docker/) is used). If implemented this way, the jax/jaxlib/numpyro versions will be ephemeral (not stored in source code).
2. Alternative, one can create a Python script that will modify the Dockerfiles upon release accordingly (using a hook of some sort).
16 changes: 16 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,14 @@ VonMises
Truncated Distributions
-----------------------

DoublyTruncatedPowerLaw
^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.truncated.DoublyTruncatedPowerLaw
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

LeftTruncatedDistribution
^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.truncated.LeftTruncatedDistribution
Expand All @@ -670,6 +678,14 @@ LeftTruncatedDistribution
:show-inheritance:
:member-order: bysource

LowerTruncatedPowerLaw
^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.truncated.LowerTruncatedPowerLaw
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

RightTruncatedDistribution
^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.truncated.RightTruncatedDistribution
Expand Down
2 changes: 1 addition & 1 deletion examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
All models have discrete latent variables. Under the hood, we enumerate over
(marginalize out) those discrete latent sites in inference. Those models have different
complexity so they are great refererences for those who are new to Pyro/NumPyro
complexity so they are great references for those who are new to Pyro/NumPyro
enumeration mechanism. We recommend readers compare the implementations with the
corresponding plate diagrams in [1] to see how concise a Pyro/NumPyro program is.
Expand Down
2 changes: 1 addition & 1 deletion examples/ar2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def run_inference(model, args, rng_key, y):


def main(args):
# generate artifical dataset
# generate artificial dataset
num_data = args.num_data
rng_key = jax.random.PRNGKey(0)
t = jnp.arange(0, num_data)
Expand Down
2 changes: 1 addition & 1 deletion examples/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def predict(model, args, samples, rng_key, y, n_seasons):


def main(args):
# generate artifical dataset
# generate artificial dataset
rng_key, _ = random.split(random.PRNGKey(0))
T = args.T
t = jnp.linspace(0, T + args.future, (T + args.future) * N_POINTS_PER_UNIT)
Expand Down
2 changes: 1 addition & 1 deletion examples/mortality.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
dimensions of the age, space and time variables. This allows us to efficiently broadcast arrays
in the likelihood.
As written above, the model includes a lot of centred random effects. The NUTS alogrithm benefits
As written above, the model includes a lot of centred random effects. The NUTS algorithm benefits
from a non-centred reparamatrisation to overcome difficult posterior geometries [2]. Rather than
manually writing out the non-centred parametrisation, we make use of the NumPyro's automatic
reparametrisation in :class:`~numpyro.infer.reparam.LocScaleReparam`.
Expand Down
26 changes: 15 additions & 11 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@
from numpyro.distributions.mixtures import Mixture, MixtureGeneral, MixtureSameFamily
from numpyro.distributions.transforms import biject_to
from numpyro.distributions.truncated import (
DoublyTruncatedPowerLaw,
LeftTruncatedDistribution,
LowerTruncatedPowerLaw,
RightTruncatedDistribution,
TruncatedCauchy,
TruncatedDistribution,
Expand All @@ -122,6 +124,7 @@
"Binomial",
"BinomialLogits",
"BinomialProbs",
"CAR",
"Categorical",
"CategoricalLogits",
"CategoricalProbs",
Expand All @@ -132,9 +135,10 @@
"DirichletMultinomial",
"DiscreteUniform",
"Distribution",
"DoublyTruncatedPowerLaw",
"EulerMaruyama",
"Exponential",
"ExpandedDistribution",
"Exponential",
"FoldedDistribution",
"Gamma",
"GammaPoisson",
Expand All @@ -152,29 +156,29 @@
"Independent",
"InverseGamma",
"Kumaraswamy",
"LKJ",
"LKJCholesky",
"Laplace",
"LeftTruncatedDistribution",
"LKJ",
"LKJCholesky",
"Logistic",
"LogNormal",
"LogUniform",
"MatrixNormal",
"LowerTruncatedPowerLaw",
"LowRankMultivariateNormal",
"MaskedDistribution",
"MatrixNormal",
"Mixture",
"MixtureSameFamily",
"MixtureGeneral",
"MixtureSameFamily",
"Multinomial",
"MultinomialLogits",
"MultinomialProbs",
"MultivariateNormal",
"CAR",
"MultivariateStudentT",
"LowRankMultivariateNormal",
"Normal",
"NegativeBinomialProbs",
"NegativeBinomialLogits",
"NegativeBinomial2",
"NegativeBinomialLogits",
"NegativeBinomialProbs",
"Normal",
"OrderedLogistic",
"Pareto",
"Poisson",
Expand All @@ -199,7 +203,7 @@
"Wishart",
"WishartCholesky",
"ZeroInflatedDistribution",
"ZeroInflatedPoisson",
"ZeroInflatedNegativeBinomial2",
"ZeroInflatedPoisson",
"ZeroSumNormal",
]
12 changes: 6 additions & 6 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ class LKJCholesky(Distribution):
r"""
LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is
controlled by ``concentration`` parameter :math:`\eta` to make the probability of the
correlation matrix :math:`M` generated from a Cholesky factor propotional to
correlation matrix :math:`M` generated from a Cholesky factor proportional to
:math:`\det(M)^{\eta - 1}`. Because of that, when ``concentration == 1``, we have a
uniform distribution over Cholesky factors of correlation matrices.
Expand Down Expand Up @@ -1048,7 +1048,7 @@ def __init__(

# We construct base distributions to generate samples for each method.
# The purpose of this base distribution is to generate a distribution for
# correlation matrices which is propotional to `det(M)^{\eta - 1}`.
# correlation matrices which is proportional to `det(M)^{\eta - 1}`.
# (note that this is not a unique way to define base distribution)
# Both of the following methods have marginal distribution of each off-diagonal
# element of sampled correlation matrices is Beta(eta + (D-2) / 2, eta + (D-2) / 2)
Expand Down Expand Up @@ -1150,12 +1150,12 @@ def log_prob(self, value):
# Generally, for a D dimensional matrix, we have:
# Jacobian = L22^(D-2) * L33^(D-3) * ... * Ldd^0
#
# From [1], we know that probability of a correlation matrix is propotional to
# From [1], we know that probability of a correlation matrix is proportional to
# determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
# On the other hand, Jabobian of the transformation from Cholesky factor to
# correlation matrix is:
# prod(L_ii ^ (D - i))
# So the probability of a Cholesky factor is propotional to
# So the probability of a Cholesky factor is proportional to
# prod(L_ii ^ (2 * concentration - 2 + D - i)) =: prod(L_ii ^ order_i)
# with order_i = 2 * concentration - 2 + D - i,
# i = 2..D (we omit the element i = 1 because L_11 = 1)
Expand Down Expand Up @@ -1286,7 +1286,7 @@ def entropy(self):
def _batch_solve_triangular(A, B):
"""
Extende solve_triangular for the case that B.ndim > A.ndim.
This is achived by first flattening the leading B.ndim - A.ndim dimensions of B and then
This is achieved by first flattening the leading B.ndim - A.ndim dimensions of B and then
moving the first dimension to the end.
Expand Down Expand Up @@ -1720,7 +1720,7 @@ def log_prob(self, value):
D_rsqrt[..., None, :] * D_rsqrt[..., None]
)

# TODO: look into sparse eignvalue methods
# TODO: look into sparse eigenvalue methods
if isinstance(adj_matrix_scaled, np.ndarray):
lam = np.linalg.eigvalsh(adj_matrix_scaled)
else:
Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,10 @@ def infer_shapes(cls, *args, **kwargs):

def cdf(self, value):
"""
The cummulative distribution function of this distribution.
The cumulative distribution function of this distribution.
:param value: samples from this distribution.
:return: output of the cummulative distribution function evaluated at `value`.
:return: output of the cumulative distribution function evaluated at `value`.
"""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/gof.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def unif01_goodness_of_fit(samples, *, plot=False):

def exp_goodness_of_fit(samples, plot=False):
"""
Transform exponentially distribued samples to Uniform(0,1) distribution and
Transform exponentially distributed samples to Uniform(0,1) distribution and
assess goodness of fit via binned Pearson's chi^2 test.
:param numpy.ndarray samples: A vector of real-valued samples from a
Expand Down Expand Up @@ -353,7 +353,7 @@ def _chi2sf(x, s):
F(x; s) = \frac{ \gamma( x/2, s/2 ) }{ \Gamma(s/2) },
with :math:`\gamma` is the incomplete gamma function defined above.
Therefore, the survival probability is givne by:
Therefore, the survival probability is given by:
.. math::
1 - \frac{ \gamma( x/2, s/2 ) }{ \Gamma(s/2) }.
Expand Down
6 changes: 3 additions & 3 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _matrix_forward_shape(shape, offset=0):
N = shape[-1]
D = round((0.25 + 2 * N) ** 0.5 - 0.5)
if D * (D + 1) // 2 != N:
raise ValueError("Input is not a flattend lower-diagonal number")
raise ValueError("Input is not a flattened lower-diagonal number")
D = D - offset
return shape[:-1] + (D, D)

Expand Down Expand Up @@ -447,7 +447,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):

class CorrCholeskyTransform(ParameterFreeTransform):
r"""
Transforms a uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
Transforms a unconstrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
triangular matrix with positive diagonals and unit Euclidean norm for each row.
The transform is processed as follows:
Expand Down Expand Up @@ -655,7 +655,7 @@ def __eq__(self, other):

class L1BallTransform(ParameterFreeTransform):
r"""
Transforms a uncontrained real vector :math:`x` into the unit L1 ball.
Transforms a unconstrained real vector :math:`x` into the unit L1 ball.
"""

domain = constraints.real_vector
Expand Down
Loading

0 comments on commit 7d50393

Please sign in to comment.