Skip to content

Commit

Permalink
feat: Allow more group_by agg expressions in the new streaming engine (
Browse files Browse the repository at this point in the history
…pola-rs#20663)

Co-authored-by: ritchie <[email protected]>
  • Loading branch information
orlp and ritchie46 authored Jan 11, 2025
1 parent 94b4087 commit 87baf86
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 78 deletions.
6 changes: 5 additions & 1 deletion crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ fn partitionable_gb(
// complex expressions in the group_by itself are also not partitionable
// in this case anything more than col("foo")
for key in keys {
if (expr_arena).iter(key.node()).count() > 1 {
if (expr_arena).iter(key.node()).count() > 1
|| has_aexpr(key.node(), expr_arena, |ae| {
matches!(ae, AExpr::Literal(LiteralValue::Series(_)))
})
{
return false;
}
}
Expand Down
34 changes: 22 additions & 12 deletions crates/polars-stream/src/physical_plan/lower_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use slotmap::SlotMap;

use super::{PhysNode, PhysNodeKey, PhysNodeKind, PhysStream};

type IRNodeKey = Node;
type ExprNodeKey = Node;

fn unique_column_name() -> PlSmallStr {
pub fn unique_column_name() -> PlSmallStr {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let idx = COUNTER.fetch_add(1, Ordering::Relaxed);
format_pl_smallstr!("__POLARS_STMP_{idx}")
Expand All @@ -48,7 +48,7 @@ struct LowerExprContext<'a> {
}

pub(crate) fn is_elementwise_rec_cached(
expr_key: IRNodeKey,
expr_key: ExprNodeKey,
arena: &Arena<AExpr>,
cache: &mut ExprCache,
) -> bool {
Expand Down Expand Up @@ -97,10 +97,10 @@ pub(crate) fn is_elementwise_rec_cached(
}

#[recursive::recursive]
fn is_input_independent_rec(
expr_key: IRNodeKey,
pub fn is_input_independent_rec(
expr_key: ExprNodeKey,
arena: &Arena<AExpr>,
cache: &mut PlHashMap<IRNodeKey, bool>,
cache: &mut PlHashMap<ExprNodeKey, bool>,
) -> bool {
if let Some(ret) = cache.get(&expr_key) {
return *ret;
Expand Down Expand Up @@ -207,7 +207,15 @@ fn is_input_independent_rec(
ret
}

fn is_input_independent(expr_key: IRNodeKey, ctx: &mut LowerExprContext) -> bool {
pub fn is_input_independent(
expr_key: ExprNodeKey,
expr_arena: &Arena<AExpr>,
cache: &mut ExprCache,
) -> bool {
is_input_independent_rec(expr_key, expr_arena, &mut cache.is_input_independent)
}

fn is_input_independent_ctx(expr_key: ExprNodeKey, ctx: &mut LowerExprContext) -> bool {
is_input_independent_rec(
expr_key,
ctx.expr_arena,
Expand Down Expand Up @@ -359,7 +367,7 @@ fn lower_exprs_with_ctx(
) -> PolarsResult<(PhysStream, Vec<Node>)> {
// We have to catch this case separately, in case all the input independent expressions are elementwise.
// TODO: we shouldn't always do this when recursing, e.g. pl.col.a.sum() + 1 will still hit this in the recursion.
if exprs.iter().all(|e| is_input_independent(*e, ctx)) {
if exprs.iter().all(|e| is_input_independent_ctx(*e, ctx)) {
let expr_irs = exprs
.iter()
.map(|e| ExprIR::new(*e, OutputName::Alias(unique_column_name())))
Expand All @@ -384,7 +392,7 @@ fn lower_exprs_with_ctx(

for expr in exprs.iter().copied() {
if is_elementwise_rec_cached(expr, ctx.expr_arena, ctx.cache) {
if !is_input_independent(expr, ctx) {
if !is_input_independent_ctx(expr, ctx) {
input_streams.insert(input);
}
transformed_exprs.push(expr);
Expand Down Expand Up @@ -679,7 +687,10 @@ fn build_select_stream_with_ctx(
exprs: &[ExprIR],
ctx: &mut LowerExprContext,
) -> PolarsResult<PhysStream> {
if exprs.iter().all(|e| is_input_independent(e.node(), ctx)) {
if exprs
.iter()
.all(|e| is_input_independent_ctx(e.node(), ctx))
{
return Ok(PhysStream::first(build_input_independent_node_with_ctx(
exprs, ctx,
)?));
Expand All @@ -696,8 +707,7 @@ fn build_select_stream_with_ctx(

if let Some(columns) = all_simple_columns {
let input_schema = ctx.phys_sm[input.node].output_schema.clone();
if !cfg!(debug_assertions)
&& input_schema.len() == columns.len()
if input_schema.len() == columns.len()
&& input_schema.iter_names().zip(&columns).all(|(l, r)| l == r)
{
// Input node already has the correct schema, just pass through.
Expand Down
Loading

0 comments on commit 87baf86

Please sign in to comment.