Skip to content

Commit

Permalink
Update Dask serialization to use Serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed May 8, 2020
1 parent 4b085dc commit e37b4a1
Showing 1 changed file with 10 additions and 38 deletions.
48 changes: 10 additions & 38 deletions python/cudf/cudf/comm/serialize.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,31 @@
import pickle

import cudf
import cudf.core.groupby.groupby

# all (de-)serializtion are attached to cudf Objects:
# Series/DataFrame/Index/Column/Buffer/etc
serializable_classes = (
cudf.CategoricalDtype,
cudf.DataFrame,
cudf.Index,
cudf.MultiIndex,
cudf.Series,
cudf.core.groupby.groupby.GroupBy,
cudf.core.groupby.groupby._Grouping,
cudf.core.column.column.ColumnBase,
cudf.core.buffer.Buffer,
)
import cudf # noqa: F401
from cudf.core.abc import Serializable

try:
from distributed.protocol import dask_deserialize, dask_serialize
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.utils import log_errors

@cuda_serialize.register(serializable_classes)
@cuda_serialize.register(Serializable)
def cuda_serialize_cudf_object(x):
with log_errors():
header, frames = x.serialize()
assert all((type(f) is cudf.core.buffer.Buffer) for f in frames)
header["lengths"] = [f.nbytes for f in frames]
return header, frames
return x.device_serialize()

# all (de-)serializtion are attached to cudf Objects:
# Series/DataFrame/Index/Column/Buffer/etc
@dask_serialize.register(serializable_classes)
@dask_serialize.register(Serializable)
def dask_serialize_cudf_object(x):
header, frames = cuda_serialize_cudf_object(x)
with log_errors():
frames = [f.to_host_array().data for f in frames]
return header, frames
return x.host_serialize()

@cuda_deserialize.register(serializable_classes)
@dask_deserialize.register(serializable_classes)
@cuda_deserialize.register(Serializable)
@dask_deserialize.register(Serializable)
def deserialize_cudf_object(header, frames):
with log_errors():
if header["serializer"] == "cuda":
for f in frames:
# some frames are empty -- meta/empty partitions/etc
if len(f) > 0:
assert hasattr(f, "__cuda_array_interface__")
return Serializable.device_deserialize(header, frames)
elif header["serializer"] == "dask":
frames = [memoryview(f) for f in frames]

cudf_typ = pickle.loads(header["type-serialized"])
cudf_obj = cudf_typ.deserialize(header, frames)
return cudf_obj
return Serializable.host_deserialize(header, frames)


except ImportError:
Expand Down

0 comments on commit e37b4a1

Please sign in to comment.