Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6237][NETWORK] Network-layer changes to allow stream upload. #21346

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.client;

public interface StreamCallbackWithID extends StreamCallback {
String getID();
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,24 @@

import io.netty.buffer.ByteBuf;

import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.server.MessageHandler;
import org.apache.spark.network.util.TransportFrameDecoder;

/**
* An interceptor that is registered with the frame decoder to feed stream data to a
* callback.
*/
class StreamInterceptor implements TransportFrameDecoder.Interceptor {
public class StreamInterceptor<T extends Message> implements TransportFrameDecoder.Interceptor {

private final TransportResponseHandler handler;
private final MessageHandler<T> handler;
private final String streamId;
private final long byteCount;
private final StreamCallback callback;
private long bytesRead;

StreamInterceptor(
TransportResponseHandler handler,
public StreamInterceptor(
MessageHandler<T> handler,
String streamId,
long byteCount,
StreamCallback callback) {
Expand All @@ -50,16 +52,24 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {

@Override
public void exceptionCaught(Throwable cause) throws Exception {
handler.deactivateStream();
deactivateStream();
callback.onFailure(streamId, cause);
}

@Override
public void channelInactive() throws Exception {
handler.deactivateStream();
deactivateStream();
callback.onFailure(streamId, new ClosedChannelException());
}

private void deactivateStream() {
if (handler instanceof TransportResponseHandler) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need to do this for TransportRequestHandler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only purpose of TransportResponseHandler.deactivateStream() is to include the stream request in the count for numOutstandingRequests (its not doing any critical cleanup). I will include a comment here explaining that.

// we only have to do this for TransportResponseHandler as it exposes numOutstandingFetches
// (there is no extra cleanup that needs to happen)
((TransportResponseHandler) handler).deactivateStream();
}
}

@Override
public boolean handle(ByteBuf buf) throws Exception {
int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead);
Expand All @@ -72,10 +82,10 @@ public boolean handle(ByteBuf buf) throws Exception {
RuntimeException re = new IllegalStateException(String.format(
"Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
callback.onFailure(streamId, re);
handler.deactivateStream();
deactivateStream();
throw re;
} else if (bytesRead == byteCount) {
handler.deactivateStream();
deactivateStream();
callback.onComplete(streamId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.SettableFuture;
import io.netty.channel.Channel;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.OneWayMessage;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.StreamRequest;
import org.apache.spark.network.protocol.*;

import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;

/**
Expand Down Expand Up @@ -133,34 +133,21 @@ public void fetchChunk(
long streamId,
int chunkIndex,
ChunkReceivedCallback callback) {
long startTime = System.currentTimeMillis();
if (logger.isDebugEnabled()) {
logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel));
}

StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
handler.addFetchRequest(streamChunkId, callback);

channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> {
if (future.isSuccess()) {
long timeTaken = System.currentTimeMillis() - startTime;
if (logger.isTraceEnabled()) {
logger.trace("Sending request {} to {} took {} ms", streamChunkId,
getRemoteAddress(channel), timeTaken);
}
} else {
String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
getRemoteAddress(channel), future.cause());
logger.error(errorMsg, future.cause());
StdChannelListener listener = new StdChannelListener(streamChunkId) {
@Override
void handleFailure(String errorMsg, Throwable cause) {
handler.removeFetchRequest(streamChunkId);
channel.close();
try {
callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
}
});
};
handler.addFetchRequest(streamChunkId, callback);

channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener);
}

/**
Expand All @@ -170,7 +157,12 @@ public void fetchChunk(
* @param callback Object to call with the stream data.
*/
public void stream(String streamId, StreamCallback callback) {
long startTime = System.currentTimeMillis();
StdChannelListener listener = new StdChannelListener(streamId) {
@Override
void handleFailure(String errorMsg, Throwable cause) throws Exception {
callback.onFailure(streamId, new IOException(errorMsg, cause));
}
};
if (logger.isDebugEnabled()) {
logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel));
}
Expand All @@ -180,25 +172,7 @@ public void stream(String streamId, StreamCallback callback) {
// when responses arrive.
synchronized (this) {
handler.addStreamCallback(streamId, callback);
channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> {
if (future.isSuccess()) {
long timeTaken = System.currentTimeMillis() - startTime;
if (logger.isTraceEnabled()) {
logger.trace("Sending request for {} to {} took {} ms", streamId,
getRemoteAddress(channel), timeTaken);
}
} else {
String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId,
getRemoteAddress(channel), future.cause());
logger.error(errorMsg, future.cause());
channel.close();
try {
callback.onFailure(streamId, new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
});
channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener);
}
}

Expand All @@ -211,35 +185,44 @@ public void stream(String streamId, StreamCallback callback) {
* @return The RPC's id.
*/
public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
long startTime = System.currentTimeMillis();
if (logger.isTraceEnabled()) {
logger.trace("Sending RPC to {}", getRemoteAddress(channel));
}

long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
long requestId = requestId();
handler.addRpcRequest(requestId, callback);

RpcChannelListener listener = new RpcChannelListener(requestId, callback);
channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
.addListener(future -> {
if (future.isSuccess()) {
long timeTaken = System.currentTimeMillis() - startTime;
if (logger.isTraceEnabled()) {
logger.trace("Sending request {} to {} took {} ms", requestId,
getRemoteAddress(channel), timeTaken);
}
} else {
String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
getRemoteAddress(channel), future.cause());
logger.error(errorMsg, future.cause());
handler.removeRpcRequest(requestId);
channel.close();
try {
callback.onFailure(new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
});
.addListener(listener);

return requestId;
}

/**
* Send data to the remote end as a stream. This differs from stream() in that this is a request
* to *send* data to the remote end, not to receive it from the remote.
*
* @param meta meta data associated with the stream, which will be read completely on the
* receiving end before the stream itself.
* @param data this will be streamed to the remote end to allow for transferring large amounts
* of data without reading into memory.
* @param callback handles the reply -- onSuccess will only be called when both message and data
* are received successfully.
*/
public long uploadStream(
ManagedBuffer meta,
ManagedBuffer data,
RpcResponseCallback callback) {
if (logger.isTraceEnabled()) {
logger.trace("Sending RPC to {}", getRemoteAddress(channel));
}

long requestId = requestId();
handler.addRpcRequest(requestId, callback);

RpcChannelListener listener = new RpcChannelListener(requestId, callback);
channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener);

return requestId;
}
Expand Down Expand Up @@ -319,4 +302,60 @@ public String toString() {
.add("isActive", isActive())
.toString();
}

private static long requestId() {
return Math.abs(UUID.randomUUID().getLeastSignificantBits());
}

private class StdChannelListener
implements GenericFutureListener<Future<? super Void>> {
final long startTime;
final Object requestId;

StdChannelListener(Object requestId) {
this.startTime = System.currentTimeMillis();
this.requestId = requestId;
}

@Override
public void operationComplete(Future future) throws Exception {
if (future.isSuccess()) {
if (logger.isTraceEnabled()) {
long timeTaken = System.currentTimeMillis() - startTime;
logger.trace("Sending request {} to {} took {} ms", requestId,
getRemoteAddress(channel), timeTaken);
}
} else {
String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
getRemoteAddress(channel), future.cause());
logger.error(errorMsg, future.cause());
channel.close();
try {
handleFailure(errorMsg, future.cause());
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
}

void handleFailure(String errorMsg, Throwable cause) throws Exception {}
}

private class RpcChannelListener extends StdChannelListener {
final long rpcRequestId;
final RpcResponseCallback callback;

RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) {
super("RPC " + rpcRequestId);
this.rpcRequestId = rpcRequestId;
this.callback = callback;
}

@Override
void handleFailure(String errorMsg, Throwable cause) {
handler.removeRpcRequest(rpcRequestId);
callback.onFailure(new IOException(errorMsg, cause));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.sasl.SaslRpcHandler;
Expand Down Expand Up @@ -149,6 +150,14 @@ public void receive(TransportClient client, ByteBuffer message) {
delegate.receive(client, message);
}

@Override
public StreamCallbackWithID receiveStream(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
return delegate.receiveStream(client, message, callback);
}

@Override
public StreamManager getStreamManager() {
return delegate.getStreamManager();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
RpcRequest(3), RpcResponse(4), RpcFailure(5),
StreamRequest(6), StreamResponse(7), StreamFailure(8),
OneWayMessage(9), User(-1);
OneWayMessage(9), UploadStream(10), User(-1);

private final byte id;

Expand Down Expand Up @@ -65,6 +65,7 @@ public static Type decode(ByteBuf buf) {
case 7: return StreamResponse;
case 8: return StreamFailure;
case 9: return OneWayMessage;
case 10: return UploadStream;
case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
default: throw new IllegalArgumentException("Unknown message type: " + id);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ private Message decode(Message.Type msgType, ByteBuf in) {
case StreamFailure:
return StreamFailure.decode(in);

case UploadStream:
return UploadStream.decode(in);

default:
throw new IllegalArgumentException("Unexpected message type: " + msgType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public static StreamResponse decode(ByteBuf buf) {

@Override
public int hashCode() {
return Objects.hashCode(byteCount, streamId, body());
return Objects.hashCode(byteCount, streamId);
}

@Override
Expand Down
Loading