Skip to content

Commit

Permalink
Add support for composite keys in create table statements (#413)
Browse files Browse the repository at this point in the history
This PR adds support for setting a composite key for a table.
From now on it is possible to set `pk` to `true` in multiple columns in
`create_table`.

The create table statement is translated to the following format:

```sql
CREATE TABLE my_table (
    id SERIAL,
    code VARCHAR(255),
    count INTEGER,
    PRIMARY KEY (id, code)
);
```
  • Loading branch information
kvch authored Oct 18, 2024
1 parent 515dd54 commit 12ae369
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 38 deletions.
22 changes: 22 additions & 0 deletions examples/01_create_tables.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@
}
]
}
},
{
"create_table": {
"name": "sellers",
"columns": [
{
"name": "name",
"type": "varchar(255)",
"pk": true
},
{
"name": "zip",
"type": "integer",
"pk": true
},
{
"name": "description",
"type": "varchar(255)",
"nullable": true
}
]
}
}
]
}
51 changes: 50 additions & 1 deletion pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table,
o.Column.Check = nil

o.Column.Name = TemporaryName(o.Column.Name)
colSQL, err := ColumnToSQL(o.Column, tr)
columnWriter := ColumnSQLWriter{WithPK: true, Transformer: tr}
colSQL, err := columnWriter.Write(o.Column)
if err != nil {
return err
}
Expand Down Expand Up @@ -243,3 +244,51 @@ func NotNullConstraintName(columnName string) string {
func IsNotNullConstraintName(name string) bool {
return strings.HasPrefix(name, "_pgroll_check_not_null_")
}

// ColumnSQLWriter writes a column to SQL
// It can optionally include the primary key constraint
// When creating a table, the primary key constraint is not added to the column definition
type ColumnSQLWriter struct {
WithPK bool
Transformer SQLTransformer
}

func (w ColumnSQLWriter) Write(col Column) (string, error) {
sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)

if w.WithPK && col.IsPrimaryKey() {
sql += " PRIMARY KEY"
}

if col.IsUnique() {
sql += " UNIQUE"
}
if !col.IsNullable() {
sql += " NOT NULL"
}
if col.Default != nil {
d, err := w.Transformer.TransformSQL(*col.Default)
if err != nil {
return "", err
}
sql += fmt.Sprintf(" DEFAULT %s", d)
}
if col.References != nil {
onDelete := "NO ACTION"
if col.References.OnDelete != "" {
onDelete = strings.ToUpper(string(col.References.OnDelete))
}

sql += fmt.Sprintf(" CONSTRAINT %s REFERENCES %s(%s) ON DELETE %s",
pq.QuoteIdentifier(col.References.Name),
pq.QuoteIdentifier(col.References.Table),
pq.QuoteIdentifier(col.References.Column),
onDelete)
}
if col.Check != nil {
sql += fmt.Sprintf(" CONSTRAINT %s CHECK (%s)",
pq.QuoteIdentifier(col.Check.Name),
col.Check.Constraint)
}
return sql, nil
}
29 changes: 29 additions & 0 deletions pkg/migrations/op_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ func ColumnMustNotHaveComment(t *testing.T, db *sql.DB, schema, table, column st
}
}

func ColumnMustBePK(t *testing.T, db *sql.DB, schema, table, column string) {
t.Helper()
if !columnMustBePK(t, db, schema, table, column) {
t.Fatalf("Expected column %q to be primary key", column)
}
}

