diff --git a/contracts/database/db/db.go b/contracts/database/db/db.go index 67ddb5efe..aae36d907 100644 --- a/contracts/database/db/db.go +++ b/contracts/database/db/db.go @@ -8,8 +8,10 @@ type DB interface { type Query interface { First(dest any) error + Delete() (*Result, error) Get(dest any) error Insert(data any) (*Result, error) + Update(data any) (*Result, error) Where(query any, args ...any) Query } diff --git a/database/db/query.go b/database/db/query.go index d15a7e6d7..779fa3a05 100644 --- a/database/db/query.go +++ b/database/db/query.go @@ -28,6 +28,29 @@ func NewQuery(config database.Config, builder db.Builder, table string) *Query { } } +func (r *Query) Delete() (*db.Result, error) { + sql, args, err := r.buildDelete() + // TODO: use logger instead of println + fmt.Println(sql, args, err) + if err != nil { + return nil, err + } + + result, err := r.builder.Exec(sql, args...) + if err != nil { + return nil, err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return nil, err + } + + return &db.Result{ + RowsAffected: rowsAffected, + }, nil +} + func (r *Query) First(dest any) error { sql, args, err := r.buildSelect() // TODO: use logger instead of println @@ -82,6 +105,34 @@ func (r *Query) Insert(data any) (*db.Result, error) { }, nil } +func (r *Query) Update(data any) (*db.Result, error) { + mapData, err := convertToMap(data) + if err != nil { + return nil, err + } + + sql, args, err := r.buildUpdate(mapData) + // TODO: use logger instead of println + fmt.Println(sql, args, err) + if err != nil { + return nil, err + } + + result, err := r.builder.Exec(sql, args...) + if err != nil { + return nil, err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return nil, err + } + + return &db.Result{ + RowsAffected: rowsAffected, + }, nil +} + func (r *Query) Where(query any, args ...any) db.Query { q := NewQuery(r.config, r.builder, r.conditions.table) q.conditions = r.conditions @@ -93,6 +144,35 @@ func (r *Query) Where(query any, args ...any) db.Query { return q } +func (r *Query) buildDelete() (sql string, args []any, err error) { + if r.conditions.table == "" { + return "", nil, errors.DatabaseTableIsRequired + } + + builder := sq.Delete(r.conditions.table) + if r.config.PlaceholderFormat != nil { + builder = builder.PlaceholderFormat(r.config.PlaceholderFormat) + } + + for _, where := range r.conditions.where { + query, ok := where.query.(string) + if ok { + if !str.Of(query).Trim().Contains(" ", "?") { + if len(where.args) > 1 { + builder = builder.Where(sq.Eq{query: where.args}) + } else if len(where.args) == 1 { + builder = builder.Where(sq.Eq{query: where.args[0]}) + } + continue + } + } + + builder = builder.Where(where.query, where.args...) + } + + return builder.ToSql() +} + func (r *Query) buildInsert(data []map[string]any) (sql string, args []any, err error) { if r.conditions.table == "" { return "", nil, errors.DatabaseTableIsRequired @@ -104,8 +184,6 @@ func (r *Query) buildInsert(data []map[string]any) (sql string, args []any, err } first := data[0] - builder = builder.SetMap(first) - cols := make([]string, 0, len(first)) for col := range first { cols = append(cols, col) @@ -154,3 +232,34 @@ func (r *Query) buildSelect() (sql string, args []any, err error) { return builder.ToSql() } + +func (r *Query) buildUpdate(data map[string]any) (sql string, args []any, err error) { + if r.conditions.table == "" { + return "", nil, errors.DatabaseTableIsRequired + } + + builder := sq.Update(r.conditions.table) + if r.config.PlaceholderFormat != nil { + builder = builder.PlaceholderFormat(r.config.PlaceholderFormat) + } + + for _, where := range r.conditions.where { + query, ok := where.query.(string) + if ok { + if !str.Of(query).Trim().Contains(" ", "?") { + if len(where.args) > 1 { + builder = builder.Where(sq.Eq{query: where.args}) + } else if len(where.args) == 1 { + builder = builder.Where(sq.Eq{query: where.args[0]}) + } + continue + } + } + + builder = builder.Where(where.query, where.args...) + } + + builder = builder.SetMap(data) + + return builder.ToSql() +} diff --git a/database/db/query_test.go b/database/db/query_test.go index 54d71d5c9..7697f707a 100644 --- a/database/db/query_test.go +++ b/database/db/query_test.go @@ -14,9 +14,10 @@ import ( // TestUser is a test model type TestUser struct { - ID uint `db:"id"` - Name string `db:"-"` - Age int + ID uint `db:"id"` + Phone string `db:"phone"` + Name string `db:"-"` + Age int } type QueryTestSuite struct { @@ -34,6 +35,18 @@ func (s *QueryTestSuite) SetupTest() { s.query = NewQuery(database.Config{}, s.mockBuilder, "users") } +func (s *QueryTestSuite) TestDelete() { + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + s.mockBuilder.EXPECT().Exec("DELETE FROM users WHERE name = ? AND id = ?", "John", 1).Return(mockResult, nil).Once() + + result, err := s.query.Where("name", "John").Where("id", 1).Delete() + s.Nil(err) + s.Equal(int64(1), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) +} + func (s *QueryTestSuite) TestFirst() { var user TestUser s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE name = ?", "John").Return(nil).Once() @@ -150,6 +163,44 @@ func (s *QueryTestSuite) TestInsert() { }) } +func (s *QueryTestSuite) TestUpdate() { + s.Run("single struct", func() { + user := TestUser{ + Phone: "1234567890", + Name: "John", + Age: 25, + } + + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + s.mockBuilder.EXPECT().Exec("UPDATE users SET phone = ? WHERE name = ? AND id = ?", "1234567890", "John", 1).Return(mockResult, nil).Once() + + result, err := s.query.Where("name", "John").Where("id", 1).Update(user) + s.Nil(err) + s.Equal(int64(1), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) + }) + + s.Run("single map", func() { + user := map[string]any{ + "phone": "1234567890", + "name": "John", + "age": 25, + } + + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + s.mockBuilder.EXPECT().Exec("UPDATE users SET age = ?, name = ?, phone = ? WHERE name = ? AND id = ?", 25, "John", "1234567890", "John", 1).Return(mockResult, nil).Once() + + result, err := s.query.Where("name", "John").Where("id", 1).Update(user) + s.Nil(err) + s.Equal(int64(1), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) + }) +} + func (s *QueryTestSuite) TestWhere() { s.Run("simple where condition", func() { var user TestUser diff --git a/database/db/utils.go b/database/db/utils.go index 9402dd9e5..35f33d18f 100644 --- a/database/db/utils.go +++ b/database/db/utils.go @@ -127,6 +127,9 @@ func convertToMap(data any) (map[string]any, error) { if fieldValue.Kind() == reflect.Ptr && !fieldValue.IsNil() { fieldValue = fieldValue.Elem() } + if fieldValue.IsZero() { + continue + } result[fieldName] = fieldValue.Interface() } return result, nil diff --git a/database/db/utils_test.go b/database/db/utils_test.go index 24d84175e..86c3c4587 100644 --- a/database/db/utils_test.go +++ b/database/db/utils_test.go @@ -7,6 +7,7 @@ import ( ) type Body struct { + Length int `db:"length"` Weight string `db:"weight"` Height int `db:"-"` Age uint @@ -31,29 +32,29 @@ func TestConvertToSliceMap(t *testing.T) { { data: []User{ {ID: 1, Name: "John", Email: "john@example.com", Body: Body{Weight: "100kg", Height: 180, Age: 25}}, - {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Weight: "90kg", Height: 170, Age: 20}}, + {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Length: 1, Weight: "90kg", Height: 170, Age: 20}}, }, want: []map[string]any{ {"id": 1, "weight": "100kg"}, - {"id": 2, "weight": "90kg"}, + {"id": 2, "length": 1, "weight": "90kg"}, }, }, { data: []*User{ {ID: 1, Name: "John", Email: "john@example.com", Body: Body{Weight: "100kg", Height: 180, Age: 25}}, - {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Weight: "90kg", Height: 170, Age: 20}}, + {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Length: 1, Weight: "90kg", Height: 170, Age: 20}}, }, want: []map[string]any{ {"id": 1, "weight": "100kg"}, - {"id": 2, "weight": "90kg"}, + {"id": 2, "length": 1, "weight": "90kg"}, }, }, { data: []Body{ {Weight: "100kg", Height: 180, Age: 25}, - {Weight: "90kg", Height: 170, Age: 20}, + {Length: 1, Weight: "90kg", Height: 170, Age: 20}, }, - want: []map[string]any{{"weight": "100kg"}, {"weight": "90kg"}}, + want: []map[string]any{{"weight": "100kg"}, {"length": 1, "weight": "90kg"}}, }, { data: Body{ diff --git a/mocks/database/db/Query.go b/mocks/database/db/Query.go index d32f62588..dc413a4df 100644 --- a/mocks/database/db/Query.go +++ b/mocks/database/db/Query.go @@ -20,6 +20,63 @@ func (_m *Query) EXPECT() *Query_Expecter { return &Query_Expecter{mock: &_m.Mock} } +// Delete provides a mock function with no fields +func (_m *Query) Delete() (*db.Result, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 *db.Result + var r1 error + if rf, ok := ret.Get(0).(func() (*db.Result, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *db.Result); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*db.Result) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type Query_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +func (_e *Query_Expecter) Delete() *Query_Delete_Call { + return &Query_Delete_Call{Call: _e.mock.On("Delete")} +} + +func (_c *Query_Delete_Call) Run(run func()) *Query_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Query_Delete_Call) Return(_a0 *db.Result, _a1 error) *Query_Delete_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Query_Delete_Call) RunAndReturn(run func() (*db.Result, error)) *Query_Delete_Call { + _c.Call.Return(run) + return _c +} + // First provides a mock function with given fields: dest func (_m *Query) First(dest interface{}) error { ret := _m.Called(dest) @@ -170,6 +227,64 @@ func (_c *Query_Insert_Call) RunAndReturn(run func(interface{}) (*db.Result, err return _c } +// Update provides a mock function with given fields: data +func (_m *Query) Update(data interface{}) (*db.Result, error) { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *db.Result + var r1 error + if rf, ok := ret.Get(0).(func(interface{}) (*db.Result, error)); ok { + return rf(data) + } + if rf, ok := ret.Get(0).(func(interface{}) *db.Result); ok { + r0 = rf(data) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*db.Result) + } + } + + if rf, ok := ret.Get(1).(func(interface{}) error); ok { + r1 = rf(data) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type Query_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - data interface{} +func (_e *Query_Expecter) Update(data interface{}) *Query_Update_Call { + return &Query_Update_Call{Call: _e.mock.On("Update", data)} +} + +func (_c *Query_Update_Call) Run(run func(data interface{})) *Query_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *Query_Update_Call) Return(_a0 *db.Result, _a1 error) *Query_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Query_Update_Call) RunAndReturn(run func(interface{}) (*db.Result, error)) *Query_Update_Call { + _c.Call.Return(run) + return _c +} + // Where provides a mock function with given fields: query, args func (_m *Query) Where(query interface{}, args ...interface{}) db.Query { var _ca []interface{} diff --git a/tests/db_test.go b/tests/db_test.go index 5a43eb91c..4e44ee69a 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -1,3 +1,5 @@ +//go:debug x509negativeserial=1 + package tests import ( @@ -36,9 +38,9 @@ func (s *DBTestSuite) TearDownSuite() { } func (s *DBTestSuite) TestInsert_First_Get() { - for driver, query := range s.queries { - now := carbon.NewDateTime(carbon.FromDateTime(2025, 1, 2, 3, 4, 5)) + now := carbon.NewDateTime(carbon.FromDateTime(2025, 1, 2, 3, 4, 5)) + for driver, query := range s.queries { s.Run(driver, func() { s.Run("single struct", func() { result, err := query.DB().Table("products").Insert(Product{ @@ -133,6 +135,73 @@ func (s *DBTestSuite) TestInsert_First_Get() { } } +func (s *DBTestSuite) TestUpdate_Delete() { + now := carbon.NewDateTime(carbon.FromDateTime(2025, 1, 2, 3, 4, 5)) + + for driver, query := range s.queries { + s.Run(driver, func() { + result, err := query.DB().Table("products").Insert([]Product{ + { + Name: "update structs1", + Model: Model{ + Timestamps: Timestamps{ + CreatedAt: now, + UpdatedAt: now, + }, + }, + }, + { + Name: "update structs2", + }, + }) + s.NoError(err) + s.Equal(int64(2), result.RowsAffected) + + // Create success + var products1 []Product + err = query.DB().Table("products").Where("name", []string{"update structs1", "update structs2"}).Where("deleted_at", nil).Get(&products1) + s.NoError(err) + s.Equal(2, len(products1)) + s.Equal("update structs1", products1[0].Name) + s.Equal("update structs2", products1[1].Name) + + // Update success via map + result, err = query.DB().Table("products").Where("name", "update structs1").Update(map[string]any{ + "name": "update structs1 updated", + }) + s.NoError(err) + s.Equal(int64(1), result.RowsAffected) + + var product1 Product + err = query.DB().Table("products").Where("name", "update structs1 updated").Where("deleted_at", nil).First(&product1) + s.NoError(err) + s.Equal("update structs1 updated", product1.Name) + + // Update success via struct + result, err = query.DB().Table("products").Where("name", "update structs2").Update(Product{ + Name: "update structs2 updated", + }) + s.NoError(err) + s.Equal(int64(1), result.RowsAffected) + + var product2 Product + err = query.DB().Table("products").Where("name", "update structs2 updated").Where("deleted_at", nil).First(&product2) + s.NoError(err) + s.Equal("update structs2 updated", product2.Name) + + // Delete success + result, err = query.DB().Table("products").Where("name like ?", "update structs%").Delete() + s.NoError(err) + s.Equal(int64(2), result.RowsAffected) + + var products2 []Product + err = query.DB().Table("products").Where("name", []string{"update structs1 updated", "update structs2 updated"}).Where("deleted_at", nil).Get(&products2) + s.NoError(err) + s.Equal(0, len(products2)) + }) + } +} + func (s *DBTestSuite) TestWhere() { for driver, query := range s.queries { s.Run(driver, func() { diff --git a/tests/go.mod b/tests/go.mod index b21658796..a220faa64 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -64,7 +64,7 @@ require ( github.com/stretchr/objx v0.5.2 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/crypto v0.33.0 // indirect - golang.org/x/exp v0.0.0-20250215185904-eff6e970281f // indirect + golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect golang.org/x/net v0.35.0 // indirect golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect diff --git a/tests/go.sum b/tests/go.sum index 63f4a1f4b..976586c6c 100644 --- a/tests/go.sum +++ b/tests/go.sum @@ -248,8 +248,8 @@ golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= -golang.org/x/exp v0.0.0-20250215185904-eff6e970281f h1:oFMYAjX0867ZD2jcNiLBrI9BdpmEkvPyi5YrBGXbamg= -golang.org/x/exp v0.0.0-20250215185904-eff6e970281f/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=