Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HSGP contribution module #1794

Merged
merged 6 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,203 @@ Stochastic Support
:undoc-members:
:show-inheritance:
:member-order: bysource


Hilbert Space Gaussian Processes Approximation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This module contains helper functions do use the Hilbert Space Gaussian Process (HSGP) approximation method
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
described in [1] and [2].

.. warning::
This module is experimental. Currently, it only gaussian processes with one-dimensional inputs.
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved


**Why do we need an approximation?**

Gaussian processes do not scale well with the number of data points. Recall we had to invert the kernel matrix!
The computational complexity of the Gaussian process model is :math:`\mathcal{O}(n^3)`, where :math:`n` is the number of data
points. The HSGP approximation method is a way to reduce the computational complexity of the Gaussian process model
to :math:`\mathcal{O}(mn + m)`, where :math:`m` is the number of basis functions used in the approximation.

**Approximation Strategy Steps:**

We strongly recommend reading [1] and [2] for a detailed explanation of the approximation method. In [3] you can find
a practical approach using NumPyro and PyMC.

Here we provide the main steps and ingredients of the approximation method:

1. Each stationary kernel :math:`k` has an associated spectral density :math:`S(\omega)`. There are closed formulas for the most common kernels. These formulas depend on the hyperparameters of the kernel (e.g. amplitudes and length scales).
2. We can approximate the spectral density :math:`S(\omega)` as a polynomial series in :math:`||\omega||`. We call :math:`\omega` the frequency.
3. We can interpret these polynomial terms as "powers" of the Laplacian operator. The key observation is that the Fourier transform of the Laplacian operator is :math:`||\omega||^2`.
4. Next, we impose Dirichlet boundary conditions on the Laplacian operator which makes it self-adjoint and with discrete spectrum.
5. We identify the expansion in (2) with the sum of powers of the Laplacian operator in the eigenbasis of (4).

For the one dimensional case the approximation formula, in the non-centered parameterization, is:

.. math::

f(x) \approx \sum_{j = 1}^{m}
\overbrace{\color{red}{\left(S(\sqrt{\lambda_j})\right)^{1/2}}}^{\text{all hyperparameters are here!}}
\times
\underbrace{\color{blue}{\phi_{j}(x)}}_{\text{easy to compute!}}
\times
\overbrace{\color{green}{\beta_{j}}}^{\sim \: \text{Normal}(0,1)}

where :math:`\lambda_j` are the eigenvalues of the Laplacian operator, :math:`\phi_{j}(x)` are the eigenfunctions of the
Laplacian operator, and :math:`\beta_{j}` are the coefficients of the expansion (see Eq. (8) in [2]).

.. note::
Even though the periodic kernel is not stationary, one can still adapt and find a similar approximation formula.
See Appendix B in [2] for more details.

**Example:**

Here is an example of how to use the HSGP approximation method with NumPyro. We will use the squared exponential kernel.
Other kernels can be used sim
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

>>> from jax import random
>>> import jax.numpy as jnp

>>> import numpyro
>>> from numpyro.contrib.hsgp.approximation import hsgp_approximation_squared_exponential
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, NUTS


>>> def generate_synthetic_data(rng_key, start, stop: float, num, scale):
... """Generate synthetic data."""
... x = jnp.linspace(start=start, stop=stop, num=num)
... y = jnp.sin(4 * jnp.pi * x) + jnp.sin(7 * jnp.pi * x)
... y_obs = y + scale * random.normal(rng_key, shape=(num,))
... return x, y_obs


>>> rng_key = random.PRNGKey(seed=42)
>>> rng_key, rng_subkey = random.split(rng_key)
>>> x, y_obs = generate_synthetic_data(
... rng_key=rng_subkey, start=0, stop=1, num=80, scale=0.3
>>> )


>>> def model(x, ell, m, non_centered, y=None):
... # --- Priors ---
... alpha = numpyro.sample("alpha", dist.InverseGamma(concentration=12, rate=10))
... length = numpyro.sample("length", dist.InverseGamma(concentration=6, rate=1))
... noise = numpyro.sample("noise", dist.InverseGamma(concentration=12, rate=10))
... # --- Parametrization ---
... f = hsgp_approximation_squared_exponential(
... x=x, alpha=alpha, length=length, ell=ell, m=m, non_centered=non_centered
... )
... # --- Likelihood ---
... with numpyro.plate("data", x.shape[0]):
... numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y)


>>> sampler = NUTS(model)
>>> mcmc = MCMC(sampler=sampler, num_warmup=500, num_samples=1_000, num_chains=2)

>>> rng_key, rng_subkey = random.split(rng_key)

>>> ell = 1.3
>>> m = 20
>>> non_centered = True

