Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor out utility functions from to_zarr #9695

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 13 additions & 70 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
_normalize_path,
)
from xarray.backends.locks import _get_scheduler
from xarray.backends.zarr import _zarr_v3
from xarray.core import indexing
from xarray.core.combine import (
_infer_concat_order_from_positions,
Expand Down Expand Up @@ -2131,73 +2130,33 @@ def to_zarr(

See `Dataset.to_zarr` for full API docs.
"""
from xarray.backends.zarr import _choose_default_mode, _get_mappers

# validate Dataset keys, DataArray names
_validate_dataset_names(dataset)

# Load empty arrays to avoid bug saving zero length dimensions (Issue #5741)
# TODO: delete when min dask>=2023.12.1
# https://github.com/dask/dask/pull/10506
for v in dataset.variables.values():
if v.size == 0:
v.load()

# expand str and path-like arguments
store = _normalize_path(store)
chunk_store = _normalize_path(chunk_store)

kwargs = {}
if storage_options is None:
mapper = store
chunk_mapper = chunk_store
else:
if not isinstance(store, str):
raise ValueError(
f"store must be a string to use storage_options. Got {type(store)}"
)

if _zarr_v3():
kwargs["storage_options"] = storage_options
mapper = store
chunk_mapper = chunk_store
else:
from fsspec import get_mapper

mapper = get_mapper(store, **storage_options)
if chunk_store is not None:
chunk_mapper = get_mapper(chunk_store, **storage_options)
else:
chunk_mapper = chunk_store

if encoding is None:
encoding = {}

if mode is None:
if append_dim is not None:
mode = "a"
elif region is not None:
mode = "r+"
else:
mode = "w-"

if mode not in ["a", "a-"] and append_dim is not None:
raise ValueError("cannot set append_dim unless mode='a' or mode=None")

if mode not in ["a", "a-", "r+"] and region is not None:
raise ValueError(
"cannot set region unless mode='a', mode='a-', mode='r+' or mode=None"
)

if mode not in ["w", "w-", "a", "a-", "r+"]:
raise ValueError(
"The only supported options for mode are 'w', "
f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}"
)

# validate Dataset keys, DataArray names
_validate_dataset_names(dataset)
kwargs, mapper, chunk_mapper = _get_mappers(
storage_options=storage_options, store=store, chunk_store=chunk_store
)
mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region)

if mode == "r+":
already_consolidated = consolidated
consolidate_on_close = False
else:
already_consolidated = False
consolidate_on_close = consolidated or consolidated is None

zstore = backends.ZarrStore.open_group(
store=mapper,
mode=mode,
Expand All @@ -2209,30 +2168,14 @@ def to_zarr(
append_dim=append_dim,
write_region=region,
safe_chunks=safe_chunks,
stacklevel=4, # for Dataset.to_zarr()
zarr_version=zarr_version,
zarr_format=zarr_format,
write_empty=write_empty_chunks,
**kwargs,
)

if region is not None:
zstore._validate_and_autodetect_region(dataset)
# can't modify indexes with region writes
dataset = dataset.drop_vars(dataset.indexes)
if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
f"``region`` with to_zarr(), got {append_dim} in both"
)

if encoding and mode in ["a", "a-", "r+"]:
existing_var_names = set(zstore.zarr_group.array_keys())
for var_name in existing_var_names:
if var_name in encoding:
raise ValueError(
f"variable {var_name!r} already exists, but encoding was provided"
)
dataset = zstore._validate_and_autodetect_region(dataset)
zstore._validate_encoding(encoding)

writer = ArrayWriter()
# TODO: figure out how to properly handle unlimited_dims
Expand Down
101 changes: 84 additions & 17 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import json
import os
import struct
import warnings
from collections.abc import Iterable
from collections.abc import Hashable, Iterable, Mapping
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand Down Expand Up @@ -46,6 +45,66 @@
from xarray.core.datatree import DataTree


def _get_mappers(*, storage_options, store, chunk_store):
# expand str and path-like arguments
store = _normalize_path(store)
chunk_store = _normalize_path(chunk_store)

kwargs = {}
if storage_options is None:
mapper = store
chunk_mapper = chunk_store
else:
if not isinstance(store, str):
raise ValueError(
f"store must be a string to use storage_options. Got {type(store)}"
)

if _zarr_v3():
kwargs["storage_options"] = storage_options
mapper = store
chunk_mapper = chunk_store
else:
from fsspec import get_mapper

mapper = get_mapper(store, **storage_options)
if chunk_store is not None:
chunk_mapper = get_mapper(chunk_store, **storage_options)
else:
chunk_mapper = chunk_store
return kwargs, mapper, chunk_mapper


def _choose_default_mode(
*,
mode: ZarrWriteModes | None,
append_dim: Hashable | None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None,
) -> ZarrWriteModes:
if mode is None:
if append_dim is not None:
mode = "a"
elif region is not None:
mode = "r+"
else:
mode = "w-"

if mode not in ["a", "a-"] and append_dim is not None:
raise ValueError("cannot set append_dim unless mode='a' or mode=None")

if mode not in ["a", "a-", "r+"] and region is not None:
raise ValueError(
"cannot set region unless mode='a', mode='a-', mode='r+' or mode=None"
)

if mode not in ["w", "w-", "a", "a-", "r+"]:
raise ValueError(
"The only supported options for mode are 'w', "
f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}"
)
return mode


def _zarr_v3() -> bool:
# TODO: switch to "3" once Zarr V3 is released
return module_available("zarr", minversion="2.99")
Expand Down Expand Up @@ -555,7 +614,6 @@ def open_store(
append_dim=None,
write_region=None,
safe_chunks=True,
stacklevel=2,
zarr_version=None,
zarr_format=None,
use_zarr_fill_value_as_mask=None,
Expand All @@ -575,7 +633,6 @@ def open_store(
consolidate_on_close=consolidate_on_close,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask,
zarr_format=zarr_format,
Expand Down Expand Up @@ -610,7 +667,6 @@ def open_group(
append_dim=None,
write_region=None,
safe_chunks=True,
stacklevel=2,
zarr_version=None,
zarr_format=None,
use_zarr_fill_value_as_mask=None,
Expand All @@ -630,7 +686,6 @@ def open_group(
consolidate_on_close=consolidate_on_close,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask,
zarr_format=zarr_format,
Expand Down Expand Up @@ -1092,7 +1147,10 @@ def _auto_detect_regions(self, ds, region):
region[dim] = slice(idxs[0], idxs[-1] + 1)
return region

def _validate_and_autodetect_region(self, ds) -> None:
def _validate_and_autodetect_region(self, ds: Dataset) -> Dataset:
if self._write_region is None:
return ds

region = self._write_region

if region == "auto":
Expand Down Expand Up @@ -1140,8 +1198,26 @@ def _validate_and_autodetect_region(self, ds) -> None:
f".drop_vars({non_matching_vars!r})"
)

if self._append_dim is not None and self._append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
f"``region`` with to_zarr(), got {self._append_dim} in both"
)

self._write_region = region

# can't modify indexes with region writes
return ds.drop_vars(ds.indexes)

def _validate_encoding(self, encoding) -> None:
if encoding and self._mode in ["a", "a-", "r+"]:
existing_var_names = set(self.zarr_group.array_keys())
for var_name in existing_var_names:
if var_name in encoding:
raise ValueError(
f"variable {var_name!r} already exists, but encoding was provided"
)


def open_zarr(
store,
Expand Down Expand Up @@ -1316,7 +1392,6 @@ def open_zarr(
"overwrite_encoded_chunks": overwrite_encoded_chunks,
"chunk_store": chunk_store,
"storage_options": storage_options,
"stacklevel": 4,
"zarr_version": zarr_version,
"zarr_format": zarr_format,
}
Expand Down Expand Up @@ -1385,7 +1460,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
consolidated=None,
chunk_store=None,
storage_options=None,
stacklevel=3,
zarr_version=None,
zarr_format=None,
store=None,
Expand All @@ -1403,7 +1477,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=None,
zarr_format=zarr_format,
Expand Down Expand Up @@ -1440,7 +1513,6 @@ def open_datatree(
consolidated=None,
chunk_store=None,
storage_options=None,
stacklevel=3,
zarr_version=None,
zarr_format=None,
**kwargs,
Expand All @@ -1461,7 +1533,6 @@ def open_datatree(
consolidated=consolidated,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel,
zarr_version=zarr_version,
zarr_format=zarr_format,
**kwargs,
Expand All @@ -1485,7 +1556,6 @@ def open_groups_as_dict(
consolidated=None,
chunk_store=None,
storage_options=None,
stacklevel=3,
zarr_version=None,
zarr_format=None,
**kwargs,
Expand All @@ -1509,7 +1579,6 @@ def open_groups_as_dict(
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
zarr_format=zarr_format,
)
Expand Down Expand Up @@ -1554,7 +1623,6 @@ def _get_open_params(
consolidate_on_close,
chunk_store,
storage_options,
stacklevel,
zarr_version,
use_zarr_fill_value_as_mask,
zarr_format,
Expand Down Expand Up @@ -1599,7 +1667,7 @@ def _get_open_params(
# ValueError in zarr-python 3.x, KeyError in 2.x.
try:
zarr_group = zarr.open_group(store, **open_kwargs)
warnings.warn(
emit_user_level_warning(
"Failed to open Zarr store with consolidated metadata, "
"but successfully read with non-consolidated metadata. "
"This is typically much slower for opening a dataset. "
Expand All @@ -1612,7 +1680,6 @@ def _get_open_params(
"error in this case instead of falling back to try "
"reading non-consolidated metadata.",
RuntimeWarning,
stacklevel=stacklevel,
)
except missing_exc as err:
raise FileNotFoundError(
Expand Down
Loading