From fc473387a25a61147457fa733b86707dcd167739 Mon Sep 17 00:00:00 2001 From: Arenatlx <314806019@qq.com> Date: Fri, 19 Jul 2024 16:27:08 +0800 Subject: [PATCH] planner: classify logical join's logic into a seperate file for later pkg move. (#54741) ref pingcap/tidb#51664, ref pingcap/tidb#52714 --- pkg/planner/core/exhaust_physical_plans.go | 102 +- pkg/planner/core/explain.go | 22 - pkg/planner/core/logical_initialize.go | 6 - pkg/planner/core/logical_join.go | 944 +++++++++++++++++- pkg/planner/core/property_cols_prune.go | 25 - pkg/planner/core/rule_build_key_info.go | 46 - pkg/planner/core/rule_column_pruning.go | 58 -- pkg/planner/core/rule_constant_propagation.go | 118 --- pkg/planner/core/rule_decorrelate.go | 4 +- pkg/planner/core/rule_eliminate_projection.go | 16 - pkg/planner/core/rule_outer_to_inner_join.go | 55 - pkg/planner/core/rule_predicate_push_down.go | 239 ----- pkg/planner/core/rule_topn_push_down.go | 53 - pkg/planner/core/stats.go | 103 -- 14 files changed, 918 insertions(+), 873 deletions(-) diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index aae131ebfc9f0..17edbe3cac1db 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -2363,104 +2363,6 @@ func preferMppBCJ(p *LogicalJoin) bool { return checkChildFitBC(p.Children()[0]) || checkChildFitBC(p.Children()[1]) } -// ExhaustPhysicalPlans implements LogicalPlan interface -// it can generates hash join, index join and sort merge join. -// Firstly we check the hint, if hint is figured by user, we force to choose the corresponding physical plan. -// If the hint is not matched, it will get other candidates. -// If the hint is not figured, we will pick all candidates. -func (p *LogicalJoin) ExhaustPhysicalPlans(prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { - if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { - indexJoins, _ := tryToGetIndexJoin(p, prop) - failpoint.Return(indexJoins, true, nil) - } - }) - - if !isJoinHintSupportedInMPPMode(p.PreferJoinType) { - if hasMPPJoinHints(p.PreferJoinType) { - // If there are MPP hints but has some conflicts join method hints, all the join hints are invalid. - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning("The MPP join hints are in conflict, and you can only specify join method hints that are currently supported by MPP mode now") - p.PreferJoinType = 0 - } else { - // If there are no MPP hints but has some conflicts join method hints, the MPP mode will be blocked. - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because you have used hint to specify a join algorithm which is not supported by mpp now.") - if prop.IsFlashProp() { - return nil, false, nil - } - } - } - if prop.MPPPartitionTp == property.BroadcastType { - return nil, false, nil - } - joins := make([]base.PhysicalPlan, 0, 8) - canPushToTiFlash := p.CanPushToCop(kv.TiFlash) - if p.SCtx().GetSessionVars().IsMPPAllowed() && canPushToTiFlash { - if (p.PreferJoinType & h.PreferShuffleJoin) > 0 { - if shuffleJoins := tryToGetMppHashJoin(p, prop, false); len(shuffleJoins) > 0 { - return shuffleJoins, true, nil - } - } - if (p.PreferJoinType & h.PreferBCJoin) > 0 { - if bcastJoins := tryToGetMppHashJoin(p, prop, true); len(bcastJoins) > 0 { - return bcastJoins, true, nil - } - } - if preferMppBCJ(p) { - mppJoins := tryToGetMppHashJoin(p, prop, true) - joins = append(joins, mppJoins...) - } else { - mppJoins := tryToGetMppHashJoin(p, prop, false) - joins = append(joins, mppJoins...) - } - } else { - hasMppHints := false - var errMsg string - if (p.PreferJoinType & h.PreferShuffleJoin) > 0 { - errMsg = "The join can not push down to the MPP side, the shuffle_join() hint is invalid" - hasMppHints = true - } - if (p.PreferJoinType & h.PreferBCJoin) > 0 { - errMsg = "The join can not push down to the MPP side, the broadcast_join() hint is invalid" - hasMppHints = true - } - if hasMppHints { - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) - } - } - if prop.IsFlashProp() { - return joins, true, nil - } - - if !p.isNAAJ() { - // naaj refuse merge join and index join. - mergeJoins := GetMergeJoin(p, prop, p.Schema(), p.StatsInfo(), p.Children()[0].StatsInfo(), p.Children()[1].StatsInfo()) - if (p.PreferJoinType&h.PreferMergeJoin) > 0 && len(mergeJoins) > 0 { - return mergeJoins, true, nil - } - joins = append(joins, mergeJoins...) - - indexJoins, forced := tryToGetIndexJoin(p, prop) - if forced { - return indexJoins, true, nil - } - joins = append(joins, indexJoins...) - } - - hashJoins, forced := getHashJoins(p, prop) - if forced && len(hashJoins) > 0 { - return hashJoins, true, nil - } - joins = append(joins, hashJoins...) - - if p.PreferJoinType > 0 { - // If we reach here, it means we have a hint that doesn't work. - // It might be affected by the required property, so we enforce - // this property and try the hint again. - return joins, false, nil - } - return joins, true, nil -} - func canExprsInJoinPushdown(p *LogicalJoin, storeType kv.StoreType) bool { equalExprs := make([]expression.Expression, 0, len(p.EqualConditions)) for _, eqCondition := range p.EqualConditions { @@ -2558,7 +2460,7 @@ func tryToGetMppHashJoin(p *LogicalJoin, prop *property.PhysicalProperty, useBCJ preferredBuildIndex = 1 } } else if p.JoinType.IsSemiJoin() { - if !useBCJ && !p.isNAAJ() && len(p.EqualConditions) > 0 && (p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin) { + if !useBCJ && !p.IsNAAJ() && len(p.EqualConditions) > 0 && (p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin) { // TiFlash only supports Non-null_aware non-cross semi/anti_semi join to use both sides as build side preferredBuildIndex = 1 // MPPOuterJoinFixedBuildSide default value is false @@ -2577,7 +2479,7 @@ func tryToGetMppHashJoin(p *LogicalJoin, prop *property.PhysicalProperty, useBCJ // 1. it is a broadcast join(for broadcast join, it makes sense to use the broadcast side as the build side) // 2. or session variable MPPOuterJoinFixedBuildSide is set to true // 3. or nullAware/cross joins - if useBCJ || p.isNAAJ() || len(p.EqualConditions) == 0 || p.SCtx().GetSessionVars().MPPOuterJoinFixedBuildSide { + if useBCJ || p.IsNAAJ() || len(p.EqualConditions) == 0 || p.SCtx().GetSessionVars().MPPOuterJoinFixedBuildSide { if !p.SCtx().GetSessionVars().MPPOuterJoinFixedBuildSide { // The hint has higher priority than variable. fixedBuildSide = true diff --git a/pkg/planner/core/explain.go b/pkg/planner/core/explain.go index 7a99f8680cbf7..600eed89ce436 100644 --- a/pkg/planner/core/explain.go +++ b/pkg/planner/core/explain.go @@ -880,28 +880,6 @@ func formatWindowFuncDescs(ctx expression.EvalContext, buffer *bytes.Buffer, des return buffer } -// ExplainInfo implements Plan interface. -func (p *LogicalJoin) ExplainInfo() string { - evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() - buffer := bytes.NewBufferString(p.JoinType.String()) - if len(p.EqualConditions) > 0 { - fmt.Fprintf(buffer, ", equal:%v", p.EqualConditions) - } - if len(p.LeftConditions) > 0 { - fmt.Fprintf(buffer, ", left cond:%s", - expression.SortedExplainExpressionList(evalCtx, p.LeftConditions)) - } - if len(p.RightConditions) > 0 { - fmt.Fprintf(buffer, ", right cond:%s", - expression.SortedExplainExpressionList(evalCtx, p.RightConditions)) - } - if len(p.OtherConditions) > 0 { - fmt.Fprintf(buffer, ", other cond:%s", - expression.SortedExplainExpressionList(evalCtx, p.OtherConditions)) - } - return buffer.String() -} - // ExplainInfo implements Plan interface. func (p *LogicalApply) ExplainInfo() string { return p.LogicalJoin.ExplainInfo() diff --git a/pkg/planner/core/logical_initialize.go b/pkg/planner/core/logical_initialize.go index 414081498f817..53d464450ee35 100644 --- a/pkg/planner/core/logical_initialize.go +++ b/pkg/planner/core/logical_initialize.go @@ -20,12 +20,6 @@ import ( "github.com/pingcap/tidb/pkg/util/plancodec" ) -// Init initializes LogicalJoin. -func (p LogicalJoin) Init(ctx base.PlanContext, offset int) *LogicalJoin { - p.BaseLogicalPlan = logicalop.NewBaseLogicalPlan(ctx, plancodec.TypeJoin, &p, offset) - return &p -} - // Init initializes DataSource. func (ds DataSource) Init(ctx base.PlanContext, offset int) *DataSource { ds.BaseLogicalPlan = logicalop.NewBaseLogicalPlan(ctx, plancodec.TypeDataSource, &ds, offset) diff --git a/pkg/planner/core/logical_join.go b/pkg/planner/core/logical_join.go index 2919fe1094309..050b3a250ca7c 100644 --- a/pkg/planner/core/logical_join.go +++ b/pkg/planner/core/logical_join.go @@ -15,16 +15,28 @@ package core import ( + "bytes" + "fmt" + "math" + + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cardinality" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/cost" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" "github.com/pingcap/tidb/pkg/planner/funcdep" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" + "github.com/pingcap/tidb/pkg/planner/util/utilfuncp" "github.com/pingcap/tidb/pkg/types" utilhint "github.com/pingcap/tidb/pkg/util/hint" "github.com/pingcap/tidb/pkg/util/intset" + "github.com/pingcap/tidb/pkg/util/plancodec" ) // JoinType contains CrossJoin, InnerJoin, LeftOuterJoin, RightOuterJoin, SemiJoin, AntiJoin. @@ -132,17 +144,586 @@ type LogicalJoin struct { EqualCondOutCnt float64 } -func (p *LogicalJoin) isNAAJ() bool { - return len(p.NAEQConditions) > 0 +// Init initializes LogicalJoin. +func (p LogicalJoin) Init(ctx base.PlanContext, offset int) *LogicalJoin { + p.BaseLogicalPlan = logicalop.NewBaseLogicalPlan(ctx, plancodec.TypeJoin, &p, offset) + return &p } -// Shallow shallow copies a LogicalJoin struct. -func (p *LogicalJoin) Shallow() *LogicalJoin { - join := *p - return join.Init(p.SCtx(), p.QueryBlockOffset()) +// *************************** start implementation of Plan interface *************************** + +// ExplainInfo implements Plan interface. +func (p *LogicalJoin) ExplainInfo() string { + evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() + buffer := bytes.NewBufferString(p.JoinType.String()) + if len(p.EqualConditions) > 0 { + fmt.Fprintf(buffer, ", equal:%v", p.EqualConditions) + } + if len(p.LeftConditions) > 0 { + fmt.Fprintf(buffer, ", left cond:%s", + expression.SortedExplainExpressionList(evalCtx, p.LeftConditions)) + } + if len(p.RightConditions) > 0 { + fmt.Fprintf(buffer, ", right cond:%s", + expression.SortedExplainExpressionList(evalCtx, p.RightConditions)) + } + if len(p.OtherConditions) > 0 { + fmt.Fprintf(buffer, ", other cond:%s", + expression.SortedExplainExpressionList(evalCtx, p.OtherConditions)) + } + return buffer.String() +} + +// ReplaceExprColumns implements base.LogicalPlan interface. +func (p *LogicalJoin) ReplaceExprColumns(replace map[string]*expression.Column) { + for _, equalExpr := range p.EqualConditions { + ResolveExprAndReplace(equalExpr, replace) + } + for _, leftExpr := range p.LeftConditions { + ResolveExprAndReplace(leftExpr, replace) + } + for _, rightExpr := range p.RightConditions { + ResolveExprAndReplace(rightExpr, replace) + } + for _, otherExpr := range p.OtherConditions { + ResolveExprAndReplace(otherExpr, replace) + } +} + +// *************************** end implementation of Plan interface *************************** + +// *************************** start implementation of logicalPlan interface *************************** + +// HashCode inherits the BaseLogicalPlan.LogicalPlan.<0th> implementation. + +// PredicatePushDown implements the base.LogicalPlan.<1st> interface. +func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression, opt *optimizetrace.LogicalOptimizeOp) (ret []expression.Expression, retPlan base.LogicalPlan) { + var equalCond []*expression.ScalarFunction + var leftPushCond, rightPushCond, otherCond, leftCond, rightCond []expression.Expression + switch p.JoinType { + case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + predicates = p.outerJoinPropConst(predicates) + dual := Conds2TableDual(p, predicates) + if dual != nil { + appendTableDualTraceStep(p, dual, predicates, opt) + return ret, dual + } + // Handle where conditions + predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) + // Only derive left where condition, because right where condition cannot be pushed down + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, true, false) + leftCond = leftPushCond + // Handle join conditions, only derive right join condition, because left join condition cannot be pushed down + _, derivedRightJoinCond := DeriveOtherConditions( + p, p.Children()[0].Schema(), p.Children()[1].Schema(), false, true) + rightCond = append(p.RightConditions, derivedRightJoinCond...) + p.RightConditions = nil + ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) + ret = append(ret, rightPushCond...) + case RightOuterJoin: + predicates = p.outerJoinPropConst(predicates) + dual := Conds2TableDual(p, predicates) + if dual != nil { + appendTableDualTraceStep(p, dual, predicates, opt) + return ret, dual + } + // Handle where conditions + predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) + // Only derive right where condition, because left where condition cannot be pushed down + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, false, true) + rightCond = rightPushCond + // Handle join conditions, only derive left join condition, because right join condition cannot be pushed down + derivedLeftJoinCond, _ := DeriveOtherConditions( + p, p.Children()[0].Schema(), p.Children()[1].Schema(), true, false) + leftCond = append(p.LeftConditions, derivedLeftJoinCond...) + p.LeftConditions = nil + ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) + ret = append(ret, leftPushCond...) + case SemiJoin, InnerJoin: + tempCond := make([]expression.Expression, 0, len(p.LeftConditions)+len(p.RightConditions)+len(p.EqualConditions)+len(p.OtherConditions)+len(predicates)) + tempCond = append(tempCond, p.LeftConditions...) + tempCond = append(tempCond, p.RightConditions...) + tempCond = append(tempCond, expression.ScalarFuncs2Exprs(p.EqualConditions)...) + tempCond = append(tempCond, p.OtherConditions...) + tempCond = append(tempCond, predicates...) + tempCond = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), tempCond) + tempCond = expression.PropagateConstant(p.SCtx().GetExprCtx(), tempCond) + // Return table dual when filter is constant false or null. + dual := Conds2TableDual(p, tempCond) + if dual != nil { + appendTableDualTraceStep(p, dual, tempCond, opt) + return ret, dual + } + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(tempCond, true, true) + p.LeftConditions = nil + p.RightConditions = nil + p.EqualConditions = equalCond + p.OtherConditions = otherCond + leftCond = leftPushCond + rightCond = rightPushCond + case AntiSemiJoin: + predicates = expression.PropagateConstant(p.SCtx().GetExprCtx(), predicates) + // Return table dual when filter is constant false or null. + dual := Conds2TableDual(p, predicates) + if dual != nil { + appendTableDualTraceStep(p, dual, predicates, opt) + return ret, dual + } + // `predicates` should only contain left conditions or constant filters. + _, leftPushCond, rightPushCond, _ = p.extractOnCondition(predicates, true, true) + // Do not derive `is not null` for anti join, since it may cause wrong results. + // For example: + // `select * from t t1 where t1.a not in (select b from t t2)` does not imply `t2.b is not null`, + // `select * from t t1 where t1.a not in (select a from t t2 where t1.b = t2.b` does not imply `t1.b is not null`, + // `select * from t t1 where not exists (select * from t t2 where t2.a = t1.a)` does not imply `t1.a is not null`, + leftCond = leftPushCond + rightCond = append(p.RightConditions, rightPushCond...) + p.RightConditions = nil + } + leftCond = expression.RemoveDupExprs(leftCond) + rightCond = expression.RemoveDupExprs(rightCond) + leftRet, lCh := p.Children()[0].PredicatePushDown(leftCond, opt) + rightRet, rCh := p.Children()[1].PredicatePushDown(rightCond, opt) + utilfuncp.AddSelection(p, lCh, leftRet, 0, opt) + utilfuncp.AddSelection(p, rCh, rightRet, 1, opt) + p.updateEQCond() + buildKeyInfo(p) + return ret, p.Self() +} + +// PruneColumns implements the base.LogicalPlan.<2nd> interface. +func (p *LogicalJoin) PruneColumns(parentUsedCols []*expression.Column, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, error) { + leftCols, rightCols := p.extractUsedCols(parentUsedCols) + + var err error + p.Children()[0], err = p.Children()[0].PruneColumns(leftCols, opt) + if err != nil { + return nil, err + } + addConstOneForEmptyProjection(p.Children()[0]) + + p.Children()[1], err = p.Children()[1].PruneColumns(rightCols, opt) + if err != nil { + return nil, err + } + addConstOneForEmptyProjection(p.Children()[1]) + + p.mergeSchema() + if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { + joinCol := p.Schema().Columns[len(p.Schema().Columns)-1] + parentUsedCols = append(parentUsedCols, joinCol) + } + p.InlineProjection(parentUsedCols, opt) + return p, nil +} + +// FindBestTask inherits the BaseLogicalPlan.LogicalPlan.<3rd> implementation. + +// BuildKeyInfo implements the base.LogicalPlan.<4th> interface. +func (p *LogicalJoin) BuildKeyInfo(selfSchema *expression.Schema, childSchema []*expression.Schema) { + p.LogicalSchemaProducer.BuildKeyInfo(selfSchema, childSchema) + switch p.JoinType { + case SemiJoin, LeftOuterSemiJoin, AntiSemiJoin, AntiLeftOuterSemiJoin: + selfSchema.Keys = childSchema[0].Clone().Keys + case InnerJoin, LeftOuterJoin, RightOuterJoin: + // If there is no equal conditions, then cartesian product can't be prevented and unique key information will destroy. + if len(p.EqualConditions) == 0 { + return + } + lOk := false + rOk := false + // Such as 'select * from t1 join t2 where t1.a = t2.a and t1.b = t2.b'. + // If one sides (a, b) is a unique key, then the unique key information is remained. + // But we don't consider this situation currently. + // Only key made by one column is considered now. + evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() + for _, expr := range p.EqualConditions { + ln := expr.GetArgs()[0].(*expression.Column) + rn := expr.GetArgs()[1].(*expression.Column) + for _, key := range childSchema[0].Keys { + if len(key) == 1 && key[0].Equal(evalCtx, ln) { + lOk = true + break + } + } + for _, key := range childSchema[1].Keys { + if len(key) == 1 && key[0].Equal(evalCtx, rn) { + rOk = true + break + } + } + } + // For inner join, if one side of one equal condition is unique key, + // another side's unique key information will all be reserved. + // If it's an outer join, NULL value will fill some position, which will destroy the unique key information. + if lOk && p.JoinType != LeftOuterJoin { + selfSchema.Keys = append(selfSchema.Keys, childSchema[1].Keys...) + } + if rOk && p.JoinType != RightOuterJoin { + selfSchema.Keys = append(selfSchema.Keys, childSchema[0].Keys...) + } + } +} + +// PushDownTopN implements the base.LogicalPlan.<5th> interface. +func (p *LogicalJoin) PushDownTopN(topNLogicalPlan base.LogicalPlan, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { + var topN *LogicalTopN + if topNLogicalPlan != nil { + topN = topNLogicalPlan.(*LogicalTopN) + } + switch p.JoinType { + case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + p.Children()[0] = p.pushDownTopNToChild(topN, 0, opt) + p.Children()[1] = p.Children()[1].PushDownTopN(nil, opt) + case RightOuterJoin: + p.Children()[1] = p.pushDownTopNToChild(topN, 1, opt) + p.Children()[0] = p.Children()[0].PushDownTopN(nil, opt) + default: + return p.BaseLogicalPlan.PushDownTopN(topN, opt) + } + + // The LogicalJoin may be also a LogicalApply. So we must use self to set parents. + if topN != nil { + return topN.AttachChild(p.Self(), opt) + } + return p.Self() +} + +// DeriveTopN inherits the BaseLogicalPlan.LogicalPlan.<6th> implementation. + +// PredicateSimplification inherits the BaseLogicalPlan.LogicalPlan.<7th> implementation. + +// ConstantPropagation implements the base.LogicalPlan.<8th> interface. +// about the logic of constant propagation in From List. +// Query: select * from t, (select a, b from s where s.a>1) tmp where tmp.a=t.a +// Origin logical plan: +/* + +----------------+ + | LogicalJoin | + +-------^--------+ + | + +-------------+--------------+ + | | ++-----+------+ +------+------+ +| Projection | | TableScan | ++-----^------+ +-------------+ + | + | ++-----+------+ +| Selection | +| s.a>1 | ++------------+ +*/ +// 1. 'PullUpConstantPredicates': Call this function until find selection and pull up the constant predicate layer by layer +// LogicalSelection: find the s.a>1 +// LogicalProjection: get the s.a>1 and pull up it, changed to tmp.a>1 +// 2. 'addCandidateSelection': Add selection above of LogicalJoin, +// put all predicates pulled up from the lower layer into the current new selection. +// LogicalSelection: tmp.a >1 +// +// Optimized plan: +/* + +----------------+ + | Selection | + | tmp.a>1 | + +-------^--------+ + | + +-------+--------+ + | LogicalJoin | + +-------^--------+ + | + +-------------+--------------+ + | | ++-----+------+ +------+------+ +| Projection | | TableScan | ++-----^------+ +-------------+ + | + | ++-----+------+ +| Selection | +| s.a>1 | ++------------+ +*/ +// Return nil if the root of plan has not been changed +// Return new root if the root of plan is changed to selection +func (p *LogicalJoin) ConstantPropagation(parentPlan base.LogicalPlan, currentChildIdx int, opt *optimizetrace.LogicalOptimizeOp) (newRoot base.LogicalPlan) { + // step1: get constant predicate from left or right according to the JoinType + var getConstantPredicateFromLeft bool + var getConstantPredicateFromRight bool + switch p.JoinType { + case LeftOuterJoin: + getConstantPredicateFromLeft = true + case RightOuterJoin: + getConstantPredicateFromRight = true + case InnerJoin: + getConstantPredicateFromLeft = true + getConstantPredicateFromRight = true + default: + return + } + var candidateConstantPredicates []expression.Expression + if getConstantPredicateFromLeft { + candidateConstantPredicates = p.Children()[0].PullUpConstantPredicates() + } + if getConstantPredicateFromRight { + candidateConstantPredicates = append(candidateConstantPredicates, p.Children()[1].PullUpConstantPredicates()...) + } + if len(candidateConstantPredicates) == 0 { + return + } + + // step2: add selection above of LogicalJoin + return addCandidateSelection(p, currentChildIdx, parentPlan, candidateConstantPredicates, opt) +} + +// PullUpConstantPredicates inherits the BaseLogicalPlan.LogicalPlan.<9th> implementation. + +// RecursiveDeriveStats inherits the BaseLogicalPlan.LogicalPlan.<10th> implementation. + +// DeriveStats implements the base.LogicalPlan.<11th> interface. +// If the type of join is SemiJoin, the selectivity of it will be same as selection's. +// If the type of join is LeftOuterSemiJoin, it will not add or remove any row. The last column is a boolean value, whose NDV should be two. +// If the type of join is inner/outer join, the output of join(s, t) should be N(s) * N(t) / (V(s.key) * V(t.key)) * Min(s.key, t.key). +// N(s) stands for the number of rows in relation s. V(s.key) means the NDV of join key in s. +// This is a quite simple strategy: We assume every bucket of relation which will participate join has the same number of rows, and apply cross join for +// every matched bucket. +func (p *LogicalJoin) DeriveStats(childStats []*property.StatsInfo, selfSchema *expression.Schema, childSchema []*expression.Schema, colGroups [][]*expression.Column) (*property.StatsInfo, error) { + if p.StatsInfo() != nil { + // Reload GroupNDVs since colGroups may have changed. + p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) + return p.StatsInfo(), nil + } + leftProfile, rightProfile := childStats[0], childStats[1] + leftJoinKeys, rightJoinKeys, _, _ := p.GetJoinKeys() + p.EqualCondOutCnt = cardinality.EstimateFullJoinRowCount(p.SCtx(), + 0 == len(p.EqualConditions), + leftProfile, rightProfile, + leftJoinKeys, rightJoinKeys, + childSchema[0], childSchema[1], + nil, nil) + if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin { + p.SetStats(&property.StatsInfo{ + RowCount: leftProfile.RowCount * cost.SelectionFactor, + ColNDVs: make(map[int64]float64, len(leftProfile.ColNDVs)), + }) + for id, c := range leftProfile.ColNDVs { + p.StatsInfo().ColNDVs[id] = c * cost.SelectionFactor + } + return p.StatsInfo(), nil + } + if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { + p.SetStats(&property.StatsInfo{ + RowCount: leftProfile.RowCount, + ColNDVs: make(map[int64]float64, selfSchema.Len()), + }) + for id, c := range leftProfile.ColNDVs { + p.StatsInfo().ColNDVs[id] = c + } + p.StatsInfo().ColNDVs[selfSchema.Columns[selfSchema.Len()-1].UniqueID] = 2.0 + p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) + return p.StatsInfo(), nil + } + count := p.EqualCondOutCnt + if p.JoinType == LeftOuterJoin { + count = math.Max(count, leftProfile.RowCount) + } else if p.JoinType == RightOuterJoin { + count = math.Max(count, rightProfile.RowCount) + } + colNDVs := make(map[int64]float64, selfSchema.Len()) + for id, c := range leftProfile.ColNDVs { + colNDVs[id] = math.Min(c, count) + } + for id, c := range rightProfile.ColNDVs { + colNDVs[id] = math.Min(c, count) + } + p.SetStats(&property.StatsInfo{ + RowCount: count, + ColNDVs: colNDVs, + }) + p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) + return p.StatsInfo(), nil +} + +// ExtractColGroups implements the base.LogicalPlan.<12th> interface. +func (p *LogicalJoin) ExtractColGroups(colGroups [][]*expression.Column) [][]*expression.Column { + leftJoinKeys, rightJoinKeys, _, _ := p.GetJoinKeys() + extracted := make([][]*expression.Column, 0, 2+len(colGroups)) + if len(leftJoinKeys) > 1 && (p.JoinType == InnerJoin || p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin) { + extracted = append(extracted, expression.SortColumns(leftJoinKeys), expression.SortColumns(rightJoinKeys)) + } + var outerSchema *expression.Schema + if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { + outerSchema = p.Children()[0].Schema() + } else if p.JoinType == RightOuterJoin { + outerSchema = p.Children()[1].Schema() + } + if len(colGroups) == 0 || outerSchema == nil { + return extracted + } + _, offsets := outerSchema.ExtractColGroups(colGroups) + if len(offsets) == 0 { + return extracted + } + for _, offset := range offsets { + extracted = append(extracted, colGroups[offset]) + } + return extracted +} + +// PreparePossibleProperties implements base.LogicalPlan.<13th> interface. +func (p *LogicalJoin) PreparePossibleProperties(_ *expression.Schema, childrenProperties ...[][]*expression.Column) [][]*expression.Column { + leftProperties := childrenProperties[0] + rightProperties := childrenProperties[1] + // TODO: We should consider properties propagation. + p.LeftProperties = leftProperties + p.RightProperties = rightProperties + if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin { + rightProperties = nil + } else if p.JoinType == RightOuterJoin { + leftProperties = nil + } + resultProperties := make([][]*expression.Column, len(leftProperties)+len(rightProperties)) + for i, cols := range leftProperties { + resultProperties[i] = make([]*expression.Column, len(cols)) + copy(resultProperties[i], cols) + } + leftLen := len(leftProperties) + for i, cols := range rightProperties { + resultProperties[leftLen+i] = make([]*expression.Column, len(cols)) + copy(resultProperties[leftLen+i], cols) + } + return resultProperties +} + +// ExhaustPhysicalPlans implements the base.LogicalPlan.<14th> interface. +// it can generates hash join, index join and sort merge join. +// Firstly we check the hint, if hint is figured by user, we force to choose the corresponding physical plan. +// If the hint is not matched, it will get other candidates. +// If the hint is not figured, we will pick all candidates. +func (p *LogicalJoin) ExhaustPhysicalPlans(prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { + if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { + indexJoins, _ := tryToGetIndexJoin(p, prop) + failpoint.Return(indexJoins, true, nil) + } + }) + + if !isJoinHintSupportedInMPPMode(p.PreferJoinType) { + if hasMPPJoinHints(p.PreferJoinType) { + // If there are MPP hints but has some conflicts join method hints, all the join hints are invalid. + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning("The MPP join hints are in conflict, and you can only specify join method hints that are currently supported by MPP mode now") + p.PreferJoinType = 0 + } else { + // If there are no MPP hints but has some conflicts join method hints, the MPP mode will be blocked. + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because you have used hint to specify a join algorithm which is not supported by mpp now.") + if prop.IsFlashProp() { + return nil, false, nil + } + } + } + if prop.MPPPartitionTp == property.BroadcastType { + return nil, false, nil + } + joins := make([]base.PhysicalPlan, 0, 8) + canPushToTiFlash := p.CanPushToCop(kv.TiFlash) + if p.SCtx().GetSessionVars().IsMPPAllowed() && canPushToTiFlash { + if (p.PreferJoinType & utilhint.PreferShuffleJoin) > 0 { + if shuffleJoins := tryToGetMppHashJoin(p, prop, false); len(shuffleJoins) > 0 { + return shuffleJoins, true, nil + } + } + if (p.PreferJoinType & utilhint.PreferBCJoin) > 0 { + if bcastJoins := tryToGetMppHashJoin(p, prop, true); len(bcastJoins) > 0 { + return bcastJoins, true, nil + } + } + if preferMppBCJ(p) { + mppJoins := tryToGetMppHashJoin(p, prop, true) + joins = append(joins, mppJoins...) + } else { + mppJoins := tryToGetMppHashJoin(p, prop, false) + joins = append(joins, mppJoins...) + } + } else { + hasMppHints := false + var errMsg string + if (p.PreferJoinType & utilhint.PreferShuffleJoin) > 0 { + errMsg = "The join can not push down to the MPP side, the shuffle_join() hint is invalid" + hasMppHints = true + } + if (p.PreferJoinType & utilhint.PreferBCJoin) > 0 { + errMsg = "The join can not push down to the MPP side, the broadcast_join() hint is invalid" + hasMppHints = true + } + if hasMppHints { + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) + } + } + if prop.IsFlashProp() { + return joins, true, nil + } + + if !p.IsNAAJ() { + // naaj refuse merge join and index join. + mergeJoins := GetMergeJoin(p, prop, p.Schema(), p.StatsInfo(), p.Children()[0].StatsInfo(), p.Children()[1].StatsInfo()) + if (p.PreferJoinType&utilhint.PreferMergeJoin) > 0 && len(mergeJoins) > 0 { + return mergeJoins, true, nil + } + joins = append(joins, mergeJoins...) + + indexJoins, forced := tryToGetIndexJoin(p, prop) + if forced { + return indexJoins, true, nil + } + joins = append(joins, indexJoins...) + } + + hashJoins, forced := getHashJoins(p, prop) + if forced && len(hashJoins) > 0 { + return hashJoins, true, nil + } + joins = append(joins, hashJoins...) + + if p.PreferJoinType > 0 { + // If we reach here, it means we have a hint that doesn't work. + // It might be affected by the required property, so we enforce + // this property and try the hint again. + return joins, false, nil + } + return joins, true, nil +} + +// ExtractCorrelatedCols implements the base.LogicalPlan.<15th> interface. +func (p *LogicalJoin) ExtractCorrelatedCols() []*expression.CorrelatedColumn { + corCols := make([]*expression.CorrelatedColumn, 0, len(p.EqualConditions)+len(p.LeftConditions)+len(p.RightConditions)+len(p.OtherConditions)) + for _, fun := range p.EqualConditions { + corCols = append(corCols, expression.ExtractCorColumns(fun)...) + } + for _, fun := range p.LeftConditions { + corCols = append(corCols, expression.ExtractCorColumns(fun)...) + } + for _, fun := range p.RightConditions { + corCols = append(corCols, expression.ExtractCorColumns(fun)...) + } + for _, fun := range p.OtherConditions { + corCols = append(corCols, expression.ExtractCorColumns(fun)...) + } + return corCols } -// ExtractFD implements the interface LogicalPlan. +// MaxOneRow inherits the BaseLogicalPlan.LogicalPlan.<16th> implementation. + +// Children inherits the BaseLogicalPlan.LogicalPlan.<17th> implementation. + +// SetChildren inherits the BaseLogicalPlan.LogicalPlan.<18th> implementation. + +// SetChild inherits the BaseLogicalPlan.LogicalPlan.<19th> implementation. + +// RollBackTaskMap inherits the BaseLogicalPlan.LogicalPlan.<20th> implementation. + +// CanPushToCop inherits the BaseLogicalPlan.LogicalPlan.<21st> implementation. + +// ExtractFD implements the base.LogicalPlan.<22th> interface. func (p *LogicalJoin) ExtractFD() *funcdep.FDSet { switch p.JoinType { case InnerJoin: @@ -156,6 +737,75 @@ func (p *LogicalJoin) ExtractFD() *funcdep.FDSet { } } +// GetBaseLogicalPlan inherits the BaseLogicalPlan.LogicalPlan.<23th> implementation. + +// ConvertOuterToInnerJoin implements base.LogicalPlan.<24th> interface. +func (p *LogicalJoin) ConvertOuterToInnerJoin(predicates []expression.Expression) base.LogicalPlan { + innerTable := p.Children()[0] + outerTable := p.Children()[1] + switchChild := false + + if p.JoinType == LeftOuterJoin { + innerTable, outerTable = outerTable, innerTable + switchChild = true + } + + // First, simplify this join + if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { + canBeSimplified := false + for _, expr := range predicates { + isOk := util.IsNullRejected(p.SCtx(), innerTable.Schema(), expr) + if isOk { + canBeSimplified = true + break + } + } + if canBeSimplified { + p.JoinType = InnerJoin + } + } + + // Next simplify join children + + combinedCond := mergeOnClausePredicates(p, predicates) + if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { + innerTable = innerTable.ConvertOuterToInnerJoin(combinedCond) + outerTable = outerTable.ConvertOuterToInnerJoin(predicates) + } else if p.JoinType == InnerJoin || p.JoinType == SemiJoin { + innerTable = innerTable.ConvertOuterToInnerJoin(combinedCond) + outerTable = outerTable.ConvertOuterToInnerJoin(combinedCond) + } else if p.JoinType == AntiSemiJoin { + innerTable = innerTable.ConvertOuterToInnerJoin(predicates) + outerTable = outerTable.ConvertOuterToInnerJoin(combinedCond) + } else { + innerTable = innerTable.ConvertOuterToInnerJoin(predicates) + outerTable = outerTable.ConvertOuterToInnerJoin(predicates) + } + + if switchChild { + p.SetChild(0, outerTable) + p.SetChild(1, innerTable) + } else { + p.SetChild(0, innerTable) + p.SetChild(1, outerTable) + } + + return p +} + +// *************************** end implementation of logicalPlan interface *************************** + +// IsNAAJ checks if the join is a non-adjacent-join. +func (p *LogicalJoin) IsNAAJ() bool { + return len(p.NAEQConditions) > 0 +} + +// Shallow copies a LogicalJoin struct. +func (p *LogicalJoin) Shallow() *LogicalJoin { + join := *p + return join.Init(p.SCtx(), p.QueryBlockOffset()) +} + func (p *LogicalJoin) extractFDForSemiJoin(filtersFromApply []expression.Expression) *funcdep.FDSet { // 1: since semi join will keep the part or all rows of the outer table, it's outer FD can be saved. // 2: the un-projected column will be left for the upper layer projection or already be pruned from bottom up. @@ -344,8 +994,8 @@ func (p *LogicalJoin) GetPotentialPartitionKeys() (leftKeys, rightKeys []*proper return } -// decorrelate eliminate the correlated column with if the col is in schema. -func (p *LogicalJoin) decorrelate(schema *expression.Schema) { +// Decorrelate eliminate the correlated column with if the col is in schema. +func (p *LogicalJoin) Decorrelate(schema *expression.Schema) { for i, cond := range p.LeftConditions { p.LeftConditions[i] = cond.Decorrelate(schema) } @@ -360,9 +1010,9 @@ func (p *LogicalJoin) decorrelate(schema *expression.Schema) { } } -// columnSubstituteAll is used in projection elimination in apply de-correlation. +// ColumnSubstituteAll is used in projection elimination in apply de-correlation. // Substitutions for all conditions should be successful, otherwise, we should keep all conditions unchanged. -func (p *LogicalJoin) columnSubstituteAll(schema *expression.Schema, exprs []expression.Expression) (hasFail bool) { +func (p *LogicalJoin) ColumnSubstituteAll(schema *expression.Schema, exprs []expression.Expression) (hasFail bool) { // make a copy of exprs for convenience of substitution (may change/partially change the expr tree) cpLeftConditions := make(expression.CNFExprs, len(p.LeftConditions)) cpRightConditions := make(expression.CNFExprs, len(p.RightConditions)) @@ -457,24 +1107,6 @@ func (p *LogicalJoin) AppendJoinConds(eq []*expression.ScalarFunction, left, rig p.OtherConditions = append(other, p.OtherConditions...) } -// ExtractCorrelatedCols implements LogicalPlan interface. -func (p *LogicalJoin) ExtractCorrelatedCols() []*expression.CorrelatedColumn { - corCols := make([]*expression.CorrelatedColumn, 0, len(p.EqualConditions)+len(p.LeftConditions)+len(p.RightConditions)+len(p.OtherConditions)) - for _, fun := range p.EqualConditions { - corCols = append(corCols, expression.ExtractCorColumns(fun)...) - } - for _, fun := range p.LeftConditions { - corCols = append(corCols, expression.ExtractCorColumns(fun)...) - } - for _, fun := range p.RightConditions { - corCols = append(corCols, expression.ExtractCorColumns(fun)...) - } - for _, fun := range p.OtherConditions { - corCols = append(corCols, expression.ExtractCorColumns(fun)...) - } - return corCols -} - // ExtractJoinKeys extract join keys as a schema for child with childIdx. func (p *LogicalJoin) ExtractJoinKeys(childIdx int) *expression.Schema { joinKeys := make([]*expression.Column, 0, len(p.EqualConditions)) @@ -484,7 +1116,117 @@ func (p *LogicalJoin) ExtractJoinKeys(childIdx int) *expression.Schema { return expression.NewSchema(joinKeys...) } -// PreferAny checks whether the join type prefers any of the join types specified in the joinFlags. +// extractUsedCols extracts all the needed columns. +func (p *LogicalJoin) extractUsedCols(parentUsedCols []*expression.Column) (leftCols []*expression.Column, rightCols []*expression.Column) { + for _, eqCond := range p.EqualConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(eqCond)...) + } + for _, leftCond := range p.LeftConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(leftCond)...) + } + for _, rightCond := range p.RightConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(rightCond)...) + } + for _, otherCond := range p.OtherConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(otherCond)...) + } + for _, naeqCond := range p.NAEQConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(naeqCond)...) + } + lChild := p.Children()[0] + rChild := p.Children()[1] + for _, col := range parentUsedCols { + if lChild.Schema().Contains(col) { + leftCols = append(leftCols, col) + } else if rChild.Schema().Contains(col) { + rightCols = append(rightCols, col) + } + } + return leftCols, rightCols +} + +// MergeSchema merge the schema of left and right child of join. +func (p *LogicalJoin) mergeSchema() { + p.SetSchema(buildLogicalJoinSchema(p.JoinType, p)) +} + +// pushDownTopNToChild will push a topN to one child of join. The idx stands for join child index. 0 is for left child. +func (p *LogicalJoin) pushDownTopNToChild(topN *LogicalTopN, idx int, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { + if topN == nil { + return p.Children()[idx].PushDownTopN(nil, opt) + } + + for _, by := range topN.ByItems { + cols := expression.ExtractColumns(by.Expr) + for _, col := range cols { + if !p.Children()[idx].Schema().Contains(col) { + return p.Children()[idx].PushDownTopN(nil, opt) + } + } + } + + newTopN := LogicalTopN{ + Count: topN.Count + topN.Offset, + ByItems: make([]*util.ByItems, len(topN.ByItems)), + PreferLimitToCop: topN.PreferLimitToCop, + }.Init(topN.SCtx(), topN.QueryBlockOffset()) + for i := range topN.ByItems { + newTopN.ByItems[i] = topN.ByItems[i].Clone() + } + appendTopNPushDownJoinTraceStep(p, newTopN, idx, opt) + return p.Children()[idx].PushDownTopN(newTopN, opt) +} + +// Add a new selection between parent plan and current plan with candidate predicates +/* ++-------------+ +-------------+ +| parentPlan | | parentPlan | ++-----^-------+ +-----^-------+ + | --addCandidateSelection---> | ++-----+-------+ +-----------+--------------+ +| currentPlan | | selection | ++-------------+ | candidate predicate | + +-----------^--------------+ + | + | + +----+--------+ + | currentPlan | + +-------------+ +*/ +// If the currentPlan at the top of query plan, return new root plan (selection) +// Else return nil +func addCandidateSelection(currentPlan base.LogicalPlan, currentChildIdx int, parentPlan base.LogicalPlan, + candidatePredicates []expression.Expression, opt *optimizetrace.LogicalOptimizeOp) (newRoot base.LogicalPlan) { + // generate a new selection for candidatePredicates + selection := LogicalSelection{Conditions: candidatePredicates}.Init(currentPlan.SCtx(), currentPlan.QueryBlockOffset()) + // add selection above of p + if parentPlan == nil { + newRoot = selection + } else { + parentPlan.SetChild(currentChildIdx, selection) + } + selection.SetChildren(currentPlan) + appendAddSelectionTraceStep(parentPlan, currentPlan, selection, opt) + if parentPlan == nil { + return newRoot + } + return nil +} + +func (p *LogicalJoin) getGroupNDVs(colGroups [][]*expression.Column, childStats []*property.StatsInfo) []property.GroupNDV { + outerIdx := int(-1) + if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { + outerIdx = 0 + } else if p.JoinType == RightOuterJoin { + outerIdx = 1 + } + if outerIdx >= 0 && len(colGroups) > 0 { + return childStats[outerIdx].GroupNDVs + } + return nil +} + +// PreferAny checks whether the join type is in the joinFlags. func (p *LogicalJoin) PreferAny(joinFlags ...uint) bool { for _, flag := range joinFlags { if p.PreferJoinType&flag > 0 { @@ -787,3 +1529,145 @@ func (p *LogicalJoin) SetPreferredJoinType() { p.PreferJoinType = 0 } } + +// updateEQCond will extract the arguments of a equal condition that connect two expressions. +func (p *LogicalJoin) updateEQCond() { + lChild, rChild := p.Children()[0], p.Children()[1] + var lKeys, rKeys []expression.Expression + var lNAKeys, rNAKeys []expression.Expression + // We need two steps here: + // step1: try best to extract normal EQ condition from OtherCondition to join EqualConditions. + for i := len(p.OtherConditions) - 1; i >= 0; i-- { + need2Remove := false + if eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction); ok && eqCond.FuncName.L == ast.EQ { + // If it is a column equal condition converted from `[not] in (subq)`, do not move it + // to EqualConditions, and keep it in OtherConditions. Reference comments in `extractOnCondition` + // for detailed reasons. + if expression.IsEQCondFromIn(eqCond) { + continue + } + lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] + if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { + lKeys = append(lKeys, lExpr) + rKeys = append(rKeys, rExpr) + need2Remove = true + } else if expression.ExprFromSchema(lExpr, rChild.Schema()) && expression.ExprFromSchema(rExpr, lChild.Schema()) { + lKeys = append(lKeys, rExpr) + rKeys = append(rKeys, lExpr) + need2Remove = true + } + } + if need2Remove { + p.OtherConditions = append(p.OtherConditions[:i], p.OtherConditions[i+1:]...) + } + } + // eg: explain select * from t1, t3 where t1.a+1 = t3.a; + // tidb only accept the join key in EqualCondition as a normal column (join OP take granted for that) + // so once we found the left and right children's schema can supply the all columns in complicated EQ condition that used by left/right key. + // we will add a layer of projection here to convert the complicated expression of EQ's left or right side to be a normal column. + adjustKeyForm := func(leftKeys, rightKeys []expression.Expression, isNA bool) { + if len(leftKeys) > 0 { + needLProj, needRProj := false, false + for i := range leftKeys { + _, lOk := leftKeys[i].(*expression.Column) + _, rOk := rightKeys[i].(*expression.Column) + needLProj = needLProj || !lOk + needRProj = needRProj || !rOk + } + + var lProj, rProj *LogicalProjection + if needLProj { + lProj = p.getProj(0) + } + if needRProj { + rProj = p.getProj(1) + } + for i := range leftKeys { + lKey, rKey := leftKeys[i], rightKeys[i] + if lProj != nil { + lKey = lProj.appendExpr(lKey) + } + if rProj != nil { + rKey = rProj.appendExpr(rKey) + } + eqCond := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), lKey, rKey) + if isNA { + p.NAEQConditions = append(p.NAEQConditions, eqCond.(*expression.ScalarFunction)) + } else { + p.EqualConditions = append(p.EqualConditions, eqCond.(*expression.ScalarFunction)) + } + } + } + } + adjustKeyForm(lKeys, rKeys, false) + + // Step2: when step1 is finished, then we can determine whether we need to extract NA-EQ from OtherCondition to NAEQConditions. + // when there are still no EqualConditions, let's try to be a NAAJ. + // todo: by now, when there is already a normal EQ condition, just keep NA-EQ as other-condition filters above it. + // eg: select * from stu where stu.name not in (select name from exam where exam.stu_id = stu.id); + // combination of and for join key is little complicated for now. + canBeNAAJ := (p.JoinType == AntiSemiJoin || p.JoinType == AntiLeftOuterSemiJoin) && len(p.EqualConditions) == 0 + if canBeNAAJ && p.SCtx().GetSessionVars().OptimizerEnableNAAJ { + var otherCond expression.CNFExprs + for i := 0; i < len(p.OtherConditions); i++ { + eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction) + if ok && eqCond.FuncName.L == ast.EQ && expression.IsEQCondFromIn(eqCond) { + // here must be a EQCondFromIn. + lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] + if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { + lNAKeys = append(lNAKeys, lExpr) + rNAKeys = append(rNAKeys, rExpr) + } else if expression.ExprFromSchema(lExpr, rChild.Schema()) && expression.ExprFromSchema(rExpr, lChild.Schema()) { + lNAKeys = append(lNAKeys, rExpr) + rNAKeys = append(rNAKeys, lExpr) + } + continue + } + otherCond = append(otherCond, p.OtherConditions[i]) + } + p.OtherConditions = otherCond + // here is for cases like: select (a+1, b*3) not in (select a,b from t2) from t1. + adjustKeyForm(lNAKeys, rNAKeys, true) + } +} + +func (p *LogicalJoin) getProj(idx int) *LogicalProjection { + child := p.Children()[idx] + proj, ok := child.(*LogicalProjection) + if ok { + return proj + } + proj = LogicalProjection{Exprs: make([]expression.Expression, 0, child.Schema().Len())}.Init(p.SCtx(), child.QueryBlockOffset()) + for _, col := range child.Schema().Columns { + proj.Exprs = append(proj.Exprs, col) + } + proj.SetSchema(child.Schema().Clone()) + proj.SetChildren(child) + p.Children()[idx] = proj + return proj +} + +// outerJoinPropConst propagates constant equal and column equal conditions over outer join. +func (p *LogicalJoin) outerJoinPropConst(predicates []expression.Expression) []expression.Expression { + outerTable := p.Children()[0] + innerTable := p.Children()[1] + if p.JoinType == RightOuterJoin { + innerTable, outerTable = outerTable, innerTable + } + lenJoinConds := len(p.EqualConditions) + len(p.LeftConditions) + len(p.RightConditions) + len(p.OtherConditions) + joinConds := make([]expression.Expression, 0, lenJoinConds) + for _, equalCond := range p.EqualConditions { + joinConds = append(joinConds, equalCond) + } + joinConds = append(joinConds, p.LeftConditions...) + joinConds = append(joinConds, p.RightConditions...) + joinConds = append(joinConds, p.OtherConditions...) + p.EqualConditions = nil + p.LeftConditions = nil + p.RightConditions = nil + p.OtherConditions = nil + nullSensitive := p.JoinType == AntiLeftOuterSemiJoin || p.JoinType == LeftOuterSemiJoin + joinConds, predicates = expression.PropConstOverOuterJoin(p.SCtx().GetExprCtx(), joinConds, predicates, outerTable.Schema(), innerTable.Schema(), nullSensitive) + p.AttachOnConds(joinConds) + return predicates +} diff --git a/pkg/planner/core/property_cols_prune.go b/pkg/planner/core/property_cols_prune.go index a3e5c435c5596..6f991c306088b 100644 --- a/pkg/planner/core/property_cols_prune.go +++ b/pkg/planner/core/property_cols_prune.go @@ -67,28 +67,3 @@ func getPossiblePropertyFromByItems(items []*util.ByItems) []*expression.Column } return cols } - -// PreparePossibleProperties implements base.LogicalPlan PreparePossibleProperties interface. -func (p *LogicalJoin) PreparePossibleProperties(_ *expression.Schema, childrenProperties ...[][]*expression.Column) [][]*expression.Column { - leftProperties := childrenProperties[0] - rightProperties := childrenProperties[1] - // TODO: We should consider properties propagation. - p.LeftProperties = leftProperties - p.RightProperties = rightProperties - if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin { - rightProperties = nil - } else if p.JoinType == RightOuterJoin { - leftProperties = nil - } - resultProperties := make([][]*expression.Column, len(leftProperties)+len(rightProperties)) - for i, cols := range leftProperties { - resultProperties[i] = make([]*expression.Column, len(cols)) - copy(resultProperties[i], cols) - } - leftLen := len(leftProperties) - for i, cols := range rightProperties { - resultProperties[leftLen+i] = make([]*expression.Column, len(cols)) - copy(resultProperties[leftLen+i], cols) - } - return resultProperties -} diff --git a/pkg/planner/core/rule_build_key_info.go b/pkg/planner/core/rule_build_key_info.go index eda67f11e7173..63050ac676dc8 100644 --- a/pkg/planner/core/rule_build_key_info.go +++ b/pkg/planner/core/rule_build_key_info.go @@ -71,52 +71,6 @@ func checkMaxOneRowCond(eqColIDs map[int64]struct{}, childSchema *expression.Sch return false } -// BuildKeyInfo implements base.LogicalPlan BuildKeyInfo interface. -func (p *LogicalJoin) BuildKeyInfo(selfSchema *expression.Schema, childSchema []*expression.Schema) { - p.LogicalSchemaProducer.BuildKeyInfo(selfSchema, childSchema) - switch p.JoinType { - case SemiJoin, LeftOuterSemiJoin, AntiSemiJoin, AntiLeftOuterSemiJoin: - selfSchema.Keys = childSchema[0].Clone().Keys - case InnerJoin, LeftOuterJoin, RightOuterJoin: - // If there is no equal conditions, then cartesian product can't be prevented and unique key information will destroy. - if len(p.EqualConditions) == 0 { - return - } - lOk := false - rOk := false - // Such as 'select * from t1 join t2 where t1.a = t2.a and t1.b = t2.b'. - // If one sides (a, b) is a unique key, then the unique key information is remained. - // But we don't consider this situation currently. - // Only key made by one column is considered now. - evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() - for _, expr := range p.EqualConditions { - ln := expr.GetArgs()[0].(*expression.Column) - rn := expr.GetArgs()[1].(*expression.Column) - for _, key := range childSchema[0].Keys { - if len(key) == 1 && key[0].Equal(evalCtx, ln) { - lOk = true - break - } - } - for _, key := range childSchema[1].Keys { - if len(key) == 1 && key[0].Equal(evalCtx, rn) { - rOk = true - break - } - } - } - // For inner join, if one side of one equal condition is unique key, - // another side's unique key information will all be reserved. - // If it's an outer join, NULL value will fill some position, which will destroy the unique key information. - if lOk && p.JoinType != LeftOuterJoin { - selfSchema.Keys = append(selfSchema.Keys, childSchema[1].Keys...) - } - if rOk && p.JoinType != RightOuterJoin { - selfSchema.Keys = append(selfSchema.Keys, childSchema[0].Keys...) - } - } -} - // checkIndexCanBeKey checks whether an Index can be a Key in schema. func checkIndexCanBeKey(idx *model.IndexInfo, columns []*model.ColumnInfo, schema *expression.Schema) (uniqueKey, newKey expression.KeyInfo) { if !idx.Unique { diff --git a/pkg/planner/core/rule_column_pruning.go b/pkg/planner/core/rule_column_pruning.go index cd160586220da..187b19e0f6dc1 100644 --- a/pkg/planner/core/rule_column_pruning.go +++ b/pkg/planner/core/rule_column_pruning.go @@ -167,64 +167,6 @@ func (ds *DataSource) PruneColumns(parentUsedCols []*expression.Column, opt *opt return ds, nil } -func (p *LogicalJoin) extractUsedCols(parentUsedCols []*expression.Column) (leftCols []*expression.Column, rightCols []*expression.Column) { - for _, eqCond := range p.EqualConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(eqCond)...) - } - for _, leftCond := range p.LeftConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(leftCond)...) - } - for _, rightCond := range p.RightConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(rightCond)...) - } - for _, otherCond := range p.OtherConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(otherCond)...) - } - for _, naeqCond := range p.NAEQConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(naeqCond)...) - } - lChild := p.Children()[0] - rChild := p.Children()[1] - for _, col := range parentUsedCols { - if lChild.Schema().Contains(col) { - leftCols = append(leftCols, col) - } else if rChild.Schema().Contains(col) { - rightCols = append(rightCols, col) - } - } - return leftCols, rightCols -} - -func (p *LogicalJoin) mergeSchema() { - p.SetSchema(buildLogicalJoinSchema(p.JoinType, p)) -} - -// PruneColumns implements base.LogicalPlan interface. -func (p *LogicalJoin) PruneColumns(parentUsedCols []*expression.Column, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, error) { - leftCols, rightCols := p.extractUsedCols(parentUsedCols) - - var err error - p.Children()[0], err = p.Children()[0].PruneColumns(leftCols, opt) - if err != nil { - return nil, err - } - addConstOneForEmptyProjection(p.Children()[0]) - - p.Children()[1], err = p.Children()[1].PruneColumns(rightCols, opt) - if err != nil { - return nil, err - } - addConstOneForEmptyProjection(p.Children()[1]) - - p.mergeSchema() - if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - joinCol := p.Schema().Columns[len(p.Schema().Columns)-1] - parentUsedCols = append(parentUsedCols, joinCol) - } - p.InlineProjection(parentUsedCols, opt) - return p, nil -} - // PruneColumns implements base.LogicalPlan interface. func (la *LogicalApply) PruneColumns(parentUsedCols []*expression.Column, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, error) { leftCols, rightCols := la.extractUsedCols(parentUsedCols) diff --git a/pkg/planner/core/rule_constant_propagation.go b/pkg/planner/core/rule_constant_propagation.go index f559a5b86d288..437c085d801e9 100644 --- a/pkg/planner/core/rule_constant_propagation.go +++ b/pkg/planner/core/rule_constant_propagation.go @@ -86,88 +86,6 @@ func (*constantPropagationSolver) name() string { return "constant_propagation" } -// ConstantPropagation implemented the logic of constant propagation in From List -// Query: select * from t, (select a, b from s where s.a>1) tmp where tmp.a=t.a -// Origin logical plan: -/* - +----------------+ - | LogicalJoin | - +-------^--------+ - | - +-------------+--------------+ - | | -+-----+------+ +------+------+ -| Projection | | TableScan | -+-----^------+ +-------------+ - | - | -+-----+------+ -| Selection | -| s.a>1 | -+------------+ -*/ -// 1. 'PullUpConstantPredicates': Call this function until find selection and pull up the constant predicate layer by layer -// LogicalSelection: find the s.a>1 -// LogicalProjection: get the s.a>1 and pull up it, changed to tmp.a>1 -// 2. 'addCandidateSelection': Add selection above of LogicalJoin, -// put all predicates pulled up from the lower layer into the current new selection. -// LogicalSelection: tmp.a >1 -// -// Optimized plan: -/* - +----------------+ - | Selection | - | tmp.a>1 | - +-------^--------+ - | - +-------+--------+ - | LogicalJoin | - +-------^--------+ - | - +-------------+--------------+ - | | -+-----+------+ +------+------+ -| Projection | | TableScan | -+-----^------+ +-------------+ - | - | -+-----+------+ -| Selection | -| s.a>1 | -+------------+ -*/ -// Return nil if the root of plan has not been changed -// Return new root if the root of plan is changed to selection -func (logicalJoin *LogicalJoin) ConstantPropagation(parentPlan base.LogicalPlan, currentChildIdx int, opt *optimizetrace.LogicalOptimizeOp) (newRoot base.LogicalPlan) { - // step1: get constant predicate from left or right according to the JoinType - var getConstantPredicateFromLeft bool - var getConstantPredicateFromRight bool - switch logicalJoin.JoinType { - case LeftOuterJoin: - getConstantPredicateFromLeft = true - case RightOuterJoin: - getConstantPredicateFromRight = true - case InnerJoin: - getConstantPredicateFromLeft = true - getConstantPredicateFromRight = true - default: - return - } - var candidateConstantPredicates []expression.Expression - if getConstantPredicateFromLeft { - candidateConstantPredicates = logicalJoin.Children()[0].PullUpConstantPredicates() - } - if getConstantPredicateFromRight { - candidateConstantPredicates = append(candidateConstantPredicates, logicalJoin.Children()[1].PullUpConstantPredicates()...) - } - if len(candidateConstantPredicates) == 0 { - return - } - - // step2: add selection above of LogicalJoin - return addCandidateSelection(logicalJoin, currentChildIdx, parentPlan, candidateConstantPredicates, opt) -} - // validComparePredicate checks if the predicate is an expression like [column '>'|'>='|'<'|'<='|'=' constant]. // return param1: return true, if the predicate is a compare constant predicate. // return param2: return the column side of predicate. @@ -190,39 +108,3 @@ func validCompareConstantPredicate(ctx expression.EvalContext, candidatePredicat } return true } - -// Add a new selection between parent plan and current plan with candidate predicates -/* -+-------------+ +-------------+ -| parentPlan | | parentPlan | -+-----^-------+ +-----^-------+ - | --addCandidateSelection---> | -+-----+-------+ +-----------+--------------+ -| currentPlan | | selection | -+-------------+ | candidate predicate | - +-----------^--------------+ - | - | - +----+--------+ - | currentPlan | - +-------------+ -*/ -// If the currentPlan at the top of query plan, return new root plan (selection) -// Else return nil -func addCandidateSelection(currentPlan base.LogicalPlan, currentChildIdx int, parentPlan base.LogicalPlan, - candidatePredicates []expression.Expression, opt *optimizetrace.LogicalOptimizeOp) (newRoot base.LogicalPlan) { - // generate a new selection for candidatePredicates - selection := LogicalSelection{Conditions: candidatePredicates}.Init(currentPlan.SCtx(), currentPlan.QueryBlockOffset()) - // add selection above of p - if parentPlan == nil { - newRoot = selection - } else { - parentPlan.SetChild(currentChildIdx, selection) - } - selection.SetChildren(currentPlan) - appendAddSelectionTraceStep(parentPlan, currentPlan, selection, opt) - if parentPlan == nil { - return newRoot - } - return nil -} diff --git a/pkg/planner/core/rule_decorrelate.go b/pkg/planner/core/rule_decorrelate.go index 627a05914749b..5fc9e569e3378 100644 --- a/pkg/planner/core/rule_decorrelate.go +++ b/pkg/planner/core/rule_decorrelate.go @@ -224,7 +224,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p base.LogicalPlan, op // upper OP (depend on column8) --> lower layer OP // | ^ // +-----------------------------+ // Fail: lower layer can't supply column8 anymore. - hasFail := apply.columnSubstituteAll(proj.Schema(), proj.Exprs) + hasFail := apply.ColumnSubstituteAll(proj.Schema(), proj.Exprs) if hasFail { goto NoOptimize } @@ -232,7 +232,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p base.LogicalPlan, op for i, expr := range proj.Exprs { proj.Exprs[i] = expr.Decorrelate(outerPlan.Schema()) } - apply.decorrelate(outerPlan.Schema()) + apply.Decorrelate(outerPlan.Schema()) innerPlan = proj.Children()[0] apply.SetChildren(outerPlan, innerPlan) diff --git a/pkg/planner/core/rule_eliminate_projection.go b/pkg/planner/core/rule_eliminate_projection.go index 9c8a7dfcd79b0..695e60da4a632 100644 --- a/pkg/planner/core/rule_eliminate_projection.go +++ b/pkg/planner/core/rule_eliminate_projection.go @@ -253,22 +253,6 @@ func ReplaceColumnOfExpr(expr expression.Expression, proj *LogicalProjection, sc return expr } -// ReplaceExprColumns implements base.LogicalPlan interface. -func (p *LogicalJoin) ReplaceExprColumns(replace map[string]*expression.Column) { - for _, equalExpr := range p.EqualConditions { - ResolveExprAndReplace(equalExpr, replace) - } - for _, leftExpr := range p.LeftConditions { - ResolveExprAndReplace(leftExpr, replace) - } - for _, rightExpr := range p.RightConditions { - ResolveExprAndReplace(rightExpr, replace) - } - for _, otherExpr := range p.OtherConditions { - ResolveExprAndReplace(otherExpr, replace) - } -} - // ReplaceExprColumns implements base.LogicalPlan interface. func (la *LogicalApply) ReplaceExprColumns(replace map[string]*expression.Column) { la.LogicalJoin.ReplaceExprColumns(replace) diff --git a/pkg/planner/core/rule_outer_to_inner_join.go b/pkg/planner/core/rule_outer_to_inner_join.go index 997108b4804d3..433739752992c 100644 --- a/pkg/planner/core/rule_outer_to_inner_join.go +++ b/pkg/planner/core/rule_outer_to_inner_join.go @@ -19,7 +19,6 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" ) @@ -60,60 +59,6 @@ func (*convertOuterToInnerJoin) optimize(_ context.Context, p base.LogicalPlan, // Also, predicates involving aggregate expressions are not null filtering. IsNullReject always returns // false for those cases. -// ConvertOuterToInnerJoin implements base.LogicalPlan ConvertOuterToInnerJoin interface. -func (p *LogicalJoin) ConvertOuterToInnerJoin(predicates []expression.Expression) base.LogicalPlan { - innerTable := p.Children()[0] - outerTable := p.Children()[1] - switchChild := false - - if p.JoinType == LeftOuterJoin { - innerTable, outerTable = outerTable, innerTable - switchChild = true - } - - // First, simplify this join - if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { - canBeSimplified := false - for _, expr := range predicates { - isOk := util.IsNullRejected(p.SCtx(), innerTable.Schema(), expr) - if isOk { - canBeSimplified = true - break - } - } - if canBeSimplified { - p.JoinType = InnerJoin - } - } - - // Next simplify join children - - combinedCond := mergeOnClausePredicates(p, predicates) - if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { - innerTable = innerTable.ConvertOuterToInnerJoin(combinedCond) - outerTable = outerTable.ConvertOuterToInnerJoin(predicates) - } else if p.JoinType == InnerJoin || p.JoinType == SemiJoin { - innerTable = innerTable.ConvertOuterToInnerJoin(combinedCond) - outerTable = outerTable.ConvertOuterToInnerJoin(combinedCond) - } else if p.JoinType == AntiSemiJoin { - innerTable = innerTable.ConvertOuterToInnerJoin(predicates) - outerTable = outerTable.ConvertOuterToInnerJoin(combinedCond) - } else { - innerTable = innerTable.ConvertOuterToInnerJoin(predicates) - outerTable = outerTable.ConvertOuterToInnerJoin(predicates) - } - - if switchChild { - p.SetChild(0, outerTable) - p.SetChild(1, innerTable) - } else { - p.SetChild(0, innerTable) - p.SetChild(1, outerTable) - } - - return p -} - func (*convertOuterToInnerJoin) name() string { return "convert_outer_to_inner_joins" } diff --git a/pkg/planner/core/rule_predicate_push_down.go b/pkg/planner/core/rule_predicate_push_down.go index c93388dc9f2ea..e656317968a5f 100644 --- a/pkg/planner/core/rule_predicate_push_down.go +++ b/pkg/planner/core/rule_predicate_push_down.go @@ -26,8 +26,6 @@ import ( "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" - "github.com/pingcap/tidb/pkg/planner/util/utilfuncp" - "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/ranger" "go.uber.org/zap" @@ -101,218 +99,6 @@ func (ds *DataSource) PredicatePushDown(predicates []expression.Expression, opt return predicates, ds } -// PredicatePushDown implements base.LogicalPlan PredicatePushDown interface. -func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression, opt *optimizetrace.LogicalOptimizeOp) (ret []expression.Expression, retPlan base.LogicalPlan) { - var equalCond []*expression.ScalarFunction - var leftPushCond, rightPushCond, otherCond, leftCond, rightCond []expression.Expression - switch p.JoinType { - case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: - predicates = p.outerJoinPropConst(predicates) - dual := Conds2TableDual(p, predicates) - if dual != nil { - appendTableDualTraceStep(p, dual, predicates, opt) - return ret, dual - } - // Handle where conditions - predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) - // Only derive left where condition, because right where condition cannot be pushed down - equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, true, false) - leftCond = leftPushCond - // Handle join conditions, only derive right join condition, because left join condition cannot be pushed down - _, derivedRightJoinCond := DeriveOtherConditions( - p, p.Children()[0].Schema(), p.Children()[1].Schema(), false, true) - rightCond = append(p.RightConditions, derivedRightJoinCond...) - p.RightConditions = nil - ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) - ret = append(ret, rightPushCond...) - case RightOuterJoin: - predicates = p.outerJoinPropConst(predicates) - dual := Conds2TableDual(p, predicates) - if dual != nil { - appendTableDualTraceStep(p, dual, predicates, opt) - return ret, dual - } - // Handle where conditions - predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) - // Only derive right where condition, because left where condition cannot be pushed down - equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, false, true) - rightCond = rightPushCond - // Handle join conditions, only derive left join condition, because right join condition cannot be pushed down - derivedLeftJoinCond, _ := DeriveOtherConditions( - p, p.Children()[0].Schema(), p.Children()[1].Schema(), true, false) - leftCond = append(p.LeftConditions, derivedLeftJoinCond...) - p.LeftConditions = nil - ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) - ret = append(ret, leftPushCond...) - case SemiJoin, InnerJoin: - tempCond := make([]expression.Expression, 0, len(p.LeftConditions)+len(p.RightConditions)+len(p.EqualConditions)+len(p.OtherConditions)+len(predicates)) - tempCond = append(tempCond, p.LeftConditions...) - tempCond = append(tempCond, p.RightConditions...) - tempCond = append(tempCond, expression.ScalarFuncs2Exprs(p.EqualConditions)...) - tempCond = append(tempCond, p.OtherConditions...) - tempCond = append(tempCond, predicates...) - tempCond = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), tempCond) - tempCond = expression.PropagateConstant(p.SCtx().GetExprCtx(), tempCond) - // Return table dual when filter is constant false or null. - dual := Conds2TableDual(p, tempCond) - if dual != nil { - appendTableDualTraceStep(p, dual, tempCond, opt) - return ret, dual - } - equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(tempCond, true, true) - p.LeftConditions = nil - p.RightConditions = nil - p.EqualConditions = equalCond - p.OtherConditions = otherCond - leftCond = leftPushCond - rightCond = rightPushCond - case AntiSemiJoin: - predicates = expression.PropagateConstant(p.SCtx().GetExprCtx(), predicates) - // Return table dual when filter is constant false or null. - dual := Conds2TableDual(p, predicates) - if dual != nil { - appendTableDualTraceStep(p, dual, predicates, opt) - return ret, dual - } - // `predicates` should only contain left conditions or constant filters. - _, leftPushCond, rightPushCond, _ = p.extractOnCondition(predicates, true, true) - // Do not derive `is not null` for anti join, since it may cause wrong results. - // For example: - // `select * from t t1 where t1.a not in (select b from t t2)` does not imply `t2.b is not null`, - // `select * from t t1 where t1.a not in (select a from t t2 where t1.b = t2.b` does not imply `t1.b is not null`, - // `select * from t t1 where not exists (select * from t t2 where t2.a = t1.a)` does not imply `t1.a is not null`, - leftCond = leftPushCond - rightCond = append(p.RightConditions, rightPushCond...) - p.RightConditions = nil - } - leftCond = expression.RemoveDupExprs(leftCond) - rightCond = expression.RemoveDupExprs(rightCond) - leftRet, lCh := p.Children()[0].PredicatePushDown(leftCond, opt) - rightRet, rCh := p.Children()[1].PredicatePushDown(rightCond, opt) - utilfuncp.AddSelection(p, lCh, leftRet, 0, opt) - utilfuncp.AddSelection(p, rCh, rightRet, 1, opt) - p.updateEQCond() - buildKeyInfo(p) - return ret, p.Self() -} - -// updateEQCond will extract the arguments of a equal condition that connect two expressions. -func (p *LogicalJoin) updateEQCond() { - lChild, rChild := p.Children()[0], p.Children()[1] - var lKeys, rKeys []expression.Expression - var lNAKeys, rNAKeys []expression.Expression - // We need two steps here: - // step1: try best to extract normal EQ condition from OtherCondition to join EqualConditions. - for i := len(p.OtherConditions) - 1; i >= 0; i-- { - need2Remove := false - if eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction); ok && eqCond.FuncName.L == ast.EQ { - // If it is a column equal condition converted from `[not] in (subq)`, do not move it - // to EqualConditions, and keep it in OtherConditions. Reference comments in `extractOnCondition` - // for detailed reasons. - if expression.IsEQCondFromIn(eqCond) { - continue - } - lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] - if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { - lKeys = append(lKeys, lExpr) - rKeys = append(rKeys, rExpr) - need2Remove = true - } else if expression.ExprFromSchema(lExpr, rChild.Schema()) && expression.ExprFromSchema(rExpr, lChild.Schema()) { - lKeys = append(lKeys, rExpr) - rKeys = append(rKeys, lExpr) - need2Remove = true - } - } - if need2Remove { - p.OtherConditions = append(p.OtherConditions[:i], p.OtherConditions[i+1:]...) - } - } - // eg: explain select * from t1, t3 where t1.a+1 = t3.a; - // tidb only accept the join key in EqualCondition as a normal column (join OP take granted for that) - // so once we found the left and right children's schema can supply the all columns in complicated EQ condition that used by left/right key. - // we will add a layer of projection here to convert the complicated expression of EQ's left or right side to be a normal column. - adjustKeyForm := func(leftKeys, rightKeys []expression.Expression, isNA bool) { - if len(leftKeys) > 0 { - needLProj, needRProj := false, false - for i := range leftKeys { - _, lOk := leftKeys[i].(*expression.Column) - _, rOk := rightKeys[i].(*expression.Column) - needLProj = needLProj || !lOk - needRProj = needRProj || !rOk - } - - var lProj, rProj *LogicalProjection - if needLProj { - lProj = p.getProj(0) - } - if needRProj { - rProj = p.getProj(1) - } - for i := range leftKeys { - lKey, rKey := leftKeys[i], rightKeys[i] - if lProj != nil { - lKey = lProj.appendExpr(lKey) - } - if rProj != nil { - rKey = rProj.appendExpr(rKey) - } - eqCond := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), lKey, rKey) - if isNA { - p.NAEQConditions = append(p.NAEQConditions, eqCond.(*expression.ScalarFunction)) - } else { - p.EqualConditions = append(p.EqualConditions, eqCond.(*expression.ScalarFunction)) - } - } - } - } - adjustKeyForm(lKeys, rKeys, false) - - // Step2: when step1 is finished, then we can determine whether we need to extract NA-EQ from OtherCondition to NAEQConditions. - // when there are still no EqualConditions, let's try to be a NAAJ. - // todo: by now, when there is already a normal EQ condition, just keep NA-EQ as other-condition filters above it. - // eg: select * from stu where stu.name not in (select name from exam where exam.stu_id = stu.id); - // combination of and for join key is little complicated for now. - canBeNAAJ := (p.JoinType == AntiSemiJoin || p.JoinType == AntiLeftOuterSemiJoin) && len(p.EqualConditions) == 0 - if canBeNAAJ && p.SCtx().GetSessionVars().OptimizerEnableNAAJ { - var otherCond expression.CNFExprs - for i := 0; i < len(p.OtherConditions); i++ { - eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction) - if ok && eqCond.FuncName.L == ast.EQ && expression.IsEQCondFromIn(eqCond) { - // here must be a EQCondFromIn. - lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] - if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { - lNAKeys = append(lNAKeys, lExpr) - rNAKeys = append(rNAKeys, rExpr) - } else if expression.ExprFromSchema(lExpr, rChild.Schema()) && expression.ExprFromSchema(rExpr, lChild.Schema()) { - lNAKeys = append(lNAKeys, rExpr) - rNAKeys = append(rNAKeys, lExpr) - } - continue - } - otherCond = append(otherCond, p.OtherConditions[i]) - } - p.OtherConditions = otherCond - // here is for cases like: select (a+1, b*3) not in (select a,b from t2) from t1. - adjustKeyForm(lNAKeys, rNAKeys, true) - } -} - -func (p *LogicalJoin) getProj(idx int) *LogicalProjection { - child := p.Children()[idx] - proj, ok := child.(*LogicalProjection) - if ok { - return proj - } - proj = LogicalProjection{Exprs: make([]expression.Expression, 0, child.Schema().Len())}.Init(p.SCtx(), child.QueryBlockOffset()) - for _, col := range child.Schema().Columns { - proj.Exprs = append(proj.Exprs, col) - } - proj.SetSchema(child.Schema().Clone()) - proj.SetChildren(child) - p.Children()[idx] = proj - return proj -} - // BreakDownPredicates breaks down predicates into two sets: canBePushed and cannotBePushed. It also maps columns to projection schema. func BreakDownPredicates(p *LogicalProjection, predicates []expression.Expression) ([]expression.Expression, []expression.Expression) { canBePushed := make([]expression.Expression, 0, len(predicates)) @@ -451,31 +237,6 @@ func DeleteTrueExprs(p base.LogicalPlan, conds []expression.Expression) []expres return newConds } -// outerJoinPropConst propagates constant equal and column equal conditions over outer join. -func (p *LogicalJoin) outerJoinPropConst(predicates []expression.Expression) []expression.Expression { - outerTable := p.Children()[0] - innerTable := p.Children()[1] - if p.JoinType == RightOuterJoin { - innerTable, outerTable = outerTable, innerTable - } - lenJoinConds := len(p.EqualConditions) + len(p.LeftConditions) + len(p.RightConditions) + len(p.OtherConditions) - joinConds := make([]expression.Expression, 0, lenJoinConds) - for _, equalCond := range p.EqualConditions { - joinConds = append(joinConds, equalCond) - } - joinConds = append(joinConds, p.LeftConditions...) - joinConds = append(joinConds, p.RightConditions...) - joinConds = append(joinConds, p.OtherConditions...) - p.EqualConditions = nil - p.LeftConditions = nil - p.RightConditions = nil - p.OtherConditions = nil - nullSensitive := p.JoinType == AntiLeftOuterSemiJoin || p.JoinType == LeftOuterSemiJoin - joinConds, predicates = expression.PropConstOverOuterJoin(p.SCtx().GetExprCtx(), joinConds, predicates, outerTable.Schema(), innerTable.Schema(), nullSensitive) - p.AttachOnConds(joinConds) - return predicates -} - func (*ppdSolver) name() string { return "predicate_push_down" } diff --git a/pkg/planner/core/rule_topn_push_down.go b/pkg/planner/core/rule_topn_push_down.go index 48b3b535a3054..7d8fb4bcd0713 100644 --- a/pkg/planner/core/rule_topn_push_down.go +++ b/pkg/planner/core/rule_topn_push_down.go @@ -19,10 +19,8 @@ import ( "context" "fmt" - "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" ) @@ -53,57 +51,6 @@ func pushDownTopNForBaseLogicalPlan(lp base.LogicalPlan, topNLogicalPlan base.Lo return p } -// pushDownTopNToChild will push a topN to one child of join. The idx stands for join child index. 0 is for left child. -func (p *LogicalJoin) pushDownTopNToChild(topN *LogicalTopN, idx int, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { - if topN == nil { - return p.Children()[idx].PushDownTopN(nil, opt) - } - - for _, by := range topN.ByItems { - cols := expression.ExtractColumns(by.Expr) - for _, col := range cols { - if !p.Children()[idx].Schema().Contains(col) { - return p.Children()[idx].PushDownTopN(nil, opt) - } - } - } - - newTopN := LogicalTopN{ - Count: topN.Count + topN.Offset, - ByItems: make([]*util.ByItems, len(topN.ByItems)), - PreferLimitToCop: topN.PreferLimitToCop, - }.Init(topN.SCtx(), topN.QueryBlockOffset()) - for i := range topN.ByItems { - newTopN.ByItems[i] = topN.ByItems[i].Clone() - } - appendTopNPushDownJoinTraceStep(p, newTopN, idx, opt) - return p.Children()[idx].PushDownTopN(newTopN, opt) -} - -// PushDownTopN implements the LogicalPlan interface. -func (p *LogicalJoin) PushDownTopN(topNLogicalPlan base.LogicalPlan, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { - var topN *LogicalTopN - if topNLogicalPlan != nil { - topN = topNLogicalPlan.(*LogicalTopN) - } - switch p.JoinType { - case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: - p.Children()[0] = p.pushDownTopNToChild(topN, 0, opt) - p.Children()[1] = p.Children()[1].PushDownTopN(nil, opt) - case RightOuterJoin: - p.Children()[1] = p.pushDownTopNToChild(topN, 1, opt) - p.Children()[0] = p.Children()[0].PushDownTopN(nil, opt) - default: - return p.BaseLogicalPlan.PushDownTopN(topN, opt) - } - - // The LogicalJoin may be also a LogicalApply. So we must use self to set parents. - if topN != nil { - return topN.AttachChild(p.Self(), opt) - } - return p.Self() -} - func (*pushDownTopNOptimizer) name() string { return "topn_push_down" } diff --git a/pkg/planner/core/stats.go b/pkg/planner/core/stats.go index fc1e4aaed4f00..4c6a0264f420c 100644 --- a/pkg/planner/core/stats.go +++ b/pkg/planner/core/stats.go @@ -443,109 +443,6 @@ func deriveLimitStats(childProfile *property.StatsInfo, limitCount float64) *pro return stats } -func (p *LogicalJoin) getGroupNDVs(colGroups [][]*expression.Column, childStats []*property.StatsInfo) []property.GroupNDV { - outerIdx := int(-1) - if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - outerIdx = 0 - } else if p.JoinType == RightOuterJoin { - outerIdx = 1 - } - if outerIdx >= 0 && len(colGroups) > 0 { - return childStats[outerIdx].GroupNDVs - } - return nil -} - -// DeriveStats implement LogicalPlan DeriveStats interface. -// If the type of join is SemiJoin, the selectivity of it will be same as selection's. -// If the type of join is LeftOuterSemiJoin, it will not add or remove any row. The last column is a boolean value, whose NDV should be two. -// If the type of join is inner/outer join, the output of join(s, t) should be N(s) * N(t) / (V(s.key) * V(t.key)) * Min(s.key, t.key). -// N(s) stands for the number of rows in relation s. V(s.key) means the NDV of join key in s. -// This is a quite simple strategy: We assume every bucket of relation which will participate join has the same number of rows, and apply cross join for -// every matched bucket. -func (p *LogicalJoin) DeriveStats(childStats []*property.StatsInfo, selfSchema *expression.Schema, childSchema []*expression.Schema, colGroups [][]*expression.Column) (*property.StatsInfo, error) { - if p.StatsInfo() != nil { - // Reload GroupNDVs since colGroups may have changed. - p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) - return p.StatsInfo(), nil - } - leftProfile, rightProfile := childStats[0], childStats[1] - leftJoinKeys, rightJoinKeys, _, _ := p.GetJoinKeys() - p.EqualCondOutCnt = cardinality.EstimateFullJoinRowCount(p.SCtx(), - 0 == len(p.EqualConditions), - leftProfile, rightProfile, - leftJoinKeys, rightJoinKeys, - childSchema[0], childSchema[1], - nil, nil) - if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin { - p.SetStats(&property.StatsInfo{ - RowCount: leftProfile.RowCount * cost.SelectionFactor, - ColNDVs: make(map[int64]float64, len(leftProfile.ColNDVs)), - }) - for id, c := range leftProfile.ColNDVs { - p.StatsInfo().ColNDVs[id] = c * cost.SelectionFactor - } - return p.StatsInfo(), nil - } - if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - p.SetStats(&property.StatsInfo{ - RowCount: leftProfile.RowCount, - ColNDVs: make(map[int64]float64, selfSchema.Len()), - }) - for id, c := range leftProfile.ColNDVs { - p.StatsInfo().ColNDVs[id] = c - } - p.StatsInfo().ColNDVs[selfSchema.Columns[selfSchema.Len()-1].UniqueID] = 2.0 - p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) - return p.StatsInfo(), nil - } - count := p.EqualCondOutCnt - if p.JoinType == LeftOuterJoin { - count = math.Max(count, leftProfile.RowCount) - } else if p.JoinType == RightOuterJoin { - count = math.Max(count, rightProfile.RowCount) - } - colNDVs := make(map[int64]float64, selfSchema.Len()) - for id, c := range leftProfile.ColNDVs { - colNDVs[id] = math.Min(c, count) - } - for id, c := range rightProfile.ColNDVs { - colNDVs[id] = math.Min(c, count) - } - p.SetStats(&property.StatsInfo{ - RowCount: count, - ColNDVs: colNDVs, - }) - p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) - return p.StatsInfo(), nil -} - -// ExtractColGroups implements LogicalPlan ExtractColGroups interface. -func (p *LogicalJoin) ExtractColGroups(colGroups [][]*expression.Column) [][]*expression.Column { - leftJoinKeys, rightJoinKeys, _, _ := p.GetJoinKeys() - extracted := make([][]*expression.Column, 0, 2+len(colGroups)) - if len(leftJoinKeys) > 1 && (p.JoinType == InnerJoin || p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin) { - extracted = append(extracted, expression.SortColumns(leftJoinKeys), expression.SortColumns(rightJoinKeys)) - } - var outerSchema *expression.Schema - if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - outerSchema = p.Children()[0].Schema() - } else if p.JoinType == RightOuterJoin { - outerSchema = p.Children()[1].Schema() - } - if len(colGroups) == 0 || outerSchema == nil { - return extracted - } - _, offsets := outerSchema.ExtractColGroups(colGroups) - if len(offsets) == 0 { - return extracted - } - for _, offset := range offsets { - extracted = append(extracted, colGroups[offset]) - } - return extracted -} - func (la *LogicalApply) getGroupNDVs(colGroups [][]*expression.Column, childStats []*property.StatsInfo) []property.GroupNDV { if len(colGroups) > 0 && (la.JoinType == LeftOuterSemiJoin || la.JoinType == AntiLeftOuterSemiJoin || la.JoinType == LeftOuterJoin) { return childStats[0].GroupNDVs