Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sine Skewed distribution #1055

Merged
merged 18 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,14 @@ ProjectedNormal
:show-inheritance:
:member-order: bysource

SineSkewed
----------
.. autoclass:: numpyro.distributions.directional.SineSkewed
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

VonMises
--------
.. autoclass:: numpyro.distributions.directional.VonMises
Expand Down
4 changes: 3 additions & 1 deletion numpyro/compat/pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import warnings

from numpyro.compat.util import UnsupportedAPIWarning
from numpyro.primitives import module, param as _param, plate, sample # noqa: F401

from numpyro.primitives import module, plate, sample # noqa: F401 isort:skip
from numpyro.primitives import param as _param # noqa: F401 isort:skip

_PARAM_STORE = {}

Expand Down
3 changes: 2 additions & 1 deletion numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
Uniform,
Weibull,
)
from numpyro.distributions.directional import ProjectedNormal, VonMises
from numpyro.distributions.directional import ProjectedNormal, SineSkewed, VonMises
from numpyro.distributions.discrete import (
Bernoulli,
BernoulliLogits,
Expand Down Expand Up @@ -145,6 +145,7 @@
"ProjectedNormal",
"PRNGIdentity",
"RightTruncatedDistribution",
"SineSkewed",
"SoftLaplace",
"StudentT",
"TransformedDistribution",
Expand Down
143 changes: 143 additions & 0 deletions numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,149 @@
)


class SineSkewed(Distribution):
r"""Sine Skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus
distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric)
base distribution.

Torus distributions are distributions with support on products of circles
(i.e., :math:`\bigotimes^d S^1` where :math:`S^1=[-\pi, \pi)` ). So, a 0-torus is a point, the 1-torus is a circle,
and the 2-torus is commonly associated with the donut shape.

The Sine Skewed X distribution is parameterized by a weight parameter for each dimension of the event of X.
For example with a von Mises distribution over a circle (1-torus), the Sine Skewed von Mises Distribution has one
skew parameter. The skewness parameters can be inferred using :class:`~numpyro.infer.HMC` or
:class:`~numpyro.infer.NUTS`. For example, the following will produce a uniform prior over
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
skewness for the 2-torus::

def model(obs):
# Sine priors
phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.))
psi_loc = numpyro.sample('psi_loc', VonMises(-pi / 2, 2.))
phi_conc = numpyro.sample('phi_conc', Beta(2., 2.))
psi_conc = numpyro.sample('psi_conc', Beta(2., 2.))
corr_scale = numpyro.sample('corr_scale', Beta(2., 5.))

# SS prior
skew_phi = numpyro.sample('skew_phi', Uniform(-1., 1.))
psi_bound = 1 - skew_phi.abs()
skew_psi = numpyro.sample('skew_psi', Uniform(-1., 1.))
skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=-1)
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
with numpyro.plate('obs_plate'):
sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc,
phi_concentration=1000 * phi_conc,
psi_concentration=1000 * psi_conc,
weighted_correlation=corr_scale)
return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)

To ensure the skewing does not alter the normalization constant of the (Sine Bivaraite von Mises) base
distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of
skewness to be less than or equal to one.
So for the above snippet it must hold that::

skew_phi.abs()+skew_psi.abs() <= 1
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved

We handle this in the prior by computing psi_bound and use it to scale skew_psi.
We do **not** use psi_bound as::

skew_psi = pyro.sample('skew_psi', Uniform(-psi_bound, psi_bound))
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved

as it would make the support for the Uniform distribution dynamic.
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
In the context of :class:`~pyro.infer.SVI`, this distribution can freely be used as a likelihood. But, when used as
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist
cannot be reparameterized.

.. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,).

.. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event
must be less than or equal to one. See eq. 2.1 in [1].

** References: **
1. Sine-skewed toroidal distributions and their application in protein bioinformatics
Ameijeiras-Alonso, J., Ley, C. (2019)

:param torch.distributions.Distribution base_dist: base density on a d-dimensional torus. Supported base
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
distributions include: 1D :class:`~numpyro.distributions.VonMises`,
:class:`~numnumpyro.distributions.SineBivariateVonMises`, 1D :class:`~numpyro.distributions.ProjectedNormal`,
and :class:`~numpyro.distributions.Uniform` (-pi, pi).
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
:param torch.tensor skewness: skewness of the distribution.
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
"""

arg_constraints = {
"skewness": constraints.independent(constraints.interval(-1.0, 1.0), 1)
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
}

support = constraints.independent(
constraints.real, 1
) # TODO: add Circular constraint (issue 1070 and PR 1080)

