Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RecursiveLinearTransform for linear state space models. #1766

Merged
merged 10 commits into from
Mar 25, 2024

Conversation

tillahoffmann
Copy link
Contributor

@tillahoffmann tillahoffmann commented Mar 20, 2024

This PR adds a RecursiveLinearTransform which is a linear transformation applied recursively such that $y_t = A y_{t - 1} + x_t$ for $t > 0$, where $x_t$ and $y_t$ are $p$-vectors and $A$ is a $p\times p$ transition matrix. The series is initialized by $y_0 = 0$.

This transform can be used to easily declare linear state space models, e.g., a Cauchy random walk is

>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import distributions as dist
>>>
>>> def cauchy_random_walk():
...     return numpyro.sample(
...         "x",
...         dist.TransformedDistribution(
...             dist.Cauchy(0, 1).expand([10, 1]).to_event(1),
...             dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
...         ),
...     )

A Kalman-style model for a rocket with state y = (position, velocity) is

>>> def rocket_trajectory():
...     scale = numpyro.sample(
...         "scale",
...         dist.HalfCauchy(1).expand([2]).to_event(1),
...     )
...     transition_matrix = jnp.array([[1, 1], [0, 1]])
...     return numpyro.sample(
...         "x",
...         dist.TransformedDistribution(
...             dist.Normal(0, scale).expand([10, 2]).to_event(1),
...             dist.transforms.RecursiveLinearTransform(transition_matrix),
...         ),
...     )

This PR also makes a few minor changes (happy to factor out if you prefer):

are vectors and :math:`A` is a transition matrix. The series is initialized by
:math:`y_0 = 0`.

:param transition_matrix: Transition matrix :math:`A` for successive states.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe 'matrix' -> squared matrix for clarity. Currently, the bias x is time-dependent, but the matrix is constant. Do you plan to make the class name more verbose to reflect that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe 'matrix' -> squared matrix for clarity.

👍

Currently, the bias x is time-dependent, but the matrix is constant. Do you plan to make the class name more verbose to reflect that?

Sorry, I explained poorly. The x here is the argument of the transform rather than the bias in the AffineTransform, for example. Keeping the transition matrix constant means that the transform can be applied to sequences of arbitrary length.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x is the (time-dependent) bias (or noise if we place a normal distribution over it) of a linear dynamical model. Typically, people use other notations, something like x_t = Ax{t-1} + b_t.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair. I was using the x/y notation here to stick with the typical arguments of the Transform classes. Do you have an idea for a more explanatory name?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name looks 👌 to me. :)

@@ -1347,26 +1347,25 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = jnp.moveaxis(x, -2, 0)

def f(y, x):
y = (self.transition_matrix * y[..., None, :]).sum(axis=-1) + x
y = y @ self.transition_matrix.T + x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might want jnp.swapaxes(self.transition_matrix, -1, -2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated it to use einsum because we have shapes (..., p, p) for the transition matrices A and (..., n, p) for the states x. Inside the scan function, we are dealing with state of shape (..., p) because we're scanning along the n dimension. There may very well be a better way to do this.

x = y_t - (self.transition_matrix * y_tm1[..., None, :]).sum(axis=-1)
return y_tm1, x
def f(y, prev):
x = y - prev @ self.transition_matrix.T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, jnp.swapaxes(self.transition_matrix, -1, -2)


_, x = lax.scan(f, y[-1], jnp.roll(y, 1, axis=0).at[0].set(0), reverse=True)
return jnp.moveaxis(x, 0, -2)

def log_abs_det_jacobian(self, x: jnp.ndarray, y: jnp.ndarray, intermediates=None):
slogdet = jnp.linalg.slogdet(self.transition_matrix)
return jnp.broadcast_to(slogdet.logabsdet, x.shape[:-2]) * x.shape[-2]
return jnp.zeros_like(x, shape=x.shape[:-2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable to me, this is sort of a shear transformation, so the Jacobian determinant is 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because of the temporal nature of the transform, the Jacobian is triangular. Because the x only appears additively, the diagonal is one, leading to the unit Jacobian.

@fehiepsi
Copy link
Member

Thanks, @tillahoffmann!

@fehiepsi fehiepsi merged commit ad6861a into pyro-ppl:master Mar 25, 2024
4 checks passed
@tillahoffmann tillahoffmann deleted the recursive-linear branch March 25, 2024 16:54
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
…pl#1766)

* Format reparam module to comply with style guide.

* Add `RealFastFourierTransform` to documentation.

* Ignore `venv` directory for `update_headers.py` script.

* Ignore autogenerated documentation sources.

* Add numerical Jacobian check for bijective transforms.

* Add `RecursiveLinearTransform`.

* Use matrix multiplication operator and fix Jacobian.

* Use non-trivial transition matrix in test.

* Specify that transition matrices must (batches of) square matrices.

* Fix `scan` implementation for batched transition matrices and add test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants