From 99063ca33cda583a0a97767c93708fb1d6e717a1 Mon Sep 17 00:00:00 2001 From: advancedxy Date: Wed, 12 Jun 2024 00:43:22 +0900 Subject: [PATCH 01/14] chore: Reuse DFSSchema::datatype_is_logically_equal method (#10867) --- datafusion/common/src/dfschema.rs | 2 +- .../physical-expr/src/expressions/in_list.rs | 18 ++++-------------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 3686af90db17..0dab13d08731 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -666,7 +666,7 @@ impl DFSchema { /// than datatype_is_semantically_equal in that a Dictionary type is logically /// equal to a plain V type, but not semantically equal. Dictionary is also /// logically equal to Dictionary. - fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool { + pub fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool { // check nested fields match (dt1, dt2) { (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index a36ec9c8ebdc..53c790ff6b54 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -38,7 +38,9 @@ use datafusion_common::cast::{ as_boolean_array, as_generic_binary_array, as_string_array, }; use datafusion_common::hash_utils::HashValue; -use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, DFSchema, Result, ScalarValue, +}; use datafusion_expr::ColumnarValue; use ahash::RandomState; @@ -416,18 +418,6 @@ impl PartialEq for InListExpr { } } -/// Checks if two types are logically equal, dictionary types are compared by their value types. -fn is_logically_eq(lhs: &DataType, rhs: &DataType) -> bool { - match (lhs, rhs) { - (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { - v1.as_ref().eq(v2.as_ref()) - } - (DataType::Dictionary(_, l), _) => l.as_ref().eq(rhs), - (_, DataType::Dictionary(_, r)) => lhs.eq(r.as_ref()), - _ => lhs.eq(rhs), - } -} - /// Creates a unary expression InList pub fn in_list( expr: Arc, @@ -439,7 +429,7 @@ pub fn in_list( let expr_data_type = expr.data_type(schema)?; for list_expr in list.iter() { let list_expr_data_type = list_expr.data_type(schema)?; - if !is_logically_eq(&expr_data_type, &list_expr_data_type) { + if !DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type) { return internal_err!( "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" ); From f554c9fdf10176b905cd3214765ca930a3ede261 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 11:45:23 -0400 Subject: [PATCH 02/14] Bump braces in /datafusion/wasmtest/datafusion-wasm-app (#10865) Bumps [braces](https://github.com/micromatch/braces) from 3.0.2 to 3.0.3. - [Changelog](https://github.com/micromatch/braces/blob/master/CHANGELOG.md) - [Commits](https://github.com/micromatch/braces/compare/3.0.2...3.0.3) --- updated-dependencies: - dependency-name: braces dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../datafusion-wasm-app/package-lock.json | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 7d324d074c9d..4cfe0b5a0cd2 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -804,12 +804,12 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -1632,9 +1632,9 @@ } }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "dependencies": { "to-regex-range": "^5.0.1" @@ -5041,12 +5041,12 @@ } }, "braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "requires": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" } }, "browserslist": { @@ -5655,9 +5655,9 @@ } }, "fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "requires": { "to-regex-range": "^5.0.1" From d84d75a23eb4fa8373fba8184844cc0f8f9103a8 Mon Sep 17 00:00:00 2001 From: Kirill Khramkov Date: Tue, 11 Jun 2024 22:34:51 +0400 Subject: [PATCH 03/14] Docs: Add `unnest` to SQL Reference (#10839) * Add unnest to SQL Reference * Add unnest docs for struct, add additional example for unnest array * unnest -> unnest (struct) * prettier --------- Co-authored-by: Andrew Lamb --- .../source/user-guide/sql/scalar_functions.md | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 625e0d95b57e..10c52bc5de9e 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2085,6 +2085,7 @@ to_unixtime(expression[, ..., format_n]) - [string_to_array](#string_to_array) - [string_to_list](#string_to_list) - [trim_array](#trim_array) +- [unnest](#unnest) - [range](#range) ### `array_append` @@ -3346,6 +3347,48 @@ trim_array(array, n) Can be a constant, column, or function, and any combination of array operators. - **n**: Element to trim the array. +### `unnest` + +Transforms an array into rows. + +#### Arguments + +- **array**: Array expression to unnest. + Can be a constant, column, or function, and any combination of array operators. + +#### Examples + +``` +> select unnest(make_array(1, 2, 3, 4, 5)); ++------------------------------------------------------------------+ +| unnest(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5))) | ++------------------------------------------------------------------+ +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | ++------------------------------------------------------------------+ +``` + +``` +> select unnest(range(0, 10)); ++-----------------------------------+ +| unnest(range(Int64(0),Int64(10))) | ++-----------------------------------+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | ++-----------------------------------+ +``` + ### `range` Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` or `SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH);` @@ -3395,6 +3438,7 @@ are not allowed - [struct](#struct) - [named_struct](#named_struct) +- [unnest](#unnest-struct) ### `struct` @@ -3480,6 +3524,33 @@ select named_struct('field_a', a, 'field_b', b) from t; Can be a constant, column, or function, and any combination of arithmetic or string operators. +### `unnest (struct)` + +Unwraps struct fields into columns. + +#### Arguments + +- **struct**: Object expression to unnest. + Can be a constant, column, or function, and any combination of object operators. + +#### Examples + +``` +> select * from foo; ++---------------------+ +| column1 | ++---------------------+ +| {a: 5, b: a string} | ++---------------------+ + +> select unnest(column1) from foo; ++-----------------------+-----------------------+ +| unnest(foo.column1).a | unnest(foo.column1).b | ++-----------------------+-----------------------+ +| 5 | a string | ++-----------------------+-----------------------+ +``` + ## Hashing Functions - [digest](#digest) From 1b89da4455f68c3199c56a7c4a4298ce3120a714 Mon Sep 17 00:00:00 2001 From: Arttu Date: Tue, 11 Jun 2024 20:35:34 +0200 Subject: [PATCH 04/14] Support correct output column names and struct field names when consuming/producing Substrait (#10829) * produce flattened list of names including inner struct fields * add a (failing) test * rename output columns (incl. inner struct fields) according to the given list of names * fix a test * add column names project to the new TPC-H test and fix case (assert_eq gives nicer error messages than assert) --- .../substrait/src/logical_plan/consumer.rs | 133 +++++++++++++++++- .../substrait/src/logical_plan/producer.rs | 2 +- .../tests/cases/consumer_integration.rs | 17 +-- .../substrait/tests/cases/logical_plans.rs | 2 +- .../tests/cases/roundtrip_logical_plan.rs | 43 +++--- 5 files changed, 156 insertions(+), 41 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 8a483db8c4d6..648a281832e1 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,7 +17,7 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{ - DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, + DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; use datafusion::common::{ not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, @@ -29,12 +29,13 @@ use url::Url; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr, - LogicalPlan, Operator, ScalarUDF, Values, + aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, + EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF, + Values, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, + col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion::prelude::JoinType; @@ -225,6 +226,7 @@ pub async fn from_substrait_plan( None => not_impl_err!("Cannot parse empty extension"), }) .collect::>>()?; + // Parse relations match plan.relations.len() { 1 => { @@ -234,7 +236,29 @@ pub async fn from_substrait_plan( Ok(from_substrait_rel(ctx, rel, &function_extension).await?) }, plan_rel::RelType::Root(root) => { - Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?) + let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?; + if root.names.is_empty() { + // Backwards compatibility for plans missing names + return Ok(plan); + } + let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; + if renamed_schema.equivalent_names_and_types(plan.schema()) { + // Nothing to do if the schema is already equivalent + return Ok(plan); + } + + match plan { + // If the last node of the plan produces expressions, bake the renames into those expressions. + // This isn't necessary for correctness, but helps with roundtrip tests. + LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)), + LogicalPlan::Aggregate(a) => { + let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?; + Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?)) + }, + // There are probably more plans where we could bake things in, can add them later as needed. + // Otherwise, add a new Project to handle the renaming. + _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?)) + } } }, None => plan_err!("Cannot parse plan relation: None") @@ -284,6 +308,105 @@ pub fn extract_projection( } } +fn rename_expressions( + exprs: impl IntoIterator, + input_schema: &DFSchema, + new_schema: DFSchemaRef, +) -> Result> { + exprs + .into_iter() + .zip(new_schema.fields()) + .map(|(old_expr, new_field)| { + if &old_expr.get_type(input_schema)? == new_field.data_type() { + // Alias column if needed + old_expr.alias_if_changed(new_field.name().into()) + } else { + // Use Cast to rename inner struct fields + alias column if needed + Expr::Cast(Cast::new( + Box::new(old_expr), + new_field.data_type().to_owned(), + )) + .alias_if_changed(new_field.name().into()) + } + }) + .collect() +} + +fn make_renamed_schema( + schema: &DFSchemaRef, + dfs_names: &Vec, +) -> Result { + fn rename_inner_fields( + dtype: &DataType, + dfs_names: &Vec, + name_idx: &mut usize, + ) -> Result { + match dtype { + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|f| { + let name = next_struct_field_name(0, dfs_names, name_idx)?; + Ok((**f).to_owned().with_name(name).with_data_type( + rename_inner_fields(f.data_type(), dfs_names, name_idx)?, + )) + }) + .collect::>()?; + Ok(DataType::Struct(fields)) + } + DataType::List(inner) => Ok(DataType::List(FieldRef::new( + (**inner).to_owned().with_data_type(rename_inner_fields( + inner.data_type(), + dfs_names, + name_idx, + )?), + ))), + DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new( + (**inner).to_owned().with_data_type(rename_inner_fields( + inner.data_type(), + dfs_names, + name_idx, + )?), + ))), + _ => Ok(dtype.to_owned()), + } + } + + let mut name_idx = 0; + + let (qualifiers, fields): (_, Vec) = schema + .iter() + .map(|(q, f)| { + let name = next_struct_field_name(0, dfs_names, &mut name_idx)?; + Ok(( + q.cloned(), + (**f) + .to_owned() + .with_name(name) + .with_data_type(rename_inner_fields( + f.data_type(), + dfs_names, + &mut name_idx, + )?), + )) + }) + .collect::>>()? + .into_iter() + .unzip(); + + if name_idx != dfs_names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + dfs_names.len()); + } + + Ok(Arc::new(DFSchema::from_field_specific_qualified_schema( + qualifiers, + &Arc::new(Schema::new(fields)), + )?)) +} + /// Convert Substrait Rel to DataFusion DataFrame #[async_recursion] pub async fn from_substrait_rel( diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 6c8be4aa9b12..88dc894eccd2 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -115,7 +115,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result Result<()> { roundtrip("SELECT * FROM data").await } +#[tokio::test] +async fn select_with_alias() -> Result<()> { + roundtrip("SELECT a AS aliased_a FROM data").await +} + #[tokio::test] async fn select_with_filter() -> Result<()> { roundtrip("SELECT * FROM data WHERE a > 1").await @@ -367,9 +372,9 @@ async fn implicit_cast() -> Result<()> { async fn aggregate_case() -> Result<()> { assert_expected_plan( "SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data", - "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\ + "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]]\ \n TableScan: data projection=[a]", - false // NULL vs Int64(NULL) + true ) .await } @@ -589,32 +594,23 @@ async fn roundtrip_union_all() -> Result<()> { #[tokio::test] async fn simple_intersect() -> Result<()> { + // Substrait treats both COUNT(*) and COUNT(1) the same assert_expected_plan( "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\ \n Projection: \ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ \n TableScan: data2 projection=[a]", - false // COUNT(*) vs COUNT(Int64(1)) + true ) .await } #[tokio::test] async fn simple_intersect_table_reuse() -> Result<()> { - assert_expected_plan( - "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Projection: \ - \n LeftSemi Join: data.a = data.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data projection=[a]", - false // COUNT(*) vs COUNT(Int64(1)) - ) - .await + roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await } #[tokio::test] @@ -694,20 +690,14 @@ async fn all_type_literal() -> Result<()> { #[tokio::test] async fn roundtrip_literal_list() -> Result<()> { - assert_expected_plan( - "SELECT [[1,2,3], [], NULL, [NULL]] FROM data", - "Projection: List([[1, 2, 3], [], , []])\ - \n TableScan: data projection=[]", - false, // "List(..)" vs "make_array(..)" - ) - .await + roundtrip("SELECT [[1,2,3], [], NULL, [NULL]] FROM data").await } #[tokio::test] async fn roundtrip_literal_struct() -> Result<()> { assert_expected_plan( "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data", - "Projection: Struct({c0:1,c1:true,c2:})\ + "Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL)\ \n TableScan: data projection=[]", false, // "Struct(..)" vs "struct(..)" ) @@ -980,12 +970,13 @@ async fn assert_expected_plan( println!("{proto:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(expected_plan_str, &plan2str); - if assert_schema { assert_eq!(plan.schema(), plan2.schema()); } + + let plan2str = format!("{plan2:?}"); + assert_eq!(expected_plan_str, &plan2str); + Ok(()) } From 0ec292f45404359356ab9125b1b0f5b21a135ab8 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen <83442793+MohamedAbdeen21@users.noreply.github.com> Date: Tue, 11 Jun 2024 21:37:10 +0300 Subject: [PATCH 05/14] Make Logical Plans more readable by removing extra aliases (#10832) * logical plan: remove unnecessary aliases * revert EnterMark * fix docs and benchmarks * revert id_array change * add alias counter * fix alias counter bug * fix slt test * fix benchmark results * revert alias/unalias changes * remove TODO * minor fix * fix benchmark --- .../optimizer/src/common_subexpr_eliminate.rs | 73 ++++++++++++++----- .../sqllogictest/test_files/group_by.slt | 18 ++--- .../sqllogictest/test_files/tpch/q1.slt.part | 4 +- 3 files changed, 65 insertions(+), 30 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 6820ba04f0e9..3ed1309f1544 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -128,7 +128,7 @@ impl CommonSubexprEliminate { fn rewrite_exprs_list( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], + arrays_list: &[&[IdArray]], expr_stats: &ExprStats, common_exprs: &mut CommonExprs, ) -> Result>> { @@ -159,7 +159,7 @@ impl CommonSubexprEliminate { fn rewrite_expr( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], + arrays_list: &[&[IdArray]], input: &LogicalPlan, expr_stats: &ExprStats, config: &dyn OptimizerConfig, @@ -480,7 +480,7 @@ fn to_arrays( input_schema: DFSchemaRef, expr_stats: &mut ExprStats, expr_mask: ExprMask, -) -> Result>> { +) -> Result> { expr.iter() .map(|e| { let mut id_array = vec![]; @@ -739,7 +739,7 @@ fn expr_identifier(expr: &Expr, sub_expr_identifier: Identifier) -> Identifier { fn expr_to_identifier( expr: &Expr, expr_stats: &mut ExprStats, - id_array: &mut Vec<(usize, Identifier)>, + id_array: &mut IdArray, input_schema: DFSchemaRef, expr_mask: ExprMask, ) -> Result<()> { @@ -769,15 +769,28 @@ struct CommonSubexprRewriter<'a> { common_exprs: &'a mut CommonExprs, // preorder index, starts from 0. down_index: usize, + // how many aliases have we seen so far + alias_counter: usize, } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { type Node = Expr; + fn f_up(&mut self, expr: Expr) -> Result> { + if matches!(expr, Expr::Alias(_)) { + self.alias_counter -= 1 + } + Ok(Transformed::no(expr)) + } + fn f_down(&mut self, expr: Expr) -> Result> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. + if matches!(expr, Expr::Alias(_)) { + self.alias_counter += 1; + } + if expr.short_circuits() || expr.is_volatile()? { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } @@ -801,15 +814,16 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { let expr_name = expr.display_name()?; self.common_exprs.insert(expr_id.clone(), expr); - // Alias this `Column` expr to it original "expr name", - // `projection_push_down` optimizer use "expr name" to eliminate useless - // projections. - // TODO: do we really need to alias here? - Ok(Transformed::new( - col(expr_id).alias(expr_name), - true, - TreeNodeRecursion::Jump, - )) + + // alias the expressions without an `Alias` ancestor node + let rewritten = if self.alias_counter > 0 { + col(expr_id) + } else { + self.alias_counter += 1; + col(expr_id).alias(expr_name) + }; + + Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)) } else { Ok(Transformed::no(expr)) } @@ -829,6 +843,7 @@ fn replace_common_expr( id_array, common_exprs, down_index: 0, + alias_counter: 0, }) .data() } @@ -962,6 +977,26 @@ mod test { Ok(()) } + #[test] + fn nested_aliases() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![ + (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")), + col("a") + col("b"), + ])? + .build()?; + + let expected = "Projection: {test.a + test.b|{test.b}|{test.a}} - test.c AS alias1 * {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b, {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b\ + \n Projection: test.a + test.b AS {test.a + test.b|{test.b}|{test.a}}, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, &plan); + + Ok(()) + } + #[test] fn aggregate() -> Result<()> { let table_scan = test_table_scan()?; @@ -1006,7 +1041,7 @@ mod test { )? .build()?; - let expected = "Projection: {AVG(test.a)|{test.a}} AS AVG(test.a) AS col1, {AVG(test.a)|{test.a}} AS AVG(test.a) AS col2, col3, {AVG(test.c)} AS AVG(test.c), {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col4, {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col5, col6, {my_agg(test.c)} AS my_agg(test.c)\ + let expected = "Projection: {AVG(test.a)|{test.a}} AS col1, {AVG(test.a)|{test.a}} AS col2, col3, {AVG(test.c)} AS AVG(test.c), {my_agg(test.a)|{test.a}} AS col4, {my_agg(test.a)|{test.a}} AS col5, col6, {my_agg(test.c)} AS my_agg(test.c)\ \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS {AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}}, AVG(test.b) AS col3, AVG(test.c) AS {AVG(test.c)}, my_agg(test.b) AS col6, my_agg(test.c) AS {my_agg(test.c)}]]\ \n TableScan: test"; @@ -1042,7 +1077,7 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test"; + let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test"; assert_optimized_plan_eq(expected, &plan); @@ -1057,7 +1092,7 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\ + let expected = "Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\ \n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; @@ -1078,8 +1113,8 @@ mod test { )? .build()?; - let expected = "Projection: UInt32(1) + test.a, UInt32(1) + {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a) AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a) AS col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a) AS col3, UInt32(1) - {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a) AS col4, {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a)\ - \n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}]]\ + let expected = "Projection: UInt32(1) + test.a, UInt32(1) + {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col3, UInt32(1) - {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col4, {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS my_agg(UInt32(1) + test.a)\ + \n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}, AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)}]]\ \n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; @@ -1126,7 +1161,7 @@ mod test { ])? .build()?; - let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1) + test.a AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1) + test.a AS second\ + let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS second\ \n Projection: Int32(1) + test.a AS {Int32(1) + test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 24a301d4a700..9e8a2450e0a5 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4538,7 +4538,7 @@ CREATE EXTERNAL TABLE timestamp_table ( c2 INT, ) STORED AS CSV -LOCATION 'test_files/scratch/group_by/timestamp_table' +LOCATION 'test_files/scratch/group_by/timestamp_table' OPTIONS ('format.has_header' 'true'); # Group By using date_trunc @@ -4611,7 +4611,7 @@ DROP TABLE timestamp_table; # Table with an int column and Dict column: statement ok -CREATE TABLE int8_dict AS VALUES +CREATE TABLE int8_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int8, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int8, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int8, Utf8)')), @@ -4649,7 +4649,7 @@ DROP TABLE int8_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE int16_dict AS VALUES +CREATE TABLE int16_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int16, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int16, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int16, Utf8)')), @@ -4687,7 +4687,7 @@ DROP TABLE int16_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE int32_dict AS VALUES +CREATE TABLE int32_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int32, Utf8)')), @@ -4725,7 +4725,7 @@ DROP TABLE int32_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE int64_dict AS VALUES +CREATE TABLE int64_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int64, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int64, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int64, Utf8)')), @@ -4763,7 +4763,7 @@ DROP TABLE int64_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint8_dict AS VALUES +CREATE TABLE uint8_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt8, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt8, Utf8)')), @@ -4801,7 +4801,7 @@ DROP TABLE uint8_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint16_dict AS VALUES +CREATE TABLE uint16_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt16, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt16, Utf8)')), @@ -4839,7 +4839,7 @@ DROP TABLE uint16_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint32_dict AS VALUES +CREATE TABLE uint32_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt32, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt32, Utf8)')), @@ -4877,7 +4877,7 @@ DROP TABLE uint32_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint64_dict AS VALUES +CREATE TABLE uint64_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt64, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt64, Utf8)')), diff --git a/datafusion/sqllogictest/test_files/tpch/q1.slt.part b/datafusion/sqllogictest/test_files/tpch/q1.slt.part index 0583c6ef07a7..5e0930b99228 100644 --- a/datafusion/sqllogictest/test_files/tpch/q1.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q1.slt.part @@ -42,7 +42,7 @@ explain select logical_plan 01)Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS LAST 02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus, sum(lineitem.l_quantity) AS sum_qty, sum(lineitem.l_extendedprice) AS sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order -03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]] +03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]] 04)------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS {lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus 05)--------Filter: lineitem.l_shipdate <= Date32("1998-09-02") 06)----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("1998-09-02")] @@ -80,7 +80,7 @@ group by l_linestatus order by l_returnflag, - l_linestatus; + l_linestatus; ---- A F 3774200 5320753880.69 5054096266.6828 5256751331.449234 25.537587 36002.123829 0.050144 147790 N F 95257 133737795.84 127132372.6512 132286291.229445 25.300664 35521.326916 0.049394 3765 From c50f0dc6ef602bd7780bdfd18ef2905e8659ee96 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 11 Jun 2024 16:16:33 -0400 Subject: [PATCH 06/14] Minor: Improve `ListingTable` documentation (#10854) * Minor: Improve ListingTable documentation * Update datafusion/core/src/datasource/listing/table.rs Co-authored-by: Oleks V --------- Co-authored-by: Oleks V --- .../core/src/datasource/listing/table.rs | 49 +++++++++++++++---- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 746e4b8e3330..7f5e80c4988a 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -547,20 +547,49 @@ impl ListingOptions { } } -/// Reads data from one or more files via an -/// [`ObjectStore`]. For example, from -/// local files or objects from AWS S3. Implements [`TableProvider`], -/// a DataFusion data source. +/// Reads data from one or more files as a single table. /// -/// # Features +/// Implements [`TableProvider`], a DataFusion data source. The files are read +/// using an [`ObjectStore`] instance, for example from local files or objects +/// from AWS S3. /// -/// 1. Merges schemas if the files have compatible but not identical schemas +/// For example, given the `table1` directory (or object store prefix) /// -/// 2. Hive-style partitioning support, where a path such as -/// `/files/date=1/1/2022/data.parquet` is injected as a `date` column. +/// ```text +/// table1 +/// ├── file1.parquet +/// └── file2.parquet +/// ``` +/// +/// A `ListingTable` would read the files `file1.parquet` and `file2.parquet` as +/// a single table, merging the schemas if the files have compatible but not +/// identical schemas. +/// +/// Given the `table2` directory (or object store prefix) +/// +/// ```text +/// table2 +/// ├── date=2024-06-01 +/// │ ├── file3.parquet +/// │ └── file4.parquet +/// └── date=2024-06-02 +/// └── file5.parquet +/// ``` +/// +/// A `ListingTable` would read the files `file3.parquet`, `file4.parquet`, and +/// `file5.parquet` as a single table, again merging schemas if necessary. +/// +/// Given the hive style partitioning structure (e.g,. directories named +/// `date=2024-06-01` and `date=2026-06-02`), `ListingTable` also adds a `date` +/// column when reading the table: +/// * The files in `table2/date=2024-06-01` will have the value `2024-06-01` +/// * The files in `table2/date=2024-06-02` will have the value `2024-06-02`. +/// +/// If the query has a predicate like `WHERE date = '2024-06-01'` +/// only the corresponding directory will be read. /// -/// 3. Projection pushdown for formats that support it such as such as -/// Parquet +/// `ListingTable` also supports filter and projection pushdown for formats that +/// support it as such as Parquet. /// /// # Example /// From 97ea05c0f60aa11a270420968d3fefc859c0d346 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Tue, 11 Jun 2024 22:19:55 -0400 Subject: [PATCH 07/14] Extending join fuzz tests to support join filtering (#10728) * Extending join fuzz tests to support join filtering --------- Co-authored-by: Oleks V --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 407 +++++++++++++----- 1 file changed, 296 insertions(+), 111 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 824f1eec4a85..8c2e24de56b9 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -22,6 +22,11 @@ use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_schema::Schema; + +use datafusion_common::ScalarValue; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::PhysicalExprRef; + use rand::Rng; use datafusion::common::JoinSide; @@ -40,92 +45,207 @@ use test_utils::stagger_batch_with_seed; #[tokio::test] async fn test_inner_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Inner, + None, + ) + .run_test() + .await +} + +fn less_than_10_join_filter(schema1: Arc, _schema2: Arc) -> JoinFilter { + let less_than_100 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::from(100))), + )) as _; + let column_indices = vec![ColumnIndex { + index: 0, + side: JoinSide::Left, + }]; + let intermediate_schema = + Schema::new(vec![schema1.field_with_name("a").unwrap().to_owned()]); + + JoinFilter::new(less_than_100, column_indices, intermediate_schema) +} + +#[tokio::test] +async fn test_inner_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Inner, + Some(Box::new(less_than_10_join_filter)), + ) + .run_test() + .await +} + +#[tokio::test] +async fn test_inner_join_1k_smjoin() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Inner, + None, ) + .run_test() .await } #[tokio::test] async fn test_left_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Left, + None, + ) + .run_test() + .await +} + +#[tokio::test] +async fn test_left_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Left, + Some(Box::new(less_than_10_join_filter)), ) + .run_test() .await } #[tokio::test] async fn test_right_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Right, + None, + ) + .run_test() + .await +} +// Add support for Right filtered joins +#[ignore] +#[tokio::test] +async fn test_right_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Right, + Some(Box::new(less_than_10_join_filter)), ) + .run_test() .await } #[tokio::test] async fn test_full_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Full, + None, ) + .run_test() + .await +} + +#[tokio::test] +async fn test_full_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Full, + Some(Box::new(less_than_10_join_filter)), + ) + .run_test() .await } #[tokio::test] async fn test_semi_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftSemi, + None, + ) + .run_test() + .await +} + +#[tokio::test] +async fn test_semi_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftSemi, + Some(Box::new(less_than_10_join_filter)), ) + .run_test() .await } #[tokio::test] async fn test_anti_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftAnti, + None, + ) + .run_test() + .await +} + +// Test failed for now. https://github.com/apache/datafusion/issues/10872 +#[ignore] +#[tokio::test] +async fn test_anti_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, + Some(Box::new(less_than_10_join_filter)), ) + .run_test() .await } -/// Perform sort-merge join and hash join on same input -/// and verify two outputs are equal -async fn run_join_test( +type JoinFilterBuilder = Box, Arc) -> JoinFilter>; + +struct JoinFuzzTestCase { + batch_sizes: &'static [usize], input1: Vec, input2: Vec, join_type: JoinType, -) { - let batch_sizes = [1, 2, 7, 49, 50, 51, 100]; - for batch_size in batch_sizes { - let session_config = SessionConfig::new().with_batch_size(batch_size); - let ctx = SessionContext::new_with_config(session_config); - let task_ctx = ctx.task_ctx(); - - let schema1 = input1[0].schema(); - let schema2 = input2[0].schema(); - let on_columns = vec![ - ( - Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _, - Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _, - ), - ( - Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _, - Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _, - ), - ]; + join_filter_builder: Option, +} - // Nested loop join uses filter for joining records - let column_indices = vec![ +impl JoinFuzzTestCase { + fn new( + input1: Vec, + input2: Vec, + join_type: JoinType, + join_filter_builder: Option, + ) -> Self { + Self { + batch_sizes: &[1, 2, 7, 49, 50, 51, 100], + input1, + input2, + join_type, + join_filter_builder, + } + } + + fn column_indices(&self) -> Vec { + vec![ ColumnIndex { index: 0, side: JoinSide::Left, @@ -142,120 +262,185 @@ async fn run_join_test( index: 1, side: JoinSide::Right, }, - ]; - let intermediate_schema = Schema::new(vec![ - schema1.field_with_name("a").unwrap().to_owned(), - schema1.field_with_name("b").unwrap().to_owned(), - schema2.field_with_name("a").unwrap().to_owned(), - schema2.field_with_name("b").unwrap().to_owned(), - ]); + ] + } - let equal_a = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Eq, - Arc::new(Column::new("a", 2)), - )) as _; - let equal_b = Arc::new(BinaryExpr::new( - Arc::new(Column::new("b", 1)), - Operator::Eq, - Arc::new(Column::new("b", 3)), - )) as _; - let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _; + fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + vec![ + ( + Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _, + ), + ( + Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _, + ), + ] + } - let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema); + fn intermediate_schema(&self) -> Schema { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + Schema::new(vec![ + schema1 + .field_with_name("a") + .unwrap() + .to_owned() + .with_nullable(true), + schema1 + .field_with_name("b") + .unwrap() + .to_owned() + .with_nullable(true), + schema2.field_with_name("a").unwrap().to_owned(), + schema2.field_with_name("b").unwrap().to_owned(), + ]) + } - // sort-merge join + fn left_right(&self) -> (Arc, Arc) { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), + MemoryExec::try_new(&[self.input1.clone()], schema1.clone(), None).unwrap(), ); let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), + MemoryExec::try_new(&[self.input2.clone()], schema2.clone(), None).unwrap(), ); - let smj = Arc::new( + (left, right) + } + + fn join_filter(&self) -> Option { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + self.join_filter_builder + .as_ref() + .map(|builder| builder(schema1, schema2)) + } + + fn sort_merge_join(&self) -> Arc { + let (left, right) = self.left_right(); + Arc::new( SortMergeJoinExec::try_new( left, right, - on_columns.clone(), - None, - join_type, + self.on_columns().clone(), + self.join_filter(), + self.join_type, vec![SortOptions::default(), SortOptions::default()], false, ) .unwrap(), - ); - let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); + ) + } - // hash join - let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), - ); - let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), - ); - let hj = Arc::new( + fn hash_join(&self) -> Arc { + let (left, right) = self.left_right(); + Arc::new( HashJoinExec::try_new( left, right, - on_columns.clone(), - None, - &join_type, + self.on_columns().clone(), + self.join_filter(), + &self.join_type, None, PartitionMode::Partitioned, false, ) .unwrap(), - ); - let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + ) + } - // nested loop join - let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), - ); - let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), - ); - let nlj = Arc::new( - NestedLoopJoinExec::try_new(left, right, Some(on_filter), &join_type) + fn nested_loop_join(&self) -> Arc { + let (left, right) = self.left_right(); + // Nested loop join uses filter for joining records + let column_indices = self.column_indices(); + let intermediate_schema = self.intermediate_schema(); + + let equal_a = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Column::new("a", 2)), + )) as _; + let equal_b = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Eq, + Arc::new(Column::new("b", 3)), + )) as _; + let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _; + + let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema); + + Arc::new( + NestedLoopJoinExec::try_new(left, right, Some(on_filter), &self.join_type) .unwrap(), - ); - let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); + ) + } - // compare - let smj_formatted = pretty_format_batches(&smj_collected).unwrap().to_string(); - let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); - let nlj_formatted = pretty_format_batches(&nlj_collected).unwrap().to_string(); + /// Perform sort-merge join and hash join on same input + /// and verify two outputs are equal + async fn run_test(&self) { + for batch_size in self.batch_sizes { + let session_config = SessionConfig::new().with_batch_size(*batch_size); + let ctx = SessionContext::new_with_config(session_config); + let task_ctx = ctx.task_ctx(); + let smj = self.sort_merge_join(); + let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); - let mut smj_formatted_sorted: Vec<&str> = smj_formatted.trim().lines().collect(); - smj_formatted_sorted.sort_unstable(); + let hj = self.hash_join(); + let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); - let mut hj_formatted_sorted: Vec<&str> = hj_formatted.trim().lines().collect(); - hj_formatted_sorted.sort_unstable(); + let nlj = self.nested_loop_join(); + let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); - let mut nlj_formatted_sorted: Vec<&str> = nlj_formatted.trim().lines().collect(); - nlj_formatted_sorted.sort_unstable(); + // compare + let smj_formatted = + pretty_format_batches(&smj_collected).unwrap().to_string(); + let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); + let nlj_formatted = + pretty_format_batches(&nlj_collected).unwrap().to_string(); - for (i, (smj_line, hj_line)) in smj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { - assert_eq!( - (i, smj_line), - (i, hj_line), - "SortMergeJoinExec and HashJoinExec produced different results" - ); - } + let mut smj_formatted_sorted: Vec<&str> = + smj_formatted.trim().lines().collect(); + smj_formatted_sorted.sort_unstable(); + + let mut hj_formatted_sorted: Vec<&str> = + hj_formatted.trim().lines().collect(); + hj_formatted_sorted.sort_unstable(); + + let mut nlj_formatted_sorted: Vec<&str> = + nlj_formatted.trim().lines().collect(); + nlj_formatted_sorted.sort_unstable(); - for (i, (nlj_line, hj_line)) in nlj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { assert_eq!( - (i, nlj_line), - (i, hj_line), - "NestedLoopJoinExec and HashJoinExec produced different results" + smj_formatted_sorted.len(), + hj_formatted_sorted.len(), + "SortMergeJoinExec and HashJoinExec produced different row counts" ); + for (i, (smj_line, hj_line)) in smj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, smj_line), + (i, hj_line), + "SortMergeJoinExec and HashJoinExec produced different results" + ); + } + + for (i, (nlj_line, hj_line)) in nlj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, nlj_line), + (i, hj_line), + "NestedLoopJoinExec and HashJoinExec produced different results" + ); + } } } } From 9b3b80510e7fa56933a4f8a7a82c26f136a9e182 Mon Sep 17 00:00:00 2001 From: RT_Enzyme <58059931+RTEnzyme@users.noreply.github.com> Date: Wed, 12 Jun 2024 16:34:52 +0800 Subject: [PATCH 08/14] replace and(.., not(...)) with and_not(..) (#10885) Co-authored-by: velosearch --- datafusion/physical-expr/src/expressions/case.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index c56229e07a63..08d8cd441334 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -26,7 +26,7 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, is_null, not, nullif, or, prep_null_mask_filter}; +use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; @@ -168,7 +168,7 @@ impl CaseExpr { } }; - remainder = and(&remainder, ¬(&when_match)?)?; + remainder = and_not(&remainder, &when_match)?; } if let Some(e) = &self.else_expr { @@ -241,7 +241,7 @@ impl CaseExpr { // Succeed tuples should be filtered out for short-circuit evaluation, // null values for the current when expr should be kept - remainder = and(&remainder, ¬(&when_value)?)?; + remainder = and_not(&remainder, &when_value)?; } if let Some(e) = &self.else_expr { From 7f6fc07577f882d39db72e44ebabe0442a7bf016 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Wed, 12 Jun 2024 09:53:51 -0400 Subject: [PATCH 09/14] Disabling test for semi join with filters (#10887) --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 8c2e24de56b9..7dbbfb25bf78 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -179,7 +179,8 @@ async fn test_semi_join_1k() { .run_test() .await } - +// See https://github.com/apache/datafusion/issues/10886 +#[ignore] #[tokio::test] async fn test_semi_join_1k_filtered() { JoinFuzzTestCase::new( From 73381fe35738ef2f5a06e9f55626f08855e8a852 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 12 Jun 2024 11:17:14 -0400 Subject: [PATCH 10/14] Minor: Update `min_statistics` and `max_statistics` to be helpers, update docs (#10866) --- .../physical_plan/parquet/statistics.rs | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index a4a919f20d0f..c0d36f1fc4d7 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`min_statistics`] and [`max_statistics`] convert statistics in parquet format to arrow [`ArrayRef`]. +//! [`StatisticsConverter`] to convert statistics in parquet format to arrow [`ArrayRef`]. // TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 @@ -542,8 +542,11 @@ pub(crate) fn parquet_column<'a>( Some((parquet_idx, field)) } -/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] -pub(crate) fn min_statistics<'a, I: Iterator>>( +/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an +/// [`ArrayRef`] +/// +/// This is an internal helper -- see [`StatisticsConverter`] for public API +fn min_statistics<'a, I: Iterator>>( data_type: &DataType, iterator: I, ) -> Result { @@ -551,7 +554,9 @@ pub(crate) fn min_statistics<'a, I: Iterator>>( +/// +/// This is an internal helper -- see [`StatisticsConverter`] for public API +fn max_statistics<'a, I: Iterator>>( data_type: &DataType, iterator: I, ) -> Result { @@ -1425,9 +1430,10 @@ mod test { assert_eq!(idx, 2); let row_groups = metadata.row_groups(); - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let converter = + StatisticsConverter::try_new("int_col", &schema, parquet_schema).unwrap(); - let min = min_statistics(&DataType::Int32, iter.clone()).unwrap(); + let min = converter.row_group_mins(row_groups.iter()).unwrap(); assert_eq!( &min, &expected_min, @@ -1435,7 +1441,7 @@ mod test { DisplayStats(row_groups) ); - let max = max_statistics(&DataType::Int32, iter).unwrap(); + let max = converter.row_group_maxes(row_groups.iter()).unwrap(); assert_eq!( &max, &expected_max, @@ -1623,22 +1629,23 @@ mod test { continue; } - let (idx, f) = - parquet_column(parquet_schema, &schema, field.name()).unwrap(); - assert_eq!(f, field); + let converter = + StatisticsConverter::try_new(field.name(), &schema, parquet_schema) + .unwrap(); - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); - let min = min_statistics(f.data_type(), iter.clone()).unwrap(); + assert_eq!(converter.arrow_field, field.as_ref()); + + let mins = converter.row_group_mins(row_groups.iter()).unwrap(); assert_eq!( - &min, + &mins, &expected_min, "Min. Statistics\n\n{}\n\n", DisplayStats(row_groups) ); - let max = max_statistics(f.data_type(), iter).unwrap(); + let maxes = converter.row_group_maxes(row_groups.iter()).unwrap(); assert_eq!( - &max, + &maxes, &expected_max, "Max. Statistics\n\n{}\n\n", DisplayStats(row_groups) @@ -1705,7 +1712,7 @@ mod test { self } - /// Reads the specified parquet file and validates that the exepcted min/max + /// Reads the specified parquet file and validates that the expected min/max /// values for the specified columns are as expected. fn run(self) { let path = PathBuf::from(parquet_test_data()).join(self.file_name); @@ -1723,14 +1730,13 @@ mod test { expected_max, } = expected_column; - let (idx, field) = - parquet_column(parquet_schema, arrow_schema, name).unwrap(); - - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); - let actual_min = min_statistics(field.data_type(), iter.clone()).unwrap(); + let converter = + StatisticsConverter::try_new(name, arrow_schema, parquet_schema) + .unwrap(); + let actual_min = converter.row_group_mins(row_groups.iter()).unwrap(); assert_eq!(&expected_min, &actual_min, "column {name}"); - let actual_max = max_statistics(field.data_type(), iter).unwrap(); + let actual_max = converter.row_group_maxes(row_groups.iter()).unwrap(); assert_eq!(&expected_max, &actual_max, "column {name}"); } } From 87d826703bfe05df292649adf6c30b2528c83ab2 Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Wed, 12 Jun 2024 18:34:22 +0200 Subject: [PATCH 11/14] chore: remove interval test (#10888) --- .../core/tests/parquet/arrow_statistics.rs | 81 +------------------ datafusion/core/tests/parquet/mod.rs | 80 +----------------- 2 files changed, 2 insertions(+), 159 deletions(-) diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 0e23e6824027..2ea18d7cf823 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -30,8 +30,7 @@ use arrow::datatypes::{ use arrow_array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, RecordBatch, StringArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, @@ -1061,84 +1060,6 @@ async fn test_dates_64_diff_rg_sizes() { .run(); } -#[tokio::test] -#[should_panic] -// Currently this test `should_panic` since statistics for `Intervals` -// are not supported and `IntervalMonthDayNano` cannot be written -// to parquet yet. -// Refer to issue: https://github.com/apache/arrow-rs/issues/5847 -// and https://github.com/apache/arrow-rs/blob/master/parquet/src/arrow/arrow_writer/mod.rs#L747 -async fn test_interval_diff_rg_sizes() { - // This creates a parquet files of 3 columns: - // "year_month" --> IntervalYearMonthArray - // "day_time" --> IntervalDayTimeArray - // "month_day_nano" --> IntervalMonthDayNanoArray - // - // The file is created by 4 record batches (each has a null row) - // each has 5 rows but then will be split into 2 row groups with size 13, 7 - let reader = TestReader { - scenario: Scenario::Interval, - row_per_group: 13, - } - .build() - .await; - - // TODO: expected values need to be changed once issue is resolved - // expected_min: Arc::new(IntervalYearMonthArray::from(vec![ - // IntervalYearMonthType::make_value(1, 10), - // IntervalYearMonthType::make_value(4, 13), - // ])), - // expected_max: Arc::new(IntervalYearMonthArray::from(vec![ - // IntervalYearMonthType::make_value(6, 51), - // IntervalYearMonthType::make_value(8, 53), - // ])), - Test { - reader: &reader, - expected_min: Arc::new(IntervalYearMonthArray::from(vec![None, None])), - expected_max: Arc::new(IntervalYearMonthArray::from(vec![None, None])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: UInt64Array::from(vec![13, 7]), - column_name: "year_month", - } - .run(); - - // expected_min: Arc::new(IntervalDayTimeArray::from(vec![ - // IntervalDayTimeType::make_value(1, 10), - // IntervalDayTimeType::make_value(4, 13), - // ])), - // expected_max: Arc::new(IntervalDayTimeArray::from(vec![ - // IntervalDayTimeType::make_value(6, 51), - // IntervalDayTimeType::make_value(8, 53), - // ])), - Test { - reader: &reader, - expected_min: Arc::new(IntervalDayTimeArray::from(vec![None, None])), - expected_max: Arc::new(IntervalDayTimeArray::from(vec![None, None])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: UInt64Array::from(vec![13, 7]), - column_name: "day_time", - } - .run(); - - // expected_min: Arc::new(IntervalMonthDayNanoArray::from(vec![ - // IntervalMonthDayNanoType::make_value(1, 10, 100), - // IntervalMonthDayNanoType::make_value(4, 13, 103), - // ])), - // expected_max: Arc::new(IntervalMonthDayNanoArray::from(vec![ - // IntervalMonthDayNanoType::make_value(6, 51, 501), - // IntervalMonthDayNanoType::make_value(8, 53, 503), - // ])), - Test { - reader: &reader, - expected_min: Arc::new(IntervalMonthDayNanoArray::from(vec![None, None])), - expected_max: Arc::new(IntervalMonthDayNanoArray::from(vec![None, None])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: UInt64Array::from(vec![13, 7]), - column_name: "month_day_nano", - } - .run(); -} - #[tokio::test] async fn test_uint() { // This creates a parquet files of 4 columns named "u8", "u16", "u32", "u64" diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 5ab268beb92f..9546ab30c9e0 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -18,9 +18,7 @@ //! Parquet integration tests use crate::parquet::utils::MetricsFinder; use arrow::array::Decimal128Array; -use arrow::datatypes::{ - i256, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, -}; +use arrow::datatypes::i256; use arrow::{ array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, @@ -36,10 +34,6 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; -use arrow_array::{ - IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, -}; -use arrow_schema::IntervalUnit; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{provider_as_source, TableProvider}, @@ -92,7 +86,6 @@ enum Scenario { Time32Millisecond, Time64Nanosecond, Time64Microsecond, - Interval, /// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64 /// -MIN, -100, -1, 0, 1, 100, MAX NumericLimits, @@ -921,71 +914,6 @@ fn make_dict_batch() -> RecordBatch { .unwrap() } -fn make_interval_batch(offset: i32) -> RecordBatch { - let schema = Schema::new(vec![ - Field::new( - "year_month", - DataType::Interval(IntervalUnit::YearMonth), - true, - ), - Field::new("day_time", DataType::Interval(IntervalUnit::DayTime), true), - Field::new( - "month_day_nano", - DataType::Interval(IntervalUnit::MonthDayNano), - true, - ), - ]); - let schema = Arc::new(schema); - - let ym_arr = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1 + offset, 10 + offset)), - Some(IntervalYearMonthType::make_value(2 + offset, 20 + offset)), - Some(IntervalYearMonthType::make_value(3 + offset, 30 + offset)), - None, - Some(IntervalYearMonthType::make_value(5 + offset, 50 + offset)), - ]); - - let dt_arr = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1 + offset, 10 + offset)), - Some(IntervalDayTimeType::make_value(2 + offset, 20 + offset)), - Some(IntervalDayTimeType::make_value(3 + offset, 30 + offset)), - None, - Some(IntervalDayTimeType::make_value(5 + offset, 50 + offset)), - ]); - - // Not yet implemented, refer to: - // https://github.com/apache/arrow-rs/blob/master/parquet/src/arrow/arrow_writer/mod.rs#L747 - let mdn_arr = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value( - 1 + offset, - 10 + offset, - 100 + (offset as i64), - )), - Some(IntervalMonthDayNanoType::make_value( - 2 + offset, - 20 + offset, - 200 + (offset as i64), - )), - Some(IntervalMonthDayNanoType::make_value( - 3 + offset, - 30 + offset, - 300 + (offset as i64), - )), - None, - Some(IntervalMonthDayNanoType::make_value( - 5 + offset, - 50 + offset, - 500 + (offset as i64), - )), - ]); - - RecordBatch::try_new( - schema, - vec![Arc::new(ym_arr), Arc::new(dt_arr), Arc::new(mdn_arr)], - ) - .unwrap() -} - fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Boolean => { @@ -1407,12 +1335,6 @@ fn create_data_batch(scenario: Scenario) -> Vec { ]), ] } - Scenario::Interval => vec![ - make_interval_batch(0), - make_interval_batch(1), - make_interval_batch(2), - make_interval_batch(3), - ], } } From dfdda7cb04f7f9b640da4f297ce1a16b08f3bf7b Mon Sep 17 00:00:00 2001 From: Arttu Date: Wed, 12 Jun 2024 18:40:40 +0200 Subject: [PATCH 12/14] fix: Ignore nullability of list elements when consuming Substrait (#10874) * Ignore nullability of list elements when consuming Substrait DataFusion (= Arrow) is quite strict about nullability, specifically, when using e.g. LogicalPlan::Values, the given schema must match the given literals exactly - including nullability. This is non-trivial to do when converting schema and literals separately. The existing implementation for from_substrait_literal already creates lists that are always nullable (see ScalarValue::new_list => array_into_list_array). This reverts part of https://github.com/apache/datafusion/pull/10640 to align from_substrait_type with that behavior. This is the error I was hitting: ``` ArrowError(InvalidArgumentError("column types must match schema types, expected List(Field { name: \"item\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }) but found List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) at column index 0"), None) ``` * use `Field::new_list_field` in `array_into_(large_)list_array` just for consistency, to reduce the places where "item" is written out * add a test for non-nullable lists --- datafusion/common/src/utils/mod.rs | 14 ++-- .../substrait/src/logical_plan/consumer.rs | 4 +- .../substrait/src/logical_plan/producer.rs | 14 ++-- .../substrait/tests/cases/logical_plans.rs | 32 +++++++-- .../non_nullable_lists.substrait.json | 71 +++++++++++++++++++ 5 files changed, 114 insertions(+), 21 deletions(-) create mode 100644 datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index ae444c2cb285..a0e4d1a76c03 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -354,7 +354,7 @@ pub fn longest_consecutive_prefix>( pub fn array_into_list_array(arr: ArrayRef) -> ListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); ListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), offsets, arr, None, @@ -366,7 +366,7 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray { pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); LargeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), offsets, arr, None, @@ -379,7 +379,7 @@ pub fn array_into_fixed_size_list_array( ) -> FixedSizeListArray { let list_size = list_size as i32; FixedSizeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), list_size, arr, None, @@ -420,7 +420,7 @@ pub fn arrays_into_list_array( let data_type = arr[0].data_type().to_owned(); let values = arr.iter().map(|x| x.as_ref()).collect::>(); Ok(ListArray::new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), OffsetBuffer::from_lengths(lens), arrow::compute::concat(values.as_slice())?, None, @@ -435,7 +435,7 @@ pub fn arrays_into_list_array( /// use datafusion_common::utils::base_type; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// assert_eq!(base_type(&data_type), DataType::Int32); /// /// let data_type = DataType::Int32; @@ -458,10 +458,10 @@ pub fn base_type(data_type: &DataType) -> DataType { /// use datafusion_common::utils::coerced_type_with_base_type_only; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); -/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 648a281832e1..3f9a895d951c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1395,7 +1395,9 @@ fn from_substrait_type( })?; let field = Arc::new(Field::new_list_field( from_substrait_type(inner_type, dfs_names, name_idx)?, - is_substrait_type_nullable(inner_type)?, + // We ignore Substrait's nullability here to match to_substrait_literal + // which always creates nullable lists + true, )); match list.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 88dc894eccd2..c0469d333164 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2309,14 +2309,12 @@ mod test { round_trip_type(DataType::Decimal128(10, 2))?; round_trip_type(DataType::Decimal256(30, 2))?; - for nullable in [true, false] { - round_trip_type(DataType::List( - Field::new_list_field(DataType::Int32, nullable).into(), - ))?; - round_trip_type(DataType::LargeList( - Field::new_list_field(DataType::Int32, nullable).into(), - ))?; - } + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; round_trip_type(DataType::Struct( vec![ diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 994a932c30e0..94572e098b2c 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -20,6 +20,7 @@ #[cfg(test)] mod tests { use datafusion::common::Result; + use datafusion::dataframe::DataFrame; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; use std::fs::File; @@ -38,11 +39,7 @@ mod tests { // File generated with substrait-java's Isthmus: // ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)" - let path = "tests/testdata/select_not_bool.substrait.json"; - let proto = serde_json::from_reader::<_, Plan>(BufReader::new( - File::open(path).expect("file not found"), - )) - .expect("failed to parse json"); + let proto = read_json("tests/testdata/select_not_bool.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; @@ -54,6 +51,31 @@ mod tests { Ok(()) } + #[tokio::test] + async fn non_nullable_lists() -> Result<()> { + // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. + // That's because implementing the non-nullability consistently is non-trivial. + // This test confirms that reading a plan with non-nullable lists works as expected. + let ctx = create_context().await?; + let proto = read_json("tests/testdata/non_nullable_lists.substrait.json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + assert_eq!(format!("{:?}", &plan), "Values: (List([1, 2]))"); + + // Need to trigger execution to ensure that Arrow has validated the plan + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } + + fn read_json(path: &str) -> Plan { + serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json") + } + async fn create_context() -> datafusion::common::Result { let ctx = SessionContext::new(); ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new()) diff --git a/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json new file mode 100644 index 000000000000..e1c5574f8bec --- /dev/null +++ b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json @@ -0,0 +1,71 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "col" + ], + "struct": { + "types": [ + { + "list": { + "type": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { + "fields": [ + { + "list": { + "values": [ + { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + }, + { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + ] + }, + "nullable": false, + "typeVariationReference": 0 + } + ] + } + ] + } + } + }, + "names": [ + "col" + ] + } + } + ], + "expectedTypeUrls": [] +} From 908a3a1d2feea1b1ae8c6220dcdb9e8264dd27ad Mon Sep 17 00:00:00 2001 From: Oleks V Date: Wed, 12 Jun 2024 14:46:50 -0700 Subject: [PATCH 13/14] Minor: SMJ fuzz tests fix for rowcounts (#10891) * Fix: Sort Merge Join crashes on TPCH Q21 * Fix LeftAnti SMJ join when the join filter is set * rm dbg * Minor: Fix fuzz testing row counts --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 7dbbfb25bf78..a893e780581f 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -55,7 +55,7 @@ async fn test_inner_join_1k() { .await } -fn less_than_10_join_filter(schema1: Arc, _schema2: Arc) -> JoinFilter { +fn less_than_100_join_filter(schema1: Arc, _schema2: Arc) -> JoinFilter { let less_than_100 = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Lt, @@ -77,7 +77,7 @@ async fn test_inner_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Inner, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -113,7 +113,7 @@ async fn test_left_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Left, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -138,7 +138,7 @@ async fn test_right_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Right, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -162,7 +162,7 @@ async fn test_full_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Full, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -179,15 +179,14 @@ async fn test_semi_join_1k() { .run_test() .await } -// See https://github.com/apache/datafusion/issues/10886 -#[ignore] + #[tokio::test] async fn test_semi_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftSemi, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -213,7 +212,7 @@ async fn test_anti_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -392,6 +391,15 @@ impl JoinFuzzTestCase { let hj = self.hash_join(); let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + // Get actual row counts(without formatting overhead) for HJ and SMJ + let hj_rows = hj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); + let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); + + assert_eq!( + hj_rows, smj_rows, + "SortMergeJoinExec and HashJoinExec produced different row counts" + ); + let nlj = self.nested_loop_join(); let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); @@ -414,21 +422,20 @@ impl JoinFuzzTestCase { nlj_formatted.trim().lines().collect(); nlj_formatted_sorted.sort_unstable(); - assert_eq!( - smj_formatted_sorted.len(), - hj_formatted_sorted.len(), - "SortMergeJoinExec and HashJoinExec produced different row counts" - ); - for (i, (smj_line, hj_line)) in smj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { - assert_eq!( - (i, smj_line), - (i, hj_line), - "SortMergeJoinExec and HashJoinExec produced different results" - ); + // row level compare if any of joins returns the result + // the reason is different formatting when there is no rows + if smj_rows > 0 || hj_rows > 0 { + for (i, (smj_line, hj_line)) in smj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, smj_line), + (i, hj_line), + "SortMergeJoinExec and HashJoinExec produced different results" + ); + } } for (i, (nlj_line, hj_line)) in nlj_formatted_sorted From 8f718dd3ce291c9f5688144ca6c9d7d854dc4b0b Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 13 Jun 2024 07:54:39 +0800 Subject: [PATCH 14/14] Move `Count` to `functions-aggregate`, update MSRV to rust 1.75 (#10484) * mv accumulate indices Signed-off-by: jayzhan211 * complete udaf Signed-off-by: jayzhan211 * register Signed-off-by: jayzhan211 * fix expr Signed-off-by: jayzhan211 * filter distinct count Signed-off-by: jayzhan211 * todo: need to move count distinct too Signed-off-by: jayzhan211 * move code around Signed-off-by: jayzhan211 * move distinct to aggr-crate Signed-off-by: jayzhan211 * replace Signed-off-by: jayzhan211 * backup Signed-off-by: jayzhan211 * fix function name and physical expr Signed-off-by: jayzhan211 * fix physical optimizer Signed-off-by: jayzhan211 * fix all slt Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix with args Signed-off-by: jayzhan211 * add label Signed-off-by: jayzhan211 * revert builtin related code back Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * fix substrait Signed-off-by: jayzhan211 * fix doc Signed-off-by: jayzhan211 * fmy Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 * fix udaf macro for distinct but not apply Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix count distinct and use workspace Signed-off-by: jayzhan211 * add reverse Signed-off-by: jayzhan211 * remove old code Signed-off-by: jayzhan211 * backup Signed-off-by: jayzhan211 * use macro Signed-off-by: jayzhan211 * expr builder Signed-off-by: jayzhan211 * introduce expr builder Signed-off-by: jayzhan211 * add example Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * clean agg sta Signed-off-by: jayzhan211 * combine agg Signed-off-by: jayzhan211 * limit distinct and fmt Signed-off-by: jayzhan211 * cleanup name Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fix window Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix merged Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 * fix rebase Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * use std Signed-off-by: jayzhan211 * update mrsv Signed-off-by: jayzhan211 * upd msrv Signed-off-by: jayzhan211 * revert test Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * downgrade to 1.75 Signed-off-by: jayzhan211 * 1.76 Signed-off-by: jayzhan211 * ahas Signed-off-by: jayzhan211 * revert to 1.75 Signed-off-by: jayzhan211 * rm count Signed-off-by: jayzhan211 * fix merge Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * rm sum in test_no_duplicate_name Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- Cargo.toml | 4 +- datafusion-cli/Cargo.lock | 2 + datafusion-cli/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/core/src/dataframe/mod.rs | 13 +- .../aggregate_statistics.rs | 79 +- .../combine_partial_final_agg.rs | 47 +- .../limited_distinct_aggregation.rs | 16 +- .../core/src/physical_optimizer/test_utils.rs | 5 +- datafusion/core/src/physical_planner.rs | 1 - .../provider_filter_pushdown.rs | 1 + datafusion/core/tests/dataframe/mod.rs | 11 +- .../core/tests/fuzz_cases/window_fuzz.rs | 5 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/expr_fn.rs | 2 + datafusion/functions-aggregate/src/count.rs | 562 ++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 8 +- datafusion/optimizer/src/decorrelate.rs | 10 +- .../src/single_distinct_to_groupby.rs | 3 +- datafusion/physical-expr-common/Cargo.toml | 2 + .../src/aggregate/count_distinct/bytes.rs | 6 +- .../src/aggregate/count_distinct/mod.rs | 23 + .../src/aggregate/count_distinct/native.rs | 23 +- .../physical-expr-common/src/aggregate/mod.rs | 1 + .../src/binary_map.rs | 21 +- datafusion/physical-expr-common/src/lib.rs | 1 + .../physical-expr/src/aggregate/build_in.rs | 92 +-- .../physical-expr/src/aggregate/count.rs | 348 --------- .../src/aggregate/count_distinct/mod.rs | 718 ------------------ .../src/aggregate/groups_accumulator/mod.rs | 2 +- datafusion/physical-expr/src/aggregate/mod.rs | 2 - .../physical-expr/src/expressions/mod.rs | 2 - datafusion/physical-expr/src/lib.rs | 4 +- .../src/aggregates/group_values/bytes.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 19 +- .../src/windows/bounded_window_agg_exec.rs | 7 +- datafusion/physical-plan/src/windows/mod.rs | 4 +- datafusion/proto-common/Cargo.toml | 2 +- datafusion/proto-common/gen/Cargo.toml | 2 +- datafusion/proto/Cargo.toml | 2 +- datafusion/proto/gen/Cargo.toml | 2 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 17 + datafusion/proto/src/generated/prost.rs | 2 + .../proto/src/logical_plan/from_proto.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 + .../proto/src/physical_plan/to_proto.rs | 15 +- .../tests/cases/roundtrip_logical_plan.rs | 2 + .../tests/cases/roundtrip_physical_plan.rs | 33 +- datafusion/sqllogictest/test_files/errors.slt | 4 +- datafusion/substrait/Cargo.toml | 2 +- .../substrait/src/logical_plan/consumer.rs | 12 +- 52 files changed, 822 insertions(+), 1329 deletions(-) create mode 100644 datafusion/functions-aggregate/src/count.rs rename datafusion/{physical-expr => physical-expr-common}/src/aggregate/count_distinct/bytes.rs (93%) create mode 100644 datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs rename datafusion/{physical-expr => physical-expr-common}/src/aggregate/count_distinct/native.rs (93%) rename datafusion/{physical-expr => physical-expr-common}/src/binary_map.rs (98%) delete mode 100644 datafusion/physical-expr/src/aggregate/count.rs delete mode 100644 datafusion/physical-expr/src/aggregate/count_distinct/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 65ef191d7421..aa1ba1f214d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.73" +rust-version = "1.75" version = "39.0.0" [workspace.dependencies] @@ -107,7 +107,7 @@ doc-comment = "0.3" env_logger = "0.11" futures = "0.3" half = { version = "2.2.1", default-features = false } -hashbrown = { version = "0.14", features = ["raw"] } +hashbrown = { version = "0.14.5", features = ["raw"] } indexmap = "2.0.0" itertools = "0.12" log = "^0.4" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 932f44d98486..c5b34df4f1cf 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1376,9 +1376,11 @@ dependencies = [ name = "datafusion-physical-expr-common" version = "39.0.0" dependencies = [ + "ahash", "arrow", "datafusion-common", "datafusion-expr", + "hashbrown 0.14.5", "rand", ] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 5e393246b958..8f4b3cd81f36 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" readme = "README.md" [dependencies] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 7533e2cff198..45617d88dc0c 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.73" +rust-version = "1.75" [lints] workspace = true diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 06a85d303687..950cb7ddb2d3 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,12 +50,11 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, + avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::sum; -use datafusion_functions_aggregate::expr_fn::{median, stddev}; +use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum}; use async_trait::async_trait; @@ -854,10 +853,7 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate( - vec![], - vec![datafusion_expr::count(Expr::Literal(COUNT_STAR_EXPANSION))], - )? + .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? .collect() .await?; let len = *rows @@ -1594,9 +1590,10 @@ mod tests { use datafusion_common::{Constraint, Constraints}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, count_distinct, create_udf, expr, lit, BuiltInWindowFunction, + array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::expr_fn::count_distinct; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 05f05d95b8db..eeacc48b85db 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -170,38 +170,6 @@ fn take_optimizable_column_and_table_count( } } } - // TODO: Remove this after revmoing Builtin Count - else if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( - &stats.num_rows, - agg_expr.as_any().downcast_ref::(), - ) { - // TODO implementing Eq on PhysicalExpr would help a lot here - if casted_expr.expressions().len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - casted_expr.name().to_string(), - )); - } - } else if let Some(lit_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - casted_expr.name().to_owned(), - )); - } - } - } - } None } @@ -307,13 +275,12 @@ fn take_optimizable_max( #[cfg(test)] pub(crate) mod tests { - use super::*; + use crate::logical_expr::Operator; use crate::physical_plan::aggregates::PhysicalGroupBy; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; - use crate::physical_plan::expressions::Count; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::memory::MemoryExec; use crate::prelude::SessionContext; @@ -322,8 +289,10 @@ pub(crate) mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_int64_array; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::cast; use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use datafusion_physical_plan::aggregates::AggregateMode; /// Mock data using a MemoryExec which has an exact count statistic @@ -414,13 +383,19 @@ pub(crate) mod tests { Self::ColumnA(schema.clone()) } - /// Return appropriate expr depending if COUNT is for col or table (*) - pub(crate) fn count_expr(&self) -> Arc { - Arc::new(Count::new( - self.column(), + // Return appropriate expr depending if COUNT is for col or table (*) + pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { + create_aggregate_expr( + &count_udaf(), + &[self.column()], + &[], + &[], + schema, self.column_name(), - DataType::Int64, - )) + false, + false, + ) + .unwrap() } /// what argument would this aggregate need in the plan? @@ -458,7 +433,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -467,7 +442,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -488,7 +463,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -497,7 +472,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -517,7 +492,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -529,7 +504,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -549,7 +524,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -561,7 +536,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -592,7 +567,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], filter, Arc::clone(&schema), @@ -601,7 +576,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -637,7 +612,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], filter, Arc::clone(&schema), @@ -646,7 +621,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 3ad61e52c82e..38b92959e841 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -206,8 +206,9 @@ mod tests { use crate::physical_plan::{displayable, Partitioning}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; - use datafusion_physical_expr::expressions::{col, Count}; + use datafusion_physical_expr::expressions::col; use datafusion_physical_plan::udaf::create_aggregate_expr; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected @@ -303,15 +304,31 @@ mod tests { ) } + // Return appropriate expr depending if COUNT is for col or table (*) + fn count_expr( + expr: Arc, + name: &str, + schema: &Schema, + ) -> Arc { + create_aggregate_expr( + &count_udaf(), + &[expr], + &[], + &[], + schema, + name, + false, + false, + ) + .unwrap() + } + #[test] fn aggregations_not_combined() -> Result<()> { let schema = schema(); - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + let plan = final_aggregate_exec( repartition_exec(partial_aggregate_exec( parquet_exec(&schema), @@ -330,16 +347,8 @@ mod tests { ]; assert_optimized!(expected, plan); - let aggr_expr1 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; - let aggr_expr2 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(2)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr1 = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + let aggr_expr2 = vec![count_expr(lit(1i8), "COUNT(2)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( @@ -365,11 +374,7 @@ mod tests { #[test] fn aggregations_combined() -> Result<()> { let schema = schema(); - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 1274fbe50a5f..f9d5a4c186ee 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -517,10 +517,10 @@ mod tests { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![agg.count_expr(&schema)], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -554,10 +554,10 @@ mod tests { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![agg.count_expr(&schema)], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 5895c39a5f87..154e77cd23ae 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -43,7 +43,8 @@ use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; +use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_plan::displayable; @@ -240,7 +241,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 79033643cf37..4f9187595018 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2181,7 +2181,6 @@ impl DefaultPhysicalPlanner { expr: &[Expr], ) -> Result> { let input_schema = input.as_ref().schema(); - let physical_exprs = expr .iter() .map(|e| { diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 8c9cffcf08d1..068383b20031 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -35,6 +35,7 @@ use datafusion::scalar::ScalarValue; use datafusion_common::cast::as_primitive_array; use datafusion_common::{internal_err, not_impl_err}; use datafusion_expr::expr::{BinaryExpr, Cast}; +use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; use async_trait::async_trait; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index befd98d04302..fa364c5f2a65 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -31,6 +31,7 @@ use arrow::{ }; use arrow_array::Float32Array; use arrow_schema::ArrowError; +use datafusion_functions_aggregate::count::count_udaf; use object_store::local::LocalFileSystem; use std::fs; use std::sync::Arc; @@ -51,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, when, wildcard, AggregateFunction, Expr, ExprSchemable, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, + placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::sum; +use datafusion_functions_aggregate::expr_fn::{count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { @@ -178,7 +179,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index b85f6376c3f2..4358691ee5a5 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -38,6 +38,7 @@ use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -165,7 +166,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), // its name "COUNT", // window function argument @@ -350,7 +351,7 @@ fn get_random_function( window_fn_map.insert( "count", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![arg.clone()], ), ); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 98ab8ec251f4..57f5414c13bd 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1861,6 +1861,7 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { null_treatment, }) => { write_function_name(w, &fun.to_string(), false, args)?; + if let Some(nt) = null_treatment { w.write_str(" ")?; write!(w, "{}", nt)?; @@ -1885,7 +1886,6 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { null_treatment, }) => { write_function_name(w, func_def.name(), *distinct, args)?; - if let Some(fe) = filter { write!(w, " FILTER (WHERE {fe})")?; }; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 420312050870..1fafc63e9665 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -193,6 +193,7 @@ pub fn avg(expr: Expr) -> Expr { } /// Create an expression to represent the count() aggregate function +// TODO: Remove this and use `expr_fn::count` instead pub fn count(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, @@ -250,6 +251,7 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { } /// Create an expression to represent the count(distinct) aggregate function +// TODO: Remove this and use `expr_fn::count_distinct` instead pub fn count_distinct(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs new file mode 100644 index 000000000000..cfd56619537b --- /dev/null +++ b/datafusion/functions-aggregate/src/count.rs @@ -0,0 +1,562 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ahash::RandomState; +use std::collections::HashSet; +use std::ops::BitAnd; +use std::{fmt::Debug, sync::Arc}; + +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{ + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, +}; + +use arrow::{ + array::{Array, BooleanArray, Int64Array, PrimitiveArray}, + buffer::BooleanBuffer, +}; +use datafusion_common::{ + downcast_value, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, + EmitTo, GroupsAccumulator, Signature, Volatility, +}; +use datafusion_expr::{Expr, ReversedUDAF}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices; +use datafusion_physical_expr_common::{ + aggregate::count_distinct::{ + BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, + PrimitiveDistinctCountAccumulator, + }, + binary_map::OutputType, +}; + +make_udaf_expr_and_func!( + Count, + count, + expr, + "Count the number of non-null values in the column", + count_udaf +); + +pub fn count_distinct(expr: Expr) -> datafusion_expr::Expr { + datafusion_expr::Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new_udf( + count_udaf(), + vec![expr], + true, + None, + None, + None, + ), + ) +} + +pub struct Count { + signature: Signature, + aliases: Vec, +} + +impl Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "COUNT" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + Ok(vec![Field::new_list( + format_state_name(args.name, "count distinct"), + Field::new("item", args.input_type.clone(), true), + false, + )]) + } else { + Ok(vec![Field::new( + format_state_name(args.name, "count"), + DataType::Int64, + true, + )]) + } + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if !acc_args.is_distinct { + return Ok(Box::new(CountAccumulator::new())); + } + + let data_type = acc_args.input_type; + Ok(match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal128Type, + >::new(data_type)), + DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal256Type, + >::new(data_type)), + + DataType::Date32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Date64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float32 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float64 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + + DataType::Utf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::LargeUtf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + + // Use the generic accumulator based on `ScalarValue` for all other types + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: data_type.clone(), + }), + }) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + // groups accumulator only supports `COUNT(c1)`, not + // `COUNT(c1, c2)`, etc + if args.is_distinct { + return false; + } + args.args_num == 1 + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + // instantiate specialized accumulator + Ok(Box::new(CountGroupsAccumulator::new())) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +#[derive(Debug)] +struct CountAccumulator { + count: i64, +} + +impl CountAccumulator { + /// new count accumulator + pub fn new() -> Self { + Self { count: 0 } + } +} + +impl Accumulator for CountAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Int64(Some(self.count))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Int64Array); + let delta = &arrow::compute::sum(counts); + if let Some(d) = delta { + self.count += *d; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.count))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +/// An accumulator to compute the counts of [`PrimitiveArray`]. +/// Stores values as native types, and does overflow checking +/// +/// Unlike most other accumulators, COUNT never produces NULLs. If no +/// non-null values are seen in any group the output is 0. Thus, this +/// accumulator has no additional null or seen filter tracking. +#[derive(Debug)] +struct CountGroupsAccumulator { + /// Count per group. + /// + /// Note this is an i64 and not a u64 (or usize) because the + /// output type of count is `DataType::Int64`. Thus by using `i64` + /// for the counts, the output [`Int64Array`] can be created + /// without copy. + counts: Vec, +} + +impl CountGroupsAccumulator { + pub fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = &values[0]; + + // Add one to each group's counter for each non null, non + // filtered value + self.counts.resize(total_num_groups, 0); + accumulate_indices( + group_indices, + values.logical_nulls().as_ref(), + opt_filter, + |group_index| { + self.counts[group_index] += 1; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values[0].as_primitive::(); + + // intermediate counts are always created as non null + assert_eq!(partial_counts.null_count(), 0); + let partial_counts = partial_counts.values(); + + // Adds the counts with the partial counts + self.counts.resize(total_num_groups, 0); + match opt_filter { + Some(filter) => filter + .iter() + .zip(group_indices.iter()) + .zip(partial_counts.iter()) + .for_each(|((filter_value, &group_index), partial_count)| { + if let Some(true) = filter_value { + self.counts[group_index] += partial_count; + } + }), + None => group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ), + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + + // Count is always non null (null inputs just don't contribute to the overall values) + let nulls = None; + let array = PrimitiveArray::::new(counts.into(), nulls); + + Ok(Arc::new(array)) + } + + // return arrays for counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls + Ok(vec![Arc::new(counts) as ArrayRef]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + } +} + +/// count null values for multiple columns +/// for each row if one column value is null, then null_count + 1 +fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { + if values.len() > 1 { + let result_bool_buf: Option = values + .iter() + .map(|a| a.logical_nulls()) + .fold(None, |acc, b| match (acc, b) { + (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), + (Some(acc), None) => Some(acc), + (None, Some(b)) => Some(b.into_inner()), + _ => None, + }); + result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) + } else { + values[0] + .logical_nulls() + .map_or(0, |nulls| nulls.null_count()) + } +} + +/// General purpose distinct accumulator that works for any DataType by using +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) +/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and +/// [`BytesDistinctCountAccumulator`] +#[derive(Debug)] +struct DistinctCountAccumulator { + values: HashSet, + state_data_type: DataType, +} + +impl DistinctCountAccumulator { + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types + fn fixed_size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .next() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .unwrap_or(0) + + std::mem::size_of::() + } + + // calculates the size as accurately as possible. Note that calling this + // method is expensive + fn full_size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .sum::() + + std::mem::size_of::() + } +} + +impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. + fn state(&mut self) -> Result> { + let scalars = self.values.iter().cloned().collect::>(); + let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = &values[0]; + if arr.data_type() == &DataType::Null { + return Ok(()); + } + + (0..arr.len()).try_for_each(|index| { + if !arr.is_null(index) { + let scalar = ScalarValue::try_from_array(arr, index)?; + self.values.insert(scalar); + } + Ok(()) + }) + } + + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!(states.len(), 1, "array_agg states must be singleton!"); + let array = &states[0]; + let list_array = array.as_list::(); + for inner_array in list_array.iter() { + let Some(inner_array) = inner_array else { + return internal_err!( + "Intermediate results of COUNT DISTINCT should always be non null" + ); + }; + self.update_batch(&[inner_array])?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + match &self.state_data_type { + DataType::Boolean | DataType::Null => self.fixed_size(), + d if d.is_primitive() => self.fixed_size(), + _ => self.full_size(), + } + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 2d062cf2cb9b..56fc1305bb59 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -56,6 +56,7 @@ pub mod macros; pub mod approx_distinct; +pub mod count; pub mod covariance; pub mod first_last; pub mod hyperloglog; @@ -77,6 +78,8 @@ use std::sync::Arc; pub mod expr_fn { pub use super::approx_distinct; pub use super::approx_median::approx_median; + pub use super::count::count; + pub use super::count::count_distinct; pub use super::covariance::covar_pop; pub use super::covariance::covar_samp; pub use super::first_last::first_value; @@ -98,6 +101,7 @@ pub fn all_default_aggregate_functions() -> Vec> { sum::sum_udaf(), covariance::covar_pop_udaf(), median::median_udaf(), + count::count_udaf(), variance::var_samp_udaf(), variance::var_pop_udaf(), stddev::stddev_udaf(), @@ -133,8 +137,8 @@ mod tests { let mut names = HashSet::new(); for func in all_default_aggregate_functions() { // TODO: remove this - // sum is in intermidiate migration state, skip this - if func.name().to_lowercase() == "sum" { + // These functions are in intermidiate migration state, skip them + if func.name().to_lowercase() == "count" { continue; } assert!( diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b55b1a7f8f2d..e14ee763a3c0 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -441,8 +441,14 @@ fn agg_exprs_evaluation_result_on_empty_batch( Transformed::yes(Expr::Literal(ScalarValue::Null)) } } - AggregateFunctionDefinition::UDF { .. } => { - Transformed::yes(Expr::Literal(ScalarValue::Null)) + AggregateFunctionDefinition::UDF(fun) => { + if fun.name() == "COUNT" { + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( + 0, + )))) + } else { + Transformed::yes(Expr::Literal(ScalarValue::Null)) + } } }, _ => Transformed::no(expr), diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 32b6703bcae5..e738209eb4fd 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -361,8 +361,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { mod tests { use super::*; use crate::test::*; - use datafusion_expr::expr; - use datafusion_expr::expr::GroupingSet; + use datafusion_expr::expr::{self, GroupingSet}; use datafusion_expr::test::function_stub::{sum, sum_udaf}; use datafusion_expr::{ count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, min, diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index 637b8775112e..3ef2d5345533 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -36,7 +36,9 @@ name = "datafusion_physical_expr_common" path = "src/lib.rs" [dependencies] +ahash = { workspace = true } arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +hashbrown = { workspace = true } rand = { workspace = true } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs similarity index 93% rename from datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs rename to datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs index 2ed9b002c841..5c888ca66caa 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs @@ -18,7 +18,7 @@ //! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values use crate::binary_map::{ArrowBytesSet, OutputType}; -use arrow_array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; use datafusion_common::ScalarValue; @@ -35,10 +35,10 @@ use std::sync::Arc; /// [`BinaryArray`]: arrow::array::BinaryArray /// [`LargeBinaryArray`]: arrow::array::LargeBinaryArray #[derive(Debug)] -pub(super) struct BytesDistinctCountAccumulator(ArrowBytesSet); +pub struct BytesDistinctCountAccumulator(ArrowBytesSet); impl BytesDistinctCountAccumulator { - pub(super) fn new(output_type: OutputType) -> Self { + pub fn new(output_type: OutputType) -> Self { Self(ArrowBytesSet::new(output_type)) } } diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs new file mode 100644 index 000000000000..f216406d0dd7 --- /dev/null +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bytes; +mod native; + +pub use bytes::BytesDistinctCountAccumulator; +pub use native::FloatDistinctCountAccumulator; +pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs similarity index 93% rename from datafusion/physical-expr/src/aggregate/count_distinct/native.rs rename to datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs index 0e7483d4a1cd..72b83676e81d 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs @@ -26,10 +26,10 @@ use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; +use arrow::array::types::ArrowPrimitiveType; use arrow::array::ArrayRef; -use arrow_array::types::ArrowPrimitiveType; -use arrow_array::PrimitiveArray; -use arrow_schema::DataType; +use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::utils::array_into_list_array; @@ -40,7 +40,7 @@ use datafusion_expr::Accumulator; use crate::aggregate::utils::Hashable; #[derive(Debug)] -pub(super) struct PrimitiveDistinctCountAccumulator +pub struct PrimitiveDistinctCountAccumulator where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, @@ -54,7 +54,7 @@ where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, { - pub(super) fn new(data_type: &DataType) -> Self { + pub fn new(data_type: &DataType) -> Self { Self { values: HashSet::default(), data_type: data_type.clone(), @@ -125,7 +125,7 @@ where } #[derive(Debug)] -pub(super) struct FloatDistinctCountAccumulator +pub struct FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { @@ -136,13 +136,22 @@ impl FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { - pub(super) fn new() -> Self { + pub fn new() -> Self { Self { values: HashSet::default(), } } } +impl Default for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn default() -> Self { + Self::new() + } +} + impl Accumulator for FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send + Debug, diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index ec02df57b82d..21884f840dbd 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod count_distinct; pub mod groups_accumulator; pub mod stats; pub mod tdigest; diff --git a/datafusion/physical-expr/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs similarity index 98% rename from datafusion/physical-expr/src/binary_map.rs rename to datafusion/physical-expr-common/src/binary_map.rs index 0923fcdaeb91..6d5ba737a1df 100644 --- a/datafusion/physical-expr/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -19,17 +19,16 @@ //! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray. use ahash::RandomState; -use arrow_array::cast::AsArray; -use arrow_array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; -use arrow_array::{ - Array, ArrayRef, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, +use arrow::array::cast::AsArray; +use arrow::array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; +use arrow::array::{ + Array, ArrayRef, BooleanBufferBuilder, BufferBuilder, GenericBinaryArray, + GenericStringArray, OffsetSizeTrait, }; -use arrow_buffer::{ - BooleanBufferBuilder, BufferBuilder, NullBuffer, OffsetBuffer, ScalarBuffer, -}; -use arrow_schema::DataType; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::DataType; use datafusion_common::hash_utils::create_hashes; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; use std::any::type_name; use std::fmt::Debug; use std::mem; @@ -605,8 +604,8 @@ where #[cfg(test)] mod tests { use super::*; - use arrow_array::{BinaryArray, LargeBinaryArray, StringArray}; - use hashbrown::HashMap; + use arrow::array::{BinaryArray, LargeBinaryArray, StringArray}; + use std::collections::HashMap; #[test] fn string_set_empty() { diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index f335958698ab..0ddb84141a07 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -16,6 +16,7 @@ // under the License. pub mod aggregate; +pub mod binary_map; pub mod expressions; pub mod physical_expr; pub mod sort_expr; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index ac24dd2e7603..aee7bca3b88f 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,12 +30,13 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; +use datafusion_expr::AggregateFunction; + use crate::aggregate::average::Avg; use crate::aggregate::regr::RegrType; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; -use datafusion_common::{exec_err, not_impl_err, Result}; -use datafusion_expr::AggregateFunction; /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -60,14 +61,9 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Count, false) => Arc::new( - expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, data_type), - ), - (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( - data_type, - input_phy_exprs[0].clone(), - name, - )), + (AggregateFunction::Count, _) => { + return internal_err!("Builtin Count will be removed"); + } (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, @@ -320,7 +316,7 @@ mod tests { use super::*; use crate::expressions::{ try_cast, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, - BoolOr, Count, DistinctArrayAgg, DistinctCount, Max, Min, + BoolOr, DistinctArrayAgg, Max, Min, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; @@ -328,8 +324,8 @@ mod tests { use datafusion_expr::{type_coercion, Signature}; #[test] - fn test_count_arragg_approx_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Count, AggregateFunction::ArrayAgg]; + fn test_approx_expr() -> Result<()> { + let funcs = vec![AggregateFunction::ArrayAgg]; let data_types = vec![ DataType::UInt32, DataType::Int32, @@ -352,29 +348,18 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Count => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::ArrayAgg { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new_list( + "c1", + Field::new("item", data_type.clone(), true), + true, + ), + result_agg_phy_exprs.field().unwrap() + ); + } let result_distinct = create_physical_agg_expr_for_test( &fun, @@ -383,29 +368,18 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Count => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_distinct.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::ArrayAgg { + assert!(result_distinct.as_any().is::()); + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new_list( + "c1", + Field::new("item", data_type.clone(), true), + true, + ), + result_agg_phy_exprs.field().unwrap() + ); + } } } Ok(()) diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs deleted file mode 100644 index aad18a82ab87..000000000000 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ /dev/null @@ -1,348 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::fmt::Debug; -use std::ops::BitAnd; -use std::sync::Arc; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, Int64Array}; -use arrow::compute; -use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::Field}; -use arrow_array::cast::AsArray; -use arrow_array::types::Int64Type; -use arrow_array::PrimitiveArray; -use arrow_buffer::BooleanBuffer; -use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; - -use crate::expressions::format_state_name; - -use super::groups_accumulator::accumulate::accumulate_indices; - -/// COUNT aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug, Clone)] -pub struct Count { - name: String, - data_type: DataType, - nullable: bool, - /// Input exprs - /// - /// For `COUNT(c1)` this is `[c1]` - /// For `COUNT(c1, c2)` this is `[c1, c2]` - exprs: Vec>, -} - -impl Count { - /// Create a new COUNT aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs: vec![expr], - data_type, - nullable: true, - } - } - - pub fn new_with_multiple_exprs( - exprs: Vec>, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs, - data_type, - nullable: true, - } - } -} - -/// An accumulator to compute the counts of [`PrimitiveArray`]. -/// Stores values as native types, and does overflow checking -/// -/// Unlike most other accumulators, COUNT never produces NULLs. If no -/// non-null values are seen in any group the output is 0. Thus, this -/// accumulator has no additional null or seen filter tracking. -#[derive(Debug)] -struct CountGroupsAccumulator { - /// Count per group. - /// - /// Note this is an i64 and not a u64 (or usize) because the - /// output type of count is `DataType::Int64`. Thus by using `i64` - /// for the counts, the output [`Int64Array`] can be created - /// without copy. - counts: Vec, -} - -impl CountGroupsAccumulator { - pub fn new() -> Self { - Self { counts: vec![] } - } -} - -impl GroupsAccumulator for CountGroupsAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = &values[0]; - - // Add one to each group's counter for each non null, non - // filtered value - self.counts.resize(total_num_groups, 0); - accumulate_indices( - group_indices, - values.logical_nulls().as_ref(), - opt_filter, - |group_index| { - self.counts[group_index] += 1; - }, - ); - - Ok(()) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "one argument to merge_batch"); - // first batch is counts, second is partial sums - let partial_counts = values[0].as_primitive::(); - - // intermediate counts are always created as non null - assert_eq!(partial_counts.null_count(), 0); - let partial_counts = partial_counts.values(); - - // Adds the counts with the partial counts - self.counts.resize(total_num_groups, 0); - match opt_filter { - Some(filter) => filter - .iter() - .zip(group_indices.iter()) - .zip(partial_counts.iter()) - .for_each(|((filter_value, &group_index), partial_count)| { - if let Some(true) = filter_value { - self.counts[group_index] += partial_count; - } - }), - None => group_indices.iter().zip(partial_counts.iter()).for_each( - |(&group_index, partial_count)| { - self.counts[group_index] += partial_count; - }, - ), - } - - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - - // Count is always non null (null inputs just don't contribute to the overall values) - let nulls = None; - let array = PrimitiveArray::::new(counts.into(), nulls); - - Ok(Arc::new(array)) - } - - // return arrays for counts - fn state(&mut self, emit_to: EmitTo) -> Result> { - let counts = emit_to.take_needed(&mut self.counts); - let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls - Ok(vec![Arc::new(counts) as ArrayRef]) - } - - fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - } -} - -/// count null values for multiple columns -/// for each row if one column value is null, then null_count + 1 -fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { - if values.len() > 1 { - let result_bool_buf: Option = values - .iter() - .map(|a| a.logical_nulls()) - .fold(None, |acc, b| match (acc, b) { - (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), - (Some(acc), None) => Some(acc), - (None, Some(b)) => Some(b.into_inner()), - _ => None, - }); - result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) - } else { - values[0] - .logical_nulls() - .map_or(0, |nulls| nulls.null_count()) - } -} - -impl AggregateExpr for Count { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, self.nullable)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "count"), - DataType::Int64, - true, - )]) - } - - fn expressions(&self) -> Vec> { - self.exprs.clone() - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - // groups accumulator only supports `COUNT(c1)`, not - // `COUNT(c1, c2)`, etc - self.exprs.len() == 1 - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn create_groups_accumulator(&self) -> Result> { - // instantiate specialized accumulator - Ok(Box::new(CountGroupsAccumulator::new())) - } - - fn with_new_expressions( - &self, - args: Vec>, - order_by_exprs: Vec>, - ) -> Option> { - debug_assert_eq!(self.exprs.len(), args.len()); - debug_assert!(order_by_exprs.is_empty()); - Some(Arc::new(Count { - name: self.name.clone(), - data_type: self.data_type.clone(), - nullable: self.nullable, - exprs: args, - })) - } -} - -impl PartialEq for Count { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.exprs.len() == x.exprs.len() - && self - .exprs - .iter() - .zip(x.exprs.iter()) - .all(|(expr1, expr2)| expr1.eq(expr2)) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -struct CountAccumulator { - count: i64, -} - -impl CountAccumulator { - /// new count accumulator - pub fn new() -> Self { - Self { count: 0 } - } -} - -impl Accumulator for CountAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Int64(Some(self.count))]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], Int64Array); - let delta = &compute::sum(counts); - if let Some(d) = delta { - self.count += *d; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.count))) - } - - fn supports_retract_batch(&self) -> bool { - true - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs deleted file mode 100644 index 52f1c5c0f9a0..000000000000 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ /dev/null @@ -1,718 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod bytes; -mod native; - -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use ahash::RandomState; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field, TimeUnit}; -use arrow_array::cast::AsArray; -use arrow_array::types::{ - Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; - -use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; - -use crate::aggregate::count_distinct::bytes::BytesDistinctCountAccumulator; -use crate::aggregate::count_distinct::native::{ - FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, -}; -use crate::aggregate::utils::down_cast_any_ref; -use crate::binary_map::OutputType; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -/// Expression for a `COUNT(DISTINCT)` aggregation. -#[derive(Debug)] -pub struct DistinctCount { - /// Column name - name: String, - /// The DataType used to hold the state for each input - state_data_type: DataType, - /// The input arguments - expr: Arc, -} - -impl DistinctCount { - /// Create a new COUNT(DISTINCT) aggregate function. - pub fn new( - input_data_type: DataType, - expr: Arc, - name: impl Into, - ) -> Self { - Self { - name: name.into(), - state_data_type: input_data_type, - expr, - } - } -} - -impl AggregateExpr for DistinctCount { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, true)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "count distinct"), - Field::new("item", self.state_data_type.clone(), true), - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - use DataType::*; - use TimeUnit::*; - - let data_type = &self.state_data_type; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - Date32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Date64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Time32(Millisecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32MillisecondType, - >::new(data_type)), - Time32(Second) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32SecondType, - >::new(data_type)), - Time64(Microsecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64MicrosecondType, - >::new(data_type)), - Time64(Nanosecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64NanosecondType, - >::new(data_type)), - Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMicrosecondType, - >::new(data_type)), - Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMillisecondType, - >::new(data_type)), - Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampNanosecondType, - >::new(data_type)), - Timestamp(Second, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampSecondType, - >::new(data_type)), - - Float16 => Box::new(FloatDistinctCountAccumulator::::new()), - Float32 => Box::new(FloatDistinctCountAccumulator::::new()), - Float64 => Box::new(FloatDistinctCountAccumulator::::new()), - - Utf8 => Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)), - LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - Binary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - }), - }) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctCount { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.state_data_type == x.state_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. -/// -/// It stores intermediate results as a `ListArray` -/// -/// Note that many types have specialized accumulators that are (much) -/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and -/// [`BytesDistinctCountAccumulator`] -#[derive(Debug)] -struct DistinctCountAccumulator { - values: HashSet, - state_data_type: DataType, -} - -impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * - // number of batches This method is faster than .full_size(), however it is - // not suitable for variable length values like strings or complex types - fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .unwrap_or(0) - + std::mem::size_of::() - } - - // calculates the size as accurately as possible. Note that calling this - // method is expensive - fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .sum::() - + std::mem::size_of::() - } -} - -impl Accumulator for DistinctCountAccumulator { - /// Returns the distinct values seen so far as (one element) ListArray. - fn state(&mut self) -> Result> { - let scalars = self.values.iter().cloned().collect::>(); - let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); - Ok(vec![ScalarValue::List(arr)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let arr = &values[0]; - if arr.data_type() == &DataType::Null { - return Ok(()); - } - - (0..arr.len()).try_for_each(|index| { - if !arr.is_null(index) { - let scalar = ScalarValue::try_from_array(arr, index)?; - self.values.insert(scalar); - } - Ok(()) - }) - } - - /// Merges multiple sets of distinct values into the current set. - /// - /// The input to this function is a `ListArray` with **multiple** rows, - /// where each row contains the values from a partial aggregate's phase (e.g. - /// the result of calling `Self::state` on multiple accumulators). - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert_eq!(states.len(), 1, "array_agg states must be singleton!"); - let array = &states[0]; - let list_array = array.as_list::(); - for inner_array in list_array.iter() { - let Some(inner_array) = inner_array else { - return internal_err!( - "Intermediate results of COUNT DISTINCT should always be non null" - ); - }; - self.update_batch(&[inner_array])?; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.values.len() as i64))) - } - - fn size(&self) -> usize { - match &self.state_data_type { - DataType::Boolean | DataType::Null => self.fixed_size(), - d if d.is_primitive() => self.fixed_size(), - _ => self.full_size(), - } - } -} - -#[cfg(test)] -mod tests { - use arrow::array::{ - BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }; - use arrow_array::Decimal256Array; - use arrow_buffer::i256; - - use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; - use datafusion_common::internal_err; - use datafusion_common::DataFusionError; - - use crate::expressions::NoOp; - - use super::*; - - macro_rules! state_to_vec_primitive { - ($LIST:expr, $DATA_TYPE:ident) => {{ - let arr = ScalarValue::raw_data($LIST).unwrap(); - let list_arr = as_list_array(&arr).unwrap(); - let arr = list_arr.values(); - let arr = as_primitive_array::<$DATA_TYPE>(arr)?; - arr.values().iter().cloned().collect::>() - }}; - } - - macro_rules! test_count_distinct_update_batch_numeric { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(1), - Some(1), - None, - Some(3), - Some(2), - None, - Some(2), - Some(3), - Some(1), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![1, 2, 3]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - fn state_to_vec_bool(sv: &ScalarValue) -> Result> { - let arr = ScalarValue::raw_data(sv)?; - let list_arr = as_list_array(&arr)?; - let arr = list_arr.values(); - let bool_arr = as_boolean_array(arr)?; - Ok(bool_arr.iter().flatten().collect()) - } - - fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - arrays[0].data_type().clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - accum.update_batch(arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - fn run_update( - data_types: &[DataType], - rows: &[Vec], - ) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - data_types[0].clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - - let cols = (0..rows[0].len()) - .map(|i| { - rows.iter() - .map(|inner| inner[i].clone()) - .collect::>() - }) - .collect::>(); - - let arrays: Vec = cols - .iter() - .map(|c| ScalarValue::iter_to_array(c.clone())) - .collect::>>()?; - - accum.update_batch(&arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - // Used trait to create associated constant for f32 and f64 - trait SubNormal: 'static { - const SUBNORMAL: Self; - } - - impl SubNormal for f64 { - const SUBNORMAL: Self = 1.0e-308_f64; - } - - impl SubNormal for f32 { - const SUBNORMAL: Self = 1.0e-38_f32; - } - - macro_rules! test_count_distinct_update_batch_floating_point { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(<$PRIM_TYPE>::INFINITY), - Some(<$PRIM_TYPE>::NAN), - Some(1.0), - Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), - Some(1.0), - Some(<$PRIM_TYPE>::INFINITY), - None, - Some(3.0), - Some(-4.5), - Some(2.0), - None, - Some(2.0), - Some(3.0), - Some(<$PRIM_TYPE>::NEG_INFINITY), - Some(1.0), - Some(<$PRIM_TYPE>::NAN), - Some(<$PRIM_TYPE>::NEG_INFINITY), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - - dbg!(&state_vec); - state_vec.sort_by(|a, b| match (a, b) { - (lhs, rhs) => lhs.total_cmp(rhs), - }); - - let nan_idx = state_vec.len() - 1; - assert_eq!(states.len(), 1); - assert_eq!( - &state_vec[..nan_idx], - vec![ - <$PRIM_TYPE>::NEG_INFINITY, - -4.5, - <$PRIM_TYPE as SubNormal>::SUBNORMAL, - 1.0, - 2.0, - 3.0, - <$PRIM_TYPE>::INFINITY - ] - ); - assert!(state_vec[nan_idx].is_nan()); - assert_eq!(result, ScalarValue::Int64(Some(8))); - - Ok(()) - }}; - } - - macro_rules! test_count_distinct_update_batch_bigint { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(i256::from(1)), - Some(i256::from(1)), - None, - Some(i256::from(3)), - Some(i256::from(2)), - None, - Some(i256::from(2)), - Some(i256::from(3)), - Some(i256::from(1)), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - #[test] - fn count_distinct_update_batch_i8() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) - } - - #[test] - fn count_distinct_update_batch_i16() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16) - } - - #[test] - fn count_distinct_update_batch_i32() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32) - } - - #[test] - fn count_distinct_update_batch_i64() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64) - } - - #[test] - fn count_distinct_update_batch_u8() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8) - } - - #[test] - fn count_distinct_update_batch_u16() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16) - } - - #[test] - fn count_distinct_update_batch_u32() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32) - } - - #[test] - fn count_distinct_update_batch_u64() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64) - } - - #[test] - fn count_distinct_update_batch_f32() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32) - } - - #[test] - fn count_distinct_update_batch_f64() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) - } - - #[test] - fn count_distinct_update_batch_i256() -> Result<()> { - test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) - } - - #[test] - fn count_distinct_update_batch_boolean() -> Result<()> { - let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { - let arrays = vec![Arc::new(data) as ArrayRef]; - let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = state_to_vec_bool(&states[0])?; - state_vec.sort(); - - let count = match result { - ScalarValue::Int64(c) => c.ok_or_else(|| { - DataFusionError::Internal("Found None count".to_string()) - }), - scalar => { - internal_err!("Found non int64 scalar value from count: {scalar}") - } - }?; - Ok((state_vec, count)) - }; - - let zero_count_values = BooleanArray::from(Vec::::new()); - - let one_count_values = BooleanArray::from(vec![false, false]); - let one_count_values_with_null = - BooleanArray::from(vec![Some(true), Some(true), None, None]); - - let two_count_values = BooleanArray::from(vec![true, false, true, false, true]); - let two_count_values_with_null = BooleanArray::from(vec![ - Some(true), - Some(false), - None, - None, - Some(true), - Some(false), - ]); - - assert_eq!(get_count(zero_count_values)?, (Vec::::new(), 0)); - assert_eq!(get_count(one_count_values)?, (vec![false], 1)); - assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1)); - assert_eq!(get_count(two_count_values)?, (vec![false, true], 2)); - assert_eq!( - get_count(two_count_values_with_null)?, - (vec![false, true], 2) - ); - Ok(()) - } - - #[test] - fn count_distinct_update_batch_all_nulls() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from( - vec![None, None, None, None] as Vec> - )) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update_batch_empty() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - Ok(()) - } - - #[test] - fn count_distinct_update_with_nulls() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(None)], - vec![ScalarValue::Int32(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(None)], - vec![ScalarValue::UInt64(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index 65227b727be7..a6946e739c97 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -20,7 +20,7 @@ pub use adapter::GroupsAccumulatorAdapter; // Backward compatibility pub(crate) mod accumulate { - pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::{accumulate_indices, NullState}; + pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; } pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 7a6c5f9d0e24..01105c8559c9 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -26,8 +26,6 @@ pub(crate) mod average; pub(crate) mod bit_and_or_xor; pub(crate) mod bool_and_or; pub(crate) mod correlation; -pub(crate) mod count; -pub(crate) mod count_distinct; pub(crate) mod covariance; pub(crate) mod grouping; pub(crate) mod nth_value; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index a96d02173018..123ada6d7c86 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -47,8 +47,6 @@ pub use crate::aggregate::bit_and_or_xor::{BitAnd, BitOr, BitXor, DistinctBitXor pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr}; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; -pub use crate::aggregate::count::Count; -pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 72f5f2d50cb8..b764e81a95d1 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -17,7 +17,9 @@ pub mod aggregate; pub mod analysis; -pub mod binary_map; +pub mod binary_map { + pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; +} pub mod equivalence; pub mod expressions; pub mod functions; diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index d073c8995a9b..f789af8b8a02 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -18,7 +18,7 @@ use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; -use datafusion_physical_expr::binary_map::{ArrowBytesMap, OutputType}; +use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 79abbdb52ca2..b6fc70be7cbc 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1194,12 +1194,14 @@ mod tests { use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::expr::Sort; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::median::median_udaf; use datafusion_physical_expr::expressions::{ - lit, Count, FirstValue, LastValue, OrderSensitiveArrayAgg, + lit, FirstValue, LastValue, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1334,11 +1336,16 @@ mod tests { ], }; - let aggregates: Vec> = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - ))]; + let aggregates = vec![create_aggregate_expr( + &count_udaf(), + &[lit(1i8)], + &[], + &[], + &input_schema, + "COUNT(1)", + false, + false, + )?]; let task_ctx = if spill { new_spill_ctx(4, 1000) diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 48f1bee59bbf..56d780e51394 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1194,9 +1194,9 @@ mod tests { RecordBatchStream, SendableRecordBatchStream, TaskContext, }; use datafusion_expr::{ - AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::{col, Column, NthValue}; use datafusion_physical_expr::window::{ BuiltInWindowExpr, BuiltInWindowFunctionExpr, @@ -1298,8 +1298,7 @@ mod tests { order_by: &str, ) -> Result> { let schema = input.schema(); - let window_fn = - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count); + let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf()); let col_expr = Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; let args = vec![col_expr]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 9b392d941ef4..63ce473fc57e 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -597,7 +597,6 @@ pub fn get_window_mode( #[cfg(test)] mod tests { use super::*; - use crate::aggregates::AggregateFunction; use crate::collect; use crate::expressions::col; use crate::streaming::StreamingTableExec; @@ -607,6 +606,7 @@ mod tests { use arrow::compute::SortOptions; use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; use futures::FutureExt; use InputOrderMode::{Linear, PartiallySorted, Sorted}; @@ -749,7 +749,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col("a", &schema)?], &[], diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 97568fb5f678..66ce7cbd838f 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 49884c48b3cc..9f8f03de6dc9 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen-common" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 358ba7e3eb94..b1897aa58e7d 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,7 +27,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index b6993f6c040b..eabaf7ba8e14 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index b401ff8810db..2bb3ec793d7f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -520,6 +520,7 @@ message AggregateExprNode { message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; + bool distinct = 5; LogicalExprNode filter = 3; repeated LogicalExprNode order_by = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d6632c77d8da..59b7861a6ef1 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -886,6 +886,9 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.args.is_empty() { len += 1; } + if self.distinct { + len += 1; + } if self.filter.is_some() { len += 1; } @@ -899,6 +902,9 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if self.distinct { + struct_ser.serialize_field("distinct", &self.distinct)?; + } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; } @@ -918,6 +924,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "fun_name", "funName", "args", + "distinct", "filter", "order_by", "orderBy", @@ -927,6 +934,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { enum GeneratedField { FunName, Args, + Distinct, Filter, OrderBy, } @@ -952,6 +960,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { match value { "funName" | "fun_name" => Ok(GeneratedField::FunName), "args" => Ok(GeneratedField::Args), + "distinct" => Ok(GeneratedField::Distinct), "filter" => Ok(GeneratedField::Filter), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -975,6 +984,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { { let mut fun_name__ = None; let mut args__ = None; + let mut distinct__ = None; let mut filter__ = None; let mut order_by__ = None; while let Some(k) = map_.next_key()? { @@ -991,6 +1001,12 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { } args__ = Some(map_.next_value()?); } + GeneratedField::Distinct => { + if distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); + } + distinct__ = Some(map_.next_value()?); + } GeneratedField::Filter => { if filter__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); @@ -1008,6 +1024,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { Ok(AggregateUdfExprNode { fun_name: fun_name__.unwrap_or_default(), args: args__.unwrap_or_default(), + distinct: distinct__.unwrap_or_default(), filter: filter__, order_by: order_by__.unwrap_or_default(), }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0aca5ef1ffb8..0861c287fcfa 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -767,6 +767,8 @@ pub struct AggregateUdfExprNode { pub fun_name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "5")] + pub distinct: bool, #[prost(message, optional, boxed, tag = "3")] pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "4")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3ad5973380ed..2ad40d883fe6 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -642,7 +642,7 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, parse_exprs(&pb.args, registry, codec)?, - false, + pb.distinct, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_vec_expr(&pb.order_by, registry, codec)?, None, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d42470f198e3..6a275ed7a1b8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -456,6 +456,7 @@ pub fn serialize_expr( protobuf::AggregateUdfExprNode { fun_name: fun.name().to_string(), args: serialize_exprs(args, codec)?, + distinct: *distinct, filter: match filter { Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), None => None, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 5258bdd11d86..e25447b023d8 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -25,10 +25,10 @@ use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, - Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, - NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, - RowNumber, StringAgg, TryCastExpr, WindowShift, + CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, + IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, + OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, StringAgg, + TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -240,12 +240,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); let mut distinct = false; - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Grouping } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::BitAnd diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 699697dd2f2c..d9736da69d42 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -649,6 +649,8 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), + count(lit(1)), + count_distinct(lit(1)), first_value(lit(1), None), first_value(lit(1), Some(vec![lit(2).sort(true, true)])), covar_samp(lit(1.5), lit(2.2)), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 9cf686dbd3d6..e517482f1db0 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -38,7 +38,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; -use datafusion::physical_expr::expressions::{Count, Max, NthValueAgg}; +use datafusion::physical_expr::expressions::{Max, NthValueAgg}; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -47,8 +47,8 @@ use datafusion::physical_plan::aggregates::{ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ - binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, - NotExpr, NthValue, PhysicalSortExpr, StringAgg, + binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr, NthValue, + PhysicalSortExpr, StringAgg, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::DataSinkExec; @@ -806,7 +806,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Count::new(udf_expr, "count", DataType::Int64))], + vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], vec![None], window, schema.clone(), @@ -818,31 +818,6 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { Ok(()) } -#[test] -fn roundtrip_distinct_count() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let aggregates: Vec> = vec![Arc::new(DistinctCount::new( - DataType::Int64, - col("b", &schema)?, - "COUNT(DISTINCT b)".to_string(), - ))]; - - let groups: Vec<(Arc, String)> = - vec![(col("a", &schema)?, "unused".to_string())]; - - roundtrip_test(Arc::new(AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::new_single(groups), - aggregates.clone(), - vec![None], - Arc::new(EmptyExec::new(schema.clone())), - schema, - )?)) -} - #[test] fn roundtrip_like() -> Result<()> { let schema = Schema::new(vec![ diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index e930af107f77..c7b9808c249d 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -46,7 +46,7 @@ statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c SELECT CAST(c1 AS INT) FROM aggregate_test_100 # aggregation_with_bad_arguments -statement error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +query error SELECT COUNT(DISTINCT) FROM aggregate_test_100 # query_cte_incorrect @@ -104,7 +104,7 @@ SELECT power(1, 2, 3); # # AggregateFunction with wrong number of arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'COUNT\(\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tCOUNT\(Any, \.\., Any\) +query error select count(); # AggregateFunction with wrong number of arguments diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index ee96ffa67044..d934dba4cfea 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" [lints] workspace = true diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 3f9a895d951c..93f197885c0a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -982,18 +982,16 @@ pub async fn from_substrait_agg_func( let function_name = substrait_fun_name((**function_name).as_str()); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { + // deal with situation that count(*) got no arguments + if fun.name() == "COUNT" && args.is_empty() { + args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); + } + Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { - match &fun { - // deal with situation that count(*) got no arguments - aggregate_function::AggregateFunction::Count if args.is_empty() => { - args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); - } - _ => {} - } Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None), )))