Skip to content

Commit

Permalink
Remove usage of dm-tree in asserts.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 547811329
  • Loading branch information
ChexDev authored and ChexDev committed Jul 18, 2023
1 parent ee3da94 commit 979d532
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 126 deletions.
35 changes: 18 additions & 17 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import jax.numpy as jnp
import jax.test_util as jax_test
import numpy as np
import tree as dm_tree

Scalar = pytypes.Scalar
Array = pytypes.Array
Expand Down Expand Up @@ -1016,8 +1015,8 @@ def _assert_fn(path, leaf):
nonlocal errors
errors.append(f"`None` detected at '{_ai.format_tree_path(path)}'.")

for path, leaf in dm_tree.flatten_with_path(tree):
_assert_fn(path, leaf)
for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
_assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)
if errors:
raise AssertionError("\n".join(errors))

Expand Down Expand Up @@ -1047,8 +1046,8 @@ def _assert_fn(path, leaf):
errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' is not an "
f"ndarray (type={type(leaf)})."))

for path, leaf in dm_tree.flatten_with_path(tree):
_assert_fn(path, leaf)
for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
_assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)
if errors:
raise AssertionError("\n".join(errors))

Expand Down Expand Up @@ -1139,8 +1138,8 @@ def _assert_fn(path, leaf):
errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' has "
f"unexpected type: {type(leaf)}."))

for path, leaf in dm_tree.flatten_with_path(tree):
_assert_fn(path, leaf)
for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
_assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)
if errors:
raise AssertionError("\n".join(errors))

Expand Down Expand Up @@ -1204,8 +1203,8 @@ def _assert_fn(path, leaf):
errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' has "
f"unexpected type: {type(leaf)}."))

for path, leaf in dm_tree.flatten_with_path(tree):
_assert_fn(path, leaf)
for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
_assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)
if errors:
raise AssertionError("\n".join(errors))

Expand Down Expand Up @@ -1256,8 +1255,8 @@ def _assert_fn(path, leaf):
f"jax.Array (type={type(leaf)})."
)

for path, leaf in dm_tree.flatten_with_path(tree):
_assert_fn(path, leaf)
for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
_assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)
if errors:
raise AssertionError("\n".join(errors))

Expand Down Expand Up @@ -1308,8 +1307,8 @@ def _assert_fn(path, leaf):
(f"Tree leaf '{_ai.format_tree_path(path)}' has a shape prefix "
f"different from expected: {suffix} != {shape_prefix}."))

for path, leaf in dm_tree.flatten_with_path(tree):
_assert_fn(path, leaf)
for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
_assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)
if errors:
raise AssertionError("\n".join(errors))

Expand Down Expand Up @@ -1360,8 +1359,8 @@ def _assert_fn(path, leaf):
(f"Tree leaf '{_ai.format_tree_path(path)}' has a shape suffix "
f"different from expected: {suffix} != {shape_suffix}."))

for path, leaf in dm_tree.flatten_with_path(tree):
_assert_fn(path, leaf)
for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
_assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)
if errors:
raise AssertionError("\n".join(errors))

Expand Down Expand Up @@ -1453,8 +1452,10 @@ def wrapped_equality_comparator(leaf_1, leaf_2):
wrapped_equality_comparator, tree_error_msg_fn)

# Trees are guaranteed to have the same structure.
paths = [path for path, _ in dm_tree.flatten_with_path(trees[0])]
trees_leaves = [dm_tree.flatten(tree) for tree in trees]
paths = [
_ai.convert_jax_path_to_dm_path(path)
for path, _ in jax.tree_util.tree_flatten_with_path(trees[0])[0]]
trees_leaves = [jax.tree_util.tree_leaves(tree) for tree in trees]
for leaf_i, path in enumerate(paths):
cmp_fn(path, *[leaves[leaf_i] for leaves in trees_leaves])

Expand Down
46 changes: 41 additions & 5 deletions chex/_src/asserts_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@
import re
import threading
import traceback
from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
from typing import Any, Sequence, Union, Callable, Hashable, List, Optional, Set, Tuple, Type

from absl import logging
from chex._src import pytypes
import jax
from jax.experimental import checkify
import jax.numpy as jnp
import numpy as np
import tree as dm_tree

# Custom pytypes.
TLeaf = Any
Expand Down Expand Up @@ -412,7 +411,8 @@ def assert_trees_all_eq_comparator_jittable(
"`assert_trees_xxx([a, b])` instead of `assert_trees_xxx(a, b)`, or "
"forgot the `error_msg_fn` arg to `assert_trees_xxx`?")

