diff --git a/DEPS.bzl b/DEPS.bzl index 442749f514a5e..39b7c9e484d56 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -3582,8 +3582,8 @@ def go_deps(): name = "com_github_tikv_client_go_v2", build_file_proto_mode = "disable_global", importpath = "github.com/tikv/client-go/v2", - sum = "h1:RI6bs9TDIIJ96N0lR5uZoGO8QNot4qS/1l+Mobx0InM=", - version = "v2.0.5-0.20230110071533-f313ddf58d73", + sum = "h1:B2FNmPDaGirXpIOgQbqxiukIkT8eOT4tKEahqYE2ers=", + version = "v2.0.5-0.20230112062023-fe5b35c5f5dc", ) go_repository( name = "com_github_tikv_pd_client", diff --git a/executor/executor.go b/executor/executor.go index 81afc5620daca..c2fcdaa2d7887 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1393,6 +1393,9 @@ type LimitExec struct { // columnIdxsUsedByChild keep column indexes of child executor used for inline projection columnIdxsUsedByChild []int + + // Log the close time when opentracing is enabled. + span opentracing.Span } // Next implements the Executor Next interface. @@ -1470,13 +1473,29 @@ func (e *LimitExec) Open(ctx context.Context) error { e.childResult = tryNewCacheChunk(e.children[0]) e.cursor = 0 e.meetFirstBatch = e.begin == 0 + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + e.span = span + } return nil } // Close implements the Executor Close interface. func (e *LimitExec) Close() error { + start := time.Now() + e.childResult = nil - return e.baseExecutor.Close() + err := e.baseExecutor.Close() + + elapsed := time.Since(start) + if elapsed > time.Millisecond { + logutil.BgLogger().Info("limit executor close takes a long time", + zap.Duration("elapsed", elapsed)) + if e.span != nil { + span1 := e.span.Tracer().StartSpan("limitExec.Close", opentracing.ChildOf(e.span.Context()), opentracing.StartTime(start)) + defer span1.Finish() + } + } + return err } func (e *LimitExec) adjustRequiredRows(chk *chunk.Chunk) *chunk.Chunk { diff --git a/executor/executor_test.go b/executor/executor_test.go index c52dcfde48169..7e6a51799d778 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -5557,6 +5557,8 @@ func TestAdmin(t *testing.T) { })) tk := testkit.NewTestKit(t, store) tk.MustExec("use test") + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") tk.MustExec("drop table if exists admin_test") tk.MustExec("create table admin_test (c1 int, c2 int, c3 int default 1, index (c1))") tk.MustExec("insert admin_test (c1) values (1),(2),(NULL)") @@ -5681,7 +5683,7 @@ func TestAdmin(t *testing.T) { // check that the result set has no duplication defer wg.Done() for i := 0; i < 10; i++ { - result := tk.MustQuery(`admin show ddl job queries 20`) + result := tk2.MustQuery(`admin show ddl job queries 20`) rows := result.Rows() rowIDs := make(map[string]struct{}) for _, row := range rows { @@ -5712,7 +5714,7 @@ func TestAdmin(t *testing.T) { // check that the result set has no duplication defer wg2.Done() for i := 0; i < 10; i++ { - result := tk.MustQuery(`admin show ddl job queries limit 3 offset 2`) + result := tk2.MustQuery(`admin show ddl job queries limit 3 offset 2`) rows := result.Rows() rowIDs := make(map[string]struct{}) for _, row := range rows { diff --git a/go.mod b/go.mod index 65440ab9d18cb..2e6cbf201b043 100644 --- a/go.mod +++ b/go.mod @@ -90,7 +90,7 @@ require ( github.com/stretchr/testify v1.8.0 github.com/tdakkota/asciicheck v0.1.1 github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 - github.com/tikv/client-go/v2 v2.0.5-0.20230110071533-f313ddf58d73 + github.com/tikv/client-go/v2 v2.0.5-0.20230112062023-fe5b35c5f5dc github.com/tikv/pd/client v0.0.0-20221031025758-80f0d8ca4d07 github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 github.com/twmb/murmur3 v1.1.3 diff --git a/go.sum b/go.sum index e34b4c5935340..a75624ea224ee 100644 --- a/go.sum +++ b/go.sum @@ -936,8 +936,8 @@ github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 h1:mbAskLJ0oJf github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfKggNGDuadAa0LElHrByyrz4JPZ9fFx6Gs7nx7ZZU= github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW93SG+q0F8KI+yFrcIDT4c/RNoc4= github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= -github.com/tikv/client-go/v2 v2.0.5-0.20230110071533-f313ddf58d73 h1:RI6bs9TDIIJ96N0lR5uZoGO8QNot4qS/1l+Mobx0InM= -github.com/tikv/client-go/v2 v2.0.5-0.20230110071533-f313ddf58d73/go.mod h1:dO/2a/xi/EO3eVv9xN5G1VFtd/hythzgTeeCbW5SWuI= +github.com/tikv/client-go/v2 v2.0.5-0.20230112062023-fe5b35c5f5dc h1:B2FNmPDaGirXpIOgQbqxiukIkT8eOT4tKEahqYE2ers= +github.com/tikv/client-go/v2 v2.0.5-0.20230112062023-fe5b35c5f5dc/go.mod h1:dO/2a/xi/EO3eVv9xN5G1VFtd/hythzgTeeCbW5SWuI= github.com/tikv/pd/client v0.0.0-20221031025758-80f0d8ca4d07 h1:ckPpxKcl75mO2N6a4cJXiZH43hvcHPpqc9dh1TmH1nc= github.com/tikv/pd/client v0.0.0-20221031025758-80f0d8ca4d07/go.mod h1:CipBxPfxPUME+BImx9MUYXCnAVLS3VJUr3mnSJwh40A= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 h1:kl4KhGNsJIbDHS9/4U9yQo1UcPQM0kOMJHn29EoH/Ro= diff --git a/planner/core/encode.go b/planner/core/encode.go index 14931d4d1ef0a..bf02059dafdda 100644 --- a/planner/core/encode.go +++ b/planner/core/encode.go @@ -262,6 +262,9 @@ type planDigester struct { // NormalizeFlatPlan normalizes a FlatPhysicalPlan and generates plan digest. func NormalizeFlatPlan(flat *FlatPhysicalPlan) (normalized string, digest *parser.Digest) { + if flat == nil { + return "", parser.NewDigest(nil) + } selectPlan, selectPlanOffset := flat.Main.GetSelectPlan() if len(selectPlan) == 0 || !selectPlan[0].IsPhysicalPlan { return "", parser.NewDigest(nil) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 38ba8df4e37b9..2d73534fc2e1e 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2017,7 +2017,7 @@ func getUintFromNode(ctx sessionctx.Context, n ast.Node, mustInt64orUint64 bool) return 0, false, true } if mustInt64orUint64 { - if expected := checkParamTypeInt64orUint64(v); !expected { + if expected, _ := CheckParamTypeInt64orUint64(v); !expected { return 0, false, false } } @@ -2054,19 +2054,19 @@ func getUintFromNode(ctx sessionctx.Context, n ast.Node, mustInt64orUint64 bool) return 0, false, false } -// check param type for plan cache limit, only allow int64 and uint64 now +// CheckParamTypeInt64orUint64 check param type for plan cache limit, only allow int64 and uint64 now // eg: set @a = 1; -func checkParamTypeInt64orUint64(param *driver.ParamMarkerExpr) bool { +func CheckParamTypeInt64orUint64(param *driver.ParamMarkerExpr) (bool, uint64) { val := param.GetValue() switch v := val.(type) { case int64: if v >= 0 { - return true + return true, uint64(v) } case uint64: - return true + return true, v } - return false + return false, 0 } func extractLimitCountOffset(ctx sessionctx.Context, limit *ast.Limit) (count uint64, diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 9ea2fdf89c006..c38b8b6f9d39a 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -393,7 +393,6 @@ func postOptimize(ctx context.Context, sctx sessionctx.Context, plan PhysicalPla plan = eliminateUnionScanAndLock(sctx, plan) plan = enableParallelApply(sctx, plan) handleFineGrainedShuffle(ctx, sctx, plan) - checkPlanCacheable(sctx, plan) propagateProbeParents(plan, nil) countStarRewrite(plan) return plan, nil @@ -966,16 +965,6 @@ func setupFineGrainedShuffleInternal(ctx context.Context, sctx sessionctx.Contex } } -// checkPlanCacheable used to check whether a plan can be cached. Plans that -// meet the following characteristics cannot be cached: -// 1. Use the TiFlash engine. -// Todo: make more careful check here. -func checkPlanCacheable(sctx sessionctx.Context, plan PhysicalPlan) { - if sctx.GetSessionVars().StmtCtx.UseCache && useTiFlash(plan) { - sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.Errorf("skip plan-cache: TiFlash plan is un-cacheable")) - } -} - // propagateProbeParents doesn't affect the execution plan, it only sets the probeParents field of a PhysicalPlan. // It's for handling the inconsistency between row count in the statsInfo and the recorded actual row count. Please // see comments in PhysicalPlan for details. diff --git a/planner/core/plan_cache.go b/planner/core/plan_cache.go index 677b9afe4e29d..b39208ae7574d 100644 --- a/planner/core/plan_cache.go +++ b/planner/core/plan_cache.go @@ -158,27 +158,29 @@ func GetPlanFromSessionPlanCache(ctx context.Context, sctx sessionctx.Context, } } - paramNum, paramTypes := parseParamTypes(sctx, params) + paramTypes := parseParamTypes(sctx, params) if stmtCtx.UseCache && stmtAst.CachedPlan != nil { // for point query plan if plan, names, ok, err := getCachedPointPlan(stmtAst, sessVars, stmtCtx); ok { return plan, names, err } } - + limitCountAndOffset, paramErr := ExtractLimitFromAst(stmt.PreparedAst.Stmt, sctx) + if paramErr != nil { + return nil, nil, paramErr + } if stmtCtx.UseCache { // for non-point plans if plan, names, ok, err := getCachedPlan(sctx, isNonPrepared, cacheKey, bindSQL, is, stmt, - paramTypes); err != nil || ok { + paramTypes, limitCountAndOffset); err != nil || ok { return plan, names, err } } - return generateNewPlan(ctx, sctx, isNonPrepared, is, stmt, cacheKey, latestSchemaVersion, paramNum, paramTypes, bindSQL) + return generateNewPlan(ctx, sctx, isNonPrepared, is, stmt, cacheKey, latestSchemaVersion, paramTypes, bindSQL, limitCountAndOffset) } // parseParamTypes get parameters' types in PREPARE statement -func parseParamTypes(sctx sessionctx.Context, params []expression.Expression) (paramNum int, paramTypes []*types.FieldType) { - paramNum = len(params) +func parseParamTypes(sctx sessionctx.Context, params []expression.Expression) (paramTypes []*types.FieldType) { for _, param := range params { if c, ok := param.(*expression.Constant); ok { // from binary protocol paramTypes = append(paramTypes, c.GetType()) @@ -221,12 +223,12 @@ func getCachedPointPlan(stmt *ast.Prepared, sessVars *variable.SessionVars, stmt } func getCachedPlan(sctx sessionctx.Context, isNonPrepared bool, cacheKey kvcache.Key, bindSQL string, - is infoschema.InfoSchema, stmt *PlanCacheStmt, paramTypes []*types.FieldType) (Plan, + is infoschema.InfoSchema, stmt *PlanCacheStmt, paramTypes []*types.FieldType, limitParams []uint64) (Plan, []*types.FieldName, bool, error) { sessVars := sctx.GetSessionVars() stmtCtx := sessVars.StmtCtx - candidate, exist := sctx.GetPlanCache(isNonPrepared).Get(cacheKey, paramTypes) + candidate, exist := sctx.GetPlanCache(isNonPrepared).Get(cacheKey, paramTypes, limitParams) if !exist { return nil, nil, false, nil } @@ -264,8 +266,9 @@ func getCachedPlan(sctx sessionctx.Context, isNonPrepared bool, cacheKey kvcache // generateNewPlan call the optimizer to generate a new plan for current statement // and try to add it to cache -func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isNonPrepared bool, is infoschema.InfoSchema, stmt *PlanCacheStmt, cacheKey kvcache.Key, latestSchemaVersion int64, paramNum int, - paramTypes []*types.FieldType, bindSQL string) (Plan, []*types.FieldName, error) { +func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isNonPrepared bool, is infoschema.InfoSchema, + stmt *PlanCacheStmt, cacheKey kvcache.Key, latestSchemaVersion int64, paramTypes []*types.FieldType, + bindSQL string, limitParams []uint64) (Plan, []*types.FieldName, error) { stmtAst := stmt.PreparedAst sessVars := sctx.GetSessionVars() stmtCtx := sessVars.StmtCtx @@ -282,10 +285,10 @@ func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isNonPrepared return nil, nil, err } - // We only cache the tableDual plan when the number of parameters are zero. - if containTableDual(p) && paramNum > 0 { - stmtCtx.SetSkipPlanCache(errors.New("skip plan-cache: get a TableDual plan")) - } + // check whether this plan is cacheable. + checkPlanCacheability(sctx, p, len(paramTypes)) + + // put this plan into the plan cache. if stmtCtx.UseCache { // rebuild key to exclude kv.TiFlash when stmt is not read only if _, isolationReadContainTiFlash := sessVars.IsolationReadEngines[kv.TiFlash]; isolationReadContainTiFlash && !IsReadOnly(stmtAst.Stmt, sessVars) { @@ -296,16 +299,51 @@ func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isNonPrepared } sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} } - cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, paramTypes) + cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, paramTypes, limitParams) stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p) stmtCtx.SetPlan(p) stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) - sctx.GetPlanCache(isNonPrepared).Put(cacheKey, cached, paramTypes) + sctx.GetPlanCache(isNonPrepared).Put(cacheKey, cached, paramTypes, limitParams) } sessVars.FoundInPlanCache = false return p, names, err } +// checkPlanCacheability checks whether this plan is cacheable and set to skip plan cache if it's uncacheable. +func checkPlanCacheability(sctx sessionctx.Context, p Plan, paramNum int) { + stmtCtx := sctx.GetSessionVars().StmtCtx + var pp PhysicalPlan + switch x := p.(type) { + case *Insert: + pp = x.SelectPlan + case *Update: + pp = x.SelectPlan + case *Delete: + pp = x.SelectPlan + case PhysicalPlan: + pp = x + default: + stmtCtx.SetSkipPlanCache(errors.Errorf("skip plan-cache: unexpected un-cacheable plan %v", p.ExplainID().String())) + return + } + if pp == nil { // simple DML statements + return + } + + if useTiFlash(pp) { + stmtCtx.SetSkipPlanCache(errors.Errorf("skip plan-cache: TiFlash plan is un-cacheable")) + return + } + + // We only cache the tableDual plan when the number of parameters are zero. + if containTableDual(pp) && paramNum > 0 { + stmtCtx.SetSkipPlanCache(errors.New("skip plan-cache: get a TableDual plan")) + return + } + + // TODO: plans accessing MVIndex are un-cacheable +} + // RebuildPlan4CachedPlan will rebuild this plan under current user parameters. func RebuildPlan4CachedPlan(p Plan) error { sc := p.SCtx().GetSessionVars().StmtCtx @@ -675,17 +713,13 @@ func tryCachePointPlan(_ context.Context, sctx sessionctx.Context, return err } -func containTableDual(p Plan) bool { +func containTableDual(p PhysicalPlan) bool { _, isTableDual := p.(*PhysicalTableDual) if isTableDual { return true } - physicalPlan, ok := p.(PhysicalPlan) - if !ok { - return false - } childContainTableDual := false - for _, child := range physicalPlan.Children() { + for _, child := range p.Children() { childContainTableDual = childContainTableDual || containTableDual(child) } return childContainTableDual diff --git a/planner/core/plan_cache_lru.go b/planner/core/plan_cache_lru.go index 413dd37e8f5a2..20fa4c3f5c85c 100644 --- a/planner/core/plan_cache_lru.go +++ b/planner/core/plan_cache_lru.go @@ -53,7 +53,7 @@ type LRUPlanCache struct { lock sync.Mutex // pickFromBucket get one element from bucket. The LRUPlanCache can not work if it is nil - pickFromBucket func(map[*list.Element]struct{}, []*types.FieldType) (*list.Element, bool) + pickFromBucket func(map[*list.Element]struct{}, []*types.FieldType, []uint64) (*list.Element, bool) // onEvict will be called if any eviction happened, only for test use now onEvict func(kvcache.Key, kvcache.Value) @@ -68,7 +68,7 @@ type LRUPlanCache struct { // NewLRUPlanCache creates a PCLRUCache object, whose capacity is "capacity". // NOTE: "capacity" should be a positive value. func NewLRUPlanCache(capacity uint, guard float64, quota uint64, - pickFromBucket func(map[*list.Element]struct{}, []*types.FieldType) (*list.Element, bool), sctx sessionctx.Context) *LRUPlanCache { + pickFromBucket func(map[*list.Element]struct{}, []*types.FieldType, []uint64) (*list.Element, bool), sctx sessionctx.Context) *LRUPlanCache { if capacity < 1 { capacity = 100 logutil.BgLogger().Info("capacity of LRU cache is less than 1, will use default value(100) init cache") @@ -94,13 +94,13 @@ func strHashKey(key kvcache.Key, deepCopy bool) string { } // Get tries to find the corresponding value according to the given key. -func (l *LRUPlanCache) Get(key kvcache.Key, paramTypes []*types.FieldType) (value kvcache.Value, ok bool) { +func (l *LRUPlanCache) Get(key kvcache.Key, paramTypes []*types.FieldType, limitParams []uint64) (value kvcache.Value, ok bool) { l.lock.Lock() defer l.lock.Unlock() bucket, bucketExist := l.buckets[strHashKey(key, false)] if bucketExist { - if element, exist := l.pickFromBucket(bucket, paramTypes); exist { + if element, exist := l.pickFromBucket(bucket, paramTypes, limitParams); exist { l.lruList.MoveToFront(element) return element.Value.(*planCacheEntry).PlanValue, true } @@ -109,14 +109,14 @@ func (l *LRUPlanCache) Get(key kvcache.Key, paramTypes []*types.FieldType) (valu } // Put puts the (key, value) pair into the LRU Cache. -func (l *LRUPlanCache) Put(key kvcache.Key, value kvcache.Value, paramTypes []*types.FieldType) { +func (l *LRUPlanCache) Put(key kvcache.Key, value kvcache.Value, paramTypes []*types.FieldType, limitParams []uint64) { l.lock.Lock() defer l.lock.Unlock() hash := strHashKey(key, true) bucket, bucketExist := l.buckets[hash] if bucketExist { - if element, exist := l.pickFromBucket(bucket, paramTypes); exist { + if element, exist := l.pickFromBucket(bucket, paramTypes, limitParams); exist { l.updateInstanceMetric(&planCacheEntry{PlanKey: key, PlanValue: value}, element.Value.(*planCacheEntry)) element.Value.(*planCacheEntry).PlanValue = value l.lruList.MoveToFront(element) @@ -252,16 +252,36 @@ func (l *LRUPlanCache) memoryControl() { } // PickPlanFromBucket pick one plan from bucket -func PickPlanFromBucket(bucket map[*list.Element]struct{}, paramTypes []*types.FieldType) (*list.Element, bool) { +func PickPlanFromBucket(bucket map[*list.Element]struct{}, paramTypes []*types.FieldType, limitParams []uint64) (*list.Element, bool) { for k := range bucket { plan := k.Value.(*planCacheEntry).PlanValue.(*PlanCacheValue) - if plan.ParamTypes.CheckTypesCompatibility4PC(paramTypes) { + ok1 := plan.ParamTypes.CheckTypesCompatibility4PC(paramTypes) + if !ok1 { + continue + } + ok2 := checkUint64SliceIfEqual(plan.limitOffsetAndCount, limitParams) + if ok2 { return k, true } } return nil, false } +func checkUint64SliceIfEqual(a, b []uint64) bool { + if (a == nil && b != nil) || (a != nil && b == nil) { + return false + } + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + // updateInstanceMetric update the memory usage and plan num for show in grafana func (l *LRUPlanCache) updateInstanceMetric(in, out *planCacheEntry) { updateInstancePlanNum(in, out) diff --git a/planner/core/plan_cache_lru_test.go b/planner/core/plan_cache_lru_test.go index 74b6b2c92c3bb..f51480401ce62 100644 --- a/planner/core/plan_cache_lru_test.go +++ b/planner/core/plan_cache_lru_test.go @@ -65,14 +65,18 @@ func TestLRUPCPut(t *testing.T) { {types.NewFieldType(mysql.TypeFloat), types.NewFieldType(mysql.TypeLong)}, {types.NewFieldType(mysql.TypeFloat), types.NewFieldType(mysql.TypeInt24)}, } + limitParams := [][]uint64{ + {1}, {2}, {3}, {4}, {5}, + } // one key corresponding to multi values for i := 0; i < 5; i++ { keys[i] = &planCacheKey{database: strconv.FormatInt(int64(1), 10)} vals[i] = &PlanCacheValue{ - ParamTypes: pTypes[i], + ParamTypes: pTypes[i], + limitOffsetAndCount: limitParams[i], } - lru.Put(keys[i], vals[i], pTypes[i]) + lru.Put(keys[i], vals[i], pTypes[i], limitParams[i]) } require.Equal(t, lru.size, lru.capacity) require.Equal(t, uint(3), lru.size) @@ -103,7 +107,7 @@ func TestLRUPCPut(t *testing.T) { bucket, exist := lru.buckets[string(hack.String(keys[i].Hash()))] require.True(t, exist) - element, exist := lru.pickFromBucket(bucket, pTypes[i]) + element, exist := lru.pickFromBucket(bucket, pTypes[i], limitParams[i]) require.NotNil(t, element) require.True(t, exist) require.Equal(t, root, element) @@ -131,22 +135,25 @@ func TestLRUPCGet(t *testing.T) { {types.NewFieldType(mysql.TypeFloat), types.NewFieldType(mysql.TypeLong)}, {types.NewFieldType(mysql.TypeFloat), types.NewFieldType(mysql.TypeInt24)}, } + limitParams := [][]uint64{ + {1}, {2}, {3}, {4}, {5}, + } // 5 bucket for i := 0; i < 5; i++ { keys[i] = &planCacheKey{database: strconv.FormatInt(int64(i%4), 10)} - vals[i] = &PlanCacheValue{ParamTypes: pTypes[i]} - lru.Put(keys[i], vals[i], pTypes[i]) + vals[i] = &PlanCacheValue{ParamTypes: pTypes[i], limitOffsetAndCount: limitParams[i]} + lru.Put(keys[i], vals[i], pTypes[i], limitParams[i]) } // test for non-existent elements for i := 0; i < 2; i++ { - value, exists := lru.Get(keys[i], pTypes[i]) + value, exists := lru.Get(keys[i], pTypes[i], limitParams[i]) require.False(t, exists) require.Nil(t, value) } for i := 2; i < 5; i++ { - value, exists := lru.Get(keys[i], pTypes[i]) + value, exists := lru.Get(keys[i], pTypes[i], limitParams[i]) require.True(t, exists) require.NotNil(t, value) require.Equal(t, vals[i], value) @@ -175,23 +182,29 @@ func TestLRUPCDelete(t *testing.T) { {types.NewFieldType(mysql.TypeFloat), types.NewFieldType(mysql.TypeEnum)}, {types.NewFieldType(mysql.TypeFloat), types.NewFieldType(mysql.TypeDate)}, } + limitParams := [][]uint64{ + {1}, {2}, {3}, + } for i := 0; i < 3; i++ { keys[i] = &planCacheKey{database: strconv.FormatInt(int64(i), 10)} - vals[i] = &PlanCacheValue{ParamTypes: pTypes[i]} - lru.Put(keys[i], vals[i], pTypes[i]) + vals[i] = &PlanCacheValue{ + ParamTypes: pTypes[i], + limitOffsetAndCount: limitParams[i], + } + lru.Put(keys[i], vals[i], pTypes[i], []uint64{}) } require.Equal(t, 3, int(lru.size)) lru.Delete(keys[1]) - value, exists := lru.Get(keys[1], pTypes[1]) + value, exists := lru.Get(keys[1], pTypes[1], limitParams[1]) require.False(t, exists) require.Nil(t, value) require.Equal(t, 2, int(lru.size)) - _, exists = lru.Get(keys[0], pTypes[0]) + _, exists = lru.Get(keys[0], pTypes[0], limitParams[0]) require.True(t, exists) - _, exists = lru.Get(keys[2], pTypes[2]) + _, exists = lru.Get(keys[2], pTypes[2], limitParams[2]) require.True(t, exists) } @@ -207,14 +220,14 @@ func TestLRUPCDeleteAll(t *testing.T) { for i := 0; i < 3; i++ { keys[i] = &planCacheKey{database: strconv.FormatInt(int64(i), 10)} vals[i] = &PlanCacheValue{ParamTypes: pTypes[i]} - lru.Put(keys[i], vals[i], pTypes[i]) + lru.Put(keys[i], vals[i], pTypes[i], []uint64{}) } require.Equal(t, 3, int(lru.size)) lru.DeleteAll() for i := 0; i < 3; i++ { - value, exists := lru.Get(keys[i], pTypes[i]) + value, exists := lru.Get(keys[i], pTypes[i], []uint64{}) require.False(t, exists) require.Nil(t, value) require.Equal(t, 0, int(lru.size)) @@ -242,7 +255,7 @@ func TestLRUPCSetCapacity(t *testing.T) { for i := 0; i < 5; i++ { keys[i] = &planCacheKey{database: strconv.FormatInt(int64(1), 10)} vals[i] = &PlanCacheValue{ParamTypes: pTypes[i]} - lru.Put(keys[i], vals[i], pTypes[i]) + lru.Put(keys[i], vals[i], pTypes[i], []uint64{}) } require.Equal(t, lru.size, lru.capacity) require.Equal(t, uint(5), lru.size) @@ -292,7 +305,7 @@ func TestIssue37914(t *testing.T) { val := &PlanCacheValue{ParamTypes: pTypes} require.NotPanics(t, func() { - lru.Put(key, val, pTypes) + lru.Put(key, val, pTypes, []uint64{}) }) } @@ -313,7 +326,7 @@ func TestIssue38244(t *testing.T) { for i := 0; i < 5; i++ { keys[i] = &planCacheKey{database: strconv.FormatInt(int64(i), 10)} vals[i] = &PlanCacheValue{ParamTypes: pTypes[i]} - lru.Put(keys[i], vals[i], pTypes[i]) + lru.Put(keys[i], vals[i], pTypes[i], []uint64{}) } require.Equal(t, lru.size, lru.capacity) require.Equal(t, uint(3), lru.size) @@ -334,7 +347,7 @@ func TestLRUPlanCacheMemoryUsage(t *testing.T) { for i := 0; i < 3; i++ { k := randomPlanCacheKey() v := randomPlanCacheValue(pTypes) - lru.Put(k, v, pTypes) + lru.Put(k, v, pTypes, []uint64{}) res += k.MemoryUsage() + v.MemoryUsage() require.Equal(t, lru.MemoryUsage(), res) } @@ -342,7 +355,7 @@ func TestLRUPlanCacheMemoryUsage(t *testing.T) { p := &PhysicalTableScan{} k := &planCacheKey{database: "3"} v := &PlanCacheValue{Plan: p} - lru.Put(k, v, pTypes) + lru.Put(k, v, pTypes, []uint64{}) res += k.MemoryUsage() + v.MemoryUsage() for kk, vv := range evict { res -= kk.(*planCacheKey).MemoryUsage() + vv.(*PlanCacheValue).MemoryUsage() diff --git a/planner/core/plan_cache_test.go b/planner/core/plan_cache_test.go index fe76180291edf..8acc28b7b0062 100644 --- a/planner/core/plan_cache_test.go +++ b/planner/core/plan_cache_test.go @@ -385,12 +385,6 @@ func TestPlanCacheDiagInfo(t *testing.T) { tk.MustExec("prepare stmt from 'select /*+ ignore_plan_cache() */ * from t'") tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: ignore plan cache by hint")) - tk.MustExec("prepare stmt from 'select * from t limit ?'") - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: query has 'limit ?' is un-cacheable")) - - tk.MustExec("prepare stmt from 'select * from t limit ?, 1'") - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: query has 'limit ?, 10' is un-cacheable")) - tk.MustExec("prepare stmt from 'select * from t order by ?'") tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: query has 'order by ?' is un-cacheable")) @@ -463,18 +457,49 @@ func TestIssue40225(t *testing.T) { tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1")) } -func TestUncacheableReason(t *testing.T) { +func TestPlanCacheWithLimit(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) tk.MustExec("use test") - tk.MustExec("create table t (a int)") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int primary key, b int)") + + testCases := []struct { + sql string + params []int + }{ + {"prepare stmt from 'select * from t limit ?'", []int{1}}, + {"prepare stmt from 'select * from t limit ?, ?'", []int{1, 2}}, + {"prepare stmt from 'delete from t order by a limit ?'", []int{1}}, + {"prepare stmt from 'insert into t select * from t order by a desc limit ?'", []int{1}}, + {"prepare stmt from 'insert into t select * from t order by a desc limit ?, ?'", []int{1, 2}}, + {"prepare stmt from 'update t set a = 1 limit ?'", []int{1}}, + {"prepare stmt from '(select * from t order by a limit ?) union (select * from t order by a desc limit ?)'", []int{1, 2}}, + {"prepare stmt from 'select * from t where a = ? limit ?, ?'", []int{1, 1, 1}}, + {"prepare stmt from 'select * from t where a in (?, ?) limit ?, ?'", []int{1, 2, 1, 1}}, + } + + for idx, testCase := range testCases { + tk.MustExec(testCase.sql) + var using []string + for i, p := range testCase.params { + tk.MustExec(fmt.Sprintf("set @a%d = %d", i, p)) + using = append(using, fmt.Sprintf("@a%d", i)) + } - tk.MustExec("prepare st from 'select * from t limit ?'") - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: query has 'limit ?' is un-cacheable")) + tk.MustExec("execute stmt using " + strings.Join(using, ", ")) + tk.MustExec("execute stmt using " + strings.Join(using, ", ")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) - tk.MustExec("set @a=1") - tk.MustQuery("execute st using @a").Check(testkit.Rows()) - tk.MustExec("prepare st from 'select * from t limit ?'") - // show the corresponding un-cacheable reason at execute-stage as well - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: query has 'limit ?' is un-cacheable")) + if idx < 6 { + tk.MustExec("set @a0 = 6") + tk.MustExec("execute stmt using " + strings.Join(using, ", ")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + } + } + + tk.MustExec("prepare stmt from 'select * from t limit ?'") + tk.MustExec("set @a = 10001") + tk.MustExec("execute stmt using @a") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: limit count more than 10000")) } diff --git a/planner/core/plan_cache_utils.go b/planner/core/plan_cache_utils.go index 3fe4ee38bfe45..082637506e590 100644 --- a/planner/core/plan_cache_utils.go +++ b/planner/core/plan_cache_utils.go @@ -343,6 +343,9 @@ type PlanCacheValue struct { TblInfo2UnionScan map[*model.TableInfo]bool ParamTypes FieldSlice memoryUsage int64 + // limitOffsetAndCount stores all the offset and key parameters extract from limit statement + // only used for cache and pick plan with parameters + limitOffsetAndCount []uint64 } func (v *PlanCacheValue) varTypesUnchanged(txtVarTps []*types.FieldType) bool { @@ -390,7 +393,7 @@ func (v *PlanCacheValue) MemoryUsage() (sum int64) { // NewPlanCacheValue creates a SQLCacheValue. func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool, - paramTypes []*types.FieldType) *PlanCacheValue { + paramTypes []*types.FieldType, limitParams []uint64) *PlanCacheValue { dstMap := make(map[*model.TableInfo]bool) for k, v := range srcMap { dstMap[k] = v @@ -400,10 +403,11 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta userParamTypes[i] = tp.Clone() } return &PlanCacheValue{ - Plan: plan, - OutPutNames: names, - TblInfo2UnionScan: dstMap, - ParamTypes: userParamTypes, + Plan: plan, + OutPutNames: names, + TblInfo2UnionScan: dstMap, + ParamTypes: userParamTypes, + limitOffsetAndCount: limitParams, } } @@ -453,3 +457,69 @@ func GetPreparedStmt(stmt *ast.ExecuteStmt, vars *variable.SessionVars) (*PlanCa } return nil, ErrStmtNotFound } + +type limitExtractor struct { + cacheable bool // For safety considerations, check if limit count less than 10000 + offsetAndCount []uint64 + unCacheableReason string + paramTypeErr error +} + +// Enter implements Visitor interface. +func (checker *limitExtractor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + switch node := in.(type) { + case *ast.Limit: + if node.Count != nil { + if count, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker { + typeExpected, val := CheckParamTypeInt64orUint64(count) + if typeExpected { + if val > 10000 { + checker.cacheable = false + checker.unCacheableReason = "limit count more than 10000" + return in, true + } + checker.offsetAndCount = append(checker.offsetAndCount, val) + } else { + checker.paramTypeErr = ErrWrongArguments.GenWithStackByArgs("LIMIT") + return in, true + } + } + } + if node.Offset != nil { + if offset, isParamMarker := node.Offset.(*driver.ParamMarkerExpr); isParamMarker { + typeExpected, val := CheckParamTypeInt64orUint64(offset) + if typeExpected { + checker.offsetAndCount = append(checker.offsetAndCount, val) + } else { + checker.paramTypeErr = ErrWrongArguments.GenWithStackByArgs("LIMIT") + return in, true + } + } + } + } + return in, false +} + +// Leave implements Visitor interface. +func (checker *limitExtractor) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, checker.cacheable +} + +// ExtractLimitFromAst extract limit offset and count from ast for plan cache key encode +func ExtractLimitFromAst(node ast.Node, sctx sessionctx.Context) ([]uint64, error) { + if node == nil { + return nil, nil + } + checker := limitExtractor{ + cacheable: true, + offsetAndCount: []uint64{}, + } + node.Accept(&checker) + if checker.paramTypeErr != nil { + return nil, checker.paramTypeErr + } + if sctx != nil && !checker.cacheable { + sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.New("skip plan-cache: " + checker.unCacheableReason)) + } + return checker.offsetAndCount, nil +} diff --git a/planner/core/plan_cacheable_checker.go b/planner/core/plan_cacheable_checker.go index 0074cff434221..2ff9e51823ee2 100644 --- a/planner/core/plan_cacheable_checker.go +++ b/planner/core/plan_cacheable_checker.go @@ -135,21 +135,22 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren return in, true } } - case *ast.Limit: - if node.Count != nil { - if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker { - checker.cacheable = false - checker.reason = "query has 'limit ?' is un-cacheable" - return in, true - } - } - if node.Offset != nil { - if _, isParamMarker := node.Offset.(*driver.ParamMarkerExpr); isParamMarker { - checker.cacheable = false - checker.reason = "query has 'limit ?, 10' is un-cacheable" - return in, true - } - } + // todo: these comment is used to add switch in the later pr + //case *ast.Limit: + // if node.Count != nil { + // if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker { + // checker.cacheable = false + // checker.reason = "query has 'limit ?' is un-cacheable" + // return in, true + // } + // } + // if node.Offset != nil { + // if _, isParamMarker := node.Offset.(*driver.ParamMarkerExpr); isParamMarker { + // checker.cacheable = false + // checker.reason = "query has 'limit ?, 10' is un-cacheable" + // return in, true + // } + // } case *ast.FrameBound: if _, ok := node.Expr.(*driver.ParamMarkerExpr); ok { checker.cacheable = false diff --git a/planner/core/plan_cacheable_checker_test.go b/planner/core/plan_cacheable_checker_test.go index e87a08592eb16..7d417e377888f 100644 --- a/planner/core/plan_cacheable_checker_test.go +++ b/planner/core/plan_cacheable_checker_test.go @@ -87,7 +87,7 @@ func TestCacheable(t *testing.T) { TableRefs: tableRefsClause, Limit: limitStmt, } - require.False(t, core.Cacheable(stmt, is)) + require.True(t, core.Cacheable(stmt, is)) limitStmt = &ast.Limit{ Offset: &driver.ParamMarkerExpr{}, @@ -96,7 +96,7 @@ func TestCacheable(t *testing.T) { TableRefs: tableRefsClause, Limit: limitStmt, } - require.False(t, core.Cacheable(stmt, is)) + require.True(t, core.Cacheable(stmt, is)) limitStmt = &ast.Limit{} stmt = &ast.DeleteStmt{ @@ -139,7 +139,7 @@ func TestCacheable(t *testing.T) { TableRefs: tableRefsClause, Limit: limitStmt, } - require.False(t, core.Cacheable(stmt, is)) + require.True(t, core.Cacheable(stmt, is)) limitStmt = &ast.Limit{ Offset: &driver.ParamMarkerExpr{}, @@ -148,7 +148,7 @@ func TestCacheable(t *testing.T) { TableRefs: tableRefsClause, Limit: limitStmt, } - require.False(t, core.Cacheable(stmt, is)) + require.True(t, core.Cacheable(stmt, is)) limitStmt = &ast.Limit{} stmt = &ast.UpdateStmt{ @@ -188,7 +188,7 @@ func TestCacheable(t *testing.T) { stmt = &ast.SelectStmt{ Limit: limitStmt, } - require.False(t, core.Cacheable(stmt, is)) + require.True(t, core.Cacheable(stmt, is)) limitStmt = &ast.Limit{ Offset: &driver.ParamMarkerExpr{}, @@ -196,7 +196,7 @@ func TestCacheable(t *testing.T) { stmt = &ast.SelectStmt{ Limit: limitStmt, } - require.False(t, core.Cacheable(stmt, is)) + require.True(t, core.Cacheable(stmt, is)) limitStmt = &ast.Limit{} stmt = &ast.SelectStmt{ diff --git a/resourcemanager/pooltask/BUILD.bazel b/resourcemanager/pooltask/BUILD.bazel index c9e37436562ee..151a0ddfdec02 100644 --- a/resourcemanager/pooltask/BUILD.bazel +++ b/resourcemanager/pooltask/BUILD.bazel @@ -1,8 +1,19 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "pooltask", - srcs = ["task.go"], + srcs = [ + "task.go", + "task_manager.go", + ], importpath = "github.com/pingcap/tidb/resourcemanager/pooltask", visibility = ["//visibility:public"], + deps = ["@org_uber_go_atomic//:atomic"], +) + +go_test( + name = "pooltask_test", + srcs = ["task_test.go"], + embed = [":pooltask"], + deps = ["@com_github_stretchr_testify//require"], ) diff --git a/resourcemanager/pooltask/task.go b/resourcemanager/pooltask/task.go index ef9b046c8ccba..e166e24f76b4c 100644 --- a/resourcemanager/pooltask/task.go +++ b/resourcemanager/pooltask/task.go @@ -16,6 +16,7 @@ package pooltask import ( "sync" + "sync/atomic" ) // Context is a interface that can be used to create a context. @@ -31,7 +32,16 @@ func (NilContext) GetContext() any { return nil } -// TaskBox is a box which contains all info about pooltask. +const ( + // PendingTask is a task waiting to start. + PendingTask int32 = iota + // RunningTask is a task running. + RunningTask + // StopTask is a stop task. + StopTask +) + +// TaskBox is a box which contains all info about pool task. type TaskBox[T any, U any, C any, CT any, TF Context[CT]] struct { constArgs C contextFunc TF @@ -39,10 +49,24 @@ type TaskBox[T any, U any, C any, CT any, TF Context[CT]] struct { task chan Task[T] resultCh chan U taskID uint64 + status atomic.Int32 // task manager is able to make this task stop, wait or running +} + +// GetStatus is to get the status of task. +func (t *TaskBox[T, U, C, CT, TF]) GetStatus() int32 { + return t.status.Load() +} + +// SetStatus is to set the status of task. +func (t *TaskBox[T, U, C, CT, TF]) SetStatus(s int32) { + t.status.Store(s) } // NewTaskBox is to create a task box for pool. func NewTaskBox[T any, U any, C any, CT any, TF Context[CT]](constArgs C, contextFunc TF, wg *sync.WaitGroup, taskCh chan Task[T], resultCh chan U, taskID uint64) TaskBox[T, U, C, CT, TF] { + // We still need to do some work after a TaskBox finishes. + // So we need to add 1 to waitgroup. After we finish the work, we need to call TaskBox.Finish() + wg.Add(1) return TaskBox[T, U, C, CT, TF]{ constArgs: constArgs, contextFunc: contextFunc, @@ -54,7 +78,7 @@ func NewTaskBox[T any, U any, C any, CT any, TF Context[CT]](constArgs C, contex } // TaskID is to get the task id. -func (t TaskBox[T, U, C, CT, TF]) TaskID() uint64 { +func (t *TaskBox[T, U, C, CT, TF]) TaskID() uint64 { return t.taskID } @@ -83,6 +107,11 @@ func (t *TaskBox[T, U, C, CT, TF]) Done() { t.wg.Done() } +// Finish is to set the TaskBox finish status. +func (t *TaskBox[T, U, C, CT, TF]) Finish() { + t.wg.Done() +} + // Clone is to copy the box func (t *TaskBox[T, U, C, CT, TF]) Clone() *TaskBox[T, U, C, CT, TF] { newBox := NewTaskBox[T, U, C, CT, TF](t.constArgs, t.contextFunc, t.wg, t.task, t.resultCh, t.taskID) @@ -92,6 +121,8 @@ func (t *TaskBox[T, U, C, CT, TF]) Clone() *TaskBox[T, U, C, CT, TF] { // GPool is a goroutine pool. type GPool[T any, U any, C any, CT any, TF Context[CT]] interface { Tune(size int) + DeleteTask(id uint64) + StopTask(id uint64) } // TaskController is a controller that can control or watch the pool. @@ -119,6 +150,12 @@ func (t *TaskController[T, U, C, CT, TF]) Wait() { <-t.close t.wg.Wait() close(t.resultCh) + t.pool.DeleteTask(t.taskID) +} + +// Stop is to send stop command to the task. But you still need to wait the task to stop. +func (t *TaskController[T, U, C, CT, TF]) Stop() { + t.pool.StopTask(t.TaskID()) } // TaskID is to get the task id. diff --git a/resourcemanager/pooltask/task_manager.go b/resourcemanager/pooltask/task_manager.go new file mode 100644 index 0000000000000..d08443dc7caf1 --- /dev/null +++ b/resourcemanager/pooltask/task_manager.go @@ -0,0 +1,146 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pooltask + +import ( + "container/list" + "sync" + "time" + + "go.uber.org/atomic" +) + +const shard int = 8 + +func getShardID(id uint64) uint64 { + return id % uint64(shard) +} + +type tContainer[T any, U any, C any, CT any, TF Context[CT]] struct { + task *TaskBox[T, U, C, CT, TF] +} + +type meta struct { + stats *list.List + createTS time.Time + origin int32 + running int32 +} + +func newStats(concurrency int32) *meta { + s := &meta{ + createTS: time.Now(), + stats: list.New(), + origin: concurrency, + } + return s +} + +func (m *meta) getOriginConcurrency() int32 { + return m.origin +} + +// TaskStatusContainer is a container that can control or watch the pool. +type TaskStatusContainer[T any, U any, C any, CT any, TF Context[CT]] struct { + stats map[uint64]*meta + rw sync.RWMutex +} + +// TaskManager is a manager that can control or watch the pool. +type TaskManager[T any, U any, C any, CT any, TF Context[CT]] struct { + task []TaskStatusContainer[T, U, C, CT, TF] + running atomic.Int32 + concurrency int32 +} + +// NewTaskManager create a new pooltask manager. +func NewTaskManager[T any, U any, C any, CT any, TF Context[CT]](c int32) TaskManager[T, U, C, CT, TF] { + task := make([]TaskStatusContainer[T, U, C, CT, TF], shard) + for i := 0; i < shard; i++ { + task[i] = TaskStatusContainer[T, U, C, CT, TF]{ + stats: make(map[uint64]*meta), + } + } + return TaskManager[T, U, C, CT, TF]{ + task: task, + concurrency: c, + } +} + +// RegisterTask register a task to the manager. +func (t *TaskManager[T, U, C, CT, TF]) RegisterTask(taskID uint64, concurrency int32) { + id := getShardID(taskID) + t.task[id].rw.Lock() + t.task[id].stats[taskID] = newStats(concurrency) + t.task[id].rw.Unlock() +} + +// DeleteTask delete a task from the manager. +func (t *TaskManager[T, U, C, CT, TF]) DeleteTask(taskID uint64) { + shardID := getShardID(taskID) + t.task[shardID].rw.Lock() + delete(t.task[shardID].stats, taskID) + t.task[shardID].rw.Unlock() +} + +// hasTask check if the task is in the manager. +func (t *TaskManager[T, U, C, CT, TF]) hasTask(taskID uint64) bool { + shardID := getShardID(taskID) + t.task[shardID].rw.Lock() + defer t.task[shardID].rw.Unlock() + _, ok := t.task[shardID].stats[taskID] + return ok +} + +// AddSubTask AddTask add a task to the manager. +func (t *TaskManager[T, U, C, CT, TF]) AddSubTask(taskID uint64, task *TaskBox[T, U, C, CT, TF]) { + shardID := getShardID(taskID) + tc := tContainer[T, U, C, CT, TF]{ + task: task, + } + t.running.Inc() + t.task[shardID].rw.Lock() + t.task[shardID].stats[taskID].stats.PushBack(tc) + t.task[shardID].stats[taskID].running++ // running job in this task + t.task[shardID].rw.Unlock() +} + +// ExitSubTask is to exit a task, and it will decrease the count of running pooltask. +func (t *TaskManager[T, U, C, CT, TF]) ExitSubTask(taskID uint64) { + shardID := getShardID(taskID) + t.running.Dec() // total running tasks + t.task[shardID].rw.Lock() + t.task[shardID].stats[taskID].running-- // running job in this task + t.task[shardID].rw.Unlock() +} + +// Running return the count of running job in this task. +func (t *TaskManager[T, U, C, CT, TF]) Running(taskID uint64) int32 { + shardID := getShardID(taskID) + t.task[shardID].rw.Lock() + defer t.task[shardID].rw.Unlock() + return t.task[shardID].stats[taskID].running +} + +// StopTask is to stop a task by TaskID. +func (t *TaskManager[T, U, C, CT, TF]) StopTask(taskID uint64) { + shardID := getShardID(taskID) + t.task[shardID].rw.Lock() + defer t.task[shardID].rw.Unlock() + l := t.task[shardID].stats[taskID].stats + for e := l.Front(); e != nil; e = e.Next() { + e.Value.(tContainer[T, U, C, CT, TF]).task.SetStatus(StopTask) + } +} diff --git a/resourcemanager/pooltask/task_test.go b/resourcemanager/pooltask/task_test.go new file mode 100644 index 0000000000000..b4f189fb14525 --- /dev/null +++ b/resourcemanager/pooltask/task_test.go @@ -0,0 +1,40 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pooltask + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTaskManager(t *testing.T) { + size := 32 + taskConcurrency := 8 + tm := NewTaskManager[int, int, int, any, NilContext](int32(size)) + tm.RegisterTask(1, int32(taskConcurrency)) + for i := 0; i < taskConcurrency; i++ { + tid := NewTaskBox[int, int, int, any, NilContext](1, NilContext{}, &sync.WaitGroup{}, make(chan Task[int]), make(chan int), 1) + tm.AddSubTask(1, &tid) + } + for i := 0; i < taskConcurrency; i++ { + tm.ExitSubTask(1) + } + require.Equal(t, int32(0), tm.Running(1)) + require.True(t, tm.hasTask(1)) + tm.DeleteTask(1) + require.False(t, tm.hasTask(1)) +} diff --git a/resourcemanager/scheduler/cpu_scheduler.go b/resourcemanager/scheduler/cpu_scheduler.go index 217c5aecbf1dd..c84fcf36fb697 100644 --- a/resourcemanager/scheduler/cpu_scheduler.go +++ b/resourcemanager/scheduler/cpu_scheduler.go @@ -35,10 +35,10 @@ func (*CPUScheduler) Tune(_ util.Component, pool util.GorotinuePool) Command { return Hold } if cpu.GetCPUUsage() < 0.5 { - return Downclock + return Overclock } if cpu.GetCPUUsage() > 0.7 { - return Overclock + return Downclock } return Hold } diff --git a/sessionctx/context.go b/sessionctx/context.go index 0e38fbdaba3d5..0999b2396cae0 100644 --- a/sessionctx/context.go +++ b/sessionctx/context.go @@ -54,8 +54,8 @@ type SessionStatesHandler interface { // PlanCache is an interface for prepare and non-prepared plan cache type PlanCache interface { - Get(key kvcache.Key, paramTypes []*types.FieldType) (value kvcache.Value, ok bool) - Put(key kvcache.Key, value kvcache.Value, paramTypes []*types.FieldType) + Get(key kvcache.Key, paramTypes []*types.FieldType, limitParams []uint64) (value kvcache.Value, ok bool) + Put(key kvcache.Key, value kvcache.Value, paramTypes []*types.FieldType, limitParams []uint64) Delete(key kvcache.Key) DeleteAll() Size() int diff --git a/statistics/handle/BUILD.bazel b/statistics/handle/BUILD.bazel index 81dff92c5b143..d52847495d539 100644 --- a/statistics/handle/BUILD.bazel +++ b/statistics/handle/BUILD.bazel @@ -42,6 +42,7 @@ go_library( "//util/memory", "//util/ranger", "//util/sqlexec", + "//util/syncutil", "//util/timeutil", "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", @@ -72,6 +73,7 @@ go_test( ], embed = [":handle"], flaky = True, + race = "on", shard_count = 50, deps = [ "//config", diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 79a5382779208..fc4f86dc54fb8 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -48,6 +48,7 @@ import ( "github.com/pingcap/tidb/util/mathutil" "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/syncutil" "github.com/prometheus/client_golang/prometheus" "github.com/tikv/client-go/v2/oracle" atomic2 "go.uber.org/atomic" @@ -70,7 +71,7 @@ type Handle struct { initStatsCtx sessionctx.Context mu struct { - sync.RWMutex + syncutil.RWMutex ctx sessionctx.Context // rateMap contains the error rate delta from feedback. rateMap errorRateDeltaMap @@ -361,8 +362,15 @@ func (h *Handle) RemoveLockedTables(tids []int64, pids []int64, tables []*ast.Ta return "", err } -// IsTableLocked check whether table is locked in handle +// IsTableLocked check whether table is locked in handle with Handle.Mutex func (h *Handle) IsTableLocked(tableID int64) bool { + h.mu.RLock() + defer h.mu.RUnlock() + return h.isTableLocked(tableID) +} + +// IsTableLocked check whether table is locked in handle without Handle.Mutex +func (h *Handle) isTableLocked(tableID int64) bool { return isTableLocked(h.tableLocked, tableID) } diff --git a/statistics/handle/update.go b/statistics/handle/update.go index 68aa9cebbcf05..e245a3ea0bca5 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -549,7 +549,8 @@ func (h *Handle) dumpTableStatCountToKV(id int64, delta variable.TableDelta) (up startTS := txn.StartTS() updateStatsMeta := func(id int64) error { var err error - if h.IsTableLocked(id) { + // This lock is already locked on it so it use isTableLocked without lock. + if h.isTableLocked(id) { if delta.Delta < 0 { _, err = exec.ExecuteInternal(ctx, "update mysql.stats_table_locked set version = %?, count = count - %?, modify_count = modify_count + %? where table_id = %? and count >= %?", startTS, -delta.Delta, delta.Count, id, -delta.Delta) } else { diff --git a/store/copr/BUILD.bazel b/store/copr/BUILD.bazel index 9ea8467d01dfa..a7cdd81453fd7 100644 --- a/store/copr/BUILD.bazel +++ b/store/copr/BUILD.bazel @@ -34,6 +34,7 @@ go_library( "//util/trxevents", "@com_github_dgraph_io_ristretto//:ristretto", "@com_github_gogo_protobuf//proto", + "@com_github_opentracing_opentracing_go//:opentracing-go", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/coprocessor", diff --git a/store/copr/coprocessor.go b/store/copr/coprocessor.go index 08f0f055a31e3..f446912ef33af 100644 --- a/store/copr/coprocessor.go +++ b/store/copr/coprocessor.go @@ -26,6 +26,7 @@ import ( "unsafe" "github.com/gogo/protobuf/proto" + "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/coprocessor" @@ -384,12 +385,20 @@ func buildCopTasks(bo *Backoffer, cache *RegionCache, ranges *KeyRanges, req *kv builder.reverse() } tasks := builder.build() - if elapsed := time.Since(start); elapsed > time.Millisecond*500 { + elapsed := time.Since(start) + if elapsed > time.Millisecond*500 { logutil.BgLogger().Warn("buildCopTasks takes too much time", zap.Duration("elapsed", elapsed), zap.Int("range len", rangesLen), zap.Int("task len", len(tasks))) } + if elapsed > time.Millisecond { + ctx := bo.GetCtx() + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("copr.buildCopTasks", opentracing.ChildOf(span.Context()), opentracing.StartTime(start)) + defer span1.Finish() + } + } metrics.TxnRegionsNumHistogramWithCoprocessor.Observe(float64(builder.regionNum())) return tasks, nil } diff --git a/util/gpool/spmc/BUILD.bazel b/util/gpool/spmc/BUILD.bazel index 1c951a219fb20..db4d724052666 100644 --- a/util/gpool/spmc/BUILD.bazel +++ b/util/gpool/spmc/BUILD.bazel @@ -32,6 +32,7 @@ go_test( "worker_loop_queue_test.go", ], embed = [":spmc"], + flaky = True, race = "on", deps = [ "//resourcemanager/pooltask", diff --git a/util/gpool/spmc/spmcpool.go b/util/gpool/spmc/spmcpool.go index b8cecb289c0e5..abef899961657 100644 --- a/util/gpool/spmc/spmcpool.go +++ b/util/gpool/spmc/spmcpool.go @@ -44,6 +44,7 @@ type Pool[T any, U any, C any, CT any, TF pooltask.Context[CT]] struct { lock sync.Locker cond *sync.Cond taskCh chan *pooltask.TaskBox[T, U, C, CT, TF] + taskManager pooltask.TaskManager[T, U, C, CT, TF] options *Options stopCh chan struct{} consumerFunc func(T, C, CT) U @@ -63,11 +64,12 @@ func NewSPMCPool[T any, U any, C any, CT any, TF pooltask.Context[CT]](name stri opts.ExpiryDuration = gpool.DefaultCleanIntervalTime } result := &Pool[T, U, C, CT, TF]{ - BasePool: gpool.NewBasePool(), - taskCh: make(chan *pooltask.TaskBox[T, U, C, CT, TF], 128), - stopCh: make(chan struct{}), - lock: gpool.NewSpinLock(), - options: opts, + BasePool: gpool.NewBasePool(), + taskCh: make(chan *pooltask.TaskBox[T, U, C, CT, TF], 128), + stopCh: make(chan struct{}), + lock: gpool.NewSpinLock(), + taskManager: pooltask.NewTaskManager[T, U, C, CT, TF](size), + options: opts, } result.SetName(name) result.state.Store(int32(gpool.OPENED)) @@ -247,6 +249,7 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error), closeCh := make(chan struct{}) inputCh := make(chan pooltask.Task[T], opt.TaskChanLen) tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result) + p.taskManager.RegisterTask(taskID, int32(opt.Concurrency)) for i := 0; i < opt.Concurrency; i++ { err := p.run() if err == gpool.ErrPoolClosed { @@ -254,6 +257,7 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error), } taskBox := pooltask.NewTaskBox[T, U, C, CT, TF](constArg, contextFn, &wg, inputCh, result, taskID) p.addWaitingTask() + p.taskManager.AddSubTask(taskID, &taskBox) p.taskCh <- &taskBox } go func() { @@ -294,6 +298,7 @@ func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg result := make(chan U, opt.ResultChanLen) closeCh := make(chan struct{}) inputCh := make(chan pooltask.Task[T], opt.TaskChanLen) + p.taskManager.RegisterTask(taskID, int32(opt.Concurrency)) tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result) for i := 0; i < opt.Concurrency; i++ { err := p.run() @@ -302,6 +307,7 @@ func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg } p.addWaitingTask() taskBox := pooltask.NewTaskBox[T, U, C, CT, TF](constArg, contextFn, &wg, inputCh, result, taskID) + p.taskManager.AddSubTask(taskID, &taskBox) p.taskCh <- &taskBox } go func() { @@ -426,3 +432,20 @@ func (p *Pool[T, U, C, CT, TF]) revertWorker(worker *goWorker[T, U, C, CT, TF]) p.lock.Unlock() return true } + +// DeleteTask is to delete task. +// Please don't use it manually. +func (p *Pool[T, U, C, CT, TF]) DeleteTask(id uint64) { + p.taskManager.DeleteTask(id) +} + +// StopTask is to stop task by id +// Please don't use it manually. +func (p *Pool[T, U, C, CT, TF]) StopTask(id uint64) { + p.taskManager.StopTask(id) +} + +// ExitSubTask is to reduce the number of subtasks. +func (p *Pool[T, U, C, CT, TF]) ExitSubTask(id uint64) { + p.taskManager.ExitSubTask(id) +} diff --git a/util/gpool/spmc/spmcpool_test.go b/util/gpool/spmc/spmcpool_test.go index cef958e52ad0f..1106e7bd98f69 100644 --- a/util/gpool/spmc/spmcpool_test.go +++ b/util/gpool/spmc/spmcpool_test.go @@ -76,6 +76,49 @@ func TestPool(t *testing.T) { pool.ReleaseAndWait() } +func TestStopPool(t *testing.T) { + type ConstArgs struct { + a int + } + myArgs := ConstArgs{a: 10} + // init the pool + // input type, output type, constArgs type + pool, err := NewSPMCPool[int, int, ConstArgs, any, pooltask.NilContext]("TestPool", 10, rmutil.UNKNOWN) + require.NoError(t, err) + pool.SetConsumerFunc(func(task int, constArgs ConstArgs, ctx any) int { + return task + constArgs.a + }) + + exit := make(chan struct{}) + + pfunc := func() (int, error) { + select { + case <-exit: + return 0, gpool.ErrProducerClosed + default: + return 1, nil + } + } + // add new task + resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4)) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for result := range resultCh { + require.Greater(t, result, 10) + } + }() + // Waiting task finishing + control.Stop() + close(exit) + control.Wait() + wg.Wait() + // close pool + pool.ReleaseAndWait() +} + func TestPoolWithEnoughCapacity(t *testing.T) { const ( RunTimes = 1000 diff --git a/util/gpool/spmc/worker.go b/util/gpool/spmc/worker.go index 32ff56a790dbd..b8e22376bb79a 100644 --- a/util/gpool/spmc/worker.go +++ b/util/gpool/spmc/worker.go @@ -58,14 +58,23 @@ func (w *goWorker[T, U, C, CT, TF]) run() { if f == nil { return } + if f.GetStatus() == pooltask.PendingTask { + f.SetStatus(pooltask.RunningTask) + } w.pool.subWaitingTask() ctx := f.GetContextFunc().GetContext() if f.GetResultCh() != nil { for t := range f.GetTaskCh() { + if f.GetStatus() == pooltask.StopTask { + f.Done() + continue + } f.GetResultCh() <- w.pool.consumerFunc(t.Task, f.ConstArgs(), ctx) f.Done() } + w.pool.ExitSubTask(f.TaskID()) } + f.Finish() if ok := w.pool.revertWorker(w); !ok { return }