-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RecursiveLinearTransform
for linear state space models.
#1766
Conversation
numpyro/distributions/transforms.py
Outdated
are vectors and :math:`A` is a transition matrix. The series is initialized by | ||
:math:`y_0 = 0`. | ||
|
||
:param transition_matrix: Transition matrix :math:`A` for successive states. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe 'matrix'
-> squared matrix
for clarity. Currently, the bias x
is time-dependent, but the matrix is constant. Do you plan to make the class name more verbose to reflect that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
'matrix'
->squared matrix
for clarity.
👍
Currently, the bias
x
is time-dependent, but the matrix is constant. Do you plan to make the class name more verbose to reflect that?
Sorry, I explained poorly. The x
here is the argument of the transform rather than the bias in the AffineTransform
, for example. Keeping the transition matrix constant means that the transform can be applied to sequences of arbitrary length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x is the (time-dependent) bias (or noise if we place a normal distribution over it) of a linear dynamical model. Typically, people use other notations, something like x_t = Ax{t-1} + b_t.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair. I was using the x/y notation here to stick with the typical arguments of the Transform
classes. Do you have an idea for a more explanatory name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name looks 👌 to me. :)
numpyro/distributions/transforms.py
Outdated
@@ -1347,26 +1347,25 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | |||
x = jnp.moveaxis(x, -2, 0) | |||
|
|||
def f(y, x): | |||
y = (self.transition_matrix * y[..., None, :]).sum(axis=-1) + x | |||
y = y @ self.transition_matrix.T + x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you might want jnp.swapaxes(self.transition_matrix, -1, -2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated it to use einsum
because we have shapes (..., p, p)
for the transition matrices A
and (..., n, p)
for the states x
. Inside the scan
function, we are dealing with state of shape (..., p)
because we're scanning along the n
dimension. There may very well be a better way to do this.
numpyro/distributions/transforms.py
Outdated
x = y_t - (self.transition_matrix * y_tm1[..., None, :]).sum(axis=-1) | ||
return y_tm1, x | ||
def f(y, prev): | ||
x = y - prev @ self.transition_matrix.T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly, jnp.swapaxes(self.transition_matrix, -1, -2)
|
||
_, x = lax.scan(f, y[-1], jnp.roll(y, 1, axis=0).at[0].set(0), reverse=True) | ||
return jnp.moveaxis(x, 0, -2) | ||
|
||
def log_abs_det_jacobian(self, x: jnp.ndarray, y: jnp.ndarray, intermediates=None): | ||
slogdet = jnp.linalg.slogdet(self.transition_matrix) | ||
return jnp.broadcast_to(slogdet.logabsdet, x.shape[:-2]) * x.shape[-2] | ||
return jnp.zeros_like(x, shape=x.shape[:-2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds reasonable to me, this is sort of a shear transformation, so the Jacobian determinant is 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, because of the temporal nature of the transform, the Jacobian is triangular. Because the x
only appears additively, the diagonal is one, leading to the unit Jacobian.
55b24c7
to
c99fec1
Compare
Thanks, @tillahoffmann! |
…pl#1766) * Format reparam module to comply with style guide. * Add `RealFastFourierTransform` to documentation. * Ignore `venv` directory for `update_headers.py` script. * Ignore autogenerated documentation sources. * Add numerical Jacobian check for bijective transforms. * Add `RecursiveLinearTransform`. * Use matrix multiplication operator and fix Jacobian. * Use non-trivial transition matrix in test. * Specify that transition matrices must (batches of) square matrices. * Fix `scan` implementation for batched transition matrices and add test.
This PR adds a$y_t = A y_{t - 1} + x_t$ for $t > 0$ , where $x_t$ and $y_t$ are $p$ -vectors and $A$ is a $p\times p$ transition matrix. The series is initialized by $y_0 = 0$ .
RecursiveLinearTransform
which is a linear transformation applied recursively such thatThis transform can be used to easily declare linear state space models, e.g., a Cauchy random walk is
A Kalman-style model for a rocket with state
y = (position, velocity)
isThis PR also makes a few minor changes (happy to factor out if you prefer):
ExplicitReparam
from Add explicit reparametrizer. #1754 to comply with the stricter linting.RealFastFourierTransform
from Add complex constraint and real Fourier transform. #1762 to the documentation.venv
directory in theupdate_headers.py
script.docs/source/{examples,tutorials,getting_started.rst}
to.gitignore
.log_abs_det_jacobian
implementation using autodiff.