Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex constraint and real Fourier transform. #1762

Merged
merged 10 commits into from
Mar 16, 2024

Conversation

tillahoffmann
Copy link
Contributor

@tillahoffmann tillahoffmann commented Mar 15, 2024

This PR

  • adds a _Complex constraint,
  • modifies the _Real constraint to inherit from _Complex and check that values are indeed real,
  • adds RealFastFourierTransform.

The RealFastFourierTransform interface differs slightly from jax.numpy.fft.rfftn to respect that batch dimensions precede event dimensions. I.e., rather than specifying the axes along which to transform, one specifies the number of dimensions to transform akin to reinterpreted_batch_ndims in the IndependentTransform. The interface does not support specifying the norm parameter because jit-ing isn't happy with string arguments (cf. jax-ml/jax#3045).

The motivation for this PR is to implement fast Gaussian processes for stationary kernels (cf. section 4 in https://arxiv.org/pdf/2301.08836v4.pdf).

Thank you for taking the time to review my recent PRs. Let me know if they create too much noise in your inbox.

@@ -638,6 +639,11 @@ def feasible_like(self, prototype):
return jax.numpy.zeros_like(prototype)


class _Real(_Complex):
def __call__(self, x):
return super().__call__(x) & (jax.numpy.isreal(x))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprisingly, adding the & leads to some warnings not being emitted in test/test_distributions.py::test_distribution_constraints, but only when jit-ed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think isreal is a jax function, which will generate a tracer under jit.

One option is to switch to np like in PositiveDefinite. It is better to have separate instances for real and complex though. I don't think that current constraints work for complex numbers.

def __init__(
self,
shape=None,
ndims=1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe using event_dim to match patterns in constraints. Could you elaborate on why we need this attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation is to specify the number of dimensions along which to transform, e.g., ndim=1 gives the one-dimensional rfft, ndim=2 gives the two-dimensional rfft, etc. Maybe event_ndims or transform_ndims for consistency with reinterpreted_batch_ndims in IndependentTransform?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't notice that you are supporting n-dimensional fft. Thanks for explaning!
Re event_ndims: looks better to me. We used event_dim just to be consistent with initial distribution designs.

Copy link
Member

@fehiepsi fehiepsi Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I like transform_lengths and transform_ndims better. It is up to you. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've kept it with transform_ndims and transform_shape for now. I was thinking shape could be good because that's the terminology numpy uses (although they just call the parameter s).

def log_abs_det_jacobian(
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
) -> jnp.ndarray:
return 0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you return the whole batch shape for this?

numpyro/distributions/transforms.py Outdated Show resolved Hide resolved
numpyro/distributions/transforms.py Outdated Show resolved Hide resolved
@tillahoffmann tillahoffmann force-pushed the rfft branch 2 times, most recently from 7f89919 to 984f02e Compare March 15, 2024 18:20
def domain(self) -> constraints.Constraint:
return constraints._IndependentConstraint(
constraints._Real(), self.transform_ndims
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe constraints.independent(constraints.real, ...)

@property
def codomain(self) -> constraints.Constraint:
return constraints._IndependentConstraint(
constraints._Complex(), self.transform_ndims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to above

def log_abs_det_jacobian(
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
) -> jnp.ndarray:
return jnp.ones_like(x, shape=x.shape[: -self.transform_ndims])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: zeros_like

I think you need to broadcast x.shape[: -self.transform_ndims] and y.shape[: -self.transform_ndims]

class _Real(_SingletonConstraint):
def __call__(self, x):
# XXX: consider to relax this condition to [-inf, inf] interval
return (x == x) & (x != float("inf")) & (x != float("-inf")) & np.isreal(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove np.isreal here? We don't check for real input across all constraints. If x is a tracer, np.isreal might not work.

@fehiepsi
Copy link
Member

Awesome work!! thanks, Till!

shape = jnp.broadcast_shapes(
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
)
return jnp.ones_like(x, shape=shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, wouldn't this zeros_like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry. Fixed.

@fehiepsi fehiepsi merged commit 136f7c0 into pyro-ppl:master Mar 16, 2024
4 checks passed
@tillahoffmann tillahoffmann deleted the rfft branch March 16, 2024 15:20
@tillahoffmann
Copy link
Contributor Author

Thank you for the fast review and merging!

I just realized we probably want to add the option to unpack the complex coefficients with shape (..., n // 2 + 1) for a signal of shape (..., n) to a real tensor with shape (..., n) because the rest of the library is designed for real-valued tensors. This would be equivalent to what we're doing in Stan here.

One option is to add an argument unpack (or some more informative name) to the transform to indicate if real or complex values should be returned. Another is to add a separate unpacking transform. What do you think?

@fehiepsi
Copy link
Member

Interesting, I feel that it's better to have a separate transform for the unpack version.

@tillahoffmann
Copy link
Contributor Author

tillahoffmann commented Mar 16, 2024

That could look like

class RealFastFourierUnpackTransform(Transform):
    """
    :param size: Size of the last dimension of the transform, required because the size
        cannot be inferred from the shape of the coefficients.
    """
    def __init__(self, size):
        self.size = size

    ...

If we added it with an unpack parameter, we could infer the size from the argument x. I agree that having a separate transform seems like a good idea, but it would also be nice if one doesn't have to specify the shape ahead of time, e.g., if one wants to use the same transform object on different size input. Not sure which one is the better compromise.

Edit: Or did you mean having a transform (which includes unpacking) that inherits from the transform implemented in this PR?

OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* Add `complex` constraint.

* Add real fast Fourier transform.

* Remove redundant domain and codomain definitions.

* Use numpy for `isreal` check.

* Return `log_abs_det_jacobian` with correct batch shape and add test.

* Update parameter names for `RealFastFourierTransform`.

* Remove `isreal` check.

* Broadcast shapes in `log_abs_det_jacobian`.

* Update construction of domain and codomain for `RealFastFourierTransform`.

* Fix incorrect Jacobian.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants