Skip to content

Commit

Permalink
deer: introduce StructVisitor (#2437)
Browse files Browse the repository at this point in the history
* feat: create struct visitor

* feat: add mandatory `visit_null` and `visit_none`

* feat: implement `deserialize_struct` for desert

* feat: implement `deserialize_struct` for json

* feat: implement manual implementation of `Deserialize` for struct

* feat: outline ok

* fix: value deserializer

* test: `StructVisitor`

* feat: remove `visit_null`/`visit_none` from `StructVisitor`

* fix: json `Deserializer`

* test: remove indirection in struct code

* chore: comments

* chore: comments (II)

* feat: use `NoneDeserializer` for values

* fix: drive-by: generalize skip logic, fix regression, skip value if key is wrong

* fix: clippy

* fix: clippy

* fix: rustfmt

* fix: remove reflection of removed tokens
  • Loading branch information
indietyp authored Apr 28, 2023
1 parent 826b014 commit f168bc9
Show file tree
Hide file tree
Showing 15 changed files with 876 additions and 84 deletions.
31 changes: 2 additions & 29 deletions libs/deer/desert/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use error_stack::{Report, Result, ResultExt};

use crate::{
deserializer::{Deserializer, DeserializerNone},
skip::skip_tokens,
token::Token,
};

Expand All @@ -29,31 +30,6 @@ impl<'a, 'b, 'de> ArrayAccess<'a, 'b, 'de> {
remaining: None,
}
}

fn scan_end(&self) -> Option<usize> {
let mut objects: usize = 0;
let mut arrays: usize = 0;

let mut n = 0;

loop {
let token = self.deserializer.peek_n(n)?;

match token {
Token::Array { .. } => arrays += 1,
Token::ArrayEnd if arrays == 0 && objects == 0 => {
// we're at the outer layer, meaning we can know where we end
return Some(n);
}
Token::ArrayEnd => arrays = arrays.saturating_sub(1),
Token::Object { .. } => objects += 1,
Token::ObjectEnd => objects = objects.saturating_sub(1),
_ => {}
}

n += 1;
}
}
}

impl<'de> deer::ArrayAccess<'de> for ArrayAccess<'_, '_, 'de> {
Expand Down Expand Up @@ -133,10 +109,7 @@ impl<'de> deer::ArrayAccess<'de> for ArrayAccess<'_, '_, 'de> {
};

// bump until the very end, which ensures that deserialize calls after this might succeed!
let bump = self
.scan_end()
.map_or_else(|| self.deserializer.tape().remaining(), |index| index + 1);
self.deserializer.tape_mut().bump_n(bump);
skip_tokens(self.deserializer, &Token::Array { length: None });

if let Some(remaining) = self.remaining {
if remaining > 0 {
Expand Down
37 changes: 31 additions & 6 deletions libs/deer/desert/src/deserializer.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use alloc::borrow::ToOwned;

use deer::{
error::{DeserializerError, TypeError, Variant},
error::{DeserializerError, ExpectedType, MissingError, ReceivedType, TypeError, Variant},
value::NoneDeserializer,
Context, EnumVisitor, OptionalVisitor, Visitor,
Context, EnumVisitor, OptionalVisitor, StructVisitor, Visitor,
};
use error_stack::{Report, Result, ResultExt};

Expand Down Expand Up @@ -134,6 +134,26 @@ impl<'a, 'de> deer::Deserializer<'de> for &mut Deserializer<'a, 'de> {

Ok(value)
}

fn deserialize_struct<V>(self, visitor: V) -> Result<V::Value, DeserializerError>
where
V: StructVisitor<'de>,
{
let token = self.next();

match token {
Token::Array { length } => visitor
.visit_array(ArrayAccess::new(self, length))
.change_context(DeserializerError),
Token::Object { length } => visitor
.visit_object(ObjectAccess::new(self, length))
.change_context(DeserializerError),
other => Err(Report::new(TypeError.into_error())
.attach(ExpectedType::new(visitor.expecting()))
.attach(ReceivedType::new(other.schema()))
.change_context(DeserializerError)),
}
}
}

impl<'a, 'de> Deserializer<'a, 'de> {
Expand All @@ -157,10 +177,6 @@ impl<'a, 'de> Deserializer<'a, 'de> {
self.tape.next().expect("should have token to deserialize")
}

pub(crate) const fn tape(&self) -> &Tape<'de> {
&self.tape
}

pub(crate) fn tape_mut(&mut self) -> &mut Tape<'de> {
&mut self.tape
}
Expand Down Expand Up @@ -224,4 +240,13 @@ impl<'de> deer::Deserializer<'de> for DeserializerNone<'_> {
.visit_value(discriminant, self)
.change_context(DeserializerError)
}

