Skip to content

Commit

Permalink
working version
Browse files Browse the repository at this point in the history
  • Loading branch information
pcmoritz committed Aug 15, 2017
1 parent bd36c83 commit 8b2ffe6
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 17 deletions.
5 changes: 4 additions & 1 deletion cpp/src/arrow/python/python_to_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ Status CallCustomSerializationCallback(PyObject* elem, PyObject** serialized_obj
// must be decremented. This is done in SerializeDict in this file.
PyObject* result = PyObject_CallObject(pyarrow_serialize_callback, arglist);
Py_XDECREF(arglist);
if (!result) { return Status::NotImplemented("python error"); }
if (!result || !PyDict_Check(result)) {
// TODO(pcm): Propagate Python error here if !result
return Status::TypeError("serialization callback must return a valid dictionary");
}
*serialized_object = result;
}
return Status::OK();
Expand Down
101 changes: 95 additions & 6 deletions python/pyarrow/serialization.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ from libcpp.vector cimport vector as c_vector
from cpython.ref cimport PyObject
from cython.operator cimport dereference as deref

import cloudpickle as pickle

from pyarrow.lib cimport Buffer, NativeFile, check_status, _RecordBatchFileWriter

cdef extern from "arrow/python/python_to_arrow.h":
Expand All @@ -30,6 +32,10 @@ cdef extern from "arrow/python/python_to_arrow.h":

cdef shared_ptr[CRecordBatch] MakeBatch(shared_ptr[CArray] data)

cdef extern PyObject *pyarrow_serialize_callback

cdef extern PyObject *pyarrow_deserialize_callback

cdef extern from "arrow/python/arrow_to_python.h":

cdef CStatus DeserializeList(shared_ptr[CArray] array, int32_t start_idx,
Expand All @@ -45,6 +51,81 @@ cdef class PythonObject:
def __cinit__(self):
pass

# Types with special serialization handlers
type_to_type_id = dict()
whitelisted_types = dict()
types_to_pickle = set()
custom_serializers = dict()
custom_deserializers = dict()

def register_type(type, type_id, pickle=False, custom_serializer=None, custom_deserializer=None):
"""Add type to the list of types we can serialize.
Args:
type (type): The type that we can serialize.
type_id: A string of bytes used to identify the type.
pickle (bool): True if the serialization should be done with pickle.
False if it should be done efficiently with Arrow.
custom_serializer: This argument is optional, but can be provided to
serialize objects of the class in a particular way.
custom_deserializer: This argument is optional, but can be provided to
deserialize objects of the class in a particular way.
"""
type_to_type_id[type] = type_id
whitelisted_types[type_id] = type
if pickle:
types_to_pickle.add(type_id)
if custom_serializer is not None:
custom_serializers[type_id] = custom_serializer
custom_deserializers[type_id] = custom_deserializer

def serialization_callback(obj):
if type(obj) not in type_to_type_id:
raise "error"
type_id = type_to_type_id[type(obj)]
if type_id in types_to_pickle:
serialized_obj = {"data": pickle.dumps(obj), "pickle": True}
elif type_id in custom_serializers:
serialized_obj = {"data": custom_serializers[type_id](obj)}
else:
if hasattr(obj, "__dict__"):
serialized_obj = obj.__dict__
else:
raise "error"
return dict(serialized_obj, **{"_pytype_": type_id})

def deserialization_callback(serialized_obj):
type_id = serialized_obj["_pytype_"]

if "pickle" in serialized_obj:
# The object was pickled, so unpickle it.
obj = pickle.loads(serialized_obj["data"])
else:
assert type_id not in types_to_pickle
if type_id not in whitelisted_types:
raise "error"
type = whitelisted_types[type_id]
if type_id in custom_deserializers:
obj = custom_deserializers[type_id](serialized_obj["data"])
else:
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = type.__new__(type, *serialized_obj["_ray_getnewargs_"])
else:
obj = type.__new__(type)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
return obj

def set_serialization_callbacks(serialization_callback, deserialization_callback):
global pyarrow_serialize_callback, pyarrow_deserialize_callback
# TODO(pcm): Are refcounts correct here?
print("setting serialization callback")
pyarrow_serialize_callback = <PyObject*> serialization_callback
print("val1 is", <object> pyarrow_serialize_callback)
pyarrow_deserialize_callback = <PyObject*> deserialization_callback
print("val2 is", <object> pyarrow_deserialize_callback)

# Main entry point for serialization
def serialize_sequence(object value):
cdef int32_t recursion_depth = 0
Expand All @@ -57,18 +138,20 @@ def serialize_sequence(object value):
sequences.push_back(<PyObject*> value)
check_status(SerializeSequences(sequences, recursion_depth, &array, tensors))
result.batch = MakeBatch(array)
num_tensors = 0
for tensor in tensors:
check_status(NdarrayToTensor(c_default_memory_pool(), <object> tensor, &out))
result.tensors.push_back(out)
return result
num_tensors += 1
return result, num_tensors

# Main entry point for deserialization
def deserialize_sequence(PythonObject value, object base):
cdef PyObject* result
check_status(DeserializeList(deref(value.batch).column(0), 0, deref(value.batch).num_rows(), <PyObject*> base, value.tensors, &result))
return <object> result

def write_python_object(PythonObject value, NativeFile sink):
def write_python_object(PythonObject value, int32_t num_tensors, NativeFile sink):
cdef shared_ptr[OutputStream] stream
sink.write_handle(&stream)
cdef shared_ptr[CRecordBatchStreamWriter] writer
Expand All @@ -79,6 +162,9 @@ def write_python_object(PythonObject value, NativeFile sink):
cdef int64_t body_length

with nogil:
# write number of tensors
check_status(stream.get().Write(<uint8_t*> &num_tensors, sizeof(int32_t)))

check_status(CRecordBatchStreamWriter.Open(stream.get(), schema, &writer))
check_status(deref(writer).WriteRecordBatch(deref(batch)))
check_status(deref(writer).Close())
Expand All @@ -93,18 +179,21 @@ def read_python_object(NativeFile source):
cdef shared_ptr[CRecordBatchStreamReader] reader
cdef shared_ptr[CTensor] tensor
cdef int64_t offset
cdef int64_t bytes_read
cdef int32_t num_tensors

with nogil:
# read number of tensors
check_status(stream.get().Read(sizeof(int32_t), &bytes_read, <uint8_t*> &num_tensors))

check_status(CRecordBatchStreamReader.Open(<shared_ptr[InputStream]> stream, &reader))
check_status(reader.get().ReadNextRecordBatch(&result.batch))

check_status(deref(stream).Tell(&offset))

while True:
s = ReadTensor(offset, stream.get(), &tensor)
for i in range(num_tensors):
check_status(ReadTensor(offset, stream.get(), &tensor))
result.tensors.push_back(tensor)
if not s.ok():
break
check_status(deref(stream).Tell(&offset))

return result
80 changes: 70 additions & 10 deletions python/pyarrow/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,85 @@
from __future__ import print_function

import os
import string
import sys

import pyarrow as pa
import numpy as np
from numpy.testing import assert_equal

obj = pa.lib.serialize_sequence([np.array([1, 2, 3]), None, np.array([4, 5, 6])])
def serialization_callback(value):
if isinstance(value, np.ndarray):
return {"data": value.tolist(), "_pytype_": str(value.dtype.str)}
else:
return {"data": str(value), "_pytype_": "long"}

SIZE = 4096
def deserialization_callback(value):
data = value["data"]
if value["_pytype_"] == "long":
return int(data)
else:
return np.array(data, dtype=np.dtype(value["_pytype_"]))

pa.lib.set_serialization_callbacks(serialization_callback, deserialization_callback)

def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str

def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))

