Skip to content

Commit

Permalink
feat: [#358] Implement Transaction method
Browse files Browse the repository at this point in the history
  • Loading branch information
hwbrzzl committed Mar 3, 2025
1 parent 279fdf5 commit 9759214
Show file tree
Hide file tree
Showing 10 changed files with 275 additions and 36 deletions.
6 changes: 3 additions & 3 deletions contracts/database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ import (
)

type DB interface {
// BeginTransaction() Query
BeginTransaction() (DB, error)
Commit() error
Connection(name string) DB
Rollback() error
Table(name string) Query
// Transaction(txFunc func(tx Query) error) error
WithContext(ctx context.Context) DB
}

type Query interface {
// commit
// Count Retrieve the "count" result of the query.
Count() (int64, error)
// Chunk Execute a callback over a given chunk size.
Expand Down Expand Up @@ -93,7 +94,6 @@ type Query interface {
OrWhereRaw(raw string, args []any) Query
// Pluck Get a collection instance containing the values of a given column.
Pluck(column string, dest any) error
// rollBack
// RightJoin(table string, on any, args ...any) Query
// Select Set the columns to be selected.
Select(columns ...string) Query
Expand Down
63 changes: 52 additions & 11 deletions database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,28 @@ import (
"github.com/jmoiron/sqlx"

"github.com/goravel/framework/contracts/config"
"github.com/goravel/framework/contracts/database/db"
contractsdb "github.com/goravel/framework/contracts/database/db"
contractsdriver "github.com/goravel/framework/contracts/database/driver"
contractslogger "github.com/goravel/framework/contracts/database/logger"
"github.com/goravel/framework/contracts/log"
"github.com/goravel/framework/database/logger"
"github.com/goravel/framework/errors"
)

type DB struct {
builder db.Builder
config config.Config
ctx context.Context
db *sqlx.DB
driver contractsdriver.Driver
log log.Log
queries map[string]db.DB
logger contractslogger.Logger
queries map[string]contractsdb.DB
tx *sqlx.Tx
txLogs *[]TxLog
}

func NewDB(ctx context.Context, config config.Config, driver contractsdriver.Driver, log log.Log, builder db.Builder) *DB {
return &DB{ctx: ctx, config: config, driver: driver, log: log, builder: builder, queries: make(map[string]db.DB)}
func NewDB(ctx context.Context, config config.Config, driver contractsdriver.Driver, log log.Log, db *sqlx.DB, tx *sqlx.Tx, txLogs *[]TxLog) *DB {
return &DB{ctx: ctx, config: config, driver: driver, log: log, logger: logger.NewLogger(config, log), db: db, queries: make(map[string]contractsdb.DB), tx: tx, txLogs: txLogs}
}

func BuildDB(ctx context.Context, config config.Config, log log.Log, connection string) (*DB, error) {
Expand All @@ -43,10 +47,35 @@ func BuildDB(ctx context.Context, config config.Config, log log.Log, connection
return nil, err
}

return NewDB(ctx, config, driver, log, sqlx.NewDb(instance, driver.Config().Driver)), nil
return NewDB(ctx, config, driver, log, sqlx.NewDb(instance, driver.Config().Driver), nil, nil), nil
}

func (r *DB) Connection(name string) db.DB {
func (r *DB) BeginTransaction() (contractsdb.DB, error) {
tx, err := r.db.Beginx()
if err != nil {
return nil, err
}

return NewDB(r.ctx, r.config, r.driver, r.log, nil, tx, &[]TxLog{}), nil
}

func (r *DB) Commit() error {
if r.tx == nil {
return errors.DatabaseTransactionNotStarted
}

if err := r.tx.Commit(); err != nil {
return err
}

for _, log := range *r.txLogs {
r.logger.Trace(log.ctx, log.begin, log.sql, log.rowsAffected, log.err)
}

return nil
}

func (r *DB) Connection(name string) contractsdb.DB {
if name == "" {
name = r.config.GetString("database.default")
}
Expand All @@ -64,10 +93,22 @@ func (r *DB) Connection(name string) db.DB {
return r.queries[name]
}

func (r *DB) Table(name string) db.Query {
return NewQuery(r.ctx, r.driver, r.builder, logger.NewLogger(r.config, r.log), name)
func (r *DB) Rollback() error {
if r.tx == nil {
return errors.DatabaseTransactionNotStarted
}

return r.tx.Rollback()
}

func (r *DB) Table(name string) contractsdb.Query {
if r.tx != nil {
return NewQuery(r.ctx, r.driver, r.tx, r.logger, name, r.txLogs)
}

return NewQuery(r.ctx, r.driver, r.db, r.logger, name, nil)
}

func (r *DB) WithContext(ctx context.Context) db.DB {
return NewDB(ctx, r.config, r.driver, r.log, r.builder)
func (r *DB) WithContext(ctx context.Context) contractsdb.DB {
return NewDB(ctx, r.config, r.driver, r.log, r.db, r.tx, r.txLogs)
}
2 changes: 1 addition & 1 deletion database/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestConnection(t *testing.T) {
mockDriver = mocksdriver.NewDriver(t)
mockLog = mockslog.NewLog(t)

db := NewDB(context.Background(), mockConfig, mockDriver, mockLog, nil)
db := NewDB(context.Background(), mockConfig, mockDriver, mockLog, nil, nil, nil)
test.setup(db)

if test.expectedPanic {
Expand Down
36 changes: 18 additions & 18 deletions database/db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ type Query struct {
err error
driver driver.Driver
logger logger.Logger
single bool
txLogs *[]TxLog
}

func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, logger logger.Logger, table string) *Query {
func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, logger logger.Logger, table string, txLogs *[]TxLog) *Query {
return &Query{
builder: builder,
conditions: Conditions{
Expand All @@ -39,16 +39,10 @@ func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, log
ctx: ctx,
driver: driver,
logger: logger,
txLogs: txLogs,
}
}

func NewSingleQuery(ctx context.Context, driver driver.Driver, builder db.Builder, logger logger.Logger, table string) *Query {
query := NewQuery(ctx, driver, builder, logger, table)
query.single = true

return query
}

func (r *Query) Count() (int64, error) {
r.conditions.Selects = []string{"COUNT(*)"}

Expand Down Expand Up @@ -768,10 +762,10 @@ func (r *Query) buildWhere(where Where) (any, []any, error) {
}
}
return query, where.args, nil
case func(db.Query):
case func(db.Query) db.Query:
// Handle nested conditions by creating a new query and applying the callback
nestedQuery := NewSingleQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table)
query(nestedQuery)
nestedQuery := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table, r.txLogs)
nestedQuery = query(nestedQuery).(*Query)

// Build the nested conditions
sqlizer, err := r.buildWheres(nestedQuery.conditions.Where)
Expand Down Expand Up @@ -834,11 +828,7 @@ func (r *Query) buildWheres(wheres []Where) (sq.Sqlizer, error) {
}

func (r *Query) clone() *Query {
if r.single {
return r
}

query := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table)
query := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table, r.txLogs)
query.conditions = r.conditions
query.err = r.err

Expand Down Expand Up @@ -867,5 +857,15 @@ func (r *Query) toSqlizer(query any, args []any) (sq.Sqlizer, error) {
}

func (r *Query) trace(sql string, args []any, rowsAffected int64, err error) {
r.logger.Trace(r.ctx, carbon.Now(), r.driver.Explain(sql, args...), rowsAffected, err)
if r.txLogs != nil {
*r.txLogs = append(*r.txLogs, TxLog{
ctx: r.ctx,
begin: carbon.Now(),
sql: r.driver.Explain(sql, args...),
rowsAffected: rowsAffected,
err: err,
})
} else {
r.logger.Trace(r.ctx, carbon.Now(), r.driver.Explain(sql, args...), rowsAffected, err)
}
}
4 changes: 2 additions & 2 deletions database/db/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *QueryTestSuite) SetupTest() {
s.now = carbon.Now()
carbon.SetTestNow(s.now)

s.query = NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "users")
s.query = NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "users", nil)
}

func (s *QueryTestSuite) TestCount() {
Expand Down Expand Up @@ -1031,7 +1031,7 @@ func (s *QueryTestSuite) TestWhereExists() {
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE (name = \"John\" AND EXISTS (SELECT * FROM agents WHERE age = 25))", int64(0), nil).Return().Once()

err := s.query.Where("name", "John").WhereExists(func() db.Query {
return NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "agents").Where("age", 25)
return NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "agents", nil).Where("age", 25)
}).Get(&users)
s.Nil(err)

Expand Down
10 changes: 10 additions & 0 deletions database/db/utils.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
package db

import (
"context"
"reflect"
"strings"

"github.com/goravel/framework/errors"
"github.com/goravel/framework/support/carbon"
)

type TxLog struct {
ctx context.Context
begin carbon.Carbon
sql string
rowsAffected int64
err error
}

func convertToSliceMap(data any) ([]map[string]any, error) {
if data == nil {
return nil, nil
Expand Down
1 change: 1 addition & 0 deletions errors/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ var (
DatabaseFailToRunSeeder = New("fail to run seeder: %v")
DatabaseUnsupportedType = New("unsupported type: %s, expected %s")
DatabaseInvalidArgumentNumber = New("invalid argument number: %s, expected %s")
DatabaseTransactionNotStarted = New("transaction not started")

DockerUnknownContainerType = New("unknown container type")
DockerInsufficientDatabaseContainers = New("the number of database container is not enough, expect: %d, got: %d")
Expand Down
Loading

0 comments on commit 9759214

Please sign in to comment.