Skip to content

Commit

Permalink
Migrate away from using JaxTestCase in tests
Browse files Browse the repository at this point in the history
Why? JaxTestCase is deprecated for use outside the JAX project as of version 0.3.1; see https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-1-feb-18-2022

PiperOrigin-RevId: 435166865
  • Loading branch information
Jake VanderPlas authored and PIXDev committed Mar 16, 2022
1 parent 37a2ab1 commit 28adfae
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 37 deletions.
6 changes: 3 additions & 3 deletions dm_pix/_src/augment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from absl.testing import parameterized
from dm_pix._src import augment
import jax
import jax.test_util as jtu
import numpy as np
import tensorflow as tf

Expand All @@ -31,7 +30,7 @@
_KERNEL_SIZE = _IMG_SHAPE[0] / 10.


class _ImageAugmentationTest(jtu.JaxTestCase, parameterized.TestCase):
class _ImageAugmentationTest(parameterized.TestCase):
"""Runs tests for the various augments with the correct arguments."""

def _test_fn_with_random_arg(self, images_list, jax_fn, tf_fn, **kw_range):
Expand All @@ -43,7 +42,8 @@ def _test_fn(self, images_list, jax_fn, tf_fn):
def assertAllCloseTolerant(self, x, y):
# Increase tolerance on TPU due to lower precision.
tol = 1e-2 if jax.local_devices()[0].platform == "tpu" else 1e-4
super().assertAllClose(x, y, rtol=tol, atol=tol)
np.testing.assert_allclose(x, y, rtol=tol, atol=tol)
self.assertEqual(x.dtype, y.dtype)

@parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE),
("out_of_range", _RAND_FLOATS_OUT_OF_RANGE))
Expand Down
25 changes: 14 additions & 11 deletions dm_pix/_src/color_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from dm_pix._src import color_conversion
import jax
import jax.numpy as jnp
import jax.test_util as jtu
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -57,7 +56,6 @@ def generate_test_images(

class ColorConversionTest(
chex.TestCase,
jtu.JaxTestCase,
parameterized.TestCase,
):

Expand All @@ -84,7 +82,7 @@ def test_hsv_to_rgb(self, test_images, channel_last):
rgb_jax = hsv_to_rgb(hsv)
if not channel_last:
rgb_jax = rgb_jax.swapaxes(-1, -3)
self.assertAllClose(rgb_jax, rgb_tf, rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(rgb_jax, rgb_tf, rtol=1e-3, atol=1e-3)

@chex.all_variants
@parameterized.product(
Expand All @@ -108,7 +106,7 @@ def test_rgb_to_hsv(self, test_images, channel_last):
hsv_jax = rgb_to_hsv(rgb)
if not channel_last:
hsv_jax = hsv_jax.swapaxes(-1, -3)
self.assertAllClose(hsv_jax, hsv_tf, rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(hsv_jax, hsv_tf, rtol=1e-3, atol=1e-3)

@chex.all_variants
def test_vmap_roundtrip(self):
Expand All @@ -118,14 +116,14 @@ def test_vmap_roundtrip(self):
hsv_to_rgb = self.variant(jax.vmap(color_conversion.hsv_to_rgb))
hsv = rgb_to_hsv(rgb_init)
rgb_final = hsv_to_rgb(hsv)
self.assertAllClose(rgb_init, rgb_final, rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(rgb_init, rgb_final, rtol=1e-3, atol=1e-3)

def test_jit_roundtrip(self):
images = generate_test_images(*TestImages.RAND_FLOATS_IN_RANGE.value)
rgb_init = np.stack(images, axis=0)
hsv = jax.jit(color_conversion.rgb_to_hsv)(rgb_init)
rgb_final = jax.jit(color_conversion.hsv_to_rgb)(hsv)
self.assertAllClose(rgb_init, rgb_final, rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(rgb_init, rgb_final, rtol=1e-3, atol=1e-3)

@chex.all_variants
@parameterized.named_parameters(
Expand Down Expand Up @@ -158,7 +156,8 @@ def test_rgb_to_hsl_golden(self, minval, maxval):
image_rgb = np.reshape(image_rgb, _IMG_SHAPE)
hsl_true = np.reshape(hsl_true, _IMG_SHAPE)
rgb_to_hsl = self.variant(color_conversion.rgb_to_hsl)
self.assertAllClose(rgb_to_hsl(image_rgb), hsl_true)
np.testing.assert_allclose(
rgb_to_hsl(image_rgb), hsl_true, atol=1E-5, rtol=1E-5)

@chex.all_variants
@parameterized.named_parameters(
Expand Down Expand Up @@ -202,7 +201,8 @@ def test_hsl_to_rgb_golden(self):
rgb_true = np.reshape(rgb_true, _IMG_SHAPE)
image_hsl = np.reshape(image_hsl, _IMG_SHAPE)
hsl_to_rgb = self.variant(color_conversion.hsl_to_rgb)
self.assertAllClose(hsl_to_rgb(image_hsl), rgb_true)
np.testing.assert_allclose(
hsl_to_rgb(image_hsl), rgb_true, atol=1E-5, rtol=1E-5)

@chex.all_variants
def test_hsl_rgb_roundtrip(self):
Expand All @@ -217,7 +217,8 @@ def test_hsl_rgb_roundtrip(self):

rgb_to_hsl = self.variant(color_conversion.rgb_to_hsl)
hsl_to_rgb = self.variant(color_conversion.hsl_to_rgb)
self.assertAllClose(image_rgb, hsl_to_rgb(rgb_to_hsl(image_rgb)))
np.testing.assert_allclose(
image_rgb, hsl_to_rgb(rgb_to_hsl(image_rgb)), atol=1E-5, rtol=1E-5)

@chex.all_variants
@parameterized.product(
Expand Down Expand Up @@ -246,9 +247,11 @@ def test_grayscale(self, test_images, keep_dims, channel_last):
grayscale_jax = grayscale_jax.swapaxes(-1, -3)
if keep_dims:
for i in range(_IMG_SHAPE[-1]):
self.assertAllClose(grayscale_jax[..., [i]], grayscale_tf)
np.testing.assert_allclose(
grayscale_jax[..., [i]], grayscale_tf, atol=1E-5, rtol=1E-5)
else:
self.assertAllClose(grayscale_jax, grayscale_tf)
np.testing.assert_allclose(
grayscale_jax, grayscale_tf, atol=1E-5, rtol=1E-5)


if __name__ == "__main__":
Expand Down
7 changes: 3 additions & 4 deletions dm_pix/_src/depth_and_space_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from absl.testing import parameterized
import chex
from dm_pix._src import depth_and_space
import jax.test_util as jtu
import numpy as np
import tensorflow as tf


class DepthAndSpaceTest(chex.TestCase, jtu.JaxTestCase, parameterized.TestCase):
class DepthAndSpaceTest(chex.TestCase, parameterized.TestCase):

@chex.all_variants
@parameterized.parameters(([1, 1, 1, 9], 3), ([2, 2, 2, 8], 2))
Expand All @@ -32,7 +31,7 @@ def test_depth_to_space(self, input_shape, block_size):
inputs = np.reshape(inputs, input_shape)
output_tf = tf.nn.depth_to_space(inputs, block_size).numpy()
output_jax = depth_to_space_fn(inputs, block_size)
self.assertArraysEqual(output_tf, output_jax)
np.testing.assert_array_equal(output_tf, output_jax)

@chex.all_variants
@parameterized.parameters(([1, 3, 3, 1], 3), ([2, 4, 4, 2], 2))
Expand All @@ -43,7 +42,7 @@ def test_space_to_depth(self, input_shape, block_size):
inputs = np.reshape(inputs, input_shape)
output_tf = tf.nn.space_to_depth(inputs, block_size).numpy()
output_jax = space_to_depth_fn(inputs, block_size)
self.assertArraysEqual(output_tf, output_jax)
np.testing.assert_array_equal(output_tf, output_jax)


if __name__ == "__main__":
Expand Down
17 changes: 9 additions & 8 deletions dm_pix/_src/interpolation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import chex
from dm_pix._src import interpolation
import jax.numpy as jnp
import jax.test_util as jtu
import numpy as np

_SHAPE_COORDS = ((1, 1), (1, 3), (3, 2), (4, 4), (4, 1, 4), (4, 2, 2))

Expand Down Expand Up @@ -121,21 +121,22 @@ def _prepare_expected(shape_coordinates: Sequence[int]) -> jnp.ndarray:
return out


class InterpolationTest(chex.TestCase, jtu.JaxTestCase, parameterized.TestCase):
class InterpolationTest(chex.TestCase, parameterized.TestCase):

@chex.all_variants
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_{shape}_coords", shape_coordinates=shape)
for shape in _SHAPE_COORDS))
@parameterized.named_parameters([
dict(testcase_name=f"_{shape}_coords", shape_coordinates=shape)
for shape in _SHAPE_COORDS
])
def test_flat_nd_linear_interpolate(self, shape_coordinates):
volume, coords = _prepare_inputs(shape_coordinates)
expected = _prepare_expected(shape_coordinates)

flat_nd_linear_interpolate = self.variant(
interpolation.flat_nd_linear_interpolate)
self.assertAllClose(flat_nd_linear_interpolate(volume, coords), expected)
self.assertAllClose(
np.testing.assert_allclose(
flat_nd_linear_interpolate(volume, coords), expected)
np.testing.assert_allclose(
flat_nd_linear_interpolate(
volume.flatten(), coords, unflattened_vol_shape=volume.shape),
expected)
Expand Down
15 changes: 7 additions & 8 deletions dm_pix/_src/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
import chex
from dm_pix._src import metrics
import jax
import jax.test_util as jtu
import numpy as np
import tensorflow as tf


class MSETest(chex.TestCase, jtu.JaxTestCase, absltest.TestCase):
class MSETest(chex.TestCase, absltest.TestCase):

def setUp(self):
super().setUp()
Expand All @@ -48,16 +47,16 @@ def test_psnr_match(self):
psnr = self.variant(metrics.psnr)
values_jax = psnr(self._img1, self._img2)
values_tf = tf.image.psnr(self._img1, self._img2, max_val=1.).numpy()
self.assertAllClose(values_jax, values_tf, rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(values_jax, values_tf, rtol=1e-3, atol=1e-3)

@chex.all_variants
def test_simse_invariance(self):
simse = self.variant(metrics.simse)
simse_jax = simse(self._img1, self._img1 * 2.0)
self.assertAllClose(simse_jax, np.zeros(4), rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(simse_jax, np.zeros(4), rtol=1e-6, atol=1e-6)


class SSIMTests(chex.TestCase, jtu.JaxTestCase, absltest.TestCase):
class SSIMTests(chex.TestCase, absltest.TestCase):

@chex.all_variants
def test_ssim_golden(self):
Expand Down Expand Up @@ -99,9 +98,9 @@ def test_ssim_golden(self):
))
ssim = ssim_fn(img0, img1)
if not return_map:
self.assertAllClose(ssim, ssim_gt, atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(ssim, ssim_gt, atol=1e-5, rtol=1e-5)
else:
self.assertAllClose(
np.testing.assert_allclose(
np.mean(ssim, list(range(-3, 0))),
ssim_gt,
atol=1e-5,
Expand All @@ -126,7 +125,7 @@ def test_ssim_lowerbound(self):
k2=eps,
))
ssim = ssim_fn(img, -img)
self.assertAllClose(ssim, -np.ones_like(ssim))
np.testing.assert_allclose(ssim, -np.ones_like(ssim), atol=1E-5, rtol=1E-5)


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions dm_pix/_src/patch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from absl.testing import parameterized
import chex
from dm_pix._src import patch
import jax.test_util as jtu
import numpy as np
import tensorflow as tf

Expand All @@ -29,7 +28,7 @@ def _create_test_images(shape):
return np.reshape(images, shape)


class PatchTest(chex.TestCase, jtu.JaxTestCase, parameterized.TestCase):
class PatchTest(chex.TestCase, parameterized.TestCase):

@chex.all_variants
@parameterized.named_parameters(
Expand Down Expand Up @@ -60,7 +59,7 @@ def test_extract_patches(self, padding):
rates=rates,
padding=padding,
)
self.assertArraysEqual(jax_patches, tf_patches.numpy())
np.testing.assert_array_equal(jax_patches, tf_patches.numpy())

@chex.all_variants
@parameterized.product(
Expand Down

0 comments on commit 28adfae

Please sign in to comment.