Skip to content

Commit

Permalink
Simplify stream agg
Browse files Browse the repository at this point in the history
  • Loading branch information
ice1000 committed Apr 2, 2023
1 parent d63532f commit cb7c382
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 89 deletions.
16 changes: 16 additions & 0 deletions src/frontend/src/optimizer/plan_node/generic/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,22 @@ pub struct MaterializedInputState {
}

impl<PlanRef: stream::StreamPlanRef> Agg<PlanRef> {
pub fn infer_tables(
&self,
me: &impl stream::StreamPlanRef,
vnode_col_idx: Option<usize>,
) -> (
TableCatalog,
Vec<AggCallState>,
HashMap<usize, TableCatalog>,
) {
(
self.infer_result_table(me, vnode_col_idx),
self.infer_stream_agg_state(me, vnode_col_idx),
self.infer_distinct_dedup_tables(me, vnode_col_idx),
)
}

/// Infer `AggCallState`s for streaming agg.
pub fn infer_stream_agg_state(
&self,
Expand Down
46 changes: 22 additions & 24 deletions src/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl LogicalAgg {
let local_agg = StreamLocalSimpleAgg::new(self.clone_with_input(stream_input));
let exchange =
RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
let global_agg = new_stream_global_simple_agg(LogicalAgg::new(
let global_agg = new_stream_global_simple_agg(generic::Agg::new(
self.agg_calls()
.iter()
.enumerate()
Expand Down Expand Up @@ -129,7 +129,7 @@ impl LogicalAgg {
local_group_key.push(vnode_col_idx);
let n_local_group_key = local_group_key.len();
let local_agg = new_stream_hash_agg(
LogicalAgg::new(self.agg_calls().to_vec(), local_group_key, project.into()),
generic::Agg::new(self.agg_calls().to_vec(), local_group_key, project.into()),
Some(vnode_col_idx),
);
// Global group key excludes vnode.
Expand All @@ -144,7 +144,7 @@ impl LogicalAgg {
if self.group_key().is_empty() {
let exchange =
RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
let global_agg = new_stream_global_simple_agg(LogicalAgg::new(
let global_agg = new_stream_global_simple_agg(generic::Agg::new(
self.agg_calls()
.iter()
.enumerate()
Expand All @@ -162,7 +162,7 @@ impl LogicalAgg {
// Local phase should have reordered the group keys into their required order.
// we can just follow it.
let global_agg = new_stream_hash_agg(
LogicalAgg::new(
generic::Agg::new(
self.agg_calls()
.iter()
.enumerate()
Expand All @@ -181,21 +181,18 @@ impl LogicalAgg {
}

fn gen_single_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
Ok(new_stream_global_simple_agg(self.clone_with_input(
RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?,
))
.into())
let mut logical = self.core.clone();
let input = RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?;
logical.input = input;
Ok(new_stream_global_simple_agg(logical).into())
}

fn gen_shuffle_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
Ok(new_stream_hash_agg(
self.clone_with_input(
RequiredDist::shard_by_key(stream_input.schema().len(), self.group_key())
.enforce_if_not_satisfies(stream_input, &Order::any())?,
),
None,
)
.into())
let input = RequiredDist::shard_by_key(stream_input.schema().len(), self.group_key())
.enforce_if_not_satisfies(stream_input, &Order::any())?;
let mut logical = self.core.clone();
logical.input = input;
Ok(new_stream_hash_agg(logical, None).into())
}

/// See if all stream aggregation calls have a stateless local agg counterpart.
Expand Down Expand Up @@ -1164,32 +1161,33 @@ impl ToBatch for LogicalAgg {
}
}

fn find_or_append_row_count(mut logical: LogicalAgg) -> (LogicalAgg, usize) {
fn find_or_append_row_count(mut logical: generic::Agg<PlanRef>) -> (generic::Agg<PlanRef>, usize) {
// `HashAgg`/`GlobalSimpleAgg` executors require a `count(*)` to correctly build changes, so
// append a `count(*)` if not exists.
let count_star = PlanAggCall::count_star();
let row_count_idx = if let Some((idx, _)) = logical
.agg_calls()
.agg_calls
.iter()
.find_position(|&c| c == &count_star)
{
idx
} else {
let (mut agg_calls, group_key, input) = logical.decompose();
let idx = agg_calls.len();
agg_calls.push(count_star);
logical = LogicalAgg::new(agg_calls, group_key, input);
let idx = logical.agg_calls.len();
logical.agg_calls.push(count_star);
idx
};
(logical, row_count_idx)
}

fn new_stream_global_simple_agg(logical: LogicalAgg) -> StreamGlobalSimpleAgg {
fn new_stream_global_simple_agg(logical: generic::Agg<PlanRef>) -> StreamGlobalSimpleAgg {
let (logical, row_count_idx) = find_or_append_row_count(logical);
StreamGlobalSimpleAgg::new(logical, row_count_idx)
}

fn new_stream_hash_agg(logical: LogicalAgg, vnode_col_idx: Option<usize>) -> StreamHashAgg {
fn new_stream_hash_agg(
logical: generic::Agg<PlanRef>,
vnode_col_idx: Option<usize>,
) -> StreamHashAgg {
let (logical, row_count_idx) = find_or_append_row_count(logical);
StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
}
Expand Down
53 changes: 24 additions & 29 deletions src/frontend/src/optimizer/plan_node/stream_global_simple_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,30 @@ use fixedbitset::FixedBitSet;
use itertools::Itertools;
use risingwave_pb::stream_plan::stream_node::PbNodeBody;

use super::generic::PlanAggCall;
use super::{ExprRewritable, LogicalAgg, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode};
use super::generic::{self, PlanAggCall};
use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode};
use crate::expr::ExprRewriter;
use crate::optimizer::property::Distribution;
use crate::stream_fragmenter::BuildFragmentGraphState;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct StreamGlobalSimpleAgg {
pub base: PlanBase,
logical: LogicalAgg,
logical: generic::Agg<PlanRef>,

/// The index of `count(*)` in `agg_calls`.
row_count_idx: usize,
}

impl StreamGlobalSimpleAgg {
pub fn new(logical: LogicalAgg, row_count_idx: usize) -> Self {
assert_eq!(
logical.agg_calls()[row_count_idx],
PlanAggCall::count_star()
);

let ctx = logical.base.ctx.clone();
let pk_indices = logical.base.logical_pk.to_vec();
let schema = logical.schema().clone();
let input = logical.input();
pub fn new(logical: generic::Agg<PlanRef>, row_count_idx: usize) -> Self {
assert_eq!(logical.agg_calls[row_count_idx], PlanAggCall::count_star());

let base = PlanBase::new_logical_with_core(&logical);
let ctx = base.ctx;
let pk_indices = base.logical_pk;
let schema = base.schema;
let input = logical.input.clone();
let input_dist = input.distribution();
let dist = match input_dist {
Distribution::Single => Distribution::Single,
Expand All @@ -59,7 +57,7 @@ impl StreamGlobalSimpleAgg {
ctx,
schema,
pk_indices,
logical.functional_dependency().clone(),
base.functional_dependency,
dist,
false,
watermark_columns,
Expand All @@ -72,7 +70,7 @@ impl StreamGlobalSimpleAgg {
}

pub fn agg_calls(&self) -> &[PlanAggCall] {
self.logical.agg_calls()
&self.logical.agg_calls
}
}

Expand All @@ -91,21 +89,24 @@ impl fmt::Display for StreamGlobalSimpleAgg {

impl PlanTreeNodeUnary for StreamGlobalSimpleAgg {
fn input(&self) -> PlanRef {
self.logical.input()
self.logical.input.clone()
}

fn clone_with_input(&self, input: PlanRef) -> Self {
Self::new(self.logical.clone_with_input(input), self.row_count_idx)
let logical = generic::Agg {
input,
..self.logical.clone()
};
Self::new(logical, self.row_count_idx)
}
}
impl_plan_tree_node_for_unary! { StreamGlobalSimpleAgg }

impl StreamNode for StreamGlobalSimpleAgg {
fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody {
use risingwave_pb::stream_plan::*;
let result_table = self.logical.infer_result_table(None);
let agg_states = self.logical.infer_stream_agg_state(None);
let distinct_dedup_tables = self.logical.infer_distinct_dedup_tables(None);
let (result_table, agg_states, distinct_dedup_tables) =
self.logical.infer_tables(&self.base, None);

PbNodeBody::GlobalSimpleAgg(SimpleAggNode {
agg_calls: self
Expand Down Expand Up @@ -153,14 +154,8 @@ impl ExprRewritable for StreamGlobalSimpleAgg {
}

fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
Self::new(
self.logical
.rewrite_exprs(r)
.as_logical_agg()
.unwrap()
.clone(),
self.row_count_idx,
)
.into()
let mut logical = self.logical.clone();
logical.rewrite_exprs(r);
Self::new(logical, self.row_count_idx).into()
}
}
66 changes: 30 additions & 36 deletions src/frontend/src/optimizer/plan_node/stream_hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use itertools::Itertools;
use risingwave_common::catalog::FieldDisplay;
use risingwave_pb::stream_plan::stream_node::PbNodeBody;

use super::generic::PlanAggCall;
use super::{ExprRewritable, LogicalAgg, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode};
use super::generic::{self, PlanAggCall};
use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode};
use crate::expr::ExprRewriter;
use crate::optimizer::property::Distribution;
use crate::stream_fragmenter::BuildFragmentGraphState;
Expand All @@ -29,7 +29,7 @@ use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct StreamHashAgg {
pub base: PlanBase,
logical: LogicalAgg,
logical: generic::Agg<PlanRef>,

/// An optional column index which is the vnode of each row computed by the input's consistent
/// hash distribution.
Expand All @@ -40,16 +40,18 @@ pub struct StreamHashAgg {
}

impl StreamHashAgg {
pub fn new(logical: LogicalAgg, vnode_col_idx: Option<usize>, row_count_idx: usize) -> Self {
assert_eq!(
logical.agg_calls()[row_count_idx],
PlanAggCall::count_star()
);

let ctx = logical.base.ctx.clone();
let pk_indices = logical.base.logical_pk.to_vec();
let schema = logical.schema().clone();
let input = logical.input();
pub fn new(
logical: generic::Agg<PlanRef>,
vnode_col_idx: Option<usize>,
row_count_idx: usize,
) -> Self {
assert_eq!(logical.agg_calls[row_count_idx], PlanAggCall::count_star());

let base = PlanBase::new_logical_with_core(&logical);
let ctx = base.ctx;
let pk_indices = base.logical_pk;
let schema = base.schema;
let input = logical.input.clone();
let input_dist = input.distribution();
let dist = match input_dist {
Distribution::HashShard(_) | Distribution::UpstreamHashShard(_, _) => logical
Expand All @@ -60,7 +62,7 @@ impl StreamHashAgg {

let mut watermark_columns = FixedBitSet::with_capacity(schema.len());
// Watermark column(s) must be in group key.
for (idx, input_idx) in logical.group_key().iter().enumerate() {
for (idx, input_idx) in logical.group_key.iter().enumerate() {
if input.watermark_columns().contains(*input_idx) {
watermark_columns.insert(idx);
}
Expand All @@ -71,7 +73,7 @@ impl StreamHashAgg {
ctx,
schema,
pk_indices,
logical.functional_dependency().clone(),
base.functional_dependency,
dist,
false,
watermark_columns,
Expand All @@ -85,11 +87,11 @@ impl StreamHashAgg {
}

pub fn agg_calls(&self) -> &[PlanAggCall] {
self.logical.agg_calls()
&self.logical.agg_calls
}

pub fn group_key(&self) -> &[usize] {
self.logical.group_key()
&self.logical.group_key
}

pub(crate) fn i2o_col_mapping(&self) -> ColIndexMapping {
Expand Down Expand Up @@ -124,25 +126,24 @@ impl fmt::Display for StreamHashAgg {

impl PlanTreeNodeUnary for StreamHashAgg {
fn input(&self) -> PlanRef {
self.logical.input()
self.logical.input.clone()
}

fn clone_with_input(&self, input: PlanRef) -> Self {
Self::new(
self.logical.clone_with_input(input),
self.vnode_col_idx,
self.row_count_idx,
)
let logical = generic::Agg {
input,
..self.logical.clone()
};
Self::new(logical, self.vnode_col_idx, self.row_count_idx)
}
}
impl_plan_tree_node_for_unary! { StreamHashAgg }

impl StreamNode for StreamHashAgg {
fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody {
use risingwave_pb::stream_plan::*;
let result_table = self.logical.infer_result_table(self.vnode_col_idx);
let agg_states = self.logical.infer_stream_agg_state(self.vnode_col_idx);
let distinct_dedup_tables = self.logical.infer_distinct_dedup_tables(self.vnode_col_idx);
let (result_table, agg_states, distinct_dedup_tables) =
self.logical.infer_tables(&self.base, self.vnode_col_idx);

PbNodeBody::HashAgg(HashAggNode {
group_key: self.group_key().iter().map(|idx| *idx as u32).collect(),
Expand Down Expand Up @@ -185,15 +186,8 @@ impl ExprRewritable for StreamHashAgg {
}

fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
Self::new(
self.logical
.rewrite_exprs(r)
.as_logical_agg()
.unwrap()
.clone(),
self.vnode_col_idx,
self.row_count_idx,
)
.into()
let mut logical = self.logical.clone();
logical.rewrite_exprs(r);
Self::new(logical, self.vnode_col_idx, self.row_count_idx).into()
}
}

0 comments on commit cb7c382

Please sign in to comment.