Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle cancel in ReleasingClientCall and rethrow the exception in start #1221

Merged
merged 6 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand All @@ -53,6 +54,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import org.threeten.bp.Duration;

/**
Expand Down Expand Up @@ -517,6 +519,7 @@ public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(

/** ClientCall wrapper that makes sure to decrement the outstanding RPC count on completion. */
static class ReleasingClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
private CancellationException cancellationException;
mutianf marked this conversation as resolved.
Show resolved Hide resolved
final Entry entry;

public ReleasingClientCall(ClientCall<ReqT, RespT> delegate, Entry entry) {
Expand All @@ -526,6 +529,9 @@ public ReleasingClientCall(ClientCall<ReqT, RespT> delegate, Entry entry) {

@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
if (cancellationException != null) {
throw new IllegalStateException("Call is already cancelled", cancellationException);
}
try {
super.start(
new SimpleForwardingClientCallListener<RespT>(responseListener) {
Expand All @@ -542,7 +548,14 @@ public void onClose(Status status, Metadata trailers) {
} catch (Exception e) {
// In case start failed, make sure to release
entry.release();
throw e;
burkedavison marked this conversation as resolved.
Show resolved Hide resolved
}
}

@Override
public void cancel(@Nullable String message, @Nullable Throwable cause) {
this.cancellationException = new CancellationException(message);
super.cancel(message, cause);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ public void run() {}

@Override
public void cancel() {
cancellationException = new CancellationException("User cancelled stream");
clientCall.cancel(null, cancellationException);
String message = "User cancelled stream";
cancellationException = new CancellationException(message);
clientCall.cancel(message, cancellationException);
mutianf marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@
*/
package com.google.api.gax.grpc;

import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE;
import static com.google.common.truth.Truth.assertThat;

import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeMethodDescriptor;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
import com.google.api.gax.rpc.ClientContext;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StreamController;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.type.Color;
Expand All @@ -49,12 +55,14 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -595,4 +603,50 @@ public void removedActiveChannelsAreShutdown() throws Exception {
// Now the channel should be closed
Mockito.verify(channels.get(1), Mockito.times(1)).shutdown();
}

@Test
public void testReleasingClientCallCancelEarly() throws IOException {
ClientCall mockClientCall = Mockito.mock(ClientCall.class);
Mockito.doAnswer(invocation -> null).when(mockClientCall).cancel(Mockito.any(), Mockito.any());
ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class);
Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall);
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));
ChannelPool channelPool = ChannelPool.create(channelPoolSettings, factory);
ClientContext context =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(channelPool))
.setDefaultCallContext(GrpcCallContext.of(channelPool, CallOptions.DEFAULT))
.build();
ServerStreamingCallSettings settings =
ServerStreamingCallSettings.<Color, Money>newBuilder().build();
ServerStreamingCallable streamingCallable =
GrpcCallableFactory.createServerStreamingCallable(
GrpcCallSettings.create(METHOD_SERVER_STREAMING_RECOGNIZE), settings, context);
Color request = Color.newBuilder().setRed(0.5f).build();
try {
mutianf marked this conversation as resolved.
Show resolved Hide resolved
streamingCallable.call(
request,
new ResponseObserver() {
@Override
public void onStart(StreamController controller) {
controller.cancel();
}

@Override
public void onResponse(Object response) {}

@Override
public void onError(Throwable t) {}

@Override
public void onComplete() {}
});
Assert.fail("Request should be cancelled");
} catch (Exception e) {
assertThat(e).isInstanceOf(IllegalStateException.class);
assertThat(e.getCause()).isInstanceOf(CancellationException.class);
assertThat(e.getMessage()).isEqualTo("Call is already cancelled");
}
}
}