Skip to content

Commit

Permalink
feat: add support for ON UPDATE and ON DELETE rules on belongs-to rel…
Browse files Browse the repository at this point in the history
…ationships from struct tags (#533)

* feat: add ON UPDATE and ON DELETE rules to belongs-to struct tag

Co-authored-by: Francesco Cartier <[email protected]>
  • Loading branch information
antipopp and Francesco Cartier authored Jun 8, 2022
1 parent 7b168ea commit a327b2a
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 3 deletions.
89 changes: 89 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ func TestDB(t *testing.T) {
{testJSONValuer},
{testSelectBool},
{testFKViolation},
{testWithForeignKeysAndRules},
{testWithForeignKeys},
{testInterfaceAny},
{testInterfaceJSON},
Expand Down Expand Up @@ -869,6 +870,94 @@ func testFKViolation(t *testing.T, db *bun.DB) {
require.Equal(t, 0, n)
}

func testWithForeignKeysAndRules(t *testing.T, db *bun.DB) {
type User struct {
ID int `bun:",pk"`
Type string `bun:",pk"`
Name string
}
type Deck struct {
ID int `bun:",pk"`
UserID int
UserType string
User *User `bun:"rel:belongs-to,join:user_id=id,join:user_type=type,on_update:cascade,on_delete:set null"`
}

if db.Dialect().Name() == dialect.SQLite {
_, err := db.Exec("PRAGMA foreign_keys = ON;")
require.NoError(t, err)
}

for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} {
_, err := db.NewDropTable().Model(model).IfExists().Exec(ctx)
require.NoError(t, err)
}

_, err := db.NewCreateTable().
Model((*User)(nil)).
IfNotExists().
Exec(ctx)
require.NoError(t, err)

_, err = db.NewCreateTable().
Model((*Deck)(nil)).
IfNotExists().
WithForeignKeys().
Exec(ctx)
require.NoError(t, err)

// Empty deck should violate FK constraint.
_, err = db.NewInsert().Model(new(Deck)).Exec(ctx)
require.Error(t, err)

// Create a deck that violates the user_id FK contraint
deck := &Deck{UserID: 42}

_, err = db.NewInsert().Model(deck).Exec(ctx)
require.Error(t, err)

decks := []*Deck{deck}
_, err = db.NewInsert().Model(&decks).Exec(ctx)
require.Error(t, err)

n, err := db.NewSelect().Model((*Deck)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)

_, err = db.NewInsert().Model(&User{ID: 1, Type: "admin", Name: "root"}).Exec(ctx)
require.NoError(t, err)
res, err := db.NewInsert().Model(&Deck{UserID: 1, UserType: "admin"}).Exec(ctx)
require.NoError(t, err)

affected, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), affected)

// Update User ID and check for FK update
res, err = db.NewUpdate().Model(&User{}).Where("id = ?", 1).Where("type = ?", "admin").Set("id = ?", 2).Exec(ctx)
require.NoError(t, err)

affected, err = res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), affected)

n, err = db.NewSelect().Model(&Deck{}).Where("user_id = 1").Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)

n, err = db.NewSelect().Model(&Deck{}).Where("user_id = 2").Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, n)

// Delete user and check for FK delete
_, err = db.NewDelete().Model(&User{}).Where("id = ?", 2).Exec(ctx)
require.NoError(t, err)

n, err = db.NewSelect().Model(&Deck{}).Where("user_id = 2").Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)
}

func testWithForeignKeys(t *testing.T, db *bun.DB) {
type User struct {
ID int `bun:",pk,autoincrement"`
Expand Down
9 changes: 6 additions & 3 deletions query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,16 @@ func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery {
func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery {
for _, relation := range q.tableModel.Table().Relations {
if relation.Type == schema.ManyToManyRelation ||
relation.Type == schema.HasManyRelation {
relation.Type == schema.HasManyRelation {
continue
}
q = q.ForeignKey("(?) REFERENCES ? (?)",
}

q = q.ForeignKey("(?) REFERENCES ? (?) ? ?",
Safe(appendColumns(nil, "", relation.BaseFields)),
relation.JoinTable.SQLName,
Safe(appendColumns(nil, "", relation.JoinFields)),
Safe(relation.OnUpdate),
Safe(relation.OnDelete),
)
}
return q
Expand Down
2 changes: 2 additions & 0 deletions schema/relation.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ type Relation struct {
JoinTable *Table
BaseFields []*Field
JoinFields []*Field
OnUpdate string
OnDelete string

PolymorphicField *Field
PolymorphicValue string
Expand Down
42 changes: 42 additions & 0 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,35 @@ func (t *Table) belongsToRelation(field *Field) *Relation {
JoinTable: joinTable,
}

rel.OnUpdate = "ON UPDATE NO ACTION"
if onUpdate, ok := field.Tag.Options["on_update"]; ok {
if len(onUpdate) > 1 {
panic(fmt.Errorf("bun: %s belongs-to %s: on_update option must be a single field", t.TypeName, field.GoName))
}

rule := strings.ToUpper(onUpdate[0])
if !isKnownFKRule(rule) {
internal.Warn.Printf("bun: %s belongs-to %s: unknown on_update rule %s", t.TypeName, field.GoName, rule)
}

s := fmt.Sprintf("ON UPDATE %s", rule)
rel.OnUpdate = s
}

rel.OnDelete = "ON DELETE NO ACTION"
if onDelete, ok := field.Tag.Options["on_delete"]; ok {
if len(onDelete) > 1 {
panic(fmt.Errorf("bun: %s belongs-to %s: on_delete option must be a single field", t.TypeName, field.GoName))
}

rule := strings.ToUpper(onDelete[0])
if !isKnownFKRule(rule) {
internal.Warn.Printf("bun: %s belongs-to %s: unknown on_delete rule %s", t.TypeName, field.GoName, rule)
}
s := fmt.Sprintf("ON DELETE %s", rule)
rel.OnDelete = s
}

if join, ok := field.Tag.Options["join"]; ok {
baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns {
Expand Down Expand Up @@ -859,13 +888,26 @@ func isKnownFieldOption(name string) bool {
"autoincrement",
"rel",
"join",
"on_update",
"on_delete",
"m2m",
"polymorphic":
return true
}
return false
}

func isKnownFKRule(name string) bool {
switch name {
case "CASCADE",
"RESTRICT",
"SET NULL",
"SET DEFAULT":
return true
}
return false
}

func removeField(fields []*Field, field *Field) []*Field {
for i, f := range fields {
if f == field {
Expand Down

0 comments on commit a327b2a

Please sign in to comment.