diff --git a/contracts/database/db/db.go b/contracts/database/db/db.go index b7bfe86a0..67ddb5efe 100644 --- a/contracts/database/db/db.go +++ b/contracts/database/db/db.go @@ -1,10 +1,24 @@ package db +import "database/sql" + type DB interface { Table(name string) Query } type Query interface { - Where(query any, args ...any) Query + First(dest any) error Get(dest any) error + Insert(data any) (*Result, error) + Where(query any, args ...any) Query +} + +type Result struct { + RowsAffected int64 +} + +type Builder interface { + Exec(query string, args ...any) (sql.Result, error) + Get(dest any, query string, args ...any) error + Select(dest any, query string, args ...any) error } diff --git a/database/db/db.go b/database/db/db.go index 503fe6b7f..31c5aeb85 100644 --- a/database/db/db.go +++ b/database/db/db.go @@ -13,12 +13,12 @@ import ( ) type DB struct { - config database.Config - instance *sqlx.DB + builder db.Builder + config database.Config } -func NewDB(config database.Config, instance *sqlx.DB) db.DB { - return &DB{config: config, instance: instance} +func NewDB(config database.Config, builder db.Builder) db.DB { + return &DB{config: config, builder: builder} } func BuildDB(config config.Config, connection string) (db.DB, error) { @@ -41,5 +41,5 @@ func BuildDB(config config.Config, connection string) (db.DB, error) { } func (r *DB) Table(name string) db.Query { - return NewQuery(r.config, r.instance, name) + return NewQuery(r.config, r.builder, name) } diff --git a/database/db/db_test.go b/database/db/db_test.go new file mode 100644 index 000000000..24541d6d7 --- /dev/null +++ b/database/db/db_test.go @@ -0,0 +1,67 @@ +package db + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/goravel/framework/contracts/database" + contractsdriver "github.com/goravel/framework/contracts/database/driver" + "github.com/goravel/framework/errors" + mocksconfig "github.com/goravel/framework/mocks/config" + mocksdriver "github.com/goravel/framework/mocks/database/driver" +) + +func TestBuildDB(t *testing.T) { + var ( + mockConfig *mocksconfig.Config + mockDriver *mocksdriver.Driver + ) + + tests := []struct { + name string + connection string + setup func() + expectedError error + }{ + { + name: "Success", + connection: "mysql", + setup: func() { + driverCallback := func() (contractsdriver.Driver, error) { + return mockDriver, nil + } + mockConfig.On("Get", "database.connections.mysql.via").Return(driverCallback) + mockDriver.On("DB").Return(&sql.DB{}, nil) + mockDriver.On("Config").Return(database.Config{Driver: "mysql"}) + }, + expectedError: nil, + }, + { + name: "Config Not Found", + connection: "invalid", + setup: func() { + mockConfig.On("Get", "database.connections.invalid.via").Return(nil) + }, + expectedError: errors.DatabaseConfigNotFound, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mockConfig = mocksconfig.NewConfig(t) + mockDriver = mocksdriver.NewDriver(t) + test.setup() + + db, err := BuildDB(mockConfig, test.connection) + if test.expectedError != nil { + assert.Equal(t, test.expectedError, err) + assert.Nil(t, db) + } else { + assert.NoError(t, err) + assert.NotNil(t, db) + } + }) + } +} diff --git a/database/db/query.go b/database/db/query.go index b419bea12..d15a7e6d7 100644 --- a/database/db/query.go +++ b/database/db/query.go @@ -2,38 +2,41 @@ package db import ( "fmt" + "sort" sq "github.com/Masterminds/squirrel" - "github.com/jmoiron/sqlx" "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/database/db" "github.com/goravel/framework/errors" + "github.com/goravel/framework/support/str" ) type Query struct { + builder db.Builder conditions Conditions config database.Config - instance *sqlx.DB } -func NewQuery(config database.Config, instance *sqlx.DB, table string) *Query { +func NewQuery(config database.Config, builder db.Builder, table string) *Query { return &Query{ + builder: builder, conditions: Conditions{ table: table, }, - config: config, - instance: instance, + config: config, } } -func (r *Query) Where(query any, args ...any) db.Query { - r.conditions.where = append(r.conditions.where, Where{ - query: query, - args: args, - }) +func (r *Query) First(dest any) error { + sql, args, err := r.buildSelect() + // TODO: use logger instead of println + fmt.Println(sql, args, err) + if err != nil { + return err + } - return r + return r.builder.Get(dest, sql, args...) } func (r *Query) Get(dest any) error { @@ -44,7 +47,81 @@ func (r *Query) Get(dest any) error { return err } - return r.instance.Select(dest, sql, args...) + return r.builder.Select(dest, sql, args...) +} + +func (r *Query) Insert(data any) (*db.Result, error) { + mapData, err := convertToSliceMap(data) + if err != nil { + return nil, err + } + if len(mapData) == 0 { + return &db.Result{ + RowsAffected: 0, + }, nil + } + + sql, args, err := r.buildInsert(mapData) + if err != nil { + return nil, err + } + // TODO: use logger instead of println + fmt.Println(sql, args, err) + result, err := r.builder.Exec(sql, args...) + if err != nil { + return nil, err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return nil, err + } + + return &db.Result{ + RowsAffected: rowsAffected, + }, nil +} + +func (r *Query) Where(query any, args ...any) db.Query { + q := NewQuery(r.config, r.builder, r.conditions.table) + q.conditions = r.conditions + q.conditions.where = append(r.conditions.where, Where{ + query: query, + args: args, + }) + + return q +} + +func (r *Query) buildInsert(data []map[string]any) (sql string, args []any, err error) { + if r.conditions.table == "" { + return "", nil, errors.DatabaseTableIsRequired + } + + builder := sq.Insert(r.conditions.table) + if r.config.PlaceholderFormat != nil { + builder = builder.PlaceholderFormat(r.config.PlaceholderFormat) + } + + first := data[0] + builder = builder.SetMap(first) + + cols := make([]string, 0, len(first)) + for col := range first { + cols = append(cols, col) + } + sort.Strings(cols) + builder = builder.Columns(cols...) + + for _, row := range data { + vals := make([]any, 0, len(first)) + for _, col := range cols { + vals = append(vals, row[col]) + } + builder = builder.Values(vals...) + } + + return builder.ToSql() } func (r *Query) buildSelect() (sql string, args []any, err error) { @@ -60,6 +137,18 @@ func (r *Query) buildSelect() (sql string, args []any, err error) { builder = builder.From(r.conditions.table) for _, where := range r.conditions.where { + query, ok := where.query.(string) + if ok { + if !str.Of(query).Trim().Contains(" ", "?") { + if len(where.args) > 1 { + builder = builder.Where(sq.Eq{query: where.args}) + } else if len(where.args) == 1 { + builder = builder.Where(sq.Eq{query: where.args[0]}) + } + continue + } + } + builder = builder.Where(where.query, where.args...) } diff --git a/database/db/query_test.go b/database/db/query_test.go new file mode 100644 index 000000000..54d71d5c9 --- /dev/null +++ b/database/db/query_test.go @@ -0,0 +1,192 @@ +package db + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/goravel/framework/contracts/database" + "github.com/goravel/framework/errors" + mocksdb "github.com/goravel/framework/mocks/database/db" +) + +// TestUser is a test model +type TestUser struct { + ID uint `db:"id"` + Name string `db:"-"` + Age int +} + +type QueryTestSuite struct { + suite.Suite + mockBuilder *mocksdb.Builder + query *Query +} + +func TestQueryTestSuite(t *testing.T) { + suite.Run(t, &QueryTestSuite{}) +} + +func (s *QueryTestSuite) SetupTest() { + s.mockBuilder = mocksdb.NewBuilder(s.T()) + s.query = NewQuery(database.Config{}, s.mockBuilder, "users") +} + +func (s *QueryTestSuite) TestFirst() { + var user TestUser + s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE name = ?", "John").Return(nil).Once() + + err := s.query.Where("name", "John").First(&user) + s.Nil(err) +} + +func (s *QueryTestSuite) TestGet() { + var users []TestUser + s.mockBuilder.EXPECT().Select(&users, "SELECT * FROM users WHERE age = ?", 25).Return(nil).Once() + + err := s.query.Where("age", 25).Get(&users) + s.Nil(err) + s.mockBuilder.AssertExpectations(s.T()) +} + +func (s *QueryTestSuite) TestInsert() { + s.Run("empty", func() { + result, err := s.query.Insert(nil) + s.Nil(err) + s.Equal(int64(0), result.RowsAffected) + }) + + s.Run("single struct", func() { + user := TestUser{ + ID: 1, + Name: "John", + Age: 25, + } + + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + s.mockBuilder.EXPECT().Exec("INSERT INTO users (id) VALUES (?)", uint(1)).Return(mockResult, nil).Once() + + result, err := s.query.Insert(user) + s.Nil(err) + s.Equal(int64(1), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) + }) + + s.Run("multiple structs", func() { + users := []TestUser{ + {ID: 1, Name: "John", Age: 25}, + {ID: 2, Name: "Jane", Age: 30}, + } + + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(2), nil) + s.mockBuilder.EXPECT().Exec("INSERT INTO users (id) VALUES (?),(?)", uint(1), uint(2)).Return(mockResult, nil).Once() + + result, err := s.query.Insert(users) + s.Nil(err) + s.Equal(int64(2), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) + }) + + s.Run("single map", func() { + user := map[string]any{ + "id": 1, + "name": "John", + "age": 25, + } + + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + s.mockBuilder.EXPECT().Exec("INSERT INTO users (age,id,name) VALUES (?,?,?)", 25, 1, "John").Return(mockResult, nil).Once() + + result, err := s.query.Insert(user) + s.Nil(err) + s.Equal(int64(1), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) + }) + + s.Run("multiple maps", func() { + users := []map[string]any{ + {"id": 1, "name": "John", "age": 25}, + {"id": 2, "name": "Jane", "age": 30}, + } + + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(2), nil) + s.mockBuilder.EXPECT().Exec("INSERT INTO users (age,id,name) VALUES (?,?,?),(?,?,?)", 25, 1, "John", 30, 2, "Jane").Return(mockResult, nil).Once() + + result, err := s.query.Insert(users) + s.Nil(err) + s.Equal(int64(2), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) + }) + + s.Run("unknown type", func() { + user := "unknown" + + _, err := s.query.Insert(user) + s.Equal(errors.DatabaseUnsupportedType.Args("string", "struct, []struct, map[string]any, []map[string]any").SetModule("DB"), err) + }) + + s.Run("failed to exec", func() { + user := TestUser{ + ID: 1, + Name: "John", + Age: 25, + } + + s.mockBuilder.EXPECT().Exec("INSERT INTO users (id) VALUES (?)", uint(1)).Return(nil, assert.AnError).Once() + + result, err := s.query.Insert(user) + s.Nil(result) + s.Equal(assert.AnError, err) + }) +} + +func (s *QueryTestSuite) TestWhere() { + s.Run("simple where condition", func() { + var user TestUser + s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE name = ?", "John").Return(nil).Once() + + err := s.query.Where("name", "John").First(&user) + s.Nil(err) + }) + + s.Run("where with multiple arguments", func() { + var users []TestUser + s.mockBuilder.EXPECT().Select(&users, "SELECT * FROM users WHERE age IN (?,?)", 25, 30).Return(nil).Once() + + err := s.query.Where("age", []int{25, 30}).Get(&users) + s.Nil(err) + }) + + s.Run("where with raw query", func() { + var users []TestUser + s.mockBuilder.EXPECT().Select(&users, "SELECT * FROM users WHERE age > ?", 18).Return(nil).Once() + + err := s.query.Where("age > ?", 18).Get(&users) + s.Nil(err) + }) +} + +// MockResult implements sql.Result interface for testing +type MockResult struct { + mock.Mock +} + +func (m *MockResult) LastInsertId() (int64, error) { + arguments := m.Called() + return arguments.Get(0).(int64), arguments.Error(1) +} + +func (m *MockResult) RowsAffected() (int64, error) { + arguments := m.Called() + return arguments.Get(0).(int64), arguments.Error(1) +} diff --git a/database/db/utils.go b/database/db/utils.go new file mode 100644 index 000000000..9402dd9e5 --- /dev/null +++ b/database/db/utils.go @@ -0,0 +1,133 @@ +package db + +import ( + "reflect" + "strings" + + "github.com/goravel/framework/errors" +) + +func convertToSliceMap(data any) ([]map[string]any, error) { + if data == nil { + return nil, nil + } + + if maps, ok := data.([]map[string]any); ok { + return maps, nil + } + + val := reflect.ValueOf(data) + typ := val.Type() + + // Handle pointer + if typ.Kind() == reflect.Ptr { + if val.IsNil() { + return nil, nil + } + val = val.Elem() + typ = val.Type() + } + + // Handle slice + if typ.Kind() == reflect.Slice { + length := val.Len() + if length == 0 { + return []map[string]any{}, nil + } + + result := make([]map[string]any, length) + for i := 0; i < length; i++ { + elem := val.Index(i) + m, err := convertToMap(elem.Interface()) + if err != nil { + return nil, err + } + if m != nil { + result[i] = m + } + } + return result, nil + } + + // Handle single value (struct or map) + m, err := convertToMap(data) + if err != nil { + return nil, err + } + if m != nil { + return []map[string]any{m}, nil + } + return nil, nil +} + +func convertToMap(data any) (map[string]any, error) { + if data == nil { + return nil, nil + } + + if m, ok := data.(map[string]any); ok { + return m, nil + } + + val := reflect.ValueOf(data) + typ := val.Type() + + // Handle pointer + if typ.Kind() == reflect.Ptr { + if val.IsNil() { + return nil, nil + } + val = val.Elem() + typ = val.Type() + } + + if typ.Kind() != reflect.Struct { + return nil, errors.DatabaseUnsupportedType.Args(typ.String(), "struct, []struct, map[string]any, []map[string]any").SetModule("DB") + } + + // Handle struct + result := make(map[string]any) + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + if !field.IsExported() { + continue + } + + // Handle embedded struct + if field.Anonymous { + fieldValue := val.Field(i) + if fieldValue.Kind() == reflect.Ptr && !fieldValue.IsNil() { + fieldValue = fieldValue.Elem() + } + if fieldValue.Kind() == reflect.Struct { + embedded, err := convertToMap(fieldValue.Interface()) + if err != nil { + return nil, err + } + for k, v := range embedded { + result[k] = v + } + } + continue + } + + // Get field name from db tag or use field name + tag := field.Tag.Get("db") + if tag == "" || tag == "-" { + continue + } + var fieldName string + if comma := strings.Index(tag, ","); comma != -1 { + fieldName = tag[:comma] + } else { + fieldName = tag + } + + fieldValue := val.Field(i) + if fieldValue.Kind() == reflect.Ptr && !fieldValue.IsNil() { + fieldValue = fieldValue.Elem() + } + result[fieldName] = fieldValue.Interface() + } + return result, nil +} diff --git a/database/db/utils_test.go b/database/db/utils_test.go new file mode 100644 index 000000000..24d84175e --- /dev/null +++ b/database/db/utils_test.go @@ -0,0 +1,95 @@ +package db + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type Body struct { + Weight string `db:"weight"` + Height int `db:"-"` + Age uint +} + +type User struct { + ID int `db:"id"` + Name string `db:"-"` + Email string + Body +} + +func TestConvertToSliceMap(t *testing.T) { + tests := []struct { + data any + want []map[string]any + }{ + { + data: nil, + want: nil, + }, + { + data: []User{ + {ID: 1, Name: "John", Email: "john@example.com", Body: Body{Weight: "100kg", Height: 180, Age: 25}}, + {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Weight: "90kg", Height: 170, Age: 20}}, + }, + want: []map[string]any{ + {"id": 1, "weight": "100kg"}, + {"id": 2, "weight": "90kg"}, + }, + }, + { + data: []*User{ + {ID: 1, Name: "John", Email: "john@example.com", Body: Body{Weight: "100kg", Height: 180, Age: 25}}, + {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Weight: "90kg", Height: 170, Age: 20}}, + }, + want: []map[string]any{ + {"id": 1, "weight": "100kg"}, + {"id": 2, "weight": "90kg"}, + }, + }, + { + data: []Body{ + {Weight: "100kg", Height: 180, Age: 25}, + {Weight: "90kg", Height: 170, Age: 20}, + }, + want: []map[string]any{{"weight": "100kg"}, {"weight": "90kg"}}, + }, + { + data: Body{ + Weight: "100kg", + Height: 180, + Age: 25, + }, + want: []map[string]any{{"weight": "100kg"}}, + }, + { + data: &Body{ + Weight: "100kg", + Height: 180, + Age: 25, + }, + want: []map[string]any{{"weight": "100kg"}}, + }, + { + data: map[string]any{ + "weight": "100kg", + "Age": 25, + }, + want: []map[string]any{{"weight": "100kg", "Age": 25}}, + }, + { + data: []map[string]any{ + {"weight": "100kg", "Age": 25}, + {"weight": "90kg", "Age": 20}, + }, + want: []map[string]any{{"weight": "100kg", "Age": 25}, {"weight": "90kg", "Age": 20}}, + }, + } + + for _, test := range tests { + sliceMap, err := convertToSliceMap(test.data) + assert.NoError(t, err) + assert.Equal(t, test.want, sliceMap) + } +} diff --git a/errors/list.go b/errors/list.go index 1515e7ea5..72828b547 100644 --- a/errors/list.go +++ b/errors/list.go @@ -50,6 +50,7 @@ var ( DatabaseForceIsRequiredInProduction = New("application in production use --force to run this command") DatabaseSeederNotFound = New("not found %s seeder") DatabaseFailToRunSeeder = New("fail to run seeder: %v") + DatabaseUnsupportedType = New("unsupported type: %s, expected %s") DockerUnknownContainerType = New("unknown container type") DockerInsufficientDatabaseContainers = New("the number of database container is not enough, expect: %d, got: %d") diff --git a/go.mod b/go.mod index 519a5a8cb..0cc74c217 100644 --- a/go.mod +++ b/go.mod @@ -92,7 +92,7 @@ require ( github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 go.opentelemetry.io/otel v1.33.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/net v0.34.0 // indirect + golang.org/x/net v0.35.0 // indirect golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/term v0.29.0 // indirect diff --git a/go.sum b/go.sum index 074a72eba..74a964998 100644 --- a/go.sum +++ b/go.sum @@ -251,8 +251,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/mocks/database/db/Builder.go b/mocks/database/db/Builder.go new file mode 100644 index 000000000..f348b118c --- /dev/null +++ b/mocks/database/db/Builder.go @@ -0,0 +1,221 @@ +// Code generated by mockery. DO NOT EDIT. + +package db + +import ( + sql "database/sql" + + mock "github.com/stretchr/testify/mock" +) + +// Builder is an autogenerated mock type for the Builder type +type Builder struct { + mock.Mock +} + +type Builder_Expecter struct { + mock *mock.Mock +} + +func (_m *Builder) EXPECT() *Builder_Expecter { + return &Builder_Expecter{mock: &_m.Mock} +} + +// Exec provides a mock function with given fields: query, args +func (_m *Builder) Exec(query string, args ...interface{}) (sql.Result, error) { + var _ca []interface{} + _ca = append(_ca, query) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Exec") + } + + var r0 sql.Result + var r1 error + if rf, ok := ret.Get(0).(func(string, ...interface{}) (sql.Result, error)); ok { + return rf(query, args...) + } + if rf, ok := ret.Get(0).(func(string, ...interface{}) sql.Result); ok { + r0 = rf(query, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(sql.Result) + } + } + + if rf, ok := ret.Get(1).(func(string, ...interface{}) error); ok { + r1 = rf(query, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Builder_Exec_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Exec' +type Builder_Exec_Call struct { + *mock.Call +} + +// Exec is a helper method to define mock.On call +// - query string +// - args ...interface{} +func (_e *Builder_Expecter) Exec(query interface{}, args ...interface{}) *Builder_Exec_Call { + return &Builder_Exec_Call{Call: _e.mock.On("Exec", + append([]interface{}{query}, args...)...)} +} + +func (_c *Builder_Exec_Call) Run(run func(query string, args ...interface{})) *Builder_Exec_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *Builder_Exec_Call) Return(_a0 sql.Result, _a1 error) *Builder_Exec_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Builder_Exec_Call) RunAndReturn(run func(string, ...interface{}) (sql.Result, error)) *Builder_Exec_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function with given fields: dest, query, args +func (_m *Builder) Get(dest interface{}, query string, args ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, dest, query) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}, string, ...interface{}) error); ok { + r0 = rf(dest, query, args...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Builder_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type Builder_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - dest interface{} +// - query string +// - args ...interface{} +func (_e *Builder_Expecter) Get(dest interface{}, query interface{}, args ...interface{}) *Builder_Get_Call { + return &Builder_Get_Call{Call: _e.mock.On("Get", + append([]interface{}{dest, query}, args...)...)} +} + +func (_c *Builder_Get_Call) Run(run func(dest interface{}, query string, args ...interface{})) *Builder_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(interface{}), args[1].(string), variadicArgs...) + }) + return _c +} + +func (_c *Builder_Get_Call) Return(_a0 error) *Builder_Get_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Builder_Get_Call) RunAndReturn(run func(interface{}, string, ...interface{}) error) *Builder_Get_Call { + _c.Call.Return(run) + return _c +} + +// Select provides a mock function with given fields: dest, query, args +func (_m *Builder) Select(dest interface{}, query string, args ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, dest, query) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Select") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}, string, ...interface{}) error); ok { + r0 = rf(dest, query, args...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Builder_Select_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Select' +type Builder_Select_Call struct { + *mock.Call +} + +// Select is a helper method to define mock.On call +// - dest interface{} +// - query string +// - args ...interface{} +func (_e *Builder_Expecter) Select(dest interface{}, query interface{}, args ...interface{}) *Builder_Select_Call { + return &Builder_Select_Call{Call: _e.mock.On("Select", + append([]interface{}{dest, query}, args...)...)} +} + +func (_c *Builder_Select_Call) Run(run func(dest interface{}, query string, args ...interface{})) *Builder_Select_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(interface{}), args[1].(string), variadicArgs...) + }) + return _c +} + +func (_c *Builder_Select_Call) Return(_a0 error) *Builder_Select_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Builder_Select_Call) RunAndReturn(run func(interface{}, string, ...interface{}) error) *Builder_Select_Call { + _c.Call.Return(run) + return _c +} + +// NewBuilder creates a new instance of Builder. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewBuilder(t interface { + mock.TestingT + Cleanup(func()) +}) *Builder { + mock := &Builder{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/mocks/database/db/Query.go b/mocks/database/db/Query.go index 106a79033..d32f62588 100644 --- a/mocks/database/db/Query.go +++ b/mocks/database/db/Query.go @@ -20,6 +20,52 @@ func (_m *Query) EXPECT() *Query_Expecter { return &Query_Expecter{mock: &_m.Mock} } +// First provides a mock function with given fields: dest +func (_m *Query) First(dest interface{}) error { + ret := _m.Called(dest) + + if len(ret) == 0 { + panic("no return value specified for First") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(dest) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Query_First_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'First' +type Query_First_Call struct { + *mock.Call +} + +// First is a helper method to define mock.On call +// - dest interface{} +func (_e *Query_Expecter) First(dest interface{}) *Query_First_Call { + return &Query_First_Call{Call: _e.mock.On("First", dest)} +} + +func (_c *Query_First_Call) Run(run func(dest interface{})) *Query_First_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *Query_First_Call) Return(_a0 error) *Query_First_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Query_First_Call) RunAndReturn(run func(interface{}) error) *Query_First_Call { + _c.Call.Return(run) + return _c +} + // Get provides a mock function with given fields: dest func (_m *Query) Get(dest interface{}) error { ret := _m.Called(dest) @@ -66,6 +112,64 @@ func (_c *Query_Get_Call) RunAndReturn(run func(interface{}) error) *Query_Get_C return _c } +// Insert provides a mock function with given fields: data +func (_m *Query) Insert(data interface{}) (*db.Result, error) { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for Insert") + } + + var r0 *db.Result + var r1 error + if rf, ok := ret.Get(0).(func(interface{}) (*db.Result, error)); ok { + return rf(data) + } + if rf, ok := ret.Get(0).(func(interface{}) *db.Result); ok { + r0 = rf(data) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*db.Result) + } + } + + if rf, ok := ret.Get(1).(func(interface{}) error); ok { + r1 = rf(data) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query_Insert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Insert' +type Query_Insert_Call struct { + *mock.Call +} + +// Insert is a helper method to define mock.On call +// - data interface{} +func (_e *Query_Expecter) Insert(data interface{}) *Query_Insert_Call { + return &Query_Insert_Call{Call: _e.mock.On("Insert", data)} +} + +func (_c *Query_Insert_Call) Run(run func(data interface{})) *Query_Insert_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *Query_Insert_Call) Return(_a0 *db.Result, _a1 error) *Query_Insert_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Query_Insert_Call) RunAndReturn(run func(interface{}) (*db.Result, error)) *Query_Insert_Call { + _c.Call.Return(run) + return _c +} + // Where provides a mock function with given fields: query, args func (_m *Query) Where(query interface{}, args ...interface{}) db.Query { var _ca []interface{} diff --git a/tests/db_test.go b/tests/db_test.go index 30cc24992..5a43eb91c 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -3,6 +3,7 @@ package tests import ( "testing" + "github.com/goravel/framework/support/carbon" "github.com/goravel/sqlite" "github.com/stretchr/testify/suite" ) @@ -22,7 +23,7 @@ func TestDBTestSuite(t *testing.T) { func (s *DBTestSuite) SetupSuite() { s.queries = NewTestQueryBuilder().All("", false) for _, query := range s.queries { - query.CreateTable(TestTableUsers) + query.CreateTable(TestTableProducts) } } @@ -34,12 +35,139 @@ func (s *DBTestSuite) TearDownSuite() { } } +func (s *DBTestSuite) TestInsert_First_Get() { + for driver, query := range s.queries { + now := carbon.NewDateTime(carbon.FromDateTime(2025, 1, 2, 3, 4, 5)) + + s.Run(driver, func() { + s.Run("single struct", func() { + result, err := query.DB().Table("products").Insert(Product{ + Name: "single struct", + Model: Model{ + Timestamps: Timestamps{ + CreatedAt: now, + UpdatedAt: now, + }, + }, + }) + + s.NoError(err) + s.Equal(int64(1), result.RowsAffected) + + var product Product + err = query.DB().Table("products").Where("name", "single struct").Where("deleted_at", nil).First(&product) + s.NoError(err) + s.True(product.ID > 0) + s.Equal("single struct", product.Name) + s.Equal(now, product.CreatedAt) + s.Equal(now, product.UpdatedAt) + s.False(product.DeletedAt.Valid) + }) + + s.Run("multiple structs", func() { + result, err := query.DB().Table("products").Insert([]Product{ + { + Name: "multiple structs1", + Model: Model{ + Timestamps: Timestamps{ + CreatedAt: now, + UpdatedAt: now, + }, + }, + }, + { + Name: "multiple structs2", + }, + }) + s.NoError(err) + s.Equal(int64(2), result.RowsAffected) + + var products []Product + err = query.DB().Table("products").Where("name", []string{"multiple structs1", "multiple structs2"}).Where("deleted_at", nil).Get(&products) + s.NoError(err) + s.Equal(2, len(products)) + s.Equal("multiple structs1", products[0].Name) + s.Equal("multiple structs2", products[1].Name) + }) + + s.Run("single map", func() { + result, err := query.DB().Table("products").Insert(map[string]any{ + "name": "single map", + "created_at": now, + "updated_at": &now, + }) + s.NoError(err) + s.Equal(int64(1), result.RowsAffected) + + var product Product + err = query.DB().Table("products").Where("name", "single map").Where("deleted_at", nil).First(&product) + s.NoError(err) + s.Equal("single map", product.Name) + s.Equal(now, product.CreatedAt) + s.Equal(now, product.UpdatedAt) + s.False(product.DeletedAt.Valid) + }) + + s.Run("multiple map", func() { + result, err := query.DB().Table("products").Insert([]map[string]any{ + { + "name": "multiple map1", + "created_at": now, + "updated_at": &now, + }, + { + "name": "multiple map2", + }, + }) + s.NoError(err) + s.Equal(int64(2), result.RowsAffected) + + var products []Product + err = query.DB().Table("products").Where("name", []string{"multiple map1", "multiple map2"}).Where("deleted_at", nil).Get(&products) + s.NoError(err) + s.Equal(2, len(products)) + s.Equal("multiple map1", products[0].Name) + s.Equal("multiple map2", products[1].Name) + }) + }) + } +} + func (s *DBTestSuite) TestWhere() { for driver, query := range s.queries { s.Run(driver, func() { - var user []User - err := query.DB().Table("users").Where("name = ?", "count_user").Get(&user) - s.NoError(err) + now := carbon.NewDateTime(carbon.FromDateTime(2025, 1, 2, 3, 4, 5)) + query.DB().Table("products").Insert(Product{ + Name: "where model", + Model: Model{ + Timestamps: Timestamps{ + CreatedAt: now, + UpdatedAt: now, + }, + }, + }) + + s.Run("simple where condition", func() { + var product Product + err := query.DB().Table("products").Where("name", "where model").First(&product) + s.NoError(err) + s.Equal("where model", product.Name) + }) + + s.Run("where with multiple arguments", func() { + var products []Product + err := query.DB().Table("products").Where("name", []string{"where model", "where model1"}).Get(&products) + s.NoError(err) + s.Equal(1, len(products)) + s.Equal("where model", products[0].Name) + }) + + s.Run("where with raw query", func() { + var product Product + err := query.DB().Table("products").Where("name = ?", "where model").First(&product) + s.NoError(err) + s.Equal("where model", product.Name) + }) }) } } diff --git a/tests/go.mod b/tests/go.mod index 0868dcb0b..c64693cdd 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -65,7 +65,7 @@ require ( github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/crypto v0.33.0 // indirect golang.org/x/exp v0.0.0-20250215185904-eff6e970281f // indirect - golang.org/x/net v0.34.0 // indirect + golang.org/x/net v0.35.0 // indirect golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/term v0.29.0 // indirect diff --git a/tests/go.sum b/tests/go.sum index c800be7c4..2cd59a462 100644 --- a/tests/go.sum +++ b/tests/go.sum @@ -267,8 +267,8 @@ golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/tests/models.go b/tests/models.go index 7a529cc10..6431a3394 100644 --- a/tests/models.go +++ b/tests/models.go @@ -446,7 +446,7 @@ type Phone struct { type Product struct { Model SoftDeletes - Name string + Name string `db:"name"` } func (r *Product) Connection() string {