From f975ad500f98adc4c5ee3bb7fc14c9238198583c Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Mon, 27 Mar 2023 01:57:59 +0800 Subject: [PATCH 1/8] disttask: add dispatcher for load data --- br/pkg/lightning/config/const.go | 8 +- ddl/disttask_flow.go | 4 +- ddl/disttask_flow_test.go | 13 +- disttask/framework/dispatcher/dispatcher.go | 4 +- .../framework/dispatcher/dispatcher_test.go | 4 +- disttask/framework/dispatcher/register.go | 6 +- disttask/loaddata/BUILD.bazel | 14 +- disttask/loaddata/dispatcher.go | 107 ++++++++++ disttask/loaddata/proto.go | 25 ++- disttask/loaddata/wrapper.go | 97 +++++++++ disttask/loaddata/wrapper_test.go | 186 ++++++++++++++++++ 11 files changed, 436 insertions(+), 32 deletions(-) create mode 100644 disttask/loaddata/dispatcher.go create mode 100644 disttask/loaddata/wrapper.go create mode 100644 disttask/loaddata/wrapper_test.go diff --git a/br/pkg/lightning/config/const.go b/br/pkg/lightning/config/const.go index 3fc2a63c2ac0d..82a7052cc4fd1 100644 --- a/br/pkg/lightning/config/const.go +++ b/br/pkg/lightning/config/const.go @@ -25,7 +25,6 @@ import ( const ( // mydumper ReadBlockSize ByteSize = 64 * units.KiB - MaxRegionSize ByteSize = 256 * units.MiB // See: https://github.com/tikv/tikv/blob/e030a0aae9622f3774df89c62f21b2171a72a69e/etc/config-template.toml#L360 // lower the max-key-count to avoid tikv trigger region auto split SplitRegionSize ByteSize = 96 * units.MiB @@ -33,8 +32,6 @@ const ( MaxSplitRegionSizeRatio int = 10 defaultMaxAllowedPacket = 64 * units.MiB - - DefaultBatchSize ByteSize = 100 * units.GiB ) var ( @@ -44,5 +41,8 @@ var ( PermitWithoutStream: false, }) // BufferSizeScale is the factor of block buffer size - BufferSizeScale = int64(5) + BufferSizeScale = int64(5) + DefaultBatchSize ByteSize = 100 * units.GiB + // mydumper + MaxRegionSize ByteSize = 256 * units.MiB ) diff --git a/ddl/disttask_flow.go b/ddl/disttask_flow.go index 2ac957a2dd7d3..3912787becb48 100644 --- a/ddl/disttask_flow.go +++ b/ddl/disttask_flow.go @@ -57,7 +57,7 @@ func NewLitBackfillFlowHandle(getDDL func() DDL) dispatcher.TaskFlowHandle { } // ProcessNormalFlow processes the normal flow. -func (h *litBackfillFlowHandle) ProcessNormalFlow(_ dispatcher.Dispatch, gTask *proto.Task) (metas [][]byte, err error) { +func (h *litBackfillFlowHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.Dispatch, gTask *proto.Task) (metas [][]byte, err error) { if gTask.State != proto.TaskStatePending { // This flow has only one step, finish task when it is not pending return nil, nil @@ -108,7 +108,7 @@ func (h *litBackfillFlowHandle) ProcessNormalFlow(_ dispatcher.Dispatch, gTask * return subTaskMetas, nil } -func (*litBackfillFlowHandle) ProcessErrFlow(_ dispatcher.Dispatch, _ *proto.Task, _ string) (meta []byte, err error) { +func (*litBackfillFlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Dispatch, _ *proto.Task, _ string) (meta []byte, err error) { // We do not need extra meta info when rolling back return nil, nil } diff --git a/ddl/disttask_flow_test.go b/ddl/disttask_flow_test.go index 145a1912af605..6dbd8f5208469 100644 --- a/ddl/disttask_flow_test.go +++ b/ddl/disttask_flow_test.go @@ -15,6 +15,7 @@ package ddl_test import ( + "context" "encoding/json" "testing" "time" @@ -48,7 +49,7 @@ func TestBackfillFlowHandle(t *testing.T) { tbl, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("tp1")) require.NoError(t, err) tblInfo := tbl.Meta() - metas, err := handler.ProcessNormalFlow(nil, gTask) + metas, err := handler.ProcessNormalFlow(context.Background(), nil, gTask) require.NoError(t, err) require.Equal(t, proto.StepOne, gTask.Step) require.Equal(t, len(tblInfo.Partition.Definitions), len(metas)) @@ -60,18 +61,18 @@ func TestBackfillFlowHandle(t *testing.T) { // test partition table ProcessNormalFlow after step1 finished gTask.State = proto.TaskStateRunning - metas, err = handler.ProcessNormalFlow(nil, gTask) + metas, err = handler.ProcessNormalFlow(context.Background(), nil, gTask) require.NoError(t, err) require.Equal(t, 0, len(metas)) // test partition table ProcessErrFlow - errMeta, err := handler.ProcessErrFlow(nil, gTask, "mockErr") + errMeta, err := handler.ProcessErrFlow(context.Background(), nil, gTask, "mockErr") require.NoError(t, err) require.Nil(t, errMeta) // test merging index gTask = createAddIndexGlobalTask(t, dom, "test", "tp1", ddl.FlowHandleLitMergeType) - metas, err = handler.ProcessNormalFlow(nil, gTask) + metas, err = handler.ProcessNormalFlow(context.Background(), nil, gTask) require.NoError(t, err) require.Equal(t, proto.StepOne, gTask.Step) require.Equal(t, len(tblInfo.Partition.Definitions), len(metas)) @@ -81,14 +82,14 @@ func TestBackfillFlowHandle(t *testing.T) { require.Equal(t, par.ID, subTask.PhysicalTableID) } - errMeta, err = handler.ProcessErrFlow(nil, gTask, "mockErr") + errMeta, err = handler.ProcessErrFlow(context.Background(), nil, gTask, "mockErr") require.NoError(t, err) require.Nil(t, errMeta) // test normal table not supported yet tk.MustExec("create table t1(id int primary key, v int)") gTask = createAddIndexGlobalTask(t, dom, "test", "t1", ddl.FlowHandleLitBackfillType) - _, err = handler.ProcessNormalFlow(nil, gTask) + _, err = handler.ProcessNormalFlow(context.Background(), nil, gTask) require.EqualError(t, err, "Non-partition table not supported yet") } diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 39d81aef1b09a..66cbfb973508f 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -246,7 +246,7 @@ func (d *dispatcher) updateTaskRevertInfo(gTask *proto.Task) { func (d *dispatcher) processErrFlow(gTask *proto.Task, receiveErr string) error { // TODO: Maybe it gets GetTaskFlowHandle fails when rolling upgrades. - meta, err := GetTaskFlowHandle(gTask.Type).ProcessErrFlow(d, gTask, receiveErr) + meta, err := GetTaskFlowHandle(gTask.Type).ProcessErrFlow(d.ctx, d, gTask, receiveErr) if err != nil { logutil.BgLogger().Warn("handle error failed", zap.Error(err)) return err @@ -292,7 +292,7 @@ func (d *dispatcher) processNormalFlow(gTask *proto.Task) (err error) { d.updateTaskRevertInfo(gTask) return errors.Errorf("%s type handle doesn't register", gTask.Type) } - metas, err := handle.ProcessNormalFlow(d, gTask) + metas, err := handle.ProcessNormalFlow(d.ctx, d, gTask) if err != nil { logutil.BgLogger().Warn("gen dist-plan failed", zap.Error(err)) return err diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index 6ec994d950c70..2ba049c5f949b 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -191,7 +191,7 @@ const taskTypeExample = "task_example" type NumberExampleHandle struct { } -func (n NumberExampleHandle) ProcessNormalFlow(_ dispatcher.Dispatch, gTask *proto.Task) (metas [][]byte, err error) { +func (n NumberExampleHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.Dispatch, gTask *proto.Task) (metas [][]byte, err error) { if gTask.State == proto.TaskStatePending { gTask.Step = proto.StepInit } @@ -211,7 +211,7 @@ func (n NumberExampleHandle) ProcessNormalFlow(_ dispatcher.Dispatch, gTask *pro return metas, nil } -func (n NumberExampleHandle) ProcessErrFlow(_ dispatcher.Dispatch, _ *proto.Task, _ string) (meta []byte, err error) { +func (n NumberExampleHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Dispatch, _ *proto.Task, _ string) (meta []byte, err error) { // Don't handle not. return nil, nil } diff --git a/disttask/framework/dispatcher/register.go b/disttask/framework/dispatcher/register.go index d7ab51e8859ed..daffa6baa6657 100644 --- a/disttask/framework/dispatcher/register.go +++ b/disttask/framework/dispatcher/register.go @@ -15,14 +15,16 @@ package dispatcher import ( + "context" + "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/util/syncutil" ) // TaskFlowHandle is used to control the process operations for each global task. type TaskFlowHandle interface { - ProcessNormalFlow(d Dispatch, gTask *proto.Task) (metas [][]byte, err error) - ProcessErrFlow(d Dispatch, gTask *proto.Task, receive string) (meta []byte, err error) + ProcessNormalFlow(ctx context.Context, d Dispatch, gTask *proto.Task) (metas [][]byte, err error) + ProcessErrFlow(ctx context.Context, d Dispatch, gTask *proto.Task, receive string) (meta []byte, err error) } var taskFlowHandleMap struct { diff --git a/disttask/loaddata/BUILD.bazel b/disttask/loaddata/BUILD.bazel index b064d040a1416..bdd5c65dbd686 100644 --- a/disttask/loaddata/BUILD.bazel +++ b/disttask/loaddata/BUILD.bazel @@ -2,14 +2,26 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "loaddata", - srcs = ["proto.go"], + srcs = [ + "dispatcher.go", + "proto.go", + "wrapper.go", + ], importpath = "github.com/pingcap/tidb/disttask/loaddata", visibility = ["//visibility:public"], deps = [ "//br/pkg/lightning/backend", "//br/pkg/lightning/config", "//br/pkg/lightning/mydump", + "//br/pkg/storage", + "//disttask/framework/dispatcher", + "//disttask/framework/proto", + "//executor/importer", "//parser/model", "//parser/mysql", + "//util/intest", + "//util/logutil", + "@com_github_pingcap_errors//:errors", + "@org_uber_go_zap//:zap", ], ) diff --git a/disttask/loaddata/dispatcher.go b/disttask/loaddata/dispatcher.go new file mode 100644 index 0000000000000..64ce0a6504098 --- /dev/null +++ b/disttask/loaddata/dispatcher.go @@ -0,0 +1,107 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loaddata + +import ( + "context" + "encoding/json" + + "github.com/pingcap/tidb/disttask/framework/dispatcher" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" +) + +// Dispatcher is the dispatcher for load data. +type Dispatcher struct{} + +// ProcessNormalFlow implements dispatcher.TaskFlowHandle interface. +func (*Dispatcher) ProcessNormalFlow(ctx context.Context, dispatch dispatcher.Dispatch, gTask *proto.Task) ([][]byte, error) { + taskMeta := &TaskMeta{} + err := json.Unmarshal(gTask.Meta, taskMeta) + if err != nil { + return nil, err + } + logutil.BgLogger().Info("process normal flow", zap.Any("task_meta", taskMeta), zap.Any("step", gTask.Step)) + + switch gTask.Step { + case Import: + gTask.State = proto.TaskStateSucceed + return nil, nil + default: + } + + instances, err := dispatch.GetTaskAllInstances(ctx, gTask.ID) + if err != nil { + return nil, err + } + subtaskMetas, err := generateSubtaskMetas(ctx, taskMeta, len(instances)) + if err != nil { + return nil, err + } + logutil.BgLogger().Info("generate subtasks", zap.Any("subtask_metas", subtaskMetas)) + metaBytes := make([][]byte, 0, len(taskMeta.FileInfos)) + for _, subtaskMeta := range subtaskMetas { + bs, err := json.Marshal(subtaskMeta) + if err != nil { + return nil, err + } + metaBytes = append(metaBytes, bs) + } + gTask.Step = Import + return metaBytes, nil +} + +// ProcessErrFlow implements dispatcher.ProcessErrFlow interface. +func (*Dispatcher) ProcessErrFlow(_ context.Context, _ dispatcher.Dispatch, _ *proto.Task, errMsg string) ([]byte, error) { + logutil.BgLogger().Info("process error flow", zap.String("error message", errMsg)) + return nil, nil +} + +func generateSubtaskMetas(ctx context.Context, task *TaskMeta, concurrency int) ([]*SubtaskMeta, error) { + tableRegions, err := makeTableRegions(ctx, task, concurrency) + if err != nil { + return nil, err + } + + engineMap := make(map[int32]int) + subtasks := make([]*SubtaskMeta, 0) + for _, region := range tableRegions { + idx, ok := engineMap[region.EngineID] + if !ok { + idx = len(subtasks) + engineMap[region.EngineID] = idx + subtasks = append(subtasks, &SubtaskMeta{ + Table: task.Table, + Format: task.Format, + Dir: task.Dir, + }) + } + subtask := subtasks[idx] + subtask.Chunks = append(subtask.Chunks, Chunk{ + Path: region.FileMeta.Path, + Offset: region.Chunk.Offset, + EndOffset: region.Chunk.EndOffset, + RealOffset: region.Chunk.RealOffset, + PrevRowIDMax: region.Chunk.PrevRowIDMax, + RowIDMax: region.Chunk.RowIDMax, + }) + } + return subtasks, nil +} + +func init() { + dispatcher.RegisterTaskFlowHandle(proto.LoadData, &Dispatcher{}) +} diff --git a/disttask/loaddata/proto.go b/disttask/loaddata/proto.go index ae920b35dd9b7..3e9e68487968f 100644 --- a/disttask/loaddata/proto.go +++ b/disttask/loaddata/proto.go @@ -24,29 +24,29 @@ import ( // TaskStep of LoadData. const ( - ReadSortImport = 1 + Import = 1 ) -// Task is the task of LoadData. -type Task struct { +// TaskMeta is the task of LoadData. +type TaskMeta struct { Table Table Format Format Dir string FileInfos []FileInfo } -// Subtask is the subtask of LoadData. +// SubtaskMeta is the subtask of LoadData. // Dispatcher will split the task into subtasks(FileInfos -> Chunks) -type Subtask struct { +type SubtaskMeta struct { Table Table Format Format Dir string Chunks []Chunk } -// MinimalTask is the minimal task of LoadData. +// MinimalTaskMeta is the minimal task of LoadData. // Scheduler will split the subtask into minimal tasks(Chunks -> Chunk) -type MinimalTask struct { +type MinimalTaskMeta struct { Table Table Format Format Dir string @@ -55,13 +55,14 @@ type MinimalTask struct { } // IsMinimalTask implements the MinimalTask interface. -func (MinimalTask) IsMinimalTask() {} +func (MinimalTaskMeta) IsMinimalTask() {} // Table records the table information. type Table struct { DBName string Info *model.TableInfo TargetColumns []string + IsRowOrdered bool } // Format records the format information. @@ -77,15 +78,13 @@ type Format struct { // CSV records the CSV format information. type CSV struct { - Config *config.CSVConfig - LoadDataReadBlockSize int64 - Strict bool + Config config.CSVConfig + Strict bool } // SQLDump records the SQL dump format information. type SQLDump struct { - SQLMode mysql.SQLMode - LoadDataReadBlockSize int64 + SQLMode mysql.SQLMode } // Parquet records the Parquet format information. diff --git a/disttask/loaddata/wrapper.go b/disttask/loaddata/wrapper.go new file mode 100644 index 0000000000000..4aaaef5949fff --- /dev/null +++ b/disttask/loaddata/wrapper.go @@ -0,0 +1,97 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loaddata + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/lightning/mydump" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/util/intest" +) + +func makeTableRegions(ctx context.Context, task *TaskMeta, concurrency int) ([]*mydump.TableRegion, error) { + if concurrency <= 0 { + return nil, errors.Errorf("concurrency must be greater than 0, but got %d", concurrency) + } + + b, err := storage.ParseBackend(task.Dir, nil) + if err != nil { + return nil, err + } + + opt := &storage.ExternalStorageOptions{} + if intest.InTest { + opt.NoCredentials = true + } + store, err := storage.New(ctx, b, opt) + if err != nil { + return nil, err + } + + meta := &mydump.MDTableMeta{ + DB: task.Table.DBName, + Name: task.Table.Info.Name.String(), + IsRowOrdered: task.Table.IsRowOrdered, + } + + sourceType, err := transformSourceType(task.Format.Type) + if err != nil { + return nil, err + } + for _, file := range task.FileInfos { + meta.DataFiles = append(meta.DataFiles, mydump.FileInfo{ + FileMeta: mydump.SourceFileMeta{ + Path: file.Path, + Type: sourceType, + FileSize: file.Size, + RealSize: file.RealSize, + Compression: task.Format.Compression, + }, + }) + } + cfg := &config.Config{ + App: config.Lightning{ + RegionConcurrency: concurrency, + TableConcurrency: concurrency, + }, + Mydumper: config.MydumperRuntime{ + CSV: task.Format.CSV.Config, + StrictFormat: task.Format.CSV.Strict, + MaxRegionSize: config.MaxRegionSize, + ReadBlockSize: config.ReadBlockSize, + // uniform distribution + BatchImportRatio: 0, + }, + } + + return mydump.MakeTableRegions(ctx, meta, len(task.Table.TargetColumns), cfg, nil, store) +} + +func transformSourceType(tp string) (mydump.SourceType, error) { + switch tp { + case importer.LoadDataFormatParquet: + return mydump.SourceTypeParquet, nil + case importer.LoadDataFormatDelimitedData: + return mydump.SourceTypeCSV, nil + case importer.LoadDataFormatSQLDump: + return mydump.SourceTypeSQL, nil + default: + return mydump.SourceTypeIgnore, errors.Errorf("unknown source type: %s", tp) + } +} diff --git a/disttask/loaddata/wrapper_test.go b/disttask/loaddata/wrapper_test.go new file mode 100644 index 0000000000000..ce20acebd43ad --- /dev/null +++ b/disttask/loaddata/wrapper_test.go @@ -0,0 +1,186 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loaddata + +import ( + "context" + "os" + "path/filepath" + + //"os" + //"path/filepath" + "testing" + + "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/lightning/mydump" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/model" + "github.com/stretchr/testify/require" +) + +func TestTransformSourceType(t *testing.T) { + testCases := []struct { + tp string + expected mydump.SourceType + }{ + { + tp: importer.LoadDataFormatParquet, + expected: mydump.SourceTypeParquet, + }, + { + tp: importer.LoadDataFormatSQLDump, + expected: mydump.SourceTypeSQL, + }, + { + tp: importer.LoadDataFormatDelimitedData, + expected: mydump.SourceTypeCSV, + }, + } + for _, tc := range testCases { + expected, err := transformSourceType(tc.tp) + require.NoError(t, err) + require.Equal(t, tc.expected, expected) + } + expected, err := transformSourceType("unknown") + require.EqualError(t, err, "unknown source type: unknown") + require.Equal(t, mydump.SourceTypeIgnore, expected) +} + +func TestMakeTableRegions(t *testing.T) { + regions, err := makeTableRegions(context.Background(), &TaskMeta{}, 0) + require.EqualError(t, err, "concurrency must be greater than 0, but got 0") + require.Nil(t, regions) + + regions, err = makeTableRegions(context.Background(), &TaskMeta{}, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "empty store is not allowed") + require.Nil(t, regions) + + task := &TaskMeta{ + Table: Table{ + Info: &model.TableInfo{ + Name: model.NewCIStr("test"), + }, + }, + Dir: "dir", + } + regions, err = makeTableRegions(context.Background(), task, 1) + require.EqualError(t, err, "unknown source type: ") + require.Nil(t, regions) + + // parquet + dir := "../../br/pkg/lightning/mydump/parquet" + filename := "000000_0.parquet" + task = &TaskMeta{ + Table: Table{ + Info: &model.TableInfo{ + Name: model.NewCIStr("test"), + }, + TargetColumns: []string{"a", "b"}, + }, + Format: Format{ + Type: importer.LoadDataFormatParquet, + }, + Dir: dir, + FileInfos: []FileInfo{ + { + Path: filename, + }, + }, + } + regions, err = makeTableRegions(context.Background(), task, 1) + require.NoError(t, err) + require.Len(t, regions, 1) + require.Equal(t, regions[0].EngineID, int32(0)) + require.Equal(t, regions[0].Chunk.Offset, int64(0)) + require.Equal(t, regions[0].Chunk.EndOffset, int64(5)) + require.Equal(t, regions[0].Chunk.PrevRowIDMax, int64(0)) + require.Equal(t, regions[0].Chunk.RowIDMax, int64(5)) + + // large csv + originRegionSize := config.MaxRegionSize + config.MaxRegionSize = 5 + originBatchSize := config.DefaultBatchSize + config.DefaultBatchSize = 12 + defer func() { + config.MaxRegionSize = originRegionSize + config.DefaultBatchSize = originBatchSize + }() + dir = "../../br/pkg/lightning/mydump/csv" + filename = "split_large_file.csv" + dataFileInfo, err := os.Stat(filepath.Join(dir, filename)) + require.NoError(t, err) + task = &TaskMeta{ + Table: Table{ + Info: &model.TableInfo{ + Name: model.NewCIStr("test"), + }, + TargetColumns: []string{"a", "b", "c"}, + }, + Format: Format{ + Type: importer.LoadDataFormatDelimitedData, + CSV: CSV{ + Config: config.CSVConfig{ + Separator: ",", + Delimiter: "", + Header: true, + HeaderSchemaMatch: true, + TrimLastSep: false, + NotNull: false, + Null: []string{"NULL"}, + EscapedBy: `\`, + }, + Strict: true, + }, + }, + Dir: dir, + FileInfos: []FileInfo{ + { + Path: filename, + Size: dataFileInfo.Size(), + RealSize: dataFileInfo.Size(), + }, + }, + } + regions, err = makeTableRegions(context.Background(), task, 1) + require.NoError(t, err) + require.Len(t, regions, 4) + chunks := []Chunk{{Offset: 6, EndOffset: 12}, {Offset: 12, EndOffset: 18}, {Offset: 18, EndOffset: 24}, {Offset: 24, EndOffset: 30}} + for i, region := range regions { + require.Equal(t, region.EngineID, int32(i/2)) + require.Equal(t, region.Chunk.Offset, chunks[i].Offset) + require.Equal(t, region.Chunk.EndOffset, chunks[i].EndOffset) + require.Equal(t, region.Chunk.RealOffset, int64(0)) + require.Equal(t, region.Chunk.PrevRowIDMax, int64(i)) + require.Equal(t, region.Chunk.RowIDMax, int64(i+1)) + } + + // compression + filename = "split_large_file.csv.zst" + dataFileInfo, err = os.Stat(filepath.Join(dir, filename)) + require.NoError(t, err) + task.FileInfos[0].Path = filename + task.FileInfos[0].Size = dataFileInfo.Size() + task.Format.Compression = mydump.CompressionZStd + regions, err = makeTableRegions(context.Background(), task, 1) + require.NoError(t, err) + require.Len(t, regions, 1) + require.Equal(t, regions[0].EngineID, int32(0)) + require.Equal(t, regions[0].Chunk.Offset, int64(0)) + require.Equal(t, regions[0].Chunk.EndOffset, mydump.TableFileSizeINF) + require.Equal(t, regions[0].Chunk.RealOffset, int64(0)) + require.Equal(t, regions[0].Chunk.PrevRowIDMax, int64(0)) + require.Equal(t, regions[0].Chunk.RowIDMax, int64(50)) +} From 5ad5939f6f15d24adda66b18cd32f22a6de6b754 Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Mon, 27 Mar 2023 02:06:55 +0800 Subject: [PATCH 2/8] update --- disttask/loaddata/BUILD.bazel | 15 ++++++++++++++- disttask/loaddata/wrapper_test.go | 3 --- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/disttask/loaddata/BUILD.bazel b/disttask/loaddata/BUILD.bazel index bdd5c65dbd686..64536b739f278 100644 --- a/disttask/loaddata/BUILD.bazel +++ b/disttask/loaddata/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "loaddata", @@ -25,3 +25,16 @@ go_library( "@org_uber_go_zap//:zap", ], ) + +go_test( + name = "loaddata_test", + srcs = ["wrapper_test.go"], + embed = [":loaddata"], + deps = [ + "//br/pkg/lightning/config", + "//br/pkg/lightning/mydump", + "//executor/importer", + "//parser/model", + "@com_github_stretchr_testify//require", + ], +) diff --git a/disttask/loaddata/wrapper_test.go b/disttask/loaddata/wrapper_test.go index ce20acebd43ad..115c6a5d7f1f8 100644 --- a/disttask/loaddata/wrapper_test.go +++ b/disttask/loaddata/wrapper_test.go @@ -18,9 +18,6 @@ import ( "context" "os" "path/filepath" - - //"os" - //"path/filepath" "testing" "github.com/pingcap/tidb/br/pkg/lightning/config" From d007732728d9d1f7662427b72b8e6db00b4b4ca6 Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Mon, 27 Mar 2023 13:26:31 +0800 Subject: [PATCH 3/8] update bazel --- disttask/loaddata/BUILD.bazel | 2 ++ 1 file changed, 2 insertions(+) diff --git a/disttask/loaddata/BUILD.bazel b/disttask/loaddata/BUILD.bazel index 64536b739f278..f3d17483248f3 100644 --- a/disttask/loaddata/BUILD.bazel +++ b/disttask/loaddata/BUILD.bazel @@ -28,8 +28,10 @@ go_library( go_test( name = "loaddata_test", + timeout = "short", srcs = ["wrapper_test.go"], embed = [":loaddata"], + flaky = True, deps = [ "//br/pkg/lightning/config", "//br/pkg/lightning/mydump", From 34665c41abf8f0079cbe88c3588dfc004a003278 Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Mon, 27 Mar 2023 15:18:47 +0800 Subject: [PATCH 4/8] add ut --- ddl/disttask_flow.go | 4 +- disttask/framework/dispatcher/BUILD.bazel | 2 + disttask/framework/dispatcher/dispatcher.go | 6 + .../framework/dispatcher/dispatcher_mock.go | 35 ++++++ .../framework/dispatcher/dispatcher_test.go | 4 +- disttask/framework/dispatcher/register.go | 4 +- disttask/loaddata/BUILD.bazel | 8 +- disttask/loaddata/dispatcher.go | 10 +- disttask/loaddata/dispatcher_test.go | 119 ++++++++++++++++++ disttask/loaddata/proto.go | 2 +- disttask/loaddata/wrapper_test.go | 2 +- 11 files changed, 182 insertions(+), 14 deletions(-) create mode 100644 disttask/framework/dispatcher/dispatcher_mock.go create mode 100644 disttask/loaddata/dispatcher_test.go diff --git a/ddl/disttask_flow.go b/ddl/disttask_flow.go index 3912787becb48..57dd66e11cbcb 100644 --- a/ddl/disttask_flow.go +++ b/ddl/disttask_flow.go @@ -57,7 +57,7 @@ func NewLitBackfillFlowHandle(getDDL func() DDL) dispatcher.TaskFlowHandle { } // ProcessNormalFlow processes the normal flow. -func (h *litBackfillFlowHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.Dispatch, gTask *proto.Task) (metas [][]byte, err error) { +func (h *litBackfillFlowHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.Handle, gTask *proto.Task) (metas [][]byte, err error) { if gTask.State != proto.TaskStatePending { // This flow has only one step, finish task when it is not pending return nil, nil @@ -108,7 +108,7 @@ func (h *litBackfillFlowHandle) ProcessNormalFlow(_ context.Context, _ dispatche return subTaskMetas, nil } -func (*litBackfillFlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Dispatch, _ *proto.Task, _ string) (meta []byte, err error) { +func (*litBackfillFlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Handle, _ *proto.Task, _ string) (meta []byte, err error) { // We do not need extra meta info when rolling back return nil, nil } diff --git a/disttask/framework/dispatcher/BUILD.bazel b/disttask/framework/dispatcher/BUILD.bazel index f0bc48257e597..85e318c18ef34 100644 --- a/disttask/framework/dispatcher/BUILD.bazel +++ b/disttask/framework/dispatcher/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "dispatcher", srcs = [ "dispatcher.go", + "dispatcher_mock.go", "register.go", ], importpath = "github.com/pingcap/tidb/disttask/framework/dispatcher", @@ -16,6 +17,7 @@ go_library( "//util/logutil", "//util/syncutil", "@com_github_pingcap_errors//:errors", + "@com_github_stretchr_testify//mock", "@org_uber_go_zap//:zap", ], ) diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 66cbfb973508f..167db4c1946a4 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -50,6 +50,12 @@ type Dispatch interface { Stop() } +// Handle provides the interface for operations needed by task flow handles. +type Handle interface { + // GetTaskAllInstances gets handles the task's all available instances. + GetTaskAllInstances(ctx context.Context, gTaskID int64) ([]string, error) +} + func (d *dispatcher) getRunningGlobalTasks() map[int64]*proto.Task { d.runningGlobalTasks.RLock() defer d.runningGlobalTasks.RUnlock() diff --git a/disttask/framework/dispatcher/dispatcher_mock.go b/disttask/framework/dispatcher/dispatcher_mock.go new file mode 100644 index 0000000000000..f8b7288ee4adf --- /dev/null +++ b/disttask/framework/dispatcher/dispatcher_mock.go @@ -0,0 +1,35 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dispatcher + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +// MockHandle is used to mock the Handle. +type MockHandle struct { + mock.Mock +} + +// GetTaskAllInstances implements the Handle.GetTaskAllInstances interface. +func (m *MockHandle) GetTaskAllInstances(ctx context.Context, gTaskID int64) ([]string, error) { + args := m.Called(ctx, gTaskID) + if args.Error(1) != nil { + return nil, args.Error(1) + } + return args.Get(0).([]string), nil +} diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index 2ba049c5f949b..42fce595dd12a 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -191,7 +191,7 @@ const taskTypeExample = "task_example" type NumberExampleHandle struct { } -func (n NumberExampleHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.Dispatch, gTask *proto.Task) (metas [][]byte, err error) { +func (n NumberExampleHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.Handle, gTask *proto.Task) (metas [][]byte, err error) { if gTask.State == proto.TaskStatePending { gTask.Step = proto.StepInit } @@ -211,7 +211,7 @@ func (n NumberExampleHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.D return metas, nil } -func (n NumberExampleHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Dispatch, _ *proto.Task, _ string) (meta []byte, err error) { +func (n NumberExampleHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Handle, _ *proto.Task, _ string) (meta []byte, err error) { // Don't handle not. return nil, nil } diff --git a/disttask/framework/dispatcher/register.go b/disttask/framework/dispatcher/register.go index daffa6baa6657..42e6d809d1d53 100644 --- a/disttask/framework/dispatcher/register.go +++ b/disttask/framework/dispatcher/register.go @@ -23,8 +23,8 @@ import ( // TaskFlowHandle is used to control the process operations for each global task. type TaskFlowHandle interface { - ProcessNormalFlow(ctx context.Context, d Dispatch, gTask *proto.Task) (metas [][]byte, err error) - ProcessErrFlow(ctx context.Context, d Dispatch, gTask *proto.Task, receive string) (meta []byte, err error) + ProcessNormalFlow(ctx context.Context, h Handle, gTask *proto.Task) (metas [][]byte, err error) + ProcessErrFlow(ctx context.Context, h Handle, gTask *proto.Task, receive string) (meta []byte, err error) } var taskFlowHandleMap struct { diff --git a/disttask/loaddata/BUILD.bazel b/disttask/loaddata/BUILD.bazel index f3d17483248f3..7ea322a01e49a 100644 --- a/disttask/loaddata/BUILD.bazel +++ b/disttask/loaddata/BUILD.bazel @@ -29,14 +29,20 @@ go_library( go_test( name = "loaddata_test", timeout = "short", - srcs = ["wrapper_test.go"], + srcs = [ + "dispatcher_test.go", + "wrapper_test.go", + ], embed = [":loaddata"], flaky = True, deps = [ "//br/pkg/lightning/config", "//br/pkg/lightning/mydump", + "//disttask/framework/dispatcher", + "//disttask/framework/proto", "//executor/importer", "//parser/model", + "@com_github_stretchr_testify//mock", "@com_github_stretchr_testify//require", ], ) diff --git a/disttask/loaddata/dispatcher.go b/disttask/loaddata/dispatcher.go index 64ce0a6504098..0fee94d68a937 100644 --- a/disttask/loaddata/dispatcher.go +++ b/disttask/loaddata/dispatcher.go @@ -24,11 +24,11 @@ import ( "go.uber.org/zap" ) -// Dispatcher is the dispatcher for load data. -type Dispatcher struct{} +// FlowHandle is the dispatcher for load data. +type FlowHandle struct{} // ProcessNormalFlow implements dispatcher.TaskFlowHandle interface. -func (*Dispatcher) ProcessNormalFlow(ctx context.Context, dispatch dispatcher.Dispatch, gTask *proto.Task) ([][]byte, error) { +func (*FlowHandle) ProcessNormalFlow(ctx context.Context, dispatch dispatcher.Handle, gTask *proto.Task) ([][]byte, error) { taskMeta := &TaskMeta{} err := json.Unmarshal(gTask.Meta, taskMeta) if err != nil { @@ -65,7 +65,7 @@ func (*Dispatcher) ProcessNormalFlow(ctx context.Context, dispatch dispatcher.Di } // ProcessErrFlow implements dispatcher.ProcessErrFlow interface. -func (*Dispatcher) ProcessErrFlow(_ context.Context, _ dispatcher.Dispatch, _ *proto.Task, errMsg string) ([]byte, error) { +func (*FlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Handle, _ *proto.Task, errMsg string) ([]byte, error) { logutil.BgLogger().Info("process error flow", zap.String("error message", errMsg)) return nil, nil } @@ -103,5 +103,5 @@ func generateSubtaskMetas(ctx context.Context, task *TaskMeta, concurrency int) } func init() { - dispatcher.RegisterTaskFlowHandle(proto.LoadData, &Dispatcher{}) + dispatcher.RegisterTaskFlowHandle(proto.LoadData, &FlowHandle{}) } diff --git a/disttask/loaddata/dispatcher_test.go b/disttask/loaddata/dispatcher_test.go new file mode 100644 index 0000000000000..e257d79bcdd89 --- /dev/null +++ b/disttask/loaddata/dispatcher_test.go @@ -0,0 +1,119 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loaddata + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/pingcap/tidb/disttask/framework/dispatcher" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/model" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestProcessNormalFlow(t *testing.T) { + flowHandle := &FlowHandle{} + mockDispatcherHandle := &dispatcher.MockHandle{} + + dir := t.TempDir() + path1 := "test1.csv" + path2 := "test2.csv" + content1 := []byte("1,1\r\n2,2\r\n3,3") + content2 := []byte("4,4\r\n5,5\r\n6,6") + require.NoError(t, os.WriteFile(filepath.Join(dir, path1), content1, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, path2), content2, 0o644)) + fileInfo1, err := os.Stat(filepath.Join(dir, path1)) + require.NoError(t, err) + fileInfo2, err := os.Stat(filepath.Join(dir, path2)) + require.NoError(t, err) + + taskMeta := TaskMeta{ + Table: Table{ + DBName: "db", + Info: &model.TableInfo{}, + TargetColumns: []string{"a", "b", "c"}, + IsRowOrdered: true, + }, + Format: Format{ + Type: importer.LoadDataFormatDelimitedData, + }, + Dir: dir, + FileInfos: []FileInfo{ + { + Path: path1, + Size: fileInfo1.Size(), + RealSize: fileInfo1.Size(), + }, + { + Path: path2, + Size: fileInfo1.Size(), + RealSize: fileInfo2.Size(), + }, + }, + } + bs, err := json.Marshal(taskMeta) + require.NoError(t, err) + task := &proto.Task{ + Meta: bs, + } + + mockDispatcherHandle.On("GetTaskAllInstances", mock.Anything, mock.Anything).Return([]string{"tidb1", "tidb2"}, nil).Once() + subtaskMetas, err := flowHandle.ProcessNormalFlow(context.Background(), mockDispatcherHandle, task) + require.NoError(t, err) + require.Equal(t, task.Step, Import) + require.Len(t, subtaskMetas, 1) + subtaskMeta := &SubtaskMeta{} + require.NoError(t, json.Unmarshal(subtaskMetas[0], subtaskMeta)) + require.Equal(t, subtaskMeta.Table, taskMeta.Table) + require.Equal(t, subtaskMeta.Format, taskMeta.Format) + require.Equal(t, subtaskMeta.Dir, dir) + require.Len(t, subtaskMeta.Chunks, 2) + require.Equal(t, subtaskMeta.Chunks[0], Chunk{ + Path: path1, + Offset: 0, + EndOffset: 13, + RealOffset: 0, + PrevRowIDMax: 0, + RowIDMax: 4, + }) + require.Equal(t, subtaskMeta.Chunks[1], Chunk{ + Path: path2, + Offset: 0, + EndOffset: 13, + RealOffset: 0, + PrevRowIDMax: 4, + RowIDMax: 8, + }) + + subtaskMetas, err = flowHandle.ProcessNormalFlow(context.Background(), mockDispatcherHandle, task) + require.NoError(t, err) + require.Len(t, subtaskMetas, 0) + require.Equal(t, task.State, proto.TaskStateSucceed) +} + +func TestProcessErrFlow(t *testing.T) { + flowHandle := &FlowHandle{} + mockDispatcherHandle := &dispatcher.MockHandle{} + // add test if needed + bs, err := flowHandle.ProcessErrFlow(context.Background(), mockDispatcherHandle, &proto.Task{}, "") + require.NoError(t, err) + require.Nil(t, bs) +} diff --git a/disttask/loaddata/proto.go b/disttask/loaddata/proto.go index 3e9e68487968f..71af86f070ec1 100644 --- a/disttask/loaddata/proto.go +++ b/disttask/loaddata/proto.go @@ -24,7 +24,7 @@ import ( // TaskStep of LoadData. const ( - Import = 1 + Import int64 = 1 ) // TaskMeta is the task of LoadData. diff --git a/disttask/loaddata/wrapper_test.go b/disttask/loaddata/wrapper_test.go index 115c6a5d7f1f8..7b397a8b07ef1 100644 --- a/disttask/loaddata/wrapper_test.go +++ b/disttask/loaddata/wrapper_test.go @@ -71,7 +71,7 @@ func TestMakeTableRegions(t *testing.T) { Name: model.NewCIStr("test"), }, }, - Dir: "dir", + Dir: "/tmp/test", } regions, err = makeTableRegions(context.Background(), task, 1) require.EqualError(t, err, "unknown source type: ") From 816ec8cd34384f4536fa02f5cc11796ca32616cd Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Mon, 27 Mar 2023 15:48:34 +0800 Subject: [PATCH 5/8] fix ut --- disttask/loaddata/BUILD.bazel | 1 + disttask/loaddata/dispatcher_test.go | 8 ++++---- disttask/loaddata/testdata/000000_0.parquet | Bin 0 -> 434 bytes disttask/loaddata/testdata/split_large_file.csv | 5 +++++ .../loaddata/testdata/split_large_file.csv.zst | Bin 0 -> 43 bytes disttask/loaddata/wrapper_test.go | 3 +-- 6 files changed, 11 insertions(+), 6 deletions(-) create mode 100644 disttask/loaddata/testdata/000000_0.parquet create mode 100644 disttask/loaddata/testdata/split_large_file.csv create mode 100644 disttask/loaddata/testdata/split_large_file.csv.zst diff --git a/disttask/loaddata/BUILD.bazel b/disttask/loaddata/BUILD.bazel index 7ea322a01e49a..847bfe90f0b89 100644 --- a/disttask/loaddata/BUILD.bazel +++ b/disttask/loaddata/BUILD.bazel @@ -33,6 +33,7 @@ go_test( "dispatcher_test.go", "wrapper_test.go", ], + data = glob(["testdata/**"]), embed = [":loaddata"], flaky = True, deps = [ diff --git a/disttask/loaddata/dispatcher_test.go b/disttask/loaddata/dispatcher_test.go index e257d79bcdd89..f124a0b9f457d 100644 --- a/disttask/loaddata/dispatcher_test.go +++ b/disttask/loaddata/dispatcher_test.go @@ -49,7 +49,7 @@ func TestProcessNormalFlow(t *testing.T) { Table: Table{ DBName: "db", Info: &model.TableInfo{}, - TargetColumns: []string{"a", "b", "c"}, + TargetColumns: []string{"a", "b"}, IsRowOrdered: true, }, Format: Format{ @@ -92,15 +92,15 @@ func TestProcessNormalFlow(t *testing.T) { EndOffset: 13, RealOffset: 0, PrevRowIDMax: 0, - RowIDMax: 4, + RowIDMax: 6, }) require.Equal(t, subtaskMeta.Chunks[1], Chunk{ Path: path2, Offset: 0, EndOffset: 13, RealOffset: 0, - PrevRowIDMax: 4, - RowIDMax: 8, + PrevRowIDMax: 6, + RowIDMax: 12, }) subtaskMetas, err = flowHandle.ProcessNormalFlow(context.Background(), mockDispatcherHandle, task) diff --git a/disttask/loaddata/testdata/000000_0.parquet b/disttask/loaddata/testdata/000000_0.parquet new file mode 100644 index 0000000000000000000000000000000000000000..ae8a5001bc2b31b67f1a3edab824c8e55ec78cfa GIT binary patch literal 434 zcmWG=3^EjD5H%4s(GleWGT1~pWF%Nv85kHOSQvq%7!Wfs0Wq^Yhzn)`X%-*`DH3H7 ztq`s70cr(_?D${sRhEHaXNM{SL(P%3bJ!RdKBNHIT84A{ycrl4Cjr?B(-J^z*;XJM z3Kq{(WdK?&CMpJU0}I3rJjfj@GM9yDlAPc(akMVC`&CW&dkqKFx0cq zGgQz>D$UGEQ7}nNN;FAHOf^qRGDtB=O-V{lGq*G{PE0j4HnvPNvPerYPBTwSPSa#~ L1`MqLU~B;ZE7Mb) literal 0 HcmV?d00001 diff --git a/disttask/loaddata/testdata/split_large_file.csv b/disttask/loaddata/testdata/split_large_file.csv new file mode 100644 index 0000000000000..7b6512d538f6a --- /dev/null +++ b/disttask/loaddata/testdata/split_large_file.csv @@ -0,0 +1,5 @@ +a,b,c +1,1,2 +2,2,1 +3,2,2 +4,2,2 diff --git a/disttask/loaddata/testdata/split_large_file.csv.zst b/disttask/loaddata/testdata/split_large_file.csv.zst new file mode 100644 index 0000000000000000000000000000000000000000..9609230bf04a5bb1a8584f6cf2d3a905d2820c4a GIT binary patch literal 43 ucmdPcs{dC-?jr+3qE3=dGMAx_p^g!kk&cm$A(t_bG~zM=lW*)MECvAhxC;UR literal 0 HcmV?d00001 diff --git a/disttask/loaddata/wrapper_test.go b/disttask/loaddata/wrapper_test.go index 7b397a8b07ef1..a8fdb173f9f0d 100644 --- a/disttask/loaddata/wrapper_test.go +++ b/disttask/loaddata/wrapper_test.go @@ -78,7 +78,7 @@ func TestMakeTableRegions(t *testing.T) { require.Nil(t, regions) // parquet - dir := "../../br/pkg/lightning/mydump/parquet" + dir := "testdata" filename := "000000_0.parquet" task = &TaskMeta{ Table: Table{ @@ -115,7 +115,6 @@ func TestMakeTableRegions(t *testing.T) { config.MaxRegionSize = originRegionSize config.DefaultBatchSize = originBatchSize }() - dir = "../../br/pkg/lightning/mydump/csv" filename = "split_large_file.csv" dataFileInfo, err := os.Stat(filepath.Join(dir, filename)) require.NoError(t, err) From 1deabfa3283d58c4978427811dd499d38b2415d5 Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Mon, 27 Mar 2023 16:06:58 +0800 Subject: [PATCH 6/8] fix ut --- .github/licenserc.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/licenserc.yml b/.github/licenserc.yml index aec59e4f66d57..b95cc4cd7feca 100644 --- a/.github/licenserc.yml +++ b/.github/licenserc.yml @@ -39,6 +39,9 @@ header: - "tidb-binlog/driver/example" - "tidb-binlog/proto/go-binlog/secondary_binlog.pb.go" - "**/*.sql" + - "**/*.csv" + - "**/*.parquet" + - "**/*.zst" - ".bazelversion" - "build/image/.ci_bazel" comment: on-failure From b7f372a6ac17fd3e76a35fd4e1831533026a9ffa Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Wed, 29 Mar 2023 11:42:25 +0800 Subject: [PATCH 7/8] address comment --- disttask/loaddata/dispatcher.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/disttask/loaddata/dispatcher.go b/disttask/loaddata/dispatcher.go index 0fee94d68a937..a72cd64f80f59 100644 --- a/disttask/loaddata/dispatcher.go +++ b/disttask/loaddata/dispatcher.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" + "golang.org/x/exp/maps" ) // FlowHandle is the dispatcher for load data. @@ -76,21 +77,18 @@ func generateSubtaskMetas(ctx context.Context, task *TaskMeta, concurrency int) return nil, err } - engineMap := make(map[int32]int) - subtasks := make([]*SubtaskMeta, 0) + subtaskMetaMap := make(map[int32]*SubtaskMeta) for _, region := range tableRegions { - idx, ok := engineMap[region.EngineID] + subtaskMeta, ok := subtaskMetaMap[region.EngineID] if !ok { - idx = len(subtasks) - engineMap[region.EngineID] = idx - subtasks = append(subtasks, &SubtaskMeta{ + subtaskMeta = &SubtaskMeta{ Table: task.Table, Format: task.Format, Dir: task.Dir, - }) + } + subtaskMetaMap[region.EngineID] = subtaskMeta } - subtask := subtasks[idx] - subtask.Chunks = append(subtask.Chunks, Chunk{ + subtaskMeta.Chunks = append(subtaskMeta.Chunks, Chunk{ Path: region.FileMeta.Path, Offset: region.Chunk.Offset, EndOffset: region.Chunk.EndOffset, @@ -99,7 +97,7 @@ func generateSubtaskMetas(ctx context.Context, task *TaskMeta, concurrency int) RowIDMax: region.Chunk.RowIDMax, }) } - return subtasks, nil + return maps.Values[map[int32]*SubtaskMeta](subtaskMetaMap), nil } func init() { From 4adce495d5adcaa4c90e2d02391b9b24c32b8ac6 Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Wed, 29 Mar 2023 13:22:48 +0800 Subject: [PATCH 8/8] address comment --- ddl/disttask_flow.go | 4 ++-- disttask/framework/dispatcher/dispatcher.go | 17 +++++++++-------- .../framework/dispatcher/dispatcher_mock.go | 4 ++-- .../framework/dispatcher/dispatcher_test.go | 12 ++++++------ disttask/framework/dispatcher/register.go | 4 ++-- disttask/loaddata/BUILD.bazel | 1 + disttask/loaddata/dispatcher.go | 8 ++++---- disttask/loaddata/dispatcher_test.go | 2 +- 8 files changed, 27 insertions(+), 25 deletions(-) diff --git a/ddl/disttask_flow.go b/ddl/disttask_flow.go index 57dd66e11cbcb..6c7b214b05f80 100644 --- a/ddl/disttask_flow.go +++ b/ddl/disttask_flow.go @@ -57,7 +57,7 @@ func NewLitBackfillFlowHandle(getDDL func() DDL) dispatcher.TaskFlowHandle { } // ProcessNormalFlow processes the normal flow. -func (h *litBackfillFlowHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.Handle, gTask *proto.Task) (metas [][]byte, err error) { +func (h *litBackfillFlowHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { if gTask.State != proto.TaskStatePending { // This flow has only one step, finish task when it is not pending return nil, nil @@ -108,7 +108,7 @@ func (h *litBackfillFlowHandle) ProcessNormalFlow(_ context.Context, _ dispatche return subTaskMetas, nil } -func (*litBackfillFlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Handle, _ *proto.Task, _ string) (meta []byte, err error) { +func (*litBackfillFlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ string) (meta []byte, err error) { // We do not need extra meta info when rolling back return nil, nil } diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 167db4c1946a4..6366359f1226e 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -44,16 +44,16 @@ const ( type Dispatch interface { // Start enables dispatching and monitoring mechanisms. Start() - // GetTaskAllInstances gets handles the task's all available instances. - GetTaskAllInstances(ctx context.Context, gTaskID int64) ([]string, error) + // GetAllSchedulerIDs gets handles the task's all available instances. + GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) // Stop stops the dispatcher. Stop() } -// Handle provides the interface for operations needed by task flow handles. -type Handle interface { - // GetTaskAllInstances gets handles the task's all available instances. - GetTaskAllInstances(ctx context.Context, gTaskID int64) ([]string, error) +// TaskHandle provides the interface for operations needed by task flow handles. +type TaskHandle interface { + // GetAllSchedulerIDs gets handles the task's all scheduler instances. + GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) } func (d *dispatcher) getRunningGlobalTasks() map[int64]*proto.Task { @@ -259,7 +259,7 @@ func (d *dispatcher) processErrFlow(gTask *proto.Task, receiveErr string) error } // TODO: Consider using a new context. - instanceIDs, err := d.GetTaskAllInstances(d.ctx, gTask.ID) + instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, gTask.ID) if err != nil { logutil.BgLogger().Warn("get global task's all instances failed", zap.Error(err)) return err @@ -384,7 +384,8 @@ func GetEligibleInstance(ctx context.Context) (string, error) { return "", errors.New("not found instance") } -func (d *dispatcher) GetTaskAllInstances(ctx context.Context, gTaskID int64) ([]string, error) { +// GetAllSchedulerIDs gets all the scheduler IDs. +func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) { if len(MockTiDBIDs) != 0 { return MockTiDBIDs, nil } diff --git a/disttask/framework/dispatcher/dispatcher_mock.go b/disttask/framework/dispatcher/dispatcher_mock.go index f8b7288ee4adf..0b9b9ed06c4d1 100644 --- a/disttask/framework/dispatcher/dispatcher_mock.go +++ b/disttask/framework/dispatcher/dispatcher_mock.go @@ -25,8 +25,8 @@ type MockHandle struct { mock.Mock } -// GetTaskAllInstances implements the Handle.GetTaskAllInstances interface. -func (m *MockHandle) GetTaskAllInstances(ctx context.Context, gTaskID int64) ([]string, error) { +// GetAllSchedulerIDs implements the Handle.GetAllSchedulerIDs interface. +func (m *MockHandle) GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) { args := m.Called(ctx, gTaskID) if args.Error(1) != nil { return nil, args.Error(1) diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index 42fce595dd12a..9fa8532c4beef 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -63,7 +63,7 @@ func TestGetInstance(t *testing.T) { instanceID, err := dispatcher.GetEligibleInstance(ctx) require.Lenf(t, instanceID, 0, "instanceID:%d", instanceID) require.EqualError(t, err, "not found instance") - instanceIDs, err := dsp.GetTaskAllInstances(ctx, 1) + instanceIDs, err := dsp.GetAllSchedulerIDs(ctx, 1) require.Lenf(t, instanceIDs, 0, "instanceID:%d", instanceID) require.NoError(t, err) @@ -85,7 +85,7 @@ func TestGetInstance(t *testing.T) { if instanceID != uuids[0] && instanceID != uuids[1] { require.FailNowf(t, "expected uuids:%d,%d, actual uuid:%d", uuids[0], uuids[1], instanceID) } - instanceIDs, err = dsp.GetTaskAllInstances(ctx, 1) + instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, 1) require.Lenf(t, instanceIDs, 0, "instanceID:%d", instanceID) require.NoError(t, err) @@ -99,7 +99,7 @@ func TestGetInstance(t *testing.T) { } err = subTaskMgr.AddNewTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true) require.NoError(t, err) - instanceIDs, err = dsp.GetTaskAllInstances(ctx, gTaskID) + instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, gTaskID) require.NoError(t, err) require.Equal(t, []string{uuids[1]}, instanceIDs) // server ids: uuid0, uuid1 @@ -111,7 +111,7 @@ func TestGetInstance(t *testing.T) { } err = subTaskMgr.AddNewTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true) require.NoError(t, err) - instanceIDs, err = dsp.GetTaskAllInstances(ctx, gTaskID) + instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, gTaskID) require.NoError(t, err) require.Len(t, instanceIDs, len(uuids)) require.ElementsMatch(t, instanceIDs, uuids) @@ -191,7 +191,7 @@ const taskTypeExample = "task_example" type NumberExampleHandle struct { } -func (n NumberExampleHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.Handle, gTask *proto.Task) (metas [][]byte, err error) { +func (n NumberExampleHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { if gTask.State == proto.TaskStatePending { gTask.Step = proto.StepInit } @@ -211,7 +211,7 @@ func (n NumberExampleHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.H return metas, nil } -func (n NumberExampleHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Handle, _ *proto.Task, _ string) (meta []byte, err error) { +func (n NumberExampleHandle) ProcessErrFlow(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ string) (meta []byte, err error) { // Don't handle not. return nil, nil } diff --git a/disttask/framework/dispatcher/register.go b/disttask/framework/dispatcher/register.go index 42e6d809d1d53..c6945e2130cd1 100644 --- a/disttask/framework/dispatcher/register.go +++ b/disttask/framework/dispatcher/register.go @@ -23,8 +23,8 @@ import ( // TaskFlowHandle is used to control the process operations for each global task. type TaskFlowHandle interface { - ProcessNormalFlow(ctx context.Context, h Handle, gTask *proto.Task) (metas [][]byte, err error) - ProcessErrFlow(ctx context.Context, h Handle, gTask *proto.Task, receive string) (meta []byte, err error) + ProcessNormalFlow(ctx context.Context, h TaskHandle, gTask *proto.Task) (metas [][]byte, err error) + ProcessErrFlow(ctx context.Context, h TaskHandle, gTask *proto.Task, receive string) (meta []byte, err error) } var taskFlowHandleMap struct { diff --git a/disttask/loaddata/BUILD.bazel b/disttask/loaddata/BUILD.bazel index 847bfe90f0b89..b2e342c22207a 100644 --- a/disttask/loaddata/BUILD.bazel +++ b/disttask/loaddata/BUILD.bazel @@ -22,6 +22,7 @@ go_library( "//util/intest", "//util/logutil", "@com_github_pingcap_errors//:errors", + "@org_golang_x_exp//maps", "@org_uber_go_zap//:zap", ], ) diff --git a/disttask/loaddata/dispatcher.go b/disttask/loaddata/dispatcher.go index a72cd64f80f59..0ee443645f147 100644 --- a/disttask/loaddata/dispatcher.go +++ b/disttask/loaddata/dispatcher.go @@ -29,7 +29,7 @@ import ( type FlowHandle struct{} // ProcessNormalFlow implements dispatcher.TaskFlowHandle interface. -func (*FlowHandle) ProcessNormalFlow(ctx context.Context, dispatch dispatcher.Handle, gTask *proto.Task) ([][]byte, error) { +func (*FlowHandle) ProcessNormalFlow(ctx context.Context, dispatch dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) { taskMeta := &TaskMeta{} err := json.Unmarshal(gTask.Meta, taskMeta) if err != nil { @@ -44,11 +44,11 @@ func (*FlowHandle) ProcessNormalFlow(ctx context.Context, dispatch dispatcher.Ha default: } - instances, err := dispatch.GetTaskAllInstances(ctx, gTask.ID) + schedulers, err := dispatch.GetAllSchedulerIDs(ctx, gTask.ID) if err != nil { return nil, err } - subtaskMetas, err := generateSubtaskMetas(ctx, taskMeta, len(instances)) + subtaskMetas, err := generateSubtaskMetas(ctx, taskMeta, len(schedulers)) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (*FlowHandle) ProcessNormalFlow(ctx context.Context, dispatch dispatcher.Ha } // ProcessErrFlow implements dispatcher.ProcessErrFlow interface. -func (*FlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.Handle, _ *proto.Task, errMsg string) ([]byte, error) { +func (*FlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, errMsg string) ([]byte, error) { logutil.BgLogger().Info("process error flow", zap.String("error message", errMsg)) return nil, nil } diff --git a/disttask/loaddata/dispatcher_test.go b/disttask/loaddata/dispatcher_test.go index f124a0b9f457d..613de95b2c935 100644 --- a/disttask/loaddata/dispatcher_test.go +++ b/disttask/loaddata/dispatcher_test.go @@ -75,7 +75,7 @@ func TestProcessNormalFlow(t *testing.T) { Meta: bs, } - mockDispatcherHandle.On("GetTaskAllInstances", mock.Anything, mock.Anything).Return([]string{"tidb1", "tidb2"}, nil).Once() + mockDispatcherHandle.On("GetAllSchedulerIDs", mock.Anything, mock.Anything).Return([]string{"tidb1", "tidb2"}, nil).Once() subtaskMetas, err := flowHandle.ProcessNormalFlow(context.Background(), mockDispatcherHandle, task) require.NoError(t, err) require.Equal(t, task.Step, Import)