Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support chunks in open_groups and open_datatree #9660

Merged
merged 23 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
fe95b16
support chunking and default values in `open_groups`
keewis Oct 22, 2024
3bfbc3a
same for `open_datatree`
keewis Oct 22, 2024
f4abb01
use `group_subtrees` instead of `map_over_datasets`
keewis Oct 22, 2024
b0458aa
check that `chunks` on `open_datatree` works
keewis Oct 22, 2024
4dbd91e
specify the chunksizes when opening from disk
keewis Oct 23, 2024
11850fd
check that `open_groups` with chunks works, too
keewis Oct 23, 2024
a71f5e2
require dask for `test_open_groups_chunks`
TomNicholas Oct 23, 2024
6d3deed
protect variables from write operations
keewis Oct 23, 2024
7f770cf
copy over `_close` from the backend tree
keewis Oct 23, 2024
05efaf6
copy a lot of the docstring from `open_dataset`
keewis Oct 23, 2024
f9fee40
same for `open_groups`
keewis Oct 23, 2024
2e10bdc
Merge branch 'main' into open_datatree-dask
keewis Oct 23, 2024
a4e99c6
reuse `_protect_dataset_variables_inplace`
keewis Oct 23, 2024
3e8b80c
final missing `requires_dask`
keewis Oct 23, 2024
cf1a6b0
typing for the test utils
keewis Oct 24, 2024
114c4dc
type hints for `_protect_datatree_variables_inplace`
keewis Oct 24, 2024
9eac19d
type hints for `_protect_dataset_variables_inplace`
keewis Oct 24, 2024
446a53d
copy over the name of the backend tree
keewis Oct 24, 2024
5b36701
typo
keewis Oct 24, 2024
66616f7
swap the order of arguments to `assert_identical`
keewis Oct 24, 2024
843b2fc
try explicitly typing `data`
keewis Oct 24, 2024
8950841
typo
keewis Oct 24, 2024
4d93ada
use `Hashable` for variable names
keewis Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 173 additions & 3 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
)
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk
from xarray.core.datatree import DataTree
from xarray.core.indexes import Index
from xarray.core.treenode import group_subtrees
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes
from xarray.core.utils import is_remote_uri
from xarray.namedarray.daskmanager import DaskManager
Expand Down Expand Up @@ -74,7 +76,6 @@
T_NetcdfTypes = Literal[
"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC"
]
from xarray.core.datatree import DataTree

DATAARRAY_NAME = "__xarray_dataarray_name__"
DATAARRAY_VARIABLE = "__xarray_dataarray_variable__"
Expand Down Expand Up @@ -414,6 +415,56 @@ def _dataset_from_backend_dataset(
return ds


def _datatree_from_backend_datatree(
backend_tree,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
**extra_tokens,
):
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
)

# _protect_datatree_variables_inplace(backend_tree, cache)
keewis marked this conversation as resolved.
Show resolved Hide resolved
if chunks is None:
tree = backend_tree
else:
tree = DataTree.from_dict(
{
path: _chunk_ds(
node.dataset,
filename_or_obj,
engine,
chunks,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
**extra_tokens,
)
for path, [node] in group_subtrees(backend_tree)
}
)
keewis marked this conversation as resolved.
Show resolved Hide resolved

# ds.set_close(backend_ds._close)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backend_tree should have been created using datatree_from_dict_with_io_cleanup, so one way to handle this could be just to copy over the _close attribute from every node of backend_tree?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the question is, do we even need that here? I copied this from open_dataset where this is explicitly set, but since datatree_from_dict_with_io_cleanup does this already we might be able to just remove it?

The only reason why I kept the commented-out line is to discuss whether the shift in paradigm (have the backend set _close vs. do it for all backends the same way) is intentional, and if we should do the same for open_dataset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it would be nice to remove this, I'm just worried that mapping over the each .dataset might not properly propagate ._close (does it? should it?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not (I think), so I'm explicitly copying it over. So far that doesn't appear to cause anything to break.


# Ensure source filename always stored in dataset object
if "source" not in tree.encoding:
path = getattr(filename_or_obj, "path", filename_or_obj)

if isinstance(path, str | os.PathLike):
tree.encoding["source"] = _normalize_path(path)
keewis marked this conversation as resolved.
Show resolved Hide resolved

return tree


def open_dataset(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
Expand Down Expand Up @@ -838,7 +889,22 @@ def open_dataarray(

def open_datatree(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
cache: bool | None = None,
decode_cf: bool | None = None,
mask_and_scale: bool | Mapping[str, bool] | None = None,
decode_times: bool | Mapping[str, bool] | None = None,
decode_timedelta: bool | Mapping[str, bool] | None = None,
use_cftime: bool | Mapping[str, bool] | None = None,
concat_characters: bool | Mapping[str, bool] | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> DataTree:
"""
Expand All @@ -856,17 +922,75 @@ def open_datatree(
-------
xarray.DataTree
"""
if cache is None:
cache = chunks is None

if backend_kwargs is not None:
kwargs.update(backend_kwargs)

if engine is None:
engine = plugins.guess_engine(filename_or_obj)

if from_array_kwargs is None:
from_array_kwargs = {}

backend = plugins.get_backend(engine)

return backend.open_datatree(filename_or_obj, **kwargs)
decoders = _resolve_decoders_kwargs(
decode_cf,
open_backend_dataset_parameters=(),
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
)
overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)

backend_tree = backend.open_datatree(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)

tree = _datatree_from_backend_datatree(
backend_tree,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
**decoders,
**kwargs,
)

return tree


def open_groups(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
cache: bool | None = None,
decode_cf: bool | None = None,
mask_and_scale: bool | Mapping[str, bool] | None = None,
decode_times: bool | Mapping[str, bool] | None = None,
decode_timedelta: bool | Mapping[str, bool] | None = None,
use_cftime: bool | Mapping[str, bool] | None = None,
concat_characters: bool | Mapping[str, bool] | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> dict[str, Dataset]:
"""
Expand All @@ -893,12 +1017,58 @@ def open_groups(
open_datatree()
DataTree.from_dict()
"""
if cache is None:
cache = chunks is None

if backend_kwargs is not None:
kwargs.update(backend_kwargs)

if engine is None:
engine = plugins.guess_engine(filename_or_obj)

if from_array_kwargs is None:
from_array_kwargs = {}

backend = plugins.get_backend(engine)

return backend.open_groups_as_dict(filename_or_obj, **kwargs)
decoders = _resolve_decoders_kwargs(
decode_cf,
open_backend_dataset_parameters=(),
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
)
overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)

backend_groups = backend.open_groups_as_dict(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)

groups = {
name: _dataset_from_backend_dataset(
backend_ds,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
**decoders,
**kwargs,
)
for name, backend_ds in backend_groups.items()
}

return groups


def open_mfdataset(
Expand Down
Loading
Loading