Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the set column NOT NULL operation #63

Merged
merged 9 commits into from
Aug 29, 2023
3 changes: 2 additions & 1 deletion examples/14_add_reviews_table.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
},
{
"name": "review",
"type": "text"
"type": "text",
"nullable": true
}
]
}
Expand Down
12 changes: 12 additions & 0 deletions examples/16_set_not_null.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"name": "16_set_not_null",
"operations": [
{
"set_not_null": {
"table": "reviews",
"column": "review",
"up": "product || ' is good'"
}
}
]
}
12 changes: 4 additions & 8 deletions pkg/migrations/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,10 @@ func (e IndexDoesNotExistError) Error() string {
return fmt.Sprintf("index %q does not exist", e.Name)
}

type NameRequiredError struct{}

func (e NameRequiredError) Error() string {
return "name is required"
type FieldRequiredError struct {
Name string
}

type UpSQLRequiredError struct{}

func (e UpSQLRequiredError) Error() string {
return "up SQL is required"
func (e FieldRequiredError) Error() string {
return fmt.Sprintf("field %q is required", e.Name)
}
20 changes: 10 additions & 10 deletions pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, schemaName, state
}

if !o.Column.Nullable && o.Column.Default == nil {
if err := addNotNullConstraint(ctx, conn, o); err != nil {
if err := addNotNullConstraint(ctx, conn, o.Table, o.Column.Name, TemporaryName(o.Column.Name)); err != nil {
return fmt.Errorf("failed to add check constraint: %w", err)
}
}
Expand All @@ -44,7 +44,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, schemaName, state
if err != nil {
return fmt.Errorf("failed to create trigger: %w", err)
}
if err := backFill(ctx, conn, o); err != nil {
if err := backFill(ctx, conn, o.Table, TemporaryName(o.Column.Name)); err != nil {
return fmt.Errorf("failed to backfill column: %w", err)
}
}
Expand Down Expand Up @@ -153,23 +153,23 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table
return err
}

func addNotNullConstraint(ctx context.Context, conn *sql.DB, o *OpAddColumn) error {
func addNotNullConstraint(ctx context.Context, conn *sql.DB, table, column, physicalColumn string) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s IS NOT NULL) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(NotNullConstraintName(o.Column.Name)),
pq.QuoteIdentifier(TemporaryName(o.Column.Name)),
pq.QuoteIdentifier(table),
pq.QuoteIdentifier(NotNullConstraintName(column)),
pq.QuoteIdentifier(physicalColumn),
))
return err
}

func backFill(ctx context.Context, conn *sql.DB, o *OpAddColumn) error {
func backFill(ctx context.Context, conn *sql.DB, table, column string) error {
// touch rows without changing them in order to have the trigger fire
// and set the value using the `up` SQL.
// TODO: this should be done in batches in case of large tables.
_, err := conn.ExecContext(ctx, fmt.Sprintf("UPDATE %s SET %s = %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column.Name)),
pq.QuoteIdentifier(TemporaryName(o.Column.Name))))
pq.QuoteIdentifier(table),
pq.QuoteIdentifier(column),
pq.QuoteIdentifier(column)))

