Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay sending the server preface when SNI is configured #3484

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelPromise;
import io.netty.handler.ssl.SniCompletionEvent;
import io.netty.handler.ssl.SslHandler;
import reactor.netty.ReactorNetty;
import reactor.netty.observability.ReactorNettyHandlerContext;
Expand Down Expand Up @@ -238,6 +239,8 @@ static final class TlsMetricsHandler extends Observation.Context
final MicrometerChannelMetricsRecorder recorder;
final SocketAddress remoteAddress;
final String type;

boolean listenerAdded;
Observation observation;

// remote address and status are not known beforehand
Expand All @@ -257,31 +260,7 @@ static final class TlsMetricsHandler extends Observation.Context
@Override
@SuppressWarnings("try")
public void channelActive(ChannelHandlerContext ctx) {
SocketAddress rAddr = remoteAddress != null ? remoteAddress : ctx.channel().remoteAddress();
if (rAddr instanceof InetSocketAddress) {
InetSocketAddress address = (InetSocketAddress) rAddr;
this.netPeerName = address.getHostString();
this.netPeerPort = address.getPort() + "";
}
else {
this.netPeerName = rAddr.toString();
this.netPeerPort = "";
}
observation = Observation.createNotStarted(recorder.name() + TLS_HANDSHAKE_TIME, this, OBSERVATION_REGISTRY);
parentContextView = updateChannelContext(ctx.channel(), observation);
observation.start();
ctx.pipeline()
.get(SslHandler.class)
.handshakeFuture()
.addListener(f -> {
ctx.pipeline().remove(this);
status = f.isSuccess() ? SUCCESS : ERROR;
observation.stop();

ReactorNetty.setChannelContext(ctx.channel(), parentContextView);
parentContextView = null;
});

addListener(ctx);
ctx.fireChannelActive();
}

Expand Down Expand Up @@ -361,12 +340,46 @@ public void handlerRemoved(ChannelHandlerContext ctx) {

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
addListener(ctx);
}
ctx.fireUserEventTriggered(evt);
}

@Override
public Timer getTimer() {
return recorder.getTlsHandshakeTimer(getName(), netPeerName + ':' + netPeerPort, proxyAddress == null ? NA : proxyAddress, status);
}

