diff --git a/statistics/handle/handle_hist.go b/statistics/handle/handle_hist.go index 0c56812f0647c..038e59dce12a0 100644 --- a/statistics/handle/handle_hist.go +++ b/statistics/handle/handle_hist.go @@ -37,6 +37,9 @@ import ( "golang.org/x/sync/singleflight" ) +// RetryCount is the max retry count for a sync load task. +const RetryCount = 3 + type statsWrapper struct { col *statistics.Column idx *statistics.Index @@ -56,6 +59,7 @@ type NeededItemTask struct { TableItemID model.TableItemID ToTimeout time.Time ResultCh chan stmtctx.StatsLoadResult + Retry int } // SendLoadRequests send neededColumns requests @@ -217,6 +221,9 @@ func (h *Handle) SubLoadWorker(ctx sessionctx.Context, exit chan struct{}, exitW } // HandleOneTask handles last task if not nil, else handle a new task from chan, and return current task if fail somewhere. +// - If the task is handled successfully, return nil, nil. +// - If the task is timeout, return the task and nil. The caller should retry the timeout task without sleep. +// - If the task is failed, return the task, error. The caller should retry the timeout task with sleep. func (h *Handle) HandleOneTask(lastTask *NeededItemTask, readerCtx *StatsReaderContext, ctx sqlexec.RestrictedSQLExecutor, exit chan struct{}) (task *NeededItemTask, err error) { defer func() { // recover for each task, worker keeps working @@ -236,28 +243,37 @@ func (h *Handle) HandleOneTask(lastTask *NeededItemTask, readerCtx *StatsReaderC } else { task = lastTask } + result := stmtctx.StatsLoadResult{Item: task.TableItemID} resultChan := h.StatsLoad.Singleflight.DoChan(task.TableItemID.Key(), func() (any, error) { - return h.handleOneItemTask(task, readerCtx, ctx) + err := h.handleOneItemTask(task, readerCtx, ctx) + return nil, err }) timeout := time.Until(task.ToTimeout) select { - case result := <-resultChan: - if result.Err == nil { - slr := result.Val.(*stmtctx.StatsLoadResult) - if slr.Error != nil { - return task, slr.Error - } - task.ResultCh <- *slr + case sr := <-resultChan: + // sr.Val is always nil. + if sr.Err == nil { + task.ResultCh <- result return nil, nil } - return task, result.Err + if !isVaildForRetry(task) { + result.Error = sr.Err + task.ResultCh <- result + return nil, nil + } + return task, sr.Err case <-time.After(timeout): task.ToTimeout.Add(time.Duration(h.mu.ctx.GetSessionVars().StatsLoadSyncWait.Load()) * time.Microsecond) return task, nil } } -func (h *Handle) handleOneItemTask(task *NeededItemTask, readerCtx *StatsReaderContext, ctx sqlexec.RestrictedSQLExecutor) (result *stmtctx.StatsLoadResult, err error) { +func isVaildForRetry(task *NeededItemTask) bool { + task.Retry++ + return task.Retry <= RetryCount +} + +func (h *Handle) handleOneItemTask(task *NeededItemTask, readerCtx *StatsReaderContext, ctx sqlexec.RestrictedSQLExecutor) (err error) { defer func() { // recover for each task, worker keeps working if r := recover(); r != nil { @@ -265,24 +281,23 @@ func (h *Handle) handleOneItemTask(task *NeededItemTask, readerCtx *StatsReaderC err = errors.Errorf("stats loading panicked: %v", r) } }() - result = &stmtctx.StatsLoadResult{Item: task.TableItemID} - item := result.Item + item := task.TableItemID oldCache := h.statsCache.Load().(statsCache) tbl, ok := oldCache.Get(item.TableID) if !ok { - return result, nil + return nil } wrapper := &statsWrapper{} if item.IsIndex { index, ok := tbl.Indices[item.ID] if !ok || index.IsFullLoad() { - return result, nil + return nil } wrapper.idx = index } else { col, ok := tbl.Columns[item.ID] if !ok || col.IsFullLoad() { - return result, nil + return nil } wrapper.col = col } @@ -292,8 +307,7 @@ func (h *Handle) handleOneItemTask(task *NeededItemTask, readerCtx *StatsReaderC needUpdate := false wrapper, err = h.readStatsForOneItem(item, wrapper, readerCtx.reader) if err != nil { - result.Error = err - return result, err + return err } if item.IsIndex { if wrapper.idx != nil { @@ -305,10 +319,10 @@ func (h *Handle) handleOneItemTask(task *NeededItemTask, readerCtx *StatsReaderC } } metrics.ReadStatsHistogram.Observe(float64(time.Since(t).Milliseconds())) - if needUpdate && h.updateCachedItem(item, wrapper.col, wrapper.idx) { - return result, nil + if needUpdate { + h.updateCachedItem(item, wrapper.col, wrapper.idx) } - return nil, nil + return nil } func (h *Handle) loadFreshStatsReader(readerCtx *StatsReaderContext, ctx sqlexec.RestrictedSQLExecutor) { @@ -493,12 +507,12 @@ func (h *Handle) updateCachedItem(item model.TableItemID, colHist *statistics.Co oldCache := h.statsCache.Load().(statsCache) tbl, ok := oldCache.Get(item.TableID) if !ok { - return true + return false } if !item.IsIndex && colHist != nil { c, ok := tbl.Columns[item.ID] if !ok || c.IsFullLoad() { - return true + return false } tbl = tbl.Copy() tbl.Columns[c.ID] = colHist diff --git a/statistics/handle/handle_hist_test.go b/statistics/handle/handle_hist_test.go index c4f30e6ef0de7..d39f110852ff4 100644 --- a/statistics/handle/handle_hist_test.go +++ b/statistics/handle/handle_hist_test.go @@ -209,6 +209,19 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { require.Error(t, err1) require.NotNil(t, task1) + select { + case <-stmtCtx1.StatsLoad.ResultCh: + t.Logf("stmtCtx1.ResultCh should not get anything") + t.FailNow() + case <-stmtCtx2.StatsLoad.ResultCh: + t.Logf("stmtCtx2.ResultCh should not get anything") + t.FailNow() + case <-task1.ResultCh: + t.Logf("task1.ResultCh should not get anything") + t.FailNow() + default: + } + require.NoError(t, failpoint.Disable(fp.failPath)) task3, err3 := h.HandleOneTask(task1, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) require.NoError(t, err3) @@ -231,3 +244,80 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { require.Greater(t, hg.Len()+topn.Num(), 0) } } + +func TestRetry(t *testing.T) { + originConfig := config.GetGlobalConfig() + newConfig := config.NewConfig() + newConfig.Performance.StatsLoadConcurrency = 0 // no worker to consume channel + config.StoreGlobalConfig(newConfig) + defer config.StoreGlobalConfig(originConfig) + store, dom := testkit.CreateMockStoreAndDomain(t) + + testKit := testkit.NewTestKit(t, store) + testKit.MustExec("use test") + testKit.MustExec("drop table if exists t") + testKit.MustExec("set @@session.tidb_analyze_version=2") + testKit.MustExec("create table t(a int, b int, c int, primary key(a), key idx(b))") + testKit.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3)") + + oriLease := dom.StatsHandle().Lease() + dom.StatsHandle().SetLease(1) + defer func() { + dom.StatsHandle().SetLease(oriLease) + }() + testKit.MustExec("analyze table t") + + is := dom.InfoSchema() + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + require.NoError(t, err) + tableInfo := tbl.Meta() + + h := dom.StatsHandle() + + neededColumns := make([]model.TableItemID, 1) + neededColumns[0] = model.TableItemID{TableID: tableInfo.ID, ID: tableInfo.Columns[2].ID, IsIndex: false} + timeout := time.Nanosecond * mathutil.MaxInt + + // clear statsCache + h.Clear() + require.NoError(t, dom.StatsHandle().Update(is)) + + // no stats at beginning + stat := h.GetTableStats(tableInfo) + c, ok := stat.Columns[tableInfo.Columns[2].ID] + require.True(t, !ok || (c.Histogram.Len()+c.TopN.Num() == 0)) + + stmtCtx1 := &stmtctx.StatementContext{} + h.SendLoadRequests(stmtCtx1, neededColumns, timeout) + stmtCtx2 := &stmtctx.StatementContext{} + h.SendLoadRequests(stmtCtx2, neededColumns, timeout) + + exitCh := make(chan struct{}) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/statistics/handle/mockReadStatsForOneFail", "return(true)")) + var ( + task1 *handle.NeededItemTask + err1 error + ) + readerCtx := &handle.StatsReaderContext{} + for i := 0; i < handle.RetryCount; i++ { + task1, err1 = h.HandleOneTask(task1, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) + require.Error(t, err1) + require.NotNil(t, task1) + select { + case <-task1.ResultCh: + t.Logf("task1.ResultCh should not get nothing") + t.FailNow() + default: + } + } + result, err1 := h.HandleOneTask(task1, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) + require.NoError(t, err1) + require.Nil(t, result) + select { + case <-task1.ResultCh: + default: + t.Logf("task1.ResultCh should get nothing") + t.FailNow() + } + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/statistics/handle/mockReadStatsForOneFail")) +}