return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/migrations/op_create_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (o *OpCreateIndex) Rollback(ctx context.Context, conn *sql.DB) error {

func (o *OpCreateIndex) Validate(ctx context.Context, s *schema.Schema) error {
if o.Name == "" {
return NameRequiredError{}
return FieldRequiredError{Name: "name"}
}

table := s.GetTable(o.Table)
Expand Down
157 changes: 155 additions & 2 deletions pkg/migrations/op_set_notnull.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package migrations
import (
"context"
"database/sql"
"fmt"

"github.com/lib/pq"
"github.com/xataio/pg-roll/pkg/schema"
)

Expand All @@ -16,15 +18,149 @@ type OpSetNotNull struct {
var _ Operation = (*OpSetNotNull)(nil)

func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, schemaName string, stateSchema string, s *schema.Schema) error {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)

// Create a copy of the column on the underlying table.
if err := duplicateColumn(ctx, conn, table, *column); err != nil {
return fmt.Errorf("failed to duplicate column: %w", err)
}

// Add an unchecked NOT NULL constraint to the new column.
if err := addNotNullConstraint(ctx, conn, o.Table, o.Column, TemporaryName(o.Column)); err != nil {
return fmt.Errorf("failed to add not null constraint: %w", err)
}

// Add a trigger to copy values from the old column to the new, rewriting NULL values using the `up` SQL.
err := createTrigger(ctx, conn, s, triggerConfig{
Direction: TriggerDirectionUp,
SchemaName: schemaName,
StateSchema: stateSchema,
Table: o.Table,
Column: o.Column,
PhysicalColumn: TemporaryName(o.Column),
SQL: *o.Up,
TestExpr: fmt.Sprintf("NEW.%s IS NULL", pq.QuoteIdentifier(o.Column)),
andrew-farries marked this conversation as resolved.
Show resolved Hide resolved
ElseExpr: fmt.Sprintf("NEW.%s = NEW.%s;",
pq.QuoteIdentifier(TemporaryName(o.Column)),
pq.QuoteIdentifier(o.Column)),
})
if err != nil {
return fmt.Errorf("failed to create up trigger: %w", err)
}

// Backfill the new column with values from the old column.
if err := backFill(ctx, conn, o.Table, TemporaryName(o.Column)); err != nil {
return fmt.Errorf("failed to backfill column: %w", err)
}

// Add a trigger to copy values from the new column to the old.
err = createTrigger(ctx, conn, s, triggerConfig{
Direction: TriggerDirectionDown,
SchemaName: schemaName,
StateSchema: stateSchema,
Table: o.Table,
Column: TemporaryName(o.Column),
PhysicalColumn: o.Column,
SQL: fmt.Sprintf("NEW.%s", pq.QuoteIdentifier(TemporaryName(o.Column))),
})
if err != nil {
return fmt.Errorf("failed to create down trigger: %w", err)
}

table.AddColumn(o.Column, schema.Column{
Name: TemporaryName(o.Column),
})

return nil
}

func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB) error {
// Validate the NOT NULL constraint on the old column.
andrew-farries marked this conversation as resolved.
Show resolved Hide resolved
// The constraint must be valid because:
// * Existing NULL values in the old column were rewritten using the `up` SQL during backfill.
// * New NULL values written to the old column during the migration period were also rewritten using `up` SQL.
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(NotNullConstraintName(o.Column))))
if err != nil {
return err
}

// Use the validated constraint to add `NOT NULL` to the new column
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s ALTER COLUMN %s SET NOT NULL",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column))))
if err != nil {
return err
}

// Drop the NOT NULL constraint
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(NotNullConstraintName(o.Column))))
if err != nil {
return err
}

// Drop the old column
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column)))
if err != nil {
return err
}

// Remove the up function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
if err != nil {
return err
}

// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column)))))
if err != nil {
return err
}

// Rename the new column to the old column name
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
pq.QuoteIdentifier(o.Column)))
if err != nil {
return err
}

return nil
}

func (o *OpSetNotNull) Rollback(ctx context.Context, conn *sql.DB) error {
return nil
// Drop the new column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
))
if err != nil {
return err
}

// Remove the up function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)),
))
if err != nil {
return err
}

// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column))),
))

return err
}

func (o *OpSetNotNull) Validate(ctx context.Context, s *schema.Schema) error {
Expand All @@ -38,7 +174,24 @@ func (o *OpSetNotNull) Validate(ctx context.Context, s *schema.Schema) error {
}

if o.Up == nil {
return UpSQLRequiredError{}
return FieldRequiredError{Name: "up"}
}
return nil
}

func duplicateColumn(ctx context.Context, conn *sql.DB, table *schema.Table, column schema.Column) error {
column.Name = TemporaryName(column.Name)

_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s",
pq.QuoteIdentifier(table.Name),
schemaColumnToSQL(column),
))

return err
}

// TODO: This function needs to be able to duplicate a column more precisely
// including constraints, indexes, defaults, etc.
Comment on lines +193 to +194
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very good point! Ideally we will have all the info needed for that in the schema object

func schemaColumnToSQL(c schema.Column) string {
return fmt.Sprintf("%s %s", pq.QuoteIdentifier(c.Name), c.Type)
}
Loading