From e631ba8a9bb93c1afdf378eef759db09a5594e4d Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Mon, 2 Dec 2024 22:05:17 +0800 Subject: [PATCH] dxf: handle modifying task concurrency in scheduler (#57673) ref pingcap/tidb#57497 --- .../framework/integrationtests/BUILD.bazel | 3 +- .../framework/integrationtests/modify_test.go | 224 ++++++++++++++++++ pkg/disttask/framework/mock/scheduler_mock.go | 14 ++ pkg/disttask/framework/proto/modify.go | 12 + pkg/disttask/framework/proto/task.go | 1 + pkg/disttask/framework/scheduler/interface.go | 6 +- pkg/disttask/framework/scheduler/scheduler.go | 29 ++- .../framework/scheduler/scheduler_manager.go | 1 + .../scheduler/scheduler_nokit_test.go | 26 ++ pkg/disttask/framework/storage/task_state.go | 35 +++ .../framework/storage/task_state_test.go | 103 +++++++- pkg/disttask/framework/storage/task_table.go | 3 +- pkg/disttask/importinto/mock/import_mock.go | 2 +- .../mock/restricted_sql_executor_mock.go | 2 +- 14 files changed, 450 insertions(+), 11 deletions(-) create mode 100644 pkg/disttask/framework/integrationtests/modify_test.go diff --git a/pkg/disttask/framework/integrationtests/BUILD.bazel b/pkg/disttask/framework/integrationtests/BUILD.bazel index 7e51f90b9f532..3c1aaaaa657bb 100644 --- a/pkg/disttask/framework/integrationtests/BUILD.bazel +++ b/pkg/disttask/framework/integrationtests/BUILD.bazel @@ -12,11 +12,12 @@ go_test( "framework_scope_test.go", "framework_test.go", "main_test.go", + "modify_test.go", "resource_control_test.go", ], flaky = True, race = "off", - shard_count = 22, + shard_count = 23, deps = [ "//pkg/config", "//pkg/ddl", diff --git a/pkg/disttask/framework/integrationtests/modify_test.go b/pkg/disttask/framework/integrationtests/modify_test.go new file mode 100644 index 0000000000000..32944d0821a8b --- /dev/null +++ b/pkg/disttask/framework/integrationtests/modify_test.go @@ -0,0 +1,224 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integrationtests + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/pingcap/tidb/pkg/disttask/framework/handle" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" + "github.com/pingcap/tidb/pkg/testkit/testfailpoint" + "github.com/stretchr/testify/require" +) + +func TestModifyTaskConcurrency(t *testing.T) { + c := testutil.NewTestDXFContext(t, 1, 16, true) + schedulerExt := testutil.GetMockSchedulerExt(c.MockCtrl, testutil.SchedulerInfo{ + AllErrorRetryable: true, + StepInfos: []testutil.StepInfo{ + {Step: proto.StepOne, SubtaskCnt: 1}, + {Step: proto.StepTwo, SubtaskCnt: 1}, + }, + }) + subtaskCh := make(chan struct{}) + registerExampleTask(t, c.MockCtrl, schedulerExt, c.TestContext, + func(ctx context.Context, subtask *proto.Subtask) error { + select { + case <-subtaskCh: + case <-ctx.Done(): + return ctx.Err() + } + return nil + }, + ) + + t.Run("modify pending task concurrency", func(t *testing.T) { + var once sync.Once + modifySyncCh := make(chan struct{}) + var theTask *proto.Task + testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { + once.Do(func() { + task, err := handle.SubmitTask(c.Ctx, "k1", proto.TaskTypeExample, 3, "", nil) + require.NoError(t, err) + require.Equal(t, 3, task.Concurrency) + require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ + PrevState: proto.TaskStatePending, + Modifications: []proto.Modification{ + {Type: proto.ModifyConcurrency, To: 7}, + }, + })) + theTask = task + gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, theTask.ID) + require.NoError(t, err) + require.Equal(t, proto.TaskStateModifying, gotTask.State) + require.Equal(t, 3, gotTask.Concurrency) + <-modifySyncCh + }) + }) + modifySyncCh <- struct{}{} + // finish subtasks + subtaskCh <- struct{}{} + subtaskCh <- struct{}{} + task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key) + require.Equal(t, proto.TaskStateSucceed, task2Base.State) + checkSubtaskConcurrency(t, c, theTask.ID, map[proto.Step]int{ + proto.StepOne: 7, + proto.StepTwo: 7, + }) + }) + + t.Run("modify running task concurrency at step two", func(t *testing.T) { + var once sync.Once + modifySyncCh := make(chan struct{}) + testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeRefreshTask", func(task *proto.Task) { + if task.State != proto.TaskStateRunning && task.Step != proto.StepTwo { + return + } + once.Do(func() { + require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ + PrevState: proto.TaskStateRunning, + Modifications: []proto.Modification{ + {Type: proto.ModifyConcurrency, To: 7}, + }, + })) + <-modifySyncCh + }) + }) + task, err := handle.SubmitTask(c.Ctx, "k2", proto.TaskTypeExample, 3, "", nil) + require.NoError(t, err) + require.Equal(t, 3, task.Concurrency) + // finish StepOne + subtaskCh <- struct{}{} + // wait task move to 'modifying' state + modifySyncCh <- struct{}{} + // wait task move back to 'running' state + require.Eventually(t, func() bool { + gotTask, err2 := c.TaskMgr.GetTaskByID(c.Ctx, task.ID) + require.NoError(t, err2) + return gotTask.State == proto.TaskStateRunning + }, 10*time.Second, 100*time.Millisecond) + // finish StepTwo + subtaskCh <- struct{}{} + task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key) + require.Equal(t, proto.TaskStateSucceed, task2Base.State) + checkSubtaskConcurrency(t, c, task.ID, map[proto.Step]int{ + proto.StepOne: 3, + proto.StepTwo: 7, + }) + }) + + t.Run("modify paused task concurrency", func(t *testing.T) { + var once sync.Once + syncCh := make(chan struct{}) + var theTask *proto.Task + testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { + once.Do(func() { + task, err := handle.SubmitTask(c.Ctx, "k3", proto.TaskTypeExample, 3, "", nil) + require.NoError(t, err) + require.Equal(t, 3, task.Concurrency) + found, err := c.TaskMgr.PauseTask(c.Ctx, task.Key) + require.NoError(t, err) + require.True(t, found) + theTask = task + <-syncCh + }) + }) + syncCh <- struct{}{} + taskBase := testutil.WaitTaskDoneOrPaused(c.Ctx, t, theTask.Key) + require.Equal(t, proto.TaskStatePaused, taskBase.State) + require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, theTask.ID, &proto.ModifyParam{ + PrevState: proto.TaskStatePaused, + Modifications: []proto.Modification{ + {Type: proto.ModifyConcurrency, To: 7}, + }, + })) + taskBase = testutil.WaitTaskDoneOrPaused(c.Ctx, t, theTask.Key) + require.Equal(t, proto.TaskStatePaused, taskBase.State) + found, err := c.TaskMgr.ResumeTask(c.Ctx, theTask.Key) + require.NoError(t, err) + require.True(t, found) + // finish subtasks + subtaskCh <- struct{}{} + subtaskCh <- struct{}{} + task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key) + require.Equal(t, proto.TaskStateSucceed, task2Base.State) + checkSubtaskConcurrency(t, c, theTask.ID, map[proto.Step]int{ + proto.StepOne: 7, + proto.StepTwo: 7, + }) + }) + + t.Run("modify pending task concurrency, but other owner already done it", func(t *testing.T) { + var once sync.Once + modifySyncCh := make(chan struct{}) + var theTask *proto.Task + testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { + once.Do(func() { + task, err := handle.SubmitTask(c.Ctx, "k4", proto.TaskTypeExample, 3, "", nil) + require.NoError(t, err) + require.Equal(t, 3, task.Concurrency) + require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ + PrevState: proto.TaskStatePending, + Modifications: []proto.Modification{ + {Type: proto.ModifyConcurrency, To: 7}, + }, + })) + theTask = task + gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, theTask.ID) + require.NoError(t, err) + require.Equal(t, proto.TaskStateModifying, gotTask.State) + require.Equal(t, 3, gotTask.Concurrency) + }) + }) + var onceForRefresh sync.Once + testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/afterRefreshTask", + func(task *proto.Task) { + onceForRefresh.Do(func() { + require.Equal(t, proto.TaskStateModifying, task.State) + taskClone := *task + taskClone.Concurrency = 7 + require.NoError(t, c.TaskMgr.ModifiedTask(c.Ctx, &taskClone)) + gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, task.ID) + require.NoError(t, err) + require.Equal(t, proto.TaskStatePending, gotTask.State) + <-modifySyncCh + }) + }, + ) + modifySyncCh <- struct{}{} + // finish subtasks + subtaskCh <- struct{}{} + subtaskCh <- struct{}{} + task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key) + require.Equal(t, proto.TaskStateSucceed, task2Base.State) + checkSubtaskConcurrency(t, c, theTask.ID, map[proto.Step]int{ + proto.StepOne: 7, + proto.StepTwo: 7, + }) + }) +} + +func checkSubtaskConcurrency(t *testing.T, c *testutil.TestDXFContext, taskID int64, expectedStepCon map[proto.Step]int) { + for step, con := range expectedStepCon { + subtasks, err := c.TaskMgr.GetSubtasksWithHistory(c.Ctx, taskID, step) + require.NoError(t, err) + require.Len(t, subtasks, 1) + require.Equal(t, con, subtasks[0].Concurrency) + } +} diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index 313a87ff121b7..5d614260d96e2 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -479,6 +479,20 @@ func (mr *MockTaskManagerMockRecorder) GetUsedSlotsOnNodes(arg0 any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsedSlotsOnNodes", reflect.TypeOf((*MockTaskManager)(nil).GetUsedSlotsOnNodes), arg0) } +// ModifiedTask mocks base method. +func (m *MockTaskManager) ModifiedTask(arg0 context.Context, arg1 *proto.Task) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ModifiedTask", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ModifiedTask indicates an expected call of ModifiedTask. +func (mr *MockTaskManagerMockRecorder) ModifiedTask(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ModifiedTask", reflect.TypeOf((*MockTaskManager)(nil).ModifiedTask), arg0, arg1) +} + // PauseTask mocks base method. func (m *MockTaskManager) PauseTask(arg0 context.Context, arg1 string) (bool, error) { m.ctrl.T.Helper() diff --git a/pkg/disttask/framework/proto/modify.go b/pkg/disttask/framework/proto/modify.go index 3b26b3c30e71a..d81ab62d9f32c 100644 --- a/pkg/disttask/framework/proto/modify.go +++ b/pkg/disttask/framework/proto/modify.go @@ -14,6 +14,8 @@ package proto +import "fmt" + // ModificationType is the type of task modification. type ModificationType string @@ -33,8 +35,18 @@ type ModifyParam struct { Modifications []Modification `json:"modifications"` } +// String implements fmt.Stringer interface. +func (p *ModifyParam) String() string { + return fmt.Sprintf("{prev_state: %s, modifications: %v}", p.PrevState, p.Modifications) +} + // Modification is one modification for task. type Modification struct { Type ModificationType `json:"type"` To int64 `json:"to"` } + +// String implements fmt.Stringer interface. +func (m Modification) String() string { + return fmt.Sprintf("{type: %s, to: %d}", m.Type, m.To) +} diff --git a/pkg/disttask/framework/proto/task.go b/pkg/disttask/framework/proto/task.go index 9a65e4c52b983..b3ca2c7d43427 100644 --- a/pkg/disttask/framework/proto/task.go +++ b/pkg/disttask/framework/proto/task.go @@ -185,6 +185,7 @@ type Task struct { // changed in below case, and framework will update the task meta in the storage. // - task switches to next step in Scheduler.OnNextSubtasksBatch // - on task cleanup, we might do some redaction on the meta. + // - on task 'modifying', params inside the meta can be changed. Meta []byte Error error ModifyParam ModifyParam diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index 23a3c1f1d9758..f48052f20caf3 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -49,10 +49,14 @@ type TaskManager interface { RevertedTask(ctx context.Context, taskID int64) error // PauseTask updated task state to pausing. PauseTask(ctx context.Context, taskKey string) (bool, error) - // PausedTask updated task state to paused. + // PausedTask updated task state to 'paused'. PausedTask(ctx context.Context, taskID int64) error // ResumedTask updated task state from resuming to running. ResumedTask(ctx context.Context, taskID int64) error + // ModifiedTask tries to update task concurrency and meta, and update state + // back to prev-state, if success, it will also update concurrency of all + // active subtasks. + ModifiedTask(ctx context.Context, task *proto.Task) error // SucceedTask updates a task to success state. SucceedTask(ctx context.Context, taskID int64) error // SwitchTaskStep switches the task to the next step and add subtasks in one diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index bbfd52bd827c1..fc09a48dabcac 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -402,9 +402,32 @@ func (s *BaseScheduler) onRunning() error { // onModifying is called when task is in modifying state. // the first return value indicates whether the scheduler should be recreated. -func (*BaseScheduler) onModifying() (bool, error) { - // TODO: implement me - panic("implement me") +func (s *BaseScheduler) onModifying() (bool, error) { + task := s.getTaskClone() + s.logger.Info("on modifying state", zap.Stringer("param", &task.ModifyParam)) + recreateScheduler := false + for _, m := range task.ModifyParam.Modifications { + if m.Type == proto.ModifyConcurrency { + if task.Concurrency == int(m.To) { + // shouldn't happen normally. + s.logger.Info("task concurrency not changed, skip", zap.Int("concurrency", task.Concurrency)) + continue + } + s.logger.Info("modify task concurrency", zap.Int("from", task.Concurrency), zap.Int64("to", m.To)) + recreateScheduler = true + task.Concurrency = int(m.To) + } else { + // will implement other modification types later. + s.logger.Warn("unsupported modification type", zap.Stringer("type", m.Type)) + } + } + if err := s.taskMgr.ModifiedTask(s.ctx, task); err != nil { + return false, errors.Trace(err) + } + task.State = task.ModifyParam.PrevState + task.ModifyParam = proto.ModifyParam{} + s.task.Store(task) + return recreateScheduler, nil } func (s *BaseScheduler) onFinished() { diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go b/pkg/disttask/framework/scheduler/scheduler_manager.go index 3c5f43e060cde..a87f01e03f5a6 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager.go @@ -213,6 +213,7 @@ func (sm *Manager) scheduleTaskLoop() { continue } + failpoint.InjectCall("beforeGetSchedulableTasks") schedulableTasks, err := sm.getSchedulableTasks() if err != nil { continue diff --git a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go index 141331a1b583d..181397bf501ba 100644 --- a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go @@ -489,4 +489,30 @@ func TestSchedulerMaintainTaskFields(t *testing.T) { require.Equal(t, *scheduler.getTaskClone(), tmpTask) require.True(t, ctrl.Satisfied()) }) + + t.Run("test on modifying", func(t *testing.T) { + taskBefore := schTask + taskBefore.State = proto.TaskStateModifying + taskBefore.ModifyParam = proto.ModifyParam{ + PrevState: proto.TaskStateRunning, + Modifications: []proto.Modification{ + {Type: proto.ModifyConcurrency, To: 123}, + }, + } + scheduler.task.Store(&taskBefore) + taskMgr.EXPECT().ModifiedTask(gomock.Any(), gomock.Any()).Return(fmt.Errorf("modify err")) + recreateScheduler, err := scheduler.onModifying() + require.ErrorContains(t, err, "modify err") + require.False(t, recreateScheduler) + + taskMgr.EXPECT().ModifiedTask(gomock.Any(), gomock.Any()).Return(nil) + recreateScheduler, err = scheduler.onModifying() + require.NoError(t, err) + require.True(t, recreateScheduler) + expectedTask := taskBefore + expectedTask.Concurrency = 123 + expectedTask.State = proto.TaskStateRunning + expectedTask.ModifyParam = proto.ModifyParam{} + require.Equal(t, *scheduler.GetTask(), expectedTask) + }) } diff --git a/pkg/disttask/framework/storage/task_state.go b/pkg/disttask/framework/storage/task_state.go index 00723b6c01d19..8e4c472791c14 100644 --- a/pkg/disttask/framework/storage/task_state.go +++ b/pkg/disttask/framework/storage/task_state.go @@ -199,6 +199,41 @@ func (mgr *TaskManager) ModifyTaskByID(ctx context.Context, taskID int64, param }) } +// ModifiedTask implements the scheduler.TaskManager interface. +func (mgr *TaskManager) ModifiedTask(ctx context.Context, task *proto.Task) error { + prevState := task.ModifyParam.PrevState + return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + failpoint.InjectCall("beforeModifiedTask") + _, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` + update mysql.tidb_global_task + set state = %?, + concurrency = %?, + meta = %?, + modify_params = null, + state_update_time = CURRENT_TIMESTAMP() + where id = %? and state = %?`, + prevState, task.Concurrency, task.Meta, task.ID, proto.TaskStateModifying, + ) + if err != nil { + return err + } + if se.GetSessionVars().StmtCtx.AffectedRows() == 0 { + // might be handled by other owner nodes, skip. + return nil + } + // subtask in final state are not changed. + // subtask might have different concurrency later, see TaskExecInfo, we + // need to handle it too, but ok for now. + _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` + update mysql.tidb_background_subtask + set concurrency = %?, state_update_time = unix_timestamp() + where task_key = %? and state in (%?, %?, %?)`, + task.Concurrency, task.ID, + proto.SubtaskStatePending, proto.SubtaskStateRunning, proto.SubtaskStatePaused) + return err + }) +} + // SucceedTask update task state from running to succeed. func (mgr *TaskManager) SucceedTask(ctx context.Context, taskID int64) error { return mgr.WithNewSession(func(se sessionctx.Context) error { diff --git a/pkg/disttask/framework/storage/task_state_test.go b/pkg/disttask/framework/storage/task_state_test.go index e4d8568db321e..0a515675d95ff 100644 --- a/pkg/disttask/framework/storage/task_state_test.go +++ b/pkg/disttask/framework/storage/task_state_test.go @@ -15,7 +15,10 @@ package storage_test import ( + "cmp" + "context" "errors" + "slices" "sync/atomic" "testing" @@ -26,6 +29,7 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/testkit/testfailpoint" tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/sqlexec" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" ) @@ -161,9 +165,14 @@ func TestModifyTask(t *testing.T) { task, err := gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStatePending, task.State) + subtasks := make([]*proto.Subtask, 0, 4) + for i := 0; i < 4; i++ { + subtasks = append(subtasks, proto.NewSubtask(proto.StepOne, task.ID, task.Type, + ":4000", task.Concurrency, proto.EmptyMeta, i+1)) + } wg.Run(func() { ch <- struct{}{} - require.NoError(t, gm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, nil)) + require.NoError(t, gm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, subtasks)) ch <- struct{}{} }) require.ErrorIs(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ @@ -175,11 +184,99 @@ func TestModifyTask(t *testing.T) { task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStateRunning, task.State) - require.NoError(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ + param := proto.ModifyParam{ PrevState: proto.TaskStateRunning, + Modifications: []proto.Modification{ + {Type: proto.ModifyConcurrency, To: 2}, + }, + } + require.NoError(t, gm.ModifyTaskByID(ctx, id, ¶m)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStateModifying, task.State) + require.Equal(t, param, task.ModifyParam) + + // modified + gotSubtasks, err := gm.GetSubtasksWithHistory(ctx, task.ID, proto.StepOne) + require.NoError(t, err) + slices.SortFunc(gotSubtasks, func(i, j *proto.Subtask) int { + return cmp.Compare(i.Ordinal, j.Ordinal) + }) + require.Len(t, gotSubtasks, len(subtasks)) + require.NoError(t, gm.FinishSubtask(ctx, gotSubtasks[0].ExecID, gotSubtasks[0].ID, nil)) + require.NoError(t, gm.StartSubtask(ctx, gotSubtasks[1].ID, gotSubtasks[1].ExecID)) + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), `update mysql.tidb_background_subtask set state='paused' where id=%?`, + gotSubtasks[2].ID) + return err })) + task.Concurrency = 2 + task.Meta = []byte("modified") + require.NoError(t, gm.ModifiedTask(ctx, task)) + checkTaskAfterModify(ctx, t, gm, task.ID, + 2, []byte("modified"), []int{4, 2, 2, 2}, + ) + + // task state changed before move to 'modified' + param = proto.ModifyParam{ + PrevState: proto.TaskStateRunning, + Modifications: []proto.Modification{ + {Type: proto.ModifyConcurrency, To: 3}, + }, + } + require.NoError(t, gm.ModifyTaskByID(ctx, id, ¶m)) task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStateModifying, task.State) - require.Equal(t, proto.TaskStateRunning, task.ModifyParam.PrevState) + require.Equal(t, param, task.ModifyParam) + ch = make(chan struct{}) + wg = tidbutil.WaitGroupWrapper{} + var called bool + testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/storage/beforeModifiedTask", func() { + if called { + return + } + called = true + <-ch + <-ch + }) + taskClone := *task + wg.Run(func() { + ch <- struct{}{} + // NOTE: this will NOT happen in real case, because the task can NOT move + // to 'modifying' again to change modify params. + // here just to show that if another client finishes modifying, our modify + // will skip silently. + taskClone.Concurrency = 5 + taskClone.Meta = []byte("modified-other") + require.NoError(t, gm.ModifiedTask(ctx, &taskClone)) + ch <- struct{}{} + }) + task.Concurrency = 3 + task.Meta = []byte("modified2") + require.NoError(t, gm.ModifiedTask(ctx, task)) + wg.Wait() + checkTaskAfterModify(ctx, t, gm, task.ID, + 5, []byte("modified-other"), []int{4, 5, 5, 5}, + ) +} + +func checkTaskAfterModify( + ctx context.Context, t *testing.T, gm *storage.TaskManager, taskID int64, + expectConcurrency int, expectedMeta []byte, expectedSTConcurrencies []int) { + task, err := gm.GetTaskByID(ctx, taskID) + require.NoError(t, err) + require.Equal(t, proto.TaskStateRunning, task.State) + require.Equal(t, expectConcurrency, task.Concurrency) + require.Equal(t, expectedMeta, task.Meta) + require.Equal(t, proto.ModifyParam{}, task.ModifyParam) + gotSubtasks, err := gm.GetSubtasksWithHistory(ctx, task.ID, proto.StepOne) + require.NoError(t, err) + require.Len(t, gotSubtasks, len(expectedSTConcurrencies)) + slices.SortFunc(gotSubtasks, func(i, j *proto.Subtask) int { + return cmp.Compare(i.Ordinal, j.Ordinal) + }) + for i, expected := range expectedSTConcurrencies { + require.Equal(t, expected, gotSubtasks[i].Concurrency) + } } diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index f3e3dd86ee969..28d6703024fbe 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -81,7 +81,8 @@ type TaskExecInfo struct { *proto.TaskBase // SubtaskConcurrency is the concurrency of subtask in current task step. // TODO: will be used when support subtask have smaller concurrency than task, - // TODO: such as post-process of import-into. + // TODO: such as post-process of import-into. Also remember the 'modifying' state + // also update subtask concurrency. // TODO: we might need create one task executor for each step in this case, to alloc // TODO: minimal resource SubtaskConcurrency int diff --git a/pkg/disttask/importinto/mock/import_mock.go b/pkg/disttask/importinto/mock/import_mock.go index 7be883ce4e346..28eec32a4c712 100644 --- a/pkg/disttask/importinto/mock/import_mock.go +++ b/pkg/disttask/importinto/mock/import_mock.go @@ -13,7 +13,7 @@ import ( context "context" reflect "reflect" - "github.com/pingcap/tidb/pkg/lightning/backend" + backend "github.com/pingcap/tidb/pkg/lightning/backend" gomock "go.uber.org/mock/gomock" ) diff --git a/pkg/util/sqlexec/mock/restricted_sql_executor_mock.go b/pkg/util/sqlexec/mock/restricted_sql_executor_mock.go index 0ee7f88c8b64b..cb4e5e244aab7 100644 --- a/pkg/util/sqlexec/mock/restricted_sql_executor_mock.go +++ b/pkg/util/sqlexec/mock/restricted_sql_executor_mock.go @@ -14,7 +14,7 @@ import ( reflect "reflect" ast "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/planner/core/resolve" + resolve "github.com/pingcap/tidb/pkg/planner/core/resolve" chunk "github.com/pingcap/tidb/pkg/util/chunk" sqlexec "github.com/pingcap/tidb/pkg/util/sqlexec" gomock "go.uber.org/mock/gomock"