diff --git a/expression/constant_propagation.go b/expression/constant_propagation.go index 3b49189f067d2..a88f4d957d5a6 100644 --- a/expression/constant_propagation.go +++ b/expression/constant_propagation.go @@ -292,10 +292,7 @@ func (s *propConstSolver) solve(conditions []Expression) []Expression { // PropagateConstant propagate constant values of deterministic predicates in a condition. func PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { - solver := &propConstSolver{} - solver.colMapper = make(map[int64]int) - solver.ctx = ctx - return solver.solve(conditions) + return newPropConstSolver().PropagateConstant(ctx, conditions) } type propOuterJoinConstSolver struct { @@ -551,3 +548,21 @@ func PropConstOverOuterJoin(ctx sessionctx.Context, joinConds, filterConds []Exp solver.ctx = ctx return solver.solve(joinConds, filterConds) } + +// PropagateConstantSolver is a constant propagate solver. +type PropagateConstantSolver interface { + PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression +} + +// newPropConstSolver returns a PropagateConstantSolver. +func newPropConstSolver() PropagateConstantSolver { + solver := &propConstSolver{} + solver.colMapper = make(map[int64]int) + return solver +} + +// PropagateConstant propagate constant values of deterministic predicates in a condition. +func (s *propConstSolver) PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { + s.ctx = ctx + return s.solve(conditions) +} diff --git a/expression/constant_test.go b/expression/constant_test.go index e864c6e3a4fbb..f07047a5f9bbe 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -56,10 +56,12 @@ func newFunction(funcName string, args ...Expression) Expression { func (*testExpressionSuite) TestConstantPropagation(c *C) { defer testleak.AfterTest(c)() tests := []struct { + solver []PropagateConstantSolver conditions []Expression result string }{ { + solver: []PropagateConstantSolver{newPropConstSolver(), pgSolver2{}}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.EQ, newColumn(1), newColumn(2)), @@ -70,6 +72,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "1, eq(test.t.0, 1), eq(test.t.1, 1), eq(test.t.2, 1), eq(test.t.3, 1)", }, { + solver: []PropagateConstantSolver{newPropConstSolver(), pgSolver2{}}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.EQ, newColumn(1), newLonglong(1)), @@ -78,6 +81,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "eq(test.t.0, 1), eq(test.t.1, 1), ne(test.t.2, 2)", }, { + solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.EQ, newColumn(1), newLonglong(1)), @@ -89,6 +93,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "eq(test.t.0, 1), eq(test.t.1, 1), eq(test.t.2, test.t.3), ge(test.t.2, 2), ge(test.t.3, 2), ne(test.t.2, 4), ne(test.t.2, 5), ne(test.t.3, 4), ne(test.t.3, 5)", }, { + solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.EQ, newColumn(0), newColumn(2)), @@ -97,6 +102,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "eq(test.t.0, test.t.1), eq(test.t.0, test.t.2), ge(test.t.0, 0), ge(test.t.1, 0), ge(test.t.2, 0)", }, { + solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.GT, newColumn(0), newLonglong(2)), @@ -107,6 +113,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "eq(test.t.0, test.t.1), gt(2, test.t.0), gt(2, test.t.1), gt(test.t.0, 2), gt(test.t.0, 3), gt(test.t.1, 2), gt(test.t.1, 3), lt(test.t.0, 1), lt(test.t.1, 1)", }, { + solver: []PropagateConstantSolver{newPropConstSolver(), pgSolver2{}}, conditions: []Expression{ newFunction(ast.EQ, newLonglong(1), newColumn(0)), newLonglong(0), @@ -114,6 +121,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "0", }, { + solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.In, newColumn(0), newLonglong(1), newLonglong(2)), @@ -122,6 +130,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "eq(test.t.0, test.t.1), in(test.t.0, 1, 2), in(test.t.0, 3, 4), in(test.t.1, 1, 2), in(test.t.1, 3, 4)", }, { + solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.EQ, newColumn(0), newFunction(ast.BitLength, newColumn(2))), @@ -129,6 +138,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "eq(test.t.0, bit_length(cast(test.t.2))), eq(test.t.0, test.t.1), eq(test.t.1, bit_length(cast(test.t.2)))", }, { + solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.LE, newFunction(ast.Mul, newColumn(0), newColumn(0)), newLonglong(50)), @@ -136,6 +146,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "eq(test.t.0, test.t.1), le(mul(test.t.0, test.t.0), 50), le(mul(test.t.1, test.t.1), 50)", }, { + solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.LE, newColumn(0), newFunction(ast.Plus, newColumn(1), newLonglong(1))), @@ -143,6 +154,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { result: "eq(test.t.0, test.t.1), le(test.t.0, plus(test.t.0, 1)), le(test.t.0, plus(test.t.1, 1)), le(test.t.1, plus(test.t.1, 1))", }, { + solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), newFunction(ast.LE, newColumn(0), newFunction(ast.Rand)), @@ -151,18 +163,20 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { }, } for _, tt := range tests { - ctx := mock.NewContext() - conds := make([]Expression, 0, len(tt.conditions)) - for _, cd := range tt.conditions { - conds = append(conds, FoldConstant(cd)) + for _, solver := range tt.solver { + ctx := mock.NewContext() + conds := make([]Expression, 0, len(tt.conditions)) + for _, cd := range tt.conditions { + conds = append(conds, FoldConstant(cd)) + } + newConds := solver.PropagateConstant(ctx, conds) + var result []string + for _, v := range newConds { + result = append(result, v.String()) + } + sort.Strings(result) + c.Assert(strings.Join(result, ", "), Equals, tt.result, Commentf("different for expr %s", tt.conditions)) } - newConds := PropagateConstant(ctx, conds) - var result []string - for _, v := range newConds { - result = append(result, v.String()) - } - sort.Strings(result) - c.Assert(strings.Join(result, ", "), Equals, tt.result, Commentf("different for expr %s", tt.conditions)) } } diff --git a/expression/constraint_propagation.go b/expression/constraint_propagation.go new file mode 100644 index 0000000000000..17454df5a35e3 --- /dev/null +++ b/expression/constraint_propagation.go @@ -0,0 +1,166 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "bytes" + + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + log "github.com/sirupsen/logrus" +) + +// exprSet is a Set container for expressions, each expression in it is unique. +// `tombstone` is deleted mark, if tombstone[i] is true, data[i] is invalid. +// `index` use expr.HashCode() as key, to implement the unique property. +type exprSet struct { + data []Expression + tombstone []bool + exists map[string]struct{} + constfalse bool +} + +func (s *exprSet) Append(e Expression) bool { + if _, ok := s.exists[string(e.HashCode(nil))]; ok { + return false + } + + s.data = append(s.data, e) + s.tombstone = append(s.tombstone, false) + s.exists[string(e.HashCode(nil))] = struct{}{} + return true +} + +// Slice returns the valid expressions in the exprSet, this function has side effect. +func (s *exprSet) Slice() []Expression { + if s.constfalse { + return []Expression{&Constant{ + Value: types.NewDatum(false), + RetType: types.NewFieldType(mysql.TypeTiny), + }} + } + + idx := 0 + for i := 0; i < len(s.data); i++ { + if !s.tombstone[i] { + s.data[idx] = s.data[i] + idx++ + } + } + return s.data[:idx] +} + +func (s *exprSet) SetConstFalse() { + s.constfalse = true +} + +func newExprSet(conditions []Expression) *exprSet { + var exprs exprSet + exprs.data = make([]Expression, 0, len(conditions)) + exprs.tombstone = make([]bool, 0, len(conditions)) + exprs.exists = make(map[string]struct{}, len(conditions)) + for _, v := range conditions { + exprs.Append(v) + } + return &exprs +} + +type pgSolver2 struct{} + +// PropagateConstant propagate constant values of deterministic predicates in a condition. +func (s pgSolver2) PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { + exprs := newExprSet(conditions) + s.fixPoint(ctx, exprs) + return exprs.Slice() +} + +// fixPoint is the core of the constant propagation algorithm. +// It will iterate the expression set over and over again, pick two expressions, +// apply one to another. +// If new conditions can be infered, they will be append into the expression set. +// Until no more conditions can be infered from the set, the algorithm finish. +func (s pgSolver2) fixPoint(ctx sessionctx.Context, exprs *exprSet) { + for { + saveLen := len(exprs.data) + iterOnce(ctx, exprs) + if saveLen == len(exprs.data) { + break + } + } + return +} + +// iterOnce picks two expressions from the set, try to propagate new conditions from them. +func iterOnce(ctx sessionctx.Context, exprs *exprSet) { + for i := 0; i < len(exprs.data); i++ { + if exprs.tombstone[i] { + continue + } + for j := 0; j < len(exprs.data); j++ { + if exprs.tombstone[j] { + continue + } + if i == j { + continue + } + solve(ctx, i, j, exprs) + } + } +} + +// solve uses exprs[i] exprs[j] to propagate new conditions. +func solve(ctx sessionctx.Context, i, j int, exprs *exprSet) { + for _, rule := range rules { + rule(ctx, i, j, exprs) + } +} + +type constantPropagateRule func(ctx sessionctx.Context, i, j int, exprs *exprSet) + +var rules = []constantPropagateRule{ + ruleConstantFalse, + ruleColumnEQConst, +} + +// ruleConstantFalse propagates from CNF condition that false plus anything returns false. +// false, a = 1, b = c ... => false +func ruleConstantFalse(ctx sessionctx.Context, i, j int, exprs *exprSet) { + cond := exprs.data[i] + if cons, ok := cond.(*Constant); ok { + v, isNull, err := cons.EvalInt(ctx, chunk.Row{}) + if err != nil { + log.Error(err) + return + } + if !isNull && v == 0 { + exprs.SetConstFalse() + } + } +} + +// ruleColumnEQConst propagates the "column = const" condition. +// "a = 3, b = a, c = a, d = b" => "a = 3, b = 3, c = 3, d = 3" +func ruleColumnEQConst(ctx sessionctx.Context, i, j int, exprs *exprSet) { + col, cons := validEqualCond(exprs.data[i]) + if col != nil { + expr := ColumnSubstitute(exprs.data[j], NewSchema(col), []Expression{cons}) + stmtctx := ctx.GetSessionVars().StmtCtx + if bytes.Compare(expr.HashCode(stmtctx), exprs.data[j].HashCode(stmtctx)) != 0 { + exprs.Append(expr) + exprs.tombstone[j] = true + } + } +}