Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting Sine Bivaraite von Mises from Pyro #1063

Merged
merged 27 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f447521
Added BvM.
OlaRonning Jun 3, 2021
529c4c5
Added BvM tests.
OlaRonning Jun 3, 2021
ec1bd1e
Added comment with tests that need to be fixed.
OlaRonning Jun 3, 2021
9858251
Added tests.
OlaRonning Jun 10, 2021
bde8204
Ran black
OlaRonning Jun 10, 2021
ebd8765
fixed license
OlaRonning Jun 10, 2021
2aef87d
Added BvM to docs
OlaRonning Jun 10, 2021
9fb1bbb
Merge remote-tracking branch 'origin/feature/bvm_dist' into feature/b…
OlaRonning Jul 8, 2021
d37e965
Merge branch 'master' of github.com:pyro-ppl/numpyro into feature/bvm…
OlaRonning Jul 8, 2021
eccbdde
Fixed docstring
OlaRonning Jul 8, 2021
9583074
Added math envs to docstring for `SineBivariateVonMises`.
OlaRonning Jul 8, 2021
269ae67
Fixed `test_distribution_constraints` failures for `SineBivariateVonM…
OlaRonning Jul 8, 2021
d6d2b91
Added circular mean to `test_mean_var`
OlaRonning Jul 8, 2021
ed40a6b
Fixed lint.
OlaRonning Jul 8, 2021
8810a2e
Added test cases with inconsistent size
OlaRonning Jul 9, 2021
7ae4f50
Moved `SineVBivariateVonMises` in docs.
OlaRonning Jul 9, 2021
2cdcdd3
Fixed docstring paths, removed lax and fixed `SineBivariateVonMises` …
OlaRonning Jul 20, 2021
667bc5b
Fixed broadcasting for mean.
OlaRonning Jul 20, 2021
b767914
Fixed lint.
OlaRonning Jul 20, 2021
e09d28f
updated docstring param type to `jnp.ndarray` for `SineBivariateVonMi…
OlaRonning Jul 20, 2021
8ab2f44
`jnp.ndarray` -> `np.ndarray`
OlaRonning Jul 20, 2021
44a8854
`jnp.ndarray` -> `np.ndarray`
OlaRonning Jul 20, 2021
88f646a
removed numpy import from directional.py
OlaRonning Jul 20, 2021
689477d
Merge remote-tracking branch 'origin/feature/bvm_dist' into feature/b…
OlaRonning Aug 9, 2021
488dcd8
Fixed scalar params
OlaRonning Aug 9, 2021
48be020
Merge branch 'master' of github.com:pyro-ppl/numpyro into feature/bvm…
OlaRonning Aug 19, 2021
fb8e5db
Changed `SineBivariateVonMises` support to circular
OlaRonning Aug 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,14 @@ ProjectedNormal
:show-inheritance:
:member-order: bysource

SineBivariateVonMises
---------------------
.. autoclass:: numpyro.distributions.directional.SineBivariateVonMises
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

VonMises
--------
.. autoclass:: numpyro.distributions.directional.VonMises
Expand Down
7 changes: 6 additions & 1 deletion numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
Uniform,
Weibull,
)
from numpyro.distributions.directional import ProjectedNormal, VonMises
from numpyro.distributions.directional import (
ProjectedNormal,
SineBivariateVonMises,
VonMises,
)
from numpyro.distributions.discrete import (
Bernoulli,
BernoulliLogits,
Expand Down Expand Up @@ -145,6 +149,7 @@
"ProjectedNormal",
"PRNGIdentity",
"RightTruncatedDistribution",
"SineBivariateVonMises",
"SoftLaplace",
"StudentT",
"TransformedDistribution",
Expand Down
313 changes: 312 additions & 1 deletion numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,80 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
import functools
import math
from math import pi
import operator

from jax import lax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import erf, i0e, i1e
from jax.scipy import special
from jax.scipy.special import erf, i0e, i1e, logsumexp

from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
is_prng_key,
lazy_property,
promote_shapes,
safe_normalize,
validate_sample,
von_mises_centered,
)
from numpyro.util import while_loop


def _numel(shape):
return functools.reduce(operator.mul, shape, 1)


def log_I1(orders: int, value, terms=250):
r"""Compute first n log modified bessel function of first kind
.. math ::
\log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk)
- \lgamma(v + k + 1)\right])
:param orders: orders of the log modified bessel function.
:param value: values to compute modified bessel function for
:param terms: truncation of summation
:return: 0 to orders modified bessel function
"""
orders = orders + 1
if value.ndim == 0:
vshape = jnp.shape([1])
else:
vshape = value.shape
value = value.reshape(-1, 1)
flat_vshape = _numel(vshape)

