From f4421edc01b561378cf5ea2cf4fa0349f4d01073 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 3 Sep 2024 17:31:15 +0100 Subject: [PATCH] Fix #145. --- conn.go | 13 ++++++++----- tests/parallel/parallel_test.go | 11 ++++++++--- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/conn.go b/conn.go index 79b1dba9..85a8899a 100644 --- a/conn.go +++ b/conn.go @@ -24,7 +24,7 @@ type Conn struct { pending *Stmt stmts []*Stmt timer *time.Timer - busy func(int) bool + busy func(context.Context, int) bool log func(xErrorCode, string) collation func(*Conn, string) wal func(*Conn, string, int) error @@ -414,7 +414,7 @@ func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmo // BusyHandler registers a callback to handle [BUSY] errors. // // https://sqlite.org/c3ref/busy_handler.html -func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error { +func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error { var enable uint64 if cb != nil { enable = 1 @@ -428,9 +428,12 @@ func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error { } func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) { - if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil && - (c.interrupt == nil || c.interrupt.Err() == nil) { - if c.busy(int(count)) { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil { + interrupt := c.interrupt + if interrupt == nil { + interrupt = context.Background() + } + if interrupt.Err() == nil && c.busy(interrupt, int(count)) { retry = 1 } } diff --git a/tests/parallel/parallel_test.go b/tests/parallel/parallel_test.go index 9045b9d4..d9f42dda 100644 --- a/tests/parallel/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -1,6 +1,7 @@ package tests import ( + "context" "errors" "io" "log" @@ -219,9 +220,13 @@ func testParallel(t testing.TB, name string, n int) { } defer db.Close() - err = db.BusyHandler(func(count int) (retry bool) { - time.Sleep(time.Millisecond) - return true + err = db.BusyHandler(func(ctx context.Context, count int) (retry bool) { + select { + case <-time.After(time.Millisecond): + return true + case <-ctx.Done(): + return false + } }) if err != nil { return err