Skip to content

Commit

Permalink
expression: propagate more filters in PropagateConstant (#7276)
Browse files Browse the repository at this point in the history
  • Loading branch information
bb7133 authored and zz-jason committed Aug 30, 2018
1 parent b4fdaf3 commit 7f37bad
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 73 deletions.
11 changes: 6 additions & 5 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,12 @@ Projection_11 10000.00 root 9_aux_0
├─TableReader_15 10000.00 root data:TableScan_14
│ └─TableScan_14 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo
└─StreamAgg_20 1.00 root funcs:count(1)
└─IndexJoin_43 10000.00 root inner join, inner:TableReader_42, outer key:s.a, inner key:t1.a
├─TableReader_51 1.00 root data:TableScan_50
│ └─TableScan_50 1.00 cop table:s, range: decided by [eq(s.a, test.t.a)], keep order:false, stats:pseudo
└─TableReader_42 10.00 root data:TableScan_41
└─TableScan_41 10.00 cop table:t1, range: decided by [s.a], keep order:false, stats:pseudo
└─IndexJoin_44 10000.00 root inner join, inner:TableReader_43, outer key:s.a, inner key:t1.a
├─TableReader_52 1.00 root data:TableScan_51
│ └─TableScan_51 1.00 cop table:s, range: decided by [eq(s.a, test.t.a)], keep order:false, stats:pseudo
└─TableReader_43 8000.00 root data:Selection_42
└─Selection_42 8000.00 cop eq(t1.a, test.t.a)
└─TableScan_41 10.00 cop table:t1, range: decided by [s.a], keep order:false, stats:pseudo
explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.a = t1.a) from t;
id count task operator info
Projection_11 10000.00 root 9_aux_0
Expand Down
182 changes: 120 additions & 62 deletions expression/constant_propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,6 @@ import (
// MaxPropagateColsCnt means the max number of columns that can participate propagation.
var MaxPropagateColsCnt = 100

var eqFuncNameMap = map[string]bool{
ast.EQ: true,
}

// inEqFuncNameMap stores all the in-equal operators that can be propagated.
var inEqFuncNameMap = map[string]bool{
ast.LT: true,
ast.GT: true,
ast.LE: true,
ast.GE: true,
ast.NE: true,
}

type multiEqualSet struct {
parent []int
}
Expand Down Expand Up @@ -72,51 +59,12 @@ type propagateConstantSolver struct {
ctx sessionctx.Context
}

// propagateInEQ propagates all in-equal conditions.
// e.g. For expression a = b and b = c and c = d and c < 1 , we can get extra a < 1 and b < 1 and d < 1.
// We maintain a unionSet representing the equivalent for every two columns.
func (s *propagateConstantSolver) propagateInEQ() {
s.unionSet = &multiEqualSet{}
s.unionSet.init(len(s.columns))
for i := range s.conditions {
if fun, ok := s.conditions[i].(*ScalarFunction); ok && fun.FuncName.L == ast.EQ {
lCol, lOk := fun.GetArgs()[0].(*Column)
rCol, rOk := fun.GetArgs()[1].(*Column)
if lOk && rOk {
lID := s.getColID(lCol)
rID := s.getColID(rCol)
s.unionSet.addRelation(lID, rID)
}
}
}
condsLen := len(s.conditions)
for i := 0; i < condsLen; i++ {
cond := s.conditions[i]
col, con := s.validPropagateCond(cond, inEqFuncNameMap)
if col == nil {
continue
}
id := s.getColID(col)
for j := range s.columns {
if id != j && s.unionSet.findRoot(id) == s.unionSet.findRoot(j) {
funName := cond.(*ScalarFunction).FuncName.L
var newExpr Expression
if _, ok := cond.(*ScalarFunction).GetArgs()[0].(*Column); ok {
newExpr = NewFunctionInternal(s.ctx, funName, cond.GetType(), s.columns[j], con)
} else {
newExpr = NewFunctionInternal(s.ctx, funName, cond.GetType(), con, s.columns[j])
}
s.conditions = append(s.conditions, newExpr)
}
}
}
}

// propagateEQ propagates equal expression multiple times. An example runs as following:
// propagateConstantEQ propagates expressions like 'column = constant' by substituting the constant for column, the
// procedure repeats multiple times. An example runs as following:
// a = d & b * 2 = c & c = d + 2 & b = 1 & a = 4, we pick eq cond b = 1 and a = 4
// d = 4 & 2 = c & c = d + 2 & b = 1 & a = 4, we propagate b = 1 and a = 4 and pick eq cond c = 2 and d = 4
// d = 4 & 2 = c & false & b = 1 & a = 4, we propagate c = 2 and d = 4, and do constant folding: c = d + 2 will be folded as false.
func (s *propagateConstantSolver) propagateEQ() {
func (s *propagateConstantSolver) propagateConstantEQ() {
s.eqList = make([]*Constant, len(s.columns))
visited := make([]bool, len(s.conditions))
for i := 0; i < MaxPropagateColsCnt; i++ {
Expand All @@ -138,10 +86,72 @@ func (s *propagateConstantSolver) propagateEQ() {
}
}

// validPropagateCond checks if the cond is an expression like [column op constant] and op is in the funNameMap.
func (s *propagateConstantSolver) validPropagateCond(cond Expression, funNameMap map[string]bool) (*Column, *Constant) {
// propagateColumnEQ propagates expressions like 'column A = column B' by adding extra filters
// 'expression(..., column B, ...)' propagated from 'expression(..., column A, ...)' as long as:
//
// 1. The expression is deterministic
// 2. The expression doesn't have any side effect
//
// e.g. For expression a = b and b = c and c = d and c < 1 , we can get extra a < 1 and b < 1 and d < 1.
// However, for a = b and a < rand(), we cannot propagate a < rand() to b < rand() because rand() is non-deterministic
//
// This propagation may bring redundancies that we need to resolve later, for example:
// for a = b and a < 3 and b < 3, we get new a < 3 and b < 3, which are redundant
// for a = b and a < 3 and 3 > b, we get new b < 3 and 3 > a, which are redundant
// for a = b and a < 3 and b < 4, we get new a < 4 and b < 3 but should expect a < 3 and b < 3
// for a = b and a in (3) and b in (4), we get b in (3) and a in (4) but should expect 'false'
//
// TODO: remove redundancies later
//
// We maintain a unionSet representing the equivalent for every two columns.
func (s *propagateConstantSolver) propagateColumnEQ() {
visited := make([]bool, len(s.conditions))
s.unionSet = &multiEqualSet{}
s.unionSet.init(len(s.columns))
for i := range s.conditions {
if fun, ok := s.conditions[i].(*ScalarFunction); ok && fun.FuncName.L == ast.EQ {
lCol, lOk := fun.GetArgs()[0].(*Column)
rCol, rOk := fun.GetArgs()[1].(*Column)
if lOk && rOk {
lID := s.getColID(lCol)
rID := s.getColID(rCol)
s.unionSet.addRelation(lID, rID)
visited[i] = true
}
}
}

condsLen := len(s.conditions)
for i, coli := range s.columns {
for j := i + 1; j < len(s.columns); j++ {
// unionSet doesn't have iterate(), we use a two layer loop to iterate col_i = col_j relation
if s.unionSet.findRoot(i) != s.unionSet.findRoot(j) {
continue
}
colj := s.columns[j]
for k := 0; k < condsLen; k++ {
if visited[k] {
// cond_k has been used to retrieve equality relation
continue
}
cond := s.conditions[k]
replaced, _, newExpr := s.tryToReplaceCond(coli, colj, cond)
if replaced {
s.conditions = append(s.conditions, newExpr)
}
replaced, _, newExpr = s.tryToReplaceCond(colj, coli, cond)
if replaced {
s.conditions = append(s.conditions, newExpr)
}
}
}
}
}

// validEqualCond checks if the cond is an expression like [column eq constant].
func (s *propagateConstantSolver) validEqualCond(cond Expression) (*Column, *Constant) {
if eq, ok := cond.(*ScalarFunction); ok {
if _, ok := funNameMap[eq.FuncName.L]; !ok {
if eq.FuncName.L != ast.EQ {
return nil, nil
}
if col, colOk := eq.GetArgs()[0].(*Column); colOk {
Expand All @@ -158,6 +168,54 @@ func (s *propagateConstantSolver) validPropagateCond(cond Expression, funNameMap
return nil, nil
}

// tryToReplaceCond aims to replace all occurrences of column 'src' and try to replace it with 'tgt' in 'cond'
// It returns
// bool: if a replacement happened
// bool: if 'cond' contains non-deterministic expression
// Expression: the replaced expression, or original 'cond' if the replacement didn't happen
//
// For example:
// for 'a, b, a < 3', it returns 'true, false, b < 3'
// for 'a, b, sin(a) + cos(a) = 5', it returns 'true, false, returns sin(b) + cos(b) = 5'
// for 'a, b, cast(a) < rand()', it returns 'false, true, cast(a) < rand()'
func (s *propagateConstantSolver) tryToReplaceCond(src *Column, tgt *Column, cond Expression) (bool, bool, Expression) {
sf, ok := cond.(*ScalarFunction)
if !ok {
return false, false, cond
}
replaced := false
var args []Expression
if _, ok := unFoldableFunctions[sf.FuncName.L]; ok {
return false, true, cond
}
for idx, expr := range sf.GetArgs() {
if src.Equal(nil, expr) {
replaced = true
if args == nil {
args = make([]Expression, len(sf.GetArgs()))
copy(args, sf.GetArgs())
}
args[idx] = tgt
} else {
subReplaced, isNonDeterminisitic, subExpr := s.tryToReplaceCond(src, tgt, expr)
if isNonDeterminisitic {
return false, true, cond
} else if subReplaced {
replaced = true
if args == nil {
args = make([]Expression, len(sf.GetArgs()))
copy(args, sf.GetArgs())
}
args[idx] = subExpr
}
}
}
if replaced {
return true, false, NewFunctionInternal(s.ctx, sf.FuncName.L, sf.GetType(), args...)
}
return false, false, cond
}

func (s *propagateConstantSolver) setConds2ConstFalse() {
s.conditions = []Expression{&Constant{
Value: types.NewDatum(false),
Expand All @@ -172,7 +230,7 @@ func (s *propagateConstantSolver) pickNewEQConds(visited []bool) (retMapper map[
if visited[i] {
continue
}
col, con := s.validPropagateCond(cond, eqFuncNameMap)
col, con := s.validEqualCond(cond)
// Then we check if this CNF item is a false constant. If so, we will set the whole condition to false.
var ok bool
if col == nil {
Expand Down Expand Up @@ -227,8 +285,8 @@ func (s *propagateConstantSolver) solve(conditions []Expression) []Expression {
log.Warnf("[const_propagation]Too many columns in a single CNF: the column count is %d, the max count is %d.", len(s.columns), MaxPropagateColsCnt)
return conditions
}
s.propagateEQ()
s.propagateInEQ()
s.propagateConstantEQ()
s.propagateColumnEQ()
for i, cond := range s.conditions {
if dnf, ok := cond.(*ScalarFunction); ok && dnf.FuncName.L == ast.LogicOr {
dnfItems := SplitDNFItems(cond)
Expand All @@ -255,7 +313,7 @@ func (s *propagateConstantSolver) insertCol(col *Column) {
}
}

// PropagateConstant propagate constant values of equality predicates and inequality predicates in a condition.
// PropagateConstant propagate constant values of deterministic predicates in a condition.
func PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression {
solver := &propagateConstantSolver{
colMapper: make(map[string]int),
Expand Down
36 changes: 36 additions & 0 deletions expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,42 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) {
},
result: "0",
},
{
conditions: []Expression{
newFunction(ast.EQ, newColumn(0), newColumn(1)),
newFunction(ast.In, newColumn(0), newLonglong(1), newLonglong(2)),
newFunction(ast.In, newColumn(1), newLonglong(3), newLonglong(4)),
},
result: "eq(test.t.0, test.t.1), in(test.t.0, 1, 2), in(test.t.0, 3, 4), in(test.t.1, 1, 2), in(test.t.1, 3, 4)",
},
{
conditions: []Expression{
newFunction(ast.EQ, newColumn(0), newColumn(1)),
newFunction(ast.EQ, newColumn(0), newFunction(ast.BitLength, newColumn(2))),
},
result: "eq(test.t.0, bit_length(cast(test.t.2))), eq(test.t.0, test.t.1), eq(test.t.1, bit_length(cast(test.t.2)))",
},
{
conditions: []Expression{
newFunction(ast.EQ, newColumn(0), newColumn(1)),
newFunction(ast.LE, newFunction(ast.Mul, newColumn(0), newColumn(0)), newLonglong(50)),
},
result: "eq(test.t.0, test.t.1), le(mul(test.t.0, test.t.0), 50), le(mul(test.t.1, test.t.1), 50)",
},
{
conditions: []Expression{
newFunction(ast.EQ, newColumn(0), newColumn(1)),
newFunction(ast.LE, newColumn(0), newFunction(ast.Plus, newColumn(1), newLonglong(1))),
},
result: "eq(test.t.0, test.t.1), le(test.t.0, plus(test.t.0, 1)), le(test.t.0, plus(test.t.1, 1)), le(test.t.1, plus(test.t.1, 1))",
},
{
conditions: []Expression{
newFunction(ast.EQ, newColumn(0), newColumn(1)),
newFunction(ast.LE, newColumn(0), newFunction(ast.Rand)),
},
result: "eq(test.t.0, test.t.1), le(cast(test.t.0), rand())",
},
}
for _, tt := range tests {
ctx := mock.NewContext()
Expand Down
9 changes: 5 additions & 4 deletions plan/cbo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var _ = Suite(&testAnalyzeSuite{})
type testAnalyzeSuite struct {
}

// CBOWithoutAnalyze tests the plan with stats that only have count info.
// TestCBOWithoutAnalyze tests the plan with stats that only have count info.
func (s *testAnalyzeSuite) TestCBOWithoutAnalyze(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
Expand Down Expand Up @@ -633,12 +633,13 @@ func (s *testAnalyzeSuite) TestCorrelatedEstimation(c *C) {
" ├─TableReader_15 10.00 root data:TableScan_14",
" │ └─TableScan_14 10.00 cop table:t, range:[-inf,+inf], keep order:false",
" └─StreamAgg_20 1.00 root funcs:count(1)",
" └─HashRightJoin_22 1.00 root inner join, inner:TableReader_25, equal:[eq(s.a, t1.a)]",
" └─HashLeftJoin_21 1.00 root inner join, inner:TableReader_28, equal:[eq(s.a, t1.a)]",
" ├─TableReader_25 1.00 root data:Selection_24",
" │ └─Selection_24 1.00 cop eq(s.a, test.t.a)",
" │ └─TableScan_23 10.00 cop table:s, range:[-inf,+inf], keep order:false",
" └─TableReader_27 10.00 root data:TableScan_26",
" └─TableScan_26 10.00 cop table:t1, range:[-inf,+inf], keep order:false",
" └─TableReader_28 1.00 root data:Selection_27",
" └─Selection_27 1.00 cop eq(t1.a, test.t.a)",
" └─TableScan_26 10.00 cop table:t1, range:[-inf,+inf], keep order:false",
))
tk.MustQuery("explain select (select concat(t1.a, \",\", t1.b) from t t1 where t1.a=t.a and t1.c=t.c) from t").
Check(testkit.Rows(
Expand Down
2 changes: 1 addition & 1 deletion plan/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ func (s *testPlanSuite) TestPredicatePushDown(c *C) {
},
{
sql: "select * from t t1, t t2 where t1.a = t2.b and t2.b > 0 and t1.a = t1.c and t1.d like 'abc' and t2.d = t1.d",
best: "Join{DataScan(t2)->DataScan(t1)->Sel([like(cast(t1.d), abc, 92)])}(t2.b,t1.a)(t2.d,t1.d)->Projection",
best: "Join{DataScan(t2)->Sel([like(cast(t2.d), abc, 92)])->DataScan(t1)->Sel([like(cast(t1.d), abc, 92)])}(t2.b,t1.a)(t2.d,t1.d)->Projection",
},
{
sql: "select * from t ta join t tb on ta.d = tb.d and ta.d > 1 where tb.a = 0",
Expand Down
2 changes: 1 addition & 1 deletion plan/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderSubquery(c *C) {
// Test Apply.
{
sql: "select t.c in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t",
best: "Apply{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(s.a,t1.a)->StreamAgg}->Projection",
best: "Apply{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->TableReader(Table(t)->Sel([eq(t1.a, test.t.a)]))}(s.a,t1.a)->StreamAgg}->Projection",
},
{
sql: "select (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t",
Expand Down

0 comments on commit 7f37bad

Please sign in to comment.