Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Window frame GROUPS mode support #4155

Merged
merged 8 commits into from
Nov 11, 2022
29 changes: 24 additions & 5 deletions datafusion/common/src/bisect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,41 @@ pub fn bisect<const SIDE: bool>(
target: &[ScalarValue],
sort_options: &[SortOptions],
) -> Result<usize> {
let mut low: usize = 0;
let mut high: usize = item_columns
let low: usize = 0;
let high: usize = item_columns
.get(0)
.ok_or_else(|| {
DataFusionError::Internal("Column array shouldn't be empty".to_string())
})?
.len();
let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
let cmp = compare(current, target, sort_options)?;
Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
};
find_bisect_point(item_columns, target, compare_fn, low, high)
}

/// This function searches for a tuple of target values among the given rows using the bisection algorithm.
/// The boolean-valued function `compare_fn` specifies whether we bisect on the left (with return value `false`),
/// or on the right (with return value `true`) as we compare the target value with the current value as we iteratively
/// bisect the input.
pub fn find_bisect_point<F>(
alamb marked this conversation as resolved.
Show resolved Hide resolved
item_columns: &[ArrayRef],
target: &[ScalarValue],
compare_fn: F,
mut low: usize,
mut high: usize,
) -> Result<usize>
where
F: Fn(&[ScalarValue], &[ScalarValue]) -> Result<bool>,
{
while low < high {
let mid = ((high - low) / 2) + low;
let val = item_columns
.iter()
.map(|arr| ScalarValue::try_from_array(arr, mid))
.collect::<Result<Vec<ScalarValue>>>()?;
let cmp = compare(&val, target, sort_options)?;
let flag = if SIDE { cmp.is_lt() } else { cmp.is_le() };
if flag {
if compare_fn(&val, target)? {
low = mid + 1;
} else {
high = mid;
Expand Down
8 changes: 1 addition & 7 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ use datafusion_expr::expr::{
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_expr::{WindowFrame, WindowFrameBound};
use datafusion_optimizer::utils::unalias;
use datafusion_physical_expr::expressions::Literal;
use datafusion_sql::utils::window_expr_common_partition_keys;
Expand Down Expand Up @@ -1457,12 +1457,6 @@ pub fn create_window_expr_with_name(
})
.collect::<Result<Vec<_>>>()?;
if let Some(ref window_frame) = window_frame {
if window_frame.units == WindowFrameUnits::Groups {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

return Err(DataFusionError::NotImplemented(
"Window frame definitions involving GROUPS are not supported yet"
.to_string(),
));
}
if !is_window_valid(window_frame) {
return Err(DataFusionError::Execution(format!(
"Invalid window frame: start bound ({}) cannot be larger than end bound ({})",
Expand Down
167 changes: 157 additions & 10 deletions datafusion/core/tests/sql/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1189,24 +1189,171 @@ async fn window_frame_ranges_unbounded_preceding_err() -> Result<()> {
}

#[tokio::test]
async fn window_frame_groups_query() -> Result<()> {
async fn window_frame_groups_preceding_following_desc() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT
SUM(c4) OVER(ORDER BY c2 DESC GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING),
SUM(c3) OVER(ORDER BY c2 DESC GROUPS BETWEEN 10000 PRECEDING AND 10000 FOLLOWING),
COUNT(*) OVER(ORDER BY c2 DESC GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
FROM aggregate_test_100
ORDER BY c9
LIMIT 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------------------+----------------------------+-----------------+",
"| SUM(aggregate_test_100.c4) | SUM(aggregate_test_100.c3) | COUNT(UInt8(1)) |",
"+----------------------------+----------------------------+-----------------+",
"| 52276 | 781 | 56 |",
"| 260620 | 781 | 63 |",
"| -28623 | 781 | 37 |",
"| 260620 | 781 | 63 |",
"| 260620 | 781 | 63 |",
"+----------------------------+----------------------------+-----------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn window_frame_groups_order_by_null_desc() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_null_cases_csv(&ctx).await?;
let sql = "SELECT
COUNT(c2) OVER (ORDER BY c1 DESC GROUPS BETWEEN 5 PRECEDING AND 3 FOLLOWING)
FROM null_cases
LIMIT 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------------+",
"| COUNT(null_cases.c2) |",
"+----------------------+",
"| 12 |",
"| 12 |",
"| 12 |",
"| 12 |",
"| 12 |",
"+----------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn window_frame_groups() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_null_cases_csv(&ctx).await?;
let sql = "SELECT
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as a,
SUM(c1) OVER (ORDER BY c3 DESC GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as b,
SUM(c1) OVER (ORDER BY c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as c,
SUM(c1) OVER (ORDER BY c3 DESC NULLS last GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as d,
SUM(c1) OVER (ORDER BY c3 DESC NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as e,
SUM(c1) OVER (ORDER BY c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as f,
SUM(c1) OVER (ORDER BY c3 GROUPS current row) as a1,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 9 PRECEDING AND 5 PRECEDING) as a2,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND 5 PRECEDING) as a3,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a4,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND current row) as a5,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as a6,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as a7,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 3 FOLLOWING AND UNBOUNDED FOLLOWING) as a8,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN current row AND UNBOUNDED FOLLOWING) as a9,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN current row AND 3 FOLLOWING) as a10,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 5 FOLLOWING AND 7 FOLLOWING) as a11,
SUM(c1) OVER (ORDER BY c3 DESC GROUPS current row) as a21,
SUM(c1) OVER (ORDER BY c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 5 PRECEDING) as a22,
SUM(c1) OVER (ORDER BY c3 DESC NULLS last GROUPS BETWEEN UNBOUNDED PRECEDING AND 5 PRECEDING) as a23,
SUM(c1) OVER (ORDER BY c3 NULLS last GROUPS BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a24,
SUM(c1) OVER (ORDER BY c3 DESC NULLS first GROUPS BETWEEN UNBOUNDED PRECEDING AND current row) as a25
FROM null_cases
ORDER BY c3
LIMIT 10";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----+-----+-----+-----+-----+-----+----+-----+-----+-----+-----+------+------+------+------+-----+-----+-----+-----+------+-----+------+",
"| a | b | c | d | e | f | a1 | a2 | a3 | a4 | a5 | a6 | a7 | a8 | a9 | a10 | a11 | a21 | a22 | a23 | a24 | a25 |",
"+-----+-----+-----+-----+-----+-----+----+-----+-----+-----+-----+------+------+------+------+-----+-----+-----+-----+------+-----+------+",
"| 412 | 307 | 412 | 307 | 307 | 412 | | | | 412 | | 4627 | 4627 | 4531 | 4627 | 115 | 85 | | | 4487 | 412 | 4627 |",
"| 488 | 339 | 488 | 339 | 339 | 488 | 72 | | | 488 | 72 | 4627 | 4627 | 4512 | 4627 | 140 | 153 | 72 | | 4473 | 488 | 4627 |",
"| 543 | 412 | 543 | 412 | 412 | 543 | 24 | | | 543 | 96 | 4627 | 4627 | 4487 | 4555 | 82 | 122 | 24 | | 4442 | 543 | 4555 |",
"| 553 | 488 | 553 | 488 | 488 | 553 | 19 | | | 553 | 115 | 4627 | 4555 | 4473 | 4531 | 89 | 114 | 19 | | 4402 | 553 | 4531 |",
"| 553 | 543 | 553 | 543 | 543 | 553 | 25 | | | 553 | 140 | 4627 | 4531 | 4442 | 4512 | 110 | 105 | 25 | | 4320 | 553 | 4512 |",
"| 591 | 553 | 591 | 553 | 553 | 591 | 14 | | | 591 | 154 | 4627 | 4512 | 4402 | 4487 | 167 | 181 | 14 | | 4320 | 591 | 4487 |",
"| 651 | 553 | 651 | 553 | 553 | 651 | 31 | 72 | 72 | 651 | 185 | 4627 | 4487 | 4320 | 4473 | 153 | 204 | 31 | 72 | 4288 | 651 | 4473 |",
"| 662 | 591 | 662 | 591 | 591 | 662 | 40 | 96 | 96 | 662 | 225 | 4627 | 4473 | 4320 | 4442 | 154 | 141 | 40 | 96 | 4215 | 662 | 4442 |",
"| 697 | 651 | 697 | 651 | 651 | 697 | 82 | 115 | 115 | 697 | 307 | 4627 | 4442 | 4288 | 4402 | 187 | 65 | 82 | 115 | 4139 | 697 | 4402 |",
"| 758 | 662 | 758 | 662 | 662 | 758 | | 140 | 140 | 758 | 307 | 4627 | 4402 | 4215 | 4320 | 181 | 48 | | 140 | 4084 | 758 | 4320 |",
"+-----+-----+-----+-----+-----+-----+----+-----+-----+-----+-----+------+------+------+------+-----+-----+-----+-----+------+-----+------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn window_frame_groups_multiple_order_columns() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_null_cases_csv(&ctx).await?;
let sql = "SELECT
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as a,
SUM(c1) OVER (ORDER BY c2, c3 DESC GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as b,
SUM(c1) OVER (ORDER BY c2, c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as c,
SUM(c1) OVER (ORDER BY c2, c3 DESC NULLS last GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as d,
SUM(c1) OVER (ORDER BY c2, c3 DESC NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as e,
SUM(c1) OVER (ORDER BY c2, c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as f,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS current row) as a1,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 9 PRECEDING AND 5 PRECEDING) as a2,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND 5 PRECEDING) as a3,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a4,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND current row) as a5,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as a6,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as a7,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 3 FOLLOWING AND UNBOUNDED FOLLOWING) as a8,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN current row AND UNBOUNDED FOLLOWING) as a9,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN current row AND 3 FOLLOWING) as a10,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 5 FOLLOWING AND 7 FOLLOWING) as a11
FROM null_cases
ORDER BY c3
LIMIT 10";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------+-----+------+-----+-----+------+----+-----+------+------+------+------+------+------+------+-----+-----+",
"| a | b | c | d | e | f | a1 | a2 | a3 | a4 | a5 | a6 | a7 | a8 | a9 | a10 | a11 |",
"+------+-----+------+-----+-----+------+----+-----+------+------+------+------+------+------+------+-----+-----+",
"| 818 | 910 | 818 | 910 | 910 | 818 | | 249 | 249 | 818 | 432 | 4627 | 4234 | 4157 | 4195 | 98 | 82 |",
"| 537 | 979 | 537 | 979 | 979 | 537 | 72 | | | 537 | 210 | 4627 | 4569 | 4378 | 4489 | 169 | 55 |",
"| 811 | 838 | 811 | 838 | 838 | 811 | 24 | 221 | 3075 | 3665 | 3311 | 4627 | 1390 | 1276 | 1340 | 117 | 144 |",
"| 763 | 464 | 763 | 464 | 464 | 763 | 19 | 168 | 3572 | 4167 | 3684 | 4627 | 962 | 829 | 962 | 194 | 80 |",
"| 552 | 964 | 552 | 964 | 964 | 552 | 25 | | | 552 | 235 | 4627 | 4489 | 4320 | 4417 | 167 | 39 |",
"| 963 | 930 | 963 | 930 | 930 | 963 | 14 | 201 | 818 | 1580 | 1098 | 4627 | 3638 | 3455 | 3543 | 177 | 224 |",
"| 1113 | 814 | 1113 | 814 | 814 | 1113 | 31 | 415 | 2653 | 3351 | 2885 | 4627 | 1798 | 1694 | 1773 | 165 | 162 |",
"| 780 | 868 | 780 | 868 | 868 | 780 | 40 | 258 | 3143 | 3665 | 3351 | 4627 | 1340 | 1223 | 1316 | 117 | 102 |",
"| 740 | 466 | 740 | 466 | 466 | 740 | 82 | 164 | 3592 | 4168 | 3766 | 4627 | 962 | 768 | 943 | 244 | 122 |",
"| 772 | 832 | 772 | 832 | 832 | 772 | | 277 | 3189 | 3684 | 3351 | 4627 | 1316 | 1199 | 1276 | 119 | 64 |",
"+------+-----+------+-----+-----+------+----+-----+------+------+------+------+------+------+------+-----+-----+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn window_frame_groups_without_order_by() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
// execute the query
let df = ctx
.sql(
"SELECT
COUNT(c1) OVER (ORDER BY c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
FROM aggregate_test_100;",
SUM(c4) OVER(PARTITION BY c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
FROM aggregate_test_100
ORDER BY c9;",
)
.await?;
let results = df.collect().await;
assert!(results
.as_ref()
.err()
.unwrap()
.to_string()
.contains("Window frame definitions involving GROUPS are not supported yet"));
let err = df.collect().await.unwrap_err();
assert_contains!(
err.to_string(),
"Execution error: GROUPS mode requires an ORDER BY clause".to_owned()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

);
Ok(())
}

Expand Down
6 changes: 4 additions & 2 deletions datafusion/physical-expr/src/window/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ use datafusion_expr::WindowFrame;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
use crate::{window::WindowExpr, AggregateExpr};

use super::window_frame_state::WindowFrameContext;

/// A window expr that takes the form of an aggregate function
#[derive(Debug)]
pub struct AggregateWindowExpr {
Expand Down Expand Up @@ -114,13 +116,13 @@ impl WindowExpr for AggregateWindowExpr {
.map(|v| v.slice(partition_range.start, length))
.collect::<Vec<_>>();

let mut window_frame_ctx = WindowFrameContext::new(&window_frame);
let mut last_range: (usize, usize) = (0, 0);

// We iterate on each row to perform a running calculation.
// First, cur_range is calculated, then it is compared with last_range.
for i in 0..length {
let cur_range = self.calculate_range(
&window_frame,
let cur_range = window_frame_ctx.calculate_range(
&slice_order_bys,
&sort_options,
length,
Expand Down
5 changes: 3 additions & 2 deletions datafusion/physical-expr/src/window/built_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Physical exec for built-in window function expressions.

use super::window_frame_state::WindowFrameContext;
use super::BuiltInWindowFunctionExpr;
use super::WindowExpr;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
Expand Down Expand Up @@ -113,10 +114,10 @@ impl WindowExpr for BuiltInWindowExpr {
.iter()
.map(|v| v.slice(partition_range.start, length))
.collect::<Vec<_>>();
let mut window_frame_ctx = WindowFrameContext::new(&window_frame);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very nice encapsulation of the window frame calculation. Thank you

// We iterate on each row to calculate window frame range and and window function result
for idx in 0..length {
let range = self.calculate_range(
&window_frame,
let range = window_frame_ctx.calculate_range(
&slice_order_bys,
&sort_options,
num_rows,
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) mod partition_evaluator;
pub(crate) mod rank;
pub(crate) mod row_number;
mod window_expr;
mod window_frame_state;

pub use aggregate::AggregateWindowExpr;
pub use built_in::BuiltInWindowExpr;
Expand Down
Loading