Skip to content

Commit

Permalink
feat: support modifying primary keys
Browse files Browse the repository at this point in the history
- Retrieve PK information on the table level rather than
for each column individually.
- Refine dependencies between migrations
  • Loading branch information
bevzzz committed Oct 27, 2024
1 parent 694f873 commit a734629
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 39 deletions.
42 changes: 42 additions & 0 deletions dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ func (m *migrator) Apply(ctx context.Context, changes ...interface{}) error {
return fmt.Errorf("apply changes: drop table %s: %w", change.FQN, err)
}
continue
case *migrate.ChangePrimaryKey:
// TODO: refactor!
b, err = m.dropConstraint(fmter, b, change.FQN, change.Old.Name)
if err != nil {
return fmt.Errorf("apply changes: %w", err)
}

query := internal.String(b)
log.Println("exec query: " + query)
if _, err = conn.ExecContext(ctx, query); err != nil {
return fmt.Errorf("apply changes: %w", err)
}

b = []byte{}
b, err = m.addPrimaryKey(fmter, b, change.FQN, change.New.Columns.Safe())
if err != nil {
return fmt.Errorf("apply changes: %w", err)
}

query = internal.String(b)
log.Println("exec query: " + query)
if _, err = conn.ExecContext(ctx, query); err != nil {
return fmt.Errorf("apply changes: %w", err)
}
continue
case *migrate.RenameTable:
b, err = m.renameTable(fmter, b, change)
case *migrate.RenameColumn:
Expand All @@ -62,6 +87,8 @@ func (m *migrator) Apply(ctx context.Context, changes ...interface{}) error {
b, err = m.addColumn(fmter, b, change)
case *migrate.DropColumn:
b, err = m.dropColumn(fmter, b, change)
case *migrate.AddPrimaryKey:
b, err = m.addPrimaryKey(fmter, b, change.FQN, change.PK.Columns.Safe())
case *migrate.AddForeignKey:
b, err = m.addForeignKey(fmter, b, change)
case *migrate.AddUniqueConstraint:
Expand All @@ -70,6 +97,8 @@ func (m *migrator) Apply(ctx context.Context, changes ...interface{}) error {
b, err = m.dropConstraint(fmter, b, change.FQN, change.Unique.Name)
case *migrate.DropConstraint:
b, err = m.dropConstraint(fmter, b, change.FQN(), change.ConstraintName)
case *migrate.DropPrimaryKey:
b, err = m.dropConstraint(fmter, b, change.FQN, change.PK.Name)
case *migrate.RenameConstraint:
b, err = m.renameConstraint(fmter, b, change)
case *migrate.ChangeColumnType:
Expand Down Expand Up @@ -155,6 +184,19 @@ func (m *migrator) renameConstraint(fmter schema.Formatter, b []byte, rename *mi
return b, nil
}

func (m *migrator) addPrimaryKey(fmter schema.Formatter, b []byte, fqn schema.FQN, columns schema.Safe) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
if b, err = fqn.AppendQuery(fmter, b); err != nil {
return b, err
}

b = append(b, " ADD PRIMARY KEY ("...)
b, _ = columns.AppendQuery(fmter, b)
b = append(b, ")"...)

return b, nil
}

func (m *migrator) addUnique(fmter schema.Formatter, b []byte, change *migrate.AddUniqueConstraint) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
if b, err = change.FQN.AppendQuery(fmter, b); err != nil {
Expand Down
47 changes: 37 additions & 10 deletions dialect/pgdialect/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) {
SQLType: c.DataType,
VarcharLen: c.VarcharLen,
DefaultValue: def,
IsPK: c.IsPK,
IsNullable: c.IsNullable,
IsAutoIncrement: c.IsSerial,
IsIdentity: c.IsIdentity,
Expand All @@ -83,11 +82,20 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) {
})
}

var pk *sqlschema.PK
if len(table.PrimaryKey.Columns) > 0 {
pk = &sqlschema.PK{
Name: table.PrimaryKey.ConstraintName,
Columns: sqlschema.NewComposite(table.PrimaryKey.Columns...),
}
}

state.Tables = append(state.Tables, sqlschema.Table{
Schema: table.Schema,
Name: table.Name,
Columns: colDefs,
UniqueContraints: unique,
PK: pk,
})
}

Expand All @@ -101,8 +109,9 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) {
}

