Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(plan_node): Use simple structures #8870

Merged
merged 14 commits into from
Mar 30, 2023
53 changes: 30 additions & 23 deletions src/frontend/src/optimizer/plan_node/batch_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::HashJoinNode;
use risingwave_pb::plan_common::JoinType;

use super::generic::GenericPlanRef;
use super::generic::{self, GenericPlanRef};
use super::{
EqJoinPredicate, ExprRewritable, LogicalJoin, PlanBase, PlanRef, PlanTreeNodeBinary, ToBatchPb,
EqJoinPredicate, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, ToBatchPb,
ToDistributedBatch,
};
use crate::expr::{Expr, ExprRewriter};
Expand All @@ -37,22 +37,26 @@ use crate::utils::ColIndexMappingRewriteExt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BatchHashJoin {
pub base: PlanBase,
logical: LogicalJoin,
logical: generic::Join<PlanRef>,

/// The join condition must be equivalent to `logical.on`, but separated into equal and
/// non-equal parts to facilitate execution later
eq_join_predicate: EqJoinPredicate,
}

impl BatchHashJoin {
pub fn new(logical: LogicalJoin, eq_join_predicate: EqJoinPredicate) -> Self {
let ctx = logical.base.ctx.clone();
pub fn new(
base: PlanBase,
logical: generic::Join<PlanRef>,
eq_join_predicate: EqJoinPredicate,
) -> Self {
let ctx = base.ctx.clone();
let dist = Self::derive_dist(
logical.left().distribution(),
logical.right().distribution(),
logical.left.distribution(),
logical.right.distribution(),
&logical,
);
let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any());
let base = PlanBase::new_batch(ctx, base.schema().clone(), dist, Order::any());

Self {
base,
Expand All @@ -64,13 +68,13 @@ impl BatchHashJoin {
pub(super) fn derive_dist(
left: &Distribution,
right: &Distribution,
logical: &LogicalJoin,
logical: &generic::Join<PlanRef>,
) -> Distribution {
match (left, right) {
(Distribution::Single, Distribution::Single) => Distribution::Single,
// we can not derive the hash distribution from the side where outer join can generate a
// NULL row
(Distribution::HashShard(_), Distribution::HashShard(_)) => match logical.join_type() {
(Distribution::HashShard(_), Distribution::HashShard(_)) => match logical.join_type {
JoinType::Unspecified => unreachable!(),
JoinType::FullOuter => Distribution::SomeShard,
JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => {
Expand Down Expand Up @@ -103,7 +107,7 @@ impl fmt::Display for BatchHashJoin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let verbose = self.base.ctx.is_explain_verbose();
let mut builder = f.debug_struct("BatchHashJoin");
builder.field("type", &self.logical.join_type());
builder.field("type", &self.logical.join_type);

let mut concat_schema = self.left().schema().fields.clone();
concat_schema.extend(self.right().schema().fields.clone());
Expand All @@ -119,7 +123,7 @@ impl fmt::Display for BatchHashJoin {
if verbose {
if self
.logical
.output_indices()
.output_indices
.iter()
.copied()
.eq(0..self.logical.internal_column_num())
Expand All @@ -129,7 +133,7 @@ impl fmt::Display for BatchHashJoin {
builder.field(
"output",
&IndicesDisplay {
indices: self.logical.output_indices(),
indices: &self.logical.output_indices,
input_schema: &concat_schema,
},
);
Expand All @@ -142,16 +146,20 @@ impl fmt::Display for BatchHashJoin {

impl PlanTreeNodeBinary for BatchHashJoin {
fn left(&self) -> PlanRef {
self.logical.left()
self.logical.left.clone()
}

fn right(&self) -> PlanRef {
self.logical.right()
self.logical.right.clone()
}

fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
let mut logical = self.logical.clone();
logical.left = left;
logical.right = right;
Self::new(
self.logical.clone_with_left_right(left, right),
self.base.clone_with_new_plan_id(),
logical,
self.eq_join_predicate.clone(),
)
}
Expand Down Expand Up @@ -221,7 +229,7 @@ impl ToDistributedBatch for BatchHashJoin {
impl ToBatchPb for BatchHashJoin {
fn to_batch_prost_body(&self) -> NodeBody {
NodeBody::HashJoin(HashJoinNode {
join_type: self.logical.join_type() as i32,
join_type: self.logical.join_type as i32,
left_key: self
.eq_join_predicate
.left_eq_indexes()
Expand All @@ -242,7 +250,7 @@ impl ToBatchPb for BatchHashJoin {
.map(|x| x.to_expr_proto()),
output_indices: self
.logical
.output_indices()
.output_indices
.iter()
.map(|&x| x as u32)
.collect(),
Expand All @@ -267,12 +275,11 @@ impl ExprRewritable for BatchHashJoin {
}

fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
let mut logical = self.logical.clone();
logical.rewrite_exprs(r);
Self::new(
self.logical
.rewrite_exprs(r)
.as_logical_join()
.unwrap()
.clone(),
self.base.clone_with_new_plan_id(),
logical,
self.eq_join_predicate.rewrite_exprs(r),
)
.into()
Expand Down
12 changes: 12 additions & 0 deletions src/frontend/src/optimizer/plan_node/generic/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,18 @@ impl<PlanRef: GenericPlanRef> Join<PlanRef> {
self.i2r_col_mapping().inverse()
}

/// Get the Mapping of columnIndex from internal column index to output column index
pub fn i2o_col_mapping(&self) -> ColIndexMapping {
ColIndexMapping::with_remaining_columns(&self.output_indices, self.internal_column_num())
}

/// Get the Mapping of columnIndex from output column index to internal column index
pub fn o2i_col_mapping(&self) -> ColIndexMapping {
// If output_indices = [0, 0, 1], we should use it as `o2i_col_mapping` directly.
// If we use `self.i2o_col_mapping().inverse()`, we will lose the first 0.
ColIndexMapping::new(self.output_indices.iter().map(|x| Some(*x)).collect())
}

pub fn add_which_join_key_to_pk(&self) -> EitherOrBoth<(), ()> {
match self.join_type {
JoinType::Inner => {
Expand Down
27 changes: 9 additions & 18 deletions src/frontend/src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,9 @@ impl LogicalJoin {
self.core.l2i_col_mapping()
}

/// Get the Mapping of columnIndex from right column index to internal column index.
pub fn r2i_col_mapping(&self) -> ColIndexMapping {
self.core.r2i_col_mapping()
}

/// get the Mapping of columnIndex from internal column index to output column index
pub fn i2o_col_mapping(&self) -> ColIndexMapping {
ColIndexMapping::with_remaining_columns(self.output_indices(), self.internal_column_num())
}

/// get the Mapping of columnIndex from output column index to internal column index
pub fn o2i_col_mapping(&self) -> ColIndexMapping {
// If output_indices = [0, 0, 1], we should use it as `o2i_col_mapping` directly.
// If we use `self.i2o_col_mapping().inverse()`, we will lose the first 0.
ColIndexMapping::new(self.output_indices().iter().map(|x| Some(*x)).collect())
self.core.i2o_col_mapping()
}

/// Get a reference to the logical join's on.
Expand Down Expand Up @@ -570,7 +558,7 @@ impl PlanTreeNodeBinary for LogicalJoin {
join.internal_column_num(),
);

let old_o2i = self.o2i_col_mapping();
let old_o2i = self.core.o2i_col_mapping();

let old_o2l = old_o2i
.composite(&self.core.i2l_col_mapping())
Expand Down Expand Up @@ -812,7 +800,7 @@ impl PredicatePushdown for LogicalJoin {
let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());

// rewrite output col referencing indices as internal cols
let mut mapping = self.o2i_col_mapping();
let mut mapping = self.core.o2i_col_mapping();

predicate = predicate.rewrite_expr(&mut mapping);

Expand Down Expand Up @@ -959,7 +947,8 @@ impl LogicalJoin {
self.right().schema().len(),
);
let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
let hash_join = StreamHashJoin::new(logical_join, eq_cond).into();
let hash_join =
StreamHashJoin::new(logical_join.base, logical_join.core, eq_cond).into();
let logical_filter = LogicalFilter::new(hash_join, predicate.non_eq_cond());
let plan = StreamFilter::new(logical_filter).into();
if self.output_indices() != &default_indices {
Expand All @@ -975,7 +964,7 @@ impl LogicalJoin {
Ok(plan)
}
} else {
Ok(StreamHashJoin::new(logical_join, predicate).into())
Ok(StreamHashJoin::new(logical_join.base, logical_join.core, predicate).into())
}
}

Expand Down Expand Up @@ -1220,7 +1209,7 @@ impl LogicalJoin {
logical_join: LogicalJoin,
) -> Result<PlanRef> {
assert!(predicate.has_eq());
Ok(BatchHashJoin::new(logical_join, predicate).into())
Ok(BatchHashJoin::new(logical_join.base, logical_join.core, predicate).into())
}

pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<PlanRef> {
Expand Down Expand Up @@ -1410,9 +1399,11 @@ impl ToStream for LogicalJoin {
// ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information

let l2o = join_with_pk
.core
.l2i_col_mapping()
.composite(&join_with_pk.i2o_col_mapping());
let r2o = join_with_pk
.core
.r2i_col_mapping()
.composite(&join_with_pk.i2o_col_mapping());
let left_right_stream_keys = join_with_pk
Expand Down
57 changes: 32 additions & 25 deletions src/frontend/src/optimizer/plan_node/stream_delta_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use risingwave_pb::plan_common::JoinType;
use risingwave_pb::stream_plan::stream_node::NodeBody;
use risingwave_pb::stream_plan::{ArrangementInfo, DeltaIndexJoinNode};

use super::generic::GenericPlanRef;
use super::{ExprRewritable, LogicalJoin, PlanBase, PlanRef, PlanTreeNodeBinary, StreamNode};
use super::generic::{self, GenericPlanRef};
use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamNode};
use crate::expr::{Expr, ExprRewriter};
use crate::optimizer::plan_node::stream::StreamPlanRef;
use crate::optimizer::plan_node::utils::IndicesDisplay;
Expand All @@ -35,19 +35,23 @@ use crate::utils::ColIndexMappingRewriteExt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct StreamDeltaJoin {
pub base: PlanBase,
logical: LogicalJoin,
logical: generic::Join<PlanRef>,

/// The join condition must be equivalent to `logical.on`, but separated into equal and
/// non-equal parts to facilitate execution later
eq_join_predicate: EqJoinPredicate,
}

impl StreamDeltaJoin {
pub fn new(logical: LogicalJoin, eq_join_predicate: EqJoinPredicate) -> Self {
let ctx = logical.base.ctx.clone();
pub fn new(
base: PlanBase,
logical: generic::Join<PlanRef>,
eq_join_predicate: EqJoinPredicate,
) -> Self {
let ctx = base.ctx.clone();
// Inner join won't change the append-only behavior of the stream. The rest might.
let append_only = match logical.join_type() {
JoinType::Inner => logical.left().append_only() && logical.right().append_only(),
let append_only = match logical.join_type {
JoinType::Inner => logical.left.append_only() && logical.right.append_only(),
_ => todo!("delta join only supports inner join for now"),
};
if eq_join_predicate.has_non_eq() {
Expand All @@ -60,19 +64,19 @@ impl StreamDeltaJoin {
let watermark_columns = {
let from_left = logical
.l2i_col_mapping()
.rewrite_bitset(logical.left().watermark_columns());
.rewrite_bitset(logical.left.watermark_columns());
let from_right = logical
.r2i_col_mapping()
.rewrite_bitset(logical.right().watermark_columns());
.rewrite_bitset(logical.right.watermark_columns());
let watermark_columns = from_left.bitand(&from_right);
logical.i2o_col_mapping().rewrite_bitset(&watermark_columns)
};
// TODO: derive from input
let base = PlanBase::new_stream(
ctx,
logical.schema().clone(),
logical.base.logical_pk.to_vec(),
logical.functional_dependency().clone(),
base.schema().clone(),
base.logical_pk.to_vec(),
base.functional_dependency().clone(),
dist,
append_only,
watermark_columns,
Expand All @@ -95,7 +99,7 @@ impl fmt::Display for StreamDeltaJoin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let verbose = self.base.ctx.is_explain_verbose();
let mut builder = f.debug_struct("StreamDeltaJoin");
builder.field("type", &self.logical.join_type());
builder.field("type", &self.logical.join_type);

let mut concat_schema = self.left().schema().fields.clone();
concat_schema.extend(self.right().schema().fields.clone());
Expand All @@ -111,7 +115,7 @@ impl fmt::Display for StreamDeltaJoin {
if verbose {
if self
.logical
.output_indices()
.output_indices
.iter()
.copied()
.eq(0..self.logical.internal_column_num())
Expand All @@ -121,7 +125,7 @@ impl fmt::Display for StreamDeltaJoin {
builder.field(
"output",
&IndicesDisplay {
indices: self.logical.output_indices(),
indices: &self.logical.output_indices,
input_schema: &concat_schema,
},
);
Expand All @@ -134,16 +138,20 @@ impl fmt::Display for StreamDeltaJoin {

impl PlanTreeNodeBinary for StreamDeltaJoin {
fn left(&self) -> PlanRef {
self.logical.left()
self.logical.left.clone()
}

fn right(&self) -> PlanRef {
self.logical.right()
self.logical.right.clone()
}

fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
let mut logical = self.logical.clone();
logical.left = left;
logical.right = right;
Self::new(
self.logical.clone_with_left_right(left, right),
self.base.clone_with_new_plan_id(),
logical,
self.eq_join_predicate.clone(),
)
}
Expand Down Expand Up @@ -172,7 +180,7 @@ impl StreamNode for StreamDeltaJoin {
// TODO: add a separate delta join node in proto, or move fragmenter to frontend so that we
// don't need an intermediate representation.
NodeBody::DeltaIndexJoin(DeltaIndexJoinNode {
join_type: self.logical.join_type() as i32,
join_type: self.logical.join_type as i32,
left_key: self
.eq_join_predicate
.left_eq_indexes()
Expand Down Expand Up @@ -216,7 +224,7 @@ impl StreamNode for StreamDeltaJoin {
}),
output_indices: self
.logical
.output_indices()
.output_indices
.iter()
.map(|&x| x as u32)
.collect(),
Expand All @@ -230,12 +238,11 @@ impl ExprRewritable for StreamDeltaJoin {
}

fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
let mut logical = self.logical.clone();
logical.rewrite_exprs(r);
Self::new(
self.logical
.rewrite_exprs(r)
.as_logical_join()
.unwrap()
.clone(),
self.base.clone_with_new_plan_id(),
logical,
self.eq_join_predicate.rewrite_exprs(r),
)
.into()
Expand Down
Loading