Skip to content

Commit

Permalink
feat: [#358] Optimize the difference between db and orm (#930)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwbrzzl authored Mar 2, 2025
1 parent 61702e0 commit 62ae533
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 132 deletions.
9 changes: 5 additions & 4 deletions contracts/database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ type Query interface {
// Distinct() Query
// dump
// dumpRawSql
Delete() (*Result, error)
// Each(callback func(rows []any) error) error
Exists() (bool, error)
Find(dest any, conds ...any) error
First(dest any) error
// FirstOr
FirstOrFail(dest any) error
// decrement
Delete() (*Result, error)
Get(dest any) error
// GroupBy(column string) Query
// GroupByRaw(query string, args ...any) Query
Expand All @@ -51,12 +52,12 @@ type Query interface {
OrderByDesc(column string) Query
OrderByRaw(raw string) Query
OrWhere(query any, args ...any) Query
OrWhereBetween(column string, args []any) Query
OrWhereBetween(column string, x, y any) Query
OrWhereColumn(column1 string, column2 ...string) Query
OrWhereIn(column string, args []any) Query
OrWhereLike(column string, value string) Query
OrWhereNot(query any, args ...any) Query
OrWhereNotBetween(column string, args []any) Query
OrWhereNotBetween(column string, x, y any) Query
OrWhereNotIn(column string, args []any) Query
OrWhereNotLike(column string, value string) Query
OrWhereNotNull(column string) Query
Expand All @@ -71,7 +72,7 @@ type Query interface {
// take
// ToSql
// ToRawSql
Update(data any) (*Result, error)
Update(column any, value ...any) (*Result, error)
// updateOrInsert
// Value(column string, dest any) error
// when
Expand Down
23 changes: 10 additions & 13 deletions contracts/database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"

"github.com/goravel/framework/contracts/database"
"github.com/goravel/framework/contracts/database/db"
)

type Orm interface {
Expand Down Expand Up @@ -52,13 +53,13 @@ type Query interface {
// DB gets the underlying database connection.
DB() (*sql.DB, error)
// Delete deletes records matching given conditions, if the conditions are empty will delete all records.
Delete(value ...any) (*Result, error)
Delete(value ...any) (*db.Result, error)
// Distinct specifies distinct fields to query.
Distinct(args ...any) Query
// Driver gets the driver for the query.
Driver() string
// Exec executes raw sql
Exec(sql string, values ...any) (*Result, error)
Exec(sql string, values ...any) (*db.Result, error)
// Exists returns true if matching records exist; otherwise, it returns false.
Exists(exists *bool) error
// Find finds records that match given conditions.
Expand All @@ -79,7 +80,7 @@ type Query interface {
// return a new instance of the model initialized with those attributes.
FirstOrNew(dest any, attributes any, values ...any) error
// ForceDelete forces delete records matching given conditions.
ForceDelete(value ...any) (*Result, error)
ForceDelete(value ...any) (*db.Result, error)
// Get retrieves all rows from the database.
Get(dest any) error
// Group specifies the group method on the query.
Expand Down Expand Up @@ -114,14 +115,14 @@ type Query interface {
OrderByDesc(column string) Query
// OrWhere add an "or where" clause to the query.
OrWhere(query any, args ...any) Query
// OrWhereIn adds an "or where column in" clause to the query.
OrWhereIn(column string, values []any) Query
// OrWhereNotIn adds an "or where column not in" clause to the query.
OrWhereNotIn(column string, values []any) Query
// OrWhereBetween adds an "or where column between x and y" clause to the query.
OrWhereBetween(column string, x, y any) Query
// OrWhereIn adds an "or where column in" clause to the query.
OrWhereIn(column string, values []any) Query
// OrWhereNotBetween adds an "or where column not between x and y" clause to the query.
OrWhereNotBetween(column string, x, y any) Query
// OrWhereNotIn adds an "or where column not in" clause to the query.
OrWhereNotIn(column string, values []any) Query
// OrWhereNull adds a "or where column is null" clause to the query.
OrWhereNull(column string) Query
// Paginate the given query into a simple paginator.
Expand All @@ -131,7 +132,7 @@ type Query interface {
// Raw creates a raw query.
Raw(sql string, values ...any) Query
// Restore restores a soft deleted model.
Restore(model ...any) (*Result, error)
Restore(model ...any) (*db.Result, error)
// Rollback rolls back the changes in a transaction.
Rollback() error
// Save updates value in a database
Expand All @@ -155,7 +156,7 @@ type Query interface {
// ToRawSql returns the query as a raw SQL string.
ToRawSql() ToSql
// Update updates records with the given column and values
Update(column any, value ...any) (*Result, error)
Update(column any, value ...any) (*db.Result, error)
// UpdateOrCreate finds the first record that matches the given attributes
// or create a new one with those attributes if none was found.
UpdateOrCreate(dest any, attributes any, values any) error
Expand Down Expand Up @@ -214,10 +215,6 @@ type Cursor interface {
Scan(value any) error
}

type Result struct {
RowsAffected int64
}

type ToSql interface {
Count() string
Create(value any) string
Expand Down
39 changes: 15 additions & 24 deletions database/db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ func (r *Query) Exists() (bool, error) {
}

func (r *Query) Find(dest any, conds ...any) error {
if r.err != nil {
return r.err
}

var q db.Query
if len(conds) > 2 {
return errors.DatabaseInvalidArgumentNumber.Args(len(conds), "1 or 2")
Expand Down Expand Up @@ -142,10 +138,6 @@ func (r *Query) First(dest any) error {
}

func (r *Query) FirstOrFail(dest any) error {
if r.err != nil {
return r.err
}

sql, args, err := r.buildSelect()
if err != nil {
return err
Expand Down Expand Up @@ -257,13 +249,8 @@ func (r *Query) OrWhere(query any, args ...any) db.Query {
return q
}

func (r *Query) OrWhereBetween(column string, args []any) db.Query {
if len(args) != 2 {
r.err = errors.DatabaseInvalidArgumentNumber.Args(len(args), "2")
return r
}

return r.OrWhere(sq.Expr(fmt.Sprintf("%s BETWEEN ? AND ?", column), args...))
func (r *Query) OrWhereBetween(column string, x, y any) db.Query {
return r.OrWhere(sq.Expr(fmt.Sprintf("%s BETWEEN ? AND ?", column), x, y))
}

func (r *Query) OrWhereColumn(column1 string, column2 ...string) db.Query {
Expand Down Expand Up @@ -312,13 +299,8 @@ func (r *Query) OrWhereNot(query any, args ...any) db.Query {
return r.OrWhere(sq.Expr(fmt.Sprintf("NOT (%s)", sql), args...))
}

func (r *Query) OrWhereNotBetween(column string, args []any) db.Query {
if len(args) != 2 {
r.err = errors.DatabaseInvalidArgumentNumber.Args(len(args), "2")
return r
}

return r.OrWhere(sq.Expr(fmt.Sprintf("%s NOT BETWEEN ? AND ?", column), args...))
func (r *Query) OrWhereNotBetween(column string, x, y any) db.Query {
return r.OrWhere(sq.Expr(fmt.Sprintf("%s NOT BETWEEN ? AND ?", column), x, y))
}

func (r *Query) OrWhereNotIn(column string, args []any) db.Query {
Expand Down Expand Up @@ -348,8 +330,17 @@ func (r *Query) Select(columns ...string) db.Query {
return q
}

func (r *Query) Update(data any) (*db.Result, error) {
mapData, err := convertToMap(data)
func (r *Query) Update(column any, value ...any) (*db.Result, error) {
columnStr, ok := column.(string)
if ok {
if len(value) != 1 {
return nil, errors.DatabaseInvalidArgumentNumber.Args(len(value), "1")
}

return r.Update(map[string]any{columnStr: value[0]})
}

mapData, err := convertToMap(column)
if err != nil {
return nil, err
}
Expand Down
25 changes: 23 additions & 2 deletions database/db/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ func (s *QueryTestSuite) TestOrWhereBetween() {
s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE (name = ? OR age BETWEEN ? AND ?)", "John", 18, 30).Return("SELECT * FROM users WHERE (name = \"John\" OR age BETWEEN 18 AND 30)").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE (name = \"John\" OR age BETWEEN 18 AND 30)", int64(0), nil).Return().Once()

err := s.query.Where("name", "John").OrWhereBetween("age", []any{18, 30}).Get(&users)
err := s.query.Where("name", "John").OrWhereBetween("age", 18, 30).Get(&users)
s.Nil(err)
}

Expand Down Expand Up @@ -536,7 +536,7 @@ func (s *QueryTestSuite) TestOrWhereNotBetween() {
s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE (name = ? OR age NOT BETWEEN ? AND ?)", "John", 18, 30).Return("SELECT * FROM users WHERE (name = \"John\" OR age NOT BETWEEN 18 AND 30)")
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE (name = \"John\" OR age NOT BETWEEN 18 AND 30)", int64(0), nil).Return().Once()

err := s.query.Where("name", "John").OrWhereNotBetween("age", []any{18, 30}).Get(&users)
err := s.query.Where("name", "John").OrWhereNotBetween("age", 18, 30).Get(&users)
s.Nil(err)
}

Expand Down Expand Up @@ -657,6 +657,27 @@ func (s *QueryTestSuite) TestUpdate() {
mockResult.AssertExpectations(s.T())
})

s.Run("single column", func() {
mockResult := &MockResult{}
mockResult.On("RowsAffected").Return(int64(1), nil)

s.mockDriver.EXPECT().Config().Return(database.Config{}).Once()
s.mockBuilder.EXPECT().Exec("UPDATE users SET phone = ? WHERE name = ?", "1234567890", "John").Return(mockResult, nil).Once()
s.mockDriver.EXPECT().Explain("UPDATE users SET phone = ? WHERE name = ?", "1234567890", "John").Return("UPDATE users SET phone = \"1234567890\" WHERE name = \"John\"").Once()
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET phone = \"1234567890\" WHERE name = \"John\"", int64(1), nil).Return().Once()

result, err := s.query.Where("name", "John").Update("phone", "1234567890")
s.Nil(err)
s.Equal(int64(1), result.RowsAffected)

mockResult.AssertExpectations(s.T())
})

s.Run("failed to update single column with wrong number of arguments", func() {
_, err := s.query.Where("name", "John").Update("phone", "1234567890", "1234567890")
s.Equal(errors.DatabaseInvalidArgumentNumber.Args(2, "1"), err)
})

s.Run("failed to exec", func() {
user := TestUser{
Phone: "1234567890",
Expand Down
31 changes: 16 additions & 15 deletions database/gorm/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/goravel/framework/contracts/config"
contractsdatabase "github.com/goravel/framework/contracts/database"
contractsdb "github.com/goravel/framework/contracts/database/db"
"github.com/goravel/framework/contracts/database/driver"
contractsorm "github.com/goravel/framework/contracts/database/orm"
"github.com/goravel/framework/contracts/log"
Expand Down Expand Up @@ -165,7 +166,7 @@ func (r *Query) DB() (*sql.DB, error) {
return r.instance.DB()
}

func (r *Query) Delete(dest ...any) (*contractsorm.Result, error) {
func (r *Query) Delete(dest ...any) (*contractsdb.Result, error) {
var (
realDest any
err error
Expand Down Expand Up @@ -194,7 +195,7 @@ func (r *Query) Delete(dest ...any) (*contractsorm.Result, error) {
return nil, err
}

return &contractsorm.Result{
return &contractsdb.Result{
RowsAffected: res.RowsAffected,
}, nil
}
Expand All @@ -210,11 +211,11 @@ func (r *Query) Driver() string {
return r.dbConfig.Driver
}

func (r *Query) Exec(sql string, values ...any) (*contractsorm.Result, error) {
func (r *Query) Exec(sql string, values ...any) (*contractsdb.Result, error) {
query := r.buildConditions()
result := query.instance.Exec(sql, values...)

return &contractsorm.Result{
return &contractsdb.Result{
RowsAffected: result.RowsAffected,
}, result.Error
}
Expand Down Expand Up @@ -379,7 +380,7 @@ func (r *Query) FirstOrNew(dest any, attributes any, values ...any) error {
return nil
}

func (r *Query) ForceDelete(dest ...any) (*contractsorm.Result, error) {
func (r *Query) ForceDelete(dest ...any) (*contractsdb.Result, error) {
var (
realDest any
err error
Expand Down Expand Up @@ -410,7 +411,7 @@ func (r *Query) ForceDelete(dest ...any) (*contractsorm.Result, error) {
}
}

return &contractsorm.Result{
return &contractsdb.Result{
RowsAffected: res.RowsAffected,
}, res.Error
}
Expand Down Expand Up @@ -635,7 +636,7 @@ func (r *Query) Raw(sql string, values ...any) contractsorm.Query {
return r.new(r.instance.Raw(sql, values...))
}

func (r *Query) Restore(model ...any) (*contractsorm.Result, error) {
func (r *Query) Restore(model ...any) (*contractsdb.Result, error) {
var (
realModel any
err error
Expand Down Expand Up @@ -679,7 +680,7 @@ func (r *Query) Restore(model ...any) (*contractsorm.Result, error) {
return nil, err
}

return &contractsorm.Result{
return &contractsdb.Result{
RowsAffected: res.RowsAffected,
}, res.Error
}
Expand Down Expand Up @@ -814,7 +815,7 @@ func (r *Query) ToRawSql() contractsorm.ToSql {
return NewToSql(r.setConditions(r.conditions), r.log, true)
}

func (r *Query) Update(column any, value ...any) (*contractsorm.Result, error) {
func (r *Query) Update(column any, value ...any) (*contractsdb.Result, error) {
query := r.buildConditions()

if _, ok := column.(string); !ok && len(value) > 0 {
Expand Down Expand Up @@ -1514,7 +1515,7 @@ func (r *Query) updated(dest any) error {
return r.event(contractsorm.EventUpdated, r.instance.Statement.Model, dest)
}

func (r *Query) update(values any) (*contractsorm.Result, error) {
func (r *Query) update(values any) (*contractsdb.Result, error) {
if len(r.instance.Statement.Selects) > 0 && len(r.instance.Statement.Omits) > 0 {
return nil, errors.OrmQuerySelectAndOmitsConflict
}
Expand All @@ -1523,15 +1524,15 @@ func (r *Query) update(values any) (*contractsorm.Result, error) {
for _, val := range r.instance.Statement.Selects {
if val == Associations {
result := r.instance.Session(&gormio.Session{FullSaveAssociations: true}).Updates(values)
return &contractsorm.Result{
return &contractsdb.Result{
RowsAffected: result.RowsAffected,
}, result.Error
}
}

result := r.instance.Updates(values)

return &contractsorm.Result{
return &contractsdb.Result{
RowsAffected: result.RowsAffected,
}, result.Error
}
Expand All @@ -1541,20 +1542,20 @@ func (r *Query) update(values any) (*contractsorm.Result, error) {
if val == Associations {
result := r.instance.Omit(Associations).Updates(values)

return &contractsorm.Result{
return &contractsdb.Result{
RowsAffected: result.RowsAffected,
}, result.Error
}
}
result := r.instance.Updates(values)

return &contractsorm.Result{
return &contractsdb.Result{
RowsAffected: result.RowsAffected,
}, result.Error
}
result := r.instance.Omit(Associations).Updates(values)

return &contractsorm.Result{
return &contractsdb.Result{
RowsAffected: result.RowsAffected,
}, result.Error
}
Expand Down
Loading

0 comments on commit 62ae533

Please sign in to comment.