From 7b5af46126baf4454707e1a7c7690b169368beba Mon Sep 17 00:00:00 2001
From: Ti Chi Robot <ti-community-prow-bot@tidb.io>
Date: Wed, 24 Apr 2024 20:24:12 +0800
Subject: [PATCH] statistics: support global singleflight for sync load
 (#52796) (#52870)

close pingcap/tidb#52797
---
 .../core/casetest/planstats/main_test.go      |   1 +
 pkg/sessionctx/stmtctx/BUILD.bazel            |   1 +
 pkg/sessionctx/stmtctx/stmtctx.go             |   3 +-
 pkg/statistics/handle/syncload/BUILD.bazel    |   2 +
 .../handle/syncload/stats_syncload.go         | 105 +++++++++---------
 .../handle/syncload/stats_syncload_test.go    |  55 +++++----
 pkg/statistics/handle/types/BUILD.bazel       |   1 -
 pkg/statistics/handle/types/interfaces.go     |   2 -
 8 files changed, 91 insertions(+), 79 deletions(-)

diff --git a/pkg/planner/core/casetest/planstats/main_test.go b/pkg/planner/core/casetest/planstats/main_test.go
index f53fa7fc26c58..a1289ccab2a6c 100644
--- a/pkg/planner/core/casetest/planstats/main_test.go
+++ b/pkg/planner/core/casetest/planstats/main_test.go
@@ -40,6 +40,7 @@ func TestMain(m *testing.M) {
 		goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
 		goleak.IgnoreTopFunction("github.com/tikv/client-go/v2/txnkv/transaction.keepAlive"),
 		goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
+		goleak.IgnoreTopFunction("github.com/pingcap/tidb/pkg/statistics/handle/syncload.(*statsSyncLoad).SendLoadRequests.func1"), // For TestPlanStatsLoadTimeout
 	}
 
 	callback := func(i int) int {
diff --git a/pkg/sessionctx/stmtctx/BUILD.bazel b/pkg/sessionctx/stmtctx/BUILD.bazel
index 894dc72b81e30..e838b69507558 100644
--- a/pkg/sessionctx/stmtctx/BUILD.bazel
+++ b/pkg/sessionctx/stmtctx/BUILD.bazel
@@ -30,6 +30,7 @@ go_library(
         "@com_github_pingcap_errors//:errors",
         "@com_github_tikv_client_go_v2//tikvrpc",
         "@org_golang_x_exp//maps",
+        "@org_golang_x_sync//singleflight",
         "@org_uber_go_atomic//:atomic",
     ],
 )
diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go
index da6d1d7217a68..35ed290715ebe 100644
--- a/pkg/sessionctx/stmtctx/stmtctx.go
+++ b/pkg/sessionctx/stmtctx/stmtctx.go
@@ -52,6 +52,7 @@ import (
 	"github.com/tikv/client-go/v2/tikvrpc"
 	atomic2 "go.uber.org/atomic"
 	"golang.org/x/exp/maps"
+	"golang.org/x/sync/singleflight"
 )
 
 const (
@@ -369,7 +370,7 @@ type StatementContext struct {
 		// NeededItems stores the columns/indices whose stats are needed for planner.
 		NeededItems []model.StatsLoadItem
 		// ResultCh to receive stats loading results
-		ResultCh chan StatsLoadResult
+		ResultCh []<-chan singleflight.Result
 		// LoadStartTime is to record the load start time to calculate latency
 		LoadStartTime time.Time
 	}
diff --git a/pkg/statistics/handle/syncload/BUILD.bazel b/pkg/statistics/handle/syncload/BUILD.bazel
index ed6e310786a2a..3be7fe67caa52 100644
--- a/pkg/statistics/handle/syncload/BUILD.bazel
+++ b/pkg/statistics/handle/syncload/BUILD.bazel
@@ -17,9 +17,11 @@ go_library(
         "//pkg/statistics/handle/types",
         "//pkg/types",
         "//pkg/util",
+        "//pkg/util/intest",
         "//pkg/util/logutil",
         "@com_github_pingcap_errors//:errors",
         "@com_github_pingcap_failpoint//:failpoint",
+        "@org_golang_x_sync//singleflight",
         "@org_uber_go_zap//:zap",
     ],
 )
diff --git a/pkg/statistics/handle/syncload/stats_syncload.go b/pkg/statistics/handle/syncload/stats_syncload.go
index 0ae6161a2cf8c..b0bd43166f3ce 100644
--- a/pkg/statistics/handle/syncload/stats_syncload.go
+++ b/pkg/statistics/handle/syncload/stats_syncload.go
@@ -32,8 +32,10 @@ import (
 	statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types"
 	"github.com/pingcap/tidb/pkg/types"
 	"github.com/pingcap/tidb/pkg/util"
+	"github.com/pingcap/tidb/pkg/util/intest"
 	"github.com/pingcap/tidb/pkg/util/logutil"
 	"go.uber.org/zap"
+	"golang.org/x/sync/singleflight"
 )
 
 // RetryCount is the max retry count for a sync load task.
@@ -44,6 +46,8 @@ type statsSyncLoad struct {
 	StatsLoad   statstypes.StatsLoad
 }
 
+var globalStatsSyncLoadSingleFlight singleflight.Group
+
 // NewStatsSyncLoad creates a new StatsSyncLoad.
 func NewStatsSyncLoad(statsHandle statstypes.StatsHandle) statstypes.StatsSyncLoad {
 	s := &statsSyncLoad{statsHandle: statsHandle}
@@ -78,25 +82,27 @@ func (s *statsSyncLoad) SendLoadRequests(sc *stmtctx.StatementContext, neededHis
 	}
 	sc.StatsLoad.Timeout = timeout
 	sc.StatsLoad.NeededItems = remainedItems
-	sc.StatsLoad.ResultCh = make(chan stmtctx.StatsLoadResult, len(remainedItems))
-	tasks := make([]*statstypes.NeededItemTask, 0)
+	sc.StatsLoad.ResultCh = make([]<-chan singleflight.Result, 0, len(remainedItems))
 	for _, item := range remainedItems {
-		task := &statstypes.NeededItemTask{
-			Item:      item,
-			ToTimeout: time.Now().Local().Add(timeout),
-			ResultCh:  sc.StatsLoad.ResultCh,
-		}
-		tasks = append(tasks, task)
-	}
-	timer := time.NewTimer(timeout)
-	defer timer.Stop()
-	for _, task := range tasks {
-		select {
-		case s.StatsLoad.NeededItemsCh <- task:
-			continue
-		case <-timer.C:
-			return errors.New("sync load stats channel is full and timeout sending task to channel")
-		}
+		localItem := item
+		resultCh := globalStatsSyncLoadSingleFlight.DoChan(localItem.Key(), func() (any, error) {
+			timer := time.NewTimer(timeout)
+			defer timer.Stop()
+			task := &statstypes.NeededItemTask{
+				Item:      localItem,
+				ToTimeout: time.Now().Local().Add(timeout),
+				ResultCh:  make(chan stmtctx.StatsLoadResult, 1),
+			}
+			select {
+			case s.StatsLoad.NeededItemsCh <- task:
+				result, ok := <-task.ResultCh
+				intest.Assert(ok, "task.ResultCh cannot be closed")
+				return result, nil
+			case <-timer.C:
+				return nil, errors.New("sync load stats channel is full and timeout sending task to channel")
+			}
+		})
+		sc.StatsLoad.ResultCh = append(sc.StatsLoad.ResultCh, resultCh)
 	}
 	sc.StatsLoad.LoadStartTime = time.Now()
 	return nil
@@ -122,25 +128,34 @@ func (*statsSyncLoad) SyncWaitStatsLoad(sc *stmtctx.StatementContext) error {
 	metrics.SyncLoadCounter.Inc()
 	timer := time.NewTimer(sc.StatsLoad.Timeout)
 	defer timer.Stop()
-	for {
+	for _, resultCh := range sc.StatsLoad.ResultCh {
 		select {
-		case result, ok := <-sc.StatsLoad.ResultCh:
+		case result, ok := <-resultCh:
 			if !ok {
 				return errors.New("sync load stats channel closed unexpectedly")
 			}
-			if result.HasError() {
-				errorMsgs = append(errorMsgs, result.ErrorMsg())
-			}
-			delete(resultCheckMap, result.Item)
-			if len(resultCheckMap) == 0 {
-				metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds()))
-				return nil
+			// this error is from statsSyncLoad.SendLoadRequests which start to task and send task into worker,
+			// not the stats loading error
+			if result.Err != nil {
+				errorMsgs = append(errorMsgs, result.Err.Error())
+			} else {
+				val := result.Val.(stmtctx.StatsLoadResult)
+				// this error is from the stats loading error
+				if val.HasError() {
+					errorMsgs = append(errorMsgs, val.ErrorMsg())
+				}
+				delete(resultCheckMap, val.Item)
 			}
 		case <-timer.C:
 			metrics.SyncLoadTimeoutCounter.Inc()
 			return errors.New("sync load stats timeout")
 		}
 	}
+	if len(resultCheckMap) == 0 {
+		metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds()))
+		return nil
+	}
+	return nil
 }
 
 // removeHistLoadedColumns removed having-hist columns based on neededColumns and statsCache.
