Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Original variable encodings are retained #471

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions pangeo_forge_recipes/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@ class XarraySchema(TypedDict):


def dataset_to_schema(ds: xr.Dataset) -> XarraySchema:
"""Convert the output of `dataset.to_dict(data=False)` to a schema
"""Convert the output of `dataset.to_dict(data=False, encoding=True)` to a schema
(Basically justs adds chunks, which is not part of the Xarray ouput).
"""

d = ds.to_dict(data=False)
# Remove redundant encoding options
for v in ds.variables:
for option in ["_FillValue", "source"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the rationale for special casing these two options?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I excluded these as they were causing certain test failures where expected schemas were compared with actual. E.g. any combiner tests using has_correct_schema():

https://github.com/pangeo-forge/pangeo-forge-recipes/blob/beam-refactor/tests/test_combiners.py#L98-L102

def has_correct_schema(expected_schema):
    def _check_results(actual):
        assert len(actual) == 1
        schema = actual[0]
        assert schema == expected_schema

The source will be unique to each original source data product, and the _FillValue appeared to be added automatically (I can't recall the specific issue with the latter though).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the latter again, and when _FillValue is retained, it's being set to nan in expected_schema (as generated by the original ds.to_dict(data=False, encoding=True), but only for the lat and lon coords. However, the actual schema doesn't contain _FillValue for lat/lon, and the assert fails.

# TODO: should be okay to remove _FillValue?
if option in ds[v].encoding:
del ds[v].encoding[option]
d = ds.to_dict(data=False, encoding=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that when I first started working on this, this option didn't even exist yet! See pydata/xarray#6634

Nice when things come together. 😄

return XarraySchema(
attrs=d.get("attrs"),
coords=d.get("coords"),
Expand Down Expand Up @@ -164,6 +170,8 @@ def _combine_vars(v1, v2, concat_dim, allow_both=False):
raise DatasetCombineError(f"Can't merge datasets with the same variable {vname}")
attrs = _combine_attrs(v1[vname]["attrs"], v2[vname]["attrs"])
dtype = _combine_dtype(v1[vname]["dtype"], v2[vname]["dtype"])
# Can combine encoding using the same approach as attrs
encoding = _combine_attrs(v1[vname]["encoding"], v2[vname]["encoding"])
Comment on lines +173 to +174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brilliant!

(d1, s1), (d2, s2) = (
(v1[vname]["dims"], v1[vname]["shape"]),
(v2[vname]["dims"], v2[vname]["shape"]),
Expand All @@ -182,7 +190,14 @@ def _combine_vars(v1, v2, concat_dim, allow_both=False):
)
else:
shape.append(l1)
new_vars[vname] = {"dims": dims, "attrs": attrs, "dtype": dtype, "shape": tuple(shape)}
new_vars[vname] = {
"dims": dims,
"attrs": attrs,
"dtype": dtype,
"shape": tuple(shape),
"encoding": encoding,
}

return new_vars


Expand All @@ -195,13 +210,10 @@ def _to_variable(template, target_chunks):
chunks = tuple(target_chunks[dim] for dim in dims)
# we pick zeros as the safest value to initialize empty data with
# will only be used for dimension coordinates
# WARNING: there are lots of edge cases aroudn time!
# Xarray will pick a time encoding for the dataset (e.g. "days since days since 1970-01-01")
# and this may not be compatible with the actual values in the time coordinate
# (which we don't know yet)
data = dsa.zeros(shape=shape, chunks=chunks, dtype=dtype)
# TODO: add more encoding
encoding = {"chunks": chunks}
encoding = template.get("encoding", {})
encoding["chunks"] = chunks
return xr.Variable(dims=dims, data=data, attrs=template["attrs"], encoding=encoding)


Expand Down
8 changes: 8 additions & 0 deletions tests/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,12 @@ def make_ds(nt=10, non_dim_coords=False):
coords=coords,
attrs={"conventions": "CF 1.6"},
)

# Add time coord encoding
# Remove "%H:%M:%s" as it will be dropped when time is 0:0:0
ds.time.encoding = {
"units": f"days since {time[0].strftime('%Y-%m-%d')}",
"calendar": "proleptic_gregorian",
}

return ds
2 changes: 2 additions & 0 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_schema_to_template_ds(specified_chunks):
chunksize = var.chunksizes[dim]
expected_chunksize = _expected_chunks(size, specified_chunks.get(dim, None))
assert chunksize == expected_chunksize
# Confirm original time units have been preserved
assert ds.time.encoding.get("units") == dst.time.encoding.get("units")
schema2 = dataset_to_schema(dst)
assert schema == schema2

Expand Down
5 changes: 4 additions & 1 deletion tests/test_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def temp_store(tmp_path):
def test_store_dataset_fragment(temp_store):

ds = make_ds(non_dim_coords=True)
schema = ds.to_dict(data=False)
schema = ds.to_dict(data=False, encoding=True)
schema["chunks"] = {}

ds.to_zarr(temp_store)
Expand Down Expand Up @@ -138,3 +138,6 @@ def test_store_dataset_fragment(temp_store):
ds_target = xr.open_dataset(temp_store, engine="zarr").load()

xr.testing.assert_identical(ds, ds_target)
# assert_identical() doesn't check encoding
# Checking the original time encoding units should be sufficient
assert ds.time.encoding.get("units") == ds_target.time.encoding.get("units")