Skip to content

fix: support reusing named parameters #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ if err != nil {
}

// Print tweets with more than 500 likes.
rows, err := db.QueryContext(ctx, "SELECT id, text FROM tweets WHERE likes > @likes", 500)
rows, err := db.QueryContext(ctx, "SELECT id, text FROM tweets WHERE likes > @likes", sql.Named("likes", 500))
if err != nil {
log.Fatal(err)
}
Expand All @@ -34,19 +34,34 @@ for rows.Next() {

## Statements

Statements support follows the official [Google Cloud Spanner Go](https://pkg.go.dev/cloud.google.com/go/spanner) client style arguments as well as positional paramaters.
Statements support follows the official [Google Cloud Spanner Go](https://pkg.go.dev/cloud.google.com/go/spanner) client
style arguments as well as positional parameters. It is highly recommended to use either positional parameters in
combination with positional arguments, __or__ named parameters in combination with named arguments.

### Using positional patameter
### Using positional parameters with positional arguments

```go
db.QueryContext(ctx, "SELECT id, text FROM tweets WHERE likes > ?", 500)

db.ExecContext(ctx, "INSERT INTO tweets (id, text, rts) VALUES (?, ?, ?)", id, text, 10000)
```

### Using named patameter
### Using named parameters with named arguments

```go
db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", sql.Named("id", 14544498215374))

db.ExecContext(ctx, "INSERT INTO tweets (id, text, rts) VALUES (@id, @text, @rts)",
sql.Named("id", id), sql.Named("text", text), sql.Named("rts", 10000))
```

### Using named parameters with positional arguments (not recommended)
Named parameters can also be used in combination with positional arguments,
but this is __not recommended__, as the behavior can be hard to predict if
the same named query parameter is used in multiple places in the statement.

```go
// Possible, but not recommended.
db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", 14544498215374)
```

Expand Down
120 changes: 120 additions & 0 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,126 @@ func TestDmlInAutocommit(t *testing.T) {
}
}

func TestQueryWithDuplicateNamedParameter(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@name, @name)"
server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{
Type: testutil.StatementResultUpdateCount,
UpdateCount: 1,
})
_, err := db.Exec(s, sql.Named("name", "foo"), sql.Named("name", "bar"))
if err != nil {
t.Fatal(err)
}
// Verify that 'bar' is used for both instances of the parameter @name.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if len(sqlRequests) != 1 {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", len(sqlRequests), 1)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.Params.Fields), 1; g != w {
t.Fatalf("params count mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := req.Params.Fields["name"].GetStringValue(), "bar"; g != w {
t.Fatalf("param value mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestQueryWithReusedNamedParameter(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@name, @name)"
server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{
Type: testutil.StatementResultUpdateCount,
UpdateCount: 1,
})
_, err := db.Exec(s, sql.Named("name", "foo"))
if err != nil {
t.Fatal(err)
}
// Verify that 'foo' is used for both instances of the parameter @name.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if len(sqlRequests) != 1 {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", len(sqlRequests), 1)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.Params.Fields), 1; g != w {
t.Fatalf("params count mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := req.Params.Fields["name"].GetStringValue(), "foo"; g != w {
t.Fatalf("param value mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestQueryWithReusedPositionalParameter(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@name, @name)"
server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{
Type: testutil.StatementResultUpdateCount,
UpdateCount: 1,
})
_, err := db.Exec(s, "foo", "bar")
if err != nil {
t.Fatal(err)
}
// Verify that 'bar' is used for both instances of the parameter @name.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if len(sqlRequests) != 1 {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", len(sqlRequests), 1)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.Params.Fields), 1; g != w {
t.Fatalf("params count mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := req.Params.Fields["name"].GetStringValue(), "bar"; g != w {
t.Fatalf("param value mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestQueryWithMissingPositionalParameter(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@name, @name)"
server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{
Type: testutil.StatementResultUpdateCount,
UpdateCount: 1,
})
_, err := db.Exec(s, "foo")
if err != nil {
t.Fatal(err)
}
// Verify that 'foo' is used for the parameter @name.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if len(sqlRequests) != 1 {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", len(sqlRequests), 1)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.Params.Fields), 1; g != w {
t.Fatalf("params count mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := req.Params.Fields["name"].GetStringValue(), "foo"; g != w {
t.Fatalf("param value mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestDdlInAutocommit(t *testing.T) {
t.Parallel()

Expand Down
18 changes: 13 additions & 5 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,24 @@ func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement,
if err != nil {
return spanner.Statement{}, err
}
if len(names) != len(args) {
return spanner.Statement{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "got %v argument values, but found %v parameters in the sql string", len(args), len(names)))
}
//if !hasNamedParams && len(names) != len(args) {
// return spanner.Statement{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "got %v argument values, but found %v parameters in the sql string", len(args), len(names)))
//}
ss := spanner.NewStatement(q)
for i, v := range args {
name := args[i].Name
if name == "" {
if name == "" && len(names) > i {
name = names[i]
}
ss.Params[name] = convertParam(v.Value)
if name != "" {
ss.Params[name] = convertParam(v.Value)
}
}
// Verify that all parameters have a value.
for _, name := range names {
if _, ok := ss.Params[name]; !ok {
return spanner.Statement{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "missing value for query parameter %v", name))
}
}
return ss, nil
}
Expand Down
Loading