From 08d0f314263df7e4a2ae90ed73fc83f1fffe49b2 Mon Sep 17 00:00:00 2001
From: AilinKid <314806019@qq.com>
Date: Tue, 21 May 2024 21:24:55 +0800
Subject: [PATCH 1/5] fix mpp final agg couldn't co-exist with non-final mode

Signed-off-by: AilinKid <314806019@qq.com>
---
 pkg/expression/aggregation/BUILD.bazel     |  1 +
 pkg/expression/aggregation/aggregation.go  | 16 +++++++
 pkg/expression/aggregation/explain.go      | 15 ++++++-
 pkg/planner/core/enforce_mpp_test.go       | 50 ++++++++++++++++++++++
 pkg/planner/core/exhaust_physical_plans.go | 23 +++++++++-
 5 files changed, 101 insertions(+), 4 deletions(-)

diff --git a/pkg/expression/aggregation/BUILD.bazel b/pkg/expression/aggregation/BUILD.bazel
index 502d43d527cbc..d4d35efd5700a 100644
--- a/pkg/expression/aggregation/BUILD.bazel
+++ b/pkg/expression/aggregation/BUILD.bazel
@@ -50,6 +50,7 @@ go_library(
         "//pkg/util/mvmap",
         "//pkg/util/size",
         "@com_github_pingcap_errors//:errors",
+        "@com_github_pingcap_failpoint//:failpoint",
         "@com_github_pingcap_tipb//go-tipb",
     ],
 )
diff --git a/pkg/expression/aggregation/aggregation.go b/pkg/expression/aggregation/aggregation.go
index 71567ed560534..c2ec7b6dfe0d6 100644
--- a/pkg/expression/aggregation/aggregation.go
+++ b/pkg/expression/aggregation/aggregation.go
@@ -134,6 +134,22 @@ const (
 	DedupMode
 )
 
+func (a AggFunctionMode) ToString() string {
+	switch a {
+	case CompleteMode:
+		return "complete"
+	case FinalMode:
+		return "final"
+	case Partial1Mode:
+		return "partial1"
+	case Partial2Mode:
+		return "partial2"
+	case DedupMode:
+		return "deduplicate"
+	}
+	return ""
+}
+
 type aggFunction struct {
 	*AggFuncDesc
 }
diff --git a/pkg/expression/aggregation/explain.go b/pkg/expression/aggregation/explain.go
index 29f88499e1bd1..23594c7123480 100644
--- a/pkg/expression/aggregation/explain.go
+++ b/pkg/expression/aggregation/explain.go
@@ -17,7 +17,7 @@ package aggregation
 import (
 	"bytes"
 	"fmt"
-
+	"github.com/pingcap/failpoint"
 	"github.com/pingcap/tidb/pkg/expression"
 	"github.com/pingcap/tidb/pkg/parser/ast"
 )
