diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go index 3328c229ffa8..6babaf93e04a 100644 --- a/spanner/spannertest/db.go +++ b/spanner/spannertest/db.go @@ -90,6 +90,7 @@ var commitTimestampSentinel = &struct{}{} // transaction records information about a running transaction. // This is not safe for concurrent use. type transaction struct { + id string // readOnly is whether this transaction was constructed // for read-only use, and should yield errors if used // to perform a mutation. @@ -102,13 +103,15 @@ type transaction struct { func (d *database) NewReadOnlyTransaction() *transaction { return &transaction{ + id: genRandomTransaction(), readOnly: true, } } func (d *database) NewTransaction() *transaction { return &transaction{ - d: d, + id: genRandomTransaction(), + d: d, } } diff --git a/spanner/spannertest/inmem.go b/spanner/spannertest/inmem.go index 5ed6385b3bbd..3a91c030d784 100644 --- a/spanner/spannertest/inmem.go +++ b/spanner/spannertest/inmem.go @@ -649,7 +649,7 @@ func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Span } func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) { - rsm, err := s.buildResultSetMetadata(ri) + rsm, err := s.buildResultSetMetadata(ri, nil) if err != nil { return nil, err } @@ -678,11 +678,10 @@ func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) { } func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error { - rsm, err := s.buildResultSetMetadata(ri) + rsm, err := s.buildResultSetMetadata(ri, tx) if err != nil { return err } - for { row, err := ri.Next() if err == io.EOF { @@ -711,15 +710,23 @@ func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spa // ResultSetMetadata is only set for the first PartialResultSet. rsm = nil } - + if rsm != nil { + // If we didn't send any partial results, send the metadata + // which may contain an implicitly-opened transaction id. + return send(&spannerpb.PartialResultSet{ + Metadata: rsm, + }) + } return nil } -func (s *server) buildResultSetMetadata(ri rowIter) (*spannerpb.ResultSetMetadata, error) { +func (s *server) buildResultSetMetadata(ri rowIter, tx *transaction) (*spannerpb.ResultSetMetadata, error) { // Build the result set metadata. rsm := &spannerpb.ResultSetMetadata{ RowType: &spannerpb.StructType{}, - // TODO: transaction info? + } + if tx != nil { + rsm.Transaction = &spannerpb.Transaction{Id: []byte(tx.id)} } for _, ci := range ri.Cols() { st, err := spannerTypeFromType(ci.Type) @@ -745,15 +752,14 @@ func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTrans return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Session) } - id := genRandomTransaction() tx := s.db.NewTransaction() sess.mu.Lock() sess.lastUse = time.Now() - sess.transactions[id] = tx + sess.transactions[tx.id] = tx sess.mu.Unlock() - tr := &spannerpb.Transaction{Id: []byte(id)} + tr := &spannerpb.Transaction{Id: []byte(tx.id)} if req.GetOptions().GetReadOnly().GetReturnReadTimestamp() { // Return the last commit timestamp. diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index 90b89031a354..cfe75c171e97 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -1284,6 +1284,23 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { if failures > 0 { t.Logf("%d queries failed", failures) } + + // Check that doing a query that matches no rows returns response + // metadata that contains the implicitly-opened transaction id. + if _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + stmt := spanner.NewStatement("SELECT * FROM Staff WHERE Name='missing'") + iter := tx.Query(ctx, stmt) + if _, err := iter.Next(); err != iterator.Done { + return fmt.Errorf("unexpected error: %w", err) + } + iter.Stop() + // If the transaction id isn't known to the client then a + // BufferWrite will fail (this is simply a direct way of + // checking this). + return tx.BufferWrite(nil) + }); err != nil { + t.Fatal(err) + } } func TestIntegration_GeneratedColumns(t *testing.T) {