From 9759214e9babbe14b836ecabe51b1ed55740505f Mon Sep 17 00:00:00 2001 From: Bowen Date: Mon, 3 Mar 2025 23:51:46 +0800 Subject: [PATCH] feat: [#358] Implement Transaction method --- contracts/database/db/db.go | 6 +- database/db/db.go | 63 +++++++++++++--- database/db/db_test.go | 2 +- database/db/query.go | 36 ++++----- database/db/query_test.go | 4 +- database/db/utils.go | 10 +++ errors/list.go | 1 + mocks/database/db/DB.go | 147 ++++++++++++++++++++++++++++++++++++ tests/db_test.go | 40 ++++++++++ tests/query.go | 2 +- 10 files changed, 275 insertions(+), 36 deletions(-) diff --git a/contracts/database/db/db.go b/contracts/database/db/db.go index e8a2c010e..9289a09b6 100644 --- a/contracts/database/db/db.go +++ b/contracts/database/db/db.go @@ -6,15 +6,16 @@ import ( ) type DB interface { - // BeginTransaction() Query + BeginTransaction() (DB, error) + Commit() error Connection(name string) DB + Rollback() error Table(name string) Query // Transaction(txFunc func(tx Query) error) error WithContext(ctx context.Context) DB } type Query interface { - // commit // Count Retrieve the "count" result of the query. Count() (int64, error) // Chunk Execute a callback over a given chunk size. @@ -93,7 +94,6 @@ type Query interface { OrWhereRaw(raw string, args []any) Query // Pluck Get a collection instance containing the values of a given column. Pluck(column string, dest any) error - // rollBack // RightJoin(table string, on any, args ...any) Query // Select Set the columns to be selected. Select(columns ...string) Query diff --git a/database/db/db.go b/database/db/db.go index a7fb31dda..2195b515c 100644 --- a/database/db/db.go +++ b/database/db/db.go @@ -7,24 +7,28 @@ import ( "github.com/jmoiron/sqlx" "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/contracts/database/db" + contractsdb "github.com/goravel/framework/contracts/database/db" contractsdriver "github.com/goravel/framework/contracts/database/driver" + contractslogger "github.com/goravel/framework/contracts/database/logger" "github.com/goravel/framework/contracts/log" "github.com/goravel/framework/database/logger" "github.com/goravel/framework/errors" ) type DB struct { - builder db.Builder config config.Config ctx context.Context + db *sqlx.DB driver contractsdriver.Driver log log.Log - queries map[string]db.DB + logger contractslogger.Logger + queries map[string]contractsdb.DB + tx *sqlx.Tx + txLogs *[]TxLog } -func NewDB(ctx context.Context, config config.Config, driver contractsdriver.Driver, log log.Log, builder db.Builder) *DB { - return &DB{ctx: ctx, config: config, driver: driver, log: log, builder: builder, queries: make(map[string]db.DB)} +func NewDB(ctx context.Context, config config.Config, driver contractsdriver.Driver, log log.Log, db *sqlx.DB, tx *sqlx.Tx, txLogs *[]TxLog) *DB { + return &DB{ctx: ctx, config: config, driver: driver, log: log, logger: logger.NewLogger(config, log), db: db, queries: make(map[string]contractsdb.DB), tx: tx, txLogs: txLogs} } func BuildDB(ctx context.Context, config config.Config, log log.Log, connection string) (*DB, error) { @@ -43,10 +47,35 @@ func BuildDB(ctx context.Context, config config.Config, log log.Log, connection return nil, err } - return NewDB(ctx, config, driver, log, sqlx.NewDb(instance, driver.Config().Driver)), nil + return NewDB(ctx, config, driver, log, sqlx.NewDb(instance, driver.Config().Driver), nil, nil), nil } -func (r *DB) Connection(name string) db.DB { +func (r *DB) BeginTransaction() (contractsdb.DB, error) { + tx, err := r.db.Beginx() + if err != nil { + return nil, err + } + + return NewDB(r.ctx, r.config, r.driver, r.log, nil, tx, &[]TxLog{}), nil +} + +func (r *DB) Commit() error { + if r.tx == nil { + return errors.DatabaseTransactionNotStarted + } + + if err := r.tx.Commit(); err != nil { + return err + } + + for _, log := range *r.txLogs { + r.logger.Trace(log.ctx, log.begin, log.sql, log.rowsAffected, log.err) + } + + return nil +} + +func (r *DB) Connection(name string) contractsdb.DB { if name == "" { name = r.config.GetString("database.default") } @@ -64,10 +93,22 @@ func (r *DB) Connection(name string) db.DB { return r.queries[name] } -func (r *DB) Table(name string) db.Query { - return NewQuery(r.ctx, r.driver, r.builder, logger.NewLogger(r.config, r.log), name) +func (r *DB) Rollback() error { + if r.tx == nil { + return errors.DatabaseTransactionNotStarted + } + + return r.tx.Rollback() +} + +func (r *DB) Table(name string) contractsdb.Query { + if r.tx != nil { + return NewQuery(r.ctx, r.driver, r.tx, r.logger, name, r.txLogs) + } + + return NewQuery(r.ctx, r.driver, r.db, r.logger, name, nil) } -func (r *DB) WithContext(ctx context.Context) db.DB { - return NewDB(ctx, r.config, r.driver, r.log, r.builder) +func (r *DB) WithContext(ctx context.Context) contractsdb.DB { + return NewDB(ctx, r.config, r.driver, r.log, r.db, r.tx, r.txLogs) } diff --git a/database/db/db_test.go b/database/db/db_test.go index ea2e04657..e1192801c 100644 --- a/database/db/db_test.go +++ b/database/db/db_test.go @@ -142,7 +142,7 @@ func TestConnection(t *testing.T) { mockDriver = mocksdriver.NewDriver(t) mockLog = mockslog.NewLog(t) - db := NewDB(context.Background(), mockConfig, mockDriver, mockLog, nil) + db := NewDB(context.Background(), mockConfig, mockDriver, mockLog, nil, nil, nil) test.setup(db) if test.expectedPanic { diff --git a/database/db/query.go b/database/db/query.go index 6390957f5..948038383 100644 --- a/database/db/query.go +++ b/database/db/query.go @@ -27,10 +27,10 @@ type Query struct { err error driver driver.Driver logger logger.Logger - single bool + txLogs *[]TxLog } -func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, logger logger.Logger, table string) *Query { +func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, logger logger.Logger, table string, txLogs *[]TxLog) *Query { return &Query{ builder: builder, conditions: Conditions{ @@ -39,16 +39,10 @@ func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, log ctx: ctx, driver: driver, logger: logger, + txLogs: txLogs, } } -func NewSingleQuery(ctx context.Context, driver driver.Driver, builder db.Builder, logger logger.Logger, table string) *Query { - query := NewQuery(ctx, driver, builder, logger, table) - query.single = true - - return query -} - func (r *Query) Count() (int64, error) { r.conditions.Selects = []string{"COUNT(*)"} @@ -768,10 +762,10 @@ func (r *Query) buildWhere(where Where) (any, []any, error) { } } return query, where.args, nil - case func(db.Query): + case func(db.Query) db.Query: // Handle nested conditions by creating a new query and applying the callback - nestedQuery := NewSingleQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table) - query(nestedQuery) + nestedQuery := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table, r.txLogs) + nestedQuery = query(nestedQuery).(*Query) // Build the nested conditions sqlizer, err := r.buildWheres(nestedQuery.conditions.Where) @@ -834,11 +828,7 @@ func (r *Query) buildWheres(wheres []Where) (sq.Sqlizer, error) { } func (r *Query) clone() *Query { - if r.single { - return r - } - - query := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table) + query := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table, r.txLogs) query.conditions = r.conditions query.err = r.err @@ -867,5 +857,15 @@ func (r *Query) toSqlizer(query any, args []any) (sq.Sqlizer, error) { } func (r *Query) trace(sql string, args []any, rowsAffected int64, err error) { - r.logger.Trace(r.ctx, carbon.Now(), r.driver.Explain(sql, args...), rowsAffected, err) + if r.txLogs != nil { + *r.txLogs = append(*r.txLogs, TxLog{ + ctx: r.ctx, + begin: carbon.Now(), + sql: r.driver.Explain(sql, args...), + rowsAffected: rowsAffected, + err: err, + }) + } else { + r.logger.Trace(r.ctx, carbon.Now(), r.driver.Explain(sql, args...), rowsAffected, err) + } } diff --git a/database/db/query_test.go b/database/db/query_test.go index 73b418eb3..a8b54e95d 100644 --- a/database/db/query_test.go +++ b/database/db/query_test.go @@ -48,7 +48,7 @@ func (s *QueryTestSuite) SetupTest() { s.now = carbon.Now() carbon.SetTestNow(s.now) - s.query = NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "users") + s.query = NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "users", nil) } func (s *QueryTestSuite) TestCount() { @@ -1031,7 +1031,7 @@ func (s *QueryTestSuite) TestWhereExists() { s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE (name = \"John\" AND EXISTS (SELECT * FROM agents WHERE age = 25))", int64(0), nil).Return().Once() err := s.query.Where("name", "John").WhereExists(func() db.Query { - return NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "agents").Where("age", 25) + return NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "agents", nil).Where("age", 25) }).Get(&users) s.Nil(err) diff --git a/database/db/utils.go b/database/db/utils.go index d04c134a9..15bededc7 100644 --- a/database/db/utils.go +++ b/database/db/utils.go @@ -1,12 +1,22 @@ package db import ( + "context" "reflect" "strings" "github.com/goravel/framework/errors" + "github.com/goravel/framework/support/carbon" ) +type TxLog struct { + ctx context.Context + begin carbon.Carbon + sql string + rowsAffected int64 + err error +} + func convertToSliceMap(data any) ([]map[string]any, error) { if data == nil { return nil, nil diff --git a/errors/list.go b/errors/list.go index a85d65522..e1087f8ee 100644 --- a/errors/list.go +++ b/errors/list.go @@ -52,6 +52,7 @@ var ( DatabaseFailToRunSeeder = New("fail to run seeder: %v") DatabaseUnsupportedType = New("unsupported type: %s, expected %s") DatabaseInvalidArgumentNumber = New("invalid argument number: %s, expected %s") + DatabaseTransactionNotStarted = New("transaction not started") DockerUnknownContainerType = New("unknown container type") DockerInsufficientDatabaseContainers = New("the number of database container is not enough, expect: %d, got: %d") diff --git a/mocks/database/db/DB.go b/mocks/database/db/DB.go index 15f348fd4..729999b8e 100644 --- a/mocks/database/db/DB.go +++ b/mocks/database/db/DB.go @@ -22,6 +22,108 @@ func (_m *DB) EXPECT() *DB_Expecter { return &DB_Expecter{mock: &_m.Mock} } +// BeginTransaction provides a mock function with no fields +func (_m *DB) BeginTransaction() (db.DB, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for BeginTransaction") + } + + var r0 db.DB + var r1 error + if rf, ok := ret.Get(0).(func() (db.DB, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() db.DB); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(db.DB) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DB_BeginTransaction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeginTransaction' +type DB_BeginTransaction_Call struct { + *mock.Call +} + +// BeginTransaction is a helper method to define mock.On call +func (_e *DB_Expecter) BeginTransaction() *DB_BeginTransaction_Call { + return &DB_BeginTransaction_Call{Call: _e.mock.On("BeginTransaction")} +} + +func (_c *DB_BeginTransaction_Call) Run(run func()) *DB_BeginTransaction_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *DB_BeginTransaction_Call) Return(_a0 db.DB, _a1 error) *DB_BeginTransaction_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *DB_BeginTransaction_Call) RunAndReturn(run func() (db.DB, error)) *DB_BeginTransaction_Call { + _c.Call.Return(run) + return _c +} + +// Commit provides a mock function with no fields +func (_m *DB) Commit() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Commit") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DB_Commit_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Commit' +type DB_Commit_Call struct { + *mock.Call +} + +// Commit is a helper method to define mock.On call +func (_e *DB_Expecter) Commit() *DB_Commit_Call { + return &DB_Commit_Call{Call: _e.mock.On("Commit")} +} + +func (_c *DB_Commit_Call) Run(run func()) *DB_Commit_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *DB_Commit_Call) Return(_a0 error) *DB_Commit_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *DB_Commit_Call) RunAndReturn(run func() error) *DB_Commit_Call { + _c.Call.Return(run) + return _c +} + // Connection provides a mock function with given fields: name func (_m *DB) Connection(name string) db.DB { ret := _m.Called(name) @@ -70,6 +172,51 @@ func (_c *DB_Connection_Call) RunAndReturn(run func(string) db.DB) *DB_Connectio return _c } +// Rollback provides a mock function with no fields +func (_m *DB) Rollback() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Rollback") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DB_Rollback_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Rollback' +type DB_Rollback_Call struct { + *mock.Call +} + +// Rollback is a helper method to define mock.On call +func (_e *DB_Expecter) Rollback() *DB_Rollback_Call { + return &DB_Rollback_Call{Call: _e.mock.On("Rollback")} +} + +func (_c *DB_Rollback_Call) Run(run func()) *DB_Rollback_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *DB_Rollback_Call) Return(_a0 error) *DB_Rollback_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *DB_Rollback_Call) RunAndReturn(run func() error) *DB_Rollback_Call { + _c.Call.Return(run) + return _c +} + // Table provides a mock function with given fields: name func (_m *DB) Table(name string) db.Query { ret := _m.Called(name) diff --git a/tests/db_test.go b/tests/db_test.go index 7c746b678..8be567d28 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -3,6 +3,7 @@ package tests import ( + "database/sql" "testing" "github.com/goravel/framework/contracts/database/db" @@ -409,6 +410,45 @@ func (s *DBTestSuite) TestPluck() { } } +func (s *DBTestSuite) TestTransaction() { + for driver, query := range s.queries { + s.Run(driver, func() { + tx, err := query.DB().BeginTransaction() + s.NoError(err) + s.NotNil(tx) + + result, err := tx.Table("products").Insert(Product{Name: "transaction product"}) + s.NoError(err) + s.Equal(int64(1), result.RowsAffected) + + s.NoError(tx.Commit()) + + var product Product + err = query.DB().Table("products").Where("name", "transaction product").First(&product) + s.NoError(err) + s.Equal("transaction product", product.Name) + + tx, err = query.DB().BeginTransaction() + s.NoError(err) + s.NotNil(tx) + + result, err = tx.Table("products").Where("name", "transaction product").Update("name", "transaction product updated") + s.NoError(err) + s.Equal(int64(1), result.RowsAffected) + s.NoError(tx.Rollback()) + + var product1 Product + err = query.DB().Table("products").Where("name", "transaction product").First(&product1) + s.NoError(err) + s.Equal("transaction product", product1.Name) + + var product2 Product + err = query.DB().Table("products").Where("name", "transaction product updated").FirstOrFail(&product2) + s.Equal(sql.ErrNoRows, err) + }) + } +} + func (s *DBTestSuite) TestUpdate_Delete() { for driver, query := range s.queries { s.Run(driver, func() { diff --git a/tests/query.go b/tests/query.go index 43228aa7f..cde1ce6a7 100644 --- a/tests/query.go +++ b/tests/query.go @@ -48,7 +48,7 @@ func NewTestQuery(ctx context.Context, driver contractsdriver.Driver, config con testQuery := &TestQuery{ config: config, - db: databasedb.NewDB(ctx, config, driver, utils.NewTestLog(), sqlx.NewDb(db, driver.Config().Driver)), + db: databasedb.NewDB(ctx, config, driver, utils.NewTestLog(), sqlx.NewDb(db, driver.Config().Driver), nil, nil), driver: driver, query: gorm.NewQuery(ctx, config, driver.Config(), query, gormQuery, utils.NewTestLog(), nil, nil), }