From 7bb1640a00fceca1e1075fe6544b9a4842ab2b26 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Tue, 22 Feb 2022 13:30:10 +0200 Subject: [PATCH] fix: check for nils when appeding driver.Value --- .github/workflows/build.yml | 4 ++-- internal/dbtest/db_test.go | 28 ++++++++++++++++++++++++++++ schema/append_value.go | 19 +++++++++++++++++-- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 83fb80360..3d95db7b7 100755 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -69,7 +69,7 @@ jobs: MSSQL_USER: sa MSSQL_PASSWORD: passWORD1 ports: - - 14339:1433 + - 1433:1433 options: >- --health-cmd="/opt/mssql-tools/bin/sqlcmd -S tcp:localhost,1433 -U sa -P passWORD1 -Q 'select 1' -b -o /dev/null" --health-interval=10s --health-timeout=5s --health-retries=5 @@ -90,4 +90,4 @@ jobs: MYSQL: user:pass@/test MYSQL5: user:pass@tcp(localhost:53306)/test MARIADB: user:pass@tcp(localhost:13306)/test - MSSQL2019: sqlserver://sa:passWORD1@localhost:14339?database=test + MSSQL2019: sqlserver://sa:passWORD1@localhost:1433?database=master diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 6b9439414..404e8cdf2 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -269,6 +269,7 @@ func TestDB(t *testing.T) { {testEmbedModelValue}, {testEmbedModelPointer}, {testJSONMarshaler}, + {testNilDriverValue}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -1284,3 +1285,30 @@ func testJSONMarshaler(t *testing.T, db *bun.DB) { require.NoError(t, err) require.Equal(t, "bar", m2.Field.Foo) } + +type DriverValue struct { + s string +} + +var _ driver.Valuer = (*DriverValue)(nil) + +func (v *DriverValue) Value() (driver.Value, error) { + return v.s, nil +} + +func testNilDriverValue(t *testing.T, db *bun.DB) { + type Model struct { + Value *DriverValue `bun:"type:varchar(100)"` + } + + ctx := context.Background() + + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + _, err = db.NewInsert().Model(&Model{}).Exec(ctx) + require.NoError(t, err) + + _, err = db.NewInsert().Model(&Model{Value: &DriverValue{s: "hello"}}).Exec(ctx) + require.NoError(t, err) +} diff --git a/schema/append_value.go b/schema/append_value.go index e6587cd6e..c805c745f 100644 --- a/schema/append_value.go +++ b/schema/append_value.go @@ -105,15 +105,21 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc { return appendJSONRawMessageValue } + kind := typ.Kind() + if typ.Implements(queryAppenderType) { + if kind == reflect.Ptr { + return nilAwareAppender(appendQueryAppenderValue) + } return appendQueryAppenderValue } if typ.Implements(driverValuerType) { + if kind == reflect.Ptr { + return nilAwareAppender(appendDriverValue) + } return appendDriverValue } - kind := typ.Kind() - if kind != reflect.Ptr { ptr := reflect.PtrTo(typ) if ptr.Implements(queryAppenderType) { @@ -156,6 +162,15 @@ func ifaceAppenderFunc(fmter Formatter, b []byte, v reflect.Value) []byte { return appender(fmter, b, elem) } +func nilAwareAppender(fn AppenderFunc) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + return fn(fmter, b, v) + } +} + func PtrAppender(fn AppenderFunc) AppenderFunc { return func(fmter Formatter, b []byte, v reflect.Value) []byte { if v.IsNil() {