Skip to content

Commit

Permalink
codify the stack-based nature of the guard
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jan 15, 2024
1 parent e1190ed commit 4b82950
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 66 deletions.
109 changes: 63 additions & 46 deletions src/recursion_guard.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ahash::AHashSet;
use std::hash::Hash;
use std::mem::MaybeUninit;

type RecursionKey = (
// Identifier for the input object, e.g. the id() of a Python dict
Expand All @@ -14,7 +14,7 @@ type RecursionKey = (
/// It's used in `validators/definition` to detect when a reference is reused within itself.
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard {
ids: SmallContainer<RecursionKey>,
ids: RecursionStack,
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
depth: u8,
Expand All @@ -33,10 +33,10 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi

impl RecursionGuard {
// insert a new value
// * return `None` if the array/set already had it in it
// * return `Some(index)` if the array didn't have it in it and it was inserted
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> Option<usize> {
self.ids.contains_or_insert((obj_id, node_id))
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
self.ids.insert((obj_id, node_id))
}

// see #143 this is used as a backup in case the identity check recursion guard fails
Expand Down Expand Up @@ -68,76 +68,93 @@ impl RecursionGuard {
self.depth = self.depth.saturating_sub(1);
}

pub fn remove(&mut self, obj_id: usize, node_id: usize, index: usize) {
self.ids.remove(&(obj_id, node_id), index);
pub fn remove(&mut self, obj_id: usize, node_id: usize) {
self.ids.remove(&(obj_id, node_id));
}
}

// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
const ARRAY_SIZE: usize = 16;

#[derive(Debug, Clone)]
enum SmallContainer<T> {
Array([Option<T>; ARRAY_SIZE]),
Set(AHashSet<T>),
enum RecursionStack {
Array {
data: [MaybeUninit<RecursionKey>; ARRAY_SIZE],
len: usize,
},
Set(AHashSet<RecursionKey>),
}

impl<T: Copy> Default for SmallContainer<T> {
impl Default for RecursionStack {
fn default() -> Self {
Self::Array([None; ARRAY_SIZE])
Self::Array {
data: std::array::from_fn(|_| MaybeUninit::uninit()),
len: 0,
}
}
}

impl<T: Eq + Hash + Clone> SmallContainer<T> {
impl RecursionStack {
// insert a new value
// * return `None` if the array/set already had it in it
// * return `Some(index)` if the array didn't have it in it and it was inserted
pub fn contains_or_insert(&mut self, v: T) -> Option<usize> {
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, v: RecursionKey) -> bool {
match self {
Self::Array(array) => {
for (index, op_value) in array.iter_mut().enumerate() {
if let Some(existing) = op_value {
if existing == &v {
return None;
Self::Array { data, len } => {
if *len < ARRAY_SIZE {
for value in data.iter().take(*len) {
// Safety: reading values within bounds
if unsafe { value.assume_init() } == v {
return false;
}
} else {
*op_value = Some(v);
return Some(index);
}
}

// No array slots exist; convert to set
let mut set: AHashSet<T> = AHashSet::with_capacity(ARRAY_SIZE + 1);
for existing in array.iter_mut() {
set.insert(existing.take().unwrap());
data[*len].write(v);
*len += 1;
true
} else {
let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1);
for existing in data.iter() {
// Safety: the array is fully initialized
set.insert(unsafe { existing.assume_init() });
}
let inserted = set.insert(v);
*self = Self::Set(set);
inserted
}
set.insert(v);
*self = Self::Set(set);
// id doesn't matter here as we'll be removing from a set
Some(0)
}
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Self::Set(set) => {
if set.insert(v) {
// again id doesn't matter here as we'll be removing from a set
Some(0)
} else {
None
}
}
Self::Set(set) => set.insert(v),
}
}

pub fn remove(&mut self, v: &T, index: usize) {
pub fn remove(&mut self, v: &RecursionKey) {
match self {
Self::Array(array) => {
debug_assert!(array[index].as_ref() == Some(v), "remove did not match insert");
array[index] = None;
Self::Array { data, len } => {
*len = len.checked_sub(1).expect("remove from empty recursion guard");
// Safety: this is reading what was the back of the initialized array
let removed = unsafe { data.get_unchecked_mut(*len) };
assert!(unsafe { removed.assume_init_ref() } == v, "remove did not match insert");
// this should compile away to a noop
unsafe { std::ptr::drop_in_place(removed.as_mut_ptr()) }
}
Self::Set(set) => {
set.remove(v);
}
}
}
}

impl Drop for RecursionStack {
fn drop(&mut self) {
// This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed
// desirable to leave this in for safety in case that should change in the future
if let Self::Array { data, len } = self {
for value in data.iter_mut().take(*len) {
// Safety: reading values within bounds
unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) };
}
}
}
}
10 changes: 5 additions & 5 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,24 +345,24 @@ pub struct SerRecursionGuard {
}

impl SerRecursionGuard {
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<(usize, usize)> {
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
let id = value.as_ptr() as usize;
let mut guard = self.guard.borrow_mut();

if let Some(insert_index) = guard.contains_or_insert(id, def_ref_id) {
if guard.insert(id, def_ref_id) {
if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
Ok((id, insert_index))
Ok(id)
}
} else {
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
}
}

pub fn pop(&self, id: usize, def_ref_id: usize, insert_index: usize) {
pub fn pop(&self, id: usize, def_ref_id: usize) {
let mut guard = self.guard.borrow_mut();
guard.decr_depth();
guard.remove(id, def_ref_id, insert_index);
guard.remove(id, def_ref_id);
}
}
14 changes: 7 additions & 7 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub(crate) fn infer_to_python_known(
extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
Ok(id) => id,
Err(e) => {
return match extra.mode {
Expand Down Expand Up @@ -226,7 +226,7 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
} else if extra.serialize_unknown {
serialize_unknown(value).into_py(py)
Expand Down Expand Up @@ -284,15 +284,15 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
}
value.into_py(py)
}
_ => value.into_py(py),
},
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
Ok(value)
}

Expand Down Expand Up @@ -351,7 +351,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
Ok(v) => v,
Err(e) => {
return if extra.serialize_unknown {
Expand Down Expand Up @@ -534,7 +534,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
} else if extra.serialize_unknown {
serializer.serialize_str(&serialize_unknown(value))
Expand All @@ -548,7 +548,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}
}
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
ser_result
}

Expand Down
8 changes: 4 additions & 4 deletions src/serializers/type_serializers/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ impl TypeSerializer for DefinitionRefSerializer {
) -> PyResult<PyObject> {
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let (value_id, insert_index) = extra.rec_guard.add(value, self.definition.id())?;
let value_id = extra.rec_guard.add(value, self.definition.id())?;
let r = comb_serializer.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id(), insert_index);
extra.rec_guard.pop(value_id, self.definition.id());
r
})
}
Expand All @@ -91,12 +91,12 @@ impl TypeSerializer for DefinitionRefSerializer {
) -> Result<S::Ok, S::Error> {
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let (value_id, insert_index) = extra
let value_id = extra
.rec_guard
.add(value, self.definition.id())
.map_err(py_err_se_err)?;
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id(), insert_index);
extra.rec_guard.pop(value_id, self.definition.id());
r
})
}
Expand Down
8 changes: 4 additions & 4 deletions src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ impl Validator for DefinitionRefValidator {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = input.identity() {
if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) {
if state.recursion_guard.insert(id, self.definition.id()) {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = validator.validate(py, input, state);
state.recursion_guard.remove(id, self.definition.id(), insert_index);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
} else {
Expand All @@ -105,12 +105,12 @@ impl Validator for DefinitionRefValidator {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = obj.identity() {
if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) {
if state.recursion_guard.insert(id, self.definition.id()) {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.definition.id(), insert_index);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
} else {
Expand Down

0 comments on commit 4b82950

Please sign in to comment.