Skip to content

Commit

Permalink
Add chex.Dimensions utility for readable shape asserts.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 456255200
  • Loading branch information
KristianHolsheimer authored and ChexDev committed Jun 22, 2022
1 parent 6723dee commit 41eb7f6
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 1 deletion.
1 change: 1 addition & 0 deletions chex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from chex._src.dataclass import dataclass
from chex._src.dataclass import mappable_dataclass
from chex._src.dataclass import register_dataclass_type_with_jax_tree_util
from chex._src.dimensions import Dimensions
from chex._src.fake import fake_jit
from chex._src.fake import fake_pmap
from chex._src.fake import fake_pmap_and_jit
Expand Down
202 changes: 202 additions & 0 deletions chex/_src/dimensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to hold expected dimension sizes."""

import re
from typing import Any, Collection, Dict, Optional, Sized, Tuple


Shape = Tuple[Optional[int], ...]


class Dimensions:
"""A lightweight utility that maps strings to shape tuples.
The most basic usage is:
.. code::
>>> dims = chex.Dimensions(B=3, T=5, N=7) # You can specify any letters.
>>> dims['NBT']
(7, 3, 5)
This is useful when dealing with many differently shaped arrays. For instance,
let's check the shape of this array:
.. code::
>>> x = jnp.array([[2, 0, 5, 6, 3],
... [5, 4, 4, 3, 3],
... [0, 0, 5, 2, 0]])
>>> chex.assert_shape(x, dims['BT'])
The dimension sizes can be gotten directly, e.g. :code:`dims.N == 7`. This can
be useful in many applications. For instance, let's one-hot encode our array.
.. code::
>>> y = jax.nn.one_hot(x, dims.N)
>>> chex.assert_shape(y, dims['BTN'])
You can also store the shape of a given array in :code:`dims`, e.g.
.. code::
>>> z = jnp.array([[0, 6, 0, 2],
... [4, 2, 2, 4]])
>>> dims['XY'] = z.shape
>>> dims
Dimensions(B=3, N=7, T=5, X=2, Y=4)
You can set a wildcard dimension, cf. :func:`chex.assert_shape`:
.. code::
>>> dims.W = None
>>> dims['BTW']
(3, 5, None)
Or you can use the wildcard character `'*'` directly:
.. code::
>>> dims['BT*']
(3, 5, None)
Single digits are interpreted as literal integers. Note that this notation
is limited to single-digit literals.
.. code::
>>> dims['BT123']
(3, 5, 1, 2, 3)
Support for single digits was mainly included to accommodate dummy axes
introduced for consistent broadcasting. For instance, instead of using
:func:`jnp.expand_dims <jax.numpy.expand_dims>` you could do the following:
.. code::
>>> w = y * x # Cannot broadcast (3, 5, 7) with (3, 5)
Traceback (most recent call last):
...
ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5))
>>> w = y * x.reshape(dims['BT1'])
>>> chex.assert_shape(w, dims['BTN'])
Sometimes you only care about some array dimensions but not all. You can use
an underscore to ignore an axis, e.g.
.. code::
>>> chex.assert_rank(y, 3)
>>> dims['__M'] = y.shape # Skip the first two axes.
Finally note that a single-character key returns a tuple of length one.
.. code::
>>> dims['M']
(7,)
"""
HAS_DYNAMIC_ATTRIBUTES = True

def __init__(self, **dim_sizes):
for dim, size in dim_sizes.items():
self._setdim(dim, size)

def __getitem__(self, key: str) -> Shape:
self._validate_key(key)
return tuple(self._getdim(dim) for dim in key)

def __setitem__(self, key: str, value: Collection[Optional[int]]):
self._validate_key(key)
self._validate_value(value)
if len(key) != len(value):
raise ValueError(
f'key string {repr(key)} and shape {tuple(value)} '
'have different lengths')
for dim, size in zip(key, value):
self._setdim(dim, size)

def __delitem__(self, key: str):
self._validate_key(key)
for dim in key:
self._deldim(dim)

def __repr__(self) -> str:
args = ', '.join(f'{k}={v}' for k, v in sorted(self._asdict().items()))
return f'{type(self).__name__}({args})'

def _asdict(self) -> Dict[str, Optional[int]]:
return {k: v for k, v in self.__dict__.items()
if re.fullmatch(r'[a-zA-Z]', k)}

def _getdim(self, dim: str) -> Optional[int]:
if dim == '*':
return None
if re.fullmatch(r'[0-9]', dim):
return int(dim)
try:
return getattr(self, dim)
except AttributeError as e:
raise KeyError(dim) from e

def _setdim(self, dim: str, size: Optional[int]):
if dim == '_': # Skip.
return
self._validate_dim(dim)
setattr(self, dim, _optional_int(size))

def _deldim(self, dim: str):
if dim == '_': # Skip.
return
self._validate_dim(dim)
try:
return delattr(self, dim)
except AttributeError as e:
raise KeyError(dim) from e

def _validate_key(self, key: Any):
if not isinstance(key, str):
raise TypeError(f'key must be a string; got: {type(key).__name__}')

def _validate_value(self, value: Any):
if not isinstance(value, Sized):
raise TypeError(
'value must be sized, i.e. an object with a well-defined len(value); '
f'got object of type: {type(value).__name__}')

def _validate_dim(self, dim: Any):
if not isinstance(dim, str):
raise TypeError(
f'dimension name must be a string; got: {type(dim).__name__}')
if not re.fullmatch(r'[a-zA-Z]', dim):
raise KeyError(
'dimension names may only be contain letters (or \'_\' to skip); '
f'got dimension name: {repr(dim)}')


def _optional_int(x: Any) -> Optional[int]:
if x is None:
return None
try:
i = int(x)
if x == i:
return i
except ValueError:
pass
raise TypeError(f'object cannot be interpreted as a python int: {repr(x)}')
96 changes: 96 additions & 0 deletions chex/_src/dimensions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ``dimensions`` module."""