private void addListener(ChannelHandlerContext ctx) {
if (!listenerAdded) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
if (sslHandler != null) {
listenerAdded = true;
SocketAddress rAddr = remoteAddress != null ? remoteAddress : ctx.channel().remoteAddress();
if (rAddr instanceof InetSocketAddress) {
InetSocketAddress address = (InetSocketAddress) rAddr;
this.netPeerName = address.getHostString();
this.netPeerPort = address.getPort() + "";
}
else {
this.netPeerName = rAddr.toString();
this.netPeerPort = "";
}
observation = Observation.createNotStarted(recorder.name() + TLS_HANDSHAKE_TIME, this, OBSERVATION_REGISTRY);
parentContextView = updateChannelContext(ctx.channel(), observation);
observation.start();
sslHandler.handshakeFuture()
.addListener(f -> {
ctx.pipeline().remove(this);
status = f.isSuccess() ? SUCCESS : ERROR;
observation.stop();

ReactorNetty.setChannelContext(ctx.channel(), parentContextView);
parentContextView = null;
});
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import io.netty.handler.codec.http2.Http2StreamFrameToHttpObjectCodec;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.AbstractSniHandler;
import io.netty.handler.ssl.ApplicationProtocolNames;
import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;
Expand Down Expand Up @@ -1144,9 +1145,14 @@ static final class H2OrHttp11Codec extends ApplicationProtocolNegotiationHandler
final ChannelOperations.OnSetup opsFactory;
final Duration readTimeout;
final Duration requestTimeout;
final boolean supportOnlyHttp2;
final Function<String, String> uriTagValue;

H2OrHttp11Codec(HttpServerChannelInitializer initializer, ConnectionObserver listener) {
this(initializer, listener, false);
}

H2OrHttp11Codec(HttpServerChannelInitializer initializer, ConnectionObserver listener, boolean supportOnlyHttp2) {
super(ApplicationProtocolNames.HTTP_1_1);
this.accessLogEnabled = initializer.accessLogEnabled;
this.accessLog = initializer.accessLog;
Expand All @@ -1169,6 +1175,7 @@ static final class H2OrHttp11Codec extends ApplicationProtocolNegotiationHandler
this.opsFactory = initializer.opsFactory;
this.readTimeout = initializer.readTimeout;
this.requestTimeout = initializer.requestTimeout;
this.supportOnlyHttp2 = supportOnlyHttp2;
this.uriTagValue = initializer.uriTagValue;
}

Expand All @@ -1188,7 +1195,7 @@ protected void configurePipeline(ChannelHandlerContext ctx, String protocol) {
return;
}

if (ApplicationProtocolNames.HTTP_1_1.equals(protocol)) {
if (!supportOnlyHttp2 && ApplicationProtocolNames.HTTP_1_1.equals(protocol)) {
configureHttp11Pipeline(p, accessLogEnabled, accessLog, compressPredicate, cookieDecoder, cookieEncoder, true,
decoder, formDecoderProvider, forwardedHeaderHandler, httpMessageLogFactory, idleTimeout, listener,
mapHandle, maxKeepAliveRequests, methodTagValue, metricsRecorder, minCompressionSize, readTimeout, requestTimeout, uriTagValue);
Expand Down Expand Up @@ -1308,29 +1315,38 @@ else if ((protocols & h11) == h11) {
uriTagValue);
}
else if ((protocols & h2) == h2) {
configureH2Pipeline(
channel.pipeline(),
accessLogEnabled,
accessLog,
compressPredicate(compressPredicate, minCompressionSize),
cookieDecoder,
cookieEncoder,
enableGracefulShutdown,
formDecoderProvider,
forwardedHeaderHandler,
http2SettingsSpec,
httpMessageLogFactory,
idleTimeout,
observer,
mapHandle,
methodTagValue,
metricsRecorder,
minCompressionSize,
opsFactory,
readTimeout,
requestTimeout,
uriTagValue,
decoder.validateHeaders());
ChannelHandler sslHandler = channel.pipeline().get(NettyPipeline.SslHandler);
if (sslHandler instanceof AbstractSniHandler) {
channel.pipeline()
.addBefore(NettyPipeline.ReactiveBridge,
NettyPipeline.H2OrHttp11Codec,
new H2OrHttp11Codec(this, observer, true));
}
else {
configureH2Pipeline(
channel.pipeline(),
accessLogEnabled,
accessLog,
compressPredicate(compressPredicate, minCompressionSize),
cookieDecoder,
cookieEncoder,
enableGracefulShutdown,
formDecoderProvider,
forwardedHeaderHandler,
http2SettingsSpec,
httpMessageLogFactory,
idleTimeout,
observer,
mapHandle,
methodTagValue,
metricsRecorder,
minCompressionSize,
opsFactory,
readTimeout,
requestTimeout,
uriTagValue,
decoder.validateHeaders());
}
}
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2142,12 +2142,18 @@ void testHang() {
}

@Test
void testSniSupport() throws Exception {
void testSniSupportHttp11() throws Exception {
doTestSniSupport(Function.identity(), Function.identity());
}

@ParameterizedTest
@MethodSource("h2CompatibleCombinations")
void testSniSupportHttp2(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols) throws Exception {
doTestSniSupport(server -> server.protocol(serverProtocols), client -> client.protocol(clientProtocols));
}

@Test
void testIssue3022() throws Exception {
void testIssue3022Http11() throws Exception {
TestHttpClientMetricsRecorder clientMetricsRecorder = new TestHttpClientMetricsRecorder();
TestHttpServerMetricsRecorder serverMetricsRecorder = new TestHttpServerMetricsRecorder();
doTestSniSupport(server -> server.metrics(true, () -> serverMetricsRecorder, Function.identity()),
Expand All @@ -2156,38 +2162,66 @@ void testIssue3022() throws Exception {
assertThat(serverMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO);
}

@ParameterizedTest
@MethodSource("h2CompatibleCombinations")
void testIssue3022Http2(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols) throws Exception {
TestHttpClientMetricsRecorder clientMetricsRecorder = new TestHttpClientMetricsRecorder();
TestHttpServerMetricsRecorder serverMetricsRecorder = new TestHttpServerMetricsRecorder();
doTestSniSupport(server -> server.protocol(serverProtocols).metrics(true, () -> serverMetricsRecorder, Function.identity()),
client -> client.protocol(clientProtocols).metrics(true, () -> clientMetricsRecorder, Function.identity()));
assertThat(clientMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO);
assertThat(serverMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO);
}

@Test
void testIssue3473Http11() throws Exception {
doTestSniSupport(server -> server.metrics(true, Function.identity()),
client -> client.metrics(true, Function.identity()));
}

@ParameterizedTest
@MethodSource("h2CompatibleCombinations")
void testIssue3473Http2(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols) throws Exception {
doTestSniSupport(server -> server.protocol(serverProtocols).metrics(true, Function.identity()),
client -> client.protocol(clientProtocols).metrics(true, Function.identity()));
}

private void doTestSniSupport(Function<HttpServer, HttpServer> serverCustomizer,
Function<HttpClient, HttpClient> clientCustomizer) throws Exception {
SelfSignedCertificate defaultCert = new SelfSignedCertificate("default");
Http11SslContextSpec defaultSslContextBuilder =
Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey());

SelfSignedCertificate testCert = new SelfSignedCertificate("test.com");
Http11SslContextSpec testSslContextBuilder =
Http11SslContextSpec.forServer(testCert.certificate(), testCert.privateKey());

Http11SslContextSpec clientSslContextBuilder =
Http11SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));

AtomicReference<String> hostname = new AtomicReference<>();
HttpServer server = serverCustomizer.apply(createServer());

boolean isH2 = (server.configuration()._protocols & HttpServerConfig.h2) == HttpServerConfig.h2;
SslProvider.ProtocolSslContextSpec defaultSslContextBuilder = isH2 ?
Http2SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()) :
Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey());
SslProvider.ProtocolSslContextSpec testSslContextBuilder = isH2 ?
Http2SslContextSpec.forServer(testCert.certificate(), testCert.privateKey()) :
Http11SslContextSpec.forServer(testCert.certificate(), testCert.privateKey());

disposableServer =
serverCustomizer.apply(createServer())
.secure(spec -> spec.sslContext(defaultSslContextBuilder)
.addSniMapping("*.test.com", domainSpec -> domainSpec.sslContext(testSslContextBuilder)))
.doOnChannelInit((obs, channel, remoteAddress) ->
channel.pipeline()
.addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
hostname.set(((SniCompletionEvent) evt).hostname());
}
ctx.fireUserEventTriggered(evt);
server.secure(spec -> spec.sslContext(defaultSslContextBuilder)
.addSniMapping("*.test.com", domainSpec -> domainSpec.sslContext(testSslContextBuilder)))
.doOnChannelInit((obs, channel, remoteAddress) ->
channel.pipeline()
.addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
hostname.set(((SniCompletionEvent) evt).hostname());
}
}))
.handle((req, res) -> res.sendString(Mono.just("testSniSupport")))
.bindNow();
ctx.fireUserEventTriggered(evt);
}
}))
.handle((req, res) -> res.sendString(Mono.just("testSniSupport")))
.bindNow();

