diff --git a/conftest.py b/conftest.py index b558abfd..ff09d4ac 100644 --- a/conftest.py +++ b/conftest.py @@ -53,6 +53,19 @@ def netcdf4_files(tmpdir): return filepath1, filepath2 +@pytest.fixture +def netcdf4_file_with_2d_coords(tmpdir): + # Set up example xarray dataset + ds = xr.tutorial.open_dataset("ROMS_example.nc") + + # Save it to disk as netCDF (in temporary directory) + filepath = f"{tmpdir}/ROMS_example.nc" + ds.to_netcdf(filepath, format="NETCDF4") + ds.close() + + return filepath + + @pytest.fixture def hdf5_empty(tmpdir): filepath = f"{tmpdir}/empty.nc" diff --git a/virtualizarr/backend.py b/virtualizarr/backend.py index 87c2aa2a..9a41a6a2 100644 --- a/virtualizarr/backend.py +++ b/virtualizarr/backend.py @@ -7,20 +7,27 @@ Any, Hashable, Optional, + TypeAlias, cast, ) import xarray as xr from xarray.backends import AbstractDataStore, BackendArray from xarray.coding.times import CFDatetimeCoder +from xarray.conventions import decode_cf_variables from xarray.core.indexes import Index, PandasIndex -from xarray.core.variable import IndexVariable +from xarray.core.variable import IndexVariable, Variable from virtualizarr.manifests import ManifestArray from virtualizarr.utils import _fsspec_openfile_from_filepath XArrayOpenT = str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore +T_Attrs = MutableMapping[Any, Any] +T_Variables = Mapping[Any, Variable] +# alias for (dims, data, attrs, encoding) +T_VariableExpanded: TypeAlias = tuple[Hashable, Any, dict[Any, Any], dict[Any, Any]] + class AutoName(Enum): # Recommended by official Python docs for auto naming: @@ -238,43 +245,56 @@ def open_virtual_dataset( vars = {**virtual_vars, **loadable_vars} - data_vars, coords = separate_coords(vars, indexes, coord_names) + decoded_vars, decoded_attrs, coord_names = determine_cf_coords(vars, ds_attrs) - vds = xr.Dataset( - data_vars, - coords=coords, - # indexes={}, # TODO should be added in a later version of xarray - attrs=ds_attrs, + vds = construct_virtual_dataset( + decoded_vars, indexes, decoded_attrs, coord_names ) - # TODO we should probably also use vds.set_close() to tell xarray how to close the file we opened - return vds -def separate_coords( +def determine_cf_coords( + variables: T_Variables, + attributes: T_Attrs, +) -> tuple[T_Variables, T_Attrs, set[Hashable]]: + """ + Determines which variables are coordinate variables according to CF conventions. + + Should not actually do any decoding of values in the variables, only inspect and possibly alter their metadata. + """ + new_vars, attrs, coord_names = decode_cf_variables( + variables=variables, + attributes=attributes, + concat_characters=False, + mask_and_scale=False, + decode_times=False, + decode_coords="all", + drop_variables=None, # should have already been dropped + use_cftime=False, # done separately, to only the loadable_vars + decode_timedelta=False, # done separately, to only the loadable_vars + ) + return new_vars, attrs, coord_names + + +def construct_virtual_dataset( vars: Mapping[str, xr.Variable], indexes: MutableMapping[str, Index], + attrs: T_Attrs, coord_names: Iterable[str] | None = None, -) -> tuple[dict[str, xr.Variable], xr.Coordinates]: +) -> xr.Dataset: """ - Try to generate a set of coordinates that won't cause xarray to automatically build a pandas.Index for the 1D coordinates. + Constructs the virtual dataset but without automatically building a pandas.Index for 1D coordinates. Currently requires this function as a workaround unless xarray PR #8124 is merged. Will also preserve any loaded variables and indexes it is passed. """ - if coord_names is None: - coord_names = [] - - # split data and coordinate variables (promote dimension coordinates) + coord_vars: dict[str, T_VariableExpanded | xr.Variable] = {} data_vars = {} - coord_vars: dict[ - str, tuple[Hashable, Any, dict[Any, Any], dict[Any, Any]] | xr.Variable - ] = {} for name, var in vars.items(): - if name in coord_names or var.dims == (name,): + if name in coord_names: # use workaround to avoid creating IndexVariables described here https://github.com/pydata/xarray/pull/8107#discussion_r1311214263 if len(var.dims) == 1: dim1d, *_ = var.dims @@ -293,4 +313,24 @@ def separate_coords( coords = xr.Coordinates(coord_vars, indexes=indexes) - return data_vars, coords + print(indexes) + + print(coords) + print(type(coords)) + + print(data_vars) + + print(list(type(var._data) for var in data_vars.values())) + print(list(type(var.data) for var in data_vars.values())) + + vds = xr.Dataset( + data_vars, + coords=coords, + # indexes={}, # TODO should be added in a later version of xarray + attrs=attrs, + ) + + # TODO we should probably also use vds.set_close() to tell xarray how to close the file we opened + # TODO see how it's done inside `xr.decode_cf` + + return vds diff --git a/virtualizarr/readers/kerchunk.py b/virtualizarr/readers/kerchunk.py index 4686ce94..6b8c20ab 100644 --- a/virtualizarr/readers/kerchunk.py +++ b/virtualizarr/readers/kerchunk.py @@ -7,7 +7,11 @@ from xarray.core.indexes import Index from xarray.core.variable import Variable -from virtualizarr.backend import FileType, separate_coords +from virtualizarr.backend import ( + FileType, + construct_virtual_dataset, + determine_cf_coords, +) from virtualizarr.manifests import ChunkManifest, ManifestArray from virtualizarr.types.kerchunk import ( KerchunkArrRefs, @@ -176,14 +180,10 @@ def dataset_from_kerchunk_refs( if indexes is None: indexes = {} - data_vars, coords = separate_coords(vars, indexes, coord_names) - vds = Dataset( - data_vars, - coords=coords, - # indexes={}, # TODO should be added in a later version of xarray - attrs=ds_attrs, - ) + decoded_vars, decoded_attrs, coord_names = determine_cf_coords(vars, ds_attrs) + + vds = construct_virtual_dataset(decoded_vars, indexes, decoded_attrs, coord_names) return vds diff --git a/virtualizarr/readers/zarr.py b/virtualizarr/readers/zarr.py index b841d5c3..dc5cc854 100644 --- a/virtualizarr/readers/zarr.py +++ b/virtualizarr/readers/zarr.py @@ -8,7 +8,7 @@ from xarray.core.indexes import Index from xarray.core.variable import Variable -from virtualizarr.backend import separate_coords +from virtualizarr.backend import construct_virtual_dataset, determine_cf_coords from virtualizarr.manifests import ChunkManifest, ManifestArray from virtualizarr.zarr import ZArray @@ -53,16 +53,11 @@ def open_virtual_dataset_from_v3_store( else: indexes = dict(**indexes) # for type hinting: to allow mutation - data_vars, coords = separate_coords(vars, indexes, coord_names) + decoded_vars, decoded_attrs, coord_names = determine_cf_coords(vars, attrs) - ds = Dataset( - data_vars, - coords=coords, - # indexes={}, # TODO should be added in a later version of xarray - attrs=ds_attrs, - ) + vds = construct_virtual_dataset(decoded_vars, indexes, decoded_attrs, coord_names) - return ds + return vds def attrs_from_zarr_group_json(filepath: Path) -> dict: diff --git a/virtualizarr/tests/test_backend.py b/virtualizarr/tests/test_backend.py index 3b0c0315..9f4c489c 100644 --- a/virtualizarr/tests/test_backend.py +++ b/virtualizarr/tests/test_backend.py @@ -121,6 +121,33 @@ def test_coordinate_variable_attrs_preserved(self, netcdf4_file): } +class TestDetermineCoords: + def test_determine_all_coords(self, netcdf4_file_with_2d_coords): + vds = open_virtual_dataset(netcdf4_file_with_2d_coords, indexes={}) + + expected_dimension_coords = ["ocean_time", "s_rho"] + expected_2d_coords = ["lon_rho", "lat_rho", "h"] + expected_1d_non_dimension_coords = ["Cs_r"] + expected_scalar_coords = ["hc", "Vtransform"] + expected_coords = ( + expected_dimension_coords + + expected_2d_coords + + expected_1d_non_dimension_coords + + expected_scalar_coords + ) + assert set(vds.coords) == set(expected_coords) + + # print(vds.attrs) + # assert False + + # TODO assert coord attributes have been altered + for coord_name in expected_coords: + print(vds[coord_name].attrs) + # assert vds[coord_name].attrs[''] + + # assert False + + @network @requires_s3fs class TestReadFromS3: diff --git a/virtualizarr/tests/test_xarray.py b/virtualizarr/tests/test_xarray.py index 9db6e3a2..252f9e97 100644 --- a/virtualizarr/tests/test_xarray.py +++ b/virtualizarr/tests/test_xarray.py @@ -7,33 +7,66 @@ from virtualizarr.zarr import ZArray -def test_wrapping(): - chunks = (5, 10) - shape = (5, 20) - dtype = np.dtype("int32") - zarray = ZArray( - chunks=chunks, - compressor={"id": "zlib", "level": 1}, - dtype=dtype, - fill_value=0.0, - filters=None, - order="C", - shape=shape, - zarr_format=2, - ) - - chunks_dict = { - "0.0": {"path": "foo.nc", "offset": 100, "length": 100}, - "0.1": {"path": "foo.nc", "offset": 200, "length": 100}, - } - manifest = ChunkManifest(entries=chunks_dict) - marr = ManifestArray(zarray=zarray, chunkmanifest=manifest) - ds = xr.Dataset({"a": (["x", "y"], marr)}) - - assert isinstance(ds["a"].data, ManifestArray) - assert ds["a"].shape == shape - assert ds["a"].dtype == dtype - assert ds["a"].chunks == chunks +class TestWrapping: + def test_wrapping(self): + chunks = (5, 10) + shape = (5, 20) + dtype = np.dtype("int32") + zarray = ZArray( + chunks=chunks, + compressor={"id": "zlib", "level": 1}, + dtype=dtype, + fill_value=0.0, + filters=None, + order="C", + shape=shape, + zarr_format=2, + ) + + chunks_dict = { + "0.0": {"path": "foo.nc", "offset": 100, "length": 100}, + "0.1": {"path": "foo.nc", "offset": 200, "length": 100}, + } + manifest = ChunkManifest(entries=chunks_dict) + marr = ManifestArray(zarray=zarray, chunkmanifest=manifest) + ds = xr.Dataset({"a": (["x", "y"], marr)}) + + assert isinstance(ds["a"].data, ManifestArray) + assert ds["a"].shape == shape + assert ds["a"].dtype == dtype + assert ds["a"].chunks == chunks + + def test_wrap_no_indexes(self): + chunks = (10,) + shape = (20,) + dtype = np.dtype("int32") + zarray = ZArray( + chunks=chunks, + compressor={"id": "zlib", "level": 1}, + dtype=dtype, + fill_value=0.0, + filters=None, + order="C", + shape=shape, + zarr_format=2, + ) + + chunks_dict = { + "0.0": {"path": "foo.nc", "offset": 100, "length": 100}, + "0.1": {"path": "foo.nc", "offset": 200, "length": 100}, + } + manifest = ChunkManifest(entries=chunks_dict) + marr = ManifestArray(zarray=zarray, chunkmanifest=manifest) + + coords = xr.Coordinates({"x": (["x"], marr)}, indexes={}) + ds = xr.Dataset(coords=coords) + + assert isinstance(ds["x"].data, ManifestArray) + assert ds["x"].shape == shape + assert ds["x"].dtype == dtype + assert ds["x"].chunks == chunks + assert "x" in ds.coords + assert ds.xindexes == {} class TestEquals: