diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 5215b4d9bc..7f045cc44d 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -45,6 +45,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CancellationException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -53,6 +54,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; +import javax.annotation.Nullable; import org.threeten.bp.Duration; /** @@ -517,6 +519,7 @@ public ClientCall newCall( /** ClientCall wrapper that makes sure to decrement the outstanding RPC count on completion. */ static class ReleasingClientCall extends SimpleForwardingClientCall { + @Nullable private CancellationException cancellationException; final Entry entry; public ReleasingClientCall(ClientCall delegate, Entry entry) { @@ -526,6 +529,9 @@ public ReleasingClientCall(ClientCall delegate, Entry entry) { @Override public void start(Listener responseListener, Metadata headers) { + if (cancellationException != null) { + throw new IllegalStateException("Call is already cancelled", cancellationException); + } try { super.start( new SimpleForwardingClientCallListener(responseListener) { @@ -542,7 +548,14 @@ public void onClose(Status status, Metadata trailers) { } catch (Exception e) { // In case start failed, make sure to release entry.release(); + throw e; } } + + @Override + public void cancel(@Nullable String message, @Nullable Throwable cause) { + this.cancellationException = new CancellationException(message); + super.cancel(message, cause); + } } } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 1bf4726532..25a87c0908 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -29,11 +29,17 @@ */ package com.google.api.gax.grpc; +import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE; import static com.google.common.truth.Truth.assertThat; import com.google.api.gax.grpc.testing.FakeChannelFactory; import com.google.api.gax.grpc.testing.FakeMethodDescriptor; import com.google.api.gax.grpc.testing.FakeServiceGrpc; +import com.google.api.gax.rpc.ClientContext; +import com.google.api.gax.rpc.ResponseObserver; +import com.google.api.gax.rpc.ServerStreamingCallSettings; +import com.google.api.gax.rpc.ServerStreamingCallable; +import com.google.api.gax.rpc.StreamController; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.type.Color; @@ -49,12 +55,14 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -595,4 +603,50 @@ public void removedActiveChannelsAreShutdown() throws Exception { // Now the channel should be closed Mockito.verify(channels.get(1), Mockito.times(1)).shutdown(); } + + @Test + public void testReleasingClientCallCancelEarly() throws IOException { + ClientCall mockClientCall = Mockito.mock(ClientCall.class); + Mockito.doAnswer(invocation -> null).when(mockClientCall).cancel(Mockito.any(), Mockito.any()); + ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class); + Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall); + ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1); + ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel)); + ChannelPool channelPool = ChannelPool.create(channelPoolSettings, factory); + ClientContext context = + ClientContext.newBuilder() + .setTransportChannel(GrpcTransportChannel.create(channelPool)) + .setDefaultCallContext(GrpcCallContext.of(channelPool, CallOptions.DEFAULT)) + .build(); + ServerStreamingCallSettings settings = + ServerStreamingCallSettings.newBuilder().build(); + ServerStreamingCallable streamingCallable = + GrpcCallableFactory.createServerStreamingCallable( + GrpcCallSettings.create(METHOD_SERVER_STREAMING_RECOGNIZE), settings, context); + Color request = Color.newBuilder().setRed(0.5f).build(); + + IllegalStateException e = + Assert.assertThrows( + IllegalStateException.class, + () -> + streamingCallable.call( + request, + new ResponseObserver() { + @Override + public void onStart(StreamController controller) { + controller.cancel(); + } + + @Override + public void onResponse(Object response) {} + + @Override + public void onError(Throwable t) {} + + @Override + public void onComplete() {} + })); + assertThat(e.getCause()).isInstanceOf(CancellationException.class); + assertThat(e.getMessage()).isEqualTo("Call is already cancelled"); + } }