Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [#358] Implement LockForUpdate, SharedLock, Cursor, InRandomOrder methods #946

Merged
merged 8 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions contracts/database/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package db
import (
"context"
"database/sql"

"github.com/jmoiron/sqlx"
)

type DB interface {
Expand All @@ -25,17 +27,18 @@ type DB interface {
type Query interface {
// Count Retrieve the "count" result of the query.
Count() (int64, error)
// Chunk Execute a callback over a given chunk size.
// Chunk(size int, callback func(dest []any) error) error
// CrossJoin specifying CROSS JOIN conditions for the query.
CrossJoin(query string, args ...any) Query
// Cursor returns a cursor, use scan to iterate over the returned rows.
Cursor() (chan Row, error)
// Decrement the given column's values by the given amounts.
Decrement(column string, value ...uint64) error
// Delete records from the database.
Delete() (*Result, error)
// DoesntExist Determine if no rows exist for the current query.
DoesntExist() (bool, error)
// Distinct Force the query to only return distinct results.
Distinct() Query
// Delete records from the database.
Delete() (*Result, error)
// Each(callback func(rows []any) error) error
// Exists Determine if any rows exist for the current query.
Exists() (bool, error)
// Find Execute a query for a single record by ID.
Expand All @@ -46,8 +49,6 @@ type Query interface {
FirstOr(dest any, callback func() error) error
// FirstOrFail finds the first record that matches the given conditions or throws an error.
FirstOrFail(dest any) error
// Decrement the given column's values by the given amounts.
Decrement(column string, value ...uint64) error
// Get Retrieve all rows from the database.
Get(dest any) error
// GroupBy specifies the group method on the query.
Expand All @@ -56,7 +57,8 @@ type Query interface {
Having(query any, args ...any) Query
// Increment a column's value by a given amount.
Increment(column string, value ...uint64) error
// inRandomOrder
// InRandomOrder Add an "in random order" clause to the query.
InRandomOrder() Query
// Insert a new record into the database.
Insert(data any) (*Result, error)
// InsertGetId returns the ID of the inserted row, only supported by MySQL and Sqlite
Expand All @@ -69,7 +71,8 @@ type Query interface {
LeftJoin(query string, args ...any) Query
// Limit Add a limit to the query.
Limit(limit uint64) Query
// lockForUpdate
// LockForUpdate Add a lock for update to the query.
LockForUpdate() Query
// Offset Add an "offset" clause to the query.
Offset(offset uint64) Query
// OrderBy Add an "order by" clause to the query.
Expand Down Expand Up @@ -108,7 +111,8 @@ type Query interface {
RightJoin(query string, args ...any) Query
// Select Set the columns to be selected.
Select(columns ...string) Query
// sharedLock
// SharedLock Add a shared lock to the query.
SharedLock() Query
// ToSql Get the SQL representation of the query.
ToSql() ToSql
// ToRawSql Get the raw SQL representation of the query with embedded bindings.
Expand Down Expand Up @@ -154,7 +158,7 @@ type Result struct {
type Builder interface {
Exec(query string, args ...any) (sql.Result, error)
Get(dest any, query string, args ...any) error
Query(query string, args ...any) (*sql.Rows, error)
Queryx(query string, args ...any) (*sqlx.Rows, error)
Select(dest any, query string, args ...any) error
}

Expand All @@ -167,3 +171,7 @@ type ToSql interface {
Pluck(column string, dest any) string
Update(column any, value ...any) string
}

type Row interface {
Scan(value any) error
}
29 changes: 16 additions & 13 deletions contracts/database/driver/conditions.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
package driver

type Conditions struct {
CrossJoin []Join
Distinct *bool
GroupBy []string
Having *Having
Join []Join
LeftJoin []Join
Limit *uint64
Offset *uint64
OrderBy []string
RightJoin []Join
Selects []string
Table string
Where []Where
CrossJoin []Join
Distinct *bool
GroupBy []string
Having *Having
Join []Join
InRandomOrder *bool
LeftJoin []Join
LockForUpdate *bool
Limit *uint64
Offset *uint64
OrderBy []string
RightJoin []Join
Selects []string
SharedLock *bool
Table string
Where []Where
}

type Having struct {
Expand Down
2 changes: 1 addition & 1 deletion contracts/database/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type Driver interface {
// Explain generates an SQL string with given parameters.
Explain(sql string, args ...any) string
// Gorm returns the Gorm database connection.
Gorm() (*gorm.DB, GormQuery, error)
Gorm() (*gorm.DB, error)
// Grammar returns the database grammar.
Grammar() Grammar
// Processor returns the database processor.
Expand Down
31 changes: 28 additions & 3 deletions contracts/database/driver/grammar.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@ package driver

import (
sq "github.com/Masterminds/squirrel"
"gorm.io/gorm/clause"
)

type Grammar interface {
SchemaGrammar
GormGrammar
DBGrammar
}

type SchemaGrammar interface {
// CompileAdd Compile an add column command.
CompileAdd(blueprint Blueprint, command *Command) string
// CompileChange Compile a change column command.
Expand Down Expand Up @@ -121,16 +128,34 @@ type Grammar interface {
TypeString(column ColumnDefinition) string
}

type GormGrammar interface {
// CompileLockForUpdateForGorm Compile the lock for update for gorm.
CompileLockForUpdateForGorm() clause.Expression
// CompileRandomOrderForGorm Compile the random order for gorm.
CompileRandomOrderForGorm() string
// CompileSharedLockForGorm Compile the shared lock for gorm.
CompileSharedLockForGorm() clause.Expression
}

type DBGrammar interface {
// CompileLockForUpdate Compile the lock for update.
CompileLockForUpdate(builder sq.SelectBuilder, conditions *Conditions) sq.SelectBuilder
// CompileInRandomOrder Compile the random order.
CompileInRandomOrder(builder sq.SelectBuilder, conditions *Conditions) sq.SelectBuilder
// CompileSharedLock Compile the shared lock.
CompileSharedLock(builder sq.SelectBuilder, conditions *Conditions) sq.SelectBuilder
}

type CompileOffsetGrammar interface {
CompileOffset(builder sq.SelectBuilder, conditions Conditions) sq.SelectBuilder
CompileOffset(builder sq.SelectBuilder, conditions *Conditions) sq.SelectBuilder
}

type CompileOrderByGrammar interface {
CompileOrderBy(builder sq.SelectBuilder, conditions Conditions) sq.SelectBuilder
CompileOrderBy(builder sq.SelectBuilder, conditions *Conditions) sq.SelectBuilder
}

type CompileLimitGrammar interface {
CompileLimit(builder sq.SelectBuilder, conditions Conditions) sq.SelectBuilder
CompileLimit(builder sq.SelectBuilder, conditions *Conditions) sq.SelectBuilder
}

type Schema interface {
Expand Down
7 changes: 1 addition & 6 deletions contracts/database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type Query interface {
// Create inserts new record into the database.
Create(value any) error
// Cursor returns a cursor, use scan to iterate over the returned rows.
Cursor() (chan Cursor, error)
Cursor() (chan db.Row, error)
// DB gets the underlying database connection.
DB() (*sql.DB, error)
// Delete deletes records matching given conditions, if the conditions are empty will delete all records.
Expand Down Expand Up @@ -216,11 +216,6 @@ type ConnectionModel interface {
Connection() string
}

type Cursor interface {
// Scan scans the current row into the given destination.
Scan(value any) error
}

type ToSql interface {
Count() string
Create(value any) string
Expand Down
100 changes: 79 additions & 21 deletions database/db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/goravel/framework/contracts/database"
"github.com/goravel/framework/contracts/database/db"
"github.com/goravel/framework/contracts/database/driver"
contractsdriver "github.com/goravel/framework/contracts/database/driver"
"github.com/goravel/framework/contracts/database/logger"
"github.com/goravel/framework/errors"
Expand Down Expand Up @@ -77,14 +76,41 @@ func (r *Query) CrossJoin(query string, args ...any) db.Query {
return q
}

// func (r *Query) Chunk(size int, callback func(dest []any) error) error {
// sql, args, err := r.buildSelect()
// if err != nil {
// return err
// }
func (r *Query) Cursor() (chan db.Row, error) {
sql, args, err := r.buildSelect()
if err != nil {
return nil, err
}

rows, err := r.builder.Queryx(sql, args...)
if err != nil {
r.trace(sql, args, -1, err)

return nil, err
}

ch := make(chan db.Row)
go func() {
defer rows.Close()
defer close(ch)

var count int64
for rows.Next() {
row := make(map[string]any)
if err := rows.MapScan(row); err != nil {
r.trace(sql, args, -1, err)
return
}

// return nil
// }
ch <- NewRow(row)
count++
}

r.trace(sql, args, count, nil)
}()

return ch, nil
}

func (r *Query) Decrement(column string, value ...uint64) error {
v := uint64(1)
Expand Down Expand Up @@ -307,6 +333,13 @@ func (r *Query) Increment(column string, value ...uint64) error {
return nil
}

func (r *Query) InRandomOrder() db.Query {
q := r.clone()
q.conditions.InRandomOrder = convert.Pointer(true)

return q
}

func (r *Query) Insert(data any) (*db.Result, error) {
mapData, err := convertToSliceMap(data)
if err != nil {
Expand Down Expand Up @@ -371,13 +404,6 @@ func (r *Query) InsertGetId(data any) (int64, error) {
return id, nil
}

func (r *Query) Limit(limit uint64) db.Query {
q := r.clone()
q.conditions.Limit = &limit

return q
}

func (r *Query) Latest(dest any, column ...string) error {
col := "created_at"
if len(column) > 0 {
Expand All @@ -397,6 +423,20 @@ func (r *Query) LeftJoin(query string, args ...any) db.Query {
return q
}

func (r *Query) Limit(limit uint64) db.Query {
q := r.clone()
q.conditions.Limit = &limit

return q
}

func (r *Query) LockForUpdate() db.Query {
q := r.clone()
q.conditions.LockForUpdate = convert.Pointer(true)

return q
}

func (r *Query) Offset(offset uint64) db.Query {
q := r.clone()
q.conditions.Offset = &offset
Expand Down Expand Up @@ -533,6 +573,13 @@ func (r *Query) Select(columns ...string) db.Query {
return q
}

func (r *Query) SharedLock() db.Query {
q := r.clone()
q.conditions.SharedLock = convert.Pointer(true)

return q
}

func (r *Query) ToSql() db.ToSql {
q := r.clone()
return NewToSql(q, false)
Expand Down Expand Up @@ -795,6 +842,10 @@ func (r *Query) buildSelect() (sql string, args []any, err error) {

builder = builder.Where(sqlizer)

if r.conditions.InRandomOrder != nil && *r.conditions.InRandomOrder {
builder = r.grammar.CompileInRandomOrder(builder, &r.conditions)
}

if len(r.conditions.GroupBy) > 0 {
builder = builder.GroupBy(r.conditions.GroupBy...)

Expand All @@ -803,33 +854,40 @@ func (r *Query) buildSelect() (sql string, args []any, err error) {
}
}

compileOrderByGrammar, ok := r.grammar.(driver.CompileOrderByGrammar)
compileOrderByGrammar, ok := r.grammar.(contractsdriver.CompileOrderByGrammar)
if ok {
builder = compileOrderByGrammar.CompileOrderBy(builder, r.conditions)
builder = compileOrderByGrammar.CompileOrderBy(builder, &r.conditions)
} else {
if len(r.conditions.OrderBy) > 0 {
builder = builder.OrderBy(r.conditions.OrderBy...)
}
}

compileOffsetGrammar, ok := r.grammar.(driver.CompileOffsetGrammar)
compileOffsetGrammar, ok := r.grammar.(contractsdriver.CompileOffsetGrammar)
if ok {
builder = compileOffsetGrammar.CompileOffset(builder, r.conditions)
builder = compileOffsetGrammar.CompileOffset(builder, &r.conditions)
} else {
if r.conditions.Offset != nil {
builder = builder.Offset(*r.conditions.Offset)
}
}

compileLimitGrammar, ok := r.grammar.(driver.CompileLimitGrammar)
compileLimitGrammar, ok := r.grammar.(contractsdriver.CompileLimitGrammar)
if ok {
builder = compileLimitGrammar.CompileLimit(builder, r.conditions)
builder = compileLimitGrammar.CompileLimit(builder, &r.conditions)
} else {
if r.conditions.Limit != nil {
builder = builder.Limit(*r.conditions.Limit)
}
}

if r.conditions.LockForUpdate != nil && *r.conditions.LockForUpdate {
builder = r.grammar.CompileLockForUpdate(builder, &r.conditions)
}
if r.conditions.SharedLock != nil && *r.conditions.SharedLock {
builder = r.grammar.CompileSharedLock(builder, &r.conditions)
}

return builder.ToSql()
}

Expand Down
Loading
Loading