-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add complex constraint and real Fourier transform. #1762
Conversation
numpyro/distributions/constraints.py
Outdated
@@ -638,6 +639,11 @@ def feasible_like(self, prototype): | |||
return jax.numpy.zeros_like(prototype) | |||
|
|||
|
|||
class _Real(_Complex): | |||
def __call__(self, x): | |||
return super().__call__(x) & (jax.numpy.isreal(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surprisingly, adding the &
leads to some warnings not being emitted in test/test_distributions.py::test_distribution_constraints
, but only when jit
-ed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think isreal
is a jax function, which will generate a tracer under jit
.
One option is to switch to np like in PositiveDefinite. It is better to have separate instances for real and complex though. I don't think that current constraints work for complex numbers.
numpyro/distributions/transforms.py
Outdated
def __init__( | ||
self, | ||
shape=None, | ||
ndims=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe using event_dim
to match patterns in constraints. Could you elaborate on why we need this attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The motivation is to specify the number of dimensions along which to transform, e.g., ndim=1
gives the one-dimensional rfft, ndim=2
gives the two-dimensional rfft, etc. Maybe event_ndims
or transform_ndims
for consistency with reinterpreted_batch_ndims
in IndependentTransform
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I didn't notice that you are supporting n-dimensional fft. Thanks for explaning!
Re event_ndims: looks better to me. We used event_dim
just to be consistent with initial distribution designs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I like transform_lengths
and transform_ndims
better. It is up to you. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've kept it with transform_ndims
and transform_shape
for now. I was thinking shape
could be good because that's the terminology numpy uses (although they just call the parameter s
).
numpyro/distributions/transforms.py
Outdated
def log_abs_det_jacobian( | ||
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None | ||
) -> jnp.ndarray: | ||
return 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you return the whole batch shape for this?
7f89919
to
984f02e
Compare
numpyro/distributions/transforms.py
Outdated
def domain(self) -> constraints.Constraint: | ||
return constraints._IndependentConstraint( | ||
constraints._Real(), self.transform_ndims | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe constraints.independent(constraints.real, ...)
numpyro/distributions/transforms.py
Outdated
@property | ||
def codomain(self) -> constraints.Constraint: | ||
return constraints._IndependentConstraint( | ||
constraints._Complex(), self.transform_ndims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to above
numpyro/distributions/transforms.py
Outdated
def log_abs_det_jacobian( | ||
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None | ||
) -> jnp.ndarray: | ||
return jnp.ones_like(x, shape=x.shape[: -self.transform_ndims]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: zeros_like
I think you need to broadcast x.shape[: -self.transform_ndims]
and y.shape[: -self.transform_ndims]
numpyro/distributions/constraints.py
Outdated
class _Real(_SingletonConstraint): | ||
def __call__(self, x): | ||
# XXX: consider to relax this condition to [-inf, inf] interval | ||
return (x == x) & (x != float("inf")) & (x != float("-inf")) & np.isreal(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove np.isreal
here? We don't check for real input across all constraints. If x is a tracer, np.isreal might not work.
Awesome work!! thanks, Till! |
numpyro/distributions/transforms.py
Outdated
shape = jnp.broadcast_shapes( | ||
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] | ||
) | ||
return jnp.ones_like(x, shape=shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, wouldn't this zeros_like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry. Fixed.
Thank you for the fast review and merging! I just realized we probably want to add the option to unpack the complex coefficients with shape One option is to add an argument |
Interesting, I feel that it's better to have a separate transform for the unpack version. |
That could look like class RealFastFourierUnpackTransform(Transform):
"""
:param size: Size of the last dimension of the transform, required because the size
cannot be inferred from the shape of the coefficients.
"""
def __init__(self, size):
self.size = size
... If we added it with an Edit: Or did you mean having a transform (which includes unpacking) that inherits from the transform implemented in this PR? |
* Add `complex` constraint. * Add real fast Fourier transform. * Remove redundant domain and codomain definitions. * Use numpy for `isreal` check. * Return `log_abs_det_jacobian` with correct batch shape and add test. * Update parameter names for `RealFastFourierTransform`. * Remove `isreal` check. * Broadcast shapes in `log_abs_det_jacobian`. * Update construction of domain and codomain for `RealFastFourierTransform`. * Fix incorrect Jacobian.
This PR
_Complex
constraint,_Real
constraint to inherit from_Complex
and check that values are indeed real,RealFastFourierTransform
.The
RealFastFourierTransform
interface differs slightly fromjax.numpy.fft.rfftn
to respect that batch dimensions precede event dimensions. I.e., rather than specifying the axes along which to transform, one specifies the number of dimensions to transform akin toreinterpreted_batch_ndims
in theIndependentTransform
. The interface does not support specifying thenorm
parameter becausejit
-ing isn't happy with string arguments (cf. jax-ml/jax#3045).The motivation for this PR is to implement fast Gaussian processes for stationary kernels (cf. section 4 in https://arxiv.org/pdf/2301.08836v4.pdf).
Thank you for taking the time to review my recent PRs. Let me know if they create too much noise in your inbox.