import doctest

from absl.testing import absltest
from absl.testing import parameterized
from chex._src import asserts
from chex._src import dimensions
import jax
import numpy as np


class _ChexModule:
"""Mock module for providing minimal context to docstring tests."""
assert_shape = asserts.assert_shape
assert_rank = asserts.assert_rank
Dimensions = dimensions.Dimensions # pylint: disable=invalid-name


class DimensionsTest(parameterized.TestCase):

def test_docstring_examples(self):
doctest.run_docstring_examples(
dimensions.Dimensions,
globs={'chex': _ChexModule, 'jax': jax, 'jnp': jax.numpy})

@parameterized.named_parameters([
('scalar', '', (), ()),
('vector', 'a', (7,), (7,)),
('list', 'ab', [7, 11], (7, 11)),
('numpy_array', 'abc', np.array([7, 11, 13]), (7, 11, 13)),
('case_sensitive', 'aA', (7, 11), (7, 11)),
])
def test_set_ok(self, k, v, shape):
dims = dimensions.Dimensions(x=23, y=29)
dims[k] = v
asserts.assert_shape(np.empty((23, *shape, 29)), dims['x' + k + 'y'])

def test_set_wildcard(self):
dims = dimensions.Dimensions(x=23, y=29)
dims['a_b__'] = (7, 11, 13, 17, 19)
self.assertEqual(dims['xayb'], (23, 7, 29, 13))
with self.assertRaisesRegex(KeyError, r'\*'):
dims['ab*'] = (7, 11, 13)

def test_get_wildcard(self):
dims = dimensions.Dimensions(x=23, y=29)
self.assertEqual(dims['x*y**'], (23, None, 29, None, None))
asserts.assert_shape(np.empty((23, 1, 29, 2, 3)), dims['x*y**'])
with self.assertRaisesRegex(KeyError, r'\_'):
dims['xy_'] # pylint: disable=pointless-statement

def test_get_literals(self):
dims = dimensions.Dimensions(x=23, y=29)
self.assertEqual(dims['x1y23'], (23, 1, 29, 2, 3))

@parameterized.named_parameters([
('scalar', 'a', 7, TypeError, r'value must be sized'),
('iterator', 'a', (x for x in [7]), TypeError, r'value must be sized'),
('len_mismatch', 'ab', (7, 11, 13), ValueError, r'different length'),
('non_integer_size', 'a', (7.001,),
TypeError, r'cannot be interpreted as a python int'),
('bad_key_type', 13, (7,), TypeError, r'key must be a string'),
('bad_key_string', '@%^#', (7, 11, 13, 17), KeyError, r'\@'),
])
def test_set_exception(self, k, v, e, m):
dims = dimensions.Dimensions(x=23, y=29)
with self.assertRaisesRegex(e, m):
dims[k] = v

@parameterized.named_parameters([
('bad_key_type', 13, TypeError, r'key must be a string'),
('bad_key_string', '@%^#', KeyError, r'\@'),
])
def test_get_exception(self, k, e, m):
dims = dimensions.Dimensions(x=23, y=29)
with self.assertRaisesRegex(e, m):
dims[k] # pylint: disable=pointless-statement


if __name__ == '__main__':
absltest.main()
8 changes: 7 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Assertions
assert_trees_all_equal_shapes
assert_trees_all_equal_structs
assert_type
Dimensions
disable_asserts
enable_asserts
clear_trace_counter
Expand Down Expand Up @@ -87,7 +88,6 @@ Generic Assertions
.. autofunction:: assert_axis_dimension
.. autofunction:: assert_axis_dimension_comparator
.. autofunction:: assert_axis_dimension_gt

.. autofunction:: assert_axis_dimension_gteq
.. autofunction:: assert_axis_dimension_lt
.. autofunction:: assert_axis_dimension_lteq
Expand All @@ -111,6 +111,12 @@ Generic Assertions
.. autofunction:: assert_type


Shapes and Named Dimensions
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: Dimensions


Utils
~~~~~

Expand Down

0 comments on commit 41eb7f6

Please sign in to comment.