Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RecursiveLinearTransform for linear state space models. #1766

Merged
merged 10 commits into from
Mar 25, 2024
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ numpyro/examples/.data
# docs
docs/build
docs/.DS_Store
docs/source/examples
docs/source/tutorials
docs/source/getting_started.rst
20 changes: 19 additions & 1 deletion docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ EulerMaruyama
:undoc-members:
:show-inheritance:
:member-order: bysource

Exponential
^^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.Exponential
Expand Down Expand Up @@ -948,6 +948,24 @@ PowerTransform
:show-inheritance:
:member-order: bysource

RealFastFourierTransform
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: numpyro.distributions.transforms.RealFastFourierTransform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

RecursiveLinearTransform
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: numpyro.distributions.transforms.RecursiveLinearTransform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

ScaledUnitLowerCholeskyTransform
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.transforms.ScaledUnitLowerCholeskyTransform
Expand Down
92 changes: 92 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,98 @@ def __eq__(self, other):
)


class RecursiveLinearTransform(Transform):
"""
Apply a linear transformation recursively such that
:math:`y_t = A y_{t - 1} + x_t` for :math:`t > 0`, where :math:`x_t` and :math:`y_t`
are vectors and :math:`A` is a square transition matrix. The series is initialized
by :math:`y_0 = 0`.

:param transition_matrix: Squared transition matrix :math:`A` for successive states
or a batch of transition matrices.

**Example:**

.. doctest::

>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import distributions as dist
>>>
>>> def cauchy_random_walk():
... return numpyro.sample(
... "x",
... dist.TransformedDistribution(
... dist.Cauchy(0, 1).expand([10, 1]).to_event(1),
... dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
... ),
... )
>>>
>>> numpyro.handlers.seed(cauchy_random_walk, 0)().shape
(10, 1)
>>>
>>> def rocket_trajectory():
... scale = numpyro.sample(
... "scale",
... dist.HalfCauchy(1).expand([2]).to_event(1),
... )
... transition_matrix = jnp.array([[1, 1], [0, 1]])
... return numpyro.sample(
... "x",
... dist.TransformedDistribution(
... dist.Normal(0, scale).expand([10, 2]).to_event(1),
... dist.transforms.RecursiveLinearTransform(transition_matrix),
... ),
... )
>>>
>>> numpyro.handlers.seed(rocket_trajectory, 0)().shape
(10, 2)
"""

domain = constraints.real_matrix
codomain = constraints.real_matrix

def __init__(self, transition_matrix: jnp.ndarray) -> None:
self.transition_matrix = transition_matrix

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# Move the time axis to the first position so we can scan over it.
x = jnp.moveaxis(x, -2, 0)

def f(y, x):
y = jnp.einsum("...ij,...j->...i", self.transition_matrix, y) + x
return y, y

_, y = lax.scan(f, jnp.zeros_like(x, shape=x.shape[1:]), x)
return jnp.moveaxis(y, 0, -2)

def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
# Move the time axis to the first position so we can scan over it in reverse.
y = jnp.moveaxis(y, -2, 0)

def f(y, prev):
x = y - jnp.einsum("...ij,...j->...i", self.transition_matrix, prev)
return prev, x

_, x = lax.scan(f, y[-1], jnp.roll(y, 1, axis=0).at[0].set(0), reverse=True)
return jnp.moveaxis(x, 0, -2)

def log_abs_det_jacobian(self, x: jnp.ndarray, y: jnp.ndarray, intermediates=None):
return jnp.zeros_like(x, shape=x.shape[:-2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable to me, this is sort of a shear transformation, so the Jacobian determinant is 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because of the temporal nature of the transform, the Jacobian is triangular. Because the x only appears additively, the diagonal is one, leading to the unit Jacobian.


def tree_flatten(self):
return (self.transition_matrix,), (
("transition_matrix",),
{},
)

def __eq__(self, other):
if not isinstance(other, RecursiveLinearTransform):
return False
return jnp.array_equal(self.transition_matrix, other.transition_matrix)


##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down
1 change: 1 addition & 0 deletions numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ class ExplicitReparam(Reparam):
>>> mcmc.run(random.PRNGKey(2)) # doctest: +SKIP
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]
"""

def __init__(self, transform):
if isinstance(transform, Iterable) and all(
isinstance(t, dist.transforms.Transform) for t in transform
Expand Down
2 changes: 1 addition & 1 deletion scripts/update_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys

root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
blacklist = ["/build/", "/dist/", "/pyro_api.egg"]
blacklist = ["/build/", "/dist/", "/pyro_api.egg", "/venv/"]
file_types = [("*.py", "# {}"), ("*.cpp", "// {}")]

parser = argparse.ArgumentParser()
Expand Down
12 changes: 12 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3222,3 +3222,15 @@ def test_lowrank_mvn_19885(capfd: pytest.CaptureFixture) -> None:
assert x.shape == (sample_size, batch_size, event_size)
log_prob = _assert_not_jax_issue_19885(capfd, distribution.log_prob, x)
assert log_prob.shape == (sample_size, batch_size)


def test_gaussian_random_walk_linear_recursive_equivalence():
dist1 = dist.GaussianRandomWalk(3.7, 15)
dist2 = dist.TransformedDistribution(
dist.Normal(0, 3.7).expand([15, 1]).to_event(2),
dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
)
x1 = dist1.sample(random.PRNGKey(7))
x2 = dist2.sample(random.PRNGKey(7))
assert jnp.allclose(x1, x2.squeeze())
assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2))
42 changes: 40 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

from collections import namedtuple
from functools import partial
import math

import pytest

from jax import jit, random, tree_map, vmap
from jax import jacfwd, jit, random, tree_map, vmap
import jax.numpy as jnp

from numpyro.distributions.flows import (
Expand All @@ -30,6 +31,7 @@
PermuteTransform,
PowerTransform,
RealFastFourierTransform,
RecursiveLinearTransform,
ReshapeTransform,
ScaledUnitLowerCholeskyTransform,
SigmoidTransform,
Expand Down Expand Up @@ -90,6 +92,11 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
(),
dict(transform_shape=(3, 4, 5), transform_ndims=3),
),
"recursive_linear": T(
RecursiveLinearTransform,
(jnp.eye(5),),
dict(),
),
"simplex_to_ordered": T(
SimplexToOrderedTransform,
(_a(1.0),),
Expand Down Expand Up @@ -277,6 +284,10 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims):
(PowerTransform(2.5), ()),
(RealFastFourierTransform(7), (7,)),
(RealFastFourierTransform((8, 9), 2), (8, 9)),
(
RecursiveLinearTransform(random.normal(random.key(17), (4, 4))),
(7, 4),
),
(ReshapeTransform((5, 2), (10,)), (10,)),
(ReshapeTransform((15,), (3, 5)), (3, 5)),
(ScaledUnitLowerCholeskyTransform(), (6,)),
Expand Down Expand Up @@ -312,4 +323,31 @@ def test_bijective_transforms(transform, shape):
atol = 1e-2
assert jnp.allclose(x1, x2, atol=atol)

assert transform.log_abs_det_jacobian(x1, y).shape == batch_shape
log_abs_det_jacobian = transform.log_abs_det_jacobian(x1, y)
assert log_abs_det_jacobian.shape == batch_shape

# Also check the Jacobian numerically for transforms with the same input and output
# size, unless they are explicitly excluded. E.g., the upper triangular of the
# CholeskyTransform is zero, giving rise to a singular Jacobian.
skip_jacobian_check = (CholeskyTransform,)
size_x = int(x1.size / math.prod(batch_shape))
size_y = int(y.size / math.prod(batch_shape))
if size_x == size_y and not isinstance(transform, skip_jacobian_check):
jac = (
vmap(jacfwd(transform))(x1)
.reshape((-1,) + x1.shape[len(batch_shape) :])
.reshape(batch_shape + (size_y, size_x))
)
slogdet = jnp.linalg.slogdet(jac)
assert jnp.allclose(log_abs_det_jacobian, slogdet.logabsdet, atol=atol)


def test_batched_recursive_linear_transform():
batch_shape = (4, 17)
x = random.normal(random.key(8), batch_shape + (10, 3))
# Get a batch of matrices with eigenvalues that don't blow up the sequence.
A = CorrCholeskyTransform()(random.normal(random.key(7), batch_shape + (3,)))
transform = RecursiveLinearTransform(A)
y = transform(x)
assert y.shape == x.shape
assert jnp.allclose(x, transform.inv(y), atol=1e-6)
Loading