Skip to content

Commit

Permalink
Merge branch 'master' into distinguish_unique
Browse files Browse the repository at this point in the history
# Conflicts:
#	ddlmod.go
#	ddlmod_test.go
#	migrator.go
  • Loading branch information
black-06 committed Feb 4, 2024
2 parents 00be1b3 + e64f7a5 commit 535c55d
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 93 deletions.
58 changes: 56 additions & 2 deletions ddlmod.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
var (
sqliteSeparator = "`|\"|'|\t"
uniqueRegexp = regexp.MustCompile(fmt.Sprintf(`^CONSTRAINT [%v]?[\w-]+[%v]? UNIQUE (.*)$`, sqliteSeparator, sqliteSeparator))
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator))
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]?(?s:.*?)ON (.*)$`, sqliteSeparator, sqliteSeparator))
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
Expand Down Expand Up @@ -138,7 +138,7 @@ func parseDDL(strs ...string) (*ddl, error) {
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
}

Expand Down Expand Up @@ -181,6 +181,18 @@ func parseDDL(strs ...string) (*ddl, error) {
return &result, nil
}

func (d *ddl) clone() *ddl {
copied := new(ddl)
*copied = *d

copied.fields = make([]string, len(d.fields))
copy(copied.fields, d.fields)
copied.columns = make([]migrator.ColumnType, len(d.columns))
copy(copied.columns, d.columns)

return copied
}

func (d *ddl) compile() string {
if len(d.fields) == 0 {
return d.head
Expand All @@ -189,6 +201,21 @@ func (d *ddl) compile() string {
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
}

func (d *ddl) renameTable(dst, src string) error {
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
if err != nil {
return err
}

replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
if replaced == d.head {
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
}

d.head = replaced
return nil
}

func (d *ddl) addConstraint(name string, sql string) {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")

Expand Down Expand Up @@ -246,3 +273,30 @@ func (d *ddl) getColumns() []string {
}
return res
}

func (d *ddl) alterColumn(name, sql string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")

for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields[i] = sql
return false
}
}

d.fields = append(d.fields, sql)
return true
}

func (d *ddl) removeColumn(name string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")

for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}

return false
}
46 changes: 41 additions & 5 deletions ddlmod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func TestParseDDL(t *testing.T) {
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
}},
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
Expand All @@ -39,7 +39,7 @@ func TestParseDDL(t *testing.T) {
{"with_special_characters", []string{
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{
Expand Down Expand Up @@ -105,6 +105,42 @@ func TestParseDDL(t *testing.T) {
NullableValue: sql.NullBool{Bool: false, Valid: true},
}},
},
{
"non-unique index",
[]string{
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
{
"index with \n from .schema sqlite",
[]string{
"CREATE TABLE `test-d` (`field` integer NOT NULL)",
"CREATE INDEX `idx_uq`\n ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
}

for _, p := range params {
Expand All @@ -130,7 +166,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
Expand All @@ -139,7 +175,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
NameValue: sql.NullString{String: "dark_mode", Valid: true},
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
NullableValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{String: "true", Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
Expand Down
11 changes: 7 additions & 4 deletions error_translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ import (
"gorm.io/gorm"
)

var errCodes = map[string]int{
"uniqueConstraint": 2067,
// The error codes to map sqlite errors to gorm errors, here is a reference about error codes for sqlite https://www.sqlite.org/rescode.html.
var errCodes = map[int]error{
1555: gorm.ErrDuplicatedKey,
2067: gorm.ErrDuplicatedKey,
787: gorm.ErrForeignKeyViolated,
}

type ErrMessage struct {
Expand All @@ -30,8 +33,8 @@ func (dialector Dialector) Translate(err error) error {
return err
}

if errMsg.ExtendedCode == errCodes["uniqueConstraint"] {
return gorm.ErrDuplicatedKey
if translatedErr, found := errCodes[errMsg.ExtendedCode]; found {
return translatedErr
}
return err
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module gorm.io/driver/sqlite
go 1.20

require (
github.com/mattn/go-sqlite3 v1.14.16
gorm.io/gorm v1.25.0
github.com/mattn/go-sqlite3 v1.14.17
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55
)

require (
Expand Down
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55 h1:sC1Xj4TYrLqg1n3AN10w871An7wJM0gzgcm8jkIkECQ=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
98 changes: 33 additions & 65 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package sqlite
import (
"database/sql"
"fmt"
"regexp"
"strings"

"gorm.io/gorm"
Expand Down Expand Up @@ -78,33 +77,16 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {

func (m Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
}
for i, f := range createDDL.fields {
if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 && matches[1] == field.DBName {
createDDL.fields[i] = fmt.Sprintf("`%v` ?", field.DBName)
sqlArgs = []interface{}{m.FullDataTypeOf(field)}
// table created by old version might look like `CREATE TABLE ? (? varchar(10) UNIQUE)`.
// FullDataTypeOf doesn't contain UNIQUE, so we need to add unique constraint.
if strings.Contains(strings.ToUpper(matches[3]), " UNIQUE") {
uniName := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
uni, _ := m.GuessConstraintInterfaceAndTable(stmt, uniName)
if uni != nil {
uniSQL, uniArgs := uni.Build()
createDDL.addConstraint(uniName, uniSQL)
sqlArgs = append(sqlArgs, uniArgs...)
}
}
break
}
}
return createDDL.compile(), sqlArgs, nil

return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
}
return "", nil, fmt.Errorf("failed to alter field with name %v", name)

return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
})
})
}
Expand Down Expand Up @@ -159,19 +141,13 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
}

func (m Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}

reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
if err != nil {
return "", nil, err
}

createSQL := reg.ReplaceAllString(rawDDL, "")

return createSQL, nil, nil
ddl.removeColumn(name)
return ddl, nil, nil
})
}

Expand All @@ -180,7 +156,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)

return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
var (
constraintName string
constraintSql string
Expand All @@ -191,17 +167,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
constraintName = constraint.GetName()
constraintSql, constraintValues = constraint.Build()
} else {
return "", nil, nil
}

createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
return nil, nil, nil
}
createDDL.addConstraint(constraintName, constraintSql)
createSQL := createDDL.compile()

return createSQL, constraintValues, nil
ddl.addConstraint(constraintName, constraintSql)
return ddl, constraintValues, nil
})
})
}
Expand All @@ -214,15 +184,9 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
}

return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.removeConstraint(name)
createSQL := createDDL.compile()

return createSQL, nil, nil
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
ddl.removeConstraint(name)
return ddl, nil, nil
})
})
}
Expand Down Expand Up @@ -324,6 +288,9 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
if sql != "" {
if err := m.DropIndex(value, oldName); err != nil {
return err
}
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
}
return fmt.Errorf("failed to find index with name %v", oldName)
Expand Down Expand Up @@ -392,8 +359,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
return createSQL, nil
}

func (m Migrator) recreateTable(value interface{}, tablePtr *string,
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
func (m Migrator) recreateTable(
value interface{}, tablePtr *string,
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
table := stmt.Table
if tablePtr != nil {
Expand All @@ -405,27 +374,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
return err
}

newTableName := table + "__temp"

createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
originDDL, err := parseDDL(rawDDL)
if err != nil {
return err
}
if createSQL == "" {
return nil
}

tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
if err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
if createDDL == nil {
return nil
}

createDDL, err := parseDDL(createSQL)
if err != nil {
newTableName := table + "__temp"
if err := createDDL.renameTable(newTableName, table); err != nil {
return err
}

columns := createDDL.getColumns()
createSQL := createDDL.compile()

return m.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
Expand Down
Loading

0 comments on commit 535c55d

Please sign in to comment.