Skip to content

Commit

Permalink
finish
Browse files Browse the repository at this point in the history
  • Loading branch information
hwbrzzl committed Feb 18, 2025
1 parent fcb9e89 commit d01662c
Show file tree
Hide file tree
Showing 12 changed files with 979 additions and 128 deletions.
17 changes: 14 additions & 3 deletions contracts/database/db/db.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
package db

import "database/sql"

type DB interface {
Table(name string) Query
}

type Query interface {
Delete() error
First(dest any) error
Get(dest any) error
Insert() error
Update() error
Insert(data any) (*Result, error)
Where(query any, args ...any) Query
}

type Result struct {
RowsAffected int64
}

type Builder interface {
Exec(query string, args ...any) (sql.Result, error)
Get(dest any, query string, args ...any) error
Select(dest any, query string, args ...any) error
}
10 changes: 5 additions & 5 deletions database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ import (
)

type DB struct {
config database.Config
instance *sqlx.DB
builder db.Builder
config database.Config
}

func NewDB(config database.Config, instance *sqlx.DB) db.DB {
return &DB{config: config, instance: instance}
func NewDB(config database.Config, builder db.Builder) db.DB {
return &DB{config: config, builder: builder}
}

func BuildDB(config config.Config, connection string) (db.DB, error) {
Expand All @@ -41,5 +41,5 @@ func BuildDB(config config.Config, connection string) (db.DB, error) {
}

func (r *DB) Table(name string) db.Query {
return NewQuery(r.config, r.instance, name)
return NewQuery(r.config, r.builder, name)
}
67 changes: 67 additions & 0 deletions database/db/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package db

import (
"database/sql"
"testing"

"github.com/stretchr/testify/assert"

"github.com/goravel/framework/contracts/database"
contractsdriver "github.com/goravel/framework/contracts/database/driver"
"github.com/goravel/framework/errors"
mocksconfig "github.com/goravel/framework/mocks/config"
mocksdriver "github.com/goravel/framework/mocks/database/driver"
)

func TestBuildDB(t *testing.T) {
var (
mockConfig *mocksconfig.Config
mockDriver *mocksdriver.Driver
)

tests := []struct {
name string
connection string
setup func()
expectedError error
}{
{
name: "Success",
connection: "mysql",
setup: func() {
driverCallback := func() (contractsdriver.Driver, error) {
return mockDriver, nil
}
mockConfig.On("Get", "database.connections.mysql.via").Return(driverCallback)
mockDriver.On("DB").Return(&sql.DB{}, nil)
mockDriver.On("Config").Return(database.Config{Driver: "mysql"})
},
expectedError: nil,
},
{
name: "Config Not Found",
connection: "invalid",
setup: func() {
mockConfig.On("Get", "database.connections.invalid.via").Return(nil)
},
expectedError: errors.DatabaseConfigNotFound,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
mockConfig = mocksconfig.NewConfig(t)
mockDriver = mocksdriver.NewDriver(t)
test.setup()

db, err := BuildDB(mockConfig, test.connection)
if test.expectedError != nil {
assert.Equal(t, test.expectedError, err)
assert.Nil(t, db)
} else {
assert.NoError(t, err)
assert.NotNil(t, db)
}
})
}
}
112 changes: 76 additions & 36 deletions database/db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,41 @@ package db

import (
"fmt"
"sort"

sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"

"github.com/goravel/framework/contracts/database"
"github.com/goravel/framework/contracts/database/db"
"github.com/goravel/framework/errors"
"github.com/goravel/framework/support/str"
)

type Query struct {
builder db.Builder
conditions Conditions
config database.Config
instance *sqlx.DB
}

func NewQuery(config database.Config, instance *sqlx.DB, table string) *Query {
func NewQuery(config database.Config, builder db.Builder, table string) *Query {
return &Query{
builder: builder,
conditions: Conditions{
table: table,
},
config: config,
instance: instance,
config: config,
}
}

func (r *Query) Delete() error {
return nil
}

func (r *Query) Where(query any, args ...any) db.Query {
r.conditions.where = append(r.conditions.where, Where{
query: query,
args: args,
})
func (r *Query) First(dest any) error {
sql, args, err := r.buildSelect()
// TODO: use logger instead of println
fmt.Println(sql, args, err)
if err != nil {
return err
}

return r
return r.builder.Get(dest, sql, args...)
}

func (r *Query) Get(dest any) error {
Expand All @@ -48,50 +47,79 @@ func (r *Query) Get(dest any) error {
return err
}

return r.instance.Select(dest, sql, args...)
return r.builder.Select(dest, sql, args...)
}

func (r *Query) Insert() error {
return nil
}

func (r *Query) Update() error {
return nil
}
func (r *Query) Insert(data any) (*db.Result, error) {
mapData, err := convertToSliceMap(data)
if err != nil {
return nil, err
}

func (r *Query) buildInsert() (sql string, args []any, err error) {
if r.conditions.table == "" {
return "", nil, errors.DatabaseTableIsRequired
sql, args, err := r.buildInsert(mapData)
if err != nil {
return nil, err
}
// TODO: use logger instead of println
fmt.Println(sql, args, err)
result, err := r.builder.Exec(sql, args...)
if err != nil {
return nil, err
}

builder := sq.Insert(r.conditions.table)
if r.config.PlaceholderFormat != nil {
builder = builder.PlaceholderFormat(r.config.PlaceholderFormat)
rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, err
}

return builder.ToSql()
return &db.Result{
RowsAffected: rowsAffected,
}, nil
}

func (r *Query) buildSelect() (sql string, args []any, err error) {
func (r *Query) Where(query any, args ...any) db.Query {
q := NewQuery(r.config, r.builder, r.conditions.table)
q.conditions = r.conditions
q.conditions.where = append(r.conditions.where, Where{
query: query,
args: args,
})

return q
}

func (r *Query) buildInsert(data []map[string]any) (sql string, args []any, err error) {
if r.conditions.table == "" {
return "", nil, errors.DatabaseTableIsRequired
}

builder := sq.Select("*")
builder := sq.Insert(r.conditions.table)
if r.config.PlaceholderFormat != nil {
builder = builder.PlaceholderFormat(r.config.PlaceholderFormat)
}

builder = builder.From(r.conditions.table)
first := data[0]
builder = builder.SetMap(first)

for _, where := range r.conditions.where {
builder = builder.Where(where.query, where.args...)
cols := make([]string, 0, len(first))
for col := range first {
cols = append(cols, col)
}
sort.Strings(cols)
builder = builder.Columns(cols...)

for _, row := range data {
vals := make([]any, 0, len(first))
for _, col := range cols {
vals = append(vals, row[col])
}
builder = builder.Values(vals...)
}

return builder.ToSql()
}

func (r *Query) buildUpdate() (sql string, args []any, err error) {
func (r *Query) buildSelect() (sql string, args []any, err error) {
if r.conditions.table == "" {
return "", nil, errors.DatabaseTableIsRequired
}
Expand All @@ -104,6 +132,18 @@ func (r *Query) buildUpdate() (sql string, args []any, err error) {
builder = builder.From(r.conditions.table)

for _, where := range r.conditions.where {
query, ok := where.query.(string)
if ok {
if !str.Of(query).Trim().Contains(" ", "?") {
if len(where.args) > 1 {
builder = builder.Where(sq.Eq{query: where.args})
} else {
builder = builder.Where(sq.Eq{query: where.args[0]})
}
continue
}
}

builder = builder.Where(where.query, where.args...)
}

Expand Down
Loading

0 comments on commit d01662c

Please sign in to comment.