>>> mcmc.run(rng_subkey, x, ell, m, non_centered, y_obs)

>>> mcmc.print_summary()

mean std median 5.0% 95.0% n_eff r_hat
alpha 1.24 0.34 1.18 0.72 1.74 1804.01 1.00
beta[0] -0.10 0.66 -0.10 -1.24 0.92 1819.91 1.00
beta[1] 0.00 0.71 -0.01 -1.09 1.26 1872.82 1.00
beta[2] -0.05 0.69 -0.03 -1.09 1.16 2105.88 1.00
beta[3] 0.25 0.74 0.26 -0.98 1.42 2281.30 1.00
beta[4] -0.17 0.69 -0.17 -1.21 1.00 2551.39 1.00
beta[5] 0.09 0.75 0.10 -1.13 1.30 3232.13 1.00
beta[6] -0.49 0.75 -0.49 -1.65 0.82 3042.31 1.00
beta[7] 0.42 0.75 0.44 -0.78 1.65 2885.42 1.00
beta[8] 0.69 0.71 0.71 -0.48 1.82 2811.68 1.00
beta[9] -1.43 0.75 -1.40 -2.63 -0.21 2858.68 1.00
beta[10] 0.33 0.71 0.33 -0.77 1.51 2198.65 1.00
beta[11] 1.09 0.73 1.11 -0.23 2.18 2765.99 1.00
beta[12] -0.91 0.72 -0.91 -2.06 0.31 2586.53 1.00
beta[13] 0.05 0.70 0.04 -1.16 1.12 2569.59 1.00
beta[14] -0.44 0.71 -0.44 -1.58 0.73 2626.09 1.00
beta[15] 0.69 0.73 0.70 -0.45 1.88 2626.32 1.00
beta[16] 0.98 0.74 0.98 -0.15 2.28 2282.86 1.00
beta[17] -2.54 0.77 -2.52 -3.82 -1.29 3347.56 1.00
beta[18] 1.35 0.66 1.35 0.30 2.46 2638.17 1.00
beta[19] 1.10 0.54 1.09 0.25 2.01 2428.37 1.00
length 0.07 0.01 0.07 0.06 0.09 2321.67 1.00
noise 0.33 0.03 0.33 0.29 0.38 2472.83 1.00

Number of divergences: 0


.. note::
Additional examples with code can be found in [3], [4] and [5].

**References:**

1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
Stat Comput 30, 419-446 (2020).

2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

3. `Orduz, J., A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods <https://juanitorduz.github.io/hsgp_intro>`_.

4. `Example: Hilbert space approximation for Gaussian processes <https://num.pyro.ai/en/stable/examples/hsgp.html>`_.

5. `Gelman, Vehtari, Simpson, et al., Bayesian workflow book - Birthdays <https://avehtari.github.io/casestudies/Birthdays/birthdays.html>`_.

.. note::
The code of this module is based on the code of the example
`Example: Hilbert space approximation for Gaussian processes <https://num.pyro.ai/en/stable/examples/hsgp.html>`_ by `Omar Sosa Rodríguez <https://github.com/omarfsosa>`_.

sqrt_eigenvalues
----------------
.. autofunction:: numpyro.contrib.hsgp.laplacian.sqrt_eigenvalues

eigenfunctions
--------------
.. autofunction:: numpyro.contrib.hsgp.laplacian.eigenfunctions

eigenfunctions_periodic
-----------------------
.. autofunction:: numpyro.contrib.hsgp.laplacian.eigenfunctions_periodic

spectral_density_squared_exponential
------------------------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.spectral_density_squared_exponential

spectral_density_matern
-----------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.spectral_density_matern

diag_spectral_density_squared_exponential
-----------------------------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.diag_spectral_density_squared_exponential

diag_spectral_density_matern
----------------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.diag_spectral_density_matern

diag_spectral_density_periodic
------------------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.diag_spectral_density_periodic

hsgp_squared_exponential
------------------------
.. autofunction:: numpyro.contrib.hsgp.approximation.hsgp_squared_exponential

hsgp_matern
-----------
.. autofunction:: numpyro.contrib.hsgp.approximation.hsgp_matern

hsgp_periodic_non_centered
--------------------------
.. autofunction:: numpyro.contrib.hsgp.approximation.hsgp_periodic_non_centered
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ NumPyro documentation
tutorials/bad_posterior_geometry
tutorials/truncated_distributions
tutorials/censoring
tutorials/hsgp_example

.. nbgallery::
:maxdepth: 1
Expand Down
940 changes: 940 additions & 0 deletions notebooks/source/hsgp_example.ipynb

Large diffs are not rendered by default.

