Skip to content

Commit

Permalink
sql/postgres: apply generate columns for new columns
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Apr 18, 2022
1 parent a8cb20b commit f28f82a
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 124 deletions.
19 changes: 10 additions & 9 deletions sql/mysql/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ func (s *state) modifySchema(modify *schema.ModifySchema) error {
// for creating a table in a schema.
func (s *state) addTable(add *schema.AddTable) error {
var (
errors []string
b = Build("CREATE TABLE")
errs []string
b = Build("CREATE TABLE")
)
if sqlx.Has(add.Extra, &schema.IfNotExists{}) {
b.P("IF NOT EXISTS")
Expand All @@ -212,7 +212,7 @@ func (s *state) addTable(add *schema.AddTable) error {
b.Wrap(func(b *sqlx.Builder) {
b.MapComma(add.T.Columns, func(i int, b *sqlx.Builder) {
if err := s.column(b, add.T, add.T.Columns[i]); err != nil {
errors = append(errors, err.Error())
errs = append(errs, err.Error())
}
})
if pk := add.T.PrimaryKey; pk != nil {
Expand All @@ -229,7 +229,7 @@ func (s *state) addTable(add *schema.AddTable) error {
if len(add.T.ForeignKeys) > 0 {
b.Comma()
if err := s.fks(b, add.T.ForeignKeys...); err != nil {
errors = append(errors, err.Error())
errs = append(errs, err.Error())
}
}
for _, attr := range add.T.Attrs {
Expand All @@ -239,8 +239,8 @@ func (s *state) addTable(add *schema.AddTable) error {
}
}
})
if len(errors) > 0 {
return fmt.Errorf("create table %q: %s", add.T.Name, strings.Join(errors, ", "))
if len(errs) > 0 {
return fmt.Errorf("create table %q: %s", add.T.Name, strings.Join(errs, ", "))
}
s.tableAttr(b, add, add.T.Attrs...)
s.append(&migrate.Change{
Expand Down Expand Up @@ -343,8 +343,9 @@ func (s *state) alterTable(t *schema.Table, changes []schema.Change) error {
return err
}
reverse = append(reverse, &schema.ModifyColumn{
From: change.To,
To: change.From,
From: change.To,
To: change.From,
Change: change.Change,
})
case *schema.DropColumn:
b.P("DROP COLUMN").Ident(change.C.Name)
Expand Down Expand Up @@ -402,7 +403,7 @@ func (s *state) alterTable(t *schema.Table, changes []schema.Change) error {
b.P("DROP CHECK").Ident(change.From.Name).Comma().P("ADD")
s.check(b, change.To)
default:
return errors.New("unknown check constraints change")
return errors.New("unknown check constraint change")
}
reverse = append(reverse, &schema.ModifyCheck{
From: change.To,
Expand Down
9 changes: 0 additions & 9 deletions sql/postgres/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,6 @@ func FormatType(t schema.Type) (string, error) {
return f, nil
}

// mustFormat calls to FormatType and panics in case of error.
func mustFormat(t schema.Type) string {
s, err := FormatType(t)
if err != nil {
panic(err)
}
return s
}

// ParseType returns the schema.Type value represented by the given raw type.
// The raw value is expected to follow the format in PostgreSQL information schema
// or as an input for the CREATE TABLE statement.
Expand Down
10 changes: 9 additions & 1 deletion sql/postgres/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,15 @@ func (d *diff) typeChanged(from, to *schema.Column) (bool, error) {
case *schema.BinaryType, *schema.BoolType, *schema.DecimalType, *schema.FloatType,
*schema.IntegerType, *schema.JSONType, *schema.SpatialType, *schema.StringType,
*schema.TimeType, *BitType, *NetworkType, *UserDefinedType:
changed = mustFormat(toT) != mustFormat(fromT)
t1, err := FormatType(toT)
if err != nil {
return false, err
}
t2, err := FormatType(fromT)
if err != nil {
return false, err
}
changed = t1 != t2
case *enumType:
toT := toT.(*schema.EnumType)
changed = fromT.T != toT.T || !sqlx.ValuesEqual(fromT.Values, toT.Values)
Expand Down
202 changes: 118 additions & 84 deletions sql/postgres/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package postgres

import (
"context"
"errors"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -121,14 +122,19 @@ func (s *state) addTable(ctx context.Context, add *schema.AddTable) error {
if err := s.addTypes(ctx, add.T.Columns...); err != nil {
return err
}
b := Build("CREATE TABLE")
var (
errs []string
b = Build("CREATE TABLE")
)
if sqlx.Has(add.Extra, &schema.IfNotExists{}) {
b.P("IF NOT EXISTS")
}
b.Table(add.T)
b.Wrap(func(b *sqlx.Builder) {
b.MapComma(add.T.Columns, func(i int, b *sqlx.Builder) {
s.column(b, add.T.Columns[i])
if err := s.column(b, add.T.Columns[i]); err != nil {
errs = append(errs, err.Error())
}
})
if pk := add.T.PrimaryKey; pk != nil {
b.Comma().P("PRIMARY KEY")
Expand All @@ -145,6 +151,9 @@ func (s *state) addTable(ctx context.Context, add *schema.AddTable) error {
}
}
})
if len(errs) > 0 {
return fmt.Errorf("create table %q: %s", add.T.Name, strings.Join(errs, ", "))
}
s.append(&migrate.Change{
Cmd: b.String(),
Source: add,
Expand Down Expand Up @@ -277,80 +286,90 @@ func (s *state) modifyTable(ctx context.Context, modify *schema.ModifyTable) err
// alterTable modifies the given table by executing on it a list of changes in one SQL statement.
func (s *state) alterTable(t *schema.Table, changes []schema.Change) error {
var (
errors []string
b = Build("ALTER TABLE").Table(t)
reverse = Build("")
reverse []schema.Change
reversible = true
)
b.MapComma(changes, func(i int, b *sqlx.Builder) {
switch change := changes[i].(type) {
case *schema.AddColumn:
b.P("ADD COLUMN")
s.column(b, change.C)
reverse.Comma().P("DROP COLUMN").Ident(change.C.Name)
case *schema.DropColumn:
b.P("DROP COLUMN").Ident(change.C.Name)
reversible = false
case *schema.ModifyColumn:
if err := s.alterColumn(b, change.Change, change.To); err != nil {
errors = append(errors, err.Error())
}
if err := s.alterColumn(reverse.Comma(), change.Change, change.From); err != nil {
errors = append(errors, err.Error())
}
case *schema.AddForeignKey:
b.P("ADD")
s.fks(b, change.F)
reverse.Comma().P("DROP CONSTRAINT").Ident(change.F.Symbol)
case *schema.DropForeignKey:
b.P("DROP CONSTRAINT").Ident(change.F.Symbol)
reverse.P("ADD")
s.fks(reverse, change.F)
case *schema.AddCheck:
check(b.P("ADD"), change.C)
// Reverse operation is supported if
// the constraint name is not generated.
if reversible = change.C.Name != ""; reversible {
reverse.Comma().P("DROP CONSTRAINT").Ident(change.C.Name)
}
case *schema.DropCheck:
b.P("DROP CONSTRAINT").Ident(change.C.Name)
check(reverse.Comma().P("ADD"), change.C)
case *schema.ModifyCheck:
switch {
case change.From.Name == "":
errors = append(errors, "cannot modify unnamed check constraint")
case change.From.Name != change.To.Name:
errors = append(errors, fmt.Sprintf("mismatch check constraint names: %q != %q", change.From.Name, change.To.Name))
case change.From.Expr != change.To.Expr,
sqlx.Has(change.From.Attrs, &NoInherit{}) && !sqlx.Has(change.To.Attrs, &NoInherit{}),
!sqlx.Has(change.From.Attrs, &NoInherit{}) && sqlx.Has(change.To.Attrs, &NoInherit{}):
b.P("DROP CONSTRAINT").Ident(change.From.Name).Comma().P("ADD")
check(b, change.To)
reverse.Comma().P("DROP CONSTRAINT").Ident(change.To.Name).Comma().P("ADD")
check(reverse, change.From)
default:
errors = append(errors, "unknown check constraints change")
build := func(changes []schema.Change) (string, error) {
b := Build("ALTER TABLE").Table(t)
err := b.MapCommaErr(changes, func(i int, b *sqlx.Builder) error {
switch change := changes[i].(type) {
case *schema.AddColumn:
b.P("ADD COLUMN")
if err := s.column(b, change.C); err != nil {
return err
}
reverse = append(reverse, &schema.DropColumn{C: change.C})
case *schema.DropColumn:
b.P("DROP COLUMN").Ident(change.C.Name)
reverse = append(reverse, &schema.AddColumn{C: change.C})
case *schema.ModifyColumn:
if err := s.alterColumn(b, change.Change, change.To); err != nil {
return err
}
reverse = append(reverse, &schema.ModifyColumn{
From: change.To,
To: change.From,
Change: change.Change,
})
case *schema.AddForeignKey:
b.P("ADD")
s.fks(b, change.F)
reverse = append(reverse, &schema.DropForeignKey{F: change.F})
case *schema.DropForeignKey:
b.P("DROP CONSTRAINT").Ident(change.F.Symbol)
reverse = append(reverse, &schema.AddForeignKey{F: change.F})
case *schema.AddCheck:
check(b.P("ADD"), change.C)
// Reverse operation is supported if
// the constraint name is not generated.
if reversible = reversible && change.C.Name != ""; reversible {
reverse = append(reverse, &schema.DropCheck{C: change.C})
}
case *schema.DropCheck:
b.P("DROP CONSTRAINT").Ident(change.C.Name)
reverse = append(reverse, &schema.AddCheck{C: change.C})
case *schema.ModifyCheck:
switch {
case change.From.Name == "":
return errors.New("cannot modify unnamed check constraint")
case change.From.Name != change.To.Name:
return fmt.Errorf("mismatch check constraint names: %q != %q", change.From.Name, change.To.Name)
case change.From.Expr != change.To.Expr,
sqlx.Has(change.From.Attrs, &NoInherit{}) && !sqlx.Has(change.To.Attrs, &NoInherit{}),
!sqlx.Has(change.From.Attrs, &NoInherit{}) && sqlx.Has(change.To.Attrs, &NoInherit{}):
b.P("DROP CONSTRAINT").Ident(change.From.Name).Comma().P("ADD")
check(b, change.To)
default:
return errors.New("unknown check constraint change")
}
reverse = append(reverse, &schema.ModifyCheck{
From: change.To,
To: change.From,
})
}
return nil
})
if err != nil {
return "", nil
}
})
if len(errors) > 0 {
return fmt.Errorf("alter table: %s", strings.Join(errors, ", "))
return b.String(), nil
}
cmd, err := build(changes)
if err != nil {
return fmt.Errorf("alter table %q: %v", t.Name, err)
}
change := &migrate.Change{
Cmd: b.String(),
Cmd: cmd,
Source: &schema.ModifyTable{
T: t,
Changes: changes,
},
Comment: fmt.Sprintf("Modify %q table", t.Name),
Comment: fmt.Sprintf("modify %q table", t.Name),
}
if reversible {
b := Build("ALTER TABLE").Table(t)
if _, err := b.ReadFrom(reverse); err != nil {
return fmt.Errorf("unexpected buffer read: %w", err)
if change.Reverse, err = build(reverse); err != nil {
return fmt.Errorf("reversd alter table %q: %v", t.Name, err)
}
change.Reverse = b.String()
}
s.append(change)
return nil
Expand Down Expand Up @@ -491,7 +510,7 @@ func (s *state) addIndexes(t *schema.Table, indexes ...*schema.Index) {
// Unlike MySQL, the DROP command is not attached to ALTER TABLE.
// Therefore, we print indexes with their qualified name, because
// the connection that executes the statements may not be attached
// to the this schema.
// to this schema.
if t.Schema != nil {
b.WriteByte(b.QuoteChar)
b.WriteString(t.Schema.Name)
Expand All @@ -505,8 +524,12 @@ func (s *state) addIndexes(t *schema.Table, indexes ...*schema.Index) {
}
}

func (s *state) column(b *sqlx.Builder, c *schema.Column) {
b.Ident(c.Name).P(mustFormat(c.Type.Type))
func (s *state) column(b *sqlx.Builder, c *schema.Column) error {
t, err := FormatType(c.Type.Type)
if err != nil {
return err
}
b.Ident(c.Name).P(t)
if !c.Type.Null {
b.P("NOT")
}
Expand All @@ -517,27 +540,34 @@ func (s *state) column(b *sqlx.Builder, c *schema.Column) {
case *schema.Comment:
case *schema.Collation:
b.P("COLLATE").Ident(a.V)
case *Identity:
case *Identity, *schema.GeneratedExpr:
// Handled below.
default:
panic(fmt.Sprintf("unexpected column attribute: %T", attr))
return fmt.Errorf("unexpected column attribute: %T", attr)
}
}
switch hasI, hasX := sqlx.Has(c.Attrs, &Identity{}), sqlx.Has(c.Attrs, &schema.GeneratedExpr{}); {
case hasI && hasX:
return fmt.Errorf("both identity and generation expression specified for column %q", c.Name)
case hasI:
id, _ := identity(c.Attrs)
b.P("GENERATED", id.Generation, "AS IDENTITY")
if id.Sequence.Start != defaultSeqStart || id.Sequence.Increment != defaultSeqIncrement {
b.Wrap(func(b *sqlx.Builder) {
if id.Sequence.Start != defaultSeqStart {
b.P("START WITH", strconv.FormatInt(id.Sequence.Start, 10))
}
if id.Sequence.Increment != defaultSeqIncrement {
b.P("INCREMENT BY", strconv.FormatInt(id.Sequence.Increment, 10))
}
})
}
case hasX:
x := &schema.GeneratedExpr{}
sqlx.Has(c.Attrs, x)
b.P("GENERATED ALWAYS AS", sqlx.MayWrap(x.Expr), "STORED")
}
id, ok := identity(c.Attrs)
if !ok {
return
}
b.P("GENERATED", id.Generation, "AS IDENTITY")
if id.Sequence.Start != defaultSeqStart || id.Sequence.Increment != defaultSeqIncrement {
b.Wrap(func(b *sqlx.Builder) {
if id.Sequence.Start != defaultSeqStart {
b.P("START WITH", strconv.FormatInt(id.Sequence.Start, 10))
}
if id.Sequence.Increment != defaultSeqIncrement {
b.P("INCREMENT BY", strconv.FormatInt(id.Sequence.Increment, 10))
}
})
}
return nil
}

// columnDefault writes the default value of column to the builder.
Expand All @@ -564,7 +594,11 @@ func (s *state) alterColumn(b *sqlx.Builder, k schema.ChangeKind, c *schema.Colu
b.P("ALTER COLUMN").Ident(c.Name)
switch {
case k.Is(schema.ChangeType):
b.P("TYPE").P(mustFormat(c.Type.Type))
t, err := FormatType(c.Type.Type)
if err != nil {
return err
}
b.P("TYPE", t)
if collate := (schema.Collation{}); sqlx.Has(c.Attrs, &collate) {
b.P("COLLATE", collate.V)
}
Expand Down
Loading

0 comments on commit f28f82a

Please sign in to comment.