Skip to content

Commit

Permalink
Add EnforceSequentialModeRequesterFilter for http-utils module (#…
Browse files Browse the repository at this point in the history
…1752)

Motivation:

ServiceTalk transport is full-duplex, meaning that a
`StreamingHttpRequester` or `BlockingStreamingHttpRequester`  can read
the response before or while it still sends a request payload body.
In some scenarios and for backward compatibility with legacy HTTP
clients, users may have expectations of sequential execution of the
request and response. This filter helps to enforce that behavior.

Modifications:

- Add `EnforceSequentialModeRequesterFilter` that makes sure the
response is not returned to the caller until after the request is
complete;

Result:

Users can enforce request-response processing.
  • Loading branch information
idelpivnitskiy committed Aug 30, 2021
1 parent 708a0c7 commit 25191f3
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.servicetalk.http.api.HttpServerBuilder;
import io.servicetalk.http.api.SingleAddressHttpClientBuilder;
import io.servicetalk.http.api.StreamingHttpClient;
import io.servicetalk.http.api.StreamingHttpClientFilterFactory;
import io.servicetalk.http.api.StreamingHttpConnection;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpResponse;
Expand Down Expand Up @@ -126,6 +127,8 @@ enum ExecutorSupplier {
private StreamingHttpServiceFilterFactory serviceFilterFactory;
@Nullable
private ConnectionFactoryFilter<InetSocketAddress, FilterableStreamingHttpConnection> connectionFactoryFilter;
@Nullable
private StreamingHttpClientFilterFactory clientFilterFactory;
private HttpProtocolConfig protocol = h1Default();
private TransportObserver clientTransportObserver = NoopTransportObserver.INSTANCE;
private TransportObserver serverTransportObserver = NoopTransportObserver.INSTANCE;
Expand Down Expand Up @@ -185,6 +188,9 @@ private void startServer() throws Exception {
clientBuilder.appendConnectionFactoryFilter(
new TransportObserverConnectionFactoryFilter<>(clientTransportObserver));
}
if (clientFilterFactory != null) {
clientBuilder.appendClientFilter(clientFilterFactory);
}
httpClient = clientBuilder.ioExecutor(clientIoExecutor)
.executionStrategy(defaultStrategy(clientExecutor))
.protocols(protocol)
Expand Down Expand Up @@ -221,6 +227,10 @@ void connectionFactoryFilter(
this.connectionFactoryFilter = connectionFactoryFilter;
}

void clientFilterFactory(StreamingHttpClientFilterFactory clientFilterFactory) {
this.clientFilterFactory = clientFilterFactory;
}

@AfterEach
void stopServer() throws Exception {
newCompositeCloseable().appendAll(httpConnection, httpClient, clientExecutor, serverContext, serverExecutor)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright © 2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.http.netty;

import io.servicetalk.http.api.StreamingHttpConnection;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.http.utils.EnforceSequentialModeRequesterFilter;

import org.junit.jupiter.api.Test;

import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;

import static io.servicetalk.concurrent.api.Publisher.fromInputStream;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.netty.AbstractNettyHttpServerTest.ExecutorSupplier.CACHED;
import static io.servicetalk.http.netty.AbstractNettyHttpServerTest.ExecutorSupplier.CACHED_SERVER;
import static io.servicetalk.http.netty.TestServiceStreaming.SVC_ECHO;
import static io.servicetalk.utils.internal.PlatformDependent.throwException;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;

class FullDuplexAndSequentialModeTest extends AbstractNettyHttpServerTest {

private static final int CHUNK_SIZE = 1024;
private static final int SIZE = 2 * CHUNK_SIZE;

@Test
void defaultFullDuplex() throws Exception {
setUp(CACHED, CACHED_SERVER);

StreamingHttpConnection connection = streamingHttpConnection();
CountDownLatch continueRequest = new CountDownLatch(1);
StreamingHttpResponse response;
try (InputStream payload = payload()) {
response = stallingSendRequest(connection, continueRequest, payload).get();
// response meta-data received before request completes
assertResponse(response, HTTP_1_1, OK);
}
continueRequest.countDown();

ExecutionException e = assertThrows(ExecutionException.class, () -> response.payloadBody().toFuture().get());
assertThat(e.getCause(), instanceOf(IOException.class));
assertThat(e.getCause().getMessage(), containsString("Stream closed"));
}

@Test
void deferResponseUntilAfterRequestSent() throws Exception {
clientFilterFactory(EnforceSequentialModeRequesterFilter.INSTANCE);
setUp(CACHED, CACHED_SERVER);

StreamingHttpConnection connection = streamingHttpConnection();
CountDownLatch continueRequest = new CountDownLatch(1);
try (InputStream payload = payload()) {
Future<StreamingHttpResponse> responseFuture = stallingSendRequest(connection, continueRequest, payload);
// Delay completion of the request payload body:
Thread.sleep(100);
assertThat(responseFuture.isDone(), is(false)); // response meta-data completes only after request is sent
continueRequest.countDown();
assertResponse(responseFuture.get(), HTTP_1_1, OK, SIZE);
}
}

private static InputStream payload() {
byte[] array = new byte[SIZE];
ThreadLocalRandom.current().nextBytes(array);
return new BufferedInputStream(new ByteArrayInputStream(array));
}

private static Future<StreamingHttpResponse> stallingSendRequest(StreamingHttpConnection connection,
CountDownLatch continueRequest,
InputStream payload) {
return connection.request(connection.post(SVC_ECHO).payloadBody(fromInputStream(payload, CHUNK_SIZE)
.map(chunk -> {
try {
continueRequest.await(); // wait until the InputStream is closed
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throwException(ie);
}
return connection.executionContext().bufferAllocator().wrap(chunk);
}))).toFuture();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright © 2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.http.utils;

import io.servicetalk.concurrent.CompletableSource;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.http.api.BlockingStreamingHttpRequester;
import io.servicetalk.http.api.FilterableStreamingHttpClient;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpExecutionStrategy;
import io.servicetalk.http.api.HttpExecutionStrategyInfluencer;
import io.servicetalk.http.api.StreamingHttpClientFilter;
import io.servicetalk.http.api.StreamingHttpClientFilterFactory;
import io.servicetalk.http.api.StreamingHttpConnectionFilter;
import io.servicetalk.http.api.StreamingHttpConnectionFilterFactory;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpRequester;
import io.servicetalk.http.api.StreamingHttpResponse;

import static io.servicetalk.concurrent.api.Processors.newCompletableProcessor;
import static io.servicetalk.concurrent.api.SourceAdapters.fromSource;

/**
* Enforces sequential behavior of the client, deferring return of the response until after the request payload body is
* sent.
* <p>
* ServiceTalk transport is full-duplex, meaning that a {@link StreamingHttpRequester} or
* {@link BlockingStreamingHttpRequester} can read the response before or while it is still sending a request payload
* body. In some scenarios, and for backward compatibility with legacy HTTP clients, users may have expectations of a
* sequential execution of the request and response. This filter helps to enforce that behavior.
*/
public final class EnforceSequentialModeRequesterFilter implements StreamingHttpClientFilterFactory,
StreamingHttpConnectionFilterFactory,
HttpExecutionStrategyInfluencer {

/**
* Singleton instance of {@link EnforceSequentialModeRequesterFilter}.
*/
public static final EnforceSequentialModeRequesterFilter INSTANCE = new EnforceSequentialModeRequesterFilter();

private EnforceSequentialModeRequesterFilter() {
// Singleton
}

private static Single<StreamingHttpResponse> request(final StreamingHttpRequester delegate,
final HttpExecutionStrategy strategy,
final StreamingHttpRequest request) {
return Single.defer(() -> {
CompletableSource.Processor requestSent = newCompletableProcessor();
StreamingHttpRequest r = request.transformMessageBody(messageBody -> messageBody
.whenFinally(requestSent::onComplete));
return fromSource(requestSent).merge(delegate.request(strategy, r).toPublisher()).firstOrError()
.subscribeShareContext();
});
}

@Override
public StreamingHttpClientFilter create(final FilterableStreamingHttpClient client) {
return new StreamingHttpClientFilter(client) {
@Override
protected Single<StreamingHttpResponse> request(final StreamingHttpRequester delegate,
final HttpExecutionStrategy strategy,
final StreamingHttpRequest request) {
return EnforceSequentialModeRequesterFilter.request(delegate, strategy, request);
}
};
}

@Override
public StreamingHttpConnectionFilter create(final FilterableStreamingHttpConnection connection) {
return new StreamingHttpConnectionFilter(connection) {
@Override
public Single<StreamingHttpResponse> request(final HttpExecutionStrategy strategy,
final StreamingHttpRequest request) {
return EnforceSequentialModeRequesterFilter.request(delegate(), strategy, request);
}
};
}

@Override
public HttpExecutionStrategy influenceStrategy(final HttpExecutionStrategy strategy) {
return strategy; // No influence since we do not block
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright © 2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.http.utils;

import io.servicetalk.buffer.api.Buffer;
import io.servicetalk.concurrent.api.TestPublisher;
import io.servicetalk.concurrent.api.test.StepVerifiers;
import io.servicetalk.http.api.DefaultHttpHeadersFactory;
import io.servicetalk.http.api.DefaultStreamingHttpRequestResponseFactory;
import io.servicetalk.http.api.FilterableStreamingHttpClient;
import io.servicetalk.http.api.StreamingHttpClientFilter;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpRequestResponseFactory;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.time.Duration;
import java.util.concurrent.TimeoutException;

import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR;
import static io.servicetalk.concurrent.api.Publisher.never;
import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.http.api.HttpExecutionStrategies.defaultStrategy;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class EnforceSequentialModeRequesterFilterTest {

private static final StreamingHttpRequestResponseFactory REQ_RES_FACTORY =
new DefaultStreamingHttpRequestResponseFactory(DEFAULT_ALLOCATOR, DefaultHttpHeadersFactory.INSTANCE,
HTTP_1_1);

private final FilterableStreamingHttpClient client = mock(FilterableStreamingHttpClient.class);

@BeforeEach
void setUp() {
when(client.request(any(), any())).thenAnswer(invocation -> {
// Simulate consumption of the request payload body:
StreamingHttpRequest request = invocation.getArgument(1);
request.payloadBody().forEach(__ -> { /* noop */ });
return succeeded(REQ_RES_FACTORY.ok());
});
}

@Test
void responseCompletesAfterRequestPayloadBodyCompletes() {
TestPublisher<Buffer> payloadBody = new TestPublisher<>();
StreamingHttpRequest request = REQ_RES_FACTORY.post("/").payloadBody(payloadBody);

FilterableStreamingHttpClient client = EnforceSequentialModeRequesterFilter.INSTANCE.create(this.client);
StepVerifiers.create(client.request(defaultStrategy(), request))
.expectCancellable()
.then(payloadBody::onComplete)
.expectSuccess()
.verify();
}

@Test
void responseNeverCompletesIfRequestPayloadBodyNeverCompletes() {
StreamingHttpRequest request = REQ_RES_FACTORY.post("/").payloadBody(never());

StreamingHttpClientFilter client = EnforceSequentialModeRequesterFilter.INSTANCE.create(this.client);
AssertionError e = assertThrows(AssertionError.class,
() -> StepVerifiers.create(client.request(defaultStrategy(), request))
.expectCancellable()
.expectSuccess()
.verify(Duration.ofMillis(100)));
assertThat(e.getCause(), instanceOf(TimeoutException.class));
}

@Test
void withoutFilterResponseCompletesIndependently() {
StreamingHttpRequest request = REQ_RES_FACTORY.post("/").payloadBody(never());

StepVerifiers.create(this.client.request(defaultStrategy(), request))
.expectCancellable()
.expectSuccess()
.verify();
}
}

0 comments on commit 25191f3

Please sign in to comment.