Skip to content

Commit

Permalink
Add HSGP contribution module (pyro-ppl#1794)
Browse files Browse the repository at this point in the history
* hsgp_init

* add licesnse

* simplyfy function names and add author reference

* feedback part 3

* feedback part 4

* fix name
  • Loading branch information
juanitorduz authored May 13, 2024
1 parent 5eb134d commit 753553f
Show file tree
Hide file tree
Showing 10 changed files with 1,977 additions and 0 deletions.
202 changes: 202 additions & 0 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,205 @@ Stochastic Support
:undoc-members:
:show-inheritance:
:member-order: bysource


Hilbert Space Gaussian Processes Approximation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This module contains helper functions for use in the Hilbert Space Gaussian Process (HSGP) approximation method
described in [1] and [2].

.. warning::
This module is experimental. Currently, it only supports Gaussian processes with one-dimensional inputs.


**Why do we need an approximation?**

Gaussian processes do not scale well with the number of data points. Recall we had to invert the kernel matrix!
The computational complexity of the Gaussian process model is :math:`\mathcal{O}(n^3)`, where :math:`n` is the number of data
points. The HSGP approximation method is a way to reduce the computational complexity of the Gaussian process model
to :math:`\mathcal{O}(mn + m)`, where :math:`m` is the number of basis functions used in the approximation.

**Approximation Strategy Steps:**

We strongly recommend reading [1] and [2] for a detailed explanation of the approximation method. In [3] you can find
a practical approach using NumPyro and PyMC.

Here we provide the main steps and ingredients of the approximation method:

1. Each stationary kernel :math:`k` has an associated spectral density :math:`S(\omega)`. There are closed formulas for the most common kernels. These formulas depend on the hyperparameters of the kernel (e.g. amplitudes and length scales).
2. We can approximate the spectral density :math:`S(\omega)` as a polynomial series in :math:`||\omega||`. We call :math:`\omega` the frequency.
3. We can interpret these polynomial terms as "powers" of the Laplacian operator. The key observation is that the Fourier transform of the Laplacian operator is :math:`||\omega||^2`.
4. Next, we impose Dirichlet boundary conditions on the Laplacian operator which makes it self-adjoint and with discrete spectrum.
5. We identify the expansion in (2) with the sum of powers of the Laplacian operator in the eigenbasis of (4).

For the one dimensional case the approximation formula, in the non-centered parameterization, is:

.. math::
f(x) \approx \sum_{j = 1}^{m}
\overbrace{\color{red}{\left(S(\sqrt{\lambda_j})\right)^{1/2}}}^{\text{all hyperparameters are here!}}
\times
\underbrace{\color{blue}{\phi_{j}(x)}}_{\text{easy to compute!}}
\times
\overbrace{\color{green}{\beta_{j}}}^{\sim \: \text{Normal}(0,1)}
where :math:`\lambda_j` are the eigenvalues of the Laplacian operator, :math:`\phi_{j}(x)` are the eigenfunctions of the
Laplacian operator, and :math:`\beta_{j}` are the coefficients of the expansion (see Eq. (8) in [2]). We expect this
to be a good approximation for a finite number of :math:`m` terms in the series as long as the inputs values :math:`x`
are not too close to the boundaries `ell` amd `-ell`.

.. note::
Even though the periodic kernel is not stationary, one can still adapt and find a similar approximation formula.
See Appendix B in [2] for more details.

**Example:**

Here is an example of how to use the HSGP approximation method with NumPyro. We will use the squared exponential kernel.
Other kernels can be used similarly.

.. code-block:: python
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> from numpyro.contrib.hsgp.approximation import hsgp_squared_exponential
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, NUTS
>>> def generate_synthetic_data(rng_key, start, stop: float, num, scale):
... """Generate synthetic data."""
... x = jnp.linspace(start=start, stop=stop, num=num)
... y = jnp.sin(4 * jnp.pi * x) + jnp.sin(7 * jnp.pi * x)
... y_obs = y + scale * random.normal(rng_key, shape=(num,))
... return x, y_obs
>>> rng_key = random.PRNGKey(seed=42)
>>> rng_key, rng_subkey = random.split(rng_key)
>>> x, y_obs = generate_synthetic_data(
... rng_key=rng_subkey, start=0, stop=1, num=80, scale=0.3
>>> )
>>> def model(x, ell, m, non_centered, y=None):
... # --- Priors ---
... alpha = numpyro.sample("alpha", dist.InverseGamma(concentration=12, rate=10))
... length = numpyro.sample("length", dist.InverseGamma(concentration=6, rate=1))
... noise = numpyro.sample("noise", dist.InverseGamma(concentration=12, rate=10))
... # --- Parametrization ---
... f = hsgp_squared_exponential(
... x=x, alpha=alpha, length=length, ell=ell, m=m, non_centered=non_centered
... )
... # --- Likelihood ---
... with numpyro.plate("data", x.shape[0]):
... numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y)
>>> sampler = NUTS(model)
>>> mcmc = MCMC(sampler=sampler, num_warmup=500, num_samples=1_000, num_chains=2)
>>> rng_key, rng_subkey = random.split(rng_key)
>>> ell = 1.3
>>> m = 20
>>> non_centered = True
>>> mcmc.run(rng_subkey, x, ell, m, non_centered, y_obs)
>>> mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
alpha 1.24 0.34 1.18 0.72 1.74 1804.01 1.00
beta[0] -0.10 0.66 -0.10 -1.24 0.92 1819.91 1.00
beta[1] 0.00 0.71 -0.01 -1.09 1.26 1872.82 1.00
beta[2] -0.05 0.69 -0.03 -1.09 1.16 2105.88 1.00
beta[3] 0.25 0.74 0.26 -0.98 1.42 2281.30 1.00
beta[4] -0.17 0.69 -0.17 -1.21 1.00 2551.39 1.00
beta[5] 0.09 0.75 0.10 -1.13 1.30 3232.13 1.00
beta[6] -0.49 0.75 -0.49 -1.65 0.82 3042.31 1.00
beta[7] 0.42 0.75 0.44 -0.78 1.65 2885.42 1.00
beta[8] 0.69 0.71 0.71 -0.48 1.82 2811.68 1.00
beta[9] -1.43 0.75 -1.40 -2.63 -0.21 2858.68 1.00
beta[10] 0.33 0.71 0.33 -0.77 1.51 2198.65 1.00
beta[11] 1.09 0.73 1.11 -0.23 2.18 2765.99 1.00
beta[12] -0.91 0.72 -0.91 -2.06 0.31 2586.53 1.00
beta[13] 0.05 0.70 0.04 -1.16 1.12 2569.59 1.00
beta[14] -0.44 0.71 -0.44 -1.58 0.73 2626.09 1.00
beta[15] 0.69 0.73 0.70 -0.45 1.88 2626.32 1.00
beta[16] 0.98 0.74 0.98 -0.15 2.28 2282.86 1.00
beta[17] -2.54 0.77 -2.52 -3.82 -1.29 3347.56 1.00
beta[18] 1.35 0.66 1.35 0.30 2.46 2638.17 1.00
beta[19] 1.10 0.54 1.09 0.25 2.01 2428.37 1.00
length 0.07 0.01 0.07 0.06 0.09 2321.67 1.00
noise 0.33 0.03 0.33 0.29 0.38 2472.83 1.00
Number of divergences: 0
.. note::
Additional examples with code can be found in [3], [4] and [5].

