Skip to content

Commit

Permalink
feat: [#358] Add connection method for DB (#909)
Browse files Browse the repository at this point in the history
* feat: [#358] Add connection method for DB

* add test
  • Loading branch information
hwbrzzl authored Feb 23, 2025
1 parent d607bf6 commit 06bd44b
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 14 deletions.
5 changes: 5 additions & 0 deletions contracts/database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ import (
)

type DB interface {
// BeginTransaction() Query
Connection(name string) DB
Table(name string) Query
// Transaction(txFunc func(tx Query) error) error
WithContext(ctx context.Context) DB
}

type Query interface {
// Avg(column string) (any, error)
// commit
// Count(dest *int64) error
// Chunk(size int, callback func(rows []any) error) error
// CrossJoin(table string, on any, args ...any) Query
Expand Down Expand Up @@ -50,6 +54,7 @@ type Query interface {
// OrWhereLike()
// OrWhereNotLike
// Pluck(column string, dest any) error
// rollBack
// RightJoin(table string, on any, args ...any) Query
// Select(dest any, columns ...string) error
// SelectRaw(query string, args ...any) (any, error)
Expand Down
35 changes: 27 additions & 8 deletions database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,25 @@ import (
"github.com/goravel/framework/contracts/config"
"github.com/goravel/framework/contracts/database/db"
contractsdriver "github.com/goravel/framework/contracts/database/driver"
contractslogger "github.com/goravel/framework/contracts/database/logger"
"github.com/goravel/framework/contracts/log"
"github.com/goravel/framework/database/logger"
"github.com/goravel/framework/errors"
)

type DB struct {
builder db.Builder
config config.Config
ctx context.Context
driver contractsdriver.Driver
logger contractslogger.Logger
log log.Log
queries map[string]db.DB
}

func NewDB(ctx context.Context, driver contractsdriver.Driver, logger contractslogger.Logger, builder db.Builder) db.DB {
return &DB{ctx: ctx, driver: driver, logger: logger, builder: builder}
func NewDB(ctx context.Context, config config.Config, driver contractsdriver.Driver, log log.Log, builder db.Builder) *DB {
return &DB{ctx: ctx, config: config, driver: driver, log: log, builder: builder, queries: make(map[string]db.DB)}
}

func BuildDB(config config.Config, log log.Log, connection string) (db.DB, error) {
func BuildDB(ctx context.Context, config config.Config, log log.Log, connection string) (*DB, error) {
driverCallback, exist := config.Get(fmt.Sprintf("database.connections.%s.via", connection)).(func() (contractsdriver.Driver, error))
if !exist {
return nil, errors.DatabaseConfigNotFound
Expand All @@ -42,13 +43,31 @@ func BuildDB(config config.Config, log log.Log, connection string) (db.DB, error
return nil, err
}

return NewDB(context.Background(), driver, logger.NewLogger(config, log), sqlx.NewDb(instance, driver.Config().Driver)), nil
return NewDB(ctx, config, driver, log, sqlx.NewDb(instance, driver.Config().Driver)), nil
}

func (r *DB) Connection(name string) db.DB {
if name == "" {
name = r.config.GetString("database.default")
}

if _, ok := r.queries[name]; !ok {
db, err := BuildDB(r.ctx, r.config, r.log, name)
if err != nil {
r.log.Panic(err.Error())
return nil
}
r.queries[name] = db
db.queries = r.queries
}

return r.queries[name]
}

func (r *DB) Table(name string) db.Query {
return NewQuery(r.ctx, r.driver, r.builder, r.logger, name)
return NewQuery(r.ctx, r.driver, r.builder, logger.NewLogger(r.config, r.log), name)
}

func (r *DB) WithContext(ctx context.Context) db.DB {
return NewDB(ctx, r.driver, r.logger, r.builder)
return NewDB(ctx, r.config, r.driver, r.log, r.builder)
}
96 changes: 93 additions & 3 deletions database/db/db_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package db

import (
"context"
"database/sql"
"testing"

"github.com/stretchr/testify/assert"

"github.com/goravel/framework/contracts/database"
contractsdb "github.com/goravel/framework/contracts/database/db"
contractsdriver "github.com/goravel/framework/contracts/database/driver"
"github.com/goravel/framework/errors"
mocksconfig "github.com/goravel/framework/mocks/config"
mocksdriver "github.com/goravel/framework/mocks/database/driver"
mockslog "github.com/goravel/framework/mocks/log"
)

func TestBuildDB(t *testing.T) {
Expand All @@ -35,8 +38,6 @@ func TestBuildDB(t *testing.T) {
mockConfig.EXPECT().Get("database.connections.mysql.via").Return(driverCallback).Once()
mockDriver.EXPECT().DB().Return(&sql.DB{}, nil).Once()
mockDriver.EXPECT().Config().Return(database.Config{Driver: "mysql"}).Once()
mockConfig.EXPECT().GetBool("app.debug").Return(false).Once()
mockConfig.EXPECT().GetInt("database.slow_threshold", 200).Return(200).Once()
},
expectedError: nil,
},
Expand All @@ -56,7 +57,7 @@ func TestBuildDB(t *testing.T) {
mockDriver = mocksdriver.NewDriver(t)
test.setup()

db, err := BuildDB(mockConfig, nil, test.connection)
db, err := BuildDB(context.Background(), mockConfig, nil, test.connection)
if test.expectedError != nil {
assert.Equal(t, test.expectedError, err)
assert.Nil(t, db)
Expand All @@ -67,3 +68,92 @@ func TestBuildDB(t *testing.T) {
})
}
}

func TestConnection(t *testing.T) {
var (
mockConfig *mocksconfig.Config
mockDriver *mocksdriver.Driver
mockLog *mockslog.Log
)

tests := []struct {
name string
connection string
setup func(*DB)
expectedPanic bool
}{
{
name: "Success with empty connection name",
connection: "",
setup: func(db *DB) {
mockConfig.EXPECT().GetString("database.default").Return("mysql").Once()
driverCallback := func() (contractsdriver.Driver, error) {
return mockDriver, nil
}
mockConfig.EXPECT().Get("database.connections.mysql.via").Return(driverCallback).Once()
mockDriver.EXPECT().DB().Return(&sql.DB{}, nil).Once()
mockDriver.EXPECT().Config().Return(database.Config{Driver: "mysql"}).Once()
},
expectedPanic: false,
},
{
name: "Success with specific connection",
connection: "postgres",
setup: func(db *DB) {
driverCallback := func() (contractsdriver.Driver, error) {
return mockDriver, nil
}
mockConfig.EXPECT().Get("database.connections.postgres.via").Return(driverCallback).Once()
mockDriver.EXPECT().DB().Return(&sql.DB{}, nil).Once()
mockDriver.EXPECT().Config().Return(database.Config{Driver: "postgres"}).Once()
},
expectedPanic: false,
},
{
name: "Return cached connection",
connection: "mysql",
setup: func(db *DB) {
driverCallback := func() (contractsdriver.Driver, error) {
return mockDriver, nil
}
mockConfig.EXPECT().Get("database.connections.mysql.via").Return(driverCallback).Once()
mockDriver.EXPECT().DB().Return(&sql.DB{}, nil).Once()
mockDriver.EXPECT().Config().Return(database.Config{Driver: "mysql"}).Once()

cachedDB, _ := BuildDB(context.Background(), mockConfig, mockLog, "mysql")
db.queries = map[string]contractsdb.DB{"mysql": cachedDB}
},
expectedPanic: false,
},
{
name: "Panic on BuildDB error",
connection: "invalid",
setup: func(db *DB) {
mockConfig.EXPECT().Get("database.connections.invalid.via").Return(nil).Once()
mockLog.EXPECT().Panic(errors.DatabaseConfigNotFound.Error()).Once()
},
expectedPanic: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
mockConfig = mocksconfig.NewConfig(t)
mockDriver = mocksdriver.NewDriver(t)
mockLog = mockslog.NewLog(t)

db := NewDB(context.Background(), mockConfig, mockDriver, mockLog, nil)
test.setup(db)

if test.expectedPanic {
assert.NotPanics(t, func() {
result := db.Connection(test.connection)
assert.Nil(t, result)
})
} else {
result := db.Connection(test.connection)
assert.NotNil(t, result)
}
})
}
}
2 changes: 1 addition & 1 deletion database/service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (r *ServiceProvider) Register(app foundation.Application) {
return nil, nil
}

return db.BuildDB(config, log, connection)
return db.BuildDB(context.Background(), config, log, connection)
})

app.Singleton(contracts.BindingSchema, func(app foundation.Application) (any, error) {
Expand Down
48 changes: 48 additions & 0 deletions mocks/database/db/DB.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions tests/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/goravel/framework/support/carbon"
"github.com/goravel/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)

Expand Down Expand Up @@ -240,3 +241,55 @@ func (s *DBTestSuite) TestWhere() {
})
}
}

func TestDB_Connection(t *testing.T) {
t.Parallel()
postgresTestQuery := NewTestQueryBuilder().Postgres("", false)
postgresTestQuery.CreateTable(TestTableProducts)

sqliteTestQuery := NewTestQueryBuilder().Sqlite("", false)
sqliteTestQuery.CreateTable(TestTableProducts)
defer func() {
docker, err := sqliteTestQuery.Driver().Docker()
assert.NoError(t, err)
assert.NoError(t, docker.Shutdown())
}()

sqliteConnection := sqliteTestQuery.Driver().Config().Connection
mockDatabaseConfig(postgresTestQuery.MockConfig(), sqliteTestQuery.Driver().Config(), sqliteConnection, "", false)

result, err := postgresTestQuery.DB().Table("products").Insert(Product{
Name: "connection",
})

assert.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected)

var product Product
err = postgresTestQuery.DB().Table("products").Where("name", "connection").First(&product)
assert.NoError(t, err)
assert.True(t, product.ID > 0)
assert.Equal(t, "connection", product.Name)

var product1 Product
err = postgresTestQuery.DB().Connection(sqliteConnection).Table("products").Where("name", "connection").First(&product1)
assert.NoError(t, err)
assert.True(t, product1.ID == 0)

result, err = postgresTestQuery.DB().Connection(sqliteConnection).Table("products").Insert(Product{
Name: "sqlite connection",
})
assert.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected)

var product2 Product
err = postgresTestQuery.DB().Connection(sqliteConnection).Table("products").Where("name", "sqlite connection").First(&product2)
assert.NoError(t, err)
assert.True(t, product2.ID > 0)
assert.Equal(t, "sqlite connection", product2.Name)

var product3 Product
err = postgresTestQuery.DB().Table("products").Where("name", "sqlite connection").First(&product3)
assert.NoError(t, err)
assert.True(t, product3.ID == 0)
}
3 changes: 1 addition & 2 deletions tests/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
contractsdocker "github.com/goravel/framework/contracts/testing/docker"
databasedb "github.com/goravel/framework/database/db"
"github.com/goravel/framework/database/gorm"
"github.com/goravel/framework/database/logger"
mocksconfig "github.com/goravel/framework/mocks/config"
"github.com/goravel/framework/support/docker"
"github.com/goravel/framework/support/str"
Expand Down Expand Up @@ -49,7 +48,7 @@ func NewTestQuery(ctx context.Context, driver contractsdriver.Driver, config con

testQuery := &TestQuery{
config: config,
db: databasedb.NewDB(ctx, driver, logger.NewLogger(config, utils.NewTestLog()), sqlx.NewDb(db, driver.Config().Driver)),
db: databasedb.NewDB(ctx, config, driver, utils.NewTestLog(), sqlx.NewDb(db, driver.Config().Driver)),
driver: driver,
query: gorm.NewQuery(ctx, config, driver.Config(), query, gormQuery, utils.NewTestLog(), nil, nil),
}
Expand Down

0 comments on commit 06bd44b

Please sign in to comment.