From 8f718dd3ce291c9f5688144ca6c9d7d854dc4b0b Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 13 Jun 2024 07:54:39 +0800 Subject: [PATCH] Move `Count` to `functions-aggregate`, update MSRV to rust 1.75 (#10484) * mv accumulate indices Signed-off-by: jayzhan211 * complete udaf Signed-off-by: jayzhan211 * register Signed-off-by: jayzhan211 * fix expr Signed-off-by: jayzhan211 * filter distinct count Signed-off-by: jayzhan211 * todo: need to move count distinct too Signed-off-by: jayzhan211 * move code around Signed-off-by: jayzhan211 * move distinct to aggr-crate Signed-off-by: jayzhan211 * replace Signed-off-by: jayzhan211 * backup Signed-off-by: jayzhan211 * fix function name and physical expr Signed-off-by: jayzhan211 * fix physical optimizer Signed-off-by: jayzhan211 * fix all slt Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix with args Signed-off-by: jayzhan211 * add label Signed-off-by: jayzhan211 * revert builtin related code back Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * fix substrait Signed-off-by: jayzhan211 * fix doc Signed-off-by: jayzhan211 * fmy Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 * fix udaf macro for distinct but not apply Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix count distinct and use workspace Signed-off-by: jayzhan211 * add reverse Signed-off-by: jayzhan211 * remove old code Signed-off-by: jayzhan211 * backup Signed-off-by: jayzhan211 * use macro Signed-off-by: jayzhan211 * expr builder Signed-off-by: jayzhan211 * introduce expr builder Signed-off-by: jayzhan211 * add example Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * clean agg sta Signed-off-by: jayzhan211 * combine agg Signed-off-by: jayzhan211 * limit distinct and fmt Signed-off-by: jayzhan211 * cleanup name Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fix window Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix merged Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 * fix rebase Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * use std Signed-off-by: jayzhan211 * update mrsv Signed-off-by: jayzhan211 * upd msrv Signed-off-by: jayzhan211 * revert test Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * downgrade to 1.75 Signed-off-by: jayzhan211 * 1.76 Signed-off-by: jayzhan211 * ahas Signed-off-by: jayzhan211 * revert to 1.75 Signed-off-by: jayzhan211 * rm count Signed-off-by: jayzhan211 * fix merge Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * rm sum in test_no_duplicate_name Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- Cargo.toml | 4 +- datafusion-cli/Cargo.lock | 2 + datafusion-cli/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/core/src/dataframe/mod.rs | 13 +- .../aggregate_statistics.rs | 79 +- .../combine_partial_final_agg.rs | 47 +- .../limited_distinct_aggregation.rs | 16 +- .../core/src/physical_optimizer/test_utils.rs | 5 +- datafusion/core/src/physical_planner.rs | 1 - .../provider_filter_pushdown.rs | 1 + datafusion/core/tests/dataframe/mod.rs | 11 +- .../core/tests/fuzz_cases/window_fuzz.rs | 5 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/expr_fn.rs | 2 + datafusion/functions-aggregate/src/count.rs | 562 ++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 8 +- datafusion/optimizer/src/decorrelate.rs | 10 +- .../src/single_distinct_to_groupby.rs | 3 +- datafusion/physical-expr-common/Cargo.toml | 2 + .../src/aggregate/count_distinct/bytes.rs | 6 +- .../src/aggregate/count_distinct/mod.rs | 23 + .../src/aggregate/count_distinct/native.rs | 23 +- .../physical-expr-common/src/aggregate/mod.rs | 1 + .../src/binary_map.rs | 21 +- datafusion/physical-expr-common/src/lib.rs | 1 + .../physical-expr/src/aggregate/build_in.rs | 92 +-- .../physical-expr/src/aggregate/count.rs | 348 --------- .../src/aggregate/count_distinct/mod.rs | 718 ------------------ .../src/aggregate/groups_accumulator/mod.rs | 2 +- datafusion/physical-expr/src/aggregate/mod.rs | 2 - .../physical-expr/src/expressions/mod.rs | 2 - datafusion/physical-expr/src/lib.rs | 4 +- .../src/aggregates/group_values/bytes.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 19 +- .../src/windows/bounded_window_agg_exec.rs | 7 +- datafusion/physical-plan/src/windows/mod.rs | 4 +- datafusion/proto-common/Cargo.toml | 2 +- datafusion/proto-common/gen/Cargo.toml | 2 +- datafusion/proto/Cargo.toml | 2 +- datafusion/proto/gen/Cargo.toml | 2 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 17 + datafusion/proto/src/generated/prost.rs | 2 + .../proto/src/logical_plan/from_proto.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 + .../proto/src/physical_plan/to_proto.rs | 15 +- .../tests/cases/roundtrip_logical_plan.rs | 2 + .../tests/cases/roundtrip_physical_plan.rs | 33 +- datafusion/sqllogictest/test_files/errors.slt | 4 +- datafusion/substrait/Cargo.toml | 2 +- .../substrait/src/logical_plan/consumer.rs | 12 +- 52 files changed, 822 insertions(+), 1329 deletions(-) create mode 100644 datafusion/functions-aggregate/src/count.rs rename datafusion/{physical-expr => physical-expr-common}/src/aggregate/count_distinct/bytes.rs (93%) create mode 100644 datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs rename datafusion/{physical-expr => physical-expr-common}/src/aggregate/count_distinct/native.rs (93%) rename datafusion/{physical-expr => physical-expr-common}/src/binary_map.rs (98%) delete mode 100644 datafusion/physical-expr/src/aggregate/count.rs delete mode 100644 datafusion/physical-expr/src/aggregate/count_distinct/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 65ef191d7421..aa1ba1f214d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.73" +rust-version = "1.75" version = "39.0.0" [workspace.dependencies] @@ -107,7 +107,7 @@ doc-comment = "0.3" env_logger = "0.11" futures = "0.3" half = { version = "2.2.1", default-features = false } -hashbrown = { version = "0.14", features = ["raw"] } +hashbrown = { version = "0.14.5", features = ["raw"] } indexmap = "2.0.0" itertools = "0.12" log = "^0.4" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 932f44d98486..c5b34df4f1cf 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1376,9 +1376,11 @@ dependencies = [ name = "datafusion-physical-expr-common" version = "39.0.0" dependencies = [ + "ahash", "arrow", "datafusion-common", "datafusion-expr", + "hashbrown 0.14.5", "rand", ] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 5e393246b958..8f4b3cd81f36 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" readme = "README.md" [dependencies] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 7533e2cff198..45617d88dc0c 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.73" +rust-version = "1.75" [lints] workspace = true diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 06a85d303687..950cb7ddb2d3 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,12 +50,11 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, + avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::sum; -use datafusion_functions_aggregate::expr_fn::{median, stddev}; +use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum}; use async_trait::async_trait; @@ -854,10 +853,7 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate( - vec![], - vec![datafusion_expr::count(Expr::Literal(COUNT_STAR_EXPANSION))], - )? + .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? .collect() .await?; let len = *rows @@ -1594,9 +1590,10 @@ mod tests { use datafusion_common::{Constraint, Constraints}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, count_distinct, create_udf, expr, lit, BuiltInWindowFunction, + array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::expr_fn::count_distinct; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 05f05d95b8db..eeacc48b85db 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -170,38 +170,6 @@ fn take_optimizable_column_and_table_count( } } } - // TODO: Remove this after revmoing Builtin Count - else if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( - &stats.num_rows, - agg_expr.as_any().downcast_ref::(), - ) { - // TODO implementing Eq on PhysicalExpr would help a lot here - if casted_expr.expressions().len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - casted_expr.name().to_string(), - )); - } - } else if let Some(lit_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - casted_expr.name().to_owned(), - )); - } - } - } - } None } @@ -307,13 +275,12 @@ fn take_optimizable_max( #[cfg(test)] pub(crate) mod tests { - use super::*; + use crate::logical_expr::Operator; use crate::physical_plan::aggregates::PhysicalGroupBy; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; - use crate::physical_plan::expressions::Count; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::memory::MemoryExec; use crate::prelude::SessionContext; @@ -322,8 +289,10 @@ pub(crate) mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_int64_array; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::cast; use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use datafusion_physical_plan::aggregates::AggregateMode; /// Mock data using a MemoryExec which has an exact count statistic @@ -414,13 +383,19 @@ pub(crate) mod tests { Self::ColumnA(schema.clone()) } - /// Return appropriate expr depending if COUNT is for col or table (*) - pub(crate) fn count_expr(&self) -> Arc { - Arc::new(Count::new( - self.column(), + // Return appropriate expr depending if COUNT is for col or table (*) + pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { + create_aggregate_expr( + &count_udaf(), + &[self.column()], + &[], + &[], + schema, self.column_name(), - DataType::Int64, - )) + false, + false, + ) + .unwrap() } /// what argument would this aggregate need in the plan? @@ -458,7 +433,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -467,7 +442,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -488,7 +463,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -497,7 +472,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -517,7 +492,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -529,7 +504,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -549,7 +524,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -561,7 +536,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -592,7 +567,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], filter, Arc::clone(&schema), @@ -601,7 +576,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -637,7 +612,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], filter, Arc::clone(&schema), @@ -646,7 +621,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 3ad61e52c82e..38b92959e841 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -206,8 +206,9 @@ mod tests { use crate::physical_plan::{displayable, Partitioning}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; - use datafusion_physical_expr::expressions::{col, Count}; + use datafusion_physical_expr::expressions::col; use datafusion_physical_plan::udaf::create_aggregate_expr; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected @@ -303,15 +304,31 @@ mod tests { ) } + // Return appropriate expr depending if COUNT is for col or table (*) + fn count_expr( + expr: Arc, + name: &str, + schema: &Schema, + ) -> Arc { + create_aggregate_expr( + &count_udaf(), + &[expr], + &[], + &[], + schema, + name, + false, + false, + ) + .unwrap() + } + #[test] fn aggregations_not_combined() -> Result<()> { let schema = schema(); - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + let plan = final_aggregate_exec( repartition_exec(partial_aggregate_exec( parquet_exec(&schema), @@ -330,16 +347,8 @@ mod tests { ]; assert_optimized!(expected, plan); - let aggr_expr1 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; - let aggr_expr2 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(2)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr1 = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + let aggr_expr2 = vec![count_expr(lit(1i8), "COUNT(2)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( @@ -365,11 +374,7 @@ mod tests { #[test] fn aggregations_combined() -> Result<()> { let schema = schema(); - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 1274fbe50a5f..f9d5a4c186ee 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -517,10 +517,10 @@ mod tests { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![agg.count_expr(&schema)], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -554,10 +554,10 @@ mod tests { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![agg.count_expr(&schema)], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 5895c39a5f87..154e77cd23ae 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -43,7 +43,8 @@ use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; +use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_plan::displayable; @@ -240,7 +241,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 79033643cf37..4f9187595018 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2181,7 +2181,6 @@ impl DefaultPhysicalPlanner { expr: &[Expr], ) -> Result> { let input_schema = input.as_ref().schema(); - let physical_exprs = expr .iter() .map(|e| { diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 8c9cffcf08d1..068383b20031 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -35,6 +35,7 @@ use datafusion::scalar::ScalarValue; use datafusion_common::cast::as_primitive_array; use datafusion_common::{internal_err, not_impl_err}; use datafusion_expr::expr::{BinaryExpr, Cast}; +use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; use async_trait::async_trait; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index befd98d04302..fa364c5f2a65 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -31,6 +31,7 @@ use arrow::{ }; use arrow_array::Float32Array; use arrow_schema::ArrowError; +use datafusion_functions_aggregate::count::count_udaf; use object_store::local::LocalFileSystem; use std::fs; use std::sync::Arc; @@ -51,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, when, wildcard, AggregateFunction, Expr, ExprSchemable, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, + placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::sum; +use datafusion_functions_aggregate::expr_fn::{count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { @@ -178,7 +179,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index b85f6376c3f2..4358691ee5a5 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -38,6 +38,7 @@ use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -165,7 +166,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), // its name "COUNT", // window function argument @@ -350,7 +351,7 @@ fn get_random_function( window_fn_map.insert( "count", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![arg.clone()], ), ); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 98ab8ec251f4..57f5414c13bd 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1861,6 +1861,7 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { null_treatment, }) => { write_function_name(w, &fun.to_string(), false, args)?; + if let Some(nt) = null_treatment { w.write_str(" ")?; write!(w, "{}", nt)?; @@ -1885,7 +1886,6 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { null_treatment, }) => { write_function_name(w, func_def.name(), *distinct, args)?; - if let Some(fe) = filter { write!(w, " FILTER (WHERE {fe})")?; }; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 420312050870..1fafc63e9665 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -193,6 +193,7 @@ pub fn avg(expr: Expr) -> Expr { } /// Create an expression to represent the count() aggregate function +// TODO: Remove this and use `expr_fn::count` instead pub fn count(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, @@ -250,6 +251,7 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { } /// Create an expression to represent the count(distinct) aggregate function +// TODO: Remove this and use `expr_fn::count_distinct` instead pub fn count_distinct(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs new file mode 100644 index 000000000000..cfd56619537b --- /dev/null +++ b/datafusion/functions-aggregate/src/count.rs @@ -0,0 +1,562 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ahash::RandomState; +use std::collections::HashSet; +use std::ops::BitAnd; +use std::{fmt::Debug, sync::Arc}; + +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{ + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, +}; + +use arrow::{ + array::{Array, BooleanArray, Int64Array, PrimitiveArray}, + buffer::BooleanBuffer, +}; +use datafusion_common::{ + downcast_value, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, + EmitTo, GroupsAccumulator, Signature, Volatility, +}; +use datafusion_expr::{Expr, ReversedUDAF}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices; +use datafusion_physical_expr_common::{ + aggregate::count_distinct::{ + BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, + PrimitiveDistinctCountAccumulator, + }, + binary_map::OutputType, +}; + +make_udaf_expr_and_func!( + Count, + count, + expr, + "Count the number of non-null values in the column", + count_udaf +); + +pub fn count_distinct(expr: Expr) -> datafusion_expr::Expr { + datafusion_expr::Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new_udf( + count_udaf(), + vec![expr], + true, + None, + None, + None, + ), + ) +} + +pub struct Count { + signature: Signature, + aliases: Vec, +} + +impl Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "COUNT" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + Ok(vec![Field::new_list( + format_state_name(args.name, "count distinct"), + Field::new("item", args.input_type.clone(), true), + false, + )]) + } else { + Ok(vec![Field::new( + format_state_name(args.name, "count"), + DataType::Int64, + true, + )]) + } + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if !acc_args.is_distinct { + return Ok(Box::new(CountAccumulator::new())); + } + + let data_type = acc_args.input_type; + Ok(match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal128Type, + >::new(data_type)), + DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal256Type, + >::new(data_type)), + + DataType::Date32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Date64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float32 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float64 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + + DataType::Utf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::LargeUtf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + + // Use the generic accumulator based on `ScalarValue` for all other types + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: data_type.clone(), + }), + }) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + // groups accumulator only supports `COUNT(c1)`, not + // `COUNT(c1, c2)`, etc + if args.is_distinct { + return false; + } + args.args_num == 1 + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + // instantiate specialized accumulator + Ok(Box::new(CountGroupsAccumulator::new())) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +#[derive(Debug)] +struct CountAccumulator { + count: i64, +} + +impl CountAccumulator { + /// new count accumulator + pub fn new() -> Self { + Self { count: 0 } + } +} + +impl Accumulator for CountAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Int64(Some(self.count))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Int64Array); + let delta = &arrow::compute::sum(counts); + if let Some(d) = delta { + self.count += *d; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.count))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +/// An accumulator to compute the counts of [`PrimitiveArray`]. +/// Stores values as native types, and does overflow checking +/// +/// Unlike most other accumulators, COUNT never produces NULLs. If no +/// non-null values are seen in any group the output is 0. Thus, this +/// accumulator has no additional null or seen filter tracking. +#[derive(Debug)] +struct CountGroupsAccumulator { + /// Count per group. + /// + /// Note this is an i64 and not a u64 (or usize) because the + /// output type of count is `DataType::Int64`. Thus by using `i64` + /// for the counts, the output [`Int64Array`] can be created + /// without copy. + counts: Vec, +} + +impl CountGroupsAccumulator { + pub fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = &values[0]; + + // Add one to each group's counter for each non null, non + // filtered value + self.counts.resize(total_num_groups, 0); + accumulate_indices( + group_indices, + values.logical_nulls().as_ref(), + opt_filter, + |group_index| { + self.counts[group_index] += 1; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values[0].as_primitive::(); + + // intermediate counts are always created as non null + assert_eq!(partial_counts.null_count(), 0); + let partial_counts = partial_counts.values(); + + // Adds the counts with the partial counts + self.counts.resize(total_num_groups, 0); + match opt_filter { + Some(filter) => filter + .iter() + .zip(group_indices.iter()) + .zip(partial_counts.iter()) + .for_each(|((filter_value, &group_index), partial_count)| { + if let Some(true) = filter_value { + self.counts[group_index] += partial_count; + } + }), + None => group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ), + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + + // Count is always non null (null inputs just don't contribute to the overall values) + let nulls = None; + let array = PrimitiveArray::::new(counts.into(), nulls); + + Ok(Arc::new(array)) + } + + // return arrays for counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls + Ok(vec![Arc::new(counts) as ArrayRef]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + } +} + +/// count null values for multiple columns +/// for each row if one column value is null, then null_count + 1 +fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { + if values.len() > 1 { + let result_bool_buf: Option = values + .iter() + .map(|a| a.logical_nulls()) + .fold(None, |acc, b| match (acc, b) { + (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), + (Some(acc), None) => Some(acc), + (None, Some(b)) => Some(b.into_inner()), + _ => None, + }); + result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) + } else { + values[0] + .logical_nulls() + .map_or(0, |nulls| nulls.null_count()) + } +} + +/// General purpose distinct accumulator that works for any DataType by using +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) +/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and +/// [`BytesDistinctCountAccumulator`] +#[derive(Debug)] +struct DistinctCountAccumulator { + values: HashSet, + state_data_type: DataType, +} + +impl DistinctCountAccumulator { + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types + fn fixed_size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .next() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .unwrap_or(0) + + std::mem::size_of::() + } + + // calculates the size as accurately as possible. Note that calling this + // method is expensive + fn full_size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .sum::() + + std::mem::size_of::() + } +} + +impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. + fn state(&mut self) -> Result> { + let scalars = self.values.iter().cloned().collect::>(); + let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = &values[0]; + if arr.data_type() == &DataType::Null { + return Ok(()); + } + + (0..arr.len()).try_for_each(|index| { + if !arr.is_null(index) { + let scalar = ScalarValue::try_from_array(arr, index)?; + self.values.insert(scalar); + } + Ok(()) + }) + } + + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!(states.len(), 1, "array_agg states must be singleton!"); + let array = &states[0]; + let list_array = array.as_list::(); + for inner_array in list_array.iter() { + let Some(inner_array) = inner_array else { + return internal_err!( + "Intermediate results of COUNT DISTINCT should always be non null" + ); + }; + self.update_batch(&[inner_array])?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + match &self.state_data_type { + DataType::Boolean | DataType::Null => self.fixed_size(), + d if d.is_primitive() => self.fixed_size(), + _ => self.full_size(), + } + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 2d062cf2cb9b..56fc1305bb59 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -56,6 +56,7 @@ pub mod macros; pub mod approx_distinct; +pub mod count; pub mod covariance; pub mod first_last; pub mod hyperloglog; @@ -77,6 +78,8 @@ use std::sync::Arc; pub mod expr_fn { pub use super::approx_distinct; pub use super::approx_median::approx_median; + pub use super::count::count; + pub use super::count::count_distinct; pub use super::covariance::covar_pop; pub use super::covariance::covar_samp; pub use super::first_last::first_value; @@ -98,6 +101,7 @@ pub fn all_default_aggregate_functions() -> Vec> { sum::sum_udaf(), covariance::covar_pop_udaf(), median::median_udaf(), + count::count_udaf(), variance::var_samp_udaf(), variance::var_pop_udaf(), stddev::stddev_udaf(), @@ -133,8 +137,8 @@ mod tests { let mut names = HashSet::new(); for func in all_default_aggregate_functions() { // TODO: remove this - // sum is in intermidiate migration state, skip this - if func.name().to_lowercase() == "sum" { + // These functions are in intermidiate migration state, skip them + if func.name().to_lowercase() == "count" { continue; } assert!( diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b55b1a7f8f2d..e14ee763a3c0 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -441,8 +441,14 @@ fn agg_exprs_evaluation_result_on_empty_batch( Transformed::yes(Expr::Literal(ScalarValue::Null)) } } - AggregateFunctionDefinition::UDF { .. } => { - Transformed::yes(Expr::Literal(ScalarValue::Null)) + AggregateFunctionDefinition::UDF(fun) => { + if fun.name() == "COUNT" { + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( + 0, + )))) + } else { + Transformed::yes(Expr::Literal(ScalarValue::Null)) + } } }, _ => Transformed::no(expr), diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 32b6703bcae5..e738209eb4fd 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -361,8 +361,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { mod tests { use super::*; use crate::test::*; - use datafusion_expr::expr; - use datafusion_expr::expr::GroupingSet; + use datafusion_expr::expr::{self, GroupingSet}; use datafusion_expr::test::function_stub::{sum, sum_udaf}; use datafusion_expr::{ count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, min, diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index 637b8775112e..3ef2d5345533 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -36,7 +36,9 @@ name = "datafusion_physical_expr_common" path = "src/lib.rs" [dependencies] +ahash = { workspace = true } arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +hashbrown = { workspace = true } rand = { workspace = true } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs similarity index 93% rename from datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs rename to datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs index 2ed9b002c841..5c888ca66caa 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs @@ -18,7 +18,7 @@ //! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values use crate::binary_map::{ArrowBytesSet, OutputType}; -use arrow_array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; use datafusion_common::ScalarValue; @@ -35,10 +35,10 @@ use std::sync::Arc; /// [`BinaryArray`]: arrow::array::BinaryArray /// [`LargeBinaryArray`]: arrow::array::LargeBinaryArray #[derive(Debug)] -pub(super) struct BytesDistinctCountAccumulator(ArrowBytesSet); +pub struct BytesDistinctCountAccumulator(ArrowBytesSet); impl BytesDistinctCountAccumulator { - pub(super) fn new(output_type: OutputType) -> Self { + pub fn new(output_type: OutputType) -> Self { Self(ArrowBytesSet::new(output_type)) } } diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs new file mode 100644 index 000000000000..f216406d0dd7 --- /dev/null +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bytes; +mod native; + +pub use bytes::BytesDistinctCountAccumulator; +pub use native::FloatDistinctCountAccumulator; +pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs similarity index 93% rename from datafusion/physical-expr/src/aggregate/count_distinct/native.rs rename to datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs index 0e7483d4a1cd..72b83676e81d 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs @@ -26,10 +26,10 @@ use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; +use arrow::array::types::ArrowPrimitiveType; use arrow::array::ArrayRef; -use arrow_array::types::ArrowPrimitiveType; -use arrow_array::PrimitiveArray; -use arrow_schema::DataType; +use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::utils::array_into_list_array; @@ -40,7 +40,7 @@ use datafusion_expr::Accumulator; use crate::aggregate::utils::Hashable; #[derive(Debug)] -pub(super) struct PrimitiveDistinctCountAccumulator +pub struct PrimitiveDistinctCountAccumulator where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, @@ -54,7 +54,7 @@ where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, { - pub(super) fn new(data_type: &DataType) -> Self { + pub fn new(data_type: &DataType) -> Self { Self { values: HashSet::default(), data_type: data_type.clone(), @@ -125,7 +125,7 @@ where } #[derive(Debug)] -pub(super) struct FloatDistinctCountAccumulator +pub struct FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { @@ -136,13 +136,22 @@ impl FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { - pub(super) fn new() -> Self { + pub fn new() -> Self { Self { values: HashSet::default(), } } } +impl Default for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn default() -> Self { + Self::new() + } +} + impl Accumulator for FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send + Debug, diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index ec02df57b82d..21884f840dbd 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod count_distinct; pub mod groups_accumulator; pub mod stats; pub mod tdigest; diff --git a/datafusion/physical-expr/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs similarity index 98% rename from datafusion/physical-expr/src/binary_map.rs rename to datafusion/physical-expr-common/src/binary_map.rs index 0923fcdaeb91..6d5ba737a1df 100644 --- a/datafusion/physical-expr/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -19,17 +19,16 @@ //! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray. use ahash::RandomState; -use arrow_array::cast::AsArray; -use arrow_array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; -use arrow_array::{ - Array, ArrayRef, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, +use arrow::array::cast::AsArray; +use arrow::array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; +use arrow::array::{ + Array, ArrayRef, BooleanBufferBuilder, BufferBuilder, GenericBinaryArray, + GenericStringArray, OffsetSizeTrait, }; -use arrow_buffer::{ - BooleanBufferBuilder, BufferBuilder, NullBuffer, OffsetBuffer, ScalarBuffer, -}; -use arrow_schema::DataType; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::DataType; use datafusion_common::hash_utils::create_hashes; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; use std::any::type_name; use std::fmt::Debug; use std::mem; @@ -605,8 +604,8 @@ where #[cfg(test)] mod tests { use super::*; - use arrow_array::{BinaryArray, LargeBinaryArray, StringArray}; - use hashbrown::HashMap; + use arrow::array::{BinaryArray, LargeBinaryArray, StringArray}; + use std::collections::HashMap; #[test] fn string_set_empty() { diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index f335958698ab..0ddb84141a07 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -16,6 +16,7 @@ // under the License. pub mod aggregate; +pub mod binary_map; pub mod expressions; pub mod physical_expr; pub mod sort_expr; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index ac24dd2e7603..aee7bca3b88f 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,12 +30,13 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; +use datafusion_expr::AggregateFunction; + use crate::aggregate::average::Avg; use crate::aggregate::regr::RegrType; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; -use datafusion_common::{exec_err, not_impl_err, Result}; -use datafusion_expr::AggregateFunction; /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -60,14 +61,9 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Count, false) => Arc::new( - expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, data_type), - ), - (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( - data_type, - input_phy_exprs[0].clone(), - name, - )), + (AggregateFunction::Count, _) => { + return internal_err!("Builtin Count will be removed"); + } (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, @@ -320,7 +316,7 @@ mod tests { use super::*; use crate::expressions::{ try_cast, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, - BoolOr, Count, DistinctArrayAgg, DistinctCount, Max, Min, + BoolOr, DistinctArrayAgg, Max, Min, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; @@ -328,8 +324,8 @@ mod tests { use datafusion_expr::{type_coercion, Signature}; #[test] - fn test_count_arragg_approx_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Count, AggregateFunction::ArrayAgg]; + fn test_approx_expr() -> Result<()> { + let funcs = vec![AggregateFunction::ArrayAgg]; let data_types = vec![ DataType::UInt32, DataType::Int32, @@ -352,29 +348,18 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Count => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::ArrayAgg { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new_list( + "c1", + Field::new("item", data_type.clone(), true), + true, + ), + result_agg_phy_exprs.field().unwrap() + ); + } let result_distinct = create_physical_agg_expr_for_test( &fun, @@ -383,29 +368,18 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Count => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_distinct.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::ArrayAgg { + assert!(result_distinct.as_any().is::()); + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new_list( + "c1", + Field::new("item", data_type.clone(), true), + true, + ), + result_agg_phy_exprs.field().unwrap() + ); + } } } Ok(()) diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs deleted file mode 100644 index aad18a82ab87..000000000000 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ /dev/null @@ -1,348 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::fmt::Debug; -use std::ops::BitAnd; -use std::sync::Arc; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, Int64Array}; -use arrow::compute; -use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::Field}; -use arrow_array::cast::AsArray; -use arrow_array::types::Int64Type; -use arrow_array::PrimitiveArray; -use arrow_buffer::BooleanBuffer; -use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; - -use crate::expressions::format_state_name; - -use super::groups_accumulator::accumulate::accumulate_indices; - -/// COUNT aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug, Clone)] -pub struct Count { - name: String, - data_type: DataType, - nullable: bool, - /// Input exprs - /// - /// For `COUNT(c1)` this is `[c1]` - /// For `COUNT(c1, c2)` this is `[c1, c2]` - exprs: Vec>, -} - -impl Count { - /// Create a new COUNT aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs: vec![expr], - data_type, - nullable: true, - } - } - - pub fn new_with_multiple_exprs( - exprs: Vec>, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs, - data_type, - nullable: true, - } - } -} - -/// An accumulator to compute the counts of [`PrimitiveArray`]. -/// Stores values as native types, and does overflow checking -/// -/// Unlike most other accumulators, COUNT never produces NULLs. If no -/// non-null values are seen in any group the output is 0. Thus, this -/// accumulator has no additional null or seen filter tracking. -#[derive(Debug)] -struct CountGroupsAccumulator { - /// Count per group. - /// - /// Note this is an i64 and not a u64 (or usize) because the - /// output type of count is `DataType::Int64`. Thus by using `i64` - /// for the counts, the output [`Int64Array`] can be created - /// without copy. - counts: Vec, -} - -impl CountGroupsAccumulator { - pub fn new() -> Self { - Self { counts: vec![] } - } -} - -impl GroupsAccumulator for CountGroupsAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = &values[0]; - - // Add one to each group's counter for each non null, non - // filtered value - self.counts.resize(total_num_groups, 0); - accumulate_indices( - group_indices, - values.logical_nulls().as_ref(), - opt_filter, - |group_index| { - self.counts[group_index] += 1; - }, - ); - - Ok(()) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "one argument to merge_batch"); - // first batch is counts, second is partial sums - let partial_counts = values[0].as_primitive::(); - - // intermediate counts are always created as non null - assert_eq!(partial_counts.null_count(), 0); - let partial_counts = partial_counts.values(); - - // Adds the counts with the partial counts - self.counts.resize(total_num_groups, 0); - match opt_filter { - Some(filter) => filter - .iter() - .zip(group_indices.iter()) - .zip(partial_counts.iter()) - .for_each(|((filter_value, &group_index), partial_count)| { - if let Some(true) = filter_value { - self.counts[group_index] += partial_count; - } - }), - None => group_indices.iter().zip(partial_counts.iter()).for_each( - |(&group_index, partial_count)| { - self.counts[group_index] += partial_count; - }, - ), - } - - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - - // Count is always non null (null inputs just don't contribute to the overall values) - let nulls = None; - let array = PrimitiveArray::::new(counts.into(), nulls); - - Ok(Arc::new(array)) - } - - // return arrays for counts - fn state(&mut self, emit_to: EmitTo) -> Result> { - let counts = emit_to.take_needed(&mut self.counts); - let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls - Ok(vec![Arc::new(counts) as ArrayRef]) - } - - fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - } -} - -/// count null values for multiple columns -/// for each row if one column value is null, then null_count + 1 -fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { - if values.len() > 1 { - let result_bool_buf: Option = values - .iter() - .map(|a| a.logical_nulls()) - .fold(None, |acc, b| match (acc, b) { - (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), - (Some(acc), None) => Some(acc), - (None, Some(b)) => Some(b.into_inner()), - _ => None, - }); - result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) - } else { - values[0] - .logical_nulls() - .map_or(0, |nulls| nulls.null_count()) - } -} - -impl AggregateExpr for Count { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, self.nullable)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "count"), - DataType::Int64, - true, - )]) - } - - fn expressions(&self) -> Vec> { - self.exprs.clone() - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - // groups accumulator only supports `COUNT(c1)`, not - // `COUNT(c1, c2)`, etc - self.exprs.len() == 1 - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn create_groups_accumulator(&self) -> Result> { - // instantiate specialized accumulator - Ok(Box::new(CountGroupsAccumulator::new())) - } - - fn with_new_expressions( - &self, - args: Vec>, - order_by_exprs: Vec>, - ) -> Option> { - debug_assert_eq!(self.exprs.len(), args.len()); - debug_assert!(order_by_exprs.is_empty()); - Some(Arc::new(Count { - name: self.name.clone(), - data_type: self.data_type.clone(), - nullable: self.nullable, - exprs: args, - })) - } -} - -impl PartialEq for Count { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.exprs.len() == x.exprs.len() - && self - .exprs - .iter() - .zip(x.exprs.iter()) - .all(|(expr1, expr2)| expr1.eq(expr2)) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -struct CountAccumulator { - count: i64, -} - -impl CountAccumulator { - /// new count accumulator - pub fn new() -> Self { - Self { count: 0 } - } -} - -impl Accumulator for CountAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Int64(Some(self.count))]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], Int64Array); - let delta = &compute::sum(counts); - if let Some(d) = delta { - self.count += *d; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.count))) - } - - fn supports_retract_batch(&self) -> bool { - true - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs deleted file mode 100644 index 52f1c5c0f9a0..000000000000 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ /dev/null @@ -1,718 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod bytes; -mod native; - -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use ahash::RandomState; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field, TimeUnit}; -use arrow_array::cast::AsArray; -use arrow_array::types::{ - Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; - -use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; - -use crate::aggregate::count_distinct::bytes::BytesDistinctCountAccumulator; -use crate::aggregate::count_distinct::native::{ - FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, -}; -use crate::aggregate::utils::down_cast_any_ref; -use crate::binary_map::OutputType; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -/// Expression for a `COUNT(DISTINCT)` aggregation. -#[derive(Debug)] -pub struct DistinctCount { - /// Column name - name: String, - /// The DataType used to hold the state for each input - state_data_type: DataType, - /// The input arguments - expr: Arc, -} - -impl DistinctCount { - /// Create a new COUNT(DISTINCT) aggregate function. - pub fn new( - input_data_type: DataType, - expr: Arc, - name: impl Into, - ) -> Self { - Self { - name: name.into(), - state_data_type: input_data_type, - expr, - } - } -} - -impl AggregateExpr for DistinctCount { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, true)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "count distinct"), - Field::new("item", self.state_data_type.clone(), true), - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - use DataType::*; - use TimeUnit::*; - - let data_type = &self.state_data_type; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - Date32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Date64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Time32(Millisecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32MillisecondType, - >::new(data_type)), - Time32(Second) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32SecondType, - >::new(data_type)), - Time64(Microsecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64MicrosecondType, - >::new(data_type)), - Time64(Nanosecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64NanosecondType, - >::new(data_type)), - Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMicrosecondType, - >::new(data_type)), - Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMillisecondType, - >::new(data_type)), - Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampNanosecondType, - >::new(data_type)), - Timestamp(Second, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampSecondType, - >::new(data_type)), - - Float16 => Box::new(FloatDistinctCountAccumulator::::new()), - Float32 => Box::new(FloatDistinctCountAccumulator::::new()), - Float64 => Box::new(FloatDistinctCountAccumulator::::new()), - - Utf8 => Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)), - LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - Binary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - }), - }) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctCount { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.state_data_type == x.state_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. -/// -/// It stores intermediate results as a `ListArray` -/// -/// Note that many types have specialized accumulators that are (much) -/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and -/// [`BytesDistinctCountAccumulator`] -#[derive(Debug)] -struct DistinctCountAccumulator { - values: HashSet, - state_data_type: DataType, -} - -impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * - // number of batches This method is faster than .full_size(), however it is - // not suitable for variable length values like strings or complex types - fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .unwrap_or(0) - + std::mem::size_of::() - } - - // calculates the size as accurately as possible. Note that calling this - // method is expensive - fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .sum::() - + std::mem::size_of::() - } -} - -impl Accumulator for DistinctCountAccumulator { - /// Returns the distinct values seen so far as (one element) ListArray. - fn state(&mut self) -> Result> { - let scalars = self.values.iter().cloned().collect::>(); - let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); - Ok(vec![ScalarValue::List(arr)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let arr = &values[0]; - if arr.data_type() == &DataType::Null { - return Ok(()); - } - - (0..arr.len()).try_for_each(|index| { - if !arr.is_null(index) { - let scalar = ScalarValue::try_from_array(arr, index)?; - self.values.insert(scalar); - } - Ok(()) - }) - } - - /// Merges multiple sets of distinct values into the current set. - /// - /// The input to this function is a `ListArray` with **multiple** rows, - /// where each row contains the values from a partial aggregate's phase (e.g. - /// the result of calling `Self::state` on multiple accumulators). - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert_eq!(states.len(), 1, "array_agg states must be singleton!"); - let array = &states[0]; - let list_array = array.as_list::(); - for inner_array in list_array.iter() { - let Some(inner_array) = inner_array else { - return internal_err!( - "Intermediate results of COUNT DISTINCT should always be non null" - ); - }; - self.update_batch(&[inner_array])?; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.values.len() as i64))) - } - - fn size(&self) -> usize { - match &self.state_data_type { - DataType::Boolean | DataType::Null => self.fixed_size(), - d if d.is_primitive() => self.fixed_size(), - _ => self.full_size(), - } - } -} - -#[cfg(test)] -mod tests { - use arrow::array::{ - BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }; - use arrow_array::Decimal256Array; - use arrow_buffer::i256; - - use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; - use datafusion_common::internal_err; - use datafusion_common::DataFusionError; - - use crate::expressions::NoOp; - - use super::*; - - macro_rules! state_to_vec_primitive { - ($LIST:expr, $DATA_TYPE:ident) => {{ - let arr = ScalarValue::raw_data($LIST).unwrap(); - let list_arr = as_list_array(&arr).unwrap(); - let arr = list_arr.values(); - let arr = as_primitive_array::<$DATA_TYPE>(arr)?; - arr.values().iter().cloned().collect::>() - }}; - } - - macro_rules! test_count_distinct_update_batch_numeric { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(1), - Some(1), - None, - Some(3), - Some(2), - None, - Some(2), - Some(3), - Some(1), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![1, 2, 3]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - fn state_to_vec_bool(sv: &ScalarValue) -> Result> { - let arr = ScalarValue::raw_data(sv)?; - let list_arr = as_list_array(&arr)?; - let arr = list_arr.values(); - let bool_arr = as_boolean_array(arr)?; - Ok(bool_arr.iter().flatten().collect()) - } - - fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - arrays[0].data_type().clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - accum.update_batch(arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - fn run_update( - data_types: &[DataType], - rows: &[Vec], - ) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - data_types[0].clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - - let cols = (0..rows[0].len()) - .map(|i| { - rows.iter() - .map(|inner| inner[i].clone()) - .collect::>() - }) - .collect::>(); - - let arrays: Vec = cols - .iter() - .map(|c| ScalarValue::iter_to_array(c.clone())) - .collect::>>()?; - - accum.update_batch(&arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - // Used trait to create associated constant for f32 and f64 - trait SubNormal: 'static { - const SUBNORMAL: Self; - } - - impl SubNormal for f64 { - const SUBNORMAL: Self = 1.0e-308_f64; - } - - impl SubNormal for f32 { - const SUBNORMAL: Self = 1.0e-38_f32; - } - - macro_rules! test_count_distinct_update_batch_floating_point { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(<$PRIM_TYPE>::INFINITY), - Some(<$PRIM_TYPE>::NAN), - Some(1.0), - Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), - Some(1.0), - Some(<$PRIM_TYPE>::INFINITY), - None, - Some(3.0), - Some(-4.5), - Some(2.0), - None, - Some(2.0), - Some(3.0), - Some(<$PRIM_TYPE>::NEG_INFINITY), - Some(1.0), - Some(<$PRIM_TYPE>::NAN), - Some(<$PRIM_TYPE>::NEG_INFINITY), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - - dbg!(&state_vec); - state_vec.sort_by(|a, b| match (a, b) { - (lhs, rhs) => lhs.total_cmp(rhs), - }); - - let nan_idx = state_vec.len() - 1; - assert_eq!(states.len(), 1); - assert_eq!( - &state_vec[..nan_idx], - vec![ - <$PRIM_TYPE>::NEG_INFINITY, - -4.5, - <$PRIM_TYPE as SubNormal>::SUBNORMAL, - 1.0, - 2.0, - 3.0, - <$PRIM_TYPE>::INFINITY - ] - ); - assert!(state_vec[nan_idx].is_nan()); - assert_eq!(result, ScalarValue::Int64(Some(8))); - - Ok(()) - }}; - } - - macro_rules! test_count_distinct_update_batch_bigint { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(i256::from(1)), - Some(i256::from(1)), - None, - Some(i256::from(3)), - Some(i256::from(2)), - None, - Some(i256::from(2)), - Some(i256::from(3)), - Some(i256::from(1)), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - #[test] - fn count_distinct_update_batch_i8() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) - } - - #[test] - fn count_distinct_update_batch_i16() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16) - } - - #[test] - fn count_distinct_update_batch_i32() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32) - } - - #[test] - fn count_distinct_update_batch_i64() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64) - } - - #[test] - fn count_distinct_update_batch_u8() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8) - } - - #[test] - fn count_distinct_update_batch_u16() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16) - } - - #[test] - fn count_distinct_update_batch_u32() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32) - } - - #[test] - fn count_distinct_update_batch_u64() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64) - } - - #[test] - fn count_distinct_update_batch_f32() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32) - } - - #[test] - fn count_distinct_update_batch_f64() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) - } - - #[test] - fn count_distinct_update_batch_i256() -> Result<()> { - test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) - } - - #[test] - fn count_distinct_update_batch_boolean() -> Result<()> { - let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { - let arrays = vec![Arc::new(data) as ArrayRef]; - let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = state_to_vec_bool(&states[0])?; - state_vec.sort(); - - let count = match result { - ScalarValue::Int64(c) => c.ok_or_else(|| { - DataFusionError::Internal("Found None count".to_string()) - }), - scalar => { - internal_err!("Found non int64 scalar value from count: {scalar}") - } - }?; - Ok((state_vec, count)) - }; - - let zero_count_values = BooleanArray::from(Vec::::new()); - - let one_count_values = BooleanArray::from(vec![false, false]); - let one_count_values_with_null = - BooleanArray::from(vec![Some(true), Some(true), None, None]); - - let two_count_values = BooleanArray::from(vec![true, false, true, false, true]); - let two_count_values_with_null = BooleanArray::from(vec![ - Some(true), - Some(false), - None, - None, - Some(true), - Some(false), - ]); - - assert_eq!(get_count(zero_count_values)?, (Vec::::new(), 0)); - assert_eq!(get_count(one_count_values)?, (vec![false], 1)); - assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1)); - assert_eq!(get_count(two_count_values)?, (vec![false, true], 2)); - assert_eq!( - get_count(two_count_values_with_null)?, - (vec![false, true], 2) - ); - Ok(()) - } - - #[test] - fn count_distinct_update_batch_all_nulls() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from( - vec![None, None, None, None] as Vec> - )) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update_batch_empty() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - Ok(()) - } - - #[test] - fn count_distinct_update_with_nulls() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(None)], - vec![ScalarValue::Int32(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(None)], - vec![ScalarValue::UInt64(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index 65227b727be7..a6946e739c97 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -20,7 +20,7 @@ pub use adapter::GroupsAccumulatorAdapter; // Backward compatibility pub(crate) mod accumulate { - pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::{accumulate_indices, NullState}; + pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; } pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 7a6c5f9d0e24..01105c8559c9 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -26,8 +26,6 @@ pub(crate) mod average; pub(crate) mod bit_and_or_xor; pub(crate) mod bool_and_or; pub(crate) mod correlation; -pub(crate) mod count; -pub(crate) mod count_distinct; pub(crate) mod covariance; pub(crate) mod grouping; pub(crate) mod nth_value; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index a96d02173018..123ada6d7c86 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -47,8 +47,6 @@ pub use crate::aggregate::bit_and_or_xor::{BitAnd, BitOr, BitXor, DistinctBitXor pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr}; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; -pub use crate::aggregate::count::Count; -pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 72f5f2d50cb8..b764e81a95d1 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -17,7 +17,9 @@ pub mod aggregate; pub mod analysis; -pub mod binary_map; +pub mod binary_map { + pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; +} pub mod equivalence; pub mod expressions; pub mod functions; diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index d073c8995a9b..f789af8b8a02 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -18,7 +18,7 @@ use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; -use datafusion_physical_expr::binary_map::{ArrowBytesMap, OutputType}; +use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 79abbdb52ca2..b6fc70be7cbc 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1194,12 +1194,14 @@ mod tests { use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::expr::Sort; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::median::median_udaf; use datafusion_physical_expr::expressions::{ - lit, Count, FirstValue, LastValue, OrderSensitiveArrayAgg, + lit, FirstValue, LastValue, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1334,11 +1336,16 @@ mod tests { ], }; - let aggregates: Vec> = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - ))]; + let aggregates = vec![create_aggregate_expr( + &count_udaf(), + &[lit(1i8)], + &[], + &[], + &input_schema, + "COUNT(1)", + false, + false, + )?]; let task_ctx = if spill { new_spill_ctx(4, 1000) diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 48f1bee59bbf..56d780e51394 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1194,9 +1194,9 @@ mod tests { RecordBatchStream, SendableRecordBatchStream, TaskContext, }; use datafusion_expr::{ - AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::{col, Column, NthValue}; use datafusion_physical_expr::window::{ BuiltInWindowExpr, BuiltInWindowFunctionExpr, @@ -1298,8 +1298,7 @@ mod tests { order_by: &str, ) -> Result> { let schema = input.schema(); - let window_fn = - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count); + let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf()); let col_expr = Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; let args = vec![col_expr]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 9b392d941ef4..63ce473fc57e 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -597,7 +597,6 @@ pub fn get_window_mode( #[cfg(test)] mod tests { use super::*; - use crate::aggregates::AggregateFunction; use crate::collect; use crate::expressions::col; use crate::streaming::StreamingTableExec; @@ -607,6 +606,7 @@ mod tests { use arrow::compute::SortOptions; use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; use futures::FutureExt; use InputOrderMode::{Linear, PartiallySorted, Sorted}; @@ -749,7 +749,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col("a", &schema)?], &[], diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 97568fb5f678..66ce7cbd838f 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 49884c48b3cc..9f8f03de6dc9 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen-common" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 358ba7e3eb94..b1897aa58e7d 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,7 +27,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index b6993f6c040b..eabaf7ba8e14 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index b401ff8810db..2bb3ec793d7f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -520,6 +520,7 @@ message AggregateExprNode { message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; + bool distinct = 5; LogicalExprNode filter = 3; repeated LogicalExprNode order_by = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d6632c77d8da..59b7861a6ef1 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -886,6 +886,9 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.args.is_empty() { len += 1; } + if self.distinct { + len += 1; + } if self.filter.is_some() { len += 1; } @@ -899,6 +902,9 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if self.distinct { + struct_ser.serialize_field("distinct", &self.distinct)?; + } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; } @@ -918,6 +924,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "fun_name", "funName", "args", + "distinct", "filter", "order_by", "orderBy", @@ -927,6 +934,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { enum GeneratedField { FunName, Args, + Distinct, Filter, OrderBy, } @@ -952,6 +960,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { match value { "funName" | "fun_name" => Ok(GeneratedField::FunName), "args" => Ok(GeneratedField::Args), + "distinct" => Ok(GeneratedField::Distinct), "filter" => Ok(GeneratedField::Filter), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -975,6 +984,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { { let mut fun_name__ = None; let mut args__ = None; + let mut distinct__ = None; let mut filter__ = None; let mut order_by__ = None; while let Some(k) = map_.next_key()? { @@ -991,6 +1001,12 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { } args__ = Some(map_.next_value()?); } + GeneratedField::Distinct => { + if distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); + } + distinct__ = Some(map_.next_value()?); + } GeneratedField::Filter => { if filter__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); @@ -1008,6 +1024,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { Ok(AggregateUdfExprNode { fun_name: fun_name__.unwrap_or_default(), args: args__.unwrap_or_default(), + distinct: distinct__.unwrap_or_default(), filter: filter__, order_by: order_by__.unwrap_or_default(), }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0aca5ef1ffb8..0861c287fcfa 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -767,6 +767,8 @@ pub struct AggregateUdfExprNode { pub fun_name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "5")] + pub distinct: bool, #[prost(message, optional, boxed, tag = "3")] pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "4")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3ad5973380ed..2ad40d883fe6 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -642,7 +642,7 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, parse_exprs(&pb.args, registry, codec)?, - false, + pb.distinct, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_vec_expr(&pb.order_by, registry, codec)?, None, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d42470f198e3..6a275ed7a1b8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -456,6 +456,7 @@ pub fn serialize_expr( protobuf::AggregateUdfExprNode { fun_name: fun.name().to_string(), args: serialize_exprs(args, codec)?, + distinct: *distinct, filter: match filter { Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), None => None, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 5258bdd11d86..e25447b023d8 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -25,10 +25,10 @@ use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, - Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, - NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, - RowNumber, StringAgg, TryCastExpr, WindowShift, + CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, + IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, + OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, StringAgg, + TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -240,12 +240,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); let mut distinct = false; - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Grouping } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::BitAnd diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 699697dd2f2c..d9736da69d42 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -649,6 +649,8 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), + count(lit(1)), + count_distinct(lit(1)), first_value(lit(1), None), first_value(lit(1), Some(vec![lit(2).sort(true, true)])), covar_samp(lit(1.5), lit(2.2)), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 9cf686dbd3d6..e517482f1db0 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -38,7 +38,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; -use datafusion::physical_expr::expressions::{Count, Max, NthValueAgg}; +use datafusion::physical_expr::expressions::{Max, NthValueAgg}; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -47,8 +47,8 @@ use datafusion::physical_plan::aggregates::{ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ - binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, - NotExpr, NthValue, PhysicalSortExpr, StringAgg, + binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr, NthValue, + PhysicalSortExpr, StringAgg, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::DataSinkExec; @@ -806,7 +806,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Count::new(udf_expr, "count", DataType::Int64))], + vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], vec![None], window, schema.clone(), @@ -818,31 +818,6 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { Ok(()) } -#[test] -fn roundtrip_distinct_count() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let aggregates: Vec> = vec![Arc::new(DistinctCount::new( - DataType::Int64, - col("b", &schema)?, - "COUNT(DISTINCT b)".to_string(), - ))]; - - let groups: Vec<(Arc, String)> = - vec![(col("a", &schema)?, "unused".to_string())]; - - roundtrip_test(Arc::new(AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::new_single(groups), - aggregates.clone(), - vec![None], - Arc::new(EmptyExec::new(schema.clone())), - schema, - )?)) -} - #[test] fn roundtrip_like() -> Result<()> { let schema = Schema::new(vec![ diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index e930af107f77..c7b9808c249d 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -46,7 +46,7 @@ statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c SELECT CAST(c1 AS INT) FROM aggregate_test_100 # aggregation_with_bad_arguments -statement error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +query error SELECT COUNT(DISTINCT) FROM aggregate_test_100 # query_cte_incorrect @@ -104,7 +104,7 @@ SELECT power(1, 2, 3); # # AggregateFunction with wrong number of arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'COUNT\(\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tCOUNT\(Any, \.\., Any\) +query error select count(); # AggregateFunction with wrong number of arguments diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index ee96ffa67044..d934dba4cfea 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" [lints] workspace = true diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 3f9a895d951c..93f197885c0a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -982,18 +982,16 @@ pub async fn from_substrait_agg_func( let function_name = substrait_fun_name((**function_name).as_str()); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { + // deal with situation that count(*) got no arguments + if fun.name() == "COUNT" && args.is_empty() { + args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); + } + Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { - match &fun { - // deal with situation that count(*) got no arguments - aggregate_function::AggregateFunction::Count if args.is_empty() => { - args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); - } - _ => {} - } Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None), )))