Skip to content

Commit

Permalink
fix: support reusing named parameters (#240)
Browse files Browse the repository at this point in the history
Named parameters that occurred multiple times in the SQL string
had to be added to the list of arguments multiple times, even if
these were given as named arguments. This fix allows queries that
use named parameters to accept only one occurence of the named
parameter.
This also clarifies the use of named parameters and positional
parameters, and that mixing named parameters with positional
arguments is not a good idea.

Based on the issue reported in #237
  • Loading branch information
olavloite authored May 30, 2024
1 parent ad95d85 commit c7140a2
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 9 deletions.
23 changes: 19 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ if err != nil {
}

// Print tweets with more than 500 likes.
rows, err := db.QueryContext(ctx, "SELECT id, text FROM tweets WHERE likes > @likes", 500)
rows, err := db.QueryContext(ctx, "SELECT id, text FROM tweets WHERE likes > @likes", sql.Named("likes", 500))
if err != nil {
log.Fatal(err)
}
Expand All @@ -34,19 +34,34 @@ for rows.Next() {

## Statements

Statements support follows the official [Google Cloud Spanner Go](https://pkg.go.dev/cloud.google.com/go/spanner) client style arguments as well as positional paramaters.
Statements support follows the official [Google Cloud Spanner Go](https://pkg.go.dev/cloud.google.com/go/spanner) client
style arguments as well as positional parameters. It is highly recommended to use either positional parameters in
combination with positional arguments, __or__ named parameters in combination with named arguments.

### Using positional patameter
### Using positional parameters with positional arguments

```go
db.QueryContext(ctx, "SELECT id, text FROM tweets WHERE likes > ?", 500)

db.ExecContext(ctx, "INSERT INTO tweets (id, text, rts) VALUES (?, ?, ?)", id, text, 10000)
```

### Using named patameter
### Using named parameters with named arguments

```go
db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", sql.Named("id", 14544498215374))

db.ExecContext(ctx, "INSERT INTO tweets (id, text, rts) VALUES (@id, @text, @rts)",
sql.Named("id", id), sql.Named("text", text), sql.Named("rts", 10000))
```

### Using named parameters with positional arguments (not recommended)
Named parameters can also be used in combination with positional arguments,
but this is __not recommended__, as the behavior can be hard to predict if
the same named query parameter is used in multiple places in the statement.

```go
// Possible, but not recommended.
db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", 14544498215374)
```

Expand Down
120 changes: 120 additions & 0 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,126 @@ func TestDmlInAutocommit(t *testing.T) {
}
}

func TestQueryWithDuplicateNamedParameter(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@name, @name)"
server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{
Type: testutil.StatementResultUpdateCount,
UpdateCount: 1,
})
_, err := db.Exec(s, sql.Named("name", "foo"), sql.Named("name", "bar"))
if err != nil {
t.Fatal(err)
}
// Verify that 'bar' is used for both instances of the parameter @name.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if len(sqlRequests) != 1 {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", len(sqlRequests), 1)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.Params.Fields), 1; g != w {
t.Fatalf("params count mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := req.Params.Fields["name"].GetStringValue(), "bar"; g != w {
t.Fatalf("param value mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestQueryWithReusedNamedParameter(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@name, @name)"
server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{
Type: testutil.StatementResultUpdateCount,
UpdateCount: 1,
})
_, err := db.Exec(s, sql.Named("name", "foo"))
if err != nil {
t.Fatal(err)
}
// Verify that 'foo' is used for both instances of the parameter @name.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if len(sqlRequests) != 1 {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", len(sqlRequests), 1)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.Params.Fields), 1; g != w {
t.Fatalf("params count mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := req.Params.Fields["name"].GetStringValue(), "foo"; g != w {
t.Fatalf("param value mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestQueryWithReusedPositionalParameter(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@name, @name)"
server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{
Type: testutil.StatementResultUpdateCount,
UpdateCount: 1,
})
_, err := db.Exec(s, "foo", "bar")
if err != nil {
t.Fatal(err)
}
// Verify that 'bar' is used for both instances of the parameter @name.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if len(sqlRequests) != 1 {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", len(sqlRequests), 1)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.Params.Fields), 1; g != w {
t.Fatalf("params count mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := req.Params.Fields["name"].GetStringValue(), "bar"; g != w {
t.Fatalf("param value mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestQueryWithMissingPositionalParameter(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@name, @name)"
server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{
Type: testutil.StatementResultUpdateCount,
UpdateCount: 1,
})
_, err := db.Exec(s, "foo")
if err != nil {
t.Fatal(err)
}
// Verify that 'foo' is used for the parameter @name.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if len(sqlRequests) != 1 {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", len(sqlRequests), 1)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.Params.Fields), 1; g != w {
t.Fatalf("params count mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := req.Params.Fields["name"].GetStringValue(), "foo"; g != w {
t.Fatalf("param value mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestDdlInAutocommit(t *testing.T) {
t.Parallel()

Expand Down
18 changes: 13 additions & 5 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,24 @@ func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement,
if err != nil {
return spanner.Statement{}, err
}
if len(names) != len(args) {
return spanner.Statement{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "got %v argument values, but found %v parameters in the sql string", len(args), len(names)))
}
//if !hasNamedParams && len(names) != len(args) {
// return spanner.Statement{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "got %v argument values, but found %v parameters in the sql string", len(args), len(names)))
//}
ss := spanner.NewStatement(q)
for i, v := range args {
name := args[i].Name
if name == "" {
if name == "" && len(names) > i {
name = names[i]
}
ss.Params[name] = convertParam(v.Value)
if name != "" {
ss.Params[name] = convertParam(v.Value)
}
}
// Verify that all parameters have a value.
for _, name := range names {
if _, ok := ss.Params[name]; !ok {
return spanner.Statement{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "missing value for query parameter %v", name))
}
}
return ss, nil
}
Expand Down

0 comments on commit c7140a2

Please sign in to comment.