Skip to content

Commit

Permalink
feat(rust, python): allow expression as quantile input (#5751)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 8, 2022
1 parent 5c5cc58 commit 3b7c0e5
Show file tree
Hide file tree
Showing 20 changed files with 248 additions and 143 deletions.
2 changes: 1 addition & 1 deletion polars/polars-lazy/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ pub enum AggExpr {
Count(Box<Expr>),
Quantile {
expr: Box<Expr>,
quantile: f64,
quantile: Box<Expr>,
interpol: QuantileInterpolOptions,
},
Sum(Box<Expr>),
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/polars-plan/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ pub fn median(name: &str) -> Expr {
}

/// Find a specific quantile of all the values in this Expression.
pub fn quantile(name: &str, quantile: f64, interpol: QuantileInterpolOptions) -> Expr {
pub fn quantile(name: &str, quantile: Expr, interpol: QuantileInterpolOptions) -> Expr {
col(name).quantile(quantile, interpol)
}

Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,10 @@ impl Expr {
}

/// Compute the quantile per group.
pub fn quantile(self, quantile: f64, interpol: QuantileInterpolOptions) -> Self {
pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self {
AggExpr::Quantile {
expr: Box::new(self),
quantile,
quantile: Box::new(quantile),
interpol,
}
.into()
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/polars-plan/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub enum AAggExpr {
List(Node),
Quantile {
expr: Node,
quantile: f64,
quantile: Node,
interpol: QuantileInterpolOptions,
},
Sum(Node),
Expand Down
9 changes: 5 additions & 4 deletions polars/polars-lazy/polars-plan/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> Node {
interpol,
} => AAggExpr::Quantile {
expr: to_aexpr(*expr, arena),
quantile,
quantile: to_aexpr(*quantile, arena),
interpol,
},
AggExpr::Sum(expr) => AAggExpr::Sum(to_aexpr(*expr, arena)),
Expand Down Expand Up @@ -540,10 +540,11 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
quantile,
interpol,
} => {
let exp = node_to_expr(expr, expr_arena);
let expr = node_to_expr(expr, expr_arena);
let quantile = node_to_expr(quantile, expr_arena);
AggExpr::Quantile {
expr: Box::new(exp),
quantile,
expr: Box::new(expr),
quantile: Box::new(quantile),
interpol,
}
.into()
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ impl LazyFrame {
/// .agg([
/// col("rain").min(),
/// col("rain").sum(),
/// col("rain").quantile(0.5, QuantileInterpolOptions::Nearest).alias("median_rain"),
/// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"),
/// ])
/// }
/// ```
Expand Down Expand Up @@ -982,7 +982,7 @@ impl LazyFrame {
}

/// Aggregate all the columns as their quantile values.
pub fn quantile(self, quantile: f64, interpol: QuantileInterpolOptions) -> LazyFrame {
pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> LazyFrame {
self.select_local(vec![col("*").quantile(quantile, interpol)])
}

Expand Down Expand Up @@ -1237,7 +1237,7 @@ impl LazyGroupBy {
/// .agg([
/// col("rain").min(),
/// col("rain").sum(),
/// col("rain").quantile(0.5, QuantileInterpolOptions::Nearest).alias("median_rain"),
/// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"),
/// ])
/// }
/// ```
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
//! .agg([
//! col("rain").min(),
//! col("rain").sum(),
//! col("rain").quantile(0.5, QuantileInterpolOptions::Nearest).alias("median_rain"),
//! col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"),
//! ])
//! .sort("date", Default::default())
//! .collect()
Expand Down
36 changes: 26 additions & 10 deletions polars/polars-lazy/src/physical_plan/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,32 +457,46 @@ impl PartitionedAggregation for AggregationExpr {
}

pub struct AggQuantileExpr {
pub(crate) expr: Arc<dyn PhysicalExpr>,
pub(crate) quantile: f64,
pub(crate) input: Arc<dyn PhysicalExpr>,
pub(crate) quantile: Arc<dyn PhysicalExpr>,
pub(crate) interpol: QuantileInterpolOptions,
}

impl AggQuantileExpr {
pub fn new(
expr: Arc<dyn PhysicalExpr>,
quantile: f64,
input: Arc<dyn PhysicalExpr>,
quantile: Arc<dyn PhysicalExpr>,
interpol: QuantileInterpolOptions,
) -> Self {
Self {
expr,
input,
quantile,
interpol,
}
}

fn get_quantile(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<f64> {
let quantile = self.quantile.evaluate(df, state)?;
if quantile.len() > 1 {
return Err(PolarsError::ComputeError(
"Polars only supports computing a single quantile. \
Make sure the 'quantile' expression input produces a single quantile."
.into(),
));
}
quantile.get(0).try_extract::<f64>()
}
}

impl PhysicalExpr for AggQuantileExpr {
fn as_expression(&self) -> Option<&Expr> {
None
}

fn evaluate(&self, _df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Series> {
unimplemented!()
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Series> {
let input = self.input.evaluate(df, state)?;
let quantile = self.get_quantile(df, state)?;
input.quantile_as_series(quantile, self.interpol)
}
#[allow(clippy::ptr_arg)]
fn evaluate_on_groups<'a>(
Expand All @@ -491,23 +505,25 @@ impl PhysicalExpr for AggQuantileExpr {
groups: &'a GroupsProxy,
state: &ExecutionState,
) -> PolarsResult<AggregationContext<'a>> {
let mut ac = self.expr.evaluate_on_groups(df, groups, state)?;
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
// don't change names by aggregations as is done in polars-core
let keep_name = ac.series().name().to_string();

let quantile = self.get_quantile(df, state)?;

// safety:
// groups are in bounds
let mut agg = unsafe {
ac.flat_naive()
.into_owned()
.agg_quantile(ac.groups(), self.quantile, self.interpol)
.agg_quantile(ac.groups(), quantile, self.interpol)
};
agg.rename(&keep_name);
Ok(AggregationContext::new(agg, Cow::Borrowed(groups), true))
}

fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.expr.to_field(input_schema)
self.input.to_field(input_schema)
}

fn is_valid_aggregation(&self) -> bool {
Expand Down
40 changes: 22 additions & 18 deletions polars/polars-lazy/src/physical_plan/planner/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,24 +452,28 @@ pub(crate) fn create_physical_expr(
} => {
// todo! add schema to get correct output type
let input = create_physical_expr(expr, ctxt, expr_arena, schema)?;
match ctxt {
Context::Aggregation => {
Ok(Arc::new(AggQuantileExpr::new(input, quantile, interpol)))
}
Context::Default => {
let function = SpecialEq::new(Arc::new(move |s: &mut [Series]| {
let s = std::mem::take(&mut s[0]);
s.quantile_as_series(quantile, interpol)
})
as Arc<dyn SeriesUdf>);
Ok(Arc::new(ApplyExpr::new_minimal(
vec![input],
function,
node_to_expr(expression, expr_arena),
ApplyOptions::ApplyFlat,
)))
}
}
let quantile = create_physical_expr(quantile, ctxt, expr_arena, schema)?;
Ok(Arc::new(AggQuantileExpr::new(input, quantile, interpol)))
//
// match ctxt {
// Context::Aggregation => {
//
// Ok(Arc::new(AggQuantileExpr::new(input, quantile, interpol)))
// }
// Context::Default => {
// let function = SpecialEq::new(Arc::new(move |s: &mut [Series]| {
// let s = std::mem::take(&mut s[0]);
// s.quantile_as_series(quantile, interpol)
// })
// as Arc<dyn SeriesUdf>);
// Ok(Arc::new(ApplyExpr::new_minimal(
// vec![input],
// function,
// node_to_expr(expression, expr_arena),
// ApplyOptions::ApplyFlat,
// )))
// }
// }
}
AAggExpr::AggGroups(expr) => {
if let Context::Default = ctxt {
Expand Down
30 changes: 0 additions & 30 deletions polars/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,6 @@ fn test_lazy_agg_scan() {
assert!(df.frame_equal_missing(&lf().collect().unwrap().mean()));
}

#[test]
fn test_lazy_df_aggregations() {
let df = load_df();

assert!(df
.clone()
.lazy()
.min()
.collect()
.unwrap()
.frame_equal_missing(&df.min()));
assert!(df
.clone()
.lazy()
.median()
.collect()
.unwrap()
.frame_equal_missing(&df.median()));
assert!(df
.clone()
.lazy()
.quantile(0.5, QuantileInterpolOptions::default())
.collect()
.unwrap()
.frame_equal_missing(
&df.quantile(0.5, QuantileInterpolOptions::default())
.unwrap()
));
}

#[test]
fn test_cumsum_agg_as_key() -> PolarsResult<()> {
let df = df![
Expand Down
36 changes: 0 additions & 36 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,42 +151,6 @@ fn test_lazy_pushdown_through_agg() {
assert_eq!(bar.get(0), AnyValue::Float64(1.3));
}

#[test]
#[cfg(feature = "temporal")]
fn test_lazy_agg() {
let s0 = DateChunked::parse_from_str_slice(
"date",
&[
"2020-08-21",
"2020-08-21",
"2020-08-22",
"2020-08-23",
"2020-08-22",
],
"%Y-%m-%d",
)
.into_series();
let s1 = Series::new("temp", [20, 10, 7, 9, 1].as_ref());
let s2 = Series::new("rain", [0.2, 0.1, 0.3, 0.1, 0.01].as_ref());
let df = DataFrame::new(vec![s0, s1, s2]).unwrap();

let lf = df
.lazy()
.groupby([col("date")])
.agg([
col("rain").min().alias("min"),
col("rain").sum().alias("sum"),
col("rain")
.quantile(0.5, QuantileInterpolOptions::default())
.alias("median_rain"),
])
.sort("date", Default::default());

let new = lf.collect().unwrap();
let min = new.column("min").unwrap();
assert_eq!(min, &Series::new("min", [0.1f64, 0.01, 0.1]));
}

#[test]
fn test_lazy_shift() {
let df = get_df();
Expand Down
Loading

0 comments on commit 3b7c0e5

Please sign in to comment.