diff --git a/contracts/database/orm/orm.go b/contracts/database/orm/orm.go index e853f4cad..6bc4cbcff 100644 --- a/contracts/database/orm/orm.go +++ b/contracts/database/orm/orm.go @@ -45,7 +45,7 @@ type Query interface { // Commit commits the changes in a transaction. Commit() error // Count retrieve the "count" result of the query. - Count(count *int64) error + Count() (int64, error) // Create inserts new record into the database. Create(value any) error // Cursor returns a cursor, use scan to iterate over the returned rows. @@ -61,7 +61,7 @@ type Query interface { // Exec executes raw sql Exec(sql string, values ...any) (*db.Result, error) // Exists returns true if matching records exist; otherwise, it returns false. - Exists(exists *bool) error + Exists() (bool, error) // Find finds records that match given conditions. Find(dest any, conds ...any) error // FindOrFail finds records that match given conditions or throws an error. @@ -108,11 +108,14 @@ type Query interface { // Omit specifies columns that should be omitted from the query. Omit(columns ...string) Query // Order specifies the order in which the results should be returned. + // DEPRECATED Use OrderByRaw instead. Order(value any) Query // OrderBy specifies the order should be ascending. OrderBy(column string, direction ...string) Query // OrderByDesc specifies the order should be descending. OrderByDesc(column string) Query + // OrderByRaw specifies the order should be raw. + OrderByRaw(raw string) Query // OrWhere add an "or where" clause to the query. OrWhere(query any, args ...any) Query // OrWhereBetween adds an "or where column between x and y" clause to the query. diff --git a/database/console/show_command.go b/database/console/show_command.go index 29bdedde9..893571732 100644 --- a/database/console/show_command.go +++ b/database/console/show_command.go @@ -135,12 +135,12 @@ func (r *ShowCommand) display(ctx console.Context, info databaseInfo) { ctx.TwoColumnDetail("Views", "Rows") for i := range info.Views { if !str.Of(info.Views[i].Name).StartsWith("pg_catalog", "information_schema", "spt_") { - var rows int64 - if err := r.schema.Orm().Query().Table(info.Views[i].Name).Count(&rows); err != nil { + count, err := r.schema.Orm().Query().Table(info.Views[i].Name).Count() + if err != nil { ctx.Error(err.Error()) return } - ctx.TwoColumnDetail(info.Views[i].Name, fmt.Sprintf("%d", rows)) + ctx.TwoColumnDetail(info.Views[i].Name, fmt.Sprintf("%d", count)) } } ctx.NewLine() diff --git a/database/console/show_command_test.go b/database/console/show_command_test.go index 249eab1f1..d09d08c27 100644 --- a/database/console/show_command_test.go +++ b/database/console/show_command_test.go @@ -140,8 +140,7 @@ func TestShowCommand(t *testing.T) { mockOrm.EXPECT().Query().Return(mockQuery).Once() mockQuery.EXPECT().Table("test").Return(mockQuery).Once() - var rows int64 - mockQuery.EXPECT().Count(&rows).Return(nil).Once() + mockQuery.EXPECT().Count().Return(int64(0), nil).Once() mockContext.EXPECT().NewLine().Times(4) for i := range successCaseExpected { mockContext.EXPECT().TwoColumnDetail(successCaseExpected[i][0], successCaseExpected[i][1]).Once() diff --git a/database/gorm/query.go b/database/gorm/query.go index e9d66f74d..da0a53b64 100644 --- a/database/gorm/query.go +++ b/database/gorm/query.go @@ -106,10 +106,17 @@ func (r *Query) Commit() error { return r.instance.Commit().Error } -func (r *Query) Count(count *int64) error { +func (r *Query) Count() (int64, error) { query := r.buildConditions() - return query.instance.Count(count).Error + var count int64 + + err := query.instance.Count(&count).Error + if err != nil { + return 0, err + } + + return count, nil } func (r *Query) Create(value any) error { @@ -220,10 +227,16 @@ func (r *Query) Exec(sql string, values ...any) (*contractsdb.Result, error) { }, result.Error } -func (r *Query) Exists(exists *bool) error { +func (r *Query) Exists() (bool, error) { query := r.buildConditions() - return query.instance.Select("1").Limit(1).Find(exists).Error + var exists bool + err := query.instance.Select("1").Limit(1).Find(&exists).Error + if err != nil { + return false, err + } + + return exists, nil } func (r *Query) Find(dest any, conds ...any) error { @@ -556,6 +569,7 @@ func (r *Query) Omit(columns ...string) contractsorm.Query { return r.setConditions(conditions) } +// DEPRECATED: Use OrderByRaw instead func (r *Query) Order(value any) contractsorm.Query { conditions := r.conditions conditions.order = append(r.conditions.order, value) @@ -570,11 +584,18 @@ func (r *Query) OrderBy(column string, direction ...string) contractsorm.Query { } else { orderDirection = "ASC" } - return r.Order(fmt.Sprintf("%s %s", column, orderDirection)) + return r.OrderByRaw(fmt.Sprintf("%s %s", column, orderDirection)) } func (r *Query) OrderByDesc(column string) contractsorm.Query { - return r.Order(fmt.Sprintf("%s DESC", column)) + return r.OrderByRaw(fmt.Sprintf("%s DESC", column)) +} + +func (r *Query) OrderByRaw(raw string) contractsorm.Query { + conditions := r.conditions + conditions.order = append(r.conditions.order, raw) + + return r.setConditions(conditions) } func (r *Query) Instance() *gormio.DB { @@ -582,7 +603,7 @@ func (r *Query) Instance() *gormio.DB { } func (r *Query) InRandomOrder() contractsorm.Query { - return r.Order(r.gormQuery.RandomOrder()) + return r.OrderByRaw(r.gormQuery.RandomOrder()) } func (r *Query) InTransaction() bool { @@ -613,13 +634,17 @@ func (r *Query) Paginate(page, limit int, dest any, total *int64) error { offset := (page - 1) * limit if total != nil { if query.conditions.table == nil && query.conditions.model == nil { - if err := query.Model(dest).Count(total); err != nil { + count, err := query.Model(dest).Count() + if err != nil { return err } + *total = count } else { - if err := query.Count(total); err != nil { + count, err := query.Count() + if err != nil { return err } + *total = count } } diff --git a/mocks/database/orm/Query.go b/mocks/database/orm/Query.go index 09ac028c1..4ae88abf9 100644 --- a/mocks/database/orm/Query.go +++ b/mocks/database/orm/Query.go @@ -174,22 +174,32 @@ func (_c *Query_Commit_Call) RunAndReturn(run func() error) *Query_Commit_Call { return _c } -// Count provides a mock function with given fields: count -func (_m *Query) Count(count *int64) error { - ret := _m.Called(count) +// Count provides a mock function with no fields +func (_m *Query) Count() (int64, error) { + ret := _m.Called() if len(ret) == 0 { panic("no return value specified for Count") } - var r0 error - if rf, ok := ret.Get(0).(func(*int64) error); ok { - r0 = rf(count) + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func() (int64, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() } else { - r0 = ret.Error(0) + r0 = ret.Get(0).(int64) } - return r0 + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // Query_Count_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Count' @@ -198,24 +208,23 @@ type Query_Count_Call struct { } // Count is a helper method to define mock.On call -// - count *int64 -func (_e *Query_Expecter) Count(count interface{}) *Query_Count_Call { - return &Query_Count_Call{Call: _e.mock.On("Count", count)} +func (_e *Query_Expecter) Count() *Query_Count_Call { + return &Query_Count_Call{Call: _e.mock.On("Count")} } -func (_c *Query_Count_Call) Run(run func(count *int64)) *Query_Count_Call { +func (_c *Query_Count_Call) Run(run func()) *Query_Count_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*int64)) + run() }) return _c } -func (_c *Query_Count_Call) Return(_a0 error) *Query_Count_Call { - _c.Call.Return(_a0) +func (_c *Query_Count_Call) Return(_a0 int64, _a1 error) *Query_Count_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *Query_Count_Call) RunAndReturn(run func(*int64) error) *Query_Count_Call { +func (_c *Query_Count_Call) RunAndReturn(run func() (int64, error)) *Query_Count_Call { _c.Call.Return(run) return _c } @@ -618,22 +627,32 @@ func (_c *Query_Exec_Call) RunAndReturn(run func(string, ...interface{}) (*db.Re return _c } -// Exists provides a mock function with given fields: exists -func (_m *Query) Exists(exists *bool) error { - ret := _m.Called(exists) +// Exists provides a mock function with no fields +func (_m *Query) Exists() (bool, error) { + ret := _m.Called() if len(ret) == 0 { panic("no return value specified for Exists") } - var r0 error - if rf, ok := ret.Get(0).(func(*bool) error); ok { - r0 = rf(exists) + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func() (bool, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() } else { - r0 = ret.Error(0) + r0 = ret.Get(0).(bool) } - return r0 + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // Query_Exists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Exists' @@ -642,24 +661,23 @@ type Query_Exists_Call struct { } // Exists is a helper method to define mock.On call -// - exists *bool -func (_e *Query_Expecter) Exists(exists interface{}) *Query_Exists_Call { - return &Query_Exists_Call{Call: _e.mock.On("Exists", exists)} +func (_e *Query_Expecter) Exists() *Query_Exists_Call { + return &Query_Exists_Call{Call: _e.mock.On("Exists")} } -func (_c *Query_Exists_Call) Run(run func(exists *bool)) *Query_Exists_Call { +func (_c *Query_Exists_Call) Run(run func()) *Query_Exists_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*bool)) + run() }) return _c } -func (_c *Query_Exists_Call) Return(_a0 error) *Query_Exists_Call { - _c.Call.Return(_a0) +func (_c *Query_Exists_Call) Return(_a0 bool, _a1 error) *Query_Exists_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *Query_Exists_Call) RunAndReturn(run func(*bool) error) *Query_Exists_Call { +func (_c *Query_Exists_Call) RunAndReturn(run func() (bool, error)) *Query_Exists_Call { _c.Call.Return(run) return _c } @@ -2235,6 +2253,54 @@ func (_c *Query_OrderByDesc_Call) RunAndReturn(run func(string) orm.Query) *Quer return _c } +// OrderByRaw provides a mock function with given fields: raw +func (_m *Query) OrderByRaw(raw string) orm.Query { + ret := _m.Called(raw) + + if len(ret) == 0 { + panic("no return value specified for OrderByRaw") + } + + var r0 orm.Query + if rf, ok := ret.Get(0).(func(string) orm.Query); ok { + r0 = rf(raw) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(orm.Query) + } + } + + return r0 +} + +// Query_OrderByRaw_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OrderByRaw' +type Query_OrderByRaw_Call struct { + *mock.Call +} + +// OrderByRaw is a helper method to define mock.On call +// - raw string +func (_e *Query_Expecter) OrderByRaw(raw interface{}) *Query_OrderByRaw_Call { + return &Query_OrderByRaw_Call{Call: _e.mock.On("OrderByRaw", raw)} +} + +func (_c *Query_OrderByRaw_Call) Run(run func(raw string)) *Query_OrderByRaw_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Query_OrderByRaw_Call) Return(_a0 orm.Query) *Query_OrderByRaw_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Query_OrderByRaw_Call) RunAndReturn(run func(string) orm.Query) *Query_OrderByRaw_Call { + _c.Call.Return(run) + return _c +} + // Paginate provides a mock function with given fields: page, limit, dest, total func (_m *Query) Paginate(page int, limit int, dest interface{}, total *int64) error { ret := _m.Called(page, limit, dest, total) diff --git a/tests/query_test.go b/tests/query_test.go index d9e89840d..5a3cec9cf 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -308,13 +308,13 @@ func (s *QueryTestSuite) TestCount() { s.Nil(query.Query().Create(&user1)) s.True(user1.ID > 0) - var count int64 - s.Nil(query.Query().Model(&User{}).Where("name = ?", "count_user").Count(&count)) + count, err := query.Query().Model(&User{}).Where("name = ?", "count_user").Count() + s.Nil(err) s.True(count > 0) - var count1 int64 - s.Nil(query.Query().Table("users").Where("name = ?", "count_user").Count(&count1)) - s.True(count1 > 0) + count, err = query.Query().Table("users").Where("name = ?", "count_user").Count() + s.Nil(err) + s.True(count > 0) }) } } @@ -429,8 +429,7 @@ func (s *QueryTestSuite) TestCreate() { s.Nil(query.Query().Create(&people)) s.True(people.ID > 0) - var count int64 - err := query.Query().Table("peoples").Where("body", "create_people").Count(&count) + count, err := query.Query().Table("peoples").Where("body", "create_people").Count() s.NoError(err) s.True(count == 0) @@ -444,7 +443,7 @@ func (s *QueryTestSuite) TestCreate() { s.Nil(query.Query().Where("body", "create_people1").First(&people1)) s.True(people1.ID > 0) - err = query.Query().Table("peoples").Where("body", "create_people1").Count(&count) + count, err = query.Query().Table("peoples").Where("body", "create_people1").Count() s.NoError(err) s.True(count == 0) }, @@ -745,8 +744,8 @@ func (s *QueryTestSuite) TestDelete() { s.Equal(int64(2), res.RowsAffected) s.Nil(err) - var count int64 - s.Nil(query.Query().Model(&User{}).Where("name", "delete_user").OrWhere("name", "delete_user1").Count(&count)) + count, err := query.Query().Model(&User{}).Where("name", "delete_user").OrWhere("name", "delete_user1").Count() + s.Nil(err) s.True(count == 0) }, }, @@ -1717,13 +1716,13 @@ func (s *QueryTestSuite) TestExists() { s.Nil(query.Query().Create(&user1)) s.True(user1.ID > 0) - var t bool - s.Nil(query.Query().Model(&User{}).Where("name = ?", "exists_user").Exists(&t)) - s.True(t) + exists, err := query.Query().Model(&User{}).Where("name = ?", "exists_user").Exists() + s.Nil(err) + s.True(exists) - var f bool - s.Nil(query.Query().Model(&User{}).Where("name = ?", "no_exists_user").Exists(&f)) - s.False(f) + exists, err = query.Query().Model(&User{}).Where("name = ?", "no_exists_user").Exists() + s.Nil(err) + s.False(exists) }) } } @@ -2155,7 +2154,7 @@ func (s *QueryTestSuite) TestOrder() { s.True(user1.ID > 0) var user2 []User - s.Nil(query.Query().Where("name = ?", "order_user").Order("id desc").Order("name asc").Get(&user2)) + s.Nil(query.Query().Where("name = ?", "order_user").OrderByRaw("id desc, name asc").Get(&user2)) s.True(len(user2) > 0) s.True(user2[0].ID > 0) }) @@ -2987,8 +2986,8 @@ func (s *QueryTestSuite) TestRestore() { s.Equal(int64(1), res.RowsAffected) s.NoError(err) - var count int64 - s.NoError(query.Query().Model(&User{}).Where("avatar", "restore_avatar").Count(&count)) + count, err := query.Query().Model(&User{}).Where("avatar", "restore_avatar").Count() + s.NoError(err) s.Equal(int64(4), count) }) } @@ -3180,8 +3179,7 @@ func (s *QueryTestSuite) TestUpdateOrCreate() { s.Nil(err) s.True(user3.ID > 0) - var count int64 - err = query.Query().Model(User{}).Where("name", "update_or_create_user").Count(&count) + count, err := query.Query().Model(User{}).Where("name", "update_or_create_user").Count() s.Nil(err) s.Equal(int64(1), count) }) diff --git a/tests/schema_test.go b/tests/schema_test.go index cf40efea5..a8675ae49 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -2341,8 +2341,7 @@ func (s *SchemaSuite) TestSql() { })) s.NoError(schema.Sql("insert into goravel_sql (name) values ('goravel');")) - var count int64 - err := testQuery.Query().Table("sql").Where("name", "goravel").Count(&count) + count, err := testQuery.Query().Table("sql").Where("name", "goravel").Count() s.NoError(err) s.Equal(int64(1), count)