Skip to content

Commit

Permalink
Adds optional serde support to datafusion-proto (#2892)
Browse files Browse the repository at this point in the history
* Add optional serde support to datafusion-proto (#2889)

* Add public methods for JSON serde (#64)

* Misc suggestions

* Update datafusion/proto/Cargo.toml

Co-authored-by: Raphael Taylor-Davies <[email protected]>

Co-authored-by: Raphael Taylor-Davies <[email protected]>

* Fixes

* Fixup Cargo.toml

* Format Cargo.toml

Co-authored-by: Brent Gardner <[email protected]>
  • Loading branch information
tustvold and avantgardnerio authored Jul 17, 2022
1 parent c528986 commit c67161b
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 34 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
CARGO_TARGET_DIR: "/github/home/target"
- name: Check Workspace builds with all features
run: |
cargo check --workspace --benches --features avro,jit,scheduler
cargo check --workspace --benches --features avro,jit,scheduler,json
env:
CARGO_HOME: "/github/home/.cargo"
CARGO_TARGET_DIR: "/github/home/target"
Expand Down Expand Up @@ -121,7 +121,7 @@ jobs:
run: |
export ARROW_TEST_DATA=$(pwd)/testing/data
export PARQUET_TEST_DATA=$(pwd)/parquet-testing/data
cargo test --features avro,jit,scheduler
cargo test --features avro,jit,scheduler,json
# test datafusion-sql examples
cargo run --example sql
# test datafusion examples
Expand Down
12 changes: 9 additions & 3 deletions datafusion/proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repository = "https://github.com/apache/arrow-datafusion"
readme = "README.md"
authors = ["Apache Arrow <[email protected]>"]
license = "Apache-2.0"
keywords = [ "arrow", "query", "sql" ]
keywords = ["arrow", "query", "sql"]
edition = "2021"
rust-version = "1.58"

Expand All @@ -33,18 +33,24 @@ name = "datafusion_proto"
path = "src/lib.rs"

[features]
default = []
json = ["pbjson", "pbjson-build", "serde", "serde_json"]

[dependencies]
arrow = { version = "18.0.0" }
datafusion = { path = "../core", version = "10.0.0" }
datafusion-common = { path = "../common", version = "10.0.0" }
datafusion-expr = { path = "../expr", version = "10.0.0" }
pbjson = { version = "0.3", optional = true }
pbjson-types = { version = "0.3", optional = true }
prost = "0.10"

serde = { version = "1.0", optional = true }
serde_json = { version = "1.0", optional = true }

[dev-dependencies]
doc-comment = "0.3"
tokio = "1.18"

[build-dependencies]
tonic-build = { version = "0.7" }
pbjson-build = { version = "0.3", optional = true }
prost-build = { version = "0.7" }
38 changes: 35 additions & 3 deletions datafusion/proto/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,44 @@
// specific language governing permissions and limitations
// under the License.

type Error = Box<dyn std::error::Error>;
type Result<T, E = Error> = std::result::Result<T, E>;

fn main() -> Result<(), String> {
// for use in docker build where file changes can be wonky
println!("cargo:rerun-if-env-changed=FORCE_REBUILD");

println!("cargo:rerun-if-changed=proto/datafusion.proto");
tonic_build::configure()
.compile(&["proto/datafusion.proto"], &["proto"])

build()?;

Ok(())
}

#[cfg(feature = "json")]
fn build() -> Result<(), String> {
let descriptor_path = std::path::PathBuf::from(std::env::var("OUT_DIR").unwrap())
.join("proto_descriptor.bin");

prost_build::Config::new()
.file_descriptor_set_path(&descriptor_path)
.compile_well_known_types()
.extern_path(".google.protobuf", "::pbjson_types")
.compile_protos(&["proto/datafusion.proto"], &["proto"])
.map_err(|e| format!("protobuf compilation failed: {}", e))?;

let descriptor_set = std::fs::read(descriptor_path).unwrap();
pbjson_build::Builder::new()
.register_descriptors(&descriptor_set)
.unwrap()
.build(&[".datafusion"])
.map_err(|e| format!("pbjson compilation failed: {}", e))?;

Ok(())
}

#[cfg(not(feature = "json"))]
fn build() -> Result<(), String> {
prost_build::Config::new()
.compile_protos(&["proto/datafusion.proto"], &["proto"])
.map_err(|e| format!("protobuf compilation failed: {}", e))
}
46 changes: 46 additions & 0 deletions datafusion/proto/src/bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ pub fn logical_plan_to_bytes(plan: &LogicalPlan) -> Result<Bytes> {
logical_plan_to_bytes_with_extension_codec(plan, &extension_codec)
}

/// Serialize a LogicalPlan as json
#[cfg(feature = "json")]
pub fn logical_plan_to_json(plan: &LogicalPlan) -> Result<String> {
let extension_codec = DefaultExtensionCodec {};
let protobuf =
protobuf::LogicalPlanNode::try_from_logical_plan(plan, &extension_codec)
.map_err(|e| {
DataFusionError::Plan(format!("Error serializing plan: {}", e))
})?;
serde_json::to_string(&protobuf)
.map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {}", e)))
}

/// Serialize a LogicalPlan as bytes, using the provided extension codec
pub fn logical_plan_to_bytes_with_extension_codec(
plan: &LogicalPlan,
Expand All @@ -121,6 +134,14 @@ pub fn logical_plan_to_bytes_with_extension_codec(
Ok(buffer.into())
}

/// Deserialize a LogicalPlan from json
#[cfg(feature = "json")]
pub fn logical_plan_from_json(json: &str, ctx: &SessionContext) -> Result<LogicalPlan> {
let back: protobuf::LogicalPlanNode = serde_json::from_str(json).unwrap();
let extension_codec = DefaultExtensionCodec {};
back.try_into_logical_plan(ctx, &extension_codec)
}

/// Deserialize a LogicalPlan from bytes
pub fn logical_plan_from_bytes(
bytes: &[u8],
Expand Down Expand Up @@ -183,6 +204,31 @@ mod test {
Expr::from_bytes(b"Leet").unwrap();
}

#[test]
#[cfg(feature = "json")]
fn plan_to_json() {
use datafusion_common::DFSchema;
use datafusion_expr::logical_plan::EmptyRelation;

let plan = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
});
let actual = logical_plan_to_json(&plan).unwrap();
let expected = r#"{"emptyRelation":{}}"#.to_string();
assert_eq!(actual, expected);
}

#[test]
#[cfg(feature = "json")]
fn json_to_plan() {
let input = r#"{"emptyRelation":{}}"#.to_string();
let ctx = SessionContext::new();
let actual = logical_plan_from_json(&input, &ctx).unwrap();
let result = matches!(actual, LogicalPlan::EmptyRelation(_));
assert!(result, "Should parse empty relation");
}

#[test]
fn udf_roundtrip_with_registry() {
let ctx = context_with_udf();
Expand Down
68 changes: 42 additions & 26 deletions datafusion/proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ use datafusion_common::DataFusionError;
#[allow(clippy::all)]
pub mod protobuf {
include!(concat!(env!("OUT_DIR"), "/datafusion.rs"));

#[cfg(feature = "json")]
include!(concat!(env!("OUT_DIR"), "/datafusion.serde.rs"));
}

pub mod bytes;
Expand Down Expand Up @@ -75,19 +78,32 @@ mod roundtrip_tests {
use std::fmt::Formatter;
use std::sync::Arc;

#[cfg(feature = "json")]
fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) {
let string = serde_json::to_string(proto).unwrap();
let back: protobuf::LogicalExprNode = serde_json::from_str(&string).unwrap();
assert_eq!(proto, &back);
}

#[cfg(not(feature = "json"))]
fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {}

// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test
// equality.
macro_rules! roundtrip_expr_test {
($initial_struct:ident, $ctx:ident) => {
let proto: protobuf::LogicalExprNode = (&$initial_struct).try_into().unwrap();
fn roundtrip_expr_test<T, E>(initial_struct: T, ctx: SessionContext)
where
for<'a> &'a T: TryInto<protobuf::LogicalExprNode, Error = E> + Debug,
E: Debug,
{
let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap();
let round_trip: Expr = parse_expr(&proto, &ctx).unwrap();

let round_trip: Expr = parse_expr(&proto, &$ctx).unwrap();
assert_eq!(
format!("{:?}", &initial_struct),
format!("{:?}", round_trip)
);

assert_eq!(
format!("{:?}", $initial_struct),
format!("{:?}", round_trip)
);
};
roundtrip_json_test(&proto);
}

fn new_box_field(name: &str, dt: DataType, nullable: bool) -> Box<Field> {
Expand Down Expand Up @@ -807,23 +823,23 @@ mod roundtrip_tests {
let test_expr = Expr::Not(Box::new(lit(1.0_f32)));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_is_null() {
let test_expr = Expr::IsNull(Box::new(col("id")));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_is_not_null() {
let test_expr = Expr::IsNotNull(Box::new(col("id")));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -836,7 +852,7 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -848,7 +864,7 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -859,7 +875,7 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -871,15 +887,15 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_negative() {
let test_expr = Expr::Negative(Box::new(lit(1.0_f32)));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -891,15 +907,15 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_wildcard() {
let test_expr = Expr::Wildcard;

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -909,7 +925,7 @@ mod roundtrip_tests {
args: vec![col("col")],
};
let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -921,7 +937,7 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand Down Expand Up @@ -975,7 +991,7 @@ mod roundtrip_tests {
let mut ctx = SessionContext::new();
ctx.register_udaf(dummy_agg);

roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -1000,7 +1016,7 @@ mod roundtrip_tests {
let mut ctx = SessionContext::new();
ctx.register_udf(udf);

roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -1012,22 +1028,22 @@ mod roundtrip_tests {
]));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_rollup() {
let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_cube() {
let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}
}

0 comments on commit c67161b

Please sign in to comment.