-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First and Last Accumulators should update with state row excluding is_set flag #7565
Changes from all commits
9128f50
ecbf7eb
6c5c357
e4b1285
10f70f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,8 +165,6 @@ struct FirstValueAccumulator { | |
orderings: Vec<ScalarValue>, | ||
// Stores the applicable ordering requirement. | ||
ordering_req: LexOrdering, | ||
// Whether merge_batch() is called before | ||
is_merge_called: bool, | ||
} | ||
|
||
impl FirstValueAccumulator { | ||
|
@@ -185,7 +183,6 @@ impl FirstValueAccumulator { | |
is_set: false, | ||
orderings, | ||
ordering_req, | ||
is_merge_called: false, | ||
}) | ||
} | ||
|
||
|
@@ -201,9 +198,7 @@ impl Accumulator for FirstValueAccumulator { | |
fn state(&self) -> Result<Vec<ScalarValue>> { | ||
let mut result = vec![self.first.clone()]; | ||
result.extend(self.orderings.iter().cloned()); | ||
if !self.is_merge_called { | ||
result.push(ScalarValue::Boolean(Some(self.is_set))); | ||
} | ||
result.push(ScalarValue::Boolean(Some(self.is_set))); | ||
Ok(result) | ||
} | ||
|
||
|
@@ -218,7 +213,6 @@ impl Accumulator for FirstValueAccumulator { | |
} | ||
|
||
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { | ||
self.is_merge_called = true; | ||
// FIRST_VALUE(first1, first2, first3, ...) | ||
// last index contains is_set flag. | ||
let is_set_idx = states.len() - 1; | ||
|
@@ -237,13 +231,17 @@ impl Accumulator for FirstValueAccumulator { | |
}; | ||
if !ordered_states[0].is_empty() { | ||
let first_row = get_row_at_idx(&ordered_states, 0)?; | ||
let first_ordering = &first_row[1..]; | ||
// When collecting orderings, we exclude the is_set flag from the state. | ||
let first_ordering = &first_row[1..is_set_idx]; | ||
let sort_options = get_sort_options(&self.ordering_req); | ||
// Either there is no existing value, or there is an earlier version in new data. | ||
if !self.is_set | ||
|| compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering if we need to do anything about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not shorter. Actually, this fix makes them same length. Previously There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I found the only time |
||
{ | ||
self.update_with_new_row(&first_row); | ||
// Update with first value in the state. Note that we should exclude the | ||
// is_set flag from the state. Otherwise, we will end up with a state | ||
// containing two is_set flags. | ||
self.update_with_new_row(&first_row[0..is_set_idx]); | ||
} | ||
} | ||
Ok(()) | ||
|
@@ -390,8 +388,6 @@ struct LastValueAccumulator { | |
orderings: Vec<ScalarValue>, | ||
// Stores the applicable ordering requirement. | ||
ordering_req: LexOrdering, | ||
// Whether merge_batch() is called before | ||
is_merge_called: bool, | ||
} | ||
|
||
impl LastValueAccumulator { | ||
|
@@ -410,7 +406,6 @@ impl LastValueAccumulator { | |
is_set: false, | ||
orderings, | ||
ordering_req, | ||
is_merge_called: false, | ||
}) | ||
} | ||
|
||
|
@@ -426,9 +421,7 @@ impl Accumulator for LastValueAccumulator { | |
fn state(&self) -> Result<Vec<ScalarValue>> { | ||
let mut result = vec![self.last.clone()]; | ||
result.extend(self.orderings.clone()); | ||
if !self.is_merge_called { | ||
result.push(ScalarValue::Boolean(Some(self.is_set))); | ||
} | ||
result.push(ScalarValue::Boolean(Some(self.is_set))); | ||
Ok(result) | ||
} | ||
|
||
|
@@ -442,7 +435,6 @@ impl Accumulator for LastValueAccumulator { | |
} | ||
|
||
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { | ||
self.is_merge_called = true; | ||
// LAST_VALUE(last1, last2, last3, ...) | ||
// last index contains is_set flag. | ||
let is_set_idx = states.len() - 1; | ||
|
@@ -463,14 +455,18 @@ impl Accumulator for LastValueAccumulator { | |
if !ordered_states[0].is_empty() { | ||
let last_idx = ordered_states[0].len() - 1; | ||
let last_row = get_row_at_idx(&ordered_states, last_idx)?; | ||
let last_ordering = &last_row[1..]; | ||
// When collecting orderings, we exclude the is_set flag from the state. | ||
let last_ordering = &last_row[1..is_set_idx]; | ||
let sort_options = get_sort_options(&self.ordering_req); | ||
// Either there is no existing value, or there is a newer (latest) | ||
// version in the new data: | ||
if !self.is_set | ||
|| compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt() | ||
{ | ||
self.update_with_new_row(&last_row); | ||
// Update with last value in the state. Note that we should exclude the | ||
// is_set flag from the state. Otherwise, we will end up with a state | ||
// containing two is_set flags. | ||
self.update_with_new_row(&last_row[0..is_set_idx]); | ||
} | ||
} | ||
Ok(()) | ||
|
@@ -531,6 +527,7 @@ mod tests { | |
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_expr::Accumulator; | ||
|
||
use arrow::compute::concat; | ||
use std::sync::Arc; | ||
|
||
#[test] | ||
|
@@ -562,4 +559,72 @@ mod tests { | |
assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12))); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_first_last_state_after_merge() -> Result<()> { | ||
let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; | ||
// create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12 | ||
let arrs = ranges | ||
.into_iter() | ||
.map(|(start, end)| { | ||
Arc::new((start..end).collect::<Int64Array>()) as ArrayRef | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
// FirstValueAccumulator | ||
let mut first_accumulator = | ||
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; | ||
|
||
first_accumulator.update_batch(&[arrs[0].clone()])?; | ||
let state1 = first_accumulator.state()?; | ||
|
||
let mut first_accumulator = | ||
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; | ||
first_accumulator.update_batch(&[arrs[1].clone()])?; | ||
let state2 = first_accumulator.state()?; | ||
|
||
assert_eq!(state1.len(), state2.len()); | ||
|
||
let mut states = vec![]; | ||
|
||
for idx in 0..state1.len() { | ||
states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?); | ||
} | ||
|
||
let mut first_accumulator = | ||
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; | ||
first_accumulator.merge_batch(&states)?; | ||
|
||
let merged_state = first_accumulator.state()?; | ||
assert_eq!(merged_state.len(), state1.len()); | ||
|
||
// LastValueAccumulator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we probably can remove some common code here but not a blocker |
||
let mut last_accumulator = | ||
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; | ||
|
||
last_accumulator.update_batch(&[arrs[0].clone()])?; | ||
let state1 = last_accumulator.state()?; | ||
|
||
let mut last_accumulator = | ||
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; | ||
last_accumulator.update_batch(&[arrs[1].clone()])?; | ||
let state2 = last_accumulator.state()?; | ||
|
||
assert_eq!(state1.len(), state2.len()); | ||
|
||
let mut states = vec![]; | ||
|
||
for idx in 0..state1.len() { | ||
states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?); | ||
} | ||
|
||
let mut last_accumulator = | ||
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; | ||
last_accumulator.merge_batch(&states)?; | ||
|
||
let merged_state = last_accumulator.state()?; | ||
assert_eq!(merged_state.len(), state1.len()); | ||
|
||
Ok(()) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_set_idx
is an index forstates
but doesfirst_row
have the same index? Can they be different becausefiltered_states
may have filtered rows?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please never mind, this should be the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't
filtered_states
come fromstates
? Why they have different index?? They are simply partial aggregation inputs. Do you think there are different rows in partial aggregation inputs?