**References:**

1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
Stat Comput 30, 419-446 (2020).

2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

3. `Orduz, J., A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods <https://juanitorduz.github.io/hsgp_intro>`_.

4. `Example: Hilbert space approximation for Gaussian processes <https://num.pyro.ai/en/stable/examples/hsgp.html>`_.

5. `Gelman, Vehtari, Simpson, et al., Bayesian workflow book - Birthdays <https://avehtari.github.io/casestudies/Birthdays/birthdays.html>`_.

.. note::
The code of this module is based on the code of the example
`Example: Hilbert space approximation for Gaussian processes <https://num.pyro.ai/en/stable/examples/hsgp.html>`_ by `Omar Sosa Rodríguez <https://github.com/omarfsosa>`_.

sqrt_eigenvalues
----------------
.. autofunction:: numpyro.contrib.hsgp.laplacian.sqrt_eigenvalues

eigenfunctions
--------------
.. autofunction:: numpyro.contrib.hsgp.laplacian.eigenfunctions

eigenfunctions_periodic
-----------------------
.. autofunction:: numpyro.contrib.hsgp.laplacian.eigenfunctions_periodic

spectral_density_squared_exponential
------------------------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.spectral_density_squared_exponential

spectral_density_matern
-----------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.spectral_density_matern

diag_spectral_density_squared_exponential
-----------------------------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.diag_spectral_density_squared_exponential

diag_spectral_density_matern
----------------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.diag_spectral_density_matern

diag_spectral_density_periodic
------------------------------
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.diag_spectral_density_periodic

hsgp_squared_exponential
------------------------
.. autofunction:: numpyro.contrib.hsgp.approximation.hsgp_squared_exponential

hsgp_matern
-----------
.. autofunction:: numpyro.contrib.hsgp.approximation.hsgp_matern

hsgp_periodic_non_centered
--------------------------
.. autofunction:: numpyro.contrib.hsgp.approximation.hsgp_periodic_non_centered
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ NumPyro documentation
tutorials/bad_posterior_geometry
tutorials/truncated_distributions
tutorials/censoring
tutorials/hsgp_example

.. nbgallery::
:maxdepth: 1
Expand Down
940 changes: 940 additions & 0 deletions notebooks/source/hsgp_example.ipynb

Large diffs are not rendered by default.

