Skip to content

Commit

Permalink
Rewrite column DEFAULTs using the SQL transformer (#332)
Browse files Browse the repository at this point in the history
Use the SQL transformer added in #329 to rewrite or reject column
`DEFAULT` values.

Column `DEFAULT` values are user-supplied SQL expressions that may need
to be restricted or rewritten in some environments.
  • Loading branch information
andrew-farries authored Mar 28, 2024
1 parent 130f451 commit 6334cb4
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 12 deletions.
14 changes: 10 additions & 4 deletions pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var _ Operation = (*OpAddColumn)(nil)
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)

if err := addColumn(ctx, conn, *o, table); err != nil {
if err := addColumn(ctx, conn, *o, table, tr); err != nil {
return nil, fmt.Errorf("failed to start add column operation: %w", err)
}

Expand Down Expand Up @@ -182,7 +182,7 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
return nil
}

func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table) error {
func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error {
// don't add non-nullable columns with no default directly
// they are handled by:
// - adding the column as nullable
Expand All @@ -203,10 +203,16 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table
o.Column.Check = nil

o.Column.Name = TemporaryName(o.Column.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s",
colSQL, err := ColumnToSQL(o.Column, tr)
if err != nil {
return err
}

_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s",
pq.QuoteIdentifier(t.Name),
ColumnToSQL(o.Column),
colSQL,
))

return err
}

Expand Down
101 changes: 101 additions & 0 deletions pkg/migrations/op_add_column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"

"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/testutils"
)

Expand Down Expand Up @@ -1289,3 +1290,103 @@ func TestAddColumnWithComment(t *testing.T) {
},
}})
}

func TestAddColumnDefaultTransformation(t *testing.T) {
t.Parallel()

sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{
"'default value 1'": "'rewritten'",
"'default value 2'": testutils.MockSQLTransformerError,
})

ExecuteTests(t, TestCases{
{
name: "column default is rewritten by the SQL transformer",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
},
},
},
},
{
Name: "02_add_column",
Operations: migrations.Operations{
&migrations.OpAddColumn{
Table: "users",
Column: migrations.Column{
Name: "name",
Type: "text",
Default: ptr("'default value 1'"),
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// Insert some data into the table
MustInsert(t, db, schema, "02_add_column", "users", map[string]string{
"id": "1",
})

// Ensure the row has the rewritten default value.
rows := MustSelect(t, db, schema, "02_add_column", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "rewritten"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Ensure the row has the rewritten default value.
rows := MustSelect(t, db, schema, "02_add_column", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "rewritten"},
}, rows)
},
},
{
name: "operation fails when the SQL transformer returns an error",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
},
},
},
},
{
Name: "02_add_column",
Operations: migrations.Operations{
&migrations.OpAddColumn{
Table: "users",
Column: migrations.Column{
Name: "name",
Type: "text",
Default: ptr("'default value 2'"),
},
},
},
},
},
wantStartErr: testutils.ErrMockSQLTransformer,
},
}, roll.WithSQLTransformer(sqlTransformer))
}
31 changes: 23 additions & 8 deletions pkg/migrations/op_create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@ import (
var _ Operation = (*OpCreateTable)(nil)

func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// Generate SQL for the columns in the table
columnsSQL, err := columnsToSQL(o.Columns, tr)
if err != nil {
return nil, fmt.Errorf("failed to create columns SQL: %w", err)
}

// Create the table under a temporary name
tempName := TemporaryName(o.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (%s)",
_, err = conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (%s)",
pq.QuoteIdentifier(tempName),
columnsToSQL(o.Columns)))
columnsSQL))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -104,18 +111,22 @@ func (o *OpCreateTable) Validate(ctx context.Context, s *schema.Schema) error {
return nil
}

func columnsToSQL(cols []Column) string {
func columnsToSQL(cols []Column, tr SQLTransformer) (string, error) {
var sql string
for i, col := range cols {
if i > 0 {
sql += ", "
}
sql += ColumnToSQL(col)
colSQL, err := ColumnToSQL(col, tr)
if err != nil {
return "", err
}
sql += colSQL
}
return sql
return sql, nil
}

func ColumnToSQL(col Column) string {
func ColumnToSQL(col Column, tr SQLTransformer) (string, error) {
sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)

if col.IsPrimaryKey() {
Expand All @@ -128,7 +139,11 @@ func ColumnToSQL(col Column) string {
sql += " NOT NULL"
}
if col.Default != nil {
sql += fmt.Sprintf(" DEFAULT %s", *col.Default)
d, err := tr.TransformSQL(*col.Default)
if err != nil {
return "", err
}
sql += fmt.Sprintf(" DEFAULT %s", d)
}
if col.References != nil {
onDelete := "NO ACTION"
Expand All @@ -147,5 +162,5 @@ func ColumnToSQL(col Column) string {
pq.QuoteIdentifier(col.Check.Name),
col.Check.Constraint)
}
return sql
return sql, nil
}
90 changes: 90 additions & 0 deletions pkg/migrations/op_create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
"github.com/xataio/pgroll/pkg/testutils"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -466,3 +467,92 @@ func TestCreateTableValidation(t *testing.T) {
},
})
}

func TestCreateTableColumnDefaultTransformation(t *testing.T) {
t.Parallel()

sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{
"'default value 1'": "'rewritten'",
"'default value 2'": testutils.MockSQLTransformerError,
})

ExecuteTests(t, TestCases{
{
name: "column default is rewritten by the SQL transformer",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "name",
Type: "text",
Default: ptr("'default value 1'"),
},
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// Insert some data into the table
MustInsert(t, db, schema, "01_create_table", "users", map[string]string{
"id": "1",
})

// Ensure the row has the rewritten default value.
rows := MustSelect(t, db, schema, "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "rewritten"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Insert some data into the table
MustInsert(t, db, schema, "01_create_table", "users", map[string]string{
"id": "1",
})

// Ensure the row has the rewritten default value.
rows := MustSelect(t, db, schema, "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "rewritten"},
}, rows)
},
},
{
name: "create table fails when the SQL transformer returns an error",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "name",
Type: "text",
Default: ptr("'default value 2'"),
},
},
},
},
},
},
wantStartErr: testutils.ErrMockSQLTransformer,
},
}, roll.WithSQLTransformer(sqlTransformer))
}

0 comments on commit 6334cb4

Please sign in to comment.