From 87baf8649f3a4dacad898be0dc1fc06b3bb76c5f Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Sat, 11 Jan 2025 09:29:18 +0100 Subject: [PATCH] feat: Allow more group_by agg expressions in the new streaming engine (#20663) Co-authored-by: ritchie --- crates/polars-mem-engine/src/planner/lp.rs | 6 +- .../src/physical_plan/lower_expr.rs | 34 +- .../src/physical_plan/lower_group_by.rs | 382 ++++++++++++++++++ .../src/physical_plan/lower_ir.rs | 87 +--- crates/polars-stream/src/physical_plan/mod.rs | 1 + py-polars/tests/unit/dataframe/test_df.py | 2 + 6 files changed, 434 insertions(+), 78 deletions(-) create mode 100644 crates/polars-stream/src/physical_plan/lower_group_by.rs diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 3158651aaa15..6981cc1d255f 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -25,7 +25,11 @@ fn partitionable_gb( // complex expressions in the group_by itself are also not partitionable // in this case anything more than col("foo") for key in keys { - if (expr_arena).iter(key.node()).count() > 1 { + if (expr_arena).iter(key.node()).count() > 1 + || has_aexpr(key.node(), expr_arena, |ae| { + matches!(ae, AExpr::Literal(LiteralValue::Series(_))) + }) + { return false; } } diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index 3d9f87b14487..bb954021149f 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -19,9 +19,9 @@ use slotmap::SlotMap; use super::{PhysNode, PhysNodeKey, PhysNodeKind, PhysStream}; -type IRNodeKey = Node; +type ExprNodeKey = Node; -fn unique_column_name() -> PlSmallStr { +pub fn unique_column_name() -> PlSmallStr { static COUNTER: AtomicU64 = AtomicU64::new(0); let idx = COUNTER.fetch_add(1, Ordering::Relaxed); format_pl_smallstr!("__POLARS_STMP_{idx}") @@ -48,7 +48,7 @@ struct LowerExprContext<'a> { } pub(crate) fn is_elementwise_rec_cached( - expr_key: IRNodeKey, + expr_key: ExprNodeKey, arena: &Arena, cache: &mut ExprCache, ) -> bool { @@ -97,10 +97,10 @@ pub(crate) fn is_elementwise_rec_cached( } #[recursive::recursive] -fn is_input_independent_rec( - expr_key: IRNodeKey, +pub fn is_input_independent_rec( + expr_key: ExprNodeKey, arena: &Arena, - cache: &mut PlHashMap, + cache: &mut PlHashMap, ) -> bool { if let Some(ret) = cache.get(&expr_key) { return *ret; @@ -207,7 +207,15 @@ fn is_input_independent_rec( ret } -fn is_input_independent(expr_key: IRNodeKey, ctx: &mut LowerExprContext) -> bool { +pub fn is_input_independent( + expr_key: ExprNodeKey, + expr_arena: &Arena, + cache: &mut ExprCache, +) -> bool { + is_input_independent_rec(expr_key, expr_arena, &mut cache.is_input_independent) +} + +fn is_input_independent_ctx(expr_key: ExprNodeKey, ctx: &mut LowerExprContext) -> bool { is_input_independent_rec( expr_key, ctx.expr_arena, @@ -359,7 +367,7 @@ fn lower_exprs_with_ctx( ) -> PolarsResult<(PhysStream, Vec)> { // We have to catch this case separately, in case all the input independent expressions are elementwise. // TODO: we shouldn't always do this when recursing, e.g. pl.col.a.sum() + 1 will still hit this in the recursion. - if exprs.iter().all(|e| is_input_independent(*e, ctx)) { + if exprs.iter().all(|e| is_input_independent_ctx(*e, ctx)) { let expr_irs = exprs .iter() .map(|e| ExprIR::new(*e, OutputName::Alias(unique_column_name()))) @@ -384,7 +392,7 @@ fn lower_exprs_with_ctx( for expr in exprs.iter().copied() { if is_elementwise_rec_cached(expr, ctx.expr_arena, ctx.cache) { - if !is_input_independent(expr, ctx) { + if !is_input_independent_ctx(expr, ctx) { input_streams.insert(input); } transformed_exprs.push(expr); @@ -679,7 +687,10 @@ fn build_select_stream_with_ctx( exprs: &[ExprIR], ctx: &mut LowerExprContext, ) -> PolarsResult { - if exprs.iter().all(|e| is_input_independent(e.node(), ctx)) { + if exprs + .iter() + .all(|e| is_input_independent_ctx(e.node(), ctx)) + { return Ok(PhysStream::first(build_input_independent_node_with_ctx( exprs, ctx, )?)); @@ -696,8 +707,7 @@ fn build_select_stream_with_ctx( if let Some(columns) = all_simple_columns { let input_schema = ctx.phys_sm[input.node].output_schema.clone(); - if !cfg!(debug_assertions) - && input_schema.len() == columns.len() + if input_schema.len() == columns.len() && input_schema.iter_names().zip(&columns).all(|(l, r)| l == r) { // Input node already has the correct schema, just pass through. diff --git a/crates/polars-stream/src/physical_plan/lower_group_by.rs b/crates/polars-stream/src/physical_plan/lower_group_by.rs new file mode 100644 index 000000000000..28e358bb63bf --- /dev/null +++ b/crates/polars-stream/src/physical_plan/lower_group_by.rs @@ -0,0 +1,382 @@ +use std::sync::Arc; + +use parking_lot::Mutex; +use polars_core::prelude::{InitHashMaps, PlHashMap, PlIndexMap}; +use polars_core::schema::Schema; +use polars_error::{polars_err, PolarsResult}; +use polars_expr::state::ExecutionState; +use polars_mem_engine::create_physical_plan; +use polars_plan::plans::expr_ir::{ExprIR, OutputName}; +use polars_plan::plans::{AExpr, ArenaExprIter, DataFrameUdf, IRAggExpr, IR}; +use polars_plan::prelude::GroupbyOptions; +use polars_utils::arena::{Arena, Node}; +use polars_utils::itertools::Itertools; +use polars_utils::pl_str::PlSmallStr; +use slotmap::SlotMap; + +use super::lower_expr::lower_exprs; +use super::{ExprCache, PhysNode, PhysNodeKey, PhysNodeKind, PhysStream}; +use crate::physical_plan::lower_expr::{ + build_select_stream, compute_output_schema, is_input_independent, unique_column_name, +}; +use crate::utils::late_materialized_df::LateMaterializedDataFrame; + +#[allow(clippy::too_many_arguments)] +fn build_group_by_fallback( + input: PhysStream, + keys: &[ExprIR], + aggs: &[ExprIR], + output_schema: Arc, + maintain_order: bool, + options: Arc, + apply: Option>, + expr_arena: &mut Arena, + phys_sm: &mut SlotMap, +) -> PolarsResult { + let input_schema = phys_sm[input.node].output_schema.clone(); + let lmdf = Arc::new(LateMaterializedDataFrame::default()); + let mut lp_arena = Arena::default(); + let input_lp_node = lp_arena.add(lmdf.clone().as_ir_node(input_schema.clone())); + let group_by_lp_node = lp_arena.add(IR::GroupBy { + input: input_lp_node, + keys: keys.to_vec(), + aggs: aggs.to_vec(), + schema: output_schema.clone(), + maintain_order, + options, + apply, + }); + let executor = Mutex::new(create_physical_plan( + group_by_lp_node, + &mut lp_arena, + expr_arena, + )?); + + let group_by_node = PhysNode { + output_schema, + kind: PhysNodeKind::InMemoryMap { + input, + map: Arc::new(move |df| { + lmdf.set_materialized_dataframe(df); + let mut state = ExecutionState::new(); + executor.lock().execute(&mut state) + }), + }, + }; + + Ok(PhysStream::first(phys_sm.insert(group_by_node))) +} + +/// Tries to lower an expression as a 'elementwise scalar agg expression'. +/// +/// Such an expression is defined as the elementwise combination of scalar +/// aggregations of elementwise combinations of the input columns / scalar literals. +fn try_lower_elementwise_scalar_agg_expr( + expr: Node, + inside_agg: bool, + outer_name: Option, + expr_arena: &mut Arena, + agg_exprs: &mut Vec, + trans_input_cols: &PlHashMap, +) -> Option { + // Helper macro to simplify recursive calls. + macro_rules! lower_rec { + ($input:expr, $inside_agg:expr) => { + try_lower_elementwise_scalar_agg_expr( + $input, + $inside_agg, + None, + expr_arena, + agg_exprs, + trans_input_cols, + ) + }; + } + + match expr_arena.get(expr) { + AExpr::Alias(..) => unreachable!("alias found in physical plan"), + + AExpr::Column(c) => { + if inside_agg { + Some(trans_input_cols[c]) + } else { + // Implicit implode not yet supported. + None + } + }, + + AExpr::Literal(lit) => { + if lit.is_scalar() { + Some(expr) + } else { + None + } + }, + + AExpr::Slice { .. } + | AExpr::Window { .. } + | AExpr::Sort { .. } + | AExpr::SortBy { .. } + | AExpr::Gather { .. } => None, + + // Explode and filter are row-separable and should thus in theory work + // in a streaming fashion but they change the length of the input which + // means the same filter/explode should also be applied to the key + // column, which is not (yet) supported. + AExpr::Explode(_) | AExpr::Filter { .. } => None, + + AExpr::BinaryExpr { left, op, right } => { + let (left, op, right) = (*left, *op, *right); + let left = lower_rec!(left, inside_agg)?; + let right = lower_rec!(right, inside_agg)?; + Some(expr_arena.add(AExpr::BinaryExpr { left, op, right })) + }, + + AExpr::Ternary { + predicate, + truthy, + falsy, + } => { + let (predicate, truthy, falsy) = (*predicate, *truthy, *falsy); + let predicate = lower_rec!(predicate, inside_agg)?; + let truthy = lower_rec!(truthy, inside_agg)?; + let falsy = lower_rec!(falsy, inside_agg)?; + Some(expr_arena.add(AExpr::Ternary { + predicate, + truthy, + falsy, + })) + }, + + node @ AExpr::Function { input, options, .. } + | node @ AExpr::AnonymousFunction { input, options, .. } + if options.is_elementwise() => + { + let node = node.clone(); + let input = input.clone(); + let new_inputs = input + .into_iter() + .map(|i| lower_rec!(i.node(), inside_agg)) + .collect::>>()?; + Some(expr_arena.add(node.replace_inputs(&new_inputs))) + }, + + AExpr::Function { .. } | AExpr::AnonymousFunction { .. } => None, + + AExpr::Cast { + expr, + dtype, + options, + } => { + let (expr, dtype, options) = (*expr, dtype.clone(), *options); + let expr = lower_rec!(expr, inside_agg)?; + Some(expr_arena.add(AExpr::Cast { + expr, + dtype, + options, + })) + }, + + AExpr::Agg(agg) => { + let orig_agg = agg.clone(); + match agg { + IRAggExpr::Min { input, .. } + | IRAggExpr::Max { input, .. } + | IRAggExpr::Mean(input) + | IRAggExpr::Sum(input) + | IRAggExpr::Var(input, ..) + | IRAggExpr::Std(input, ..) => { + // Nested aggregates not supported. + if inside_agg { + return None; + } + // Lower and replace input. + let trans_input = lower_rec!(*input, true)?; + let mut trans_agg = orig_agg; + trans_agg.set_input(trans_input); + let trans_agg_node = expr_arena.add(AExpr::Agg(trans_agg)); + + // Add to aggregation expressions and replace with a reference to its output. + + let agg_expr = if let Some(name) = outer_name { + ExprIR::new(trans_agg_node, OutputName::Alias(name)) + } else { + ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name())) + }; + let result_node = expr_arena.add(AExpr::Column(agg_expr.output_name().clone())); + agg_exprs.push(agg_expr); + Some(result_node) + }, + IRAggExpr::Median(..) + | IRAggExpr::NUnique(..) + | IRAggExpr::First(..) + | IRAggExpr::Last(..) + | IRAggExpr::Implode(..) + | IRAggExpr::Quantile { .. } + | IRAggExpr::Count(..) + | IRAggExpr::AggGroups(..) => None, // TODO: allow all aggregates, + } + }, + AExpr::Len => { + let agg_expr = if let Some(name) = outer_name { + ExprIR::new(expr, OutputName::Alias(name)) + } else { + ExprIR::new(expr, OutputName::Alias(unique_column_name())) + }; + let result_node = expr_arena.add(AExpr::Column(agg_expr.output_name().clone())); + agg_exprs.push(agg_expr); + Some(result_node) + }, + } +} + +#[allow(clippy::too_many_arguments)] +fn try_build_streaming_group_by( + input: PhysStream, + keys: &[ExprIR], + aggs: &[ExprIR], + maintain_order: bool, + options: Arc, + apply: Option>, + expr_arena: &mut Arena, + phys_sm: &mut SlotMap, + expr_cache: &mut ExprCache, +) -> Option> { + if apply.is_some() || maintain_order { + return None; // TODO + } + + #[cfg(feature = "dynamic_group_by")] + if options.dynamic.is_some() || options.rolling.is_some() { + return None; // TODO + } + + if keys.is_empty() { + return Some(Err( + polars_err!(ComputeError: "at least one key is required in a group_by operation"), + )); + } + + let all_independent = keys + .iter() + .chain(aggs.iter()) + .all(|expr| is_input_independent(expr.node(), expr_arena, expr_cache)); + if all_independent { + return None; + } + + // We must lower the keys together with the input to the aggregations. + let mut input_columns = PlIndexMap::new(); + for agg in aggs { + for (node, expr) in (&*expr_arena).iter(agg.node()) { + if let AExpr::Column(c) = expr { + input_columns.insert(c.clone(), node); + } + } + } + + let mut pre_lower_exprs = keys.to_vec(); + for (col, node) in input_columns.iter() { + pre_lower_exprs.push(ExprIR::new(*node, OutputName::ColumnLhs(col.clone()))); + } + let Ok((trans_input, trans_exprs)) = + lower_exprs(input, &pre_lower_exprs, expr_arena, phys_sm, expr_cache) + else { + return None; + }; + let trans_keys = trans_exprs[..keys.len()].to_vec(); + let trans_input_cols: PlHashMap<_, _> = trans_exprs[keys.len()..] + .iter() + .zip(input_columns.into_keys()) + .map(|(expr, col)| (col, expr.node())) + .collect(); + + // We must now lower each (presumed) scalar aggregate expression while + // substituting the translated input columns and extracting the aggregate + // expressions. + let mut trans_agg_exprs = Vec::new(); + let mut trans_output_exprs = keys + .iter() + .map(|key| { + let key_node = expr_arena.add(AExpr::Column(key.output_name().clone())); + ExprIR::from_node(key_node, expr_arena) + }) + .collect_vec(); + for agg in aggs { + let trans_node = try_lower_elementwise_scalar_agg_expr( + agg.node(), + false, + Some(agg.output_name().clone()), + expr_arena, + &mut trans_agg_exprs, + &trans_input_cols, + )?; + trans_output_exprs.push(ExprIR::new(trans_node, agg.output_name_inner().clone())); + } + + let input_schema = &phys_sm[trans_input.node].output_schema; + let group_by_output_schema = compute_output_schema( + input_schema, + &[trans_keys.clone(), trans_agg_exprs.clone()].concat(), + expr_arena, + ) + .unwrap(); + let agg_node = phys_sm.insert(PhysNode::new( + group_by_output_schema, + PhysNodeKind::GroupBy { + input: trans_input, + key: trans_keys, + aggs: trans_agg_exprs, + }, + )); + + let post_select = build_select_stream( + PhysStream::first(agg_node), + &trans_output_exprs, + expr_arena, + phys_sm, + expr_cache, + ); + Some(post_select) +} + +#[allow(clippy::too_many_arguments)] +pub fn build_group_by_stream( + input: PhysStream, + keys: &[ExprIR], + aggs: &[ExprIR], + output_schema: Arc, + maintain_order: bool, + options: Arc, + apply: Option>, + expr_arena: &mut Arena, + phys_sm: &mut SlotMap, + expr_cache: &mut ExprCache, +) -> PolarsResult { + let streaming = try_build_streaming_group_by( + input, + keys, + aggs, + maintain_order, + options.clone(), + apply.clone(), + expr_arena, + phys_sm, + expr_cache, + ); + if let Some(stream) = streaming { + stream + } else { + build_group_by_fallback( + input, + keys, + aggs, + output_schema, + maintain_order, + options, + apply, + expr_arena, + phys_sm, + ) + } +} diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 649ed49453b1..1c79abd29670 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use polars_core::frame::DataFrame; use polars_core::prelude::{InitHashMaps, PlHashMap, PlIndexMap}; use polars_core::schema::Schema; -use polars_error::{polars_ensure, PolarsResult}; +use polars_error::PolarsResult; use polars_io::RowIndex; use polars_plan::plans::expr_ir::{ExprIR, OutputName}; -use polars_plan::plans::{AExpr, FileScan, FunctionIR, IRAggExpr, IR}; +use polars_plan::plans::{AExpr, FileScan, FunctionIR, IR}; use polars_plan::prelude::{FileType, SinkType}; use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; @@ -16,6 +16,7 @@ use super::{PhysNode, PhysNodeKey, PhysNodeKind, PhysStream}; use crate::physical_plan::lower_expr::{ build_select_stream, is_elementwise_rec_cached, lower_exprs, ExprCache, }; +use crate::physical_plan::lower_group_by::build_group_by_stream; /// Creates a new PhysStream which outputs a slice of the input stream. fn build_slice_stream( @@ -451,76 +452,32 @@ pub fn lower_ir( input, keys, aggs, - schema: _, + schema: output_schema, apply, maintain_order, options, } => { - if apply.is_some() || *maintain_order { - todo!() - } - - #[cfg(feature = "dynamic_group_by")] - if options.dynamic.is_some() || options.rolling.is_some() { - todo!() - } - - let key = keys.clone(); - let mut aggs = aggs.clone(); + let input = *input; + let keys = keys.clone(); + let aggs = aggs.clone(); + let output_schema = output_schema.clone(); + let apply = apply.clone(); + let maintain_order = *maintain_order; let options = options.clone(); - polars_ensure!(!keys.is_empty(), ComputeError: "at least one key is required in a group_by operation"); - - // TODO: allow all aggregates. - let mut input_exprs = key.clone(); - for agg in &aggs { - match expr_arena.get(agg.node()) { - AExpr::Agg(expr) => match expr { - IRAggExpr::Min { input, .. } - | IRAggExpr::Max { input, .. } - | IRAggExpr::Mean(input) - | IRAggExpr::Sum(input) - | IRAggExpr::Var(input, ..) - | IRAggExpr::Std(input, ..) => { - if is_elementwise_rec_cached(*input, expr_arena, expr_cache) { - input_exprs.push(ExprIR::from_node(*input, expr_arena)); - } else { - todo!() - } - }, - _ => todo!(), - }, - AExpr::Len => input_exprs.push(key[0].clone()), // Hack, use the first key column for the length. - _ => todo!(), - } - } - - let phys_input = lower_ir!(*input)?; - let (trans_input, trans_exprs) = - lower_exprs(phys_input, &input_exprs, expr_arena, phys_sm, expr_cache)?; - let trans_key = trans_exprs[..key.len()].to_vec(); - let trans_aggs = aggs - .iter_mut() - .zip(trans_exprs.iter().skip(key.len())) - .map(|(agg, trans_expr)| { - let old_expr = expr_arena.get(agg.node()).clone(); - let new_expr = old_expr.replace_inputs(&[trans_expr.node()]); - ExprIR::new(expr_arena.add(new_expr), agg.output_name_inner().clone()) - }) - .collect(); - - let node = phys_sm.insert(PhysNode::new( + let phys_input = lower_ir!(input)?; + let mut stream = build_group_by_stream( + phys_input, + &keys, + &aggs, output_schema, - PhysNodeKind::GroupBy { - input: trans_input, - key: trans_key, - aggs: trans_aggs, - }, - )); - - // TODO: actually limit number of groups instead of computing full - // result and then slicing. - let mut stream = PhysStream::first(node); + maintain_order, + options.clone(), + apply, + expr_arena, + phys_sm, + expr_cache, + )?; if let Some((offset, len)) = options.slice { stream = build_slice_stream(stream, offset, len, phys_sm); } diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index f311de368d07..87acf2c3a726 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -13,6 +13,7 @@ use polars_plan::prelude::expr_ir::ExprIR; mod fmt; mod lower_expr; +mod lower_group_by; mod lower_ir; mod to_graph; diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 3649d60a629a..77c4f941bd43 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -1802,6 +1802,8 @@ def test_filter_with_all_expansion() -> None: assert out.shape == (2, 3) +# TODO: investigate this discrepancy in auto streaming +@pytest.mark.may_fail_auto_streaming def test_extension() -> None: class Foo: def __init__(self, value: Any) -> None: