Skip to content

Commit

Permalink
test args from jnp.array -> np.array
Browse files Browse the repository at this point in the history
  • Loading branch information
brendancooley committed Jun 9, 2024
1 parent 4fa787a commit c9c1457
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 40 deletions.
3 changes: 2 additions & 1 deletion numpyro/contrib/hsgp/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

from jaxlib.xla_extension import ArrayImpl
import numpy as np

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -210,7 +211,7 @@ def _convert_ell(
"The length of ell must be equal to the dimension of the space."
)
ell_ = jnp.array(ell)[..., None] # dim x 1 array
elif isinstance(ell, jax.Array):
elif isinstance(ell, jax.Array | np.ndarray):
ell_ = ell
if ell_.shape != (dim, 1):
raise ValueError("ell must be a scalar or a list of length `dim`.")
Expand Down
1 change: 1 addition & 0 deletions numpyro/contrib/hsgp/spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def modified_bessel_first_kind(v, z):
) from e

v = jnp.asarray(v, dtype=float)
z = jnp.asarray(z, dtype=float)
return jnp.exp(jnp.abs(z)) * tfp.math.bessel_ive(v, z)


Expand Down
51 changes: 26 additions & 25 deletions test/contrib/hsgp/test_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from operator import mul
from typing import Literal

import numpy as np
import pytest
from sklearn.gaussian_process.kernels import RBF, ExpSineSquared, Matern

