Skip to content

Commit

Permalink
Disable cache by default for DNS ServiceDiscoverer
Browse files Browse the repository at this point in the history
Motivation:

When we poll DNS in the background the cache is not important because
we always schedule the next query after cache expires. It only helps
with concurrent resolutions of the same address, which in practice
should be minimal if users properly reuse clients. Disadvantage of
having a cache is that if it gets poisoned with invalid or stale entries
there is no way to clear the cache. Cancelling the events publisher and
re-subscribing doesn't help because re-subscribe always hits the cache.
See apple#2514.

Modifications:

- Disable cache by default;
- Provide API to opt-in for caching (can be useful if users resolve per
new connection instead of polling in the background);
- Deprecate `minTTL`, remove `maxTTL` builder methods, introduce
`ttl(min, max, cache)` instead;
- Invoke `ttlCache.prepareForResolution(name)` only for scheduled queries,
keep it unchanged when cancel/re-subscribe to correctly offset ttl;
- Ignore empty list inside `MinTtlCache.get(...)`;
- Make `DefaultDnsClient` logging more consistent;
- Enhance testing;

Result:

Caching is disabled by default. Polling is driven by TTL. In case of
re-subscribe, we always send a new query. Users have API to configure
min/max polling intervals and caching.
  • Loading branch information