def _tree_error_msg_fn(path: str, i_1: int, i_2: int):
def _tree_error_msg_fn(
path: Tuple[Union[int, str, Hashable]], i_1: int, i_2: int):
if path:
return (
f"Trees {i_1} and {i_2} differ in leaves '{path}':"
Expand All @@ -435,8 +435,10 @@ def _cmp_leaves(path, *leaves):
return verdict

# Trees are guaranteed to have the same structure.
paths = [path for path, _ in dm_tree.flatten_with_path(trees[0])]
trees_leaves = [dm_tree.flatten(tree) for tree in trees]
paths = [
convert_jax_path_to_dm_path(path)
for path, _ in jax.tree_util.tree_flatten_with_path(trees[0])[0]]
trees_leaves = [jax.tree_util.tree_leaves(tree) for tree in trees]

verdict = jnp.array(True)
for leaf_i, path in enumerate(paths):
Expand All @@ -445,3 +447,37 @@ def _cmp_leaves(path, *leaves):
)

return verdict


JaxKeyType = Union[
int,
str,
Hashable,
jax.tree_util.SequenceKey,
jax.tree_util.DictKey,
jax.tree_util.FlattenedIndexKey,
jax.tree_util.GetAttrKey,
]


def convert_jax_path_to_dm_path(
jax_tree_path: Sequence[JaxKeyType],
) -> Tuple[Union[int, str, Hashable]]:
"""Converts a path from jax.tree_util to one from dm-tree."""

# pytype:disable=attribute-error
def _convert_key_fn(key: JaxKeyType) -> Union[int, str, Hashable]:
if isinstance(key, (str, int)):
return key # int | str.
if isinstance(key, jax.tree_util.SequenceKey):
return key.idx # int.
if isinstance(key, jax.tree_util.DictKey):
return key.key # Hashable
if isinstance(key, jax.tree_util.FlattenedIndexKey):
return key.key # int.
if isinstance(key, jax.tree_util.GetAttrKey):
return key.name # str.
raise ValueError(f"Jax tree key '{key}' of type '{type(key)}' not valid.")
# pytype:enable=attribute-error

return tuple(_convert_key_fn(key) for key in jax_tree_path)
103 changes: 0 additions & 103 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,16 +933,6 @@ def test_assert_trees_all_equal_fail_values_close_but_not_equal(self):
with self.assertRaisesRegex(ValueError, error_msg):
asserts._assert_trees_all_equal_jittable(tree1, tree2)

def test_assert_trees_all_equal_nones(self):
tree = {'a': [jnp.zeros((1,))], 'b': None}
asserts.assert_trees_all_equal(tree, tree, ignore_nones=True)
err_regex = _get_err_regex('`None` detected')
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_equal(tree, tree, ignore_nones=False)

with self.assertRaisesRegex(AssertionError, err_regex):
asserts._assert_trees_all_equal_jittable(tree, tree, ignore_nones=False)

def test_assert_trees_all_equal_strict_mode(self):
# See 'notes' section of
# https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_array_equal.html
Expand Down Expand Up @@ -989,15 +979,6 @@ def test_assert_trees_all_close_passes_values_close_but_not_equal(self):
self.assertTrue(
asserts._assert_trees_all_close_jittable(tree1, tree2, rtol=1e-6))

def test_assert_trees_all_close_nones(self):
tree = {'a': [jnp.zeros((1,))], 'b': None}
asserts.assert_trees_all_close(tree, tree, ignore_nones=True)
err_regex = _get_err_regex('`None` detected')
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_close(tree, tree, ignore_nones=False)
with self.assertRaisesRegex(AssertionError, err_regex):
asserts._assert_trees_all_close_jittable(tree, tree, ignore_nones=False)

def test_assert_trees_all_close_bfloat16(self):
tree1 = {'a': jnp.asarray([0.8, 1.6], dtype=jnp.bfloat16)}
tree2 = {
Expand Down Expand Up @@ -1078,13 +1059,6 @@ def test_assert_trees_all_close_ulp_fails_values_gt_maxulp_apart(self):
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_close_ulp(tree1, tree2, maxulp=1)

def test_assert_trees_all_close_ulp_nones(self):
tree = {'a': [jnp.zeros((1,))], 'b': None}
asserts.assert_trees_all_close_ulp(tree, tree, ignore_nones=True)
err_regex = _get_err_regex('`None` detected')
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_close_ulp(tree, tree, ignore_nones=False)

def test_assert_trees_all_close_ulp_fails_bfloat16(self):
tree_f32 = (jnp.array([0.0]),)
tree_bf16 = (jnp.array([0.0], dtype=jnp.bfloat16),)
Expand All @@ -1095,13 +1069,6 @@ def test_assert_trees_all_close_ulp_fails_bfloat16(self):
with self.assertRaisesRegex(ValueError, err_regex): # pylint: disable=g-error-prone-assert-raises
asserts.assert_trees_all_close_ulp(tree_bf16, tree_f32)

def test_assert_trees_all_equal_shapes_nones(self):
tree = {'a': [jnp.zeros((1,))], 'b': None}
asserts.assert_trees_all_equal_shapes(tree, tree, ignore_nones=True)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_trees_all_equal_shapes(tree, tree, ignore_nones=False)

def test_assert_tree_has_only_ndarrays(self):
# Check correct inputs.
asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros(1), 'b': np.ones(3)})
Expand All @@ -1116,13 +1083,6 @@ def test_assert_tree_has_only_ndarrays(self):
_get_err_regex('\'b/1\' is not an ndarray')):
asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros(101), 'b': [1, 2]})