Expand Down Expand Up @@ -75,11 +76,11 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]:
@pytest.mark.parametrize(
argnames="x1, x2, length, ell",
argvalues=[
(jnp.array([[1.0]]), jnp.array([[0.0]]), jnp.array([1.0]), 5.0),
(np.array([[1.0]]), np.array([[0.0]]), np.array([1.0]), 5.0),
(
jnp.array([[1.5, 1.25]]),
jnp.array([[0.0, 0.0]]),
jnp.array([1.0]),
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array([1.0]),
5.0,
),
],
Expand Down Expand Up @@ -111,20 +112,20 @@ def test_kernel_approx_squared_exponential(
@pytest.mark.parametrize(
argnames="x1, x2, nu, length, ell",
argvalues=[
(jnp.array([[1.0]]), jnp.array([[0.0]]), 3 / 2, jnp.array([1.0]), 5.0),
(jnp.array([[1.0]]), jnp.array([[0.0]]), 5 / 2, jnp.array([1.0]), 5.0),
(np.array([[1.0]]), np.array([[0.0]]), 3 / 2, np.array([1.0]), 5.0),
(np.array([[1.0]]), np.array([[0.0]]), 5 / 2, np.array([1.0]), 5.0),
(
jnp.array([[1.5, 1.25]]),
jnp.array([[0.0, 0.0]]),
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
3 / 2,
jnp.array([1.0]),
np.array([1.0]),
5.0,
),
(
jnp.array([[1.5, 1.25]]),
jnp.array([[0.0, 0.0]]),
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
5 / 2,
jnp.array([1.0]),
np.array([1.0]),
5.0,
),
],
Expand Down Expand Up @@ -160,8 +161,8 @@ def test_kernel_approx_squared_matern(
@pytest.mark.parametrize(
argnames="x1, x2, w0, length",
argvalues=[
(jnp.array([1.0]), jnp.array([0.0]), 1.0, 1.0),
(jnp.array([1.0]), jnp.array([0.0]), 1.5, 1.0),
(np.array([1.0]), np.array([0.0]), 1.0, 1.0),
(np.array([1.0]), np.array([0.0]), 1.5, 1.0),
],
ids=[
"1d,w0=1.0",
Expand Down Expand Up @@ -199,10 +200,10 @@ def test_kernel_approx_periodic(
@pytest.mark.parametrize(
argnames="x, alpha, length, ell, m, non_centered",
argvalues=[
(jnp.linspace(0, 1, 10), 1.0, 0.2, 12, 10, True),
(jnp.linspace(0, 1, 10), 1.0, 0.2, 12, 10, False),
(jnp.linspace(0, 10, 100), 3.0, 0.5, 120, 100, True),
(jnp.linspace(jnp.zeros(2), jnp.ones(2), 10), 1.0, 0.2, 12, [3, 3], True),
(np.linspace(0, 1, 10), 1.0, 0.2, 12, 10, True),
(np.linspace(0, 1, 10), 1.0, 0.2, 12, 10, False),
(np.linspace(0, 10, 100), 3.0, 0.5, 120, 100, True),
(np.linspace(np.zeros(2), np.ones(2), 10), 1.0, 0.2, 12, [3, 3], True),
],
ids=["non_centered", "centered", "non_centered-large-domain", "non_centered-2d"],
)
Expand Down Expand Up @@ -242,11 +243,11 @@ def model(x, alpha, length, ell, m, non_centered):
@pytest.mark.parametrize(
argnames="x, nu, alpha, length, ell, m, non_centered",
argvalues=[
(jnp.linspace(0, 1, 10), 3 / 2, 1.0, 0.2, 12, 10, True),
(jnp.linspace(0, 1, 10), 5 / 2, 1.0, 0.2, 12, 10, False),
(jnp.linspace(0, 10, 100), 7 / 2, 3.0, 0.5, 120, 100, True),
(np.linspace(0, 1, 10), 3 / 2, 1.0, 0.2, 12, 10, True),
(np.linspace(0, 1, 10), 5 / 2, 1.0, 0.2, 12, 10, False),
(np.linspace(0, 10, 100), 7 / 2, 3.0, 0.5, 120, 100, True),
(
jnp.linspace(jnp.zeros(2), jnp.ones(2), 10),
np.linspace(np.zeros(2), np.ones(2), 10),
3 / 2,
1.0,
0.2,
Expand Down Expand Up @@ -420,9 +421,9 @@ def model(x, nu, ell, m, non_centered, y=None):
@pytest.mark.parametrize(
argnames="w0, m",
argvalues=[
(2 * jnp.pi / 7, 2),
(2 * jnp.pi / 10, 3),
(2 * jnp.pi / 5, 10),
(2 * np.pi / 7, 2),
(2 * np.pi / 10, 3),
(2 * np.pi / 5, 10),
],
ids=["m=2", "m=3", "m=10"],
)
Expand Down
17 changes: 9 additions & 8 deletions test/contrib/hsgp/test_laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import reduce
from operator import mul

import numpy as np
import pytest

from jax._src.array import ArrayImpl
Expand Down Expand Up @@ -96,13 +97,13 @@ def test_sqrt_eigenvalues(ell: float | int, m: int | list[int], dim: int):
@pytest.mark.parametrize(
argnames="x, ell, m",
argvalues=[
(jnp.linspace(0, 1, 10), 1, 1),
(jnp.linspace(-1, 1, 10), 1, 21),
(jnp.linspace(-2, -1, 10), 2, 10),
(jnp.linspace(0, 100, 500), 120, 100),
(jnp.linspace(jnp.zeros(3), jnp.ones(3), 10), 2, [2, 2, 3]),
(np.linspace(0, 1, 10), 1, 1),
(np.linspace(-1, 1, 10), 1, 21),
(np.linspace(-2, -1, 10), 2, 10),
(np.linspace(0, 100, 500), 120, 100),
(np.linspace(np.zeros(3), np.ones(3), 10), 2, [2, 2, 3]),
(
jnp.linspace(jnp.zeros(3), jnp.ones(3), 100).reshape((10, 10, 3)),
np.linspace(np.zeros(3), np.ones(3), 100).reshape((10, 10, 3)),
2,
[2, 2, 3],
),
Expand All @@ -129,8 +130,8 @@ def test_eigenfunctions(x: ArrayImpl, ell: float | int, m: int | list[int]):
(1, 1, False),
(1, 2, False),
([1, 1], 2, False),
(jnp.array([1, 1])[..., None], 2, False),
(jnp.array([1, 1]), 2, True),
(np.array([1, 1])[..., None], 2, False),
(np.array([1, 1]), 2, True),
([1, 1], 1, True),
],
ids=[
Expand Down
13 changes: 7 additions & 6 deletions test/contrib/hsgp/test_spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import reduce
from operator import mul

import numpy as np
import pytest

import jax.numpy as jnp
Expand All @@ -22,8 +23,8 @@
argnames="dim, w, alpha, length",
argvalues=[
(1, 0.1, 1.0, 0.2),
(2, jnp.array([0.1, 0.2]), 1.0, 0.2),
(3, jnp.array([0.1, 0.2, 0.3]), 1.0, 5.0),
(2, np.array([0.1, 0.2]), 1.0, 0.2),
(3, np.array([0.1, 0.2, 0.3]), 1.0, 5.0),
],
ids=["dim=1", "dim=2", "dim=3"],
)
Expand All @@ -39,8 +40,8 @@ def test_spectral_density_squared_exponential(dim, w, alpha, length):
argnames="dim, nu, w, alpha, length",
argvalues=[
(1, 3 / 2, 0.1, 1.0, 0.2),
(2, 5 / 2, jnp.array([0.1, 0.2]), 1.0, 0.2),
(3, 5 / 2, jnp.array([0.1, 0.2, 0.3]), 1.0, 5.0),
(2, 5 / 2, np.array([0.1, 0.2]), 1.0, 0.2),
(3, 5 / 2, np.array([0.1, 0.2, 0.3]), 1.0, 5.0),
],
ids=["dim=1", "dim=2", "dim=3"],
)
Expand Down Expand Up @@ -113,8 +114,8 @@ def test_modified_bessel_first_kind_one_dim(v, z):
@pytest.mark.parametrize(
argnames="v, z",
argvalues=[
(jnp.linspace(0.1, 1.0, 10), jnp.array([0.1])),
(jnp.linspace(0.1, 1.0, 10), jnp.linspace(0.1, 1.0, 10)),
(np.linspace(0.1, 1.0, 10), np.array([0.1])),
(np.linspace(0.1, 1.0, 10), np.linspace(0.1, 1.0, 10)),
],
ids=["z=0.1", "z=0.2"],
)
Expand Down

0 comments on commit c9c1457

Please sign in to comment.