idelpivnitskiy committed Feb 22, 2023
1 parent 229e2d1 commit 96ffa97
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import io.netty.resolver.dns.DefaultDnsCache;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.resolver.dns.NoopDnsCache;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
Expand Down Expand Up @@ -128,7 +129,8 @@ final class DefaultDnsClient implements DnsClient {
private final String id;
private boolean closed;

DefaultDnsClient(final IoExecutor ioExecutor, final int minTTL, final long ttlJitterNanos,
DefaultDnsClient(final String id, final IoExecutor ioExecutor,
final int minTTL, final int maxTTL, final boolean cache, final long ttlJitterNanos,
final int srvConcurrency, final boolean inactiveEventsOnError,
final boolean completeOncePreferredResolved, final boolean srvFilterDuplicateEvents,
Duration srvHostNameRepeatInitialDelay, Duration srvHostNameRepeatJitter,
Expand All @@ -137,11 +139,7 @@ final class DefaultDnsClient implements DnsClient {
final DnsResolverAddressTypes dnsResolverAddressTypes,
@Nullable final DnsServerAddressStreamProvider dnsServerAddressStreamProvider,
@Nullable final DnsServiceDiscovererObserver observer,
final ServiceDiscovererEvent.Status missingRecordStatus,
final int maxTTL, final String id) {
if (srvConcurrency <= 0) {
throw new IllegalArgumentException("srvConcurrency: " + srvConcurrency + " (expected >0)");
}
final ServiceDiscovererEvent.Status missingRecordStatus) {
this.maxTTLNanos = SECONDS.toNanos(maxTTL);
this.srvConcurrency = srvConcurrency;
this.srvFilterDuplicateEvents = srvFilterDuplicateEvents;
Expand All @@ -151,8 +149,8 @@ final class DefaultDnsClient implements DnsClient {
// We must use nettyIoExecutor for the repeater for thread safety!
srvHostNameRepeater = repeatWithConstantBackoffDeltaJitter(
srvHostNameRepeatInitialDelay, srvHostNameRepeatJitter, nettyIoExecutor);
this.ttlCache = new MinTtlCache(new DefaultDnsCache(minTTL, maxTTL, minTTL), minTTL,
nettyIoExecutor);
this.ttlCache = new MinTtlCache(cache ? new DefaultDnsCache(minTTL, maxTTL, 0) : NoopDnsCache.INSTANCE,
minTTL, nettyIoExecutor);
this.ttlJitterNanos = ttlJitterNanos;
this.observer = observer;
this.missingRecordStatus = missingRecordStatus;
Expand Down Expand Up @@ -336,7 +334,7 @@ protected AbstractDnsSubscription newSubscription(
final Subscriber<? super List<ServiceDiscovererEvent<HostAndPort>>> subscriber) {
return new AbstractDnsSubscription(subscriber) {
@Override
protected Future<DnsAnswer<HostAndPort>> doDnsQuery() {
protected Future<DnsAnswer<HostAndPort>> doDnsQuery(final boolean scheduledQuery) {
Promise<DnsAnswer<HostAndPort>> promise = nettyIoExecutor.eventLoopGroup().next().newPromise();
resolver.resolveAll(new DefaultDnsQuestion(name, SRV))
.addListener((Future<? super List<DnsRecord>> completedFuture) -> {
Expand Down Expand Up @@ -367,9 +365,9 @@ protected Future<DnsAnswer<HostAndPort>> doDnsQuery() {
final int port = content.readUnsignedShort();
hostAndPorts.add(HostAndPort.of(decodeName(content), port));
}
LOGGER.trace("{} original result for {}: {}, minTTL: {} second(s).",
LOGGER.trace("{} original result for {} (size={}, TTL={}s): {}.",
DefaultDnsClient.this, SrvRecordPublisher.this,
completedFuture.getNow(), minTTLSeconds);
toRelease.size(), minTTLSeconds, toRelease);
dnsAnswer = new DnsAnswer<>(hostAndPorts, SECONDS.toNanos(minTTLSeconds));
} catch (Throwable cause2) {
promise.tryFailure(cause2);
Expand Down Expand Up @@ -410,8 +408,10 @@ protected AbstractDnsSubscription newSubscription(
final Subscriber<? super List<ServiceDiscovererEvent<InetAddress>>> subscriber) {
return new AbstractDnsSubscription(subscriber) {
@Override
protected Future<DnsAnswer<InetAddress>> doDnsQuery() {
ttlCache.prepareForResolution(name);
protected Future<DnsAnswer<InetAddress>> doDnsQuery(final boolean scheduledQuery) {
if (scheduledQuery) {
ttlCache.prepareForResolution(name);
}
Promise<DnsAnswer<InetAddress>> dnsAnswerPromise =
nettyIoExecutor.eventLoopGroup().next().newPromise();
resolver.resolveAll(name).addListener(completedFuture -> {
Expand All @@ -420,13 +420,14 @@ protected Future<DnsAnswer<InetAddress>> doDnsQuery() {
dnsAnswerPromise.tryFailure(cause);
} else {
final DnsAnswer<InetAddress> dnsAnswer;
@SuppressWarnings("unchecked")
final List<InetAddress> original = (List<InetAddress>) completedFuture.getNow();
final long minTTLSeconds = ttlCache.minTtl(name);
LOGGER.trace("{} original result for {}: {}, minTTL: {} second(s).",
LOGGER.trace("{} original result for {} (size={}, TTL={}s): {}.",
DefaultDnsClient.this, ARecordPublisher.this,
completedFuture.getNow(), minTTLSeconds);
original.size(), minTTLSeconds, original);
try {
dnsAnswer = new DnsAnswer<>(toAddresses(completedFuture),
SECONDS.toNanos(minTTLSeconds));
dnsAnswer = new DnsAnswer<>(toAddresses(original), SECONDS.toNanos(minTTLSeconds));
} catch (Throwable cause2) {
dnsAnswerPromise.tryFailure(cause2);
return;
Expand All @@ -442,9 +443,7 @@ protected Comparator<InetAddress> comparator() {
return INET_ADDRESS_COMPARATOR;
}

private List<InetAddress> toAddresses(Future<? super List<InetAddress>> completedFuture) {
@SuppressWarnings("unchecked")
final List<InetAddress> original = (List<InetAddress>) completedFuture.getNow();
private List<InetAddress> toAddresses(final List<InetAddress> original) {
if (addressTypes == IPV4_PREFERRED || addressTypes == IPV6_PREFERRED) {
// Filter out addresses to keep only preferred if both available.
int ipv4Cnt = 0;
Expand Down Expand Up @@ -580,9 +579,10 @@ abstract class AbstractDnsSubscription implements Subscription {
/**
* Performs DNS query.
*
* @param scheduledQuery indicates when query was scheduled
* @return a {@link Future} that will be notified when {@link DnsAnswer} is available
*/
protected abstract Future<DnsAnswer<T>> doDnsQuery();
protected abstract Future<DnsAnswer<T>> doDnsQuery(boolean scheduledQuery);

/**
* Returns a {@link Comparator} for the resolved address type.
Expand Down Expand Up @@ -630,29 +630,33 @@ private void request0(final long n) {
pendingRequests = addWithOverflowProtection(pendingRequests, n);
if (cancellableForQuery == null) {
if (ttlNanos < 0) {
doQuery0();
doQuery0(false);
} else {
final long durationNs =
nettyIoExecutor.currentTime(NANOSECONDS) - resolveDoneNoScheduleTime;
if (durationNs > ttlNanos) {
doQuery0();
doQuery0(false);
} else {
scheduleQuery0(ttlNanos - durationNs);
scheduleQuery0(ttlNanos - durationNs, ttlNanos);
}
}
}
}

private void doQuery0() {
private void executeScheduledQuery0() {
doQuery0(true);
}

private void doQuery0(final boolean scheduledQuery) {
assertInEventloop();

if (closed) {
// best effort check to cleanup state after close.
handleTerminalError0(new ClosedDnsServiceDiscovererException());
} else {
final DnsResolutionObserver resolutionObserver = newResolutionObserver();
LOGGER.trace("{} querying DNS for {}", DefaultDnsClient.this, AbstractDnsPublisher.this);
final Future<DnsAnswer<T>> addressFuture = doDnsQuery();
LOGGER.trace("{} querying DNS for {}.", DefaultDnsClient.this, AbstractDnsPublisher.this);
final Future<DnsAnswer<T>> addressFuture = doDnsQuery(scheduledQuery);
cancellableForQuery = () -> addressFuture.cancel(true);
if (addressFuture.isDone()) {
handleResolveDone0(addressFuture, resolutionObserver);
Expand Down Expand Up @@ -680,6 +684,7 @@ private DnsResolutionObserver newResolutionObserver() {

private void cancel0() {
assertInEventloop();
LOGGER.debug("{} subscription for {} is cancelled.", DefaultDnsClient.this, AbstractDnsPublisher.this);
Cancellable oldCancellable = cancellableForQuery;
cancellableForQuery = TERMINATED;
if (oldCancellable != null) {
Expand All @@ -696,17 +701,21 @@ private void cancelAndTerminate0(Throwable cause) {
}

private void scheduleQuery0(final long nanos) {
scheduleQuery0(nanos, nanos);
}

private void scheduleQuery0(final long nanos, final long originalTtlNanos) {
assertInEventloop();

final long delay = ThreadLocalRandom.current()
.nextLong(nanos, addWithOverflowProtection(nanos, ttlJitterNanos));
LOGGER.debug("{} scheduling DNS query for {} after {}ms, original TTL: {}ms.",
LOGGER.debug("{} scheduling DNS query for {} after {}ms (TTL={}s, jitter={}ms).",
DefaultDnsClient.this, AbstractDnsPublisher.this, NANOSECONDS.toMillis(delay),
NANOSECONDS.toMillis(nanos));
NANOSECONDS.toSeconds(originalTtlNanos), NANOSECONDS.toMillis(ttlJitterNanos));

// This value is coming from DNS TTL for which the unit is seconds and the minimum value we accept
// in the builder is 1 second.
cancellableForQuery = nettyIoExecutor.schedule(this::doQuery0, delay, NANOSECONDS);
cancellableForQuery = nettyIoExecutor.schedule(this::executeScheduledQuery0, delay, NANOSECONDS);
}

private void handleResolveDone0(final Future<DnsAnswer<T>> addressFuture,
Expand Down Expand Up @@ -746,18 +755,19 @@ private void handleResolveDone0(final Future<DnsAnswer<T>> addressFuture,
cancellableForQuery = null;
}
try {
LOGGER.debug("{} sending events for {} (size={}, ttl={}ms) {}.",
LOGGER.debug("{} sending events for {} (size={}, TTL={}s): {}.",
DefaultDnsClient.this, AbstractDnsPublisher.this, events.size(),
NANOSECONDS.toMillis(ttlNanos), events);
NANOSECONDS.toSeconds(ttlNanos), events);

subscriber.onNext(events);
} catch (final Throwable error) {
handleTerminalError0(error);
}
} else {
LOGGER.trace("{} resolution done but no changes for {} (size={}, ttl={}ms) {}.",
LOGGER.trace("{} resolution is complete but no changes detected for {} based on result " +
"(size={}, TTL={}s) {}.",
DefaultDnsClient.this, AbstractDnsPublisher.this, activeAddresses.size(),
NANOSECONDS.toMillis(ttlNanos), activeAddresses);
NANOSECONDS.toSeconds(ttlNanos), activeAddresses);

scheduleQuery0(ttlNanos);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public final class DefaultDnsServiceDiscovererBuilder implements DnsServiceDisco
private Duration queryTimeout;
private int minTTLSeconds = 10;
private int maxTTLSeconds = (int) TimeUnit.MINUTES.toSeconds(5);
private boolean cache;
private Duration ttlJitter = ofSeconds(4);
private int srvConcurrency = 2048;
private boolean inactiveEventsOnError;
Expand All @@ -85,10 +86,20 @@ public DefaultDnsServiceDiscovererBuilder() {
}

DefaultDnsServiceDiscovererBuilder(final String id) {
this.id = requireNonNull(id);
if (id.isEmpty()) {
throw new IllegalArgumentException("id can not be empty");
}
this.id = id;
}

@Override
/**
* The minimum allowed TTL. This will be the minimum poll interval.
*
* @param minTTLSeconds The minimum amount of time a cache entry will be considered valid (in seconds).
* @return {@code this}.
* @deprecated Use {@link #ttl(int, int, boolean)}.
*/
@Deprecated
public DefaultDnsServiceDiscovererBuilder minTTL(final int minTTLSeconds) {
if (minTTLSeconds <= 0) {
throw new IllegalArgumentException("minTTLSeconds: " + minTTLSeconds + " (expected > 0)");
Expand All @@ -98,11 +109,14 @@ public DefaultDnsServiceDiscovererBuilder minTTL(final int minTTLSeconds) {
}

@Override
public DefaultDnsServiceDiscovererBuilder maxTTL(final int maxTTLSeconds) {
if (minTTLSeconds <= 0) {
throw new IllegalArgumentException("maxTTLSeconds: " + maxTTLSeconds + " (expected > 0)");
public DefaultDnsServiceDiscovererBuilder ttl(final int minSeconds, final int maxSeconds, final boolean cache) {
if (minSeconds < 0 || maxSeconds < minSeconds) {
throw new IllegalArgumentException("minSeconds: " + minSeconds + ", maxSeconds: " + maxSeconds +
" (expected: 0 <= minSeconds <= maxSeconds)");
}
this.maxTTLSeconds = maxTTLSeconds;
this.minTTLSeconds = minSeconds;
this.maxTTLSeconds = maxSeconds;
this.cache = cache;
return this;
}

Expand Down Expand Up @@ -256,18 +270,12 @@ private static DnsClientFilterFactory appendFilter(@Nullable final DnsClientFilt
* @return a new instance of {@link DnsClient}.
*/
DnsClient build() {
if (minTTLSeconds > maxTTLSeconds) {
throw new IllegalArgumentException("minTTLSeconds (" + minTTLSeconds + ") must not be larger " +
"than maxTTLSeconds (" + maxTTLSeconds + ")");
}

final DnsClient rawClient = new DefaultDnsClient(
ioExecutor == null ? globalExecutionContext().ioExecutor() : ioExecutor, minTTLSeconds,
ttlJitter.toNanos(), srvConcurrency,
final DnsClient rawClient = new DefaultDnsClient(id,
ioExecutor == null ? globalExecutionContext().ioExecutor() : ioExecutor,
minTTLSeconds, maxTTLSeconds, cache, ttlJitter.toNanos(), srvConcurrency,
inactiveEventsOnError, completeOncePreferredResolved, srvFilterDuplicateEvents,
srvHostNameRepeatInitialDelay, srvHostNameRepeatJitter, maxUdpPayloadSize, ndots, optResourceEnabled,
queryTimeout, dnsResolverAddressTypes, dnsServerAddressStreamProvider, observer, missingRecordStatus,
maxTTLSeconds, id);
queryTimeout, dnsResolverAddressTypes, dnsServerAddressStreamProvider, observer, missingRecordStatus);
return filterFactory == null ? rawClient : filterFactory.create(rawClient);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,8 @@ protected final DnsServiceDiscovererBuilder delegate() {
}

@Override
public DnsServiceDiscovererBuilder minTTL(final int minTTLSeconds) {
delegate = delegate.minTTL(minTTLSeconds);
return this;
}

@Override
public DnsServiceDiscovererBuilder maxTTL(final int maxTTLSeconds) {
delegate = delegate.maxTTL(maxTTLSeconds);
public DnsServiceDiscovererBuilder ttl(final int minSeconds, final int maxSeconds, final boolean cache) {
delegate = delegate.ttl(minSeconds, maxSeconds, cache);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,29 @@
*/
public interface DnsServiceDiscovererBuilder {
/**
* The minimum allowed TTL. This will be the minimum poll interval.
*
* @param minTTLSeconds The minimum amount of time a cache entry will be considered valid (in seconds).
* @return {@code this}.
*/
DnsServiceDiscovererBuilder minTTL(int minTTLSeconds);

/**
* The maximum allowed TTL. This will be the maximum poll interval as well as the maximum dns cache value.
* Controls min/max TTL values that will affect polling interval and caching.
* <p>
* The created {@link ServiceDiscoverer} polls DNS server based on TTL value of the resolved records. Min/max values
* help to make sure polling stays within reasonable boundaries. The 3rd argument controls if the resolved records
* should be cached or not. Cache is helpful in scenarios when multiple concurrent resolutions are possible for the
* same address: either an application runs multiple client instances for the same remote address (not recommended)
* or clients perform DNS resolutions per new connection instead of background polling.
*
* @param maxTTLSeconds the maximum amount of time a cache entry will be considered valid (in seconds).
* @param minSeconds The minimum about of time the result will be considered valid (in seconds), must be greater
* than {@code 0}.
* @param maxSeconds The maximum about of time the result will be considered valid (in seconds), must be greater
* than {@code minSeconds}.
* @param cache If {@code true}, DNS responses will be cached locally for the specified time. Any concurrent
* resolutions for the same address will hit the cache if it's not expired. Otherwise, all resolutions will generate
* a new query for DNS server.
* @return {@code this}.
*/
DnsServiceDiscovererBuilder maxTTL(int maxTTLSeconds);
DnsServiceDiscovererBuilder ttl(int minSeconds, int maxSeconds, boolean cache);

/**
* The jitter to apply to schedule the next query after TTL.
* The jitter to apply for scheduling the next query after TTL to help spread out subsequent DNS queries.
* <p>
* The jitter value will be added on top of the TTL value returned from the DNS server to help spread out
* subsequent DNS queries.
* The jitter value will be added on top of the TTL value returned from the DNS server to avoid hitting the cache.
*
* @param ttlJitter The jitter to apply to schedule the next query after TTL.
* @return {@code this}.
Expand Down
Loading

0 comments on commit 96ffa97

Please sign in to comment.