Skip to content

Commit

Permalink
fix(spanner/spannertest): send transaction id in result metadata (#7809)
Browse files Browse the repository at this point in the history
Fix a bug where ExecuteStreamingSQL wouldn't return result metadata if
there were no results, and include the transaction ID (if any).

Co-authored-by: rahul2393 <[email protected]>
  • Loading branch information
adg and rahul2393 authored Apr 21, 2023
1 parent e1e8ba9 commit e3bbd5f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
5 changes: 4 additions & 1 deletion spanner/spannertest/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ var commitTimestampSentinel = &struct{}{}
// transaction records information about a running transaction.
// This is not safe for concurrent use.
type transaction struct {
id string
// readOnly is whether this transaction was constructed
// for read-only use, and should yield errors if used
// to perform a mutation.
Expand All @@ -102,13 +103,15 @@ type transaction struct {

func (d *database) NewReadOnlyTransaction() *transaction {
return &transaction{
id: genRandomTransaction(),
readOnly: true,
}
}

func (d *database) NewTransaction() *transaction {
return &transaction{
d: d,
id: genRandomTransaction(),
d: d,
}
}

Expand Down
24 changes: 15 additions & 9 deletions spanner/spannertest/inmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Span
}

func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) {
rsm, err := s.buildResultSetMetadata(ri)
rsm, err := s.buildResultSetMetadata(ri, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -678,11 +678,10 @@ func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) {
}

func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error {
rsm, err := s.buildResultSetMetadata(ri)
rsm, err := s.buildResultSetMetadata(ri, tx)
if err != nil {
return err
}

for {
row, err := ri.Next()
if err == io.EOF {
Expand Down Expand Up @@ -711,15 +710,23 @@ func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spa
// ResultSetMetadata is only set for the first PartialResultSet.
rsm = nil
}

if rsm != nil {
// If we didn't send any partial results, send the metadata
// which may contain an implicitly-opened transaction id.
return send(&spannerpb.PartialResultSet{
Metadata: rsm,
})
}
return nil
}

func (s *server) buildResultSetMetadata(ri rowIter) (*spannerpb.ResultSetMetadata, error) {
func (s *server) buildResultSetMetadata(ri rowIter, tx *transaction) (*spannerpb.ResultSetMetadata, error) {
// Build the result set metadata.
rsm := &spannerpb.ResultSetMetadata{
RowType: &spannerpb.StructType{},
// TODO: transaction info?
}
if tx != nil {
rsm.Transaction = &spannerpb.Transaction{Id: []byte(tx.id)}
}
for _, ci := range ri.Cols() {
st, err := spannerTypeFromType(ci.Type)
Expand All @@ -745,15 +752,14 @@ func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTrans
return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Session)
}

id := genRandomTransaction()
tx := s.db.NewTransaction()

sess.mu.Lock()
sess.lastUse = time.Now()
sess.transactions[id] = tx
sess.transactions[tx.id] = tx
sess.mu.Unlock()

tr := &spannerpb.Transaction{Id: []byte(id)}
tr := &spannerpb.Transaction{Id: []byte(tx.id)}

if req.GetOptions().GetReadOnly().GetReturnReadTimestamp() {
// Return the last commit timestamp.
Expand Down
17 changes: 17 additions & 0 deletions spanner/spannertest/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,23 @@ func TestIntegration_ReadsAndQueries(t *testing.T) {
if failures > 0 {
t.Logf("%d queries failed", failures)
}

// Check that doing a query that matches no rows returns response
// metadata that contains the implicitly-opened transaction id.
if _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
stmt := spanner.NewStatement("SELECT * FROM Staff WHERE Name='missing'")
iter := tx.Query(ctx, stmt)
if _, err := iter.Next(); err != iterator.Done {
return fmt.Errorf("unexpected error: %w", err)
}
iter.Stop()
// If the transaction id isn't known to the client then a
// BufferWrite will fail (this is simply a direct way of
// checking this).
return tx.BufferWrite(nil)
}); err != nil {
t.Fatal(err)
}
}

func TestIntegration_GeneratedColumns(t *testing.T) {
Expand Down

0 comments on commit e3bbd5f

Please sign in to comment.