diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 540587fc2150..99c3adb47cff 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -9,40 +9,6 @@ use polars_utils::sync::SyncPtr; use polars_utils::total_ord::ToTotalOrd; use polars_utils::unwrap::UnwrapUncheckedRelease; -#[derive(Clone)] -pub struct Scalar { - dtype: DataType, - value: AnyValue<'static>, -} - -impl Scalar { - pub fn new(dtype: DataType, value: AnyValue<'static>) -> Self { - Self { dtype, value } - } - - pub fn value(&self) -> &AnyValue<'static> { - &self.value - } - - pub fn as_any_value(&self) -> AnyValue { - self.value - .strict_cast(&self.dtype) - .unwrap_or_else(|| self.value.clone()) - } - - pub fn into_series(self, name: &str) -> Series { - Series::from_any_values_and_dtype(name, &[self.as_any_value()], &self.dtype, true).unwrap() - } - - pub fn dtype(&self) -> &DataType { - &self.dtype - } - - pub fn update(&mut self, value: AnyValue<'static>) { - self.value = value; - } -} - use super::*; #[cfg(feature = "dtype-struct")] use crate::prelude::any_value::arr_to_any_value; @@ -854,8 +820,8 @@ impl<'a> AnyValue<'a> { pub fn add(&self, rhs: &AnyValue) -> AnyValue<'static> { use AnyValue::*; match (self, rhs) { - (Null, _) => Null, - (_, Null) => Null, + (Null, r) => r.clone().into_static().unwrap(), + (l, Null) => l.clone().into_static().unwrap(), (Int32(l), Int32(r)) => Int32(l + r), (Int64(l), Int64(r)) => Int64(l + r), (UInt32(l), UInt32(r)) => UInt32(l + r), diff --git a/crates/polars-core/src/lib.rs b/crates/polars-core/src/lib.rs index c5f4316b37b6..117f462619dc 100644 --- a/crates/polars-core/src/lib.rs +++ b/crates/polars-core/src/lib.rs @@ -19,6 +19,7 @@ mod named_from; pub mod prelude; #[cfg(feature = "random")] pub mod random; +pub mod scalar; pub mod schema; #[cfg(feature = "serde")] pub mod serde; diff --git a/crates/polars-core/src/prelude.rs b/crates/polars-core/src/prelude.rs index c906c961ae8b..a2a865f7e63c 100644 --- a/crates/polars-core/src/prelude.rs +++ b/crates/polars-core/src/prelude.rs @@ -47,6 +47,7 @@ pub use crate::frame::group_by::*; pub use crate::frame::{DataFrame, UniqueKeepStrategy}; pub use crate::hashing::VecHash; pub use crate::named_from::{NamedFrom, NamedFromOwned}; +pub use crate::scalar::Scalar; pub use crate::schema::*; #[cfg(feature = "checked_arithmetic")] pub use crate::series::arithmetic::checked::NumOpsDispatchChecked; diff --git a/crates/polars-core/src/scalar/mod.rs b/crates/polars-core/src/scalar/mod.rs new file mode 100644 index 000000000000..07ed78b0863f --- /dev/null +++ b/crates/polars-core/src/scalar/mod.rs @@ -0,0 +1,38 @@ +pub mod reduce; + +use crate::datatypes::{AnyValue, DataType}; +use crate::prelude::Series; + +#[derive(Clone)] +pub struct Scalar { + dtype: DataType, + value: AnyValue<'static>, +} + +impl Scalar { + pub fn new(dtype: DataType, value: AnyValue<'static>) -> Self { + Self { dtype, value } + } + + pub fn value(&self) -> &AnyValue<'static> { + &self.value + } + + pub fn as_any_value(&self) -> AnyValue { + self.value + .strict_cast(&self.dtype) + .unwrap_or_else(|| self.value.clone()) + } + + pub fn into_series(self, name: &str) -> Series { + Series::from_any_values_and_dtype(name, &[self.as_any_value()], &self.dtype, true).unwrap() + } + + pub fn dtype(&self) -> &DataType { + &self.dtype + } + + pub fn update(&mut self, value: AnyValue<'static>) { + self.value = value; + } +} diff --git a/crates/polars-core/src/scalar/reduce.rs b/crates/polars-core/src/scalar/reduce.rs new file mode 100644 index 000000000000..078dc0411fa4 --- /dev/null +++ b/crates/polars-core/src/scalar/reduce.rs @@ -0,0 +1,37 @@ +use crate::datatypes::{AnyValue, TimeUnit}; +#[cfg(feature = "dtype-date")] +use crate::prelude::MS_IN_DAY; +use crate::prelude::{DataType, Scalar}; + +pub fn mean_reduce(value: Option, dtype: DataType) -> Scalar { + match dtype { + DataType::Float32 => { + let val = value.map(|m| m as f32); + Scalar::new(dtype, val.into()) + }, + dt if dt.is_numeric() || dt.is_decimal() || dt.is_bool() => { + Scalar::new(DataType::Float64, value.into()) + }, + #[cfg(feature = "dtype-date")] + DataType::Date => { + let val = value.map(|v| (v * MS_IN_DAY as f64) as i64); + Scalar::new(DataType::Datetime(TimeUnit::Milliseconds, None), val.into()) + }, + #[cfg(feature = "dtype-datetime")] + dt @ DataType::Datetime(_, _) => { + let val = value.map(|v| v as i64); + Scalar::new(dt, val.into()) + }, + #[cfg(feature = "dtype-duration")] + dt @ DataType::Duration(_) => { + let val = value.map(|v| v as i64); + Scalar::new(dt, val.into()) + }, + #[cfg(feature = "dtype-time")] + dt @ DataType::Time => { + let val = value.map(|v| v as i64); + Scalar::new(dt, val.into()) + }, + dt => Scalar::new(dt, AnyValue::Null), + } +} diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index dacfa88a0e8f..cf8a8ce11b24 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -808,41 +808,7 @@ impl Series { } pub fn mean_reduce(&self) -> Scalar { - match self.dtype() { - DataType::Float32 => { - let val = self.mean().map(|m| m as f32); - Scalar::new(self.dtype().clone(), val.into()) - }, - dt if dt.is_numeric() || dt.is_decimal() || dt.is_bool() => { - let val = self.mean(); - Scalar::new(DataType::Float64, val.into()) - }, - #[cfg(feature = "dtype-date")] - DataType::Date => { - let val = self.mean().map(|v| (v * MS_IN_DAY as f64) as i64); - let av: AnyValue = val.into(); - Scalar::new(DataType::Datetime(TimeUnit::Milliseconds, None), av) - }, - #[cfg(feature = "dtype-datetime")] - dt @ DataType::Datetime(_, _) => { - let val = self.mean().map(|v| v as i64); - let av: AnyValue = val.into(); - Scalar::new(dt.clone(), av) - }, - #[cfg(feature = "dtype-duration")] - dt @ DataType::Duration(_) => { - let val = self.mean().map(|v| v as i64); - let av: AnyValue = val.into(); - Scalar::new(dt.clone(), av) - }, - #[cfg(feature = "dtype-time")] - dt @ DataType::Time => { - let val = self.mean().map(|v| v as i64); - let av: AnyValue = val.into(); - Scalar::new(dt.clone(), av) - }, - dt => Scalar::new(dt.clone(), AnyValue::Null), - } + crate::scalar::reduce::mean_reduce(self.mean(), self.dtype().clone()) } /// Compute the unique elements, but maintain order. This requires more work diff --git a/crates/polars-expr/src/lib.rs b/crates/polars-expr/src/lib.rs index 4ec62a1c9148..9981e47f1451 100644 --- a/crates/polars-expr/src/lib.rs +++ b/crates/polars-expr/src/lib.rs @@ -1,6 +1,7 @@ mod expressions; pub mod planner; pub mod prelude; +pub mod reduce; pub mod state; pub use crate::planner::{create_physical_expr, ExpressionConversionState}; diff --git a/crates/polars-expr/src/reduce/convert.rs b/crates/polars-expr/src/reduce/convert.rs new file mode 100644 index 000000000000..f5a33aca1a0b --- /dev/null +++ b/crates/polars-expr/src/reduce/convert.rs @@ -0,0 +1,82 @@ +use polars_core::error::feature_gated; +use polars_plan::prelude::*; +use polars_utils::arena::{Arena, Node}; + +use super::extrema::*; +use super::sum::SumReduce; +use super::*; +use crate::reduce::mean::MeanReduce; + +pub fn can_convert_into_reduction(node: Node, expr_arena: &Arena) -> bool { + match expr_arena.get(node) { + AExpr::Agg(agg) => matches!( + agg, + IRAggExpr::Min { .. } + | IRAggExpr::Max { .. } + | IRAggExpr::Mean { .. } + | IRAggExpr::Sum(_) + ), + _ => false, + } +} + +pub fn into_reduction( + node: Node, + expr_arena: &Arena, + schema: &Schema, +) -> PolarsResult, Node)>> { + let e = expr_arena.get(node); + let field = e.to_field(schema, Context::Default, expr_arena)?; + let out = match expr_arena.get(node) { + AExpr::Agg(agg) => match agg { + IRAggExpr::Sum(node) => ( + Box::new(SumReduce::new(field.dtype.clone())) as Box, + *node, + ), + IRAggExpr::Min { + propagate_nans, + input, + } => { + if *propagate_nans && field.dtype.is_float() { + feature_gated!("propagate_nans", { + let out: Box = match field.dtype { + DataType::Float32 => Box::new(MinNanReduce::::new()), + DataType::Float64 => Box::new(MinNanReduce::::new()), + _ => unreachable!(), + }; + (out, *input) + }) + } else { + ( + Box::new(MinReduce::new(field.dtype.clone())) as Box, + *input, + ) + } + }, + IRAggExpr::Max { + propagate_nans, + input, + } => { + if *propagate_nans && field.dtype.is_float() { + feature_gated!("propagate_nans", { + let out: Box = match field.dtype { + DataType::Float32 => Box::new(MaxNanReduce::::new()), + DataType::Float64 => Box::new(MaxNanReduce::::new()), + _ => unreachable!(), + }; + (out, *input) + }) + } else { + (Box::new(MaxReduce::new(field.dtype.clone())) as _, *input) + } + }, + IRAggExpr::Mean(input) => { + let out: Box = Box::new(MeanReduce::new(field.dtype.clone())); + (out, *input) + }, + _ => return Ok(None), + }, + _ => return Ok(None), + }; + Ok(Some(out)) +} diff --git a/crates/polars-expr/src/reduce/extrema.rs b/crates/polars-expr/src/reduce/extrema.rs new file mode 100644 index 000000000000..5eee559e1588 --- /dev/null +++ b/crates/polars-expr/src/reduce/extrema.rs @@ -0,0 +1,249 @@ +#[cfg(feature = "propagate_nans")] +use polars_core::datatypes::PolarsFloatType; +#[cfg(feature = "propagate_nans")] +use polars_ops::prelude::nan_propagating_aggregate; +#[cfg(feature = "propagate_nans")] +use polars_utils::min_max::MinMax; + +use super::*; + +#[derive(Clone)] +pub(super) struct MinReduce { + dtype: DataType, + value: Option, +} + +impl MinReduce { + pub(super) fn new(dtype: DataType) -> Self { + Self { dtype, value: None } + } + + fn update_impl(&mut self, other: &AnyValue<'static>) { + if let Some(value) = &mut self.value { + if other < value.value() { + value.update(other.clone()); + } + } else { + self.value = Some(Scalar::new(self.dtype.clone(), other.clone())) + } + } +} + +impl Reduction for MinReduce { + fn init_dyn(&self) -> Box { + Box::new(Self::new(self.dtype.clone())) + } + + fn reset(&mut self) { + *self = Self::new(self.dtype.clone()); + } + + fn update(&mut self, batch: &Series) -> PolarsResult<()> { + let sc = batch.min_reduce()?; + self.update_impl(sc.value()); + Ok(()) + } + + fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + if let Some(value) = &other.value { + self.update_impl(value.value()); + } + Ok(()) + } + + fn finalize(&mut self) -> PolarsResult { + if let Some(value) = self.value.take() { + Ok(value) + } else { + Ok(Scalar::new(self.dtype.clone(), AnyValue::Null)) + } + } + + fn as_any(&self) -> &dyn Any { + self + } +} +#[derive(Clone)] +pub(super) struct MaxReduce { + dtype: DataType, + value: Option, +} + +impl MaxReduce { + pub(super) fn new(dtype: DataType) -> Self { + Self { dtype, value: None } + } + fn update_impl(&mut self, other: &AnyValue<'static>) { + if let Some(value) = &mut self.value { + if other > value.value() { + value.update(other.clone()); + } + } else { + self.value = Some(Scalar::new(self.dtype.clone(), other.clone())) + } + } +} + +impl Reduction for MaxReduce { + fn init_dyn(&self) -> Box { + Box::new(Self::new(self.dtype.clone())) + } + fn reset(&mut self) { + *self = Self::new(self.dtype.clone()); + } + + fn update(&mut self, batch: &Series) -> PolarsResult<()> { + let sc = batch.max_reduce()?; + self.update_impl(sc.value()); + Ok(()) + } + + fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + + if let Some(value) = &other.value { + self.update_impl(value.value()); + } + Ok(()) + } + + fn finalize(&mut self) -> PolarsResult { + if let Some(value) = self.value.take() { + Ok(value) + } else { + Ok(Scalar::new(self.dtype.clone(), AnyValue::Null)) + } + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[cfg(feature = "propagate_nans")] +#[derive(Clone)] +pub(super) struct MaxNanReduce { + value: Option, +} + +#[cfg(feature = "propagate_nans")] +impl MaxNanReduce +where + T::Native: MinMax, +{ + pub(super) fn new() -> Self { + Self { value: None } + } + fn update_impl(&mut self, other: T::Native) { + if let Some(value) = self.value { + self.value = Some(MinMax::max_propagate_nan(value, other)); + } else { + self.value = Some(other); + } + } +} + +#[cfg(feature = "propagate_nans")] +impl Reduction for MaxNanReduce +where + T::Native: MinMax, +{ + fn init_dyn(&self) -> Box { + Box::new(Self::new()) + } + fn reset(&mut self) { + self.value = None; + } + + fn update(&mut self, batch: &Series) -> PolarsResult<()> { + if let Some(v) = nan_propagating_aggregate::ca_nan_agg( + batch.unpack::().unwrap(), + MinMax::max_propagate_nan, + ) { + self.update_impl(v) + } + Ok(()) + } + + fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + + if let Some(value) = &other.value { + self.update_impl(*value); + } + Ok(()) + } + + fn finalize(&mut self) -> PolarsResult { + let av = AnyValue::from(self.value); + Ok(Scalar::new(T::get_dtype(), av)) + } + + fn as_any(&self) -> &dyn Any { + self + } +} +#[cfg(feature = "propagate_nans")] +#[derive(Clone)] +pub(super) struct MinNanReduce { + value: Option, +} + +#[cfg(feature = "propagate_nans")] +impl crate::reduce::extrema::MinNanReduce +where + T::Native: MinMax, +{ + pub(super) fn new() -> Self { + Self { value: None } + } + fn update_impl(&mut self, other: T::Native) { + if let Some(value) = self.value { + self.value = Some(MinMax::min_propagate_nan(value, other)); + } else { + self.value = Some(other); + } + } +} + +#[cfg(feature = "propagate_nans")] +impl Reduction for crate::reduce::extrema::MinNanReduce +where + T::Native: MinMax, +{ + fn init_dyn(&self) -> Box { + Box::new(Self::new()) + } + fn reset(&mut self) { + self.value = None; + } + + fn update(&mut self, batch: &Series) -> PolarsResult<()> { + if let Some(v) = nan_propagating_aggregate::ca_nan_agg( + batch.unpack::().unwrap(), + MinMax::min_propagate_nan, + ) { + self.update_impl(v) + } + Ok(()) + } + + fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + + if let Some(value) = &other.value { + self.update_impl(*value); + } + Ok(()) + } + + fn finalize(&mut self) -> PolarsResult { + let av = AnyValue::from(self.value); + Ok(Scalar::new(T::get_dtype(), av)) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/mean.rs b/crates/polars-expr/src/reduce/mean.rs index 6cd35e7bf0bc..0d06974d956b 100644 --- a/crates/polars-expr/src/reduce/mean.rs +++ b/crates/polars-expr/src/reduce/mean.rs @@ -1,45 +1,69 @@ -use polars_core::prelude::{AnyValue, DataType}; use polars_core::utils::Container; + use super::*; +#[derive(Clone)] pub struct MeanReduce { - value: Scalar, + value: Option, len: u64, + dtype: DataType, } impl MeanReduce { pub(crate) fn new(dtype: DataType) -> Self { - let value = Scalar::new(dtype, AnyValue::Null); - Self { value, len: 0 } + let value = None; + Self { + value, + len: 0, + dtype, + } } - fn update_impl(&mut self, value: &AnyValue<'static>) { - self.value.update(self.value.value().add(value)) + fn update_impl(&mut self, value: &AnyValue<'static>, len: u64) { + let value = value.extract::().expect("phys numeric"); + if let Some(acc) = &mut self.value { + *acc += value; + self.len += len; + } else { + self.value = Some(value); + self.len = len; + } } } impl Reduction for MeanReduce { - fn init(&mut self) { - let av = AnyValue::zero(self.value.dtype()); - self.value.update(av); + fn init_dyn(&self) -> Box { + Box::new(Self::new(self.dtype.clone())) + } + fn reset(&mut self) { + self.value = None; + self.len = 0; } fn update(&mut self, batch: &Series) -> PolarsResult<()> { let sc = batch.sum_reduce()?; - self.update_impl(sc.value()); - self.len += batch.len() as u64; + self.update_impl(sc.value(), batch.len() as u64); Ok(()) } fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - self.update_impl(&other.value.value()); + + match (self.value, other.value) { + (Some(l), Some(r)) => self.value = Some(l + r), + (None, Some(r)) => self.value = Some(r), + (Some(l), None) => self.value = Some(l), + (None, None) => self.value = None, + } self.len += other.len; Ok(()) } fn finalize(&mut self) -> PolarsResult { - Ok(self.value.clone()) + Ok(polars_core::scalar::reduce::mean_reduce( + self.value.map(|v| v / self.len as f64), + self.dtype.clone(), + )) } fn as_any(&self) -> &dyn Any { diff --git a/crates/polars-expr/src/reduce/mod.rs b/crates/polars-expr/src/reduce/mod.rs new file mode 100644 index 000000000000..bb51ba5c8a8d --- /dev/null +++ b/crates/polars-expr/src/reduce/mod.rs @@ -0,0 +1,33 @@ +mod convert; +mod extrema; +mod mean; +mod sum; + +use std::any::Any; + +pub use convert::{can_convert_into_reduction, into_reduction}; +use polars_core::prelude::*; + +#[allow(dead_code)] +pub trait Reduction: Any + Send { + // Creates a fresh reduction. + fn init_dyn(&self) -> Box; + + // Resets this reduction to the fresh initial state. + fn reset(&mut self); + + fn update(&mut self, batch: &Series) -> PolarsResult<()>; + + /// # Safety + /// Implementations may elide bound checks. + unsafe fn update_gathered(&mut self, batch: &Series, idx: &[IdxSize]) -> PolarsResult<()> { + let batch = batch.take_unchecked_from_slice(idx); + self.update(&batch) + } + + fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()>; + + fn finalize(&mut self) -> PolarsResult; + + fn as_any(&self) -> &dyn Any; +} diff --git a/crates/polars-plan/src/reduce/sum.rs b/crates/polars-expr/src/reduce/sum.rs similarity index 83% rename from crates/polars-plan/src/reduce/sum.rs rename to crates/polars-expr/src/reduce/sum.rs index a3bbafaed88e..9e1e0e4600e4 100644 --- a/crates/polars-plan/src/reduce/sum.rs +++ b/crates/polars-expr/src/reduce/sum.rs @@ -2,6 +2,7 @@ use polars_core::prelude::{AnyValue, DataType}; use super::*; +#[derive(Clone)] pub struct SumReduce { value: Scalar, } @@ -18,7 +19,10 @@ impl SumReduce { } impl Reduction for SumReduce { - fn init(&mut self) { + fn init_dyn(&self) -> Box { + Box::new(Self::new(self.value.dtype().clone())) + } + fn reset(&mut self) { let av = AnyValue::zero(self.value.dtype()); self.value.update(av); } @@ -31,7 +35,7 @@ impl Reduction for SumReduce { fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - self.update_impl(&other.value.value()); + self.update_impl(other.value.value()); Ok(()) } diff --git a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs index fc996ceb5184..6c811ccbbf0f 100644 --- a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs +++ b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs @@ -11,7 +11,7 @@ use polars_core::frame::group_by::aggregations::{ use polars_core::prelude::*; use polars_utils::min_max::MinMax; -fn ca_nan_agg(ca: &ChunkedArray, min_or_max_fn: Agg) -> Option +pub fn ca_nan_agg(ca: &ChunkedArray, min_or_max_fn: Agg) -> Option where T: PolarsFloatType, Agg: Fn(T::Native, T::Native) -> T::Native + Copy, diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 2f569330c24c..37254c4da3aa 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -393,7 +393,7 @@ impl Operator { ) } - pub(crate) fn is_arithmetic(&self) -> bool { + pub fn is_arithmetic(&self) -> bool { !(self.is_comparison()) } } diff --git a/crates/polars-plan/src/lib.rs b/crates/polars-plan/src/lib.rs index a3b1b42808ab..3c495d30b9be 100644 --- a/crates/polars-plan/src/lib.rs +++ b/crates/polars-plan/src/lib.rs @@ -4,14 +4,12 @@ extern crate core; +#[cfg(feature = "polars_cloud")] +pub mod client; pub mod constants; pub mod dsl; pub mod frame; pub mod global; pub mod plans; pub mod prelude; -// Activate later -// mod reduce; -#[cfg(feature = "polars_cloud")] -pub mod client; pub mod utils; diff --git a/crates/polars-plan/src/reduce/convert.rs b/crates/polars-plan/src/reduce/convert.rs deleted file mode 100644 index 03484152709b..000000000000 --- a/crates/polars-plan/src/reduce/convert.rs +++ /dev/null @@ -1,44 +0,0 @@ -use polars_core::datatypes::Field; -use polars_utils::arena::{Arena, Node}; - -use super::*; -use crate::prelude::{AExpr, IRAggExpr}; -use crate::reduce::sum::SumReduce; - - -struct ReductionImpl { - reduce: Box, - prepare: Node -} - -impl ReductionImpl { - fn new(reduce: Box, prepare: Node) -> Self { - ReductionImpl { - reduce, - prepare - } - - } - -} - -pub fn into_reduction( - node: Node, - expr_arena: Arena, - field: &Field, -) -> ReductionImpl { - match expr_arena.get(node) { - AExpr::Agg(agg) => match agg { - IRAggExpr::Sum(node) => { - ReductionImpl::new( - Box::new(SumReduce::new(field.dtype.clone())), - *node - ) - }, - _ => todo!(), - }, - _ => { - todo!() - }, - } -} diff --git a/crates/polars-plan/src/reduce/extrema.rs b/crates/polars-plan/src/reduce/extrema.rs deleted file mode 100644 index 27ef0d0fb0bb..000000000000 --- a/crates/polars-plan/src/reduce/extrema.rs +++ /dev/null @@ -1,80 +0,0 @@ -use polars_core::prelude::AnyValue; - -use super::*; - -struct MinReduce { - value: Scalar, -} - -impl MinReduce { - fn update_impl(&mut self, value: &AnyValue<'static>) { - if value < self.value.value() { - self.value.update(value.clone()); - } - } -} - -impl Reduction for MinReduce { - fn init(&mut self) { - let av = AnyValue::zero(self.value.dtype()); - self.value.update(av); - } - - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let sc = batch.min_reduce()?; - self.update_impl(sc.value()); - Ok(()) - } - - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - self.update_impl(&other.value.value()); - Ok(()) - } - - fn finalize(&mut self) -> PolarsResult { - Ok(self.value.clone()) - } - - fn as_any(&self) -> &dyn Any { - self - } -} -struct MaxReduce { - value: Scalar, -} - -impl MaxReduce { - fn update_impl(&mut self, value: &AnyValue<'static>) { - if value > self.value.value() { - self.value.update(value.clone()); - } - } -} - -impl Reduction for MaxReduce { - fn init(&mut self) { - let av = AnyValue::zero(self.value.dtype()); - self.value.update(av); - } - - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let sc = batch.max_reduce()?; - self.update_impl(sc.value()); - Ok(()) - } - - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - self.update_impl(&other.value.value()); - Ok(()) - } - - fn finalize(&mut self) -> PolarsResult { - Ok(self.value.clone()) - } - - fn as_any(&self) -> &dyn Any { - self - } -} diff --git a/crates/polars-plan/src/reduce/mod.rs b/crates/polars-plan/src/reduce/mod.rs deleted file mode 100644 index 7b4c0f3c8877..000000000000 --- a/crates/polars-plan/src/reduce/mod.rs +++ /dev/null @@ -1,22 +0,0 @@ -mod convert; -mod extrema; -mod sum; - -use std::any::Any; - -use arrow::legacy::error::PolarsResult; -use polars_core::datatypes::Scalar; -use polars_core::prelude::Series; - -#[allow(dead_code)] -trait Reduction: Any { - fn init(&mut self); - - fn update(&mut self, batch: &Series) -> PolarsResult<()>; - - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()>; - - fn finalize(&mut self) -> PolarsResult; - - fn as_any(&self) -> &dyn Any; -} diff --git a/crates/polars-stream/src/nodes/mod.rs b/crates/polars-stream/src/nodes/mod.rs index 9795bf6d4c11..723c7a222d4a 100644 --- a/crates/polars-stream/src/nodes/mod.rs +++ b/crates/polars-stream/src/nodes/mod.rs @@ -4,6 +4,7 @@ pub mod in_memory_sink; pub mod in_memory_source; pub mod map; pub mod ordered_union; +pub mod reduce; pub mod select; pub mod simple_projection; pub mod streaming_slice; diff --git a/crates/polars-stream/src/nodes/reduce.rs b/crates/polars-stream/src/nodes/reduce.rs new file mode 100644 index 000000000000..ba8ccba3e51e --- /dev/null +++ b/crates/polars-stream/src/nodes/reduce.rs @@ -0,0 +1,206 @@ +use std::sync::Arc; + +use parking_lot::Mutex; +use polars_core::frame::DataFrame; +use polars_core::schema::SchemaRef; +use polars_error::PolarsResult; +use polars_expr::prelude::{ExecutionState, PhysicalExpr}; +use polars_expr::reduce::Reduction; + +use super::compute_node_prelude::*; +use crate::async_executor::{JoinHandle, TaskPriority, TaskScope}; +use crate::graph::PortState; +use crate::morsel::{Morsel, MorselSeq}; +use crate::nodes::ComputeNode; + +// All reductions in a single operation. +// `select(sum, min) -> vec![sum, min] +type ReduceSet = Vec>; + +enum ReduceState { + Sink { + inputs: Vec>, + reductions: Arc>>>, + }, + Source(Mutex>), + Done, +} + +pub struct ReduceNode { + // Reductions that are ready to finalize + full: Arc>>, + state: ReduceState, + output_schema: SchemaRef, +} + +impl ReduceNode { + pub fn new( + inputs: Vec>, + reductions: ReduceSet, + output_schema: SchemaRef, + ) -> Self { + Self { + state: ReduceState::Sink { + inputs, + reductions: Arc::new(Mutex::new(reductions)), + }, + output_schema, + full: Default::default(), + } + } + + fn spawn_sink<'env, 's>( + &'env self, + scope: &'s TaskScope<'s, 'env>, + _pipeline: usize, + recv: &mut [Option>], + send: &mut [Option>], + state: &'s ExecutionState, + ) -> JoinHandle> { + assert!(send.len() == 1 && recv.len() == 1); + let ReduceState::Sink { + inputs, reductions, .. + } = &self.state + else { + unreachable!() + }; + + let mut recv = recv[0].take().unwrap(); + let full = self.full.clone(); + + scope.spawn_task(TaskPriority::High, async move { + let mut reductions = reductions + .lock() + .iter() + .map(|d| d.init_dyn()) + .collect::>(); + + while let Ok(morsel) = recv.recv().await { + let df = morsel.into_df(); + + for (i, input) in inputs.iter().enumerate() { + let reduction_input = input.evaluate(&df, state)?; + reductions[i].update(&reduction_input.to_physical_repr())?; + } + } + + full.lock().push(reductions); + + Ok(()) + }) + } + + fn spawn_source<'env, 's>( + &'env self, + scope: &'s TaskScope<'s, 'env>, + _pipeline: usize, + recv: &mut [Option>], + send: &mut [Option>], + _state: &'s ExecutionState, + ) -> JoinHandle> { + assert!(send.len() == 1 && recv.len() == 1); + let ReduceState::Source(df) = &self.state else { + unreachable!() + }; + let mut send = send[0].take().unwrap(); + + scope.spawn_task(TaskPriority::High, async move { + let Some(df) = df.lock().take() else { + return Ok(()); + }; + let morsel = Morsel::new(df, MorselSeq::new(0)); + let _ = send.send(morsel).await; + Ok(()) + }) + } +} + +impl ComputeNode for ReduceNode { + fn name(&self) -> &str { + "reduce" + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + assert!(recv.len() == 1 && send.len() == 1); + + // State transitions + // If the output doesn't want any more data, transition to being done. + if send[0] == PortState::Done && !matches!(&self.state, ReduceState::Done) { + self.state = ReduceState::Done; + } + + match self.state { + // Input is done, we can combine the reductions into a single scalar. + ReduceState::Sink { .. } if matches!(recv[0], PortState::Done) => { + let reductions = std::mem::take(&mut *self.full.lock()); + + let reductions = reductions + .into_iter() + .map(PolarsResult::Ok) + .reduce(|a, b| { + let mut a = a?; + let mut b = b?; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + a.combine(b.as_ref())? + } + Ok(a) + }) + .expect("expected at least 1 thread running this node"); + + // TODO! make `update_state` fallible. + let reductions = reductions.unwrap(); + + let columns = reductions + .into_iter() + .zip(self.output_schema.iter_fields()) + .map(|(mut r, field)| { + r.finalize().map(|scalar| { + scalar.into_series(&field.name).cast(&field.dtype).unwrap() + }) + }) + .collect::>>() + .unwrap(); + let out = unsafe { DataFrame::new_no_checks(columns) }; + + self.state = ReduceState::Source(Mutex::new(Some(out))); + }, + // We have fed the source, we are done. + ReduceState::Source(ref df) if df.lock().is_none() => { + self.state = ReduceState::Done; + }, + // Nothing to change. + ReduceState::Done | ReduceState::Sink { .. } | ReduceState::Source(_) => {}, + } + + // Communicate state + match &self.state { + ReduceState::Sink { .. } => { + send[0] = PortState::Blocked; + recv[0] = PortState::Ready; + }, + ReduceState::Source(..) => { + recv[0] = PortState::Done; + send[0] = PortState::Ready; + }, + ReduceState::Done => { + recv[0] = PortState::Done; + send[0] = PortState::Done; + }, + } + } + + fn spawn<'env, 's>( + &'env self, + scope: &'s TaskScope<'s, 'env>, + pipeline: usize, + recv: &mut [Option>], + send: &mut [Option>], + state: &'s ExecutionState, + ) -> JoinHandle> { + match self.state { + ReduceState::Sink { .. } => self.spawn_sink(scope, pipeline, recv, send, state), + ReduceState::Source(..) => self.spawn_source(scope, pipeline, recv, send, state), + ReduceState::Done => unreachable!(), + } + } +} diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 9cce563ab779..615d5d51f8de 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use polars_error::PolarsResult; +use polars_expr::reduce::can_convert_into_reduction; use polars_plan::plans::{AExpr, Context, IR}; use polars_plan::prelude::SinkType; use polars_utils::arena::{Arena, Node}; @@ -44,6 +45,29 @@ pub fn lower_ir( extend_original: false, })) }, + // TODO: split reductions and streamable selections. E.g. sum(a) + sum(b) should be split + // into Select(a + b) -> Reduce(sum(a), sum(b) + IR::Select { + input, + expr, + schema: output_schema, + .. + } if expr + .iter() + .all(|e| can_convert_into_reduction(e.node(), expr_arena)) => + { + let exprs = expr.clone(); + let input_ir_node = ir_arena.get(*input); + let input_schema = input_ir_node.schema(ir_arena).into_owned(); + let output_schema = output_schema.clone(); + let input_node = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; + Ok(phys_sm.insert(PhysNode::Reduce { + input: input_node, + exprs, + input_schema, + output_schema, + })) + }, // TODO: split partially streamable selections to avoid fallback as much as possible. IR::HStack { diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index cf3719e8e680..aef0c1d7cdeb 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -33,7 +33,12 @@ pub enum PhysNode { extend_original: bool, output_schema: Arc, }, - + Reduce { + input: PhysNodeKey, + exprs: Vec, + input_schema: Arc, + output_schema: Arc, + }, StreamingSlice { input: PhysNodeKey, offset: usize, diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index cc9a8227c5cb..ad8134110415 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -3,8 +3,10 @@ use std::sync::Arc; use parking_lot::Mutex; use polars_error::PolarsResult; use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState}; +use polars_expr::reduce::into_reduction; use polars_expr::state::ExecutionState; use polars_mem_engine::create_physical_plan; +use polars_plan::plans::expr_ir::ExprIR; use polars_plan::plans::{AExpr, Context, IR}; use polars_utils::arena::Arena; use recursive::recursive; @@ -115,7 +117,39 @@ fn to_graph_rec<'a>( [input_key], ) }, + Reduce { + input, + exprs, + input_schema, + output_schema, + } => { + let input_key = to_graph_rec(*input, ctx)?; + + let mut reductions = Vec::with_capacity(exprs.len()); + let mut inputs = Vec::with_capacity(reductions.len()); + + for e in exprs { + let (red, input_node) = + into_reduction(e.node(), ctx.expr_arena, input_schema.as_ref())? + .expect("invariant"); + reductions.push(red); + let input_phys = create_physical_expr( + &ExprIR::from_node(input_node, ctx.expr_arena), + Context::Default, + ctx.expr_arena, + None, + &mut ctx.expr_conversion_state, + )?; + + inputs.push(input_phys) + } + + ctx.graph.add_node( + nodes::reduce::ReduceNode::new(inputs, reductions, output_schema.clone()), + [input_key], + ) + }, SimpleProjection { schema, input } => { let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node(