@@ -25,7 +25,18 @@ import (
 // ExplainAggFunc generates explain information for a aggregation function.
 func ExplainAggFunc(ctx expression.EvalContext, agg *AggFuncDesc, normalized bool) string {
 	var buffer bytes.Buffer
-	fmt.Fprintf(&buffer, "%s(", agg.Name)
+	showMode := false
+	failpoint.Inject("show-agg-mode", func(v failpoint.Value) {
+		if v.(bool) {
+			showMode = true
+		}
+	})
+	if showMode {
+		fmt.Fprintf(&buffer, "%s(%s,", agg.Name, agg.Mode.ToString())
+	} else {
+		fmt.Fprintf(&buffer, "%s(", agg.Name)
+	}
+
 	if agg.HasDistinct {
 		buffer.WriteString("distinct ")
 	}
diff --git a/pkg/planner/core/enforce_mpp_test.go b/pkg/planner/core/enforce_mpp_test.go
index f161f7b7bd7b0..6e8deb41b8197 100644
--- a/pkg/planner/core/enforce_mpp_test.go
+++ b/pkg/planner/core/enforce_mpp_test.go
@@ -19,12 +19,62 @@ import (
 	"strconv"
 	"testing"
 
+	"github.com/pingcap/failpoint"
 	"github.com/pingcap/tidb/pkg/domain"
 	"github.com/pingcap/tidb/pkg/parser/model"
 	"github.com/pingcap/tidb/pkg/testkit"
 	"github.com/stretchr/testify/require"
 )
 
+func TestMppAggShouldAlignFinalMode(t *testing.T) {
+	store := testkit.CreateMockStore(t)
+	tk := testkit.NewTestKit(t, store)
+	tk.MustExec("use test")
+	tk.MustExec("create table t (" +
+		"  d date," +
+		"  v int," +
+		"  primary key(d, v)" +
+		") partition by range columns (d) (" +
+		"  partition p1 values less than ('2023-07-02')," +
+		"  partition p2 values less than ('2023-07-03')" +
+		");")
+	// Create virtual tiflash replica info.
+	dom := domain.GetDomain(tk.Session())
+	is := dom.InfoSchema()
+	db, exists := is.SchemaByName(model.NewCIStr("test"))
+	require.True(t, exists)
+	for _, tblInfo := range db.Tables {
+		if tblInfo.Name.L == "t" {
+			tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
+				Count:     1,
+				Available: true,
+			}
+		}
+	}
+	tk.MustExec(`set tidb_partition_prune_mode='static';`)
+	err := failpoint.Enable("github.com/pingcap/tidb/pkg/expression/aggregation/show-agg-mode", "return(true)")
+	require.Nil(t, err)
+
+	tk.MustQuery("explain format='brief' select 1 from (" +
+		"  select /*+ read_from_storage(tiflash[t]) */ /*+ set_var(mpp_version=\"0\") */ sum(1)" +
+		"  from t where d BETWEEN '2023-07-01' and '2023-07-03' group by d" +
+		") total;").Check(testkit.Rows("Projection 400.00 root  1->Column#4",
+		"└─HashAgg 400.00 root  group by:test.t.d, funcs:count(complete,1)->Column#8",
+		"  └─PartitionUnion 400.00 root  ",
+		"    ├─Projection 200.00 root  test.t.d",
+		"    │ └─HashAgg 200.00 root  group by:test.t.d, funcs:firstrow(partial2,test.t.d)->test.t.d, funcs:count(final,Column#12)->Column#9",
+		"    │   └─TableReader 200.00 root  data:HashAgg",
+		"    │     └─HashAgg 200.00 cop[tikv]  group by:test.t.d, funcs:count(partial1,1)->Column#12",
+		"    │       └─TableRangeScan 250.00 cop[tikv] table:t, partition:p1 range:[2023-07-01,2023-07-03], keep order:false, stats:pseudo",
+		"    └─Projection 200.00 root  test.t.d",
+		"      └─HashAgg 200.00 root  group by:test.t.d, funcs:firstrow(partial2,test.t.d)->test.t.d, funcs:count(final,Column#16)->Column#10",
+		"        └─TableReader 200.00 root  data:HashAgg",
+		"          └─HashAgg 200.00 cop[tikv]  group by:test.t.d, funcs:count(partial1,1)->Column#16",
+		"            └─TableRangeScan 250.00 cop[tikv] table:t, partition:p2 range:[2023-07-01,2023-07-03], keep order:false, stats:pseudo"))
+
+	err = failpoint.Disable("github.com/pingcap/tidb/pkg/expression/aggregation/show-agg-mode")
+	require.Nil(t, err)
+}
 func TestRowSizeInMPP(t *testing.T) {
 	store := testkit.CreateMockStore(t)
 	tk := testkit.NewTestKit(t, store)
diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go
index 0bd930d28a4b4..cd85371563409 100644
--- a/pkg/planner/core/exhaust_physical_plans.go
+++ b/pkg/planner/core/exhaust_physical_plans.go
@@ -3210,6 +3210,9 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
 	if prop.MPPPartitionTp == property.BroadcastType {
 		return nil
 	}
+	if strings.HasPrefix(la.SCtx().GetSessionVars().StmtCtx.OriginalSQL, "explain select 1 from (") {
+		fmt.Println(1)
+	}
 
 	// Is this aggregate a final stage aggregate?
 	// Final agg can't be split into multi-stage aggregate
@@ -3226,6 +3229,18 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
 			}
 		}
 	}
