diff --git a/ddlmod.go b/ddlmod.go index 39cc13a..50e9655 100644 --- a/ddlmod.go +++ b/ddlmod.go @@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) { return &result, nil } +func (d *ddl) clone() *ddl { + copied := new(ddl) + *copied = *d + + copied.fields = make([]string, len(d.fields)) + copy(copied.fields, d.fields) + copied.columns = make([]migrator.ColumnType, len(d.columns)) + copy(copied.columns, d.columns) + + return copied +} + func (d *ddl) compile() string { if len(d.fields) == 0 { return d.head @@ -183,6 +195,21 @@ func (d *ddl) compile() string { return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ",")) } +func (d *ddl) renameTable(dst, src string) error { + tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*") + if err != nil { + return err + } + + replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst)) + if replaced == d.head { + return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head) + } + + d.head = replaced + return nil +} + func (d *ddl) addConstraint(name string, sql string) { reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]") @@ -240,3 +267,30 @@ func (d *ddl) getColumns() []string { } return res } + +func (d *ddl) alterColumn(name, sql string) bool { + reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$") + + for i := 0; i < len(d.fields); i++ { + if reg.MatchString(d.fields[i]) { + d.fields[i] = sql + return false + } + } + + d.fields = append(d.fields, sql) + return true +} + +func (d *ddl) removeColumn(name string) bool { + reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$") + + for i := 0; i < len(d.fields); i++ { + if reg.MatchString(d.fields[i]) { + d.fields = append(d.fields[:i], d.fields[i+1:]...) + return true + } + } + + return false +} diff --git a/migrator.go b/migrator.go index fd2eeb4..1b85d6d 100644 --- a/migrator.go +++ b/migrator.go @@ -3,7 +3,6 @@ package sqlite import ( "database/sql" "fmt" - "regexp" "strings" "gorm.io/gorm" @@ -78,23 +77,16 @@ func (m Migrator) HasColumn(value interface{}, name string) bool { func (m Migrator) AlterColumn(value interface{}, name string) error { return m.RunWithoutForeignKey(func() error { - return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { if field := stmt.Schema.LookUpField(name); field != nil { - // lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)` - reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)") - if err != nil { - return "", nil, err + if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) { + return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile()) } - createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName)) - - if createSQL == rawDDL { - return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL) - } - - return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil + return ddl, []interface{}{m.FullDataTypeOf(field)}, nil } - return "", nil, fmt.Errorf("failed to alter field with name %v", name) + + return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name) }) }) } @@ -149,19 +141,13 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { } func (m Migrator) DropColumn(value interface{}, name string) error { - return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { if field := stmt.Schema.LookUpField(name); field != nil { name = field.DBName } - reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,") - if err != nil { - return "", nil, err - } - - createSQL := reg.ReplaceAllString(rawDDL, "") - - return createSQL, nil, nil + ddl.removeColumn(name) + return ddl, nil, nil }) } @@ -170,7 +156,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) return m.recreateTable(value, &table, - func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { var ( constraintName string constraintSql string @@ -185,17 +171,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { constraintSql = "CONSTRAINT ? CHECK (?)" constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} } else { - return "", nil, nil + return nil, nil, nil } - createDDL, err := parseDDL(rawDDL) - if err != nil { - return "", nil, err - } - createDDL.addConstraint(constraintName, constraintSql) - createSQL := createDDL.compile() - - return createSQL, constraintValues, nil + ddl.addConstraint(constraintName, constraintSql) + return ddl, constraintValues, nil }) }) } @@ -210,15 +190,9 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { } return m.recreateTable(value, &table, - func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { - createDDL, err := parseDDL(rawDDL) - if err != nil { - return "", nil, err - } - createDDL.removeConstraint(name) - createSQL := createDDL.compile() - - return createSQL, nil, nil + func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { + ddl.removeConstraint(name) + return ddl, nil, nil }) }) } @@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) { return createSQL, nil } -func (m Migrator) recreateTable(value interface{}, tablePtr *string, - getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error { +func (m Migrator) recreateTable( + value interface{}, tablePtr *string, + getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error), +) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { table := stmt.Table if tablePtr != nil { @@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string, return err } - newTableName := table + "__temp" - - createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt) + originDDL, err := parseDDL(rawDDL) if err != nil { return err } - if createSQL == "" { - return nil - } - tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*") + createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt) if err != nil { return err } - createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + if createDDL == nil { + return nil + } - createDDL, err := parseDDL(createSQL) - if err != nil { + newTableName := table + "__temp" + if err := createDDL.renameTable(newTableName, table); err != nil { return err } + columns := createDDL.getColumns() + createSQL := createDDL.compile() return m.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil { diff --git a/sqlite.go b/sqlite.go index 4f0da2e..abcb3ae 100644 --- a/sqlite.go +++ b/sqlite.go @@ -198,7 +198,8 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { case schema.Bool: return "numeric" case schema.Int, schema.Uint: - if field.AutoIncrement && !field.PrimaryKey { + if field.AutoIncrement { + // doesn't check `PrimaryKey`, to keep backward compatibility // https://www.sqlite.org/autoinc.html return "integer PRIMARY KEY AUTOINCREMENT" } else {