diff --git a/pkg/executor/builder.go b/pkg/executor/builder.go index 188deffd6365a..6f80aefaaf3f1 100644 --- a/pkg/executor/builder.go +++ b/pkg/executor/builder.go @@ -2174,8 +2174,9 @@ func (b *executorBuilder) buildTopN(v *plannercore.PhysicalTopN) exec.Executor { } executor_metrics.ExecutorCounterTopNExec.Inc() return &sortexec.TopNExec{ - SortExec: sortExec, - Limit: &plannercore.PhysicalLimit{Count: v.Count, Offset: v.Offset}, + SortExec: sortExec, + Limit: &plannercore.PhysicalLimit{Count: v.Count, Offset: v.Offset}, + Concurrency: b.ctx.GetSessionVars().Concurrency.ExecutorConcurrency, } } diff --git a/pkg/executor/executor_required_rows_test.go b/pkg/executor/executor_required_rows_test.go index ad222b6ddf920..274abc391eba2 100644 --- a/pkg/executor/executor_required_rows_test.go +++ b/pkg/executor/executor_required_rows_test.go @@ -393,8 +393,9 @@ func buildTopNExec(ctx sessionctx.Context, offset, count int, byItems []*util.By ExecSchema: src.Schema(), } return &sortexec.TopNExec{ - SortExec: sortExec, - Limit: &plannercore.PhysicalLimit{Count: uint64(count), Offset: uint64(offset)}, + SortExec: sortExec, + Limit: &plannercore.PhysicalLimit{Count: uint64(count), Offset: uint64(offset)}, + Concurrency: 5, } } diff --git a/pkg/executor/sortexec/BUILD.bazel b/pkg/executor/sortexec/BUILD.bazel index c545297a74fd7..82bbc64890c05 100644 --- a/pkg/executor/sortexec/BUILD.bazel +++ b/pkg/executor/sortexec/BUILD.bazel @@ -11,6 +11,9 @@ go_library( "sort_spill.go", "sort_util.go", "topn.go", + "topn_chunk_heap.go", + "topn_spill.go", + "topn_worker.go", ], importpath = "github.com/pingcap/tidb/pkg/executor/sortexec", visibility = ["//visibility:public"], @@ -39,7 +42,7 @@ go_test( timeout = "short", srcs = ["sort_test.go"], flaky = True, - shard_count = 10, + shard_count = 13, deps = [ "//pkg/config", "//pkg/sessionctx/variable", @@ -60,6 +63,7 @@ go_test( "sort_spill_test.go", "sort_test.go", "sortexec_pkg_test.go", + "topn_spill_test.go", ], embed = [":sortexec"], flaky = True, @@ -70,6 +74,7 @@ go_test( "//pkg/executor/internal/testutil", "//pkg/expression", "//pkg/parser/mysql", + "//pkg/planner/core", "//pkg/planner/util", "//pkg/sessionctx/variable", "//pkg/testkit", diff --git a/pkg/executor/sortexec/sort_spill_test.go b/pkg/executor/sortexec/sort_spill_test.go index cb8b00642c4f1..1c2cb46b02a0b 100644 --- a/pkg/executor/sortexec/sort_spill_test.go +++ b/pkg/executor/sortexec/sort_spill_test.go @@ -98,10 +98,17 @@ func (r *resultChecker) initRowPtrs() { } } -func (r *resultChecker) check(resultChunks []*chunk.Chunk) bool { +func (r *resultChecker) check(resultChunks []*chunk.Chunk, offset int64, count int64) bool { if r.rowPtrs == nil { r.initRowPtrs() sort.Slice(r.rowPtrs, r.keyColumnsLess) + if offset < 0 { + offset = 0 + } + if count < 0 { + count = (int64(len(r.rowPtrs)) - offset) + } + r.rowPtrs = r.rowPtrs[offset : offset+count] } cursor := 0 @@ -220,7 +227,7 @@ func executeSortExecutorAndManullyTriggerSpill(t *testing.T, exe *sortexec.SortE func checkCorrectness(schema *expression.Schema, exe *sortexec.SortExec, dataSource *testutil.MockDataSource, resultChunks []*chunk.Chunk) bool { keyColumns, keyCmpFuncs, byItemsDesc := exe.GetSortMetaForTest() checker := newResultChecker(schema, keyColumns, keyCmpFuncs, byItemsDesc, dataSource.GenData) - return checker.check(resultChunks) + return checker.check(resultChunks, -1, -1) } func onePartitionAndAllDataInMemoryCase(t *testing.T, ctx *mock.Context, sortCase *testutil.SortCase) { diff --git a/pkg/executor/sortexec/sort_util.go b/pkg/executor/sortexec/sort_util.go index 1e63dbd3c1b27..59ef17f90da2c 100644 --- a/pkg/executor/sortexec/sort_util.go +++ b/pkg/executor/sortexec/sort_util.go @@ -46,10 +46,10 @@ type rowWithPartition struct { partitionID int } -func processPanicAndLog(errOutputChan chan rowWithError, r any) { +func processPanicAndLog(errOutputChan chan<- rowWithError, r any) { err := util.GetRecoverError(r) errOutputChan <- rowWithError{err: err} - logutil.BgLogger().Error("parallel sort panicked", zap.Error(err), zap.Stack("stack")) + logutil.BgLogger().Error("executor panicked", zap.Error(err), zap.Stack("stack")) } // chunkWithMemoryUsage contains chunk and memory usage. diff --git a/pkg/executor/sortexec/topn.go b/pkg/executor/sortexec/topn.go index 6c6f074064359..146daa01ebd9b 100644 --- a/pkg/executor/sortexec/topn.go +++ b/pkg/executor/sortexec/topn.go @@ -17,12 +17,19 @@ package sortexec import ( "container/heap" "context" + "math/rand" "slices" + "sync" "sync/atomic" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/executor/internal/exec" plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/channel" "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disk" "github.com/pingcap/tidb/pkg/util/memory" ) @@ -30,134 +37,237 @@ import ( // Instead of sorting all the rows fetched from the table, it keeps the Top-N elements only in a heap to reduce memory usage. type TopNExec struct { SortExec - Limit *plannercore.PhysicalLimit - totalLimit uint64 + Limit *plannercore.PhysicalLimit + + // It's useful when spill is triggered and the fetcher could know when workers finish their works. + fetcherAndWorkerSyncer *sync.WaitGroup + resultChannel chan rowWithError + chunkChannel chan *chunk.Chunk + + finishCh chan struct{} chkHeap *topNChunkHeap -} -// topNChunkHeap implements heap.Interface. -type topNChunkHeap struct { - *TopNExec + spillHelper *topNSpillHelper + spillAction *topNSpillAction - // rowChunks is the chunks to store row values. - rowChunks *chunk.List - // rowPointer store the chunk index and row index for each row. - rowPtrs []chunk.RowPtr + // Normally, heap will be stored in memory after it has been built. + // However, other executors may trigger topn spill after the heap is built + // and inMemoryThenSpillFlag will be set to true at this time. + inMemoryThenSpillFlag bool - Idx int -} + // Topn executor has two stage: + // 1. Building heap, in this stage all received rows will be inserted into heap. + // 2. Updating heap, in this stage only rows that is smaller than the heap top could be inserted and we will drop the heap top. + // + // This variable is only used for test. + isSpillTriggeredInStage1ForTest bool + isSpillTriggeredInStage2ForTest bool -// Less implement heap.Interface, but since we mantains a max heap, -// this function returns true if row i is greater than row j. -func (h *topNChunkHeap) Less(i, j int) bool { - rowI := h.rowChunks.GetRow(h.rowPtrs[i]) - rowJ := h.rowChunks.GetRow(h.rowPtrs[j]) - return h.greaterRow(rowI, rowJ) + Concurrency int } -func (h *topNChunkHeap) greaterRow(rowI, rowJ chunk.Row) bool { - for i, colIdx := range h.keyColumns { - cmpFunc := h.keyCmpFuncs[i] - cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) - if h.ByItems[i].Desc { - cmp = -cmp +// Open implements the Executor Open interface. +func (e *TopNExec) Open(ctx context.Context) error { + e.memTracker = memory.NewTracker(e.ID(), -1) + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + + e.fetched = &atomic.Bool{} + e.fetched.Store(false) + e.chkHeap = &topNChunkHeap{memTracker: e.memTracker} + e.chkHeap.idx = 0 + + e.finishCh = make(chan struct{}, 1) + e.resultChannel = make(chan rowWithError, e.MaxChunkSize()) + e.chunkChannel = make(chan *chunk.Chunk, e.Concurrency) + e.inMemoryThenSpillFlag = false + e.isSpillTriggeredInStage1ForTest = false + e.isSpillTriggeredInStage2ForTest = false + + if variable.EnableTmpStorageOnOOM.Load() { + e.diskTracker = disk.NewTracker(e.ID(), -1) + diskTracker := e.Ctx().GetSessionVars().StmtCtx.DiskTracker + if diskTracker != nil { + e.diskTracker.AttachTo(diskTracker) } - if cmp > 0 { - return true - } else if cmp < 0 { - return false + e.fetcherAndWorkerSyncer = &sync.WaitGroup{} + + workers := make([]*topNWorker, e.Concurrency) + for i := range workers { + chkHeap := &topNChunkHeap{} + // Offset of heap in worker should be 0, as we need to spill all data + chkHeap.init(e, e.memTracker, e.Limit.Offset+e.Limit.Count, 0, e.greaterRow, e.RetFieldTypes()) + workers[i] = newTopNWorker(i, e.chunkChannel, e.fetcherAndWorkerSyncer, e.resultChannel, e.finishCh, e, chkHeap, e.memTracker) } + + e.spillHelper = newTopNSpillerHelper( + e, + e.finishCh, + e.resultChannel, + e.memTracker, + e.diskTracker, + exec.RetTypes(e), + workers, + e.Concurrency, + ) + e.spillAction = &topNSpillAction{spillHelper: e.spillHelper} + e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.spillAction) } - return false -} -func (h *topNChunkHeap) Len() int { - return len(h.rowPtrs) + return exec.Open(ctx, e.Children(0)) } -func (*topNChunkHeap) Push(any) { - // Should never be called. -} +// Close implements the Executor Close interface. +func (e *TopNExec) Close() error { + // `e.finishCh == nil` means that `Open` is not called. + if e.finishCh == nil { + return exec.Close(e.Children(0)) + } -func (h *topNChunkHeap) Pop() any { - h.rowPtrs = h.rowPtrs[:len(h.rowPtrs)-1] - // We don't need the popped value, return nil to avoid memory allocation. - return nil -} + close(e.finishCh) + if e.fetched.CompareAndSwap(false, true) { + close(e.resultChannel) + return exec.Close(e.Children(0)) + } -func (h *topNChunkHeap) Swap(i, j int) { - h.rowPtrs[i], h.rowPtrs[j] = h.rowPtrs[j], h.rowPtrs[i] -} + // Wait for the finish of all tasks + channel.Clear(e.resultChannel) -func (e *TopNExec) keyColumnsCompare(i, j chunk.RowPtr) int { - rowI := e.chkHeap.rowChunks.GetRow(i) - rowJ := e.chkHeap.rowChunks.GetRow(j) - return e.compareRow(rowI, rowJ) -} + e.chkHeap = nil + e.spillAction = nil -func (e *TopNExec) initPointers() { - e.chkHeap.rowPtrs = make([]chunk.RowPtr, 0, e.chkHeap.rowChunks.Len()) - e.memTracker.Consume(int64(8 * e.chkHeap.rowChunks.Len())) - for chkIdx := 0; chkIdx < e.chkHeap.rowChunks.NumChunks(); chkIdx++ { - rowChk := e.chkHeap.rowChunks.GetChunk(chkIdx) - for rowIdx := 0; rowIdx < rowChk.NumRows(); rowIdx++ { - e.chkHeap.rowPtrs = append(e.chkHeap.rowPtrs, chunk.RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)}) - } + if e.spillHelper != nil { + e.spillHelper.close() + e.spillHelper = nil } -} -// Open implements the Executor Open interface. -func (e *TopNExec) Open(ctx context.Context) error { - e.memTracker = memory.NewTracker(e.ID(), -1) - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + if e.memTracker != nil { + e.memTracker.ReplaceBytesUsed(0) + } - e.fetched = &atomic.Bool{} - e.fetched.Store(false) - e.chkHeap = &topNChunkHeap{TopNExec: e} - e.chkHeap.Idx = 0 + return exec.Close(e.Children(0)) +} - return exec.Open(ctx, e.Children(0)) +func (e *TopNExec) greaterRow(rowI, rowJ chunk.Row) bool { + for i, colIdx := range e.keyColumns { + cmpFunc := e.keyCmpFuncs[i] + cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) + if e.ByItems[i].Desc { + cmp = -cmp + } + if cmp > 0 { + return true + } else if cmp < 0 { + return false + } + } + return false } // Next implements the Executor Next interface. +// +// The following picture shows the procedure of topn when spill is triggered. +/* +Spill Stage: + ┌─────────┐ + │ Child │ + └────▲────┘ + │ + Fetch + │ + ┌───────┴───────┐ + │ Chunk Fetcher │ + └───────┬───────┘ + │ + │ + ▼ + Check Spill──────►Spill Triggered─────────►Spill + │ │ + ▼ │ + Spill Not Triggered │ + │ │ + ▼ │ + Push Chunk◄─────────────────────────────────┘ + │ + ▼ + ┌────────────────►Channel◄───────────────────┐ + │ ▲ │ + │ │ │ + Fetch Fetch Fetch + │ │ │ + ┌────┴───┐ ┌───┴────┐ ┌───┴────┐ + │ Worker │ │ Worker │ ...... │ Worker │ + └────┬───┘ └───┬────┘ └───┬────┘ + │ │ │ + │ │ │ + │ ▼ │ + └───────────► Multi-way Merge◄───────────────┘ + │ + │ + ▼ + Output + +Restore Stage: + ┌────────┐ ┌────────┐ ┌────────┐ + │ Heap │ │ Heap │ ...... │ Heap │ + └────┬───┘ └───┬────┘ └───┬────┘ + │ │ │ + │ │ │ + │ ▼ │ + └───────────► Multi-way Merge◄───────────────┘ + │ + │ + ▼ + Output + +*/ func (e *TopNExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.fetched.CompareAndSwap(false, true) { - e.totalLimit = e.Limit.Offset + e.Limit.Count - e.chkHeap.Idx = int(e.Limit.Offset) - err := e.loadChunksUntilTotalLimit(ctx) - if err != nil { - return err - } - err = e.executeTopN(ctx) + err := e.fetchChunks(ctx) if err != nil { return err } } - if e.chkHeap.Idx >= len(e.chkHeap.rowPtrs) { - return nil - } + if !req.IsFull() { - numToAppend := min(len(e.chkHeap.rowPtrs)-e.chkHeap.Idx, req.RequiredRows()-req.NumRows()) - rows := make([]chunk.Row, numToAppend) - for index := 0; index < numToAppend; index++ { - rows[index] = e.chkHeap.rowChunks.GetRow(e.chkHeap.rowPtrs[e.chkHeap.Idx]) - e.chkHeap.Idx++ + numToAppend := req.RequiredRows() - req.NumRows() + for i := 0; i < numToAppend; i++ { + row, ok := <-e.resultChannel + if !ok || row.err != nil { + return row.err + } + req.AppendRow(row.row) + } + } + return nil +} + +func (e *TopNExec) fetchChunks(ctx context.Context) error { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(e.resultChannel, r) + close(e.resultChannel) } - req.AppendRows(rows) + }() + + err := e.loadChunksUntilTotalLimit(ctx) + if err != nil { + close(e.resultChannel) + return err } + go e.executeTopN(ctx) return nil } func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error { - e.chkHeap.rowChunks = chunk.NewList(exec.RetTypes(e), e.InitCap(), e.MaxChunkSize()) - e.chkHeap.rowChunks.GetMemTracker().AttachTo(e.memTracker) - e.chkHeap.rowChunks.GetMemTracker().SetLabel(memory.LabelForRowChunks) - for uint64(e.chkHeap.rowChunks.Len()) < e.totalLimit { + e.initCompareFuncs() + e.buildKeyColumns() + e.chkHeap.init(e, e.memTracker, e.Limit.Offset+e.Limit.Count, int(e.Limit.Offset), e.greaterRow, e.RetFieldTypes()) + for uint64(e.chkHeap.rowChunks.Len()) < e.chkHeap.totalLimit { srcChk := exec.TryNewCacheChunk(e.Children(0)) // adjust required rows by total limit - srcChk.SetRequiredRows(int(e.totalLimit-uint64(e.chkHeap.rowChunks.Len())), e.MaxChunkSize()) + srcChk.SetRequiredRows(int(e.chkHeap.totalLimit-uint64(e.chkHeap.rowChunks.Len())), e.MaxChunkSize()) err := exec.Next(ctx, e.Children(0), srcChk) if err != nil { return err @@ -166,76 +276,365 @@ func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error { break } e.chkHeap.rowChunks.Add(srcChk) + if e.spillHelper.isSpillNeeded() { + e.isSpillTriggeredInStage1ForTest = true + break + } + + injectTopNRandomFail(1) } - e.initPointers() - e.initCompareFuncs() - e.buildKeyColumns() + + e.chkHeap.initPtrs() return nil } const topNCompactionFactor = 4 -func (e *TopNExec) executeTopN(ctx context.Context) error { - heap.Init(e.chkHeap) - for uint64(len(e.chkHeap.rowPtrs)) > e.totalLimit { - // The number of rows we loaded may exceeds total limit, remove greatest rows by Pop. - heap.Pop(e.chkHeap) +func (e *TopNExec) executeTopNWhenNoSpillTriggered(ctx context.Context) error { + if e.spillHelper.isSpillNeeded() { + e.isSpillTriggeredInStage2ForTest = true + return nil } + childRowChk := exec.TryNewCacheChunk(e.Children(0)) for { + if e.spillHelper.isSpillNeeded() { + e.isSpillTriggeredInStage2ForTest = true + return nil + } + err := exec.Next(ctx, e.Children(0), childRowChk) if err != nil { return err } + if childRowChk.NumRows() == 0 { break } - err = e.processChildChk(childRowChk) - if err != nil { - return err - } + + e.chkHeap.processChk(childRowChk) + if e.chkHeap.rowChunks.Len() > len(e.chkHeap.rowPtrs)*topNCompactionFactor { - err = e.doCompaction(e.chkHeap) + err = e.chkHeap.doCompaction(e) if err != nil { return err } } + injectTopNRandomFail(10) + } + + slices.SortFunc(e.chkHeap.rowPtrs, e.chkHeap.keyColumnsCompare) + return nil +} + +func (e *TopNExec) spillRemainingRowsWhenNeeded() error { + if e.spillHelper.isSpillTriggered() { + return e.spillHelper.spill() + } + return nil +} + +func (e *TopNExec) checkSpillAndExecute() error { + if e.spillHelper.isSpillNeeded() { + // Wait for the stop of all workers + e.fetcherAndWorkerSyncer.Wait() + return e.spillHelper.spill() } - slices.SortFunc(e.chkHeap.rowPtrs, e.keyColumnsCompare) return nil } -func (e *TopNExec) processChildChk(childRowChk *chunk.Chunk) error { - for i := 0; i < childRowChk.NumRows(); i++ { - heapMaxPtr := e.chkHeap.rowPtrs[0] - var heapMax, next chunk.Row - heapMax = e.chkHeap.rowChunks.GetRow(heapMaxPtr) - next = childRowChk.GetRow(i) - if e.chkHeap.greaterRow(heapMax, next) { - // Evict heap max, keep the next row. - e.chkHeap.rowPtrs[0] = e.chkHeap.rowChunks.AppendRow(childRowChk.GetRow(i)) - heap.Fix(e.chkHeap, 0) +func (e *TopNExec) fetchChunksFromChild(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(e.resultChannel, r) + } + + e.fetcherAndWorkerSyncer.Wait() + err := e.spillRemainingRowsWhenNeeded() + if err != nil { + e.resultChannel <- rowWithError{err: err} + } + + close(e.chunkChannel) + }() + + for { + chk := exec.TryNewCacheChunk(e.Children(0)) + err := exec.Next(ctx, e.Children(0), chk) + if err != nil { + e.resultChannel <- rowWithError{err: err} + return + } + + rowCount := chk.NumRows() + if rowCount == 0 { + break + } + + e.fetcherAndWorkerSyncer.Add(1) + select { + case <-e.finishCh: + e.fetcherAndWorkerSyncer.Done() + return + case e.chunkChannel <- chk: + } + + injectTopNRandomFail(10) + + err = e.checkSpillAndExecute() + if err != nil { + e.resultChannel <- rowWithError{err: err} + return } } +} + +// Spill the heap which is in TopN executor +func (e *TopNExec) spillTopNExecHeap() error { + e.spillHelper.setInSpilling() + defer e.spillHelper.cond.Broadcast() + defer e.spillHelper.setNotSpilled() + + err := e.spillHelper.spillHeap(e.chkHeap) + if err != nil { + return err + } return nil } -// doCompaction rebuild the chunks and row pointers to release memory. -// If we don't do compaction, in a extreme case like the child data is already ascending sorted -// but we want descending top N, then we will keep all data in memory. -// But if data is distributed randomly, this function will be called log(n) times. -func (e *TopNExec) doCompaction(chkHeap *topNChunkHeap) error { - newRowChunks := chunk.NewList(exec.RetTypes(e), e.InitCap(), e.MaxChunkSize()) - newRowPtrs := make([]chunk.RowPtr, 0, chkHeap.rowChunks.Len()) - for _, rowPtr := range chkHeap.rowPtrs { - newRowPtr := newRowChunks.AppendRow(chkHeap.rowChunks.GetRow(rowPtr)) - newRowPtrs = append(newRowPtrs, newRowPtr) +func (e *TopNExec) executeTopNWhenSpillTriggered(ctx context.Context) error { + // idx need to be set to 0 as we need to spill all data + e.chkHeap.idx = 0 + err := e.spillTopNExecHeap() + if err != nil { + return err } - newRowChunks.GetMemTracker().SetLabel(memory.LabelForRowChunks) - e.memTracker.ReplaceChild(chkHeap.rowChunks.GetMemTracker(), newRowChunks.GetMemTracker()) - chkHeap.rowChunks = newRowChunks - e.memTracker.Consume(int64(8 * (len(newRowPtrs) - len(chkHeap.rowPtrs)))) - chkHeap.rowPtrs = newRowPtrs + // Wait for the finish of chunk fetcher + fetcherWaiter := util.WaitGroupWrapper{} + // Wait for the finish of all workers + workersWaiter := util.WaitGroupWrapper{} + + // Fetch chunks from child and put chunks into chunkChannel + fetcherWaiter.Run(func() { + e.fetchChunksFromChild(ctx) + }) + + for i := range e.spillHelper.workers { + worker := e.spillHelper.workers[i] + workersWaiter.Run(func() { + worker.run() + }) + } + + fetcherWaiter.Wait() + workersWaiter.Wait() return nil } + +func (e *TopNExec) executeTopN(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(e.resultChannel, r) + } + + close(e.resultChannel) + }() + + heap.Init(e.chkHeap) + for uint64(len(e.chkHeap.rowPtrs)) > e.chkHeap.totalLimit { + // The number of rows we loaded may exceeds total limit, remove greatest rows by Pop. + heap.Pop(e.chkHeap) + } + + if err := e.executeTopNWhenNoSpillTriggered(ctx); err != nil { + e.resultChannel <- rowWithError{err: err} + return + } + + if e.spillHelper.isSpillNeeded() { + if err := e.executeTopNWhenSpillTriggered(ctx); err != nil { + e.resultChannel <- rowWithError{err: err} + return + } + } + + e.generateTopNResults() +} + +// Return true when spill is triggered +func (e *TopNExec) generateTopNResultsWhenNoSpillTriggered() bool { + rowPtrNum := len(e.chkHeap.rowPtrs) + for ; e.chkHeap.idx < rowPtrNum; e.chkHeap.idx++ { + if e.chkHeap.idx%10 == 0 && e.spillHelper.isSpillNeeded() { + return true + } + e.resultChannel <- rowWithError{row: e.chkHeap.rowChunks.GetRow(e.chkHeap.rowPtrs[e.chkHeap.idx])} + } + return false +} + +func (e *TopNExec) generateResultWithMultiWayMerge(offset int64, limit int64) error { + multiWayMerge := newMultiWayMerger(&diskSource{sortedRowsInDisk: e.spillHelper.sortedRowsInDisk}, e.lessRow) + + err := multiWayMerge.init() + if err != nil { + return err + } + + outputRowNum := int64(0) + for { + if outputRowNum >= limit { + return nil + } + + row, err := multiWayMerge.next() + if err != nil { + return err + } + + if row.IsEmpty() { + return nil + } + + if outputRowNum >= offset { + select { + case <-e.finishCh: + return nil + case e.resultChannel <- rowWithError{row: row}: + } + } + outputRowNum++ + injectParallelSortRandomFail(1) + } +} + +// GenerateTopNResultsWhenSpillOnlyOnce generates results with this function when we trigger spill only once. +// It's a public function as we need to test it in ut. +func (e *TopNExec) GenerateTopNResultsWhenSpillOnlyOnce() error { + inDisk := e.spillHelper.sortedRowsInDisk[0] + chunkNum := inDisk.NumChunks() + skippedRowNum := uint64(0) + offset := e.Limit.Offset + for i := 0; i < chunkNum; i++ { + chk, err := inDisk.GetChunk(i) + if err != nil { + return err + } + + injectTopNRandomFail(10) + + rowNum := chk.NumRows() + j := 0 + if !e.inMemoryThenSpillFlag { + // When e.inMemoryThenSpillFlag == false, we need to manually set j + // because rows that should be ignored before offset have also been + // spilled to disk. + if skippedRowNum < offset { + rowNumNeedSkip := offset - skippedRowNum + if rowNum <= int(rowNumNeedSkip) { + // All rows in this chunk should be skipped + skippedRowNum += uint64(rowNum) + continue + } + j += int(rowNumNeedSkip) + skippedRowNum += rowNumNeedSkip + } + } + + for ; j < rowNum; j++ { + select { + case <-e.finishCh: + return nil + case e.resultChannel <- rowWithError{row: chk.GetRow(j)}: + } + } + } + return nil +} + +func (e *TopNExec) generateTopNResultsWhenSpillTriggered() error { + inDiskNum := len(e.spillHelper.sortedRowsInDisk) + if inDiskNum == 0 { + panic("inDiskNum can't be 0 when we generate result with spill triggered") + } + + if inDiskNum == 1 { + return e.GenerateTopNResultsWhenSpillOnlyOnce() + } + return e.generateResultWithMultiWayMerge(int64(e.Limit.Offset), int64(e.Limit.Offset+e.Limit.Count)) +} + +func (e *TopNExec) generateTopNResults() { + if !e.spillHelper.isSpillTriggered() { + if !e.generateTopNResultsWhenNoSpillTriggered() { + return + } + + err := e.spillTopNExecHeap() + if err != nil { + e.resultChannel <- rowWithError{err: err} + } + + e.inMemoryThenSpillFlag = true + } + + err := e.generateTopNResultsWhenSpillTriggered() + if err != nil { + e.resultChannel <- rowWithError{err: err} + } +} + +// IsSpillTriggeredForTest shows if spill is triggered, used for test. +func (e *TopNExec) IsSpillTriggeredForTest() bool { + return e.spillHelper.isSpillTriggered() +} + +// GetIsSpillTriggeredInStage1ForTest shows if spill is triggered in stage 1, only used for test. +func (e *TopNExec) GetIsSpillTriggeredInStage1ForTest() bool { + return e.isSpillTriggeredInStage1ForTest +} + +// GetIsSpillTriggeredInStage2ForTest shows if spill is triggered in stage 2, only used for test. +func (e *TopNExec) GetIsSpillTriggeredInStage2ForTest() bool { + return e.isSpillTriggeredInStage2ForTest +} + +// GetInMemoryThenSpillFlagForTest shows if results are in memory before they are spilled, only used for test +func (e *TopNExec) GetInMemoryThenSpillFlagForTest() bool { + return e.inMemoryThenSpillFlag +} + +func injectTopNRandomFail(triggerFactor int32) { + failpoint.Inject("TopNRandomFail", func(val failpoint.Value) { + if val.(bool) { + randNum := rand.Int31n(10000) + if randNum < triggerFactor { + panic("panic is triggered by random fail") + } + } + }) +} + +// InitTopNExecForTest initializes TopN executors, only for test. +func InitTopNExecForTest(topnExec *TopNExec, offset uint64, sortedRowsInDisk *chunk.DataInDiskByChunks) { + topnExec.inMemoryThenSpillFlag = false + topnExec.finishCh = make(chan struct{}, 1) + topnExec.resultChannel = make(chan rowWithError, 10000) + topnExec.Limit.Offset = offset + topnExec.spillHelper = &topNSpillHelper{} + topnExec.spillHelper.sortedRowsInDisk = []*chunk.DataInDiskByChunks{sortedRowsInDisk} +} + +// GetResultForTest gets result, only for test. +func GetResultForTest(topnExec *TopNExec) []int64 { + close(topnExec.resultChannel) + result := make([]int64, 0, 100) + for { + row, ok := <-topnExec.resultChannel + if !ok { + return result + } + result = append(result, row.row.GetInt64(0)) + } +} diff --git a/pkg/executor/sortexec/topn_chunk_heap.go b/pkg/executor/sortexec/topn_chunk_heap.go new file mode 100644 index 0000000000000..df19763e4693a --- /dev/null +++ b/pkg/executor/sortexec/topn_chunk_heap.go @@ -0,0 +1,155 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec + +import ( + "container/heap" + + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/memory" +) + +// topNChunkHeap implements heap.Interface. +type topNChunkHeap struct { + compareRow func(chunk.Row, chunk.Row) int + greaterRow func(chunk.Row, chunk.Row) bool + + // rowChunks is the chunks to store row values. + rowChunks *chunk.List + // rowPointer store the chunk index and row index for each row. + rowPtrs []chunk.RowPtr + + isInitialized bool + isRowPtrsInit bool + + memTracker *memory.Tracker + + totalLimit uint64 + idx int + + fieldTypes []*types.FieldType +} + +func (h *topNChunkHeap) init(topnExec *TopNExec, memTracker *memory.Tracker, totalLimit uint64, idx int, greaterRow func(chunk.Row, chunk.Row) bool, fieldTypes []*types.FieldType) { + h.memTracker = memTracker + + h.rowChunks = chunk.NewList(exec.RetTypes(topnExec), topnExec.InitCap(), topnExec.MaxChunkSize()) + h.rowChunks.GetMemTracker().AttachTo(h.memTracker) + h.rowChunks.GetMemTracker().SetLabel(memory.LabelForRowChunks) + + h.compareRow = topnExec.compareRow + h.greaterRow = greaterRow + + h.totalLimit = totalLimit + h.idx = idx + h.isInitialized = true + + h.fieldTypes = fieldTypes +} + +func (h *topNChunkHeap) initPtrs() { + h.memTracker.Consume(int64(chunk.RowPtrSize * h.rowChunks.Len())) + h.initPtrsImpl() +} + +func (h *topNChunkHeap) initPtrsImpl() { + h.rowPtrs = make([]chunk.RowPtr, 0, h.rowChunks.Len()) + for chkIdx := 0; chkIdx < h.rowChunks.NumChunks(); chkIdx++ { + rowChk := h.rowChunks.GetChunk(chkIdx) + for rowIdx := 0; rowIdx < rowChk.NumRows(); rowIdx++ { + h.rowPtrs = append(h.rowPtrs, chunk.RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)}) + } + } + h.isRowPtrsInit = true +} + +func (h *topNChunkHeap) clear() { + h.rowChunks.Clear() + h.memTracker.Consume(int64(-chunk.RowPtrSize * len(h.rowPtrs))) + h.rowPtrs = nil + h.isRowPtrsInit = false + h.isInitialized = false + h.idx = 0 +} + +func (h *topNChunkHeap) update(heapMaxRow chunk.Row, newRow chunk.Row) { + if h.greaterRow(heapMaxRow, newRow) { + // Evict heap max, keep the next row. + h.rowPtrs[0] = h.rowChunks.AppendRow(newRow) + heap.Fix(h, 0) + } +} + +func (h *topNChunkHeap) processChk(chk *chunk.Chunk) { + for i := 0; i < chk.NumRows(); i++ { + heapMaxRow := h.rowChunks.GetRow(h.rowPtrs[0]) + newRow := chk.GetRow(i) + h.update(heapMaxRow, newRow) + } +} + +// doCompaction rebuild the chunks and row pointers to release memory. +// If we don't do compaction, in a extreme case like the child data is already ascending sorted +// but we want descending top N, then we will keep all data in memory. +// But if data is distributed randomly, this function will be called log(n) times. +func (h *topNChunkHeap) doCompaction(topnExec *TopNExec) error { + newRowChunks := chunk.NewList(exec.RetTypes(topnExec), topnExec.InitCap(), topnExec.MaxChunkSize()) + newRowPtrs := make([]chunk.RowPtr, 0, h.rowChunks.Len()) + for _, rowPtr := range h.rowPtrs { + newRowPtr := newRowChunks.AppendRow(h.rowChunks.GetRow(rowPtr)) + newRowPtrs = append(newRowPtrs, newRowPtr) + } + newRowChunks.GetMemTracker().SetLabel(memory.LabelForRowChunks) + h.memTracker.ReplaceChild(h.rowChunks.GetMemTracker(), newRowChunks.GetMemTracker()) + h.rowChunks = newRowChunks + + h.memTracker.Consume(int64(chunk.RowPtrSize * (len(newRowPtrs) - len(h.rowPtrs)))) + h.rowPtrs = newRowPtrs + return nil +} + +func (h *topNChunkHeap) keyColumnsCompare(i, j chunk.RowPtr) int { + rowI := h.rowChunks.GetRow(i) + rowJ := h.rowChunks.GetRow(j) + return h.compareRow(rowI, rowJ) +} + +// Less implement heap.Interface, but since we mantains a max heap, +// this function returns true if row i is greater than row j. +func (h *topNChunkHeap) Less(i, j int) bool { + rowI := h.rowChunks.GetRow(h.rowPtrs[i]) + rowJ := h.rowChunks.GetRow(h.rowPtrs[j]) + return h.greaterRow(rowI, rowJ) +} + +func (h *topNChunkHeap) Len() int { + return len(h.rowPtrs) +} + +func (*topNChunkHeap) Push(any) { + // Should never be called. +} + +func (h *topNChunkHeap) Pop() any { + h.rowPtrs = h.rowPtrs[:len(h.rowPtrs)-1] + // We don't need the popped value, return nil to avoid memory allocation. + return nil +} + +func (h *topNChunkHeap) Swap(i, j int) { + h.rowPtrs[i], h.rowPtrs[j] = h.rowPtrs[j], h.rowPtrs[i] +} diff --git a/pkg/executor/sortexec/topn_spill.go b/pkg/executor/sortexec/topn_spill.go new file mode 100644 index 0000000000000..b3cf5ff20c0e8 --- /dev/null +++ b/pkg/executor/sortexec/topn_spill.go @@ -0,0 +1,265 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec + +import ( + "slices" + "sync" + "sync/atomic" + + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" +) + +type topNSpillHelper struct { + cond *sync.Cond + spillStatus int + sortedRowsInDisk []*chunk.DataInDiskByChunks + + finishCh <-chan struct{} + errOutputChan chan<- rowWithError + + memTracker *memory.Tracker + diskTracker *disk.Tracker + + fieldTypes []*types.FieldType + tmpSpillChunksChan chan *chunk.Chunk + + workers []*topNWorker + + bytesConsumed atomic.Int64 + bytesLimit atomic.Int64 +} + +func newTopNSpillerHelper( + topn *TopNExec, + finishCh <-chan struct{}, + errOutputChan chan<- rowWithError, + memTracker *memory.Tracker, + diskTracker *disk.Tracker, + fieldTypes []*types.FieldType, + workers []*topNWorker, + concurrencyNum int, +) *topNSpillHelper { + lock := sync.Mutex{} + tmpSpillChunksChan := make(chan *chunk.Chunk, concurrencyNum) + for i := 0; i < len(workers); i++ { + tmpSpillChunksChan <- exec.TryNewCacheChunk(topn.Children(0)) + } + + return &topNSpillHelper{ + cond: sync.NewCond(&lock), + spillStatus: notSpilled, + sortedRowsInDisk: make([]*chunk.DataInDiskByChunks, 0), + finishCh: finishCh, + errOutputChan: errOutputChan, + memTracker: memTracker, + diskTracker: diskTracker, + fieldTypes: fieldTypes, + tmpSpillChunksChan: tmpSpillChunksChan, + workers: workers, + bytesConsumed: atomic.Int64{}, + bytesLimit: atomic.Int64{}, + } +} + +func (t *topNSpillHelper) close() { + for _, inDisk := range t.sortedRowsInDisk { + inDisk.Close() + } +} + +func (t *topNSpillHelper) isNotSpilledNoLock() bool { + return t.spillStatus == notSpilled +} + +func (t *topNSpillHelper) isInSpillingNoLock() bool { + return t.spillStatus == inSpilling +} + +func (t *topNSpillHelper) isSpillNeeded() bool { + t.cond.L.Lock() + defer t.cond.L.Unlock() + return t.spillStatus == needSpill +} + +func (t *topNSpillHelper) isSpillTriggered() bool { + t.cond.L.Lock() + defer t.cond.L.Unlock() + return len(t.sortedRowsInDisk) > 0 +} + +func (t *topNSpillHelper) setInSpilling() { + t.cond.L.Lock() + defer t.cond.L.Unlock() + t.spillStatus = inSpilling + logutil.BgLogger().Info(spillInfo, zap.Int64("consumed", t.bytesConsumed.Load()), zap.Int64("quota", t.bytesLimit.Load())) +} + +func (t *topNSpillHelper) setNotSpilled() { + t.cond.L.Lock() + defer t.cond.L.Unlock() + t.spillStatus = notSpilled +} + +func (t *topNSpillHelper) setNeedSpillNoLock() { + t.spillStatus = needSpill +} + +func (t *topNSpillHelper) addInDisk(inDisk *chunk.DataInDiskByChunks) { + t.cond.L.Lock() + defer t.cond.L.Unlock() + t.sortedRowsInDisk = append(t.sortedRowsInDisk, inDisk) +} + +func (*topNSpillHelper) spillTmpSpillChunk(inDisk *chunk.DataInDiskByChunks, tmpSpillChunk *chunk.Chunk) error { + err := inDisk.Add(tmpSpillChunk) + if err != nil { + return err + } + tmpSpillChunk.Reset() + return nil +} + +func (t *topNSpillHelper) spill() (err error) { + defer func() { + if r := recover(); r != nil { + err = util.GetRecoverError(r) + } + }() + + select { + case <-t.finishCh: + return nil + default: + } + + t.setInSpilling() + defer t.cond.Broadcast() + defer t.setNotSpilled() + + workerNum := len(t.workers) + errChan := make(chan error, workerNum) + workerWaiter := &sync.WaitGroup{} + workerWaiter.Add(workerNum) + for i := 0; i < workerNum; i++ { + go func(idx int) { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(t.errOutputChan, r) + } + workerWaiter.Done() + }() + + spillErr := t.spillHeap(t.workers[idx].chkHeap) + if spillErr != nil { + errChan <- spillErr + } + }(i) + } + + workerWaiter.Wait() + close(errChan) + + // Fetch only one error is enough + spillErr := <-errChan + if spillErr != nil { + return spillErr + } + return nil +} + +func (t *topNSpillHelper) spillHeap(chkHeap *topNChunkHeap) error { + if chkHeap.Len() <= 0 && chkHeap.rowChunks.Len() <= 0 { + return nil + } + + if !chkHeap.isRowPtrsInit { + // Do not consume memory here, as it will hang + chkHeap.initPtrsImpl() + } + slices.SortFunc(chkHeap.rowPtrs, chkHeap.keyColumnsCompare) + + tmpSpillChunk := <-t.tmpSpillChunksChan + tmpSpillChunk.Reset() + defer func() { + t.tmpSpillChunksChan <- tmpSpillChunk + }() + + inDisk := chunk.NewDataInDiskByChunks(t.fieldTypes) + inDisk.GetDiskTracker().AttachTo(t.diskTracker) + + rowPtrNum := chkHeap.Len() + for ; chkHeap.idx < rowPtrNum; chkHeap.idx++ { + if tmpSpillChunk.IsFull() { + err := t.spillTmpSpillChunk(inDisk, tmpSpillChunk) + if err != nil { + return err + } + } + tmpSpillChunk.AppendRow(chkHeap.rowChunks.GetRow(chkHeap.rowPtrs[chkHeap.idx])) + } + + // Spill remaining rows in tmpSpillChunk + if tmpSpillChunk.NumRows() > 0 { + err := t.spillTmpSpillChunk(inDisk, tmpSpillChunk) + if err != nil { + return err + } + } + + t.addInDisk(inDisk) + injectTopNRandomFail(200) + + chkHeap.clear() + return nil +} + +type topNSpillAction struct { + memory.BaseOOMAction + spillHelper *topNSpillHelper +} + +// GetPriority get the priority of the Action. +func (*topNSpillAction) GetPriority() int64 { + return memory.DefSpillPriority +} + +func (t *topNSpillAction) Action(tracker *memory.Tracker) { + t.spillHelper.cond.L.Lock() + defer t.spillHelper.cond.L.Unlock() + + for t.spillHelper.isInSpillingNoLock() { + t.spillHelper.cond.Wait() + } + + hasEnoughData := hasEnoughDataToSpill(t.spillHelper.memTracker, tracker) + if tracker.CheckExceed() && t.spillHelper.isNotSpilledNoLock() && hasEnoughData { + t.spillHelper.setNeedSpillNoLock() + t.spillHelper.bytesConsumed.Store(tracker.BytesConsumed()) + t.spillHelper.bytesLimit.Store(tracker.GetBytesLimit()) + return + } + + if tracker.CheckExceed() && !hasEnoughData { + t.GetFallback() + } +} diff --git a/pkg/executor/sortexec/topn_spill_test.go b/pkg/executor/sortexec/topn_spill_test.go new file mode 100644 index 0000000000000..3c8ddda4ac2c8 --- /dev/null +++ b/pkg/executor/sortexec/topn_spill_test.go @@ -0,0 +1,480 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec_test + +import ( + "context" + "math/rand" + "sync" + "testing" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/internal/testutil" + "github.com/pingcap/tidb/pkg/executor/sortexec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + plannerutil "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/stretchr/testify/require" +) + +var totalRowNum = 10000 +var noSpillCaseHardLimit = hardLimit2 +var spillCase1HardLimit = hardLimit1 +var spillCase2HardLimit = hardLimit1 +var inMemoryThenSpillHardLimit = hardLimit1 * 2 + +// Test is successful if there is no hang +func executeTopNInFailpoint(t *testing.T, exe *sortexec.TopNExec, hardLimit int64, tracker *memory.Tracker) { + tmpCtx := context.Background() + err := exe.Open(tmpCtx) + require.NoError(t, err) + + goRoutineWaiter := sync.WaitGroup{} + goRoutineWaiter.Add(1) + defer goRoutineWaiter.Wait() + + once := sync.Once{} + + go func() { + time.Sleep(time.Duration(rand.Int31n(300)) * time.Millisecond) + once.Do(func() { + exe.Close() + }) + goRoutineWaiter.Done() + }() + + chk := exec.NewFirstChunk(exe) + for i := 0; i >= 0; i++ { + err := exe.Next(tmpCtx, chk) + if err != nil { + once.Do(func() { + err = exe.Close() + require.Equal(t, nil, err) + }) + break + } + if chk.NumRows() == 0 { + break + } + + if i == 10 && hardLimit > 0 { + // Trigger the spill + tracker.Consume(hardLimit) + tracker.Consume(-hardLimit) + } + } + once.Do(func() { + err = exe.Close() + require.Equal(t, nil, err) + }) +} + +func initTopNNoSpillCaseParams( + ctx *mock.Context, + dataSource *testutil.MockDataSource, + topNCase *testutil.SortCase, + totalRowNum int, + count *uint64, + offset *uint64, + exe **sortexec.TopNExec, +) { + ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, noSpillCaseHardLimit) + ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) + + *count = uint64(totalRowNum / 3) + *offset = uint64(totalRowNum / 10) + + if exe != nil { + *exe = buildTopNExec(topNCase, dataSource, *offset, *count) + } +} + +func initTopNSpillCase1Params( + ctx *mock.Context, + dataSource *testutil.MockDataSource, + topNCase *testutil.SortCase, + totalRowNum int, + count *uint64, + offset *uint64, + exe **sortexec.TopNExec, +) { + ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, spillCase1HardLimit) + ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) + + *count = uint64(totalRowNum - totalRowNum/10) + *offset = uint64(totalRowNum / 10) + + if exe != nil { + *exe = buildTopNExec(topNCase, dataSource, *offset, *count) + } +} + +func initTopNSpillCase2Params( + ctx *mock.Context, + dataSource *testutil.MockDataSource, + topNCase *testutil.SortCase, + totalRowNum int, + count *uint64, + offset *uint64, + exe **sortexec.TopNExec, +) { + ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, spillCase2HardLimit) + ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) + + *count = uint64(totalRowNum / 5) + *offset = *count / 5 + + if exe != nil { + *exe = buildTopNExec(topNCase, dataSource, *offset, *count) + } +} + +func initTopNInMemoryThenSpillParams( + ctx *mock.Context, + dataSource *testutil.MockDataSource, + topNCase *testutil.SortCase, + totalRowNum int, + count *uint64, + offset *uint64, + exe **sortexec.TopNExec, +) { + ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, inMemoryThenSpillHardLimit) + ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) + + *count = uint64(totalRowNum / 5) + *offset = *count / 5 + + if exe != nil { + *exe = buildTopNExec(topNCase, dataSource, *offset, *count) + } +} + +func checkTopNCorrectness(schema *expression.Schema, exe *sortexec.TopNExec, dataSource *testutil.MockDataSource, resultChunks []*chunk.Chunk, offset uint64, count uint64) bool { + keyColumns, keyCmpFuncs, byItemsDesc := exe.GetSortMetaForTest() + checker := newResultChecker(schema, keyColumns, keyCmpFuncs, byItemsDesc, dataSource.GenData) + return checker.check(resultChunks, int64(offset), int64(count)) +} + +func buildTopNExec(sortCase *testutil.SortCase, dataSource *testutil.MockDataSource, offset uint64, count uint64) *sortexec.TopNExec { + dataSource.PrepareChunks() + sortExec := sortexec.SortExec{ + BaseExecutor: exec.NewBaseExecutor(sortCase.Ctx, dataSource.Schema(), 0, dataSource), + ByItems: make([]*plannerutil.ByItems, 0, len(sortCase.OrderByIdx)), + ExecSchema: dataSource.Schema(), + } + + for _, idx := range sortCase.OrderByIdx { + sortExec.ByItems = append(sortExec.ByItems, &plannerutil.ByItems{Expr: sortCase.Columns()[idx]}) + } + + topNexec := &sortexec.TopNExec{ + SortExec: sortExec, + Limit: &plannercore.PhysicalLimit{Offset: offset, Count: count}, + Concurrency: 5, + } + + return topNexec +} + +func executeTopNExecutor(t *testing.T, exe *sortexec.TopNExec) []*chunk.Chunk { + tmpCtx := context.Background() + err := exe.Open(tmpCtx) + require.NoError(t, err) + + resultChunks := make([]*chunk.Chunk, 0) + chk := exec.NewFirstChunk(exe) + for { + err = exe.Next(tmpCtx, chk) + require.NoError(t, err) + if chk.NumRows() == 0 { + break + } + resultChunks = append(resultChunks, chk.CopyConstruct()) + } + return resultChunks +} + +func executeTopNAndManuallyTriggerSpill(t *testing.T, exe *sortexec.TopNExec, hardLimit int64, tracker *memory.Tracker) []*chunk.Chunk { + tmpCtx := context.Background() + err := exe.Open(tmpCtx) + require.NoError(t, err) + + resultChunks := make([]*chunk.Chunk, 0) + chk := exec.NewFirstChunk(exe) + for i := 0; i >= 0; i++ { + err = exe.Next(tmpCtx, chk) + require.NoError(t, err) + + if i == 10 { + // Trigger the spill + tracker.Consume(hardLimit) + tracker.Consume(-hardLimit) + } + + if chk.NumRows() == 0 { + break + } + resultChunks = append(resultChunks, chk.CopyConstruct()) + } + return resultChunks +} + +// No spill will be triggered in this test +func topNNoSpillCase(t *testing.T, exe *sortexec.TopNExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource, offset uint64, count uint64) { + if exe == nil { + exe = buildTopNExec(sortCase, dataSource, offset, count) + } + dataSource.PrepareChunks() + resultChunks := executeTopNExecutor(t, exe) + + require.False(t, exe.IsSpillTriggeredForTest()) + + err := exe.Close() + require.NoError(t, err) + + require.True(t, checkTopNCorrectness(schema, exe, dataSource, resultChunks, offset, count)) +} + +// Topn executor has two stage: +// 1. Building heap, in this stage all received rows will be inserted into heap. +// 2. Updating heap, in this stage only rows that is smaller than the heap top could be inserted and we will drop the heap top. +// +// Case1 means that we will trigger spill in stage 1 +func topNSpillCase1(t *testing.T, exe *sortexec.TopNExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource, offset uint64, count uint64) { + if exe == nil { + exe = buildTopNExec(sortCase, dataSource, offset, count) + } + dataSource.PrepareChunks() + resultChunks := executeTopNExecutor(t, exe) + + require.True(t, exe.IsSpillTriggeredForTest()) + require.True(t, exe.GetIsSpillTriggeredInStage1ForTest()) + require.False(t, exe.GetInMemoryThenSpillFlagForTest()) + + err := exe.Close() + require.NoError(t, err) + + require.True(t, checkTopNCorrectness(schema, exe, dataSource, resultChunks, offset, count)) +} + +// Case2 means that we will trigger spill in stage 2 +func topNSpillCase2(t *testing.T, exe *sortexec.TopNExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource, offset uint64, count uint64) { + if exe == nil { + exe = buildTopNExec(sortCase, dataSource, offset, count) + } + dataSource.PrepareChunks() + resultChunks := executeTopNExecutor(t, exe) + + require.True(t, exe.IsSpillTriggeredForTest()) + require.False(t, exe.GetIsSpillTriggeredInStage1ForTest()) + require.True(t, exe.GetIsSpillTriggeredInStage2ForTest()) + require.False(t, exe.GetInMemoryThenSpillFlagForTest()) + + err := exe.Close() + require.NoError(t, err) + + require.True(t, checkTopNCorrectness(schema, exe, dataSource, resultChunks, offset, count)) +} + +// After all sorted rows are in memory, then the spill will be triggered after some chunks have been fetched +func topNInMemoryThenSpillCase(t *testing.T, ctx *mock.Context, exe *sortexec.TopNExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource, offset uint64, count uint64) { + if exe == nil { + exe = buildTopNExec(sortCase, dataSource, offset, count) + } + dataSource.PrepareChunks() + resultChunks := executeTopNAndManuallyTriggerSpill(t, exe, hardLimit1*2, ctx.GetSessionVars().StmtCtx.MemTracker) + + require.True(t, exe.IsSpillTriggeredForTest()) + require.False(t, exe.GetIsSpillTriggeredInStage1ForTest()) + require.False(t, exe.GetIsSpillTriggeredInStage2ForTest()) + require.True(t, exe.GetInMemoryThenSpillFlagForTest()) + + err := exe.Close() + require.NoError(t, err) + + require.True(t, checkTopNCorrectness(schema, exe, dataSource, resultChunks, offset, count)) +} + +func topNFailPointTest(t *testing.T, exe *sortexec.TopNExec, sortCase *testutil.SortCase, dataSource *testutil.MockDataSource, offset uint64, count uint64, hardLimit int64, tracker *memory.Tracker) { + if exe == nil { + exe = buildTopNExec(sortCase, dataSource, offset, count) + } + dataSource.PrepareChunks() + executeTopNInFailpoint(t, exe, hardLimit, tracker) +} + +const spilledChunkMaxSize = 32 + +func createAndInitDataInDiskByChunks(spilledRowNum uint64) *chunk.DataInDiskByChunks { + fieldType := types.FieldType{} + fieldType.SetType(mysql.TypeLonglong) + inDisk := chunk.NewDataInDiskByChunks([]*types.FieldType{&fieldType}) + var spilledChunk *chunk.Chunk + for i := uint64(0); i < spilledRowNum; i++ { + if i%spilledChunkMaxSize == 0 { + if spilledChunk != nil { + inDisk.Add(spilledChunk) + } + spilledChunk = chunk.NewChunkWithCapacity([]*types.FieldType{&fieldType}, spilledChunkMaxSize) + } + spilledChunk.AppendInt64(0, int64(i)) + } + inDisk.Add(spilledChunk) + return inDisk +} + +func testImpl(t *testing.T, topnExec *sortexec.TopNExec, inDisk *chunk.DataInDiskByChunks, totalRowNumInDisk uint64, offset uint64) { + sortexec.InitTopNExecForTest(topnExec, offset, inDisk) + topnExec.GenerateTopNResultsWhenSpillOnlyOnce() + result := sortexec.GetResultForTest(topnExec) + require.Equal(t, int(totalRowNumInDisk-offset), len(result)) + for i := range result { + require.Equal(t, int64(i+int(offset)), result[i]) + } +} + +func oneChunkInDiskCase(t *testing.T, topnExec *sortexec.TopNExec) { + rowNumInDisk := uint64(spilledChunkMaxSize) + inDisk := createAndInitDataInDiskByChunks(rowNumInDisk) + + testImpl(t, topnExec, inDisk, rowNumInDisk, 0) + testImpl(t, topnExec, inDisk, rowNumInDisk, uint64(spilledChunkMaxSize-15)) + testImpl(t, topnExec, inDisk, rowNumInDisk, rowNumInDisk-1) + testImpl(t, topnExec, inDisk, rowNumInDisk, rowNumInDisk) +} + +func severalChunksInDiskCase(t *testing.T, topnExec *sortexec.TopNExec) { + rowNumInDisk := uint64(spilledChunkMaxSize*3 + 10) + inDisk := createAndInitDataInDiskByChunks(rowNumInDisk) + + testImpl(t, topnExec, inDisk, rowNumInDisk, 0) + testImpl(t, topnExec, inDisk, rowNumInDisk, spilledChunkMaxSize-15) + testImpl(t, topnExec, inDisk, rowNumInDisk, spilledChunkMaxSize*2+10) + testImpl(t, topnExec, inDisk, rowNumInDisk, rowNumInDisk-1) + testImpl(t, topnExec, inDisk, rowNumInDisk, rowNumInDisk) +} + +func TestGenerateTopNResultsWhenSpillOnlyOnce(t *testing.T) { + topnExec := &sortexec.TopNExec{} + topnExec.Limit = &plannercore.PhysicalLimit{} + + oneChunkInDiskCase(t, topnExec) + severalChunksInDiskCase(t, topnExec) +} + +func TestTopNSpillDisk(t *testing.T) { + sortexec.SetSmallSpillChunkSizeForTest() + ctx := mock.NewContext() + topNCase := &testutil.SortCase{Rows: totalRowNum, OrderByIdx: []int{0, 1}, Ndvs: []int{0, 0}, Ctx: ctx} + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/sortexec/SlowSomeWorkers", `return(true)`)) + + ctx.GetSessionVars().InitChunkSize = 32 + ctx.GetSessionVars().MaxChunkSize = 32 + ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2) + ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) + ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) + + offset := uint64(totalRowNum / 10) + count := uint64(totalRowNum / 3) + + var exe *sortexec.TopNExec + schema := expression.NewSchema(topNCase.Columns()...) + dataSource := buildDataSource(topNCase, schema) + initTopNNoSpillCaseParams(ctx, dataSource, topNCase, totalRowNum, &count, &offset, &exe) + for i := 0; i < 20; i++ { + topNNoSpillCase(t, nil, topNCase, schema, dataSource, 0, count) + topNNoSpillCase(t, exe, topNCase, schema, dataSource, offset, count) + } + + initTopNSpillCase1Params(ctx, dataSource, topNCase, totalRowNum, &count, &offset, &exe) + for i := 0; i < 20; i++ { + topNSpillCase1(t, nil, topNCase, schema, dataSource, 0, count) + topNSpillCase1(t, exe, topNCase, schema, dataSource, offset, count) + } + + initTopNSpillCase2Params(ctx, dataSource, topNCase, totalRowNum, &count, &offset, &exe) + for i := 0; i < 20; i++ { + topNSpillCase2(t, nil, topNCase, schema, dataSource, 0, count) + topNSpillCase2(t, exe, topNCase, schema, dataSource, offset, count) + } + + initTopNInMemoryThenSpillParams(ctx, dataSource, topNCase, totalRowNum, &count, &offset, &exe) + for i := 0; i < 20; i++ { + topNInMemoryThenSpillCase(t, ctx, nil, topNCase, schema, dataSource, 0, count) + topNInMemoryThenSpillCase(t, ctx, exe, topNCase, schema, dataSource, offset, count) + } + + failpoint.Disable("github.com/pingcap/tidb/pkg/executor/sortexec/SlowSomeWorkers") +} + +func TestTopNSpillDiskFailpoint(t *testing.T) { + sortexec.SetSmallSpillChunkSizeForTest() + ctx := mock.NewContext() + topNCase := &testutil.SortCase{Rows: totalRowNum, OrderByIdx: []int{0, 1}, Ndvs: []int{0, 0}, Ctx: ctx} + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/sortexec/SlowSomeWorkers", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/sortexec/TopNRandomFail", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/sortexec/ParallelSortRandomFail", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/chunk/ChunkInDiskError", `return(true)`)) + + ctx.GetSessionVars().InitChunkSize = 32 + ctx.GetSessionVars().MaxChunkSize = 32 + ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit1) + ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) + ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) + + offset := uint64(totalRowNum / 10) + count := uint64(totalRowNum / 3) + + var exe *sortexec.TopNExec + schema := expression.NewSchema(topNCase.Columns()...) + dataSource := buildDataSource(topNCase, schema) + initTopNNoSpillCaseParams(ctx, dataSource, topNCase, totalRowNum, &count, &offset, &exe) + for i := 0; i < 10; i++ { + topNFailPointTest(t, nil, topNCase, dataSource, 0, count, 0, ctx.GetSessionVars().MemTracker) + topNFailPointTest(t, exe, topNCase, dataSource, offset, count, 0, ctx.GetSessionVars().MemTracker) + } + + initTopNSpillCase1Params(ctx, dataSource, topNCase, totalRowNum, &count, &offset, &exe) + for i := 0; i < 10; i++ { + topNFailPointTest(t, nil, topNCase, dataSource, 0, count, 0, ctx.GetSessionVars().MemTracker) + topNFailPointTest(t, exe, topNCase, dataSource, offset, count, 0, ctx.GetSessionVars().MemTracker) + } + + initTopNSpillCase2Params(ctx, dataSource, topNCase, totalRowNum, &count, &offset, &exe) + for i := 0; i < 10; i++ { + topNFailPointTest(t, nil, topNCase, dataSource, 0, count, 0, ctx.GetSessionVars().MemTracker) + topNFailPointTest(t, exe, topNCase, dataSource, offset, count, 0, ctx.GetSessionVars().MemTracker) + } + + initTopNInMemoryThenSpillParams(ctx, dataSource, topNCase, totalRowNum, &count, &offset, &exe) + for i := 0; i < 10; i++ { + topNFailPointTest(t, nil, topNCase, dataSource, 0, count, inMemoryThenSpillHardLimit, ctx.GetSessionVars().MemTracker) + topNFailPointTest(t, exe, topNCase, dataSource, offset, count, inMemoryThenSpillHardLimit, ctx.GetSessionVars().MemTracker) + } + + failpoint.Disable("github.com/pingcap/tidb/pkg/executor/sortexec/SlowSomeWorkers") + failpoint.Disable("github.com/pingcap/tidb/pkg/executor/sortexec/TopNRandomFail") + failpoint.Disable("github.com/pingcap/tidb/pkg/executor/sortexec/ParallelSortRandomFail") + failpoint.Disable("github.com/pingcap/tidb/pkg/util/chunk/ChunkInDiskError") +} diff --git a/pkg/executor/sortexec/topn_worker.go b/pkg/executor/sortexec/topn_worker.go new file mode 100644 index 0000000000000..b62844aba28ca --- /dev/null +++ b/pkg/executor/sortexec/topn_worker.go @@ -0,0 +1,127 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec + +import ( + "container/heap" + "math/rand" + "sync" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/memory" +) + +// topNWorker is used only when topn spill is triggered +type topNWorker struct { + workerIDForTest int + + chunkChannel <-chan *chunk.Chunk + fetcherAndWorkerSyncer *sync.WaitGroup + errOutputChan chan<- rowWithError + finishChan <-chan struct{} + + topn *TopNExec + chkHeap *topNChunkHeap + memTracker *memory.Tracker +} + +func newTopNWorker( + idForTest int, + chunkChannel <-chan *chunk.Chunk, + fetcherAndWorkerSyncer *sync.WaitGroup, + errOutputChan chan<- rowWithError, + finishChan <-chan struct{}, + topn *TopNExec, + chkHeap *topNChunkHeap, + memTracker *memory.Tracker) *topNWorker { + return &topNWorker{ + workerIDForTest: idForTest, + chunkChannel: chunkChannel, + fetcherAndWorkerSyncer: fetcherAndWorkerSyncer, + errOutputChan: errOutputChan, + finishChan: finishChan, + chkHeap: chkHeap, + topn: topn, + memTracker: memTracker, + } +} + +func (t *topNWorker) fetchChunksAndProcess() { + // Offset of heap in worker should be 0, as we need to spill all data + t.chkHeap.init(t.topn, t.memTracker, t.topn.Limit.Offset+t.topn.Limit.Count, 0, t.topn.greaterRow, t.topn.RetFieldTypes()) + for t.fetchChunksAndProcessImpl() { + } +} + +func (t *topNWorker) fetchChunksAndProcessImpl() bool { + select { + case <-t.finishChan: + return false + case chk, ok := <-t.chunkChannel: + if !ok { + return false + } + defer func() { + t.fetcherAndWorkerSyncer.Done() + }() + + t.injectFailPointForTopNWorker(3) + + if uint64(t.chkHeap.rowChunks.Len()) < t.chkHeap.totalLimit { + if !t.chkHeap.isInitialized { + t.chkHeap.init(t.topn, t.memTracker, t.topn.Limit.Offset+t.topn.Limit.Count, 0, t.topn.greaterRow, t.topn.RetFieldTypes()) + } + t.chkHeap.rowChunks.Add(chk) + } else { + if !t.chkHeap.isRowPtrsInit { + t.chkHeap.initPtrs() + heap.Init(t.chkHeap) + } + t.chkHeap.processChk(chk) + } + } + return true +} + +func (t *topNWorker) run() { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(t.errOutputChan, r) + } + + // Consume all chunks to avoid hang of fetcher + for range t.chunkChannel { + t.fetcherAndWorkerSyncer.Done() + } + }() + + t.fetchChunksAndProcess() +} + +func (t *topNWorker) injectFailPointForTopNWorker(triggerFactor int32) { + injectTopNRandomFail(triggerFactor) + failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { + if val.(bool) { + if t.workerIDForTest%2 == 0 { + randNum := rand.Int31n(10000) + if randNum < 10 { + time.Sleep(1 * time.Millisecond) + } + } + } + }) +} diff --git a/pkg/util/chunk/list.go b/pkg/util/chunk/list.go index f8246850ecd6b..32d5313e03407 100644 --- a/pkg/util/chunk/list.go +++ b/pkg/util/chunk/list.go @@ -15,6 +15,8 @@ package chunk import ( + "unsafe" + "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/memory" @@ -33,6 +35,9 @@ type List struct { consumedIdx int // chunk index in "chunks", has been consumed. } +// RowPtrSize shows the size of RowPtr +const RowPtrSize = int(unsafe.Sizeof(RowPtr{})) + // RowPtr is used to get a row from a list. // It is only valid for the list that returns it. type RowPtr struct {