Skip to content

Commit

Permalink
feat: add int support (#260)
Browse files Browse the repository at this point in the history
* fix: don't panic in convertParam on nil pointer

* feat: add *int support

Fixes #257
  • Loading branch information
egonelbre authored Jun 17, 2024
1 parent b9c0b4d commit ec15462
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 9 deletions.
2 changes: 2 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,8 @@ func checkIsValidType(v driver.Value) bool {
case *[]uint:
case int:
case []int:
case *int:
case *[]int:
case int64:
case []int64:
case spanner.NullInt64:
Expand Down
45 changes: 36 additions & 9 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,52 @@ func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement,
}

func convertParam(v driver.Value) driver.Value {
switch v.(type) {
switch v := v.(type) {
default:
return v
case int:
return int64(v)
case []int:
res := make([]int64, len(v))
for i, val := range v {
res[i] = int64(val)
}
return res
case uint:
return int64(v.(uint))
return int64(v)
case []uint:
vu := v.([]uint)
res := make([]int64, len(vu))
for i, val := range vu {
res := make([]int64, len(v))
for i, val := range v {
res[i] = int64(val)
}
return res
case *int:
if v == nil {
return (*int64)(nil)
}
vi := int64(*v)
return &vi
case *[]int:
if v == nil {
return (*[]int64)(nil)
}
res := make([]int64, len(*v))
for i, val := range *v {
res[i] = int64(val)
}
return &res
case *uint:
vi := int64(*v.(*uint))
if v == nil {
return (*int64)(nil)
}
vi := int64(*v)
return &vi
case *[]uint:
vu := v.(*[]uint)
res := make([]int64, len(*vu))
for i, val := range *vu {
if v == nil {
return (*[]int64)(nil)
}
res := make([]int64, len(*v))
for i, val := range *v {
res[i] = int64(val)
}
return &res
Expand Down
44 changes: 44 additions & 0 deletions stmt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package spannerdriver

import (
"database/sql/driver"
"reflect"
"testing"
)

func TestConvertParam(t *testing.T) {
check := func(in, want driver.Value) {
got := convertParam(in)
if !reflect.DeepEqual(got, want) {
t.Errorf("in:%#v want:%#v got:%#v", in, want, got)
}
}

check(uint(197), int64(197))
check(pointerTo[uint](197), pointerTo[int64](197))
check((*uint)(nil), (*int64)(nil))

check([]uint{197}, []int64{197})
check(pointerTo[[]uint]([]uint{197}), pointerTo[[]int64]([]int64{197}))
check((*[]uint)(nil), (*[]int64)(nil))

check([]int{197}, []int64{197})
check(pointerTo[[]int]([]int{197}), pointerTo[[]int64]([]int64{197}))
check((*[]int)(nil), (*[]int64)(nil))
}

func pointerTo[T any](v T) *T { return &v }

0 comments on commit ec15462

Please sign in to comment.