Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for check_array argument of concatenate_arrays #312

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions kerchunk/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,32 +571,66 @@ def concatenate_arrays(
else:
path = "/".join(path.rstrip(".").rstrip("/").split(".")) + "/"

def _replace(l: list, i: int, v) -> list:
l = l.copy()
l[i] = v
return l

n_files = len(files)

chunks_offset = 0
for i, fn in enumerate(files):
fs = fsspec.filesystem("reference", fo=fn, **(storage_options or {}))
zdata = ujson.load(fs.open(f"{path}.zarray"))
zarray = ujson.load(fs.open(f"{path}.zarray"))
shape = zarray["shape"]
chunks = zarray["chunks"]
n_chunks, rem = divmod(shape[axis], chunks[axis])
n_chunks += rem > 0

if i == 0:
shape = zdata["shape"]
chunks = zdata["chunks"]
chunks_per_file = int(shape[axis] / chunks[axis])
shape[axis] *= len(files)
zdata["shape"] = shape
out[f"{path}.zarray"] = ujson.dumps(zdata)
base_shape = _replace(shape, axis, None)
base_chunks = chunks
# result_* are modified in-place
result_zarray = zarray
result_shape = shape
for name in [".zgroup", ".zattrs", f"{path}.zattrs"]:
if name in fs.references:
out[name] = fs.references[name]
else:
result_shape[axis] += shape[axis]

# Safety checks
if check_arrays:
if shape != zdata["shape"]:
raise ValueError(f"Incompatible array shapes at {fn}")
if chunks != zdata["chunks"]:
raise ValueError(f"Incompatible array chunks at {fn}")
if _replace(shape, axis, None) != base_shape:
expected_shape = (
f"[{', '.join(map(str, _replace(base_shape, axis, '*')))}]"
)
raise ValueError(
f"Incompatible array shape at index {i}. Expected {expected_shape}, got {shape}."
)
if chunks != base_chunks:
raise ValueError(
f"Incompatible array chunks at index {i}. Expected {base_chunks}, got {chunks}."
)
if i < (n_files - 1) and rem != 0:
raise ValueError(
f"Array at index {i} has irregular chunking at its boundary. "
"This is only allowed for the final array."
)

# Referencing the offset chunks
for key in fs.find(""):
if key.startswith(f"{path}.z") or not key.startswith(path):
continue
parts = key.lstrip(path).split(key_seperator)
parts[axis] = str(int(parts[axis]) + i * chunks_per_file)
parts[axis] = str(int(parts[axis]) + chunks_offset)
key2 = path + key_seperator.join(parts)
out[key2] = fs.references[key]

chunks_offset += n_chunks

out[f"{path}.zarray"] = ujson.dumps(result_zarray)

return consolidate(out)


Expand Down
90 changes: 72 additions & 18 deletions kerchunk/tests/test_combine_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,61 @@
import kerchunk.zarr


def test_success(tmpdir):
fn1 = f"{tmpdir}/out1.zarr"
fn2 = f"{tmpdir}/out2.zarr"
x1 = np.arange(10)
x2 = np.arange(10, 20)
g = zarr.open(fn1)
g.create_dataset("x", data=x1, chunks=(2,))
g = zarr.open(fn2)
g.create_dataset("x", data=x2, chunks=(2,))

ref1 = kerchunk.zarr.single_zarr(fn1, inline=0)
ref2 = kerchunk.zarr.single_zarr(fn2, inline=0)
@pytest.mark.parametrize(
"arrays,chunks,axis",
[
(
[np.arange(10), np.arange(10, 20)],
(2,),
0,
),
(
[np.arange(12), np.arange(12, 36), np.arange(36, 42)],
(6,),
0,
),
(
# Terminal chunk does not need to be filled
[np.arange(5), np.arange(5, 10), np.arange(10, 17)],
(5,),
0,
),
(
[
np.broadcast_to(np.arange(6), (10, 6)),
np.broadcast_to(np.arange(7, 10), (10, 3)),
],
(10, 3),
1,
),
(
[
np.broadcast_to(np.arange(6), (10, 6)).T,
np.broadcast_to(np.arange(7, 10), (10, 3)).T,
],
(3, 10),
0,
),
],
)
def test_success(tmpdir, arrays, chunks, axis):
fns = []
refs = []
for i, x in enumerate(arrays):
fn = f"{tmpdir}/out{i}.zarr"
g = zarr.open(fn)
g.create_dataset("x", data=x, chunks=chunks)
fns.append(fn)
ref = kerchunk.zarr.single_zarr(fn, inline=0)
refs.append(ref)

out = kerchunk.combine.concatenate_arrays([ref1, ref2], path="x")
out = kerchunk.combine.concatenate_arrays(
refs, axis=axis, path="x", check_arrays=True
)

mapper = fsspec.get_mapper("reference://", fo=out)
g = zarr.open(mapper)
assert (g.x[:] == np.concatenate([x1, x2])).all()
assert (g.x[:] == np.concatenate(arrays, axis=axis)).all()


def test_fail_chunks(tmpdir):
Expand All @@ -41,15 +78,15 @@ def test_fail_chunks(tmpdir):
ref1 = kerchunk.zarr.single_zarr(fn1, inline=0)
ref2 = kerchunk.zarr.single_zarr(fn2, inline=0)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Incompatible array chunks at index 1.*"):
kerchunk.combine.concatenate_arrays([ref1, ref2], path="x", check_arrays=True)


def test_fail_shape(tmpdir):
fn1 = f"{tmpdir}/out1.zarr"
fn2 = f"{tmpdir}/out2.zarr"
x1 = np.arange(10).reshape(5, 2)
x2 = np.arange(10, 20)
x1 = np.arange(12).reshape(6, 2)
x2 = np.arange(12, 24)
g = zarr.open(fn1)
g.create_dataset("x", data=x1, chunks=(2,))
g = zarr.open(fn2)
Expand All @@ -58,5 +95,22 @@ def test_fail_shape(tmpdir):
ref1 = kerchunk.zarr.single_zarr(fn1, inline=0)
ref2 = kerchunk.zarr.single_zarr(fn2, inline=0)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Incompatible array shape at index 1.*"):
kerchunk.combine.concatenate_arrays([ref1, ref2], path="x", check_arrays=True)


def test_fail_irregular_chunk_boundaries(tmpdir):
fn1 = f"{tmpdir}/out1.zarr"
fn2 = f"{tmpdir}/out2.zarr"
x1 = np.arange(10)
x2 = np.arange(10, 24)
g = zarr.open(fn1)
g.create_dataset("x", data=x1, chunks=(4,))
g = zarr.open(fn2)
g.create_dataset("x", data=x2, chunks=(4,))

ref1 = kerchunk.zarr.single_zarr(fn1, inline=0)
ref2 = kerchunk.zarr.single_zarr(fn2, inline=0)

with pytest.raises(ValueError, match=r"Array at index 0 has irregular chunking.*"):
kerchunk.combine.concatenate_arrays([ref1, ref2], path="x", check_arrays=True)