def __init__(self, base_dist: Distribution, skewness, validate_args=None):
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
assert (
base_dist.event_shape == skewness.shape[-1:]
), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`."

batch_shape = jnp.broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1])
event_shape = skewness.shape[-1:]
self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape)
self.base_dist = base_dist.expand(batch_shape)
super().__init__(batch_shape, event_shape, validate_args=validate_args)

def __repr__(self):
args_string = ", ".join(
[
"{}: {}".format(
p,
getattr(self, p)
if getattr(self, p).numel() == 1
else getattr(self, p).size(),
)
for p in self.arg_constraints.keys()
]
)
return (
self.__class__.__name__
+ "("
+ f"base_density: {str(self.base_dist)}, "
+ args_string
+ ")"
)

def sample(self, key, sample_shape=()):
base_key, skew_key = random.split(key)
bd = self.base_dist
ys = bd.sample(base_key, sample_shape)
u = random.uniform(skew_key, sample_shape + self.batch_shape)

# Section 2.3 step 3 in [1]
mask = u <= 0.5 + 0.5 * (
self.skewness * jnp.sin((ys - bd.mean) % (2 * jnp.pi))
).sum(-1)
mask = mask[..., None]
samples = (jnp.where(mask, ys, -ys + 2 * bd.mean) + jnp.pi) % (
2 * jnp.pi
) - jnp.pi
return samples

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
if self.base_dist._validate_args:
self.base_dist._validate_sample(value)

# Eq. 2.1 in [1]
skew_prob = jnp.log(
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
1
+ (
self.skewness * jnp.sin((value - self.base_dist.mean) % (2 * jnp.pi))
).sum(-1)
)
return self.base_dist.log_prob(value) + skew_prob

@property
def mean(self):
return self.base_dist.mean


class VonMises(Distribution):
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
reparametrized_params = ["loc"]
Expand Down
48 changes: 38 additions & 10 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import namedtuple
from functools import partial
import inspect
import math
import os

import numpy as np
Expand Down Expand Up @@ -86,6 +87,27 @@ def _TruncatedNormal(loc, scale, low, high):
_TruncatedNormal.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ())


class SineSkewedUniform(dist.SineSkewed):
def __init__(self, skewness, **kwargs):
lower, upper = (jnp.array([-math.pi, -math.pi]), jnp.array([math.pi, math.pi]))
base_dist = dist.Uniform(lower, upper, **kwargs).to_event(lower.ndim)
super().__init__(base_dist, skewness, **kwargs)


class SineSkewedVonMises(dist.SineSkewed):
def __init__(self, skewness, **kwargs):
von_loc, von_conc = (jnp.array([0.0]), jnp.array([1.0]))
base_dist = dist.VonMises(von_loc, von_conc, **kwargs).to_event(von_loc.ndim)
super().__init__(base_dist, skewness, **kwargs)


class SineSkewedVonMisesBatched(dist.SineSkewed):
def __init__(self, skewness, **kwargs):
von_loc, von_conc = (jnp.array([0.0, -1.234]), jnp.array([1.0, 10.0]))
base_dist = dist.VonMises(von_loc, von_conc, **kwargs).to_event(von_loc.ndim)
super().__init__(base_dist, skewness, **kwargs)


class _ImproperWrapper(dist.ImproperUniform):
def sample(self, key, sample_shape=()):
transform = biject_to(self.support)
Expand Down Expand Up @@ -337,6 +359,9 @@ def get_sp_dist(jax_dist):
T(dist.ProjectedNormal, jnp.array([[2.0, 3.0]])),
T(dist.ProjectedNormal, jnp.array([0.0, 0.0, 0.0])),
T(dist.ProjectedNormal, jnp.array([[-1.0, 2.0, 3.0]])),
T(SineSkewedUniform, jnp.array([-math.pi / 4, 0.1])),
T(SineSkewedVonMises, jnp.array([0.342355])),
T(SineSkewedVonMisesBatched, jnp.array([[0.342355, -0.0001], [0.91, 0.09]])),
]

DISCRETE = [
Expand Down Expand Up @@ -743,15 +768,13 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
rng_key = random.PRNGKey(0)
samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape
truncated_dists = (
dist.LeftTruncatedDistribution,
dist.RightTruncatedDistribution,
dist.TwoSidedTruncatedDistribution,
)
if sp_dist is None:
if isinstance(
jax_dist,
(
dist.LeftTruncatedDistribution,
dist.RightTruncatedDistribution,
dist.TwoSidedTruncatedDistribution,
),
):
if isinstance(jax_dist, truncated_dists):
if isinstance(params[0], dist.Distribution):
# new api
loc, scale, low, high = (
Expand Down Expand Up @@ -1104,6 +1127,8 @@ def test_mean_var(jax_dist, sp_dist, params):
pytest.skip("Improper distribution does not has mean/var implemented")
if jax_dist is FoldedNormal:
pytest.skip("Folded distribution does not has mean/var implemented")
if "SineSkewed" in jax_dist.__name__:
pytest.skip("Skewed Distribution are not symmetric about location.")
if jax_dist in (
_TruncatedNormal,
dist.LeftTruncatedDistribution,
Expand Down Expand Up @@ -1213,6 +1238,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
and dist_args[i] != "concentration"
):
continue
if "SineSkewed" in jax_dist.__name__ and dist_args[i] != "skewness":
continue
if (
jax_dist is dist.TwoSidedTruncatedDistribution
and dist_args[i] == "base_dist"
Expand All @@ -1239,7 +1266,9 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
assert jax_dist(*oob_params)

# Invalid parameter values throw ValueError
if not dependent_constraint and jax_dist is not _ImproperWrapper:
if not dependent_constraint and (
jax_dist is not _ImproperWrapper and "SineSkewed" not in jax_dist.__name__
):
with pytest.raises(ValueError):
jax_dist(*oob_params, validate_args=True)

Expand Down Expand Up @@ -1488,7 +1517,6 @@ def test_constraints(constraint, x, expected):
)
@pytest.mark.parametrize("shape", [(), (1,), (3,), (6,), (3, 1), (1, 3), (5, 3)])
def test_biject_to(constraint, shape):

transform = biject_to(constraint)
event_dim = transform.domain.event_dim
if isinstance(constraint, constraints._Interval):
Expand Down