Skip to content

Commit

Permalink
Chex: add Python 3.13 CI tests
Browse files Browse the repository at this point in the history
This requires making dm-tree a testing requirement only for Python < 3.13, because there is no py313-compatible dm-tree release.

PiperOrigin-RevId: 716138572
  • Loading branch information
Jake VanderPlas authored and ChexDev committed Jan 16, 2025
1 parent 2fe940a commit 512d26e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- python-version: "3.9"
os: "ubuntu-latest"
jax-version: "0.4.27" # Keep this in sync with version in requirements.txt
- python-version: "3.12"
- python-version: "3.13"
os: "ubuntu-latest"
jax-version: "nightly"

Expand Down
28 changes: 26 additions & 2 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pickle
import sys
from typing import Any, Generic, Mapping, TypeVar
import unittest

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -30,7 +31,12 @@
import cloudpickle
import jax
import numpy as np
import tree

# dm-tree is not compatible with Python 3.13.
try:
import tree # pylint:disable=g-import-not-at-top
except ImportError:
tree = None

chex_dataclass = dataclass.dataclass
mappable_dataclass = dataclass.mappable_dataclass
Expand Down Expand Up @@ -209,7 +215,9 @@ def _init_testdata(self, test_type):
self.dcls_tree_size = 18
self.dcls_tree_size_no_dicts = 14

@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
def testFlattenAndUnflatten(self, test_type):
assert tree is not None
self._init_testdata(test_type)

self.assertEqual(self.dcls_flattened, tree.flatten(self.dcls_with_map))
Expand All @@ -223,29 +231,37 @@ def testFlattenAndUnflatten(self, test_type):
self.assertEqual(dataclass_in_seq,
tree.unflatten_as(dataclass_in_seq, dataclass_in_seq_flat))

@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
def testFlattenUpTo(self, test_type):
assert tree is not None
self._init_testdata(test_type)
structure = copy.copy(self.dcls_with_map)
structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map'
self.assertEqual(self.dcls_flattened_up_to,
tree.flatten_up_to(structure, self.dcls_with_map))

@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
def testFlattenWithPath(self, test_type):
assert tree is not None
self._init_testdata(test_type)

self.assertEqual(
tree.flatten_with_path(self.dcls_with_map),
self.dcls_flattened_with_path)

@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
def testFlattenWithPathUpTo(self, test_type):
assert tree is not None
self._init_testdata(test_type)
structure = copy.copy(self.dcls_with_map)
structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map'
self.assertEqual(
tree.flatten_with_path_up_to(structure, self.dcls_with_map),
self.dcls_flattened_with_path_up_to)

@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
def testMapStructure(self, test_type):
assert tree is not None
self._init_testdata(test_type)

add_one_to_ints_fn = lambda x: x + 1 if isinstance(x, int) else x
Expand All @@ -256,7 +272,9 @@ def testMapStructure(self, test_type):
self.dcls_with_map_inc_ints.k_int * 10)
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)

@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
def testMapStructureUpTo(self, test_type):
assert tree is not None
self._init_testdata(test_type)

structure = copy.copy(self.dcls_with_map)
Expand All @@ -273,7 +291,9 @@ def testMapStructureUpTo(self, test_type):
self.dcls_with_map_inc_ints.k_int * 10)
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)

@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
def testMapStructureWithPath(self, test_type):
assert tree is not None
self._init_testdata(test_type)

add_one_to_ints_fn = lambda path, x: x + 1 if isinstance(x, int) else x
Expand All @@ -285,7 +305,9 @@ def testMapStructureWithPath(self, test_type):
self.dcls_with_map_inc_ints.k_int * 10)
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)

@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
def testMapStructureWithPathUpTo(self, test_type):
assert tree is not None
self._init_testdata(test_type)

structure = copy.copy(self.dcls_with_map)
Expand All @@ -303,7 +325,9 @@ def testMapStructureWithPathUpTo(self, test_type):
self.dcls_with_map_inc_ints.k_int * 10)
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)

@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
def testTraverse(self, test_type):
assert tree is not None
self._init_testdata(test_type)

visited = []
Expand Down Expand Up @@ -358,7 +382,7 @@ def test_tree_flatten_with_keys(self):
)
leaves = [l for _, l in keys_and_leaves]
new_obj = treedef.unflatten(leaves)
self.assertEqual(new_obj, obj)
asserts.assert_trees_all_equal(new_obj, obj)

def test_tree_map_with_keys(self):
obj = dummy_dataclass()
Expand Down
3 changes: 2 additions & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
cloudpickle==2.2.0
dm-tree>=0.1.5
# dm-tree is not compatible with Python 3.13.
dm-tree>=0.1.5; python_version < "3.13"

0 comments on commit 512d26e

Please sign in to comment.