From c9c1457cad9c2e230ab7f149ffa2925d119b3667 Mon Sep 17 00:00:00 2001 From: Brendan Cooley Date: Sun, 9 Jun 2024 11:53:05 -0400 Subject: [PATCH] test args from jnp.array -> np.array --- numpyro/contrib/hsgp/laplacian.py | 3 +- numpyro/contrib/hsgp/spectral_densities.py | 1 + test/contrib/hsgp/test_approximation.py | 51 ++++++++++---------- test/contrib/hsgp/test_laplacian.py | 17 ++++--- test/contrib/hsgp/test_spectral_densities.py | 13 ++--- 5 files changed, 45 insertions(+), 40 deletions(-) diff --git a/numpyro/contrib/hsgp/laplacian.py b/numpyro/contrib/hsgp/laplacian.py index 99fcc2c78..3329c436d 100644 --- a/numpyro/contrib/hsgp/laplacian.py +++ b/numpyro/contrib/hsgp/laplacian.py @@ -8,6 +8,7 @@ from __future__ import annotations from jaxlib.xla_extension import ArrayImpl +import numpy as np import jax import jax.numpy as jnp @@ -210,7 +211,7 @@ def _convert_ell( "The length of ell must be equal to the dimension of the space." ) ell_ = jnp.array(ell)[..., None] # dim x 1 array - elif isinstance(ell, jax.Array): + elif isinstance(ell, jax.Array | np.ndarray): ell_ = ell if ell_.shape != (dim, 1): raise ValueError("ell must be a scalar or a list of length `dim`.") diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index 0af858540..0d4d3db3d 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -166,6 +166,7 @@ def modified_bessel_first_kind(v, z): ) from e v = jnp.asarray(v, dtype=float) + z = jnp.asarray(z, dtype=float) return jnp.exp(jnp.abs(z)) * tfp.math.bessel_ive(v, z) diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index 718e7a093..a941652b1 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -7,6 +7,7 @@ from operator import mul from typing import Literal +import numpy as np import pytest from sklearn.gaussian_process.kernels import RBF, ExpSineSquared, Matern @@ -75,11 +76,11 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]: @pytest.mark.parametrize( argnames="x1, x2, length, ell", argvalues=[ - (jnp.array([[1.0]]), jnp.array([[0.0]]), jnp.array([1.0]), 5.0), + (np.array([[1.0]]), np.array([[0.0]]), np.array([1.0]), 5.0), ( - jnp.array([[1.5, 1.25]]), - jnp.array([[0.0, 0.0]]), - jnp.array([1.0]), + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + np.array([1.0]), 5.0, ), ], @@ -111,20 +112,20 @@ def test_kernel_approx_squared_exponential( @pytest.mark.parametrize( argnames="x1, x2, nu, length, ell", argvalues=[ - (jnp.array([[1.0]]), jnp.array([[0.0]]), 3 / 2, jnp.array([1.0]), 5.0), - (jnp.array([[1.0]]), jnp.array([[0.0]]), 5 / 2, jnp.array([1.0]), 5.0), + (np.array([[1.0]]), np.array([[0.0]]), 3 / 2, np.array([1.0]), 5.0), + (np.array([[1.0]]), np.array([[0.0]]), 5 / 2, np.array([1.0]), 5.0), ( - jnp.array([[1.5, 1.25]]), - jnp.array([[0.0, 0.0]]), + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), 3 / 2, - jnp.array([1.0]), + np.array([1.0]), 5.0, ), ( - jnp.array([[1.5, 1.25]]), - jnp.array([[0.0, 0.0]]), + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), 5 / 2, - jnp.array([1.0]), + np.array([1.0]), 5.0, ), ], @@ -160,8 +161,8 @@ def test_kernel_approx_squared_matern( @pytest.mark.parametrize( argnames="x1, x2, w0, length", argvalues=[ - (jnp.array([1.0]), jnp.array([0.0]), 1.0, 1.0), - (jnp.array([1.0]), jnp.array([0.0]), 1.5, 1.0), + (np.array([1.0]), np.array([0.0]), 1.0, 1.0), + (np.array([1.0]), np.array([0.0]), 1.5, 1.0), ], ids=[ "1d,w0=1.0", @@ -199,10 +200,10 @@ def test_kernel_approx_periodic( @pytest.mark.parametrize( argnames="x, alpha, length, ell, m, non_centered", argvalues=[ - (jnp.linspace(0, 1, 10), 1.0, 0.2, 12, 10, True), - (jnp.linspace(0, 1, 10), 1.0, 0.2, 12, 10, False), - (jnp.linspace(0, 10, 100), 3.0, 0.5, 120, 100, True), - (jnp.linspace(jnp.zeros(2), jnp.ones(2), 10), 1.0, 0.2, 12, [3, 3], True), + (np.linspace(0, 1, 10), 1.0, 0.2, 12, 10, True), + (np.linspace(0, 1, 10), 1.0, 0.2, 12, 10, False), + (np.linspace(0, 10, 100), 3.0, 0.5, 120, 100, True), + (np.linspace(np.zeros(2), np.ones(2), 10), 1.0, 0.2, 12, [3, 3], True), ], ids=["non_centered", "centered", "non_centered-large-domain", "non_centered-2d"], ) @@ -242,11 +243,11 @@ def model(x, alpha, length, ell, m, non_centered): @pytest.mark.parametrize( argnames="x, nu, alpha, length, ell, m, non_centered", argvalues=[ - (jnp.linspace(0, 1, 10), 3 / 2, 1.0, 0.2, 12, 10, True), - (jnp.linspace(0, 1, 10), 5 / 2, 1.0, 0.2, 12, 10, False), - (jnp.linspace(0, 10, 100), 7 / 2, 3.0, 0.5, 120, 100, True), + (np.linspace(0, 1, 10), 3 / 2, 1.0, 0.2, 12, 10, True), + (np.linspace(0, 1, 10), 5 / 2, 1.0, 0.2, 12, 10, False), + (np.linspace(0, 10, 100), 7 / 2, 3.0, 0.5, 120, 100, True), ( - jnp.linspace(jnp.zeros(2), jnp.ones(2), 10), + np.linspace(np.zeros(2), np.ones(2), 10), 3 / 2, 1.0, 0.2, @@ -420,9 +421,9 @@ def model(x, nu, ell, m, non_centered, y=None): @pytest.mark.parametrize( argnames="w0, m", argvalues=[ - (2 * jnp.pi / 7, 2), - (2 * jnp.pi / 10, 3), - (2 * jnp.pi / 5, 10), + (2 * np.pi / 7, 2), + (2 * np.pi / 10, 3), + (2 * np.pi / 5, 10), ], ids=["m=2", "m=3", "m=10"], ) diff --git a/test/contrib/hsgp/test_laplacian.py b/test/contrib/hsgp/test_laplacian.py index f7b79d295..2749a2c7d 100644 --- a/test/contrib/hsgp/test_laplacian.py +++ b/test/contrib/hsgp/test_laplacian.py @@ -6,6 +6,7 @@ from functools import reduce from operator import mul +import numpy as np import pytest from jax._src.array import ArrayImpl @@ -96,13 +97,13 @@ def test_sqrt_eigenvalues(ell: float | int, m: int | list[int], dim: int): @pytest.mark.parametrize( argnames="x, ell, m", argvalues=[ - (jnp.linspace(0, 1, 10), 1, 1), - (jnp.linspace(-1, 1, 10), 1, 21), - (jnp.linspace(-2, -1, 10), 2, 10), - (jnp.linspace(0, 100, 500), 120, 100), - (jnp.linspace(jnp.zeros(3), jnp.ones(3), 10), 2, [2, 2, 3]), + (np.linspace(0, 1, 10), 1, 1), + (np.linspace(-1, 1, 10), 1, 21), + (np.linspace(-2, -1, 10), 2, 10), + (np.linspace(0, 100, 500), 120, 100), + (np.linspace(np.zeros(3), np.ones(3), 10), 2, [2, 2, 3]), ( - jnp.linspace(jnp.zeros(3), jnp.ones(3), 100).reshape((10, 10, 3)), + np.linspace(np.zeros(3), np.ones(3), 100).reshape((10, 10, 3)), 2, [2, 2, 3], ), @@ -129,8 +130,8 @@ def test_eigenfunctions(x: ArrayImpl, ell: float | int, m: int | list[int]): (1, 1, False), (1, 2, False), ([1, 1], 2, False), - (jnp.array([1, 1])[..., None], 2, False), - (jnp.array([1, 1]), 2, True), + (np.array([1, 1])[..., None], 2, False), + (np.array([1, 1]), 2, True), ([1, 1], 1, True), ], ids=[ diff --git a/test/contrib/hsgp/test_spectral_densities.py b/test/contrib/hsgp/test_spectral_densities.py index 4794015e7..c51d9caf5 100644 --- a/test/contrib/hsgp/test_spectral_densities.py +++ b/test/contrib/hsgp/test_spectral_densities.py @@ -4,6 +4,7 @@ from functools import reduce from operator import mul +import numpy as np import pytest import jax.numpy as jnp @@ -22,8 +23,8 @@ argnames="dim, w, alpha, length", argvalues=[ (1, 0.1, 1.0, 0.2), - (2, jnp.array([0.1, 0.2]), 1.0, 0.2), - (3, jnp.array([0.1, 0.2, 0.3]), 1.0, 5.0), + (2, np.array([0.1, 0.2]), 1.0, 0.2), + (3, np.array([0.1, 0.2, 0.3]), 1.0, 5.0), ], ids=["dim=1", "dim=2", "dim=3"], ) @@ -39,8 +40,8 @@ def test_spectral_density_squared_exponential(dim, w, alpha, length): argnames="dim, nu, w, alpha, length", argvalues=[ (1, 3 / 2, 0.1, 1.0, 0.2), - (2, 5 / 2, jnp.array([0.1, 0.2]), 1.0, 0.2), - (3, 5 / 2, jnp.array([0.1, 0.2, 0.3]), 1.0, 5.0), + (2, 5 / 2, np.array([0.1, 0.2]), 1.0, 0.2), + (3, 5 / 2, np.array([0.1, 0.2, 0.3]), 1.0, 5.0), ], ids=["dim=1", "dim=2", "dim=3"], ) @@ -113,8 +114,8 @@ def test_modified_bessel_first_kind_one_dim(v, z): @pytest.mark.parametrize( argnames="v, z", argvalues=[ - (jnp.linspace(0.1, 1.0, 10), jnp.array([0.1])), - (jnp.linspace(0.1, 1.0, 10), jnp.linspace(0.1, 1.0, 10)), + (np.linspace(0.1, 1.0, 10), np.array([0.1])), + (np.linspace(0.1, 1.0, 10), np.linspace(0.1, 1.0, 10)), ], ids=["z=0.1", "z=0.2"], )