+	// ref: https://github.com/pingcap/tiflash/blob/3ebb102fba17dce3d990d824a9df93d93f1ab
+	// 766/dbms/src/Flash/Coprocessor/AggregationInterpreterHelper.cpp#L26
+	validMppAgg := func(mppAgg *PhysicalHashAgg) bool {
+		isFinalOrCompleteMode := true
+		for _, one := range mppAgg.AggFuncs {
+			if one.Mode == aggregation.FinalMode || one.Mode == aggregation.CompleteMode {
+				continue
+			}
+			isFinalOrCompleteMode = false
+		}
+		return isFinalOrCompleteMode
+	}
 
 	if len(la.GroupByItems) > 0 {
 		partitionCols := la.GetPotentialPartitionKeys()
@@ -3259,7 +3274,9 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
 			agg.SetSchema(la.schema.Clone())
 			agg.MppRunMode = Mpp1Phase
 			finalAggAdjust(agg.AggFuncs)
-			hashAggs = append(hashAggs, agg)
+			if validMppAgg(agg) {
+				hashAggs = append(hashAggs, agg)
+			}
 		}
 
 		// Final agg can't be split into multi-stage aggregate, so exit early
@@ -3274,7 +3291,9 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
 		agg.SetSchema(la.schema.Clone())
 		agg.MppRunMode = Mpp2Phase
 		agg.MppPartitionCols = partitionCols
