diff --git a/reactor-netty-core/src/main/java/reactor/netty/channel/MicrometerChannelMetricsHandler.java b/reactor-netty-core/src/main/java/reactor/netty/channel/MicrometerChannelMetricsHandler.java index ddc41fa33a..ac505b6aa7 100644 --- a/reactor-netty-core/src/main/java/reactor/netty/channel/MicrometerChannelMetricsHandler.java +++ b/reactor-netty-core/src/main/java/reactor/netty/channel/MicrometerChannelMetricsHandler.java @@ -23,6 +23,7 @@ import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelOutboundHandler; import io.netty.channel.ChannelPromise; +import io.netty.handler.ssl.SniCompletionEvent; import io.netty.handler.ssl.SslHandler; import reactor.netty.ReactorNetty; import reactor.netty.observability.ReactorNettyHandlerContext; @@ -238,6 +239,8 @@ static final class TlsMetricsHandler extends Observation.Context final MicrometerChannelMetricsRecorder recorder; final SocketAddress remoteAddress; final String type; + + boolean listenerAdded; Observation observation; // remote address and status are not known beforehand @@ -257,31 +260,7 @@ static final class TlsMetricsHandler extends Observation.Context @Override @SuppressWarnings("try") public void channelActive(ChannelHandlerContext ctx) { - SocketAddress rAddr = remoteAddress != null ? remoteAddress : ctx.channel().remoteAddress(); - if (rAddr instanceof InetSocketAddress) { - InetSocketAddress address = (InetSocketAddress) rAddr; - this.netPeerName = address.getHostString(); - this.netPeerPort = address.getPort() + ""; - } - else { - this.netPeerName = rAddr.toString(); - this.netPeerPort = ""; - } - observation = Observation.createNotStarted(recorder.name() + TLS_HANDSHAKE_TIME, this, OBSERVATION_REGISTRY); - parentContextView = updateChannelContext(ctx.channel(), observation); - observation.start(); - ctx.pipeline() - .get(SslHandler.class) - .handshakeFuture() - .addListener(f -> { - ctx.pipeline().remove(this); - status = f.isSuccess() ? SUCCESS : ERROR; - observation.stop(); - - ReactorNetty.setChannelContext(ctx.channel(), parentContextView); - parentContextView = null; - }); - + addListener(ctx); ctx.fireChannelActive(); } @@ -361,6 +340,9 @@ public void handlerRemoved(ChannelHandlerContext ctx) { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SniCompletionEvent) { + addListener(ctx); + } ctx.fireUserEventTriggered(evt); } @@ -368,5 +350,36 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { public Timer getTimer() { return recorder.getTlsHandshakeTimer(getName(), netPeerName + ':' + netPeerPort, proxyAddress == null ? NA : proxyAddress, status); } + + private void addListener(ChannelHandlerContext ctx) { + if (!listenerAdded) { + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + if (sslHandler != null) { + listenerAdded = true; + SocketAddress rAddr = remoteAddress != null ? remoteAddress : ctx.channel().remoteAddress(); + if (rAddr instanceof InetSocketAddress) { + InetSocketAddress address = (InetSocketAddress) rAddr; + this.netPeerName = address.getHostString(); + this.netPeerPort = address.getPort() + ""; + } + else { + this.netPeerName = rAddr.toString(); + this.netPeerPort = ""; + } + observation = Observation.createNotStarted(recorder.name() + TLS_HANDSHAKE_TIME, this, OBSERVATION_REGISTRY); + parentContextView = updateChannelContext(ctx.channel(), observation); + observation.start(); + sslHandler.handshakeFuture() + .addListener(f -> { + ctx.pipeline().remove(this); + status = f.isSuccess() ? SUCCESS : ERROR; + observation.stop(); + + ReactorNetty.setChannelContext(ctx.channel(), parentContextView); + parentContextView = null; + }); + } + } + } } } diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpServerConfig.java b/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpServerConfig.java index 11f59cef73..1db18cf1b7 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpServerConfig.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpServerConfig.java @@ -45,6 +45,7 @@ import io.netty.handler.codec.http2.Http2StreamFrameToHttpObjectCodec; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.AbstractSniHandler; import io.netty.handler.ssl.ApplicationProtocolNames; import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; import io.netty.handler.timeout.ReadTimeoutHandler; @@ -1144,9 +1145,14 @@ static final class H2OrHttp11Codec extends ApplicationProtocolNegotiationHandler final ChannelOperations.OnSetup opsFactory; final Duration readTimeout; final Duration requestTimeout; + final boolean supportOnlyHttp2; final Function uriTagValue; H2OrHttp11Codec(HttpServerChannelInitializer initializer, ConnectionObserver listener) { + this(initializer, listener, false); + } + + H2OrHttp11Codec(HttpServerChannelInitializer initializer, ConnectionObserver listener, boolean supportOnlyHttp2) { super(ApplicationProtocolNames.HTTP_1_1); this.accessLogEnabled = initializer.accessLogEnabled; this.accessLog = initializer.accessLog; @@ -1169,6 +1175,7 @@ static final class H2OrHttp11Codec extends ApplicationProtocolNegotiationHandler this.opsFactory = initializer.opsFactory; this.readTimeout = initializer.readTimeout; this.requestTimeout = initializer.requestTimeout; + this.supportOnlyHttp2 = supportOnlyHttp2; this.uriTagValue = initializer.uriTagValue; } @@ -1188,7 +1195,7 @@ protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { return; } - if (ApplicationProtocolNames.HTTP_1_1.equals(protocol)) { + if (!supportOnlyHttp2 && ApplicationProtocolNames.HTTP_1_1.equals(protocol)) { configureHttp11Pipeline(p, accessLogEnabled, accessLog, compressPredicate, cookieDecoder, cookieEncoder, true, decoder, formDecoderProvider, forwardedHeaderHandler, httpMessageLogFactory, idleTimeout, listener, mapHandle, maxKeepAliveRequests, methodTagValue, metricsRecorder, minCompressionSize, readTimeout, requestTimeout, uriTagValue); @@ -1308,29 +1315,38 @@ else if ((protocols & h11) == h11) { uriTagValue); } else if ((protocols & h2) == h2) { - configureH2Pipeline( - channel.pipeline(), - accessLogEnabled, - accessLog, - compressPredicate(compressPredicate, minCompressionSize), - cookieDecoder, - cookieEncoder, - enableGracefulShutdown, - formDecoderProvider, - forwardedHeaderHandler, - http2SettingsSpec, - httpMessageLogFactory, - idleTimeout, - observer, - mapHandle, - methodTagValue, - metricsRecorder, - minCompressionSize, - opsFactory, - readTimeout, - requestTimeout, - uriTagValue, - decoder.validateHeaders()); + ChannelHandler sslHandler = channel.pipeline().get(NettyPipeline.SslHandler); + if (sslHandler instanceof AbstractSniHandler) { + channel.pipeline() + .addBefore(NettyPipeline.ReactiveBridge, + NettyPipeline.H2OrHttp11Codec, + new H2OrHttp11Codec(this, observer, true)); + } + else { + configureH2Pipeline( + channel.pipeline(), + accessLogEnabled, + accessLog, + compressPredicate(compressPredicate, minCompressionSize), + cookieDecoder, + cookieEncoder, + enableGracefulShutdown, + formDecoderProvider, + forwardedHeaderHandler, + http2SettingsSpec, + httpMessageLogFactory, + idleTimeout, + observer, + mapHandle, + methodTagValue, + metricsRecorder, + minCompressionSize, + opsFactory, + readTimeout, + requestTimeout, + uriTagValue, + decoder.validateHeaders()); + } } } else { diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java b/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java index a391896fc7..d8b6717877 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java @@ -2142,12 +2142,18 @@ void testHang() { } @Test - void testSniSupport() throws Exception { + void testSniSupportHttp11() throws Exception { doTestSniSupport(Function.identity(), Function.identity()); } + @ParameterizedTest + @MethodSource("h2CompatibleCombinations") + void testSniSupportHttp2(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols) throws Exception { + doTestSniSupport(server -> server.protocol(serverProtocols), client -> client.protocol(clientProtocols)); + } + @Test - void testIssue3022() throws Exception { + void testIssue3022Http11() throws Exception { TestHttpClientMetricsRecorder clientMetricsRecorder = new TestHttpClientMetricsRecorder(); TestHttpServerMetricsRecorder serverMetricsRecorder = new TestHttpServerMetricsRecorder(); doTestSniSupport(server -> server.metrics(true, () -> serverMetricsRecorder, Function.identity()), @@ -2156,38 +2162,66 @@ void testIssue3022() throws Exception { assertThat(serverMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO); } + @ParameterizedTest + @MethodSource("h2CompatibleCombinations") + void testIssue3022Http2(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols) throws Exception { + TestHttpClientMetricsRecorder clientMetricsRecorder = new TestHttpClientMetricsRecorder(); + TestHttpServerMetricsRecorder serverMetricsRecorder = new TestHttpServerMetricsRecorder(); + doTestSniSupport(server -> server.protocol(serverProtocols).metrics(true, () -> serverMetricsRecorder, Function.identity()), + client -> client.protocol(clientProtocols).metrics(true, () -> clientMetricsRecorder, Function.identity())); + assertThat(clientMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO); + assertThat(serverMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO); + } + + @Test + void testIssue3473Http11() throws Exception { + doTestSniSupport(server -> server.metrics(true, Function.identity()), + client -> client.metrics(true, Function.identity())); + } + + @ParameterizedTest + @MethodSource("h2CompatibleCombinations") + void testIssue3473Http2(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols) throws Exception { + doTestSniSupport(server -> server.protocol(serverProtocols).metrics(true, Function.identity()), + client -> client.protocol(clientProtocols).metrics(true, Function.identity())); + } + private void doTestSniSupport(Function serverCustomizer, Function clientCustomizer) throws Exception { SelfSignedCertificate defaultCert = new SelfSignedCertificate("default"); - Http11SslContextSpec defaultSslContextBuilder = - Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()); - SelfSignedCertificate testCert = new SelfSignedCertificate("test.com"); - Http11SslContextSpec testSslContextBuilder = - Http11SslContextSpec.forServer(testCert.certificate(), testCert.privateKey()); - - Http11SslContextSpec clientSslContextBuilder = - Http11SslContextSpec.forClient() - .configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)); AtomicReference hostname = new AtomicReference<>(); + HttpServer server = serverCustomizer.apply(createServer()); + + boolean isH2 = (server.configuration()._protocols & HttpServerConfig.h2) == HttpServerConfig.h2; + SslProvider.ProtocolSslContextSpec defaultSslContextBuilder = isH2 ? + Http2SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()) : + Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()); + SslProvider.ProtocolSslContextSpec testSslContextBuilder = isH2 ? + Http2SslContextSpec.forServer(testCert.certificate(), testCert.privateKey()) : + Http11SslContextSpec.forServer(testCert.certificate(), testCert.privateKey()); + disposableServer = - serverCustomizer.apply(createServer()) - .secure(spec -> spec.sslContext(defaultSslContextBuilder) - .addSniMapping("*.test.com", domainSpec -> domainSpec.sslContext(testSslContextBuilder))) - .doOnChannelInit((obs, channel, remoteAddress) -> - channel.pipeline() - .addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - if (evt instanceof SniCompletionEvent) { - hostname.set(((SniCompletionEvent) evt).hostname()); - } - ctx.fireUserEventTriggered(evt); + server.secure(spec -> spec.sslContext(defaultSslContextBuilder) + .addSniMapping("*.test.com", domainSpec -> domainSpec.sslContext(testSslContextBuilder))) + .doOnChannelInit((obs, channel, remoteAddress) -> + channel.pipeline() + .addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SniCompletionEvent) { + hostname.set(((SniCompletionEvent) evt).hostname()); } - })) - .handle((req, res) -> res.sendString(Mono.just("testSniSupport"))) - .bindNow(); + ctx.fireUserEventTriggered(evt); + } + })) + .handle((req, res) -> res.sendString(Mono.just("testSniSupport"))) + .bindNow(); + + SslProvider.ProtocolSslContextSpec clientSslContextBuilder = isH2 ? + Http2SslContextSpec.forClient().configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)) : + Http11SslContextSpec.forClient().configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)); clientCustomizer.apply(createClient(disposableServer::address)) .secure(spec -> spec.sslContext(clientSslContextBuilder) @@ -2203,40 +2237,55 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { } @Test - void testSniSupportAsyncMappings() throws Exception { - SelfSignedCertificate defaultCert = new SelfSignedCertificate("default"); - Http11SslContextSpec defaultSslContextBuilder = - Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()); + void testSniSupportAsyncMappingsHttp11() throws Exception { + doTestSniSupportAsyncMappings(Function.identity(), Function.identity()); + } + @ParameterizedTest + @MethodSource("h2CompatibleCombinations") + void testSniSupportAsyncMappingsHttp2(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols) throws Exception { + doTestSniSupportAsyncMappings(server -> server.protocol(serverProtocols), client -> client.protocol(clientProtocols)); + } + + private void doTestSniSupportAsyncMappings(Function serverCustomizer, + Function clientCustomizer) throws Exception { + SelfSignedCertificate defaultCert = new SelfSignedCertificate("default"); SelfSignedCertificate testCert = new SelfSignedCertificate("test.com"); - Http11SslContextSpec testSslContextBuilder = + + AtomicReference hostname = new AtomicReference<>(); + HttpServer server = serverCustomizer.apply(createServer()); + + boolean isH2 = (server.configuration()._protocols & HttpServerConfig.h2) == HttpServerConfig.h2; + SslProvider.ProtocolSslContextSpec defaultSslContextBuilder = isH2 ? + Http2SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()) : + Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()); + SslProvider.ProtocolSslContextSpec testSslContextBuilder = isH2 ? + Http2SslContextSpec.forServer(testCert.certificate(), testCert.privateKey()) : Http11SslContextSpec.forServer(testCert.certificate(), testCert.privateKey()); SslProvider testSslProvider = SslProvider.builder().sslContext(testSslContextBuilder).build(); - Http11SslContextSpec clientSslContextBuilder = - Http11SslContextSpec.forClient() - .configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)); - - AtomicReference hostname = new AtomicReference<>(); disposableServer = - createServer() - .secure(spec -> spec.sslContext(defaultSslContextBuilder) + server.secure(spec -> spec.sslContext(defaultSslContextBuilder) .setSniAsyncMappings((input, promise) -> promise.setSuccess(testSslProvider))) - .doOnChannelInit((obs, channel, remoteAddress) -> - channel.pipeline() - .addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - if (evt instanceof SniCompletionEvent) { - hostname.set(((SniCompletionEvent) evt).hostname()); - } - ctx.fireUserEventTriggered(evt); - } - })) - .handle((req, res) -> res.sendString(Mono.just("testSniSupport"))) - .bindNow(); + .doOnChannelInit((obs, channel, remoteAddress) -> + channel.pipeline() + .addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SniCompletionEvent) { + hostname.set(((SniCompletionEvent) evt).hostname()); + } + ctx.fireUserEventTriggered(evt); + } + })) + .handle((req, res) -> res.sendString(Mono.just("testSniSupport"))) + .bindNow(); - createClient(disposableServer::address) + SslProvider.ProtocolSslContextSpec clientSslContextBuilder = isH2 ? + Http2SslContextSpec.forClient().configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)) : + Http11SslContextSpec.forClient().configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)); + + clientCustomizer.apply(createClient(disposableServer::address)) .secure(spec -> spec.sslContext(clientSslContextBuilder) .serverNames(new SNIHostName("test.com"))) .get()