diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 53579b763033..ab36afc410c3 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -170,7 +170,7 @@ impl ApplyExpr { // })? let out: ListChunked = POOL.install(|| iter.collect::>())?; - debug_assert_eq!(out.dtype(), &DataType::List(Box::new(dtype))); + debug_assert_eq!(&dtype, out.dtype()); out } else { diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index c4006de0c8ec..cba2c18764e3 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -239,7 +239,7 @@ fn create_physical_expr_inner( // TODO! Order by let group_by = create_physical_expressions_from_nodes( partition_by, - Context::Default, + Context::Aggregation, expr_arena, schema, state, @@ -473,10 +473,13 @@ fn create_physical_expr_inner( options, } => { let is_scalar = is_scalar_ae(expression, expr_arena); - let output_dtype = + let mut output_field = expr_arena .get(expression) .to_field(schema, Context::Default, expr_arena)?; + if let Context::Aggregation = ctxt { + output_field.dtype = DataType::List(Box::new(output_field.dtype)); + } let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR) && matches!(options.collect_groups, ApplyOptions::GroupWise); @@ -501,7 +504,7 @@ fn create_physical_expr_inner( *options, state.allow_threading, schema.clone(), - output_dtype, + output_field, is_scalar, ))) }, @@ -509,13 +512,16 @@ fn create_physical_expr_inner( input, function, options, - .. } => { let is_scalar = is_scalar_ae(expression, expr_arena); - let output_field = + let mut output_field = expr_arena .get(expression) .to_field(schema, Context::Default, expr_arena)?; + if let Context::Aggregation = ctxt { + output_field.dtype = DataType::List(Box::new(output_field.dtype)); + } + let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR) && matches!(options.collect_groups, ApplyOptions::GroupWise); // Will be reset in the function so get that here. @@ -565,9 +571,15 @@ fn create_physical_expr_inner( move |c: &mut [polars_core::frame::column::Column]| c[0].explode().map(Some), ) as Arc); - let field = expr_arena - .get(expression) - .to_field(schema, ctxt, expr_arena)?; + let mut field = + expr_arena + .get(expression) + .to_field(schema, Context::Default, expr_arena)?; + + if let Context::Aggregation = ctxt { + field.dtype = DataType::List(Box::new(field.dtype)); + } + Ok(Arc::new(ApplyExpr::new( vec![input], function, diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 8cb4b8cc2387..d90ad5d76867 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -72,7 +72,6 @@ impl AExpr { }, Explode(expr) => { let field = arena.get(*expr).to_field_impl(schema, arena, nested)?; - *nested = nested.saturating_sub(1); if let List(inner) = field.dtype() { Ok(Field::new(field.name().clone(), *inner.clone())) diff --git a/py-polars/tests/unit/lazyframe/test_lazyframe.py b/py-polars/tests/unit/lazyframe/test_lazyframe.py index fbc9b5e1ae30..c18d086574b3 100644 --- a/py-polars/tests/unit/lazyframe/test_lazyframe.py +++ b/py-polars/tests/unit/lazyframe/test_lazyframe.py @@ -1424,3 +1424,11 @@ def test_lf_unnest() -> None: ] ) assert_frame_equal(lf.unnest("a", "b").collect(), expected) + + +def test_lf_schema_explode_in_agg_19562() -> None: + lf = pl.LazyFrame({"a": 1, "b": [[1]]}) + q = lf.group_by("a").agg(pl.col("b").explode()) + + assert q.collect_schema() == {"a": pl.Int32, "b": pl.List(pl.Int64)} + assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "b": [[1]]}))