Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [#358] Add Insert First methods for DB #888

Merged
merged 15 commits into from
Feb 19, 2025
4 changes: 2 additions & 2 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ var testUserGuard = "user"
type User struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string
CreatedAt carbon.DateTime `gorm:"autoCreateTime;column:created_at" json:"created_at"`
UpdatedAt carbon.DateTime `gorm:"autoUpdateTime;column:updated_at" json:"updated_at"`
CreatedAt *carbon.DateTime `gorm:"autoCreateTime;column:created_at" json:"created_at"`
UpdatedAt *carbon.DateTime `gorm:"autoUpdateTime;column:updated_at" json:"updated_at"`
}

type Context struct {
Expand Down
16 changes: 15 additions & 1 deletion contracts/database/db/db.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
package db

import "database/sql"

type DB interface {
Table(name string) Query
}

type Query interface {
Where(query any, args ...any) Query
First(dest any) error
Get(dest any) error
Insert(data any) (*Result, error)
Where(query any, args ...any) Query
}

type Result struct {
RowsAffected int64
}

type Builder interface {
Exec(query string, args ...any) (sql.Result, error)
Get(dest any, query string, args ...any) error
Select(dest any, query string, args ...any) error
}
10 changes: 5 additions & 5 deletions database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ import (
)

type DB struct {
config database.Config
instance *sqlx.DB
builder db.Builder
config database.Config
}

func NewDB(config database.Config, instance *sqlx.DB) db.DB {
return &DB{config: config, instance: instance}
func NewDB(config database.Config, builder db.Builder) db.DB {
return &DB{config: config, builder: builder}
}

func BuildDB(config config.Config, connection string) (db.DB, error) {
Expand All @@ -41,5 +41,5 @@ func BuildDB(config config.Config, connection string) (db.DB, error) {
}

func (r *DB) Table(name string) db.Query {
return NewQuery(r.config, r.instance, name)
return NewQuery(r.config, r.builder, name)
}
67 changes: 67 additions & 0 deletions database/db/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package db

import (
"database/sql"
"testing"

"github.com/stretchr/testify/assert"

"github.com/goravel/framework/contracts/database"
contractsdriver "github.com/goravel/framework/contracts/database/driver"
"github.com/goravel/framework/errors"
mocksconfig "github.com/goravel/framework/mocks/config"
mocksdriver "github.com/goravel/framework/mocks/database/driver"
)

func TestBuildDB(t *testing.T) {
var (
mockConfig *mocksconfig.Config
mockDriver *mocksdriver.Driver
)

tests := []struct {
name string
connection string
setup func()
expectedError error
}{
{
name: "Success",
connection: "mysql",
setup: func() {
driverCallback := func() (contractsdriver.Driver, error) {
return mockDriver, nil
}
mockConfig.On("Get", "database.connections.mysql.via").Return(driverCallback)
mockDriver.On("DB").Return(&sql.DB{}, nil)
mockDriver.On("Config").Return(database.Config{Driver: "mysql"})
},
expectedError: nil,
},
{
name: "Config Not Found",
connection: "invalid",
setup: func() {
mockConfig.On("Get", "database.connections.invalid.via").Return(nil)
},
expectedError: errors.DatabaseConfigNotFound,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
mockConfig = mocksconfig.NewConfig(t)
mockDriver = mocksdriver.NewDriver(t)
test.setup()

db, err := BuildDB(mockConfig, test.connection)
if test.expectedError != nil {
assert.Equal(t, test.expectedError, err)
assert.Nil(t, db)
} else {
assert.NoError(t, err)
assert.NotNil(t, db)
}
})
}
}
108 changes: 96 additions & 12 deletions database/db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,41 @@ package db

import (
"fmt"
"sort"

sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"

"github.com/goravel/framework/contracts/database"
"github.com/goravel/framework/contracts/database/db"
"github.com/goravel/framework/errors"
"github.com/goravel/framework/support/str"
)

type Query struct {
builder db.Builder
conditions Conditions
config database.Config
instance *sqlx.DB
}

func NewQuery(config database.Config, instance *sqlx.DB, table string) *Query {
func NewQuery(config database.Config, builder db.Builder, table string) *Query {
return &Query{
builder: builder,
conditions: Conditions{
table: table,
},
config: config,
instance: instance,
config: config,
}
}

func (r *Query) Where(query any, args ...any) db.Query {
r.conditions.where = append(r.conditions.where, Where{
query: query,
args: args,
})
func (r *Query) First(dest any) error {
sql, args, err := r.buildSelect()
// TODO: use logger instead of println
fmt.Println(sql, args, err)
if err != nil {
return err
}

return r
return r.builder.Get(dest, sql, args...)
}

func (r *Query) Get(dest any) error {
Expand All @@ -44,7 +47,76 @@ func (r *Query) Get(dest any) error {
return err
}

return r.instance.Select(dest, sql, args...)
return r.builder.Select(dest, sql, args...)
}

func (r *Query) Insert(data any) (*db.Result, error) {
mapData, err := convertToSliceMap(data)
if err != nil {
return nil, err
}

sql, args, err := r.buildInsert(mapData)
if err != nil {
return nil, err
}
// TODO: use logger instead of println
fmt.Println(sql, args, err)
result, err := r.builder.Exec(sql, args...)
if err != nil {
return nil, err
}

rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, err
}

return &db.Result{
RowsAffected: rowsAffected,
}, nil
}

func (r *Query) Where(query any, args ...any) db.Query {
q := NewQuery(r.config, r.builder, r.conditions.table)
q.conditions = r.conditions
q.conditions.where = append(r.conditions.where, Where{
query: query,
args: args,
})

return q
}

func (r *Query) buildInsert(data []map[string]any) (sql string, args []any, err error) {
if r.conditions.table == "" {
return "", nil, errors.DatabaseTableIsRequired
}

builder := sq.Insert(r.conditions.table)
if r.config.PlaceholderFormat != nil {
builder = builder.PlaceholderFormat(r.config.PlaceholderFormat)
}

first := data[0]
builder = builder.SetMap(first)

cols := make([]string, 0, len(first))
for col := range first {
cols = append(cols, col)
}
sort.Strings(cols)
builder = builder.Columns(cols...)

for _, row := range data {
vals := make([]any, 0, len(first))
for _, col := range cols {
vals = append(vals, row[col])
}
builder = builder.Values(vals...)
}

return builder.ToSql()
}

func (r *Query) buildSelect() (sql string, args []any, err error) {
Expand All @@ -60,6 +132,18 @@ func (r *Query) buildSelect() (sql string, args []any, err error) {
builder = builder.From(r.conditions.table)

for _, where := range r.conditions.where {
query, ok := where.query.(string)
if ok {
if !str.Of(query).Trim().Contains(" ", "?") {
if len(where.args) > 1 {
builder = builder.Where(sq.Eq{query: where.args})
} else {
builder = builder.Where(sq.Eq{query: where.args[0]})
}
continue
}
}

builder = builder.Where(where.query, where.args...)
}

Expand Down
Loading
Loading