From 4eb886f37932a64e58f2215d4740ba326f6855ea Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 5 Jun 2024 08:52:23 +0100 Subject: [PATCH] Fix `unify_chunks` to return regular chunks in all cases. (#470) --- cubed/core/ops.py | 36 ++++++++++- cubed/tests/test_array_api.py | 9 ++- cubed/tests/test_ops.py | 108 ++++++++++++++++++++++++++++++++ cubed/vendor/dask/array/core.py | 68 -------------------- 4 files changed, 149 insertions(+), 72 deletions(-) create mode 100644 cubed/tests/test_ops.py diff --git a/cubed/core/ops.py b/cubed/core/ops.py index ae7deab9..e5da5d5f 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -10,7 +10,7 @@ import numpy as np import zarr -from tlz import concat, partition +from tlz import concat, first, partition from toolz import accumulate, map from zarr.indexing import ( IntDimIndexer, @@ -37,7 +37,7 @@ offset_to_block_id, to_chunksize, ) -from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks +from cubed.vendor.dask.array.core import normalize_chunks from cubed.vendor.dask.array.utils import validate_axis from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product from cubed.vendor.dask.utils import has_keyword @@ -1383,7 +1383,9 @@ def unify_chunks(*args: "Array", **kwargs): else: nameinds.append((a, ind)) - chunkss = broadcast_dimensions(nameinds, blockdim_dict, consolidate=common_blockdim) + chunkss = broadcast_dimensions( + nameinds, blockdim_dict, consolidate=smallest_blockdim + ) arrays = [] for a, i in arginds: @@ -1400,8 +1402,36 @@ def unify_chunks(*args: "Array", **kwargs): ) if chunks != a.chunks and all(a.chunks): # this will raise if chunks are not regular + # but this should never happen with smallest_blockdim chunksize = to_chunksize(chunks) arrays.append(rechunk(a, chunksize)) else: arrays.append(a) return chunkss, arrays + + +def smallest_blockdim(blockdims): + """Find the smallest block dimensions from the list of block dimensions + + Unlike Dask's common_blockdim, this returns regular chunks (assuming + regular chunks are passed in). + """ + if not any(blockdims): + return () + non_trivial_dims = {d for d in blockdims if len(d) > 1} + if len(non_trivial_dims) == 1: + return first(non_trivial_dims) + if len(non_trivial_dims) == 0: + return max(blockdims, key=first) + + if len(set(map(sum, non_trivial_dims))) > 1: + raise ValueError("Chunks do not add up to same value", blockdims) + + # find dims with the smallest first chunk + m = -1 + out = None + for ntd in non_trivial_dims: + if m == -1 or ntd[0] < m: + m = ntd[0] + out = ntd + return out diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index e54c4843..49294671 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -169,7 +169,7 @@ def test_add(spec, any_executor): ) -def test_add_with_broadcast(spec, executor): +def test_add_different_chunks(spec, executor): a = xp.ones((10, 10), chunks=(10, 2), spec=spec) b = xp.ones((10, 10), chunks=(2, 10), spec=spec) c = xp.add(a, b) @@ -178,6 +178,13 @@ def test_add_with_broadcast(spec, executor): ) +def test_add_different_chunks_fail(spec, executor): + a = xp.ones((10,), chunks=(3,), spec=spec) + b = xp.ones((10,), chunks=(5,), spec=spec) + c = xp.add(a, b) + assert_array_equal(c.compute(executor=executor), np.ones((10,)) + np.ones((10,))) + + def test_equal(spec): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) b = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) diff --git a/cubed/tests/test_ops.py b/cubed/tests/test_ops.py new file mode 100644 index 00000000..ae51934b --- /dev/null +++ b/cubed/tests/test_ops.py @@ -0,0 +1,108 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import cubed.array_api as xp +from cubed.core.ops import smallest_blockdim, unify_chunks +from cubed.tests.utils import TaskCounter + + +def test_smallest_blockdim(): + assert smallest_blockdim([]) == () + assert smallest_blockdim([(5,), (5,)]) == (5,) + assert smallest_blockdim([(5,), (3, 2)]) == (3, 2) + assert smallest_blockdim([(5, 5), (3, 3, 3, 1)]) == (3, 3, 3, 1) + assert smallest_blockdim([(2, 1), (2, 1)]) == (2, 1) + assert smallest_blockdim([(2, 2, 1), (3, 2), (2, 2, 1)]) == (2, 2, 1) + + with pytest.raises(ValueError, match="Chunks do not add up to same value"): + smallest_blockdim([(2, 1), (2, 2)]) + + +@pytest.mark.parametrize( + "chunks_a, chunks_b, expected_chunksize", + [ + ((2,), (4,), (2,)), + ((4,), (2,), (2,)), + ((6,), (10,), (6,)), + ((10,), (10,), (10,)), + ((5,), (10,), (5,)), + ((3,), (5,), (3,)), + ((5,), (3,), (3,)), + ], +) +def test_unify_chunks_elemwise(chunks_a, chunks_b, expected_chunksize): + a = xp.ones((10,), chunks=chunks_a) + b = xp.ones((10,), chunks=chunks_b) + + _, arrays = unify_chunks(a, "i", b, "i") + for arr in arrays: + assert arr.chunksize == expected_chunksize + + c = xp.add(a, b) + assert_array_equal(c.compute(), np.ones((10,)) + np.ones((10,))) + + +@pytest.mark.parametrize( + "chunks_a, chunks_b, expected_chunksize", + [ + ((2, 2), (4, 4), (2, 2)), + ((2, 4), (4, 2), (2, 2)), + ((4, 2), (2, 4), (2, 2)), + ((3, 5), (5, 3), (3, 3)), + ((3, 10), (10, 3), (3, 3)), + ], +) +def test_unify_chunks_elemwise_2d(chunks_a, chunks_b, expected_chunksize): + a = xp.ones((10, 10), chunks=chunks_a) + b = xp.ones((10, 10), chunks=chunks_b) + + _, arrays = unify_chunks(a, "ij", b, "ij") + for arr in arrays: + assert arr.chunksize == expected_chunksize + + c = xp.add(a, b) + assert_array_equal(c.compute(), np.ones((10, 10)) + np.ones((10, 10))) + + +@pytest.mark.parametrize( + "chunks_a, chunks_b, expected_chunksize", + [ + ((2, 2), (4, 4), (2, 2)), + ((2, 4), (2, 4), (2, 2)), + ((4, 2), (4, 2), (2, 2)), + ((3, 5), (3, 5), (3, 3)), + ((3, 10), (3, 10), (3, 3)), + ], +) +def test_unify_chunks_blockwise_2d(chunks_a, chunks_b, expected_chunksize): + a = xp.ones((10, 10), chunks=chunks_a) + b = xp.ones((10, 10), chunks=chunks_b) + + _, arrays = unify_chunks(a, "ij", b, "ji") + for arr in arrays: + assert arr.chunksize == expected_chunksize + + c = xp.matmul(a, b) + assert_array_equal(c.compute(), np.matmul(np.ones((10, 10)), np.ones((10, 10)))) + + +def test_unify_chunks_broadcast_scalar(): + a = xp.ones((10,), chunks=(3,)) + b = a + 1 + assert_array_equal(b.compute(), np.ones((10,)) + 1) + + +def test_unify_chunks_broadcast_2d(): + a = xp.ones((10, 10), chunks=(3, 3)) + b = xp.ones((10,), chunks=(5,)) + c = xp.add(a, b) + + # the following checks that b is rechunked *before* broadcasting, to avoid materializing the full (broadcasted) array + task_counter = TaskCounter() + res = c.compute(callbacks=[task_counter]) + num_created_arrays = 2 # b rechunked, c + # 1 task for rechunk of b, 16 for addition operation + assert task_counter.value == num_created_arrays + 1 + 16 + + assert_array_equal(res, np.ones((10, 10)) + np.ones((10,))) diff --git a/cubed/vendor/dask/array/core.py b/cubed/vendor/dask/array/core.py index 280e3a28..f92c8dec 100644 --- a/cubed/vendor/dask/array/core.py +++ b/cubed/vendor/dask/array/core.py @@ -462,71 +462,3 @@ def _check_regular_chunks(chunkset): if chunks[-1] > chunks[0]: return False return True - - -def common_blockdim(blockdims): - """Find the common block dimensions from the list of block dimensions - - Currently only implements the simplest possible heuristic: the common - block-dimension is the only one that does not span fully span a dimension. - This is a conservative choice that allows us to avoid potentially very - expensive rechunking. - - Assumes that each element of the input block dimensions has all the same - sum (i.e., that they correspond to dimensions of the same size). - - Examples - -------- - >>> common_blockdim([(3,), (2, 1)]) - (2, 1) - >>> common_blockdim([(1, 2), (2, 1)]) - (1, 1, 1) - >>> common_blockdim([(2, 2), (3, 1)]) # doctest: +SKIP - Traceback (most recent call last): - ... - ValueError: Chunks do not align - """ - if not any(blockdims): - return () - non_trivial_dims = {d for d in blockdims if len(d) > 1} - if len(non_trivial_dims) == 1: - return first(non_trivial_dims) - if len(non_trivial_dims) == 0: - return max(blockdims, key=first) - - if np.isnan(sum(map(sum, blockdims))): - raise ValueError( - "Arrays' chunk sizes (%s) are unknown.\n\n" - "A possible solution:\n" - " x.compute_chunk_sizes()" % blockdims - ) - - if len(set(map(sum, non_trivial_dims))) > 1: - raise ValueError("Chunks do not add up to same value", blockdims) - - # We have multiple non-trivial chunks on this axis - # e.g. (5, 2) and (4, 3) - - # We create a single chunk tuple with the same total length - # that evenly divides both, e.g. (4, 1, 2) - - # To accomplish this we walk down all chunk tuples together, finding the - # smallest element, adding it to the output, and subtracting it from all - # other elements and remove the element itself. We stop once we have - # burned through all of the chunk tuples. - # For efficiency's sake we reverse the lists so that we can pop off the end - rchunks = [list(ntd)[::-1] for ntd in non_trivial_dims] - total = sum(first(non_trivial_dims)) - i = 0 - - out = [] - while i < total: - m = min(c[-1] for c in rchunks) - out.append(m) - for c in rchunks: - c[-1] -= m - if c[-1] == 0: - c.pop() - i += m - - return tuple(out)