Skip to content

Commit

Permalink
planner: fix wrong result when pushing Agg down through Union in MPP …
Browse files Browse the repository at this point in the history
…plans (#46310) (#46515)

close #45850
  • Loading branch information
ti-chi-bot authored Oct 31, 2023
1 parent f2faaa4 commit 789564b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 10 deletions.
42 changes: 42 additions & 0 deletions executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,48 @@ func TestAggPushDownCountStar(t *testing.T) {
tk.MustQuery("select count(*) from c, o where c.c_id=o.c_id").Check(testkit.Rows("5"))
}

func TestAggPushDownUnionAndMPP(t *testing.T) {
store, clean := testkit.CreateMockStore(t, withMockTiFlash(2))
defer clean()
tk := testkit.NewTestKit(t, store)

tk.MustExec("use test")
tk.MustExec("create table t (a int, b int)")
tk.MustExec("alter table t set tiflash replica 1")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("set @@tidb_allow_mpp=1;")
tk.MustExec("set @@tidb_enforce_mpp=1;")
tk.MustExec("set @@tidb_opt_agg_push_down=1")
tb := external.GetTableByName(t, tk, "test", "t")
err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)

tk.MustExec("create table c(c_id int)")
tk.MustExec("create table o(o_id int, c_id int)")
tk.MustExec("insert into c values(1),(1),(1),(1)")
tk.MustExec("insert into o values(1,1),(1,1),(1,2)")
tk.MustExec("alter table c set tiflash replica 1")
tk.MustExec("alter table o set tiflash replica 1")
tb = external.GetTableByName(t, tk, "test", "c")
err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
tk.MustExec("alter table o set tiflash replica 1")
tb = external.GetTableByName(t, tk, "test", "o")
err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)

tk.MustQuery("select a, count(*) from (select a, b from t " +
"union all " +
"select a, b from t" +
") t group by a order by a limit 10;").Check(testkit.Rows("1 10"))

tk.MustQuery("select o.o_id, count(*) from c, o where c.c_id=o.o_id group by o.o_id").Check(testkit.Rows("1 12"))
}

func TestGroupStreamAggOnTiFlash(t *testing.T) {
store, clean := testkit.CreateMockStore(t, withMockTiFlash(2))
defer clean()
Expand Down
9 changes: 8 additions & 1 deletion planner/core/enforce_mpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,20 @@ func TestMPP2PhaseAggPushDown(t *testing.T) {
tk.MustExec("create table c(c_id bigint)")
tk.MustExec("create table o(o_id bigint, c_id bigint not null)")

tk.MustExec("create table t (a int, b int)")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" {
if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" || tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
Expand Down
13 changes: 13 additions & 0 deletions planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -2687,6 +2687,18 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
// Is this aggregate a final stage aggregate?
// Final agg can't be split into multi-stage aggregate
hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode
// count final agg should become sum for MPP execution path.
// In the traditional case, TiDB take up the final agg role and push partial agg to TiKV,
// while TiDB can tell the partialMode and do the sum computation rather than counting but MPP doesn't
finalAggAdjust := func(aggFuncs []*aggregation.AggFuncDesc) {
for i, agg := range aggFuncs {
if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount {
oldFt := agg.RetTp
aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncSum, agg.Args, false)
aggFuncs[i].RetTp = oldFt
}
}
}

if len(la.GroupByItems) > 0 {
partitionCols := la.GetPotentialPartitionKeys()
Expand All @@ -2710,6 +2722,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
agg.SetSchema(la.schema.Clone())
agg.MppRunMode = Mpp1Phase
finalAggAdjust(agg.AggFuncs)
hashAggs = append(hashAggs, agg)
}

Expand Down
3 changes: 2 additions & 1 deletion planner/core/testdata/enforce_mpp_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@
"set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;",
"EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate",
"EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column",
"EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column"
"EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column",
"EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10"
]
}
]
40 changes: 32 additions & 8 deletions planner/core/testdata/enforce_mpp_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -676,11 +676,11 @@
{
"SQL": "EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column",
"Plan": [
"TableReader_78 8000.00 root data:ExchangeSender_77",
"└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough",
"TableReader_84 8000.00 root data:ExchangeSender_83",
"└─ExchangeSender_83 8000.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection_10 8000.00 mpp[tiflash] test.o.o_id, Column#6",
" └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.o_id",
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id",
" └─Projection_77 8000.00 mpp[tiflash] Column#6, test.o.o_id",
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:sum(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id",
" └─ExchangeReceiver_71 9990.00 mpp[tiflash] ",
" └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.o_id, collate: binary]",
" └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
Expand All @@ -699,11 +699,11 @@
{
"SQL": "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column",
"Plan": [
"TableReader_78 8000.00 root data:ExchangeSender_77",
"└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough",
"TableReader_84 8000.00 root data:ExchangeSender_83",
"└─ExchangeSender_83 8000.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection_10 8000.00 mpp[tiflash] test.o.c_id, Column#6",
" └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.c_id",
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id",
" └─Projection_77 8000.00 mpp[tiflash] Column#6, test.o.c_id",
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id",
" └─ExchangeReceiver_71 9990.00 mpp[tiflash] ",
" └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.c_id, collate: binary]",
" └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
Expand All @@ -718,6 +718,30 @@
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c keep order:false, stats:pseudo"
],
"Warn": null
},
{
"SQL": "EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10",
"Plan": [
"Projection 10.00 root Column#7, Column#9",
"└─TopN 10.00 root Column#7, offset:0, count:10",
" └─TableReader 10.00 root data:ExchangeSender",
" └─ExchangeSender 10.00 mpp[tiflash] ExchangeType: PassThrough",
" └─TopN 10.00 mpp[tiflash] Column#7, offset:0, count:10",
" └─Projection 16000.00 mpp[tiflash] Column#9, Column#7",
" └─HashAgg 16000.00 mpp[tiflash] group by:Column#7, funcs:sum(Column#10)->Column#9, funcs:firstrow(Column#11)->Column#7",
" └─ExchangeReceiver 16000.00 mpp[tiflash] ",
" └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: Column#7, collate: binary]",
" └─Union 16000.00 mpp[tiflash] ",
" ├─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7",
" │ └─ExchangeReceiver 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary]",
" │ └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo",
" └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7",
" └─ExchangeReceiver 10000.00 mpp[tiflash] ",
" └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary]",
" └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo"
],
"Warn": null
}
]
}
Expand Down

0 comments on commit 789564b

Please sign in to comment.