diff --git a/contracts/database/db/db.go b/contracts/database/db/db.go index 18911918e..e8a2c010e 100644 --- a/contracts/database/db/db.go +++ b/contracts/database/db/db.go @@ -14,80 +14,125 @@ type DB interface { } type Query interface { - // Avg(column string) (any, error) // commit - // Count(dest *int64) error - // Chunk(size int, callback func(rows []any) error) error + // Count Retrieve the "count" result of the query. + Count() (int64, error) + // Chunk Execute a callback over a given chunk size. + // Chunk(size int, callback func(dest []any) error) error // CrossJoin(table string, on any, args ...any) Query - // DoesntExist() (bool, error) - // Distinct() Query + // DoesntExist Determine if no rows exist for the current query. + DoesntExist() (bool, error) + // Distinct Force the query to only return distinct results. + Distinct() Query // dump // dumpRawSql + // Delete records from the database. Delete() (*Result, error) // Each(callback func(rows []any) error) error + // Exists Determine if any rows exist for the current query. Exists() (bool, error) + // Find Execute a query for a single record by ID. Find(dest any, conds ...any) error + // First finds record that match given conditions. First(dest any) error - // FirstOr + // FirstOr finds the first record that matches the given conditions or execute the callback and return its result if no record is found. + FirstOr(dest any, callback func() error) error + // FirstOrFail finds the first record that matches the given conditions or throws an error. FirstOrFail(dest any) error - // decrement + // Decrement the given column's values by the given amounts. + Decrement(column string, value ...uint64) error + // Get Retrieve all rows from the database. Get(dest any) error // GroupBy(column string) Query // GroupByRaw(query string, args ...any) Query // having // HavingRaw(query any, args ...any) Query - // increment + // Increment a column's value by a given amount. + Increment(column string, value ...uint64) error // inRandomOrder + // Insert a new record into the database. Insert(data any) (*Result, error) - // incrementEach - // insertGetId + // InsertGetId returns the ID of the inserted row, only supported by MySQL and Sqlite + InsertGetId(data any) (int64, error) // Join(table string, on any, args ...any) Query - // latest + // Latest Retrieve the latest record from the database. + Latest(dest any, column ...string) error // LeftJoin(table string, on any, args ...any) Query - // limit + // Limit(limit uint64) Query // lockForUpdate - // Max(column string) (any, error) // offset + // OrderBy Add an "order by" clause to the query. OrderBy(column string) Query + // OrderByDesc Add a descending "order by" clause to the query. OrderByDesc(column string) Query + // OrderByRaw Add a raw "order by" clause to the query. OrderByRaw(raw string) Query + // OrWhere add an "or where" clause to the query. OrWhere(query any, args ...any) Query + // OrWhereBetween adds an "or where column between x and y" clause to the query. OrWhereBetween(column string, x, y any) Query + // OrWhereColumn adds an "or where column" clause to the query. OrWhereColumn(column1 string, column2 ...string) Query + // OrWhereIn adds an "or where column in" clause to the query. OrWhereIn(column string, args []any) Query + // OrWhereLike adds an "or where column like" clause to the query. OrWhereLike(column string, value string) Query + // OrWhereNot adds an "or where not" clause to the query. OrWhereNot(query any, args ...any) Query + // OrWhereNotBetween adds an "or where column not between x and y" clause to the query. OrWhereNotBetween(column string, x, y any) Query + // OrWhereNotIn adds an "or where column not in" clause to the query. OrWhereNotIn(column string, args []any) Query + // OrWhereNotLike adds an "or where column not like" clause to the query. OrWhereNotLike(column string, value string) Query + // OrWhereNotNull adds an "or where column is not null" clause to the query. OrWhereNotNull(column string) Query + // OrWhereNull adds an "or where column is null" clause to the query. OrWhereNull(column string) Query + // OrWhereRaw adds a raw "or where" clause to the query. OrWhereRaw(raw string, args []any) Query - // Pluck(column string, dest any) error + // Pluck Get a collection instance containing the values of a given column. + Pluck(column string, dest any) error // rollBack // RightJoin(table string, on any, args ...any) Query + // Select Set the columns to be selected. Select(columns ...string) Query // sharedLock // skip // take // ToSql // ToRawSql + // Update records in the database. Update(column any, value ...any) (*Result, error) // updateOrInsert // Value(column string, dest any) error - // when + // When executes the callback if the condition is true. + When(condition bool, callback func(query Query) Query) Query + // Where Add a basic where clause to the query. Where(query any, args ...any) Query + // WhereBetween Add a where between statement to the query. WhereBetween(column string, x, y any) Query + // WhereColumn Add a "where" clause comparing two columns to the query. WhereColumn(column1 string, column2 ...string) Query + // WhereExists Add an exists clause to the query. WhereExists(func() Query) Query + // WhereIn Add a "where in" clause to the query. WhereIn(column string, args []any) Query + // WhereLike Add a "where like" clause to the query. WhereLike(column string, value string) Query + // WhereNot Add a basic "where not" clause to the query. WhereNot(query any, args ...any) Query + // WhereNotBetween Add a where not between statement to the query. WhereNotBetween(column string, x, y any) Query + // WhereNotIn Add a "where not in" clause to the query. WhereNotIn(column string, args []any) Query + // WhereNotLike Add a "where not like" clause to the query. WhereNotLike(column string, value string) Query + // WhereNotNull Add a "where not null" clause to the query. WhereNotNull(column string) Query + // WhereNull Add a "where null" clause to the query. WhereNull(column string) Query + // WhereRaw Add a raw where clause to the query. WhereRaw(raw string, args []any) Query } @@ -98,6 +143,6 @@ type Result struct { type Builder interface { Exec(query string, args ...any) (sql.Result, error) Get(dest any, query string, args ...any) error - // Query(query string, args ...any) (*sql.Rows, error) + Query(query string, args ...any) (*sql.Rows, error) Select(dest any, query string, args ...any) error } diff --git a/database/db/conditions.go b/database/db/conditions.go index 86355fdc8..7f8520d9a 100644 --- a/database/db/conditions.go +++ b/database/db/conditions.go @@ -1,10 +1,12 @@ package db type Conditions struct { - table string - where []Where - orderBy []string - selects []string + Distinct *bool + Limit *uint64 + OrderBy []string + Selects []string + Table string + Where []Where } type Where struct { diff --git a/database/db/query.go b/database/db/query.go index 4fd148e2a..6390957f5 100644 --- a/database/db/query.go +++ b/database/db/query.go @@ -16,6 +16,7 @@ import ( "github.com/goravel/framework/contracts/database/logger" "github.com/goravel/framework/errors" "github.com/goravel/framework/support/carbon" + "github.com/goravel/framework/support/convert" "github.com/goravel/framework/support/str" ) @@ -33,10 +34,10 @@ func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, log return &Query{ builder: builder, conditions: Conditions{ - table: table, + Table: table, }, - driver: driver, ctx: ctx, + driver: driver, logger: logger, } } @@ -48,6 +49,50 @@ func NewSingleQuery(ctx context.Context, driver driver.Driver, builder db.Builde return query } +func (r *Query) Count() (int64, error) { + r.conditions.Selects = []string{"COUNT(*)"} + + sql, args, err := r.buildSelect() + if err != nil { + return 0, err + } + + var count int64 + err = r.builder.Get(&count, sql, args...) + if err != nil { + r.trace(sql, args, -1, err) + + return 0, err + } + + r.trace(sql, args, -1, nil) + + return count, nil +} + +// func (r *Query) Chunk(size int, callback func(dest []any) error) error { +// sql, args, err := r.buildSelect() +// if err != nil { +// return err +// } + +// return nil +// } + +func (r *Query) Decrement(column string, value ...uint64) error { + v := uint64(1) + if len(value) > 0 { + v = value[0] + } + + _, err := r.Update(column, sq.Expr(fmt.Sprintf("%s - ?", column), v)) + if err != nil { + return err + } + + return nil +} + func (r *Query) Delete() (*db.Result, error) { sql, args, err := r.buildDelete() if err != nil { @@ -73,24 +118,28 @@ func (r *Query) Delete() (*db.Result, error) { }, nil } -func (r *Query) Exists() (bool, error) { - r.conditions.selects = []string{"COUNT(*)"} +func (r *Query) Distinct() db.Query { + q := r.clone() + q.conditions.Distinct = convert.Pointer(true) - sql, args, err := r.buildSelect() + return q +} + +func (r *Query) DoesntExist() (bool, error) { + count, err := r.Count() if err != nil { return false, err } - var count int64 - err = r.builder.Get(&count, sql, args...) - if err != nil { - r.trace(sql, args, -1, err) + return count == 0, nil +} +func (r *Query) Exists() (bool, error) { + count, err := r.Count() + if err != nil { return false, err } - r.trace(sql, args, -1, nil) - return count > 0, nil } @@ -137,6 +186,30 @@ func (r *Query) First(dest any) error { return nil } +func (r *Query) FirstOr(dest any, callback func() error) error { + sql, args, err := r.buildSelect() + if err != nil { + return err + } + + err = r.builder.Get(dest, sql, args...) + if err != nil { + if errors.Is(err, databasesql.ErrNoRows) { + r.trace(sql, args, 0, nil) + + return callback() + } + + r.trace(sql, args, -1, err) + + return err + } + + r.trace(sql, args, 1, nil) + + return nil +} + func (r *Query) FirstOrFail(dest any) error { sql, args, err := r.buildSelect() if err != nil { @@ -182,6 +255,20 @@ func (r *Query) Get(dest any) error { return nil } +func (r *Query) Increment(column string, value ...uint64) error { + v := uint64(1) + if len(value) > 0 { + v = value[0] + } + + _, err := r.Update(column, sq.Expr(fmt.Sprintf("%s + ?", column), v)) + if err != nil { + return err + } + + return nil +} + func (r *Query) Insert(data any) (*db.Result, error) { mapData, err := convertToSliceMap(data) if err != nil { @@ -217,30 +304,77 @@ func (r *Query) Insert(data any) (*db.Result, error) { }, nil } +func (r *Query) InsertGetId(data any) (int64, error) { + mapData, err := convertToMap(data) + if err != nil { + return 0, err + } + if len(mapData) == 0 { + return 0, errors.DatabaseUnsupportedType.Args("nil", "struct, map[string]any").SetModule("DB") + } + + sql, args, err := r.buildInsert([]map[string]any{mapData}) + if err != nil { + return 0, err + } + + result, err := r.builder.Exec(sql, args...) + if err != nil { + r.trace(sql, args, -1, err) + return 0, err + } + + id, err := result.LastInsertId() + if err != nil { + r.trace(sql, args, -1, err) + return 0, err + } + + r.trace(sql, args, id, nil) + + return id, nil +} + +// func (r *Query) Limit(limit uint64) db.Query { +// q := r.clone() +// q.conditions.Limit = &limit + +// return q +// } + +func (r *Query) Latest(dest any, column ...string) error { + col := "created_at" + if len(column) > 0 { + col = column[0] + } + + return r.OrderByDesc(col).First(dest) +} + func (r *Query) OrderBy(column string) db.Query { q := r.clone() - q.conditions.orderBy = append(q.conditions.orderBy, column+" ASC") + q.conditions.OrderBy = append(q.conditions.OrderBy, column+" ASC") return q } func (r *Query) OrderByDesc(column string) db.Query { q := r.clone() - q.conditions.orderBy = append(q.conditions.orderBy, column+" DESC") + q.conditions.OrderBy = append(q.conditions.OrderBy, column+" DESC") return q } func (r *Query) OrderByRaw(raw string) db.Query { q := r.clone() - q.conditions.orderBy = append(q.conditions.orderBy, raw) + q.conditions.OrderBy = append(q.conditions.OrderBy, raw) return q } func (r *Query) OrWhere(query any, args ...any) db.Query { q := r.clone() - q.conditions.where = append(q.conditions.where, Where{ + q.conditions.Where = append(q.conditions.Where, Where{ query: query, args: args, or: true, @@ -323,9 +457,15 @@ func (r *Query) OrWhereRaw(raw string, args []any) db.Query { return r.OrWhere(sq.Expr(raw, args...)) } +func (r *Query) Pluck(column string, dest any) error { + r.conditions.Selects = []string{column} + + return r.Get(dest) +} + func (r *Query) Select(columns ...string) db.Query { q := r.clone() - q.conditions.selects = append(q.conditions.selects, columns...) + q.conditions.Selects = append(q.conditions.Selects, columns...) return q } @@ -369,9 +509,43 @@ func (r *Query) Update(column any, value ...any) (*db.Result, error) { }, nil } +// func (r *Query) Value(column string, dest any) error { +// r.conditions.Selects = []string{column} +// r.conditions.Limit = convert.Pointer(uint64(1)) + +// sql, args, err := r.buildSelect() +// if err != nil { +// return err +// } + +// err = r.builder.Get(dest, sql, args...) +// if err != nil { +// if errors.Is(err, databasesql.ErrNoRows) { +// r.trace(sql, args, 0, nil) +// return nil +// } + +// r.trace(sql, args, -1, err) + +// return err +// } + +// r.trace(sql, args, -1, nil) + +// return nil +// } + +func (r *Query) When(condition bool, callback func(query db.Query) db.Query) db.Query { + if condition { + return callback(r) + } + + return r +} + func (r *Query) Where(query any, args ...any) db.Query { q := r.clone() - q.conditions.where = append(q.conditions.where, Where{ + q.conditions.Where = append(q.conditions.Where, Where{ query: query, args: args, }) @@ -471,16 +645,16 @@ func (r *Query) buildDelete() (sql string, args []any, err error) { return "", nil, r.err } - if r.conditions.table == "" { + if r.conditions.Table == "" { return "", nil, errors.DatabaseTableIsRequired } - builder := sq.Delete(r.conditions.table) + builder := sq.Delete(r.conditions.Table) if placeholderFormat := r.placeholderFormat(); placeholderFormat != nil { builder = builder.PlaceholderFormat(placeholderFormat) } - sqlizer, err := r.buildWheres(r.conditions.where) + sqlizer, err := r.buildWheres(r.conditions.Where) if err != nil { return "", nil, err } @@ -493,11 +667,11 @@ func (r *Query) buildInsert(data []map[string]any) (sql string, args []any, err return "", nil, r.err } - if r.conditions.table == "" { + if r.conditions.Table == "" { return "", nil, errors.DatabaseTableIsRequired } - builder := sq.Insert(r.conditions.table) + builder := sq.Insert(r.conditions.Table) if placeholderFormat := r.placeholderFormat(); placeholderFormat != nil { builder = builder.PlaceholderFormat(placeholderFormat) } @@ -526,28 +700,37 @@ func (r *Query) buildSelect() (sql string, args []any, err error) { return "", nil, r.err } - if r.conditions.table == "" { + if r.conditions.Table == "" { return "", nil, errors.DatabaseTableIsRequired } selects := "*" - if len(r.conditions.selects) > 0 { - selects = strings.Join(r.conditions.selects, ", ") + if len(r.conditions.Selects) > 0 { + selects = strings.Join(r.conditions.Selects, ", ") } builder := sq.Select(selects) + + if r.conditions.Distinct != nil && *r.conditions.Distinct { + builder = builder.Distinct() + } + if placeholderFormat := r.placeholderFormat(); placeholderFormat != nil { builder = builder.PlaceholderFormat(placeholderFormat) } - builder = builder.From(r.conditions.table) - sqlizer, err := r.buildWheres(r.conditions.where) + builder = builder.From(r.conditions.Table) + sqlizer, err := r.buildWheres(r.conditions.Where) if err != nil { return "", nil, err } builder = builder.Where(sqlizer) - builder = builder.OrderBy(r.conditions.orderBy...) + builder = builder.OrderBy(r.conditions.OrderBy...) + + if r.conditions.Limit != nil { + builder = builder.Limit(*r.conditions.Limit) + } return builder.ToSql() } @@ -557,16 +740,16 @@ func (r *Query) buildUpdate(data map[string]any) (sql string, args []any, err er return "", nil, r.err } - if r.conditions.table == "" { + if r.conditions.Table == "" { return "", nil, errors.DatabaseTableIsRequired } - builder := sq.Update(r.conditions.table) + builder := sq.Update(r.conditions.Table) if placeholderFormat := r.placeholderFormat(); placeholderFormat != nil { builder = builder.PlaceholderFormat(placeholderFormat) } - sqlizer, err := r.buildWheres(r.conditions.where) + sqlizer, err := r.buildWheres(r.conditions.Where) if err != nil { return "", nil, err } @@ -587,11 +770,11 @@ func (r *Query) buildWhere(where Where) (any, []any, error) { return query, where.args, nil case func(db.Query): // Handle nested conditions by creating a new query and applying the callback - nestedQuery := NewSingleQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.table) + nestedQuery := NewSingleQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table) query(nestedQuery) // Build the nested conditions - sqlizer, err := r.buildWheres(nestedQuery.conditions.where) + sqlizer, err := r.buildWheres(nestedQuery.conditions.Where) if err != nil { return nil, nil, err } @@ -655,7 +838,7 @@ func (r *Query) clone() *Query { return r } - query := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.table) + query := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table) query.conditions = r.conditions query.err = r.err diff --git a/database/db/query_test.go b/database/db/query_test.go index 7af3aedad..73b418eb3 100644 --- a/database/db/query_test.go +++ b/database/db/query_test.go @@ -51,6 +51,37 @@ func (s *QueryTestSuite) SetupTest() { s.query = NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "users") } +func (s *QueryTestSuite) TestCount() { + var count int64 + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Get(&count, "SELECT COUNT(*) FROM users WHERE name = ?", "John").Run(func(dest any, query string, args ...any) { + destCount := dest.(*int64) + *destCount = 1 + }).Return(nil).Once() + s.mockDriver.EXPECT().Explain("SELECT COUNT(*) FROM users WHERE name = ?", "John").Return("SELECT COUNT(*) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(*) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + count, err := s.query.Where("name", "John").Count() + s.NoError(err) + s.Equal(int64(1), count) +} + +func (s *QueryTestSuite) TestDecrement() { + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Exec("UPDATE users SET age = age - ? WHERE name = ?", uint64(1), "John").Return(mockResult, nil).Once() + s.mockDriver.EXPECT().Explain("UPDATE users SET age = age - ? WHERE name = ?", uint64(1), "John").Return("UPDATE users SET age = age - 1 WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET age = age - 1 WHERE name = \"John\"", int64(1), nil).Return().Once() + + err := s.query.Where("name", "John").Decrement("age") + s.NoError(err) + + mockResult.AssertExpectations(s.T()) +} + func (s *QueryTestSuite) TestDelete() { s.Run("success", func() { mockResult := &MockResult{} @@ -92,6 +123,18 @@ func (s *QueryTestSuite) TestDelete() { }) } +func (s *QueryTestSuite) TestDistinct() { + var users TestUser + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Get(&users, "SELECT DISTINCT * FROM users WHERE name = ?", "John").Return(nil).Once() + s.mockDriver.EXPECT().Explain("SELECT DISTINCT * FROM users WHERE name = ?", "John").Return("SELECT DISTINCT * FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT DISTINCT * FROM users WHERE name = \"John\"", int64(1), nil).Return().Once() + + err := s.query.Where("name", "John").Distinct().First(&users) + s.NoError(err) +} + func (s *QueryTestSuite) TestExists() { var count int64 @@ -200,6 +243,21 @@ func (s *QueryTestSuite) TestFirst() { }) } +func (s *QueryTestSuite) TestFirstOr() { + var user TestUser + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE name = ?", "John").Return(databasesql.ErrNoRows).Once() + s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE name = ?", "John").Return("SELECT * FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE name = \"John\"", int64(0), nil).Return().Once() + + err := s.query.Where("name", "John").FirstOr(&user, func() error { + return errors.New("no rows") + }) + + s.Equal(errors.New("no rows"), err) +} + func (s *QueryTestSuite) TestFirstOrFail() { s.Run("success", func() { var user TestUser @@ -271,6 +329,21 @@ func (s *QueryTestSuite) TestGet() { }) } +func (s *QueryTestSuite) TestIncrement() { + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Exec("UPDATE users SET age = age + ? WHERE name = ?", uint64(1), "John").Return(mockResult, nil).Once() + s.mockDriver.EXPECT().Explain("UPDATE users SET age = age + ? WHERE name = ?", uint64(1), "John").Return("UPDATE users SET age = age + 1 WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "UPDATE users SET age = age + 1 WHERE name = \"John\"", int64(1), nil).Return().Once() + + err := s.query.Where("name", "John").Increment("age") + s.NoError(err) + + mockResult.AssertExpectations(s.T()) +} + func (s *QueryTestSuite) TestInsert() { s.Run("empty", func() { result, err := s.query.Insert(nil) @@ -389,6 +462,90 @@ func (s *QueryTestSuite) TestInsert() { }) } +func (s *QueryTestSuite) TestInsertGetId() { + s.Run("empty", func() { + id, err := s.query.InsertGetId(nil) + s.Equal(errors.DatabaseUnsupportedType.Args("nil", "struct, map[string]any").SetModule("DB"), err) + s.Equal(int64(0), id) + }) + + s.Run("success", func() { + user := map[string]any{ + "name": "John", + "age": 25, + } + + mockResult := &MockResult{} + mockResult.On("LastInsertId").Return(int64(1), nil) + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Exec("INSERT INTO users (age,name) VALUES (?,?)", 25, "John").Return(mockResult, nil).Once() + s.mockDriver.EXPECT().Explain("INSERT INTO users (age,name) VALUES (?,?)", 25, "John").Return("INSERT INTO users (age,name) VALUES (25,\"John\")").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "INSERT INTO users (age,name) VALUES (25,\"John\")", int64(1), nil).Return().Once() + + id, err := s.query.InsertGetId(user) + s.Nil(err) + s.Equal(int64(1), id) + + mockResult.AssertExpectations(s.T()) + }) + + s.Run("failed to exec", func() { + user := TestUser{ + ID: 1, + Name: "John", + Age: 25, + } + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Exec("INSERT INTO users (id) VALUES (?)", uint(1)).Return(nil, assert.AnError).Once() + s.mockDriver.EXPECT().Explain("INSERT INTO users (id) VALUES (?)", uint(1)).Return("INSERT INTO users (id) VALUES (1)").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "INSERT INTO users (id) VALUES (1)", int64(-1), assert.AnError).Return().Once() + + result, err := s.query.Insert(user) + s.Nil(result) + s.Equal(assert.AnError, err) + }) +} + +// func (s *QueryTestSuite) TestLimit() { +// var users []TestUser + +// s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() +// s.mockBuilder.EXPECT().Select(&users, "SELECT * FROM users WHERE age = ? LIMIT 1", 25).Return(nil).Once() +// s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE age = ? LIMIT 1", 25).Return("SELECT * FROM users WHERE age = 25 LIMIT 1").Once() +// s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE age = 25 LIMIT 1", int64(0), nil).Return().Once() + +// err := s.query.Where("age", 25).Limit(1).Get(&users) +// s.Nil(err) +// } + +func (s *QueryTestSuite) TestLatest() { + s.Run("default column", func() { + var user TestUser + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE age = ? ORDER BY created_at DESC", 25).Return(nil).Once() + s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE age = ? ORDER BY created_at DESC", 25).Return("SELECT * FROM users WHERE age = 25 ORDER BY created_at DESC").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE age = 25 ORDER BY created_at DESC", int64(1), nil).Return().Once() + + err := s.query.Where("age", 25).Latest(&user) + s.Nil(err) + }) + + s.Run("custom column", func() { + var user TestUser + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE age = ? ORDER BY name DESC", 25).Return(nil).Once() + s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE age = ? ORDER BY name DESC", 25).Return("SELECT * FROM users WHERE age = 25 ORDER BY name DESC").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE age = 25 ORDER BY name DESC", int64(1), nil).Return().Once() + + err := s.query.Where("age", 25).Latest(&user, "name") + s.Nil(err) + }) +} + func (s *QueryTestSuite) TestOrderBy() { var users []TestUser @@ -600,6 +757,22 @@ func (s *QueryTestSuite) TestOrWhereRaw() { s.Nil(err) } +func (s *QueryTestSuite) TestPluck() { + var names []string + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Select(&names, "SELECT name FROM users WHERE name = ?", "John").Run(func(dest any, query string, args ...any) { + destNames := dest.(*[]string) + *destNames = []string{"John"} + }).Return(nil).Once() + s.mockDriver.EXPECT().Explain("SELECT name FROM users WHERE name = ?", "John").Return("SELECT name FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT name FROM users WHERE name = \"John\"", int64(1), nil).Return().Once() + + err := s.query.Where("name", "John").Pluck("name", &names) + s.NoError(err) + s.Equal([]string{"John"}, names) +} + func (s *QueryTestSuite) TestSelect() { var users []TestUser @@ -716,6 +889,52 @@ func (s *QueryTestSuite) TestUpdate() { }) } +// func (s *QueryTestSuite) TestValue() { +// var name string + +// s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() +// s.mockBuilder.EXPECT().Get(&name, "SELECT name FROM users WHERE name = ? LIMIT 1", "John").Run(func(dest any, query string, args ...any) { +// destName := dest.(*string) +// *destName = "John" +// }).Return(nil).Once() +// s.mockDriver.EXPECT().Explain("SELECT name FROM users WHERE name = ? LIMIT 1", "John").Return("SELECT name FROM users WHERE name = \"John\" LIMIT 1").Once() +// s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT name FROM users WHERE name = \"John\" LIMIT 1", int64(-1), nil).Return().Once() + +// err := s.query.Where("name", "John").Value("name", &name) +// s.NoError(err) +// s.Equal("John", name) +// } + +func (s *QueryTestSuite) TestWhen() { + s.Run("when condition is true", func() { + var user TestUser + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE (name = ? AND age = ?)", "John", 25).Return(nil).Once() + s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE (name = ? AND age = ?)", "John", 25).Return("SELECT * FROM users WHERE (name = \"John\" AND age = 25)").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE (name = \"John\" AND age = 25)", int64(1), nil).Return().Once() + + err := s.query.Where("name", "John").When(true, func(query db.Query) db.Query { + return query.Where("age", 25) + }).First(&user) + s.Nil(err) + }) + + s.Run("when condition is false", func() { + var user TestUser + + s.mockDriver.EXPECT().Config().Return(database.Config{}).Once() + s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE name = ?", "John").Return(nil).Once() + s.mockDriver.EXPECT().Explain("SELECT * FROM users WHERE name = ?", "John").Return("SELECT * FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE name = \"John\"", int64(1), nil).Return().Once() + + err := s.query.Where("name", "John").When(false, func(query db.Query) db.Query { + return query.Where("age", 25) + }).First(&user) + s.Nil(err) + }) +} + func (s *QueryTestSuite) TestWhere() { s.Run("simple condition", func() { var user TestUser diff --git a/database/db/utils.go b/database/db/utils.go index 35f33d18f..d04c134a9 100644 --- a/database/db/utils.go +++ b/database/db/utils.go @@ -40,7 +40,7 @@ func convertToSliceMap(data any) ([]map[string]any, error) { elem := val.Index(i) m, err := convertToMap(elem.Interface()) if err != nil { - return nil, err + return nil, errors.DatabaseUnsupportedType.Args(typ.String(), "struct, []struct, map[string]any, []map[string]any").SetModule("DB") } if m != nil { result[i] = m @@ -52,7 +52,7 @@ func convertToSliceMap(data any) ([]map[string]any, error) { // Handle single value (struct or map) m, err := convertToMap(data) if err != nil { - return nil, err + return nil, errors.DatabaseUnsupportedType.Args(typ.String(), "struct, []struct, map[string]any, []map[string]any").SetModule("DB") } if m != nil { return []map[string]any{m}, nil @@ -82,7 +82,7 @@ func convertToMap(data any) (map[string]any, error) { } if typ.Kind() != reflect.Struct { - return nil, errors.DatabaseUnsupportedType.Args(typ.String(), "struct, []struct, map[string]any, []map[string]any").SetModule("DB") + return nil, errors.DatabaseUnsupportedType.Args(typ.String(), "struct, map[string]any").SetModule("DB") } // Handle struct diff --git a/mocks/database/db/Builder.go b/mocks/database/db/Builder.go index f348b118c..87187555e 100644 --- a/mocks/database/db/Builder.go +++ b/mocks/database/db/Builder.go @@ -148,6 +148,75 @@ func (_c *Builder_Get_Call) RunAndReturn(run func(interface{}, string, ...interf return _c } +// Query provides a mock function with given fields: query, args +func (_m *Builder) Query(query string, args ...interface{}) (*sql.Rows, error) { + var _ca []interface{} + _ca = append(_ca, query) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Query") + } + + var r0 *sql.Rows + var r1 error + if rf, ok := ret.Get(0).(func(string, ...interface{}) (*sql.Rows, error)); ok { + return rf(query, args...) + } + if rf, ok := ret.Get(0).(func(string, ...interface{}) *sql.Rows); ok { + r0 = rf(query, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sql.Rows) + } + } + + if rf, ok := ret.Get(1).(func(string, ...interface{}) error); ok { + r1 = rf(query, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Builder_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query' +type Builder_Query_Call struct { + *mock.Call +} + +// Query is a helper method to define mock.On call +// - query string +// - args ...interface{} +func (_e *Builder_Expecter) Query(query interface{}, args ...interface{}) *Builder_Query_Call { + return &Builder_Query_Call{Call: _e.mock.On("Query", + append([]interface{}{query}, args...)...)} +} + +func (_c *Builder_Query_Call) Run(run func(query string, args ...interface{})) *Builder_Query_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *Builder_Query_Call) Return(_a0 *sql.Rows, _a1 error) *Builder_Query_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Builder_Query_Call) RunAndReturn(run func(string, ...interface{}) (*sql.Rows, error)) *Builder_Query_Call { + _c.Call.Return(run) + return _c +} + // Select provides a mock function with given fields: dest, query, args func (_m *Builder) Select(dest interface{}, query string, args ...interface{}) error { var _ca []interface{} diff --git a/mocks/database/db/Query.go b/mocks/database/db/Query.go index 4dc9ae023..cc27de31c 100644 --- a/mocks/database/db/Query.go +++ b/mocks/database/db/Query.go @@ -20,6 +20,122 @@ func (_m *Query) EXPECT() *Query_Expecter { return &Query_Expecter{mock: &_m.Mock} } +// Count provides a mock function with no fields +func (_m *Query) Count() (int64, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Count") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func() (int64, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query_Count_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Count' +type Query_Count_Call struct { + *mock.Call +} + +// Count is a helper method to define mock.On call +func (_e *Query_Expecter) Count() *Query_Count_Call { + return &Query_Count_Call{Call: _e.mock.On("Count")} +} + +func (_c *Query_Count_Call) Run(run func()) *Query_Count_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Query_Count_Call) Return(_a0 int64, _a1 error) *Query_Count_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Query_Count_Call) RunAndReturn(run func() (int64, error)) *Query_Count_Call { + _c.Call.Return(run) + return _c +} + +// Decrement provides a mock function with given fields: column, value +func (_m *Query) Decrement(column string, value ...uint64) error { + _va := make([]interface{}, len(value)) + for _i := range value { + _va[_i] = value[_i] + } + var _ca []interface{} + _ca = append(_ca, column) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Decrement") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, ...uint64) error); ok { + r0 = rf(column, value...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Query_Decrement_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Decrement' +type Query_Decrement_Call struct { + *mock.Call +} + +// Decrement is a helper method to define mock.On call +// - column string +// - value ...uint64 +func (_e *Query_Expecter) Decrement(column interface{}, value ...interface{}) *Query_Decrement_Call { + return &Query_Decrement_Call{Call: _e.mock.On("Decrement", + append([]interface{}{column}, value...)...)} +} + +func (_c *Query_Decrement_Call) Run(run func(column string, value ...uint64)) *Query_Decrement_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]uint64, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(uint64) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *Query_Decrement_Call) Return(_a0 error) *Query_Decrement_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Query_Decrement_Call) RunAndReturn(run func(string, ...uint64) error) *Query_Decrement_Call { + _c.Call.Return(run) + return _c +} + // Delete provides a mock function with no fields func (_m *Query) Delete() (*db.Result, error) { ret := _m.Called() @@ -77,6 +193,108 @@ func (_c *Query_Delete_Call) RunAndReturn(run func() (*db.Result, error)) *Query return _c } +// Distinct provides a mock function with no fields +func (_m *Query) Distinct() db.Query { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Distinct") + } + + var r0 db.Query + if rf, ok := ret.Get(0).(func() db.Query); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(db.Query) + } + } + + return r0 +} + +// Query_Distinct_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Distinct' +type Query_Distinct_Call struct { + *mock.Call +} + +// Distinct is a helper method to define mock.On call +func (_e *Query_Expecter) Distinct() *Query_Distinct_Call { + return &Query_Distinct_Call{Call: _e.mock.On("Distinct")} +} + +func (_c *Query_Distinct_Call) Run(run func()) *Query_Distinct_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Query_Distinct_Call) Return(_a0 db.Query) *Query_Distinct_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Query_Distinct_Call) RunAndReturn(run func() db.Query) *Query_Distinct_Call { + _c.Call.Return(run) + return _c +} + +// DoesntExist provides a mock function with no fields +func (_m *Query) DoesntExist() (bool, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for DoesntExist") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func() (bool, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query_DoesntExist_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DoesntExist' +type Query_DoesntExist_Call struct { + *mock.Call +} + +// DoesntExist is a helper method to define mock.On call +func (_e *Query_Expecter) DoesntExist() *Query_DoesntExist_Call { + return &Query_DoesntExist_Call{Call: _e.mock.On("DoesntExist")} +} + +func (_c *Query_DoesntExist_Call) Run(run func()) *Query_DoesntExist_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Query_DoesntExist_Call) Return(_a0 bool, _a1 error) *Query_DoesntExist_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Query_DoesntExist_Call) RunAndReturn(run func() (bool, error)) *Query_DoesntExist_Call { + _c.Call.Return(run) + return _c +} + // Exists provides a mock function with no fields func (_m *Query) Exists() (bool, error) { ret := _m.Called() @@ -235,6 +453,53 @@ func (_c *Query_First_Call) RunAndReturn(run func(interface{}) error) *Query_Fir return _c } +// FirstOr provides a mock function with given fields: dest, callback +func (_m *Query) FirstOr(dest interface{}, callback func() error) error { + ret := _m.Called(dest, callback) + + if len(ret) == 0 { + panic("no return value specified for FirstOr") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}, func() error) error); ok { + r0 = rf(dest, callback) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Query_FirstOr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FirstOr' +type Query_FirstOr_Call struct { + *mock.Call +} + +// FirstOr is a helper method to define mock.On call +// - dest interface{} +// - callback func() error +func (_e *Query_Expecter) FirstOr(dest interface{}, callback interface{}) *Query_FirstOr_Call { + return &Query_FirstOr_Call{Call: _e.mock.On("FirstOr", dest, callback)} +} + +func (_c *Query_FirstOr_Call) Run(run func(dest interface{}, callback func() error)) *Query_FirstOr_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{}), args[1].(func() error)) + }) + return _c +} + +func (_c *Query_FirstOr_Call) Return(_a0 error) *Query_FirstOr_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Query_FirstOr_Call) RunAndReturn(run func(interface{}, func() error) error) *Query_FirstOr_Call { + _c.Call.Return(run) + return _c +} + // FirstOrFail provides a mock function with given fields: dest func (_m *Query) FirstOrFail(dest interface{}) error { ret := _m.Called(dest) @@ -327,6 +592,67 @@ func (_c *Query_Get_Call) RunAndReturn(run func(interface{}) error) *Query_Get_C return _c } +// Increment provides a mock function with given fields: column, value +func (_m *Query) Increment(column string, value ...uint64) error { + _va := make([]interface{}, len(value)) + for _i := range value { + _va[_i] = value[_i] + } + var _ca []interface{} + _ca = append(_ca, column) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Increment") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, ...uint64) error); ok { + r0 = rf(column, value...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Query_Increment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Increment' +type Query_Increment_Call struct { + *mock.Call +} + +// Increment is a helper method to define mock.On call +// - column string +// - value ...uint64 +func (_e *Query_Expecter) Increment(column interface{}, value ...interface{}) *Query_Increment_Call { + return &Query_Increment_Call{Call: _e.mock.On("Increment", + append([]interface{}{column}, value...)...)} +} + +func (_c *Query_Increment_Call) Run(run func(column string, value ...uint64)) *Query_Increment_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]uint64, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(uint64) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *Query_Increment_Call) Return(_a0 error) *Query_Increment_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Query_Increment_Call) RunAndReturn(run func(string, ...uint64) error) *Query_Increment_Call { + _c.Call.Return(run) + return _c +} + // Insert provides a mock function with given fields: data func (_m *Query) Insert(data interface{}) (*db.Result, error) { ret := _m.Called(data) @@ -385,6 +711,123 @@ func (_c *Query_Insert_Call) RunAndReturn(run func(interface{}) (*db.Result, err return _c } +// InsertGetId provides a mock function with given fields: data +func (_m *Query) InsertGetId(data interface{}) (int64, error) { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for InsertGetId") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(interface{}) (int64, error)); ok { + return rf(data) + } + if rf, ok := ret.Get(0).(func(interface{}) int64); ok { + r0 = rf(data) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(interface{}) error); ok { + r1 = rf(data) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query_InsertGetId_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InsertGetId' +type Query_InsertGetId_Call struct { + *mock.Call +} + +// InsertGetId is a helper method to define mock.On call +// - data interface{} +func (_e *Query_Expecter) InsertGetId(data interface{}) *Query_InsertGetId_Call { + return &Query_InsertGetId_Call{Call: _e.mock.On("InsertGetId", data)} +} + +func (_c *Query_InsertGetId_Call) Run(run func(data interface{})) *Query_InsertGetId_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *Query_InsertGetId_Call) Return(_a0 int64, _a1 error) *Query_InsertGetId_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Query_InsertGetId_Call) RunAndReturn(run func(interface{}) (int64, error)) *Query_InsertGetId_Call { + _c.Call.Return(run) + return _c +} + +// Latest provides a mock function with given fields: dest, column +func (_m *Query) Latest(dest interface{}, column ...string) error { + _va := make([]interface{}, len(column)) + for _i := range column { + _va[_i] = column[_i] + } + var _ca []interface{} + _ca = append(_ca, dest) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Latest") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}, ...string) error); ok { + r0 = rf(dest, column...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Query_Latest_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Latest' +type Query_Latest_Call struct { + *mock.Call +} + +// Latest is a helper method to define mock.On call +// - dest interface{} +// - column ...string +func (_e *Query_Expecter) Latest(dest interface{}, column ...interface{}) *Query_Latest_Call { + return &Query_Latest_Call{Call: _e.mock.On("Latest", + append([]interface{}{dest}, column...)...)} +} + +func (_c *Query_Latest_Call) Run(run func(dest interface{}, column ...string)) *Query_Latest_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]string, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(string) + } + } + run(args[0].(interface{}), variadicArgs...) + }) + return _c +} + +func (_c *Query_Latest_Call) Return(_a0 error) *Query_Latest_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Query_Latest_Call) RunAndReturn(run func(interface{}, ...string) error) *Query_Latest_Call { + _c.Call.Return(run) + return _c +} + // OrWhere provides a mock function with given fields: query, args func (_m *Query) OrWhere(query interface{}, args ...interface{}) db.Query { var _ca []interface{} @@ -1151,6 +1594,53 @@ func (_c *Query_OrderByRaw_Call) RunAndReturn(run func(string) db.Query) *Query_ return _c } +// Pluck provides a mock function with given fields: column, dest +func (_m *Query) Pluck(column string, dest interface{}) error { + ret := _m.Called(column, dest) + + if len(ret) == 0 { + panic("no return value specified for Pluck") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, interface{}) error); ok { + r0 = rf(column, dest) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Query_Pluck_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Pluck' +type Query_Pluck_Call struct { + *mock.Call +} + +// Pluck is a helper method to define mock.On call +// - column string +// - dest interface{} +func (_e *Query_Expecter) Pluck(column interface{}, dest interface{}) *Query_Pluck_Call { + return &Query_Pluck_Call{Call: _e.mock.On("Pluck", column, dest)} +} + +func (_c *Query_Pluck_Call) Run(run func(column string, dest interface{})) *Query_Pluck_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(interface{})) + }) + return _c +} + +func (_c *Query_Pluck_Call) Return(_a0 error) *Query_Pluck_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Query_Pluck_Call) RunAndReturn(run func(string, interface{}) error) *Query_Pluck_Call { + _c.Call.Return(run) + return _c +} + // Select provides a mock function with given fields: columns func (_m *Query) Select(columns ...string) db.Query { _va := make([]interface{}, len(columns)) @@ -1281,6 +1771,55 @@ func (_c *Query_Update_Call) RunAndReturn(run func(interface{}, ...interface{}) return _c } +// When provides a mock function with given fields: condition, callback +func (_m *Query) When(condition bool, callback func(db.Query) db.Query) db.Query { + ret := _m.Called(condition, callback) + + if len(ret) == 0 { + panic("no return value specified for When") + } + + var r0 db.Query + if rf, ok := ret.Get(0).(func(bool, func(db.Query) db.Query) db.Query); ok { + r0 = rf(condition, callback) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(db.Query) + } + } + + return r0 +} + +// Query_When_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'When' +type Query_When_Call struct { + *mock.Call +} + +// When is a helper method to define mock.On call +// - condition bool +// - callback func(db.Query) db.Query +func (_e *Query_Expecter) When(condition interface{}, callback interface{}) *Query_When_Call { + return &Query_When_Call{Call: _e.mock.On("When", condition, callback)} +} + +func (_c *Query_When_Call) Run(run func(condition bool, callback func(db.Query) db.Query)) *Query_When_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(bool), args[1].(func(db.Query) db.Query)) + }) + return _c +} + +func (_c *Query_When_Call) Return(_a0 db.Query) *Query_When_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Query_When_Call) RunAndReturn(run func(bool, func(db.Query) db.Query) db.Query) *Query_When_Call { + _c.Call.Return(run) + return _c +} + // Where provides a mock function with given fields: query, args func (_m *Query) Where(query interface{}, args ...interface{}) db.Query { var _ca []interface{} diff --git a/tests/db_test.go b/tests/db_test.go index f007905f4..7c746b678 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -9,7 +9,9 @@ import ( "github.com/goravel/framework/errors" "github.com/goravel/framework/support/carbon" "github.com/goravel/framework/support/convert" + "github.com/goravel/postgres" "github.com/goravel/sqlite" + "github.com/goravel/sqlserver" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -46,6 +48,65 @@ func (s *DBTestSuite) TearDownSuite() { } } +func (s *DBTestSuite) TestCount() { + for driver, query := range s.queries { + s.Run(driver, func() { + query.DB().Table("products").Insert([]Product{ + {Name: "count_product1"}, + {Name: "count_product2"}, + }) + count, err := query.DB().Table("products").Count() + s.NoError(err) + s.Equal(int64(2), count) + }) + } +} + +func (s *DBTestSuite) TestDecrement() { + for driver, query := range s.queries { + s.Run(driver, func() { + query.DB().Table("products").Insert(Product{Name: "decrement_product", Weight: convert.Pointer(100)}) + + s.Run("decrement", func() { + err := query.DB().Table("products").Where("name", "decrement_product").Decrement("weight", 1) + s.NoError(err) + + var product Product + err = query.DB().Table("products").Where("name", "decrement_product").First(&product) + s.NoError(err) + s.Equal(99, *product.Weight) + }) + + s.Run("decrement with number", func() { + err := query.DB().Table("products").Where("name", "decrement_product").Decrement("weight", 5) + s.NoError(err) + + var product Product + err = query.DB().Table("products").Where("name", "decrement_product").First(&product) + s.NoError(err) + s.Equal(94, *product.Weight) + }) + }) + } +} + +func (s *DBTestSuite) TestDistinct() { + for driver, query := range s.queries { + s.Run(driver, func() { + query.DB().Table("products").Insert([]Product{ + {Name: "distinct_product"}, + {Name: "distinct_product"}, + }) + + var products []Product + err := query.DB().Table("products").Distinct().Select("name").Get(&products) + s.NoError(err) + s.Equal(1, len(products)) + s.Equal("distinct_product", products[0].Name) + }) + } +} + func (s *DBTestSuite) TestExists() { for driver, query := range s.queries { s.Run(driver, func() { @@ -62,6 +123,34 @@ func (s *DBTestSuite) TestExists() { } } +func (s *DBTestSuite) TestIncrement() { + for driver, query := range s.queries { + s.Run(driver, func() { + query.DB().Table("products").Insert(Product{Name: "increment_product", Weight: convert.Pointer(100)}) + + s.Run("increment", func() { + err := query.DB().Table("products").Where("name", "increment_product").Increment("weight", 1) + s.NoError(err) + + var product Product + err = query.DB().Table("products").Where("name", "increment_product").First(&product) + s.NoError(err) + s.Equal(101, *product.Weight) + }) + + s.Run("increment with number", func() { + err := query.DB().Table("products").Where("name", "increment_product").Increment("weight", 5) + s.NoError(err) + + var product Product + err = query.DB().Table("products").Where("name", "increment_product").First(&product) + s.NoError(err) + s.Equal(106, *product.Weight) + }) + }) + } +} + func (s *DBTestSuite) TestInsert_First_Get() { for driver, query := range s.queries { s.Run(driver, func() { @@ -158,6 +247,29 @@ func (s *DBTestSuite) TestInsert_First_Get() { } } +func (s *DBTestSuite) TestInsertGetId() { + for driver, query := range s.queries { + s.Run(driver, func() { + id, err := query.DB().Table("products").InsertGetId(Product{ + Name: "insert get id", + }) + + if driver == sqlserver.Name || driver == postgres.Name { + s.Error(err) + s.Equal(int64(0), id) + } else { + s.NoError(err) + s.True(id > 0) + + var product Product + err = query.DB().Table("products").Where("id", id).First(&product) + s.NoError(err) + s.Equal("insert get id", product.Name) + } + }) + } +} + func (s *DBTestSuite) TestOrWhere() { for driver, query := range s.queries { s.Run(driver, func() { @@ -280,6 +392,23 @@ func (s *DBTestSuite) TestOrWhereNot() { } } +func (s *DBTestSuite) TestPluck() { + for driver, query := range s.queries { + s.Run(driver, func() { + query.DB().Table("products").Insert([]Product{ + {Name: "pluck_product1"}, + {Name: "pluck_product2"}, + }) + + var names []string + err := query.DB().Table("products").WhereLike("name", "pluck_product%").Pluck("name", &names) + + s.NoError(err) + s.Equal([]string{"pluck_product1", "pluck_product2"}, names) + }) + } +} + func (s *DBTestSuite) TestUpdate_Delete() { for driver, query := range s.queries { s.Run(driver, func() { @@ -345,6 +474,20 @@ func (s *DBTestSuite) TestUpdate_Delete() { } } +// func (s *DBTestSuite) TestValue() { +// for driver, query := range s.queries { +// s.Run(driver, func() { +// query.DB().Table("products").Insert(Product{Name: "value_product"}) + +// var name string +// err := query.DB().Table("products").Where("name", "value_product").Value("name", &name) + +// s.NoError(err) +// s.Equal("value_product", name) +// }) +// } +// } + func (s *DBTestSuite) TestWhere() { for driver, query := range s.queries { s.Run(driver, func() {