diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 0d5631a0c..567a98638 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -44,7 +44,7 @@ from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Iterator from os import PathLike from typing import Any, Literal @@ -375,13 +375,12 @@ def write_list( # It's in the `AnnData.concatenate` docstring, but should we keep it? @_REGISTRY.register_write(H5Group, views.ArrayView, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(H5Group, np.ndarray, IOSpec("array", "0.2.0")) -@_REGISTRY.register_write(H5Group, h5py.Dataset, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(H5Group, np.ma.MaskedArray, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, views.ArrayView, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, np.ndarray, IOSpec("array", "0.2.0")) -@_REGISTRY.register_write(ZarrGroup, h5py.Dataset, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, np.ma.MaskedArray, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, ZarrArray, IOSpec("array", "0.2.0")) +@_REGISTRY.register_write(ZarrGroup, H5Array, IOSpec("array", "0.2.0")) @zero_dim_array_as_scalar def write_basic( f: GroupStorageType, @@ -395,6 +394,51 @@ def write_basic( f.create_dataset(k, data=elem, **dataset_kwargs) +def _iter_chunks_for_copy( + elem: ArrayStorageType, dest: ArrayStorageType +) -> Iterator[slice | tuple[list[slice]]]: + """ + Returns an iterator of tuples of slices for copying chunks from `elem` to `dest`. + + * If `dest` has chunks, it will return the chunks of `dest`. + * If `dest` is not chunked, we write it in ~100MB chunks or 1000 rows, whichever is larger. + """ + if dest.chunks and hasattr(dest, "iter_chunks"): + return dest.iter_chunks() + else: + itemsize = elem.dtype.itemsize + shape = elem.shape + # Number of elements to write + entry_chunk_size = 100 * 1024 * 1024 // itemsize + # Number of rows that works out to + n_rows = max(entry_chunk_size // shape[0], 1000) + return (slice(i, min(i + n_rows, shape[0])) for i in range(0, shape[0], n_rows)) + + +@_REGISTRY.register_write(H5Group, H5Array, IOSpec("array", "0.2.0")) +@_REGISTRY.register_write(H5Group, ZarrArray, IOSpec("array", "0.2.0")) +def write_chunked_dense_array_to_group( + f: GroupStorageType, + k: str, + elem: ArrayStorageType, + *, + _writer: Writer, + dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), +): + """Write to a h5py.Dataset in chunks. + + `h5py.Group.create_dataset(..., data: h5py.Dataset)` will load all of `data` into memory + before writing. Instead, we will write in chunks to avoid this. We don't need to do this for + zarr since zarr handles this automatically. + """ + dtype = dataset_kwargs.get("dtype", elem.dtype) + kwargs = {**dataset_kwargs, "dtype": dtype} + dest = f.create_dataset(k, shape=elem.shape, **kwargs) + + for chunk in _iter_chunks_for_copy(elem, dest): + dest[chunk] = elem[chunk] + + _REGISTRY.register_write(H5Group, CupyArray, IOSpec("array", "0.2.0"))( _to_cpu_mem_wrapper(write_basic) ) @@ -605,10 +649,14 @@ def write_sparse_compressed( # Allow resizing for hdf5 if isinstance(f, H5Group) and "maxshape" not in dataset_kwargs: dataset_kwargs = dict(maxshape=(None,), **dataset_kwargs) - - g.create_dataset("data", data=value.data, **dataset_kwargs) - g.create_dataset("indices", data=value.indices, **dataset_kwargs) - g.create_dataset("indptr", data=value.indptr, dtype=indptr_dtype, **dataset_kwargs) + _writer.write_elem(g, "data", value.data, dataset_kwargs=dataset_kwargs) + _writer.write_elem(g, "indices", value.indices, dataset_kwargs=dataset_kwargs) + _writer.write_elem( + g, + "indptr", + value.indptr, + dataset_kwargs={"dtype": indptr_dtype, **dataset_kwargs}, + ) write_csr = partial(write_sparse_compressed, fmt="csr") diff --git a/tests/test_io_dispatched.py b/tests/test_io_dispatched.py index 0bbbf285a..c0d60c18d 100644 --- a/tests/test_io_dispatched.py +++ b/tests/test_io_dispatched.py @@ -175,7 +175,5 @@ def zarr_reader(func, elem_name: str, elem, iospec): write_dispatched(f, "/", adata, callback=zarr_writer) _ = read_dispatched(f, zarr_reader) - assert h5ad_write_keys == zarr_write_keys - assert h5ad_read_keys == zarr_read_keys - - assert sorted(h5ad_write_keys) == sorted(h5ad_read_keys) + assert sorted(h5ad_write_keys) == sorted(zarr_write_keys) + assert sorted(h5ad_read_keys) == sorted(zarr_read_keys) diff --git a/tests/test_io_elementwise.py b/tests/test_io_elementwise.py index 0f5bfb883..884ba9f8f 100644 --- a/tests/test_io_elementwise.py +++ b/tests/test_io_elementwise.py @@ -188,6 +188,18 @@ def create_sparse_store( pytest.param( pd.array([True, False, True, True]), "nullable-boolean", id="pd_arr_bool" ), + pytest.param( + zarr.ones((100, 100), chunks=(10, 10)), + "array", + id="zarr_dense_array", + ), + pytest.param( + create_dense_store( + h5py.File("test1.h5", mode="w", driver="core", backing_store=False) + )["X"], + "array", + id="h5_dense_array", + ), # pytest.param(bytes, b"some bytes", "bytes", id="py_bytes"), # Does not work for zarr # TODO consider how specific encodings should be. Should we be fully describing the written type? # Currently the info we add is: "what you wouldn't be able to figure out yourself"