From 06bd44bee627521b3044a5b7a432f0e6a1d81bb7 Mon Sep 17 00:00:00 2001 From: Wenbo Han Date: Sun, 23 Feb 2025 10:55:30 +0800 Subject: [PATCH] feat: [#358] Add connection method for DB (#909) * feat: [#358] Add connection method for DB * add test --- contracts/database/db/db.go | 5 ++ database/db/db.go | 35 ++++++++++--- database/db/db_test.go | 96 ++++++++++++++++++++++++++++++++++-- database/service_provider.go | 2 +- mocks/database/db/DB.go | 48 ++++++++++++++++++ tests/db_test.go | 53 ++++++++++++++++++++ tests/query.go | 3 +- 7 files changed, 228 insertions(+), 14 deletions(-) diff --git a/contracts/database/db/db.go b/contracts/database/db/db.go index ca4f35803..82930f5b3 100644 --- a/contracts/database/db/db.go +++ b/contracts/database/db/db.go @@ -6,12 +6,16 @@ import ( ) type DB interface { + // BeginTransaction() Query + Connection(name string) DB Table(name string) Query + // Transaction(txFunc func(tx Query) error) error WithContext(ctx context.Context) DB } type Query interface { // Avg(column string) (any, error) + // commit // Count(dest *int64) error // Chunk(size int, callback func(rows []any) error) error // CrossJoin(table string, on any, args ...any) Query @@ -50,6 +54,7 @@ type Query interface { // OrWhereLike() // OrWhereNotLike // Pluck(column string, dest any) error + // rollBack // RightJoin(table string, on any, args ...any) Query // Select(dest any, columns ...string) error // SelectRaw(query string, args ...any) (any, error) diff --git a/database/db/db.go b/database/db/db.go index 0e948665b..a7fb31dda 100644 --- a/database/db/db.go +++ b/database/db/db.go @@ -9,7 +9,6 @@ import ( "github.com/goravel/framework/contracts/config" "github.com/goravel/framework/contracts/database/db" contractsdriver "github.com/goravel/framework/contracts/database/driver" - contractslogger "github.com/goravel/framework/contracts/database/logger" "github.com/goravel/framework/contracts/log" "github.com/goravel/framework/database/logger" "github.com/goravel/framework/errors" @@ -17,16 +16,18 @@ import ( type DB struct { builder db.Builder + config config.Config ctx context.Context driver contractsdriver.Driver - logger contractslogger.Logger + log log.Log + queries map[string]db.DB } -func NewDB(ctx context.Context, driver contractsdriver.Driver, logger contractslogger.Logger, builder db.Builder) db.DB { - return &DB{ctx: ctx, driver: driver, logger: logger, builder: builder} +func NewDB(ctx context.Context, config config.Config, driver contractsdriver.Driver, log log.Log, builder db.Builder) *DB { + return &DB{ctx: ctx, config: config, driver: driver, log: log, builder: builder, queries: make(map[string]db.DB)} } -func BuildDB(config config.Config, log log.Log, connection string) (db.DB, error) { +func BuildDB(ctx context.Context, config config.Config, log log.Log, connection string) (*DB, error) { driverCallback, exist := config.Get(fmt.Sprintf("database.connections.%s.via", connection)).(func() (contractsdriver.Driver, error)) if !exist { return nil, errors.DatabaseConfigNotFound @@ -42,13 +43,31 @@ func BuildDB(config config.Config, log log.Log, connection string) (db.DB, error return nil, err } - return NewDB(context.Background(), driver, logger.NewLogger(config, log), sqlx.NewDb(instance, driver.Config().Driver)), nil + return NewDB(ctx, config, driver, log, sqlx.NewDb(instance, driver.Config().Driver)), nil +} + +func (r *DB) Connection(name string) db.DB { + if name == "" { + name = r.config.GetString("database.default") + } + + if _, ok := r.queries[name]; !ok { + db, err := BuildDB(r.ctx, r.config, r.log, name) + if err != nil { + r.log.Panic(err.Error()) + return nil + } + r.queries[name] = db + db.queries = r.queries + } + + return r.queries[name] } func (r *DB) Table(name string) db.Query { - return NewQuery(r.ctx, r.driver, r.builder, r.logger, name) + return NewQuery(r.ctx, r.driver, r.builder, logger.NewLogger(r.config, r.log), name) } func (r *DB) WithContext(ctx context.Context) db.DB { - return NewDB(ctx, r.driver, r.logger, r.builder) + return NewDB(ctx, r.config, r.driver, r.log, r.builder) } diff --git a/database/db/db_test.go b/database/db/db_test.go index f83491e93..ea2e04657 100644 --- a/database/db/db_test.go +++ b/database/db/db_test.go @@ -1,16 +1,19 @@ package db import ( + "context" "database/sql" "testing" "github.com/stretchr/testify/assert" "github.com/goravel/framework/contracts/database" + contractsdb "github.com/goravel/framework/contracts/database/db" contractsdriver "github.com/goravel/framework/contracts/database/driver" "github.com/goravel/framework/errors" mocksconfig "github.com/goravel/framework/mocks/config" mocksdriver "github.com/goravel/framework/mocks/database/driver" + mockslog "github.com/goravel/framework/mocks/log" ) func TestBuildDB(t *testing.T) { @@ -35,8 +38,6 @@ func TestBuildDB(t *testing.T) { mockConfig.EXPECT().Get("database.connections.mysql.via").Return(driverCallback).Once() mockDriver.EXPECT().DB().Return(&sql.DB{}, nil).Once() mockDriver.EXPECT().Config().Return(database.Config{Driver: "mysql"}).Once() - mockConfig.EXPECT().GetBool("app.debug").Return(false).Once() - mockConfig.EXPECT().GetInt("database.slow_threshold", 200).Return(200).Once() }, expectedError: nil, }, @@ -56,7 +57,7 @@ func TestBuildDB(t *testing.T) { mockDriver = mocksdriver.NewDriver(t) test.setup() - db, err := BuildDB(mockConfig, nil, test.connection) + db, err := BuildDB(context.Background(), mockConfig, nil, test.connection) if test.expectedError != nil { assert.Equal(t, test.expectedError, err) assert.Nil(t, db) @@ -67,3 +68,92 @@ func TestBuildDB(t *testing.T) { }) } } + +func TestConnection(t *testing.T) { + var ( + mockConfig *mocksconfig.Config + mockDriver *mocksdriver.Driver + mockLog *mockslog.Log + ) + + tests := []struct { + name string + connection string + setup func(*DB) + expectedPanic bool + }{ + { + name: "Success with empty connection name", + connection: "", + setup: func(db *DB) { + mockConfig.EXPECT().GetString("database.default").Return("mysql").Once() + driverCallback := func() (contractsdriver.Driver, error) { + return mockDriver, nil + } + mockConfig.EXPECT().Get("database.connections.mysql.via").Return(driverCallback).Once() + mockDriver.EXPECT().DB().Return(&sql.DB{}, nil).Once() + mockDriver.EXPECT().Config().Return(database.Config{Driver: "mysql"}).Once() + }, + expectedPanic: false, + }, + { + name: "Success with specific connection", + connection: "postgres", + setup: func(db *DB) { + driverCallback := func() (contractsdriver.Driver, error) { + return mockDriver, nil + } + mockConfig.EXPECT().Get("database.connections.postgres.via").Return(driverCallback).Once() + mockDriver.EXPECT().DB().Return(&sql.DB{}, nil).Once() + mockDriver.EXPECT().Config().Return(database.Config{Driver: "postgres"}).Once() + }, + expectedPanic: false, + }, + { + name: "Return cached connection", + connection: "mysql", + setup: func(db *DB) { + driverCallback := func() (contractsdriver.Driver, error) { + return mockDriver, nil + } + mockConfig.EXPECT().Get("database.connections.mysql.via").Return(driverCallback).Once() + mockDriver.EXPECT().DB().Return(&sql.DB{}, nil).Once() + mockDriver.EXPECT().Config().Return(database.Config{Driver: "mysql"}).Once() + + cachedDB, _ := BuildDB(context.Background(), mockConfig, mockLog, "mysql") + db.queries = map[string]contractsdb.DB{"mysql": cachedDB} + }, + expectedPanic: false, + }, + { + name: "Panic on BuildDB error", + connection: "invalid", + setup: func(db *DB) { + mockConfig.EXPECT().Get("database.connections.invalid.via").Return(nil).Once() + mockLog.EXPECT().Panic(errors.DatabaseConfigNotFound.Error()).Once() + }, + expectedPanic: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mockConfig = mocksconfig.NewConfig(t) + mockDriver = mocksdriver.NewDriver(t) + mockLog = mockslog.NewLog(t) + + db := NewDB(context.Background(), mockConfig, mockDriver, mockLog, nil) + test.setup(db) + + if test.expectedPanic { + assert.NotPanics(t, func() { + result := db.Connection(test.connection) + assert.Nil(t, result) + }) + } else { + result := db.Connection(test.connection) + assert.NotNil(t, result) + } + }) + } +} diff --git a/database/service_provider.go b/database/service_provider.go index 5509c1769..03ff55673 100644 --- a/database/service_provider.go +++ b/database/service_provider.go @@ -66,7 +66,7 @@ func (r *ServiceProvider) Register(app foundation.Application) { return nil, nil } - return db.BuildDB(config, log, connection) + return db.BuildDB(context.Background(), config, log, connection) }) app.Singleton(contracts.BindingSchema, func(app foundation.Application) (any, error) { diff --git a/mocks/database/db/DB.go b/mocks/database/db/DB.go index faaf9228e..15f348fd4 100644 --- a/mocks/database/db/DB.go +++ b/mocks/database/db/DB.go @@ -22,6 +22,54 @@ func (_m *DB) EXPECT() *DB_Expecter { return &DB_Expecter{mock: &_m.Mock} } +// Connection provides a mock function with given fields: name +func (_m *DB) Connection(name string) db.DB { + ret := _m.Called(name) + + if len(ret) == 0 { + panic("no return value specified for Connection") + } + + var r0 db.DB + if rf, ok := ret.Get(0).(func(string) db.DB); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(db.DB) + } + } + + return r0 +} + +// DB_Connection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Connection' +type DB_Connection_Call struct { + *mock.Call +} + +// Connection is a helper method to define mock.On call +// - name string +func (_e *DB_Expecter) Connection(name interface{}) *DB_Connection_Call { + return &DB_Connection_Call{Call: _e.mock.On("Connection", name)} +} + +func (_c *DB_Connection_Call) Run(run func(name string)) *DB_Connection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *DB_Connection_Call) Return(_a0 db.DB) *DB_Connection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *DB_Connection_Call) RunAndReturn(run func(string) db.DB) *DB_Connection_Call { + _c.Call.Return(run) + return _c +} + // Table provides a mock function with given fields: name func (_m *DB) Table(name string) db.Query { ret := _m.Called(name) diff --git a/tests/db_test.go b/tests/db_test.go index 4e44ee69a..5148ce5fb 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -7,6 +7,7 @@ import ( "github.com/goravel/framework/support/carbon" "github.com/goravel/sqlite" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -240,3 +241,55 @@ func (s *DBTestSuite) TestWhere() { }) } } + +func TestDB_Connection(t *testing.T) { + t.Parallel() + postgresTestQuery := NewTestQueryBuilder().Postgres("", false) + postgresTestQuery.CreateTable(TestTableProducts) + + sqliteTestQuery := NewTestQueryBuilder().Sqlite("", false) + sqliteTestQuery.CreateTable(TestTableProducts) + defer func() { + docker, err := sqliteTestQuery.Driver().Docker() + assert.NoError(t, err) + assert.NoError(t, docker.Shutdown()) + }() + + sqliteConnection := sqliteTestQuery.Driver().Config().Connection + mockDatabaseConfig(postgresTestQuery.MockConfig(), sqliteTestQuery.Driver().Config(), sqliteConnection, "", false) + + result, err := postgresTestQuery.DB().Table("products").Insert(Product{ + Name: "connection", + }) + + assert.NoError(t, err) + assert.Equal(t, int64(1), result.RowsAffected) + + var product Product + err = postgresTestQuery.DB().Table("products").Where("name", "connection").First(&product) + assert.NoError(t, err) + assert.True(t, product.ID > 0) + assert.Equal(t, "connection", product.Name) + + var product1 Product + err = postgresTestQuery.DB().Connection(sqliteConnection).Table("products").Where("name", "connection").First(&product1) + assert.NoError(t, err) + assert.True(t, product1.ID == 0) + + result, err = postgresTestQuery.DB().Connection(sqliteConnection).Table("products").Insert(Product{ + Name: "sqlite connection", + }) + assert.NoError(t, err) + assert.Equal(t, int64(1), result.RowsAffected) + + var product2 Product + err = postgresTestQuery.DB().Connection(sqliteConnection).Table("products").Where("name", "sqlite connection").First(&product2) + assert.NoError(t, err) + assert.True(t, product2.ID > 0) + assert.Equal(t, "sqlite connection", product2.Name) + + var product3 Product + err = postgresTestQuery.DB().Table("products").Where("name", "sqlite connection").First(&product3) + assert.NoError(t, err) + assert.True(t, product3.ID == 0) +} diff --git a/tests/query.go b/tests/query.go index 0f037732d..2250ccfa6 100644 --- a/tests/query.go +++ b/tests/query.go @@ -13,7 +13,6 @@ import ( contractsdocker "github.com/goravel/framework/contracts/testing/docker" databasedb "github.com/goravel/framework/database/db" "github.com/goravel/framework/database/gorm" - "github.com/goravel/framework/database/logger" mocksconfig "github.com/goravel/framework/mocks/config" "github.com/goravel/framework/support/docker" "github.com/goravel/framework/support/str" @@ -49,7 +48,7 @@ func NewTestQuery(ctx context.Context, driver contractsdriver.Driver, config con testQuery := &TestQuery{ config: config, - db: databasedb.NewDB(ctx, driver, logger.NewLogger(config, utils.NewTestLog()), sqlx.NewDb(db, driver.Config().Driver)), + db: databasedb.NewDB(ctx, config, driver, utils.NewTestLog(), sqlx.NewDb(db, driver.Config().Driver)), driver: driver, query: gorm.NewQuery(ctx, config, driver.Config(), query, gormQuery, utils.NewTestLog(), nil, nil), }