Skip to content

Commit

Permalink
Make HostHeaderHttpRequesterFilter public (#2212)
Browse files Browse the repository at this point in the history
Motivation:

Users may have a need to set the `Host` header with a different value
than they used to create a client. For these cases, it's convenient to
reuse existing filter instead of writing their own.

Modifications:

- Move `HostHeaderHttpRequesterFilter` from `http-netty` to `http-utils`
and make it `public`;
- Disable
`DefaultSingleAddressHttpClientBuilder.addHostHeaderFallbackFilter` when
`HostHeaderHttpRequesterFilter` is added manually;
- Move `NetUtils` from `http-api` to `utils-internal` module to share
with other modules outside `http-api`, rename it to `NetworkUtils`;
- Add `NetworkUtilsTests`;
- Replace all usages of `io.netty.util.NetUtil.isValidIpV*Address` with
`io.servicetalk.utils.internal.NetworkUtils.isValidIpV*Address`;

Result:

Users can use `HostHeaderHttpRequesterFilter` to set the `Host` header
and move it anywhere in the filter chain.
  • Loading branch information
idelpivnitskiy authored May 20, 2022
1 parent c0b80c0 commit 49e3525
Show file tree
Hide file tree
Showing 10 changed files with 498 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@
import static io.servicetalk.http.api.HttpHeaderNames.TRANSFER_ENCODING;
import static io.servicetalk.http.api.HttpHeaderNames.VARY;
import static io.servicetalk.http.api.HttpHeaderValues.CHUNKED;
import static io.servicetalk.http.api.NetUtils.isValidIpV4Address;
import static io.servicetalk.http.api.NetUtils.isValidIpV6Address;
import static io.servicetalk.http.api.UriUtils.TCHAR_HMASK;
import static io.servicetalk.http.api.UriUtils.TCHAR_LMASK;
import static io.servicetalk.http.api.UriUtils.isBitSet;
import static io.servicetalk.utils.internal.CharsetUtils.standardCharsets;
import static io.servicetalk.utils.internal.NetworkUtils.isValidIpV4Address;
import static io.servicetalk.utils.internal.NetworkUtils.isValidIpV6Address;
import static java.lang.Math.min;
import static java.lang.System.lineSeparator;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ SingleAddressHttpClientBuilder<U, R> enableWireLogging(String loggerName,
SingleAddressHttpClientBuilder<U, R> protocols(HttpProtocolConfig... protocols);

/**
* Configures automatically setting {@code Host} headers by inferring from the address or {@link HttpMetaData}.
* Configures automatically setting {@code Host} headers by inferring from the address.
* <p>
* When {@code false} is passed, this setting disables the default filter such that no {@code Host} header will be
* manipulated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import io.servicetalk.http.api.StreamingHttpConnectionFilterFactory;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpRequestResponseFactory;
import io.servicetalk.http.utils.HostHeaderHttpRequesterFilter;
import io.servicetalk.logging.api.LogLevel;
import io.servicetalk.transport.api.ClientSslConfig;
import io.servicetalk.transport.api.ExecutionStrategy;
Expand Down Expand Up @@ -448,13 +449,22 @@ public DefaultSingleAddressHttpClientBuilder<U, R> appendConnectionFilter(
requireNonNull(factory);
connectionFilterFactory = appendConnectionFilter(connectionFilterFactory, factory);
strategyComputation.add(factory);
ifHostHeaderHttpRequesterFilter(factory);
return this;
}

@Override
public DefaultSingleAddressHttpClientBuilder<U, R> appendConnectionFilter(
final Predicate<StreamingHttpRequest> predicate, final StreamingHttpConnectionFilterFactory factory) {
return appendConnectionFilter(toConditionalConnectionFilterFactory(predicate, factory));
appendConnectionFilter(toConditionalConnectionFilterFactory(predicate, factory));
ifHostHeaderHttpRequesterFilter(factory);
return this;
}

private void ifHostHeaderHttpRequesterFilter(final Object filter) {
if (filter instanceof HostHeaderHttpRequesterFilter) {
addHostHeaderFallbackFilter = false;
}
}

// Use another method to keep final references and avoid StackOverflowError
Expand Down Expand Up @@ -491,7 +501,9 @@ public DefaultSingleAddressHttpClientBuilder<U, R> appendClientFilter(
ensureSingleRetryFilter();
retryingHttpRequesterFilter = (RetryingHttpRequesterFilter) factory;
}
return appendClientFilter(toConditionalClientFilterFactory(predicate, factory));
appendClientFilter(toConditionalClientFilterFactory(predicate, factory));
ifHostHeaderHttpRequesterFilter(factory);
return this;
}

private void ensureSingleRetryFilter() {
Expand All @@ -518,6 +530,7 @@ public DefaultSingleAddressHttpClientBuilder<U, R> appendClientFilter(
}
clientFilterFactory = appendFilter(clientFilterFactory, factory);
strategyComputation.add(factory);
ifHostHeaderHttpRequesterFilter(factory);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import java.util.List;
import javax.annotation.Nullable;

import static io.netty.util.NetUtil.isValidIpV4Address;
import static io.netty.util.NetUtil.isValidIpV6Address;
import static io.servicetalk.http.netty.HttpServerConfig.httpAlpnProtocols;
import static io.servicetalk.utils.internal.NetworkUtils.isValidIpV4Address;
import static io.servicetalk.utils.internal.NetworkUtils.isValidIpV6Address;

final class HttpClientConfig {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.servicetalk.http.api.HttpRequest;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.http.api.ReservedBlockingHttpConnection;
import io.servicetalk.http.utils.HostHeaderHttpRequesterFilter;
import io.servicetalk.transport.api.ServerContext;

import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -128,7 +129,6 @@ void clientBuilderAppendClientFilter(HttpVersionConfig httpVersionConfig) throws
try (ServerContext context = buildServer();
BlockingHttpClient client = forSingleAddress(serverHostAndPort(context))
.protocols(httpVersionConfig.config())
.hostHeaderFallback(false) // turn off the default
.appendClientFilter(new HostHeaderHttpRequesterFilter("foo.bar:-1"))
.buildBlocking()) {
assertResponse(client, null, "foo.bar:-1");
Expand All @@ -142,7 +142,6 @@ void clientBuilderAppendConnectionFilter(HttpVersionConfig httpVersionConfig) th
try (ServerContext context = buildServer();
BlockingHttpClient client = forSingleAddress(serverHostAndPort(context))
.protocols(httpVersionConfig.config())
.hostHeaderFallback(false) // turn off the default
.appendConnectionFilter(new HostHeaderHttpRequesterFilter("foo.bar:-1"))
.buildBlocking()) {
assertResponse(client, null, "foo.bar:-1");
Expand All @@ -156,7 +155,6 @@ void reserveConnection(HttpVersionConfig httpVersionConfig) throws Exception {
try (ServerContext context = buildServer();
BlockingHttpClient client = HttpClients.forResolvedAddress(serverHostAndPort(context))
.protocols(httpVersionConfig.config())
.hostHeaderFallback(false) // turn off the default
.appendConnectionFilter(new HostHeaderHttpRequesterFilter("foo.bar:-1"))
.buildBlocking();
ReservedBlockingHttpConnection conn = client.reserveConnection(client.get("/"))) {
Expand All @@ -172,7 +170,6 @@ void clientBuilderAppendClientFilterExplicitHostHeader(HttpVersionConfig httpVer
try (ServerContext context = buildServer();
BlockingHttpClient client = forSingleAddress(serverHostAndPort(context))
.protocols(httpVersionConfig.config())
.hostHeaderFallback(false) // turn off the default
.appendClientFilter(new HostHeaderHttpRequesterFilter("foo.bar:-1"))
.buildBlocking()) {
assertResponse(client, "bar.only:-1", "bar.only:-1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.servicetalk.http.api.HttpRequest;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.http.netty.ConditionalFilterFactory.FilterFactory;
import io.servicetalk.http.utils.HostHeaderHttpRequesterFilter;
import io.servicetalk.http.utils.RedirectingHttpRequesterFilter;
import io.servicetalk.transport.api.HostAndPort;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2018-2021 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018-2022 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,14 +13,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.http.netty;
package io.servicetalk.http.utils;

import io.servicetalk.concurrent.api.Single;
import io.servicetalk.http.api.FilterableStreamingHttpClient;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpExecutionStrategies;
import io.servicetalk.http.api.HttpExecutionStrategy;
import io.servicetalk.http.api.HttpHeaderNames;
import io.servicetalk.http.api.HttpRequestMetaData;
import io.servicetalk.http.api.StreamingHttpClientFilter;
import io.servicetalk.http.api.StreamingHttpClientFilterFactory;
import io.servicetalk.http.api.StreamingHttpConnectionFilter;
Expand All @@ -29,24 +30,26 @@
import io.servicetalk.http.api.StreamingHttpRequester;
import io.servicetalk.http.api.StreamingHttpResponse;

import static io.netty.util.NetUtil.isValidIpV6Address;
import static io.servicetalk.buffer.api.CharSequences.newAsciiString;
import static io.servicetalk.concurrent.api.Single.defer;
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_0;
import static io.servicetalk.utils.internal.NetworkUtils.isValidIpV6Address;

/**
* A filter which will apply a fallback value for the {@link HttpHeaderNames#HOST} header if one is not present.
* A filter which will set a {@link HttpHeaderNames#HOST} header with the fallback value if the header is not already
* present in {@link HttpRequestMetaData}.
*/
final class HostHeaderHttpRequesterFilter implements StreamingHttpClientFilterFactory,
StreamingHttpConnectionFilterFactory {
public final class HostHeaderHttpRequesterFilter implements StreamingHttpClientFilterFactory,
StreamingHttpConnectionFilterFactory {
private final CharSequence fallbackHost;

/**
* Create a new instance.
*
* @param fallbackHost The address to use as a fallback if a {@link HttpHeaderNames#HOST} header is not present.
*/
HostHeaderHttpRequesterFilter(CharSequence fallbackHost) {
public HostHeaderHttpRequesterFilter(CharSequence fallbackHost) {
this.fallbackHost = newAsciiString(isValidIpV6Address(fallbackHost) && fallbackHost.charAt(0) != '[' ?
"[" + fallbackHost + "]" : fallbackHost.toString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import io.netty.incubator.channel.uring.IOUringDatagramChannel;
import io.netty.incubator.channel.uring.IOUringServerSocketChannel;
import io.netty.incubator.channel.uring.IOUringSocketChannel;
import io.netty.util.NetUtil;

import java.io.Closeable;
import java.io.IOException;
Expand All @@ -51,6 +50,7 @@
import javax.annotation.Nullable;

import static io.netty.util.NetUtil.createByteArrayFromIpAddressString;
import static io.netty.util.NetUtil.toAddressString;
import static io.servicetalk.transport.netty.internal.NativeTransportUtils.useEpoll;
import static io.servicetalk.transport.netty.internal.NativeTransportUtils.useIoUring;
import static io.servicetalk.transport.netty.internal.NativeTransportUtils.useKQueue;
Expand Down Expand Up @@ -203,9 +203,9 @@ public static String formatCanonicalAddress(SocketAddress address) {
if (inetAddress == null) {
return address.toString();
} else if (inetAddress instanceof Inet6Address) {
return '[' + NetUtil.toAddressString(inetAddress) + "]:" + inetSocketAddress.getPort();
return '[' + toAddressString(inetAddress) + "]:" + inetSocketAddress.getPort();
} else {
return NetUtil.toAddressString(inetAddress) + ':' + inetSocketAddress.getPort();
return toAddressString(inetAddress) + ':' + inetSocketAddress.getPort();
}
}
return address.toString();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2018, 2021 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018, 2021-2022 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,29 +28,37 @@
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.servicetalk.http.api;
package io.servicetalk.utils.internal;

import static io.servicetalk.buffer.api.CharSequences.indexOf;

final class NetUtils {
/**
* Network-related utilities.
* <p>
* This class borrowed some of its methods from
* <a href="https://github.com/netty/netty/blob/4.1/common/src/main/java/io/netty/util/NetUtil.java">NetUtil</a> class
* which was part of Netty.
*/
public final class NetworkUtils {

private NetUtils() {
private NetworkUtils() {
// no instances
}

/**
* Takes a string and parses it to see if it is a valid IPV4 address.
*
* @param ip the IP-address to validate
* @return true, if the string represents an IPV4 address in dotted notation, false otherwise.
*/
static boolean isValidIpV4Address(final CharSequence ip) {
public static boolean isValidIpV4Address(final CharSequence ip) {
return isValidIpV4Address(ip, 0, ip.length());
}

private static boolean isValidIpV4Address(final CharSequence ip, int from, int toExclusive) {
int len = toExclusive - from;
int i;
return len <= 15 && len > 7 &&
return len <= 15 && len >= 7 &&
(i = indexOf(ip, '.', from + 1)) > 0 && isValidIpV4Word(ip, from, i) &&
(i = indexOf(ip, '.', from = i + 2)) > 0 && isValidIpV4Word(ip, from - 1, i) &&
(i = indexOf(ip, '.', from = i + 2)) > 0 && isValidIpV4Word(ip, from - 1, i) &&
Expand All @@ -60,9 +68,10 @@ private static boolean isValidIpV4Address(final CharSequence ip, int from, int t
/**
* Takes a string and parses it to see if it is a valid IPV6 address.
*
* @param ip the IP-address to validate
* @return true, if the string represents an IPV6 address
*/
static boolean isValidIpV6Address(final CharSequence ip) {
public static boolean isValidIpV6Address(final CharSequence ip) {
int end = ip.length();
if (end < 2) {
return false;
Expand Down
Loading

0 comments on commit 49e3525

Please sign in to comment.