diff --git a/executor/builder.go b/executor/builder.go index 7da5f7ab8ce55..88d35167352af 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1588,6 +1588,7 @@ func (b *executorBuilder) buildUnionAll(v *plannercore.PhysicalUnionAll) Executo } e := &UnionExec{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ID(), childExecs...), + concurrency: b.ctx.GetSessionVars().Concurrency.UnionConcurrency, } return e } diff --git a/executor/executor.go b/executor/executor.go index 9e738f356b17d..c30fb18846637 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1405,6 +1405,8 @@ func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.Chunk) error { // +-------------+ type UnionExec struct { baseExecutor + concurrency int + childIDChan chan int stopFetchData atomic.Value @@ -1412,9 +1414,9 @@ type UnionExec struct { resourcePools []chan *chunk.Chunk resultPool chan *unionWorkerResult - childrenResults []*chunk.Chunk - wg sync.WaitGroup - initialized bool + results []*chunk.Chunk + wg sync.WaitGroup + initialized bool } // unionWorkerResult stores the result for a union worker. @@ -1437,9 +1439,6 @@ func (e *UnionExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } - for _, child := range e.children { - e.childrenResults = append(e.childrenResults, newFirstChunk(child)) - } e.stopFetchData.Store(false) e.initialized = false e.finished = make(chan struct{}) @@ -1447,22 +1446,33 @@ func (e *UnionExec) Open(ctx context.Context) error { } func (e *UnionExec) initialize(ctx context.Context) { - e.resultPool = make(chan *unionWorkerResult, len(e.children)) - e.resourcePools = make([]chan *chunk.Chunk, len(e.children)) - for i := range e.children { + if e.concurrency > len(e.children) { + e.concurrency = len(e.children) + } + for i := 0; i < e.concurrency; i++ { + e.results = append(e.results, newFirstChunk(e.children[0])) + } + e.resultPool = make(chan *unionWorkerResult, e.concurrency) + e.resourcePools = make([]chan *chunk.Chunk, e.concurrency) + e.childIDChan = make(chan int, len(e.children)) + for i := 0; i < e.concurrency; i++ { e.resourcePools[i] = make(chan *chunk.Chunk, 1) - e.resourcePools[i] <- e.childrenResults[i] + e.resourcePools[i] <- e.results[i] e.wg.Add(1) go e.resultPuller(ctx, i) } + for i := 0; i < len(e.children); i++ { + e.childIDChan <- i + } + close(e.childIDChan) go e.waitAllFinished() } -func (e *UnionExec) resultPuller(ctx context.Context, childID int) { +func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { result := &unionWorkerResult{ err: nil, chk: nil, - src: e.resourcePools[childID], + src: e.resourcePools[workerID], } defer func() { if r := recover(); r != nil { @@ -1476,23 +1486,26 @@ func (e *UnionExec) resultPuller(ctx context.Context, childID int) { } e.wg.Done() }() - for { - if e.stopFetchData.Load().(bool) { - return - } - select { - case <-e.finished: - return - case result.chk = <-e.resourcePools[childID]: - } - result.err = Next(ctx, e.children[childID], result.chk) - if result.err == nil && result.chk.NumRows() == 0 { - return - } - e.resultPool <- result - if result.err != nil { - e.stopFetchData.Store(true) - return + for childID := range e.childIDChan { + for { + if e.stopFetchData.Load().(bool) { + return + } + select { + case <-e.finished: + return + case result.chk = <-e.resourcePools[workerID]: + } + result.err = Next(ctx, e.children[childID], result.chk) + if result.err == nil && result.chk.NumRows() == 0 { + e.resourcePools[workerID] <- result.chk + break + } + e.resultPool <- result + if result.err != nil { + e.stopFetchData.Store(true) + return + } } } } @@ -1522,12 +1535,16 @@ func (e *UnionExec) Close() error { if e.finished != nil { close(e.finished) } - e.childrenResults = nil + e.results = nil if e.resultPool != nil { for range e.resultPool { } } e.resourcePools = nil + if e.childIDChan != nil { + for range e.childIDChan { + } + } return e.baseExecutor.Close() } diff --git a/executor/executor_test.go b/executor/executor_test.go index 12e6507e447ec..7708e1f4f4b79 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1342,6 +1342,18 @@ func (s *testSuiteP2) TestUnion(c *C) { tk.MustQuery("select count(distinct a), sum(distinct a), avg(distinct a) from (select a from t union all select b from t) tmp;").Check(testkit.Rows("1 1.000 1.0000000")) } +func (s *testSuite2) TestUnionLimit(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists union_limit") + tk.MustExec("create table union_limit (id int) partition by hash(id) partitions 30") + for i := 0; i < 60; i++ { + tk.MustExec(fmt.Sprintf("insert into union_limit values (%d)", i)) + } + // Cover the code for worker count limit in the union executor. + tk.MustQuery("select * from union_limit limit 10") +} + func (s *testSuiteP1) TestNeighbouringProj(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/session/session.go b/session/session.go index 43d53a054c453..e0991e7921ef4 100644 --- a/session/session.go +++ b/session/session.go @@ -1934,6 +1934,7 @@ var builtinGlobalVariable = []string{ variable.TiDBHashAggPartialConcurrency, variable.TiDBHashAggFinalConcurrency, variable.TiDBWindowConcurrency, + variable.TiDBUnionConcurrency, variable.TiDBBackoffLockFast, variable.TiDBBackOffWeight, variable.TiDBConstraintCheckInPlace, diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 7c01644460f4d..a3d9f1901d33b 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -743,6 +743,7 @@ func NewSessionVars() *SessionVars { HashAggPartialConcurrency: DefTiDBHashAggPartialConcurrency, HashAggFinalConcurrency: DefTiDBHashAggFinalConcurrency, WindowConcurrency: DefTiDBWindowConcurrency, + UnionConcurrency: DefTiDBUnionConcurrency, } vars.MemQuota = MemQuota{ MemQuotaQuery: config.GetGlobalConfig().MemQuotaQuery, @@ -1160,6 +1161,8 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { s.HashAggFinalConcurrency = tidbOptPositiveInt32(val, DefTiDBHashAggFinalConcurrency) case TiDBWindowConcurrency: s.WindowConcurrency = tidbOptPositiveInt32(val, DefTiDBWindowConcurrency) + case TiDBUnionConcurrency: + s.UnionConcurrency = tidbOptPositiveInt32(val, DefTiDBUnionConcurrency) case TiDBDistSQLScanConcurrency: s.DistSQLScanConcurrency = tidbOptPositiveInt32(val, DefDistSQLScanConcurrency) case TiDBIndexSerialScanConcurrency: @@ -1459,6 +1462,9 @@ type Concurrency struct { // IndexSerialScanConcurrency is the number of concurrent index serial scan worker. IndexSerialScanConcurrency int + + // UnionConcurrency is the number of concurrent union worker. + UnionConcurrency int } // MemQuota defines memory quota values. diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 11ece6ebfbba3..74e7e693e24f5 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -666,6 +666,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal | ScopeSession, TiDBHashAggPartialConcurrency, strconv.Itoa(DefTiDBHashAggPartialConcurrency)}, {ScopeGlobal | ScopeSession, TiDBHashAggFinalConcurrency, strconv.Itoa(DefTiDBHashAggFinalConcurrency)}, {ScopeGlobal | ScopeSession, TiDBWindowConcurrency, strconv.Itoa(DefTiDBWindowConcurrency)}, + {ScopeGlobal | ScopeSession, TiDBUnionConcurrency, strconv.Itoa(DefTiDBUnionConcurrency)}, {ScopeGlobal | ScopeSession, TiDBBackoffLockFast, strconv.Itoa(kv.DefBackoffLockFast)}, {ScopeGlobal | ScopeSession, TiDBBackOffWeight, strconv.Itoa(kv.DefBackOffWeight)}, {ScopeGlobal | ScopeSession, TiDBRetryLimit, strconv.Itoa(DefTiDBRetryLimit)}, diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index f9126b377fddb..de2941b94199c 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -285,6 +285,9 @@ const ( // tidb_window_concurrency is used for window parallel executor. TiDBWindowConcurrency = "tidb_window_concurrency" + // tidb_union_concurrency is used for union executor. + TiDBUnionConcurrency = "tidb_union_concurrency" + // tidb_backoff_lock_fast is used for tikv backoff base time in milliseconds. TiDBBackoffLockFast = "tidb_backoff_lock_fast" @@ -480,6 +483,7 @@ const ( DefTiDBHashAggPartialConcurrency = 4 DefTiDBHashAggFinalConcurrency = 4 DefTiDBWindowConcurrency = 4 + DefTiDBUnionConcurrency = 4 DefTiDBForcePriority = mysql.NoPriority DefTiDBUseRadixJoin = false DefEnableWindowFunction = true diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index e9214009fa7af..8134a3f7e4bc5 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -517,6 +517,7 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string, scope Sc TiDBHashAggPartialConcurrency, TiDBHashAggFinalConcurrency, TiDBWindowConcurrency, + TiDBUnionConcurrency, TiDBDistSQLScanConcurrency, TiDBIndexSerialScanConcurrency, TiDBDDLReorgWorkerCount, TiDBBackoffLockFast, TiDBBackOffWeight, diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index 2444b288bd83a..77286d2bd5d97 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -70,6 +70,7 @@ func (s *testVarsutilSuite) TestNewSessionVars(c *C) { c.Assert(vars.HashAggPartialConcurrency, Equals, DefTiDBHashAggPartialConcurrency) c.Assert(vars.HashAggFinalConcurrency, Equals, DefTiDBHashAggFinalConcurrency) c.Assert(vars.WindowConcurrency, Equals, DefTiDBWindowConcurrency) + c.Assert(vars.UnionConcurrency, Equals, DefTiDBUnionConcurrency) c.Assert(vars.DistSQLScanConcurrency, Equals, DefDistSQLScanConcurrency) c.Assert(vars.MaxChunkSize, Equals, DefMaxChunkSize) c.Assert(vars.DMLBatchSize, Equals, DefDMLBatchSize)