Skip to content

Commit

Permalink
Add flattening functionality to chex.Dimensions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617479139
  • Loading branch information
KristianHolsheimer authored and ChexDev committed Mar 20, 2024
1 parent 0855af7 commit bf80d37
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 7 deletions.
47 changes: 43 additions & 4 deletions chex/_src/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ class Dimensions:
>>> dims.size('BT') # Same as prod(dims['BT']).
15
Similarly, you can flatten axes together by wrapping them in parentheses:
.. code::
>>> dims['(BT)N']
(15, 7)
You can set a wildcard dimension, cf. :func:`chex.assert_shape`:
.. code::
Expand Down Expand Up @@ -118,7 +125,6 @@ class Dimensions:
>>> dims['M']
(7,)
"""
# Tell static type checker not to worry about attribute errors.
_HAS_DYNAMIC_ATTRIBUTES = True
Expand All @@ -129,15 +135,48 @@ def __init__(self, **dim_sizes) -> None:

def size(self, key: str) -> int:
"""Returns the flat size of a given named shape, i.e. prod(shape)."""
if None in (shape := self[key]):
shape = self[key]
if any(size is None or size <= 0 for size in shape):
raise ValueError(
f"cannot take product of shape '{key}' = {shape}, "
'because it contains wildcard dimensions')
'because it contains non-positive sized dimensions'
)
return math.prod(shape)

def __getitem__(self, key: str) -> Shape:
self._validate_key(key)
return tuple(self._getdim(dim) for dim in key)
shape = []
open_parentheses = False
dims_to_flatten = ''
for dim in key:
# Signal to start accumulating `dims_to_flatten`.
if dim == '(':
if open_parentheses:
raise ValueError(f"nested parentheses are unsupported; got: '{key}'")
open_parentheses = True

# Signal to collect accumulated `dims_to_flatten`.
elif dim == ')':
if not open_parentheses:
raise ValueError(f"unmatched parentheses in named shape: '{key}'")
if not dims_to_flatten:
raise ValueError(f"found empty parentheses in named shape: '{key}'")
shape.append(self.size(dims_to_flatten))
# Reset.
open_parentheses = False
dims_to_flatten = ''

# Accumulate `dims_to_flatten`.
elif open_parentheses:
dims_to_flatten += dim

# The typical (non-flattening) case.
else:
shape.append(self._getdim(dim))

if open_parentheses:
raise ValueError(f"unmatched parentheses in named shape: '{key}'")
return tuple(shape)

def __setitem__(self, key: str, value: Collection[Optional[int]]) -> None:
self._validate_key(key)
Expand Down
40 changes: 37 additions & 3 deletions chex/_src/dimensions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_get_wildcard(self):
self.assertEqual(dims['x*y**'], (23, None, 29, None, None))
asserts.assert_shape(np.empty((23, 1, 29, 2, 3)), dims['x*y**'])
with self.assertRaisesRegex(KeyError, r'\_'):
dims['xy_'] # pylint: disable=pointless-statement
_ = dims['xy_']

def test_get_literals(self):
dims = dimensions.Dimensions(x=23, y=29)
Expand All @@ -89,7 +89,7 @@ def test_set_exception(self, k, v, e, m):
def test_get_exception(self, k, e, m):
dims = dimensions.Dimensions(x=23, y=29)
with self.assertRaisesRegex(e, m):
dims[k] # pylint: disable=pointless-statement
_ = dims[k]

@parameterized.named_parameters([
('scalar', '', (), 1),
Expand All @@ -102,12 +102,46 @@ def test_size_ok(self, names, shape, expected_size):
@parameterized.named_parameters([
('named', 'ab'),
('asterisk', 'a*'),
('zero', 'a0'),
('negative', 'ac'),
])
def test_size_fail_wildcard(self, names):
dims = dimensions.Dimensions(a=3, b=None)
dims = dimensions.Dimensions(a=3, b=None, c=-1)
with self.assertRaisesRegex(ValueError, r'cannot take product of shape'):
dims.size(names)

@parameterized.named_parameters([
('trivial_start', '(a)bc', (3, 5, 7)),
('trivial_mid', 'a(b)c', (3, 5, 7)),
('trivial_end', 'ab(c)', (3, 5, 7)),
('start', '(ab)cd', (15, 7, 11)),
('mid', 'a(bc)d', (3, 35, 11)),
('end', 'ab(cd)', (3, 5, 77)),
('multiple', '(ab)(cd)', (15, 77)),
('all', '(abc)', (105,)),
])
def test_flatten_ok(self, named_shape, expected_shape):
dims = dimensions.Dimensions(a=3, b=5, c=7, d=11)
self.assertEqual(dims[named_shape], expected_shape)

@parameterized.named_parameters([
('unmatched_open', '(ab', r'unmatched parentheses in named shape'),
('unmatched_closed', 'a)b', r'unmatched parentheses in named shape'),
('nested', '(a(bc))', r'nested parentheses are unsupported'),
('wildcard_named', 'a(bx)', r'cannot take product of shape'),
('wildcard_asterisk', '(a*)b', r'cannot take product of shape'),
('zero_sized_dim', '(a0)b', r'cannot take product of shape'),
('neg_sized_dim', '(ay)b', r'cannot take product of shape'),
('empty_start', '()ab', r'found empty parentheses in named shape'),
('empty_mid', 'a()b', r'found empty parentheses in named shape'),
('empty_end', 'ab()', r'found empty parentheses in named shape'),
('empty_solo', '()', r'found empty parentheses in named shape'),
])
def test_flatten_fail(self, named_shape, error_message):
dims = dimensions.Dimensions(a=3, b=5, x=None, y=-1)
with self.assertRaisesRegex(ValueError, error_message):
_ = dims[named_shape]


if __name__ == '__main__':
absltest.main()

0 comments on commit bf80d37

Please sign in to comment.