From 5aef53a8884038f3c9f06e6dddb9372196253378 Mon Sep 17 00:00:00 2001
From: Chi Wang <chiwang@google.com>
Date: Tue, 14 Dec 2021 16:18:26 +0100
Subject: [PATCH] Remote: Don't blocking-get when acquiring gRPC connections.
 (#14420)

With recent change to limit the max number of gRPC connections by default, acquiring a connection could suspend a thread if there is no available connection.

gRPC calls are scheduled to a dedicated background thread pool. Workers in the thread pool are responsible to acquire the connection before starting the RPC call.

There could be a race condition that a worker thread handles some gRPC calls and then switches to a new call which will acquire new connections. If the number of connections reaches the max, the worker thread is suspended and doesn't have a chance to switch to previous calls. The connections held by previous calls are, hence, never released.

This PR changes to not use blocking get when acquiring gRPC connections.

Fixes #14363.

Closes #14416.

PiperOrigin-RevId: 416282883
---
 .../google/devtools/build/lib/remote/BUILD    |   3 +-
 .../build/lib/remote/ByteStreamUploader.java  |  30 ++--
 .../ExperimentalGrpcRemoteExecutor.java       |  44 ++++--
 .../build/lib/remote/GrpcCacheClient.java     |  63 ++++++---
 .../build/lib/remote/GrpcRemoteExecutor.java  |  15 +-
 .../lib/remote/ReferenceCountedChannel.java   | 129 +++++-------------
 .../build/lib/remote/RemoteModule.java        |  11 +-
 .../lib/remote/RemoteServerCapabilities.java  |   9 +-
 .../build/lib/remote/UploadManifest.java      |   7 +-
 .../downloader/GrpcRemoteDownloader.java      |  11 +-
 .../build/lib/remote/util/RxFutures.java      |  32 ++---
 11 files changed, 188 insertions(+), 166 deletions(-)

diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD
index 1eaa3fdf618efe..a5745bf9b616b8 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD
@@ -138,9 +138,10 @@ java_library(
     ],
     deps = [
         "//src/main/java/com/google/devtools/build/lib/remote/grpc",
+        "//src/main/java/com/google/devtools/build/lib/remote/util",
         "//third_party:guava",
-        "//third_party:jsr305",
         "//third_party:netty",
+        "//third_party:rxjava3",
         "//third_party/grpc:grpc-jar",
     ],
 )
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
index c488f14f397d07..cc31b5bf070705 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
@@ -24,6 +24,7 @@
 import com.google.bytestream.ByteStreamGrpc;
 import com.google.bytestream.ByteStreamGrpc.ByteStreamFutureStub;
 import com.google.bytestream.ByteStreamProto.QueryWriteStatusRequest;
+import com.google.bytestream.ByteStreamProto.QueryWriteStatusResponse;
 import com.google.bytestream.ByteStreamProto.WriteRequest;
 import com.google.bytestream.ByteStreamProto.WriteResponse;
 import com.google.common.annotations.VisibleForTesting;
