diff --git a/sql/mysql/driver.go b/sql/mysql/driver.go index 140351cf8a8..fd0452d1cbf 100644 --- a/sql/mysql/driver.go +++ b/sql/mysql/driver.go @@ -142,13 +142,17 @@ func (d *conn) gteV(w string) bool { return d.compareV(w) >= 0 } // ltV reports if the connection version is < w. func (d *conn) ltV(w string) bool { return d.compareV(w) == -1 } -// MySQL standard unescape field function from its codebase: -// https://github.com/mysql/mysql-server/blob/8.0/sql/dd/impl/utils.cc +// unescape strings with backslashes returned +// for SQL expressions from information schema. func unescape(s string) string { var b strings.Builder - for i, c := range s { - if c != '\\' || i+1 < len(s) && s[i+1] != '\\' && s[i+1] != '=' && s[i+1] != ';' { - b.WriteRune(c) + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case c != '\\' || i == len(s)-1: + b.WriteByte(c) + case s[i+1] == '\'', s[i+1] == '\\': + b.WriteByte(s[i+1]) + i++ } } return b.String() diff --git a/sql/mysql/inspect.go b/sql/mysql/inspect.go index d3ff9558b3e..6b34ff7d218 100644 --- a/sql/mysql/inspect.go +++ b/sql/mysql/inspect.go @@ -335,12 +335,7 @@ func (i *inspect) addIndexes(s *schema.Schema, rows *sql.Rows) error { part := &schema.IndexPart{SeqNo: seqno, Desc: desc.Bool} switch { case sqlx.ValidString(expr): - part.X = &schema.RawExpr{X: expr.String} - // Functional indexes may need to be extracted from 'SHOW CREATE', - // because INFORMATION_SCHEMA returns them escaped and they cannot - // be inlined this way. - s := putShow(t) - s.indexes[idx] = append(s.indexes[idx], len(idx.Parts)) + part.X = &schema.RawExpr{X: unescape(expr.String)} case sqlx.ValidString(column): part.C, ok = t.Column(column.String) if !ok { @@ -407,6 +402,7 @@ func (i *inspect) checks(ctx context.Context, s *schema.Schema) error { Expr: unescape(clause.String), } if i.mariadb() { + check.Expr = clause.String // In MariaDB, JSON is an alias to LONGTEXT. For versions >= 10.4.3, the CHARSET and COLLATE set to utf8mb4 // and a CHECK constraint is automatically created for the column as well (i.e. JSON_VALID(``)). However, // we expect tools like Atlas and Ent to manually add this CHECK for older versions of MariaDB. @@ -425,10 +421,6 @@ func (i *inspect) checks(ctx context.Context, s *schema.Schema) error { check.Attrs = append(check.Attrs, &Enforced{V: false}) } t.Attrs = append(t.Attrs, check) - // CHECK constraints need to be extracted from 'SHOW CREATE', because - // INFORMATION_SCHEMA returns them escaped and they cannot be used on - // 'CREATE' this way. - putShow(t).checks = true } return rows.Err() } @@ -475,7 +467,6 @@ func (i *inspect) showCreate(ctx context.Context, s *schema.Schema) error { if err := i.setAutoInc(st, t); err != nil { return err } - // TODO(a8m): setChecks, setIndexExpr from CREATE statement. } return nil } @@ -797,10 +788,6 @@ type ( schema.Attr // AUTO_INCREMENT value to due missing value in information_schema. auto *AutoIncrement - // checks expressions formatted differently from exist in 'SHOW CREATE'. - checks bool - // indexes that contain expressions. - indexes map[*schema.Index][]int } ) @@ -810,7 +797,7 @@ func putShow(t *schema.Table) *showTable { return s } } - s := &showTable{indexes: make(map[*schema.Index][]int)} + s := &showTable{} t.Attrs = append(t.Attrs, s) return s } diff --git a/sql/mysql/inspect_test.go b/sql/mysql/inspect_test.go index 342bd79a06b..d3308f7b3bb 100644 --- a/sql/mysql/inspect_test.go +++ b/sql/mysql/inspect_test.go @@ -178,15 +178,8 @@ func TestDriver_InspectTable(t *testing.T) { | table | CONSTRAINT_NAME | CHECK_CLAUSE | ENFORCED | +--------+------------------+-------------------------------------------+------------+ | users | jsonc | json_valid(` + "`jsonc`" + `) | YES | +| users | users_chk_1 | longtext <> '\'\'""' | YES | +--------+------------------+-------------------------------------------+------------+ -`)) - m.ExpectQuery(sqltest.Escape("SHOW CREATE TABLE `public`.`users`")). - WillReturnRows(sqltest.Rows(` -+-------+---------------------------------------------------------------------------------------------------------------------------------------------+ -| Table | Create Table | -+-------+---------------------------------------------------------------------------------------------------------------------------------------------+ -| users | CREATE TABLE users (id bigint NOT NULL AUTO_INCREMENT) ENGINE=InnoDB AUTO_INCREMENT=55834574848 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin | -+-------+---------------------------------------------------------------------------------------------------------------------------------------------+ `)) }, expect: func(require *require.Assertions, t *schema.Table, err error) { @@ -198,6 +191,10 @@ func TestDriver_InspectTable(t *testing.T) { {Name: "longtext", Type: &schema.ColumnType{Raw: "longtext", Type: &schema.StringType{T: "longtext"}}}, {Name: "jsonc", Type: &schema.ColumnType{Raw: "json", Type: &schema.JSONType{T: "json"}}}, }, t.Columns) + require.EqualValues([]schema.Attr{ + &schema.Check{Name: "jsonc", Expr: "json_valid(`jsonc`)"}, + &schema.Check{Name: "users_chk_1", Expr: `longtext <> '\'\'""'`}, + }, t.Attrs) }, }, { @@ -605,7 +602,11 @@ func TestDriver_InspectTable(t *testing.T) { +-------------------+-------------------+-------------------------------------------+------------+ | TABLE_NAME | CONSTRAINT_NAME | CHECK_CLAUSE | ENFORCED | +-------------------+-------------------+-------------------------------------------+------------+ -| users | users_chk_1 | (` + "`c6`" + ` <>_latin1\'foo\\\'s\') | YES | +| users | users_chk_1 | (` + "`c6`" + ` <> _latin1\'foo\\\'s\') | YES | +| users | users_chk_2 | (c1 <> _latin1\'dev/atlas\') | YES | +| users | users_chk_3 | (c1 <> _latin1\'a\\\'b""\') | YES | +| users | users_chk_4 | (c1 <> in (_latin1\'usa\',_latin1\'uk\')) | YES | +| users | users_chk_5 | (c1 <> _latin1\'\\\\\\\\\\\'\\\'\') | YES | +-------------------+-------------------+-------------------------------------------+------------+ `)) m.ExpectQuery(sqltest.Escape("SHOW CREATE TABLE `public`.`users`")). @@ -626,7 +627,13 @@ func TestDriver_InspectTable(t *testing.T) { {Name: "c1", Type: &schema.ColumnType{Raw: "int", Type: &schema.IntegerType{T: "int"}}}, } require.EqualValues(columns, t.Columns) - require.EqualValues([]schema.Attr{&schema.Check{Name: "users_chk_1", Expr: "(`c6` <>_latin1\\'foo\\'s\\')"}, &CreateStmt{S: "CREATE TABLE users()"}}, t.Attrs) + require.EqualValues([]schema.Attr{ + &schema.Check{Name: "users_chk_1", Expr: "(`c6` <> _latin1'foo\\'s')"}, + &schema.Check{Name: "users_chk_2", Expr: "(c1 <> _latin1'dev/atlas')"}, + &schema.Check{Name: "users_chk_3", Expr: `(c1 <> _latin1'a\'b""')`}, + &schema.Check{Name: "users_chk_4", Expr: `(c1 <> in (_latin1'usa',_latin1'uk'))`}, + &schema.Check{Name: "users_chk_5", Expr: `(c1 <> _latin1'\\\\\'\'')`}, + }, t.Attrs) }, }, }