Skip to content

Commit

Permalink
Fix flaky ResponseCancelTest (#2337)
Browse files Browse the repository at this point in the history
Motivation:

#2297 indicates that `ResponseCancelTest` sometimes fails. It happens
because we incorrectly request signals from `delayedClientTermination`
queue. After cancel we may or may not see a terminal event. In rare
cases, `StacklessClosedChannelException` is propagated to the subscriber
after cancel. The next request does not assume that there is a pending
`ClientTerminationSignal` in the queue and considers this exception as
failure for a new request.

Modifications:
- Introduce a `requestId` to associate `ClientTerminationSignal` with
a proper request;
- Discard signals for prior requests inside `resume` logic;
- Wrap `signal.err` with `AssertionError` to preserve a caller stack
trace;

Result:

Fixes #2297.
  • Loading branch information
idelpivnitskiy authored Sep 29, 2022
1 parent 438892b commit 906c4fc
Showing 1 changed file with 84 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.api.TestPublisher;
import io.servicetalk.context.api.ContextMap;
import io.servicetalk.context.api.ContextMap.Key;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpClient;
import io.servicetalk.http.api.HttpConnection;
import io.servicetalk.http.api.HttpExecutionStrategies;
import io.servicetalk.http.api.HttpRequest;
import io.servicetalk.http.api.HttpRequester;
import io.servicetalk.http.api.StreamingHttpConnection;
import io.servicetalk.http.api.StreamingHttpConnectionFilter;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpResponse;
Expand All @@ -43,29 +46,36 @@

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable;
import static io.servicetalk.concurrent.api.Completable.completed;
import static io.servicetalk.concurrent.api.Processors.newSingleProcessor;
import static io.servicetalk.concurrent.api.Publisher.never;
import static io.servicetalk.concurrent.api.SourceAdapters.fromSource;
import static io.servicetalk.context.api.ContextMap.Key.newKey;
import static io.servicetalk.http.netty.HttpClients.forSingleAddress;
import static io.servicetalk.http.netty.HttpServers.forAddress;
import static io.servicetalk.logging.api.LogLevel.TRACE;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.util.Objects.requireNonNull;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasSize;

@Timeout(3)
class ResponseCancelTest {

@RegisterExtension
Expand All @@ -77,6 +87,10 @@ class ResponseCancelTest {
ExecutionContextExtension.cached("client-io", "client-executor")
.setClassLevel(true);

private static final Logger LOGGER = LoggerFactory.getLogger(ResponseCancelTest.class);
private static final Key<Integer> REQUEST_ID = newKey("REQUEST_ID", Integer.class);
private static final AtomicInteger REQUEST_ID_GENERATOR = new AtomicInteger();

private final BlockingQueue<Processor<StreamingHttpResponse, StreamingHttpResponse>> serverResponses;
private final BlockingQueue<Cancellable> delayedClientCancels;
private final BlockingQueue<ClientTerminationSignal> delayedClientTermination;
Expand Down Expand Up @@ -114,6 +128,8 @@ public Completable accept(final ConnectionContext context) {
.appendConnectionFilter(connection -> new StreamingHttpConnectionFilter(connection) {
@Override
public Single<StreamingHttpResponse> request(final StreamingHttpRequest request) {
final Integer requestId = request.context().get(REQUEST_ID);
assert requestId != null;
return delegate().request(request)
.liftSync(target -> new Subscriber<StreamingHttpResponse>() {
@Override
Expand All @@ -123,12 +139,13 @@ public void onSubscribe(final Cancellable cancellable) {

@Override
public void onSuccess(final StreamingHttpResponse result) {
delayedClientTermination.add(new ClientTerminationSignal(target, result));
delayedClientTermination.add(
new ClientTerminationSignal(requestId, target, result));
}

@Override
public void onError(final Throwable t) {
delayedClientTermination.add(new ClientTerminationSignal(target, t));
delayedClientTermination.add(new ClientTerminationSignal(requestId, target, t));
}
});
}
Expand Down Expand Up @@ -205,16 +222,22 @@ void connectionCancel() throws Throwable {
sendSecondRequestUsingClient();
}

@ParameterizedTest
@ParameterizedTest(name = "{displayName} [{index}] finishRequest={0}")
@ValueSource(booleans = {false, true})
void connectionCancelWaitingForPayloadBody(boolean finishRequest) throws Throwable {
HttpConnection connection = client.reserveConnection(client.get("/")).toFuture().get();
Cancellable cancellable = finishRequest ? sendRequest(connection, null) :
connection.asStreamingConnection().request(connection.asStreamingConnection().post("/")
.payloadBody(never())).flatMapPublisher(StreamingHttpResponse::payloadBody)
.collect(() -> connection.executionContext().bufferAllocator().newCompositeBuffer(),
CompositeBuffer::addBuffer)
.subscribe(__ -> { });
Cancellable cancellable;
if (finishRequest) {
cancellable = sendRequest(connection, null);
} else {
StreamingHttpConnection streamingConnection = connection.asStreamingConnection();
StreamingHttpRequest request = streamingConnection.post("/").payloadBody(never());
request.context().put(REQUEST_ID, REQUEST_ID_GENERATOR.incrementAndGet());
cancellable = streamingConnection.request(request).flatMapPublisher(StreamingHttpResponse::payloadBody)
.collect(() -> connection.executionContext().bufferAllocator().newCompositeBuffer(),
CompositeBuffer::addBuffer)
.subscribe(__ -> { });
}
// wait for server to receive request.
Processor<StreamingHttpResponse, StreamingHttpResponse> serverResp = serverResponses.take();
assertActiveConnectionsCount(1);
Expand All @@ -238,16 +261,26 @@ void connectionCancelWaitingForPayloadBody(boolean finishRequest) throws Throwab
private void sendSecondRequestUsingClient() throws Throwable {
assertActiveConnectionsCount(0);
// Validate client can still communicate with a server using a new connection.
int requestId = REQUEST_ID_GENERATOR.incrementAndGet();
CountDownLatch latch = new CountDownLatch(1);
sendRequest(client, latch);
sendRequest(client, requestId, latch);
serverResponses.take().onSuccess(client.asStreamingClient().httpResponseFactory().ok());
ClientTerminationSignal.resume(delayedClientTermination, latch);
ClientTerminationSignal.resume(delayedClientTermination, requestId, latch);
assertActiveConnectionsCount(1);
}

private static Cancellable sendRequest(final HttpRequester requester, @Nullable final CountDownLatch latch) {
return (latch == null ? requester.request(requester.get("/")) :
requester.request(requester.get("/"))
private static Cancellable sendRequest(HttpRequester requester,
@Nullable CountDownLatch latch) {
return sendRequest(requester, REQUEST_ID_GENERATOR.incrementAndGet(), latch);
}

private static Cancellable sendRequest(HttpRequester requester,
int requestId,
@Nullable CountDownLatch latch) {
HttpRequest request = requester.get("/");
request.context().put(REQUEST_ID, requestId);
return (latch == null ? requester.request(request) :
requester.request(request)
.afterOnSuccess(__ -> latch.countDown())
.afterOnError(__ -> latch.countDown())
).subscribe(__ -> { });
Expand Down Expand Up @@ -283,27 +316,31 @@ public Single<FilterableStreamingHttpConnection> newConnection(final InetSocketA
}

private static final class ClientTerminationSignal {
@SuppressWarnings("rawtypes")
private final Subscriber subscriber;
private final int requestId;
private final Subscriber<? super StreamingHttpResponse> subscriber;
@Nullable
private final Throwable err;
@Nullable
private final StreamingHttpResponse response;

ClientTerminationSignal(@SuppressWarnings("rawtypes") final Subscriber subscriber, final Throwable err) {
this.subscriber = subscriber;
this.err = err;
ClientTerminationSignal(int requestId,
Subscriber<? super StreamingHttpResponse> subscriber,
Throwable err) {
this.requestId = requestId;
this.subscriber = requireNonNull(subscriber);
this.err = requireNonNull(err);
response = null;
}

ClientTerminationSignal(@SuppressWarnings("rawtypes") final Subscriber subscriber,
final StreamingHttpResponse response) {
this.subscriber = subscriber;
ClientTerminationSignal(int requestId,
Subscriber<? super StreamingHttpResponse> subscriber,
StreamingHttpResponse response) {
this.requestId = requestId;
this.subscriber = requireNonNull(subscriber);
err = null;
this.response = response;
this.response = requireNonNull(response);
}

@SuppressWarnings("unchecked")
void resume() {
if (err != null) {
subscriber.onError(err);
Expand All @@ -312,17 +349,36 @@ void resume() {
}
}

@SuppressWarnings("unchecked")
static void resume(BlockingQueue<ClientTerminationSignal> signals,
final CountDownLatch latch) throws Throwable {
ClientTerminationSignal signal = signals.take();
int requestId,
CountDownLatch latch) throws Throwable {
ClientTerminationSignal signal;
do {
// In case of cancel, a terminal signal may or may not arrive to the subscriber. The requestId helps
// to make sure we discard optional signals of all previous requests and resuming only for the current
// request.
signal = signals.take();
if (signal.requestId != requestId) {
LOGGER.info("Skipped {} because looking for requestId={}", signal, requestId);
}
} while (signal.requestId != requestId);
if (signal.err != null) {
signal.subscriber.onError(signal.err);
throw signal.err;
throw new AssertionError("Response terminated with an error", signal.err);
} else {
signal.subscriber.onSuccess(signal.response);
}
latch.await();
}

@Override
public String toString() {
return getClass().getSimpleName() +
"{requestId=" + requestId +
", subscriber=" + subscriber +
", err=" + err +
", response=" + response +
'}';
}
}
}

0 comments on commit 906c4fc

Please sign in to comment.