@@ -374,7 +375,7 @@ public ReferenceCounted touch(Object o) {
   private static class AsyncUpload {
 
     private final RemoteActionExecutionContext context;
-    private final Channel channel;
+    private final ReferenceCountedChannel channel;
     private final CallCredentialsProvider callCredentialsProvider;
     private final long callTimeoutSecs;
     private final Retrier retrier;
@@ -385,7 +386,7 @@ private static class AsyncUpload {
 
     AsyncUpload(
         RemoteActionExecutionContext context,
-        Channel channel,
+        ReferenceCountedChannel channel,
         CallCredentialsProvider callCredentialsProvider,
         long callTimeoutSecs,
         Retrier retrier,
@@ -452,7 +453,7 @@ ListenableFuture<Void> start() {
           MoreExecutors.directExecutor());
     }
 
-    private ByteStreamFutureStub bsFutureStub() {
+    private ByteStreamFutureStub bsFutureStub(Channel channel) {
       return ByteStreamGrpc.newFutureStub(channel)
           .withInterceptors(
               TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()))
@@ -463,7 +464,10 @@ private ByteStreamFutureStub bsFutureStub() {
     private ListenableFuture<Void> callAndQueryOnFailure(
         AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
       return Futures.catchingAsync(
-          call(committedOffset),
+          Futures.transform(
+              channel.withChannelFuture(channel -> call(committedOffset, channel)),
+              written -> null,
+              MoreExecutors.directExecutor()),
           Exception.class,
           (e) -> guardQueryWithSuppression(e, committedOffset, progressiveBackoff),
           MoreExecutors.directExecutor());
@@ -500,10 +504,14 @@ private ListenableFuture<Void> query(
         AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
       ListenableFuture<Long> committedSizeFuture =
           Futures.transform(
-              bsFutureStub()
-                  .queryWriteStatus(
-                      QueryWriteStatusRequest.newBuilder().setResourceName(resourceName).build()),
-              (response) -> response.getCommittedSize(),
+              channel.withChannelFuture(
+                  channel ->
+                      bsFutureStub(channel)
+                          .queryWriteStatus(
+                              QueryWriteStatusRequest.newBuilder()
+                                  .setResourceName(resourceName)
+                                  .build())),
+              QueryWriteStatusResponse::getCommittedSize,
               MoreExecutors.directExecutor());
       ListenableFuture<Long> guardedCommittedSizeFuture =
           Futures.catchingAsync(
@@ -533,14 +541,14 @@ private ListenableFuture<Void> query(
           MoreExecutors.directExecutor());
     }
 
-    private ListenableFuture<Void> call(AtomicLong committedOffset) {
+    private ListenableFuture<Long> call(AtomicLong committedOffset, Channel channel) {
       CallOptions callOptions =
           CallOptions.DEFAULT
               .withCallCredentials(callCredentialsProvider.getCallCredentials())
               .withDeadlineAfter(callTimeoutSecs, SECONDS);
       call = channel.newCall(ByteStreamGrpc.getWriteMethod(), callOptions);
 
-      SettableFuture<Void> uploadResult = SettableFuture.create();
+      SettableFuture<Long> uploadResult = SettableFuture.create();
       ClientCall.Listener<WriteResponse> callListener =
           new ClientCall.Listener<WriteResponse>() {
 
@@ -568,7 +576,7 @@ public void onMessage(WriteResponse response) {
             @Override
             public void onClose(Status status, Metadata trailers) {
               if (status.isOk()) {
-                uploadResult.set(null);
+                uploadResult.set(committedOffset.get());
               } else {
                 uploadResult.setException(status.asRuntimeException());
               }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ExperimentalGrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/ExperimentalGrpcRemoteExecutor.java
index 41f5306624d29e..d50a77cd1c32b2 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/ExperimentalGrpcRemoteExecutor.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/ExperimentalGrpcRemoteExecutor.java
@@ -35,12 +35,13 @@
 import com.google.longrunning.Operation;
 import com.google.longrunning.Operation.ResultCase;
 import com.google.rpc.Status;
+import io.grpc.Channel;
 import io.grpc.Status.Code;
 import io.grpc.StatusRuntimeException;
+import io.reactivex.rxjava3.functions.Function;
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Supplier;
 import javax.annotation.Nullable;
 
 /**
@@ -73,7 +74,7 @@ public ExperimentalGrpcRemoteExecutor(
     this.retrier = retrier;
   }
 
-  private ExecutionBlockingStub executionBlockingStub(RequestMetadata metadata) {
+  private ExecutionBlockingStub executionBlockingStub(RequestMetadata metadata, Channel channel) {
     return ExecutionGrpc.newBlockingStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataInterceptor(metadata))
         .withCallCredentials(callCredentialsProvider.getCallCredentials())
@@ -90,7 +91,8 @@ private static class Execution {
     // Count retry times for WaitExecution() calls and is reset when we receive any response from
     // the server that is not an error.
     private final ProgressiveBackoff waitExecutionBackoff;
-    private final Supplier<ExecutionBlockingStub> executionBlockingStubSupplier;
+    private final Function<ExecuteRequest, Iterator<Operation>> executeFunction;
+    private final Function<WaitExecutionRequest, Iterator<Operation>> waitExecutionFunction;
 
     // Last response (without error) we received from server.
     private Operation lastOperation;
@@ -100,14 +102,16 @@ private static class Execution {
         OperationObserver observer,
         RemoteRetrier retrier,
         CallCredentialsProvider callCredentialsProvider,
-        Supplier<ExecutionBlockingStub> executionBlockingStubSupplier) {
+        Function<ExecuteRequest, Iterator<Operation>> executeFunction,
+        Function<WaitExecutionRequest, Iterator<Operation>> waitExecutionFunction) {
       this.request = request;
       this.observer = observer;
       this.retrier = retrier;
       this.callCredentialsProvider = callCredentialsProvider;
       this.executeBackoff = this.retrier.newBackoff();
       this.waitExecutionBackoff = new ProgressiveBackoff(this.retrier::newBackoff);
-      this.executionBlockingStubSupplier = executionBlockingStubSupplier;
+      this.executeFunction = executeFunction;
+      this.waitExecutionFunction = waitExecutionFunction;
     }
 
     ExecuteResponse start() throws IOException, InterruptedException {
@@ -168,9 +172,9 @@ ExecuteResponse execute() throws IOException {
       Preconditions.checkState(lastOperation == null);
 
       try {
-        Iterator<Operation> operationStream = executionBlockingStubSupplier.get().execute(request);
+        Iterator<Operation> operationStream = executeFunction.apply(request);
         return handleOperationStream(operationStream);
-      } catch (StatusRuntimeException e) {
+      } catch (Throwable e) {
         // If lastOperation is not null, we know the execution request is accepted by the server. In
         // this case, we will fallback to WaitExecution() loop when the stream is broken.
         if (lastOperation != null) {
@@ -188,17 +192,20 @@ ExecuteResponse waitExecution() throws IOException {
       WaitExecutionRequest request =
           WaitExecutionRequest.newBuilder().setName(lastOperation.getName()).build();
       try {
-        Iterator<Operation> operationStream =
-            executionBlockingStubSupplier.get().waitExecution(request);
+        Iterator<Operation> operationStream = waitExecutionFunction.apply(request);
         return handleOperationStream(operationStream);
-      } catch (StatusRuntimeException e) {
+      } catch (Throwable e) {
         // A NOT_FOUND error means Operation was lost on the server, retry Execute().
         //
         // However, we only retry Execute() if executeBackoff should retry. Also increase the retry
         // counter at the same time (done by nextDelayMillis()).
-        if (e.getStatus().getCode() == Code.NOT_FOUND && executeBackoff.nextDelayMillis(e) >= 0) {
-          lastOperation = null;
-          return null;
+        if (e instanceof StatusRuntimeException) {
+          StatusRuntimeException sre = (StatusRuntimeException) e;
+          if (sre.getStatus().getCode() == Code.NOT_FOUND
+              && executeBackoff.nextDelayMillis(sre) >= 0) {
+            lastOperation = null;
+            return null;
+          }
         }
         throw new IOException(e);
       }
@@ -321,7 +328,16 @@ public ExecuteResponse executeRemotely(
             observer,
             retrier,
             callCredentialsProvider,
-            () -> this.executionBlockingStub(context.getRequestMetadata()));
+            (req) ->
+                channel.withChannelBlocking(
+                    channel ->
+                        this.executionBlockingStub(context.getRequestMetadata(), channel)
+                            .execute(req)),
+            (req) ->
+                channel.withChannelBlocking(
+                    channel ->
+                        this.executionBlockingStub(context.getRequestMetadata(), channel)
+                            .waitExecution(req)));
     return execution.start();
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
index e35d4c6f32ce94..717504be39f401 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
@@ -56,6 +56,7 @@
 import com.google.devtools.build.lib.remote.zstd.ZstdDecompressingOutputStream;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.protobuf.ByteString;
+import io.grpc.Channel;
 import io.grpc.Status;
 import io.grpc.Status.Code;
 import io.grpc.StatusRuntimeException;
@@ -122,7 +123,8 @@ private int computeMaxMissingBlobsDigestsPerMessage() {
     return (options.maxOutboundMessageSize - overhead) / digestSize;
   }
 
-  private ContentAddressableStorageFutureStub casFutureStub(RemoteActionExecutionContext context) {
+  private ContentAddressableStorageFutureStub casFutureStub(
+      RemoteActionExecutionContext context, Channel channel) {
     return ContentAddressableStorageGrpc.newFutureStub(channel)
         .withInterceptors(
             TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()),
@@ -131,7 +133,7 @@ private ContentAddressableStorageFutureStub casFutureStub(RemoteActionExecutionC
         .withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
   }
 
-  private ByteStreamStub bsAsyncStub(RemoteActionExecutionContext context) {
+  private ByteStreamStub bsAsyncStub(RemoteActionExecutionContext context, Channel channel) {
     return ByteStreamGrpc.newStub(channel)
         .withInterceptors(
             TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()),
@@ -140,7 +142,8 @@ private ByteStreamStub bsAsyncStub(RemoteActionExecutionContext context) {
         .withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
   }
 
-  private ActionCacheFutureStub acFutureStub(RemoteActionExecutionContext context) {
+  private ActionCacheFutureStub acFutureStub(
+      RemoteActionExecutionContext context, Channel channel) {
     return ActionCacheGrpc.newFutureStub(channel)
         .withInterceptors(
             TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()),
@@ -222,7 +225,11 @@ public ListenableFuture<ImmutableSet<Digest>> findMissingDigests(
   private ListenableFuture<FindMissingBlobsResponse> getMissingDigests(
       RemoteActionExecutionContext context, FindMissingBlobsRequest request) {
     return Utils.refreshIfUnauthenticatedAsync(
-        () -> retrier.executeAsync(() -> casFutureStub(context).findMissingBlobs(request)),
+        () ->
+            retrier.executeAsync(
+                () ->
+                    channel.withChannelFuture(
+                        channel -> casFutureStub(context, channel).findMissingBlobs(request))),
         callCredentialsProvider);
   }
 
@@ -254,7 +261,10 @@ public ListenableFuture<CachedActionResult> downloadActionResult(
     return Utils.refreshIfUnauthenticatedAsync(
         () ->
             retrier.executeAsync(
-                () -> handleStatus(acFutureStub(context).getActionResult(request))),
+                () ->
+                    handleStatus(
+                        channel.withChannelFuture(
+                            channel -> acFutureStub(context, channel).getActionResult(request)))),
         callCredentialsProvider);
   }
 
@@ -267,13 +277,15 @@ public ListenableFuture<Void> uploadActionResult(
                 retrier.executeAsync(
                     () ->
                         Futures.catchingAsync(
-                            acFutureStub(context)
-                                .updateActionResult(
-                                    UpdateActionResultRequest.newBuilder()
-                                        .setInstanceName(options.remoteInstanceName)
-                                        .setActionDigest(actionKey.getDigest())
-                                        .setActionResult(actionResult)
-                                        .build()),
+                            channel.withChannelFuture(
+                                channel ->
+                                    acFutureStub(context, channel)
+                                        .updateActionResult(
+                                            UpdateActionResultRequest.newBuilder()
+                                                .setInstanceName(options.remoteInstanceName)
+                                                .setActionDigest(actionKey.getDigest())
+                                                .setActionResult(actionResult)
+                                                .build())),
                             StatusRuntimeException.class,
                             (sre) -> Futures.immediateFailedFuture(new IOException(sre)),
                             MoreExecutors.directExecutor())),
@@ -317,18 +329,26 @@ private ListenableFuture<Void> downloadBlob(
       @Nullable Supplier<Digest> digestSupplier) {
     AtomicLong offset = new AtomicLong(0);
     ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
-    ListenableFuture<Void> downloadFuture =
+    ListenableFuture<Long> downloadFuture =
         Utils.refreshIfUnauthenticatedAsync(
             () ->
                 retrier.executeAsync(
                     () ->
-                        requestRead(
-                            context, offset, progressiveBackoff, digest, out, digestSupplier),
+                        channel.withChannelFuture(
+                            channel ->
+                                requestRead(
+                                    context,
+                                    offset,
+                                    progressiveBackoff,
+                                    digest,
+                                    out,
+                                    digestSupplier,
+                                    channel)),
                     progressiveBackoff),
             callCredentialsProvider);
 
     return Futures.catchingAsync(
-        downloadFuture,
+        Futures.transform(downloadFuture, bytesWritten -> null, MoreExecutors.directExecutor()),
         StatusRuntimeException.class,
         (e) -> Futures.immediateFailedFuture(new IOException(e)),
         MoreExecutors.directExecutor());
@@ -343,17 +363,18 @@ public static String getResourceName(String instanceName, Digest digest, boolean
     return resourceName + DigestUtil.toString(digest);
   }
 
-  private ListenableFuture<Void> requestRead(
+  private ListenableFuture<Long> requestRead(
       RemoteActionExecutionContext context,
       AtomicLong offset,
       ProgressiveBackoff progressiveBackoff,
       Digest digest,
       CountingOutputStream out,
-      @Nullable Supplier<Digest> digestSupplier) {
+      @Nullable Supplier<Digest> digestSupplier,
+      Channel channel) {
     String resourceName =
         getResourceName(options.remoteInstanceName, digest, options.cacheCompression);
-    SettableFuture<Void> future = SettableFuture.create();
-    bsAsyncStub(context)
+    SettableFuture<Long> future = SettableFuture.create();
+    bsAsyncStub(context, channel)
         .read(
             ReadRequest.newBuilder()
                 .setResourceName(resourceName)
@@ -400,7 +421,7 @@ public void onCompleted() {
                     Utils.verifyBlobContents(digest, digestSupplier.get());
                   }
                   out.flush();
-                  future.set(null);
+                  future.set(offset.get());
                 } catch (IOException e) {
                   future.setException(e);
                 } catch (RuntimeException e) {
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
index 0b8c3fa312585e..df3872ebfaeb46 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
@@ -30,6 +30,7 @@
 import com.google.devtools.build.lib.remote.util.Utils;
 import com.google.longrunning.Operation;
 import com.google.rpc.Status;
+import io.grpc.Channel;
 import io.grpc.Status.Code;
 import io.grpc.StatusRuntimeException;
 import java.io.IOException;
@@ -57,7 +58,7 @@ public GrpcRemoteExecutor(
     this.retrier = retrier;
   }
 
-  private ExecutionBlockingStub execBlockingStub(RequestMetadata metadata) {
+  private ExecutionBlockingStub execBlockingStub(RequestMetadata metadata, Channel channel) {
     return ExecutionGrpc.newBlockingStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataInterceptor(metadata))
         .withCallCredentials(callCredentialsProvider.getCallCredentials());
@@ -152,9 +153,17 @@ public ExecuteResponse executeRemotely(
                             WaitExecutionRequest.newBuilder()
                                 .setName(operation.get().getName())
                                 .build();
-                        replies = execBlockingStub(context.getRequestMetadata()).waitExecution(wr);
+                        replies =
+                            channel.withChannelBlocking(
+                                channel ->
+                                    execBlockingStub(context.getRequestMetadata(), channel)
+                                        .waitExecution(wr));
                       } else {
-                        replies = execBlockingStub(context.getRequestMetadata()).execute(request);
+                        replies =
+                            channel.withChannelBlocking(
+                                channel ->
+                                    execBlockingStub(context.getRequestMetadata(), channel)
+                                        .execute(request));
                       }
                       try {
                         while (replies.hasNext()) {
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java b/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java
index 36df5e77a1a78f..ee67160c372d54 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java
@@ -13,26 +13,23 @@
 // limitations under the License.
 package com.google.devtools.build.lib.remote;
 
-import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Throwables.throwIfInstanceOf;
+import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
 
-import com.google.common.base.Throwables;
+import com.google.common.util.concurrent.ListenableFuture;
 import com.google.devtools.build.lib.remote.grpc.ChannelConnectionFactory;
 import com.google.devtools.build.lib.remote.grpc.ChannelConnectionFactory.ChannelConnection;
 import com.google.devtools.build.lib.remote.grpc.DynamicConnectionPool;
 import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
-import io.grpc.CallOptions;
+import com.google.devtools.build.lib.remote.util.RxFutures;
 import io.grpc.Channel;
-import io.grpc.ClientCall;
-import io.grpc.ForwardingClientCall;
-import io.grpc.ForwardingClientCallListener;
-import io.grpc.Metadata;
-import io.grpc.MethodDescriptor;
-import io.grpc.Status;
 import io.netty.util.AbstractReferenceCounted;
 import io.netty.util.ReferenceCounted;
+import io.reactivex.rxjava3.annotations.CheckReturnValue;
+import io.reactivex.rxjava3.core.Single;
+import io.reactivex.rxjava3.core.SingleSource;
+import io.reactivex.rxjava3.functions.Function;
 import java.io.IOException;
-import java.util.concurrent.atomic.AtomicReference;
-import javax.annotation.Nullable;
 
 /**
  * A wrapper around a {@link DynamicConnectionPool} exposing {@link Channel} and a reference count.
@@ -41,7 +38,7 @@
  *
  * <p>See {@link ReferenceCounted} for more information about reference counting.
  */
-public class ReferenceCountedChannel extends Channel implements ReferenceCounted {
+public class ReferenceCountedChannel implements ReferenceCounted {
   private final DynamicConnectionPool dynamicConnectionPool;
   private final AbstractReferenceCounted referenceCounted =
       new AbstractReferenceCounted() {
@@ -59,7 +56,6 @@ public ReferenceCounted touch(Object o) {
           return this;
         }
       };
-  private final AtomicReference<String> authorityRef = new AtomicReference<>();
 
   public ReferenceCountedChannel(ChannelConnectionFactory connectionFactory) {
     this(connectionFactory, /*maxConnections=*/ 0);
@@ -75,93 +71,42 @@ public boolean isShutdown() {
     return dynamicConnectionPool.isClosed();
   }
 
-  /** A {@link ClientCall} which call {@link SharedConnection#close()} after the RPC is closed. */
-  static class ConnectionCleanupCall<ReqT, RespT>
-      extends ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT> {
-    private final SharedConnection connection;
-
-    protected ConnectionCleanupCall(ClientCall<ReqT, RespT> delegate, SharedConnection connection) {
-      super(delegate);
-      this.connection = connection;
-    }
-
-    @Override
-    public void start(Listener<RespT> responseListener, Metadata headers) {
-      super.start(
-          new ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT>(
-              responseListener) {
-            @Override
-            public void onClose(Status status, Metadata trailers) {
-              try {
-                connection.close();
-              } catch (IOException e) {
-                throw new AssertionError(e.getMessage(), e);
-              } finally {
-                super.onClose(status, trailers);
-              }
-            }
-          },
-          headers);
-    }
-  }
-
-  private static class CloseOnStartClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> {
-    private final Status status;
-
-    CloseOnStartClientCall(Status status) {
-      this.status = status;
-    }
-
-    @Override
-    public void start(Listener<RespT> responseListener, Metadata headers) {
-      responseListener.onClose(status, new Metadata());
-    }
-
-    @Override
-    public void request(int numMessages) {}
-
-    @Override
-    public void cancel(@Nullable String message, @Nullable Throwable cause) {}
-
-    @Override
-    public void halfClose() {}
-
-    @Override
-    public void sendMessage(ReqT message) {}
+  @CheckReturnValue
+  public <T> ListenableFuture<T> withChannelFuture(
+      Function<Channel, ? extends ListenableFuture<T>> source) {
+    return RxFutures.toListenableFuture(
+        withChannel(channel -> RxFutures.toSingle(() -> source.apply(channel), directExecutor())));
   }
 
-  private SharedConnection acquireSharedConnection() throws IOException, InterruptedException {
+  public <T> T withChannelBlocking(Function<Channel, T> source)
+      throws IOException, InterruptedException {
     try {
-      SharedConnection sharedConnection = dynamicConnectionPool.create().blockingGet();
-      ChannelConnection connection = (ChannelConnection) sharedConnection.getUnderlyingConnection();
-      authorityRef.compareAndSet(null, connection.getChannel().authority());
-      return sharedConnection;
+      return withChannel(channel -> Single.just(source.apply(channel))).blockingGet();
     } catch (RuntimeException e) {
-      Throwables.throwIfInstanceOf(e.getCause(), IOException.class);
-      Throwables.throwIfInstanceOf(e.getCause(), InterruptedException.class);
+      Throwable cause = e.getCause();
+      if (cause != null) {
+        throwIfInstanceOf(cause, IOException.class);
+        throwIfInstanceOf(cause, InterruptedException.class);
+      }
       throw e;
     }
   }
 
-  @Override
-  public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
-      MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
-    try {
-      SharedConnection sharedConnection = acquireSharedConnection();
-      return new ConnectionCleanupCall<>(
-          sharedConnection.call(methodDescriptor, callOptions), sharedConnection);
-    } catch (IOException e) {
-      return new CloseOnStartClientCall<>(Status.UNKNOWN.withCause(e));
-    } catch (InterruptedException e) {
-      return new CloseOnStartClientCall<>(Status.CANCELLED.withCause(e));
-    }
-  }
-
-  @Override
-  public String authority() {
-    String authority = authorityRef.get();
-    checkNotNull(authority, "create a connection first to get the authority");
-    return authority;
+  @CheckReturnValue
+  public <T> Single<T> withChannel(Function<Channel, ? extends SingleSource<? extends T>> source) {
+    return dynamicConnectionPool
+        .create()
+        .flatMap(
+            sharedConnection ->
+                Single.using(
+                    () -> sharedConnection,
+                    conn -> {
+                      ChannelConnection connection =
+                          (ChannelConnection) sharedConnection.getUnderlyingConnection();
+                      Channel channel = connection.getChannel();
+                      return source.apply(channel);
+                    },
+                    SharedConnection::close));
   }
 
   @Override
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
index 252abc94622141..df4e38cf5df2de 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
@@ -96,6 +96,7 @@
 import com.google.devtools.common.options.OptionsBase;
 import com.google.devtools.common.options.OptionsParsingResult;
 import io.grpc.CallCredentials;
+import io.grpc.Channel;
 import io.grpc.ClientInterceptor;
 import io.grpc.ManagedChannel;
 import io.reactivex.rxjava3.plugins.RxJavaPlugins;
@@ -516,7 +517,15 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
 
     String remoteBytestreamUriPrefix = remoteOptions.remoteBytestreamUriPrefix;
     if (Strings.isNullOrEmpty(remoteBytestreamUriPrefix)) {
-      remoteBytestreamUriPrefix = cacheChannel.authority();
+      try {
+        remoteBytestreamUriPrefix = cacheChannel.withChannelBlocking(Channel::authority);
+      } catch (IOException e) {
+        handleInitFailure(env, e, Code.CACHE_INIT_FAILURE);
+        return;
+      } catch (InterruptedException e) {
+        handleInitFailure(env, new IOException(e), Code.CACHE_INIT_FAILURE);
+        return;
+      }
       if (!Strings.isNullOrEmpty(remoteOptions.remoteInstanceName)) {
         remoteBytestreamUriPrefix += "/" + remoteOptions.remoteInstanceName;
       }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java
index 6eb03ceb559b87..6d486480e30b5d 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java
@@ -31,6 +31,7 @@
 import com.google.devtools.build.lib.remote.options.RemoteOptions;
 import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
 import io.grpc.CallCredentials;
+import io.grpc.Channel;
 import io.grpc.StatusRuntimeException;
 import java.io.IOException;
 import java.util.List;
@@ -59,7 +60,8 @@ public RemoteServerCapabilities(
     this.retrier = retrier;
   }
 
-  private CapabilitiesBlockingStub capabilitiesBlockingStub(RemoteActionExecutionContext context) {
+  private CapabilitiesBlockingStub capabilitiesBlockingStub(
+      RemoteActionExecutionContext context, Channel channel) {
     return CapabilitiesGrpc.newBlockingStub(channel)
         .withInterceptors(
             TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()))
@@ -77,7 +79,10 @@ public ServerCapabilities get(String buildRequestId, String commandId)
           instanceName == null
               ? GetCapabilitiesRequest.getDefaultInstance()
               : GetCapabilitiesRequest.newBuilder().setInstanceName(instanceName).build();
-      return retrier.execute(() -> capabilitiesBlockingStub(context).getCapabilities(request));
+      return retrier.execute(
+          () ->
+              channel.withChannelBlocking(
+                  channel -> capabilitiesBlockingStub(context, channel).getCapabilities(request)));
     } catch (StatusRuntimeException e) {
       if (e.getCause() instanceof IOException) {
         throw (IOException) e.getCause();
diff --git a/src/main/java/com/google/devtools/build/lib/remote/UploadManifest.java b/src/main/java/com/google/devtools/build/lib/remote/UploadManifest.java
index 5dbbb0721c1dd2..b9b391227306d4 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/UploadManifest.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/UploadManifest.java
@@ -354,8 +354,11 @@ public ActionResult upload(
     try {
       return uploadAsync(context, remoteCache, reporter).blockingGet();
     } catch (RuntimeException e) {
-      throwIfInstanceOf(e.getCause(), InterruptedException.class);
-      throwIfInstanceOf(e.getCause(), IOException.class);
+      Throwable cause = e.getCause();
+      if (cause != null) {
+        throwIfInstanceOf(cause, InterruptedException.class);
+        throwIfInstanceOf(cause, IOException.class);
+      }
       throw e;
     }
   }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java b/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java
index a0bc56b0b12d6e..c3456eb687968c 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java
@@ -38,6 +38,7 @@
 import com.google.gson.Gson;
 import com.google.gson.JsonObject;
 import io.grpc.CallCredentials;
+import io.grpc.Channel;
 import io.grpc.StatusRuntimeException;
 import java.io.IOException;
 import java.io.OutputStream;
@@ -122,7 +123,12 @@ public void download(
         newFetchBlobRequest(options.remoteInstanceName, urls, authHeaders, checksum, canonicalId);
     try {
       FetchBlobResponse response =
-          retrier.execute(() -> fetchBlockingStub(remoteActionExecutionContext).fetchBlob(request));
+          retrier.execute(
+              () ->
+                  channel.withChannelBlocking(
+                      channel ->
+                          fetchBlockingStub(remoteActionExecutionContext, channel)
+                              .fetchBlob(request)));
       final Digest blobDigest = response.getBlobDigest();
 
       retrier.execute(
@@ -172,7 +178,8 @@ static FetchBlobRequest newFetchBlobRequest(
     return requestBuilder.build();
   }
 
-  private FetchBlockingStub fetchBlockingStub(RemoteActionExecutionContext context) {
+  private FetchBlockingStub fetchBlockingStub(
+      RemoteActionExecutionContext context, Channel channel) {
     return FetchGrpc.newBlockingStub(channel)
         .withInterceptors(
             TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()))
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java b/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java
index 7eb07d4d95e05d..d86cfd8bcfdd8a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java
@@ -13,7 +13,6 @@
 // limitations under the License.
 package com.google.devtools.build.lib.remote.util;
 
-import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.util.concurrent.AbstractFuture;
@@ -31,7 +30,7 @@
 import io.reactivex.rxjava3.core.SingleOnSubscribe;
 import io.reactivex.rxjava3.disposables.Disposable;
 import io.reactivex.rxjava3.exceptions.Exceptions;
-import java.util.concurrent.Callable;
+import io.reactivex.rxjava3.functions.Supplier;
 import java.util.concurrent.CancellationException;
 import java.util.concurrent.Executor;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -48,7 +47,7 @@ private RxFutures() {}
    * completed.
    *
    * <p>A {@link ListenableFuture} represents some computation that is already in progress. We use
-   * {@link Callable} here to defer the execution of the thing that produces ListenableFuture until
+   * {@link Supplier} here to defer the execution of the thing that produces ListenableFuture until
    * there is subscriber.
    *
    * <p>Errors are also propagated except for certain "fatal" exceptions defined by rxjava. Multiple
@@ -57,19 +56,19 @@ private RxFutures() {}
    * <p>Disposes the Completable to cancel the underlying ListenableFuture.
    */
   public static Completable toCompletable(
-      Callable<ListenableFuture<Void>> callable, Executor executor) {
-    return Completable.create(new OnceCompletableOnSubscribe(callable, executor));
+      Supplier<ListenableFuture<Void>> supplier, Executor executor) {
+    return Completable.create(new OnceCompletableOnSubscribe(supplier, executor));
   }
 
   private static class OnceCompletableOnSubscribe implements CompletableOnSubscribe {
     private final AtomicBoolean subscribed = new AtomicBoolean(false);
 
-    private final Callable<ListenableFuture<Void>> callable;
+    private final Supplier<ListenableFuture<Void>> supplier;
     private final Executor executor;
 
     private OnceCompletableOnSubscribe(
-        Callable<ListenableFuture<Void>> callable, Executor executor) {
-      this.callable = callable;
+        Supplier<ListenableFuture<Void>> supplier, Executor executor) {
+      this.supplier = supplier;
       this.executor = executor;
     }
 
@@ -77,7 +76,7 @@ private OnceCompletableOnSubscribe(
     public void subscribe(@NonNull CompletableEmitter emitter) throws Throwable {
       try {
         checkState(!subscribed.getAndSet(true), "This completable cannot be subscribed to twice");
-        ListenableFuture<Void> future = callable.call();
+        ListenableFuture<Void> future = supplier.get();
         Futures.addCallback(
             future,
             new FutureCallback<Void>() {
@@ -120,7 +119,7 @@ public void onFailure(Throwable throwable) {
    * completed.
    *
    * <p>A {@link ListenableFuture} represents some computation that is already in progress. We use
-   * {@link Callable} here to defer the execution of the thing that produces ListenableFuture until
+   * {@link Supplier} here to defer the execution of the thing that produces ListenableFuture until
    * there is subscriber.
    *
    * <p>Errors are also propagated except for certain "fatal" exceptions defined by rxjava. Multiple
@@ -128,18 +127,18 @@ public void onFailure(Throwable throwable) {
    *
    * <p>Disposes the Single to cancel the underlying ListenableFuture.
    */
-  public static <T> Single<T> toSingle(Callable<ListenableFuture<T>> callable, Executor executor) {
-    return Single.create(new OnceSingleOnSubscribe<>(callable, executor));
+  public static <T> Single<T> toSingle(Supplier<ListenableFuture<T>> supplier, Executor executor) {
+    return Single.create(new OnceSingleOnSubscribe<>(supplier, executor));
   }
 
   private static class OnceSingleOnSubscribe<T> implements SingleOnSubscribe<T> {
     private final AtomicBoolean subscribed = new AtomicBoolean(false);
 
-    private final Callable<ListenableFuture<T>> callable;
+    private final Supplier<ListenableFuture<T>> supplier;
     private final Executor executor;
 
-    private OnceSingleOnSubscribe(Callable<ListenableFuture<T>> callable, Executor executor) {
-      this.callable = callable;
+    private OnceSingleOnSubscribe(Supplier<ListenableFuture<T>> supplier, Executor executor) {
+      this.supplier = supplier;
       this.executor = executor;
     }
 
@@ -147,13 +146,12 @@ private OnceSingleOnSubscribe(Callable<ListenableFuture<T>> callable, Executor e
     public void subscribe(@NonNull SingleEmitter<T> emitter) throws Throwable {
       try {
         checkState(!subscribed.getAndSet(true), "This single cannot be subscribed to twice");
-        ListenableFuture<T> future = callable.call();
+        ListenableFuture<T> future = supplier.get();
         Futures.addCallback(
             future,
             new FutureCallback<T>() {
               @Override
               public void onSuccess(@Nullable T t) {
-                checkNotNull(t, "value in future onSuccess callback is null");
                 emitter.onSuccess(t);
               }