From 5683b1054f13654c2c5b58fc5856b84b2b8a8f8d Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 2 May 2018 09:55:15 -0500 Subject: [PATCH 01/18] [SPARK-6237][NETWORK] Network-layer changes to allow stream upload. These changes allow an RPCHandler to receive an upload as a stream of data, without having to buffer the entire message in the FrameDecoder. The primary use case is for replicating large blocks. Added unit tests. --- .../network/client/StreamInterceptor.java | 25 +- .../spark/network/client/TransportClient.java | 56 ++++- .../spark/network/crypto/AuthRpcHandler.java | 11 +- .../spark/network/protocol/Message.java | 3 +- .../network/protocol/MessageDecoder.java | 3 + .../spark/network/protocol/UploadStream.java | 107 ++++++++ .../spark/network/sasl/SaslRpcHandler.java | 9 +- .../spark/network/server/NoOpRpcHandler.java | 6 +- .../spark/network/server/RpcHandler.java | 15 +- .../spark/network/server/StreamData.java | 99 ++++++++ .../server/TransportRequestHandler.java | 61 +++-- .../network/ChunkFetchIntegrationSuite.java | 2 + .../RequestTimeoutIntegrationSuite.java | 4 + .../spark/network/RpcIntegrationSuite.java | 231 ++++++++++++++++-- .../org/apache/spark/network/StreamSuite.java | 93 +++---- .../spark/network/StreamTestHelper.java | 101 ++++++++ .../network/crypto/AuthIntegrationSuite.java | 6 +- .../spark/network/sasl/SparkSaslSuite.java | 8 +- .../shuffle/ExternalShuffleBlockHandler.java | 8 +- .../network/sasl/SaslIntegrationSuite.java | 12 +- .../ExternalShuffleBlockHandlerSuite.java | 8 +- .../network/netty/NettyBlockRpcServer.scala | 3 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 32 ++- .../spark/rpc/netty/NettyRpcEnvSuite.scala | 12 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 3 +- 26 files changed, 764 insertions(+), 156 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index b0e85bae7c309..e973a99323ee1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -22,22 +22,25 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.server.MessageHandler; +import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.util.TransportFrameDecoder; /** * An interceptor that is registered with the frame decoder to feed stream data to a * callback. */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { +public class StreamInterceptor implements TransportFrameDecoder.Interceptor { - private final TransportResponseHandler handler; + private final MessageHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private long bytesRead; - StreamInterceptor( - TransportResponseHandler handler, + public StreamInterceptor( + MessageHandler handler, String streamId, long byteCount, StreamCallback callback) { @@ -50,16 +53,22 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } + private void deactivateStream() { + if (handler instanceof TransportResponseHandler) { + ((TransportResponseHandler) handler).deactivateStream(); + } + } + @Override public boolean handle(ByteBuf buf) throws Exception { int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); @@ -72,10 +81,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); - handler.deactivateStream(); + deactivateStream(); throw re; } else if (bytesRead == byteCount) { - handler.deactivateStream(); + deactivateStream(); callback.onComplete(streamId); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..b54fa8bae32bb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -32,15 +32,13 @@ import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.protocol.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamRequest; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -244,6 +242,54 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { return requestId; } + /** + * Send data to the remote end as a stream. This differs from stream() in that this is a request + * to *send* data to the remote end, not to receive it from the remote. + * + * @param meta meta data associated with the stream, which will be read completely on the + * receiving end before the stream itself. + * @param data this will be streamed to the remote end to allow for transferring large amounts + * of data without reading into memory. + * @param callback handles the reply -- onSuccess will only be called when both message and data + * are received successfully. + */ + public long uploadStream( + ManagedBuffer meta, + ManagedBuffer data, + RpcResponseCallback callback) { + long startTime = System.currentTimeMillis(); + if (logger.isTraceEnabled()) { + logger.trace("Sending RPC to {}", getRemoteAddress(channel)); + } + + long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + handler.addRpcRequest(requestId, callback); + + channel.writeAndFlush(new UploadStream(requestId, meta, data)) + .addListener(future -> { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + if (logger.isTraceEnabled()) { + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + handler.removeRpcRequest(requestId); + channel.close(); + try { + callback.onFailure(new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + }); + + return requestId; + } + /** * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 8a6e3858081bf..2c44e9fa71ee1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -33,6 +33,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.sasl.SaslRpcHandler; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.TransportConf; @@ -80,9 +81,13 @@ class AuthRpcHandler extends RpcHandler { } @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + StreamData streamData, + RpcResponseCallback callback) { if (doDelegate) { - delegate.receive(client, message, callback); + delegate.receive(client, message, streamData, callback); return; } @@ -100,7 +105,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder); message.position(position); message.limit(limit); - delegate.receive(client, message, callback); + delegate.receive(client, message, streamData, callback); doDelegate = true; } else { LOG.debug("Unexpected challenge message from client {}, closing channel.", diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 434935a8ef2ad..0ccd70c03aba8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,7 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), User(-1); + OneWayMessage(9), UploadStream(10), User(-1); private final byte id; @@ -65,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case 10: return UploadStream; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 39a7495828a8a..bf80aed0afe10 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -80,6 +80,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case StreamFailure: return StreamFailure.decode(in); + case UploadStream: + return UploadStream.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java new file mode 100644 index 0000000000000..a175f90cbe56a --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * An RPC with data that is sent outside of the frame, so it can be read in a stream. + */ +public final class UploadStream extends AbstractMessage implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + public final ManagedBuffer meta; + public final long bodyByteCount; + + public UploadStream(long requestId, ManagedBuffer meta, ManagedBuffer body) { + super(body, false); // body is *not* included in the frame + this.requestId = requestId; + this.meta = meta; + bodyByteCount = body.size(); + } + + // this version is called when decoding the bytes on the receiving end. The body is handled + // separately. + private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { + super(null, false); + this.requestId = requestId; + this.meta = meta; + this.bodyByteCount = bodyByteCount; + } + + @Override + public Type type() { return Type.UploadStream; } + + @Override + public int encodedLength() { + // the requestId, meta size, meta and bodyByteCount (body is not included) + return 8 + 4 + ((int) meta.size()) + 8; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + try { + ByteBuffer metaBuf = meta.nioByteBuffer(); + buf.writeInt(metaBuf.remaining()); + buf.writeBytes(metaBuf); + } catch (IOException io) { + throw new RuntimeException(io); + } + buf.writeLong(bodyByteCount); + } + + public static UploadStream decode(ByteBuf buf) { + long requestId = buf.readLong(); + int metaSize = buf.readInt(); + ManagedBuffer meta = new NettyManagedBuffer(buf.readRetainedSlice(metaSize)); + long bodyByteCount = buf.readLong(); + // This is called by the frame decoder, so the data is still null. We need a StreamInterceptor + // to read the data. + return new UploadStream(requestId, meta, bodyByteCount); + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, body()); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UploadStream) { + UploadStream o = (UploadStream) other; + return requestId == o.requestId && super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("body", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 0231428318add..37887eba2f7e0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -30,6 +30,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; @@ -76,10 +77,14 @@ public SaslRpcHandler( } @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + StreamData streamData, + RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. - delegate.receive(client, message, callback); + delegate.receive(client, message, streamData, callback); return; } if (saslServer == null || !saslServer.isComplete()) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 6ed61da5c7eff..d25a5a914696d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -31,7 +31,11 @@ public NoOpRpcHandler() { } @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + StreamData streamData, + RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 8f7554e2e07d5..00d60313f4048 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -38,15 +38,24 @@ public abstract class RpcHandler { * * This method will not be called in parallel for a single TransportClient (i.e., channel). * + * The rpc *might* included a data stream in streamData(eg. for uploading a large + * amount of data which should not be buffered in memory here). Any errors while handling the + * streamData will lead to failing this entire connection -- all other in-flight rpcs will fail. + * If stream data is not null, you *must* call streamData.registerStreamCallback + * before this method returns. + * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. + * @param streamData StreamData if there is data which is meant to be read via a StreamCallback; + * otherwise it is null. * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ public abstract void receive( TransportClient client, ByteBuffer message, + StreamData streamData, RpcResponseCallback callback); /** @@ -57,15 +66,15 @@ public abstract void receive( /** * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if - * any of the callback methods are called. + * call "{@link #receive(TransportClient, ByteBuffer, StreamData, RpcResponseCallback)}" and log a + * warning if any of the callback methods are called. * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. */ public void receive(TransportClient client, ByteBuffer message) { - receive(client, message, ONE_WAY_CALLBACK); + receive(client, message, null, ONE_WAY_CALLBACK); } /** diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java new file mode 100644 index 0000000000000..6e9576bd443a5 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.client.StreamInterceptor; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportFrameDecoder; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** + * A holder for streamed data sent along with an RPC message. + */ +public class StreamData { + + private final TransportRequestHandler handler; + private final TransportFrameDecoder frameDecoder; + private final RpcResponseCallback rpcCallback; + private final ByteBuffer meta; + private final long streamByteCount; + private boolean hasCallback = false; + + public StreamData( + TransportRequestHandler handler, + TransportFrameDecoder frameDecoder, + RpcResponseCallback rpcCallback, + ByteBuffer meta, + long streamByteCount) { + this.handler = handler; + this.frameDecoder = frameDecoder; + this.rpcCallback = rpcCallback; + this.meta = meta; + this.streamByteCount = streamByteCount; + } + + public boolean hasCallback() { + return hasCallback; + } + + /** + * Register callback to receive the streaming data. + * + * If an exception is thrown from the callback, it will be propogated back to the sender as an rpc + * failure. + * @param callback + */ + public void registerStreamCallback(String streamId, StreamCallback callback) throws IOException { + if (hasCallback) { + throw new IllegalStateException("Cannot register more than one stream callback"); + } + hasCallback = true; + // the passed callback handles the actual data, but we need to also make sure we respond to the + // original rpc request. + StreamCallback wrappedCallback = new StreamCallback() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + callback.onData(streamId, buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + callback.onComplete(streamId); + rpcCallback.onSuccess(ByteBuffer.allocate(0)); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + rpcCallback.onFailure(new IOException("Destination failed while reading stream", cause)); + callback.onFailure(streamId, cause); + } + }; + if (streamByteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(handler, streamId, streamByteCount, + wrappedCallback); + frameDecoder.setInterceptor(interceptor); + } else { + wrappedCallback.onComplete(streamId); + } + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e94453578e6b0..4e20264e4412b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -23,6 +23,9 @@ import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportFrameDecoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,18 +33,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RequestMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.protocol.StreamResponse; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -113,6 +105,8 @@ public void handle(RequestMessage request) { processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); + } else if (request instanceof UploadStream) { + processStreamUpload((UploadStream) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -184,7 +178,7 @@ private void processStreamRequest(final StreamRequest req) { private void processRpcRequest(final RpcRequest req) { try { - rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { + RpcResponseCallback callback = new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); @@ -194,7 +188,8 @@ public void onSuccess(ByteBuffer response) { public void onFailure(Throwable e) { respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); } - }); + }; + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), null, callback); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); @@ -203,6 +198,44 @@ public void onFailure(Throwable e) { } } + /** + * Handle a request from the client to upload a stream of data. + */ + private void processStreamUpload(final UploadStream req) { + assert (req.body() == null); + try { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }; + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + ByteBuffer meta = req.meta.nioByteBuffer(); + StreamData streamData = new StreamData(TransportRequestHandler.this, frameDecoder, + callback, meta, req.bodyByteCount); + rpcHandler.receive(reverseClient, meta, streamData, callback); + if (!streamData.hasCallback()) { + throw new RuntimeException("Destination did not register stream handler"); + } + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + // We choose to totally fail the channel, rather than trying to recover as we do in other + // cases. We don't know how many bytes of the stream the client has already sent for the + // stream, its not worth trying to recover. + channel.pipeline().fireExceptionCaught(e); + } finally { + req.meta.release(); + } + } + private void processOneWayMessage(OneWayMessage req) { try { rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 824482af08dd4..708200f8a2d58 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -32,6 +32,7 @@ import com.google.common.collect.Sets; import com.google.common.io.Closeables; +import org.apache.spark.network.server.StreamData; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -108,6 +109,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { public void receive( TransportClient client, ByteBuffer message, + StreamData data, RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index c0724e018263f..47dcf75d620c4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -25,6 +25,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.util.MapConfigProvider; @@ -91,6 +92,7 @@ public void timeoutInactiveRequests() throws Exception { public void receive( TransportClient client, ByteBuffer message, + StreamData streamData, RpcResponseCallback callback) { try { semaphore.acquire(); @@ -138,6 +140,7 @@ public void timeoutCleanlyClosesClient() throws Exception { public void receive( TransportClient client, ByteBuffer message, + StreamData streamData, RpcResponseCallback callback) { try { semaphore.acquire(); @@ -194,6 +197,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { public void receive( TransportClient client, ByteBuffer message, + StreamData streamData, RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8ff737b129641..a0321f6424d47 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,19 +17,23 @@ package org.apache.spark.network; +import java.io.*; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Set; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; +import com.google.common.io.Files; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.server.*; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; @@ -37,37 +41,44 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { + static TransportConf conf; static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; static List oneWayMsgs; + static StreamTestHelper testData; + + static ConcurrentHashMap streamCallbacks = + new ConcurrentHashMap<>(); @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + testData = new StreamTestHelper(); rpcHandler = new RpcHandler() { @Override public void receive( TransportClient client, ByteBuffer message, + StreamData streamData, RpcResponseCallback callback) { String msg = JavaUtils.bytesToString(message); - String[] parts = msg.split("/"); - if (parts[0].equals("hello")) { - callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); - } else if (parts[0].equals("return error")) { - callback.onFailure(new RuntimeException("Returned: " + parts[1])); - } else if (parts[0].equals("throw error")) { - throw new RuntimeException("Thrown: " + parts[1]); + if (streamData != null) { + receiveStream(msg, streamData); + } else { + String[] parts = msg.split("/"); + if (parts[0].equals("hello")) { + callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); + } else if (parts[0].equals("return error")) { + callback.onFailure(new RuntimeException("Returned: " + parts[1])); + } else if (parts[0].equals("throw error")) { + throw new RuntimeException("Thrown: " + parts[1]); + } } } @@ -85,10 +96,52 @@ public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs = new ArrayList<>(); } + private static void receiveStream(String msg, StreamData streamData) { + try { + if (msg.startsWith("fail/")) { + String[] parts = msg.split("/"); + switch(parts[1]) { + case "no callback": + // don't register anything here, check the rpc error response is appropriate + break; + case "exception": + StreamCallback callback = new StreamCallback() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + throw new IOException("failed to read stream data!"); + } + + @Override + public void onComplete(String streamId) throws IOException { + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + }; + streamData.registerStreamCallback(msg, callback); + break; + case "multiple": + VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); + streamData.registerStreamCallback(msg, streamCallback); + streamData.registerStreamCallback(msg, streamCallback); + break; + } + } else { + VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); + streamData.registerStreamCallback(msg, streamCallback); + streamCallbacks.put(msg, streamCallback); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @AfterClass public static void tearDown() { server.close(); clientFactory.close(); + testData.cleanup(); } static class RpcResult { @@ -130,6 +183,50 @@ public void onFailure(Throwable e) { return res; } + private RpcResult sendRpcWithStream(String... streams) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet()); + res.errorMessages = Collections.synchronizedSet(new HashSet()); + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer message) { + String response = JavaUtils.bytesToString(message); + res.successMessages.add(response); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + }; + + for (String stream: streams) { + int idx = stream.lastIndexOf('/'); + ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); + String streamName = (idx == -1) ? stream : stream.substring(idx + 1); + ManagedBuffer data = testData.openStream(conf, streamName); + client.uploadStream(meta, data, callback); + } + streamCallbacks.values().forEach(streamCallback -> { + try { + streamCallback.waitForCompletionAndVerify(TimeUnit.SECONDS.toMillis(5)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + + if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + @Test public void singleRPC() throws Exception { RpcResult res = sendRPC("hello/Aaron"); @@ -193,8 +290,52 @@ public void sendOneWayMessage() throws Exception { } } + @Test + public void sendRpcWithStreamOneAtATime() throws Exception { + for (String stream: StreamTestHelper.STREAMS) { + RpcResult res = sendRpcWithStream(stream); + assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty()); + assertEquals(Sets.newHashSet(stream), res.successMessages); + } + } + + @Test + public void sendRpcWithStreamConcurrently() throws Exception { + String[] streams = new String[10]; + for (int i = 0; i < 10; i++) { + streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length]; + } + RpcResult res = sendRpcWithStream(streams); + assertEquals(res.successMessages, Sets.newHashSet(StreamTestHelper.STREAMS)); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void sendRpcWithStreamFailures() throws Exception { + // when there is a failure reading stream data, we don't try to keep the channel usable, + // just send back a decent error msg. + RpcResult noCallbackResult = sendRpcWithStream("fail/no callback/smallBuffer", "smallBuffer"); + assertTrue("unexpected success: " + noCallbackResult.successMessages, + noCallbackResult.successMessages.isEmpty()); + assertErrorsContain(noCallbackResult.errorMessages, + Sets.newHashSet("Destination did not register stream handler", "closed")); + + RpcResult multiCallbackResult = sendRpcWithStream("fail/multiple/smallBuffer", "smallBuffer"); + assertTrue("unexpected success: " + multiCallbackResult.successMessages, + multiCallbackResult.successMessages.isEmpty()); + assertErrorsContain(multiCallbackResult.errorMessages, + Sets.newHashSet("Cannot register more than one stream callback", "closed")); + + RpcResult exceptionInCallbackResult = sendRpcWithStream("fail/exception/file", "smallBuffer"); + assertTrue("unexpected success: " + exceptionInCallbackResult.successMessages, + exceptionInCallbackResult.successMessages.isEmpty()); + assertErrorsContain(exceptionInCallbackResult.errorMessages, + Sets.newHashSet("Destination failed while reading stream", "Connection reset")); + } + private void assertErrorsContain(Set errors, Set contains) { - assertEquals(contains.size(), errors.size()); + assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " + + errors, contains.size(), errors.size()); Set remainingErrors = Sets.newHashSet(errors); for (String contain : contains) { @@ -212,4 +353,56 @@ private void assertErrorsContain(Set errors, Set contains) { assertTrue(remainingErrors.isEmpty()); } + + private static class VerifyingStreamCallback implements StreamCallback { + final String streamId; + final StreamSuite.TestCallback helper; + final OutputStream out; + final File outFile; + VerifyingStreamCallback(String streamId) throws IOException { + if (streamId.equals("file")) { + outFile = File.createTempFile("data", ".tmp", testData.tempDir); + out = new FileOutputStream(outFile); + } else { + out = new ByteArrayOutputStream(); + outFile = null; + } + this.streamId = streamId; + helper = new StreamSuite.TestCallback(out); + } + + void waitForCompletionAndVerify(long timeoutMs) throws IOException { + helper.waitForCompletion(timeoutMs); + if (streamId.equals("file")) { + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); + } else { + byte[] result = ((ByteArrayOutputStream)out).toByteArray(); + ByteBuffer srcBuffer = testData.srcBuffer(streamId); + ByteBuffer base; + synchronized (srcBuffer) { + base = srcBuffer.duplicate(); + } + byte[] expected = new byte[base.remaining()]; + base.get(expected); + assertEquals(expected.length, result.length); + assertTrue("buffers don't match", Arrays.equals(expected, result)); + + } + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + helper.onData(streamId, buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + helper.onComplete(streamId); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + helper.onFailure(streamId, cause); + } + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f253a07e64be1..d3d990b3be403 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -32,6 +32,7 @@ import java.util.concurrent.TimeUnit; import com.google.common.io.Files; +import org.apache.spark.network.server.StreamData; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -51,16 +52,11 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + private static final String[] STREAMS = StreamTestHelper.STREAMS; + private static StreamTestHelper testData; private static TransportServer server; private static TransportClientFactory clientFactory; - private static File testFile; - private static File tempDir; - - private static ByteBuffer emptyBuffer; - private static ByteBuffer smallBuffer; - private static ByteBuffer largeBuffer; private static ByteBuffer createBuffer(int bufSize) { ByteBuffer buf = ByteBuffer.allocate(bufSize); @@ -73,23 +69,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { - tempDir = Files.createTempDir(); - emptyBuffer = createBuffer(0); - smallBuffer = createBuffer(100); - largeBuffer = createBuffer(100000); - - testFile = File.createTempFile("stream-test-file", "txt", tempDir); - FileOutputStream fp = new FileOutputStream(testFile); - try { - Random rnd = new Random(); - for (int i = 0; i < 512; i++) { - byte[] fileContent = new byte[1024]; - rnd.nextBytes(fileContent); - fp.write(fileContent); - } - } finally { - fp.close(); - } + testData = new StreamTestHelper(); final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @@ -100,18 +80,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { @Override public ManagedBuffer openStream(String streamId) { - switch (streamId) { - case "largeBuffer": - return new NioManagedBuffer(largeBuffer); - case "smallBuffer": - return new NioManagedBuffer(smallBuffer); - case "emptyBuffer": - return new NioManagedBuffer(emptyBuffer); - case "file": - return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); - default: - throw new IllegalArgumentException("Invalid stream: " + streamId); - } + return testData.openStream(conf, streamId); } }; RpcHandler handler = new RpcHandler() { @@ -119,6 +88,7 @@ public ManagedBuffer openStream(String streamId) { public void receive( TransportClient client, ByteBuffer message, + StreamData streamData, RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -137,12 +107,7 @@ public StreamManager getStreamManager() { public static void tearDown() { server.close(); clientFactory.close(); - if (tempDir != null) { - for (File f : tempDir.listFiles()) { - f.delete(); - } - tempDir.delete(); - } + testData.cleanup(); } @Test @@ -234,21 +199,21 @@ public void run() { case "largeBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = largeBuffer; + srcBuffer = testData.largeBuffer; break; case "smallBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = smallBuffer; + srcBuffer = testData.smallBuffer; break; case "file": - outFile = File.createTempFile("data", ".tmp", tempDir); + outFile = File.createTempFile("data", ".tmp", testData.tempDir); out = new FileOutputStream(outFile); break; case "emptyBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = emptyBuffer; + srcBuffer = testData.emptyBuffer; break; default: throw new IllegalArgumentException(streamId); @@ -256,10 +221,10 @@ public void run() { TestCallback callback = new TestCallback(out); client.stream(streamId, callback); - waitForCompletion(callback); + callback.waitForCompletion(timeoutMs); if (srcBuffer == null) { - assertTrue("File stream did not match.", Files.equal(testFile, outFile)); + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); } else { ByteBuffer base; synchronized (srcBuffer) { @@ -293,22 +258,11 @@ public void check() throws Throwable { } } - private void waitForCompletion(TestCallback callback) throws Exception { - long now = System.currentTimeMillis(); - long deadline = now + timeoutMs; - synchronized (callback) { - while (!callback.completed && now < deadline) { - callback.wait(deadline - now); - now = System.currentTimeMillis(); - } - } - assertTrue("Timed out waiting for stream.", callback.completed); - assertNull(callback.error); - } + } - private static class TestCallback implements StreamCallback { + static class TestCallback implements StreamCallback { private final OutputStream out; public volatile boolean completed; @@ -344,6 +298,23 @@ public void onFailure(String streamId, Throwable cause) { } } + void waitForCompletion(long timeoutMs) { + long now = System.currentTimeMillis(); + long deadline = now + timeoutMs; + synchronized (this) { + while (!completed && now < deadline) { + try { + wait(deadline - now); + } catch (InterruptedException ie) { + throw new RuntimeException(ie); + } + now = System.currentTimeMillis(); + } + } + assertTrue("Timed out waiting for stream.", completed); + assertNull(error); + } + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java new file mode 100644 index 0000000000000..63c16e27c6701 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import com.google.common.io.Files; +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.TransportConf; + +import java.io.File; +import java.io.FileOutputStream; +import java.nio.ByteBuffer; +import java.util.Random; + +class StreamTestHelper { + static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + + final File testFile; + File tempDir; + + ByteBuffer emptyBuffer; + ByteBuffer smallBuffer; + ByteBuffer largeBuffer; + + private static ByteBuffer createBuffer(int bufSize) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + return buf; + } + + StreamTestHelper() throws Exception { + tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); + smallBuffer = createBuffer(100); + largeBuffer = createBuffer(100000); + + testFile = File.createTempFile("stream-test-file", "txt", tempDir); + FileOutputStream fp = new FileOutputStream(testFile); + try { + Random rnd = new Random(); + for (int i = 0; i < 512; i++) { + byte[] fileContent = new byte[1024]; + rnd.nextBytes(fileContent); + fp.write(fileContent); + } + } finally { + fp.close(); + } + } + + public ByteBuffer srcBuffer(String name) { + switch (name) { + case "largeBuffer": + return largeBuffer; + case "smallBuffer": + return smallBuffer; + case "emptyBuffer": + return emptyBuffer; + default: + throw new IllegalArgumentException("Invalid stream: " + name); + } + } + + public ManagedBuffer openStream(TransportConf conf, String streamId) { + switch (streamId) { + case "file": + return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); + default: + return new NioManagedBuffer(srcBuffer(streamId)); + } + } + + + void cleanup() { + if (tempDir != null) { + for (File f : tempDir.listFiles()) { + f.delete(); + } + tempDir.delete(); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 8751944a1c2a3..6c57fdc9694d2 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableMap; import io.netty.channel.Channel; +import org.apache.spark.network.server.*; import org.junit.After; import org.junit.Test; import static org.junit.Assert.*; @@ -37,10 +38,6 @@ import org.apache.spark.network.sasl.SaslRpcHandler; import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -144,6 +141,7 @@ private class AuthTestCtx { public void receive( TransportClient client, ByteBuffer message, + StreamData streamData, RpcResponseCallback callback) { assertEquals("Ping", JavaUtils.bytesToString(message)); callback.onSuccess(JavaUtils.stringToBytes("Pong")); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 6f15718bd8705..4e72fdf2b92e0 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -44,6 +44,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; +import org.apache.spark.network.server.*; import org.junit.Test; import org.apache.spark.network.TestUtils; @@ -54,10 +55,6 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; @@ -145,7 +142,8 @@ private static void testBasicSasl(boolean encrypt) throws Throwable { return null; }) .when(rpcHandler) - .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); + .receive(any(TransportClient.class), any(ByteBuffer.class), any(StreamData.class), + any(RpcResponseCallback.class)); SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index fc7bba41185f0..4f3c174b8f646 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -38,6 +38,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; import org.apache.spark.network.shuffle.protocol.*; @@ -76,7 +77,12 @@ public ExternalShuffleBlockHandler( } @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + StreamData streamData, + RpcResponseCallback callback) { + assert(streamData == null); BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); handleMessage(msgObj, client, callback); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 02e6eb3a4467e..69c3889abeac8 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -39,11 +39,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.server.*; import org.apache.spark.network.shuffle.BlockFetchingListener; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver; @@ -264,7 +260,11 @@ public void onFailure(int chunkIndex, Throwable t) { /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + StreamData streamData, + RpcResponseCallback callback) { callback.onSuccess(message); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7846b71d5a8b1..7f0279558221e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -63,7 +63,7 @@ public void testRegisterExecutor() { ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); - handler.receive(client, registerMessage, callback); + handler.receive(client, registerMessage, null, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); @@ -88,7 +88,7 @@ public void testOpenShuffleBlocks() { ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) .toByteBuffer(); - handler.receive(client, openBlocks, callback); + handler.receive(client, openBlocks, null, callback); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); @@ -129,7 +129,7 @@ public void testBadMessages() { ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); try { - handler.receive(client, unserializableMsg, callback); + handler.receive(client, unserializableMsg, null, callback); fail("Should have thrown"); } catch (Exception e) { // pass @@ -138,7 +138,7 @@ public void testBadMessages() { ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); try { - handler.receive(client, unexpectedMsg, callback); + handler.receive(client, unexpectedMsg, null, callback); fail("Should have thrown"); } catch (UnsupportedOperationException e) { // pass diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index eb4cf94164fd4..1397ae24db15a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.NioManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} -import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} +import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamData, StreamManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -50,6 +50,7 @@ class NettyBlockRpcServer( override def receive( client: TransportClient, rpcMessage: ByteBuffer, + streamData: StreamData, responseContext: RpcResponseCallback): Unit = { val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index a2936d6ad539c..1c36c7a620c49 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -520,12 +520,12 @@ private[netty] class NettyRpcEndpointRef( override def name: String = endpointAddress.name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout) + nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message, null), timeout) } override def send(message: Any): Unit = { require(message != null, "Message is null") - nettyEnv.send(new RequestMessage(nettyEnv.address, this, message)) + nettyEnv.send(new RequestMessage(nettyEnv.address, this, message, null)) } override def toString: String = s"NettyRpcEndpointRef(${endpointAddress})" @@ -546,11 +546,14 @@ private[netty] class NettyRpcEndpointRef( * `NettyRpcEnv`. * @param receiver the receiver of this message. * @param content the message content. + * @param streamData optional stream of data. May be null. If present, + * streamData.registerStreamCallback *must* be called. */ private[netty] class RequestMessage( val senderAddress: RpcAddress, val receiver: NettyRpcEndpointRef, - val content: Any) { + val content: Any, + val streamData: StreamData) { /** Manually serialize [[RequestMessage]] to minimize the size. */ def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = { @@ -596,7 +599,11 @@ private[netty] object RequestMessage { } } - def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = { + def apply( + nettyEnv: NettyRpcEnv, + client: TransportClient, + bytes: ByteBuffer, + streamData: StreamData): RequestMessage = { val bis = new ByteBufferInputStream(bytes) val in = new DataInputStream(bis) try { @@ -608,7 +615,8 @@ private[netty] object RequestMessage { senderAddress, ref, // The remaining bytes in `bytes` are the message content. - nettyEnv.deserialize(client, bytes)) + nettyEnv.deserialize(client, bytes), + streamData) } finally { in.close() } @@ -643,26 +651,30 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, message: ByteBuffer, + streamData: StreamData, callback: RpcResponseCallback): Unit = { - val messageToDispatch = internalReceive(client, message) + val messageToDispatch = internalReceive(client, message, streamData) dispatcher.postRemoteMessage(messageToDispatch, callback) } override def receive( client: TransportClient, message: ByteBuffer): Unit = { - val messageToDispatch = internalReceive(client, message) + val messageToDispatch = internalReceive(client, message, null) dispatcher.postOneWayMessage(messageToDispatch) } - private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { + private def internalReceive( + client: TransportClient, + message: ByteBuffer, + streamData: StreamData): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostString, addr.getPort) - val requestMessage = RequestMessage(nettyEnv, client, message) + val requestMessage = RequestMessage(nettyEnv, client, message, streamData) if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. - new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) + new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, streamData) } else { // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for // the listening address diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index f9481f875d439..46dccc2c88ee4 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -69,19 +69,19 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { val receiverAddress = RpcEndpointAddress("localhost", 54321, "test") val receiver = new NettyRpcEndpointRef(nettyEnv.conf, receiverAddress, nettyEnv) - val msg = new RequestMessage(senderAddress, receiver, "foo") + val msg = new RequestMessage(senderAddress, receiver, "foo", null) assertRequestMessageEquals( msg, - RequestMessage(nettyEnv, client, msg.serialize(nettyEnv))) + RequestMessage(nettyEnv, client, msg.serialize(nettyEnv), null)) - val msg2 = new RequestMessage(null, receiver, "foo") + val msg2 = new RequestMessage(null, receiver, "foo", null) assertRequestMessageEquals( msg2, - RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv))) + RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv), null)) - val msg3 = new RequestMessage(senderAddress, receiver, null) + val msg3 = new RequestMessage(senderAddress, receiver, null, null) assertRequestMessageEquals( msg3, - RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv))) + RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv), null)) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index a71d8726e7066..138e84ae7616a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -34,7 +34,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) - .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null)) + .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null, null)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..f334e0a67ef8f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.network.{BlockDataManager, BlockTransferService, Transpo import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} -import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} +import org.apache.spark.network.server.{NoOpRpcHandler, StreamData, TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.network.util.TransportConf @@ -1341,6 +1341,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def receive( client: TransportClient, message: ByteBuffer, + streamData: StreamData, callback: RpcResponseCallback): Unit = { val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) msgObj match { From 3098b9cd9ffc29517b446bb660fe5be9f0031cc1 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 16 May 2018 14:06:39 -0500 Subject: [PATCH 02/18] fix --- project/MimaExcludes.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4f6d5ff898681..eeb097ef153ad 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), From 2fef75f18a115db542afe96d49b8cbe9ed534d53 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 24 May 2018 13:23:28 -0500 Subject: [PATCH 03/18] fixes --- .../spark/network/RpcIntegrationSuite.java | 39 ++++++++++++------- .../spark/network/sasl/SparkSaslSuite.java | 2 +- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index a0321f6424d47..0af11ae3452c1 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -189,27 +189,13 @@ private RpcResult sendRpcWithStream(String... streams) throws Exception { RpcResult res = new RpcResult(); res.successMessages = Collections.synchronizedSet(new HashSet()); res.errorMessages = Collections.synchronizedSet(new HashSet()); - RpcResponseCallback callback = new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer message) { - String response = JavaUtils.bytesToString(message); - res.successMessages.add(response); - sem.release(); - } - - @Override - public void onFailure(Throwable e) { - res.errorMessages.add(e.getMessage()); - sem.release(); - } - }; for (String stream: streams) { int idx = stream.lastIndexOf('/'); ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); String streamName = (idx == -1) ? stream : stream.substring(idx + 1); ManagedBuffer data = testData.openStream(conf, streamName); - client.uploadStream(meta, data, callback); + client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem)); } streamCallbacks.values().forEach(streamCallback -> { try { @@ -227,6 +213,29 @@ public void onFailure(Throwable e) { return res; } + private static class RpcStreamCallback implements RpcResponseCallback { + final String streamId; + final RpcResult res; + final Semaphore sem; + RpcStreamCallback(String streamId, RpcResult res, Semaphore sem) { + this.streamId = streamId; + this.res = res; + this.sem = sem; + } + + @Override + public void onSuccess(ByteBuffer message) { + res.successMessages.add(streamId); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + } + @Test public void singleRPC() throws Exception { RpcResult res = sendRPC("hello/Aaron"); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 4e72fdf2b92e0..889b29564cc8b 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -136,7 +136,7 @@ private static void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); doAnswer(invocation -> { ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; - RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; + RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[3]; assertEquals("Ping", JavaUtils.bytesToString(message)); cb.onSuccess(JavaUtils.stringToBytes("Pong")); return null; From 54533c2c882a399d5708da8dfe6a518dd6132844 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 24 May 2018 13:42:15 -0500 Subject: [PATCH 04/18] java style checks --- .../org/apache/spark/network/client/StreamInterceptor.java | 1 - .../main/java/org/apache/spark/network/server/StreamData.java | 2 -- .../apache/spark/network/server/TransportRequestHandler.java | 1 - .../java/org/apache/spark/network/RpcIntegrationSuite.java | 2 -- .../src/test/java/org/apache/spark/network/StreamSuite.java | 3 --- 5 files changed, 9 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index e973a99323ee1..10f970da6d9d5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -24,7 +24,6 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.server.MessageHandler; -import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.util.TransportFrameDecoder; /** diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java index 6e9576bd443a5..86f8c35e5e5c1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java @@ -17,11 +17,9 @@ package org.apache.spark.network.server; -import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.StreamInterceptor; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportFrameDecoder; import java.io.IOException; diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 4e20264e4412b..46d53b1baf7e6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -24,7 +24,6 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import org.apache.spark.network.protocol.*; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportFrameDecoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 0af11ae3452c1..08afa90d74465 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -27,13 +27,11 @@ import com.google.common.collect.Sets; import com.google.common.io.Files; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.server.*; import org.junit.AfterClass; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index d3d990b3be403..e5863d0708a3e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -26,7 +26,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Random; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; @@ -38,9 +37,7 @@ import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; From 32f4f94e3cde50015a8ea478969636fca708cf82 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 24 May 2018 21:26:30 -0500 Subject: [PATCH 05/18] don't care whether error is closed or reset --- .../spark/network/RpcIntegrationSuite.java | 55 ++++++++++++++----- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 08afa90d74465..a084ae6196307 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -26,6 +26,8 @@ import com.google.common.collect.Sets; import com.google.common.io.Files; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.StreamCallback; @@ -322,29 +324,51 @@ public void sendRpcWithStreamFailures() throws Exception { // when there is a failure reading stream data, we don't try to keep the channel usable, // just send back a decent error msg. RpcResult noCallbackResult = sendRpcWithStream("fail/no callback/smallBuffer", "smallBuffer"); - assertTrue("unexpected success: " + noCallbackResult.successMessages, - noCallbackResult.successMessages.isEmpty()); - assertErrorsContain(noCallbackResult.errorMessages, - Sets.newHashSet("Destination did not register stream handler", "closed")); + assertErrorAndClosed(noCallbackResult, "Destination did not register stream handler"); + RpcResult multiCallbackResult = sendRpcWithStream("fail/multiple/smallBuffer", "smallBuffer"); - assertTrue("unexpected success: " + multiCallbackResult.successMessages, - multiCallbackResult.successMessages.isEmpty()); - assertErrorsContain(multiCallbackResult.errorMessages, - Sets.newHashSet("Cannot register more than one stream callback", "closed")); + assertErrorAndClosed(multiCallbackResult, "Cannot register more than one stream callback"); RpcResult exceptionInCallbackResult = sendRpcWithStream("fail/exception/file", "smallBuffer"); - assertTrue("unexpected success: " + exceptionInCallbackResult.successMessages, - exceptionInCallbackResult.successMessages.isEmpty()); - assertErrorsContain(exceptionInCallbackResult.errorMessages, - Sets.newHashSet("Destination failed while reading stream", "Connection reset")); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); } private void assertErrorsContain(Set errors, Set contains) { assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " + errors, contains.size(), errors.size()); + Pair, List> r = checkErrorsContain(errors, contains); + assertTrue("Could not find error containing " + r.getRight() + "; errors: " + errors, + r.getRight().isEmpty()); + + assertTrue(r.getLeft().isEmpty()); + } + + private void assertErrorAndClosed(RpcResult result, String expectedError) { + assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); + // we expect 1 additional error, which contains *either* "closed" or "Connection reset" + Set errors = result.errorMessages; + assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + + errors, 2, errors.size()); + + Set containsAndClosed = Sets.newHashSet(expectedError); + containsAndClosed.add("closed"); + containsAndClosed.add("Connection reset"); + + Pair, List> r = checkErrorsContain(errors, containsAndClosed); + + List errorsNotFound = r.getRight(); + assertEquals(1, errorsNotFound.size()); + String err = errorsNotFound.get(0); + assertTrue(err.equals("closed") || err.equals("Connection reset")); + + assertTrue(r.getLeft().isEmpty()); + } + + private Pair, List> checkErrorsContain(Set errors, Set contains) { Set remainingErrors = Sets.newHashSet(errors); + List notFound = new LinkedList(); for (String contain : contains) { Iterator it = remainingErrors.iterator(); boolean foundMatch = false; @@ -355,10 +379,11 @@ private void assertErrorsContain(Set errors, Set contains) { break; } } - assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + if (!foundMatch) { + notFound.add(contain); + } } - - assertTrue(remainingErrors.isEmpty()); + return new ImmutablePair<>(remainingErrors, notFound); } private static class VerifyingStreamCallback implements StreamCallback { From 331124b125db6b59009e12249542f667a227226e Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 24 May 2018 22:47:37 -0500 Subject: [PATCH 06/18] style --- .../apache/spark/network/RpcIntegrationSuite.java | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index a084ae6196307..5e28fa70a69f3 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -338,7 +338,7 @@ private void assertErrorsContain(Set errors, Set contains) { assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " + errors, contains.size(), errors.size()); - Pair, List> r = checkErrorsContain(errors, contains); + Pair, Set> r = checkErrorsContain(errors, contains); assertTrue("Could not find error containing " + r.getRight() + "; errors: " + errors, r.getRight().isEmpty()); @@ -356,19 +356,21 @@ private void assertErrorAndClosed(RpcResult result, String expectedError) { containsAndClosed.add("closed"); containsAndClosed.add("Connection reset"); - Pair, List> r = checkErrorsContain(errors, containsAndClosed); + Pair, Set> r = checkErrorsContain(errors, containsAndClosed); - List errorsNotFound = r.getRight(); + Set errorsNotFound = r.getRight(); assertEquals(1, errorsNotFound.size()); - String err = errorsNotFound.get(0); + String err = errorsNotFound.iterator().next(); assertTrue(err.equals("closed") || err.equals("Connection reset")); assertTrue(r.getLeft().isEmpty()); } - private Pair, List> checkErrorsContain(Set errors, Set contains) { + private Pair, Set> checkErrorsContain( + Set errors, + Set contains) { Set remainingErrors = Sets.newHashSet(errors); - List notFound = new LinkedList(); + Set notFound = Sets.newHashSet(); for (String contain : contains) { Iterator it = remainingErrors.iterator(); boolean foundMatch = false; From f4d9123be67ee2421436af741289134013a9760f Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sat, 26 May 2018 16:51:10 -0500 Subject: [PATCH 07/18] review feedback --- .../spark/network/client/TransportClient.java | 150 +++++++++--------- .../spark/network/protocol/UploadStream.java | 2 +- .../spark/network/server/RpcHandler.java | 2 +- .../spark/network/server/StreamData.java | 7 +- .../server/TransportRequestHandler.java | 4 +- .../network/ChunkFetchIntegrationSuite.java | 2 +- .../spark/network/RpcIntegrationSuite.java | 15 +- .../org/apache/spark/network/StreamSuite.java | 2 +- .../spark/network/StreamTestHelper.java | 11 +- .../network/crypto/AuthIntegrationSuite.java | 2 +- .../spark/network/sasl/SparkSaslSuite.java | 2 +- 11 files changed, 96 insertions(+), 103 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index b54fa8bae32bb..ba9e4962254c0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -32,12 +32,14 @@ import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.protocol.*; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.protocol.*; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -139,26 +141,14 @@ public void fetchChunk( StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); handler.addFetchRequest(streamChunkId, callback); - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", streamChunkId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeFetchRequest(streamChunkId); - channel.close(); - try { - callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)) + .addListener( new StdChannelListener(startTime, streamChunkId) { + @Override + void handleFailure(String errorMsg, Throwable cause) { + handler.removeFetchRequest(streamChunkId); + callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); } - } - }); + }); } /** @@ -178,25 +168,13 @@ public void stream(String streamId, StreamCallback callback) { // when responses arrive. synchronized (this) { handler.addStreamCallback(streamId, callback); - channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request for {} to {} took {} ms", streamId, - getRemoteAddress(channel), timeTaken); + channel.writeAndFlush(new StreamRequest(streamId)) + .addListener(new StdChannelListener(startTime, streamId) { + @Override + void handleFailure(String errorMsg, Throwable cause) throws Exception { + callback.onFailure(streamId, new IOException(errorMsg, cause)); } - } else { - String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - channel.close(); - try { - callback.onFailure(streamId, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + }); } } @@ -218,26 +196,7 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { handler.addRpcRequest(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))) - .addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", requestId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(requestId); - channel.close(); - try { - callback.onFailure(new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + .addListener(new RpcChannelListener(startTime, requestId, callback)); return requestId; } @@ -266,30 +225,63 @@ public long uploadStream( handler.addRpcRequest(requestId, callback); channel.writeAndFlush(new UploadStream(requestId, meta, data)) - .addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", requestId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(requestId); - channel.close(); - try { - callback.onFailure(new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + .addListener(new RpcChannelListener(startTime, requestId, callback)); return requestId; } + private class StdChannelListener + implements GenericFutureListener> { + final long startTime; + final Object requestId; + + StdChannelListener(long startTime, Object requestId) { + this.startTime = startTime; + this.requestId = requestId; + } + + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + if (logger.isTraceEnabled()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + handleFailure(errorMsg, future.cause()); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + + void handleFailure(String errorMsg, Throwable cause) throws Exception {} + } + + class RpcChannelListener extends StdChannelListener { + final long rpcRequestId; + final RpcResponseCallback callback; + + RpcChannelListener(long startTime, long rpcRequestId, RpcResponseCallback callback) { + super(startTime, "RPC " + rpcRequestId); + this.rpcRequestId = rpcRequestId; + this.callback = callback; + } + + @Override + void handleFailure(String errorMsg, Throwable cause) { + handler.removeRpcRequest(rpcRequestId); + callback.onFailure(new IOException(errorMsg, cause)); + } + } + + /** * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java index a175f90cbe56a..3f9c1612126b7 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -27,7 +27,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer; /** - * An RPC with data that is sent outside of the frame, so it can be read in a stream. + * An RPC with data that is sent outside of the frame, so it can be read as a stream. */ public final class UploadStream extends AbstractMessage implements RequestMessage { /** Used to link an RPC request with its response. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 00d60313f4048..32e31fd2cd0d0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -38,7 +38,7 @@ public abstract class RpcHandler { * * This method will not be called in parallel for a single TransportClient (i.e., channel). * - * The rpc *might* included a data stream in streamData(eg. for uploading a large + * The rpc *might* included a data stream in streamData (eg. for uploading a large * amount of data which should not be buffered in memory here). Any errors while handling the * streamData will lead to failing this entire connection -- all other in-flight rpcs will fail. * If stream data is not null, you *must* call streamData.registerStreamCallback diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java index 86f8c35e5e5c1..539cb492640c5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java @@ -17,14 +17,14 @@ package org.apache.spark.network.server; +import java.io.IOException; +import java.nio.ByteBuffer; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.StreamInterceptor; import org.apache.spark.network.util.TransportFrameDecoder; -import java.io.IOException; -import java.nio.ByteBuffer; - /** * A holder for streamed data sent along with an RPC message. */ @@ -59,7 +59,6 @@ public boolean hasCallback() { * * If an exception is thrown from the callback, it will be propogated back to the sender as an rpc * failure. - * @param callback */ public void registerStreamCallback(String streamId, StreamCallback callback) throws IOException { if (hasCallback) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 46d53b1baf7e6..30215c8590d27 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -23,8 +23,6 @@ import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; -import org.apache.spark.network.protocol.*; -import org.apache.spark.network.util.TransportFrameDecoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,6 +30,8 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.util.TransportFrameDecoder; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 708200f8a2d58..2bbe07036e392 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -32,7 +32,6 @@ import com.google.common.collect.Sets; import com.google.common.io.Closeables; -import org.apache.spark.network.server.StreamData; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -47,6 +46,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.MapConfigProvider; diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 5e28fa70a69f3..aca568fcc62bd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -28,19 +28,19 @@ import com.google.common.io.Files; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.server.*; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.*; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -100,7 +100,7 @@ private static void receiveStream(String msg, StreamData streamData) { try { if (msg.startsWith("fail/")) { String[] parts = msg.split("/"); - switch(parts[1]) { + switch (parts[1]) { case "no callback": // don't register anything here, check the rpc error response is appropriate break; @@ -190,7 +190,7 @@ private RpcResult sendRpcWithStream(String... streams) throws Exception { res.successMessages = Collections.synchronizedSet(new HashSet()); res.errorMessages = Collections.synchronizedSet(new HashSet()); - for (String stream: streams) { + for (String stream : streams) { int idx = stream.lastIndexOf('/'); ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); String streamName = (idx == -1) ? stream : stream.substring(idx + 1); @@ -217,6 +217,7 @@ private static class RpcStreamCallback implements RpcResponseCallback { final String streamId; final RpcResult res; final Semaphore sem; + RpcStreamCallback(String streamId, RpcResult res, Semaphore sem) { this.streamId = streamId; this.res = res; @@ -301,7 +302,7 @@ public void sendOneWayMessage() throws Exception { @Test public void sendRpcWithStreamOneAtATime() throws Exception { - for (String stream: StreamTestHelper.STREAMS) { + for (String stream : StreamTestHelper.STREAMS) { RpcResult res = sendRpcWithStream(stream); assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty()); assertEquals(Sets.newHashSet(stream), res.successMessages); diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index e5863d0708a3e..f6bd006ef8f56 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -31,7 +31,6 @@ import java.util.concurrent.TimeUnit; import com.google.common.io.Files; -import org.apache.spark.network.server.StreamData; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -43,6 +42,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.util.MapConfigProvider; diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java index 63c16e27c6701..5d581b80afd9e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java @@ -17,17 +17,18 @@ package org.apache.spark.network; +import java.io.File; +import java.io.FileOutputStream; +import java.nio.ByteBuffer; +import java.util.Random; + import com.google.common.io.Files; + import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.util.TransportConf; -import java.io.File; -import java.io.FileOutputStream; -import java.nio.ByteBuffer; -import java.util.Random; - class StreamTestHelper { static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 6c57fdc9694d2..c2cfe1d997586 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -24,7 +24,6 @@ import com.google.common.collect.ImmutableMap; import io.netty.channel.Channel; -import org.apache.spark.network.server.*; import org.junit.After; import org.junit.Test; import static org.junit.Assert.*; @@ -38,6 +37,7 @@ import org.apache.spark.network.sasl.SaslRpcHandler; import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.*; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 889b29564cc8b..24532dd152eb3 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -44,7 +44,6 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; -import org.apache.spark.network.server.*; import org.junit.Test; import org.apache.spark.network.TestUtils; @@ -55,6 +54,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.server.*; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; From 7bd1b43c81a3cdd7b88cf64994cfe8f2b3c5fdf8 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sat, 26 May 2018 22:45:09 -0500 Subject: [PATCH 08/18] private --- .../java/org/apache/spark/network/client/TransportClient.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index ba9e4962254c0..9daa64a66a9b9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -264,7 +264,7 @@ public void operationComplete(Future future) throws Exception { void handleFailure(String errorMsg, Throwable cause) throws Exception {} } - class RpcChannelListener extends StdChannelListener { + private class RpcChannelListener extends StdChannelListener { final long rpcRequestId; final RpcResponseCallback callback; From 83c3271d2f45bbef18d865bddbc6807e9fbd2503 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 30 May 2018 22:58:38 -0500 Subject: [PATCH 09/18] review feedback --- .../org/apache/spark/network/client/StreamInterceptor.java | 2 ++ .../java/org/apache/spark/network/protocol/StreamResponse.java | 2 +- .../java/org/apache/spark/network/protocol/UploadStream.java | 2 +- .../main/java/org/apache/spark/network/server/StreamData.java | 3 --- .../apache/spark/network/server/TransportRequestHandler.java | 2 +- 5 files changed, 5 insertions(+), 6 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index 10f970da6d9d5..f3eb744ff7345 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -64,6 +64,8 @@ public void channelInactive() throws Exception { private void deactivateStream() { if (handler instanceof TransportResponseHandler) { + // we only have to do this for TransportResponseHandler as it exposes numOutstandingFetches + // (there is no extra cleanup that needs to happen) ((TransportResponseHandler) handler).deactivateStream(); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index 87e212f3e157b..50b811604b84b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -67,7 +67,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId, body()); + return Objects.hashCode(byteCount, streamId); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java index 3f9c1612126b7..fa1d26e76b852 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -85,7 +85,7 @@ public static UploadStream decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(requestId, body()); + return Long.hashCode(requestId); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java index 539cb492640c5..add70825b1317 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java @@ -33,7 +33,6 @@ public class StreamData { private final TransportRequestHandler handler; private final TransportFrameDecoder frameDecoder; private final RpcResponseCallback rpcCallback; - private final ByteBuffer meta; private final long streamByteCount; private boolean hasCallback = false; @@ -41,12 +40,10 @@ public StreamData( TransportRequestHandler handler, TransportFrameDecoder frameDecoder, RpcResponseCallback rpcCallback, - ByteBuffer meta, long streamByteCount) { this.handler = handler; this.frameDecoder = frameDecoder; this.rpcCallback = rpcCallback; - this.meta = meta; this.streamByteCount = streamByteCount; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 30215c8590d27..e26b593d88906 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -218,7 +218,7 @@ public void onFailure(Throwable e) { channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); ByteBuffer meta = req.meta.nioByteBuffer(); StreamData streamData = new StreamData(TransportRequestHandler.this, frameDecoder, - callback, meta, req.bodyByteCount); + callback, req.bodyByteCount); rpcHandler.receive(reverseClient, meta, streamData, callback); if (!streamData.hasCallback()) { throw new RuntimeException("Destination did not register stream handler"); From 6c086c51873c72fa0cf9f373afd069ac63de3b75 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 4 Jun 2018 08:54:28 -0700 Subject: [PATCH 10/18] dont fail channel for failures during onComplete --- .../spark/network/server/RpcHandler.java | 10 ++++-- .../spark/network/server/StreamData.java | 15 +++++++-- .../spark/network/RpcIntegrationSuite.java | 33 ++++++++++++++++--- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 32e31fd2cd0d0..b1d4e746179fd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -39,9 +39,13 @@ public abstract class RpcHandler { * This method will not be called in parallel for a single TransportClient (i.e., channel). * * The rpc *might* included a data stream in streamData (eg. for uploading a large - * amount of data which should not be buffered in memory here). Any errors while handling the - * streamData will lead to failing this entire connection -- all other in-flight rpcs will fail. - * If stream data is not null, you *must* call streamData.registerStreamCallback + * amount of data which should not be buffered in memory here). An error while reading data from + * the stream ({@link org.apache.spark.network.client.StreamCallback#onData(String, ByteBuffer)}) + * will fail the entire channel. A failure in "post-processing" the stream in + * {@link org.apache.spark.network.client.StreamCallback#onComplete(String)} will result in an + * rpcFailure, but the channel will remain active. + * + * If streamData is not null, you *must* call streamData.registerStreamCallback * before this method returns. * * @param client A channel client which enables the handler to make requests back to the sender diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java index add70825b1317..6bd9514dc70bc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java @@ -55,7 +55,9 @@ public boolean hasCallback() { * Register callback to receive the streaming data. * * If an exception is thrown from the callback, it will be propogated back to the sender as an rpc - * failure. + * failure. A failure during onData will fail the entire channel (as we don't know + * what to do with the other data on the channel). A failure during onComplete will + * result in an rpc failure, but the channel will remain active. */ public void registerStreamCallback(String streamId, StreamCallback callback) throws IOException { if (hasCallback) { @@ -72,8 +74,15 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { @Override public void onComplete(String streamId) throws IOException { - callback.onComplete(streamId); - rpcCallback.onSuccess(ByteBuffer.allocate(0)); + try { + callback.onComplete(streamId); + rpcCallback.onSuccess(ByteBuffer.allocate(0)); + } catch (Exception ex) { + IOException ioExc = new IOException("Failure post-processing complete stream; failing " + + "this rpc and leaving channel active"); + rpcCallback.onFailure(ioExc); + callback.onFailure(streamId, ioExc); + } } @Override diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index aca568fcc62bd..a6da658cc6cd6 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -104,8 +104,8 @@ private static void receiveStream(String msg, StreamData streamData) { case "no callback": // don't register anything here, check the rpc error response is appropriate break; - case "exception": - StreamCallback callback = new StreamCallback() { + case "exception-ondata": + StreamCallback onDataExcCallback = new StreamCallback() { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { throw new IOException("failed to read stream data!"); @@ -119,7 +119,24 @@ public void onComplete(String streamId) throws IOException { public void onFailure(String streamId, Throwable cause) throws IOException { } }; - streamData.registerStreamCallback(msg, callback); + streamData.registerStreamCallback(msg, onDataExcCallback); + break; + case "exception-oncomplete": + StreamCallback onCompleteExcCallback = new StreamCallback() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + } + + @Override + public void onComplete(String streamId) throws IOException { + throw new IOException("exception in onComplete"); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + }; + streamData.registerStreamCallback(msg, onCompleteExcCallback); break; case "multiple": VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); @@ -331,8 +348,16 @@ public void sendRpcWithStreamFailures() throws Exception { RpcResult multiCallbackResult = sendRpcWithStream("fail/multiple/smallBuffer", "smallBuffer"); assertErrorAndClosed(multiCallbackResult, "Cannot register more than one stream callback"); - RpcResult exceptionInCallbackResult = sendRpcWithStream("fail/exception/file", "smallBuffer"); + RpcResult exceptionInCallbackResult = + sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer"); assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + // OTOH, if there is a failure during onComplete, the channel should still be fine + RpcResult exceptionInOnComplete = + sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer"); + assertErrorsContain(exceptionInOnComplete.errorMessages, + Sets.newHashSet("Failure post-processing")); + assertEquals(Sets.newHashSet("smallBuffer"), exceptionInOnComplete.successMessages); } private void assertErrorsContain(Set errors, Set contains) { From d357885eea86461ae29a9df66dc599406762fe0a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 12 Jun 2018 21:49:24 -0500 Subject: [PATCH 11/18] wip refactoring --- .../spark/network/crypto/AuthRpcHandler.java | 15 ++- .../spark/network/sasl/SaslRpcHandler.java | 13 ++- .../spark/network/server/NoOpRpcHandler.java | 1 - .../spark/network/server/RpcHandler.java | 51 ++++++--- .../spark/network/server/StreamData.java | 102 ------------------ .../server/TransportRequestHandler.java | 45 ++++++-- .../network/ChunkFetchIntegrationSuite.java | 2 - .../RequestTimeoutIntegrationSuite.java | 4 - .../spark/network/RpcIntegrationSuite.java | 47 ++++---- .../org/apache/spark/network/StreamSuite.java | 2 - .../network/crypto/AuthIntegrationSuite.java | 1 - .../spark/network/sasl/SparkSaslSuite.java | 3 +- .../shuffle/ExternalShuffleBlockHandler.java | 3 - 13 files changed, 118 insertions(+), 171 deletions(-) delete mode 100644 common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 2c44e9fa71ee1..7195414e0321f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -29,11 +29,11 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.sasl.SaslRpcHandler; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.TransportConf; @@ -84,10 +84,9 @@ class AuthRpcHandler extends RpcHandler { public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { if (doDelegate) { - delegate.receive(client, message, streamData, callback); + delegate.receive(client, message, callback); return; } @@ -105,7 +104,7 @@ public void receive( delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder); message.position(position); message.limit(limit); - delegate.receive(client, message, streamData, callback); + delegate.receive(client, message, callback); doDelegate = true; } else { LOG.debug("Unexpected challenge message from client {}, closing channel.", @@ -154,6 +153,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallback receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 37887eba2f7e0..9e77e9daaff31 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -28,9 +28,9 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; @@ -80,11 +80,10 @@ public SaslRpcHandler( public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. - delegate.receive(client, message, streamData, callback); + delegate.receive(client, message, callback); return; } if (saslServer == null || !saslServer.isComplete()) { @@ -137,6 +136,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallback receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index d25a5a914696d..c11c40377603a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -34,7 +34,6 @@ public NoOpRpcHandler() { public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index b1d4e746179fd..01737f8622af4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; /** @@ -36,32 +37,50 @@ public abstract class RpcHandler { * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * - * This method will not be called in parallel for a single TransportClient (i.e., channel). - * - * The rpc *might* included a data stream in streamData (eg. for uploading a large - * amount of data which should not be buffered in memory here). An error while reading data from - * the stream ({@link org.apache.spark.network.client.StreamCallback#onData(String, ByteBuffer)}) - * will fail the entire channel. A failure in "post-processing" the stream in - * {@link org.apache.spark.network.client.StreamCallback#onComplete(String)} will result in an - * rpcFailure, but the channel will remain active. - * - * If streamData is not null, you *must* call streamData.registerStreamCallback - * before this method returns. + * Neither this method nor #receiveStream will not be called in parallel for a single + * TransportClient (i.e., channel). * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. - * @param streamData StreamData if there is data which is meant to be read via a StreamCallback; - * otherwise it is null. * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ public abstract void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback); + /** + * Receive a single RPC message which includes data that is to be received as a stream. Any + * exception thrown while in this method will be sent back to the client in string form as a + * standard RPC failure. + * + * Neither this method nor #receive will not be called in parallel for a single TransportClient + * (i.e., channel). + * + * An error while reading data from the stream + * ({@link org.apache.spark.network.client.StreamCallback#onData(String, ByteBuffer)}) + * will fail the entire channel. A failure in "post-processing" the stream in + * {@link org.apache.spark.network.client.StreamCallback#onComplete(String)} will result in an + * rpcFailure, but the channel will remain active. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param messageHeader The serialized bytes of the header portion of the RPC. This is in meant + * to be relatively small, and will be buffered entirely in memory, to + * facilitate how the streaming portion should be received. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + * @return a StreamCallback for handling the accompanying streaming data + */ + public StreamCallback receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. @@ -70,7 +89,7 @@ public abstract void receive( /** * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link #receive(TransportClient, ByteBuffer, StreamData, RpcResponseCallback)}" and log a + * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a * warning if any of the callback methods are called. * * @param client A channel client which enables the handler to make requests back to the sender @@ -78,7 +97,7 @@ public abstract void receive( * @param message The serialized bytes of the RPC. */ public void receive(TransportClient client, ByteBuffer message) { - receive(client, message, null, ONE_WAY_CALLBACK); + receive(client, message, ONE_WAY_CALLBACK); } /** diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java deleted file mode 100644 index 6bd9514dc70bc..0000000000000 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamData.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import java.io.IOException; -import java.nio.ByteBuffer; - -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.StreamInterceptor; -import org.apache.spark.network.util.TransportFrameDecoder; - -/** - * A holder for streamed data sent along with an RPC message. - */ -public class StreamData { - - private final TransportRequestHandler handler; - private final TransportFrameDecoder frameDecoder; - private final RpcResponseCallback rpcCallback; - private final long streamByteCount; - private boolean hasCallback = false; - - public StreamData( - TransportRequestHandler handler, - TransportFrameDecoder frameDecoder, - RpcResponseCallback rpcCallback, - long streamByteCount) { - this.handler = handler; - this.frameDecoder = frameDecoder; - this.rpcCallback = rpcCallback; - this.streamByteCount = streamByteCount; - } - - public boolean hasCallback() { - return hasCallback; - } - - /** - * Register callback to receive the streaming data. - * - * If an exception is thrown from the callback, it will be propogated back to the sender as an rpc - * failure. A failure during onData will fail the entire channel (as we don't know - * what to do with the other data on the channel). A failure during onComplete will - * result in an rpc failure, but the channel will remain active. - */ - public void registerStreamCallback(String streamId, StreamCallback callback) throws IOException { - if (hasCallback) { - throw new IllegalStateException("Cannot register more than one stream callback"); - } - hasCallback = true; - // the passed callback handles the actual data, but we need to also make sure we respond to the - // original rpc request. - StreamCallback wrappedCallback = new StreamCallback() { - @Override - public void onData(String streamId, ByteBuffer buf) throws IOException { - callback.onData(streamId, buf); - } - - @Override - public void onComplete(String streamId) throws IOException { - try { - callback.onComplete(streamId); - rpcCallback.onSuccess(ByteBuffer.allocate(0)); - } catch (Exception ex) { - IOException ioExc = new IOException("Failure post-processing complete stream; failing " + - "this rpc and leaving channel active"); - rpcCallback.onFailure(ioExc); - callback.onFailure(streamId, ioExc); - } - } - - @Override - public void onFailure(String streamId, Throwable cause) throws IOException { - rpcCallback.onFailure(new IOException("Destination failed while reading stream", cause)); - callback.onFailure(streamId, cause); - } - }; - if (streamByteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(handler, streamId, streamByteCount, - wrappedCallback); - frameDecoder.setInterceptor(interceptor); - } else { - wrappedCallback.onComplete(streamId); - } - } -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e26b593d88906..1b436a4f88d43 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import java.io.IOException; import java.net.SocketAddress; import java.nio.ByteBuffer; @@ -29,6 +30,8 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.client.StreamInterceptor; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.*; import org.apache.spark.network.util.TransportFrameDecoder; @@ -188,7 +191,7 @@ public void onFailure(Throwable e) { respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); } }; - rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), null, callback); + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), callback); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); @@ -217,11 +220,41 @@ public void onFailure(Throwable e) { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); ByteBuffer meta = req.meta.nioByteBuffer(); - StreamData streamData = new StreamData(TransportRequestHandler.this, frameDecoder, - callback, req.bodyByteCount); - rpcHandler.receive(reverseClient, meta, streamData, callback); - if (!streamData.hasCallback()) { - throw new RuntimeException("Destination did not register stream handler"); + // TODO streamId? + String streamId = null; + StreamCallback streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); + // TODO do something with the streamHandler + StreamCallback wrappedCallback = new StreamCallback() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + streamHandler.onData(streamId, buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + try { + streamHandler.onComplete(streamId); + callback.onSuccess(ByteBuffer.allocate(0)); + } catch (Exception ex) { + IOException ioExc = new IOException("Failure post-processing complete stream; failing " + + "this rpc and leaving channel active"); + callback.onFailure(ioExc); + streamHandler.onFailure(streamId, ioExc); + } + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + callback.onFailure(new IOException("Destination failed while reading stream", cause)); + streamHandler.onFailure(streamId, cause); + } + }; + if (req.bodyByteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, streamId, req.bodyByteCount, + wrappedCallback); + frameDecoder.setInterceptor(interceptor); + } else { + wrappedCallback.onComplete(streamId); } } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 2bbe07036e392..824482af08dd4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -46,7 +46,6 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.MapConfigProvider; @@ -109,7 +108,6 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { public void receive( TransportClient client, ByteBuffer message, - StreamData data, RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 47dcf75d620c4..c0724e018263f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -25,7 +25,6 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.util.MapConfigProvider; @@ -92,7 +91,6 @@ public void timeoutInactiveRequests() throws Exception { public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { try { semaphore.acquire(); @@ -140,7 +138,6 @@ public void timeoutCleanlyClosesClient() throws Exception { public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { try { semaphore.acquire(); @@ -197,7 +194,6 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index a6da658cc6cd6..1d7c31b157698 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -65,23 +65,26 @@ public static void setUp() throws Exception { public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { String msg = JavaUtils.bytesToString(message); - if (streamData != null) { - receiveStream(msg, streamData); - } else { - String[] parts = msg.split("/"); - if (parts[0].equals("hello")) { - callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); - } else if (parts[0].equals("return error")) { - callback.onFailure(new RuntimeException("Returned: " + parts[1])); - } else if (parts[0].equals("throw error")) { - throw new RuntimeException("Thrown: " + parts[1]); - } + String[] parts = msg.split("/"); + if (parts[0].equals("hello")) { + callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); + } else if (parts[0].equals("return error")) { + callback.onFailure(new RuntimeException("Returned: " + parts[1])); + } else if (parts[0].equals("throw error")) { + throw new RuntimeException("Thrown: " + parts[1]); } } + @Override + public StreamCallback receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + return null; + } + @Override public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs.add(JavaUtils.bytesToString(message)); @@ -96,7 +99,7 @@ public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs = new ArrayList<>(); } - private static void receiveStream(String msg, StreamData streamData) { + private static StreamCallback receiveStream(String msg) { try { if (msg.startsWith("fail/")) { String[] parts = msg.split("/"); @@ -105,7 +108,7 @@ private static void receiveStream(String msg, StreamData streamData) { // don't register anything here, check the rpc error response is appropriate break; case "exception-ondata": - StreamCallback onDataExcCallback = new StreamCallback() { + return new StreamCallback() { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { throw new IOException("failed to read stream data!"); @@ -119,10 +122,8 @@ public void onComplete(String streamId) throws IOException { public void onFailure(String streamId, Throwable cause) throws IOException { } }; - streamData.registerStreamCallback(msg, onDataExcCallback); - break; case "exception-oncomplete": - StreamCallback onCompleteExcCallback = new StreamCallback() { + return new StreamCallback() { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { } @@ -136,22 +137,18 @@ public void onComplete(String streamId) throws IOException { public void onFailure(String streamId, Throwable cause) throws IOException { } }; - streamData.registerStreamCallback(msg, onCompleteExcCallback); - break; - case "multiple": - VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); - streamData.registerStreamCallback(msg, streamCallback); - streamData.registerStreamCallback(msg, streamCallback); - break; + default: + throw new IllegalArgumentException("unexpected msg: " + msg); } } else { VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); - streamData.registerStreamCallback(msg, streamCallback); streamCallbacks.put(msg, streamCallback); + return streamCallback; } } catch (IOException e) { throw new RuntimeException(e); } + throw new RuntimeException("unreachable"); } @AfterClass diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f6bd006ef8f56..1241046b5e4fb 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -42,7 +42,6 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.util.MapConfigProvider; @@ -85,7 +84,6 @@ public ManagedBuffer openStream(String streamId) { public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index c2cfe1d997586..72cbe146f522e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -141,7 +141,6 @@ private class AuthTestCtx { public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { assertEquals("Ping", JavaUtils.bytesToString(message)); callback.onSuccess(JavaUtils.stringToBytes("Pong")); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 24532dd152eb3..130396b09d6e8 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -142,8 +142,7 @@ private static void testBasicSasl(boolean encrypt) throws Throwable { return null; }) .when(rpcHandler) - .receive(any(TransportClient.class), any(ByteBuffer.class), any(StreamData.class), - any(RpcResponseCallback.class)); + .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 4f3c174b8f646..d8217fe47702d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -38,7 +38,6 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamData; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; import org.apache.spark.network.shuffle.protocol.*; @@ -80,9 +79,7 @@ public ExternalShuffleBlockHandler( public void receive( TransportClient client, ByteBuffer message, - StreamData streamData, RpcResponseCallback callback) { - assert(streamData == null); BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); handleMessage(msgObj, client, callback); } From 93a5adf8cd0b908ee9b0cc1392a3cfebcb4207d8 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 13 Jun 2018 12:03:02 -0500 Subject: [PATCH 12/18] fixes --- .../network/client/StreamCallbackWithID.java | 22 +++++++++ .../spark/network/crypto/AuthRpcHandler.java | 4 +- .../spark/network/sasl/SaslRpcHandler.java | 4 +- .../spark/network/server/RpcHandler.java | 4 +- .../server/TransportRequestHandler.java | 23 +++++----- .../spark/network/RpcIntegrationSuite.java | 45 ++++++++++--------- .../spark/network/sasl/SparkSaslSuite.java | 2 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 23 ++++------ 8 files changed, 71 insertions(+), 56 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java new file mode 100644 index 0000000000000..ae41618002b8e --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +public interface StreamCallbackWithID extends StreamCallback { + public String getID(); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 7195414e0321f..435e84e0e6bd3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -29,7 +29,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.sasl.SaslRpcHandler; @@ -154,7 +154,7 @@ public void receive(TransportClient client, ByteBuffer message) { } @Override - public StreamCallback receiveStream( + public StreamCallbackWithID receiveStream( TransportClient client, ByteBuffer message, RpcResponseCallback callback) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 9e77e9daaff31..722ebc4c0c7ca 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -28,7 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -137,7 +137,7 @@ public void receive(TransportClient client, ByteBuffer message) { } @Override - public StreamCallback receiveStream( + public StreamCallbackWithID receiveStream( TransportClient client, ByteBuffer message, RpcResponseCallback callback) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 01737f8622af4..a516baedde20c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,7 +23,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; /** @@ -74,7 +74,7 @@ public abstract void receive( * RPC. * @return a StreamCallback for handling the accompanying streaming data */ - public StreamCallback receiveStream( + public StreamCallbackWithID receiveStream( TransportClient client, ByteBuffer messageHeader, RpcResponseCallback callback) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 1b436a4f88d43..0e0b1fca4d70f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -24,15 +24,12 @@ import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; +import org.apache.spark.network.client.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.StreamInterceptor; -import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.*; import org.apache.spark.network.util.TransportFrameDecoder; @@ -220,11 +217,8 @@ public void onFailure(Throwable e) { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); ByteBuffer meta = req.meta.nioByteBuffer(); - // TODO streamId? - String streamId = null; - StreamCallback streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); - // TODO do something with the streamHandler - StreamCallback wrappedCallback = new StreamCallback() { + StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); + StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { streamHandler.onData(streamId, buf); @@ -248,13 +242,18 @@ public void onFailure(String streamId, Throwable cause) throws IOException { callback.onFailure(new IOException("Destination failed while reading stream", cause)); streamHandler.onFailure(streamId, cause); } + + @Override + public String getID() { + return streamHandler.getID(); + } }; if (req.bodyByteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, streamId, req.bodyByteCount, - wrappedCallback); + StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), + req.bodyByteCount, wrappedCallback); frameDecoder.setInterceptor(interceptor); } else { - wrappedCallback.onComplete(streamId); + wrappedCallback.onComplete(wrappedCallback.getID()); } } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 1d7c31b157698..69bbaafa187b1 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -36,10 +36,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.client.*; import org.apache.spark.network.server.*; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; @@ -78,11 +75,11 @@ public void receive( } @Override - public StreamCallback receiveStream( + public StreamCallbackWithID receiveStream( TransportClient client, ByteBuffer messageHeader, RpcResponseCallback callback) { - return null; + return receiveStreamHelper(JavaUtils.bytesToString(messageHeader)); } @Override @@ -99,16 +96,13 @@ public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs = new ArrayList<>(); } - private static StreamCallback receiveStream(String msg) { + private static StreamCallbackWithID receiveStreamHelper(String msg) { try { if (msg.startsWith("fail/")) { String[] parts = msg.split("/"); switch (parts[1]) { - case "no callback": - // don't register anything here, check the rpc error response is appropriate - break; case "exception-ondata": - return new StreamCallback() { + return new StreamCallbackWithID() { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { throw new IOException("failed to read stream data!"); @@ -121,9 +115,14 @@ public void onComplete(String streamId) throws IOException { @Override public void onFailure(String streamId, Throwable cause) throws IOException { } + + @Override + public String getID() { + return msg; + } }; case "exception-oncomplete": - return new StreamCallback() { + return new StreamCallbackWithID() { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { } @@ -136,6 +135,11 @@ public void onComplete(String streamId) throws IOException { @Override public void onFailure(String streamId, Throwable cause) throws IOException { } + + @Override + public String getID() { + return msg; + } }; default: throw new IllegalArgumentException("unexpected msg: " + msg); @@ -148,7 +152,6 @@ public void onFailure(String streamId, Throwable cause) throws IOException { } catch (IOException e) { throw new RuntimeException(e); } - throw new RuntimeException("unreachable"); } @AfterClass @@ -330,7 +333,7 @@ public void sendRpcWithStreamConcurrently() throws Exception { streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length]; } RpcResult res = sendRpcWithStream(streams); - assertEquals(res.successMessages, Sets.newHashSet(StreamTestHelper.STREAMS)); + assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), res.successMessages); assertTrue(res.errorMessages.isEmpty()); } @@ -338,13 +341,6 @@ public void sendRpcWithStreamConcurrently() throws Exception { public void sendRpcWithStreamFailures() throws Exception { // when there is a failure reading stream data, we don't try to keep the channel usable, // just send back a decent error msg. - RpcResult noCallbackResult = sendRpcWithStream("fail/no callback/smallBuffer", "smallBuffer"); - assertErrorAndClosed(noCallbackResult, "Destination did not register stream handler"); - - - RpcResult multiCallbackResult = sendRpcWithStream("fail/multiple/smallBuffer", "smallBuffer"); - assertErrorAndClosed(multiCallbackResult, "Cannot register more than one stream callback"); - RpcResult exceptionInCallbackResult = sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer"); assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); @@ -411,7 +407,7 @@ private Pair, Set> checkErrorsContain( return new ImmutablePair<>(remainingErrors, notFound); } - private static class VerifyingStreamCallback implements StreamCallback { + private static class VerifyingStreamCallback implements StreamCallbackWithID { final String streamId; final StreamSuite.TestCallback helper; final OutputStream out; @@ -461,5 +457,10 @@ public void onComplete(String streamId) throws IOException { public void onFailure(String streamId, Throwable cause) throws IOException { helper.onFailure(streamId, cause); } + + @Override + public String getID() { + return streamId; + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 130396b09d6e8..5f01ac1ea2768 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -136,7 +136,7 @@ private static void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); doAnswer(invocation -> { ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; - RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[3]; + RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; assertEquals("Ping", JavaUtils.bytesToString(message)); cb.onSuccess(JavaUtils.stringToBytes("Pong")); return null; diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 1c36c7a620c49..9a5d041da1ca6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -546,14 +546,11 @@ private[netty] class NettyRpcEndpointRef( * `NettyRpcEnv`. * @param receiver the receiver of this message. * @param content the message content. - * @param streamData optional stream of data. May be null. If present, - * streamData.registerStreamCallback *must* be called. */ private[netty] class RequestMessage( val senderAddress: RpcAddress, val receiver: NettyRpcEndpointRef, - val content: Any, - val streamData: StreamData) { + val content: Any) { /** Manually serialize [[RequestMessage]] to minimize the size. */ def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = { @@ -602,8 +599,7 @@ private[netty] object RequestMessage { def apply( nettyEnv: NettyRpcEnv, client: TransportClient, - bytes: ByteBuffer, - streamData: StreamData): RequestMessage = { + bytes: ByteBuffer): RequestMessage = { val bis = new ByteBufferInputStream(bytes) val in = new DataInputStream(bis) try { @@ -615,8 +611,7 @@ private[netty] object RequestMessage { senderAddress, ref, // The remaining bytes in `bytes` are the message content. - nettyEnv.deserialize(client, bytes), - streamData) + nettyEnv.deserialize(client, bytes)) } finally { in.close() } @@ -651,30 +646,28 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, message: ByteBuffer, - streamData: StreamData, callback: RpcResponseCallback): Unit = { - val messageToDispatch = internalReceive(client, message, streamData) + val messageToDispatch = internalReceive(client, message) dispatcher.postRemoteMessage(messageToDispatch, callback) } override def receive( client: TransportClient, message: ByteBuffer): Unit = { - val messageToDispatch = internalReceive(client, message, null) + val messageToDispatch = internalReceive(client, message) dispatcher.postOneWayMessage(messageToDispatch) } private def internalReceive( client: TransportClient, - message: ByteBuffer, - streamData: StreamData): RequestMessage = { + message: ByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostString, addr.getPort) - val requestMessage = RequestMessage(nettyEnv, client, message, streamData) + val requestMessage = RequestMessage(nettyEnv, client, message) if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. - new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, streamData) + new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) } else { // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for // the listening address From cf991a9d1923d6639315f5b2a75642a316d394e9 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 13 Jun 2018 12:08:05 -0500 Subject: [PATCH 13/18] cleanup --- .../spark/network/client/TransportClient.java | 1 - .../apache/spark/network/crypto/AuthRpcHandler.java | 5 +---- .../apache/spark/network/sasl/SaslRpcHandler.java | 5 +---- .../apache/spark/network/server/NoOpRpcHandler.java | 5 +---- .../org/apache/spark/network/server/RpcHandler.java | 8 ++++---- .../network/server/TransportRequestHandler.java | 8 ++++---- .../spark/network/crypto/AuthIntegrationSuite.java | 5 ++++- .../apache/spark/network/sasl/SparkSaslSuite.java | 5 ++++- .../shuffle/ExternalShuffleBlockHandler.java | 5 +---- .../spark/network/sasl/SaslIntegrationSuite.java | 12 ++++++------ .../shuffle/ExternalShuffleBlockHandlerSuite.java | 8 ++++---- .../spark/network/netty/NettyBlockRpcServer.scala | 3 +-- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 13 ++++--------- .../apache/spark/rpc/netty/NettyRpcEnvSuite.scala | 12 ++++++------ .../spark/rpc/netty/NettyRpcHandlerSuite.scala | 2 +- .../apache/spark/storage/BlockManagerSuite.scala | 3 +-- 16 files changed, 43 insertions(+), 57 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 9daa64a66a9b9..b22c5298cfdea 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -281,7 +281,6 @@ void handleFailure(String errorMsg, Throwable cause) { } } - /** * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 435e84e0e6bd3..fb44dbbb0953b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -81,10 +81,7 @@ class AuthRpcHandler extends RpcHandler { } @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { if (doDelegate) { delegate.receive(client, message, callback); return; diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 722ebc4c0c7ca..355a3def8cc22 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -77,10 +77,7 @@ public SaslRpcHandler( } @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index c11c40377603a..6ed61da5c7eff 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -31,10 +31,7 @@ public NoOpRpcHandler() { } @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index a516baedde20c..38569baf82bce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -37,7 +37,7 @@ public abstract class RpcHandler { * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * - * Neither this method nor #receiveStream will not be called in parallel for a single + * Neither this method nor #receiveStream will be called in parallel for a single * TransportClient (i.e., channel). * * @param client A channel client which enables the handler to make requests back to the sender @@ -56,7 +56,7 @@ public abstract void receive( * exception thrown while in this method will be sent back to the client in string form as a * standard RPC failure. * - * Neither this method nor #receive will not be called in parallel for a single TransportClient + * Neither this method nor #receive will be called in parallel for a single TransportClient * (i.e., channel). * * An error while reading data from the stream @@ -89,8 +89,8 @@ public StreamCallbackWithID receiveStream( /** * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a - * warning if any of the callback methods are called. + * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if + * any of the callback methods are called. * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 0e0b1fca4d70f..d16dcb6cc5d41 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -24,12 +24,12 @@ import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; -import org.apache.spark.network.client.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.*; import org.apache.spark.network.protocol.*; import org.apache.spark.network.util.TransportFrameDecoder; @@ -43,6 +43,7 @@ * The messages should have been processed by the pipeline setup by {@link TransportServer}. */ public class TransportRequestHandler extends MessageHandler { + private static final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); /** The Netty channel that this handler is associated with. */ @@ -177,7 +178,7 @@ private void processStreamRequest(final StreamRequest req) { private void processRpcRequest(final RpcRequest req) { try { - RpcResponseCallback callback = new RpcResponseCallback() { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); @@ -187,8 +188,7 @@ public void onSuccess(ByteBuffer response) { public void onFailure(Throwable e) { respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); } - }; - rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), callback); + }); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 72cbe146f522e..8751944a1c2a3 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -37,7 +37,10 @@ import org.apache.spark.network.sasl.SaslRpcHandler; import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.server.*; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 5f01ac1ea2768..6f15718bd8705 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -54,7 +54,10 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.server.*; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 9c213bb4d6302..098fa7974b87b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -76,10 +76,7 @@ public ExternalShuffleBlockHandler( } @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); handleMessage(msgObj, client, callback); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 69c3889abeac8..02e6eb3a4467e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -39,7 +39,11 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.*; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.BlockFetchingListener; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver; @@ -260,11 +264,7 @@ public void onFailure(int chunkIndex, Throwable t) { /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override - public void receive( - TransportClient client, - ByteBuffer message, - StreamData streamData, - RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { callback.onSuccess(message); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7f0279558221e..7846b71d5a8b1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -63,7 +63,7 @@ public void testRegisterExecutor() { ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); - handler.receive(client, registerMessage, null, callback); + handler.receive(client, registerMessage, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); @@ -88,7 +88,7 @@ public void testOpenShuffleBlocks() { ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) .toByteBuffer(); - handler.receive(client, openBlocks, null, callback); + handler.receive(client, openBlocks, callback); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); @@ -129,7 +129,7 @@ public void testBadMessages() { ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); try { - handler.receive(client, unserializableMsg, null, callback); + handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); } catch (Exception e) { // pass @@ -138,7 +138,7 @@ public void testBadMessages() { ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); try { - handler.receive(client, unexpectedMsg, null, callback); + handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); } catch (UnsupportedOperationException e) { // pass diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 1397ae24db15a..eb4cf94164fd4 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.NioManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} -import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamData, StreamManager} +import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -50,7 +50,6 @@ class NettyBlockRpcServer( override def receive( client: TransportClient, rpcMessage: ByteBuffer, - streamData: StreamData, responseContext: RpcResponseCallback): Unit = { val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 9a5d041da1ca6..a2936d6ad539c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -520,12 +520,12 @@ private[netty] class NettyRpcEndpointRef( override def name: String = endpointAddress.name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message, null), timeout) + nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout) } override def send(message: Any): Unit = { require(message != null, "Message is null") - nettyEnv.send(new RequestMessage(nettyEnv.address, this, message, null)) + nettyEnv.send(new RequestMessage(nettyEnv.address, this, message)) } override def toString: String = s"NettyRpcEndpointRef(${endpointAddress})" @@ -596,10 +596,7 @@ private[netty] object RequestMessage { } } - def apply( - nettyEnv: NettyRpcEnv, - client: TransportClient, - bytes: ByteBuffer): RequestMessage = { + def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = { val bis = new ByteBufferInputStream(bytes) val in = new DataInputStream(bis) try { @@ -658,9 +655,7 @@ private[netty] class NettyRpcHandler( dispatcher.postOneWayMessage(messageToDispatch) } - private def internalReceive( - client: TransportClient, - message: ByteBuffer): RequestMessage = { + private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostString, addr.getPort) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index 46dccc2c88ee4..f9481f875d439 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -69,19 +69,19 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { val receiverAddress = RpcEndpointAddress("localhost", 54321, "test") val receiver = new NettyRpcEndpointRef(nettyEnv.conf, receiverAddress, nettyEnv) - val msg = new RequestMessage(senderAddress, receiver, "foo", null) + val msg = new RequestMessage(senderAddress, receiver, "foo") assertRequestMessageEquals( msg, - RequestMessage(nettyEnv, client, msg.serialize(nettyEnv), null)) + RequestMessage(nettyEnv, client, msg.serialize(nettyEnv))) - val msg2 = new RequestMessage(null, receiver, "foo", null) + val msg2 = new RequestMessage(null, receiver, "foo") assertRequestMessageEquals( msg2, - RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv), null)) + RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv))) - val msg3 = new RequestMessage(senderAddress, receiver, null, null) + val msg3 = new RequestMessage(senderAddress, receiver, null) assertRequestMessageEquals( msg3, - RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv), null)) + RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv))) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 138e84ae7616a..a71d8726e7066 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -34,7 +34,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) - .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null, null)) + .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f334e0a67ef8f..b19d8ebf72c61 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.network.{BlockDataManager, BlockTransferService, Transpo import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} -import org.apache.spark.network.server.{NoOpRpcHandler, StreamData, TransportServer, TransportServerBootstrap} +import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.network.util.TransportConf @@ -1341,7 +1341,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def receive( client: TransportClient, message: ByteBuffer, - streamData: StreamData, callback: RpcResponseCallback): Unit = { val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) msgObj match { From 8a18da511a8053fc4dcf6529f49333d71bd6277d Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 13 Jun 2018 14:44:48 -0500 Subject: [PATCH 14/18] fix --- .../org/apache/spark/network/client/StreamCallbackWithID.java | 2 +- .../apache/spark/network/server/TransportRequestHandler.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java index ae41618002b8e..bd173b653e33e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java @@ -18,5 +18,5 @@ package org.apache.spark.network.client; public interface StreamCallbackWithID extends StreamCallback { - public String getID(); + String getID(); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index d16dcb6cc5d41..7935f95247d8c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -230,8 +230,8 @@ public void onComplete(String streamId) throws IOException { streamHandler.onComplete(streamId); callback.onSuccess(ByteBuffer.allocate(0)); } catch (Exception ex) { - IOException ioExc = new IOException("Failure post-processing complete stream; failing " + - "this rpc and leaving channel active"); + IOException ioExc = new IOException("Failure post-processing complete stream;" + + " failing this rpc and leaving channel active"); callback.onFailure(ioExc); streamHandler.onFailure(streamId, ioExc); } From 1a222aa77d2a31fd3b3ffe21edfc69ab99e80806 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 14 Jun 2018 10:19:08 -0500 Subject: [PATCH 15/18] factor out requestId() --- .../org/apache/spark/network/client/TransportClient.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index b22c5298cfdea..fa882ca37f903 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -192,7 +192,7 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } - long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))) @@ -221,7 +221,7 @@ public long uploadStream( logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } - long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); channel.writeAndFlush(new UploadStream(requestId, meta, data)) @@ -264,6 +264,10 @@ public void operationComplete(Future future) throws Exception { void handleFailure(String errorMsg, Throwable cause) throws Exception {} } + private static long requestId() { + return Math.abs(UUID.randomUUID().getLeastSignificantBits()); + } + private class RpcChannelListener extends StdChannelListener { final long rpcRequestId; final RpcResponseCallback callback; From ea4a1f5325495a2b611f67e01e8a86953361b1aa Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 15 Jun 2018 11:27:28 -0500 Subject: [PATCH 16/18] review feedback --- .../spark/network/client/TransportClient.java | 154 +++++++++--------- .../server/TransportRequestHandler.java | 2 +- .../spark/network/RpcIntegrationSuite.java | 9 +- .../spark/network/StreamTestHelper.java | 18 +- 4 files changed, 91 insertions(+), 92 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index fa882ca37f903..325225dc0ea2c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -133,22 +133,21 @@ public void fetchChunk( long streamId, int chunkIndex, ChunkReceivedCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); + StdChannelListener listener = new StdChannelListener(streamChunkId) { + @Override + void handleFailure(String errorMsg, Throwable cause) { + handler.removeFetchRequest(streamChunkId); + callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); + } + }; handler.addFetchRequest(streamChunkId, callback); - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)) - .addListener( new StdChannelListener(startTime, streamChunkId) { - @Override - void handleFailure(String errorMsg, Throwable cause) { - handler.removeFetchRequest(streamChunkId); - callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); - } - }); + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener); } /** @@ -158,7 +157,12 @@ void handleFailure(String errorMsg, Throwable cause) { * @param callback Object to call with the stream data. */ public void stream(String streamId, StreamCallback callback) { - long startTime = System.currentTimeMillis(); + StdChannelListener listener = new StdChannelListener(streamId) { + @Override + void handleFailure(String errorMsg, Throwable cause) throws Exception { + callback.onFailure(streamId, new IOException(errorMsg, cause)); + } + }; if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); } @@ -168,13 +172,7 @@ public void stream(String streamId, StreamCallback callback) { // when responses arrive. synchronized (this) { handler.addStreamCallback(streamId, callback); - channel.writeAndFlush(new StreamRequest(streamId)) - .addListener(new StdChannelListener(startTime, streamId) { - @Override - void handleFailure(String errorMsg, Throwable cause) throws Exception { - callback.onFailure(streamId, new IOException(errorMsg, cause)); - } - }); + channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener); } } @@ -187,7 +185,6 @@ void handleFailure(String errorMsg, Throwable cause) throws Exception { * @return The RPC's id. */ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isTraceEnabled()) { logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } @@ -195,14 +192,15 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { long requestId = requestId(); handler.addRpcRequest(requestId, callback); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))) - .addListener(new RpcChannelListener(startTime, requestId, callback)); + .addListener(listener); return requestId; } /** - * Send data to the remote end as a stream. This differs from stream() in that this is a request + * Send data to the remote end as a stream. This differs from stream() in that this is a request * to *send* data to the remote end, not to receive it from the remote. * * @param meta meta data associated with the stream, which will be read completely on the @@ -216,7 +214,6 @@ public long uploadStream( ManagedBuffer meta, ManagedBuffer data, RpcResponseCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isTraceEnabled()) { logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } @@ -224,67 +221,12 @@ public long uploadStream( long requestId = requestId(); handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new UploadStream(requestId, meta, data)) - .addListener(new RpcChannelListener(startTime, requestId, callback)); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener); return requestId; } - private class StdChannelListener - implements GenericFutureListener> { - final long startTime; - final Object requestId; - - StdChannelListener(long startTime, Object requestId) { - this.startTime = startTime; - this.requestId = requestId; - } - - @Override - public void operationComplete(Future future) throws Exception { - if (future.isSuccess()) { - if (logger.isTraceEnabled()) { - long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request {} to {} took {} ms", requestId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - channel.close(); - try { - handleFailure(errorMsg, future.cause()); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - } - - void handleFailure(String errorMsg, Throwable cause) throws Exception {} - } - - private static long requestId() { - return Math.abs(UUID.randomUUID().getLeastSignificantBits()); - } - - private class RpcChannelListener extends StdChannelListener { - final long rpcRequestId; - final RpcResponseCallback callback; - - RpcChannelListener(long startTime, long rpcRequestId, RpcResponseCallback callback) { - super(startTime, "RPC " + rpcRequestId); - this.rpcRequestId = rpcRequestId; - this.callback = callback; - } - - @Override - void handleFailure(String errorMsg, Throwable cause) { - handler.removeRpcRequest(rpcRequestId); - callback.onFailure(new IOException(errorMsg, cause)); - } - } - /** * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. @@ -360,4 +302,60 @@ public String toString() { .add("isActive", isActive()) .toString(); } + + private static long requestId() { + return Math.abs(UUID.randomUUID().getLeastSignificantBits()); + } + + private class StdChannelListener + implements GenericFutureListener> { + final long startTime; + final Object requestId; + + StdChannelListener(Object requestId) { + this.startTime = System.currentTimeMillis(); + this.requestId = requestId; + } + + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + if (logger.isTraceEnabled()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + handleFailure(errorMsg, future.cause()); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + + void handleFailure(String errorMsg, Throwable cause) throws Exception {} + } + + private class RpcChannelListener extends StdChannelListener { + final long rpcRequestId; + final RpcResponseCallback callback; + + RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { + super("RPC " + rpcRequestId); + this.rpcRequestId = rpcRequestId; + this.callback = callback; + } + + @Override + void handleFailure(String errorMsg, Throwable cause) { + handler.removeRpcRequest(rpcRequestId); + callback.onFailure(new IOException(errorMsg, cause)); + } + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 7935f95247d8c..2a656fd840ca4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -260,7 +260,7 @@ public String getID() { respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); // We choose to totally fail the channel, rather than trying to recover as we do in other // cases. We don't know how many bytes of the stream the client has already sent for the - // stream, its not worth trying to recover. + // stream, it's not worth trying to recover. channel.pipeline().fireExceptionCaught(e); } finally { req.meta.release(); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 69bbaafa187b1..fa9ffc009540b 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -214,6 +214,10 @@ private RpcResult sendRpcWithStream(String... streams) throws Exception { ManagedBuffer data = testData.openStream(conf, streamName); client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem)); } + + if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } streamCallbacks.values().forEach(streamCallback -> { try { streamCallback.waitForCompletionAndVerify(TimeUnit.SECONDS.toMillis(5)); @@ -221,11 +225,6 @@ private RpcResult sendRpcWithStream(String... streams) throws Exception { throw new RuntimeException(e); } }); - - - if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); - } client.close(); return res; } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java index 5d581b80afd9e..0f5c82c9e9b1f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.FileOutputStream; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.Random; @@ -27,17 +28,18 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; class StreamTestHelper { static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; final File testFile; - File tempDir; + final File tempDir; - ByteBuffer emptyBuffer; - ByteBuffer smallBuffer; - ByteBuffer largeBuffer; + final ByteBuffer emptyBuffer; + final ByteBuffer smallBuffer; + final ByteBuffer largeBuffer; private static ByteBuffer createBuffer(int bufSize) { ByteBuffer buf = ByteBuffer.allocate(bufSize); @@ -90,13 +92,13 @@ public ManagedBuffer openStream(TransportConf conf, String streamId) { } } - void cleanup() { if (tempDir != null) { - for (File f : tempDir.listFiles()) { - f.delete(); + try { + JavaUtils.deleteRecursively(tempDir); + } catch (IOException io) { + throw new RuntimeException(io); } - tempDir.delete(); } } } From fd62f615369e287d6deb707d6b0bfa11cfead2fe Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 19 Jun 2018 15:11:36 -0500 Subject: [PATCH 17/18] review feedback --- .../network/server/TransportRequestHandler.java | 3 +++ .../apache/spark/network/RpcIntegrationSuite.java | 12 ++++++++---- .../java/org/apache/spark/network/StreamSuite.java | 4 ---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 2a656fd840ca4..e1d7b2dbff60f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -218,6 +218,9 @@ public void onFailure(Throwable e) { channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); ByteBuffer meta = req.meta.nioByteBuffer(); StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); + if (streamHandler == null) { + throw new NullPointerException("rpcHandler returned a null streamHandler"); + } StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index fa9ffc009540b..4fa3fc8bc1870 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -141,6 +141,8 @@ public String getID() { return msg; } }; + case "null": + return null; default: throw new IllegalArgumentException("unexpected msg: " + msg); } @@ -220,7 +222,7 @@ private RpcResult sendRpcWithStream(String... streams) throws Exception { } streamCallbacks.values().forEach(streamCallback -> { try { - streamCallback.waitForCompletionAndVerify(TimeUnit.SECONDS.toMillis(5)); + streamCallback.verify(); } catch (IOException e) { throw new RuntimeException(e); } @@ -344,6 +346,10 @@ public void sendRpcWithStreamFailures() throws Exception { sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer"); assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + RpcResult nullStreamHandler = + sendRpcWithStream("fail/null/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + // OTOH, if there is a failure during onComplete, the channel should still be fine RpcResult exceptionInOnComplete = sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer"); @@ -423,8 +429,7 @@ private static class VerifyingStreamCallback implements StreamCallbackWithID { helper = new StreamSuite.TestCallback(out); } - void waitForCompletionAndVerify(long timeoutMs) throws IOException { - helper.waitForCompletion(timeoutMs); + void verify() throws IOException { if (streamId.equals("file")) { assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); } else { @@ -438,7 +443,6 @@ void waitForCompletionAndVerify(long timeoutMs) throws IOException { base.get(expected); assertEquals(expected.length, result.length); assertTrue("buffers don't match", Arrays.equals(expected, result)); - } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index 1241046b5e4fb..f3050cb79cdfd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -252,9 +252,6 @@ public void check() throws Throwable { throw error; } } - - - } static class TestCallback implements StreamCallback { @@ -309,7 +306,6 @@ void waitForCompletion(long timeoutMs) { assertTrue("Timed out waiting for stream.", completed); assertNull(error); } - } } From cd11abc3261d6f37731aa4574705119e0ac57a93 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 26 Jun 2018 12:38:20 -0500 Subject: [PATCH 18/18] review feedback --- .../test/java/org/apache/spark/network/RpcIntegrationSuite.java | 1 + 1 file changed, 1 insertion(+) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 4fa3fc8bc1870..1f4d75c7e2ec5 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -417,6 +417,7 @@ private static class VerifyingStreamCallback implements StreamCallbackWithID { final StreamSuite.TestCallback helper; final OutputStream out; final File outFile; + VerifyingStreamCallback(String streamId) throws IOException { if (streamId.equals("file")) { outFile = File.createTempFile("data", ".tmp", testData.tempDir);