Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: fix wrong behavior for = all() (#52801) #53256

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/explaintest/r/select.result
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ explain format = 'brief' select a = all (select a from t t2) from t t1;
id estRows task access object operator info
Projection 10000.00 root or(and(and(le(Column#11, 1), eq(test.t.a, Column#10)), if(ne(Column#12, 0), <nil>, 1)), or(eq(Column#13, 0), if(isnull(test.t.a), <nil>, 0)))->Column#14
└─HashJoin 10000.00 root CARTESIAN inner join
├─StreamAgg(Build) 1.00 root funcs:firstrow(Column#16)->Column#10, funcs:count(distinct Column#17)->Column#11, funcs:sum(Column#18)->Column#12, funcs:count(1)->Column#13
├─StreamAgg(Build) 1.00 root funcs:max(Column#16)->Column#10, funcs:count(distinct Column#17)->Column#11, funcs:sum(Column#18)->Column#12, funcs:count(1)->Column#13
│ └─Projection 10000.00 root test.t.a, test.t.a, cast(isnull(test.t.a), decimal(20,0) BINARY)->Column#18
│ └─TableReader 10000.00 root data:TableFullScan
│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
Expand Down
22 changes: 22 additions & 0 deletions planner/core/casetest/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,25 @@ func TestJSONPlanInExplain(t *testing.T) {
}
}
}

func TestHandleEQAll(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("CREATE TABLE t1 (c1 int, c2 int, UNIQUE i1 (c1, c2));")
tk.MustExec("INSERT INTO t1 VALUES (7, null),(5,1);")
tk.MustQuery("SELECT c1 FROM t1 WHERE ('m' = ALL (SELECT /*+ IGNORE_INDEX(t1, i1) */ c2 FROM t1)) IS NOT UNKNOWN; ").Check(testkit.Rows("5", "7"))
tk.MustQuery("SELECT c1 FROM t1 WHERE ('m' = ALL (SELECT /*+ use_INDEX(t1, i1) */ c2 FROM t1)) IS NOT UNKNOWN; ").Check(testkit.Rows("5", "7"))
tk.MustQuery("select (null = ALL (SELECT /*+ NO_INDEX() */ c2 FROM t1)) IS NOT UNKNOWN").Check(testkit.Rows("0"))
tk.MustExec("CREATE TABLE t2 (c1 int, c2 int, UNIQUE i1 (c1, c2));")
tk.MustExec("INSERT INTO t2 VALUES (7, null),(5,null);")
tk.MustQuery("select (null = ALL (SELECT /*+ NO_INDEX() */ c2 FROM t2)) IS NOT UNKNOWN").Check(testkit.Rows("0"))
tk.MustQuery("SELECT c1 FROM t2 WHERE ('m' = ALL (SELECT /*+ IGNORE_INDEX(t2, i1) */ c2 FROM t2)) IS NOT UNKNOWN; ").Check(testkit.Rows())
tk.MustQuery("SELECT c1 FROM t2 WHERE ('m' = ALL (SELECT /*+ use_INDEX(t2, i1) */ c2 FROM t2)) IS NOT UNKNOWN; ").Check(testkit.Rows())
tk.MustExec("truncate table t2")
tk.MustExec("INSERT INTO t2 VALUES (7, null),(7,null);")
tk.MustQuery("select c1 from t2 where (c1 = all (select /*+ IGNORE_INDEX(t2, i1) */ c1 from t2))").Check(testkit.Rows("7", "7"))
tk.MustQuery("select c1 from t2 where (c1 = all (select /*+ use_INDEX(t2, i1) */ c1 from t2))").Check(testkit.Rows("7", "7"))
tk.MustQuery("select c2 from t2 where (c2 = all (select /*+ IGNORE_INDEX(t2, i1) */ c2 from t2))").Check(testkit.Rows())
tk.MustQuery("select c2 from t2 where (c2 = all (select /*+ use_INDEX(t2, i1) */ c2 from t2))").Check(testkit.Rows())
}
26 changes: 9 additions & 17 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,9 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np
// handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to
// t.id = (select s.id from s having count(distinct s.id) <= 1 and [all checker]).
func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np LogicalPlan, markNoDecorrelate bool) {
firstRowFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false)
// If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id == s.id.
// So use function max to filter NULL.
maxFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncMax, []expression.Expression{rexpr}, false)
if err != nil {
er.err = err
return
Expand All @@ -789,37 +791,27 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np
return
}
plan4Agg := LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc},
}.Init(er.sctx, er.b.getSelectOffset())
if hint := er.b.TableHints(); hint != nil {
plan4Agg.aggHints = hint.aggHints
}
plan4Agg.SetChildren(np)
plan4Agg.names = append(plan4Agg.names, types.EmptyName)

// Currently, firstrow agg function is treated like the exact representation of aggregate group key,
// so the data type is the same with group key, even if the group key is not null.
// However, the return type of firstrow should be nullable, we clear the null flag here instead of
// during invoking NewAggFuncDesc, in order to keep compatibility with the existing presumption
// that the return type firstrow does not change nullability, whatsoever.
// Cloning it because the return type is the same object with argument's data type.
newRetTp := firstRowFunc.RetTp.Clone()
newRetTp.DelFlag(mysql.NotNullFlag)
firstRowFunc.RetTp = newRetTp

firstRowResultCol := &expression.Column{
maxResultCol := &expression.Column{
UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
RetType: firstRowFunc.RetTp,
RetType: maxFunc.RetTp,
}
firstRowResultCol.SetCoercibility(rexpr.Coercibility())
maxResultCol.SetCoercibility(rexpr.Coercibility())
plan4Agg.names = append(plan4Agg.names, types.EmptyName)
count := &expression.Column{
UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
RetType: countFunc.RetTp,
}
plan4Agg.SetSchema(expression.NewSchema(firstRowResultCol, count))
plan4Agg.SetSchema(expression.NewSchema(maxResultCol, count))
leFunc := expression.NewFunctionInternal(er.sctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.NewOne())
eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol)
eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, maxResultCol)
cond := expression.ComposeCNFCondition(er.sctx, leFunc, eqCond)
er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, true, markNoDecorrelate)
}
Expand Down