From c45f27300df15370135039703b095740d5a03015 Mon Sep 17 00:00:00 2001 From: Bowen Date: Fri, 21 Feb 2025 18:49:53 +0800 Subject: [PATCH] feat: [#358] Add Update Delete methods for DB --- contracts/database/db/db.go | 2 + database/db/query.go | 113 ++++++++++++++++++++++++++++++++++- database/db/query_test.go | 57 +++++++++++++++++- database/db/utils.go | 3 + database/db/utils_test.go | 13 ++-- mocks/database/db/Query.go | 115 ++++++++++++++++++++++++++++++++++++ 6 files changed, 292 insertions(+), 11 deletions(-) diff --git a/contracts/database/db/db.go b/contracts/database/db/db.go index 67ddb5efe..aae36d907 100644 --- a/contracts/database/db/db.go +++ b/contracts/database/db/db.go @@ -8,8 +8,10 @@ type DB interface { type Query interface { First(dest any) error + Delete() (*Result, error) Get(dest any) error Insert(data any) (*Result, error) + Update(data any) (*Result, error) Where(query any, args ...any) Query } diff --git a/database/db/query.go b/database/db/query.go index d15a7e6d7..779fa3a05 100644 --- a/database/db/query.go +++ b/database/db/query.go @@ -28,6 +28,29 @@ func NewQuery(config database.Config, builder db.Builder, table string) *Query { } } +func (r *Query) Delete() (*db.Result, error) { + sql, args, err := r.buildDelete() + // TODO: use logger instead of println + fmt.Println(sql, args, err) + if err != nil { + return nil, err + } + + result, err := r.builder.Exec(sql, args...) + if err != nil { + return nil, err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return nil, err + } + + return &db.Result{ + RowsAffected: rowsAffected, + }, nil +} + func (r *Query) First(dest any) error { sql, args, err := r.buildSelect() // TODO: use logger instead of println @@ -82,6 +105,34 @@ func (r *Query) Insert(data any) (*db.Result, error) { }, nil } +func (r *Query) Update(data any) (*db.Result, error) { + mapData, err := convertToMap(data) + if err != nil { + return nil, err + } + + sql, args, err := r.buildUpdate(mapData) + // TODO: use logger instead of println + fmt.Println(sql, args, err) + if err != nil { + return nil, err + } + + result, err := r.builder.Exec(sql, args...) + if err != nil { + return nil, err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return nil, err + } + + return &db.Result{ + RowsAffected: rowsAffected, + }, nil +} + func (r *Query) Where(query any, args ...any) db.Query { q := NewQuery(r.config, r.builder, r.conditions.table) q.conditions = r.conditions @@ -93,6 +144,35 @@ func (r *Query) Where(query any, args ...any) db.Query { return q } +func (r *Query) buildDelete() (sql string, args []any, err error) { + if r.conditions.table == "" { + return "", nil, errors.DatabaseTableIsRequired + } + + builder := sq.Delete(r.conditions.table) + if r.config.PlaceholderFormat != nil { + builder = builder.PlaceholderFormat(r.config.PlaceholderFormat) + } + + for _, where := range r.conditions.where { + query, ok := where.query.(string) + if ok { + if !str.Of(query).Trim().Contains(" ", "?") { + if len(where.args) > 1 { + builder = builder.Where(sq.Eq{query: where.args}) + } else if len(where.args) == 1 { + builder = builder.Where(sq.Eq{query: where.args[0]}) + } + continue + } + } + + builder = builder.Where(where.query, where.args...) + } + + return builder.ToSql() +} + func (r *Query) buildInsert(data []map[string]any) (sql string, args []any, err error) { if r.conditions.table == "" { return "", nil, errors.DatabaseTableIsRequired @@ -104,8 +184,6 @@ func (r *Query) buildInsert(data []map[string]any) (sql string, args []any, err } first := data[0] - builder = builder.SetMap(first) - cols := make([]string, 0, len(first)) for col := range first { cols = append(cols, col) @@ -154,3 +232,34 @@ func (r *Query) buildSelect() (sql string, args []any, err error) { return builder.ToSql() } + +func (r *Query) buildUpdate(data map[string]any) (sql string, args []any, err error) { + if r.conditions.table == "" { + return "", nil, errors.DatabaseTableIsRequired + } + + builder := sq.Update(r.conditions.table) + if r.config.PlaceholderFormat != nil { + builder = builder.PlaceholderFormat(r.config.PlaceholderFormat) + } + + for _, where := range r.conditions.where { + query, ok := where.query.(string) + if ok { + if !str.Of(query).Trim().Contains(" ", "?") { + if len(where.args) > 1 { + builder = builder.Where(sq.Eq{query: where.args}) + } else if len(where.args) == 1 { + builder = builder.Where(sq.Eq{query: where.args[0]}) + } + continue + } + } + + builder = builder.Where(where.query, where.args...) + } + + builder = builder.SetMap(data) + + return builder.ToSql() +} diff --git a/database/db/query_test.go b/database/db/query_test.go index 54d71d5c9..7697f707a 100644 --- a/database/db/query_test.go +++ b/database/db/query_test.go @@ -14,9 +14,10 @@ import ( // TestUser is a test model type TestUser struct { - ID uint `db:"id"` - Name string `db:"-"` - Age int + ID uint `db:"id"` + Phone string `db:"phone"` + Name string `db:"-"` + Age int } type QueryTestSuite struct { @@ -34,6 +35,18 @@ func (s *QueryTestSuite) SetupTest() { s.query = NewQuery(database.Config{}, s.mockBuilder, "users") } +func (s *QueryTestSuite) TestDelete() { + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + s.mockBuilder.EXPECT().Exec("DELETE FROM users WHERE name = ? AND id = ?", "John", 1).Return(mockResult, nil).Once() + + result, err := s.query.Where("name", "John").Where("id", 1).Delete() + s.Nil(err) + s.Equal(int64(1), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) +} + func (s *QueryTestSuite) TestFirst() { var user TestUser s.mockBuilder.EXPECT().Get(&user, "SELECT * FROM users WHERE name = ?", "John").Return(nil).Once() @@ -150,6 +163,44 @@ func (s *QueryTestSuite) TestInsert() { }) } +func (s *QueryTestSuite) TestUpdate() { + s.Run("single struct", func() { + user := TestUser{ + Phone: "1234567890", + Name: "John", + Age: 25, + } + + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + s.mockBuilder.EXPECT().Exec("UPDATE users SET phone = ? WHERE name = ? AND id = ?", "1234567890", "John", 1).Return(mockResult, nil).Once() + + result, err := s.query.Where("name", "John").Where("id", 1).Update(user) + s.Nil(err) + s.Equal(int64(1), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) + }) + + s.Run("single map", func() { + user := map[string]any{ + "phone": "1234567890", + "name": "John", + "age": 25, + } + + mockResult := &MockResult{} + mockResult.On("RowsAffected").Return(int64(1), nil) + s.mockBuilder.EXPECT().Exec("UPDATE users SET age = ?, name = ?, phone = ? WHERE name = ? AND id = ?", 25, "John", "1234567890", "John", 1).Return(mockResult, nil).Once() + + result, err := s.query.Where("name", "John").Where("id", 1).Update(user) + s.Nil(err) + s.Equal(int64(1), result.RowsAffected) + + mockResult.AssertExpectations(s.T()) + }) +} + func (s *QueryTestSuite) TestWhere() { s.Run("simple where condition", func() { var user TestUser diff --git a/database/db/utils.go b/database/db/utils.go index 9402dd9e5..35f33d18f 100644 --- a/database/db/utils.go +++ b/database/db/utils.go @@ -127,6 +127,9 @@ func convertToMap(data any) (map[string]any, error) { if fieldValue.Kind() == reflect.Ptr && !fieldValue.IsNil() { fieldValue = fieldValue.Elem() } + if fieldValue.IsZero() { + continue + } result[fieldName] = fieldValue.Interface() } return result, nil diff --git a/database/db/utils_test.go b/database/db/utils_test.go index 24d84175e..86c3c4587 100644 --- a/database/db/utils_test.go +++ b/database/db/utils_test.go @@ -7,6 +7,7 @@ import ( ) type Body struct { + Length int `db:"length"` Weight string `db:"weight"` Height int `db:"-"` Age uint @@ -31,29 +32,29 @@ func TestConvertToSliceMap(t *testing.T) { { data: []User{ {ID: 1, Name: "John", Email: "john@example.com", Body: Body{Weight: "100kg", Height: 180, Age: 25}}, - {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Weight: "90kg", Height: 170, Age: 20}}, + {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Length: 1, Weight: "90kg", Height: 170, Age: 20}}, }, want: []map[string]any{ {"id": 1, "weight": "100kg"}, - {"id": 2, "weight": "90kg"}, + {"id": 2, "length": 1, "weight": "90kg"}, }, }, { data: []*User{ {ID: 1, Name: "John", Email: "john@example.com", Body: Body{Weight: "100kg", Height: 180, Age: 25}}, - {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Weight: "90kg", Height: 170, Age: 20}}, + {ID: 2, Name: "Jane", Email: "jane@example.com", Body: Body{Length: 1, Weight: "90kg", Height: 170, Age: 20}}, }, want: []map[string]any{ {"id": 1, "weight": "100kg"}, - {"id": 2, "weight": "90kg"}, + {"id": 2, "length": 1, "weight": "90kg"}, }, }, { data: []Body{ {Weight: "100kg", Height: 180, Age: 25}, - {Weight: "90kg", Height: 170, Age: 20}, + {Length: 1, Weight: "90kg", Height: 170, Age: 20}, }, - want: []map[string]any{{"weight": "100kg"}, {"weight": "90kg"}}, + want: []map[string]any{{"weight": "100kg"}, {"length": 1, "weight": "90kg"}}, }, { data: Body{ diff --git a/mocks/database/db/Query.go b/mocks/database/db/Query.go index d32f62588..dc413a4df 100644 --- a/mocks/database/db/Query.go +++ b/mocks/database/db/Query.go @@ -20,6 +20,63 @@ func (_m *Query) EXPECT() *Query_Expecter { return &Query_Expecter{mock: &_m.Mock} } +// Delete provides a mock function with no fields +func (_m *Query) Delete() (*db.Result, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 *db.Result + var r1 error + if rf, ok := ret.Get(0).(func() (*db.Result, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *db.Result); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*db.Result) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type Query_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +func (_e *Query_Expecter) Delete() *Query_Delete_Call { + return &Query_Delete_Call{Call: _e.mock.On("Delete")} +} + +func (_c *Query_Delete_Call) Run(run func()) *Query_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Query_Delete_Call) Return(_a0 *db.Result, _a1 error) *Query_Delete_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Query_Delete_Call) RunAndReturn(run func() (*db.Result, error)) *Query_Delete_Call { + _c.Call.Return(run) + return _c +} + // First provides a mock function with given fields: dest func (_m *Query) First(dest interface{}) error { ret := _m.Called(dest) @@ -170,6 +227,64 @@ func (_c *Query_Insert_Call) RunAndReturn(run func(interface{}) (*db.Result, err return _c } +// Update provides a mock function with given fields: data +func (_m *Query) Update(data interface{}) (*db.Result, error) { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *db.Result + var r1 error + if rf, ok := ret.Get(0).(func(interface{}) (*db.Result, error)); ok { + return rf(data) + } + if rf, ok := ret.Get(0).(func(interface{}) *db.Result); ok { + r0 = rf(data) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*db.Result) + } + } + + if rf, ok := ret.Get(1).(func(interface{}) error); ok { + r1 = rf(data) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type Query_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - data interface{} +func (_e *Query_Expecter) Update(data interface{}) *Query_Update_Call { + return &Query_Update_Call{Call: _e.mock.On("Update", data)} +} + +func (_c *Query_Update_Call) Run(run func(data interface{})) *Query_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *Query_Update_Call) Return(_a0 *db.Result, _a1 error) *Query_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Query_Update_Call) RunAndReturn(run func(interface{}) (*db.Result, error)) *Query_Update_Call { + _c.Call.Return(run) + return _c +} + // Where provides a mock function with given fields: query, args func (_m *Query) Where(query interface{}, args ...interface{}) db.Query { var _ca []interface{}