Empty file.
181 changes: 181 additions & 0 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
This module contains the low-rank approximation functions of the Hilbert space Gaussian process.
"""

from jaxlib.xla_extension import ArrayImpl

import jax.numpy as jnp

import numpyro
from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic
from numpyro.contrib.hsgp.spectral_densities import (
diag_spectral_density_matern,
diag_spectral_density_periodic,
diag_spectral_density_squared_exponential,
)
import numpyro.distributions as dist


def _non_centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)


def _centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

return phi @ beta


def linear_approximation(
phi: ArrayImpl, spd: ArrayImpl, m: int, non_centered: bool = True
) -> ArrayImpl:
"""
Linear approximation formula of the Hilbert space Gaussian process.
See Eq. (8) in [1].
**References:**
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param ArrayImpl phi: laplacian eigenfunctions
:param ArrayImpl spd: square root of the diagonal of the spectral density evaluated at square
root of the first `m` eigenvalues.
:param int m: number of eigenfunctions in the approximation
:param bool non_centered: whether to use a non-centered parameterization
:return: The low-rank approximation linear model
:rtype: ArrayImpl
"""
if non_centered:
return _non_centered_approximation(phi, spd, m)
return _centered_approximation(phi, spd, m)


def hsgp_squared_exponential(
x: ArrayImpl,
alpha: float,
length: float,
ell: float,
m: int,
non_centered: bool = True,
) -> ArrayImpl:
"""
Hilbert space Gaussian process approximation using the squared exponential kernel.
The main idea of the approach is to combine the associated spectral density of the
squared exponential kernel and the spectrum of the Dirichlet Laplacian operator to
obtain a low-rank approximation of the Gram matrix. For more details see [1, 2].
**References:**
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
Stat Comput 30, 419-446 (2020).
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param ArrayImpl x: input data
:param float alpha: amplitude of the squared exponential kernel
:param float length: length scale of the squared exponential kernel
:param float ell: positive value that parametrizes the length of the one-dimensional box so that the input data
lies in the interval [-ell, ell]. We expect the approximation to be valid within this interval
:param int m: number of eigenvalues to compute and include in the approximation
:param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True
:return: the low-rank approximation linear model
:rtype: ArrayImpl
"""
phi = eigenfunctions(x=x, ell=ell, m=m)
spd = jnp.sqrt(
diag_spectral_density_squared_exponential(
alpha=alpha, length=length, ell=ell, m=m
)
)
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered)


def hsgp_matern(
x: ArrayImpl,
nu: float,
alpha: float,
length: float,
ell: float,
m: int,
non_centered: bool = True,
):
"""
Hilbert space Gaussian process approximation using the Matérn kernel.
The main idea of the approach is to combine the associated spectral density of the
Matérn kernel kernel and the spectrum of the Dirichlet Laplacian operator to obtain
a low-rank approximation of the Gram matrix. For more details see [1, 2].
**References:**
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
Stat Comput 30, 419-446 (2020).
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param ArrayImpl x: input data
:param float nu: smoothness parameter
:param float alpha: amplitude of the squared exponential kernel
:param float length: length scale of the squared exponential kernel
:param float ell: positive value that parametrizes the length of the one-dimensional box so that the input data
lies in the interval [-ell, ell]. We expect the approximation to be valid within this interval.
:param int m: number of eigenvalues to compute and include in the approximation
:param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True.
:return: the low-rank approximation linear model
:rtype: ArrayImpl
"""
phi = eigenfunctions(x=x, ell=ell, m=m)
spd = jnp.sqrt(
diag_spectral_density_matern(nu=nu, alpha=alpha, length=length, ell=ell, m=m)
)
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered)


def hsgp_periodic_non_centered(
x: ArrayImpl, alpha: float, length: float, w0: float, m: int
) -> ArrayImpl:
"""
Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization.
See Appendix B in [1].
**References:**
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
:param ArrayImpl x: input data
:param float alpha: amplitude
:param float length: length scale
:param float w0: frequency of the periodic kernel
:param int m: number of eigenvalues to compute and include in the approximation
:return: the low-rank approximation linear model
:rtype: ArrayImpl
"""
q2 = diag_spectral_density_periodic(alpha=alpha, length=length, m=m)
cosines, sines = eigenfunctions_periodic(x=x, w0=w0, m=m)

with numpyro.plate("cos_basis", m):
beta_cos = numpyro.sample("beta_cos", dist.Normal(0, 1))

with numpyro.plate("sin_basis", m - 1):
beta_sin = numpyro.sample("beta_sin", dist.Normal(0, 1))

# The first eigenfunction for the sine component
# is zero, so the first parameter wouldn't contribute to the approximation.
# We set it to zero to identify the model and avoid divergences.
zero = jnp.array([0.0])
beta_sin = jnp.concatenate((zero, beta_sin))

return cosines @ (q2 * beta_cos) + sines @ (q2 * beta_sin)
Loading

0 comments on commit 753553f

Please sign in to comment.