Skip to content

Commit

Permalink
Icechunk store (#633)
Browse files Browse the repository at this point in the history
* Add option to return stores in apply_blockwise function to support icechunk

* Add store_icechunk

* Add icechunk optional dependency

* Run icechunk tests in CI

* Typing improvements

* Don't include icechunk in coverage

* Update to Icechunk 0.1.0-alpha.10 and Zarr 3.0.0

* Move icechunk CI tests to own workflow

* Fix mypy
  • Loading branch information
tomwhite authored Jan 17, 2025
1 parent 5f75ba2 commit c487014
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 4 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ omit =
*/tests/*
cubed/array_api/*
cubed/diagnostics/memray.py
cubed/icechunk.py
cubed/runtime/executors/beam.py
cubed/runtime/executors/coiled.py
cubed/runtime/executors/dask*.py
Expand Down
42 changes: 42 additions & 0 deletions .github/workflows/icechunk-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Icechunk tests

on:
push:
branches:
- "main"
pull_request:
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.12"]

steps:
- name: Checkout source
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64

- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[test]' 'icechunk'
- name: Run tests
run: |
pytest -v -k icechunk
82 changes: 82 additions & 0 deletions cubed/icechunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import TYPE_CHECKING, Any, List, Sequence, Union

import zarr
from icechunk import Session

from cubed import compute
from cubed.core.array import CoreArray
from cubed.core.ops import blockwise
from cubed.runtime.types import Callback

if TYPE_CHECKING:
from cubed.array_api.array_object import Array


def store_icechunk(
session: Session,
*,
sources: Union["Array", Sequence["Array"]],
targets: List[zarr.Array],
executor=None,
**kwargs: Any,
) -> None:
if isinstance(sources, CoreArray):
sources = [sources]
targets = [targets] # type: ignore

if any(not isinstance(s, CoreArray) for s in sources):
raise ValueError("All sources must be cubed array objects")

if len(sources) != len(targets):
raise ValueError(
f"Different number of sources ({len(sources)}) and targets ({len(targets)})"
)

arrays = []
for source, target in zip(sources, targets):
identity = lambda a: a
ind = tuple(range(source.ndim))
array = blockwise(
identity,
ind,
source,
ind,
dtype=source.dtype,
align_arrays=False,
target_store=target,
return_writes_stores=True,
)
arrays.append(array)

# use a callback to merge icechunk sessions
store_callback = IcechunkStoreCallback()
# add to other callbacks the user may have set
callbacks = kwargs.pop("callbacks", [])
callbacks = [store_callback] + list(callbacks)

compute(
*arrays,
executor=executor,
_return_in_memory_array=False,
callbacks=callbacks,
**kwargs,
)

# merge back into the session passed into this function
merged_session = store_callback.session
session.merge(merged_session)


class IcechunkStoreCallback(Callback):
def on_compute_start(self, event):
self.session = None

def on_task_end(self, event):
result = event.result
if result is None:
return
for store in result:
if self.session is None:
self.session = store.session
else:
self.session.merge(store.session)
15 changes: 14 additions & 1 deletion cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,12 @@ class BlockwiseSpec:
iterable_input_blocks: Tuple[bool, ...]
reads_map: Dict[str, CubedArrayProxy]
writes_list: List[CubedArrayProxy]
return_writes_stores: bool = False


def apply_blockwise(out_coords: List[int], *, config: BlockwiseSpec) -> None:
def apply_blockwise(
out_coords: List[int], *, config: BlockwiseSpec
) -> Optional[List[T_Store]]:
"""Stage function for blockwise."""
# lithops needs params to be lists not tuples, so convert back
out_coords_tuple = tuple(out_coords)
Expand All @@ -100,6 +103,10 @@ def apply_blockwise(out_coords: List[int], *, config: BlockwiseSpec) -> None:
result = backend_array_to_numpy_array(result)
config.writes_list[i].open()[out_chunk_key] = result

if config.return_writes_stores:
return [write_proxy.open().store for write_proxy in config.writes_list]
return None


def get_results_in_different_scope(out_coords: List[int], *, config: BlockwiseSpec):
# wrap function call in a function so that args go out of scope (and free memory) as soon as results are returned
Expand Down Expand Up @@ -267,6 +274,7 @@ def general_blockwise(
function_nargs: Optional[int] = None,
num_input_blocks: Optional[Tuple[int, ...]] = None,
iterable_input_blocks: Optional[Tuple[bool, ...]] = None,
return_writes_stores: bool = False,
**kwargs,
) -> PrimitiveOperation:
"""A more general form of ``blockwise`` that uses a function to specify the block
Expand Down Expand Up @@ -365,6 +373,7 @@ def general_blockwise(
iterable_input_blocks,
read_proxies,
write_proxies,
return_writes_stores,
)

# calculate projected memory
Expand Down Expand Up @@ -534,6 +543,7 @@ def fused_func(*args):
function_nargs = pipeline1.config.function_nargs
read_proxies = pipeline1.config.reads_map
write_proxies = pipeline2.config.writes_list
return_writes_stores = pipeline2.config.return_writes_stores
num_input_blocks = tuple(
n * pipeline2.config.num_input_blocks[0]
for n in pipeline1.config.num_input_blocks
Expand All @@ -547,6 +557,7 @@ def fused_func(*args):
iterable_input_blocks,
read_proxies,
write_proxies,
return_writes_stores,
)

source_array_names = primitive_op1.source_array_names
Expand Down Expand Up @@ -679,6 +690,7 @@ def fuse_blockwise_specs(
for bws in predecessor_bw_specs:
read_proxies.update(bws.reads_map)
write_proxies = bw_spec.writes_list
return_writes_stores = bw_spec.return_writes_stores
return BlockwiseSpec(
fused_key_func,
fused_func,
Expand All @@ -687,6 +699,7 @@ def fuse_blockwise_specs(
fused_iterable_input_blocks,
read_proxies,
write_proxies,
return_writes_stores,
)


Expand Down
6 changes: 3 additions & 3 deletions cubed/runtime/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass
from typing import Any, Iterable, Optional
from typing import Any, Callable, Iterable, Optional

from networkx import MultiDiGraph

from cubed.vendor.rechunker.types import Config, StageFunction
from cubed.vendor.rechunker.types import Config


class DagExecutor:
Expand All @@ -22,7 +22,7 @@ def execute_dag(self, dag: MultiDiGraph, **kwargs) -> None:
class CubedPipeline:
"""Generalisation of rechunker ``Pipeline`` with extra attributes."""

function: StageFunction
function: Callable[..., Any]
name: str
mappable: Iterable
config: Config
Expand Down
91 changes: 91 additions & 0 deletions cubed/tests/test_icechunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Iterable

import numpy as np
import pytest
import zarr
from numpy.testing import assert_array_equal

import cubed
import cubed.array_api as xp
import cubed.random
from cubed.tests.utils import MAIN_EXECUTORS

icechunk = pytest.importorskip("icechunk")

from icechunk import Repository, Storage

from cubed.icechunk import store_icechunk


@pytest.fixture(
scope="module",
params=MAIN_EXECUTORS,
ids=[executor.name for executor in MAIN_EXECUTORS],
)
def executor(request):
return request.param


@pytest.fixture(scope="function")
def icechunk_storage(tmpdir) -> "Storage":
return Storage.new_local_filesystem(str(tmpdir))


def create_icechunk(a, icechunk_storage, /, *, dtype=None, chunks=None):
# from dask.asarray
if not isinstance(getattr(a, "shape", None), Iterable):
# ensure blocks are arrays
a = np.asarray(a, dtype=dtype)
if dtype is None:
dtype = a.dtype

repo = Repository.create(storage=icechunk_storage)
session = repo.writable_session("main")
store = session.store

group = zarr.group(store=store, overwrite=True)
arr = group.create_array("a", shape=a.shape, dtype=dtype, chunks=chunks)

arr[...] = a

session.commit("commit 1")


def test_from_zarr_icechunk(icechunk_storage, executor):
create_icechunk(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
icechunk_storage,
chunks=(2, 2),
)

repo = Repository.open(icechunk_storage)
session = repo.readonly_session(branch="main")
store = session.store

a = cubed.from_zarr(store, path="a")
assert_array_equal(
a.compute(executor=executor), np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
)


def test_store_icechunk(icechunk_storage, executor):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))

repo = Repository.create(storage=icechunk_storage)
session = repo.writable_session("main")
store = session.store

group = zarr.group(store=store, overwrite=True)
target = group.create_array("a", shape=a.shape, dtype=a.dtype, chunks=a.chunksize)
store_icechunk(session, sources=a, targets=target, executor=executor)
session.commit("commit 1")

# reopen store and check contents of array
repo = Repository.open(icechunk_storage)
session = repo.readonly_session(branch="main")
store = session.store

group = zarr.open_group(store=store, mode="r")
assert_array_equal(
cubed.from_array(group["a"])[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ diagnostics = [
beam = ["apache-beam", "gcsfs"]
dask = ["dask < 2024.12.0"]
dask-distributed = ["distributed < 2024.12.0"]
icechunk = ["icechunk"]
lithops = ["lithops[aws] >= 2.7.0"]
lithops-aws = [
"cubed[diagnostics]",
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-fsspec.*]
ignore_missing_imports = True
[mypy-icechunk.*]
ignore_missing_imports = True
[mypy-lithops.*]
ignore_missing_imports = True
[mypy-IPython.*]
Expand Down

0 comments on commit c487014

Please sign in to comment.