Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First and Last Accumulators should update with state row excluding is_set flag #7565

Merged
merged 5 commits into from
Sep 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 83 additions & 18 deletions datafusion/physical-expr/src/aggregate/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ struct FirstValueAccumulator {
orderings: Vec<ScalarValue>,
// Stores the applicable ordering requirement.
ordering_req: LexOrdering,
// Whether merge_batch() is called before
is_merge_called: bool,
}

impl FirstValueAccumulator {
Expand All @@ -185,7 +183,6 @@ impl FirstValueAccumulator {
is_set: false,
orderings,
ordering_req,
is_merge_called: false,
})
}

Expand All @@ -201,9 +198,7 @@ impl Accumulator for FirstValueAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.first.clone()];
result.extend(self.orderings.iter().cloned());
if !self.is_merge_called {
result.push(ScalarValue::Boolean(Some(self.is_set)));
}
result.push(ScalarValue::Boolean(Some(self.is_set)));
Ok(result)
}

Expand All @@ -218,7 +213,6 @@ impl Accumulator for FirstValueAccumulator {
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.is_merge_called = true;
// FIRST_VALUE(first1, first2, first3, ...)
// last index contains is_set flag.
let is_set_idx = states.len() - 1;
Expand All @@ -237,13 +231,17 @@ impl Accumulator for FirstValueAccumulator {
};
if !ordered_states[0].is_empty() {
let first_row = get_row_at_idx(&ordered_states, 0)?;
let first_ordering = &first_row[1..];
// When collecting orderings, we exclude the is_set flag from the state.
let first_ordering = &first_row[1..is_set_idx];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_set_idx is an index for states but does first_row have the same index? Can they be different because filtered_states may have filtered rows?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please never mind, this should be the same

Copy link
Member Author

@viirya viirya Sep 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't filtered_states come from states? Why they have different index?? They are simply partial aggregation inputs. Do you think there are different rows in partial aggregation inputs?

let sort_options = get_sort_options(&self.ordering_req);
// Either there is no existing value, or there is an earlier version in new data.
if !self.is_set
|| compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we need to do anything about compare_rows(first_ordering, &self.orderings, &sort_options) since now first_ordering is shorter?

Copy link
Member Author

@viirya viirya Sep 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not shorter. Actually, this fix makes them same length. Previously first_ordering has more one element (is_set) but it is not in orderings.

Copy link
Contributor

@kazuyukitanimura kazuyukitanimura Sep 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I found the only time self.orderings can be longer is right after the creation by try_new(). But at the time it is self.is_set = false, so it will not hit compare_rows. Then self.orderings is updated by first_ordering. So I agree that this should work.

{
self.update_with_new_row(&first_row);
// Update with first value in the state. Note that we should exclude the
// is_set flag from the state. Otherwise, we will end up with a state
// containing two is_set flags.
self.update_with_new_row(&first_row[0..is_set_idx]);
}
}
Ok(())
Expand Down Expand Up @@ -390,8 +388,6 @@ struct LastValueAccumulator {
orderings: Vec<ScalarValue>,
// Stores the applicable ordering requirement.
ordering_req: LexOrdering,
// Whether merge_batch() is called before
is_merge_called: bool,
}

impl LastValueAccumulator {
Expand All @@ -410,7 +406,6 @@ impl LastValueAccumulator {
is_set: false,
orderings,
ordering_req,
is_merge_called: false,
})
}

Expand All @@ -426,9 +421,7 @@ impl Accumulator for LastValueAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.last.clone()];
result.extend(self.orderings.clone());
if !self.is_merge_called {
result.push(ScalarValue::Boolean(Some(self.is_set)));
}
result.push(ScalarValue::Boolean(Some(self.is_set)));
Ok(result)
}

Expand All @@ -442,7 +435,6 @@ impl Accumulator for LastValueAccumulator {
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.is_merge_called = true;
// LAST_VALUE(last1, last2, last3, ...)
// last index contains is_set flag.
let is_set_idx = states.len() - 1;
Expand All @@ -463,14 +455,18 @@ impl Accumulator for LastValueAccumulator {
if !ordered_states[0].is_empty() {
let last_idx = ordered_states[0].len() - 1;
let last_row = get_row_at_idx(&ordered_states, last_idx)?;
let last_ordering = &last_row[1..];
// When collecting orderings, we exclude the is_set flag from the state.
let last_ordering = &last_row[1..is_set_idx];
let sort_options = get_sort_options(&self.ordering_req);
// Either there is no existing value, or there is a newer (latest)
// version in the new data:
if !self.is_set
|| compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt()
{
self.update_with_new_row(&last_row);
// Update with last value in the state. Note that we should exclude the
// is_set flag from the state. Otherwise, we will end up with a state
// containing two is_set flags.
self.update_with_new_row(&last_row[0..is_set_idx]);
}
}
Ok(())
Expand Down Expand Up @@ -531,6 +527,7 @@ mod tests {
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;

use arrow::compute::concat;
use std::sync::Arc;

#[test]
Expand Down Expand Up @@ -562,4 +559,72 @@ mod tests {
assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
Ok(())
}

#[test]
fn test_first_last_state_after_merge() -> Result<()> {
let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
// create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
let arrs = ranges
.into_iter()
.map(|(start, end)| {
Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
})
.collect::<Vec<_>>();

// FirstValueAccumulator
let mut first_accumulator =
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;

first_accumulator.update_batch(&[arrs[0].clone()])?;
let state1 = first_accumulator.state()?;

let mut first_accumulator =
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
first_accumulator.update_batch(&[arrs[1].clone()])?;
let state2 = first_accumulator.state()?;

assert_eq!(state1.len(), state2.len());

let mut states = vec![];

for idx in 0..state1.len() {
states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?);
}

let mut first_accumulator =
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
first_accumulator.merge_batch(&states)?;

let merged_state = first_accumulator.state()?;
assert_eq!(merged_state.len(), state1.len());

// LastValueAccumulator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably can remove some common code here but not a blocker

let mut last_accumulator =
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;

last_accumulator.update_batch(&[arrs[0].clone()])?;
let state1 = last_accumulator.state()?;

let mut last_accumulator =
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
last_accumulator.update_batch(&[arrs[1].clone()])?;
let state2 = last_accumulator.state()?;

assert_eq!(state1.len(), state2.len());

let mut states = vec![];

for idx in 0..state1.len() {
states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?);
}

let mut last_accumulator =
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
last_accumulator.merge_batch(&states)?;

let merged_state = last_accumulator.state()?;
assert_eq!(merged_state.len(), state1.len());

Ok(())
}
}