diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go index ac5a37b34cd7..042888f9c98c 100644 --- a/spanner/sessionclient.go +++ b/spanner/sessionclient.go @@ -85,6 +85,7 @@ type sessionConsumer interface { // will ensure that the sessions that are created are evenly distributed over // all available channels. type sessionClient struct { + waitWorkers sync.WaitGroup mu sync.Mutex closed bool disableRouteToLeader bool @@ -120,10 +121,17 @@ func newSessionClient(connPool gtransport.ConnPool, database, userAgent string, } func (sc *sessionClient) close() error { - sc.mu.Lock() - defer sc.mu.Unlock() - sc.closed = true - return sc.connPool.Close() + defer sc.waitWorkers.Wait() + + var err error + func() { + sc.mu.Lock() + defer sc.mu.Unlock() + + sc.closed = true + err = sc.connPool.Close() + }() + return err } // createSession creates one session for the database of the sessionClient. The @@ -231,6 +239,7 @@ func (sc *sessionClient) batchCreateSessions(createSessionCount int32, distribut createCountForChannel += remainder } if createCountForChannel > 0 { + sc.waitWorkers.Add(1) go sc.executeBatchCreateSessions(client, createCountForChannel, sc.sessionLabels, sc.md, consumer) numBeingCreated += createCountForChannel } @@ -241,11 +250,14 @@ func (sc *sessionClient) batchCreateSessions(createSessionCount int32, distribut // executeBatchCreateSessions executes the gRPC call for creating a batch of // sessions. func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createCount int32, labels map[string]string, md metadata.MD, consumer sessionConsumer) { + defer sc.waitWorkers.Done() + ctx, cancel := context.WithTimeout(context.Background(), sc.batchTimeout) defer cancel() ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchCreateSessions") defer func() { trace.EndSpan(ctx, nil) }() trace.TracePrintf(ctx, nil, "Creating a batch of %d sessions", createCount) + remainingCreateCount := createCount for { sc.mu.Lock() diff --git a/spanner/spannertest/inmem.go b/spanner/spannertest/inmem.go index e73f99ca9059..e9d26d259769 100644 --- a/spanner/spannertest/inmem.go +++ b/spanner/spannertest/inmem.go @@ -191,7 +191,7 @@ func NewServer(laddr string) (*Server, error) { s := &Server{ Addr: l.Addr().String(), l: l, - srv: grpc.NewServer(), + srv: grpc.NewServer(grpc.WaitForHandlers(true)), s: &server{ logf: func(format string, args ...interface{}) { log.Printf("spannertest.inmem: "+format, args...)