Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HttpClientWithTomcatTest flaky test. #2903

Merged
merged 4 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import org.apache.catalina.Context;
import org.apache.catalina.Wrapper;
import org.apache.catalina.startup.Tomcat;
import org.apache.coyote.AbstractProtocol;
import org.apache.coyote.ProtocolHandler;
import org.apache.coyote.http11.AbstractHttp11Protocol;

import javax.servlet.MultipartConfigElement;
import javax.servlet.ServletException;
Expand All @@ -37,7 +40,7 @@
public class TomcatServer {
static final String TOMCAT_BASE_DIR = "./build/tomcat";
public static final String TOO_LARGE = "Request payload too large";
public static final int PAYLOAD_MAX = 5000000;
public static final int PAYLOAD_MAX = 4096;

final Tomcat tomcat;

Expand All @@ -54,6 +57,24 @@ public TomcatServer(int port) {
this.tomcat.setBaseDir(baseDir.getAbsolutePath());
}

public int getMaxSwallowSize() {
ProtocolHandler protoHandler = tomcat.getConnector().getProtocolHandler();
if (!(protoHandler instanceof AbstractProtocol<?>)) {
throw new IllegalStateException("Connection protocol handler is not an instance of AbstractProtocol: " + protoHandler.getClass().getName());
}
AbstractHttp11Protocol<?> protocol = (AbstractHttp11Protocol<?>) protoHandler;
return protocol.getMaxSwallowSize();
}

public void setMaxSwallowSize(int bytes) {
ProtocolHandler protoHandler = tomcat.getConnector().getProtocolHandler();
if (!(protoHandler instanceof AbstractProtocol<?>)) {
throw new IllegalStateException("Connection protocol handler is not an instance of AbstractProtocol: " + protoHandler.getClass().getName());
}
AbstractHttp11Protocol<?> protocol = (AbstractHttp11Protocol<?>) protoHandler;
protocol.setMaxSwallowSize(bytes);
}

public int port() {
if (this.started) {
return this.tomcat.getConnector().getLocalPort();
Expand Down Expand Up @@ -174,14 +195,15 @@ static final class PayloadSizeServlet extends HttpServlet {
protected void service(HttpServletRequest req, HttpServletResponse resp) throws IOException {
InputStream in = req.getInputStream();
int count = 0;
byte[] buf = new byte[4096];
int n;

while ((n = in.read()) != -1) {
while ((n = in.read(buf, 0, buf.length)) != -1) {
count += n;
if (count >= PAYLOAD_MAX) {
// By default, Tomcat is configured with maxSwallowSize=2 MB (see https://tomcat.apache.org/tomcat-9.0-doc/config/http.html)
// This means that once the 400 bad request is sent, the client will still be able to continue writing (if it is currently writing)
// up to 2 MB. So, it is very likely that the client will be blocked and it will then be able to consume the 400 bad request and
// up to 2 MB. So, it is very likely that the client will be blocked, and it will then be able to consume the 400 bad request and
// close itself the connection.
sendResponse(resp, TOO_LARGE, HttpServletResponse.SC_BAD_REQUEST);
return;
Expand All @@ -193,7 +215,7 @@ protected void service(HttpServletRequest req, HttpServletResponse resp) throws

private void sendResponse(HttpServletResponse resp, String message, int status) throws IOException {
resp.setStatus(status);
resp.setHeader("Transfer-Encoding", "chunked");
resp.setHeader("Content-Length", String.valueOf(message.length()));
resp.setHeader("Content-Type", "text/plain");
PrintWriter out = resp.getWriter();
out.print(message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
*/
package reactor.netty.http.client;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
Expand All @@ -29,12 +29,7 @@
import io.netty.handler.codec.http.multipart.HttpData;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Named;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.netty.ByteBufFlux;
Expand All @@ -60,8 +55,6 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;
import static reactor.netty.http.client.HttpClientOperations.SendForm.DEFAULT_FACTORY;
Expand All @@ -71,7 +64,8 @@
*/
class HttpClientWithTomcatTest {
private static TomcatServer tomcat;
private static final byte[] PAYLOAD = String.join("", Collections.nCopies(TomcatServer.PAYLOAD_MAX + (1024 * 1024), "X"))
private static final int MAX_SWALLOW_SIZE = 1024 * 1024;
private static final byte[] PAYLOAD = String.join("", Collections.nCopies(TomcatServer.PAYLOAD_MAX + MAX_SWALLOW_SIZE + (1024 * 1024), "X"))
.getBytes(Charset.defaultCharset());

@BeforeAll
Expand Down Expand Up @@ -333,46 +327,46 @@ void contentHeader() {
fixed.dispose();
}

static Stream<Arguments> testIssue2825Args() {
Supplier<Publisher<ByteBuf>> postMono = () -> Mono.just(Unpooled.wrappedBuffer(PAYLOAD));
Supplier<Publisher<ByteBuf>> postFlux = () -> Flux.just(Unpooled.wrappedBuffer(PAYLOAD));

return Stream.of(
Arguments.of(Named.of("postMono", postMono), Named.of("bytes", PAYLOAD.length)),
Arguments.of(Named.of("postFlux", postFlux), Named.of("bytes", PAYLOAD.length))
);
}
@Test
void testIssue2825() {
int currentMaxSwallowSize = tomcat.getMaxSwallowSize();

try {
tomcat.setMaxSwallowSize(MAX_SWALLOW_SIZE);

AtomicReference<SocketAddress> serverAddress = new AtomicReference<>();
HttpClient client = HttpClient.create()
.port(getPort())
.wiretap(false)
.metrics(true, ClientMetricsRecorder::reset)
.option(ChannelOption.SO_SNDBUF, 4096)
.doOnConnected(conn -> serverAddress.set(conn.address()));

StepVerifier.create(client
.headers(hdr -> hdr.set("Content-Type", "text/plain"))
.post()
.uri("/payload-size")
.send(Mono.just(Unpooled.wrappedBuffer(PAYLOAD)))
.response((r, buf) -> buf.aggregate().asString().zipWith(Mono.just(r))))
.expectNextMatches(tuple -> TomcatServer.TOO_LARGE.equals(tuple.getT1())
&& tuple.getT2().status().equals(HttpResponseStatus.BAD_REQUEST))
.expectComplete()
.verify(Duration.ofSeconds(30));

assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeMethod).isEqualTo("POST");
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime).isNotNull();
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime.isZero()).isFalse();
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeUri).isEqualTo("/payload-size");
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeRemoteAddr).isEqualTo(serverAddress.get());

assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentRemoteAddr).isEqualTo(serverAddress.get());
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentUri).isEqualTo("/payload-size");
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentBytes).isEqualTo(PAYLOAD.length);
}

@ParameterizedTest
@MethodSource("testIssue2825Args")
void testIssue2825(Supplier<Publisher<ByteBuf>> payload, long bytesToSend) {
AtomicReference<SocketAddress> serverAddress = new AtomicReference<>();
HttpClient client = HttpClient.create()
.port(getPort())
.wiretap(false)
.metrics(true, ClientMetricsRecorder::reset)
.doOnConnected(conn -> serverAddress.set(conn.address()));

StepVerifier.create(client
.headers(hdr -> hdr.set("Content-Type", "text/plain"))
.post()
.uri("/payload-size")
.send(payload.get())
.response((r, buf) -> buf.aggregate().asString().zipWith(Mono.just(r))))
.expectNextMatches(tuple -> TomcatServer.TOO_LARGE.equals(tuple.getT1())
&& tuple.getT2().status().equals(HttpResponseStatus.BAD_REQUEST))
.expectComplete()
.verify(Duration.ofSeconds(30));

assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeMethod).isEqualTo("POST");
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime).isNotNull();
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime.isZero()).isFalse();
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeUri).isEqualTo("/payload-size");
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeRemoteAddr).isEqualTo(serverAddress.get());

assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentRemoteAddr).isEqualTo(serverAddress.get());
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentUri).isEqualTo("/payload-size");
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentBytes).isEqualTo(bytesToSend);
finally {
tomcat.setMaxSwallowSize(currentMaxSwallowSize);
}
}

private int getPort() {
Expand Down