Skip to content

Commit

Permalink
Require opt-in to pickling writable Sessions and Stores.
Browse files Browse the repository at this point in the history
This is uglier than it used to because we cannot deserialize to a
read-only Session or Store by default anymore.
Well we could, but it would take some extra work.

Closes #478
xref #185 (comment)
  • Loading branch information
dcherian committed Jan 31, 2025
1 parent 0ea64bf commit 0d96c8b
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 24 deletions.
31 changes: 28 additions & 3 deletions icechunk-python/python/icechunk/session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import contextlib
from collections.abc import Generator
from typing import Self

from icechunk import (
Expand Down Expand Up @@ -86,26 +88,49 @@ class Session:
"""A session object that allows for reading and writing data from an Icechunk repository."""

_session: PySession
_allow_pickling: bool

def __init__(self, session: PySession):
def __init__(self, session: PySession, _allow_pickling: bool = False):
self._session = session
self._allow_distributed_write = False
self._allow_pickling = _allow_pickling

def __eq__(self, value: object) -> bool:
if not isinstance(value, Session):
return False
return self._session == value._session

def __getstate__(self) -> object:
if not self._allow_pickling and not self.read_only:
print(self._allow_pickling, self.read_only)
raise ValueError(
"You must opt-in to pickle writable sessions in a distributed context "
"using the `Session.allow_pickling` context manager. "
# link to docs
"If you are using xarray's `Dataset.to_zarr` method, please use "
"`icechunk.xarray.to_icechunk` instead."
)
state = {
"_session": self._session.as_bytes(),
"_allow_pickling": self._allow_pickling,
}
return state

def __setstate__(self, state: object) -> None:
if not isinstance(state, dict):
raise ValueError("Invalid state")
self._session = PySession.from_bytes(state["_session"])
self._allow_pickling = state["_allow_pickling"]

@contextlib.contextmanager
def allow_pickling(self) -> Generator[None, None, None]:
"""
Context manager to allow unpickling this store if writable.
"""
try:
self._allow_pickling = True
yield
finally:
self._allow_pickling = False

@property
def read_only(self) -> bool:
Expand Down Expand Up @@ -171,7 +196,7 @@ def store(self) -> IcechunkStore:
IcechunkStore
A zarr Store object for reading and writing data from the repository.
"""
return IcechunkStore(self._session.store)
return IcechunkStore(self._session.store, self._allow_pickling)

def all_virtual_chunk_locations(self) -> list[str]:
"""
Expand Down
9 changes: 8 additions & 1 deletion icechunk-python/python/icechunk/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ def _byte_request_to_tuple(

class IcechunkStore(Store, SyncMixin):
_store: PyStore
_allow_pickling: bool

def __init__(
self,
store: PyStore,
allow_pickling: bool,
*args: Any,
**kwargs: Any,
):
Expand All @@ -52,6 +54,7 @@ def __init__(
)
self._store = store
self._is_open = True
self._allow_pickling = allow_pickling

def __eq__(self, value: object) -> bool:
if not isinstance(value, IcechunkStore):
Expand All @@ -60,6 +63,10 @@ def __eq__(self, value: object) -> bool:

def __getstate__(self) -> object:
# we serialize the Rust store as bytes
if not self._allow_pickling and not self._store.read_only:
raise ValueError(
"You must opt in to pickling this *writable* store by using `Session.allow_pickling` context manager"
)
d = self.__dict__.copy()
d["_store"] = self._store.as_bytes()
return d
Expand All @@ -75,7 +82,7 @@ def __setstate__(self, state: Any) -> None:
def session(self) -> "Session":
from icechunk import Session

return Session(self._store.session)
return Session(self._store.session, self._allow_pickling)

async def clear(self) -> None:
"""Clear the store.
Expand Down
13 changes: 7 additions & 6 deletions icechunk-python/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
def test_distributed() -> None:
with distributed.Client(): # type: ignore [no-untyped-call]
ds = create_test_data().chunk(dim1=3, dim2=4)
with roundtrip(ds) as actual:
with roundtrip(ds, allow_pickling=True) as actual:
assert_identical(actual, ds)

# with pytest.raises(ValueError, match="Session cannot be serialized"):
# with roundtrip(ds, allow_distributed_write=False) as actual:
# pass
# FIXME: this should be nicer! this TypeError is from distributed
with pytest.raises(TypeError):
with roundtrip(ds, allow_pickling=False) as actual:
pass


def test_threaded() -> None:
with dask.config.set(scheduler="threads"):
ds = create_test_data().chunk(dim1=3, dim2=4)
with roundtrip(ds) as actual:
assert_identical(actual, ds)
# with roundtrip(ds, allow_distributed_write=False) as actual:
# assert_identical(actual, ds)
with roundtrip(ds, allow_pickling=False) as actual:
assert_identical(actual, ds)
14 changes: 9 additions & 5 deletions icechunk-python/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ def test_pickle_read_only(tmp_repo: Repository) -> None:
tmp_store = tmp_session.store

assert tmp_store._read_only is False
assert tmp_session.read_only is False

roundtripped = pickle.loads(pickle.dumps(tmp_store))
assert roundtripped._read_only is False
with pytest.raises(ValueError, match="You must opt in"):
roundtripped = pickle.loads(pickle.dumps(tmp_store))

# with tmp_store.preserve_read_only():
# roundtripped = pickle.loads(pickle.dumps(tmp_store))
# assert roundtripped._read_only is False
with tmp_session.allow_pickling():
roundtripped = pickle.loads(pickle.dumps(tmp_session.store))
assert roundtripped._read_only is False

roundtripped = pickle.loads(pickle.dumps(tmp_session))
assert roundtripped.read_only is False

assert tmp_store._read_only is False

Expand Down
21 changes: 12 additions & 9 deletions icechunk-python/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ def create_test_data(


@contextlib.contextmanager
def roundtrip(data: xr.Dataset) -> Generator[xr.Dataset, None, None]:
def roundtrip(
data: xr.Dataset, allow_pickling: bool = False
) -> Generator[xr.Dataset, None, None]:
with tempfile.TemporaryDirectory() as tmpdir:
repo = Repository.create(local_filesystem_storage(tmpdir))
session = repo.writable_session("main")
to_icechunk(data, store=session.store, mode="w")

# if allow_distributed_write:
# with session.allow_distributed_write():
# to_icechunk(data, store=session.store, mode="w")
# else:
# to_icechunk(data, store=session.store, mode="w")
if allow_pickling:
with session.allow_pickling():
to_icechunk(data, store=session.store, mode="w")
with xr.open_zarr(session.store, consolidated=False) as ds:
yield ds

with xr.open_zarr(session.store, consolidated=False) as ds:
yield ds
else:
to_icechunk(data, store=session.store, mode="w")
with xr.open_zarr(session.store, consolidated=False) as ds:
yield ds


def test_xarray_to_icechunk() -> None:
Expand Down

0 comments on commit 0d96c8b

Please sign in to comment.