diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ResponseCancelTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ResponseCancelTest.java index 5ed4c63296..4b56608b2c 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ResponseCancelTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ResponseCancelTest.java @@ -1,5 +1,5 @@ /* - * Copyright © 2020-2021 Apple Inc. and the ServiceTalk project authors + * Copyright © 2020-2022 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,68 +15,105 @@ */ package io.servicetalk.http.netty; +import io.servicetalk.buffer.api.Buffer; +import io.servicetalk.buffer.api.CompositeBuffer; import io.servicetalk.client.api.ConnectionFactory; import io.servicetalk.client.api.ConnectionFactoryFilter; import io.servicetalk.client.api.DelegatingConnectionFactory; import io.servicetalk.concurrent.Cancellable; import io.servicetalk.concurrent.SingleSource.Processor; import io.servicetalk.concurrent.SingleSource.Subscriber; +import io.servicetalk.concurrent.api.Completable; +import io.servicetalk.concurrent.api.Publisher; import io.servicetalk.concurrent.api.Single; +import io.servicetalk.concurrent.api.TestPublisher; import io.servicetalk.context.api.ContextMap; import io.servicetalk.http.api.FilterableStreamingHttpConnection; import io.servicetalk.http.api.HttpClient; +import io.servicetalk.http.api.HttpConnection; import io.servicetalk.http.api.HttpExecutionStrategies; -import io.servicetalk.http.api.HttpResponse; +import io.servicetalk.http.api.HttpRequester; import io.servicetalk.http.api.StreamingHttpConnectionFilter; import io.servicetalk.http.api.StreamingHttpRequest; import io.servicetalk.http.api.StreamingHttpResponse; +import io.servicetalk.transport.api.ConnectionContext; +import io.servicetalk.transport.api.DelegatingConnectionAcceptor; import io.servicetalk.transport.api.ServerContext; import io.servicetalk.transport.api.TransportObserver; +import io.servicetalk.transport.netty.internal.ExecutionContextExtension; import org.hamcrest.Matcher; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.Nullable; import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable; +import static io.servicetalk.concurrent.api.Completable.completed; import static io.servicetalk.concurrent.api.Processors.newSingleProcessor; -import static io.servicetalk.concurrent.api.Single.defer; import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; import static io.servicetalk.http.netty.HttpClients.forSingleAddress; import static io.servicetalk.http.netty.HttpServers.forAddress; +import static io.servicetalk.logging.api.LogLevel.TRACE; import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress; import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; class ResponseCancelTest { - private final BlockingQueue> serverResponses; + @RegisterExtension + static final ExecutionContextExtension SERVER_CTX = + ExecutionContextExtension.cached("server-io", "server-executor") + .setClassLevel(true); + @RegisterExtension + static final ExecutionContextExtension CLIENT_CTX = + ExecutionContextExtension.cached("client-io", "client-executor") + .setClassLevel(true); + + private final BlockingQueue> serverResponses; private final BlockingQueue delayedClientCancels; private final BlockingQueue delayedClientTermination; private final ServerContext ctx; private final HttpClient client; - private final AtomicInteger connectionCount = new AtomicInteger(); + private final BlockingQueue clientConnectionsClosedStates = new LinkedBlockingQueue<>(); + private final BlockingQueue serverConnectionsClosedStates = new LinkedBlockingQueue<>(); ResponseCancelTest() throws Exception { serverResponses = new LinkedBlockingQueue<>(); delayedClientCancels = new LinkedBlockingQueue<>(); delayedClientTermination = new LinkedBlockingQueue<>(); ctx = forAddress(localAddress(0)) - .listenAndAwait((__, ___, factory) -> { - Processor resp = newSingleProcessor(); + .ioExecutor(SERVER_CTX.ioExecutor()) + .executor(SERVER_CTX.executor()) + .enableWireLogging("servicetalk-tests-wire-logger", TRACE, () -> true) + .appendConnectionAcceptorFilter(original -> new DelegatingConnectionAcceptor(original) { + @Override + public Completable accept(final ConnectionContext context) { + CountDownLatch serverConnectionClosed = new CountDownLatch(1); + serverConnectionsClosedStates.add(serverConnectionClosed); + context.onClose().whenFinally(serverConnectionClosed::countDown).subscribe(); + return completed(); + } + }) + .listenStreamingAndAwait((__, ___, factory) -> { + Processor resp = newSingleProcessor(); serverResponses.add(resp); return fromSource(resp); }); client = forSingleAddress(serverHostAndPort(ctx)) + .ioExecutor(CLIENT_CTX.ioExecutor()) + .executor(CLIENT_CTX.executor()) + .enableWireLogging("servicetalk-tests-wire-logger", TRACE, () -> true) .appendConnectionFilter(connection -> new StreamingHttpConnectionFilter(connection) { @Override public Single request(final StreamingHttpRequest request) { @@ -100,7 +137,7 @@ public void onError(final Throwable t) { } }) .appendConnectionFactoryFilter(ConnectionFactoryFilter.withStrategy( - original -> new CountingConnectionFactory(original, connectionCount), + original -> new CountingConnectionFactory(original, clientConnectionsClosedStates), HttpExecutionStrategies.offloadNone())) .build(); } @@ -112,12 +149,12 @@ void tearDown() throws Exception { } @Test - void cancel() throws Throwable { + void clientCancel() throws Throwable { CountDownLatch latch1 = new CountDownLatch(1); - Cancellable cancellable = sendRequest(latch1); + Cancellable cancellable = sendRequest(client, latch1); // wait for server to receive request. - serverResponses.take(); - assertThat("Unexpected connections count.", connectionCount.get(), is(1)); + Processor serverResp = serverResponses.take(); + assertActiveConnectionsCount(1); cancellable.cancel(); // wait for cancel to be observed but don't send cancel to the transport so that transport does not close the // connection which will then be ambiguous. @@ -126,23 +163,23 @@ void cancel() throws Throwable { // and hence fail the response. ClientTerminationSignal.resumeExpectFailure(delayedClientTermination, latch1, instanceOf(ClosedChannelException.class)); + clientConnectionsClosedStates.take().await(); + // Let the server write the response to fail the write and close the connection + serverResp.onSuccess(client.asStreamingClient().httpResponseFactory().ok()); + serverConnectionsClosedStates.take().await(); - CountDownLatch latch2 = new CountDownLatch(1); - sendRequest(latch2); - serverResponses.take().onSuccess(client.httpResponseFactory().ok()); - ClientTerminationSignal.resume(delayedClientTermination, latch2); - assertThat("Unexpected connections count.", connectionCount.get(), is(2)); + sendSecondRequest(); } @Test - void cancelAfterSuccessOnTransport() throws Throwable { + void clientCancelAfterSuccessOnTransport() throws Throwable { CountDownLatch latch1 = new CountDownLatch(1); - Cancellable cancellable = sendRequest(latch1); + Cancellable cancellable = sendRequest(client, latch1); // wait for server to receive request. - Processor serverResp = serverResponses.take(); - assertThat("Unexpected connections count.", connectionCount.get(), is(1)); + Processor serverResp = serverResponses.take(); + assertActiveConnectionsCount(1); - serverResp.onSuccess(client.httpResponseFactory().ok()); + serverResp.onSuccess(client.asStreamingClient().httpResponseFactory().ok()); cancellable.cancel(); // wait for cancel to be observed but don't send cancel to the transport so that transport does not close the // connection which will then be ambiguous. @@ -150,39 +187,105 @@ void cancelAfterSuccessOnTransport() throws Throwable { // As there is a race between completion and cancellation, we may get a success or failure, so just wait for // any termination. delayedClientTermination.take().resume(); + latch1.await(); + clientConnectionsClosedStates.take().await(); + serverConnectionsClosedStates.take().await(); + + sendSecondRequest(); + } + + @Test + void connectionCancel() throws Throwable { + HttpConnection connection = client.reserveConnection(client.get("/")).toFuture().get(); + Cancellable cancellable = sendRequest(connection, null); + // wait for server to receive request. + Processor serverResp = serverResponses.take(); + assertActiveConnectionsCount(1); + cancellable.cancel(); + // wait for cancel to be observed and propagate it to the transport to initiate connection closure. + delayedClientCancels.take().cancel(); + // Transport should close the connection, the response terminal signal is not guaranteed after cancellation. + clientConnectionsClosedStates.take().await(); + // Let the server write the response to fail the write and close the connection + serverResp.onSuccess(client.asStreamingClient().httpResponseFactory().ok()); + serverConnectionsClosedStates.take().await(); + + sendSecondRequest(); + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void connectionCancelWaitingForPayloadBody(boolean finishRequest) throws Throwable { + HttpConnection connection = client.reserveConnection(client.get("/")).toFuture().get(); + Cancellable cancellable = finishRequest ? sendRequest(connection, null) : + connection.asStreamingConnection().request(connection.asStreamingConnection().post("/") + .payloadBody(Publisher.never())).flatMapPublisher(StreamingHttpResponse::payloadBody) + .collect(() -> connection.executionContext().bufferAllocator().newCompositeBuffer(), + CompositeBuffer::addBuffer) + .subscribe(__ -> { }); + // wait for server to receive request. + Processor serverResp = serverResponses.take(); + assertActiveConnectionsCount(1); + + TestPublisher payload = new TestPublisher<>(); + serverResp.onSuccess(connection.asStreamingConnection().httpResponseFactory().ok() + .payloadBody(payload)); + // wait for response meta-data to be received. + delayedClientTermination.take().resume(); + // cancel payload body. + cancellable.cancel(); + // Transport should close the connection, the response terminal signal is not guaranteed after cancellation. + clientConnectionsClosedStates.take().await(); + // Finish server response to let server close the connection + payload.onComplete(); + serverConnectionsClosedStates.take().await(); + + sendSecondRequest(); + } + private void sendSecondRequest() throws Throwable { + assertActiveConnectionsCount(0); + // Validate client can still communicate with a server using a new connection. CountDownLatch latch2 = new CountDownLatch(1); - sendRequest(latch2); - serverResponses.take().onSuccess(client.httpResponseFactory().ok()); + sendRequest(client, latch2); + serverResponses.take().onSuccess(client.asStreamingClient().httpResponseFactory().ok()); ClientTerminationSignal.resume(delayedClientTermination, latch2); - assertThat("Unexpected connections count.", connectionCount.get(), is(2)); + assertActiveConnectionsCount(1); } - private Cancellable sendRequest(final CountDownLatch latch) { - return client.request(client.get("/")) - .afterOnSuccess(__ -> latch.countDown()) - .afterOnError(__ -> latch.countDown()) - .subscribe(__ -> { }); + private static Cancellable sendRequest(final HttpRequester requester, @Nullable final CountDownLatch latch) { + return (latch == null ? requester.request(requester.get("/")) : + requester.request(requester.get("/")) + .afterOnSuccess(__ -> latch.countDown()) + .afterOnError(__ -> latch.countDown()) + ).subscribe(__ -> { }); + } + + private void assertActiveConnectionsCount(int expected) { + assertThat("Unexpected client connections count.", clientConnectionsClosedStates, hasSize(expected)); + assertThat("Unexpected server connections count.", serverConnectionsClosedStates, hasSize(expected)); } private static class CountingConnectionFactory extends DelegatingConnectionFactory { - private final AtomicInteger connectionCount; + + private final BlockingQueue clientConnectionsClosedStates; CountingConnectionFactory( final ConnectionFactory delegate, - final AtomicInteger connectionCount) { + final BlockingQueue clientConnectionsClosedStates) { super(delegate); - this.connectionCount = connectionCount; + this.clientConnectionsClosedStates = clientConnectionsClosedStates; } @Override public Single newConnection(final InetSocketAddress inetSocketAddress, @Nullable final ContextMap context, @Nullable final TransportObserver observer) { - return defer(() -> { - connectionCount.incrementAndGet(); - return delegate().newConnection(inetSocketAddress, context, observer); + return delegate().newConnection(inetSocketAddress, context, observer).whenOnSuccess(c -> { + CountDownLatch clientConnectionClosed = new CountDownLatch(1); + clientConnectionsClosedStates.add(clientConnectionClosed); + c.onClose().whenFinally(clientConnectionClosed::countDown).subscribe(); }); } } diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NettyChannelPublisher.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NettyChannelPublisher.java index 2192c1c4a3..f7af095f80 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NettyChannelPublisher.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NettyChannelPublisher.java @@ -1,5 +1,5 @@ /* - * Copyright © 2018, 2020 Apple Inc. and the ServiceTalk project authors + * Copyright © 2018, 2020-2022 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -260,9 +260,12 @@ private void emitCatchError(@Nullable SubscriptionImpl target, Throwable cause, if (target != null) { emitError(target, cause); } else { - LOGGER.debug("caught unexpected exception, closing channel {}", channel, cause); + // This branch executes only when an error is originated by the current Subscriber: either an unexpected + // exception is thrown from Subscriber.onComplete() or cancellation. // If an incomplete subscriber is cancelled then close channel. A subscriber can cancel after getting - // complete, which should not close the channel. + // complete, which should not close the channel (won't reach this point, returns earlier). + // Use outbound/inbound closure instead of channel.close() to register CHANNEL_CLOSED_OUTBOUND event. + closeChannelOutbound(); closeChannelInbound(); } } @@ -282,6 +285,8 @@ private void emitComplete(SubscriptionImpl target) { try { target.associatedSub.onComplete(); } catch (Throwable cause) { + LOGGER.debug("Caught unexpected exception from Subscriber {}, closing channel {}", + target.associatedSub, channel, cause); emitCatchError(null, cause, false); } } @@ -308,13 +313,20 @@ private void cancel0(SubscriptionImpl forSubscription) { // If a cancel occurs with a valid subscription we need to clear any pending data and set a fatalError so that // any future Subscribers don't get partial data delivered from the queue. + // We don't need to terminate the subscriber because cancellation is originated by the subscriber, pass null. emitCatchError(null, StacklessClosedChannelException.newInstance(NettyChannelPublisher.class, "cancel"), true); } + // For cases when an error occurred in netty pipeline private void closeChannelInbound() { closeHandler.closeChannelInbound(channel); } + // For cases with an error occurred in subscriber or a result of cancellation + private void closeChannelOutbound() { + closeHandler.closeChannelOutbound(channel); + } + private void resetSubscription() { subscription = null; requestCount = 0;