-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add chex.Dimensions utility for readable shape asserts.
PiperOrigin-RevId: 456255200
- Loading branch information
1 parent
6723dee
commit 41eb7f6
Showing
4 changed files
with
306 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Utilities to hold expected dimension sizes.""" | ||
|
||
import re | ||
from typing import Any, Collection, Dict, Optional, Sized, Tuple | ||
|
||
|
||
Shape = Tuple[Optional[int], ...] | ||
|
||
|
||
class Dimensions: | ||
"""A lightweight utility that maps strings to shape tuples. | ||
The most basic usage is: | ||
.. code:: | ||
>>> dims = chex.Dimensions(B=3, T=5, N=7) # You can specify any letters. | ||
>>> dims['NBT'] | ||
(7, 3, 5) | ||
This is useful when dealing with many differently shaped arrays. For instance, | ||
let's check the shape of this array: | ||
.. code:: | ||
>>> x = jnp.array([[2, 0, 5, 6, 3], | ||
... [5, 4, 4, 3, 3], | ||
... [0, 0, 5, 2, 0]]) | ||
>>> chex.assert_shape(x, dims['BT']) | ||
The dimension sizes can be gotten directly, e.g. :code:`dims.N == 7`. This can | ||
be useful in many applications. For instance, let's one-hot encode our array. | ||
.. code:: | ||
>>> y = jax.nn.one_hot(x, dims.N) | ||
>>> chex.assert_shape(y, dims['BTN']) | ||
You can also store the shape of a given array in :code:`dims`, e.g. | ||
.. code:: | ||
>>> z = jnp.array([[0, 6, 0, 2], | ||
... [4, 2, 2, 4]]) | ||
>>> dims['XY'] = z.shape | ||
>>> dims | ||
Dimensions(B=3, N=7, T=5, X=2, Y=4) | ||
You can set a wildcard dimension, cf. :func:`chex.assert_shape`: | ||
.. code:: | ||
>>> dims.W = None | ||
>>> dims['BTW'] | ||
(3, 5, None) | ||
Or you can use the wildcard character `'*'` directly: | ||
.. code:: | ||
>>> dims['BT*'] | ||
(3, 5, None) | ||
Single digits are interpreted as literal integers. Note that this notation | ||
is limited to single-digit literals. | ||
.. code:: | ||
>>> dims['BT123'] | ||
(3, 5, 1, 2, 3) | ||
Support for single digits was mainly included to accommodate dummy axes | ||
introduced for consistent broadcasting. For instance, instead of using | ||
:func:`jnp.expand_dims <jax.numpy.expand_dims>` you could do the following: | ||
.. code:: | ||
>>> w = y * x # Cannot broadcast (3, 5, 7) with (3, 5) | ||
Traceback (most recent call last): | ||
... | ||
ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5)) | ||
>>> w = y * x.reshape(dims['BT1']) | ||
>>> chex.assert_shape(w, dims['BTN']) | ||
Sometimes you only care about some array dimensions but not all. You can use | ||
an underscore to ignore an axis, e.g. | ||
.. code:: | ||
>>> chex.assert_rank(y, 3) | ||
>>> dims['__M'] = y.shape # Skip the first two axes. | ||
Finally note that a single-character key returns a tuple of length one. | ||
.. code:: | ||
>>> dims['M'] | ||
(7,) | ||
""" | ||
HAS_DYNAMIC_ATTRIBUTES = True | ||
|
||
def __init__(self, **dim_sizes): | ||
for dim, size in dim_sizes.items(): | ||
self._setdim(dim, size) | ||
|
||
def __getitem__(self, key: str) -> Shape: | ||
self._validate_key(key) | ||
return tuple(self._getdim(dim) for dim in key) | ||
|
||
def __setitem__(self, key: str, value: Collection[Optional[int]]): | ||
self._validate_key(key) | ||
self._validate_value(value) | ||
if len(key) != len(value): | ||
raise ValueError( | ||
f'key string {repr(key)} and shape {tuple(value)} ' | ||
'have different lengths') | ||
for dim, size in zip(key, value): | ||
self._setdim(dim, size) | ||
|
||
def __delitem__(self, key: str): | ||
self._validate_key(key) | ||
for dim in key: | ||
self._deldim(dim) | ||
|
||
def __repr__(self) -> str: | ||
args = ', '.join(f'{k}={v}' for k, v in sorted(self._asdict().items())) | ||
return f'{type(self).__name__}({args})' | ||
|
||
def _asdict(self) -> Dict[str, Optional[int]]: | ||
return {k: v for k, v in self.__dict__.items() | ||
if re.fullmatch(r'[a-zA-Z]', k)} | ||
|
||
def _getdim(self, dim: str) -> Optional[int]: | ||
if dim == '*': | ||
return None | ||
if re.fullmatch(r'[0-9]', dim): | ||
return int(dim) | ||
try: | ||
return getattr(self, dim) | ||
except AttributeError as e: | ||
raise KeyError(dim) from e | ||
|
||
def _setdim(self, dim: str, size: Optional[int]): | ||
if dim == '_': # Skip. | ||
return | ||
self._validate_dim(dim) | ||
setattr(self, dim, _optional_int(size)) | ||
|
||
def _deldim(self, dim: str): | ||
if dim == '_': # Skip. | ||
return | ||
self._validate_dim(dim) | ||
try: | ||
return delattr(self, dim) | ||
except AttributeError as e: | ||
raise KeyError(dim) from e | ||
|
||
def _validate_key(self, key: Any): | ||
if not isinstance(key, str): | ||
raise TypeError(f'key must be a string; got: {type(key).__name__}') | ||
|
||
def _validate_value(self, value: Any): | ||
if not isinstance(value, Sized): | ||
raise TypeError( | ||
'value must be sized, i.e. an object with a well-defined len(value); ' | ||
f'got object of type: {type(value).__name__}') | ||
|
||
def _validate_dim(self, dim: Any): | ||
if not isinstance(dim, str): | ||
raise TypeError( | ||
f'dimension name must be a string; got: {type(dim).__name__}') | ||
if not re.fullmatch(r'[a-zA-Z]', dim): | ||
raise KeyError( | ||
'dimension names may only be contain letters (or \'_\' to skip); ' | ||
f'got dimension name: {repr(dim)}') | ||
|
||
|
||
def _optional_int(x: Any) -> Optional[int]: | ||
if x is None: | ||
return None | ||
try: | ||
i = int(x) | ||
if x == i: | ||
return i | ||
except ValueError: | ||
pass | ||
raise TypeError(f'object cannot be interpreted as a python int: {repr(x)}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for ``dimensions`` module.""" | ||
|
||
import doctest | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
from chex._src import asserts | ||
from chex._src import dimensions | ||
import jax | ||
import numpy as np | ||
|
||
|
||
class _ChexModule: | ||
"""Mock module for providing minimal context to docstring tests.""" | ||
assert_shape = asserts.assert_shape | ||
assert_rank = asserts.assert_rank | ||
Dimensions = dimensions.Dimensions # pylint: disable=invalid-name | ||
|
||
|
||
class DimensionsTest(parameterized.TestCase): | ||
|
||
def test_docstring_examples(self): | ||
doctest.run_docstring_examples( | ||
dimensions.Dimensions, | ||
globs={'chex': _ChexModule, 'jax': jax, 'jnp': jax.numpy}) | ||
|
||
@parameterized.named_parameters([ | ||
('scalar', '', (), ()), | ||
('vector', 'a', (7,), (7,)), | ||
('list', 'ab', [7, 11], (7, 11)), | ||
('numpy_array', 'abc', np.array([7, 11, 13]), (7, 11, 13)), | ||
('case_sensitive', 'aA', (7, 11), (7, 11)), | ||
]) | ||
def test_set_ok(self, k, v, shape): | ||
dims = dimensions.Dimensions(x=23, y=29) | ||
dims[k] = v | ||
asserts.assert_shape(np.empty((23, *shape, 29)), dims['x' + k + 'y']) | ||
|
||
def test_set_wildcard(self): | ||
dims = dimensions.Dimensions(x=23, y=29) | ||
dims['a_b__'] = (7, 11, 13, 17, 19) | ||
self.assertEqual(dims['xayb'], (23, 7, 29, 13)) | ||
with self.assertRaisesRegex(KeyError, r'\*'): | ||
dims['ab*'] = (7, 11, 13) | ||
|
||
def test_get_wildcard(self): | ||
dims = dimensions.Dimensions(x=23, y=29) | ||
self.assertEqual(dims['x*y**'], (23, None, 29, None, None)) | ||
asserts.assert_shape(np.empty((23, 1, 29, 2, 3)), dims['x*y**']) | ||
with self.assertRaisesRegex(KeyError, r'\_'): | ||
dims['xy_'] # pylint: disable=pointless-statement | ||
|
||
def test_get_literals(self): | ||
dims = dimensions.Dimensions(x=23, y=29) | ||
self.assertEqual(dims['x1y23'], (23, 1, 29, 2, 3)) | ||
|
||
@parameterized.named_parameters([ | ||
('scalar', 'a', 7, TypeError, r'value must be sized'), | ||
('iterator', 'a', (x for x in [7]), TypeError, r'value must be sized'), | ||
('len_mismatch', 'ab', (7, 11, 13), ValueError, r'different length'), | ||
('non_integer_size', 'a', (7.001,), | ||
TypeError, r'cannot be interpreted as a python int'), | ||
('bad_key_type', 13, (7,), TypeError, r'key must be a string'), | ||
('bad_key_string', '@%^#', (7, 11, 13, 17), KeyError, r'\@'), | ||
]) | ||
def test_set_exception(self, k, v, e, m): | ||
dims = dimensions.Dimensions(x=23, y=29) | ||
with self.assertRaisesRegex(e, m): | ||
dims[k] = v | ||
|
||
@parameterized.named_parameters([ | ||
('bad_key_type', 13, TypeError, r'key must be a string'), | ||
('bad_key_string', '@%^#', KeyError, r'\@'), | ||
]) | ||
def test_get_exception(self, k, e, m): | ||
dims = dimensions.Dimensions(x=23, y=29) | ||
with self.assertRaisesRegex(e, m): | ||
dims[k] # pylint: disable=pointless-statement | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters