diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 861ef22fe..f7c71daed 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -489,6 +489,14 @@ ProjectedNormal :show-inheritance: :member-order: bysource +SineBivariateVonMises +--------------------- +.. autoclass:: numpyro.distributions.directional.SineBivariateVonMises + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + VonMises ^^^^^^^^ .. autoclass:: numpyro.distributions.directional.VonMises diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 65e9d7bb9..428b40508 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -37,7 +37,11 @@ Uniform, Weibull, ) -from numpyro.distributions.directional import ProjectedNormal, VonMises +from numpyro.distributions.directional import ( + ProjectedNormal, + SineBivariateVonMises, + VonMises, +) from numpyro.distributions.discrete import ( Bernoulli, BernoulliLogits, @@ -145,6 +149,7 @@ "ProjectedNormal", "PRNGIdentity", "RightTruncatedDistribution", + "SineBivariateVonMises", "SoftLaplace", "StudentT", "TransformedDistribution", diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 48cef64dc..f27126587 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -1,22 +1,80 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from collections import namedtuple +import functools import math +from math import pi +import operator from jax import lax import jax.numpy as jnp import jax.random as random -from jax.scipy.special import erf, i0e, i1e +from jax.scipy import special +from jax.scipy.special import erf, i0e, i1e, logsumexp from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import ( is_prng_key, + lazy_property, promote_shapes, safe_normalize, validate_sample, von_mises_centered, ) +from numpyro.util import while_loop + + +def _numel(shape): + return functools.reduce(operator.mul, shape, 1) + + +def log_I1(orders: int, value, terms=250): + r"""Compute first n log modified bessel function of first kind + .. math :: + \log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk) + - \lgamma(v + k + 1)\right]) + :param orders: orders of the log modified bessel function. + :param value: values to compute modified bessel function for + :param terms: truncation of summation + :return: 0 to orders modified bessel function + """ + orders = orders + 1 + if value.ndim == 0: + vshape = jnp.shape([1]) + else: + vshape = value.shape + value = value.reshape(-1, 1) + flat_vshape = _numel(vshape) + + k = jnp.arange(terms) + lgammas_all = special.gammaln(jnp.arange(1.0, terms + orders + 1)) + assert lgammas_all.shape == (orders + terms,) # lgamma(0) = inf => start from 1 + + lvalues = jnp.log(value / 2) * k.reshape(1, -1) + assert lvalues.shape == (flat_vshape, terms) + + lfactorials = lgammas_all[:terms] + assert lfactorials.shape == (terms,) + + lgammas = lgammas_all.tile(orders).reshape((orders, -1)) + assert lgammas.shape == (orders, terms + orders) # lgamma(0) = inf => start from 1 + + indices = k[:orders].reshape(-1, 1) + k.reshape(1, -1) + assert indices.shape == (orders, terms) + + seqs = logsumexp( + 2 * lvalues[None, :, :] + - lfactorials[None, None, :] + - jnp.take_along_axis(lgammas, indices, axis=1)[:, None, :], + -1, + ) + assert seqs.shape == (orders, flat_vshape) + + i1s = lvalues[..., :orders].T + seqs + assert i1s.shape == (orders, flat_vshape) + return i1s.reshape(-1, *vshape) class VonMises(Distribution): @@ -92,6 +150,256 @@ def variance(self): ) +PhiMarginalState = namedtuple("PhiMarginalState", ["i", "done", "phi", "key"]) + + +class SineBivariateVonMises(Distribution): + r"""Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by + + .. math:: + C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) + + and + + .. math:: + C = (2\pi)^2 \sum_{i=0} {2i \choose i} + \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2), + + where :math:`I_i(\cdot)` is the modified bessel function of first kind, mu's are the locations of the distribution, + kappa's are the concentration and rho gives the correlation between angles :math:`x_1` and :math:`x_2`. + This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains. + + To infer parameters, use :class:`~numpyro.infer.hmc.NUTS` or :class:`~numpyro.infer.hmc.HMC` with priors that + avoid parameterizations where the distribution becomes bimodal; see note below. + + .. note:: Sample efficiency drops as + + .. math:: + \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1 + + because the distribution becomes increasingly bimodal. + + .. note:: The correlation and weighted_correlation params are mutually exclusive. + + .. note:: In the context of :class:`~numpyro.infer.svi.SVI`, this distribution can be used as a likelihood but not + for latent variables. + + ** References: ** + 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) + + :param np.ndarray phi_loc: location of first angle + :param np.ndarray psi_loc: location of second angle + :param np.ndarray phi_concentration: concentration of first angle + :param np.ndarray psi_concentration: concentration of second angle + :param np.ndarray correlation: correlation between the two angles + :param np.ndarray weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) + to avoid bimodality (see note). + """ + + arg_constraints = { + "phi_loc": constraints.circular, + "psi_loc": constraints.circular, + "phi_concentration": constraints.positive, + "psi_concentration": constraints.positive, + "correlation": constraints.real, + } + support = constraints.independent(constraints.circular, 1) + max_sample_iter = 1000 + + def __init__( + self, + phi_loc, + psi_loc, + phi_concentration, + psi_concentration, + correlation=None, + weighted_correlation=None, + validate_args=None, + ): + assert (correlation is None) != (weighted_correlation is None) + + if weighted_correlation is not None: + correlation = ( + weighted_correlation * jnp.sqrt(phi_concentration * psi_concentration) + + 1e-8 + ) + + ( + self.phi_loc, + self.psi_loc, + self.phi_concentration, + self.psi_concentration, + self.correlation, + ) = promote_shapes( + phi_loc, psi_loc, phi_concentration, psi_concentration, correlation + ) + batch_shape = lax.broadcast_shapes( + jnp.shape(phi_loc), + jnp.shape(psi_loc), + jnp.shape(phi_concentration), + jnp.shape(psi_concentration), + jnp.shape(correlation), + ) + super().__init__(batch_shape, (2,), validate_args) + + @lazy_property + def norm_const(self): + corr = jnp.reshape(self.correlation, (1, -1)) + 1e-8 + conc = jnp.stack( + (self.phi_concentration, self.psi_concentration), axis=-1 + ).reshape(-1, 2) + m = jnp.arange(50).reshape(-1, 1) + num = special.gammaln(2 * m + 1.0) + den = special.gammaln(m + 1.0) + lbinoms = num - 2 * den + + fs = ( + lbinoms.reshape(-1, 1) + + 2 * m * jnp.log(corr) + - m * jnp.log(4 * jnp.prod(conc, axis=-1)) + ) + fs += log_I1(49, conc, terms=51).sum(-1) + mfs = fs.max() + norm_const = 2 * jnp.log(jnp.array(2 * pi)) + mfs + logsumexp(fs - mfs, 0) + return norm_const.reshape(jnp.shape(self.phi_loc)) + + @validate_sample + def log_prob(self, value): + indv = self.phi_concentration * jnp.cos( + value[..., 0] - self.phi_loc + ) + self.psi_concentration * jnp.cos(value[..., 1] - self.psi_loc) + corr = ( + self.correlation + * jnp.sin(value[..., 0] - self.phi_loc) + * jnp.sin(value[..., 1] - self.psi_loc) + ) + return indv + corr - self.norm_const + + def sample(self, key, sample_shape=()): + """ + ** References: ** + 1. A New Unified Approach for the Simulation of aWide Class of Directional Distributions + John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018) + """ + assert is_prng_key(key) + phi_key, psi_key = random.split(key) + + corr = self.correlation + conc = jnp.stack((self.phi_concentration, self.psi_concentration)) + + eig = 0.5 * (conc[0] - corr ** 2 / conc[1]) + eig = jnp.stack((jnp.zeros_like(eig), eig)) + eigmin = jnp.where(eig[1] < 0, eig[1], jnp.zeros_like(eig[1], dtype=eig.dtype)) + eig = eig - eigmin + b0 = self._bfind(eig) + + total = _numel(sample_shape) + phi_den = log_I1(0, conc[1]).squeeze(0) + batch_size = _numel(self.batch_shape) + phi_shape = (total, 2, batch_size) + phi_state = SineBivariateVonMises._phi_marginal( + phi_shape, + phi_key, + jnp.reshape(conc, (2, batch_size)), + jnp.reshape(corr, (batch_size,)), + jnp.reshape(eig, (2, batch_size)), + jnp.reshape(b0, (batch_size,)), + jnp.reshape(eigmin, (batch_size,)), + jnp.reshape(phi_den, (batch_size,)), + ) + + phi = jnp.arctan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) + + alpha = jnp.sqrt(conc[1] ** 2 + (corr * jnp.sin(phi)) ** 2) + beta = jnp.arctan(corr / conc[1] * jnp.sin(phi)) + + psi = VonMises(beta, alpha).sample(psi_key) + + phi_psi = jnp.concatenate( + ( + (phi + self.phi_loc + pi) % (2 * pi) - pi, + (psi + self.psi_loc + pi) % (2 * pi) - pi, + ), + axis=1, + ) + phi_psi = jnp.transpose(phi_psi, (0, 2, 1)) + return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape) + + @staticmethod + def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den): + conc = jnp.broadcast_to(conc, shape) + eig = jnp.broadcast_to(eig, shape) + b0 = jnp.broadcast_to(b0, shape) + eigmin = jnp.broadcast_to(eigmin, shape) + phi_den = jnp.broadcast_to(phi_den, shape) + + def update_fn(curr): + i, done, phi, key = curr + phi_key, key = random.split(key) + accept_key, acg_key, phi_key = random.split(phi_key, 3) + + x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) + x /= jnp.linalg.norm( + x, axis=1, keepdims=True + ) # Angular Central Gaussian distribution + + lf = ( + conc[:, :1] * (x[:, :1] - 1) + + eigmin + + log_I1( + 0, jnp.sqrt(conc[:, 1:] ** 2 + (corr * x[:, 1:]) ** 2) + ).squeeze(0) + - phi_den + ) + assert lf.shape == shape + + lg_inv = ( + 1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True)) + ) + assert lg_inv.shape == lf.shape + + accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv) + + phi = jnp.where(accepted, x, phi) + return PhiMarginalState(i + 1, done | accepted, phi, key) + + def cond_fn(curr): + return jnp.bitwise_and( + curr.i < SineBivariateVonMises.max_sample_iter, + jnp.logical_not(jnp.all(curr.done)), + ) + + phi_state = while_loop( + cond_fn, + update_fn, + PhiMarginalState( + i=jnp.array(0), + done=jnp.zeros(shape, dtype=bool), + phi=jnp.empty(shape, dtype=float), + key=rng_key, + ), + ) + return PhiMarginalState( + phi_state.i, phi_state.done, phi_state.phi, phi_state.key + ) + + @property + def mean(self): + """Computes circular mean of distribution. Note: same as location when mapped to support [-pi, pi]""" + mean = (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % ( + 2.0 * jnp.pi + ) - jnp.pi + print(mean.shape) + print(self.batch_shape) + return jnp.broadcast_to(mean, (*self.batch_shape, 2)) + + def _bfind(self, eig): + b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype) + g1 = jnp.sum(1 / (b + 2 * eig) ** 2, axis=0) + g2 = jnp.sum(-2 / (b + 2 * eig) ** 3, axis=0) + return jnp.where(jnp.linalg.norm(eig, axis=0) != 0, b - g1 / g2, b) + + class ProjectedNormal(Distribution): """ Projected isotropic normal distribution of arbitrary dimension. diff --git a/test/test_distributions.py b/test/test_distributions.py index 5bcdbd242..5a827ba84 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4,6 +4,7 @@ from collections import namedtuple from functools import partial import inspect +import math import os import numpy as np @@ -19,6 +20,7 @@ import numpyro.distributions as dist from numpyro.distributions import constraints, kl_divergence, transforms +from numpyro.distributions.directional import SineBivariateVonMises from numpyro.distributions.discrete import _to_probs_bernoulli, _to_probs_multinom from numpyro.distributions.flows import InverseAutoregressiveTransform from numpyro.distributions.gof import InvalidTest, auto_goodness_of_fit @@ -45,6 +47,12 @@ def _identity(x): return x +def _circ_mean(angles): + return jnp.arctan2( + jnp.mean(jnp.sin(angles), axis=0), jnp.mean(jnp.cos(angles), axis=0) + ) + + class T(namedtuple("TestCase", ["jax_dist", "sp_dist", "params"])): def __new__(cls, jax_dist, *params): sp_dist = get_sp_dist(jax_dist) @@ -333,6 +341,39 @@ def get_sp_dist(jax_dist): T(dist.VonMises, 2.0, 10.0), T(dist.VonMises, 2.0, jnp.array([150.0, 10.0])), T(dist.VonMises, jnp.array([1 / 3 * jnp.pi, -1.0]), jnp.array([20.0, 30.0])), + T( + SineBivariateVonMises, + jnp.array([0.0]), + jnp.array([0.0]), + jnp.array([5.0]), + jnp.array([6.0]), + jnp.array([2.0]), + ), + T( + SineBivariateVonMises, + jnp.array([3.003]), + jnp.array([-1.3430]), + jnp.array(5.0), + jnp.array([6.0]), + jnp.array([2.0]), + ), + T( + SineBivariateVonMises, + jnp.array(-1.232), + jnp.array(-1.3430), + jnp.array(3.4), + jnp.array(2.0), + jnp.array(1.0), + ), + T( + SineBivariateVonMises, + jnp.array([math.pi - 0.2, 1.0]), + jnp.array([0.0, 1.0]), + jnp.array([2.123, 20.0]), + jnp.array([7.0, 0.5]), + None, + jnp.array([0.2, 0.5]), + ), T(dist.ProjectedNormal, jnp.array([0.0, 0.0])), T(dist.ProjectedNormal, jnp.array([[2.0, 3.0]])), T(dist.ProjectedNormal, jnp.array([0.0, 0.0, 0.0])), @@ -1186,6 +1227,13 @@ def test_mean_var(jax_dist, sp_dist, params): expected_variance = 1 - jnp.sqrt(x ** 2 + y ** 2) assert_allclose(d_jax.variance, expected_variance, rtol=0.05, atol=1e-2) + elif jax_dist in [dist.SineBivariateVonMises]: + phi_loc = _circ_mean(samples[..., 0]) + psi_loc = _circ_mean(samples[..., 1]) + + assert_allclose( + d_jax.mean, jnp.stack((phi_loc, psi_loc), axis=-1), rtol=0.05, atol=1e-2 + ) else: if jnp.all(jnp.isfinite(d_jax.mean)): assert_allclose(jnp.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2) @@ -1220,6 +1268,11 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": continue + if ( + jax_dist is dist.SineBivariateVonMises + and dist_args[i] == "weighted_correlation" + ): + continue if params[i] is None: oob_params[i] = None valid_params[i] = None @@ -1489,7 +1542,6 @@ def test_constraints(constraint, x, expected): ) @pytest.mark.parametrize("shape", [(), (1,), (3,), (6,), (3, 1), (1, 3), (5, 3)]) def test_biject_to(constraint, shape): - transform = biject_to(constraint) event_dim = transform.domain.event_dim if isinstance(constraint, constraints._Interval):