pa.lib.register_type(np.ndarray, 20 * b"\x01", pickle=False,
custom_serializer=array_custom_serializer,
custom_deserializer=array_custom_deserializer)

if sys.version_info >= (3, 0):
long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]
else:
long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] # noqa: E501,F821

PRIMITIVE_OBJECTS = [
0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999,
[1 << 100, [1 << 100]], "a", string.printable, "\u262F",
u"hello world", u"\xff\xfe\x9c\x001\x000\x00", None, True,
False, [], (), {}, np.int8(3), np.int32(4), np.int64(5),
np.uint8(3), np.uint32(4), np.uint64(5), np.float32(1.9),
np.float64(1.9), np.zeros([100, 100]),
np.random.normal(size=[100, 100]), np.array(["hi", 3]),
np.array(["hi", 3], dtype=object)] + long_extras

COMPLEX_OBJECTS = [
[[[[[[[[[[[[]]]]]]]]]]]],
{"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)},
{(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {
(): {(): {}}}}}}}}}}}}},
((((((((((),),),),),),),),),),
{"a": {"b": {"c": {"d": {}}}}}]

def serialization_roundtrip(value, f):
f.seek(0)
serialized, num_tensors = pa.lib.serialize_sequence(value)
pa.lib.write_python_object(serialized, num_tensors, f)
f.seek(0)
res = pa.lib.read_python_object(f)
base = None
result = pa.lib.deserialize_sequence(res, base)
assert_equal(value, result)

# Create a large memory mapped file
SIZE = 100 * 1024 * 1024 # 100 MB
arr = np.random.randint(0, 256, size=SIZE).astype('u1')
data = arr.tobytes()[:SIZE]
path = os.path.join("/tmp/temp")
path = os.path.join("/tmp/pyarrow-temp-file")
with open(path, 'wb') as f:
f.write(data)

f = pa.memory_map(path, mode="w")

pa.lib.write_python_object(obj, f)

f = pa.memory_map(path, mode="r")
MEMORY_MAPPED_FILE = pa.memory_map(path, mode="r+")

res = pa.lib.read_python_object(f)
def test_primitive_serialization():
for obj in PRIMITIVE_OBJECTS:
serialization_roundtrip([obj], MEMORY_MAPPED_FILE)

pa.lib.deserialize_sequence(res, res)
def test_complex_serialization():
for obj in COMPLEX_OBJECTS:
serialization_roundtrip([obj], MEMORY_MAPPED_FILE)

0 comments on commit 8b2ffe6

Please sign in to comment.