type InformationSchemaTable struct {
Schema string `bun:"table_schema,pk"`
Name string `bun:"table_name,pk"`
Schema string `bun:"table_schema,pk"`
Name string `bun:"table_name,pk"`
PrimaryKey PrimaryKey `bun:"embed:primary_key_"`

Columns []*InformationSchemaColumn `bun:"rel:has-many,join:table_schema=table_schema,join:table_name=table_name"`
}
Expand All @@ -117,7 +126,6 @@ type InformationSchemaColumn struct {
ArrayDims int `bun:"array_dims"`
Default string `bun:"default"`
IsDefaultLiteral bool `bun:"default_is_literal_expr"`
IsPK bool `bun:"is_pk"`
IsIdentity bool `bun:"is_identity"`
IndentityType string `bun:"identity_type"`
IsSerial bool `bun:"is_serial"`
Expand All @@ -135,18 +143,38 @@ type ForeignKey struct {
TargetColumns []string `bun:"target_columns,array"`
}

type PrimaryKey struct {
ConstraintName string `bun:"name"`
Columns []string `bun:"columns,array"`
}

const (
// sqlInspectTables retrieves all user-defined tables across all schemas.
// It excludes relations from Postgres's reserved "pg_" schemas and views from the "information_schema".
// Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results.
sqlInspectTables = `
SELECT "table_schema", "table_name"
FROM information_schema.tables
SELECT
"t".table_schema,
"t".table_name,
pk.name AS primary_key_name,
pk.columns AS primary_key_columns
FROM information_schema.tables "t"
LEFT JOIN (
SELECT i.indrelid, "idx".relname AS "name", ARRAY_AGG("a".attname) AS "columns"
FROM pg_index i
JOIN pg_attribute "a"
ON "a".attrelid = i.indrelid
AND "a".attnum = ANY("i".indkey)
AND i.indisprimary
JOIN pg_class "idx" ON i.indexrelid = "idx".oid
GROUP BY 1, 2
) pk
ON ("t".table_schema || '.' || "t".table_name)::regclass = pk.indrelid
WHERE table_type = 'BASE TABLE'
AND "table_schema" <> 'information_schema'
AND "table_schema" NOT LIKE 'pg_%'
AND "t".table_schema <> 'information_schema'
AND "t".table_schema NOT LIKE 'pg_%'
AND "table_name" NOT IN (?)
ORDER BY "table_schema", "table_name"
ORDER BY "t".table_schema, "t".table_name
`

// sqlInspectColumnsQuery retrieves column definitions for the specified table.
Expand All @@ -166,7 +194,6 @@ SELECT
ELSE "c".column_default
END AS "default",
"c".column_default ~ '^''.*''::.*$' OR "c".column_default ~ '^[0-9\.]+$' AS default_is_literal_expr,
'p' = ANY("c".constraint_type) AS is_pk,
"c".is_identity = 'YES' AS is_identity,
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial,
COALESCE("c".identity_type, '') AS identity_type,
Expand Down
50 changes: 35 additions & 15 deletions internal/dbtest/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
Columns: map[string]sqlschema.Column{
"office_name": {
SQLType: sqltype.VarChar,
IsPK: true,
},
"publisher_id": {
SQLType: sqltype.VarChar,
Expand All @@ -111,30 +110,28 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
IsNullable: true,
},
},
PK: &sqlschema.PK{Columns: sqlschema.NewComposite("office_name")},
},
{
Schema: defaultSchema,
Name: "articles",
Columns: map[string]sqlschema.Column{
"isbn": {
SQLType: "bigint",
IsPK: true,
IsNullable: false,
IsAutoIncrement: false,
IsIdentity: true,
DefaultValue: "",
},
"editor": {
SQLType: sqltype.VarChar,
IsPK: false,
IsNullable: false,
IsAutoIncrement: false,
IsIdentity: false,
DefaultValue: "john doe",
},
"title": {
SQLType: sqltype.VarChar,
IsPK: false,
IsNullable: false,
IsAutoIncrement: false,
IsIdentity: false,
Expand All @@ -143,23 +140,20 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
"locale": {
SQLType: sqltype.VarChar,
VarcharLen: 5,
IsPK: false,
IsNullable: true,
IsAutoIncrement: false,
IsIdentity: false,
DefaultValue: "en-GB",
},
"page_count": {
SQLType: "smallint",
IsPK: false,
IsNullable: false,
IsAutoIncrement: false,
IsIdentity: false,
DefaultValue: "1",
},
"book_count": {
SQLType: "integer",
IsPK: false,
IsNullable: false,
IsAutoIncrement: true,
IsIdentity: false,
Expand All @@ -172,6 +166,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
SQLType: "bigint",
},
},
PK: &sqlschema.PK{Columns: sqlschema.NewComposite("isbn")},
UniqueContraints: []sqlschema.Unique{
{Columns: sqlschema.NewComposite("editor", "title")},
},
Expand All @@ -182,7 +177,6 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
Columns: map[string]sqlschema.Column{
"author_id": {
SQLType: "bigint",
IsPK: true,
IsIdentity: true,
},
"first_name": {
Expand All @@ -195,6 +189,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
SQLType: sqltype.VarChar,
},
},
PK: &sqlschema.PK{Columns: sqlschema.NewComposite("author_id")},
UniqueContraints: []sqlschema.Unique{
{Columns: sqlschema.NewComposite("first_name", "last_name")},
{Columns: sqlschema.NewComposite("email")},
Expand All @@ -206,21 +201,19 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
Columns: map[string]sqlschema.Column{
"publisher_id": {
SQLType: sqltype.VarChar,
IsPK: true,
},
"author_id": {
SQLType: "bigint",
IsPK: true,
},
},
PK: &sqlschema.PK{Columns: sqlschema.NewComposite("publisher_id", "author_id")},
},
{
Schema: defaultSchema,
Name: "publishers",
Columns: map[string]sqlschema.Column{
"publisher_id": {
SQLType: sqltype.VarChar,
IsPK: true,
DefaultValue: "gen_random_uuid()",
},
"publisher_name": {
Expand All @@ -232,6 +225,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
IsNullable: true,
},
},
PK: &sqlschema.PK{Columns: sqlschema.NewComposite("publisher_id")},
UniqueContraints: []sqlschema.Unique{
{Columns: sqlschema.NewComposite("publisher_id", "publisher_name")},
},
Expand Down Expand Up @@ -301,7 +295,7 @@ func mustCreateSchema(tb testing.TB, ctx context.Context, db *bun.DB, schema str
func cmpTables(tb testing.TB, d sqlschema.InspectorDialect, want, got []sqlschema.Table) {
tb.Helper()

require.Equal(tb, tableNames(want), tableNames(got), "different set of tables")
require.ElementsMatch(tb, tableNames(want), tableNames(got), "different set of tables")

// Now we are guaranteed to have the same tables.
for _, wt := range want {
Expand Down Expand Up @@ -345,15 +339,15 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w
}

if wantCol.IsNullable != gotCol.IsNullable {
errorf("isNullable:\n\t(+want)\t%s\n\t(-got)\t%s", wantCol.IsNullable, gotCol.IsNullable)
errorf("isNullable:\n\t(+want)\t%t\n\t(-got)\t%t", wantCol.IsNullable, gotCol.IsNullable)
}

if wantCol.IsAutoIncrement != gotCol.IsAutoIncrement {
errorf("IsAutoIncrement:\n\t(+want)\t%s\n\t(-got)\t%s", wantCol.IsAutoIncrement, gotCol.IsAutoIncrement)
errorf("IsAutoIncrement:\n\t(+want)\t%s\b\t(-got)\t%t", wantCol.IsAutoIncrement, gotCol.IsAutoIncrement)
}

if wantCol.IsIdentity != gotCol.IsIdentity {
errorf("IsIdentity:\n\t(+want)\t%s\n\t(-got)\t%s", wantCol.IsIdentity, gotCol.IsIdentity)
errorf("IsIdentity:\n\t(+want)\t%t\n\t(-got)\t%t", wantCol.IsIdentity, gotCol.IsIdentity)
}
}

Expand Down Expand Up @@ -381,6 +375,13 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w
func cmpConstraints(tb testing.TB, want, got sqlschema.Table) {
tb.Helper()

if want.PK != nil {
require.NotNilf(tb, got.PK, "table %q missing primary key, want: (%s)", want.Name, want.PK.Columns)
require.Equalf(tb, want.PK.Columns, got.PK.Columns, "table %q has wrong primary key", want.Name)
} else {
require.Nilf(tb, got.PK, "table %q shouldn't have a primary key", want.Name)
}

// Only keep columns included in each unique constraint for comparison.
stripNames := func(uniques []sqlschema.Unique) (res []string) {
for _, u := range uniques {
Expand Down Expand Up @@ -496,5 +497,24 @@ func TestSchemaInspector_Inspect(t *testing.T) {
require.Len(t, got.Tables, 1)
cmpConstraints(t, want, got.Tables[0])
})
t.Run("collects primary keys", func(t *testing.T) {
type Model struct {
ID string `bun:",pk"`
Email string `bun:",pk"`
Birthday time.Time `bun:",notnull"`
}

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewSchemaInspector(tables)
want := sqlschema.NewComposite("id", "email")

got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

require.Len(t, got.Tables, 1)
require.NotNilf(t, got.Tables[0].PK, "did not register primary key, want (%s)", want)
require.Equal(t, want, got.Tables[0].PK.Columns, "wrong primary key columns")
})
})
}
Loading

0 comments on commit a734629

Please sign in to comment.