Skip to content

Commit

Permalink
Remove VLenUTF8 from filters to avoid double encoding error pydata/xa…
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed May 5, 2022
1 parent 6cf0ccb commit 8bb8ee6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
10 changes: 10 additions & 0 deletions sgkit/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, MutableMapping, Optional, Union

import fsspec
import numcodecs
import xarray as xr
from xarray import Dataset

Expand Down Expand Up @@ -38,6 +39,15 @@ def save_dataset(
for v in ds:
# Workaround for https://github.com/pydata/xarray/issues/4380
ds[v].encoding.pop("chunks", None)

# Remove VLenUTF8 from filters to avoid double encoding error https://github.com/pydata/xarray/issues/3476
filters = ds[v].encoding.get("filters", None)
var_len_str_codec = numcodecs.VLenUTF8()
if filters is not None and var_len_str_codec in filters:
filters = list(filters)
filters.remove(var_len_str_codec)
ds[v].encoding["filters"] = filters

ds.to_zarr(store, **kwargs)


Expand Down
11 changes: 10 additions & 1 deletion sgkit/tests/io/vcf/test_vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from numcodecs import Blosc, PackBits, VLenUTF8
from numpy.testing import assert_allclose, assert_array_equal

from sgkit import load_dataset
from sgkit import load_dataset, save_dataset
from sgkit.io.utils import FLOAT32_FILL, INT_FILL, INT_MISSING
from sgkit.io.vcf import (
MaxAltAllelesExceededWarning,
partition_into_regions,
vcf_to_zarr,
)
from sgkit.io.vcf.vcf_reader import zarr_array_sizes
from sgkit.tests.io.test_dataset import assert_identical

from .utils import path_for_test

Expand Down Expand Up @@ -95,6 +96,14 @@ def test_vcf_to_zarr__small_vcf(shared_datadir, is_path, tmp_path):
assert_array_equal(ds["call_genotype_mask"], call_genotype < 0)
assert_array_equal(ds["call_genotype_phased"], call_genotype_phased)

# save and load again to test https://github.com/pydata/xarray/issues/3476
with pytest.warns(xr.coding.variables.SerializationWarning):
path2 = tmp_path / "ds2.zarr"
if not is_path:
path2 = str(path2)
save_dataset(ds, path2)
assert_identical(ds, load_dataset(path2))


@pytest.mark.parametrize(
"is_path",
Expand Down

0 comments on commit 8bb8ee6

Please sign in to comment.