Skip to content

Commit

Permalink
feat: [#540] Remove contracts/database/config.go::Driver (#823)
Browse files Browse the repository at this point in the history
* update tests mod

* feat: [#540] Remove config.go::Driver

* add test

* rename

* rename

* fix test

* remove files

* chore: update mocks

* fix test

* fix test

* remove hint

* fix test

---------

Co-authored-by: hwbrzzl <[email protected]>
  • Loading branch information
hwbrzzl and hwbrzzl authored Jan 21, 2025
1 parent cdd2b8b commit 29c6fd4
Show file tree
Hide file tree
Showing 58 changed files with 568 additions and 5,565 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/check_pr_title.yml.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Check PR Title
on:
pull_request:
jobs:
check:
uses: goravel/.github/.github/workflows/check_pr_title.yml@master
secrets: inherit

11 changes: 0 additions & 11 deletions .github/workflows/pr-check-title.yml

This file was deleted.

22 changes: 22 additions & 0 deletions .github/workflows/test_external.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Test External
on:
push:
branches:
- master
pull_request:
jobs:
ubuntu:
strategy:
matrix:
go: [ "1.22", "1.23" ]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- name: Install dependencies
run: go mod tidy
- name: Run tests
run: cd tests && go test -timeout 1h ./...

9 changes: 8 additions & 1 deletion contracts/database/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package driver

import (
"gorm.io/gorm"
"gorm.io/gorm/clause"

"github.com/goravel/framework/contracts/database"
"github.com/goravel/framework/contracts/database/orm"
Expand All @@ -12,8 +13,14 @@ import (
type Driver interface {
Config() database.Config
Docker() (testing.DatabaseDriver, error)
Gorm() (*gorm.DB, error)
Gorm() (*gorm.DB, GormQuery, error)
Grammar() schema.Grammar
Processor() schema.Processor
Schema(orm.Orm) schema.DriverSchema
}

type GormQuery interface {
LockForUpdate() clause.Expression
RandomOrder() string
SharedLock() clause.Expression
}
4 changes: 1 addition & 3 deletions contracts/database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package orm
import (
"context"
"database/sql"

"github.com/goravel/framework/contracts/database"
)

type Orm interface {
Expand Down Expand Up @@ -52,7 +50,7 @@ type Query interface {
// Distinct specifies distinct fields to query.
Distinct(args ...any) Query
// Driver gets the driver for the query.
Driver() database.Driver
Driver() string
// Exec executes raw sql
Exec(sql string, values ...any) (*Result, error)
// Exists returns true if matching records exist; otherwise, it returns false.
Expand Down
15 changes: 7 additions & 8 deletions database/console/show_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,35 +106,34 @@ func (r *ShowCommand) Handle(ctx console.Context) error {

func (r *ShowCommand) getDataBaseInfo() (name, version, openConnections string, err error) {
var (
drivers = map[database.Driver]struct {
drivers = map[string]struct {
name string
versionQuery string
openConnectionsQuery string
}{
database.DriverSqlite: {
database.DriverSqlite.String(): {
name: "SQLite",
versionQuery: "SELECT sqlite_version() AS value;",
},
database.DriverMysql: {
database.DriverMysql.String(): {
name: "MySQL",
versionQuery: "SELECT VERSION() AS value;",
openConnectionsQuery: "SHOW status WHERE variable_name = 'threads_connected';",
},
database.DriverPostgres: {
database.DriverPostgres.String(): {
name: "PostgresSQL",
versionQuery: "SELECT current_setting('server_version') AS value;",
openConnectionsQuery: "SELECT COUNT(*) AS value FROM pg_stat_activity;",
},
database.DriverSqlserver: {
database.DriverSqlserver.String(): {
name: "SQL Server",
versionQuery: "SELECT SERVERPROPERTY('productversion') AS value;",
openConnectionsQuery: "SELECT COUNT(*) Value FROM sys.dm_exec_sessions WHERE status = 'running';",
},
}
)
name = string(r.schema.Orm().Query().Driver())
if driver, ok := drivers[r.schema.Orm().Query().Driver()]; ok {
name = driver.name
name = r.schema.Orm().Query().Driver()
if driver, ok := drivers[name]; ok {
var versionResult queryResult
if err = r.schema.Orm().Query().Raw(driver.versionQuery).Scan(&versionResult); err == nil {
version = versionResult.Value
Expand Down
19 changes: 9 additions & 10 deletions database/console/show_command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/stretchr/testify/assert"

"github.com/goravel/framework/contracts/database"
"github.com/goravel/framework/contracts/database/schema"
mocksconfig "github.com/goravel/framework/mocks/config"
mocksconsole "github.com/goravel/framework/mocks/console"
Expand Down Expand Up @@ -65,9 +64,9 @@ func TestShowCommand(t *testing.T) {
mockConfig.EXPECT().GetString("database.connections.test.host").Return("host").Once()
mockConfig.EXPECT().GetString("database.connections.test.port").Return("port").Once()
mockConfig.EXPECT().GetString("database.connections.test.username").Return("username").Once()
mockQuery.EXPECT().Driver().Return(database.DriverMysql).Twice()
mockOrm.EXPECT().Query().Return(mockQuery).Times(4)
mockSchema.EXPECT().Orm().Return(mockOrm).Times(4)
mockOrm.EXPECT().Query().Return(mockQuery).Times(3)
mockSchema.EXPECT().Orm().Return(mockOrm).Times(3)
mockQuery.EXPECT().Driver().Return("mysql").Once()
mockQuery.EXPECT().Raw("SELECT VERSION() AS value;").Return(mockQuery).Once()
mockQuery.EXPECT().Raw("SHOW status WHERE variable_name = 'threads_connected';").Return(mockQuery).Once()
mockQuery.EXPECT().Scan(&queryResult{}).Return(nil).Twice()
Expand All @@ -86,9 +85,9 @@ func TestShowCommand(t *testing.T) {
mockConfig.EXPECT().GetString("database.connections.test.host").Return("host").Once()
mockConfig.EXPECT().GetString("database.connections.test.port").Return("port").Once()
mockConfig.EXPECT().GetString("database.connections.test.username").Return("username").Once()
mockQuery.EXPECT().Driver().Return(database.DriverMysql).Twice()
mockOrm.EXPECT().Query().Return(mockQuery).Times(4)
mockSchema.EXPECT().Orm().Return(mockOrm).Times(4)
mockOrm.EXPECT().Query().Return(mockQuery).Times(3)
mockSchema.EXPECT().Orm().Return(mockOrm).Times(3)
mockQuery.EXPECT().Driver().Return("mysql").Once()
mockQuery.EXPECT().Raw("SELECT VERSION() AS value;").Return(mockQuery).Once()
mockQuery.EXPECT().Raw("SHOW status WHERE variable_name = 'threads_connected';").Return(mockQuery).Once()
mockQuery.EXPECT().Scan(&queryResult{}).Return(nil).Twice()
Expand All @@ -109,9 +108,9 @@ func TestShowCommand(t *testing.T) {
mockConfig.EXPECT().GetString("database.connections.test.host").Return("host").Once()
mockConfig.EXPECT().GetString("database.connections.test.port").Return("port").Once()
mockConfig.EXPECT().GetString("database.connections.test.username").Return("username").Once()
mockQuery.EXPECT().Driver().Return(database.DriverMysql).Twice()
mockOrm.EXPECT().Query().Return(mockQuery).Times(5)
mockSchema.EXPECT().Orm().Return(mockOrm).Times(5)
mockOrm.EXPECT().Query().Return(mockQuery).Times(4)
mockSchema.EXPECT().Orm().Return(mockOrm).Times(4)
mockQuery.EXPECT().Driver().Return("mysql").Once()
mockQuery.EXPECT().Raw("SELECT VERSION() AS value;").Return(mockQuery).Once()
mockQuery.EXPECT().Raw("SHOW status WHERE variable_name = 'threads_connected';").Return(mockQuery).Once()
mockQuery.EXPECT().Scan(&queryResult{}).Run(func(dest interface{}) {
Expand Down
2 changes: 1 addition & 1 deletion database/gorm/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (e *Event) IsDirty(columns ...string) bool {
}

func (e *Event) Query() orm.Query {
return NewQuery(e.query.ctx, e.query.config, e.query.dbConfig, e.query.instance.Session(&gorm.Session{NewDB: true}), e.query.log, e.query.modelToObserver, nil)
return NewQuery(e.query.ctx, e.query.config, e.query.dbConfig, e.query.instance.Session(&gorm.Session{NewDB: true}), e.query.gormQuery, e.query.log, e.query.modelToObserver, nil)
}

func (e *Event) SetAttribute(key string, value any) {
Expand Down
58 changes: 0 additions & 58 deletions database/gorm/hints/with_hint.go

This file was deleted.

47 changes: 16 additions & 31 deletions database/gorm/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/goravel/framework/contracts/database/driver"
contractsorm "github.com/goravel/framework/contracts/database/orm"
"github.com/goravel/framework/contracts/log"
"github.com/goravel/framework/database/gorm/hints"
"github.com/goravel/framework/errors"
"github.com/goravel/framework/support/database"
)
Expand All @@ -29,6 +28,7 @@ type Query struct {
ctx context.Context
dbConfig contractsdatabase.Config
instance *gormio.DB
gormQuery driver.GormQuery
log log.Log
modelToObserver []contractsorm.ModelToObserver
mutex sync.Mutex
Expand All @@ -40,6 +40,7 @@ func NewQuery(
config config.Config,
dbConfig contractsdatabase.Config,
db *gormio.DB,
gormQuery driver.GormQuery,
log log.Log,
modelToObserver []contractsorm.ModelToObserver,
conditions *Conditions,
Expand All @@ -49,6 +50,7 @@ func NewQuery(
ctx: ctx,
dbConfig: dbConfig,
instance: db,
gormQuery: gormQuery,
log: log,
modelToObserver: modelToObserver,
queries: make(map[string]*Query),
Expand All @@ -72,12 +74,12 @@ func BuildQuery(ctx context.Context, config config.Config, connection string, lo
return nil, err
}

gorm, err := driver.Gorm()
gorm, gormQuery, err := driver.Gorm()
if err != nil {
return nil, err
}

return NewQuery(ctx, config, driver.Config(), gorm, log, modelToObserver, nil), nil
return NewQuery(ctx, config, driver.Config(), gorm, gormQuery, log, modelToObserver, nil), nil
}

func (r *Query) Association(association string) contractsorm.Association {
Expand Down Expand Up @@ -204,8 +206,8 @@ func (r *Query) Distinct(args ...any) contractsorm.Query {
return r.setConditions(conditions)
}

func (r *Query) Driver() contractsdatabase.Driver {
return contractsdatabase.Driver(r.instance.Dialector.Name())
func (r *Query) Driver() string {
return r.dbConfig.Driver
}

func (r *Query) Exec(sql string, values ...any) (*contractsorm.Result, error) {
Expand Down Expand Up @@ -579,18 +581,7 @@ func (r *Query) Instance() *gormio.DB {
}

func (r *Query) InRandomOrder() contractsorm.Query {
order := ""
switch r.Driver() {
case contractsdatabase.DriverMysql:
order = "RAND()"
case contractsdatabase.DriverSqlserver:
order = "NEWID()"
case contractsdatabase.DriverPostgres:
order = "RANDOM()"
case contractsdatabase.DriverSqlite:
order = "RANDOM()"
}
return r.Order(order)
return r.Order(r.gormQuery.RandomOrder())
}

func (r *Query) InTransaction() bool {
Expand Down Expand Up @@ -1050,12 +1041,9 @@ func (r *Query) buildLockForUpdate(db *gormio.DB) *gormio.DB {
return db
}

driver := r.instance.Name()
// TODO To Check if the hardcoded driver names can be optimized
if driver == "mysql" || driver == "postgres" {
return db.Clauses(clause.Locking{Strength: "UPDATE"})
} else if driver == "sqlserver" {
return db.Clauses(hints.With("rowlock", "updlock", "holdlock"))
lockForUpdate := r.gormQuery.LockForUpdate()
if lockForUpdate != nil {
return db.Clauses(lockForUpdate)
}

r.conditions.lockForUpdate = false
Expand Down Expand Up @@ -1155,12 +1143,9 @@ func (r *Query) buildSharedLock(db *gormio.DB) *gormio.DB {
return db
}

driver := r.instance.Name()
// TODO To Check if the hardcoded driver names can be optimized
if driver == "mysql" || driver == "postgres" {
return db.Clauses(clause.Locking{Strength: "SHARE"})
} else if driver == "sqlserver" {
return db.Clauses(hints.With("rowlock", "holdlock"))
sharedLock := r.gormQuery.SharedLock()
if sharedLock != nil {
return db.Clauses(sharedLock)
}

r.conditions.sharedLock = false
Expand Down Expand Up @@ -1208,7 +1193,7 @@ func (r *Query) buildWith(db *gormio.DB) *gormio.DB {
if arg, ok := item.args[0].(func(contractsorm.Query) contractsorm.Query); ok {
newArgs := []any{
func(tx *gormio.DB) *gormio.DB {
queryImpl := NewQuery(r.ctx, r.config, r.dbConfig, tx, r.log, r.modelToObserver, nil)
queryImpl := NewQuery(r.ctx, r.config, r.dbConfig, tx, r.gormQuery, r.log, r.modelToObserver, nil)
query := arg(queryImpl)
queryImpl = query.(*Query)
queryImpl = queryImpl.buildConditions()
Expand Down Expand Up @@ -1361,7 +1346,7 @@ func (r *Query) getObserver(dest any) contractsorm.Observer {
}

func (r *Query) new(db *gormio.DB) *Query {
return NewQuery(r.ctx, r.config, r.dbConfig, db, r.log, r.modelToObserver, &r.conditions)
return NewQuery(r.ctx, r.config, r.dbConfig, db, r.gormQuery, r.log, r.modelToObserver, &r.conditions)
}

func (r *Query) omitCreate(value any) error {
Expand Down
Loading

0 comments on commit 29c6fd4

Please sign in to comment.