# Check `None`.
tree_with_none = (np.array([2]), None)
asserts.assert_tree_has_only_ndarrays(tree_with_none, ignore_nones=True)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_has_only_ndarrays(tree_with_none)

def test_assert_tree_is_on_host(self):
cpu = jax.devices('cpu')[0]

Expand Down Expand Up @@ -1192,13 +1152,6 @@ def test_assert_tree_is_on_host(self):
allow_sharded_arrays=True,
)

# Check `None`.
tree_with_none = (np.array([2]), None)
asserts.assert_tree_is_on_host(tree_with_none, ignore_nones=True)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_is_on_host(tree_with_none)

def test_assert_tree_is_on_device(self):
# Check CPU platform.
cpu = jax.devices('cpu')[0]
Expand Down Expand Up @@ -1283,12 +1236,6 @@ def test_assert_tree_is_on_device(self):
asserts.assert_tree_is_on_device(
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))}, device=cpu)

# Check `None`.
asserts.assert_tree_is_on_device((None,), ignore_nones=True)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_is_on_device((None,))

def test_assert_tree_is_sharded(self):
np_tree = {'a': np.zeros(1), 'b': np.ones(3)}

Expand Down Expand Up @@ -1396,26 +1343,6 @@ def _format(*devs):
asserts.assert_tree_is_sharded({'a': jax.device_put(np.zeros(1), cpu)},
devices=(cpu,))

# Check `None`.
asserts.assert_tree_is_sharded((None,), devices=(cpu,), ignore_nones=True)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_is_sharded((None,), devices=(cpu,))

def test_assert_tree_no_nones(self):
tree_ok = {'a': [jnp.zeros((1,))], 'b': 1}
asserts.assert_tree_no_nones(tree_ok)

tree_with_none = {'a': [jnp.zeros((1,))], 'b': None}
with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_no_nones(tree_with_none)

# Check `None`.
with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_no_nones(None)

def test_assert_trees_all_close_fails_different_structure(self):
self._assert_tree_structs_validation(asserts.assert_trees_all_close)

Expand Down Expand Up @@ -1513,19 +1440,6 @@ def test_leaf_shape_should_fail_wrong_length(self):
_get_err_regex(r'leaf \'x/y\' has a shape of length 2')):
asserts.assert_tree_shape_prefix(tree, [3, 2, 1])

def test_assert_tree_shape_prefix_none(self):
tree = {'x': np.zeros([3]), 'n': None}
asserts.assert_tree_shape_prefix(tree, (3,), ignore_nones=True)
asserts.assert_tree_shape_prefix(tree, [3], ignore_nones=True)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_prefix(tree, (3,), ignore_nones=False)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_prefix(tree, [3], ignore_nones=False)

@parameterized.named_parameters(
('scalars', ()),
('vectors', (1,)),
Expand Down Expand Up @@ -1576,19 +1490,6 @@ def test_assert_tree_shape_suffix_long_suffix(self):
AssertionError, _get_err_regex('which is smaller than the expected')):
asserts.assert_tree_shape_suffix(tree, [3, 4, 2, 1])

def test_assert_tree_shape_suffix_none(self):
tree = {'x': np.zeros([3]), 'n': None}
asserts.assert_tree_shape_suffix(tree, (3,), ignore_nones=True)
asserts.assert_tree_shape_suffix(tree, [3], ignore_nones=True)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_suffix(tree, (3,), ignore_nones=False)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_suffix(tree, [3], ignore_nones=False)

def test_assert_trees_all_equal_dtypes(self):
t_0 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)}
t_1 = {'x': np.zeros(5, dtype=np.uint16), 'y': np.ones(4, dtype=np.float32)}
Expand Down Expand Up @@ -1663,10 +1564,6 @@ def test_assert_trees_all_equal_none(self):
_get_err_regex('Trees 0 and 2 differ')):
asserts.assert_trees_all_equal_dtypes(t_0, t_1, t_2, ignore_nones=True)
asserts.assert_trees_all_equal_dtypes(t_0, t_1, ignore_nones=True)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_trees_all_equal_dtypes(t_0, t_1, ignore_nones=False)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('trees 0 and 1 do not match')):
asserts.assert_trees_all_equal_dtypes(t_0, t_3, ignore_nones=True)
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
cloudpickle==2.2.0
dm-tree>=0.1.5
1 change: 0 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
absl-py>=0.9.0
typing_extensions>=4.2.0
dm-tree>=0.1.5
jax>=0.4.6
jaxlib>=0.1.37
numpy>=1.25.0
Expand Down

0 comments on commit 979d532

Please sign in to comment.