Empty file.
179 changes: 179 additions & 0 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
This module contains the low-rank approximation functions of the Hilbert space Gaussian process.
"""

from jaxlib.xla_extension import ArrayImpl

import jax.numpy as jnp

import numpyro
from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic
from numpyro.contrib.hsgp.spectral_densities import (
diag_spectral_density_matern,
diag_spectral_density_periodic,
diag_spectral_density_squared_exponential,
)
import numpyro.distributions as dist


def _non_centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)


def _centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

return phi @ beta


def linear_approximation(
phi: ArrayImpl, spd: ArrayImpl, m: int, non_centered: bool = True
) -> ArrayImpl:
"""
Linear approximation formula of the Hilbert space Gaussian process.

See Eq. (8) in [1].

**References:**
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayImpl phi: laplacian eigenfunctions
:param ArrayImpl spd: square root of the diagonal of the spectral density evaluated at square
root of the first `m` eigenvalues.
:param int m: number of eigenfunctions in the approximation
:param bool non_centered: whether to use a non-centered parameterization
:return: The low-rank approximation linear model
:rtype: ArrayImpl
"""
if non_centered:
return _non_centered_approximation(phi, spd, m)
return _centered_approximation(phi, spd, m)


def hsgp_squared_exponential(
x: ArrayImpl,
alpha: float,
length: float,
ell: float,
m: int,
non_centered: bool = True,
) -> ArrayImpl:
"""
Hilbert space Gaussian process approximation using the squared exponential kernel.

The main idea of the approach is to combine the associated spectral density of the
squared exponential kernel and the spectrum of the Drichlet Laplacian operator to
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
obtain a low-rank approximation of the Gram matrix. For more details see [1, 2].

**References:**

1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
Stat Comput 30, 419-446 (2020).

2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayImpl x: input data
:param float alpha: amplitude of the squared exponential kernel
:param float length: length scale of the squared exponential kernel
:param float ell: The length of the interval divided by 2
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
:param int m: The number of eigenvalues to compute
:param bool non_centered: Whether to use a non-centered parameterization. By default, it is set to True
:return: The low-rank approximation linear model
:rtype: ArrayImpl
"""
phi = eigenfunctions(x=x, ell=ell, m=m)
spd = jnp.sqrt(
diag_spectral_density_squared_exponential(
alpha=alpha, length=length, ell=ell, m=m
)
)
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered)


def hsgp_matern(
x: ArrayImpl,
nu: float,
alpha: float,
length: float,
ell: float,
m: int,
non_centered: bool = True,
):
"""
Hilbert space Gaussian process approximation using the Matérn kernel.

The main idea of the approach is to combine the associated spectral density of the
Matérn kernel kernel and the spectrum of the Drichlet Laplacian operator to obtain
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
a low-rank approximation of the Gram matrix. For more details see [1, 2].

**References:**

1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
Stat Comput 30, 419-446 (2020).

2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayImpl x: input data
:param float nu: smoothness parameter
:param float alpha: amplitude of the squared exponential kernel
:param float length: length scale of the squared exponential kernel
:param float ell: The length of the interval divided by 2
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
:param int m: The number of eigenvalues to compute
:param bool non_centered: Whether to use a non-centered parameterization. By default, it is set to True.
:return: The low-rank approximation linear model
:rtype: ArrayImpl
"""
phi = eigenfunctions(x=x, ell=ell, m=m)
spd = jnp.sqrt(
diag_spectral_density_matern(nu=nu, alpha=alpha, length=length, ell=ell, m=m)
)
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered)


def hsgp_periodic_non_centered(
x: ArrayImpl, alpha: float, length: float, w0: float, m: int
) -> ArrayImpl:
"""
Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization.

See Appendix B in [1].

**References:**

1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayImpl x: input data
:param float alpha: amplitude
:param float length: length scale
:param float w0: frequency of the periodic kernel
:param int m: number of eigenfunctions in the approximation
:return: The low-rank approximation linear model
:rtype: ArrayImpl
"""
q2 = diag_spectral_density_periodic(alpha=alpha, length=length, m=m)
cosines, sines = eigenfunctions_periodic(x=x, w0=w0, m=m)

with numpyro.plate("cos_basis", m):
beta_cos = numpyro.sample("beta_cos", dist.Normal(0, 1))

with numpyro.plate("sin_basis", m - 1):
beta_sin = numpyro.sample("beta_sin", dist.Normal(0, 1))

# The first eigenfunction for the sine component
# is zero, so the first parameter wouldn't contribute to the approximation.
# We set it to zero to identify the model and avoid divergences.
zero = jnp.array([0.0])
beta_sin = jnp.concatenate((zero, beta_sin))

return cosines @ (q2 * beta_cos) + sines @ (q2 * beta_sin)
Loading
Loading