Skip to content

Commit

Permalink
Fix #145.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Sep 3, 2024
1 parent b51234c commit f4421ed
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
13 changes: 8 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Conn struct {
pending *Stmt
stmts []*Stmt
timer *time.Timer
busy func(int) bool
busy func(context.Context, int) bool
log func(xErrorCode, string)
collation func(*Conn, string)
wal func(*Conn, string, int) error
Expand Down Expand Up @@ -414,7 +414,7 @@ func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmo
// BusyHandler registers a callback to handle [BUSY] errors.
//
// https://sqlite.org/c3ref/busy_handler.html
func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error {
func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error {
var enable uint64
if cb != nil {
enable = 1
Expand All @@ -428,9 +428,12 @@ func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error {
}

func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil &&
(c.interrupt == nil || c.interrupt.Err() == nil) {
if c.busy(int(count)) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil {
interrupt := c.interrupt
if interrupt == nil {
interrupt = context.Background()
}
if interrupt.Err() == nil && c.busy(interrupt, int(count)) {
retry = 1
}
}
Expand Down
11 changes: 8 additions & 3 deletions tests/parallel/parallel_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tests

import (
"context"
"errors"
"io"
"log"
Expand Down Expand Up @@ -219,9 +220,13 @@ func testParallel(t testing.TB, name string, n int) {
}
defer db.Close()

err = db.BusyHandler(func(count int) (retry bool) {
time.Sleep(time.Millisecond)
return true
err = db.BusyHandler(func(ctx context.Context, count int) (retry bool) {
select {
case <-time.After(time.Millisecond):
return true
case <-ctx.Done():
return false
}
})
if err != nil {
return err
Expand Down

0 comments on commit f4421ed

Please sign in to comment.