From 5a2091bdcb590f0b963f4fb429d0ca78e4792bfb Mon Sep 17 00:00:00 2001 From: Egon Elbre Date: Fri, 3 May 2024 12:49:35 +0300 Subject: [PATCH] fix(spanner): wait for things to complete As a good rule of thumb every single goroutine needs to be waited to be completed, otherwise it's easy to introduce issues where the created goroutine outlives the parent and accesses variables or services that have been shutdown. --- spanner/sessionclient.go | 20 ++++++++++++++++---- spanner/spannertest/inmem.go | 2 +- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go index ac5a37b34cd7..042888f9c98c 100644 --- a/spanner/sessionclient.go +++ b/spanner/sessionclient.go @@ -85,6 +85,7 @@ type sessionConsumer interface { // will ensure that the sessions that are created are evenly distributed over // all available channels. type sessionClient struct { + waitWorkers sync.WaitGroup mu sync.Mutex closed bool disableRouteToLeader bool @@ -120,10 +121,17 @@ func newSessionClient(connPool gtransport.ConnPool, database, userAgent string, } func (sc *sessionClient) close() error { - sc.mu.Lock() - defer sc.mu.Unlock() - sc.closed = true - return sc.connPool.Close() + defer sc.waitWorkers.Wait() + + var err error + func() { + sc.mu.Lock() + defer sc.mu.Unlock() + + sc.closed = true + err = sc.connPool.Close() + }() + return err } // createSession creates one session for the database of the sessionClient. The @@ -231,6 +239,7 @@ func (sc *sessionClient) batchCreateSessions(createSessionCount int32, distribut createCountForChannel += remainder } if createCountForChannel > 0 { + sc.waitWorkers.Add(1) go sc.executeBatchCreateSessions(client, createCountForChannel, sc.sessionLabels, sc.md, consumer) numBeingCreated += createCountForChannel } @@ -241,11 +250,14 @@ func (sc *sessionClient) batchCreateSessions(createSessionCount int32, distribut // executeBatchCreateSessions executes the gRPC call for creating a batch of // sessions. func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createCount int32, labels map[string]string, md metadata.MD, consumer sessionConsumer) { + defer sc.waitWorkers.Done() + ctx, cancel := context.WithTimeout(context.Background(), sc.batchTimeout) defer cancel() ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchCreateSessions") defer func() { trace.EndSpan(ctx, nil) }() trace.TracePrintf(ctx, nil, "Creating a batch of %d sessions", createCount) + remainingCreateCount := createCount for { sc.mu.Lock() diff --git a/spanner/spannertest/inmem.go b/spanner/spannertest/inmem.go index e73f99ca9059..e9d26d259769 100644 --- a/spanner/spannertest/inmem.go +++ b/spanner/spannertest/inmem.go @@ -191,7 +191,7 @@ func NewServer(laddr string) (*Server, error) { s := &Server{ Addr: l.Addr().String(), l: l, - srv: grpc.NewServer(), + srv: grpc.NewServer(grpc.WaitForHandlers(true)), s: &server{ logf: func(format string, args ...interface{}) { log.Printf("spannertest.inmem: "+format, args...)