diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c87cb61..fa83cf5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,14 +23,14 @@ jobs: strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] os: [ubuntu-latest] jax-version: ["newest"] include: - python-version: "3.9" os: "ubuntu-latest" jax-version: "0.4.27" # Keep this in sync with version in requirements.txt - - python-version: "3.11" + - python-version: "3.12" os: "ubuntu-latest" jax-version: "nightly" diff --git a/chex/_src/asserts_internal.py b/chex/_src/asserts_internal.py index dc8cd69..8f93caa 100644 --- a/chex/_src/asserts_internal.py +++ b/chex/_src/asserts_internal.py @@ -24,11 +24,12 @@ import collections import collections.abc +from collections.abc import Hashable import functools import re import threading import traceback -from typing import Any, Sequence, Union, Callable, Hashable, List, Optional, Set, Tuple, Type +from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type from absl import logging from chex._src import pytypes @@ -299,7 +300,7 @@ def format_tree_path(path: Sequence[Any]) -> str: def format_shape_matcher(shape: TShapeMatcher) -> str: - return f"({', '.join('...' if d is Ellipsis else str(d) for d in shape)})" + return f"({', '.join('...' if d is Ellipsis else str(d) for d in shape)})" # pylint: disable=inconsistent-quotes def num_devices_available(devtype: str, backend: Optional[str] = None) -> int: diff --git a/chex/_src/dimensions.py b/chex/_src/dimensions.py index 94d7d2b..695c5be 100644 --- a/chex/_src/dimensions.py +++ b/chex/_src/dimensions.py @@ -14,9 +14,10 @@ # ============================================================================== """Utilities to hold expected dimension sizes.""" +from collections.abc import Sized import math import re -from typing import Any, Collection, Dict, Optional, Sized, Tuple +from typing import Any, Collection, Dict, Optional, Tuple Shape = Tuple[Optional[int], ...]