Skip to content

Commit

Permalink
Merge pull request #632 from kmpm/fix/CompositeIn
Browse files Browse the repository at this point in the history
Fix/composite in
  • Loading branch information
vmihailenco authored Aug 8, 2022
2 parents bf04f2a + 91348c5 commit efe03f4
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 2 deletions.
1 change: 1 addition & 0 deletions dialect/feature/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ const (
UpdateFromTable
MSSavepoint
GeneratedIdentity
CompositeIn // ... WHERE (A,B) IN ((N, NN), (N, NN)...)
)
3 changes: 2 additions & 1 deletion dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ func New() *Dialect {
feature.TableNotExists |
feature.InsertOnConflict |
feature.SelectExists |
feature.GeneratedIdentity
feature.GeneratedIdentity |
feature.CompositeIn
return d
}

Expand Down
3 changes: 2 additions & 1 deletion dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ func New() *Dialect {
feature.DeleteTableAlias |
feature.InsertOnConflict |
feature.TableNotExists |
feature.SelectExists
feature.SelectExists |
feature.CompositeIn
return d
}

Expand Down
29 changes: 29 additions & 0 deletions internal/dbtest/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func TestORM(t *testing.T) {
{testRelationExcludeAll},
{testM2MRelationExcludeColumn},
{testRelationBelongsToSelf},
{testCompositeHasMany},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -429,6 +430,18 @@ func testM2MRelationExcludeColumn(t *testing.T, db *bun.DB) {
require.NoError(t, err)
}

func testCompositeHasMany(t *testing.T, db *bun.DB) {
department := new(Department)
err := db.NewSelect().
Model(department).
Where("company_no=? AND no=?", "company one", "hr").
Relation("Employees").
Scan(ctx)
require.NoError(t, err)
require.Equal(t, "hr", department.No)
require.Equal(t, 2, len(department.Employees))
}

type Genre struct {
ID int `bun:",pk"`
Name string
Expand Down Expand Up @@ -530,6 +543,20 @@ type Comment struct {
Text string
}

type Department struct {
bun.BaseModel `bun:"alias:d"`
CompanyNo string `bun:",pk"`
No string `bun:",pk"`
Employees []Employee `bun:"rel:has-many,join:company_no=company_no,join:no=department_no"`
}

type Employee struct {
bun.BaseModel `bun:"alias:p"`
CompanyNo string `bun:",pk"`
DepartmentNo string `bun:",pk"`
Name string `bun:",pk"`
}

func createTestSchema(t *testing.T, db *bun.DB) {
_ = db.Table(reflect.TypeOf((*BookGenre)(nil)).Elem())

Expand All @@ -541,6 +568,8 @@ func createTestSchema(t *testing.T, db *bun.DB) {
(*BookGenre)(nil),
(*Translation)(nil),
(*Comment)(nil),
(*Department)(nil),
(*Employee)(nil),
}
for _, model := range models {
_, err := db.NewDropTable().Model(model).IfExists().Exec(ctx)
Expand Down
22 changes: 22 additions & 0 deletions internal/dbtest/testdata/fixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,25 @@
- trackable_id: 1000
trackable_type: translation
text: comment3

- model: Department
rows:
- company_no: company one
no: accounting
- company_no: company one
no: 'hr'

- model: Employee
rows:
- company_no: company one
department_no: accounting
name: 'adam'
- company_no: company one
department_no: accounting
name: 'bravo'
- company_no: company one
department_no: hr
name: 'charlie'
- company_no: company one
department_no: hr
name: 'foxtrot'
84 changes: 84 additions & 0 deletions relation_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"reflect"

"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
Expand Down Expand Up @@ -60,6 +61,14 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
q = q.Model(hasManyModel)

var where []byte

if q.db.dialect.Features().Has(feature.CompositeIn) {
return j.manyQueryCompositeIn(where, q)
}
return j.manyQueryMulti(where, q)
}

func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery {
if len(j.Relation.JoinFields) > 1 {
where = append(where, '(')
}
Expand Down Expand Up @@ -88,6 +97,29 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
return q
}

func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery {
where = appendMultiValues(
q.db.Formatter(),
where,
j.JoinModel.rootValue(),
j.JoinModel.parentIndex(),
j.Relation.BaseFields,
j.Relation.JoinFields,
j.JoinModel.Table().SQLAlias,
)

q = q.Where(internal.String(where))

if j.Relation.PolymorphicField != nil {
q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
}

j.applyTo(q)
q = q.Apply(j.hasManyColumns)

return q
}

func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery {
b := make([]byte, 0, 32)

Expand Down Expand Up @@ -312,3 +344,55 @@ func appendChildValues(
}
return b
}

// appendMultiValues is an alternative to appendChildValues that doesn't use the sql keyword ID
// but instead use a old style ((k1=v1) AND (k2=v2)) OR (...) of conditions.
func appendMultiValues(
fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, joinTable schema.Safe,
) []byte {
// This is based on a mix of appendChildValues and query_base.appendColumns

// These should never missmatch in length but nice to know if it does
if len(joinFields) != len(baseFields) {
panic("not reached")
}

// walk the relations
b = append(b, '(')
seen := make(map[string]struct{})
walk(v, index, func(v reflect.Value) {
start := len(b)
for i, f := range baseFields {
if i > 0 {
b = append(b, " AND "...)
}
if len(baseFields) > 1 {
b = append(b, '(')
}
// Field name
b = append(b, joinTable...)
b = append(b, '.')
b = append(b, []byte(joinFields[i].SQLName)...)

// Equals value
b = append(b, '=')
b = f.AppendValue(fmter, b, v)
if len(baseFields) > 1 {
b = append(b, ')')
}
}

b = append(b, ") OR ("...)

if _, ok := seen[string(b[start:])]; ok {
b = b[:start]
} else {
seen[string(b[start:])] = struct{}{}
}
})
if len(seen) > 0 {
b = b[:len(b)-6] // trim ") OR ("
}
b = append(b, ')')
return b
}

0 comments on commit efe03f4

Please sign in to comment.