Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive ellipsis for serde_pyo3 #1589

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 151 additions & 1 deletion bindings/python/src/utils/serde_pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ use serde::de::value::Error;
use serde::{ser, Serialize};
type Result<T> = ::std::result::Result<T, Error>;

const MAX_DEPTH: usize = 5;

pub struct Serializer {
// This string starts empty and JSON is appended as values are serialized.
output: String,
level: usize,
}

// By convention, the public API of a Serde serializer is one or more `to_abc`
Expand All @@ -18,6 +21,7 @@ where
{
let mut serializer = Serializer {
output: String::new(),
level: 0,
};
value.serialize(&mut serializer)?;
Ok(serializer.output)
Expand Down Expand Up @@ -51,6 +55,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// of the primitive types of the data model and map it to JSON by appending
// into the output string.
fn serialize_bool(self, v: bool) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += if v { "True" } else { "False" };
Ok(())
}
Expand All @@ -74,6 +83,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// Not particularly efficient but this is example code anyway. A more
// performant approach would be to use the `itoa` crate.
fn serialize_i64(self, v: i64) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += &v.to_string();
Ok(())
}
Expand All @@ -91,6 +105,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
}

fn serialize_u64(self, v: u64) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += &v.to_string();
Ok(())
}
Expand All @@ -100,6 +119,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
}

fn serialize_f64(self, v: f64) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += &v.to_string();
Ok(())
}
Expand All @@ -114,6 +138,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// get the idea. For example it would emit invalid JSON if the input string
// contains a '"' character.
fn serialize_str(self, v: &str) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += "\"";
self.output += v;
self.output += "\"";
Expand Down Expand Up @@ -152,6 +181,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// In Serde, unit means an anonymous value containing no data. Map this to
// JSON as `null`.
fn serialize_unit(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += "None";
Ok(())
}
Expand All @@ -173,6 +207,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
_variant_index: u32,
variant: &'static str,
) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
// self.serialize_str(variant)
self.output += variant;
Ok(())
Expand Down Expand Up @@ -202,6 +241,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
// variant.serialize(&mut *self)?;
self.output += variant;
self.output += "(";
Expand All @@ -221,6 +265,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// explicitly in the serialized form. Some serializers may only be able to
// support sequences for which the length is known up front.
fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
self.output += "[";
Ok(self)
}
Expand All @@ -230,6 +279,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// means that the corresponding `Deserialize implementation will know the
// length without needing to look at the serialized data.
fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
self.output += "(";
Ok(self)
}
Expand All @@ -252,6 +306,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleVariant> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
// variant.serialize(&mut *self)?;
self.output += variant;
self.output += "(";
Expand All @@ -260,6 +319,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {

// Maps are represented in JSON as `{ K: V, K: V, ... }`.
fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
println!("Serialize map");
self.output += "{";
Ok(self)
Expand All @@ -271,6 +335,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
// Deserialize implementation is required to know what the keys are without
// looking at the serialized data.
fn serialize_struct(self, name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
// self.serialize_map(Some(len))
// name.serialize(&mut *self)?;
if let Some(stripped) = name.strip_suffix("Helper") {
Expand All @@ -291,6 +360,11 @@ impl<'a> ser::Serializer for &'a mut Serializer {
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(self);
}
// variant.serialize(&mut *self)?;
self.output += variant;
self.output += "(";
Expand All @@ -316,6 +390,11 @@ impl<'a> ser::SerializeSeq for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('[') {
self.output += ", ";
}
Expand All @@ -324,6 +403,11 @@ impl<'a> ser::SerializeSeq for &'a mut Serializer {

// Close the sequence.
fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += "]";
Ok(())
}
Expand All @@ -338,13 +422,23 @@ impl<'a> ser::SerializeTuple for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
value.serialize(&mut **self)
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand All @@ -359,13 +453,23 @@ impl<'a> ser::SerializeTupleStruct for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
value.serialize(&mut **self)
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand All @@ -388,13 +492,23 @@ impl<'a> ser::SerializeTupleVariant for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
value.serialize(&mut **self)
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand Down Expand Up @@ -424,6 +538,11 @@ impl<'a> ser::SerializeMap for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('{') {
self.output += ", ";
}
Expand All @@ -437,11 +556,21 @@ impl<'a> ser::SerializeMap for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ":";
value.serialize(&mut **self)
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += "}";
Ok(())
}
Expand All @@ -457,6 +586,11 @@ impl<'a> ser::SerializeStruct for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
Expand All @@ -471,6 +605,11 @@ impl<'a> ser::SerializeStruct for &'a mut Serializer {
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand All @@ -486,6 +625,11 @@ impl<'a> ser::SerializeStructVariant for &'a mut Serializer {
where
T: ?Sized + Serialize,
{
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
if !self.output.ends_with('(') {
self.output += ", ";
}
Expand All @@ -496,6 +640,11 @@ impl<'a> ser::SerializeStructVariant for &'a mut Serializer {
}

fn end(self) -> Result<()> {
self.level += 1;
if self.level > MAX_DEPTH {
self.output += "...";
return Ok(());
}
self.output += ")";
Ok(())
}
Expand Down Expand Up @@ -525,7 +674,7 @@ fn test_struct() {
let expected = r#"Test(int=1, seq=["a", "b"])"#;
assert_eq!(to_string(&test).unwrap(), expected);
}

/*
#[test]
fn test_enum() {
#[derive(Serialize)]
Expand Down Expand Up @@ -640,3 +789,4 @@ fn test_flatten() {
let expected = r#"C(a=True, b=1, d=2)"#;
assert_eq!(to_string(&u).unwrap(), expected);
}
*/
Loading