k = jnp.arange(terms)
lgammas_all = special.gammaln(jnp.arange(1.0, terms + orders + 1))
assert lgammas_all.shape == (orders + terms,) # lgamma(0) = inf => start from 1

lvalues = jnp.log(value / 2) * k.reshape(1, -1)
assert lvalues.shape == (flat_vshape, terms)

lfactorials = lgammas_all[:terms]
assert lfactorials.shape == (terms,)

lgammas = lgammas_all.tile(orders).reshape((orders, -1))
assert lgammas.shape == (orders, terms + orders) # lgamma(0) = inf => start from 1

indices = k[:orders].reshape(-1, 1) + k.reshape(1, -1)
assert indices.shape == (orders, terms)

seqs = logsumexp(
2 * lvalues[None, :, :]
- lfactorials[None, None, :]
- jnp.take_along_axis(lgammas, indices, axis=1)[:, None, :],
-1,
)
assert seqs.shape == (orders, flat_vshape)

i1s = lvalues[..., :orders].T + seqs
assert i1s.shape == (orders, flat_vshape)
return i1s.reshape(-1, *vshape)


class VonMises(Distribution):
Expand Down Expand Up @@ -75,6 +133,259 @@ def variance(self):
)


PhiMarginalState = namedtuple("PhiMarginalState", ["i", "done", "phi", "key"])


class SineBivariateVonMises(Distribution):
r"""Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by

.. math::
C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2))

and

.. math::
C = (2\pi)^2 \sum_{i=0} {2i \choose i}
\left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2),

where :math:`I_i(\cdot)` is the modified bessel function of first kind, mu's are the locations of the distribution,
kappa's are the concentration and rho gives the correlation between angles :math:`x_1` and :math:`x_2`.
This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.

To infer parameters, use :class:`~numpyro.infer.hmc.NUTS` or :class:`~numpyro.infer.hmc.HMC` with priors that
avoid parameterizations where the distribution becomes bimodal; see note below.
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved

.. note:: Sample efficiency drops as

.. math::
\frac{\rho}{\kappa_1\kappa_2} \rightarrow 1

because the distribution becomes increasingly bimodal.

.. note:: The correlation and weighted_correlation params are mutually exclusive.

.. note:: In the context of :class:`~numpyro.infer.svi.SVI`, this distribution can be used as a likelihood but not
for latent variables.

** References: **
1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)

:param np.ndarray phi_loc: location of first angle
:param np.ndarray psi_loc: location of second angle
:param np.ndarray phi_concentration: concentration of first angle
:param np.ndarray psi_concentration: concentration of second angle
:param np.ndarray correlation: correlation between the two angles
:param np.ndarray weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc)
to avoid bimodality (see note).
"""

arg_constraints = {
"phi_loc": constraints.real,
"psi_loc": constraints.real,
"phi_concentration": constraints.positive,
"psi_concentration": constraints.positive,
"correlation": constraints.real,
}
support = constraints.independent(
constraints.real, 1
) # TODO: @OlaRonning update to circular constraint @1080
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
max_sample_iter = 1000

def __init__(
self,
phi_loc,
psi_loc,
phi_concentration,
psi_concentration,
correlation=None,
weighted_correlation=None,
validate_args=None,
):

assert (correlation is None) != (weighted_correlation is None)

if weighted_correlation is not None:
correlation = (
weighted_correlation * jnp.sqrt(phi_concentration * psi_concentration)
+ 1e-8
)

(
self.phi_loc,
self.psi_loc,
self.phi_concentration,
self.psi_concentration,
self.correlation,
) = promote_shapes(
phi_loc, psi_loc, phi_concentration, psi_concentration, correlation
)
batch_shape = lax.broadcast_shapes(
jnp.shape(phi_loc),
jnp.shape(psi_loc),
jnp.shape(phi_concentration),
jnp.shape(psi_concentration),
jnp.shape(correlation),
)
super().__init__(batch_shape, (2,), validate_args)

@lazy_property
def norm_const(self):
corr = jnp.reshape(self.correlation, (1, -1)) + 1e-8
conc = jnp.stack(
(self.phi_concentration, self.psi_concentration), axis=-1
).reshape(-1, 2)
m = jnp.arange(50).reshape(-1, 1)
num = special.gammaln(2 * m + 1.0)
den = special.gammaln(m + 1.0)
lbinoms = num - 2 * den

fs = (
lbinoms.reshape(-1, 1)
+ 2 * m * jnp.log(corr)
- m * jnp.log(4 * jnp.prod(conc, axis=-1))
)
fs += log_I1(49, conc, terms=51).sum(-1)
mfs = fs.max()
norm_const = 2 * jnp.log(jnp.array(2 * pi)) + mfs + logsumexp(fs - mfs, 0)
return norm_const.reshape(jnp.shape(self.phi_loc))

