Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [#358] Implement Transaction methods #933

Merged
merged 4 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions contracts/database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ import (
)

type DB interface {
// BeginTransaction() Query
BeginTransaction() (DB, error)
Commit() error
Connection(name string) DB
Rollback() error
Table(name string) Query
// Transaction(txFunc func(tx Query) error) error
Transaction(txFunc func(tx DB) error) error
WithContext(ctx context.Context) DB
}

type Query interface {
// commit
// Count Retrieve the "count" result of the query.
Count() (int64, error)
// Chunk Execute a callback over a given chunk size.
Expand Down Expand Up @@ -93,7 +94,6 @@ type Query interface {
OrWhereRaw(raw string, args []any) Query
// Pluck Get a collection instance containing the values of a given column.
Pluck(column string, dest any) error
// rollBack
// RightJoin(table string, on any, args ...any) Query
// Select Set the columns to be selected.
Select(columns ...string) Query
Expand All @@ -104,7 +104,6 @@ type Query interface {
// ToRawSql
// Update records in the database.
Update(column any, value ...any) (*Result, error)
// updateOrInsert
// Value(column string, dest any) error
// When executes the callback if the condition is true.
When(condition bool, callback func(query Query) Query) Query
Expand Down
81 changes: 70 additions & 11 deletions database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,28 @@ import (
"github.com/jmoiron/sqlx"

"github.com/goravel/framework/contracts/config"
"github.com/goravel/framework/contracts/database/db"
contractsdb "github.com/goravel/framework/contracts/database/db"
contractsdriver "github.com/goravel/framework/contracts/database/driver"
contractslogger "github.com/goravel/framework/contracts/database/logger"
"github.com/goravel/framework/contracts/log"
"github.com/goravel/framework/database/logger"
"github.com/goravel/framework/errors"
)

type DB struct {
builder db.Builder
config config.Config
ctx context.Context
db *sqlx.DB
driver contractsdriver.Driver
log log.Log
queries map[string]db.DB
logger contractslogger.Logger
queries map[string]contractsdb.DB
tx *sqlx.Tx
txLogs *[]TxLog
}

func NewDB(ctx context.Context, config config.Config, driver contractsdriver.Driver, log log.Log, builder db.Builder) *DB {
return &DB{ctx: ctx, config: config, driver: driver, log: log, builder: builder, queries: make(map[string]db.DB)}
func NewDB(ctx context.Context, config config.Config, driver contractsdriver.Driver, log log.Log, db *sqlx.DB, tx *sqlx.Tx, txLogs *[]TxLog) *DB {
return &DB{ctx: ctx, config: config, driver: driver, log: log, logger: logger.NewLogger(config, log), db: db, queries: make(map[string]contractsdb.DB), tx: tx, txLogs: txLogs}
}

func BuildDB(ctx context.Context, config config.Config, log log.Log, connection string) (*DB, error) {
Expand All @@ -43,10 +47,35 @@ func BuildDB(ctx context.Context, config config.Config, log log.Log, connection
return nil, err
}

return NewDB(ctx, config, driver, log, sqlx.NewDb(instance, driver.Config().Driver)), nil
return NewDB(ctx, config, driver, log, sqlx.NewDb(instance, driver.Config().Driver), nil, nil), nil
}

func (r *DB) Connection(name string) db.DB {
func (r *DB) BeginTransaction() (contractsdb.DB, error) {
tx, err := r.db.Beginx()
if err != nil {
return nil, err
}

return NewDB(r.ctx, r.config, r.driver, r.log, nil, tx, &[]TxLog{}), nil
}

func (r *DB) Commit() error {
if r.tx == nil {
return errors.DatabaseTransactionNotStarted
}

if err := r.tx.Commit(); err != nil {
return err
}

for _, log := range *r.txLogs {
r.logger.Trace(log.ctx, log.begin, log.sql, log.rowsAffected, log.err)
}

return nil
}

func (r *DB) Connection(name string) contractsdb.DB {
if name == "" {
name = r.config.GetString("database.default")
}
Expand All @@ -64,10 +93,40 @@ func (r *DB) Connection(name string) db.DB {
return r.queries[name]
}

func (r *DB) Table(name string) db.Query {
return NewQuery(r.ctx, r.driver, r.builder, logger.NewLogger(r.config, r.log), name)
func (r *DB) Rollback() error {
if r.tx == nil {
return errors.DatabaseTransactionNotStarted
}

return r.tx.Rollback()
}

func (r *DB) Table(name string) contractsdb.Query {
if r.tx != nil {
return NewQuery(r.ctx, r.driver, r.tx, r.logger, name, r.txLogs)
}

return NewQuery(r.ctx, r.driver, r.db, r.logger, name, nil)
}

func (r *DB) Transaction(callback func(tx contractsdb.DB) error) error {
tx, err := r.BeginTransaction()
if err != nil {
return err
}

err = callback(tx)
if err != nil {
if err := tx.Rollback(); err != nil {
return err
}

return err
}

return tx.Commit()
}

func (r *DB) WithContext(ctx context.Context) db.DB {
return NewDB(ctx, r.config, r.driver, r.log, r.builder)
func (r *DB) WithContext(ctx context.Context) contractsdb.DB {
return NewDB(ctx, r.config, r.driver, r.log, r.db, r.tx, r.txLogs)
}
2 changes: 1 addition & 1 deletion database/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestConnection(t *testing.T) {
mockDriver = mocksdriver.NewDriver(t)
mockLog = mockslog.NewLog(t)

db := NewDB(context.Background(), mockConfig, mockDriver, mockLog, nil)
db := NewDB(context.Background(), mockConfig, mockDriver, mockLog, nil, nil, nil)
test.setup(db)

if test.expectedPanic {
Expand Down
36 changes: 18 additions & 18 deletions database/db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ type Query struct {
err error
driver driver.Driver
logger logger.Logger
single bool
txLogs *[]TxLog
}

func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, logger logger.Logger, table string) *Query {
func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, logger logger.Logger, table string, txLogs *[]TxLog) *Query {
return &Query{
builder: builder,
conditions: Conditions{
Expand All @@ -39,16 +39,10 @@ func NewQuery(ctx context.Context, driver driver.Driver, builder db.Builder, log
ctx: ctx,
driver: driver,
logger: logger,
txLogs: txLogs,
}
}

func NewSingleQuery(ctx context.Context, driver driver.Driver, builder db.Builder, logger logger.Logger, table string) *Query {
query := NewQuery(ctx, driver, builder, logger, table)
query.single = true

return query
}

func (r *Query) Count() (int64, error) {
r.conditions.Selects = []string{"COUNT(*)"}

Expand Down Expand Up @@ -768,10 +762,10 @@ func (r *Query) buildWhere(where Where) (any, []any, error) {
}
}
return query, where.args, nil
case func(db.Query):
case func(db.Query) db.Query:
// Handle nested conditions by creating a new query and applying the callback
nestedQuery := NewSingleQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table)
query(nestedQuery)
nestedQuery := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table, r.txLogs)
nestedQuery = query(nestedQuery).(*Query)

// Build the nested conditions
sqlizer, err := r.buildWheres(nestedQuery.conditions.Where)
Expand Down Expand Up @@ -834,11 +828,7 @@ func (r *Query) buildWheres(wheres []Where) (sq.Sqlizer, error) {
}

func (r *Query) clone() *Query {
if r.single {
return r
}

query := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table)
query := NewQuery(r.ctx, r.driver, r.builder, r.logger, r.conditions.Table, r.txLogs)
query.conditions = r.conditions
query.err = r.err

Expand Down Expand Up @@ -867,5 +857,15 @@ func (r *Query) toSqlizer(query any, args []any) (sq.Sqlizer, error) {
}

func (r *Query) trace(sql string, args []any, rowsAffected int64, err error) {
r.logger.Trace(r.ctx, carbon.Now(), r.driver.Explain(sql, args...), rowsAffected, err)
if r.txLogs != nil {
*r.txLogs = append(*r.txLogs, TxLog{
ctx: r.ctx,
begin: carbon.Now(),
sql: r.driver.Explain(sql, args...),
rowsAffected: rowsAffected,
err: err,
})
} else {
r.logger.Trace(r.ctx, carbon.Now(), r.driver.Explain(sql, args...), rowsAffected, err)
}
}
4 changes: 2 additions & 2 deletions database/db/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *QueryTestSuite) SetupTest() {
s.now = carbon.Now()
carbon.SetTestNow(s.now)

s.query = NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "users")
s.query = NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "users", nil)
}

func (s *QueryTestSuite) TestCount() {
Expand Down Expand Up @@ -1031,7 +1031,7 @@ func (s *QueryTestSuite) TestWhereExists() {
s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE (name = \"John\" AND EXISTS (SELECT * FROM agents WHERE age = 25))", int64(0), nil).Return().Once()

err := s.query.Where("name", "John").WhereExists(func() db.Query {
return NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "agents").Where("age", 25)
return NewQuery(s.ctx, s.mockDriver, s.mockBuilder, s.mockLogger, "agents", nil).Where("age", 25)
}).Get(&users)
s.Nil(err)

Expand Down
10 changes: 10 additions & 0 deletions database/db/utils.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
package db

import (
"context"
"reflect"
"strings"

"github.com/goravel/framework/errors"
"github.com/goravel/framework/support/carbon"
)

type TxLog struct {
ctx context.Context
begin carbon.Carbon
sql string
rowsAffected int64
err error
}

func convertToSliceMap(data any) ([]map[string]any, error) {
if data == nil {
return nil, nil
Expand Down
1 change: 1 addition & 0 deletions errors/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ var (
DatabaseFailToRunSeeder = New("fail to run seeder: %v")
DatabaseUnsupportedType = New("unsupported type: %s, expected %s")
DatabaseInvalidArgumentNumber = New("invalid argument number: %s, expected %s")
DatabaseTransactionNotStarted = New("transaction not started")

DockerUnknownContainerType = New("unknown container type")
DockerInsufficientDatabaseContainers = New("the number of database container is not enough, expect: %d, got: %d")
Expand Down
Loading
Loading