fn deserialize_struct<V>(self, visitor: V) -> Result<V::Value, DeserializerError>
where
V: StructVisitor<'de>,
{
Err(Report::new(MissingError.into_error())
.attach(ExpectedType::new(visitor.expecting()))
.change_context(DeserializerError))
}
}
1 change: 1 addition & 0 deletions libs/deer/desert/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod assert;
mod deserializer;
pub mod error;
pub(crate) mod object;
mod skip;
pub(crate) mod tape;
mod token;

Expand Down
38 changes: 8 additions & 30 deletions libs/deer/desert/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use error_stack::{Report, Result, ResultExt};

use crate::{
deserializer::{Deserializer, DeserializerNone},
skip::skip_tokens,
token::Token,
};

Expand All @@ -29,31 +30,6 @@ impl<'a, 'b, 'de: 'a> ObjectAccess<'a, 'b, 'de> {
consumed: 0,
}
}

fn scan_end(&self) -> Option<usize> {
let mut objects: usize = 0;
let mut arrays: usize = 0;

let mut n = 0;

loop {
let token = self.deserializer.peek_n(n)?;

match token {
Token::Array { .. } => arrays += 1,
Token::ArrayEnd => arrays = arrays.saturating_sub(1),
Token::Object { .. } => objects += 1,
Token::ObjectEnd if arrays == 0 && objects == 0 => {
// we're at the outer layer, meaning we can know where we end
return Some(n);
}
Token::ObjectEnd => objects = objects.saturating_sub(1),
_ => {}
}

n += 1;
}
}
}

impl<'de> deer::ObjectAccess<'de> for ObjectAccess<'_, '_, 'de> {
Expand Down Expand Up @@ -115,6 +91,12 @@ impl<'de> deer::ObjectAccess<'de> for ObjectAccess<'_, '_, 'de> {
} else {
let key = access.visit_key(&mut *self.deserializer);

if key.is_err() {
// the key is an error, we need to swallow the value
let next = self.deserializer.next();
skip_tokens(self.deserializer, &next);
}

key.and_then(|key| access.visit_value(key, &mut *self.deserializer))
};

Expand All @@ -141,11 +123,7 @@ impl<'de> deer::ObjectAccess<'de> for ObjectAccess<'_, '_, 'de> {
};

// bump until the very end, which ensures that deserialize calls after this might succeed!
let bump = self
.scan_end()
.unwrap_or_else(|| self.deserializer.tape().remaining());

self.deserializer.tape_mut().bump_n(bump + 1);
skip_tokens(self.deserializer, &Token::Object { length: None });