@validate_sample
def log_prob(self, value):
indv = self.phi_concentration * jnp.cos(
value[..., 0] - self.phi_loc
) + self.psi_concentration * jnp.cos(value[..., 1] - self.psi_loc)
corr = (
self.correlation
* jnp.sin(value[..., 0] - self.phi_loc)
* jnp.sin(value[..., 1] - self.psi_loc)
)
return indv + corr - self.norm_const

def sample(self, key, sample_shape=()):
"""
** References: **
1. A New Unified Approach for the Simulation of aWide Class of Directional Distributions
John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)
"""
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
assert is_prng_key(key)
phi_key, psi_key = random.split(key)

corr = self.correlation
conc = jnp.stack((self.phi_concentration, self.psi_concentration))

eig = 0.5 * (conc[0] - corr ** 2 / conc[1])
eig = jnp.stack((jnp.zeros_like(eig), eig))
eigmin = jnp.where(eig[1] < 0, eig[1], jnp.zeros_like(eig[1], dtype=eig.dtype))
eig = eig - eigmin
b0 = self._bfind(eig)

total = _numel(sample_shape)
phi_den = log_I1(0, conc[1]).squeeze(0)
batch_size = _numel(self.batch_shape)
phi_shape = (total, 2, batch_size)
phi_state = SineBivariateVonMises._phi_marginal(
phi_shape,
phi_key,
jnp.reshape(conc, (2, batch_size)),
jnp.reshape(corr, (batch_size,)),
jnp.reshape(eig, (2, batch_size)),
jnp.reshape(b0, (batch_size,)),
jnp.reshape(eigmin, (batch_size,)),
jnp.reshape(phi_den, (batch_size,)),
)

phi = jnp.arctan2(phi_state.phi[:, 1:], phi_state.phi[:, :1])

alpha = jnp.sqrt(conc[1] ** 2 + (corr * jnp.sin(phi)) ** 2)
beta = jnp.arctan(corr / conc[1] * jnp.sin(phi))

psi = VonMises(beta, alpha).sample(psi_key)

phi_psi = jnp.concatenate(
(
(phi + self.phi_loc + pi) % (2 * pi) - pi,
(psi + self.psi_loc + pi) % (2 * pi) - pi,
),
axis=1,
)
phi_psi = jnp.transpose(phi_psi, (0, 2, 1))
return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape)

@staticmethod
def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den):
conc = jnp.broadcast_to(conc, shape)
eig = jnp.broadcast_to(eig, shape)
b0 = jnp.broadcast_to(b0, shape)
eigmin = jnp.broadcast_to(eigmin, shape)
phi_den = jnp.broadcast_to(phi_den, shape)

def update_fn(curr):
i, done, phi, key = curr
phi_key, key = random.split(key)
accept_key, acg_key, phi_key = random.split(phi_key, 3)

x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape)
x /= jnp.linalg.norm(
x, axis=1, keepdims=True
) # Angular Central Gaussian distribution

lf = (
conc[:, :1] * (x[:, :1] - 1)
+ eigmin
+ log_I1(
0, jnp.sqrt(conc[:, 1:] ** 2 + (corr * x[:, 1:]) ** 2)
).squeeze(0)
- phi_den
)
assert lf.shape == shape

lg_inv = (
1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True))
)
assert lg_inv.shape == lf.shape

accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv)

phi = jnp.where(accepted, x, phi)
return PhiMarginalState(i + 1, done | accepted, phi, key)

def cond_fn(curr):
return jnp.bitwise_and(
curr.i < SineBivariateVonMises.max_sample_iter,
jnp.logical_not(jnp.all(curr.done)),
)

phi_state = while_loop(
cond_fn,
update_fn,
PhiMarginalState(
i=jnp.array(0),
done=jnp.zeros(shape, dtype=bool),
phi=jnp.empty(shape, dtype=float),
key=rng_key,
),
)
return PhiMarginalState(
phi_state.i, phi_state.done, phi_state.phi, phi_state.key
)

@property
def mean(self):
"""Computes circular mean of distribution. Note: same as location when mapped to support [-pi, pi]"""
mean = (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % (
2.0 * jnp.pi
) - jnp.pi
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
print(mean.shape)
print(self.batch_shape)
return jnp.broadcast_to(mean, (*self.batch_shape, 2))

def _bfind(self, eig):
b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype)
g1 = jnp.sum(1 / (b + 2 * eig) ** 2, axis=0)
g2 = jnp.sum(-2 / (b + 2 * eig) ** 3, axis=0)
return jnp.where(jnp.linalg.norm(eig, axis=0) != 0, b - g1 / g2, b)


class ProjectedNormal(Distribution):
"""
Projected isotropic normal distribution of arbitrary dimension.
Expand Down
Loading