From 7f37bad60ff75cc8e7a5c054155aa00eaad769a5 Mon Sep 17 00:00:00 2001 From: bb7133 Date: Fri, 31 Aug 2018 00:11:04 +0800 Subject: [PATCH] expression: propagate more filters in PropagateConstant (#7276) --- cmd/explaintest/r/explain_easy.result | 11 +- expression/constant_propagation.go | 182 +++++++++++++++++--------- expression/constant_test.go | 36 +++++ plan/cbo_test.go | 9 +- plan/logical_plan_test.go | 2 +- plan/physical_plan_test.go | 2 +- 6 files changed, 169 insertions(+), 73 deletions(-) diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index d08e79cde7ce6..ff0e9159abeaa 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -274,11 +274,12 @@ Projection_11 10000.00 root 9_aux_0 ├─TableReader_15 10000.00 root data:TableScan_14 │ └─TableScan_14 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo └─StreamAgg_20 1.00 root funcs:count(1) - └─IndexJoin_43 10000.00 root inner join, inner:TableReader_42, outer key:s.a, inner key:t1.a - ├─TableReader_51 1.00 root data:TableScan_50 - │ └─TableScan_50 1.00 cop table:s, range: decided by [eq(s.a, test.t.a)], keep order:false, stats:pseudo - └─TableReader_42 10.00 root data:TableScan_41 - └─TableScan_41 10.00 cop table:t1, range: decided by [s.a], keep order:false, stats:pseudo + └─IndexJoin_44 10000.00 root inner join, inner:TableReader_43, outer key:s.a, inner key:t1.a + ├─TableReader_52 1.00 root data:TableScan_51 + │ └─TableScan_51 1.00 cop table:s, range: decided by [eq(s.a, test.t.a)], keep order:false, stats:pseudo + └─TableReader_43 8000.00 root data:Selection_42 + └─Selection_42 8000.00 cop eq(t1.a, test.t.a) + └─TableScan_41 10.00 cop table:t1, range: decided by [s.a], keep order:false, stats:pseudo explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.a = t1.a) from t; id count task operator info Projection_11 10000.00 root 9_aux_0 diff --git a/expression/constant_propagation.go b/expression/constant_propagation.go index 232462e026986..c0573b6712b0c 100644 --- a/expression/constant_propagation.go +++ b/expression/constant_propagation.go @@ -27,19 +27,6 @@ import ( // MaxPropagateColsCnt means the max number of columns that can participate propagation. var MaxPropagateColsCnt = 100 -var eqFuncNameMap = map[string]bool{ - ast.EQ: true, -} - -// inEqFuncNameMap stores all the in-equal operators that can be propagated. -var inEqFuncNameMap = map[string]bool{ - ast.LT: true, - ast.GT: true, - ast.LE: true, - ast.GE: true, - ast.NE: true, -} - type multiEqualSet struct { parent []int } @@ -72,51 +59,12 @@ type propagateConstantSolver struct { ctx sessionctx.Context } -// propagateInEQ propagates all in-equal conditions. -// e.g. For expression a = b and b = c and c = d and c < 1 , we can get extra a < 1 and b < 1 and d < 1. -// We maintain a unionSet representing the equivalent for every two columns. -func (s *propagateConstantSolver) propagateInEQ() { - s.unionSet = &multiEqualSet{} - s.unionSet.init(len(s.columns)) - for i := range s.conditions { - if fun, ok := s.conditions[i].(*ScalarFunction); ok && fun.FuncName.L == ast.EQ { - lCol, lOk := fun.GetArgs()[0].(*Column) - rCol, rOk := fun.GetArgs()[1].(*Column) - if lOk && rOk { - lID := s.getColID(lCol) - rID := s.getColID(rCol) - s.unionSet.addRelation(lID, rID) - } - } - } - condsLen := len(s.conditions) - for i := 0; i < condsLen; i++ { - cond := s.conditions[i] - col, con := s.validPropagateCond(cond, inEqFuncNameMap) - if col == nil { - continue - } - id := s.getColID(col) - for j := range s.columns { - if id != j && s.unionSet.findRoot(id) == s.unionSet.findRoot(j) { - funName := cond.(*ScalarFunction).FuncName.L - var newExpr Expression - if _, ok := cond.(*ScalarFunction).GetArgs()[0].(*Column); ok { - newExpr = NewFunctionInternal(s.ctx, funName, cond.GetType(), s.columns[j], con) - } else { - newExpr = NewFunctionInternal(s.ctx, funName, cond.GetType(), con, s.columns[j]) - } - s.conditions = append(s.conditions, newExpr) - } - } - } -} - -// propagateEQ propagates equal expression multiple times. An example runs as following: +// propagateConstantEQ propagates expressions like 'column = constant' by substituting the constant for column, the +// procedure repeats multiple times. An example runs as following: // a = d & b * 2 = c & c = d + 2 & b = 1 & a = 4, we pick eq cond b = 1 and a = 4 // d = 4 & 2 = c & c = d + 2 & b = 1 & a = 4, we propagate b = 1 and a = 4 and pick eq cond c = 2 and d = 4 // d = 4 & 2 = c & false & b = 1 & a = 4, we propagate c = 2 and d = 4, and do constant folding: c = d + 2 will be folded as false. -func (s *propagateConstantSolver) propagateEQ() { +func (s *propagateConstantSolver) propagateConstantEQ() { s.eqList = make([]*Constant, len(s.columns)) visited := make([]bool, len(s.conditions)) for i := 0; i < MaxPropagateColsCnt; i++ { @@ -138,10 +86,72 @@ func (s *propagateConstantSolver) propagateEQ() { } } -// validPropagateCond checks if the cond is an expression like [column op constant] and op is in the funNameMap. -func (s *propagateConstantSolver) validPropagateCond(cond Expression, funNameMap map[string]bool) (*Column, *Constant) { +// propagateColumnEQ propagates expressions like 'column A = column B' by adding extra filters +// 'expression(..., column B, ...)' propagated from 'expression(..., column A, ...)' as long as: +// +// 1. The expression is deterministic +// 2. The expression doesn't have any side effect +// +// e.g. For expression a = b and b = c and c = d and c < 1 , we can get extra a < 1 and b < 1 and d < 1. +// However, for a = b and a < rand(), we cannot propagate a < rand() to b < rand() because rand() is non-deterministic +// +// This propagation may bring redundancies that we need to resolve later, for example: +// for a = b and a < 3 and b < 3, we get new a < 3 and b < 3, which are redundant +// for a = b and a < 3 and 3 > b, we get new b < 3 and 3 > a, which are redundant +// for a = b and a < 3 and b < 4, we get new a < 4 and b < 3 but should expect a < 3 and b < 3 +// for a = b and a in (3) and b in (4), we get b in (3) and a in (4) but should expect 'false' +// +// TODO: remove redundancies later +// +// We maintain a unionSet representing the equivalent for every two columns. +func (s *propagateConstantSolver) propagateColumnEQ() { + visited := make([]bool, len(s.conditions)) + s.unionSet = &multiEqualSet{} + s.unionSet.init(len(s.columns)) + for i := range s.conditions { + if fun, ok := s.conditions[i].(*ScalarFunction); ok && fun.FuncName.L == ast.EQ { + lCol, lOk := fun.GetArgs()[0].(*Column) + rCol, rOk := fun.GetArgs()[1].(*Column) + if lOk && rOk { + lID := s.getColID(lCol) + rID := s.getColID(rCol) + s.unionSet.addRelation(lID, rID) + visited[i] = true + } + } + } + + condsLen := len(s.conditions) + for i, coli := range s.columns { + for j := i + 1; j < len(s.columns); j++ { + // unionSet doesn't have iterate(), we use a two layer loop to iterate col_i = col_j relation + if s.unionSet.findRoot(i) != s.unionSet.findRoot(j) { + continue + } + colj := s.columns[j] + for k := 0; k < condsLen; k++ { + if visited[k] { + // cond_k has been used to retrieve equality relation + continue + } + cond := s.conditions[k] + replaced, _, newExpr := s.tryToReplaceCond(coli, colj, cond) + if replaced { + s.conditions = append(s.conditions, newExpr) + } + replaced, _, newExpr = s.tryToReplaceCond(colj, coli, cond) + if replaced { + s.conditions = append(s.conditions, newExpr) + } + } + } + } +} + +// validEqualCond checks if the cond is an expression like [column eq constant]. +func (s *propagateConstantSolver) validEqualCond(cond Expression) (*Column, *Constant) { if eq, ok := cond.(*ScalarFunction); ok { - if _, ok := funNameMap[eq.FuncName.L]; !ok { + if eq.FuncName.L != ast.EQ { return nil, nil } if col, colOk := eq.GetArgs()[0].(*Column); colOk { @@ -158,6 +168,54 @@ func (s *propagateConstantSolver) validPropagateCond(cond Expression, funNameMap return nil, nil } +// tryToReplaceCond aims to replace all occurrences of column 'src' and try to replace it with 'tgt' in 'cond' +// It returns +// bool: if a replacement happened +// bool: if 'cond' contains non-deterministic expression +// Expression: the replaced expression, or original 'cond' if the replacement didn't happen +// +// For example: +// for 'a, b, a < 3', it returns 'true, false, b < 3' +// for 'a, b, sin(a) + cos(a) = 5', it returns 'true, false, returns sin(b) + cos(b) = 5' +// for 'a, b, cast(a) < rand()', it returns 'false, true, cast(a) < rand()' +func (s *propagateConstantSolver) tryToReplaceCond(src *Column, tgt *Column, cond Expression) (bool, bool, Expression) { + sf, ok := cond.(*ScalarFunction) + if !ok { + return false, false, cond + } + replaced := false + var args []Expression + if _, ok := unFoldableFunctions[sf.FuncName.L]; ok { + return false, true, cond + } + for idx, expr := range sf.GetArgs() { + if src.Equal(nil, expr) { + replaced = true + if args == nil { + args = make([]Expression, len(sf.GetArgs())) + copy(args, sf.GetArgs()) + } + args[idx] = tgt + } else { + subReplaced, isNonDeterminisitic, subExpr := s.tryToReplaceCond(src, tgt, expr) + if isNonDeterminisitic { + return false, true, cond + } else if subReplaced { + replaced = true + if args == nil { + args = make([]Expression, len(sf.GetArgs())) + copy(args, sf.GetArgs()) + } + args[idx] = subExpr + } + } + } + if replaced { + return true, false, NewFunctionInternal(s.ctx, sf.FuncName.L, sf.GetType(), args...) + } + return false, false, cond +} + func (s *propagateConstantSolver) setConds2ConstFalse() { s.conditions = []Expression{&Constant{ Value: types.NewDatum(false), @@ -172,7 +230,7 @@ func (s *propagateConstantSolver) pickNewEQConds(visited []bool) (retMapper map[ if visited[i] { continue } - col, con := s.validPropagateCond(cond, eqFuncNameMap) + col, con := s.validEqualCond(cond) // Then we check if this CNF item is a false constant. If so, we will set the whole condition to false. var ok bool if col == nil { @@ -227,8 +285,8 @@ func (s *propagateConstantSolver) solve(conditions []Expression) []Expression { log.Warnf("[const_propagation]Too many columns in a single CNF: the column count is %d, the max count is %d.", len(s.columns), MaxPropagateColsCnt) return conditions } - s.propagateEQ() - s.propagateInEQ() + s.propagateConstantEQ() + s.propagateColumnEQ() for i, cond := range s.conditions { if dnf, ok := cond.(*ScalarFunction); ok && dnf.FuncName.L == ast.LogicOr { dnfItems := SplitDNFItems(cond) @@ -255,7 +313,7 @@ func (s *propagateConstantSolver) insertCol(col *Column) { } } -// PropagateConstant propagate constant values of equality predicates and inequality predicates in a condition. +// PropagateConstant propagate constant values of deterministic predicates in a condition. func PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { solver := &propagateConstantSolver{ colMapper: make(map[string]int), diff --git a/expression/constant_test.go b/expression/constant_test.go index c36c2d39e2411..cac21ca495b1b 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -113,6 +113,42 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { }, result: "0", }, + { + conditions: []Expression{ + newFunction(ast.EQ, newColumn(0), newColumn(1)), + newFunction(ast.In, newColumn(0), newLonglong(1), newLonglong(2)), + newFunction(ast.In, newColumn(1), newLonglong(3), newLonglong(4)), + }, + result: "eq(test.t.0, test.t.1), in(test.t.0, 1, 2), in(test.t.0, 3, 4), in(test.t.1, 1, 2), in(test.t.1, 3, 4)", + }, + { + conditions: []Expression{ + newFunction(ast.EQ, newColumn(0), newColumn(1)), + newFunction(ast.EQ, newColumn(0), newFunction(ast.BitLength, newColumn(2))), + }, + result: "eq(test.t.0, bit_length(cast(test.t.2))), eq(test.t.0, test.t.1), eq(test.t.1, bit_length(cast(test.t.2)))", + }, + { + conditions: []Expression{ + newFunction(ast.EQ, newColumn(0), newColumn(1)), + newFunction(ast.LE, newFunction(ast.Mul, newColumn(0), newColumn(0)), newLonglong(50)), + }, + result: "eq(test.t.0, test.t.1), le(mul(test.t.0, test.t.0), 50), le(mul(test.t.1, test.t.1), 50)", + }, + { + conditions: []Expression{ + newFunction(ast.EQ, newColumn(0), newColumn(1)), + newFunction(ast.LE, newColumn(0), newFunction(ast.Plus, newColumn(1), newLonglong(1))), + }, + result: "eq(test.t.0, test.t.1), le(test.t.0, plus(test.t.0, 1)), le(test.t.0, plus(test.t.1, 1)), le(test.t.1, plus(test.t.1, 1))", + }, + { + conditions: []Expression{ + newFunction(ast.EQ, newColumn(0), newColumn(1)), + newFunction(ast.LE, newColumn(0), newFunction(ast.Rand)), + }, + result: "eq(test.t.0, test.t.1), le(cast(test.t.0), rand())", + }, } for _, tt := range tests { ctx := mock.NewContext() diff --git a/plan/cbo_test.go b/plan/cbo_test.go index 832ac06ede355..5f90f40e36576 100644 --- a/plan/cbo_test.go +++ b/plan/cbo_test.go @@ -37,7 +37,7 @@ var _ = Suite(&testAnalyzeSuite{}) type testAnalyzeSuite struct { } -// CBOWithoutAnalyze tests the plan with stats that only have count info. +// TestCBOWithoutAnalyze tests the plan with stats that only have count info. func (s *testAnalyzeSuite) TestCBOWithoutAnalyze(c *C) { defer testleak.AfterTest(c)() store, dom, err := newStoreWithBootstrap() @@ -633,12 +633,13 @@ func (s *testAnalyzeSuite) TestCorrelatedEstimation(c *C) { " ├─TableReader_15 10.00 root data:TableScan_14", " │ └─TableScan_14 10.00 cop table:t, range:[-inf,+inf], keep order:false", " └─StreamAgg_20 1.00 root funcs:count(1)", - " └─HashRightJoin_22 1.00 root inner join, inner:TableReader_25, equal:[eq(s.a, t1.a)]", + " └─HashLeftJoin_21 1.00 root inner join, inner:TableReader_28, equal:[eq(s.a, t1.a)]", " ├─TableReader_25 1.00 root data:Selection_24", " │ └─Selection_24 1.00 cop eq(s.a, test.t.a)", " │ └─TableScan_23 10.00 cop table:s, range:[-inf,+inf], keep order:false", - " └─TableReader_27 10.00 root data:TableScan_26", - " └─TableScan_26 10.00 cop table:t1, range:[-inf,+inf], keep order:false", + " └─TableReader_28 1.00 root data:Selection_27", + " └─Selection_27 1.00 cop eq(t1.a, test.t.a)", + " └─TableScan_26 10.00 cop table:t1, range:[-inf,+inf], keep order:false", )) tk.MustQuery("explain select (select concat(t1.a, \",\", t1.b) from t t1 where t1.a=t.a and t1.c=t.c) from t"). Check(testkit.Rows( diff --git a/plan/logical_plan_test.go b/plan/logical_plan_test.go index eb80d298edde2..2f70772b0544f 100644 --- a/plan/logical_plan_test.go +++ b/plan/logical_plan_test.go @@ -315,7 +315,7 @@ func (s *testPlanSuite) TestPredicatePushDown(c *C) { }, { sql: "select * from t t1, t t2 where t1.a = t2.b and t2.b > 0 and t1.a = t1.c and t1.d like 'abc' and t2.d = t1.d", - best: "Join{DataScan(t2)->DataScan(t1)->Sel([like(cast(t1.d), abc, 92)])}(t2.b,t1.a)(t2.d,t1.d)->Projection", + best: "Join{DataScan(t2)->Sel([like(cast(t2.d), abc, 92)])->DataScan(t1)->Sel([like(cast(t1.d), abc, 92)])}(t2.b,t1.a)(t2.d,t1.d)->Projection", }, { sql: "select * from t ta join t tb on ta.d = tb.d and ta.d > 1 where tb.a = 0", diff --git a/plan/physical_plan_test.go b/plan/physical_plan_test.go index 8fc193f08b381..189ccc23fab72 100644 --- a/plan/physical_plan_test.go +++ b/plan/physical_plan_test.go @@ -467,7 +467,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderSubquery(c *C) { // Test Apply. { sql: "select t.c in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t", - best: "Apply{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(s.a,t1.a)->StreamAgg}->Projection", + best: "Apply{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->TableReader(Table(t)->Sel([eq(t1.a, test.t.a)]))}(s.a,t1.a)->StreamAgg}->Projection", }, { sql: "select (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t",