-		hashAggs = append(hashAggs, agg)
+		if validMppAgg(agg) {
+			hashAggs = append(hashAggs, agg)
+		}
 
 		// agg runs on TiDB with a partial agg on TiFlash if possible
 		if prop.TaskTp == property.RootTaskType {

From 76ecfab2fc69ffe28030b8c1a81d5d55b0f777d1 Mon Sep 17 00:00:00 2001
From: AilinKid <314806019@qq.com>
Date: Tue, 21 May 2024 21:26:37 +0800
Subject: [PATCH 2/5] .

Signed-off-by: AilinKid <314806019@qq.com>
---
 pkg/expression/aggregation/explain.go | 1 +
 1 file changed, 1 insertion(+)

diff --git a/pkg/expression/aggregation/explain.go b/pkg/expression/aggregation/explain.go
index 23594c7123480..d89fc08d88dc6 100644
--- a/pkg/expression/aggregation/explain.go
+++ b/pkg/expression/aggregation/explain.go
@@ -17,6 +17,7 @@ package aggregation
 import (
 	"bytes"
 	"fmt"
+
 	"github.com/pingcap/failpoint"
 	"github.com/pingcap/tidb/pkg/expression"
 	"github.com/pingcap/tidb/pkg/parser/ast"

From 239ebcf2225cafae9151e58a98322923995dc0b2 Mon Sep 17 00:00:00 2001
From: AilinKid <314806019@qq.com>
Date: Tue, 21 May 2024 21:27:11 +0800
Subject: [PATCH 3/5] .

Signed-off-by: AilinKid <314806019@qq.com>
---
 pkg/planner/core/exhaust_physical_plans.go | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go
index cd85371563409..35728bef8736d 100644
--- a/pkg/planner/core/exhaust_physical_plans.go
+++ b/pkg/planner/core/exhaust_physical_plans.go
@@ -3210,9 +3210,6 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
 	if prop.MPPPartitionTp == property.BroadcastType {
 		return nil
 	}
-	if strings.HasPrefix(la.SCtx().GetSessionVars().StmtCtx.OriginalSQL, "explain select 1 from (") {
-		fmt.Println(1)
-	}
 
 	// Is this aggregate a final stage aggregate?
 	// Final agg can't be split into multi-stage aggregate

From 5f80e9c83377f4e647c76d92465e4508ed64e3f2 Mon Sep 17 00:00:00 2001
From: AilinKid <314806019@qq.com>
Date: Wed, 22 May 2024 00:18:34 +0800
Subject: [PATCH 4/5] .

Signed-off-by: AilinKid <314806019@qq.com>
---
 pkg/expression/aggregation/aggregation.go  |  1 +
 pkg/planner/core/exhaust_physical_plans.go | 16 ++++++++++------
 2 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/pkg/expression/aggregation/aggregation.go b/pkg/expression/aggregation/aggregation.go
index c2ec7b6dfe0d6..4229d40dc9e76 100644
--- a/pkg/expression/aggregation/aggregation.go
+++ b/pkg/expression/aggregation/aggregation.go
@@ -134,6 +134,7 @@ const (
 	DedupMode
 )
 
+// ToString show the agg mode.
 func (a AggFunctionMode) ToString() string {
 	switch a {
 	case CompleteMode:
diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go
index 35728bef8736d..c5f7e6fa6d12c 100644
--- a/pkg/planner/core/exhaust_physical_plans.go
+++ b/pkg/planner/core/exhaust_physical_plans.go
@@ -3229,14 +3229,18 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
 	// ref: https://github.com/pingcap/tiflash/blob/3ebb102fba17dce3d990d824a9df93d93f1ab
 	// 766/dbms/src/Flash/Coprocessor/AggregationInterpreterHelper.cpp#L26
 	validMppAgg := func(mppAgg *PhysicalHashAgg) bool {
-		isFinalOrCompleteMode := true
-		for _, one := range mppAgg.AggFuncs {
-			if one.Mode == aggregation.FinalMode || one.Mode == aggregation.CompleteMode {
-				continue
+		isFinalAgg := true
+		if mppAgg.AggFuncs[0].Mode != aggregation.FinalMode && mppAgg.AggFuncs[0].Mode != aggregation.CompleteMode {
+			isFinalAgg = false
+		}
+		for _, one := range mppAgg.AggFuncs[1:] {
+			otherIsFinalAgg := one.Mode == aggregation.FinalMode || one.Mode == aggregation.CompleteMode
+			if isFinalAgg != otherIsFinalAgg {
+				// different agg mode detected in mpp side.
+				return false
 			}
-			isFinalOrCompleteMode = false
 		}
-		return isFinalOrCompleteMode
+		return true
 	}
 
 	if len(la.GroupByItems) > 0 {

From 90ed74a01ca12539b92a5282f0d014547728feaf Mon Sep 17 00:00:00 2001
From: AilinKid <314806019@qq.com>
Date: Wed, 17 Jul 2024 15:25:17 +0800
Subject: [PATCH 5/5] .

Signed-off-by: AilinKid <314806019@qq.com>
---
 pkg/executor/test/tiflashtest/BUILD.bazel     |  2 +-
 pkg/executor/test/tiflashtest/tiflash_test.go | 44 ++++++++++++++++
 pkg/planner/core/enforce_mpp_test.go          | 50 -------------------
 3 files changed, 45 insertions(+), 51 deletions(-)

diff --git a/pkg/executor/test/tiflashtest/BUILD.bazel b/pkg/executor/test/tiflashtest/BUILD.bazel
index b60b56f408a9e..6f1da72d25ce1 100644
--- a/pkg/executor/test/tiflashtest/BUILD.bazel
+++ b/pkg/executor/test/tiflashtest/BUILD.bazel
@@ -9,7 +9,7 @@ go_test(
     ],
     flaky = True,
     race = "on",
-    shard_count = 42,
+    shard_count = 43,
     deps = [
         "//pkg/config",
         "//pkg/domain",
diff --git a/pkg/executor/test/tiflashtest/tiflash_test.go b/pkg/executor/test/tiflashtest/tiflash_test.go
index a58847cfebbed..da0bc8df5fcf6 100644
--- a/pkg/executor/test/tiflashtest/tiflash_test.go
+++ b/pkg/executor/test/tiflashtest/tiflash_test.go
@@ -1988,3 +1988,47 @@ func TestIssue50358(t *testing.T) {
 		tk.MustQuery("select 8 from t join t1").Check(testkit.Rows("8", "8"))
 	}
 }
+
+func TestMppAggShouldAlignFinalMode(t *testing.T) {
+	store := testkit.CreateMockStore(t, withMockTiFlash(1))
+	tk := testkit.NewTestKit(t, store)
+	tk.MustExec("use test")
+	tk.MustExec("create table t (" +
+		"  d date," +
+		"  v int," +
+		"  primary key(d, v)" +
+		") partition by range columns (d) (" +
+		"  partition p1 values less than ('2023-07-02')," +
+		"  partition p2 values less than ('2023-07-03')" +
+		");")
+	tk.MustExec("alter table t set tiflash replica 1")
+	tb := external.GetTableByName(t, tk, "test", "t")
+	err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
+	require.NoError(t, err)
+	tk.MustExec(`set tidb_partition_prune_mode='static';`)
+	err = failpoint.Enable("github.com/pingcap/tidb/pkg/expression/aggregation/show-agg-mode", "return(true)")
+	require.Nil(t, err)
+
+	tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")
+	tk.MustQuery("explain format='brief' select 1 from (" +
+		"  select /*+ read_from_storage(tiflash[t]) */ sum(1)" +
+		"  from t where d BETWEEN '2023-07-01' and '2023-07-03' group by d" +
+		") total;").Check(testkit.Rows("Projection 400.00 root  1->Column#4",
+		"└─HashAgg 400.00 root  group by:test.t.d, funcs:count(complete,1)->Column#8",
+		"  └─PartitionUnion 400.00 root  ",
+		"    ├─Projection 200.00 root  test.t.d",
+		"    │ └─HashAgg 200.00 root  group by:test.t.d, funcs:firstrow(partial2,test.t.d)->test.t.d, funcs:count(final,Column#12)->Column#9",
+		"    │   └─TableReader 200.00 root  MppVersion: 2, data:ExchangeSender",
+		"    │     └─ExchangeSender 200.00 mpp[tiflash]  ExchangeType: PassThrough",
+		"    │       └─HashAgg 200.00 mpp[tiflash]  group by:test.t.d, funcs:count(partial1,1)->Column#12",
+		"    │         └─TableRangeScan 250.00 mpp[tiflash] table:t, partition:p1 range:[2023-07-01,2023-07-03], keep order:false, stats:pseudo",
+		"    └─Projection 200.00 root  test.t.d",
+		"      └─HashAgg 200.00 root  group by:test.t.d, funcs:firstrow(partial2,test.t.d)->test.t.d, funcs:count(final,Column#14)->Column#10",
+		"        └─TableReader 200.00 root  MppVersion: 2, data:ExchangeSender",
+		"          └─ExchangeSender 200.00 mpp[tiflash]  ExchangeType: PassThrough",
+		"            └─HashAgg 200.00 mpp[tiflash]  group by:test.t.d, funcs:count(partial1,1)->Column#14",
+		"              └─TableRangeScan 250.00 mpp[tiflash] table:t, partition:p2 range:[2023-07-01,2023-07-03], keep order:false, stats:pseudo"))
+
+	err = failpoint.Disable("github.com/pingcap/tidb/pkg/expression/aggregation/show-agg-mode")
+	require.Nil(t, err)
+}
diff --git a/pkg/planner/core/enforce_mpp_test.go b/pkg/planner/core/enforce_mpp_test.go
index 6e8deb41b8197..f161f7b7bd7b0 100644
--- a/pkg/planner/core/enforce_mpp_test.go
+++ b/pkg/planner/core/enforce_mpp_test.go
@@ -19,62 +19,12 @@ import (
 	"strconv"
 	"testing"
 
-	"github.com/pingcap/failpoint"
 	"github.com/pingcap/tidb/pkg/domain"
 	"github.com/pingcap/tidb/pkg/parser/model"
 	"github.com/pingcap/tidb/pkg/testkit"
 	"github.com/stretchr/testify/require"
 )
 
-func TestMppAggShouldAlignFinalMode(t *testing.T) {
-	store := testkit.CreateMockStore(t)
-	tk := testkit.NewTestKit(t, store)
-	tk.MustExec("use test")
-	tk.MustExec("create table t (" +
-		"  d date," +
-		"  v int," +
-		"  primary key(d, v)" +
-		") partition by range columns (d) (" +
-		"  partition p1 values less than ('2023-07-02')," +
-		"  partition p2 values less than ('2023-07-03')" +
-		");")
-	// Create virtual tiflash replica info.
-	dom := domain.GetDomain(tk.Session())
-	is := dom.InfoSchema()
-	db, exists := is.SchemaByName(model.NewCIStr("test"))
-	require.True(t, exists)
-	for _, tblInfo := range db.Tables {
-		if tblInfo.Name.L == "t" {
-			tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
-				Count:     1,
-				Available: true,
-			}
-		}
-	}
-	tk.MustExec(`set tidb_partition_prune_mode='static';`)
-	err := failpoint.Enable("github.com/pingcap/tidb/pkg/expression/aggregation/show-agg-mode", "return(true)")
-	require.Nil(t, err)
-
-	tk.MustQuery("explain format='brief' select 1 from (" +
-		"  select /*+ read_from_storage(tiflash[t]) */ /*+ set_var(mpp_version=\"0\") */ sum(1)" +
-		"  from t where d BETWEEN '2023-07-01' and '2023-07-03' group by d" +
-		") total;").Check(testkit.Rows("Projection 400.00 root  1->Column#4",
-		"└─HashAgg 400.00 root  group by:test.t.d, funcs:count(complete,1)->Column#8",
-		"  └─PartitionUnion 400.00 root  ",
-		"    ├─Projection 200.00 root  test.t.d",
-		"    │ └─HashAgg 200.00 root  group by:test.t.d, funcs:firstrow(partial2,test.t.d)->test.t.d, funcs:count(final,Column#12)->Column#9",
-		"    │   └─TableReader 200.00 root  data:HashAgg",
-		"    │     └─HashAgg 200.00 cop[tikv]  group by:test.t.d, funcs:count(partial1,1)->Column#12",
-		"    │       └─TableRangeScan 250.00 cop[tikv] table:t, partition:p1 range:[2023-07-01,2023-07-03], keep order:false, stats:pseudo",
-		"    └─Projection 200.00 root  test.t.d",
-		"      └─HashAgg 200.00 root  group by:test.t.d, funcs:firstrow(partial2,test.t.d)->test.t.d, funcs:count(final,Column#16)->Column#10",
-		"        └─TableReader 200.00 root  data:HashAgg",
-		"          └─HashAgg 200.00 cop[tikv]  group by:test.t.d, funcs:count(partial1,1)->Column#16",
-		"            └─TableRangeScan 250.00 cop[tikv] table:t, partition:p2 range:[2023-07-01,2023-07-03], keep order:false, stats:pseudo"))
-
-	err = failpoint.Disable("github.com/pingcap/tidb/pkg/expression/aggregation/show-agg-mode")
-	require.Nil(t, err)
-}
 func TestRowSizeInMPP(t *testing.T) {
 	store := testkit.CreateMockStore(t)
 	tk := testkit.NewTestKit(t, store)