Skip to content

Commit

Permalink
Add the Serializable abstract base class
Browse files Browse the repository at this point in the history
This provides a generic interface for handling serialization in Python.
Handles serialization with Dask through the `"cuda"` and `"dask"`
serializers. Also implements pickle serialization using `__reduce_ex__`.
For Python versions with support for pickle's protocol 5, the class also
supports out-of-band buffers for more efficient serialization.
Subclasses are responsible for implementing `serialize` and
`deserialize` to/from a Python `dict` with the `header` and a collection
of `frames`. The abstract base class handles all other serialization
using these two methods.
  • Loading branch information
jakirkham committed May 8, 2020
1 parent 31c65ec commit df37353
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions python/cudf/cudf/core/abc.py
Original file line number Diff line number Diff line change
@@ -1 +1,56 @@
# Copyright (c) 2020, NVIDIA CORPORATION.

import abc
import pickle
from abc import abstractmethod

import rmm


class Serializable(abc.ABC):
@abstractmethod
def serialize(self):
pass

@classmethod
@abstractmethod
def deserialize(cls, header, frames):
pass

def device_serialize(self):
header, frames = self.serialize()
assert all((type(f) is cudf.core.buffer.Buffer) for f in frames)
header["type-serialized"] = pickle.dumps(type(self))
header["lengths"] = [f.nbytes for f in frames]
return header, frames

@classmethod
def device_deserialize(cls, header, frames):
for f in frames:
# some frames are empty -- meta/empty partitions/etc
if len(f) > 0:
assert hasattr(f, "__cuda_array_interface__")

typ = pickle.loads(header["type-serialized"])
obj = typ.deserialize(header, frames)

return obj

def host_serialize(self):
header, frames = self.device_serialize()
frames = [f.to_host_array().data for f in frames]
return header, frames

@classmethod
def host_deserialize(cls, header, frames):
frames = [
rmm.DeviceBuffer.to_device(memoryview(f).cast("B")) for f in frames
]
obj = cls.device_deserialize(header, frames)
return obj

def __reduce_ex__(self, protocol):
header, frames = self.host_serialize()
if protocol >= 5:
frames = [pickle.PickleBuffer(f) for f in frames]
return self.host_deserialize, (header, frames)

0 comments on commit df37353

Please sign in to comment.