From f44752171d92dd9f5675feae0f5e6e64adaf9431 Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 3 Jun 2021 10:19:26 +0200 Subject: [PATCH 01/23] Added BvM. --- examples/gp.py | 10 +- numpyro/__init__.py | 2 +- numpyro/compat/pyro.py | 4 +- numpyro/contrib/control_flow/scan.py | 18 +- numpyro/contrib/funsor/__init__.py | 16 +- numpyro/contrib/funsor/discrete.py | 3 +- numpyro/contrib/funsor/enum_messenger.py | 6 +- numpyro/contrib/funsor/infer_util.py | 8 +- numpyro/contrib/tfp/distributions.py | 6 +- numpyro/distributions/__init__.py | 10 +- numpyro/distributions/conjugate.py | 7 +- numpyro/distributions/continuous.py | 9 +- numpyro/distributions/directional.py | 236 +++++++++++++++++- numpyro/distributions/discrete.py | 2 +- numpyro/distributions/distribution.py | 7 +- numpyro/distributions/kl.py | 2 +- numpyro/distributions/transforms.py | 7 +- numpyro/distributions/truncated.py | 16 +- numpyro/infer/__init__.py | 2 +- numpyro/infer/autoguide.py | 13 +- numpyro/infer/hmc.py | 2 +- numpyro/infer/hmc_gibbs.py | 12 +- .../contrib/einstein/test_einstein_kernels.py | 2 +- test/contrib/test_funsor.py | 3 +- test/contrib/test_module.py | 2 +- test/infer/test_autoguide.py | 2 +- test/infer/test_hmc_util.py | 2 +- test/infer/test_infer_util.py | 4 +- test/infer/test_reparam.py | 7 +- test/test_diagnostics.py | 2 +- test/test_distributions.py | 4 +- test/test_distributions_util.py | 2 +- test/test_example_utils.py | 9 +- test/test_flows.py | 5 +- test/test_pickle.py | 11 +- 35 files changed, 290 insertions(+), 163 deletions(-) diff --git a/examples/gp.py b/examples/gp.py index 11b7a4d3a..0b3f7b7a1 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -27,15 +27,7 @@ import numpyro import numpyro.distributions as dist -from numpyro.infer import ( - MCMC, - NUTS, - init_to_feasible, - init_to_median, - init_to_sample, - init_to_uniform, - init_to_value, -) +from numpyro.infer import MCMC, NUTS, init_to_feasible, init_to_median, init_to_sample, init_to_uniform, init_to_value matplotlib.use("Agg") # noqa: E402 diff --git a/numpyro/__init__.py b/numpyro/__init__.py index 4fc884250..8ff9bb4ef 100644 --- a/numpyro/__init__.py +++ b/numpyro/__init__.py @@ -17,7 +17,7 @@ plate_stack, prng_key, sample, - subsample, + subsample ) from numpyro.util import enable_x64, set_host_device_count, set_platform from numpyro.version import __version__ diff --git a/numpyro/compat/pyro.py b/numpyro/compat/pyro.py index 47f805596..b317a8d06 100644 --- a/numpyro/compat/pyro.py +++ b/numpyro/compat/pyro.py @@ -4,7 +4,9 @@ import warnings from numpyro.compat.util import UnsupportedAPIWarning -from numpyro.primitives import module, param as _param, plate, sample # noqa: F401 +from numpyro.primitives import module +from numpyro.primitives import param as _param # noqa: F401 +from numpyro.primitives import plate, sample _PARAM_STORE = {} diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 57165b834..fac5368b5 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -4,15 +4,7 @@ from collections import OrderedDict from functools import partial -from jax import ( - device_put, - lax, - random, - tree_flatten, - tree_map, - tree_multimap, - tree_unflatten, -) +from jax import device_put, lax, random, tree_flatten, tree_map, tree_multimap, tree_unflatten import jax.numpy as jnp from numpyro import handlers @@ -115,12 +107,8 @@ def scan_enum( history=1, first_available_dim=None, ): - from numpyro.contrib.funsor import ( - config_enumerate, - enum, - markov, - trace as packed_trace, - ) + from numpyro.contrib.funsor import config_enumerate, enum, markov + from numpyro.contrib.funsor import trace as packed_trace # amount number of steps to unroll history = min(history, length) diff --git a/numpyro/contrib/funsor/__init__.py b/numpyro/contrib/funsor/__init__.py index 53027ae12..999d6d222 100644 --- a/numpyro/contrib/funsor/__init__.py +++ b/numpyro/contrib/funsor/__init__.py @@ -12,20 +12,8 @@ ) from e from numpyro.contrib.funsor.discrete import infer_discrete -from numpyro.contrib.funsor.enum_messenger import ( - enum, - infer_config, - markov, - plate, - to_data, - to_funsor, - trace, -) -from numpyro.contrib.funsor.infer_util import ( - config_enumerate, - log_density, - plate_to_enum_plate, -) +from numpyro.contrib.funsor.enum_messenger import enum, infer_config, markov, plate, to_data, to_funsor, trace +from numpyro.contrib.funsor.infer_util import config_enumerate, log_density, plate_to_enum_plate funsor.set_backend("jax") diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index 59ed4513a..a6b706812 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -7,7 +7,8 @@ from jax import random import funsor -from numpyro.contrib.funsor.enum_messenger import enum, trace as packed_trace +from numpyro.contrib.funsor.enum_messenger import enum +from numpyro.contrib.funsor.enum_messenger import trace as packed_trace from numpyro.contrib.funsor.infer_util import plate_to_enum_plate from numpyro.distributions.util import is_identically_one from numpyro.handlers import block, replay, seed, trace diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index c84f25b05..64d27717f 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -9,8 +9,10 @@ import jax.numpy as jnp import funsor -from numpyro.handlers import infer_config, trace as OrigTraceMessenger -from numpyro.primitives import Messenger, apply_stack, plate as OrigPlateMessenger +from numpyro.handlers import infer_config +from numpyro.handlers import trace as OrigTraceMessenger +from numpyro.primitives import Messenger, apply_stack +from numpyro.primitives import plate as OrigPlateMessenger funsor.set_backend("jax") diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 09d94b88f..786169def 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -8,11 +8,9 @@ import funsor import numpyro -from numpyro.contrib.funsor.enum_messenger import ( - infer_config, - plate as enum_plate, - trace as packed_trace, -) +from numpyro.contrib.funsor.enum_messenger import infer_config +from numpyro.contrib.funsor.enum_messenger import plate as enum_plate +from numpyro.contrib.funsor.enum_messenger import trace as packed_trace from numpyro.distributions.util import is_identically_one from numpyro.handlers import substitute diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index 30b24f532..00f300217 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -4,10 +4,12 @@ import numpy as np import jax.numpy as jnp -from tensorflow_probability.substrates.jax import bijectors as tfb, distributions as tfd +from tensorflow_probability.substrates.jax import bijectors as tfb +from tensorflow_probability.substrates.jax import distributions as tfd import numpyro.distributions as numpyro_dist -from numpyro.distributions import Distribution as NumPyroDistribution, constraints +from numpyro.distributions import Distribution as NumPyroDistribution +from numpyro.distributions import constraints from numpyro.distributions.transforms import Transform, biject_to from numpyro.util import not_jax_tracer diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 65e9d7bb9..67871c206 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -8,7 +8,7 @@ NegativeBinomial2, NegativeBinomialLogits, NegativeBinomialProbs, - ZeroInflatedNegativeBinomial2, + ZeroInflatedNegativeBinomial2 ) from numpyro.distributions.continuous import ( LKJ, @@ -35,7 +35,7 @@ SoftLaplace, StudentT, Uniform, - Weibull, + Weibull ) from numpyro.distributions.directional import ProjectedNormal, VonMises from numpyro.distributions.discrete import ( @@ -58,7 +58,7 @@ Poisson, PRNGIdentity, ZeroInflatedDistribution, - ZeroInflatedPoisson, + ZeroInflatedPoisson ) from numpyro.distributions.distribution import ( Delta, @@ -69,7 +69,7 @@ Independent, MaskedDistribution, TransformedDistribution, - Unit, + Unit ) from numpyro.distributions.kl import kl_divergence from numpyro.distributions.transforms import biject_to @@ -80,7 +80,7 @@ TruncatedDistribution, TruncatedNormal, TruncatedPolyaGamma, - TwoSidedTruncatedDistribution, + TwoSidedTruncatedDistribution ) from . import constraints, transforms diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index 58f6351c8..5110d58c2 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -7,12 +7,7 @@ from numpyro.distributions import constraints from numpyro.distributions.continuous import Beta, Dirichlet, Gamma -from numpyro.distributions.discrete import ( - BinomialProbs, - MultinomialProbs, - Poisson, - ZeroInflatedDistribution, -) +from numpyro.distributions.discrete import BinomialProbs, MultinomialProbs, Poisson, ZeroInflatedDistribution from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 95797595e..b32ea1621 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -35,12 +35,7 @@ from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution, TransformedDistribution -from numpyro.distributions.transforms import ( - AffineTransform, - CorrMatrixCholeskyTransform, - ExpTransform, - PowerTransform, -) +from numpyro.distributions.transforms import AffineTransform, CorrMatrixCholeskyTransform, ExpTransform, PowerTransform from numpyro.distributions.util import ( cholesky_of_inverse, is_prng_key, @@ -49,7 +44,7 @@ promote_shapes, signed_stick_breaking_tril, validate_sample, - vec_to_tril_matrix, + vec_to_tril_matrix ) EULER_MASCHERONI = 0.5772156649015328606065120900824024310421 diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 50e60cdb8..f5906eee4 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -1,12 +1,17 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import functools import math +import operator +import warnings +from collections import namedtuple +from math import pi -from jax import lax import jax.numpy as jnp import jax.random as random -from jax.scipy.special import erf, i0e, i1e +from jax import lax +from jax.scipy.special import erf, i0e, i1e, logsumexp from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution @@ -16,7 +21,56 @@ safe_normalize, validate_sample, von_mises_centered, + lazy_property ) +from numpyro.util import while_loop + + +def _numel(shape): + return functools.reduce(operator.mul, shape, 1) + + +def log_I1(orders: int, value, terms=250): + r""" Compute first n log modified bessel function of first kind + .. math :: + \log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk) + - \lgamma(v + k + 1)\right]) + :param orders: orders of the log modified bessel function. + :param value: values to compute modified bessel function for + :param terms: truncation of summation + :return: 0 to orders modified bessel function + """ + orders = orders + 1 + if value.ndim == 0: + vshape = jnp.shape([1]) + else: + vshape = value.shape + value = value.reshape(-1, 1) + flat_vshape = _numel(vshape) + + k = jnp.arange(terms) + lgammas_all = lax.lgamma(jnp.arange(1., terms + orders + 1)) + assert lgammas_all.shape == (orders + terms,) # lgamma(0) = inf => start from 1 + + lvalues = lax.log(value / 2) * k.reshape(1, -1) + assert lvalues.shape == (flat_vshape, terms) + + lfactorials = lgammas_all[:terms] + assert lfactorials.shape == (terms,) + + lgammas = lgammas_all.tile(orders).reshape((orders, -1)) + assert lgammas.shape == (orders, terms + orders) # lgamma(0) = inf => start from 1 + + indices = k[:orders].reshape(-1, 1) + k.reshape(1, -1) + assert indices.shape == (orders, terms) + + seqs = logsumexp(2 * lvalues[None, :, :] - lfactorials[None, None, :] + - jnp.take_along_axis(lgammas, indices, axis=1)[:, None, :], -1) + assert seqs.shape == (orders, flat_vshape) + + i1s = lvalues[..., :orders].T + seqs + assert i1s.shape == (orders, flat_vshape) + return i1s.reshape(-1, *vshape) class VonMises(Distribution): @@ -56,9 +110,8 @@ def sample(self, key, sample_shape=()): @validate_sample def log_prob(self, value): - return -( - jnp.log(2 * jnp.pi) + jnp.log(i0e(self.concentration)) - ) + self.concentration * (jnp.cos((value - self.loc) % (2 * jnp.pi)) - 1) + return -(jnp.log(2 * jnp.pi) + jnp.log(i0e(self.concentration))) + \ + self.concentration * (jnp.cos((value - self.loc) % (2 * jnp.pi)) - 1) @property def mean(self): @@ -75,6 +128,179 @@ def variance(self): ) +PhiMarginalState = namedtuple("PhiMarginalState", ['i', 'done', 'phi', 'key']) + + +class Sine(Distribution): + r""" Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by + .. math:: + C^{-1}\exp(\kappa_1\cos(x-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) + and + .. math:: + C = (2\pi)^2 \sum_{i=0} {2i \choose i} + \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2), + where I_i(\cdot) is the modified bessel function of first kind, mu's are the locations of the distribution, + kappa's are the concentration and rho gives the correlation between angles x_1 and x_2. + This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains. + To infer parameters, use :class:`~pyro.infer.NUTS` or :class:`~pyro.infer.HMC` with priors that + avoid parameterizations where the distribution becomes bimodal; see note below. + .. note:: Sample efficiency drops as + .. math:: + \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1 + because the distribution becomes increasingly bimodal. + .. note:: The correlation and weighted_correlation params are mutually exclusive. + .. note:: In the context of :class:`~pyro.infer.SVI`, this distribution can be used as a likelihood but not for + latent variables. + ** References: ** + 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) + :param jnp.Tensor phi_loc: location of first angle + :param jnp.Tensor psi_loc: location of second angle + :param jnp.Tensor phi_concentration: concentration of first angle + :param jnp.Tensor psi_concentration: concentration of second angle + :param jnp.Tensor correlation: correlation between the two angles + :param jnp.Tensor weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) + to avoid bimodality (see note). + """ + + arg_constraints = {'phi_loc': constraints.real, 'psi_loc': constraints.real, + 'phi_concentration': constraints.positive, 'psi_concentration': constraints.positive, + 'correlation': constraints.real} + support = constraints.independent(constraints.real, 1) + max_sample_iter = 10_000 + + def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration, correlation=None, + weighted_correlation=None, validate_args=None): + + assert (correlation is None) != (weighted_correlation is None) + + if weighted_correlation is not None: + correlation = weighted_correlation * jnp.sqrt(phi_concentration * psi_concentration) + 1e-8 + + self.phi_loc, self.psi_loc, self.phi_concentration, self.psi_concentration, self.correlation = promote_shapes( + phi_loc, psi_loc, + phi_concentration, + psi_concentration, + correlation) + batch_shape = lax.broadcast_shapes(phi_loc.shape, psi_loc.shape, phi_concentration.shape, + psi_concentration.shape, correlation.shape) + super().__init__(batch_shape, (2,), validate_args) + + if self._validate_args and jnp.any(phi_concentration * psi_concentration <= correlation ** 2): + warnings.warn( + f'{self.__class__.__name__} bimodal due to concentration-correlation relation, ' + f'sampling will likely fail.', UserWarning) + + @lazy_property + def norm_const(self): + corr = self.correlation.reshape(1, -1) + 1e-8 + conc = jnp.stack((self.phi_concentration, self.psi_concentration), axis=-1).reshape(-1, 2) + m = jnp.arange(50).reshape(-1, 1) + num = lax.lgamma(2 * m + 1.) + den = lax.lgamma(m + 1.) + lbinoms = num - 2 * den + + fs = lbinoms.reshape(-1, 1) + 2 * m * jnp.log(corr) - m * jnp.log(4 * jnp.prod(conc, axis=-1)) + fs += log_I1(49, conc, terms=51).sum(-1) + mfs = fs.max() + norm_const = 2 * jnp.log(jnp.array(2 * pi)) + mfs + logsumexp(fs - mfs, 0) + return norm_const.reshape(self.phi_loc.shape) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + indv = self.phi_concentration * jnp.cos(value[..., 0] - self.phi_loc) + self.psi_concentration * jnp.cos( + value[..., 1] - self.psi_loc) + corr = self.correlation * jnp.sin(value[..., 0] - self.phi_loc) * jnp.sin(value[..., 1] - self.psi_loc) + return indv + corr - self.norm_const + + def sample(self, key, sample_shape=()): + """ + ** References: ** + 1. A New Unified Approach for the Simulation of aWide Class of Directional Distributions + John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018) + """ + phi_key, psi_key = random.split(key) + + corr = self.correlation + conc = jnp.stack((self.phi_concentration, self.psi_concentration)) + + eig = 0.5 * (conc[0] - corr ** 2 / conc[1]) + eig = jnp.stack((jnp.zeros_like(eig), eig)) + eigmin = jnp.where(eig[1] < 0, eig[1], jnp.zeros_like(eig[1], dtype=eig.dtype)) + eig = eig - eigmin + b0 = self._bfind(eig) + + total = _numel(sample_shape) + phi_den = log_I1(0, conc[1]).squeeze(0) + phi_shape = (total, 2, _numel(self.batch_shape)) + phi_state = Sine._phi_marginal(phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den) + + if not jnp.all(phi_state.done): + raise ValueError("maximum number of iterations exceeded; " + "try increasing `SineBivariateVonMises.max_sample_iter`") + + phi = lax.atan2(phi_state.phi[:, :1], phi_state.phi[:, 1:]) + + alpha = jnp.sqrt(conc[1] ** 2 + (corr * jnp.sin(phi)) ** 2) + beta = lax.atan(corr / conc[1] * jnp.sin(phi)) + + psi = VonMises(beta, alpha).sample(psi_key) + + phi_psi = jnp.concatenate(((phi + self.phi_loc + pi) % (2 * pi) - pi, + (psi + self.psi_loc + pi) % (2 * pi) - pi), axis=1) + phi_psi = jnp.transpose(phi_psi, (0, 2, 1)) + return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape) + + @staticmethod + def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den): + conc = jnp.broadcast_to(conc, shape) + eig = jnp.broadcast_to(eig, shape) + b0 = jnp.broadcast_to(b0, shape) + eigmin = jnp.broadcast_to(eigmin, shape) + phi_den = jnp.broadcast_to(phi_den, shape) + + def update_fn(curr): + i, done, phi, key = curr + phi_key, key = random.split(key) + accept_key, acg_key, phi_key = random.split(phi_key, 3) + + x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) + + x /= jnp.linalg.norm(x, axis=1)[:, None, :] # Angular Central Gaussian distribution + + lf = conc[:, :1] * (x[:, :1] - 1) + eigmin + log_I1(0, jnp.sqrt( + conc[:, 1:] ** 2 + (corr * x[:, 1:]) ** 2)).squeeze(0) - phi_den + assert lf.shape == shape + + lg_inv = 1. - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True)) + assert lg_inv.shape == lf.shape + + accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv) + + phi = jnp.where(accepted, x, phi) + return PhiMarginalState(i + 1, done | accepted, phi, key) + + def cond_fn(curr): + return jnp.bitwise_and(curr.i < Sine.max_sample_iter, jnp.logical_not(jnp.all(curr.done))) + + phi_state = while_loop(cond_fn, update_fn, + PhiMarginalState(i=jnp.array(0), + done=jnp.zeros(shape, dtype=bool), + phi=jnp.empty(shape, dtype=float), + key=rng_key)) + return PhiMarginalState(phi_state.i, phi_state.done, phi_state.phi, phi_state.key) + + @property + def mean(self): + return jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + + def _bfind(self, eig): + b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype) + g1 = jnp.sum(1 / (b + 2 * eig) ** 2, axis=0) + g2 = jnp.sum(-2 / (b + 2 * eig) ** 3, axis=0) + return jnp.where(jnp.linalg.norm(eig, axis=0) != 0, b - g1 / g2, b) + + class ProjectedNormal(Distribution): """ Projected isotropic normal distribution of arbitrary dimension. diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 8596f71b9..3bbe95dd9 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -48,7 +48,7 @@ lazy_property, multinomial, promote_shapes, - validate_sample, + validate_sample ) from numpyro.util import not_jax_tracer diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index f65751d55..2e5ff9393 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -38,12 +38,7 @@ from jax.scipy.special import logsumexp from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform -from numpyro.distributions.util import ( - lazy_property, - promote_shapes, - sum_rightmost, - validate_sample, -) +from numpyro.distributions.util import lazy_property, promote_shapes, sum_rightmost, validate_sample from numpyro.util import not_jax_tracer from . import constraints diff --git a/numpyro/distributions/kl.py b/numpyro/distributions/kl.py index c97bbeb63..091f60446 100644 --- a/numpyro/distributions/kl.py +++ b/numpyro/distributions/kl.py @@ -37,7 +37,7 @@ Distribution, ExpandedDistribution, Independent, - MaskedDistribution, + MaskedDistribution ) from numpyro.distributions.util import scale_and_mask, sum_rightmost diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 636e15b2d..6c22e75f5 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -16,12 +16,7 @@ from jax.tree_util import tree_flatten, tree_map from numpyro.distributions import constraints -from numpyro.distributions.util import ( - matrix_to_tril_vec, - signed_stick_breaking_tril, - sum_rightmost, - vec_to_tril_matrix, -) +from numpyro.distributions.util import matrix_to_tril_vec, signed_stick_breaking_tril, sum_rightmost, vec_to_tril_matrix from numpyro.util import not_jax_tracer __all__ = [ diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 1717b155a..3d02a0aee 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -8,21 +8,9 @@ from jax.tree_util import tree_map from numpyro.distributions import constraints -from numpyro.distributions.continuous import ( - Cauchy, - Laplace, - Logistic, - Normal, - SoftLaplace, - StudentT, -) +from numpyro.distributions.continuous import Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT from numpyro.distributions.distribution import Distribution -from numpyro.distributions.util import ( - is_prng_key, - lazy_property, - promote_shapes, - validate_sample, -) +from numpyro.distributions.util import is_prng_key, lazy_property, promote_shapes, validate_sample class LeftTruncatedDistribution(Distribution): diff --git a/numpyro/infer/__init__.py b/numpyro/infer/__init__.py index 11fa0a553..34867c125 100644 --- a/numpyro/infer/__init__.py +++ b/numpyro/infer/__init__.py @@ -10,7 +10,7 @@ init_to_median, init_to_sample, init_to_uniform, - init_to_value, + init_to_value ) from numpyro.infer.mcmc import MCMC from numpyro.infer.mixed_hmc import MixedHMC diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index e7a8dc1cb..435fff2fe 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -17,10 +17,7 @@ from numpyro import handlers import numpyro.distributions as dist from numpyro.distributions import constraints -from numpyro.distributions.flows import ( - BlockNeuralAutoregressiveTransform, - InverseAutoregressiveTransform, -) +from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform, InverseAutoregressiveTransform from numpyro.distributions.transforms import ( AffineTransform, ComposeTransform, @@ -28,13 +25,9 @@ LowerCholeskyAffine, PermuteTransform, UnpackTransform, - biject_to, -) -from numpyro.distributions.util import ( - cholesky_of_inverse, - periodic_repeat, - sum_rightmost, + biject_to ) +from numpyro.distributions.util import cholesky_of_inverse, periodic_repeat, sum_rightmost from numpyro.infer.elbo import Trace_ELBO from numpyro.infer.initialization import init_to_median from numpyro.infer.util import init_to_uniform, initialize_model diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index aff4493f5..e5fd7de64 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -15,7 +15,7 @@ euclidean_kinetic_energy, find_reasonable_step_size, velocity_verlet, - warmup_adapter, + warmup_adapter ) from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 1c1fab8a0..7cf9c96c9 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -8,17 +8,7 @@ import numpy as np -from jax import ( - device_put, - grad, - hessian, - jacfwd, - jacobian, - lax, - ops, - random, - value_and_grad, -) +from jax import device_put, grad, hessian, jacfwd, jacobian, lax, ops, random, value_and_grad from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.scipy.special import expit diff --git a/test/contrib/einstein/test_einstein_kernels.py b/test/contrib/einstein/test_einstein_kernels.py index 556080c5a..30088af3e 100644 --- a/test/contrib/einstein/test_einstein_kernels.py +++ b/test/contrib/einstein/test_einstein_kernels.py @@ -16,7 +16,7 @@ MixtureKernel, PrecondMatrixKernel, RandomFeatureKernel, - RBFKernel, + RBFKernel ) jnp.set_printoptions(precision=100) diff --git a/test/contrib/test_funsor.py b/test/contrib/test_funsor.py index 85520bafb..3d86c82f0 100644 --- a/test/contrib/test_funsor.py +++ b/test/contrib/test_funsor.py @@ -15,7 +15,8 @@ import numpyro from numpyro.contrib.control_flow import scan from numpyro.contrib.funsor import config_enumerate, enum, markov, to_data, to_funsor -from numpyro.contrib.funsor.enum_messenger import NamedMessenger, plate as enum_plate +from numpyro.contrib.funsor.enum_messenger import NamedMessenger +from numpyro.contrib.funsor.enum_messenger import plate as enum_plate from numpyro.contrib.funsor.infer_util import log_density from numpyro.contrib.indexing import Vindex import numpyro.distributions as dist diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 2e411f62d..e8e5a6c02 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -17,7 +17,7 @@ flax_module, haiku_module, random_flax_module, - random_haiku_module, + random_haiku_module ) import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index bb16cd166..8b6353317 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -27,7 +27,7 @@ AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal, - AutoNormal, + AutoNormal ) from numpyro.infer.initialization import init_to_median from numpyro.infer.reparam import TransformReparam diff --git a/test/infer/test_hmc_util.py b/test/infer/test_hmc_util.py index 8849028d4..d33ffa08e 100644 --- a/test/infer/test_hmc_util.py +++ b/test/infer/test_hmc_util.py @@ -25,7 +25,7 @@ parametric_draws, velocity_verlet, warmup_adapter, - welford_covariance, + welford_covariance ) from numpyro.util import control_flow_prims_disabled, fori_loop, optional diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index e3d58161c..da067706f 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -21,7 +21,7 @@ init_to_median, init_to_sample, init_to_uniform, - init_to_value, + init_to_value ) from numpyro.infer.reparam import TransformReparam from numpyro.infer.util import ( @@ -30,7 +30,7 @@ initialize_model, log_likelihood, potential_energy, - transform_fn, + transform_fn ) import numpyro.optim as optim diff --git a/test/infer/test_reparam.py b/test/infer/test_reparam.py index 98e992a86..2ad44a245 100644 --- a/test/infer/test_reparam.py +++ b/test/infer/test_reparam.py @@ -14,12 +14,7 @@ import numpyro.handlers as handlers from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO from numpyro.infer.autoguide import AutoIAFNormal -from numpyro.infer.reparam import ( - LocScaleReparam, - NeuTraReparam, - ProjectedNormalReparam, - TransformReparam, -) +from numpyro.infer.reparam import LocScaleReparam, NeuTraReparam, ProjectedNormalReparam, TransformReparam from numpyro.infer.util import initialize_model from numpyro.optim import Adam diff --git a/test/test_diagnostics.py b/test/test_diagnostics.py index d0a32df87..f389598e4 100644 --- a/test/test_diagnostics.py +++ b/test/test_diagnostics.py @@ -13,7 +13,7 @@ effective_sample_size, gelman_rubin, hpdi, - split_gelman_rubin, + split_gelman_rubin ) diff --git a/test/test_distributions.py b/test/test_distributions.py index bbaa0c6ac..9e36e52b9 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -27,14 +27,14 @@ PermuteTransform, PowerTransform, SoftplusTransform, - biject_to, + biject_to ) from numpyro.distributions.util import ( matrix_to_tril_vec, multinomial, signed_stick_breaking_tril, sum_rightmost, - vec_to_tril_matrix, + vec_to_tril_matrix ) from numpyro.nn import AutoregressiveNN diff --git a/test/test_distributions_util.py b/test/test_distributions_util.py index 830725e5e..bcd51dc4f 100644 --- a/test/test_distributions_util.py +++ b/test/test_distributions_util.py @@ -20,7 +20,7 @@ multinomial, safe_normalize, vec_to_tril_matrix, - von_mises_centered, + von_mises_centered ) diff --git a/test/test_example_utils.py b/test/test_example_utils.py index 65221d2d7..c170a3cee 100644 --- a/test/test_example_utils.py +++ b/test/test_example_utils.py @@ -3,14 +3,7 @@ import jax.numpy as jnp -from numpyro.examples.datasets import ( - BASEBALL, - COVTYPE, - JSB_CHORALES, - MNIST, - SP500, - load_dataset, -) +from numpyro.examples.datasets import BASEBALL, COVTYPE, JSB_CHORALES, MNIST, SP500, load_dataset from numpyro.util import fori_loop diff --git a/test/test_flows.py b/test/test_flows.py index 769644fc3..21d16e4b7 100644 --- a/test/test_flows.py +++ b/test/test_flows.py @@ -10,10 +10,7 @@ from jax import jacfwd, random from jax.experimental import stax -from numpyro.distributions.flows import ( - BlockNeuralAutoregressiveTransform, - InverseAutoregressiveTransform, -) +from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform, InverseAutoregressiveTransform from numpyro.distributions.util import matrix_to_tril_vec from numpyro.nn import AutoregressiveNN, BlockNeuralAutoregressiveNN diff --git a/test/test_pickle.py b/test/test_pickle.py index bf0d79793..21d947a75 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -10,16 +10,7 @@ import numpyro import numpyro.distributions as dist -from numpyro.infer import ( - HMC, - HMCECS, - MCMC, - NUTS, - SA, - BarkerMH, - DiscreteHMCGibbs, - MixedHMC, -) +from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SA, BarkerMH, DiscreteHMCGibbs, MixedHMC def normal_model(): From 529c4c5735953170ed0fbc7af0524b3133b097f8 Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 3 Jun 2021 11:10:19 +0200 Subject: [PATCH 02/23] Added BvM tests. --- numpyro/distributions/__init__.py | 3 ++- numpyro/distributions/directional.py | 6 +++--- test/test_distributions.py | 8 ++++++++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 67871c206..991c9cbb1 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -37,7 +37,7 @@ Uniform, Weibull ) -from numpyro.distributions.directional import ProjectedNormal, VonMises +from numpyro.distributions.directional import ProjectedNormal, VonMises, SineBivariateVonMises from numpyro.distributions.discrete import ( Bernoulli, BernoulliLogits, @@ -145,6 +145,7 @@ "ProjectedNormal", "PRNGIdentity", "RightTruncatedDistribution", + "SineBivariateVonMises", "SoftLaplace", "StudentT", "TransformedDistribution", diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index f5906eee4..b746eae4e 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -131,7 +131,7 @@ def variance(self): PhiMarginalState = namedtuple("PhiMarginalState", ['i', 'done', 'phi', 'key']) -class Sine(Distribution): +class SineBivariateVonMises(Distribution): r""" Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by .. math:: C^{-1}\exp(\kappa_1\cos(x-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) @@ -233,7 +233,7 @@ def sample(self, key, sample_shape=()): total = _numel(sample_shape) phi_den = log_I1(0, conc[1]).squeeze(0) phi_shape = (total, 2, _numel(self.batch_shape)) - phi_state = Sine._phi_marginal(phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den) + phi_state = SineBivariateVonMises._phi_marginal(phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den) if not jnp.all(phi_state.done): raise ValueError("maximum number of iterations exceeded; " @@ -281,7 +281,7 @@ def update_fn(curr): return PhiMarginalState(i + 1, done | accepted, phi, key) def cond_fn(curr): - return jnp.bitwise_and(curr.i < Sine.max_sample_iter, jnp.logical_not(jnp.all(curr.done))) + return jnp.bitwise_and(curr.i < SineBivariateVonMises.max_sample_iter, jnp.logical_not(jnp.all(curr.done))) phi_state = while_loop(cond_fn, update_fn, PhiMarginalState(i=jnp.array(0), diff --git a/test/test_distributions.py b/test/test_distributions.py index 9e36e52b9..5b52de11b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -5,6 +5,7 @@ from functools import partial import inspect import os +import math import numpy as np from numpy.testing import assert_allclose, assert_array_equal @@ -288,6 +289,13 @@ def get_sp_dist(jax_dist): T(dist.Pareto, 1.0, 2.0), T(dist.Pareto, jnp.array([1.0, 0.5]), jnp.array([0.3, 2.0])), T(dist.Pareto, jnp.array([[1.0], [3.0]]), jnp.array([1.0, 0.5])), + T(dist.SineBivariateVonMises, jnp.array([0.]), jnp.array([0.]), jnp.array([5.]), jnp.array([6.]), jnp.array([2.])), + T(dist.SineBivariateVonMises, jnp.array([3.003]), jnp.array([-1.343]), + jnp.array([5.]), jnp.array([6.]), jnp.array([2.])), + T(dist.SineBivariateVonMises, jnp.array([-math.pi/3]), jnp.array(-1), + jnp.array(.4), jnp.array(10.), jnp.array(.9)), + T(dist.SineBivariateVonMises, jnp.array([math.pi - .2, 1.]), jnp.array([0.,1.]), + jnp.array([5., 5.]), jnp.array([7., .5]), None, jnp.array([.5, .1])), T(dist.SoftLaplace, 1.0, 1.0), T(dist.SoftLaplace, jnp.array([-1.0, 50.0]), jnp.array([4.0, 100.0])), T(dist.StudentT, 1.0, 1.0, 0.5), From ec1bd1e4eb753d5c8af2474757a334e5bcf33898 Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 3 Jun 2021 11:19:19 +0200 Subject: [PATCH 03/23] Added comment with tests that need to be fixed. --- test/test_distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 5b52de11b..0dbc92eb6 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -290,8 +290,8 @@ def get_sp_dist(jax_dist): T(dist.Pareto, jnp.array([1.0, 0.5]), jnp.array([0.3, 2.0])), T(dist.Pareto, jnp.array([[1.0], [3.0]]), jnp.array([1.0, 0.5])), T(dist.SineBivariateVonMises, jnp.array([0.]), jnp.array([0.]), jnp.array([5.]), jnp.array([6.]), jnp.array([2.])), - T(dist.SineBivariateVonMises, jnp.array([3.003]), jnp.array([-1.343]), - jnp.array([5.]), jnp.array([6.]), jnp.array([2.])), + T(dist.SineBivariateVonMises, jnp.array([3.003]), jnp.array([-1.343]), # check test_gof, test_mean_var, + jnp.array([5.]), jnp.array([6.]), jnp.array([2.])), # check test_distribution_constraints T(dist.SineBivariateVonMises, jnp.array([-math.pi/3]), jnp.array(-1), jnp.array(.4), jnp.array(10.), jnp.array(.9)), T(dist.SineBivariateVonMises, jnp.array([math.pi - .2, 1.]), jnp.array([0.,1.]), From 985825110b0b8f461f8cb5efc9b48ae71cbfe5a8 Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 10 Jun 2021 21:24:13 +0200 Subject: [PATCH 04/23] Added tests. --- examples/gp.py | 10 +- numpyro/__init__.py | 2 +- numpyro/compat/pyro.py | 4 +- numpyro/contrib/control_flow/scan.py | 18 +- numpyro/contrib/funsor/__init__.py | 16 +- numpyro/contrib/funsor/discrete.py | 3 +- numpyro/contrib/funsor/enum_messenger.py | 6 +- numpyro/contrib/funsor/infer_util.py | 8 +- numpyro/contrib/tfp/distributions.py | 6 +- numpyro/distributions/__init__.py | 10 +- numpyro/distributions/conjugate.py | 7 +- numpyro/distributions/continuous.py | 9 +- numpyro/distributions/directional.py | 19 +- numpyro/distributions/discrete.py | 2 +- numpyro/distributions/distribution.py | 7 +- numpyro/distributions/kl.py | 2 +- numpyro/distributions/transforms.py | 7 +- numpyro/distributions/truncated.py | 16 +- numpyro/infer/__init__.py | 2 +- numpyro/infer/autoguide.py | 13 +- numpyro/infer/hmc.py | 2 +- numpyro/infer/hmc_gibbs.py | 12 +- .../contrib/einstein/test_einstein_kernels.py | 2 +- test/contrib/test_funsor.py | 3 +- test/contrib/test_module.py | 2 +- test/infer/test_autoguide.py | 2 +- test/infer/test_hmc_util.py | 2 +- test/infer/test_infer_util.py | 4 +- test/infer/test_reparam.py | 7 +- test/test_diagnostics.py | 2 +- test/test_distributions.py | 239 +++++++++--------- test/test_distributions_util.py | 2 +- test/test_example_utils.py | 9 +- test/test_flows.py | 5 +- test/test_pickle.py | 11 +- 35 files changed, 287 insertions(+), 184 deletions(-) diff --git a/examples/gp.py b/examples/gp.py index 0b3f7b7a1..11b7a4d3a 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -27,7 +27,15 @@ import numpyro import numpyro.distributions as dist -from numpyro.infer import MCMC, NUTS, init_to_feasible, init_to_median, init_to_sample, init_to_uniform, init_to_value +from numpyro.infer import ( + MCMC, + NUTS, + init_to_feasible, + init_to_median, + init_to_sample, + init_to_uniform, + init_to_value, +) matplotlib.use("Agg") # noqa: E402 diff --git a/numpyro/__init__.py b/numpyro/__init__.py index 8ff9bb4ef..4fc884250 100644 --- a/numpyro/__init__.py +++ b/numpyro/__init__.py @@ -17,7 +17,7 @@ plate_stack, prng_key, sample, - subsample + subsample, ) from numpyro.util import enable_x64, set_host_device_count, set_platform from numpyro.version import __version__ diff --git a/numpyro/compat/pyro.py b/numpyro/compat/pyro.py index b317a8d06..47f805596 100644 --- a/numpyro/compat/pyro.py +++ b/numpyro/compat/pyro.py @@ -4,9 +4,7 @@ import warnings from numpyro.compat.util import UnsupportedAPIWarning -from numpyro.primitives import module -from numpyro.primitives import param as _param # noqa: F401 -from numpyro.primitives import plate, sample +from numpyro.primitives import module, param as _param, plate, sample # noqa: F401 _PARAM_STORE = {} diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index fac5368b5..57165b834 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -4,7 +4,15 @@ from collections import OrderedDict from functools import partial -from jax import device_put, lax, random, tree_flatten, tree_map, tree_multimap, tree_unflatten +from jax import ( + device_put, + lax, + random, + tree_flatten, + tree_map, + tree_multimap, + tree_unflatten, +) import jax.numpy as jnp from numpyro import handlers @@ -107,8 +115,12 @@ def scan_enum( history=1, first_available_dim=None, ): - from numpyro.contrib.funsor import config_enumerate, enum, markov - from numpyro.contrib.funsor import trace as packed_trace + from numpyro.contrib.funsor import ( + config_enumerate, + enum, + markov, + trace as packed_trace, + ) # amount number of steps to unroll history = min(history, length) diff --git a/numpyro/contrib/funsor/__init__.py b/numpyro/contrib/funsor/__init__.py index 999d6d222..53027ae12 100644 --- a/numpyro/contrib/funsor/__init__.py +++ b/numpyro/contrib/funsor/__init__.py @@ -12,8 +12,20 @@ ) from e from numpyro.contrib.funsor.discrete import infer_discrete -from numpyro.contrib.funsor.enum_messenger import enum, infer_config, markov, plate, to_data, to_funsor, trace -from numpyro.contrib.funsor.infer_util import config_enumerate, log_density, plate_to_enum_plate +from numpyro.contrib.funsor.enum_messenger import ( + enum, + infer_config, + markov, + plate, + to_data, + to_funsor, + trace, +) +from numpyro.contrib.funsor.infer_util import ( + config_enumerate, + log_density, + plate_to_enum_plate, +) funsor.set_backend("jax") diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index a6b706812..59ed4513a 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -7,8 +7,7 @@ from jax import random import funsor -from numpyro.contrib.funsor.enum_messenger import enum -from numpyro.contrib.funsor.enum_messenger import trace as packed_trace +from numpyro.contrib.funsor.enum_messenger import enum, trace as packed_trace from numpyro.contrib.funsor.infer_util import plate_to_enum_plate from numpyro.distributions.util import is_identically_one from numpyro.handlers import block, replay, seed, trace diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index 64d27717f..c84f25b05 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -9,10 +9,8 @@ import jax.numpy as jnp import funsor -from numpyro.handlers import infer_config -from numpyro.handlers import trace as OrigTraceMessenger -from numpyro.primitives import Messenger, apply_stack -from numpyro.primitives import plate as OrigPlateMessenger +from numpyro.handlers import infer_config, trace as OrigTraceMessenger +from numpyro.primitives import Messenger, apply_stack, plate as OrigPlateMessenger funsor.set_backend("jax") diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 786169def..09d94b88f 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -8,9 +8,11 @@ import funsor import numpyro -from numpyro.contrib.funsor.enum_messenger import infer_config -from numpyro.contrib.funsor.enum_messenger import plate as enum_plate -from numpyro.contrib.funsor.enum_messenger import trace as packed_trace +from numpyro.contrib.funsor.enum_messenger import ( + infer_config, + plate as enum_plate, + trace as packed_trace, +) from numpyro.distributions.util import is_identically_one from numpyro.handlers import substitute diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index 00f300217..30b24f532 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -4,12 +4,10 @@ import numpy as np import jax.numpy as jnp -from tensorflow_probability.substrates.jax import bijectors as tfb -from tensorflow_probability.substrates.jax import distributions as tfd +from tensorflow_probability.substrates.jax import bijectors as tfb, distributions as tfd import numpyro.distributions as numpyro_dist -from numpyro.distributions import Distribution as NumPyroDistribution -from numpyro.distributions import constraints +from numpyro.distributions import Distribution as NumPyroDistribution, constraints from numpyro.distributions.transforms import Transform, biject_to from numpyro.util import not_jax_tracer diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 67871c206..65e9d7bb9 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -8,7 +8,7 @@ NegativeBinomial2, NegativeBinomialLogits, NegativeBinomialProbs, - ZeroInflatedNegativeBinomial2 + ZeroInflatedNegativeBinomial2, ) from numpyro.distributions.continuous import ( LKJ, @@ -35,7 +35,7 @@ SoftLaplace, StudentT, Uniform, - Weibull + Weibull, ) from numpyro.distributions.directional import ProjectedNormal, VonMises from numpyro.distributions.discrete import ( @@ -58,7 +58,7 @@ Poisson, PRNGIdentity, ZeroInflatedDistribution, - ZeroInflatedPoisson + ZeroInflatedPoisson, ) from numpyro.distributions.distribution import ( Delta, @@ -69,7 +69,7 @@ Independent, MaskedDistribution, TransformedDistribution, - Unit + Unit, ) from numpyro.distributions.kl import kl_divergence from numpyro.distributions.transforms import biject_to @@ -80,7 +80,7 @@ TruncatedDistribution, TruncatedNormal, TruncatedPolyaGamma, - TwoSidedTruncatedDistribution + TwoSidedTruncatedDistribution, ) from . import constraints, transforms diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index 5110d58c2..58f6351c8 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -7,7 +7,12 @@ from numpyro.distributions import constraints from numpyro.distributions.continuous import Beta, Dirichlet, Gamma -from numpyro.distributions.discrete import BinomialProbs, MultinomialProbs, Poisson, ZeroInflatedDistribution +from numpyro.distributions.discrete import ( + BinomialProbs, + MultinomialProbs, + Poisson, + ZeroInflatedDistribution, +) from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index b32ea1621..95797595e 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -35,7 +35,12 @@ from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution, TransformedDistribution -from numpyro.distributions.transforms import AffineTransform, CorrMatrixCholeskyTransform, ExpTransform, PowerTransform +from numpyro.distributions.transforms import ( + AffineTransform, + CorrMatrixCholeskyTransform, + ExpTransform, + PowerTransform, +) from numpyro.distributions.util import ( cholesky_of_inverse, is_prng_key, @@ -44,7 +49,7 @@ promote_shapes, signed_stick_breaking_tril, validate_sample, - vec_to_tril_matrix + vec_to_tril_matrix, ) EULER_MASCHERONI = 0.5772156649015328606065120900824024310421 diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index f5906eee4..d3c7ace07 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -1,27 +1,27 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from collections import namedtuple import functools import math +from math import pi import operator import warnings -from collections import namedtuple -from math import pi +from jax import lax import jax.numpy as jnp import jax.random as random -from jax import lax from jax.scipy.special import erf, i0e, i1e, logsumexp from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import ( is_prng_key, + lazy_property, promote_shapes, safe_normalize, validate_sample, von_mises_centered, - lazy_property ) from numpyro.util import while_loop @@ -131,7 +131,7 @@ def variance(self): PhiMarginalState = namedtuple("PhiMarginalState", ['i', 'done', 'phi', 'key']) -class Sine(Distribution): +class SineBivariateVonMises(Distribution): r""" Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by .. math:: C^{-1}\exp(\kappa_1\cos(x-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) @@ -166,7 +166,7 @@ class Sine(Distribution): 'phi_concentration': constraints.positive, 'psi_concentration': constraints.positive, 'correlation': constraints.real} support = constraints.independent(constraints.real, 1) - max_sample_iter = 10_000 + max_sample_iter = 1000 def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration, correlation=None, weighted_correlation=None, validate_args=None): @@ -233,13 +233,13 @@ def sample(self, key, sample_shape=()): total = _numel(sample_shape) phi_den = log_I1(0, conc[1]).squeeze(0) phi_shape = (total, 2, _numel(self.batch_shape)) - phi_state = Sine._phi_marginal(phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den) + phi_state = SineBivariateVonMises._phi_marginal(phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den) if not jnp.all(phi_state.done): raise ValueError("maximum number of iterations exceeded; " "try increasing `SineBivariateVonMises.max_sample_iter`") - phi = lax.atan2(phi_state.phi[:, :1], phi_state.phi[:, 1:]) + phi = lax.atan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) alpha = jnp.sqrt(conc[1] ** 2 + (corr * jnp.sin(phi)) ** 2) beta = lax.atan(corr / conc[1] * jnp.sin(phi)) @@ -265,7 +265,6 @@ def update_fn(curr): accept_key, acg_key, phi_key = random.split(phi_key, 3) x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) - x /= jnp.linalg.norm(x, axis=1)[:, None, :] # Angular Central Gaussian distribution lf = conc[:, :1] * (x[:, :1] - 1) + eigmin + log_I1(0, jnp.sqrt( @@ -281,7 +280,7 @@ def update_fn(curr): return PhiMarginalState(i + 1, done | accepted, phi, key) def cond_fn(curr): - return jnp.bitwise_and(curr.i < Sine.max_sample_iter, jnp.logical_not(jnp.all(curr.done))) + return jnp.bitwise_and(curr.i < SineBivariateVonMises.max_sample_iter, jnp.logical_not(jnp.all(curr.done))) phi_state = while_loop(cond_fn, update_fn, PhiMarginalState(i=jnp.array(0), diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 3bbe95dd9..8596f71b9 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -48,7 +48,7 @@ lazy_property, multinomial, promote_shapes, - validate_sample + validate_sample, ) from numpyro.util import not_jax_tracer diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 2e5ff9393..f65751d55 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -38,7 +38,12 @@ from jax.scipy.special import logsumexp from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform -from numpyro.distributions.util import lazy_property, promote_shapes, sum_rightmost, validate_sample +from numpyro.distributions.util import ( + lazy_property, + promote_shapes, + sum_rightmost, + validate_sample, +) from numpyro.util import not_jax_tracer from . import constraints diff --git a/numpyro/distributions/kl.py b/numpyro/distributions/kl.py index 091f60446..c97bbeb63 100644 --- a/numpyro/distributions/kl.py +++ b/numpyro/distributions/kl.py @@ -37,7 +37,7 @@ Distribution, ExpandedDistribution, Independent, - MaskedDistribution + MaskedDistribution, ) from numpyro.distributions.util import scale_and_mask, sum_rightmost diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 6c22e75f5..636e15b2d 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -16,7 +16,12 @@ from jax.tree_util import tree_flatten, tree_map from numpyro.distributions import constraints -from numpyro.distributions.util import matrix_to_tril_vec, signed_stick_breaking_tril, sum_rightmost, vec_to_tril_matrix +from numpyro.distributions.util import ( + matrix_to_tril_vec, + signed_stick_breaking_tril, + sum_rightmost, + vec_to_tril_matrix, +) from numpyro.util import not_jax_tracer __all__ = [ diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 3d02a0aee..1717b155a 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -8,9 +8,21 @@ from jax.tree_util import tree_map from numpyro.distributions import constraints -from numpyro.distributions.continuous import Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT +from numpyro.distributions.continuous import ( + Cauchy, + Laplace, + Logistic, + Normal, + SoftLaplace, + StudentT, +) from numpyro.distributions.distribution import Distribution -from numpyro.distributions.util import is_prng_key, lazy_property, promote_shapes, validate_sample +from numpyro.distributions.util import ( + is_prng_key, + lazy_property, + promote_shapes, + validate_sample, +) class LeftTruncatedDistribution(Distribution): diff --git a/numpyro/infer/__init__.py b/numpyro/infer/__init__.py index 34867c125..11fa0a553 100644 --- a/numpyro/infer/__init__.py +++ b/numpyro/infer/__init__.py @@ -10,7 +10,7 @@ init_to_median, init_to_sample, init_to_uniform, - init_to_value + init_to_value, ) from numpyro.infer.mcmc import MCMC from numpyro.infer.mixed_hmc import MixedHMC diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 435fff2fe..e7a8dc1cb 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -17,7 +17,10 @@ from numpyro import handlers import numpyro.distributions as dist from numpyro.distributions import constraints -from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform, InverseAutoregressiveTransform +from numpyro.distributions.flows import ( + BlockNeuralAutoregressiveTransform, + InverseAutoregressiveTransform, +) from numpyro.distributions.transforms import ( AffineTransform, ComposeTransform, @@ -25,9 +28,13 @@ LowerCholeskyAffine, PermuteTransform, UnpackTransform, - biject_to + biject_to, +) +from numpyro.distributions.util import ( + cholesky_of_inverse, + periodic_repeat, + sum_rightmost, ) -from numpyro.distributions.util import cholesky_of_inverse, periodic_repeat, sum_rightmost from numpyro.infer.elbo import Trace_ELBO from numpyro.infer.initialization import init_to_median from numpyro.infer.util import init_to_uniform, initialize_model diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index e5fd7de64..aff4493f5 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -15,7 +15,7 @@ euclidean_kinetic_energy, find_reasonable_step_size, velocity_verlet, - warmup_adapter + warmup_adapter, ) from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 7cf9c96c9..1c1fab8a0 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -8,7 +8,17 @@ import numpy as np -from jax import device_put, grad, hessian, jacfwd, jacobian, lax, ops, random, value_and_grad +from jax import ( + device_put, + grad, + hessian, + jacfwd, + jacobian, + lax, + ops, + random, + value_and_grad, +) from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.scipy.special import expit diff --git a/test/contrib/einstein/test_einstein_kernels.py b/test/contrib/einstein/test_einstein_kernels.py index 30088af3e..556080c5a 100644 --- a/test/contrib/einstein/test_einstein_kernels.py +++ b/test/contrib/einstein/test_einstein_kernels.py @@ -16,7 +16,7 @@ MixtureKernel, PrecondMatrixKernel, RandomFeatureKernel, - RBFKernel + RBFKernel, ) jnp.set_printoptions(precision=100) diff --git a/test/contrib/test_funsor.py b/test/contrib/test_funsor.py index 3d86c82f0..85520bafb 100644 --- a/test/contrib/test_funsor.py +++ b/test/contrib/test_funsor.py @@ -15,8 +15,7 @@ import numpyro from numpyro.contrib.control_flow import scan from numpyro.contrib.funsor import config_enumerate, enum, markov, to_data, to_funsor -from numpyro.contrib.funsor.enum_messenger import NamedMessenger -from numpyro.contrib.funsor.enum_messenger import plate as enum_plate +from numpyro.contrib.funsor.enum_messenger import NamedMessenger, plate as enum_plate from numpyro.contrib.funsor.infer_util import log_density from numpyro.contrib.indexing import Vindex import numpyro.distributions as dist diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index e8e5a6c02..2e411f62d 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -17,7 +17,7 @@ flax_module, haiku_module, random_flax_module, - random_haiku_module + random_haiku_module, ) import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 8b6353317..bb16cd166 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -27,7 +27,7 @@ AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal, - AutoNormal + AutoNormal, ) from numpyro.infer.initialization import init_to_median from numpyro.infer.reparam import TransformReparam diff --git a/test/infer/test_hmc_util.py b/test/infer/test_hmc_util.py index d33ffa08e..8849028d4 100644 --- a/test/infer/test_hmc_util.py +++ b/test/infer/test_hmc_util.py @@ -25,7 +25,7 @@ parametric_draws, velocity_verlet, warmup_adapter, - welford_covariance + welford_covariance, ) from numpyro.util import control_flow_prims_disabled, fori_loop, optional diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index da067706f..e3d58161c 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -21,7 +21,7 @@ init_to_median, init_to_sample, init_to_uniform, - init_to_value + init_to_value, ) from numpyro.infer.reparam import TransformReparam from numpyro.infer.util import ( @@ -30,7 +30,7 @@ initialize_model, log_likelihood, potential_energy, - transform_fn + transform_fn, ) import numpyro.optim as optim diff --git a/test/infer/test_reparam.py b/test/infer/test_reparam.py index 2ad44a245..98e992a86 100644 --- a/test/infer/test_reparam.py +++ b/test/infer/test_reparam.py @@ -14,7 +14,12 @@ import numpyro.handlers as handlers from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO from numpyro.infer.autoguide import AutoIAFNormal -from numpyro.infer.reparam import LocScaleReparam, NeuTraReparam, ProjectedNormalReparam, TransformReparam +from numpyro.infer.reparam import ( + LocScaleReparam, + NeuTraReparam, + ProjectedNormalReparam, + TransformReparam, +) from numpyro.infer.util import initialize_model from numpyro.optim import Adam diff --git a/test/test_diagnostics.py b/test/test_diagnostics.py index f389598e4..d0a32df87 100644 --- a/test/test_diagnostics.py +++ b/test/test_diagnostics.py @@ -13,7 +13,7 @@ effective_sample_size, gelman_rubin, hpdi, - split_gelman_rubin + split_gelman_rubin, ) diff --git a/test/test_distributions.py b/test/test_distributions.py index 9e36e52b9..7e18d8471 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1,9 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 - from collections import namedtuple from functools import partial import inspect +import math import os import numpy as np @@ -19,6 +19,7 @@ import numpyro.distributions as dist from numpyro.distributions import constraints, kl_divergence, transforms +from numpyro.distributions.directional import SineBivariateVonMises from numpyro.distributions.discrete import _to_probs_bernoulli, _to_probs_multinom from numpyro.distributions.flows import InverseAutoregressiveTransform from numpyro.distributions.gof import InvalidTest, auto_goodness_of_fit @@ -27,14 +28,14 @@ PermuteTransform, PowerTransform, SoftplusTransform, - biject_to + biject_to, ) from numpyro.distributions.util import ( matrix_to_tril_vec, multinomial, signed_stick_breaking_tril, sum_rightmost, - vec_to_tril_matrix + vec_to_tril_matrix, ) from numpyro.nn import AutoregressiveNN @@ -333,6 +334,11 @@ def get_sp_dist(jax_dist): T(dist.VonMises, 2.0, 10.0), T(dist.VonMises, 2.0, jnp.array([150.0, 10.0])), T(dist.VonMises, jnp.array([1 / 3 * jnp.pi, -1.0]), jnp.array([20.0, 30.0])), + T(SineBivariateVonMises, jnp.array([0.]), jnp.array([0.]), jnp.array([5.]), jnp.array([6.]), jnp.array([2.])), + T(SineBivariateVonMises, jnp.array([3.003]), jnp.array([-1.3430]), jnp.array([5.]), jnp.array([6.]), + jnp.array([2.])), + T(SineBivariateVonMises, jnp.array([math.pi - .2, 1.]), jnp.array([0., 1.]), jnp.array([2.123, 20.]), + jnp.array([7., .5]), None, jnp.array([.2, .5])), T(dist.ProjectedNormal, jnp.array([0.0, 0.0])), T(dist.ProjectedNormal, jnp.array([[2.0, 3.0]])), T(dist.ProjectedNormal, jnp.array([0.0, 0.0, 0.0])), @@ -488,22 +494,22 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): elif isinstance(constraint, constraints.multinomial): n = size[-1] return ( - multinomial( - key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1] - ) - + 1 + multinomial( + key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1] + ) + + 1 ) elif constraint is constraints.corr_cholesky: return ( - signed_stick_breaking_tril( - random.uniform( - key, - size[:-2] + (size[-1] * (size[-1] - 1) // 2,), - minval=-1, - maxval=1, + signed_stick_breaking_tril( + random.uniform( + key, + size[:-2] + (size[-1] * (size[-1] - 1) // 2,), + minval=-1, + maxval=1, + ) ) - ) - + 1e-2 + + 1e-2 ) elif constraint is constraints.corr_matrix: cholesky = 1e-2 + signed_stick_breaking_tril( @@ -714,10 +720,10 @@ def g(params): ) def test_jit_log_likelihood(jax_dist, sp_dist, params): if jax_dist.__name__ in ( - "GaussianRandomWalk", - "_ImproperWrapper", - "LKJ", - "LKJCholesky", + "GaussianRandomWalk", + "_ImproperWrapper", + "LKJ", + "LKJCholesky", ): pytest.xfail(reason="non-jittable params") @@ -745,12 +751,12 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape if sp_dist is None: if isinstance( - jax_dist, - ( - dist.LeftTruncatedDistribution, - dist.RightTruncatedDistribution, - dist.TwoSidedTruncatedDistribution, - ), + jax_dist, + ( + dist.LeftTruncatedDistribution, + dist.RightTruncatedDistribution, + dist.TwoSidedTruncatedDistribution, + ), ): if isinstance(params[0], dist.Distribution): # new api @@ -1076,7 +1082,7 @@ def fn(*args): eps = 1e-3 for i in range(len(params)): if isinstance( - params[i], dist.Distribution + params[i], dist.Distribution ): # skip taking grad w.r.t. base_dist continue if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64): @@ -1105,10 +1111,10 @@ def test_mean_var(jax_dist, sp_dist, params): if jax_dist is FoldedNormal: pytest.skip("Folded distribution does not has mean/var implemented") if jax_dist in ( - _TruncatedNormal, - dist.LeftTruncatedDistribution, - dist.RightTruncatedDistribution, - dist.TwoSidedTruncatedDistribution, + _TruncatedNormal, + dist.LeftTruncatedDistribution, + dist.RightTruncatedDistribution, + dist.TwoSidedTruncatedDistribution, ): pytest.skip("Truncated distributions do not has mean/var implemented") if jax_dist is dist.ProjectedNormal: @@ -1121,9 +1127,9 @@ def test_mean_var(jax_dist, sp_dist, params): # check with suitable scipy implementation if available # XXX: VonMises is already tested below if ( - sp_dist - and not _is_batched_multivariate(d_jax) - and jax_dist not in [dist.VonMises] + sp_dist + and not _is_batched_multivariate(d_jax) + and jax_dist not in [dist.VonMises] ): d_sp = sp_dist(*params) try: @@ -1209,13 +1215,13 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): dependent_constraint = False for i in range(len(params)): if ( - jax_dist in (_ImproperWrapper, dist.LKJ, dist.LKJCholesky) - and dist_args[i] != "concentration" + jax_dist in (_ImproperWrapper, dist.LKJ, dist.LKJCholesky) + and dist_args[i] != "concentration" ): continue if ( - jax_dist is dist.TwoSidedTruncatedDistribution - and dist_args[i] == "base_dist" + jax_dist is dist.TwoSidedTruncatedDistribution + and dist_args[i] == "base_dist" ): continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": @@ -1258,9 +1264,9 @@ def dist_gen_fn(): # Test agreement of log density evaluation on randomly generated samples # with scipy's implementation when available. if ( - sp_dist - and not _is_batched_multivariate(d) - and not (d.event_shape and prepend_shape) + sp_dist + and not _is_batched_multivariate(d) + and not (d.event_shape and prepend_shape) ): valid_samples = gen_values_within_bounds( d.support, size=prepend_shape + d.batch_shape + d.event_shape @@ -1336,113 +1342,113 @@ def g(x): (constraints.boolean, jnp.array([1, 1]), jnp.array([True, True])), (constraints.boolean, jnp.array([-1, 1]), jnp.array([False, True])), ( - constraints.corr_cholesky, - jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), - jnp.array([True, False]), + constraints.corr_cholesky, + jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), + jnp.array([True, False]), ), # NB: not lower_triangular ( - constraints.corr_cholesky, - jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), - jnp.array([False, False]), + constraints.corr_cholesky, + jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), + jnp.array([False, False]), ), # NB: not positive_diagonal & not unit_norm_row ( - constraints.corr_matrix, - jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), - jnp.array([True, False]), + constraints.corr_matrix, + jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), + jnp.array([True, False]), ), # NB: not lower_triangular ( - constraints.corr_matrix, - jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), - jnp.array([False, False]), + constraints.corr_matrix, + jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), + jnp.array([False, False]), ), # NB: not unit diagonal (constraints.greater_than(1), 3, True), ( - constraints.greater_than(1), - jnp.array([-1, 1, 5]), - jnp.array([False, False, True]), + constraints.greater_than(1), + jnp.array([-1, 1, 5]), + jnp.array([False, False, True]), ), (constraints.integer_interval(-3, 5), 0, True), ( - constraints.integer_interval(-3, 5), - jnp.array([-5, -3, 0, 1.1, 5, 7]), - jnp.array([False, True, True, False, True, False]), + constraints.integer_interval(-3, 5), + jnp.array([-5, -3, 0, 1.1, 5, 7]), + jnp.array([False, True, True, False, True, False]), ), (constraints.interval(-3, 5), 0, True), ( - constraints.interval(-3, 5), - jnp.array([-5, -3, 0, 5, 7]), - jnp.array([False, True, True, True, False]), + constraints.interval(-3, 5), + jnp.array([-5, -3, 0, 5, 7]), + jnp.array([False, True, True, True, False]), ), (constraints.less_than(1), -2, True), ( - constraints.less_than(1), - jnp.array([-1, 1, 5]), - jnp.array([True, False, False]), + constraints.less_than(1), + jnp.array([-1, 1, 5]), + jnp.array([True, False, False]), ), (constraints.lower_cholesky, jnp.array([[1.0, 0.0], [-2.0, 0.1]]), True), ( - constraints.lower_cholesky, - jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), - jnp.array([False, False]), + constraints.lower_cholesky, + jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), + jnp.array([False, False]), ), (constraints.nonnegative_integer, 3, True), ( - constraints.nonnegative_integer, - jnp.array([-1.0, 0.0, 5.0]), - jnp.array([False, True, True]), + constraints.nonnegative_integer, + jnp.array([-1.0, 0.0, 5.0]), + jnp.array([False, True, True]), ), (constraints.positive, 3, True), (constraints.positive, jnp.array([-1, 0, 5]), jnp.array([False, False, True])), (constraints.positive_definite, jnp.array([[1.0, 0.3], [0.3, 1.0]]), True), ( - constraints.positive_definite, - jnp.array([[[2.0, 0.4], [0.3, 2.0]], [[1.0, 0.1], [0.1, 0.0]]]), - jnp.array([False, False]), + constraints.positive_definite, + jnp.array([[[2.0, 0.4], [0.3, 2.0]], [[1.0, 0.1], [0.1, 0.0]]]), + jnp.array([False, False]), ), (constraints.positive_integer, 3, True), ( - constraints.positive_integer, - jnp.array([-1.0, 0.0, 5.0]), - jnp.array([False, False, True]), + constraints.positive_integer, + jnp.array([-1.0, 0.0, 5.0]), + jnp.array([False, False, True]), ), (constraints.real, -1, True), ( - constraints.real, - jnp.array([jnp.inf, jnp.NINF, jnp.nan, jnp.pi]), - jnp.array([False, False, False, True]), + constraints.real, + jnp.array([jnp.inf, jnp.NINF, jnp.nan, jnp.pi]), + jnp.array([False, False, False, True]), ), (constraints.simplex, jnp.array([0.1, 0.3, 0.6]), True), ( - constraints.simplex, - jnp.array([[0.1, 0.3, 0.6], [-0.1, 0.6, 0.5], [0.1, 0.6, 0.5]]), - jnp.array([True, False, False]), + constraints.simplex, + jnp.array([[0.1, 0.3, 0.6], [-0.1, 0.6, 0.5], [0.1, 0.6, 0.5]]), + jnp.array([True, False, False]), ), (constraints.softplus_positive, 3, True), ( - constraints.softplus_positive, - jnp.array([-1, 0, 5]), - jnp.array([False, False, True]), + constraints.softplus_positive, + jnp.array([-1, 0, 5]), + jnp.array([False, False, True]), ), ( - constraints.softplus_lower_cholesky, - jnp.array([[1.0, 0.0], [-2.0, 0.1]]), - True, + constraints.softplus_lower_cholesky, + jnp.array([[1.0, 0.0], [-2.0, 0.1]]), + True, ), ( - constraints.softplus_lower_cholesky, - jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), - jnp.array([False, False]), + constraints.softplus_lower_cholesky, + jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), + jnp.array([False, False]), ), (constraints.unit_interval, 0.1, True), ( - constraints.unit_interval, - jnp.array([-5, 0, 0.5, 1, 7]), - jnp.array([False, True, True, True, False]), + constraints.unit_interval, + jnp.array([-5, 0, 0.5, 1, 7]), + jnp.array([False, True, True, True, False]), ), ( - constraints.sphere, - jnp.array([[1, 0, 0], [0.5, 0.5, 0]]), - jnp.array([True, False]), + constraints.sphere, + jnp.array([[1, 0, 0], [0.5, 0.5, 0]]), + jnp.array([True, False]), ), ], ) @@ -1488,7 +1494,6 @@ def test_constraints(constraint, x, expected): ) @pytest.mark.parametrize("shape", [(), (1,), (3,), (6,), (3, 1), (1, 3), (5, 3)]) def test_biject_to(constraint, shape): - transform = biject_to(constraint) event_dim = transform.domain.event_dim if isinstance(constraint, constraints._Interval): @@ -1547,9 +1552,9 @@ def inv_vec_transform(y): if constraint is constraints.corr_matrix: # fill the upper triangular part matrix = ( - matrix - + jnp.swapaxes(matrix, -2, -1) - + jnp.identity(matrix.shape[-1]) + matrix + + jnp.swapaxes(matrix, -2, -1) + + jnp.identity(matrix.shape[-1]) ) return transform.inv(matrix) @@ -1568,9 +1573,9 @@ def inv_vec_transform(y): if constraint is constraints.positive_definite: # fill the upper triangular part matrix = ( - matrix - + jnp.swapaxes(matrix, -2, -1) - - jnp.diag(jnp.diag(matrix)) + matrix + + jnp.swapaxes(matrix, -2, -1) + - jnp.diag(jnp.diag(matrix)) ) return transform.inv(matrix) @@ -1592,10 +1597,10 @@ def inv_vec_transform(y): (PowerTransform(2.0), ()), (SoftplusTransform(), ()), ( - LowerCholeskyAffine( - jnp.array([1.0, 2.0]), jnp.array([[0.6, 0.0], [1.5, 0.4]]) - ), - (2,), + LowerCholeskyAffine( + jnp.array([1.0, 2.0]), jnp.array([[0.6, 0.0], [1.5, 0.4]]) + ), + (2,), ), ], ) @@ -1649,7 +1654,7 @@ def test_composed_transform(batch_shape): log_det = t.log_abs_det_jacobian(x, y) assert log_det.shape == batch_shape expected_log_det = ( - jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 + jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 ) assert_allclose(log_det, expected_log_det) @@ -1668,9 +1673,9 @@ def test_composed_transform_1(batch_shape): assert log_det.shape == batch_shape z = t2(x * 2) expected_log_det = ( - jnp.log(2) * 6 - + t2.log_abs_det_jacobian(x * 2, z) - + t2.log_abs_det_jacobian(z, t2(z)).sum(-1) + jnp.log(2) * 6 + + t2.log_abs_det_jacobian(x * 2, z) + + t2.log_abs_det_jacobian(z, t2(z)).sum(-1) ) assert_allclose(log_det, expected_log_det) @@ -1681,8 +1686,8 @@ def test_composed_transform_1(batch_shape): def test_transformed_distribution(batch_shape, prepend_event_shape, sample_shape): base_dist = ( dist.Normal(0, 1) - .expand(batch_shape + prepend_event_shape + (6,)) - .to_event(1 + len(prepend_event_shape)) + .expand(batch_shape + prepend_event_shape + (6,)) + .to_event(1 + len(prepend_event_shape)) ) t1 = transforms.AffineTransform(0, 2) t2 = transforms.LowerCholeskyTransform() @@ -1794,7 +1799,7 @@ def test_unpack_transform(x_dim, y_dim): @pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS) def test_generated_sample_distribution( - jax_dist, sp_dist, params, N_sample=100_000, key=random.PRNGKey(11) + jax_dist, sp_dist, params, N_sample=100_000, key=random.PRNGKey(11) ): """On samplers that we do not get directly from JAX, (e.g. we only get Gumbel(0,1) but also provide samplers for Gumbel(loc, scale)), also test @@ -1852,8 +1857,8 @@ def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape): assert expanded_dist.log_prob(samples).shape == sample_shape + new_batch_shape # test expand of expand assert ( - expanded_dist.expand((3,) + new_batch_shape).batch_shape - == (3,) + new_batch_shape + expanded_dist.expand((3,) + new_batch_shape).batch_shape + == (3,) + new_batch_shape ) # test expand error if prepend_shape: diff --git a/test/test_distributions_util.py b/test/test_distributions_util.py index bcd51dc4f..830725e5e 100644 --- a/test/test_distributions_util.py +++ b/test/test_distributions_util.py @@ -20,7 +20,7 @@ multinomial, safe_normalize, vec_to_tril_matrix, - von_mises_centered + von_mises_centered, ) diff --git a/test/test_example_utils.py b/test/test_example_utils.py index c170a3cee..65221d2d7 100644 --- a/test/test_example_utils.py +++ b/test/test_example_utils.py @@ -3,7 +3,14 @@ import jax.numpy as jnp -from numpyro.examples.datasets import BASEBALL, COVTYPE, JSB_CHORALES, MNIST, SP500, load_dataset +from numpyro.examples.datasets import ( + BASEBALL, + COVTYPE, + JSB_CHORALES, + MNIST, + SP500, + load_dataset, +) from numpyro.util import fori_loop diff --git a/test/test_flows.py b/test/test_flows.py index 21d16e4b7..769644fc3 100644 --- a/test/test_flows.py +++ b/test/test_flows.py @@ -10,7 +10,10 @@ from jax import jacfwd, random from jax.experimental import stax -from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform, InverseAutoregressiveTransform +from numpyro.distributions.flows import ( + BlockNeuralAutoregressiveTransform, + InverseAutoregressiveTransform, +) from numpyro.distributions.util import matrix_to_tril_vec from numpyro.nn import AutoregressiveNN, BlockNeuralAutoregressiveNN diff --git a/test/test_pickle.py b/test/test_pickle.py index 21d947a75..bf0d79793 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -10,7 +10,16 @@ import numpyro import numpyro.distributions as dist -from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SA, BarkerMH, DiscreteHMCGibbs, MixedHMC +from numpyro.infer import ( + HMC, + HMCECS, + MCMC, + NUTS, + SA, + BarkerMH, + DiscreteHMCGibbs, + MixedHMC, +) def normal_model(): From bde8204fcbed38c1088439f5eab7d0d60ac94d30 Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 10 Jun 2021 21:26:21 +0200 Subject: [PATCH 05/23] Ran black --- numpyro/distributions/directional.py | 173 ++++++++++++------ test/test_distributions.py | 256 +++++++++++++++------------ 2 files changed, 261 insertions(+), 168 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index d3c7ace07..f10a73e1f 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -31,7 +31,7 @@ def _numel(shape): def log_I1(orders: int, value, terms=250): - r""" Compute first n log modified bessel function of first kind + r"""Compute first n log modified bessel function of first kind .. math :: \log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk) - \lgamma(v + k + 1)\right]) @@ -49,7 +49,7 @@ def log_I1(orders: int, value, terms=250): flat_vshape = _numel(vshape) k = jnp.arange(terms) - lgammas_all = lax.lgamma(jnp.arange(1., terms + orders + 1)) + lgammas_all = lax.lgamma(jnp.arange(1.0, terms + orders + 1)) assert lgammas_all.shape == (orders + terms,) # lgamma(0) = inf => start from 1 lvalues = lax.log(value / 2) * k.reshape(1, -1) @@ -64,8 +64,12 @@ def log_I1(orders: int, value, terms=250): indices = k[:orders].reshape(-1, 1) + k.reshape(1, -1) assert indices.shape == (orders, terms) - seqs = logsumexp(2 * lvalues[None, :, :] - lfactorials[None, None, :] - - jnp.take_along_axis(lgammas, indices, axis=1)[:, None, :], -1) + seqs = logsumexp( + 2 * lvalues[None, :, :] + - lfactorials[None, None, :] + - jnp.take_along_axis(lgammas, indices, axis=1)[:, None, :], + -1, + ) assert seqs.shape == (orders, flat_vshape) i1s = lvalues[..., :orders].T + seqs @@ -110,8 +114,9 @@ def sample(self, key, sample_shape=()): @validate_sample def log_prob(self, value): - return -(jnp.log(2 * jnp.pi) + jnp.log(i0e(self.concentration))) + \ - self.concentration * (jnp.cos((value - self.loc) % (2 * jnp.pi)) - 1) + return -( + jnp.log(2 * jnp.pi) + jnp.log(i0e(self.concentration)) + ) + self.concentration * (jnp.cos((value - self.loc) % (2 * jnp.pi)) - 1) @property def mean(self): @@ -128,11 +133,11 @@ def variance(self): ) -PhiMarginalState = namedtuple("PhiMarginalState", ['i', 'done', 'phi', 'key']) +PhiMarginalState = namedtuple("PhiMarginalState", ["i", "done", "phi", "key"]) class SineBivariateVonMises(Distribution): - r""" Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by + r"""Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by .. math:: C^{-1}\exp(\kappa_1\cos(x-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) and @@ -162,44 +167,78 @@ class SineBivariateVonMises(Distribution): to avoid bimodality (see note). """ - arg_constraints = {'phi_loc': constraints.real, 'psi_loc': constraints.real, - 'phi_concentration': constraints.positive, 'psi_concentration': constraints.positive, - 'correlation': constraints.real} + arg_constraints = { + "phi_loc": constraints.real, + "psi_loc": constraints.real, + "phi_concentration": constraints.positive, + "psi_concentration": constraints.positive, + "correlation": constraints.real, + } support = constraints.independent(constraints.real, 1) max_sample_iter = 1000 - def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration, correlation=None, - weighted_correlation=None, validate_args=None): + def __init__( + self, + phi_loc, + psi_loc, + phi_concentration, + psi_concentration, + correlation=None, + weighted_correlation=None, + validate_args=None, + ): assert (correlation is None) != (weighted_correlation is None) if weighted_correlation is not None: - correlation = weighted_correlation * jnp.sqrt(phi_concentration * psi_concentration) + 1e-8 - - self.phi_loc, self.psi_loc, self.phi_concentration, self.psi_concentration, self.correlation = promote_shapes( - phi_loc, psi_loc, - phi_concentration, - psi_concentration, - correlation) - batch_shape = lax.broadcast_shapes(phi_loc.shape, psi_loc.shape, phi_concentration.shape, - psi_concentration.shape, correlation.shape) + correlation = ( + weighted_correlation * jnp.sqrt(phi_concentration * psi_concentration) + + 1e-8 + ) + + ( + self.phi_loc, + self.psi_loc, + self.phi_concentration, + self.psi_concentration, + self.correlation, + ) = promote_shapes( + phi_loc, psi_loc, phi_concentration, psi_concentration, correlation + ) + batch_shape = lax.broadcast_shapes( + phi_loc.shape, + psi_loc.shape, + phi_concentration.shape, + psi_concentration.shape, + correlation.shape, + ) super().__init__(batch_shape, (2,), validate_args) - if self._validate_args and jnp.any(phi_concentration * psi_concentration <= correlation ** 2): + if self._validate_args and jnp.any( + phi_concentration * psi_concentration <= correlation ** 2 + ): warnings.warn( - f'{self.__class__.__name__} bimodal due to concentration-correlation relation, ' - f'sampling will likely fail.', UserWarning) + f"{self.__class__.__name__} bimodal due to concentration-correlation relation, " + f"sampling will likely fail.", + UserWarning, + ) @lazy_property def norm_const(self): corr = self.correlation.reshape(1, -1) + 1e-8 - conc = jnp.stack((self.phi_concentration, self.psi_concentration), axis=-1).reshape(-1, 2) + conc = jnp.stack( + (self.phi_concentration, self.psi_concentration), axis=-1 + ).reshape(-1, 2) m = jnp.arange(50).reshape(-1, 1) - num = lax.lgamma(2 * m + 1.) - den = lax.lgamma(m + 1.) + num = lax.lgamma(2 * m + 1.0) + den = lax.lgamma(m + 1.0) lbinoms = num - 2 * den - fs = lbinoms.reshape(-1, 1) + 2 * m * jnp.log(corr) - m * jnp.log(4 * jnp.prod(conc, axis=-1)) + fs = ( + lbinoms.reshape(-1, 1) + + 2 * m * jnp.log(corr) + - m * jnp.log(4 * jnp.prod(conc, axis=-1)) + ) fs += log_I1(49, conc, terms=51).sum(-1) mfs = fs.max() norm_const = 2 * jnp.log(jnp.array(2 * pi)) + mfs + logsumexp(fs - mfs, 0) @@ -208,9 +247,14 @@ def norm_const(self): def log_prob(self, value): if self._validate_args: self._validate_sample(value) - indv = self.phi_concentration * jnp.cos(value[..., 0] - self.phi_loc) + self.psi_concentration * jnp.cos( - value[..., 1] - self.psi_loc) - corr = self.correlation * jnp.sin(value[..., 0] - self.phi_loc) * jnp.sin(value[..., 1] - self.psi_loc) + indv = self.phi_concentration * jnp.cos( + value[..., 0] - self.phi_loc + ) + self.psi_concentration * jnp.cos(value[..., 1] - self.psi_loc) + corr = ( + self.correlation + * jnp.sin(value[..., 0] - self.phi_loc) + * jnp.sin(value[..., 1] - self.psi_loc) + ) return indv + corr - self.norm_const def sample(self, key, sample_shape=()): @@ -233,11 +277,15 @@ def sample(self, key, sample_shape=()): total = _numel(sample_shape) phi_den = log_I1(0, conc[1]).squeeze(0) phi_shape = (total, 2, _numel(self.batch_shape)) - phi_state = SineBivariateVonMises._phi_marginal(phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den) + phi_state = SineBivariateVonMises._phi_marginal( + phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den + ) if not jnp.all(phi_state.done): - raise ValueError("maximum number of iterations exceeded; " - "try increasing `SineBivariateVonMises.max_sample_iter`") + raise ValueError( + "maximum number of iterations exceeded; " + "try increasing `SineBivariateVonMises.max_sample_iter`" + ) phi = lax.atan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) @@ -246,8 +294,13 @@ def sample(self, key, sample_shape=()): psi = VonMises(beta, alpha).sample(psi_key) - phi_psi = jnp.concatenate(((phi + self.phi_loc + pi) % (2 * pi) - pi, - (psi + self.psi_loc + pi) % (2 * pi) - pi), axis=1) + phi_psi = jnp.concatenate( + ( + (phi + self.phi_loc + pi) % (2 * pi) - pi, + (psi + self.psi_loc + pi) % (2 * pi) - pi, + ), + axis=1, + ) phi_psi = jnp.transpose(phi_psi, (0, 2, 1)) return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape) @@ -265,13 +318,23 @@ def update_fn(curr): accept_key, acg_key, phi_key = random.split(phi_key, 3) x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) - x /= jnp.linalg.norm(x, axis=1)[:, None, :] # Angular Central Gaussian distribution - - lf = conc[:, :1] * (x[:, :1] - 1) + eigmin + log_I1(0, jnp.sqrt( - conc[:, 1:] ** 2 + (corr * x[:, 1:]) ** 2)).squeeze(0) - phi_den + x /= jnp.linalg.norm(x, axis=1)[ + :, None, : + ] # Angular Central Gaussian distribution + + lf = ( + conc[:, :1] * (x[:, :1] - 1) + + eigmin + + log_I1( + 0, jnp.sqrt(conc[:, 1:] ** 2 + (corr * x[:, 1:]) ** 2) + ).squeeze(0) + - phi_den + ) assert lf.shape == shape - lg_inv = 1. - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True)) + lg_inv = ( + 1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True)) + ) assert lg_inv.shape == lf.shape accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv) @@ -280,14 +343,24 @@ def update_fn(curr): return PhiMarginalState(i + 1, done | accepted, phi, key) def cond_fn(curr): - return jnp.bitwise_and(curr.i < SineBivariateVonMises.max_sample_iter, jnp.logical_not(jnp.all(curr.done))) - - phi_state = while_loop(cond_fn, update_fn, - PhiMarginalState(i=jnp.array(0), - done=jnp.zeros(shape, dtype=bool), - phi=jnp.empty(shape, dtype=float), - key=rng_key)) - return PhiMarginalState(phi_state.i, phi_state.done, phi_state.phi, phi_state.key) + return jnp.bitwise_and( + curr.i < SineBivariateVonMises.max_sample_iter, + jnp.logical_not(jnp.all(curr.done)), + ) + + phi_state = while_loop( + cond_fn, + update_fn, + PhiMarginalState( + i=jnp.array(0), + done=jnp.zeros(shape, dtype=bool), + phi=jnp.empty(shape, dtype=float), + key=rng_key, + ), + ) + return PhiMarginalState( + phi_state.i, phi_state.done, phi_state.phi, phi_state.key + ) @property def mean(self): diff --git a/test/test_distributions.py b/test/test_distributions.py index 7e18d8471..39983066e 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -334,11 +334,31 @@ def get_sp_dist(jax_dist): T(dist.VonMises, 2.0, 10.0), T(dist.VonMises, 2.0, jnp.array([150.0, 10.0])), T(dist.VonMises, jnp.array([1 / 3 * jnp.pi, -1.0]), jnp.array([20.0, 30.0])), - T(SineBivariateVonMises, jnp.array([0.]), jnp.array([0.]), jnp.array([5.]), jnp.array([6.]), jnp.array([2.])), - T(SineBivariateVonMises, jnp.array([3.003]), jnp.array([-1.3430]), jnp.array([5.]), jnp.array([6.]), - jnp.array([2.])), - T(SineBivariateVonMises, jnp.array([math.pi - .2, 1.]), jnp.array([0., 1.]), jnp.array([2.123, 20.]), - jnp.array([7., .5]), None, jnp.array([.2, .5])), + T( + SineBivariateVonMises, + jnp.array([0.0]), + jnp.array([0.0]), + jnp.array([5.0]), + jnp.array([6.0]), + jnp.array([2.0]), + ), + T( + SineBivariateVonMises, + jnp.array([3.003]), + jnp.array([-1.3430]), + jnp.array([5.0]), + jnp.array([6.0]), + jnp.array([2.0]), + ), + T( + SineBivariateVonMises, + jnp.array([math.pi - 0.2, 1.0]), + jnp.array([0.0, 1.0]), + jnp.array([2.123, 20.0]), + jnp.array([7.0, 0.5]), + None, + jnp.array([0.2, 0.5]), + ), T(dist.ProjectedNormal, jnp.array([0.0, 0.0])), T(dist.ProjectedNormal, jnp.array([[2.0, 3.0]])), T(dist.ProjectedNormal, jnp.array([0.0, 0.0, 0.0])), @@ -494,22 +514,22 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): elif isinstance(constraint, constraints.multinomial): n = size[-1] return ( - multinomial( - key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1] - ) - + 1 + multinomial( + key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1] + ) + + 1 ) elif constraint is constraints.corr_cholesky: return ( - signed_stick_breaking_tril( - random.uniform( - key, - size[:-2] + (size[-1] * (size[-1] - 1) // 2,), - minval=-1, - maxval=1, - ) + signed_stick_breaking_tril( + random.uniform( + key, + size[:-2] + (size[-1] * (size[-1] - 1) // 2,), + minval=-1, + maxval=1, ) - + 1e-2 + ) + + 1e-2 ) elif constraint is constraints.corr_matrix: cholesky = 1e-2 + signed_stick_breaking_tril( @@ -720,10 +740,10 @@ def g(params): ) def test_jit_log_likelihood(jax_dist, sp_dist, params): if jax_dist.__name__ in ( - "GaussianRandomWalk", - "_ImproperWrapper", - "LKJ", - "LKJCholesky", + "GaussianRandomWalk", + "_ImproperWrapper", + "LKJ", + "LKJCholesky", ): pytest.xfail(reason="non-jittable params") @@ -751,12 +771,12 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape if sp_dist is None: if isinstance( - jax_dist, - ( - dist.LeftTruncatedDistribution, - dist.RightTruncatedDistribution, - dist.TwoSidedTruncatedDistribution, - ), + jax_dist, + ( + dist.LeftTruncatedDistribution, + dist.RightTruncatedDistribution, + dist.TwoSidedTruncatedDistribution, + ), ): if isinstance(params[0], dist.Distribution): # new api @@ -1082,7 +1102,7 @@ def fn(*args): eps = 1e-3 for i in range(len(params)): if isinstance( - params[i], dist.Distribution + params[i], dist.Distribution ): # skip taking grad w.r.t. base_dist continue if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64): @@ -1111,10 +1131,10 @@ def test_mean_var(jax_dist, sp_dist, params): if jax_dist is FoldedNormal: pytest.skip("Folded distribution does not has mean/var implemented") if jax_dist in ( - _TruncatedNormal, - dist.LeftTruncatedDistribution, - dist.RightTruncatedDistribution, - dist.TwoSidedTruncatedDistribution, + _TruncatedNormal, + dist.LeftTruncatedDistribution, + dist.RightTruncatedDistribution, + dist.TwoSidedTruncatedDistribution, ): pytest.skip("Truncated distributions do not has mean/var implemented") if jax_dist is dist.ProjectedNormal: @@ -1127,9 +1147,9 @@ def test_mean_var(jax_dist, sp_dist, params): # check with suitable scipy implementation if available # XXX: VonMises is already tested below if ( - sp_dist - and not _is_batched_multivariate(d_jax) - and jax_dist not in [dist.VonMises] + sp_dist + and not _is_batched_multivariate(d_jax) + and jax_dist not in [dist.VonMises] ): d_sp = sp_dist(*params) try: @@ -1215,13 +1235,13 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): dependent_constraint = False for i in range(len(params)): if ( - jax_dist in (_ImproperWrapper, dist.LKJ, dist.LKJCholesky) - and dist_args[i] != "concentration" + jax_dist in (_ImproperWrapper, dist.LKJ, dist.LKJCholesky) + and dist_args[i] != "concentration" ): continue if ( - jax_dist is dist.TwoSidedTruncatedDistribution - and dist_args[i] == "base_dist" + jax_dist is dist.TwoSidedTruncatedDistribution + and dist_args[i] == "base_dist" ): continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": @@ -1264,9 +1284,9 @@ def dist_gen_fn(): # Test agreement of log density evaluation on randomly generated samples # with scipy's implementation when available. if ( - sp_dist - and not _is_batched_multivariate(d) - and not (d.event_shape and prepend_shape) + sp_dist + and not _is_batched_multivariate(d) + and not (d.event_shape and prepend_shape) ): valid_samples = gen_values_within_bounds( d.support, size=prepend_shape + d.batch_shape + d.event_shape @@ -1342,113 +1362,113 @@ def g(x): (constraints.boolean, jnp.array([1, 1]), jnp.array([True, True])), (constraints.boolean, jnp.array([-1, 1]), jnp.array([False, True])), ( - constraints.corr_cholesky, - jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), - jnp.array([True, False]), + constraints.corr_cholesky, + jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), + jnp.array([True, False]), ), # NB: not lower_triangular ( - constraints.corr_cholesky, - jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), - jnp.array([False, False]), + constraints.corr_cholesky, + jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), + jnp.array([False, False]), ), # NB: not positive_diagonal & not unit_norm_row ( - constraints.corr_matrix, - jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), - jnp.array([True, False]), + constraints.corr_matrix, + jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), + jnp.array([True, False]), ), # NB: not lower_triangular ( - constraints.corr_matrix, - jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), - jnp.array([False, False]), + constraints.corr_matrix, + jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), + jnp.array([False, False]), ), # NB: not unit diagonal (constraints.greater_than(1), 3, True), ( - constraints.greater_than(1), - jnp.array([-1, 1, 5]), - jnp.array([False, False, True]), + constraints.greater_than(1), + jnp.array([-1, 1, 5]), + jnp.array([False, False, True]), ), (constraints.integer_interval(-3, 5), 0, True), ( - constraints.integer_interval(-3, 5), - jnp.array([-5, -3, 0, 1.1, 5, 7]), - jnp.array([False, True, True, False, True, False]), + constraints.integer_interval(-3, 5), + jnp.array([-5, -3, 0, 1.1, 5, 7]), + jnp.array([False, True, True, False, True, False]), ), (constraints.interval(-3, 5), 0, True), ( - constraints.interval(-3, 5), - jnp.array([-5, -3, 0, 5, 7]), - jnp.array([False, True, True, True, False]), + constraints.interval(-3, 5), + jnp.array([-5, -3, 0, 5, 7]), + jnp.array([False, True, True, True, False]), ), (constraints.less_than(1), -2, True), ( - constraints.less_than(1), - jnp.array([-1, 1, 5]), - jnp.array([True, False, False]), + constraints.less_than(1), + jnp.array([-1, 1, 5]), + jnp.array([True, False, False]), ), (constraints.lower_cholesky, jnp.array([[1.0, 0.0], [-2.0, 0.1]]), True), ( - constraints.lower_cholesky, - jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), - jnp.array([False, False]), + constraints.lower_cholesky, + jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), + jnp.array([False, False]), ), (constraints.nonnegative_integer, 3, True), ( - constraints.nonnegative_integer, - jnp.array([-1.0, 0.0, 5.0]), - jnp.array([False, True, True]), + constraints.nonnegative_integer, + jnp.array([-1.0, 0.0, 5.0]), + jnp.array([False, True, True]), ), (constraints.positive, 3, True), (constraints.positive, jnp.array([-1, 0, 5]), jnp.array([False, False, True])), (constraints.positive_definite, jnp.array([[1.0, 0.3], [0.3, 1.0]]), True), ( - constraints.positive_definite, - jnp.array([[[2.0, 0.4], [0.3, 2.0]], [[1.0, 0.1], [0.1, 0.0]]]), - jnp.array([False, False]), + constraints.positive_definite, + jnp.array([[[2.0, 0.4], [0.3, 2.0]], [[1.0, 0.1], [0.1, 0.0]]]), + jnp.array([False, False]), ), (constraints.positive_integer, 3, True), ( - constraints.positive_integer, - jnp.array([-1.0, 0.0, 5.0]), - jnp.array([False, False, True]), + constraints.positive_integer, + jnp.array([-1.0, 0.0, 5.0]), + jnp.array([False, False, True]), ), (constraints.real, -1, True), ( - constraints.real, - jnp.array([jnp.inf, jnp.NINF, jnp.nan, jnp.pi]), - jnp.array([False, False, False, True]), + constraints.real, + jnp.array([jnp.inf, jnp.NINF, jnp.nan, jnp.pi]), + jnp.array([False, False, False, True]), ), (constraints.simplex, jnp.array([0.1, 0.3, 0.6]), True), ( - constraints.simplex, - jnp.array([[0.1, 0.3, 0.6], [-0.1, 0.6, 0.5], [0.1, 0.6, 0.5]]), - jnp.array([True, False, False]), + constraints.simplex, + jnp.array([[0.1, 0.3, 0.6], [-0.1, 0.6, 0.5], [0.1, 0.6, 0.5]]), + jnp.array([True, False, False]), ), (constraints.softplus_positive, 3, True), ( - constraints.softplus_positive, - jnp.array([-1, 0, 5]), - jnp.array([False, False, True]), + constraints.softplus_positive, + jnp.array([-1, 0, 5]), + jnp.array([False, False, True]), ), ( - constraints.softplus_lower_cholesky, - jnp.array([[1.0, 0.0], [-2.0, 0.1]]), - True, + constraints.softplus_lower_cholesky, + jnp.array([[1.0, 0.0], [-2.0, 0.1]]), + True, ), ( - constraints.softplus_lower_cholesky, - jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), - jnp.array([False, False]), + constraints.softplus_lower_cholesky, + jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), + jnp.array([False, False]), ), (constraints.unit_interval, 0.1, True), ( - constraints.unit_interval, - jnp.array([-5, 0, 0.5, 1, 7]), - jnp.array([False, True, True, True, False]), + constraints.unit_interval, + jnp.array([-5, 0, 0.5, 1, 7]), + jnp.array([False, True, True, True, False]), ), ( - constraints.sphere, - jnp.array([[1, 0, 0], [0.5, 0.5, 0]]), - jnp.array([True, False]), + constraints.sphere, + jnp.array([[1, 0, 0], [0.5, 0.5, 0]]), + jnp.array([True, False]), ), ], ) @@ -1552,9 +1572,9 @@ def inv_vec_transform(y): if constraint is constraints.corr_matrix: # fill the upper triangular part matrix = ( - matrix - + jnp.swapaxes(matrix, -2, -1) - + jnp.identity(matrix.shape[-1]) + matrix + + jnp.swapaxes(matrix, -2, -1) + + jnp.identity(matrix.shape[-1]) ) return transform.inv(matrix) @@ -1573,9 +1593,9 @@ def inv_vec_transform(y): if constraint is constraints.positive_definite: # fill the upper triangular part matrix = ( - matrix - + jnp.swapaxes(matrix, -2, -1) - - jnp.diag(jnp.diag(matrix)) + matrix + + jnp.swapaxes(matrix, -2, -1) + - jnp.diag(jnp.diag(matrix)) ) return transform.inv(matrix) @@ -1597,10 +1617,10 @@ def inv_vec_transform(y): (PowerTransform(2.0), ()), (SoftplusTransform(), ()), ( - LowerCholeskyAffine( - jnp.array([1.0, 2.0]), jnp.array([[0.6, 0.0], [1.5, 0.4]]) - ), - (2,), + LowerCholeskyAffine( + jnp.array([1.0, 2.0]), jnp.array([[0.6, 0.0], [1.5, 0.4]]) + ), + (2,), ), ], ) @@ -1654,7 +1674,7 @@ def test_composed_transform(batch_shape): log_det = t.log_abs_det_jacobian(x, y) assert log_det.shape == batch_shape expected_log_det = ( - jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 + jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 ) assert_allclose(log_det, expected_log_det) @@ -1673,9 +1693,9 @@ def test_composed_transform_1(batch_shape): assert log_det.shape == batch_shape z = t2(x * 2) expected_log_det = ( - jnp.log(2) * 6 - + t2.log_abs_det_jacobian(x * 2, z) - + t2.log_abs_det_jacobian(z, t2(z)).sum(-1) + jnp.log(2) * 6 + + t2.log_abs_det_jacobian(x * 2, z) + + t2.log_abs_det_jacobian(z, t2(z)).sum(-1) ) assert_allclose(log_det, expected_log_det) @@ -1686,8 +1706,8 @@ def test_composed_transform_1(batch_shape): def test_transformed_distribution(batch_shape, prepend_event_shape, sample_shape): base_dist = ( dist.Normal(0, 1) - .expand(batch_shape + prepend_event_shape + (6,)) - .to_event(1 + len(prepend_event_shape)) + .expand(batch_shape + prepend_event_shape + (6,)) + .to_event(1 + len(prepend_event_shape)) ) t1 = transforms.AffineTransform(0, 2) t2 = transforms.LowerCholeskyTransform() @@ -1799,7 +1819,7 @@ def test_unpack_transform(x_dim, y_dim): @pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS) def test_generated_sample_distribution( - jax_dist, sp_dist, params, N_sample=100_000, key=random.PRNGKey(11) + jax_dist, sp_dist, params, N_sample=100_000, key=random.PRNGKey(11) ): """On samplers that we do not get directly from JAX, (e.g. we only get Gumbel(0,1) but also provide samplers for Gumbel(loc, scale)), also test @@ -1857,8 +1877,8 @@ def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape): assert expanded_dist.log_prob(samples).shape == sample_shape + new_batch_shape # test expand of expand assert ( - expanded_dist.expand((3,) + new_batch_shape).batch_shape - == (3,) + new_batch_shape + expanded_dist.expand((3,) + new_batch_shape).batch_shape + == (3,) + new_batch_shape ) # test expand error if prepend_shape: From ebd876511ea4ceed5b23cfc96d0c655a93509c8f Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 10 Jun 2021 21:36:04 +0200 Subject: [PATCH 06/23] fixed license --- test/test_distributions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_distributions.py b/test/test_distributions.py index 39983066e..34cd5214e 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1,5 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + from collections import namedtuple from functools import partial import inspect From 2aef87d06d0c1ffd087138386da9060dcb5cbc2c Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 10 Jun 2021 21:44:48 +0200 Subject: [PATCH 07/23] Added BvM to docs --- docs/source/distributions.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 845d22282..7b08cda41 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -477,6 +477,13 @@ ZeroInflatedNegativeBinomial2 Directional Distributions ========================= +SineBivariateVonMises +--------------------- +.. autoclass:: numpyro.distributions.directional.SineBivariateVonMises + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource ProjectedNormal --------------- From eccbdde8283e2de2deab812e616c4ff73361da36 Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 8 Jul 2021 11:44:54 +0200 Subject: [PATCH 08/23] Fixed docstring --- numpyro/distributions/directional.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index f10a73e1f..e4cfb0c0c 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -138,26 +138,38 @@ def variance(self): class SineBivariateVonMises(Distribution): r"""Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by + .. math:: C^{-1}\exp(\kappa_1\cos(x-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) + and + .. math:: C = (2\pi)^2 \sum_{i=0} {2i \choose i} \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2), + where I_i(\cdot) is the modified bessel function of first kind, mu's are the locations of the distribution, kappa's are the concentration and rho gives the correlation between angles x_1 and x_2. This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains. - To infer parameters, use :class:`~pyro.infer.NUTS` or :class:`~pyro.infer.HMC` with priors that + + To infer parameters, use :class:`~numpyro.infer.NUTS` or :class:`~numpyro.infer.HMC` with priors that avoid parameterizations where the distribution becomes bimodal; see note below. + .. note:: Sample efficiency drops as + .. math:: \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1 + because the distribution becomes increasingly bimodal. + .. note:: The correlation and weighted_correlation params are mutually exclusive. - .. note:: In the context of :class:`~pyro.infer.SVI`, this distribution can be used as a likelihood but not for + + .. note:: In the context of :class:`~numpyro.infer.SVI`, this distribution can be used as a likelihood but not for latent variables. + ** References: ** - 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) + 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) + :param jnp.Tensor phi_loc: location of first angle :param jnp.Tensor psi_loc: location of second angle :param jnp.Tensor phi_concentration: concentration of first angle @@ -281,12 +293,6 @@ def sample(self, key, sample_shape=()): phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den ) - if not jnp.all(phi_state.done): - raise ValueError( - "maximum number of iterations exceeded; " - "try increasing `SineBivariateVonMises.max_sample_iter`" - ) - phi = lax.atan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) alpha = jnp.sqrt(conc[1] ** 2 + (corr * jnp.sin(phi)) ** 2) From 95830742161a30a11de541a9a21ec60c95ec9c52 Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 8 Jul 2021 12:04:51 +0200 Subject: [PATCH 09/23] Added math envs to docstring for `SineBivariateVonMises`. --- numpyro/distributions/directional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index e4cfb0c0c..5b3906ef4 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -140,7 +140,7 @@ class SineBivariateVonMises(Distribution): r"""Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by .. math:: - C^{-1}\exp(\kappa_1\cos(x-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) + C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) and @@ -148,8 +148,8 @@ class SineBivariateVonMises(Distribution): C = (2\pi)^2 \sum_{i=0} {2i \choose i} \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2), - where I_i(\cdot) is the modified bessel function of first kind, mu's are the locations of the distribution, - kappa's are the concentration and rho gives the correlation between angles x_1 and x_2. + where :math:`I_i(\cdot)` is the modified bessel function of first kind, mu's are the locations of the distribution, + kappa's are the concentration and rho gives the correlation between angles :math:`x_1` and :math:`x_2`. This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains. To infer parameters, use :class:`~numpyro.infer.NUTS` or :class:`~numpyro.infer.HMC` with priors that From 269ae670ca82f86831ec4ed67d96fde385034e5f Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 8 Jul 2021 12:40:57 +0200 Subject: [PATCH 10/23] Fixed `test_distribution_constraints` failures for `SineBivariateVonMises`. --- numpyro/distributions/directional.py | 13 +- test/test_distributions.py | 236 +++++++++++++-------------- 2 files changed, 118 insertions(+), 131 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 5b3906ef4..ba8c7ce5e 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -226,15 +226,6 @@ def __init__( ) super().__init__(batch_shape, (2,), validate_args) - if self._validate_args and jnp.any( - phi_concentration * psi_concentration <= correlation ** 2 - ): - warnings.warn( - f"{self.__class__.__name__} bimodal due to concentration-correlation relation, " - f"sampling will likely fail.", - UserWarning, - ) - @lazy_property def norm_const(self): corr = self.correlation.reshape(1, -1) + 1e-8 @@ -294,6 +285,7 @@ def sample(self, key, sample_shape=()): ) phi = lax.atan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) + print(phi) alpha = jnp.sqrt(conc[1] ** 2 + (corr * jnp.sin(phi)) ** 2) beta = lax.atan(corr / conc[1] * jnp.sin(phi)) @@ -370,7 +362,8 @@ def cond_fn(curr): @property def mean(self): - return jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + """Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]""" + return (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % (2.0 * jnp.pi) - jnp.pi def _bfind(self, eig): b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype) diff --git a/test/test_distributions.py b/test/test_distributions.py index bd50ac319..ab2957673 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4,7 +4,6 @@ from collections import namedtuple from functools import partial import inspect -import math import os import math @@ -291,13 +290,6 @@ def get_sp_dist(jax_dist): T(dist.Pareto, 1.0, 2.0), T(dist.Pareto, jnp.array([1.0, 0.5]), jnp.array([0.3, 2.0])), T(dist.Pareto, jnp.array([[1.0], [3.0]]), jnp.array([1.0, 0.5])), - T(dist.SineBivariateVonMises, jnp.array([0.]), jnp.array([0.]), jnp.array([5.]), jnp.array([6.]), jnp.array([2.])), - T(dist.SineBivariateVonMises, jnp.array([3.003]), jnp.array([-1.343]), # check test_gof, test_mean_var, - jnp.array([5.]), jnp.array([6.]), jnp.array([2.])), # check test_distribution_constraints - T(dist.SineBivariateVonMises, jnp.array([-math.pi/3]), jnp.array(-1), - jnp.array(.4), jnp.array(10.), jnp.array(.9)), - T(dist.SineBivariateVonMises, jnp.array([math.pi - .2, 1.]), jnp.array([0.,1.]), - jnp.array([5., 5.]), jnp.array([7., .5]), None, jnp.array([.5, .1])), T(dist.SoftLaplace, 1.0, 1.0), T(dist.SoftLaplace, jnp.array([-1.0, 50.0]), jnp.array([4.0, 100.0])), T(dist.StudentT, 1.0, 1.0, 0.5), @@ -523,22 +515,22 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): elif isinstance(constraint, constraints.multinomial): n = size[-1] return ( - multinomial( - key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1] - ) - + 1 + multinomial( + key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1] + ) + + 1 ) elif constraint is constraints.corr_cholesky: return ( - signed_stick_breaking_tril( - random.uniform( - key, - size[:-2] + (size[-1] * (size[-1] - 1) // 2,), - minval=-1, - maxval=1, + signed_stick_breaking_tril( + random.uniform( + key, + size[:-2] + (size[-1] * (size[-1] - 1) // 2,), + minval=-1, + maxval=1, + ) ) - ) - + 1e-2 + + 1e-2 ) elif constraint is constraints.corr_matrix: cholesky = 1e-2 + signed_stick_breaking_tril( @@ -749,10 +741,10 @@ def g(params): ) def test_jit_log_likelihood(jax_dist, sp_dist, params): if jax_dist.__name__ in ( - "GaussianRandomWalk", - "_ImproperWrapper", - "LKJ", - "LKJCholesky", + "GaussianRandomWalk", + "_ImproperWrapper", + "LKJ", + "LKJCholesky", ): pytest.xfail(reason="non-jittable params") @@ -780,12 +772,12 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape if sp_dist is None: if isinstance( - jax_dist, - ( - dist.LeftTruncatedDistribution, - dist.RightTruncatedDistribution, - dist.TwoSidedTruncatedDistribution, - ), + jax_dist, + ( + dist.LeftTruncatedDistribution, + dist.RightTruncatedDistribution, + dist.TwoSidedTruncatedDistribution, + ), ): if isinstance(params[0], dist.Distribution): # new api @@ -1111,7 +1103,7 @@ def fn(*args): eps = 1e-3 for i in range(len(params)): if isinstance( - params[i], dist.Distribution + params[i], dist.Distribution ): # skip taking grad w.r.t. base_dist continue if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64): @@ -1140,10 +1132,10 @@ def test_mean_var(jax_dist, sp_dist, params): if jax_dist is FoldedNormal: pytest.skip("Folded distribution does not has mean/var implemented") if jax_dist in ( - _TruncatedNormal, - dist.LeftTruncatedDistribution, - dist.RightTruncatedDistribution, - dist.TwoSidedTruncatedDistribution, + _TruncatedNormal, + dist.LeftTruncatedDistribution, + dist.RightTruncatedDistribution, + dist.TwoSidedTruncatedDistribution, ): pytest.skip("Truncated distributions do not has mean/var implemented") if jax_dist is dist.ProjectedNormal: @@ -1156,9 +1148,9 @@ def test_mean_var(jax_dist, sp_dist, params): # check with suitable scipy implementation if available # XXX: VonMises is already tested below if ( - sp_dist - and not _is_batched_multivariate(d_jax) - and jax_dist not in [dist.VonMises] + sp_dist + and not _is_batched_multivariate(d_jax) + and jax_dist not in [dist.VonMises] ): d_sp = sp_dist(*params) try: @@ -1244,17 +1236,19 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): dependent_constraint = False for i in range(len(params)): if ( - jax_dist in (_ImproperWrapper, dist.LKJ, dist.LKJCholesky) - and dist_args[i] != "concentration" + jax_dist in (_ImproperWrapper, dist.LKJ, dist.LKJCholesky) + and dist_args[i] != "concentration" ): continue if ( - jax_dist is dist.TwoSidedTruncatedDistribution - and dist_args[i] == "base_dist" + jax_dist is dist.TwoSidedTruncatedDistribution + and dist_args[i] == "base_dist" ): continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": continue + if jax_dist is dist.SineBivariateVonMises and dist_args[i] == 'weighted_correlation': + continue if params[i] is None: oob_params[i] = None valid_params[i] = None @@ -1293,9 +1287,9 @@ def dist_gen_fn(): # Test agreement of log density evaluation on randomly generated samples # with scipy's implementation when available. if ( - sp_dist - and not _is_batched_multivariate(d) - and not (d.event_shape and prepend_shape) + sp_dist + and not _is_batched_multivariate(d) + and not (d.event_shape and prepend_shape) ): valid_samples = gen_values_within_bounds( d.support, size=prepend_shape + d.batch_shape + d.event_shape @@ -1371,113 +1365,113 @@ def g(x): (constraints.boolean, jnp.array([1, 1]), jnp.array([True, True])), (constraints.boolean, jnp.array([-1, 1]), jnp.array([False, True])), ( - constraints.corr_cholesky, - jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), - jnp.array([True, False]), + constraints.corr_cholesky, + jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), + jnp.array([True, False]), ), # NB: not lower_triangular ( - constraints.corr_cholesky, - jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), - jnp.array([False, False]), + constraints.corr_cholesky, + jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), + jnp.array([False, False]), ), # NB: not positive_diagonal & not unit_norm_row ( - constraints.corr_matrix, - jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), - jnp.array([True, False]), + constraints.corr_matrix, + jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), + jnp.array([True, False]), ), # NB: not lower_triangular ( - constraints.corr_matrix, - jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), - jnp.array([False, False]), + constraints.corr_matrix, + jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), + jnp.array([False, False]), ), # NB: not unit diagonal (constraints.greater_than(1), 3, True), ( - constraints.greater_than(1), - jnp.array([-1, 1, 5]), - jnp.array([False, False, True]), + constraints.greater_than(1), + jnp.array([-1, 1, 5]), + jnp.array([False, False, True]), ), (constraints.integer_interval(-3, 5), 0, True), ( - constraints.integer_interval(-3, 5), - jnp.array([-5, -3, 0, 1.1, 5, 7]), - jnp.array([False, True, True, False, True, False]), + constraints.integer_interval(-3, 5), + jnp.array([-5, -3, 0, 1.1, 5, 7]), + jnp.array([False, True, True, False, True, False]), ), (constraints.interval(-3, 5), 0, True), ( - constraints.interval(-3, 5), - jnp.array([-5, -3, 0, 5, 7]), - jnp.array([False, True, True, True, False]), + constraints.interval(-3, 5), + jnp.array([-5, -3, 0, 5, 7]), + jnp.array([False, True, True, True, False]), ), (constraints.less_than(1), -2, True), ( - constraints.less_than(1), - jnp.array([-1, 1, 5]), - jnp.array([True, False, False]), + constraints.less_than(1), + jnp.array([-1, 1, 5]), + jnp.array([True, False, False]), ), (constraints.lower_cholesky, jnp.array([[1.0, 0.0], [-2.0, 0.1]]), True), ( - constraints.lower_cholesky, - jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), - jnp.array([False, False]), + constraints.lower_cholesky, + jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), + jnp.array([False, False]), ), (constraints.nonnegative_integer, 3, True), ( - constraints.nonnegative_integer, - jnp.array([-1.0, 0.0, 5.0]), - jnp.array([False, True, True]), + constraints.nonnegative_integer, + jnp.array([-1.0, 0.0, 5.0]), + jnp.array([False, True, True]), ), (constraints.positive, 3, True), (constraints.positive, jnp.array([-1, 0, 5]), jnp.array([False, False, True])), (constraints.positive_definite, jnp.array([[1.0, 0.3], [0.3, 1.0]]), True), ( - constraints.positive_definite, - jnp.array([[[2.0, 0.4], [0.3, 2.0]], [[1.0, 0.1], [0.1, 0.0]]]), - jnp.array([False, False]), + constraints.positive_definite, + jnp.array([[[2.0, 0.4], [0.3, 2.0]], [[1.0, 0.1], [0.1, 0.0]]]), + jnp.array([False, False]), ), (constraints.positive_integer, 3, True), ( - constraints.positive_integer, - jnp.array([-1.0, 0.0, 5.0]), - jnp.array([False, False, True]), + constraints.positive_integer, + jnp.array([-1.0, 0.0, 5.0]), + jnp.array([False, False, True]), ), (constraints.real, -1, True), ( - constraints.real, - jnp.array([jnp.inf, jnp.NINF, jnp.nan, jnp.pi]), - jnp.array([False, False, False, True]), + constraints.real, + jnp.array([jnp.inf, jnp.NINF, jnp.nan, jnp.pi]), + jnp.array([False, False, False, True]), ), (constraints.simplex, jnp.array([0.1, 0.3, 0.6]), True), ( - constraints.simplex, - jnp.array([[0.1, 0.3, 0.6], [-0.1, 0.6, 0.5], [0.1, 0.6, 0.5]]), - jnp.array([True, False, False]), + constraints.simplex, + jnp.array([[0.1, 0.3, 0.6], [-0.1, 0.6, 0.5], [0.1, 0.6, 0.5]]), + jnp.array([True, False, False]), ), (constraints.softplus_positive, 3, True), ( - constraints.softplus_positive, - jnp.array([-1, 0, 5]), - jnp.array([False, False, True]), + constraints.softplus_positive, + jnp.array([-1, 0, 5]), + jnp.array([False, False, True]), ), ( - constraints.softplus_lower_cholesky, - jnp.array([[1.0, 0.0], [-2.0, 0.1]]), - True, + constraints.softplus_lower_cholesky, + jnp.array([[1.0, 0.0], [-2.0, 0.1]]), + True, ), ( - constraints.softplus_lower_cholesky, - jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), - jnp.array([False, False]), + constraints.softplus_lower_cholesky, + jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), + jnp.array([False, False]), ), (constraints.unit_interval, 0.1, True), ( - constraints.unit_interval, - jnp.array([-5, 0, 0.5, 1, 7]), - jnp.array([False, True, True, True, False]), + constraints.unit_interval, + jnp.array([-5, 0, 0.5, 1, 7]), + jnp.array([False, True, True, True, False]), ), ( - constraints.sphere, - jnp.array([[1, 0, 0], [0.5, 0.5, 0]]), - jnp.array([True, False]), + constraints.sphere, + jnp.array([[1, 0, 0], [0.5, 0.5, 0]]), + jnp.array([True, False]), ), ], ) @@ -1581,9 +1575,9 @@ def inv_vec_transform(y): if constraint is constraints.corr_matrix: # fill the upper triangular part matrix = ( - matrix - + jnp.swapaxes(matrix, -2, -1) - + jnp.identity(matrix.shape[-1]) + matrix + + jnp.swapaxes(matrix, -2, -1) + + jnp.identity(matrix.shape[-1]) ) return transform.inv(matrix) @@ -1602,9 +1596,9 @@ def inv_vec_transform(y): if constraint is constraints.positive_definite: # fill the upper triangular part matrix = ( - matrix - + jnp.swapaxes(matrix, -2, -1) - - jnp.diag(jnp.diag(matrix)) + matrix + + jnp.swapaxes(matrix, -2, -1) + - jnp.diag(jnp.diag(matrix)) ) return transform.inv(matrix) @@ -1626,10 +1620,10 @@ def inv_vec_transform(y): (PowerTransform(2.0), ()), (SoftplusTransform(), ()), ( - LowerCholeskyAffine( - jnp.array([1.0, 2.0]), jnp.array([[0.6, 0.0], [1.5, 0.4]]) - ), - (2,), + LowerCholeskyAffine( + jnp.array([1.0, 2.0]), jnp.array([[0.6, 0.0], [1.5, 0.4]]) + ), + (2,), ), ], ) @@ -1683,7 +1677,7 @@ def test_composed_transform(batch_shape): log_det = t.log_abs_det_jacobian(x, y) assert log_det.shape == batch_shape expected_log_det = ( - jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 + jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 ) assert_allclose(log_det, expected_log_det) @@ -1702,9 +1696,9 @@ def test_composed_transform_1(batch_shape): assert log_det.shape == batch_shape z = t2(x * 2) expected_log_det = ( - jnp.log(2) * 6 - + t2.log_abs_det_jacobian(x * 2, z) - + t2.log_abs_det_jacobian(z, t2(z)).sum(-1) + jnp.log(2) * 6 + + t2.log_abs_det_jacobian(x * 2, z) + + t2.log_abs_det_jacobian(z, t2(z)).sum(-1) ) assert_allclose(log_det, expected_log_det) @@ -1715,8 +1709,8 @@ def test_composed_transform_1(batch_shape): def test_transformed_distribution(batch_shape, prepend_event_shape, sample_shape): base_dist = ( dist.Normal(0, 1) - .expand(batch_shape + prepend_event_shape + (6,)) - .to_event(1 + len(prepend_event_shape)) + .expand(batch_shape + prepend_event_shape + (6,)) + .to_event(1 + len(prepend_event_shape)) ) t1 = transforms.AffineTransform(0, 2) t2 = transforms.LowerCholeskyTransform() @@ -1828,7 +1822,7 @@ def test_unpack_transform(x_dim, y_dim): @pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS) def test_generated_sample_distribution( - jax_dist, sp_dist, params, N_sample=100_000, key=random.PRNGKey(11) + jax_dist, sp_dist, params, N_sample=100_000, key=random.PRNGKey(11) ): """On samplers that we do not get directly from JAX, (e.g. we only get Gumbel(0,1) but also provide samplers for Gumbel(loc, scale)), also test @@ -1886,8 +1880,8 @@ def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape): assert expanded_dist.log_prob(samples).shape == sample_shape + new_batch_shape # test expand of expand assert ( - expanded_dist.expand((3,) + new_batch_shape).batch_shape - == (3,) + new_batch_shape + expanded_dist.expand((3,) + new_batch_shape).batch_shape + == (3,) + new_batch_shape ) # test expand error if prepend_shape: From d6d2b912bae22bc9f32305f6041c67ce2876cef3 Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 8 Jul 2021 13:32:50 +0200 Subject: [PATCH 11/23] Added circular mean to `test_mean_var` --- numpyro/distributions/directional.py | 1 - test/test_distributions.py | 5 +++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index ba8c7ce5e..c7a42024e 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -285,7 +285,6 @@ def sample(self, key, sample_shape=()): ) phi = lax.atan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) - print(phi) alpha = jnp.sqrt(conc[1] ** 2 + (corr * jnp.sin(phi)) ** 2) beta = lax.atan(corr / conc[1] * jnp.sin(phi)) diff --git a/test/test_distributions.py b/test/test_distributions.py index ab2957673..c349e72fe 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1213,6 +1213,11 @@ def test_mean_var(jax_dist, sp_dist, params): expected_variance = 1 - jnp.sqrt(x ** 2 + y ** 2) assert_allclose(d_jax.variance, expected_variance, rtol=0.05, atol=1e-2) + elif jax_dist in [dist.SineBivariateVonMises]: + circ_mean = lambda angles: jnp.arctan2(jnp.mean(jnp.sin(angles), axis=0), jnp.mean(jnp.cos(angles), axis=0)) + phi_loc = circ_mean(samples[..., 0]) + psi_loc = circ_mean(samples[..., 1]) + assert_allclose(d_jax.mean, jnp.stack((phi_loc, psi_loc), axis=-1), rtol=0.05, atol=1e-2) else: if jnp.all(jnp.isfinite(d_jax.mean)): assert_allclose(jnp.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2) From ed40a6b194e999ac39ac7d16066eab4c7406898a Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 8 Jul 2021 13:36:41 +0200 Subject: [PATCH 12/23] Fixed lint. --- numpyro/distributions/__init__.py | 6 +- numpyro/distributions/directional.py | 5 +- test/test_distributions.py | 248 ++++++++++++++------------- 3 files changed, 137 insertions(+), 122 deletions(-) diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index ed017d2fc..428b40508 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -37,7 +37,11 @@ Uniform, Weibull, ) -from numpyro.distributions.directional import ProjectedNormal, VonMises, SineBivariateVonMises +from numpyro.distributions.directional import ( + ProjectedNormal, + SineBivariateVonMises, + VonMises, +) from numpyro.distributions.discrete import ( Bernoulli, BernoulliLogits, diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index c7a42024e..5ae7258a5 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -6,7 +6,6 @@ import math from math import pi import operator -import warnings from jax import lax import jax.numpy as jnp @@ -362,7 +361,9 @@ def cond_fn(curr): @property def mean(self): """Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]""" - return (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % (2.0 * jnp.pi) - jnp.pi + return (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % ( + 2.0 * jnp.pi + ) - jnp.pi def _bfind(self, eig): b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype) diff --git a/test/test_distributions.py b/test/test_distributions.py index c349e72fe..d7876010c 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4,8 +4,8 @@ from collections import namedtuple from functools import partial import inspect -import os import math +import os import numpy as np from numpy.testing import assert_allclose, assert_array_equal @@ -47,6 +47,12 @@ def _identity(x): return x +def _circ_mean(angles): + return jnp.arctan2( + jnp.mean(jnp.sin(angles), axis=0), jnp.mean(jnp.cos(angles), axis=0) + ) + + class T(namedtuple("TestCase", ["jax_dist", "sp_dist", "params"])): def __new__(cls, jax_dist, *params): sp_dist = get_sp_dist(jax_dist) @@ -515,22 +521,22 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): elif isinstance(constraint, constraints.multinomial): n = size[-1] return ( - multinomial( - key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1] - ) - + 1 + multinomial( + key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1] + ) + + 1 ) elif constraint is constraints.corr_cholesky: return ( - signed_stick_breaking_tril( - random.uniform( - key, - size[:-2] + (size[-1] * (size[-1] - 1) // 2,), - minval=-1, - maxval=1, - ) + signed_stick_breaking_tril( + random.uniform( + key, + size[:-2] + (size[-1] * (size[-1] - 1) // 2,), + minval=-1, + maxval=1, ) - + 1e-2 + ) + + 1e-2 ) elif constraint is constraints.corr_matrix: cholesky = 1e-2 + signed_stick_breaking_tril( @@ -741,10 +747,10 @@ def g(params): ) def test_jit_log_likelihood(jax_dist, sp_dist, params): if jax_dist.__name__ in ( - "GaussianRandomWalk", - "_ImproperWrapper", - "LKJ", - "LKJCholesky", + "GaussianRandomWalk", + "_ImproperWrapper", + "LKJ", + "LKJCholesky", ): pytest.xfail(reason="non-jittable params") @@ -772,12 +778,12 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape if sp_dist is None: if isinstance( - jax_dist, - ( - dist.LeftTruncatedDistribution, - dist.RightTruncatedDistribution, - dist.TwoSidedTruncatedDistribution, - ), + jax_dist, + ( + dist.LeftTruncatedDistribution, + dist.RightTruncatedDistribution, + dist.TwoSidedTruncatedDistribution, + ), ): if isinstance(params[0], dist.Distribution): # new api @@ -1103,7 +1109,7 @@ def fn(*args): eps = 1e-3 for i in range(len(params)): if isinstance( - params[i], dist.Distribution + params[i], dist.Distribution ): # skip taking grad w.r.t. base_dist continue if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64): @@ -1132,10 +1138,10 @@ def test_mean_var(jax_dist, sp_dist, params): if jax_dist is FoldedNormal: pytest.skip("Folded distribution does not has mean/var implemented") if jax_dist in ( - _TruncatedNormal, - dist.LeftTruncatedDistribution, - dist.RightTruncatedDistribution, - dist.TwoSidedTruncatedDistribution, + _TruncatedNormal, + dist.LeftTruncatedDistribution, + dist.RightTruncatedDistribution, + dist.TwoSidedTruncatedDistribution, ): pytest.skip("Truncated distributions do not has mean/var implemented") if jax_dist is dist.ProjectedNormal: @@ -1148,9 +1154,9 @@ def test_mean_var(jax_dist, sp_dist, params): # check with suitable scipy implementation if available # XXX: VonMises is already tested below if ( - sp_dist - and not _is_batched_multivariate(d_jax) - and jax_dist not in [dist.VonMises] + sp_dist + and not _is_batched_multivariate(d_jax) + and jax_dist not in [dist.VonMises] ): d_sp = sp_dist(*params) try: @@ -1214,10 +1220,11 @@ def test_mean_var(jax_dist, sp_dist, params): expected_variance = 1 - jnp.sqrt(x ** 2 + y ** 2) assert_allclose(d_jax.variance, expected_variance, rtol=0.05, atol=1e-2) elif jax_dist in [dist.SineBivariateVonMises]: - circ_mean = lambda angles: jnp.arctan2(jnp.mean(jnp.sin(angles), axis=0), jnp.mean(jnp.cos(angles), axis=0)) - phi_loc = circ_mean(samples[..., 0]) - psi_loc = circ_mean(samples[..., 1]) - assert_allclose(d_jax.mean, jnp.stack((phi_loc, psi_loc), axis=-1), rtol=0.05, atol=1e-2) + phi_loc = _circ_mean(samples[..., 0]) + psi_loc = _circ_mean(samples[..., 1]) + assert_allclose( + d_jax.mean, jnp.stack((phi_loc, psi_loc), axis=-1), rtol=0.05, atol=1e-2 + ) else: if jnp.all(jnp.isfinite(d_jax.mean)): assert_allclose(jnp.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2) @@ -1241,18 +1248,21 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): dependent_constraint = False for i in range(len(params)): if ( - jax_dist in (_ImproperWrapper, dist.LKJ, dist.LKJCholesky) - and dist_args[i] != "concentration" + jax_dist in (_ImproperWrapper, dist.LKJ, dist.LKJCholesky) + and dist_args[i] != "concentration" ): continue if ( - jax_dist is dist.TwoSidedTruncatedDistribution - and dist_args[i] == "base_dist" + jax_dist is dist.TwoSidedTruncatedDistribution + and dist_args[i] == "base_dist" ): continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": continue - if jax_dist is dist.SineBivariateVonMises and dist_args[i] == 'weighted_correlation': + if ( + jax_dist is dist.SineBivariateVonMises + and dist_args[i] == "weighted_correlation" + ): continue if params[i] is None: oob_params[i] = None @@ -1292,9 +1302,9 @@ def dist_gen_fn(): # Test agreement of log density evaluation on randomly generated samples # with scipy's implementation when available. if ( - sp_dist - and not _is_batched_multivariate(d) - and not (d.event_shape and prepend_shape) + sp_dist + and not _is_batched_multivariate(d) + and not (d.event_shape and prepend_shape) ): valid_samples = gen_values_within_bounds( d.support, size=prepend_shape + d.batch_shape + d.event_shape @@ -1370,113 +1380,113 @@ def g(x): (constraints.boolean, jnp.array([1, 1]), jnp.array([True, True])), (constraints.boolean, jnp.array([-1, 1]), jnp.array([False, True])), ( - constraints.corr_cholesky, - jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), - jnp.array([True, False]), + constraints.corr_cholesky, + jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), + jnp.array([True, False]), ), # NB: not lower_triangular ( - constraints.corr_cholesky, - jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), - jnp.array([False, False]), + constraints.corr_cholesky, + jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), + jnp.array([False, False]), ), # NB: not positive_diagonal & not unit_norm_row ( - constraints.corr_matrix, - jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), - jnp.array([True, False]), + constraints.corr_matrix, + jnp.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]), + jnp.array([True, False]), ), # NB: not lower_triangular ( - constraints.corr_matrix, - jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), - jnp.array([False, False]), + constraints.corr_matrix, + jnp.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]), + jnp.array([False, False]), ), # NB: not unit diagonal (constraints.greater_than(1), 3, True), ( - constraints.greater_than(1), - jnp.array([-1, 1, 5]), - jnp.array([False, False, True]), + constraints.greater_than(1), + jnp.array([-1, 1, 5]), + jnp.array([False, False, True]), ), (constraints.integer_interval(-3, 5), 0, True), ( - constraints.integer_interval(-3, 5), - jnp.array([-5, -3, 0, 1.1, 5, 7]), - jnp.array([False, True, True, False, True, False]), + constraints.integer_interval(-3, 5), + jnp.array([-5, -3, 0, 1.1, 5, 7]), + jnp.array([False, True, True, False, True, False]), ), (constraints.interval(-3, 5), 0, True), ( - constraints.interval(-3, 5), - jnp.array([-5, -3, 0, 5, 7]), - jnp.array([False, True, True, True, False]), + constraints.interval(-3, 5), + jnp.array([-5, -3, 0, 5, 7]), + jnp.array([False, True, True, True, False]), ), (constraints.less_than(1), -2, True), ( - constraints.less_than(1), - jnp.array([-1, 1, 5]), - jnp.array([True, False, False]), + constraints.less_than(1), + jnp.array([-1, 1, 5]), + jnp.array([True, False, False]), ), (constraints.lower_cholesky, jnp.array([[1.0, 0.0], [-2.0, 0.1]]), True), ( - constraints.lower_cholesky, - jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), - jnp.array([False, False]), + constraints.lower_cholesky, + jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), + jnp.array([False, False]), ), (constraints.nonnegative_integer, 3, True), ( - constraints.nonnegative_integer, - jnp.array([-1.0, 0.0, 5.0]), - jnp.array([False, True, True]), + constraints.nonnegative_integer, + jnp.array([-1.0, 0.0, 5.0]), + jnp.array([False, True, True]), ), (constraints.positive, 3, True), (constraints.positive, jnp.array([-1, 0, 5]), jnp.array([False, False, True])), (constraints.positive_definite, jnp.array([[1.0, 0.3], [0.3, 1.0]]), True), ( - constraints.positive_definite, - jnp.array([[[2.0, 0.4], [0.3, 2.0]], [[1.0, 0.1], [0.1, 0.0]]]), - jnp.array([False, False]), + constraints.positive_definite, + jnp.array([[[2.0, 0.4], [0.3, 2.0]], [[1.0, 0.1], [0.1, 0.0]]]), + jnp.array([False, False]), ), (constraints.positive_integer, 3, True), ( - constraints.positive_integer, - jnp.array([-1.0, 0.0, 5.0]), - jnp.array([False, False, True]), + constraints.positive_integer, + jnp.array([-1.0, 0.0, 5.0]), + jnp.array([False, False, True]), ), (constraints.real, -1, True), ( - constraints.real, - jnp.array([jnp.inf, jnp.NINF, jnp.nan, jnp.pi]), - jnp.array([False, False, False, True]), + constraints.real, + jnp.array([jnp.inf, jnp.NINF, jnp.nan, jnp.pi]), + jnp.array([False, False, False, True]), ), (constraints.simplex, jnp.array([0.1, 0.3, 0.6]), True), ( - constraints.simplex, - jnp.array([[0.1, 0.3, 0.6], [-0.1, 0.6, 0.5], [0.1, 0.6, 0.5]]), - jnp.array([True, False, False]), + constraints.simplex, + jnp.array([[0.1, 0.3, 0.6], [-0.1, 0.6, 0.5], [0.1, 0.6, 0.5]]), + jnp.array([True, False, False]), ), (constraints.softplus_positive, 3, True), ( - constraints.softplus_positive, - jnp.array([-1, 0, 5]), - jnp.array([False, False, True]), + constraints.softplus_positive, + jnp.array([-1, 0, 5]), + jnp.array([False, False, True]), ), ( - constraints.softplus_lower_cholesky, - jnp.array([[1.0, 0.0], [-2.0, 0.1]]), - True, + constraints.softplus_lower_cholesky, + jnp.array([[1.0, 0.0], [-2.0, 0.1]]), + True, ), ( - constraints.softplus_lower_cholesky, - jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), - jnp.array([False, False]), + constraints.softplus_lower_cholesky, + jnp.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]), + jnp.array([False, False]), ), (constraints.unit_interval, 0.1, True), ( - constraints.unit_interval, - jnp.array([-5, 0, 0.5, 1, 7]), - jnp.array([False, True, True, True, False]), + constraints.unit_interval, + jnp.array([-5, 0, 0.5, 1, 7]), + jnp.array([False, True, True, True, False]), ), ( - constraints.sphere, - jnp.array([[1, 0, 0], [0.5, 0.5, 0]]), - jnp.array([True, False]), + constraints.sphere, + jnp.array([[1, 0, 0], [0.5, 0.5, 0]]), + jnp.array([True, False]), ), ], ) @@ -1580,9 +1590,9 @@ def inv_vec_transform(y): if constraint is constraints.corr_matrix: # fill the upper triangular part matrix = ( - matrix - + jnp.swapaxes(matrix, -2, -1) - + jnp.identity(matrix.shape[-1]) + matrix + + jnp.swapaxes(matrix, -2, -1) + + jnp.identity(matrix.shape[-1]) ) return transform.inv(matrix) @@ -1601,9 +1611,9 @@ def inv_vec_transform(y): if constraint is constraints.positive_definite: # fill the upper triangular part matrix = ( - matrix - + jnp.swapaxes(matrix, -2, -1) - - jnp.diag(jnp.diag(matrix)) + matrix + + jnp.swapaxes(matrix, -2, -1) + - jnp.diag(jnp.diag(matrix)) ) return transform.inv(matrix) @@ -1625,10 +1635,10 @@ def inv_vec_transform(y): (PowerTransform(2.0), ()), (SoftplusTransform(), ()), ( - LowerCholeskyAffine( - jnp.array([1.0, 2.0]), jnp.array([[0.6, 0.0], [1.5, 0.4]]) - ), - (2,), + LowerCholeskyAffine( + jnp.array([1.0, 2.0]), jnp.array([[0.6, 0.0], [1.5, 0.4]]) + ), + (2,), ), ], ) @@ -1682,7 +1692,7 @@ def test_composed_transform(batch_shape): log_det = t.log_abs_det_jacobian(x, y) assert log_det.shape == batch_shape expected_log_det = ( - jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 + jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 ) assert_allclose(log_det, expected_log_det) @@ -1701,9 +1711,9 @@ def test_composed_transform_1(batch_shape): assert log_det.shape == batch_shape z = t2(x * 2) expected_log_det = ( - jnp.log(2) * 6 - + t2.log_abs_det_jacobian(x * 2, z) - + t2.log_abs_det_jacobian(z, t2(z)).sum(-1) + jnp.log(2) * 6 + + t2.log_abs_det_jacobian(x * 2, z) + + t2.log_abs_det_jacobian(z, t2(z)).sum(-1) ) assert_allclose(log_det, expected_log_det) @@ -1714,8 +1724,8 @@ def test_composed_transform_1(batch_shape): def test_transformed_distribution(batch_shape, prepend_event_shape, sample_shape): base_dist = ( dist.Normal(0, 1) - .expand(batch_shape + prepend_event_shape + (6,)) - .to_event(1 + len(prepend_event_shape)) + .expand(batch_shape + prepend_event_shape + (6,)) + .to_event(1 + len(prepend_event_shape)) ) t1 = transforms.AffineTransform(0, 2) t2 = transforms.LowerCholeskyTransform() @@ -1827,7 +1837,7 @@ def test_unpack_transform(x_dim, y_dim): @pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS) def test_generated_sample_distribution( - jax_dist, sp_dist, params, N_sample=100_000, key=random.PRNGKey(11) + jax_dist, sp_dist, params, N_sample=100_000, key=random.PRNGKey(11) ): """On samplers that we do not get directly from JAX, (e.g. we only get Gumbel(0,1) but also provide samplers for Gumbel(loc, scale)), also test @@ -1885,8 +1895,8 @@ def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape): assert expanded_dist.log_prob(samples).shape == sample_shape + new_batch_shape # test expand of expand assert ( - expanded_dist.expand((3,) + new_batch_shape).batch_shape - == (3,) + new_batch_shape + expanded_dist.expand((3,) + new_batch_shape).batch_shape + == (3,) + new_batch_shape ) # test expand error if prepend_shape: From 8810a2ebb862c4f1f695c9b8eeea3df75713923b Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 9 Jul 2021 11:25:28 +0200 Subject: [PATCH 13/23] Added test cases with inconsistent size --- numpyro/distributions/directional.py | 22 +++++++++++++--------- test/test_distributions.py | 15 ++++++++++++--- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 5ae7258a5..9de0006da 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -217,17 +217,17 @@ def __init__( phi_loc, psi_loc, phi_concentration, psi_concentration, correlation ) batch_shape = lax.broadcast_shapes( - phi_loc.shape, - psi_loc.shape, - phi_concentration.shape, - psi_concentration.shape, - correlation.shape, + jnp.shape(phi_loc), + jnp.shape(psi_loc), + jnp.shape(phi_concentration), + jnp.shape(psi_concentration), + jnp.shape(correlation), ) super().__init__(batch_shape, (2,), validate_args) @lazy_property def norm_const(self): - corr = self.correlation.reshape(1, -1) + 1e-8 + corr = jnp.reshape(self.correlation, (1, -1)) + 1e-8 conc = jnp.stack( (self.phi_concentration, self.psi_concentration), axis=-1 ).reshape(-1, 2) @@ -244,7 +244,7 @@ def norm_const(self): fs += log_I1(49, conc, terms=51).sum(-1) mfs = fs.max() norm_const = 2 * jnp.log(jnp.array(2 * pi)) + mfs + logsumexp(fs - mfs, 0) - return norm_const.reshape(self.phi_loc.shape) + return jnp.reshape(norm_const, jnp.shape(self.phi_loc)) def log_prob(self, value): if self._validate_args: @@ -265,6 +265,8 @@ def sample(self, key, sample_shape=()): 1. A New Unified Approach for the Simulation of aWide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018) """ + assert is_prng_key(key) + phi_key, psi_key = random.split(key) corr = self.correlation @@ -283,10 +285,10 @@ def sample(self, key, sample_shape=()): phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den ) - phi = lax.atan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) + phi = jnp.arctan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) alpha = jnp.sqrt(conc[1] ** 2 + (corr * jnp.sin(phi)) ** 2) - beta = lax.atan(corr / conc[1] * jnp.sin(phi)) + beta = jnp.arctan(corr / conc[1] * jnp.sin(phi)) psi = VonMises(beta, alpha).sample(psi_key) @@ -422,6 +424,8 @@ def mode(self): return safe_normalize(self.concentration) def sample(self, key, sample_shape=()): + assert is_prng_key(key) + shape = sample_shape + self.batch_shape + self.event_shape eps = random.normal(key, shape=shape) return safe_normalize(self.concentration + eps) diff --git a/test/test_distributions.py b/test/test_distributions.py index d7876010c..7770e4dc8 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -343,8 +343,8 @@ def get_sp_dist(jax_dist): T(dist.VonMises, jnp.array([1 / 3 * jnp.pi, -1.0]), jnp.array([20.0, 30.0])), T( SineBivariateVonMises, - jnp.array([0.0]), - jnp.array([0.0]), + jnp.array([0.]), + 0., jnp.array([5.0]), jnp.array([6.0]), jnp.array([2.0]), @@ -352,7 +352,7 @@ def get_sp_dist(jax_dist): T( SineBivariateVonMises, jnp.array([3.003]), - jnp.array([-1.3430]), + jnp.array(-1.3430), jnp.array([5.0]), jnp.array([6.0]), jnp.array([2.0]), @@ -366,6 +366,15 @@ def get_sp_dist(jax_dist): None, jnp.array([0.2, 0.5]), ), + T( + SineBivariateVonMises, + jnp.array([math.pi - 0.2, 1.0]), + jnp.array([0.0, -math.pi + .1]), + jnp.array([2.123, 20.0]), + jnp.array(0.5), + None, + jnp.array([0., 0.]), + ), T(dist.ProjectedNormal, jnp.array([0.0, 0.0])), T(dist.ProjectedNormal, jnp.array([[2.0, 3.0]])), T(dist.ProjectedNormal, jnp.array([0.0, 0.0, 0.0])), From 7ae4f500ddb96c1b264ab6d4c15ae5fac4dd1879 Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 9 Jul 2021 11:38:03 +0200 Subject: [PATCH 14/23] Moved `SineVBivariateVonMises` in docs. --- docs/source/distributions.rst | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 7b08cda41..8cc6996cd 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -477,17 +477,18 @@ ZeroInflatedNegativeBinomial2 Directional Distributions ========================= -SineBivariateVonMises ---------------------- -.. autoclass:: numpyro.distributions.directional.SineBivariateVonMises + +ProjectedNormal +--------------- +.. autoclass:: numpyro.distributions.directional.ProjectedNormal :members: :undoc-members: :show-inheritance: :member-order: bysource -ProjectedNormal ---------------- -.. autoclass:: numpyro.distributions.directional.ProjectedNormal +SineBivariateVonMises +--------------------- +.. autoclass:: numpyro.distributions.directional.SineBivariateVonMises :members: :undoc-members: :show-inheritance: From 2cdcdd3293e53a2b2c4640012c857c2c4e24cddf Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 20 Jul 2021 10:26:12 +0200 Subject: [PATCH 15/23] Fixed docstring paths, removed lax and fixed `SineBivariateVonMises` placement in docs. --- docs/source/distributions.rst | 13 ++++--- numpyro/distributions/directional.py | 58 ++++++++++++++-------------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 7b08cda41..8cc6996cd 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -477,17 +477,18 @@ ZeroInflatedNegativeBinomial2 Directional Distributions ========================= -SineBivariateVonMises ---------------------- -.. autoclass:: numpyro.distributions.directional.SineBivariateVonMises + +ProjectedNormal +--------------- +.. autoclass:: numpyro.distributions.directional.ProjectedNormal :members: :undoc-members: :show-inheritance: :member-order: bysource -ProjectedNormal ---------------- -.. autoclass:: numpyro.distributions.directional.ProjectedNormal +SineBivariateVonMises +--------------------- +.. autoclass:: numpyro.distributions.directional.SineBivariateVonMises :members: :undoc-members: :show-inheritance: diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 5ae7258a5..b8c4f04e0 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -10,6 +10,7 @@ from jax import lax import jax.numpy as jnp import jax.random as random +from jax.scipy import special from jax.scipy.special import erf, i0e, i1e, logsumexp from numpyro.distributions import constraints @@ -48,10 +49,10 @@ def log_I1(orders: int, value, terms=250): flat_vshape = _numel(vshape) k = jnp.arange(terms) - lgammas_all = lax.lgamma(jnp.arange(1.0, terms + orders + 1)) + lgammas_all = special.gammaln(jnp.arange(1.0, terms + orders + 1)) assert lgammas_all.shape == (orders + terms,) # lgamma(0) = inf => start from 1 - lvalues = lax.log(value / 2) * k.reshape(1, -1) + lvalues = jnp.log(value / 2) * k.reshape(1, -1) assert lvalues.shape == (flat_vshape, terms) lfactorials = lgammas_all[:terms] @@ -151,7 +152,7 @@ class SineBivariateVonMises(Distribution): kappa's are the concentration and rho gives the correlation between angles :math:`x_1` and :math:`x_2`. This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains. - To infer parameters, use :class:`~numpyro.infer.NUTS` or :class:`~numpyro.infer.HMC` with priors that + To infer parameters, use :class:`~numpyro.infer.hmc.NUTS` or :class:`~numpyro.infer.hmc.HMC` with priors that avoid parameterizations where the distribution becomes bimodal; see note below. .. note:: Sample efficiency drops as @@ -163,18 +164,18 @@ class SineBivariateVonMises(Distribution): .. note:: The correlation and weighted_correlation params are mutually exclusive. - .. note:: In the context of :class:`~numpyro.infer.SVI`, this distribution can be used as a likelihood but not for + .. note:: In the context of :class:`~numpyro.infer.svi.SVI`, this distribution can be used as a likelihood but not for latent variables. ** References: ** 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) - :param jnp.Tensor phi_loc: location of first angle - :param jnp.Tensor psi_loc: location of second angle - :param jnp.Tensor phi_concentration: concentration of first angle - :param jnp.Tensor psi_concentration: concentration of second angle - :param jnp.Tensor correlation: correlation between the two angles - :param jnp.Tensor weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) + :param jnp.array phi_loc: location of first angle + :param jnp.array psi_loc: location of second angle + :param jnp.array phi_concentration: concentration of first angle + :param jnp.array psi_concentration: concentration of second angle + :param jnp.array correlation: correlation between the two angles + :param jnp.array weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) to avoid bimodality (see note). """ @@ -185,7 +186,7 @@ class SineBivariateVonMises(Distribution): "psi_concentration": constraints.positive, "correlation": constraints.real, } - support = constraints.independent(constraints.real, 1) + support = constraints.independent(constraints.real, 1) # TODO: update to circular constraint @1080 max_sample_iter = 1000 def __init__( @@ -217,23 +218,23 @@ def __init__( phi_loc, psi_loc, phi_concentration, psi_concentration, correlation ) batch_shape = lax.broadcast_shapes( - phi_loc.shape, - psi_loc.shape, - phi_concentration.shape, - psi_concentration.shape, - correlation.shape, + jnp.shape(phi_loc), + jnp.shape(psi_loc), + jnp.shape(phi_concentration), + jnp.shape(psi_concentration), + jnp.shape(correlation), ) super().__init__(batch_shape, (2,), validate_args) @lazy_property def norm_const(self): - corr = self.correlation.reshape(1, -1) + 1e-8 + corr = jnp.reshape(self.correlation, (1, -1)) + 1e-8 conc = jnp.stack( (self.phi_concentration, self.psi_concentration), axis=-1 ).reshape(-1, 2) m = jnp.arange(50).reshape(-1, 1) - num = lax.lgamma(2 * m + 1.0) - den = lax.lgamma(m + 1.0) + num = special.gammaln(2 * m + 1.0) + den = special.gammaln(m + 1.0) lbinoms = num - 2 * den fs = ( @@ -244,11 +245,10 @@ def norm_const(self): fs += log_I1(49, conc, terms=51).sum(-1) mfs = fs.max() norm_const = 2 * jnp.log(jnp.array(2 * pi)) + mfs + logsumexp(fs - mfs, 0) - return norm_const.reshape(self.phi_loc.shape) + return norm_const.reshape(jnp.shape(self.phi_loc)) + @validate_sample def log_prob(self, value): - if self._validate_args: - self._validate_sample(value) indv = self.phi_concentration * jnp.cos( value[..., 0] - self.phi_loc ) + self.psi_concentration * jnp.cos(value[..., 1] - self.psi_loc) @@ -265,6 +265,7 @@ def sample(self, key, sample_shape=()): 1. A New Unified Approach for the Simulation of aWide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018) """ + assert is_prng_key(key) phi_key, psi_key = random.split(key) corr = self.correlation @@ -283,10 +284,10 @@ def sample(self, key, sample_shape=()): phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den ) - phi = lax.atan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) + phi = jnp.arctan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) alpha = jnp.sqrt(conc[1] ** 2 + (corr * jnp.sin(phi)) ** 2) - beta = lax.atan(corr / conc[1] * jnp.sin(phi)) + beta = jnp.arctan(corr / conc[1] * jnp.sin(phi)) psi = VonMises(beta, alpha).sample(psi_key) @@ -314,9 +315,7 @@ def update_fn(curr): accept_key, acg_key, phi_key = random.split(phi_key, 3) x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) - x /= jnp.linalg.norm(x, axis=1)[ - :, None, : - ] # Angular Central Gaussian distribution + x /= jnp.linalg.norm(x, axis=1, keepdims=True) # Angular Central Gaussian distribution lf = ( conc[:, :1] * (x[:, :1] - 1) @@ -360,10 +359,11 @@ def cond_fn(curr): @property def mean(self): - """Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]""" - return (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % ( + """Computes circular mean of distribution. Note: same as location when mapped to support [-pi, pi]""" + mean = (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % ( 2.0 * jnp.pi ) - jnp.pi + return jnp.broadcast_to(mean, self.batch_shape) def _bfind(self, eig): b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype) From 667bc5b6aba33a1208e6ada96bb194a86814efda Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 20 Jul 2021 13:01:21 +0200 Subject: [PATCH 16/23] Fixed broadcasting for mean. --- numpyro/distributions/directional.py | 12 +++++++++--- test/test_distributions.py | 3 ++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index b8c4f04e0..3723fe01b 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -186,7 +186,9 @@ class SineBivariateVonMises(Distribution): "psi_concentration": constraints.positive, "correlation": constraints.real, } - support = constraints.independent(constraints.real, 1) # TODO: update to circular constraint @1080 + support = constraints.independent( + constraints.real, 1 + ) # TODO: @OlaRonning update to circular constraint @1080 max_sample_iter = 1000 def __init__( @@ -315,7 +317,9 @@ def update_fn(curr): accept_key, acg_key, phi_key = random.split(phi_key, 3) x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) - x /= jnp.linalg.norm(x, axis=1, keepdims=True) # Angular Central Gaussian distribution + x /= jnp.linalg.norm( + x, axis=1, keepdims=True + ) # Angular Central Gaussian distribution lf = ( conc[:, :1] * (x[:, :1] - 1) @@ -363,7 +367,9 @@ def mean(self): mean = (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % ( 2.0 * jnp.pi ) - jnp.pi - return jnp.broadcast_to(mean, self.batch_shape) + print(mean.shape) + print(self.batch_shape) + return jnp.broadcast_to(mean, (*self.batch_shape, 2)) def _bfind(self, eig): b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype) diff --git a/test/test_distributions.py b/test/test_distributions.py index d7876010c..4664096cd 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -353,7 +353,7 @@ def get_sp_dist(jax_dist): SineBivariateVonMises, jnp.array([3.003]), jnp.array([-1.3430]), - jnp.array([5.0]), + jnp.array(5.0), jnp.array([6.0]), jnp.array([2.0]), ), @@ -1222,6 +1222,7 @@ def test_mean_var(jax_dist, sp_dist, params): elif jax_dist in [dist.SineBivariateVonMises]: phi_loc = _circ_mean(samples[..., 0]) psi_loc = _circ_mean(samples[..., 1]) + assert_allclose( d_jax.mean, jnp.stack((phi_loc, psi_loc), axis=-1), rtol=0.05, atol=1e-2 ) From b7679144e919740fd8a66290fd19b5ec2a6a7a54 Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 20 Jul 2021 13:05:24 +0200 Subject: [PATCH 17/23] Fixed lint. --- numpyro/distributions/directional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 3723fe01b..da041f333 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -164,8 +164,8 @@ class SineBivariateVonMises(Distribution): .. note:: The correlation and weighted_correlation params are mutually exclusive. - .. note:: In the context of :class:`~numpyro.infer.svi.SVI`, this distribution can be used as a likelihood but not for - latent variables. + .. note:: In the context of :class:`~numpyro.infer.svi.SVI`, this distribution can be used as a likelihood but not + for latent variables. ** References: ** 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) From e09d28fcf060b91e57a1347e6d47dfdae37b788a Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 20 Jul 2021 14:52:06 +0200 Subject: [PATCH 18/23] updated docstring param type to `jnp.ndarray` for `SineBivariateVonMises`. --- numpyro/distributions/directional.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index da041f333..ae6824515 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -170,12 +170,12 @@ class SineBivariateVonMises(Distribution): ** References: ** 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) - :param jnp.array phi_loc: location of first angle - :param jnp.array psi_loc: location of second angle - :param jnp.array phi_concentration: concentration of first angle - :param jnp.array psi_concentration: concentration of second angle - :param jnp.array correlation: correlation between the two angles - :param jnp.array weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) + :param jnp.ndarray phi_loc: location of first angle + :param jnp.ndarray psi_loc: location of second angle + :param jnp.ndarray phi_concentration: concentration of first angle + :param jnp.ndarray psi_concentration: concentration of second angle + :param jnp.ndarray correlation: correlation between the two angles + :param jnp.ndarray weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) to avoid bimodality (see note). """ From 8ab2f449faf301b5d59dc387df19992e2737931a Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 20 Jul 2021 14:54:43 +0200 Subject: [PATCH 19/23] `jnp.ndarray` -> `np.ndarray` --- numpyro/distributions/directional.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index ae6824515..7c623b59a 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -6,6 +6,7 @@ import math from math import pi import operator +import numpy as np from jax import lax import jax.numpy as jnp @@ -170,12 +171,12 @@ class SineBivariateVonMises(Distribution): ** References: ** 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) - :param jnp.ndarray phi_loc: location of first angle - :param jnp.ndarray psi_loc: location of second angle - :param jnp.ndarray phi_concentration: concentration of first angle - :param jnp.ndarray psi_concentration: concentration of second angle - :param jnp.ndarray correlation: correlation between the two angles - :param jnp.ndarray weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) + :param np.ndarray phi_loc: location of first angle + :param np.ndarray psi_loc: location of second angle + :param np.ndarray phi_concentration: concentration of first angle + :param np.ndarray psi_concentration: concentration of second angle + :param np.ndarray correlation: correlation between the two angles + :param np.ndarray weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) to avoid bimodality (see note). """ From 44a8854f5c34561da9c755418cd676475734ac4e Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 20 Jul 2021 14:55:12 +0200 Subject: [PATCH 20/23] `jnp.ndarray` -> `np.ndarray` --- numpyro/distributions/directional.py | 1 + 1 file changed, 1 insertion(+) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 7c623b59a..7d69f55fb 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -6,6 +6,7 @@ import math from math import pi import operator + import numpy as np from jax import lax From 88f646aaafd641c5dac54b88c29467df4b251b65 Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 20 Jul 2021 15:04:23 +0200 Subject: [PATCH 21/23] removed numpy import from directional.py --- numpyro/distributions/directional.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 7d69f55fb..50469d848 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -7,8 +7,6 @@ from math import pi import operator -import numpy as np - from jax import lax import jax.numpy as jnp import jax.random as random From 488dcd88f808b61f58da9863abce7e467e6349b7 Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 9 Aug 2021 13:13:12 +0200 Subject: [PATCH 22/23] Fixed scalar params --- numpyro/distributions/directional.py | 12 ++++++++++-- test/test_distributions.py | 8 ++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 50469d848..a16939933 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -281,9 +281,17 @@ def sample(self, key, sample_shape=()): total = _numel(sample_shape) phi_den = log_I1(0, conc[1]).squeeze(0) - phi_shape = (total, 2, _numel(self.batch_shape)) + batch_size = _numel(self.batch_shape) + phi_shape = (total, 2, batch_size) phi_state = SineBivariateVonMises._phi_marginal( - phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den + phi_shape, + phi_key, + jnp.reshape(conc, (2, batch_size)), + jnp.reshape(corr, (batch_size,)), + jnp.reshape(eig, (2, batch_size)), + jnp.reshape(b0, (batch_size,)), + jnp.reshape(eigmin, (batch_size,)), + jnp.reshape(phi_den, (batch_size,)), ) phi = jnp.arctan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) diff --git a/test/test_distributions.py b/test/test_distributions.py index 4664096cd..1a3b188bf 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -357,6 +357,14 @@ def get_sp_dist(jax_dist): jnp.array([6.0]), jnp.array([2.0]), ), + T( + SineBivariateVonMises, + jnp.array(-1.232), + jnp.array(-1.3430), + jnp.array(3.4), + jnp.array(2.0), + jnp.array(1.0), + ), T( SineBivariateVonMises, jnp.array([math.pi - 0.2, 1.0]), From fb8e5dbe93359936770f710390e01d36d3447cff Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 19 Aug 2021 09:35:14 +0200 Subject: [PATCH 23/23] Changed `SineBivariateVonMises` support to circular --- numpyro/distributions/directional.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 8a34a2909..f27126587 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -197,15 +197,13 @@ class SineBivariateVonMises(Distribution): """ arg_constraints = { - "phi_loc": constraints.real, - "psi_loc": constraints.real, + "phi_loc": constraints.circular, + "psi_loc": constraints.circular, "phi_concentration": constraints.positive, "psi_concentration": constraints.positive, "correlation": constraints.real, } - support = constraints.independent( - constraints.real, 1 - ) # TODO: @OlaRonning update to circular constraint @1080 + support = constraints.independent(constraints.circular, 1) max_sample_iter = 1000 def __init__( @@ -218,7 +216,6 @@ def __init__( weighted_correlation=None, validate_args=None, ): - assert (correlation is None) != (weighted_correlation is None) if weighted_correlation is not None: