Skip to content

Commit

Permalink
some updates on XuHuaiyu's comments:
Browse files Browse the repository at this point in the history
1. add propagation for EQ condition like 'a eq func'
2. refined some method/variable names
3. added more comments
4. added 1 more unit test

Signed-off-by: bb7133 <[email protected]>
  • Loading branch information
bb7133 committed Aug 29, 2018
1 parent 4f153e4 commit fd15734
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 57 deletions.
11 changes: 6 additions & 5 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,12 @@ Projection_11 10000.00 root 9_aux_0
├─TableReader_15 10000.00 root data:TableScan_14
│ └─TableScan_14 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo
└─StreamAgg_20 1.00 root funcs:count(1)
└─IndexJoin_43 10000.00 root inner join, inner:TableReader_42, outer key:s.a, inner key:t1.a
├─TableReader_51 1.00 root data:TableScan_50
│ └─TableScan_50 1.00 cop table:s, range: decided by [eq(s.a, test.t.a)], keep order:false, stats:pseudo
└─TableReader_42 10.00 root data:TableScan_41
└─TableScan_41 10.00 cop table:t1, range: decided by [s.a], keep order:false, stats:pseudo
└─IndexJoin_44 10000.00 root inner join, inner:TableReader_43, outer key:s.a, inner key:t1.a
├─TableReader_52 1.00 root data:TableScan_51
│ └─TableScan_51 1.00 cop table:s, range: decided by [eq(s.a, test.t.a)], keep order:false, stats:pseudo
└─TableReader_43 8000.00 root data:Selection_42
└─Selection_42 8000.00 cop eq(t1.a, test.t.a)
└─TableScan_41 10.00 cop table:t1, range: decided by [s.a], keep order:false, stats:pseudo
explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.a = t1.a) from t;
id count task operator info
Projection_11 10000.00 root 9_aux_0
Expand Down
119 changes: 72 additions & 47 deletions expression/constant_propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ type propagateConstantSolver struct {
ctx sessionctx.Context
}

// propagateEQ propagates equal expression multiple times. An example runs as following:
// propagateConstantEQ propagates expressions like 'column = constant' by substituting the constant for column, the
// procedure repeats multiple times. An example runs as following:
// a = d & b * 2 = c & c = d + 2 & b = 1 & a = 4, we pick eq cond b = 1 and a = 4
// d = 4 & 2 = c & c = d + 2 & b = 1 & a = 4, we propagate b = 1 and a = 4 and pick eq cond c = 2 and d = 4
// d = 4 & 2 = c & false & b = 1 & a = 4, we propagate c = 2 and d = 4, and do constant folding: c = d + 2 will be folded as false.
func (s *propagateConstantSolver) propagateEQ() {
func (s *propagateConstantSolver) propagateConstantEQ() {
s.eqList = make([]*Constant, len(s.columns))
visited := make([]bool, len(s.conditions))
for i := 0; i < MaxPropagateColsCnt; i++ {
Expand All @@ -85,11 +86,26 @@ func (s *propagateConstantSolver) propagateEQ() {
}
}

// propagateOthers propagates all deterministic conditions.
// propagateColumnEQ propagates expressions like 'column A = column B' by adding extra filters
// 'expression(..., column B, ...)' propagated from 'expression(..., column A, ...)' as long as:
//
// 1. The expression is deterministic
// 2. The expression doesn't have any side effect
//
// e.g. For expression a = b and b = c and c = d and c < 1 , we can get extra a < 1 and b < 1 and d < 1.
// However, for a = b and a < rand(), we cannot propagate a < rand() to b < rand() because rand() is non-deterministic
//
// This propagation may bring redundancies that we need to resolve later, for example:
// for a = b and a < 3 and b < 3, we get new a < 3 and b < 3, which are redundant
// for a = b and a < 3 and 3 > b, we get new b < 3 and 3 > a, which are redundant
// for a = b and a < 3 and b < 4, we get new a < 4 and b < 3 but should expect a < 3 and b < 3
// for a = b and a in (3) and b in (4), we get b in (3) and a in (4) but should expect 'false'
//
// TODO: remove redundancies later
//
// We maintain a unionSet representing the equivalent for every two columns.
func (s *propagateConstantSolver) propagateOthers() {
func (s *propagateConstantSolver) propagateColumnEQ() {
visited := make([]bool, len(s.conditions))
s.unionSet = &multiEqualSet{}
s.unionSet.init(len(s.columns))
for i := range s.conditions {
Expand All @@ -100,24 +116,30 @@ func (s *propagateConstantSolver) propagateOthers() {
lID := s.getColID(lCol)
rID := s.getColID(rCol)
s.unionSet.addRelation(lID, rID)
visited[i] = true
}
}
}

condsLen := len(s.conditions)
for j, colj := range s.columns {
for k := j + 1; k < len(s.columns); k++ {
if s.unionSet.findRoot(j) != s.unionSet.findRoot(k) {
for i, coli := range s.columns {
for j := i + 1; j < len(s.columns); j++ {
// unionSet doesn't have iterate(), we use a two layer loop to iterate col_i = col_j relation
if s.unionSet.findRoot(i) != s.unionSet.findRoot(j) {
continue
}
colk := s.columns[k]
for i := 0; i < condsLen; i++ {
cond := s.conditions[i]
replaced, _, newExpr := s.tryToReplaceCond(colj, colk, cond)
colj := s.columns[j]
for k := 0; k < condsLen; k++ {
if visited[k] {
// cond_k has been used to retrieve equality relation
continue
}
cond := s.conditions[k]
replaced, _, newExpr := s.tryToReplaceCond(coli, colj, cond)
if replaced {
s.conditions = append(s.conditions, newExpr)
}
replaced, _, newExpr = s.tryToReplaceCond(colk, colj, cond)
replaced, _, newExpr = s.tryToReplaceCond(colj, coli, cond)
if replaced {
s.conditions = append(s.conditions, newExpr)
}
Expand All @@ -126,8 +148,8 @@ func (s *propagateConstantSolver) propagateOthers() {
}
}

// validPropagateCond checks if the cond is an expression like [column op constant] and op is in the funNameMap.
func (s *propagateConstantSolver) validPropagateCond(cond Expression) (*Column, *Constant) {
// validEqualCond checks if the cond is an expression like [column eq constant].
func (s *propagateConstantSolver) validEqualCond(cond Expression) (*Column, *Constant) {
if eq, ok := cond.(*ScalarFunction); ok {
if eq.FuncName.L != ast.EQ {
return nil, nil
Expand All @@ -146,47 +168,50 @@ func (s *propagateConstantSolver) validPropagateCond(cond Expression) (*Column,
return nil, nil
}

// tryToReplaceCond aims to replace all occurances of column 'src' and try to replace it with 'tgt' in 'cond'
// tryToReplaceCond aims to replace all occurrences of column 'src' and try to replace it with 'tgt' in 'cond'
// It returns
// bool: if a replacement happened
// bool: if 'cond' contains non-deterministic expression
// Expression: the replaced expression, or original 'cond' if the replacement is not happened
// Expression: the replaced expression, or original 'cond' if the replacement didn't happen
//
// For example:
// for 'a, b, a < 3', it returns 'true, false, b < 3'
// for 'a, b, sin(a) + cos(a) = 5', it returns 'true, false, returns sin(b) + cos(b) = 5'
// for 'a, b, cast(a) < rand()', it returns 'false, true, cast(a) < rand()'
func (s *propagateConstantSolver) tryToReplaceCond(src *Column, tgt *Column, cond Expression) (bool, bool, Expression) {
if sf, ok := cond.(*ScalarFunction); ok {
replaced := false
var args []Expression
if _, ok := unFoldableFunctions[sf.FuncName.L]; ok {
return false, true, cond
}
// Equality is handled in propagateEQ already
if sf.FuncName.L == ast.EQ {
return false, false, cond
}
for idx, expr := range sf.GetArgs() {
if src.Equal(nil, expr) {
sf, ok := cond.(*ScalarFunction)
if !ok {
return false, false, cond
}
replaced := false
var args []Expression
if _, ok := unFoldableFunctions[sf.FuncName.L]; ok {
return false, true, cond
}
for idx, expr := range sf.GetArgs() {
if src.Equal(nil, expr) {
replaced = true
if args == nil {
args = make([]Expression, len(sf.GetArgs()))
copy(args, sf.GetArgs())
}
args[idx] = tgt
} else {
subReplaced, isNonDeterminisitic, subExpr := s.tryToReplaceCond(src, tgt, expr)
if isNonDeterminisitic {
return false, true, cond
} else if subReplaced {
replaced = true
if args == nil {
args = make([]Expression, len(sf.GetArgs()))
copy(args, sf.GetArgs())
}
args[idx] = tgt
} else {
subReplaced, isNonDeterminisitic, subExpr := s.tryToReplaceCond(src, tgt, expr)
if isNonDeterminisitic {
return false, true, cond
} else if subReplaced {
replaced = true
if args == nil {
args = make([]Expression, len(sf.GetArgs()))
copy(args, sf.GetArgs())
}
args[idx] = subExpr
}
args[idx] = subExpr
}
}
if replaced {
return true, false, NewFunctionInternal(s.ctx, sf.FuncName.L, sf.GetType(), args...)
}
}
if replaced {
return true, false, NewFunctionInternal(s.ctx, sf.FuncName.L, sf.GetType(), args...)
}
return false, false, cond
}
Expand All @@ -205,7 +230,7 @@ func (s *propagateConstantSolver) pickNewEQConds(visited []bool) (retMapper map[
if visited[i] {
continue
}
col, con := s.validPropagateCond(cond)
col, con := s.validEqualCond(cond)
// Then we check if this CNF item is a false constant. If so, we will set the whole condition to false.
var ok bool
if col == nil {
Expand Down Expand Up @@ -260,8 +285,8 @@ func (s *propagateConstantSolver) solve(conditions []Expression) []Expression {
log.Warnf("[const_propagation]Too many columns in a single CNF: the column count is %d, the max count is %d.", len(s.columns), MaxPropagateColsCnt)
return conditions
}
s.propagateEQ()
s.propagateOthers()
s.propagateConstantEQ()
s.propagateColumnEQ()
for i, cond := range s.conditions {
if dnf, ok := cond.(*ScalarFunction); ok && dnf.FuncName.L == ast.LogicOr {
dnfItems := SplitDNFItems(cond)
Expand Down
7 changes: 7 additions & 0 deletions expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) {
},
result: "eq(test.t.0, test.t.1), in(test.t.0, 1, 2), in(test.t.0, 3, 4), in(test.t.1, 1, 2), in(test.t.1, 3, 4)",
},
{
conditions: []Expression{
newFunction(ast.EQ, newColumn(0), newColumn(1)),
newFunction(ast.EQ, newColumn(0), newFunction(ast.BitLength, newColumn(2))),
},
result: "eq(test.t.0, bit_length(cast(test.t.2))), eq(test.t.0, test.t.1), eq(test.t.1, bit_length(cast(test.t.2)))",
},
{
conditions: []Expression{
newFunction(ast.EQ, newColumn(0), newColumn(1)),
Expand Down
9 changes: 5 additions & 4 deletions plan/cbo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var _ = Suite(&testAnalyzeSuite{})
type testAnalyzeSuite struct {
}

// CBOWithoutAnalyze tests the plan with stats that only have count info.
// TestCBOWithoutAnalyze tests the plan with stats that only have count info.
func (s *testAnalyzeSuite) TestCBOWithoutAnalyze(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
Expand Down Expand Up @@ -633,12 +633,13 @@ func (s *testAnalyzeSuite) TestCorrelatedEstimation(c *C) {
" ├─TableReader_15 10.00 root data:TableScan_14",
" │ └─TableScan_14 10.00 cop table:t, range:[-inf,+inf], keep order:false",
" └─StreamAgg_20 1.00 root funcs:count(1)",
" └─HashRightJoin_22 1.00 root inner join, inner:TableReader_25, equal:[eq(s.a, t1.a)]",
" └─HashLeftJoin_21 1.00 root inner join, inner:TableReader_28, equal:[eq(s.a, t1.a)]",
" ├─TableReader_25 1.00 root data:Selection_24",
" │ └─Selection_24 1.00 cop eq(s.a, test.t.a)",
" │ └─TableScan_23 10.00 cop table:s, range:[-inf,+inf], keep order:false",
" └─TableReader_27 10.00 root data:TableScan_26",
" └─TableScan_26 10.00 cop table:t1, range:[-inf,+inf], keep order:false",
" └─TableReader_28 1.00 root data:Selection_27",
" └─Selection_27 1.00 cop eq(t1.a, test.t.a)",
" └─TableScan_26 10.00 cop table:t1, range:[-inf,+inf], keep order:false",
))
tk.MustQuery("explain select (select concat(t1.a, \",\", t1.b) from t t1 where t1.a=t.a and t1.c=t.c) from t").
Check(testkit.Rows(
Expand Down
2 changes: 1 addition & 1 deletion plan/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderSubquery(c *C) {
// Test Apply.
{
sql: "select t.c in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t",
best: "Apply{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(s.a,t1.a)->StreamAgg}->Projection",
best: "Apply{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->TableReader(Table(t)->Sel([eq(t1.a, test.t.a)]))}(s.a,t1.a)->StreamAgg}->Projection",
},
{
sql: "select (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t",
Expand Down

0 comments on commit fd15734

Please sign in to comment.