@@ -230,33 +245,17 @@ func (s *statsSyncLoad) HandleOneTask(sctx sessionctx.Context, lastTask *statsty
 		task = lastTask
 	}
 	result := stmtctx.StatsLoadResult{Item: task.Item.TableItemID}
-	resultChan := s.StatsLoad.Singleflight.DoChan(task.Item.Key(), func() (any, error) {
-		err := s.handleOneItemTask(task)
-		return nil, err
-	})
-	timeout := time.Until(task.ToTimeout)
-	select {
-	case sr := <-resultChan:
-		// sr.Val is always nil.
-		if sr.Err == nil {
-			task.ResultCh <- result
-			return nil, nil
-		}
-		if !isVaildForRetry(task) {
-			result.Error = sr.Err
-			task.ResultCh <- result
-			return nil, nil
-		}
-		return task, sr.Err
-	case <-time.After(timeout):
-		if !isVaildForRetry(task) {
-			result.Error = errors.New("stats loading timeout")
-			task.ResultCh <- result
-			return nil, nil
-		}
-		task.ToTimeout.Add(time.Duration(sctx.GetSessionVars().StatsLoadSyncWait.Load()) * time.Microsecond)
-		return task, nil
+	err = s.handleOneItemTask(task)
+	if err == nil {
+		task.ResultCh <- result
+		return nil, nil
+	}
+	if !isVaildForRetry(task) {
+		result.Error = err
+		task.ResultCh <- result
+		return nil, nil
 	}
+	return task, err
 }
 
 func isVaildForRetry(task *statstypes.NeededItemTask) bool {
diff --git a/pkg/statistics/handle/syncload/stats_syncload_test.go b/pkg/statistics/handle/syncload/stats_syncload_test.go
index 4b38387430c49..8a8929d9d93e5 100644
--- a/pkg/statistics/handle/syncload/stats_syncload_test.go
+++ b/pkg/statistics/handle/syncload/stats_syncload_test.go
@@ -208,13 +208,23 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) {
 		task1, err1 := h.HandleOneTask(testKit.Session().(sessionctx.Context), nil, exitCh)
 		require.Error(t, err1)
 		require.NotNil(t, task1)
+		for _, resultCh := range stmtCtx1.StatsLoad.ResultCh {
+			select {
+			case <-resultCh:
+				t.Logf("stmtCtx1.ResultCh should not get anything")
+				t.FailNow()
+			default:
+			}
+		}
+		for _, resultCh := range stmtCtx2.StatsLoad.ResultCh {
+			select {
+			case <-resultCh:
+				t.Logf("stmtCtx1.ResultCh should not get anything")
+				t.FailNow()
+			default:
+			}
+		}
 		select {
-		case <-stmtCtx1.StatsLoad.ResultCh:
-			t.Logf("stmtCtx1.ResultCh should not get anything")
-			t.FailNow()
-		case <-stmtCtx2.StatsLoad.ResultCh:
-			t.Logf("stmtCtx2.ResultCh should not get anything")
-			t.FailNow()
 		case <-task1.ResultCh:
 			t.Logf("task1.ResultCh should not get anything")
 			t.FailNow()
@@ -225,17 +235,18 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) {
 		task3, err3 := h.HandleOneTask(testKit.Session().(sessionctx.Context), task1, exitCh)
 		require.NoError(t, err3)
 		require.Nil(t, task3)
-
-		task, err3 := h.HandleOneTask(testKit.Session().(sessionctx.Context), nil, exitCh)
-		require.NoError(t, err3)
-		require.Nil(t, task)
-
-		rs1, ok1 := <-stmtCtx1.StatsLoad.ResultCh
-		require.True(t, ok1)
-		require.Equal(t, neededColumns[0].TableItemID, rs1.Item)
-		rs2, ok2 := <-stmtCtx2.StatsLoad.ResultCh
-		require.True(t, ok2)
-		require.Equal(t, neededColumns[0].TableItemID, rs2.Item)
+		for _, resultCh := range stmtCtx1.StatsLoad.ResultCh {
+			rs1, ok1 := <-resultCh
+			require.True(t, rs1.Shared)
+			require.True(t, ok1)
+			require.Equal(t, neededColumns[0].TableItemID, rs1.Val.(stmtctx.StatsLoadResult).Item)
+		}
+		for _, resultCh := range stmtCtx2.StatsLoad.ResultCh {
+			rs1, ok1 := <-resultCh
+			require.True(t, rs1.Shared)
+			require.True(t, ok1)
+			require.Equal(t, neededColumns[0].TableItemID, rs1.Val.(stmtctx.StatsLoadResult).Item)
+		}
 
 		stat = h.GetTableStats(tableInfo)
 		hg := stat.Columns[tableInfo.Columns[2].ID].Histogram
@@ -312,11 +323,11 @@ func TestRetry(t *testing.T) {
 	result, err1 := h.HandleOneTask(testKit.Session().(sessionctx.Context), task1, exitCh)
 	require.NoError(t, err1)
 	require.Nil(t, result)
-	select {
-	case <-task1.ResultCh:
-	default:
-		t.Logf("task1.ResultCh should get nothing")
-		t.FailNow()
+	for _, resultCh := range stmtCtx1.StatsLoad.ResultCh {
+		rs1, ok1 := <-resultCh
+		require.True(t, rs1.Shared)
+		require.True(t, ok1)
+		require.Error(t, rs1.Val.(stmtctx.StatsLoadResult).Error)
 	}
 	task1.Retry = 0
 	for i := 0; i < syncload.RetryCount*5; i++ {
diff --git a/pkg/statistics/handle/types/BUILD.bazel b/pkg/statistics/handle/types/BUILD.bazel
index 328d1a75b1159..df7a6ea2acfa1 100644
--- a/pkg/statistics/handle/types/BUILD.bazel
+++ b/pkg/statistics/handle/types/BUILD.bazel
@@ -17,6 +17,5 @@ go_library(
         "//pkg/types",
         "//pkg/util",
         "//pkg/util/sqlexec",
-        "@org_golang_x_sync//singleflight",
     ],
 )
diff --git a/pkg/statistics/handle/types/interfaces.go b/pkg/statistics/handle/types/interfaces.go
index 5c1b41d7fbd65..8726cd7d64a7c 100644
--- a/pkg/statistics/handle/types/interfaces.go
+++ b/pkg/statistics/handle/types/interfaces.go
@@ -30,7 +30,6 @@ import (
 	"github.com/pingcap/tidb/pkg/types"
 	"github.com/pingcap/tidb/pkg/util"
 	"github.com/pingcap/tidb/pkg/util/sqlexec"
-	"golang.org/x/sync/singleflight"
 )
 
 // StatsGC is used to GC unnecessary stats.
@@ -398,7 +397,6 @@ type NeededItemTask struct {
 type StatsLoad struct {
 	NeededItemsCh  chan *NeededItemTask
 	TimeoutItemsCh chan *NeededItemTask
-	Singleflight   singleflight.Group
 	sync.Mutex
 }