diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 6133c239873b..10c1319b903b 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -32,151 +32,22 @@ mod tests { use std::io::BufReader; use substrait::proto::Plan; - async fn register_csv( - ctx: &SessionContext, - table_name: &str, - file_path: &str, - ) -> Result<()> { - ctx.register_csv(table_name, file_path, CsvReadOptions::default()) - .await - } - - async fn create_context_tpch1() -> Result { - let ctx = SessionContext::new(); - register_csv( - &ctx, - "FILENAME_PLACEHOLDER_0", - "tests/testdata/tpch/lineitem.csv", - ) - .await?; - Ok(ctx) - } - - async fn create_context_tpch2() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), - ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"), - ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"), - ("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch3() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch4() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch5() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/supplier.csv"), - ("NATION", "tests/testdata/tpch/nation.csv"), - ("REGION", "tests/testdata/tpch/region.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch6() -> Result { - let ctx = SessionContext::new(); - - let registrations = - vec![("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/lineitem.csv")]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - // missing context for query 7,8,9 - - async fn create_context_tpch10() -> Result { + async fn create_context(files: Vec<(&str, &str)>) -> Result { let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; + for (table_name, file_path) in files { + ctx.register_csv(table_name, file_path, CsvReadOptions::default()) + .await?; } - - Ok(ctx) - } - - async fn create_context_tpch11() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/nation.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/nation.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - Ok(ctx) } #[tokio::test] async fn tpch_test_1() -> Result<()> { - let ctx = create_context_tpch1().await?; + let ctx = create_context(vec![( + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + )]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_1.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -200,7 +71,18 @@ mod tests { #[tokio::test] async fn tpch_test_2() -> Result<()> { - let ctx = create_context_tpch2().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"), + ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_2.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -242,7 +124,12 @@ mod tests { #[tokio::test] async fn tpch_test_3() -> Result<()> { - let ctx = create_context_tpch3().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_3.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -267,7 +154,11 @@ mod tests { #[tokio::test] async fn tpch_test_4() -> Result<()> { - let ctx = create_context_tpch4().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_4.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -289,7 +180,15 @@ mod tests { #[tokio::test] async fn tpch_test_5() -> Result<()> { - let ctx = create_context_tpch5().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/supplier.csv"), + ("NATION", "tests/testdata/tpch/nation.csv"), + ("REGION", "tests/testdata/tpch/region.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_5.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -319,7 +218,11 @@ mod tests { #[tokio::test] async fn tpch_test_6() -> Result<()> { - let ctx = create_context_tpch6().await?; + let ctx = create_context(vec![( + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + )]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_6.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -338,7 +241,13 @@ mod tests { // TODO: missing plan 7, 8, 9 #[tokio::test] async fn tpch_test_10() -> Result<()> { - let ctx = create_context_tpch10().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_10.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -365,7 +274,15 @@ mod tests { #[tokio::test] async fn tpch_test_11() -> Result<()> { - let ctx = create_context_tpch11().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/nation.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_11.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"),