Skip to content

Commit

Permalink
Deserialization avoids pyo3 borrows, uses FFI for list and dict
Browse files Browse the repository at this point in the history
  • Loading branch information
ijl committed Dec 4, 2018
1 parent 4eed3fd commit ed18c84
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
45 changes: 25 additions & 20 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::IntoPyPointer;
use serde::de::{self, DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor};
use smallvec::SmallVec;
use std::borrow::Cow;
Expand All @@ -14,11 +15,11 @@ pub fn deserialize(py: Python, data: &str) -> PyResult<PyObject> {
let seed = JsonValue::new(py);
let mut deserializer = serde_json::Deserializer::from_str(data);
match seed.deserialize(&mut deserializer) {
Ok(py_object) => {
Ok(py_ptr) => {
deserializer
.end()
.map_err(|e| JSONDecodeError::py_err((e.to_string(), "", 0)))?;
Ok(py_object)
Ok(unsafe { PyObject::from_owned_ptr(py, py_ptr) })
}
Err(e) => {
return Err(JSONDecodeError::py_err((e.to_string(), "", 0)));
Expand All @@ -38,7 +39,7 @@ impl<'a> JsonValue<'a> {
}

impl<'de, 'a> DeserializeSeed<'de> for JsonValue<'a> {
type Value = PyObject;
type Value = *mut pyo3::ffi::PyObject;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
Expand All @@ -49,81 +50,85 @@ impl<'de, 'a> DeserializeSeed<'de> for JsonValue<'a> {
}

impl<'de, 'a> Visitor<'de> for JsonValue<'a> {
type Value = PyObject;
type Value = *mut pyo3::ffi::PyObject;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("JSON")
}

fn visit_unit<E>(self) -> Result<Self::Value, E> {
Ok(self.py.None())
Ok(unsafe { pyo3::ffi::Py_None() })
}

fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_object(self.py))
Ok(value.into_object(self.py).into_ptr())
}

fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_object(self.py))
Ok(value.into_object(self.py).into_ptr())
}

fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_object(self.py))
Ok(value.into_object(self.py).into_ptr())
}

fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_object(self.py))
Ok(PyFloat::new(self.py, value).into_ptr())
}

fn visit_borrowed_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_object(self.py))
Ok(PyString::new(self.py, value).into_ptr())
}

fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_object(self.py))
Ok(PyString::new(self.py, value).into_ptr())
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut elements: SmallVec<[PyObject; 8]> = SmallVec::new();
let mut elements: SmallVec<[*mut pyo3::ffi::PyObject; 8]> = SmallVec::new();
while let Some(elem) = seq.next_element_seed(self.clone())? {
elements.push(elem);
}
Ok(elements.as_slice().to_object(self.py))
let ptr = unsafe { pyo3::ffi::PyList_New(elements.len() as pyo3::ffi::Py_ssize_t) };
for (i, obj) in elements.iter().enumerate() {
unsafe { pyo3::ffi::PyList_SetItem(ptr, i as pyo3::ffi::Py_ssize_t, *obj) };
}
Ok(ptr)
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut elements: SmallVec<[(PyObject, PyObject); 8]> = SmallVec::new();
let dict_ptr = PyDict::new(self.py).into_ptr();
while let Some((key, value)) = map.next_entry_seed(PhantomData::<Cow<str>>, self.clone())? {
elements.push((key.to_object(self.py), value));
}
let dict = PyDict::new(self.py);
for (key, value) in elements.iter() {
dict.set_item(key, value).unwrap()
let _ = unsafe { pyo3::ffi::PyDict_SetItem(
dict_ptr,
PyString::new(self.py, &key).into_ptr(),
value,
) };
}
Ok(dict.into())
Ok(dict_ptr)
}
}
6 changes: 3 additions & 3 deletions src/typeref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ pub fn init_typerefs(py: Python) {
LIST_PTR = PyList::empty(py).as_ref().get_type_ptr();
TUPLE_PTR = PyTuple::empty(py).as_ref(py).get_type_ptr();
NONE_PTR = py.None().as_ref(py).get_type_ptr();
BOOL_PTR = true.to_object(py).as_ref(py).get_type_ptr();
INT_PTR = 1.to_object(py).as_ref(py).get_type_ptr();
FLOAT_PTR = 1.0.to_object(py).as_ref(py).get_type_ptr();
BOOL_PTR = true.into_object(py).as_ref(py).get_type_ptr();
INT_PTR = 1.into_object(py).as_ref(py).get_type_ptr();
FLOAT_PTR = 1.0.into_object(py).as_ref(py).get_type_ptr();
});
}
20 changes: 19 additions & 1 deletion test/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,33 @@ def test_bool(self):
self.assertEqual(orjson.dumps(obj), ref.encode('utf-8'))
self.assertEqual(orjson.loads(ref), obj)

def test_bool_array(self):
"""
bool array
"""
obj = [True] * 256
ref = ('[' + ('true,' * 255) + 'true]').encode('utf-8')
self.assertEqual(orjson.dumps(obj), ref)
self.assertEqual(orjson.loads(ref), obj)

def test_none(self):
"""
NoneType
null
"""
obj = None
ref = u'null'
self.assertEqual(orjson.dumps(obj), ref.encode('utf-8'))
self.assertEqual(orjson.loads(ref), obj)

def test_null_array(self):
"""
null array
"""
obj = [None] * 256
ref = ('[' + ('null,' * 255) + 'null]').encode('utf-8')
self.assertEqual(orjson.dumps(obj), ref)
self.assertEqual(orjson.loads(ref), obj)

def test_int_64(self):
"""
int 64-bit
Expand Down

0 comments on commit ed18c84

Please sign in to comment.