if let Some(remaining) = self.remaining {
if remaining > 0 {
Expand Down
45 changes: 45 additions & 0 deletions libs/deer/desert/src/skip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::{deserializer::Deserializer, Token};

fn scan_object(deserializer: &Deserializer, stop: &Token) -> usize {
let mut objects: usize = 0;
let mut arrays: usize = 0;

let mut n = 0;

loop {
let Some(token) = deserializer.peek_n(n) else {
// we're at the end
return n;
};

if token == *stop && arrays == 0 && objects == 0 {
// we're at the outer layer, meaning we can know where we end
// need to increment by one as we want to also skip the ObjectEnd
return n + 1;
}

match token {
Token::Array { .. } => arrays += 1,
Token::ArrayEnd => arrays = arrays.saturating_sub(1),
Token::Object { .. } => objects += 1,
Token::ObjectEnd => objects = objects.saturating_sub(1),
_ => {}
}

n += 1;
}
}

/// Skips all tokens required for the start token, be aware that the start token should no longer be
/// on the tape
pub(crate) fn skip_tokens(deserializer: &mut Deserializer, start: &Token) {
let n = match start {
Token::Array { .. } => scan_object(&*deserializer, &Token::ArrayEnd),
Token::Object { .. } => scan_object(&*deserializer, &Token::ObjectEnd),
_ => 0,
};

if n > 0 {
deserializer.tape_mut().bump_n(n);
}
}
40 changes: 39 additions & 1 deletion libs/deer/desert/src/token.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::fmt::{Debug, Display, Formatter};

use deer::Number;
use deer::{Deserialize, Document, Number, Reflection, Schema};

// TODO: test
// TODO: this should be `Copy`, but `Number` has no &'static constructor
Expand Down Expand Up @@ -181,3 +181,41 @@ impl Display for Token {
Debug::fmt(self, f)
}
}

struct AnyArray;

impl Reflection for AnyArray {
fn schema(_: &mut Document) -> Schema {
Schema::new("array")
}
}

struct AnyObject;

impl Reflection for AnyObject {
fn schema(_: &mut Document) -> Schema {
Schema::new("object")
}
}

impl Token {
pub(crate) fn schema(&self) -> Document {
match self {
Self::Bool(_) => Document::new::<bool>(),
Self::Number(_) => Document::new::<Number>(),
Self::U128(_) => Document::new::<u128>(),
Self::I128(_) => Document::new::<i128>(),
Self::Char(_) => Document::new::<char>(),
Self::Str(_) | Self::BorrowedStr(_) | Self::String(_) => Document::new::<str>(),
Self::Bytes(_) | Self::BorrowedBytes(_) | Self::BytesBuf(_) => Document::new::<[u8]>(),
Self::Array { .. } | Self::ArrayEnd => Document::new::<AnyArray>(),
Self::Object { .. } | Self::ObjectEnd => Document::new::<AnyObject>(),
Self::Null => Document::new::<<() as Deserialize>::Reflection>(),
}
}
}

// TODO: maybe number
// TODO: IdentifierVisitor (u8, u64, str, borrowed_str, string,
// bytes, bytes_buf, borrowed_bytes)
// TODO: test
27 changes: 23 additions & 4 deletions libs/deer/json/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ use std::any::Demand;
use deer::{
error::{
ArrayAccessError, ArrayLengthError, BoundedContractViolationError, DeserializeError,
DeserializerError, ExpectedLength, ExpectedType, ObjectAccessError, ObjectItemsExtraError,
ObjectLengthError, ReceivedKey, ReceivedLength, ReceivedType, ReceivedValue, TypeError,
ValueError, Variant,
DeserializerError, ExpectedLength, ExpectedType, MissingError, ObjectAccessError,
ObjectItemsExtraError, ObjectLengthError, ReceivedKey, ReceivedLength, ReceivedType,
ReceivedValue, TypeError, ValueError, Variant,
},
value::NoneDeserializer,
Context, Deserialize, DeserializeOwned, Document, EnumVisitor, FieldVisitor, OptionalVisitor,
Reflection, Schema, Visitor,
Reflection, Schema, StructVisitor, Visitor,
};
use error_stack::{IntoReport, Report, Result, ResultExt};
use serde_json::{Map, Value};
Expand Down Expand Up @@ -427,6 +427,25 @@ impl<'a, 'de> deer::Deserializer<'de> for Deserializer<'a> {
.change_context(DeserializerError)
}
}

fn deserialize_struct<V>(self, visitor: V) -> Result<V::Value, DeserializerError>
where
V: StructVisitor<'de>,
{
match self.value {
None => Err(Report::new(MissingError.into_error())
.attach(ExpectedType::new(visitor.expecting()))
.change_context(DeserializerError)),
Some(Value::Object(object)) => visitor
.visit_object(ObjectAccess::new(object, self.context))
.change_context(DeserializerError),
// we do not allow arrays as struct, only objects are allowed for structs
Some(value) => Err(Report::new(TypeError.into_error())
.attach(ExpectedType::new(visitor.expecting()))
.attach(ReceivedType::new(into_document(&value)))
.change_context(DeserializerError)),
}
}
}

#[must_use]
Expand Down
34 changes: 34 additions & 0 deletions libs/deer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,36 @@ pub trait OptionalVisitor<'de>: Sized {
}
}

#[allow(unused_variables)]
pub trait StructVisitor<'de>: Sized {
type Value;

fn expecting(&self) -> Document;

// visit_none and visit_null are not implemented, as they can be used more expressively using
// `OptionalVisitor`

fn visit_array<A>(self, array: A) -> Result<Self::Value, VisitorError>
where
A: ArrayAccess<'de>,
{
Err(Report::new(TypeError.into_error())
.attach(ReceivedType::new(visitor::ArraySchema::document()))
.attach(ExpectedType::new(self.expecting()))
.change_context(VisitorError))
}

fn visit_object<A>(self, object: A) -> Result<Self::Value, VisitorError>
where
A: ObjectAccess<'de>,
{
Err(Report::new(TypeError.into_error())
.attach(ReceivedType::new(visitor::ObjectSchema::document()))
.attach(ExpectedType::new(self.expecting()))
.change_context(VisitorError))
}
}

// internal visitor, which is used during the default implementation of the `deserialize_i*` and
// `deserialize_u*` methods.
struct NumberVisitor<T: Reflection>(PhantomData<fn() -> *const T>);
Expand Down Expand Up @@ -602,6 +632,10 @@ pub trait Deserializer<'de>: Sized {
where
V: EnumVisitor<'de>;

fn deserialize_struct<V>(self, visitor: V) -> Result<V::Value, DeserializerError>
where
V: StructVisitor<'de>;

derive_from_number![
deserialize_i8(to_i8: i8) -> visit_i8,
deserialize_i16(to_i16: i16) -> visit_i16,
Expand Down
Loading

0 comments on commit f168bc9

Please sign in to comment.