func TableMustHaveComment(t *testing.T, db *sql.DB, schema, table, expectedComment string) {
t.Helper()
if !tableHasComment(t, db, schema, table, expectedComment) {
Expand Down Expand Up @@ -526,6 +533,28 @@ func columnHasComment(t *testing.T, db *sql.DB, schema, table, column string, ex
return actualComment != nil && *expectedComment == *actualComment
}

func columnMustBePK(t *testing.T, db *sql.DB, schema, table, column string) bool {
t.Helper()

var exists bool
err := db.QueryRow(fmt.Sprintf(`
SELECT EXISTS (
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid
AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = %[1]s::regclass AND i.indisprimary AND a.attname = %[2]s
)`,
pq.QuoteLiteral(fmt.Sprintf("%s.%s", schema, table)),
pq.QuoteLiteral(column)),
).Scan(&exists)
if err != nil {
t.Fatal(err)
}

return exists
}

func tableHasComment(t *testing.T, db *sql.DB, schema, table, expectedComment string) bool {
t.Helper()

Expand Down
44 changes: 7 additions & 37 deletions pkg/migrations/op_create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,55 +113,25 @@ func (o *OpCreateTable) Validate(ctx context.Context, s *schema.Schema) error {

func columnsToSQL(cols []Column, tr SQLTransformer) (string, error) {
var sql string
var primaryKeys []string
columnWriter := ColumnSQLWriter{WithPK: false, Transformer: tr}
for i, col := range cols {
if i > 0 {
sql += ", "
}
colSQL, err := ColumnToSQL(col, tr)
colSQL, err := columnWriter.Write(col)
if err != nil {
return "", err
}
sql += colSQL
}
return sql, nil
}

// ColumnToSQL generates the SQL for a column definition.
func ColumnToSQL(col Column, tr SQLTransformer) (string, error) {
sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)

if col.IsPrimaryKey() {
sql += " PRIMARY KEY"
}
if col.IsUnique() {
sql += " UNIQUE"
}
if !col.IsNullable() {
sql += " NOT NULL"
}
if col.Default != nil {
d, err := tr.TransformSQL(*col.Default)
if err != nil {
return "", err
if col.IsPrimaryKey() {
primaryKeys = append(primaryKeys, pq.QuoteIdentifier(col.Name))
}
sql += fmt.Sprintf(" DEFAULT %s", d)
}
if col.References != nil {
onDelete := "NO ACTION"
if col.References.OnDelete != "" {
onDelete = strings.ToUpper(string(col.References.OnDelete))
}

sql += fmt.Sprintf(" CONSTRAINT %s REFERENCES %s(%s) ON DELETE %s",
pq.QuoteIdentifier(col.References.Name),
pq.QuoteIdentifier(col.References.Table),
pq.QuoteIdentifier(col.References.Column),
onDelete)
}
if col.Check != nil {
sql += fmt.Sprintf(" CONSTRAINT %s CHECK (%s)",
pq.QuoteIdentifier(col.Check.Name),
col.Check.Constraint)
if len(primaryKeys) > 0 {
sql += fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(primaryKeys, ", "))
}
return sql, nil
}
83 changes: 83 additions & 0 deletions pkg/migrations/op_create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,89 @@ func TestCreateTable(t *testing.T) {
}, rows)
},
},
{
name: "create table with composite key",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "rand",
Type: "varchar(255)",
Pk: ptr(true),
},
{
Name: "name",
Type: "varchar(255)",
Unique: ptr(true),
},
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// The new view exists in the new version schema.
ViewMustExist(t, db, schema, "01_create_table", "users")

// Data can be inserted into the new view.
MustInsert(t, db, schema, "01_create_table", "users", map[string]string{
"rand": "123",
"name": "Alice",
})
// New record with same keys cannot be inserted.
MustNotInsert(t, db, schema, "01_create_table", "users", map[string]string{
"id": "1",
"rand": "123",
"name": "Malice",
}, testutils.UniqueViolationErrorCode)

// Data can be retrieved from the new view.
rows := MustSelect(t, db, schema, "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "rand": "123", "name": "Alice"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// The underlying table has been dropped.
TableMustNotExist(t, db, schema, "users")
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// The view still exists
ViewMustExist(t, db, schema, "01_create_table", "users")

// The columns are still primary keys.
ColumnMustBePK(t, db, schema, "users", "id")
ColumnMustBePK(t, db, schema, "users", "rand")

// Data can be inserted into the new view.
MustInsert(t, db, schema, "01_create_table", "users", map[string]string{
"rand": "123",
"name": "Alice",
})

// New record with same keys cannot be inserted.
MustNotInsert(t, db, schema, "01_create_table", "users", map[string]string{
"id": "1",
"rand": "123",
"name": "Malice",
}, testutils.UniqueViolationErrorCode)

// Data can be retrieved from the new view.
rows := MustSelect(t, db, schema, "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "rand": "123", "name": "Alice"},
}, rows)
},
},
{
name: "create table with foreign key with default ON DELETE NO ACTION",
migrations: []migrations.Migration{
Expand Down

0 comments on commit 12ae369

Please sign in to comment.