diff --git a/pkg/ddl/index.go b/pkg/ddl/index.go index f2b421254f674..1e84d492e2b15 100644 --- a/pkg/ddl/index.go +++ b/pkg/ddl/index.go @@ -2118,6 +2118,8 @@ func (w *worker) executeDistTask(reorgInfo *reorgInfo) error { job := reorgInfo.Job workerCntLimit := int(variable.GetDDLReorgWorkerCounter()) + // we're using cpu count of current node, not of framework managed nodes, + // but it seems more intuitive. concurrency := min(workerCntLimit, cpu.GetCPUCount()) logutil.BgLogger().Info("adjusted add-index task concurrency", zap.Int("worker-cnt", workerCntLimit), zap.Int("task-concurrency", concurrency), diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index 567f9719c3a0a..5c24654bcec26 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -210,10 +210,10 @@ func (mr *MockTaskManagerMockRecorder) GCSubtasks(arg0 any) *gomock.Call { } // GetAllNodes mocks base method. -func (m *MockTaskManager) GetAllNodes(arg0 context.Context) ([]string, error) { +func (m *MockTaskManager) GetAllNodes(arg0 context.Context) ([]proto.ManagedNode, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAllNodes", arg0) - ret0, _ := ret[0].([]string) + ret0, _ := ret[0].([]proto.ManagedNode) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -225,10 +225,10 @@ func (mr *MockTaskManagerMockRecorder) GetAllNodes(arg0 any) *gomock.Call { } // GetManagedNodes mocks base method. -func (m *MockTaskManager) GetManagedNodes(arg0 context.Context) ([]string, error) { +func (m *MockTaskManager) GetManagedNodes(arg0 context.Context) ([]proto.ManagedNode, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetManagedNodes", arg0) - ret0, _ := ret[0].([]string) + ret0, _ := ret[0].([]proto.ManagedNode) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/disttask/framework/proto/BUILD.bazel b/pkg/disttask/framework/proto/BUILD.bazel index 5b38a5bf766ad..8359f0783ea8d 100644 --- a/pkg/disttask/framework/proto/BUILD.bazel +++ b/pkg/disttask/framework/proto/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "proto", srcs = [ + "node.go", "subtask.go", "task.go", ], diff --git a/pkg/disttask/framework/proto/node.go b/pkg/disttask/framework/proto/node.go new file mode 100644 index 0000000000000..aabe61479d464 --- /dev/null +++ b/pkg/disttask/framework/proto/node.go @@ -0,0 +1,25 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proto + +// ManagedNode is a TiDB node that is managed by the framework. +type ManagedNode struct { + // ID see GenerateExecID, it's named as host in the meta table. + ID string + // Role of the node, either "" or "background" + // all managed node should have the same role + Role string + CPUCount int +} diff --git a/pkg/disttask/framework/scheduler/BUILD.bazel b/pkg/disttask/framework/scheduler/BUILD.bazel index 53b68f930335d..dd50741076b86 100644 --- a/pkg/disttask/framework/scheduler/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/BUILD.bazel @@ -50,7 +50,7 @@ go_test( embed = [":scheduler"], flaky = True, race = "off", - shard_count = 24, + shard_count = 25, deps = [ "//pkg/config", "//pkg/disttask/framework/mock", @@ -63,6 +63,7 @@ go_test( "//pkg/sessionctx", "//pkg/testkit", "//pkg/testkit/testsetup", + "//pkg/util/cpu", "//pkg/util/disttask", "//pkg/util/logutil", "//pkg/util/sqlexec", diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index 0dbc5d9cd4a94..d0c792d1a10dc 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -33,7 +33,7 @@ type TaskManager interface { GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) UpdateTaskAndAddSubTasks(ctx context.Context, task *proto.Task, subtasks []*proto.Subtask, prevState proto.TaskState) (bool, error) GCSubtasks(ctx context.Context) error - GetAllNodes(ctx context.Context) ([]string, error) + GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) DeleteDeadNodes(ctx context.Context, nodes []string) error TransferTasks2History(ctx context.Context, tasks []*proto.Task) error CancelTask(ctx context.Context, taskID int64) error @@ -68,7 +68,7 @@ type TaskManager interface { // to execute tasks. If there are any nodes with background role, we use them, // else we use nodes without role. // returned nodes are sorted by node id(host:port). - GetManagedNodes(ctx context.Context) ([]string, error) + GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) diff --git a/pkg/disttask/framework/scheduler/nodes.go b/pkg/disttask/framework/scheduler/nodes.go index 47bea22172e52..d87cbed69ad3c 100644 --- a/pkg/disttask/framework/scheduler/nodes.go +++ b/pkg/disttask/framework/scheduler/nodes.go @@ -91,9 +91,9 @@ func (nm *NodeManager) maintainLiveNodes(ctx context.Context, taskMgr TaskManage } deadNodes := make([]string, 0) - for _, nodeID := range oldNodes { - if _, ok := currLiveNodes[nodeID]; !ok { - deadNodes = append(deadNodes, nodeID) + for _, node := range oldNodes { + if _, ok := currLiveNodes[node.ID]; !ok { + deadNodes = append(deadNodes, node.ID) } } if len(deadNodes) == 0 { @@ -110,7 +110,7 @@ func (nm *NodeManager) maintainLiveNodes(ctx context.Context, taskMgr TaskManage nm.prevLiveNodes = currLiveNodes } -func (nm *NodeManager) refreshManagedNodesLoop(ctx context.Context, taskMgr TaskManager) { +func (nm *NodeManager) refreshManagedNodesLoop(ctx context.Context, taskMgr TaskManager, slotMgr *slotManager) { ticker := time.NewTicker(nodesCheckInterval) defer ticker.Stop() for { @@ -118,22 +118,28 @@ func (nm *NodeManager) refreshManagedNodesLoop(ctx context.Context, taskMgr Task case <-ctx.Done(): return case <-ticker.C: - nm.refreshManagedNodes(ctx, taskMgr) + nm.refreshManagedNodes(ctx, taskMgr, slotMgr) } } } // refreshManagedNodes maintains the nodes managed by the framework. -func (nm *NodeManager) refreshManagedNodes(ctx context.Context, taskMgr TaskManager) { +func (nm *NodeManager) refreshManagedNodes(ctx context.Context, taskMgr TaskManager, slotMgr *slotManager) { newNodes, err := taskMgr.GetManagedNodes(ctx) if err != nil { logutil.BgLogger().Warn("get managed nodes met error", log.ShortError(err)) return } - if newNodes == nil { - newNodes = []string{} + nodeIDs := make([]string, 0, len(newNodes)) + var cpuCount int + for _, node := range newNodes { + nodeIDs = append(nodeIDs, node.ID) + if node.CPUCount > 0 { + cpuCount = node.CPUCount + } } - nm.managedNodes.Store(&newNodes) + slotMgr.updateCapacity(cpuCount) + nm.managedNodes.Store(&nodeIDs) } // GetManagedNodes returns the nodes managed by the framework. diff --git a/pkg/disttask/framework/scheduler/nodes_test.go b/pkg/disttask/framework/scheduler/nodes_test.go index e6ff715773de3..04f4a52db0b1a 100644 --- a/pkg/disttask/framework/scheduler/nodes_test.go +++ b/pkg/disttask/framework/scheduler/nodes_test.go @@ -21,7 +21,9 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/disttask/framework/mock" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/util/cpu" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -47,7 +49,7 @@ func TestMaintainLiveNodes(t *testing.T) { require.Empty(t, nodeMgr.prevLiveNodes) require.True(t, ctrl.Satisfied()) // no change - mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000"}, nil) + mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: ":4000"}}, nil) nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes) require.True(t, ctrl.Satisfied()) @@ -63,13 +65,13 @@ func TestMaintainLiveNodes(t *testing.T) { } // fail on clean - mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4002"}, nil) + mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: ":4000"}, {ID: ":4001"}, {ID: ":4002"}}, nil) mockTaskMgr.EXPECT().DeleteDeadNodes(gomock.Any(), gomock.Any()).Return(errors.New("mock error")) nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes) require.True(t, ctrl.Satisfied()) // remove 1 node - mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4002"}, nil) + mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: ":4000"}, {ID: ":4001"}, {ID: ":4002"}}, nil) mockTaskMgr.EXPECT().DeleteDeadNodes(gomock.Any(), gomock.Any()).Return(nil) nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) require.Equal(t, map[string]struct{}{":4000": {}, ":4001": {}}, nodeMgr.prevLiveNodes) @@ -84,7 +86,7 @@ func TestMaintainLiveNodes(t *testing.T) { {Port: 4000}, } - mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]string{":4000", ":4001", ":4002"}, nil) + mockTaskMgr.EXPECT().GetAllNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: ":4000"}, {ID: ":4001"}, {ID: ":4002"}}, nil) mockTaskMgr.EXPECT().DeleteDeadNodes(gomock.Any(), gomock.Any()).Return(nil) nodeMgr.maintainLiveNodes(ctx, mockTaskMgr) require.Equal(t, map[string]struct{}{":4000": {}}, nodeMgr.prevLiveNodes) @@ -102,18 +104,25 @@ func TestMaintainManagedNodes(t *testing.T) { mockTaskMgr := mock.NewMockTaskManager(ctrl) nodeMgr := newNodeManager() + slotMgr := newSlotManager() mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return(nil, errors.New("mock error")) - nodeMgr.refreshManagedNodes(ctx, mockTaskMgr) + nodeMgr.refreshManagedNodes(ctx, mockTaskMgr, slotMgr) + require.Equal(t, cpu.GetCPUCount(), int(slotMgr.capacity.Load())) require.Empty(t, nodeMgr.getManagedNodes()) require.True(t, ctrl.Satisfied()) - mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]string{":4000", ":4001"}, nil) - nodeMgr.refreshManagedNodes(ctx, mockTaskMgr) + mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]proto.ManagedNode{ + {ID: ":4000", CPUCount: 100}, + {ID: ":4001", CPUCount: 100}, + }, nil) + nodeMgr.refreshManagedNodes(ctx, mockTaskMgr, slotMgr) require.Equal(t, []string{":4000", ":4001"}, nodeMgr.getManagedNodes()) + require.Equal(t, 100, int(slotMgr.capacity.Load())) require.True(t, ctrl.Satisfied()) mockTaskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return(nil, nil) - nodeMgr.refreshManagedNodes(ctx, mockTaskMgr) + nodeMgr.refreshManagedNodes(ctx, mockTaskMgr, slotMgr) require.NotNil(t, nodeMgr.getManagedNodes()) require.Empty(t, nodeMgr.getManagedNodes()) + require.Equal(t, 100, int(slotMgr.capacity.Load())) require.True(t, ctrl.Satisfied()) } diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go b/pkg/disttask/framework/scheduler/scheduler_manager.go index a8baac9fd1a0a..9e50a5c8875e4 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager.go @@ -123,7 +123,7 @@ func (sm *Manager) Start() { failpoint.Return() }) // init cached managed nodes - sm.nodeMgr.refreshManagedNodes(sm.ctx, sm.taskMgr) + sm.nodeMgr.refreshManagedNodes(sm.ctx, sm.taskMgr, sm.slotMgr) sm.wg.Run(sm.scheduleTaskLoop) sm.wg.Run(sm.gcSubtaskHistoryTableLoop) @@ -132,7 +132,7 @@ func (sm *Manager) Start() { sm.nodeMgr.maintainLiveNodesLoop(sm.ctx, sm.taskMgr) }) sm.wg.Run(func() { - sm.nodeMgr.refreshManagedNodesLoop(sm.ctx, sm.taskMgr) + sm.nodeMgr.refreshManagedNodesLoop(sm.ctx, sm.taskMgr, sm.slotMgr) }) sm.initialized = true } diff --git a/pkg/disttask/framework/scheduler/slots.go b/pkg/disttask/framework/scheduler/slots.go index 01de011de354b..528dd45ad48d8 100644 --- a/pkg/disttask/framework/scheduler/slots.go +++ b/pkg/disttask/framework/scheduler/slots.go @@ -18,9 +18,12 @@ import ( "context" "slices" "sync" + "sync/atomic" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/util/cpu" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" ) type taskStripes struct { @@ -47,10 +50,7 @@ type taskStripes struct { // quota to subtask, but subtask can determine what to conform. type slotManager struct { // Capacity is the total number of slots and stripes. - // TODO: we assume that all nodes managed by dist framework are isomorphic, - // but dist owner might run on normal node where the capacity might not be - // able to run any task. - capacity int + capacity atomic.Int32 mu sync.RWMutex // represents the number of stripes reserved by task, when we reserve by the @@ -75,12 +75,16 @@ type slotManager struct { // newSlotManager creates a new slotManager. func newSlotManager() *slotManager { - return &slotManager{ - capacity: cpu.GetCPUCount(), + s := &slotManager{ task2Index: make(map[int64]int), reservedSlots: make(map[string]int), usedSlots: make(map[string]int), } + // this node might not be the managed node of the framework, but we initialize + // capacity with the cpu count of this node, it will be updated when node + // manager starts. + s.updateCapacity(cpu.GetCPUCount()) + return s } // Update updates the used slots on each node. @@ -96,7 +100,7 @@ func (sm *slotManager) update(ctx context.Context, taskMgr TaskManager) error { } newUsedSlots := make(map[string]int, len(nodes)) for _, node := range nodes { - newUsedSlots[node] = slotsOnNodes[node] + newUsedSlots[node.ID] = slotsOnNodes[node.ID] } sm.mu.Lock() defer sm.mu.Unlock() @@ -111,6 +115,7 @@ func (sm *slotManager) update(ctx context.Context, taskMgr TaskManager) error { // are enough resources, or return true on resource shortage when some task // scheduled subtasks. func (sm *slotManager) canReserve(task *proto.Task) (execID string, ok bool) { + capacity := int(sm.capacity.Load()) sm.mu.RLock() defer sm.mu.RUnlock() if len(sm.usedSlots) == 0 { @@ -125,12 +130,12 @@ func (sm *slotManager) canReserve(task *proto.Task) (execID string, ok bool) { } reservedForHigherPriority += s.stripes } - if task.Concurrency+reservedForHigherPriority <= sm.capacity { + if task.Concurrency+reservedForHigherPriority <= capacity { return "", true } for id, count := range sm.usedSlots { - if count+sm.reservedSlots[id]+task.Concurrency <= sm.capacity { + if count+sm.reservedSlots[id]+task.Concurrency <= capacity { return id, true } } @@ -178,3 +183,16 @@ func (sm *slotManager) unReserve(task *proto.Task, execID string) { } } } + +func (sm *slotManager) updateCapacity(cpuCount int) { + old := sm.capacity.Load() + if cpuCount > 0 && cpuCount != int(old) { + sm.capacity.Store(int32(cpuCount)) + if old == 0 { + logutil.BgLogger().Info("initialize slot capacity", zap.Int("capacity", cpuCount)) + } else { + logutil.BgLogger().Info("update slot capacity", + zap.Int("old", int(old)), zap.Int("new", cpuCount)) + } + } +} diff --git a/pkg/disttask/framework/scheduler/slots_test.go b/pkg/disttask/framework/scheduler/slots_test.go index 8363c7df381a9..04b72c6fdac9b 100644 --- a/pkg/disttask/framework/scheduler/slots_test.go +++ b/pkg/disttask/framework/scheduler/slots_test.go @@ -28,7 +28,7 @@ import ( func TestSlotManagerReserve(t *testing.T) { sm := newSlotManager() - sm.capacity = 16 + sm.updateCapacity(16) // no node _, ok := sm.canReserve(&proto.Task{Concurrency: 1}) require.False(t, ok) @@ -181,13 +181,13 @@ func TestSlotManagerUpdate(t *testing.T) { defer ctrl.Finish() taskMgr := mock.NewMockTaskManager(ctrl) - taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]string{"tidb-1", "tidb-2", "tidb-3"}, nil) + taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: "tidb-1"}, {ID: "tidb-2"}, {ID: "tidb-3"}}, nil) taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(map[string]int{ "tidb-1": 12, "tidb-2": 8, }, nil) sm := newSlotManager() - sm.capacity = 16 + sm.updateCapacity(16) require.Empty(t, sm.usedSlots) require.Empty(t, sm.reservedSlots) require.NoError(t, sm.update(context.Background(), taskMgr)) @@ -198,7 +198,7 @@ func TestSlotManagerUpdate(t *testing.T) { "tidb-3": 0, }, sm.usedSlots) // some node scaled in, should be reflected - taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]string{"tidb-1"}, nil) + taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: "tidb-1"}}, nil) taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(map[string]int{ "tidb-1": 12, "tidb-2": 8, @@ -215,7 +215,7 @@ func TestSlotManagerUpdate(t *testing.T) { require.Equal(t, map[string]int{ "tidb-1": 12, }, sm.usedSlots) - taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]string{"tidb-1"}, nil) + taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]proto.ManagedNode{{ID: "tidb-1"}}, nil) taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(nil, errors.New("mock err")) require.ErrorContains(t, sm.update(context.Background(), taskMgr), "mock err") require.Empty(t, sm.reservedSlots) @@ -223,3 +223,13 @@ func TestSlotManagerUpdate(t *testing.T) { "tidb-1": 12, }, sm.usedSlots) } + +func TestSlotManagerUpdateCapacity(t *testing.T) { + sm := newSlotManager() + sm.updateCapacity(16) + require.Equal(t, 16, int(sm.capacity.Load())) + sm.updateCapacity(32) + require.Equal(t, 32, int(sm.capacity.Load())) + sm.updateCapacity(0) + require.Equal(t, 32, int(sm.capacity.Load())) +} diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 4a7fd27039ca3..5f0d56d3b45f2 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -50,7 +50,9 @@ func checkTaskStateStep(t *testing.T, task *proto.Task, state proto.TaskState, s } func TestTaskTable(t *testing.T) { - gm, ctx := testutil.InitTableTest(t) + _, gm, ctx := testutil.InitTableTest(t) + + require.NoError(t, gm.StartManager(ctx, ":4000", "")) _, err := gm.CreateTask(ctx, "key1", "test", 999, []byte("test")) require.ErrorContains(t, err, "task concurrency(999) larger than cpu count") @@ -184,24 +186,10 @@ func checkAfterSwitchStep(t *testing.T, startTime time.Time, task *proto.Task, s } func TestSwitchTaskStep(t *testing.T) { - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)")) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask")) - }() - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(8)")) - t.Cleanup(func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu")) - }) - store := testkit.CreateMockStore(t) + store, tm, ctx := testutil.InitTableTest(t) tk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return tk.Session(), nil - }, 1, 1, time.Second) - tm := GetTaskManager(t, pool) - defer pool.Close() - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "table_test") + require.NoError(t, tm.StartManager(ctx, ":4000", "")) taskID, err := tm.CreateTask(ctx, "key1", "test", 4, []byte("test")) require.NoError(t, err) task, err := tm.GetTaskByID(ctx, taskID) @@ -245,25 +233,10 @@ func TestSwitchTaskStep(t *testing.T) { } func TestSwitchTaskStepInBatch(t *testing.T) { - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)")) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask")) - }() - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(8)")) - t.Cleanup(func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu")) - }) - - store := testkit.CreateMockStore(t) + store, tm, ctx := testutil.InitTableTest(t) tk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return tk.Session(), nil - }, 1, 1, time.Second) - tm := GetTaskManager(t, pool) - defer pool.Close() - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "table_test") + require.NoError(t, tm.StartManager(ctx, ":4000", "")) // normal flow prepare := func(taskKey string) (*proto.Task, []*proto.Subtask) { taskID, err := tm.CreateTask(ctx, taskKey, "test", 4, []byte("test")) @@ -336,8 +309,9 @@ func TestSwitchTaskStepInBatch(t *testing.T) { } func TestGetTopUnfinishedTasks(t *testing.T) { - gm, ctx := testutil.InitTableTest(t) + _, gm, ctx := testutil.InitTableTest(t) + require.NoError(t, gm.StartManager(ctx, ":4000", "")) taskStates := []proto.TaskState{ proto.TaskStateSucceed, proto.TaskStatePending, @@ -401,7 +375,7 @@ func TestGetTopUnfinishedTasks(t *testing.T) { } func TestGetUsedSlotsOnNodes(t *testing.T) { - sm, ctx := testutil.InitTableTest(t) + _, sm, ctx := testutil.InitTableTest(t) testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-1", []byte(""), proto.TaskStateRunning, "test", 12) testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-2", []byte(""), proto.TaskStatePending, "test", 12) @@ -418,8 +392,9 @@ func TestGetUsedSlotsOnNodes(t *testing.T) { } func TestSubTaskTable(t *testing.T) { - sm, ctx := testutil.InitTableTest(t) + _, sm, ctx := testutil.InitTableTest(t) timeBeforeCreate := time.Unix(time.Now().Unix(), 0) + require.NoError(t, sm.StartManager(ctx, ":4000", "")) id, err := sm.CreateTask(ctx, "key1", "test", 4, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) @@ -648,7 +623,8 @@ func TestSubTaskTable(t *testing.T) { } func TestBothTaskAndSubTaskTable(t *testing.T) { - sm, ctx := testutil.InitTableTest(t) + _, sm, ctx := testutil.InitTableTest(t) + require.NoError(t, sm.StartManager(ctx, ":4000", "")) id, err := sm.CreateTask(ctx, "key1", "test", 4, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) @@ -761,42 +737,75 @@ func TestBothTaskAndSubTaskTable(t *testing.T) { } func TestDistFrameworkMeta(t *testing.T) { - // to avoid inserted nodes be cleaned by scheduler - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)")) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask")) - }() - sm, ctx := testutil.InitTableTest(t) + _, sm, ctx := testutil.InitTableTest(t) + + // when no node + _, err := storage.GetCPUCountOfManagedNodes(ctx, sm) + require.ErrorContains(t, err, "no managed nodes") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(0)")) + require.NoError(t, sm.StartManager(ctx, ":4000", "background")) + cpuCount, err := storage.GetCPUCountOfManagedNodes(ctx, sm) + require.NoError(t, err) + require.Equal(t, 0, cpuCount) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(100)")) require.NoError(t, sm.StartManager(ctx, ":4000", "background")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(8)")) require.NoError(t, sm.StartManager(ctx, ":4001", "")) require.NoError(t, sm.StartManager(ctx, ":4002", "background")) - // won't be replaced by below one + nodes, err := sm.GetAllNodes(ctx) + require.NoError(t, err) + require.Equal(t, []proto.ManagedNode{ + {ID: ":4000", Role: "background", CPUCount: 100}, + {ID: ":4001", Role: "", CPUCount: 8}, + {ID: ":4002", Role: "background", CPUCount: 8}, + }, nodes) + + // won't be replaced by below one, but cpu count will be updated + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(100)")) require.NoError(t, sm.StartManager(ctx, ":4002", "")) require.NoError(t, sm.StartManager(ctx, ":4003", "background")) - nodes, err := sm.GetAllNodes(ctx) + nodes, err = sm.GetAllNodes(ctx) + require.NoError(t, err) + require.Equal(t, []proto.ManagedNode{ + {ID: ":4000", Role: "background", CPUCount: 100}, + {ID: ":4001", Role: "", CPUCount: 8}, + {ID: ":4002", Role: "background", CPUCount: 100}, + {ID: ":4003", Role: "background", CPUCount: 100}, + }, nodes) + cpuCount, err = storage.GetCPUCountOfManagedNodes(ctx, sm) require.NoError(t, err) - require.Equal(t, []string{":4000", ":4001", ":4002", ":4003"}, nodes) + require.Equal(t, 100, cpuCount) require.NoError(t, sm.DeleteDeadNodes(ctx, []string{":4000"})) nodes, err = sm.GetManagedNodes(ctx) require.NoError(t, err) - require.Equal(t, []string{":4002", ":4003"}, nodes) + require.Equal(t, []proto.ManagedNode{ + {ID: ":4002", Role: "background", CPUCount: 100}, + {ID: ":4003", Role: "background", CPUCount: 100}, + }, nodes) require.NoError(t, sm.DeleteDeadNodes(ctx, []string{":4003"})) nodes, err = sm.GetManagedNodes(ctx) require.NoError(t, err) - require.Equal(t, []string{":4002"}, nodes) + require.Equal(t, []proto.ManagedNode{ + {ID: ":4002", Role: "background", CPUCount: 100}, + }, nodes) require.NoError(t, sm.DeleteDeadNodes(ctx, []string{":4002"})) nodes, err = sm.GetManagedNodes(ctx) require.NoError(t, err) - require.Equal(t, []string{":4001"}, nodes) + require.Equal(t, []proto.ManagedNode{ + {ID: ":4001", Role: "", CPUCount: 8}, + }, nodes) + cpuCount, err = storage.GetCPUCountOfManagedNodes(ctx, sm) + require.NoError(t, err) + require.Equal(t, 8, cpuCount) } func TestSubtaskHistoryTable(t *testing.T) { - sm, ctx := testutil.InitTableTest(t) + _, sm, ctx := testutil.InitTableTest(t) const ( taskID = 1 @@ -861,8 +870,9 @@ func TestSubtaskHistoryTable(t *testing.T) { } func TestTaskHistoryTable(t *testing.T) { - gm, ctx := testutil.InitTableTest(t) + _, gm, ctx := testutil.InitTableTest(t) + require.NoError(t, gm.StartManager(ctx, ":4000", "")) _, err := gm.CreateTask(ctx, "1", proto.TaskTypeExample, 1, nil) require.NoError(t, err) taskID, err := gm.CreateTask(ctx, "2", proto.TaskTypeExample, 1, nil) @@ -903,7 +913,7 @@ func TestTaskHistoryTable(t *testing.T) { } func TestPauseAndResume(t *testing.T) { - sm, ctx := testutil.InitTableTest(t) + _, sm, ctx := testutil.InitTableTest(t) testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) @@ -920,7 +930,7 @@ func TestPauseAndResume(t *testing.T) { require.Equal(t, int64(3), cnt) // 2.1 pause 2 subtasks. - sm.UpdateSubtaskStateAndError(ctx, "tidb1", 1, proto.SubtaskStateSucceed, nil) + require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", 1, proto.SubtaskStateSucceed, nil)) require.NoError(t, sm.PauseSubtasks(ctx, "tidb1", 1)) cnt, err = sm.GetSubtaskInStatesCnt(ctx, 1, proto.SubtaskStatePaused) require.NoError(t, err) diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 8c3bfbe1ccb3f..f9b7fa141c082 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -221,14 +221,13 @@ func (stm *TaskManager) CreateTask(ctx context.Context, key string, tp proto.Tas } // CreateTaskWithSession adds a new task to task table with session. -func (*TaskManager) CreateTaskWithSession(ctx context.Context, se sessionctx.Context, key string, tp proto.TaskType, concurrency int, meta []byte) (taskID int64, err error) { - cpuCount := cpu.GetCPUCount() +func (stm *TaskManager) CreateTaskWithSession(ctx context.Context, se sessionctx.Context, key string, tp proto.TaskType, concurrency int, meta []byte) (taskID int64, err error) { + cpuCount, err := stm.getCPUCountOfManagedNodes(ctx, se) + if err != nil { + return 0, err + } if concurrency > cpuCount { - // current resource control cannot schedule tasks with concurrency larger - // than cpu count - // TODO: if we are submitting a task on a node that is not managed by - // disttask framework, the checked cpu-count might not right. - return 0, errors.Errorf("task concurrency(%d) larger than cpu count(%d)", concurrency, cpuCount) + return 0, errors.Errorf("task concurrency(%d) larger than cpu count(%d) of managed node", concurrency, cpuCount) } _, err = sqlexec.ExecSQL(ctx, se, ` insert into mysql.tidb_global_task(`+InsertTaskColumns+`) @@ -691,10 +690,21 @@ func (stm *TaskManager) StartSubtask(ctx context.Context, subtaskID int64, execI } // StartManager insert the manager information into dist_framework_meta. -func (stm *TaskManager) StartManager(ctx context.Context, execID string, role string) error { - _, err := stm.executeSQLWithNewSession(ctx, `insert into mysql.dist_framework_meta(host, role, keyspace_id) - SELECT %?, %?,-1 - WHERE NOT EXISTS (SELECT 1 FROM mysql.dist_framework_meta WHERE host = %?)`, execID, role, execID) +func (stm *TaskManager) StartManager(ctx context.Context, tidbID string, role string) error { + return stm.WithNewSession(func(se sessionctx.Context) error { + return stm.StartManagerSession(ctx, se, tidbID, role) + }) +} + +// StartManagerSession insert the manager information into dist_framework_meta. +// if the record exists, update the cpu_count. +func (*TaskManager) StartManagerSession(ctx context.Context, se sessionctx.Context, execID string, role string) error { + cpuCount := cpu.GetCPUCount() + _, err := sqlexec.ExecSQL(ctx, se, ` + insert into mysql.dist_framework_meta(host, role, cpu_count, keyspace_id) + values (%?, %?, %?, -1) + on duplicate key update cpu_count = %?`, + execID, role, cpuCount, cpuCount) return err } @@ -1247,36 +1257,76 @@ func (stm *TaskManager) TransferTasks2History(ctx context.Context, tasks []*prot } // GetManagedNodes implements scheduler.TaskManager interface. -func (stm *TaskManager) GetManagedNodes(ctx context.Context) ([]string, error) { - rs, err := stm.executeSQLWithNewSession(ctx, ` - select host, role - from mysql.dist_framework_meta - where role = 'background' or role = '' - order by host`) +func (stm *TaskManager) GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) { + var nodes []proto.ManagedNode + err := stm.WithNewSession(func(se sessionctx.Context) error { + var err2 error + nodes, err2 = stm.getManagedNodesWithSession(ctx, se) + return err2 + }) + return nodes, err +} + +func (stm *TaskManager) getManagedNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) { + nodes, err := stm.getAllNodesWithSession(ctx, se) if err != nil { return nil, err } - nodes := make(map[string][]string, 2) - for _, r := range rs { - role := r.GetString(1) - nodes[role] = append(nodes[role], r.GetString(0)) + nodeMap := make(map[string][]proto.ManagedNode, 2) + for _, node := range nodes { + nodeMap[node.Role] = append(nodeMap[node.Role], node) } - if len(nodes["background"]) == 0 { - return nodes[""], nil + if len(nodeMap["background"]) == 0 { + return nodeMap[""], nil } - return nodes["background"], nil + return nodeMap["background"], nil } // GetAllNodes gets nodes in dist_framework_meta. -func (stm *TaskManager) GetAllNodes(ctx context.Context) ([]string, error) { - rs, err := stm.executeSQLWithNewSession(ctx, - "select host from mysql.dist_framework_meta") +func (stm *TaskManager) GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) { + var nodes []proto.ManagedNode + err := stm.WithNewSession(func(se sessionctx.Context) error { + var err2 error + nodes, err2 = stm.getAllNodesWithSession(ctx, se) + return err2 + }) + return nodes, err +} + +func (*TaskManager) getAllNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) { + rs, err := sqlexec.ExecSQL(ctx, se, ` + select host, role, cpu_count + from mysql.dist_framework_meta + order by host`) if err != nil { return nil, err } - nodes := make([]string, 0, len(rs)) + nodes := make([]proto.ManagedNode, 0, len(rs)) for _, r := range rs { - nodes = append(nodes, r.GetString(0)) + nodes = append(nodes, proto.ManagedNode{ + ID: r.GetString(0), + Role: r.GetString(1), + CPUCount: int(r.GetInt64(2)), + }) } return nodes, nil } + +// getCPUCountOfManagedNodes gets the cpu count of managed nodes. +func (stm *TaskManager) getCPUCountOfManagedNodes(ctx context.Context, se sessionctx.Context) (int, error) { + nodes, err := stm.getManagedNodesWithSession(ctx, se) + if err != nil { + return 0, err + } + if len(nodes) == 0 { + return 0, errors.New("no managed nodes") + } + var cpuCount int + for _, n := range nodes { + if n.CPUCount > 0 { + cpuCount = n.CPUCount + break + } + } + return cpuCount, nil +} diff --git a/pkg/disttask/framework/storage/task_table_test.go b/pkg/disttask/framework/storage/task_table_test.go index dec57601a2af4..daccd37c0829b 100644 --- a/pkg/disttask/framework/storage/task_table_test.go +++ b/pkg/disttask/framework/storage/task_table_test.go @@ -15,11 +15,13 @@ package storage import ( + "context" "testing" "github.com/pingcap/tidb/pkg/config" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/testkit/testsetup" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -37,6 +39,16 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m, opts...) } +func GetCPUCountOfManagedNodes(ctx context.Context, taskMgr *TaskManager) (int, error) { + var cnt int + err := taskMgr.WithNewSession(func(se sessionctx.Context) error { + var err2 error + cnt, err2 = taskMgr.getCPUCountOfManagedNodes(ctx, se) + return err2 + }) + return cnt, err +} + func TestSplitSubtasks(t *testing.T) { tm := &TaskManager{} subtasks := make([]*proto.Subtask, 0, 10) diff --git a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go index 25bbf96cdac11..2581f21e43b65 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go +++ b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go @@ -96,6 +96,7 @@ func TestTaskExecutorBasic(t *testing.T) { } return nil }) + require.NoError(t, mgr.StartManager(ctx, ":4000", "")) for i := 0; i < 10; i++ { runOneTask(ctx, t, mgr, "key"+strconv.Itoa(i), i) } diff --git a/pkg/disttask/framework/testutil/BUILD.bazel b/pkg/disttask/framework/testutil/BUILD.bazel index 9a01022874cf3..2f8cf34e03341 100644 --- a/pkg/disttask/framework/testutil/BUILD.bazel +++ b/pkg/disttask/framework/testutil/BUILD.bazel @@ -21,6 +21,7 @@ go_library( "//pkg/disttask/framework/scheduler/mock", "//pkg/disttask/framework/storage", "//pkg/disttask/framework/taskexecutor", + "//pkg/kv", "//pkg/sessionctx", "//pkg/testkit", "//pkg/util/sqlexec", diff --git a/pkg/disttask/framework/testutil/context.go b/pkg/disttask/framework/testutil/context.go index 1ef6ffebfa2b6..34b38ba27c80c 100644 --- a/pkg/disttask/framework/testutil/context.go +++ b/pkg/disttask/framework/testutil/context.go @@ -20,9 +20,11 @@ import ( "sync" "sync/atomic" "testing" + "time" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" @@ -51,10 +53,19 @@ func InitTestContext(t *testing.T, nodeNum int) (context.Context, *gomock.Contro require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu")) }) + executionContext := testkit.NewDistExecutionContext(t, nodeNum) + // wait until some node is registered. + require.Eventually(t, func() bool { + taskMgr, err := storage.GetTaskManager() + require.NoError(t, err) + nodes, err := taskMgr.GetAllNodes(ctx) + require.NoError(t, err) + return len(nodes) > 0 + }, 5*time.Second, 100*time.Millisecond) testCtx := &TestContext{ subtasksHasRun: make(map[string]map[int64]struct{}), } - return ctx, ctrl, testCtx, testkit.NewDistExecutionContext(t, nodeNum) + return ctx, ctrl, testCtx, executionContext } // CollectSubtask collects subtask info diff --git a/pkg/disttask/framework/testutil/table_util.go b/pkg/disttask/framework/testutil/table_util.go index 00078765724e9..aaec5232ab6c1 100644 --- a/pkg/disttask/framework/testutil/table_util.go +++ b/pkg/disttask/framework/testutil/table_util.go @@ -22,32 +22,38 @@ import ( "github.com/ngaut/pools" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" ) // InitTableTest inits needed components for table_test. -func InitTableTest(t *testing.T) (*storage.TaskManager, context.Context) { - pool := getResourcePool(t) +// it disables disttask and mock cpu count to 8. +func InitTableTest(t *testing.T) (kv.Storage, *storage.TaskManager, context.Context) { + store, pool := getResourcePool(t) ctx := context.Background() ctx = util.WithInternalSourceType(ctx, "table_test") require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(8)")) t.Cleanup(func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu")) }) - return getTaskManager(t, pool), ctx + return store, getTaskManager(t, pool), ctx } // InitTableTestWithCancel inits needed components with context.CancelFunc for table_test. func InitTableTestWithCancel(t *testing.T) (*storage.TaskManager, context.Context, context.CancelFunc) { - pool := getResourcePool(t) + _, pool := getResourcePool(t) ctx, cancel := context.WithCancel(context.Background()) ctx = util.WithInternalSourceType(ctx, "table_test") return getTaskManager(t, pool), ctx, cancel } -func getResourcePool(t *testing.T) *pools.ResourcePool { +func getResourcePool(t *testing.T) (kv.Storage, *pools.ResourcePool) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask")) + }() store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) pool := pools.NewResourcePool(func() (pools.Resource, error) { @@ -57,7 +63,7 @@ func getResourcePool(t *testing.T) *pools.ResourcePool { t.Cleanup(func() { pool.Close() }) - return pool + return store, pool } func getTaskManager(t *testing.T, pool *pools.ResourcePool) *storage.TaskManager { diff --git a/pkg/executor/importer/import.go b/pkg/executor/importer/import.go index 86c15046f8af2..df41fc850b91c 100644 --- a/pkg/executor/importer/import.go +++ b/pkg/executor/importer/import.go @@ -500,6 +500,8 @@ func (e *LoadDataController) checkFieldParams() error { } func (p *Plan) initDefaultOptions() { + // we're using cpu count of current node, not of framework managed nodes, + // but it seems more intuitive. threadCnt := cpu.GetCPUCount() threadCnt = int(math.Max(1, float64(threadCnt)*0.5)) diff --git a/pkg/executor/set.go b/pkg/executor/set.go index 4d80c4695adef..9272b863b10f8 100644 --- a/pkg/executor/set.go +++ b/pkg/executor/set.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/executor/internal/exec" "github.com/pingcap/tidb/pkg/expression" @@ -39,7 +40,6 @@ import ( "github.com/pingcap/tidb/pkg/util/gcutil" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/sem" - "github.com/pingcap/tidb/pkg/util/sqlexec" "go.uber.org/zap" ) @@ -172,8 +172,11 @@ func (e *SetExecutor) setSysVariable(ctx context.Context, name string, v *expres dom := domain.GetDomain(e.Ctx()) config.GetGlobalConfig().Instance.TiDBServiceScope = valStr serverID := disttaskutil.GenerateSubtaskExecID(ctx, dom.DDL().GetID()) - _, err = e.Ctx().(sqlexec.SQLExecutor).ExecuteInternal(ctx, - `replace into mysql.dist_framework_meta values(%?, %?, DEFAULT)`, serverID, valStr) + taskMgr, err := storage.GetTaskManager() + if err != nil { + return err + } + return taskMgr.StartManagerSession(ctx, e.Ctx(), serverID, valStr) } return err } diff --git a/pkg/session/bootstrap.go b/pkg/session/bootstrap.go index 8fac677aa1423..10950a4e55a23 100644 --- a/pkg/session/bootstrap.go +++ b/pkg/session/bootstrap.go @@ -612,7 +612,9 @@ const ( CreateDistFrameworkMeta = `CREATE TABLE IF NOT EXISTS mysql.dist_framework_meta ( host VARCHAR(100) NOT NULL PRIMARY KEY, role VARCHAR(64), - keyspace_id bigint(8) NOT NULL DEFAULT -1);` + cpu_count int default 0, + keyspace_id bigint(8) NOT NULL DEFAULT -1 + );` // CreateLoadDataJobs is a table that LOAD DATA uses CreateLoadDataJobs = `CREATE TABLE IF NOT EXISTS mysql.load_data_jobs ( @@ -1036,6 +1038,7 @@ const ( // add priority/create_time/end_time to `mysql.tidb_global_task`/`mysql.tidb_global_task_history` // add concurrency/create_time/end_time/digest to `mysql.tidb_background_subtask`/`mysql.tidb_background_subtask_history` // add idx_exec_id(exec_id), uk_digest to `mysql.tidb_background_subtask` + // add cpu_count to mysql.dist_framework_meta version180 = 180 // version 181 @@ -2938,6 +2941,8 @@ func upgradeToVer180(s sessiontypes.Session, ver int64) { doReentrantDDL(s, "ALTER TABLE mysql.tidb_background_subtask ADD INDEX idx_exec_id(exec_id)", dbterror.ErrDupKeyName) doReentrantDDL(s, "ALTER TABLE mysql.tidb_background_subtask ADD UNIQUE INDEX uk_task_key_step_ordinal(task_key, step, ordinal)", dbterror.ErrDupKeyName) + + doReentrantDDL(s, "ALTER TABLE mysql.dist_framework_meta ADD COLUMN `cpu_count` INT DEFAULT 0 AFTER `role`", infoschema.ErrColumnExists) } func upgradeToVer181(s sessiontypes.Session, ver int64) { diff --git a/pkg/util/disttask/idservice.go b/pkg/util/disttask/idservice.go index 239760a7d89ca..fea4fb13ed0a5 100644 --- a/pkg/util/disttask/idservice.go +++ b/pkg/util/disttask/idservice.go @@ -45,16 +45,6 @@ func FindServerInfo(serverInfos []*infosync.ServerInfo, schedulerID string) int return -1 } -// MatchSchedulerID will find schedulerID in taskNodes. -func MatchSchedulerID(taskNodes []string, schedulerID string) bool { - for _, nodeID := range taskNodes { - if schedulerID == nodeID { - return true - } - } - return false -} - // GenerateSubtaskExecID generates the subTask execID. func GenerateSubtaskExecID(ctx context.Context, id string) string { serverInfos, err := infosync.GetAllServerInfo(ctx)