Skip to content

Commit

Permalink
feat: [#358] Add OrWhere method for DB (#910)
Browse files Browse the repository at this point in the history
* feat: [#358] Add OrWhere method for DB

* chore: update mocks

* test

* fix lint

* implement OrWhere

* fix test

* fix test

---------

Co-authored-by: hwbrzzl <[email protected]>
  • Loading branch information
hwbrzzl and hwbrzzl authored Feb 27, 2025
1 parent 027f8a0 commit e26aa18
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 96 deletions.
2 changes: 1 addition & 1 deletion contracts/database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type Query interface {
// OrderBy(column string) Query
// orderByDesc
// OrderByRaw(query string, args ...any) Query
// OrWhere(query any, args ...any) Query
OrWhere(query any, args ...any) Query
// OrWhereLike()
// OrWhereNotLike
// Pluck(column string, dest any) error
Expand Down
2 changes: 1 addition & 1 deletion database/db/conditions.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ type Conditions struct {
type Where struct {
query any
args []any
// or bool
or bool
}
138 changes: 94 additions & 44 deletions database/db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"context"
databasesql "database/sql"
"fmt"
"reflect"
"sort"

Expand Down Expand Up @@ -147,6 +148,18 @@ func (r *Query) Insert(data any) (*db.Result, error) {
}, nil
}

func (r *Query) OrWhere(query any, args ...any) db.Query {
q := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.table)
q.conditions = r.conditions
q.conditions.where = append(r.conditions.where, Where{
query: query,
args: args,
or: true,
})

return q
}

func (r *Query) Update(data any) (*db.Result, error) {
mapData, err := convertToMap(data)
if err != nil {
Expand Down Expand Up @@ -198,23 +211,12 @@ func (r *Query) buildDelete() (sql string, args []any, err error) {
builder = builder.PlaceholderFormat(placeholderFormat)
}

for _, where := range r.conditions.where {
query, ok := where.query.(string)
if ok {
if !str.Of(query).Trim().Contains(" ", "?") {
if len(where.args) > 1 {
builder = builder.Where(sq.Eq{query: where.args})
} else if len(where.args) == 1 {
builder = builder.Where(sq.Eq{query: where.args[0]})
}
continue
}
}

builder = builder.Where(where.query, where.args...)
sqlizer, err := r.buildWheres(r.conditions.where)
if err != nil {
return "", nil, err
}

return builder.ToSql()
return builder.Where(sqlizer).ToSql()
}

func (r *Query) buildInsert(data []map[string]any) (sql string, args []any, err error) {
Expand Down Expand Up @@ -257,24 +259,12 @@ func (r *Query) buildSelect() (sql string, args []any, err error) {
}

builder = builder.From(r.conditions.table)

for _, where := range r.conditions.where {
query, ok := where.query.(string)
if ok {
if !str.Of(query).Trim().Contains(" ", "?") {
if len(where.args) > 1 {
builder = builder.Where(sq.Eq{query: where.args})
} else if len(where.args) == 1 {
builder = builder.Where(sq.Eq{query: where.args[0]})
}
continue
}
}

builder = builder.Where(where.query, where.args...)
sqlizer, err := r.buildWheres(r.conditions.where)
if err != nil {
return "", nil, err
}

return builder.ToSql()
return builder.Where(sqlizer).ToSql()
}

func (r *Query) buildUpdate(data map[string]any) (sql string, args []any, err error) {
Expand All @@ -287,25 +277,72 @@ func (r *Query) buildUpdate(data map[string]any) (sql string, args []any, err er
builder = builder.PlaceholderFormat(placeholderFormat)
}

for _, where := range r.conditions.where {
query, ok := where.query.(string)
if ok {
if !str.Of(query).Trim().Contains(" ", "?") {
if len(where.args) > 1 {
builder = builder.Where(sq.Eq{query: where.args})
} else if len(where.args) == 1 {
builder = builder.Where(sq.Eq{query: where.args[0]})
}
continue
sqlizer, err := r.buildWheres(r.conditions.where)
if err != nil {
return "", nil, err
}

return builder.Where(sqlizer).SetMap(data).ToSql()
}

func (r *Query) buildWhere(where Where) (any, []any) {
query, ok := where.query.(string)
if ok {
if !str.Of(query).Trim().Contains(" ", "?") {
if len(where.args) > 1 {
return sq.Eq{query: where.args}, nil
} else if len(where.args) == 1 {
return sq.Eq{query: where.args[0]}, nil
}
}
}

return where.query, where.args
}

builder = builder.Where(where.query, where.args...)
func (r *Query) buildWheres(wheres []Where) (sq.Sqlizer, error) {
if len(wheres) == 0 {
return nil, nil
}

builder = builder.SetMap(data)
var sqlizers []sq.Sqlizer
for _, where := range wheres {
query, args := r.buildWhere(where)

return builder.ToSql()
sqlizer, err := r.toSqlizer(query, args)
if err != nil {
return nil, err
}

if where.or && len(sqlizers) > 0 {
// If it's an OR condition and we have previous conditions,
// wrap the previous conditions in an AND and create an OR condition
if len(sqlizers) == 1 {
sqlizers = []sq.Sqlizer{
sq.Or{
sqlizers[0],
sqlizer,
},
}
} else {
sqlizers = []sq.Sqlizer{
sq.Or{
sq.And(sqlizers),
sqlizer,
},
}
}
} else {
// For regular WHERE conditions or the first condition
sqlizers = append(sqlizers, sqlizer)
}
}

if len(sqlizers) == 1 {
return sqlizers[0], nil
}

return sq.And(sqlizers), nil
}

func (r *Query) placeholderFormat() database.PlaceholderFormat {
Expand All @@ -319,3 +356,16 @@ func (r *Query) placeholderFormat() database.PlaceholderFormat {
func (r *Query) trace(sql string, args []any, rowsAffected int64, err error) {
r.logger.Trace(r.ctx, carbon.Now(), r.driver.Explain(sql, args...), rowsAffected, err)
}

func (r *Query) toSqlizer(query any, args []any) (sq.Sqlizer, error) {
switch q := query.(type) {
case map[string]any:
return sq.Eq(q), nil
case string:
return sq.Expr(q, args...), nil
case sq.Sqlizer:
return q, nil
default:
return nil, errors.DatabaseUnsupportedType.Args(fmt.Sprintf("%T", query), "string-keyed map or string or squirrel.Sqlizer")
}
}
87 changes: 59 additions & 28 deletions database/db/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ func (s *QueryTestSuite) TestDelete() {
mockResult.On("RowsAffected").Return(int64(1), nil)

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Exec("DELETE FROM users WHERE name = ? AND id = ?", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("DELETE FROM users WHERE name = ? AND id = ?", "John", 1).Return("DELETE FROM users WHERE name = \"John\" AND id = 1").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "DELETE FROM users WHERE name = \"John\" AND id = 1", int64(1), nil).Return().Once()
s.mockBuilder.EXPECT().Exec("DELETE FROM users WHERE (name = ? AND id = ?)", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("DELETE FROM users WHERE (name = ? AND id = ?)", "John", 1).Return("DELETE FROM users WHERE (name = \"John\" AND id = 1)").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "DELETE FROM users WHERE (name = \"John\" AND id = 1)", int64(1), nil).Return().Once()

result, err := s.query.Where("name", "John").Where("id", 1).Delete()
s.Nil(err)
Expand All @@ -69,9 +69,9 @@ func (s *QueryTestSuite) TestDelete() {

s.Run("failed to exec", func() {
s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Exec("DELETE FROM users WHERE name = ? AND id = ?", "John", 1).Return(nil, assert.AnError).Once()
s.mockDriver.EXPECT().Explain("DELETE FROM users WHERE name = ? AND id = ?", "John", 1).Return("DELETE FROM users WHERE name = \"John\" AND id = 1").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "DELETE FROM users WHERE name = \"John\" AND id = 1", int64(-1), assert.AnError).Return().Once()
s.mockBuilder.EXPECT().Exec("DELETE FROM users WHERE (name = ? AND id = ?)", "John", 1).Return(nil, assert.AnError).Once()
s.mockDriver.EXPECT().Explain("DELETE FROM users WHERE (name = ? AND id = ?)", "John", 1).Return("DELETE FROM users WHERE (name = \"John\" AND id = 1)").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "DELETE FROM users WHERE (name = \"John\" AND id = 1)", int64(-1), assert.AnError).Return().Once()

_, err := s.query.Where("name", "John").Where("id", 1).Delete()
s.Equal(assert.AnError, err)
Expand All @@ -82,9 +82,9 @@ func (s *QueryTestSuite) TestDelete() {
mockResult.On("RowsAffected").Return(int64(0), assert.AnError).Once()

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Exec("DELETE FROM users WHERE name = ? AND id = ?", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("DELETE FROM users WHERE name = ? AND id = ?", "John", 1).Return("DELETE FROM users WHERE name = \"John\" AND id = 1").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "DELETE FROM users WHERE name = \"John\" AND id = 1", int64(-1), assert.AnError).Return().Once()
s.mockBuilder.EXPECT().Exec("DELETE FROM users WHERE (name = ? AND id = ?)", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("DELETE FROM users WHERE (name = ? AND id = ?)", "John", 1).Return("DELETE FROM users WHERE (name = \"John\" AND id = 1)").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "DELETE FROM users WHERE (name = \"John\" AND id = 1)", int64(-1), assert.AnError).Return().Once()

_, err := s.query.Where("name", "John").Where("id", 1).Delete()
s.Equal(assert.AnError, err)
Expand Down Expand Up @@ -292,9 +292,9 @@ func (s *QueryTestSuite) TestUpdate() {
mockResult.On("RowsAffected").Return(int64(1), nil)

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Exec("UPDATE users SET phone = ? WHERE name = ? AND id = ?", "1234567890", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("UPDATE users SET phone = ? WHERE name = ? AND id = ?", "1234567890", "John", 1).Return("UPDATE users SET phone = \"1234567890\" WHERE name = \"John\" AND id = 1").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET phone = \"1234567890\" WHERE name = \"John\" AND id = 1", int64(1), nil).Return().Once()
s.mockBuilder.EXPECT().Exec("UPDATE users SET phone = ? WHERE (name = ? AND id = ?)", "1234567890", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("UPDATE users SET phone = ? WHERE (name = ? AND id = ?)", "1234567890", "John", 1).Return("UPDATE users SET phone = \"1234567890\" WHERE (name = \"John\" AND id = 1)").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET phone = \"1234567890\" WHERE (name = \"John\" AND id = 1)", int64(1), nil).Return().Once()

result, err := s.query.Where("name", "John").Where("id", 1).Update(user)
s.Nil(err)
Expand All @@ -314,9 +314,9 @@ func (s *QueryTestSuite) TestUpdate() {
mockResult.On("RowsAffected").Return(int64(1), nil)

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Exec("UPDATE users SET age = ?, name = ?, phone = ? WHERE name = ? AND id = ?", 25, "John", "1234567890", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("UPDATE users SET age = ?, name = ?, phone = ? WHERE name = ? AND id = ?", 25, "John", "1234567890", "John", 1).Return("UPDATE users SET age = 25, name = \"John\", phone = \"1234567890\" WHERE name = \"John\" AND id = 1").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET age = 25, name = \"John\", phone = \"1234567890\" WHERE name = \"John\" AND id = 1", int64(1), nil).Return().Once()
s.mockBuilder.EXPECT().Exec("UPDATE users SET age = ?, name = ?, phone = ? WHERE (name = ? AND id = ?)", 25, "John", "1234567890", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("UPDATE users SET age = ?, name = ?, phone = ? WHERE (name = ? AND id = ?)", 25, "John", "1234567890", "John", 1).Return("UPDATE users SET age = 25, name = \"John\", phone = \"1234567890\" WHERE (name = \"John\" AND id = 1)").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET age = 25, name = \"John\", phone = \"1234567890\" WHERE (name = \"John\" AND id = 1)", int64(1), nil).Return().Once()

result, err := s.query.Where("name", "John").Where("id", 1).Update(user)
s.Nil(err)
Expand All @@ -333,9 +333,9 @@ func (s *QueryTestSuite) TestUpdate() {
}

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Exec("UPDATE users SET phone = ? WHERE name = ? AND id = ?", "1234567890", "John", 1).Return(nil, assert.AnError).Once()
s.mockDriver.EXPECT().Explain("UPDATE users SET phone = ? WHERE name = ? AND id = ?", "1234567890", "John", 1).Return("UPDATE users SET phone = \"1234567890\" WHERE name = \"John\" AND id = 1").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET phone = \"1234567890\" WHERE name = \"John\" AND id = 1", int64(-1), assert.AnError).Return().Once()
s.mockBuilder.EXPECT().Exec("UPDATE users SET phone = ? WHERE (name = ? AND id = ?)", "1234567890", "John", 1).Return(nil, assert.AnError).Once()
s.mockDriver.EXPECT().Explain("UPDATE users SET phone = ? WHERE (name = ? AND id = ?)", "1234567890", "John", 1).Return("UPDATE users SET phone = \"1234567890\" WHERE (name = \"John\" AND id = 1)").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET phone = \"1234567890\" WHERE (name = \"John\" AND id = 1)", int64(-1), assert.AnError).Return().Once()

result, err := s.query.Where("name", "John").Where("id", 1).Update(user)
s.Nil(result)
Expand All @@ -353,9 +353,9 @@ func (s *QueryTestSuite) TestUpdate() {
mockResult.On("RowsAffected").Return(int64(0), assert.AnError).Once()

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Exec("UPDATE users SET phone = ? WHERE name = ? AND id = ?", "1234567890", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("UPDATE users SET phone = ? WHERE name = ? AND id = ?", "1234567890", "John", 1).Return("UPDATE users SET phone = \"1234567890\" WHERE name = \"John\" AND id = 1").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET phone = \"1234567890\" WHERE name = \"John\" AND id = 1", int64(-1), assert.AnError).Return().Once()
s.mockBuilder.EXPECT().Exec("UPDATE users SET phone = ? WHERE (name = ? AND id = ?)", "1234567890", "John", 1).Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("UPDATE users SET phone = ? WHERE (name = ? AND id = ?)", "1234567890", "John", 1).Return("UPDATE users SET phone = \"1234567890\" WHERE (name = \"John\" AND id = 1)").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET phone = \"1234567890\" WHERE (name = \"John\" AND id = 1)", int64(-1), assert.AnError).Return().Once()

result, err := s.query.Where("name", "John").Where("id", 1).Update(user)
s.Nil(result)
Expand All @@ -367,19 +367,19 @@ func (s *QueryTestSuite) TestWhere() {
now := carbon.Now()
carbon.SetTestNow(now)

s.Run("simple where condition", func() {
s.Run("simple condition", func() {
var user TestUser

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE name = ?", "John").Return(nil).Once()
s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE name = ?", "John").Return("SELECT * FROM users WHERE name = \"John\"").Once()
s.mockLogger.EXPECT().Trace(s.ctx, now, "SELECT * FROM users WHERE name = \"John\"", int64(1), nil).Return().Once()
s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE (name = ? AND age = ?)", "John", 25).Return(nil).Once()
s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE (name = ? AND age = ?)", "John", 25).Return("SELECT * FROM users WHERE (name = \"John\" AND age = 25)").Once()
s.mockLogger.EXPECT().Trace(s.ctx, now, "SELECT * FROM users WHERE (name = \"John\" AND age = 25)", int64(1), nil).Return().Once()

err := s.query.Where("name", "John").First(&user)
err := s.query.Where("name", "John").Where("age", 25).First(&user)
s.Nil(err)
})

s.Run("where with multiple arguments", func() {
s.Run("multiple arguments", func() {
var users []TestUser

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
Expand All @@ -391,7 +391,7 @@ func (s *QueryTestSuite) TestWhere() {
s.Nil(err)
})

s.Run("where with raw query", func() {
s.Run("raw query", func() {
var users []TestUser

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
Expand All @@ -402,6 +402,37 @@ func (s *QueryTestSuite) TestWhere() {
err := s.query.Where("age > ?", 18).Get(&users)
s.Nil(err)
})

// s.Run("nested condition", func() {
// var users []TestUser

// s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
// s.mockBuilder.EXPECT().Select(&users, "SELECT * FROM users WHERE age IN (?,?) AND name = ?", 25, 30, "John").Return(nil).Once()
// s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE age IN (?,?) AND name = ?", 25, 30, "John").Return("SELECT * FROM users WHERE age IN (25,30) AND name = \"John\"").Once()
// s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE age IN (25,30) AND name = \"John\"", int64(0), nil).Return().Once()

// err := s.query.Where(func(query db.Query) {
// query.Where("age", []int{25, 30}).Where("name", "John")
// }).Get(&users)
// s.Nil(err)
// })
}

func (s *QueryTestSuite) TestOrWhere() {
now := carbon.Now()
carbon.SetTestNow(now)

s.Run("simple condition", func() {
var user TestUser

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE (((name = ? AND age = ?) OR age = ?) OR name = ?)", "John", 25, 30, "Jane").Return(nil).Once()
s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE (((name = ? AND age = ?) OR age = ?) OR name = ?)", "John", 25, 30, "Jane").Return("SELECT * FROM users WHERE (((name = \"John\" AND age = 25) OR age = 30) OR name = \"Jane\")").Once()
s.mockLogger.EXPECT().Trace(s.ctx, now, "SELECT * FROM users WHERE (((name = \"John\" AND age = 25) OR age = 30) OR name = \"Jane\")", int64(1), nil).Return().Once()

err := s.query.Where("name", "John").Where("age", 25).OrWhere("age", 30).OrWhere("name", "Jane").First(&user)
s.Nil(err)
})
}

// MockResult implements sql.Result interface for testing
Expand Down
3 changes: 2 additions & 1 deletion event/console/listener_make_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ func (r *ListenerMakeCommand) Handle(ctx console.Context) error {
}

if err := file.PutContent(m.GetFilePath(), r.populateStub(r.getStub(), m.GetPackageName(), m.GetStructName())); err != nil {
return err
ctx.Error(err.Error())
return nil
}

ctx.Success("Listener created successfully")
Expand Down
Loading

0 comments on commit e26aa18

Please sign in to comment.