From 3ccccfb4c9c683a80e2cce810ac652616579e51c Mon Sep 17 00:00:00 2001 From: Phil Eaton Date: Sun, 29 May 2022 21:06:43 -0400 Subject: [PATCH] Support returning any from callbacks (#1046) Support returning any from callbacks --- callback.go | 19 ++++++++++++++++ callback_test.go | 12 ++++++++++ sqlite3_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/callback.go b/callback.go index b020fe37..d3056910 100644 --- a/callback.go +++ b/callback.go @@ -353,6 +353,20 @@ func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error { return nil } +func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error { + if v.IsNil() { + C.sqlite3_result_null(ctx) + return nil + } + + cb, err := callbackRet(v.Elem().Type()) + if err != nil { + return err + } + + return cb(ctx, v.Elem()) +} + func callbackRet(typ reflect.Type) (callbackRetConverter, error) { switch typ.Kind() { case reflect.Interface: @@ -360,6 +374,11 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) { if typ.Implements(errorInterface) { return callbackRetNil, nil } + + if typ.NumMethod() == 0 { + return callbackRetGeneric, nil + } + fallthrough case reflect.Slice: if typ.Elem().Kind() != reflect.Uint8 { diff --git a/callback_test.go b/callback_test.go index 714ed607..b09122ae 100644 --- a/callback_test.go +++ b/callback_test.go @@ -102,3 +102,15 @@ func TestCallbackConverters(t *testing.T) { } } } + +func TestCallbackReturnAny(t *testing.T) { + udf := func() interface{} { + return 1 + } + + typ := reflect.TypeOf(udf) + _, err := callbackRet(typ.Out(0)) + if err != nil { + t.Errorf("Expected valid callback for any return type, got: %s", err) + } +} diff --git a/sqlite3_test.go b/sqlite3_test.go index c86aba4b..9ee87e7e 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1449,6 +1449,63 @@ func TestAggregatorRegistration(t *testing.T) { } } +type mode struct { + counts map[interface{}]int + top interface{} + topCount int +} + +func newMode() *mode { + return &mode{ + counts: map[interface{}]int{}, + } +} + +func (m *mode) Step(x interface{}) { + m.counts[x]++ + c := m.counts[x] + if c > m.topCount { + m.top = x + m.topCount = c + } +} + +func (m *mode) Done() interface{} { + return m.top +} + +func TestAggregatorRegistration_GenericReturn(t *testing.T) { + sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + return conn.RegisterAggregator("mode", newMode, true) + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + var mode int + err = db.QueryRow("select mode(profits) from foo").Scan(&mode) + if err != nil { + t.Fatal("MODE query error:", err) + } + + if mode != 20 { + t.Fatal("Got incorrect mode. Wanted 20, got: ", mode) + } +} + func rot13(r rune) rune { switch { case r >= 'A' && r <= 'Z':