SslProvider.ProtocolSslContextSpec clientSslContextBuilder = isH2 ?
Http2SslContextSpec.forClient().configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)) :
Http11SslContextSpec.forClient().configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));

clientCustomizer.apply(createClient(disposableServer::address))
.secure(spec -> spec.sslContext(clientSslContextBuilder)
Expand All @@ -2203,40 +2237,55 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
}

@Test
void testSniSupportAsyncMappings() throws Exception {
SelfSignedCertificate defaultCert = new SelfSignedCertificate("default");
Http11SslContextSpec defaultSslContextBuilder =
Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey());
void testSniSupportAsyncMappingsHttp11() throws Exception {
doTestSniSupportAsyncMappings(Function.identity(), Function.identity());
}

@ParameterizedTest
@MethodSource("h2CompatibleCombinations")
void testSniSupportAsyncMappingsHttp2(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols) throws Exception {
doTestSniSupportAsyncMappings(server -> server.protocol(serverProtocols), client -> client.protocol(clientProtocols));
}

private void doTestSniSupportAsyncMappings(Function<HttpServer, HttpServer> serverCustomizer,
Function<HttpClient, HttpClient> clientCustomizer) throws Exception {
SelfSignedCertificate defaultCert = new SelfSignedCertificate("default");
SelfSignedCertificate testCert = new SelfSignedCertificate("test.com");
Http11SslContextSpec testSslContextBuilder =

AtomicReference<String> hostname = new AtomicReference<>();
HttpServer server = serverCustomizer.apply(createServer());

boolean isH2 = (server.configuration()._protocols & HttpServerConfig.h2) == HttpServerConfig.h2;
SslProvider.ProtocolSslContextSpec defaultSslContextBuilder = isH2 ?
Http2SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()) :
Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey());
SslProvider.ProtocolSslContextSpec testSslContextBuilder = isH2 ?
Http2SslContextSpec.forServer(testCert.certificate(), testCert.privateKey()) :
Http11SslContextSpec.forServer(testCert.certificate(), testCert.privateKey());
SslProvider testSslProvider = SslProvider.builder().sslContext(testSslContextBuilder).build();

Http11SslContextSpec clientSslContextBuilder =
Http11SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));

AtomicReference<String> hostname = new AtomicReference<>();
disposableServer =
createServer()
.secure(spec -> spec.sslContext(defaultSslContextBuilder)
server.secure(spec -> spec.sslContext(defaultSslContextBuilder)
.setSniAsyncMappings((input, promise) -> promise.setSuccess(testSslProvider)))
.doOnChannelInit((obs, channel, remoteAddress) ->
channel.pipeline()
.addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
hostname.set(((SniCompletionEvent) evt).hostname());
}
ctx.fireUserEventTriggered(evt);
}
}))
.handle((req, res) -> res.sendString(Mono.just("testSniSupport")))
.bindNow();
.doOnChannelInit((obs, channel, remoteAddress) ->
channel.pipeline()
.addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
hostname.set(((SniCompletionEvent) evt).hostname());
}
ctx.fireUserEventTriggered(evt);
}
}))
.handle((req, res) -> res.sendString(Mono.just("testSniSupport")))
.bindNow();

createClient(disposableServer::address)
SslProvider.ProtocolSslContextSpec clientSslContextBuilder = isH2 ?
Http2SslContextSpec.forClient().configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)) :
Http11SslContextSpec.forClient().configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));

clientCustomizer.apply(createClient(disposableServer::address))
.secure(spec -> spec.sslContext(clientSslContextBuilder)
.serverNames(new SNIHostName("test.com")))
.get()
Expand Down