forked from pyro-ppl/numpyro
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HSGP contribution module (pyro-ppl#1794)
* hsgp_init * add licesnse * simplyfy function names and add author reference * feedback part 3 * feedback part 4 * fix name
- Loading branch information
1 parent
5eb134d
commit 753553f
Showing
10 changed files
with
1,977 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
""" | ||
This module contains the low-rank approximation functions of the Hilbert space Gaussian process. | ||
""" | ||
|
||
from jaxlib.xla_extension import ArrayImpl | ||
|
||
import jax.numpy as jnp | ||
|
||
import numpyro | ||
from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic | ||
from numpyro.contrib.hsgp.spectral_densities import ( | ||
diag_spectral_density_matern, | ||
diag_spectral_density_periodic, | ||
diag_spectral_density_squared_exponential, | ||
) | ||
import numpyro.distributions as dist | ||
|
||
|
||
def _non_centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl: | ||
with numpyro.plate("basis", m): | ||
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0)) | ||
|
||
return phi @ (spd * beta) | ||
|
||
|
||
def _centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl: | ||
with numpyro.plate("basis", m): | ||
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd)) | ||
|
||
return phi @ beta | ||
|
||
|
||
def linear_approximation( | ||
phi: ArrayImpl, spd: ArrayImpl, m: int, non_centered: bool = True | ||
) -> ArrayImpl: | ||
""" | ||
Linear approximation formula of the Hilbert space Gaussian process. | ||
See Eq. (8) in [1]. | ||
**References:** | ||
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space | ||
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). | ||
:param ArrayImpl phi: laplacian eigenfunctions | ||
:param ArrayImpl spd: square root of the diagonal of the spectral density evaluated at square | ||
root of the first `m` eigenvalues. | ||
:param int m: number of eigenfunctions in the approximation | ||
:param bool non_centered: whether to use a non-centered parameterization | ||
:return: The low-rank approximation linear model | ||
:rtype: ArrayImpl | ||
""" | ||
if non_centered: | ||
return _non_centered_approximation(phi, spd, m) | ||
return _centered_approximation(phi, spd, m) | ||
|
||
|
||
def hsgp_squared_exponential( | ||
x: ArrayImpl, | ||
alpha: float, | ||
length: float, | ||
ell: float, | ||
m: int, | ||
non_centered: bool = True, | ||
) -> ArrayImpl: | ||
""" | ||
Hilbert space Gaussian process approximation using the squared exponential kernel. | ||
The main idea of the approach is to combine the associated spectral density of the | ||
squared exponential kernel and the spectrum of the Dirichlet Laplacian operator to | ||
obtain a low-rank approximation of the Gram matrix. For more details see [1, 2]. | ||
**References:** | ||
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. | ||
Stat Comput 30, 419-446 (2020). | ||
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space | ||
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). | ||
:param ArrayImpl x: input data | ||
:param float alpha: amplitude of the squared exponential kernel | ||
:param float length: length scale of the squared exponential kernel | ||
:param float ell: positive value that parametrizes the length of the one-dimensional box so that the input data | ||
lies in the interval [-ell, ell]. We expect the approximation to be valid within this interval | ||
:param int m: number of eigenvalues to compute and include in the approximation | ||
:param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True | ||
:return: the low-rank approximation linear model | ||
:rtype: ArrayImpl | ||
""" | ||
phi = eigenfunctions(x=x, ell=ell, m=m) | ||
spd = jnp.sqrt( | ||
diag_spectral_density_squared_exponential( | ||
alpha=alpha, length=length, ell=ell, m=m | ||
) | ||
) | ||
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered) | ||
|
||
|
||
def hsgp_matern( | ||
x: ArrayImpl, | ||
nu: float, | ||
alpha: float, | ||
length: float, | ||
ell: float, | ||
m: int, | ||
non_centered: bool = True, | ||
): | ||
""" | ||
Hilbert space Gaussian process approximation using the Matérn kernel. | ||
The main idea of the approach is to combine the associated spectral density of the | ||
Matérn kernel kernel and the spectrum of the Dirichlet Laplacian operator to obtain | ||
a low-rank approximation of the Gram matrix. For more details see [1, 2]. | ||
**References:** | ||
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. | ||
Stat Comput 30, 419-446 (2020). | ||
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space | ||
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). | ||
:param ArrayImpl x: input data | ||
:param float nu: smoothness parameter | ||
:param float alpha: amplitude of the squared exponential kernel | ||
:param float length: length scale of the squared exponential kernel | ||
:param float ell: positive value that parametrizes the length of the one-dimensional box so that the input data | ||
lies in the interval [-ell, ell]. We expect the approximation to be valid within this interval. | ||
:param int m: number of eigenvalues to compute and include in the approximation | ||
:param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True. | ||
:return: the low-rank approximation linear model | ||
:rtype: ArrayImpl | ||
""" | ||
phi = eigenfunctions(x=x, ell=ell, m=m) | ||
spd = jnp.sqrt( | ||
diag_spectral_density_matern(nu=nu, alpha=alpha, length=length, ell=ell, m=m) | ||
) | ||
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered) | ||
|
||
|
||
def hsgp_periodic_non_centered( | ||
x: ArrayImpl, alpha: float, length: float, w0: float, m: int | ||
) -> ArrayImpl: | ||
""" | ||
Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization. | ||
See Appendix B in [1]. | ||
**References:** | ||
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space | ||
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). | ||
:param ArrayImpl x: input data | ||
:param float alpha: amplitude | ||
:param float length: length scale | ||
:param float w0: frequency of the periodic kernel | ||
:param int m: number of eigenvalues to compute and include in the approximation | ||
:return: the low-rank approximation linear model | ||
:rtype: ArrayImpl | ||
""" | ||
q2 = diag_spectral_density_periodic(alpha=alpha, length=length, m=m) | ||
cosines, sines = eigenfunctions_periodic(x=x, w0=w0, m=m) | ||
|
||
with numpyro.plate("cos_basis", m): | ||
beta_cos = numpyro.sample("beta_cos", dist.Normal(0, 1)) | ||
|
||
with numpyro.plate("sin_basis", m - 1): | ||
beta_sin = numpyro.sample("beta_sin", dist.Normal(0, 1)) | ||
|
||
# The first eigenfunction for the sine component | ||
# is zero, so the first parameter wouldn't contribute to the approximation. | ||
# We set it to zero to identify the model and avoid divergences. | ||
zero = jnp.array([0.0]) | ||
beta_sin = jnp.concatenate((zero, beta_sin)) | ||
|
||
return cosines @ (q2 * beta_cos) + sines @ (q2 * beta_sin) |
Oops, something went wrong.