From c3e8be9add2df6ac2f426b8ac35dd650f3dea9d4 Mon Sep 17 00:00:00 2001 From: okJiang <819421878@qq.com> Date: Mon, 14 Feb 2022 10:47:37 +0800 Subject: [PATCH] checker(dm): support concurrent check (#3975) close pingcap/tiflow#3974 --- dm/checker/check_test.go | 87 ++-- dm/checker/checker.go | 41 +- dm/pkg/checker/mysql_server.go | 9 +- dm/pkg/checker/privilege.go | 20 +- dm/pkg/checker/privilege_test.go | 25 +- dm/pkg/checker/table_structure.go | 523 ++++++++++++++----------- dm/pkg/checker/table_structure_test.go | 167 ++++---- dm/pkg/checker/utils.go | 20 +- dm/pkg/utils/db.go | 20 + dm/pkg/utils/db_test.go | 15 + 10 files changed, 527 insertions(+), 400 deletions(-) diff --git a/dm/checker/check_test.go b/dm/checker/check_test.go index a29f952ea15..45668d2354a 100644 --- a/dm/checker/check_test.go +++ b/dm/checker/check_test.go @@ -255,32 +255,33 @@ func (s *testCheckerSuite) TestTableSchemaChecking(c *tc.C) { } createTable1 := `CREATE TABLE %s ( - id int(11) DEFAULT NULL, - b int(11) DEFAULT NULL - ) ENGINE=InnoDB DEFAULT CHARSET=latin1` + id int(11) DEFAULT NULL, + b int(11) DEFAULT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=latin1` createTable2 := `CREATE TABLE %s ( - id int(11) DEFAULT NULL, - b int(11) DEFAULT NULL, - UNIQUE KEY id (id) - ) ENGINE=InnoDB DEFAULT CHARSET=latin1` + id int(11) DEFAULT NULL, + b int(11) DEFAULT NULL, + UNIQUE KEY id (id) + ) ENGINE=InnoDB DEFAULT CHARSET=latin1` mock := initMockDB(c) - mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb1))) + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("max_connections", "2")) mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) + mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb1))) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb2))) - mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) msg, err := CheckSyncConfig(context.Background(), cfgs, common.DefaultErrorCnt, common.DefaultWarnCnt) - c.Assert(len(msg), tc.Equals, 0) c.Assert(err, tc.ErrorMatches, "(.|\n)*primary/unique key does not exist(.|\n)*") + c.Assert(len(msg), tc.Equals, 0) mock = initMockDB(c) + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("max_connections", "2")) + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable2, tb1))) mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable2, tb2))) - mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) msg, err = CheckSyncConfig(context.Background(), cfgs, common.DefaultErrorCnt, common.DefaultWarnCnt) - c.Assert(msg, tc.Equals, CheckTaskSuccess) c.Assert(err, tc.IsNil) + c.Assert(msg, tc.Equals, CheckTaskSuccess) } func (s *testCheckerSuite) TestShardTableSchemaChecking(c *tc.C) { @@ -299,29 +300,37 @@ func (s *testCheckerSuite) TestShardTableSchemaChecking(c *tc.C) { } createTable1 := `CREATE TABLE %s ( - id int(11) DEFAULT NULL, - b int(11) DEFAULT NULL - ) ENGINE=InnoDB DEFAULT CHARSET=latin1` + id int(11) DEFAULT NULL, + b int(11) DEFAULT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=latin1` createTable2 := `CREATE TABLE %s ( - id int(11) DEFAULT NULL, - c int(11) DEFAULT NULL - ) ENGINE=InnoDB DEFAULT CHARSET=latin1` + id int(11) DEFAULT NULL, + c int(11) DEFAULT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=latin1` mock := initMockDB(c) mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb1))) + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("max_connections", "2")) + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) + mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb1))) + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable2, tb2))) msg, err := CheckSyncConfig(context.Background(), cfgs, common.DefaultErrorCnt, common.DefaultWarnCnt) - c.Assert(len(msg), tc.Equals, 0) c.Assert(err, tc.ErrorMatches, "(.|\n)*different column definition(.|\n)*") + c.Assert(len(msg), tc.Equals, 0) mock = initMockDB(c) mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb1))) + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("max_connections", "2")) + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) + mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb1))) + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb2))) msg, err = CheckSyncConfig(context.Background(), cfgs, common.DefaultErrorCnt, common.DefaultWarnCnt) - c.Assert(msg, tc.Equals, CheckTaskSuccess) c.Assert(err, tc.IsNil) + c.Assert(msg, tc.Equals, CheckTaskSuccess) } func (s *testCheckerSuite) TestShardAutoIncrementIDChecking(c *tc.C) { @@ -340,36 +349,42 @@ func (s *testCheckerSuite) TestShardAutoIncrementIDChecking(c *tc.C) { } createTable1 := `CREATE TABLE %s ( - id int(11) NOT NULL AUTO_INCREMENT, - b int(11) DEFAULT NULL, - PRIMARY KEY (id), - UNIQUE KEY u_b(b) - ) ENGINE=InnoDB DEFAULT CHARSET=latin1` + id int(11) NOT NULL AUTO_INCREMENT, + b int(11) DEFAULT NULL, + PRIMARY KEY (id), + UNIQUE KEY u_b(b) + ) ENGINE=InnoDB DEFAULT CHARSET=latin1` createTable2 := `CREATE TABLE %s ( - id int(11) NOT NULL, - b int(11) DEFAULT NULL, - INDEX (id), - UNIQUE KEY u_b(b) - ) ENGINE=InnoDB DEFAULT CHARSET=latin1` + id int(11) NOT NULL, + b int(11) DEFAULT NULL, + INDEX (id), + UNIQUE KEY u_b(b) + ) ENGINE=InnoDB DEFAULT CHARSET=latin1` mock := initMockDB(c) mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb1))) + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("max_connections", "2")) + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) + mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb1))) + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable1, tb2))) msg, err := CheckSyncConfig(context.Background(), cfgs, common.DefaultErrorCnt, common.DefaultWarnCnt) - c.Assert(len(msg), tc.Equals, 0) - c.Assert(err, tc.ErrorMatches, "(.|\n)*instance table .* of sharding .* have auto-increment key(.|\n)*") + c.Assert(msg, tc.Matches, "(.|\n)*sourceID table .* of sharding .* have auto-increment key(.|\n)*") + c.Assert(err, tc.IsNil) - mock = conn.InitMockDB(c) - mock.ExpectQuery("SHOW DATABASES").WillReturnRows(sqlmock.NewRows([]string{"DATABASE"}).AddRow(schema)) - mock.ExpectQuery("SHOW FULL TABLES").WillReturnRows(sqlmock.NewRows([]string{"Tables_in_" + schema, "Table_type"}).AddRow(tb1, "BASE TABLE").AddRow(tb2, "BASE TABLE")) + mock = initMockDB(c) mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable2, tb1))) + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("max_connections", "2")) + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) + mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable2, tb1))) + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")) mock.ExpectQuery("SHOW CREATE TABLE .*").WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow(tb1, fmt.Sprintf(createTable2, tb2))) msg, err = CheckSyncConfig(context.Background(), cfgs, common.DefaultErrorCnt, common.DefaultWarnCnt) - c.Assert(msg, tc.Equals, CheckTaskSuccess) c.Assert(err, tc.IsNil) + c.Assert(msg, tc.Equals, CheckTaskSuccess) } func (s *testCheckerSuite) TestSameTargetTableDetection(c *tc.C) { diff --git a/dm/checker/checker.go b/dm/checker/checker.go index 267aebf9687..0ae60fea15a 100644 --- a/dm/checker/checker.go +++ b/dm/checker/checker.go @@ -119,9 +119,11 @@ func (c *Checker) Init(ctx context.Context) (err error) { rollbackHolder.Add(fr.FuncRollback{Name: "close-DBs", Fn: c.closeDBs}) c.tctx = tcontext.NewContext(ctx, log.With(zap.String("unit", "task check"))) - // target name => source => schema => [tables] - sharding := make(map[string]map[string]map[string][]string) + // targetTableID => source => [tables] + sharding := make(map[string]map[string][]*filter.Table) shardingCounter := make(map[string]int) + // sourceID => []table + checkTablesMap := make(map[string][]*filter.Table) dbs := make(map[string]*sql.DB) columnMapping := make(map[string]*column.Mapping) _, checkingShardID := c.checkingItems[config.ShardAutoIncrementIDChecking] @@ -206,29 +208,22 @@ func (c *Checker) Init(ctx context.Context) (err error) { return err } - // checkTables map schema => {table1, table2, ...} - checkTables := make(map[string][]string) + var checkTables []*filter.Table checkSchemas := make(map[string]struct{}, len(mapping)) - for name, tables := range mapping { + for targetTableID, tables := range mapping { + checkTables = append(checkTables, tables...) + if _, ok := sharding[targetTableID]; !ok { + sharding[targetTableID] = make(map[string][]*filter.Table) + } + sharding[targetTableID][instance.cfg.SourceID] = append(sharding[targetTableID][instance.cfg.SourceID], tables...) + shardingCounter[targetTableID] += len(tables) for _, table := range tables { - checkTables[table.Schema] = append(checkTables[table.Schema], table.Name) if _, ok := checkSchemas[table.Schema]; !ok { checkSchemas[table.Schema] = struct{}{} } - if _, ok := sharding[name]; !ok { - sharding[name] = make(map[string]map[string][]string) - } - if _, ok := sharding[name][instance.cfg.SourceID]; !ok { - sharding[name][instance.cfg.SourceID] = make(map[string][]string) - } - if _, ok := sharding[name][instance.cfg.SourceID][table.Schema]; !ok { - sharding[name][instance.cfg.SourceID][table.Schema] = make([]string, 0, 1) - } - - sharding[name][instance.cfg.SourceID][table.Schema] = append(sharding[name][instance.cfg.SourceID][table.Schema], table.Name) - shardingCounter[name]++ } } + checkTablesMap[instance.cfg.SourceID] = checkTables dbs[instance.cfg.SourceID] = instance.sourceDB.DB if _, ok := c.checkingItems[config.DumpPrivilegeChecking]; ok { exportCfg := export.DefaultConfig() @@ -241,9 +236,11 @@ func (c *Checker) Init(ctx context.Context) (err error) { if c.onlineDDL != nil { c.checkList = append(c.checkList, checker.NewOnlineDDLChecker(instance.sourceDB.DB, checkSchemas, c.onlineDDL, bw)) } - if checkSchema { - c.checkList = append(c.checkList, checker.NewTablesChecker(instance.sourceDB.DB, instance.sourceDBinfo, checkTables)) - } + } + + dumpThreads := c.instances[0].cfg.MydumperConfig.Threads + if checkSchema { + c.checkList = append(c.checkList, checker.NewTablesChecker(dbs, checkTablesMap, dumpThreads)) } if checkingShard { @@ -252,7 +249,7 @@ func (c *Checker) Init(ctx context.Context) (err error) { continue } - c.checkList = append(c.checkList, checker.NewShardingTablesChecker(name, dbs, shardingSet, columnMapping, checkingShardID)) + c.checkList = append(c.checkList, checker.NewShardingTablesChecker(name, dbs, shardingSet, columnMapping, checkingShardID, dumpThreads)) } } diff --git a/dm/pkg/checker/mysql_server.go b/dm/pkg/checker/mysql_server.go index 8e3346d70c3..302baa4b73d 100644 --- a/dm/pkg/checker/mysql_server.go +++ b/dm/pkg/checker/mysql_server.go @@ -124,13 +124,12 @@ func (pc *MySQLServerIDChecker) Check(ctx context.Context) *Result { serverID, err := dbutil.ShowServerID(ctx, pc.db) if err != nil { - if utils.OriginError(err) == sql.ErrNoRows { - result.Errors = append(result.Errors, NewError("server_id not set")) - result.Instruction = "please set server_id in your database" - } else { + if utils.OriginError(err) != sql.ErrNoRows { markCheckError(result, err) + return result } - + result.Errors = append(result.Errors, NewError("server_id not set")) + result.Instruction = "please set server_id in your database" return result } diff --git a/dm/pkg/checker/privilege.go b/dm/pkg/checker/privilege.go index 2dab51ff1be..dda6075f178 100644 --- a/dm/pkg/checker/privilege.go +++ b/dm/pkg/checker/privilege.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/pingcap/tidb-tools/pkg/dbutil" + "github.com/pingcap/tidb-tools/pkg/filter" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/mysql" @@ -41,12 +42,12 @@ var privNeedGlobal = map[mysql.PrivilegeType]struct{}{ type SourceDumpPrivilegeChecker struct { db *sql.DB dbinfo *dbutil.DBConfig - checkTables map[string][]string // map schema => {table1, table2, ...} + checkTables []*filter.Table consistency string } // NewSourceDumpPrivilegeChecker returns a RealChecker. -func NewSourceDumpPrivilegeChecker(db *sql.DB, dbinfo *dbutil.DBConfig, checkTables map[string][]string, consistency string) RealChecker { +func NewSourceDumpPrivilegeChecker(db *sql.DB, dbinfo *dbutil.DBConfig, checkTables []*filter.Table, consistency string) RealChecker { return &SourceDumpPrivilegeChecker{ db: db, dbinfo: dbinfo, @@ -320,9 +321,8 @@ func verifyPrivileges(result *Result, grants []string, lackPriv map[mysql.Privil return NewError(privileges) } -// checkTables map schema => {table1, table2, ...}. // lackPriv map privilege => schema => table. -func genExpectPriv(privileges map[mysql.PrivilegeType]struct{}, checkTables map[string][]string) map[mysql.PrivilegeType]map[string]map[string]struct{} { +func genExpectPriv(privileges map[mysql.PrivilegeType]struct{}, checkTables []*filter.Table) map[mysql.PrivilegeType]map[string]map[string]struct{} { lackPriv := make(map[mysql.PrivilegeType]map[string]map[string]struct{}, len(privileges)) for p := range privileges { if _, ok := privNeedGlobal[p]; ok { @@ -330,13 +330,11 @@ func genExpectPriv(privileges map[mysql.PrivilegeType]struct{}, checkTables map[ continue } lackPriv[p] = make(map[string]map[string]struct{}, len(checkTables)) - for schema, tables := range checkTables { - if _, ok := lackPriv[p][schema]; !ok { - lackPriv[p][schema] = make(map[string]struct{}, len(tables)) - } - for _, table := range tables { - lackPriv[p][schema][table] = struct{}{} + for _, table := range checkTables { + if _, ok := lackPriv[p][table.Schema]; !ok { + lackPriv[p][table.Schema] = make(map[string]struct{}) } + lackPriv[p][table.Schema][table.Name] = struct{}{} } if p == mysql.SelectPriv { if _, ok := lackPriv[p]["INFORMATION_SCHEMA"]; !ok { @@ -353,7 +351,7 @@ func genReplicPriv(replicationPrivileges map[mysql.PrivilegeType]struct{}) map[m return genExpectPriv(replicationPrivileges, nil) } -func genDumpPriv(dumpPrivileges map[mysql.PrivilegeType]struct{}, checkTables map[string][]string) map[mysql.PrivilegeType]map[string]map[string]struct{} { +func genDumpPriv(dumpPrivileges map[mysql.PrivilegeType]struct{}, checkTables []*filter.Table) map[mysql.PrivilegeType]map[string]map[string]struct{} { // due to dump privilege checker need check db/table level privilege // so we need know the check tables return genExpectPriv(dumpPrivileges, checkTables) diff --git a/dm/pkg/checker/privilege_test.go b/dm/pkg/checker/privilege_test.go index e38da710c93..927a7c4e231 100644 --- a/dm/pkg/checker/privilege_test.go +++ b/dm/pkg/checker/privilege_test.go @@ -17,6 +17,7 @@ import ( "testing" tc "github.com/pingcap/check" + "github.com/pingcap/tidb-tools/pkg/filter" "github.com/pingcap/tidb/parser/mysql" ) @@ -31,7 +32,7 @@ type testCheckSuite struct{} func (t *testCheckSuite) TestVerifyDumpPrivileges(c *tc.C) { cases := []struct { grants []string - checkTables map[string][]string + checkTables []*filter.Table dumpState State errMatch string }{ @@ -58,8 +59,8 @@ func (t *testCheckSuite) TestVerifyDumpPrivileges(c *tc.C) { "GRANT EXECUTE ON FUNCTION db1.anomaly_score TO 'user1'@'domain-or-ip-address1'", }, dumpState: StateFailure, - checkTables: map[string][]string{ - "db1": {"anomaly_score"}, + checkTables: []*filter.Table{ + {Schema: "db1", Name: "anomaly_score"}, }, // `db1`.`anomaly_score`; `INFORMATION_SCHEMA` // can't guarantee the order @@ -126,8 +127,8 @@ func (t *testCheckSuite) TestVerifyDumpPrivileges(c *tc.C) { "GRANT ALL PRIVILEGES ON `medz`.* TO `zhangsan`@`10.8.1.9` WITH GRANT OPTION", }, dumpState: StateFailure, - checkTables: map[string][]string{ - "medz": {"medz"}, + checkTables: []*filter.Table{ + {Schema: "medz", Name: "medz"}, }, errMatch: "lack of RELOAD privilege; ", }, @@ -137,8 +138,8 @@ func (t *testCheckSuite) TestVerifyDumpPrivileges(c *tc.C) { "GRANT ALL PRIVILEGES ON `INFORMATION_SCHEMA`.* TO `zhangsan`@`10.8.1.9` WITH GRANT OPTION", }, dumpState: StateFailure, - checkTables: map[string][]string{ - "medz": {"medz"}, + checkTables: []*filter.Table{ + {Schema: "medz", Name: "medz"}, }, errMatch: "lack of RELOAD privilege; ", }, @@ -149,8 +150,8 @@ func (t *testCheckSuite) TestVerifyDumpPrivileges(c *tc.C) { "GRANT SELECT ON `INFORMATION_SCHEMA`.* TO 'user'@'%'", }, dumpState: StateFailure, - checkTables: map[string][]string{ - "lance": {"t"}, + checkTables: []*filter.Table{ + {Schema: "lance", Name: "t"}, }, errMatch: "lack of Select privilege: {`lance`.`t`}; ", }, @@ -162,8 +163,8 @@ func (t *testCheckSuite) TestVerifyDumpPrivileges(c *tc.C) { "GRANT `r1`@`%`,`r2`@`%` TO `u1`@`localhost`", }, dumpState: StateSuccess, - checkTables: map[string][]string{ - "db1": {"t"}, + checkTables: []*filter.Table{ + {Schema: "db1", Name: "t"}, }, }, { @@ -195,7 +196,7 @@ func (t *testCheckSuite) TestVerifyDumpPrivileges(c *tc.C) { func (t *testCheckSuite) TestVerifyReplicationPrivileges(c *tc.C) { cases := []struct { grants []string - checkTables map[string][]string + checkTables []*filter.Table replicationState State errMatch string }{ diff --git a/dm/pkg/checker/table_structure.go b/dm/pkg/checker/table_structure.go index 48b7044078a..e6eb1363ff2 100644 --- a/dm/pkg/checker/table_structure.go +++ b/dm/pkg/checker/table_structure.go @@ -18,23 +18,41 @@ import ( "context" "database/sql" "fmt" + "math" "strings" + "sync" + "time" + + "go.uber.org/zap" + "golang.org/x/sync/errgroup" "github.com/pingcap/errors" column "github.com/pingcap/tidb-tools/pkg/column-mapping" "github.com/pingcap/tidb-tools/pkg/dbutil" + "github.com/pingcap/tidb-tools/pkg/filter" + "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/charset" - "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" + + "github.com/pingcap/tiflow/dm/pkg/log" + "github.com/pingcap/tiflow/dm/pkg/utils" ) -// AutoIncrementKeyChecking is an identification for auto increment key checking. -const AutoIncrementKeyChecking = "auto-increment key checking" +const ( + // AutoIncrementKeyChecking is an identification for auto increment key checking. + AutoIncrementKeyChecking = "auto-increment key checking" +) + +type checkItem struct { + table *filter.Table + sourceID string +} // hold information of incompatibility option. type incompatibilityOption struct { state State + tableID string instruction string errMessage string } @@ -58,19 +76,28 @@ func (o *incompatibilityOption) String() string { // In generally we need to check definitions of columns, constraints and table options. // Because of the early TiDB engineering design, we did not have a complete list of check items, which are all based on experience now. type TablesChecker struct { - db *sql.DB - dbinfo *dbutil.DBConfig - tables map[string][]string // schema => []table; if []table is empty, query tables from db - + dbs map[string]*sql.DB + tableMap map[string][]*filter.Table // sourceID => {[table1, table2, ...]} + reMu sync.Mutex + inCh chan *checkItem + optCh chan *incompatibilityOption + wg sync.WaitGroup + dumpThreads int } // NewTablesChecker returns a RealChecker. -func NewTablesChecker(db *sql.DB, dbinfo *dbutil.DBConfig, tables map[string][]string) RealChecker { - return &TablesChecker{ - db: db, - dbinfo: dbinfo, - tables: tables, +func NewTablesChecker(dbs map[string]*sql.DB, tableMap map[string][]*filter.Table, dumpThreads int) RealChecker { + if dumpThreads == 0 { + dumpThreads = 1 + } + c := &TablesChecker{ + dbs: dbs, + tableMap: tableMap, + dumpThreads: dumpThreads, } + c.inCh = make(chan *checkItem, dumpThreads) + c.optCh = make(chan *incompatibilityOption, dumpThreads) + return c } // Check implements RealChecker interface. @@ -79,111 +106,119 @@ func (c *TablesChecker) Check(ctx context.Context) *Result { Name: c.Name(), Desc: "check compatibility of table structure", State: StateSuccess, - Extra: fmt.Sprintf("address of db instance - %s:%d", c.dbinfo.Host, c.dbinfo.Port), } - var ( - err error - options = make(map[string][]*incompatibilityOption) - statements = make(map[string]string) - ) - for schema, tables := range c.tables { - if len(tables) == 0 { - tables, err = dbutil.GetTables(ctx, c.db, schema) - if err != nil { - markCheckError(r, err) - return r - } - } - - for _, table := range tables { - tableName := dbutil.TableName(schema, table) - statement, err := dbutil.GetCreateTableSQL(ctx, c.db, schema, table) - if err != nil { - // continue if table was deleted when checking - if isMySQLError(err, mysql.ErrNoSuchTable) { - continue - } - markCheckError(r, err) - return r - } + startTime := time.Now() + concurrency, err := getConcurrency(ctx, c.tableMap, c.dbs, c.dumpThreads) + if err != nil { + markCheckError(r, err) + return r + } + eg, checkCtx := errgroup.WithContext(ctx) + for i := 0; i < concurrency; i++ { + eg.Go(func() error { + return c.checkTable(checkCtx) + }) + } - opts := c.checkCreateSQL(ctx, statement) - if len(opts) > 0 { - options[tableName] = opts - statements[tableName] = statement - } - } + dispatchTableItem(checkCtx, c.tableMap, c.inCh) + c.wg.Add(1) + go c.handleOpts(ctx, r) + if err := eg.Wait(); err != nil { + c.reMu.Lock() + markCheckError(r, err) + c.reMu.Unlock() } + close(c.optCh) + c.wg.Wait() - for name, opts := range options { - if len(opts) == 0 { - continue - } - tableMsg := "table " + name + " " + log.L().Logger.Info("check table structure over", zap.Duration("spend time", time.Since(startTime))) + return r +} - for _, option := range opts { - switch option.state { +// Name implements RealChecker interface. +func (c *TablesChecker) Name() string { + return "table structure compatibility check" +} + +func (c *TablesChecker) handleOpts(ctx context.Context, r *Result) { + defer c.wg.Done() + for { + select { + case <-ctx.Done(): + return + case opt, ok := <-c.optCh: + if !ok { + return + } + tableMsg := "table " + opt.tableID + " " + c.reMu.Lock() + switch opt.state { case StateWarning: if r.State != StateFailure { r.State = StateWarning } - e := NewError(tableMsg + option.errMessage) + e := NewError(tableMsg + opt.errMessage) e.Severity = StateWarning - e.Instruction = option.instruction + e.Instruction = opt.instruction r.Errors = append(r.Errors, e) case StateFailure: r.State = StateFailure - e := NewError(tableMsg + option.errMessage) - e.Instruction = option.instruction + e := NewError(tableMsg + opt.errMessage) + e.Instruction = opt.instruction r.Errors = append(r.Errors, e) } + c.reMu.Unlock() } } - - return r } -// Name implements RealChecker interface. -func (c *TablesChecker) Name() string { - return "table structure compatibility check" -} - -func (c *TablesChecker) checkCreateSQL(ctx context.Context, statement string) []*incompatibilityOption { - parser2, err := dbutil.GetParserForDB(ctx, c.db) - if err != nil { - return []*incompatibilityOption{ - { - state: StateFailure, - errMessage: err.Error(), - }, - } - } +func (c *TablesChecker) checkTable(ctx context.Context) error { + var ( + sourceID string + p *parser.Parser + err error + ) + for { + select { + case <-ctx.Done(): + return context.Canceled + case checkItem, ok := <-c.inCh: + if !ok { + return nil + } + table := checkItem.table + if len(sourceID) == 0 || sourceID != checkItem.sourceID { + sourceID = checkItem.sourceID + p, err = dbutil.GetParserForDB(ctx, c.dbs[sourceID]) + if err != nil { + return err + } + } + db := c.dbs[checkItem.sourceID] + statement, err := dbutil.GetCreateTableSQL(ctx, db, table.Schema, table.Name) + if err != nil { + // continue if table was deleted when checking + if isMySQLError(err, mysql.ErrNoSuchTable) { + continue + } + return err + } - stmt, err := parser2.ParseOneStmt(statement, "", "") - if err != nil { - return []*incompatibilityOption{ - { - state: StateFailure, - errMessage: err.Error(), - }, + ctStmt, err := getCreateTableStmt(p, statement) + if err != nil { + return err + } + opts := c.checkAST(ctStmt) + for _, opt := range opts { + opt.tableID = table.String() + c.optCh <- opt + } } } - // Analyze ast - return c.checkAST(stmt) } -func (c *TablesChecker) checkAST(stmt ast.StmtNode) []*incompatibilityOption { - st, ok := stmt.(*ast.CreateTableStmt) - if !ok { - return []*incompatibilityOption{ - { - state: StateFailure, - errMessage: fmt.Sprintf("Expect CreateTableStmt but got %T", stmt), - }, - } - } - +func (c *TablesChecker) checkAST(st *ast.CreateTableStmt) []*incompatibilityOption { var options []*incompatibilityOption // check columns @@ -271,23 +306,35 @@ func (c *TablesChecker) checkTableOption(opt *ast.TableOption) *incompatibilityO // * check whether they have same column list // * check whether they have auto_increment key. type ShardingTablesChecker struct { - name string - + targetTableID string dbs map[string]*sql.DB - tables map[string]map[string][]string // instance => {schema: [table1, table2, ...]} + tableMap map[string][]*filter.Table // sourceID => {[table1, table2, ...]} mapping map[string]*column.Mapping checkAutoIncrementPrimaryKey bool + firstCreateTableStmtNode *ast.CreateTableStmt + firstTable *filter.Table + firstSourceID string + inCh chan *checkItem + reMu sync.Mutex + dumpThreads int } // NewShardingTablesChecker returns a RealChecker. -func NewShardingTablesChecker(name string, dbs map[string]*sql.DB, tables map[string]map[string][]string, mapping map[string]*column.Mapping, checkAutoIncrementPrimaryKey bool) RealChecker { - return &ShardingTablesChecker{ - name: name, +func NewShardingTablesChecker(targetTableID string, dbs map[string]*sql.DB, tableMap map[string][]*filter.Table, mapping map[string]*column.Mapping, checkAutoIncrementPrimaryKey bool, dumpThreads int) RealChecker { + if dumpThreads == 0 { + dumpThreads = 1 + } + c := &ShardingTablesChecker{ + targetTableID: targetTableID, dbs: dbs, - tables: tables, + tableMap: tableMap, mapping: mapping, checkAutoIncrementPrimaryKey: checkAutoIncrementPrimaryKey, + dumpThreads: dumpThreads, } + c.inCh = make(chan *checkItem, dumpThreads) + + return c } // Check implements RealChecker interface. @@ -296,168 +343,136 @@ func (c *ShardingTablesChecker) Check(ctx context.Context) *Result { Name: c.Name(), Desc: "check consistency of sharding table structures", State: StateSuccess, - Extra: fmt.Sprintf("sharding %s", c.name), + Extra: fmt.Sprintf("sharding %s,", c.targetTableID), } - var ( - stmtNode *ast.CreateTableStmt - firstTable string - firstInstance string - ) - - for instance, schemas := range c.tables { - db, ok := c.dbs[instance] - if !ok { - markCheckError(r, errors.NotFoundf("client for instance %s", instance)) - return r - } - - parser2, err := dbutil.GetParserForDB(ctx, db) - if err != nil { - markCheckError(r, err) - r.Extra = fmt.Sprintf("fail to get parser for instance %s on sharding %s", instance, c.name) - return r - } - - for schema, tables := range schemas { - for _, table := range tables { - statement, err := dbutil.GetCreateTableSQL(ctx, db, schema, table) - if err != nil { - // continue if table was deleted when checking - if isMySQLError(err, mysql.ErrNoSuchTable) { - continue - } - markCheckError(r, err) - r.Extra = fmt.Sprintf("instance %s on sharding %s", instance, c.name) - return r - } + startTime := time.Now() + log.L().Logger.Info("start to check sharding tables") - info, err := dbutil.GetTableInfoBySQL(statement, parser2) - if err != nil { - markCheckError(r, err) - r.Extra = fmt.Sprintf("instance %s on sharding %s", instance, c.name) - return r - } - stmt, err := parser2.ParseOneStmt(statement, "", "") - if err != nil { - markCheckError(r, errors.Annotatef(err, "statement %s", statement)) - r.Extra = fmt.Sprintf("instance %s on sharding %s", instance, c.name) - return r - } + for sourceID, tables := range c.tableMap { + c.firstSourceID = sourceID + c.firstTable = tables[0] + break + } + db, ok := c.dbs[c.firstSourceID] + if !ok { + markCheckError(r, errors.NotFoundf("client for sourceID %s", c.firstSourceID)) + return r + } - ctStmt, ok := stmt.(*ast.CreateTableStmt) - if !ok { - markCheckError(r, errors.Errorf("Expect CreateTableStmt but got %T", stmt)) - r.Extra = fmt.Sprintf("instance %s on sharding %s", instance, c.name) - return r - } + p, err := dbutil.GetParserForDB(ctx, db) + if err != nil { + r.Extra = fmt.Sprintf("fail to get parser for sourceID %s on sharding %s", c.firstSourceID, c.targetTableID) + markCheckError(r, err) + return r + } + r.Extra = fmt.Sprintf("sourceID %s on sharding %s", c.firstSourceID, c.targetTableID) + statement, err := dbutil.GetCreateTableSQL(ctx, db, c.firstTable.Schema, c.firstTable.Name) + if err != nil { + markCheckError(r, err) + return r + } - if c.checkAutoIncrementPrimaryKey { - passed := c.checkAutoIncrementKey(instance, schema, table, ctStmt, info, r) - if !passed { - return r - } - } + c.firstCreateTableStmtNode, err = getCreateTableStmt(p, statement) + if err != nil { + markCheckError(r, err) + return r + } - if stmtNode == nil { - stmtNode = ctStmt - firstTable = dbutil.TableName(schema, table) - firstInstance = instance - continue - } + concurrency, err := getConcurrency(ctx, c.tableMap, c.dbs, c.dumpThreads) + if err != nil { + markCheckError(r, err) + return r + } + eg, checkCtx := errgroup.WithContext(ctx) + for i := 0; i < concurrency; i++ { + eg.Go(func() error { + return c.checkShardingTable(checkCtx, r) + }) + } - checkErr := c.checkConsistency(stmtNode, ctStmt, firstTable, dbutil.TableName(schema, table), firstInstance, instance) - if checkErr != nil { - r.State = StateFailure - r.Errors = append(r.Errors, checkErr) - r.Extra = fmt.Sprintf("error on sharding %s", c.name) - r.Instruction = "please set same table structure for sharding tables" - return r - } - } - } + dispatchTableItem(checkCtx, c.tableMap, c.inCh) + if err := eg.Wait(); err != nil { + markCheckError(r, err) } + log.L().Logger.Info("check sharding table structure over", zap.Duration("spend time", time.Since(startTime))) return r } -func (c *ShardingTablesChecker) checkAutoIncrementKey(instance, schema, table string, ctStmt *ast.CreateTableStmt, info *model.TableInfo, r *Result) bool { - autoIncrementKeys := c.findAutoIncrementKey(ctStmt, info) - for columnName, isBigInt := range autoIncrementKeys { - hasMatchedRule := false - if cm, ok1 := c.mapping[instance]; ok1 { - ruleSet := cm.Selector.Match(schema, table) - for _, rule := range ruleSet { - r, ok2 := rule.(*column.Rule) - if !ok2 { +func (c *ShardingTablesChecker) checkShardingTable(ctx context.Context, r *Result) error { + var ( + sourceID string + p *parser.Parser + err error + ) + for { + select { + case <-ctx.Done(): + return nil + case checkItem, ok := <-c.inCh: + if !ok { + return nil + } + table := checkItem.table + if len(sourceID) == 0 || sourceID != checkItem.sourceID { + sourceID = checkItem.sourceID + p, err = dbutil.GetParserForDB(ctx, c.dbs[sourceID]) + if err != nil { + c.reMu.Lock() + r.Extra = fmt.Sprintf("fail to get parser for sourceID %s on sharding %s", sourceID, c.targetTableID) + c.reMu.Unlock() + return err + } + } + + statement, err := dbutil.GetCreateTableSQL(ctx, c.dbs[sourceID], table.Schema, table.Name) + if err != nil { + // continue if table was deleted when checking + if isMySQLError(err, mysql.ErrNoSuchTable) { continue } + return err + } - if r.Expression == column.PartitionID && r.TargetColumn == columnName { - hasMatchedRule = true - break + ctStmt, err := getCreateTableStmt(p, statement) + if err != nil { + return err + } + + if has := c.hasAutoIncrementKey(ctStmt); has { + c.reMu.Lock() + if r.State == StateSuccess { + r.State = StateWarning + r.Errors = append(r.Errors, NewError("sourceID %s table %v of sharding %s have auto-increment key, please make sure them don't conflict in target table!", sourceID, table, c.targetTableID)) + r.Instruction = "If happen conflict, please handle it by yourself. You can refer to https://docs.pingcap.com/tidb-data-migration/stable/shard-merge-best-practices/#handle-conflicts-between-primary-keys-or-unique-indexes-across-multiple-sharded-tables" + r.Extra = AutoIncrementKeyChecking } + c.reMu.Unlock() } - if hasMatchedRule && !isBigInt { + if checkErr := c.checkConsistency(ctStmt, table.String(), sourceID); checkErr != nil { + c.reMu.Lock() r.State = StateFailure - r.Errors = append(r.Errors, NewError("instance %s table `%s`.`%s` of sharding %s have auto-increment key %s and column mapping, but type of %s should be bigint", instance, schema, table, c.name, columnName, columnName)) - r.Instruction = "please set auto-increment key type to bigint" - r.Extra = AutoIncrementKeyChecking - return false + r.Errors = append(r.Errors, checkErr) + r.Extra = fmt.Sprintf("error on sharding %s", c.targetTableID) + r.Instruction = "please set same table structure for sharding tables" + c.reMu.Unlock() + return nil } } - - if !hasMatchedRule { - r.State = StateFailure - r.Errors = append(r.Errors, NewError("instance %s table `%s`.`%s` of sharding %s have auto-increment key %s and column mapping, but type of %s should be bigint", instance, schema, table, c.name, columnName, columnName)) - r.Instruction = "please handle it by yourself" - r.Extra = AutoIncrementKeyChecking - return false - } } - - return true } -func (c *ShardingTablesChecker) findAutoIncrementKey(stmt *ast.CreateTableStmt, info *model.TableInfo) map[string]bool { - autoIncrementKeys := make(map[string]bool) - autoIncrementCols := make(map[string]bool) - +func (c *ShardingTablesChecker) hasAutoIncrementKey(stmt *ast.CreateTableStmt) bool { for _, col := range stmt.Cols { - var ( - hasAutoIncrementOpt bool - isUnique bool - ) for _, opt := range col.Options { - switch opt.Tp { - case ast.ColumnOptionAutoIncrement: - hasAutoIncrementOpt = true - case ast.ColumnOptionPrimaryKey, ast.ColumnOptionUniqKey: - isUnique = true - } - } - - if hasAutoIncrementOpt { - if isUnique { - autoIncrementKeys[col.Name.Name.O] = col.Tp.Tp == mysql.TypeLonglong - } else { - autoIncrementCols[col.Name.Name.O] = col.Tp.Tp == mysql.TypeLonglong - } - } - } - - for _, index := range info.Indices { - if index.Unique || index.Primary { - if len(index.Columns) == 1 { - if isBigInt, ok := autoIncrementCols[index.Columns[0].Name.O]; ok { - autoIncrementKeys[index.Columns[0].Name.O] = isBigInt - } + if opt.Tp == ast.ColumnOptionAutoIncrement { + return true } } } - - return autoIncrementKeys + return false } type briefColumnInfo struct { @@ -490,8 +505,8 @@ func (cs briefColumnInfos) String() string { return strings.Join(colStrs, "\n") } -func (c *ShardingTablesChecker) checkConsistency(self, other *ast.CreateTableStmt, selfTable, otherTable, selfInstance, otherInstance string) *Error { - selfColumnList := getBriefColumnList(self) +func (c *ShardingTablesChecker) checkConsistency(other *ast.CreateTableStmt, otherTable, othersourceID string) *Error { + selfColumnList := getBriefColumnList(c.firstCreateTableStmtNode) otherColumnList := getBriefColumnList(other) if len(selfColumnList) != len(otherColumnList) { @@ -503,16 +518,16 @@ func (c *ShardingTablesChecker) checkConsistency(self, other *ast.CreateTableStm } return ret } - e.Self = fmt.Sprintf("instance %s table %s columns %v", selfInstance, selfTable, getColumnNames(selfColumnList)) - e.Other = fmt.Sprintf("instance %s table %s columns %v", otherInstance, otherTable, getColumnNames(otherColumnList)) + e.Self = fmt.Sprintf("sourceID %s table %v columns %v", c.firstSourceID, c.firstTable, getColumnNames(selfColumnList)) + e.Other = fmt.Sprintf("sourceID %s table %s columns %v", othersourceID, otherTable, getColumnNames(otherColumnList)) return e } for i := range selfColumnList { if *selfColumnList[i] != *otherColumnList[i] { e := NewError("different column definition") - e.Self = fmt.Sprintf("instance %s table %s column %s", selfInstance, selfTable, selfColumnList[i]) - e.Other = fmt.Sprintf("instance %s table %s column %s", otherInstance, otherTable, otherColumnList[i]) + e.Self = fmt.Sprintf("sourceID %s table %s column %s", c.firstSourceID, c.firstTable, selfColumnList[i]) + e.Other = fmt.Sprintf("sourceID %s table %s column %s", othersourceID, otherTable, otherColumnList[i]) return e } } @@ -546,5 +561,35 @@ func getBriefColumnList(stmt *ast.CreateTableStmt) briefColumnInfos { // Name implements Checker interface. func (c *ShardingTablesChecker) Name() string { - return fmt.Sprintf("sharding table %s consistency checking", c.name) + return fmt.Sprintf("sharding table %s consistency checking", c.targetTableID) +} + +func dispatchTableItem(ctx context.Context, tableMap map[string][]*filter.Table, inCh chan *checkItem) { + for sourceID, tables := range tableMap { + for _, table := range tables { + select { + case <-ctx.Done(): + log.L().Logger.Warn("ctx canceled before input tables completely") + return + case inCh <- &checkItem{table, sourceID}: + } + } + } + close(inCh) +} + +func getConcurrency(ctx context.Context, tableMap map[string][]*filter.Table, dbs map[string]*sql.DB, dumpThreads int) (int, error) { + concurrency := dumpThreads + for sourceID := range tableMap { + db, ok := dbs[sourceID] + if !ok { + return 0, errors.NotFoundf("client for sourceID %s", sourceID) + } + maxConnections, err := utils.GetMaxConnections(ctx, db) + if err != nil { + return 0, err + } + concurrency = int(math.Min(float64(concurrency), float64((maxConnections+1)/2))) + } + return concurrency, nil } diff --git a/dm/pkg/checker/table_structure_test.go b/dm/pkg/checker/table_structure_test.go index cdd48cad524..bb67ab1a6a8 100644 --- a/dm/pkg/checker/table_structure_test.go +++ b/dm/pkg/checker/table_structure_test.go @@ -16,12 +16,10 @@ package checker import ( "context" "database/sql" - "encoding/json" - "fmt" "github.com/DATA-DOG/go-sqlmock" tc "github.com/pingcap/check" - "github.com/pingcap/tidb-tools/pkg/dbutil" + "github.com/pingcap/tidb-tools/pkg/filter" ) func (t *testCheckSuite) TestShardingTablesChecker(c *tc.C) { @@ -29,23 +27,8 @@ func (t *testCheckSuite) TestShardingTablesChecker(c *tc.C) { c.Assert(err, tc.IsNil) ctx := context.Background() - printJSON := func(r *Result) { - rawResult, _ := json.MarshalIndent(r, "", "\t") - fmt.Println("\n" + string(rawResult)) - } - // 1. test a success check - - sqlModeRow := sqlmock.NewRows([]string{"Variable_name", "Value"}). - AddRow("sql_mode", "ANSI_QUOTES") - mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) - createTableRow := sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test-table-1", `CREATE TABLE "test-table-1" ( - "c" int(11) NOT NULL, - PRIMARY KEY ("c") -) ENGINE=InnoDB DEFAULT CHARSET=latin1`) - mock.ExpectQuery("SHOW CREATE TABLE `test-db`.`test-table-1`").WillReturnRows(createTableRow) - + mock = initShardingMock(mock) createTableRow2 := sqlmock.NewRows([]string{"Table", "Create Table"}). AddRow("test-table-2", `CREATE TABLE "test-table-2" ( "c" int(11) NOT NULL, @@ -55,26 +38,28 @@ func (t *testCheckSuite) TestShardingTablesChecker(c *tc.C) { checker := NewShardingTablesChecker("test-name", map[string]*sql.DB{"test-source": db}, - map[string]map[string][]string{"test-source": {"test-db": []string{"test-table-1", "test-table-2"}}}, + map[string][]*filter.Table{"test-source": { + {Schema: "test-db", Name: "test-table-1"}, + {Schema: "test-db", Name: "test-table-2"}, + }}, nil, - false) + false, + 1) result := checker.Check(ctx) - c.Assert(result.State, tc.Equals, StateSuccess) c.Assert(mock.ExpectationsWereMet(), tc.IsNil) - printJSON(result) // 2. check different column number - - sqlModeRow = sqlmock.NewRows([]string{"Variable_name", "Value"}). - AddRow("sql_mode", "ANSI_QUOTES") - mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) - createTableRow = sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test-table-1", `CREATE TABLE "test-table-1" ( - "c" int(11) NOT NULL, - PRIMARY KEY ("c") -) ENGINE=InnoDB DEFAULT CHARSET=latin1`) - mock.ExpectQuery("SHOW CREATE TABLE `test-db`.`test-table-1`").WillReturnRows(createTableRow) + checker = NewShardingTablesChecker("test-name", + map[string]*sql.DB{"test-source": db}, + map[string][]*filter.Table{"test-source": { + {Schema: "test-db", Name: "test-table-1"}, + {Schema: "test-db", Name: "test-table-2"}, + }}, + nil, + false, + 1) + mock = initShardingMock(mock) createTableRow2 = sqlmock.NewRows([]string{"Table", "Create Table"}). AddRow("test-table-2", `CREATE TABLE "test-table-2" ( "c" int(11) NOT NULL, @@ -87,19 +72,18 @@ func (t *testCheckSuite) TestShardingTablesChecker(c *tc.C) { c.Assert(result.State, tc.Equals, StateFailure) c.Assert(result.Errors, tc.HasLen, 1) c.Assert(mock.ExpectationsWereMet(), tc.IsNil) - printJSON(result) // 3. check different column def - - sqlModeRow = sqlmock.NewRows([]string{"Variable_name", "Value"}). - AddRow("sql_mode", "ANSI_QUOTES") - mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) - createTableRow = sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test-table-1", `CREATE TABLE "test-table-1" ( - "c" int(11) NOT NULL, - PRIMARY KEY ("c") -) ENGINE=InnoDB DEFAULT CHARSET=latin1`) - mock.ExpectQuery("SHOW CREATE TABLE `test-db`.`test-table-1`").WillReturnRows(createTableRow) + checker = NewShardingTablesChecker("test-name", + map[string]*sql.DB{"test-source": db}, + map[string][]*filter.Table{"test-source": { + {Schema: "test-db", Name: "test-table-1"}, + {Schema: "test-db", Name: "test-table-2"}, + }}, + nil, + false, + 1) + mock = initShardingMock(mock) createTableRow2 = sqlmock.NewRows([]string{"Table", "Create Table"}). AddRow("test-table-2", `CREATE TABLE "test-table-2" ( "c" varchar(20) NOT NULL, @@ -111,7 +95,6 @@ func (t *testCheckSuite) TestShardingTablesChecker(c *tc.C) { c.Assert(result.State, tc.Equals, StateFailure) c.Assert(result.Errors, tc.HasLen, 1) c.Assert(mock.ExpectationsWereMet(), tc.IsNil) - printJSON(result) } func (t *testCheckSuite) TestTablesChecker(c *tc.C) { @@ -119,67 +102,103 @@ func (t *testCheckSuite) TestTablesChecker(c *tc.C) { c.Assert(err, tc.IsNil) ctx := context.Background() - printJSON := func(r *Result) { - rawResult, _ := json.MarshalIndent(r, "", "\t") - fmt.Println("\n" + string(rawResult)) - } - // 1. test a success check - - createTableRow := sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test-table-1", `CREATE TABLE "test-table-1" ( - "c" int(11) NOT NULL, - PRIMARY KEY ("c") -) ENGINE=InnoDB DEFAULT CHARSET=latin1`) - mock.ExpectQuery("SHOW CREATE TABLE `test-db`.`test-table-1`").WillReturnRows(createTableRow) + maxConnectionsRow := sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("max_connections", "2") + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(maxConnectionsRow) sqlModeRow := sqlmock.NewRows([]string{"Variable_name", "Value"}). AddRow("sql_mode", "ANSI_QUOTES") mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) + createTableRow := sqlmock.NewRows([]string{"Table", "Create Table"}). + AddRow("test-table-1", `CREATE TABLE "test-table-1" ( + "c" int(11) NOT NULL, + PRIMARY KEY ("c") + ) ENGINE=InnoDB DEFAULT CHARSET=latin1`) + mock.ExpectQuery("SHOW CREATE TABLE `test-db`.`test-table-1`").WillReturnRows(createTableRow) - checker := NewTablesChecker(db, - &dbutil.DBConfig{}, - map[string][]string{"test-db": {"test-table-1"}}) + checker := NewTablesChecker( + map[string]*sql.DB{"test-source": db}, + map[string][]*filter.Table{"test-source": { + {Schema: "test-db", Name: "test-table-1"}, + }}, + 1) result := checker.Check(ctx) - c.Assert(result.State, tc.Equals, StateSuccess) c.Assert(mock.ExpectationsWereMet(), tc.IsNil) - printJSON(result) // 2. check many errors - + maxConnectionsRow = sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("max_connections", "2") + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(maxConnectionsRow) + sqlModeRow = sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("sql_mode", "ANSI_QUOTES") + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) createTableRow = sqlmock.NewRows([]string{"Table", "Create Table"}). AddRow("test-table-1", `CREATE TABLE "test-table-1" ( "c" int(11) NOT NULL, CONSTRAINT "fk" FOREIGN KEY ("c") REFERENCES "t" ("c") ) ENGINE=InnoDB DEFAULT CHARSET=latin1`) mock.ExpectQuery("SHOW CREATE TABLE `test-db`.`test-table-1`").WillReturnRows(createTableRow) - sqlModeRow = sqlmock.NewRows([]string{"Variable_name", "Value"}). - AddRow("sql_mode", "ANSI_QUOTES") - mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) + checker = NewTablesChecker( + map[string]*sql.DB{"test-source": db}, + map[string][]*filter.Table{"test-source": { + {Schema: "test-db", Name: "test-table-1"}, + }}, + 1) result = checker.Check(ctx) - c.Assert(result.State, tc.Equals, StateFailure) c.Assert(result.Errors, tc.HasLen, 2) // no PK/UK + has FK c.Assert(mock.ExpectationsWereMet(), tc.IsNil) - printJSON(result) // 3. unsupported charset - + maxConnectionsRow = sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("max_connections", "2") + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(maxConnectionsRow) + sqlModeRow = sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("sql_mode", "ANSI_QUOTES") + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) createTableRow = sqlmock.NewRows([]string{"Table", "Create Table"}). AddRow("test-table-1", `CREATE TABLE "test-table-1" ( "c" int(11) NOT NULL, PRIMARY KEY ("c") ) ENGINE=InnoDB DEFAULT CHARSET=ucs2`) mock.ExpectQuery("SHOW CREATE TABLE `test-db`.`test-table-1`").WillReturnRows(createTableRow) - sqlModeRow = sqlmock.NewRows([]string{"Variable_name", "Value"}). - AddRow("sql_mode", "ANSI_QUOTES") - mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) + checker = NewTablesChecker( + map[string]*sql.DB{"test-source": db}, + map[string][]*filter.Table{"test-source": { + {Schema: "test-db", Name: "test-table-1"}, + }}, + 1) result = checker.Check(ctx) - c.Assert(result.State, tc.Equals, StateFailure) c.Assert(result.Errors, tc.HasLen, 1) c.Assert(mock.ExpectationsWereMet(), tc.IsNil) - printJSON(result) +} + +func initShardingMock(mock sqlmock.Sqlmock) sqlmock.Sqlmock { + sqlModeRow := sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("sql_mode", "ANSI_QUOTES") + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) + createTableRow := sqlmock.NewRows([]string{"Table", "Create Table"}). + AddRow("test-table-1", `CREATE TABLE "test-table-1" ( +"c" int(11) NOT NULL, +PRIMARY KEY ("c") +) ENGINE=InnoDB DEFAULT CHARSET=latin1`) + mock.ExpectQuery("SHOW CREATE TABLE `test-db`.`test-table-1`").WillReturnRows(createTableRow) + + maxConnecionsRow := sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("max_connections", "2") + mock.ExpectQuery("SHOW VARIABLES LIKE 'max_connections'").WillReturnRows(maxConnecionsRow) + sqlModeRow = sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("sql_mode", "ANSI_QUOTES") + mock.ExpectQuery("SHOW VARIABLES LIKE 'sql_mode'").WillReturnRows(sqlModeRow) + createTableRow = sqlmock.NewRows([]string{"Table", "Create Table"}). + AddRow("test-table-1", `CREATE TABLE "test-table-1" ( +"c" int(11) NOT NULL, +PRIMARY KEY ("c") +) ENGINE=InnoDB DEFAULT CHARSET=latin1`) + mock.ExpectQuery("SHOW CREATE TABLE `test-db`.`test-table-1`").WillReturnRows(createTableRow) + return mock } diff --git a/dm/pkg/checker/utils.go b/dm/pkg/checker/utils.go index 7d53500feee..689e5627d63 100644 --- a/dm/pkg/checker/utils.go +++ b/dm/pkg/checker/utils.go @@ -23,6 +23,8 @@ import ( "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" "github.com/pingcap/tidb-tools/pkg/utils" + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" ) // MySQLVersion represents MySQL version number. @@ -136,7 +138,10 @@ func markCheckError(result *Result, err error) { } else { state = StateFailure } - result.State = state + // `StateWarning` can't cover `StateFailure`. + if result.State != StateFailure { + result.State = state + } result.Errors = append(result.Errors, &Error{Severity: state, ShortErr: err.Error()}) } } @@ -146,3 +151,16 @@ func isMySQLError(err error, code uint16) bool { e, ok := err.(*mysql.MySQLError) return ok && e.Number == code } + +func getCreateTableStmt(p *parser.Parser, statement string) (*ast.CreateTableStmt, error) { + stmt, err := p.ParseOneStmt(statement, "", "") + if err != nil { + return nil, errors.Annotatef(err, "statement %s", statement) + } + + ctStmt, ok := stmt.(*ast.CreateTableStmt) + if !ok { + return nil, errors.Errorf("Expect CreateTableStmt but got %T", stmt) + } + return ctStmt, nil +} diff --git a/dm/pkg/utils/db.go b/dm/pkg/utils/db.go index f8761c201f2..89bbbe1bc31 100644 --- a/dm/pkg/utils/db.go +++ b/dm/pkg/utils/db.go @@ -634,3 +634,23 @@ func GetTableCreateSQL(ctx context.Context, conn *sql.Conn, tableID string) (sql } return createStr, nil } + +// GetMaxConnections gets max_connections for sql.DB which is suitable for session variable max_connections. +func GetMaxConnections(ctx context.Context, db *sql.DB) (int, error) { + c, err := db.Conn(ctx) + if err != nil { + return 0, err + } + defer c.Close() + return GetMaxConnectionsForConn(ctx, c) +} + +// GetMaxConnectionsForConn gets max_connections for sql.Conn which is suitable for session variable max_connections. +func GetMaxConnectionsForConn(ctx context.Context, conn *sql.Conn) (int, error) { + maxConnectionsStr, err := GetSessionVariable(ctx, conn, "max_connections") + if err != nil { + return 0, err + } + maxConnections, err := strconv.ParseUint(maxConnectionsStr, 10, 32) + return int(maxConnections), err +} diff --git a/dm/pkg/utils/db_test.go b/dm/pkg/utils/db_test.go index 3687f1eaf82..38d345ac044 100644 --- a/dm/pkg/utils/db_test.go +++ b/dm/pkg/utils/db_test.go @@ -458,3 +458,18 @@ func (t *testDBSuite) TestAddGSetWithPurged(c *C) { c.Assert(originSet, DeepEquals, tc.originGSet) } } + +func (t *testDBSuite) TestGetMaxConnections(c *C) { + ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout) + defer cancel() + + db, mock, err := sqlmock.New() + c.Assert(err, IsNil) + + rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("max_connections", "151") + mock.ExpectQuery(`SHOW VARIABLES LIKE 'max_connections'`).WillReturnRows(rows) + maxConnections, err := GetMaxConnections(ctx, db) + c.Assert(err, IsNil) + c.Assert(maxConnections, Equals, 151) + c.Assert(mock.ExpectationsWereMet(), IsNil) +}