Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sine Skewed distribution #1055

Merged
merged 18 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -502,13 +502,21 @@ ProjectedNormal
:member-order: bysource

SineBivariateVonMises
---------------------
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.directional.SineBivariateVonMises
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

SineSkewed
^^^^^^^^^^
.. autoclass:: numpyro.distributions.directional.SineSkewed
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

VonMises
^^^^^^^^
.. autoclass:: numpyro.distributions.directional.VonMises
Expand Down Expand Up @@ -599,7 +607,7 @@ boolean
.. autodata:: numpyro.distributions.constraints.boolean

circular
--------
^^^^^^^^
.. autodata:: numpyro.distributions.constraints.circular

corr_cholesky
Expand Down Expand Up @@ -787,7 +795,7 @@ InvCholeskyTransform
:member-order: bysource

L1BallTransform
---------------
^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.transforms.L1BallTransform
:members:
:undoc-members:
Expand Down
4 changes: 3 additions & 1 deletion numpyro/compat/pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import warnings

from numpyro.compat.util import UnsupportedAPIWarning
from numpyro.primitives import module, param as _param, plate, sample # noqa: F401

from numpyro.primitives import module, plate, sample # noqa: F401 isort:skip
from numpyro.primitives import param as _param # noqa: F401 isort:skip

_PARAM_STORE = {}

Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from numpyro.distributions.directional import (
ProjectedNormal,
SineBivariateVonMises,
SineSkewed,
VonMises,
)
from numpyro.distributions.discrete import (
Expand Down Expand Up @@ -152,6 +153,7 @@
"PRNGIdentity",
"RightTruncatedDistribution",
"SineBivariateVonMises",
"SineSkewed",
"SoftLaplace",
"StudentT",
"TransformedDistribution",
Expand Down
132 changes: 130 additions & 2 deletions numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,136 @@ def variance(self):
PhiMarginalState = namedtuple("PhiMarginalState", ["i", "done", "phi", "key"])


class SineSkewed(Distribution):
"""Sine-skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus
distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric)
base distribution. Torus distributions are distributions with support on products of circles
(i.e., ⨂^d S^1 where S^1=[-pi,pi) ). So, a 0-torus is a point, the 1-torus is a circle,
and the 2-torus is commonly associated with the donut shape.

The sine skewed X distribution is parameterized by a weight parameter for each dimension of the event of X.
For example with a von Mises distribution over a circle (1-torus), the sine skewed von Mises distribution has one
skew parameter. The skewness parameters can be inferred using :class:`~numpyro.infer.HMC` or
:class:`~numpyro.infer.NUTS`. For example, the following will produce a prior over
skewness for the 2-torus,::

@numpyro.handlers.reparam(config={'phi_loc': CircularReparam(), 'psi_loc': CircularReparam()})
def model(obs):
# Sine priors
phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.))
psi_loc = numpyro.sample('psi_loc', VonMises(-pi / 2, 2.))
phi_conc = numpyro.sample('phi_conc', Beta(1., 1.))
psi_conc = numpyro.sample('psi_conc', Beta(1., 1.))
corr_scale = numpyro.sample('corr_scale', Beta(2., 5.))

# Skewing prior
ball_trans = L1BallTransform()
skewness = numpyro.sample('skew_phi', Normal(0, 0.5).expand((2,)))
skewness = ball_trans(skewness) # constraint sum |skewness_i| <= 1

with numpyro.plate('obs_plate'):
sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc,
phi_concentration=70 * phi_conc,
psi_concentration=70 * psi_conc,
weighted_correlation=corr_scale)
return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)

To ensure the skewing does not alter the normalization constant of the (sine bivariate von Mises) base
distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of
skewness to be less than or equal to one. We can use the :class:`~numpyro.distriubtions.transforms.L1BallTransform`
to achieve this.

In the context of :class:`~pyro.infer.SVI`, this distribution can freely be used as a likelihood, but use as
latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist
cannot be reparameterized.

.. note:: An event in the base distribution must be on a d-torus, so the event_shape must be `(d,)`.

.. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event
must be less than or equal to one. See eq. 2.1 in [1].

** References: **
1. Sine-skewed toroidal distributions and their application in protein bioinformatics
Ameijeiras-Alonso, J., Ley, C. (2019)

:param numpyro.distributions.Distribution base_dist: base density on a d-dimensional torus. Supported base
distributions include: 1D :class:`~numpyro.distributions.VonMises`,
:class:`~numnumpyro.distributions.SineBivariateVonMises`, 1D :class:`~numpyro.distributions.ProjectedNormal`,
and :class:`~numpyro.distributions.Uniform` (-pi, pi).
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
:param jax.numpy.array skewness: skewness of the distribution.
"""

arg_constraints = {"skewness": constraints.l1_ball}

support = constraints.independent(constraints.circular, 1)

def __init__(self, base_dist: Distribution, skewness, validate_args=None):
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
assert (
base_dist.event_shape == skewness.shape[-1:]
), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`."

batch_shape = jnp.broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1])
event_shape = skewness.shape[-1:]
self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape)
self.base_dist = base_dist.expand(batch_shape)
super().__init__(batch_shape, event_shape, validate_args=validate_args)

def __repr__(self):
args_string = ", ".join(
[
"{}: {}".format(
p,
getattr(self, p)
if getattr(self, p).numel() == 1
else getattr(self, p).size(),
)
for p in self.arg_constraints.keys()
]
)
return (
self.__class__.__name__
+ "("
+ f"base_density: {str(self.base_dist)}, "
+ args_string
+ ")"
)

def sample(self, key, sample_shape=()):
base_key, skew_key = random.split(key)
bd = self.base_dist
ys = bd.sample(base_key, sample_shape)
u = random.uniform(skew_key, sample_shape + self.batch_shape)

# Section 2.3 step 3 in [1]
mask = u <= 0.5 + 0.5 * (
self.skewness * jnp.sin((ys - bd.mean) % (2 * jnp.pi))
).sum(-1)
mask = mask[..., None]
samples = (jnp.where(mask, ys, -ys + 2 * bd.mean) + jnp.pi) % (
2 * jnp.pi
) - jnp.pi
return samples

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
if self.base_dist._validate_args:
self.base_dist._validate_sample(value)

# Eq. 2.1 in [1]
skew_prob = jnp.log1p(
(self.skewness * jnp.sin((value - self.base_dist.mean) % (2 * jnp.pi))).sum(
-1
)
)
return self.base_dist.log_prob(value) + skew_prob

@property
def mean(self):
"""Mean of the base distribution"""
return self.base_dist.mean


class SineBivariateVonMises(Distribution):
r"""Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by

Expand Down Expand Up @@ -389,8 +519,6 @@ def mean(self):
mean = (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % (
2.0 * jnp.pi
) - jnp.pi
print(mean.shape)
print(self.batch_shape)
return jnp.broadcast_to(mean, (*self.batch_shape, 2))

def _bfind(self, eig):
Expand Down
57 changes: 48 additions & 9 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,27 @@ def _TruncatedNormal(loc, scale, low, high):
_TruncatedNormal.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ())


class SineSkewedUniform(dist.SineSkewed):
def __init__(self, skewness, **kwargs):
lower, upper = (jnp.array([-math.pi, -math.pi]), jnp.array([math.pi, math.pi]))
base_dist = dist.Uniform(lower, upper, **kwargs).to_event(lower.ndim)
super().__init__(base_dist, skewness, **kwargs)


class SineSkewedVonMises(dist.SineSkewed):
def __init__(self, skewness, **kwargs):
von_loc, von_conc = (jnp.array([0.0]), jnp.array([1.0]))
base_dist = dist.VonMises(von_loc, von_conc, **kwargs).to_event(von_loc.ndim)
super().__init__(base_dist, skewness, **kwargs)


class SineSkewedVonMisesBatched(dist.SineSkewed):
def __init__(self, skewness, **kwargs):
von_loc, von_conc = (jnp.array([0.0, -1.234]), jnp.array([1.0, 10.0]))
base_dist = dist.VonMises(von_loc, von_conc, **kwargs).to_event(von_loc.ndim)
super().__init__(base_dist, skewness, **kwargs)


def _GaussianMixture(mixing_probs, loc, scale):
component_dist = dist.Normal(loc=loc, scale=scale)
mixing_distribution = dist.Categorical(probs=mixing_probs)
Expand Down Expand Up @@ -434,6 +455,9 @@ def get_sp_dist(jax_dist):
T(dist.ProjectedNormal, jnp.array([[2.0, 3.0]])),
T(dist.ProjectedNormal, jnp.array([0.0, 0.0, 0.0])),
T(dist.ProjectedNormal, jnp.array([[-1.0, 2.0, 3.0]])),
T(SineSkewedUniform, jnp.array([-math.pi / 4, 0.1])),
T(SineSkewedVonMises, jnp.array([0.342355])),
T(SineSkewedVonMisesBatched, jnp.array([[0.342355, -0.0001], [0.91, 0.09]])),
]

DISCRETE = [
Expand Down Expand Up @@ -561,6 +585,12 @@ def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
elif constraint is constraints.sphere:
x = random.normal(key, size)
return x / jnp.linalg.norm(x, axis=-1)
elif constraint is constraints.l1_ball:
key1, key2 = random.split(key)
sign = random.bernoulli(key1)
bounds = [0, (-1) ** sign * 0.5]
return random.uniform(key, size, float, *sorted(bounds))

else:
raise NotImplementedError("{} not implemented.".format(constraint))

Expand Down Expand Up @@ -622,6 +652,11 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
x = random.normal(key, size)
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
return 2 * x
elif constraint is constraints.l1_ball:
key1, key2 = random.split(key)
sign = random.bernoulli(key1)
bounds = [(-1) ** sign * 1.1, (-1) ** sign * 2]
return random.uniform(key, size, float, *sorted(bounds))
else:
raise NotImplementedError("{} not implemented.".format(constraint))

Expand Down Expand Up @@ -839,15 +874,13 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
rng_key = random.PRNGKey(0)
samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape
truncated_dists = (
dist.LeftTruncatedDistribution,
dist.RightTruncatedDistribution,
dist.TwoSidedTruncatedDistribution,
)
if sp_dist is None:
if isinstance(
jax_dist,
(
dist.LeftTruncatedDistribution,
dist.RightTruncatedDistribution,
dist.TwoSidedTruncatedDistribution,
),
):
if isinstance(jax_dist, truncated_dists):
if isinstance(params[0], dist.Distribution):
# new api
loc, scale, low, high = (
Expand Down Expand Up @@ -1200,6 +1233,8 @@ def test_mean_var(jax_dist, sp_dist, params):
pytest.skip("Improper distribution does not has mean/var implemented")
if jax_dist is FoldedNormal:
pytest.skip("Folded distribution does not has mean/var implemented")
if "SineSkewed" in jax_dist.__name__:
pytest.skip("Skewed Distribution are not symmetric about location.")
if jax_dist in (
_TruncatedNormal,
dist.LeftTruncatedDistribution,
Expand Down Expand Up @@ -1316,6 +1351,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
and dist_args[i] != "concentration"
):
continue
if "SineSkewed" in jax_dist.__name__ and dist_args[i] != "skewness":
continue
if (
jax_dist is dist.TwoSidedTruncatedDistribution
and dist_args[i] == "base_dist"
Expand Down Expand Up @@ -1347,7 +1384,9 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
assert jax_dist(*oob_params)

# Invalid parameter values throw ValueError
if not dependent_constraint and jax_dist is not _ImproperWrapper:
if not dependent_constraint and (
jax_dist is not _ImproperWrapper and "SineSkewed" not in jax_dist.__name__
):
with pytest.raises(ValueError):
jax_dist(*oob_params, validate_args=True)

Expand Down