diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index f0a89c9d9116c..3fe69b1bd8851 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -22,6 +22,7 @@ import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -106,6 +107,7 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) .addLast("decoder", decoder) + .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. .addLast("handler", channelHandler); @@ -126,7 +128,8 @@ private TransportChannelHandler createChannelHandler(Channel channel) { TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); - return new TransportChannelHandler(client, responseHandler, requestHandler); + return new TransportChannelHandler(client, responseHandler, requestHandler, + conf.connectionTimeoutMs()); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 2044afb0d85db..94fc21af5e606 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -20,8 +20,8 @@ import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; -import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,13 +50,18 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; + /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ + private final AtomicLong timeOfLastRequestNs; + public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap(); this.outstandingRpcs = new ConcurrentHashMap(); + this.timeOfLastRequestNs = new AtomicLong(0); } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingFetches.put(streamChunkId, callback); } @@ -65,6 +70,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingRpcs.put(requestId, callback); } @@ -161,8 +167,12 @@ public void handle(ResponseMessage message) { } /** Returns total number of outstanding requests (fetch requests + rpcs) */ - @VisibleForTesting public int numOutstandingRequests() { return outstandingFetches.size() + outstandingRpcs.size(); } + + /** Returns the time in nanoseconds of when the last request was sent out. */ + public long getTimeOfLastRequestNs() { + return timeOfLastRequestNs.get(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index e491367fa4528..8e0ee709e38e3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -19,6 +19,8 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,6 +42,11 @@ * Client. * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, * for the Client's responses to the Server's requests. + * + * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}. + * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic + * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not + * timeout if the client is continuously sending but getting no responses, for simplicity. */ public class TransportChannelHandler extends SimpleChannelInboundHandler { private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); @@ -47,14 +54,17 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler 0; + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + ctx.close(); + } + } + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java new file mode 100644 index 0000000000000..668d2356b955d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import com.google.common.collect.Maps; + +import java.util.Map; +import java.util.NoSuchElementException; + +/** ConfigProvider based on a Map (copied in the constructor). */ +public class MapConfigProvider extends ConfigProvider { + private final Map config; + + public MapConfigProvider(Map config) { + this.config = Maps.newHashMap(config); + } + + @Override + public String get(String name) { + String value = config.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index dabd6261d2aa0..26c6399ce7dbc 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -98,7 +98,7 @@ public static ByteToMessageDecoder createFrameDecoder() { return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); } - /** Returns the remote address on the channel or "<remote address>" if none exists. */ + /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ public static String getRemoteAddress(Channel channel) { if (channel != null && channel.remoteAddress() != null) { return channel.remoteAddress().toString(); diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java new file mode 100644 index 0000000000000..84ebb337e6d54 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import com.google.common.collect.Maps; +import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.junit.*; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** + * Suite which ensures that requests that go without a response for the network timeout period are + * failed, and the connection closed. + * + * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests, + * to ensure stability in different test environments. + */ +public class RequestTimeoutIntegrationSuite { + + private TransportServer server; + private TransportClientFactory clientFactory; + + private StreamManager defaultManager; + private TransportConf conf; + + // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever. + private final int FOREVER = 60 * 1000; + + @Before + public void setUp() throws Exception { + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.connectionTimeout", "2s"); + conf = new TransportConf(new MapConfigProvider(configMap)); + + defaultManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + throw new UnsupportedOperationException(); + } + }; + } + + @After + public void tearDown() { + if (server != null) { + server.close(); + } + if (clientFactory != null) { + clientFactory.close(); + } + } + + // Basic suite: First request completes quickly, and second waits for longer than network timeout. + @Test + public void timeoutInactiveRequests() throws Exception { + final Semaphore semaphore = new Semaphore(1); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // First completes quickly (semaphore starts at 1). + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.success.length == response.length); + } + + // Second times out after 2 seconds, with slack. Must be IOException. + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client.sendRpc(new byte[0], callback1); + callback1.wait(4 * 1000); + assert (callback1.failure != null); + assert (callback1.failure instanceof IOException); + } + semaphore.release(); + } + + // A timeout will cause the connection to be closed, invalidating the current TransportClient. + // It should be the case that requesting a client from the factory produces a new, valid one. + @Test + public void timeoutCleanlyClosesClient() throws Exception { + final Semaphore semaphore = new Semaphore(0); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + + // First request should eventually fail. + TransportClient client0 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client0.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.failure instanceof IOException); + assert (!client0.isActive()); + } + + // Increment the semaphore and the second request should succeed quickly. + semaphore.release(2); + TransportClient client1 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client1.sendRpc(new byte[0], callback1); + callback1.wait(FOREVER); + assert (callback1.success.length == response.length); + assert (callback1.failure == null); + } + } + + // The timeout is relative to the LAST request sent, which is kinda weird, but still. + // This test also makes sure the timeout works for Fetch requests as well as RPCs. + @Test + public void furtherRequestsDelay() throws Exception { + final byte[] response = new byte[16]; + final StreamManager manager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); + return new NioManagedBuffer(ByteBuffer.wrap(response)); + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return manager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // Send one request, which will eventually fail. + TestCallback callback0 = new TestCallback(); + client.fetchChunk(0, 0, callback0); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + // Send a second request before the first has failed. + TestCallback callback1 = new TestCallback(); + client.fetchChunk(0, 1, callback1); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + synchronized (callback0) { + // not complete yet, but should complete soon + assert (callback0.success == null && callback0.failure == null); + callback0.wait(2 * 1000); + assert (callback0.failure instanceof IOException); + } + + synchronized (callback1) { + // failed at same time as previous + assert (callback0.failure instanceof IOException); + } + } + + /** + * Callback which sets 'success' or 'failure' on completion. + * Additionally notifies all waiters on this callback when invoked. + */ + class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { + + byte[] success; + Throwable failure; + + @Override + public void onSuccess(byte[] response) { + synchronized(this) { + success = response; + this.notifyAll(); + } + } + + @Override + public void onFailure(Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + synchronized(this) { + try { + success = buffer.nioByteBuffer().array(); + this.notifyAll(); + } catch (IOException e) { + // weird + } + } + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 416dc1b969fa4..35de5e57ccb98 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -20,10 +20,11 @@ import java.io.IOException; import java.util.Collections; import java.util.HashSet; -import java.util.NoSuchElementException; +import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; +import com.google.common.collect.Maps; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -36,9 +37,9 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -70,16 +71,10 @@ public void tearDown() { */ private void testClientReuse(final int maxConnections, boolean concurrent) throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { - @Override - public String get(String name) { - if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) { - return Integer.toString(maxConnections); - } else { - throw new NoSuchElementException(); - } - } - }); + + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); + TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler);