From f8e22f6cbba10fc262e87b4d06d5c1289d877503 Mon Sep 17 00:00:00 2001 From: rahul2393 Date: Mon, 17 Apr 2023 11:08:39 +0530 Subject: [PATCH] fix(spanner): context timeout should be wrapped correctly (#7744) * fix(spanner): context timeout should be wrapped correctly * add test --- spanner/client.go | 3 ++- spanner/client_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/spanner/client.go b/spanner/client.go index d7af7b901205..cb1a11b9a6f7 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -533,7 +533,8 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea } if t.shouldExplicitBegin(attempt) { if err = t.begin(ctx); err != nil { - return spannerErrorf(codes.Internal, "error while BeginTransaction during retrying a ReadWrite transaction: %v", err) + trace.TracePrintf(ctx, nil, "Error while BeginTransaction during retrying a ReadWrite transaction: %v", ToSpannerError(err)) + return ToSpannerError(err) } } else { t = &ReadWriteTransaction{ diff --git a/spanner/client_test.go b/spanner/client_test.go index 6e96cadddbeb..64551b9ddade 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -1838,6 +1838,41 @@ func TestClient_ReadWriteTransaction_MultipleReadsWithoutNext(t *testing.T) { } } +func TestClient_ReadWriteTransaction_WithCancelledContext(t *testing.T) { + t.Parallel() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.AddPartialResultSetError( + SelectSingerIDAlbumIDAlbumTitleFromAlbums, + PartialResultSetExecutionTime{ + ResumeToken: EncodeResumeToken(2), + Err: status.Errorf(codes.Internal, "stream terminated by RST_STREAM"), + }, + ) + ctx, cancel := context.WithCancel(context.Background()) + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + iter := tx.Read(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}) + if _, err := iter.Next(); err != nil { + return err + } + return nil + }) + if err != nil { + panic(err) + } + cancel() + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + iter := tx.Read(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}) + if _, err := iter.Next(); err != nil { + return err + } + return nil + }) + if status.Code(err) != codes.Canceled { + t.Fatal(err) + } +} + func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error { return testReadWriteTransactionWithConfig(t, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, executionTimes, expectedAttempts) }