-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support vector lengthscales for RBF and Matern kernels #1819
Changes from 2 commits
03c70f3
f0cbaa9
1d16f4e
6982850
82c76a9
74c3593
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Union | ||
|
||
import numpy as np | ||
|
||
import jax | ||
|
||
ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,23 +74,50 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]: | |
|
||
|
||
@pytest.mark.parametrize( | ||
argnames="x1, x2, length, ell", | ||
argnames="x1, x2, length, ell, xfail", | ||
argvalues=[ | ||
(np.array([[1.0]]), np.array([[0.0]]), np.array([1.0]), 5.0), | ||
(np.array([[1.0]]), np.array([[0.0]]), 1.0, 5.0, False), | ||
( | ||
np.array([[1.5, 1.25]]), | ||
np.array([[0.0, 0.0]]), | ||
np.array([1.0]), | ||
1.0, | ||
5.0, | ||
False, | ||
), | ||
( | ||
np.array([[1.5, 1.25]]), | ||
np.array([[0.0, 0.0]]), | ||
np.array([1.0, 0.5]), | ||
5.0, | ||
False, | ||
), | ||
( | ||
np.array([[1.5, 1.25]]), | ||
np.array([[0.0, 0.0]]), | ||
np.array( | ||
[[1.0, 0.5], [0.5, 1.0]] | ||
), # different length scale for each point/dimension | ||
5.0, | ||
False, | ||
), | ||
( | ||
np.array([[1.5, 1.25, 1.0]]), | ||
np.array([[0.0, 0.0, 0.0]]), | ||
np.array([[1.0, 0.5], [0.5, 1.0]]), # invalid length scale | ||
5.0, | ||
True, | ||
), | ||
], | ||
ids=[ | ||
"1d", | ||
"2d,1d-length", | ||
"1d,scalar-length", | ||
"2d,scalar-length", | ||
"2d,vector-length", | ||
"2d,matrix-length", | ||
"2d,invalid-length", | ||
], | ||
) | ||
def test_kernel_approx_squared_exponential( | ||
x1: ArrayImpl, x2: ArrayImpl, length: ArrayImpl, ell: float | ||
x1: ArrayImpl, x2: ArrayImpl, length: float | ArrayImpl, ell: float, xfail: bool | ||
): | ||
"""ensure that the approximation of the squared exponential kernel is accurate, | ||
matching the exact kernel implementation from sklearn. | ||
|
@@ -100,13 +127,26 @@ def test_kernel_approx_squared_exponential( | |
assert x1.shape == x2.shape | ||
m = 100 # large enough to ensure the approximation is accurate | ||
dim = x1.shape[-1] | ||
if xfail: | ||
with pytest.raises(ValueError): | ||
diag_spectral_density_squared_exponential(1.0, length, ell, m, dim) | ||
return | ||
spd = diag_spectral_density_squared_exponential(1.0, length, ell, m, dim) | ||
|
||
eig_f1 = eigenfunctions(x1, ell=ell, m=m) | ||
eig_f2 = eigenfunctions(x2, ell=ell, m=m) | ||
approx = (eig_f1 * eig_f2) @ spd | ||
exact = RBF(length)(x1, x2) | ||
assert jnp.isclose(approx, exact, rtol=1e-3) | ||
|
||
def _exact_rbf(length): | ||
return RBF(length)(x1, x2).squeeze(axis=-1) | ||
|
||
if isinstance(length, float | int): | ||
samanklesaria marked this conversation as resolved.
Show resolved
Hide resolved
|
||
exact = _exact_rbf(length) | ||
elif length.ndim == 1: | ||
exact = _exact_rbf(length) | ||
else: | ||
exact = np.apply_along_axis(_exact_rbf, axis=0, arr=length) | ||
assert jnp.isclose(approx, exact, rtol=1e-3).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -118,14 +158,32 @@ def test_kernel_approx_squared_exponential( | |
np.array([[1.5, 1.25]]), | ||
np.array([[0.0, 0.0]]), | ||
3 / 2, | ||
np.array([1.0]), | ||
np.array([0.25, 0.5]), | ||
5.0, | ||
), | ||
( | ||
np.array([[1.5, 1.25]]), | ||
np.array([[0.0, 0.0]]), | ||
5 / 2, | ||
np.array([1.0]), | ||
np.array([0.25, 0.5]), | ||
5.0, | ||
), | ||
( | ||
np.array([[1.5, 1.25]]), | ||
np.array([[0.0, 0.0]]), | ||
3 / 2, | ||
np.array( | ||
[[1.0, 0.5], [0.5, 1.0]] | ||
), # different length scale for each point/dimension | ||
5.0, | ||
), | ||
( | ||
np.array([[1.5, 1.25]]), | ||
np.array([[0.0, 0.0]]), | ||
5 / 2, | ||
np.array( | ||
[[1.0, 0.5], [0.5, 1.0]] | ||
), # different length scale for each point/dimension | ||
5.0, | ||
), | ||
], | ||
|
@@ -134,6 +192,8 @@ def test_kernel_approx_squared_exponential( | |
"1d,nu=5/2", | ||
"2d,nu=3/2,1d-length", | ||
"2d,nu=5/2,1d-length", | ||
"2d,nu=3/2,2d-length", | ||
"2d,nu=5/2,2d-length", | ||
], | ||
) | ||
def test_kernel_approx_squared_matern( | ||
|
@@ -154,8 +214,17 @@ def test_kernel_approx_squared_matern( | |
eig_f1 = eigenfunctions(x1, ell=ell, m=m) | ||
eig_f2 = eigenfunctions(x2, ell=ell, m=m) | ||
approx = (eig_f1 * eig_f2) @ spd | ||
exact = Matern(length_scale=length, nu=nu)(x1, x2) | ||
assert jnp.isclose(approx, exact, rtol=1e-3) | ||
|
||
def _exact_matern(length): | ||
return Matern(length_scale=length, nu=nu)(x1, x2).squeeze(axis=-1) | ||
|
||
if isinstance(length, float | int): | ||
exact = _exact_matern(length) | ||
elif length.ndim == 1: | ||
exact = _exact_matern(length) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you need two conditions for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (this is just a question, no need to use the OR statement in view of readability) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was copying the code from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. either way fine by me |
||
else: | ||
exact = np.apply_along_axis(_exact_matern, axis=0, arr=length) | ||
assert jnp.isclose(approx, exact, rtol=1e-3).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this type can be used in other NumPyro modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
ArrayImpl
is only an issue when a model gets compiled and the arrays turn into tracers.isinstance(X, jax.Array)
will work for both jax arrays and tracers.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some details here https://jax.readthedocs.io/en/latest/jax_array_migration.html
I believe this is best practice for typing jax arrays (as of last year), but I am not sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am definitively not an expert in type hints, so following the recommendation from the docs seems the safest path :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I mark this thread as resolved, as this seems to be in line with the recommendation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
works for me!