diff --git a/disttask/framework/mock/scheduler_mock.go b/disttask/framework/mock/scheduler_mock.go index ab9f805f3616e..cd885ef46af4d 100644 --- a/disttask/framework/mock/scheduler_mock.go +++ b/disttask/framework/mock/scheduler_mock.go @@ -153,17 +153,17 @@ func (mr *MockTaskTableMockRecorder) StartSubtask(arg0 interface{}) *gomock.Call } // UpdateErrorToSubtask mocks base method. -func (m *MockTaskTable) UpdateErrorToSubtask(arg0 string, arg1 error) error { +func (m *MockTaskTable) UpdateErrorToSubtask(arg0 string, arg1 int64, arg2 error) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateErrorToSubtask", arg0, arg1) + ret := m.ctrl.Call(m, "UpdateErrorToSubtask", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // UpdateErrorToSubtask indicates an expected call of UpdateErrorToSubtask. -func (mr *MockTaskTableMockRecorder) UpdateErrorToSubtask(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockTaskTableMockRecorder) UpdateErrorToSubtask(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateErrorToSubtask", reflect.TypeOf((*MockTaskTable)(nil).UpdateErrorToSubtask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateErrorToSubtask", reflect.TypeOf((*MockTaskTable)(nil).UpdateErrorToSubtask), arg0, arg1, arg2) } // UpdateSubtaskStateAndError mocks base method. diff --git a/disttask/framework/scheduler/interface.go b/disttask/framework/scheduler/interface.go index 2a4f79108bf83..6da317e6e3a4e 100644 --- a/disttask/framework/scheduler/interface.go +++ b/disttask/framework/scheduler/interface.go @@ -26,12 +26,12 @@ type TaskTable interface { GetGlobalTaskByID(taskID int64) (task *proto.Task, err error) GetSubtaskInStates(instanceID string, taskID int64, step int64, states ...interface{}) (*proto.Subtask, error) - StartSubtask(id int64) error - UpdateSubtaskStateAndError(id int64, state string, err error) error - FinishSubtask(id int64, meta []byte) error + StartSubtask(subtaskID int64) error + UpdateSubtaskStateAndError(subtaskID int64, state string, err error) error + FinishSubtask(subtaskID int64, meta []byte) error HasSubtasksInStates(instanceID string, taskID int64, step int64, states ...interface{}) (bool, error) - UpdateErrorToSubtask(tidbID string, err error) error - IsSchedulerCanceled(taskID int64, execID string) (bool, error) + UpdateErrorToSubtask(instanceID string, taskID int64, err error) error + IsSchedulerCanceled(taskID int64, instanceID string) (bool, error) } // Pool defines the interface of a pool. diff --git a/disttask/framework/scheduler/scheduler.go b/disttask/framework/scheduler/scheduler.go index 796c93ab18f87..42276393b3330 100644 --- a/disttask/framework/scheduler/scheduler.go +++ b/disttask/framework/scheduler/scheduler.go @@ -103,7 +103,7 @@ func (s *InternalSchedulerImpl) Run(ctx context.Context, task *proto.Task) error if s.mu.handled { return err } - return s.taskTable.UpdateErrorToSubtask(s.id, err) + return s.taskTable.UpdateErrorToSubtask(s.id, task.ID, err) } func (s *InternalSchedulerImpl) run(ctx context.Context, task *proto.Task) error { @@ -477,8 +477,8 @@ func (s *InternalSchedulerImpl) startSubtask(id int64) { } } -func (s *InternalSchedulerImpl) updateSubtaskStateAndError(id int64, state string, subTaskErr error) { - err := s.taskTable.UpdateSubtaskStateAndError(id, state, subTaskErr) +func (s *InternalSchedulerImpl) updateSubtaskStateAndError(subtaskID int64, state string, subTaskErr error) { + err := s.taskTable.UpdateSubtaskStateAndError(subtaskID, state, subTaskErr) if err != nil { s.onError(err) } diff --git a/disttask/framework/scheduler/scheduler_test.go b/disttask/framework/scheduler/scheduler_test.go index 5664c40c6e19b..db4248e1d7319 100644 --- a/disttask/framework/scheduler/scheduler_test.go +++ b/disttask/framework/scheduler/scheduler_test.go @@ -79,10 +79,10 @@ func TestSchedulerRun(t *testing.T) { // UpdateErrorToSubtask won't return such errors, but since the error is not handled, // it's saved by UpdateErrorToSubtask. // here we use this to check the returned error of s.run. - forwardErrFn := func(_ string, err error) error { + forwardErrFn := func(_ string, _ int64, err error) error { return err } - mockSubtaskTable.EXPECT().UpdateErrorToSubtask(gomock.Any(), gomock.Any()).DoAndReturn(forwardErrFn).AnyTimes() + mockSubtaskTable.EXPECT().UpdateErrorToSubtask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(forwardErrFn).AnyTimes() err := scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp}) require.EqualError(t, err, schedulerRegisterErr.Error()) diff --git a/disttask/framework/storage/table_test.go b/disttask/framework/storage/table_test.go index 3e06bc56b71dc..9b539354f6d48 100644 --- a/disttask/framework/storage/table_test.go +++ b/disttask/framework/storage/table_test.go @@ -244,7 +244,7 @@ func TestSubTaskTable(t *testing.T) { // test UpdateErrorToSubtask do update start/update time err = sm.AddNewSubTask(3, proto.StepInit, "for_test", []byte("test"), proto.TaskTypeExample, false) require.NoError(t, err) - require.NoError(t, sm.UpdateErrorToSubtask("for_test", errors.New("fail"))) + require.NoError(t, sm.UpdateErrorToSubtask("for_test", 3, errors.New("fail"))) subtask, err = sm.GetSubtaskInStates("for_test", 3, proto.StepInit, proto.TaskStateFailed) require.NoError(t, err) require.Equal(t, proto.TaskStateFailed, subtask.State) diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go index f2771371363ec..97a029b78563d 100644 --- a/disttask/framework/storage/task_table.go +++ b/disttask/framework/storage/task_table.go @@ -338,14 +338,14 @@ func (stm *TaskManager) GetSubtaskInStates(tidbID string, taskID int64, step int } // UpdateErrorToSubtask updates the error to subtask. -func (stm *TaskManager) UpdateErrorToSubtask(tidbID string, err error) error { +func (stm *TaskManager) UpdateErrorToSubtask(tidbID string, taskID int64, err error) error { if err == nil { return nil } _, err1 := stm.executeSQLWithNewSession(stm.ctx, `update mysql.tidb_background_subtask set state = %?, error = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp() - where exec_id = %? and state = %? limit 1;`, - proto.TaskStateFailed, serializeErr(err), tidbID, proto.TaskStatePending) + where exec_id = %? and task_key = %? and state = %? limit 1;`, + proto.TaskStateFailed, serializeErr(err), tidbID, taskID, proto.TaskStatePending) return err1 } @@ -469,11 +469,11 @@ func (stm *TaskManager) HasSubtasksInStates(tidbID string, taskID int64, step in } // StartSubtask updates the subtask state to running. -func (stm *TaskManager) StartSubtask(id int64) error { +func (stm *TaskManager) StartSubtask(subtaskID int64) error { _, err := stm.executeSQLWithNewSession(stm.ctx, `update mysql.tidb_background_subtask set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp() where id = %?`, - proto.TaskStateRunning, id) + proto.TaskStateRunning, subtaskID) return err }