diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fa83cf5..c8b10ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,7 +30,7 @@ jobs: - python-version: "3.9" os: "ubuntu-latest" jax-version: "0.4.27" # Keep this in sync with version in requirements.txt - - python-version: "3.12" + - python-version: "3.13" os: "ubuntu-latest" jax-version: "nightly" diff --git a/chex/_src/dataclass_test.py b/chex/_src/dataclass_test.py index 3003a9e..f981c8a 100644 --- a/chex/_src/dataclass_test.py +++ b/chex/_src/dataclass_test.py @@ -21,6 +21,7 @@ import pickle import sys from typing import Any, Generic, Mapping, TypeVar +import unittest from absl.testing import absltest from absl.testing import parameterized @@ -30,7 +31,12 @@ import cloudpickle import jax import numpy as np -import tree + +# dm-tree is not compatible with Python 3.13. +try: + import tree # pylint:disable=g-import-not-at-top +except ImportError: + tree = None chex_dataclass = dataclass.dataclass mappable_dataclass = dataclass.mappable_dataclass @@ -209,7 +215,9 @@ def _init_testdata(self, test_type): self.dcls_tree_size = 18 self.dcls_tree_size_no_dicts = 14 + @unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13') def testFlattenAndUnflatten(self, test_type): + assert tree is not None self._init_testdata(test_type) self.assertEqual(self.dcls_flattened, tree.flatten(self.dcls_with_map)) @@ -223,21 +231,27 @@ def testFlattenAndUnflatten(self, test_type): self.assertEqual(dataclass_in_seq, tree.unflatten_as(dataclass_in_seq, dataclass_in_seq_flat)) + @unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13') def testFlattenUpTo(self, test_type): + assert tree is not None self._init_testdata(test_type) structure = copy.copy(self.dcls_with_map) structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map' self.assertEqual(self.dcls_flattened_up_to, tree.flatten_up_to(structure, self.dcls_with_map)) + @unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13') def testFlattenWithPath(self, test_type): + assert tree is not None self._init_testdata(test_type) self.assertEqual( tree.flatten_with_path(self.dcls_with_map), self.dcls_flattened_with_path) + @unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13') def testFlattenWithPathUpTo(self, test_type): + assert tree is not None self._init_testdata(test_type) structure = copy.copy(self.dcls_with_map) structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map' @@ -245,7 +259,9 @@ def testFlattenWithPathUpTo(self, test_type): tree.flatten_with_path_up_to(structure, self.dcls_with_map), self.dcls_flattened_with_path_up_to) + @unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13') def testMapStructure(self, test_type): + assert tree is not None self._init_testdata(test_type) add_one_to_ints_fn = lambda x: x + 1 if isinstance(x, int) else x @@ -256,7 +272,9 @@ def testMapStructure(self, test_type): self.dcls_with_map_inc_ints.k_int * 10) self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10) + @unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13') def testMapStructureUpTo(self, test_type): + assert tree is not None self._init_testdata(test_type) structure = copy.copy(self.dcls_with_map) @@ -273,7 +291,9 @@ def testMapStructureUpTo(self, test_type): self.dcls_with_map_inc_ints.k_int * 10) self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10) + @unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13') def testMapStructureWithPath(self, test_type): + assert tree is not None self._init_testdata(test_type) add_one_to_ints_fn = lambda path, x: x + 1 if isinstance(x, int) else x @@ -285,7 +305,9 @@ def testMapStructureWithPath(self, test_type): self.dcls_with_map_inc_ints.k_int * 10) self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10) + @unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13') def testMapStructureWithPathUpTo(self, test_type): + assert tree is not None self._init_testdata(test_type) structure = copy.copy(self.dcls_with_map) @@ -303,7 +325,9 @@ def testMapStructureWithPathUpTo(self, test_type): self.dcls_with_map_inc_ints.k_int * 10) self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10) + @unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13') def testTraverse(self, test_type): + assert tree is not None self._init_testdata(test_type) visited = [] @@ -358,7 +382,7 @@ def test_tree_flatten_with_keys(self): ) leaves = [l for _, l in keys_and_leaves] new_obj = treedef.unflatten(leaves) - self.assertEqual(new_obj, obj) + asserts.assert_trees_all_equal(new_obj, obj) def test_tree_map_with_keys(self): obj = dummy_dataclass() diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index c5923a7..035bfac 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,2 +1,3 @@ cloudpickle==2.2.0 -dm-tree>=0.1.5 +# dm-tree is not compatible with Python 3.13. +dm-tree>=0.1.5; python_version < "3.13"