Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6237][NETWORK] Network-layer changes to allow stream upload. #21346

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.client;

public interface StreamCallbackWithID extends StreamCallback {
String getID();
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,24 @@

import io.netty.buffer.ByteBuf;

import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.server.MessageHandler;
import org.apache.spark.network.util.TransportFrameDecoder;

/**
* An interceptor that is registered with the frame decoder to feed stream data to a
* callback.
*/
class StreamInterceptor implements TransportFrameDecoder.Interceptor {
public class StreamInterceptor<T extends Message> implements TransportFrameDecoder.Interceptor {

private final TransportResponseHandler handler;
private final MessageHandler<T> handler;
private final String streamId;
private final long byteCount;
private final StreamCallback callback;
private long bytesRead;

StreamInterceptor(
TransportResponseHandler handler,
public StreamInterceptor(
MessageHandler<T> handler,
String streamId,
long byteCount,
StreamCallback callback) {
Expand All @@ -50,16 +52,24 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {

@Override
public void exceptionCaught(Throwable cause) throws Exception {
handler.deactivateStream();
deactivateStream();
callback.onFailure(streamId, cause);
}

@Override
public void channelInactive() throws Exception {
handler.deactivateStream();
deactivateStream();
callback.onFailure(streamId, new ClosedChannelException());
}

private void deactivateStream() {
if (handler instanceof TransportResponseHandler) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need to do this for TransportRequestHandler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only purpose of TransportResponseHandler.deactivateStream() is to include the stream request in the count for numOutstandingRequests (its not doing any critical cleanup). I will include a comment here explaining that.

// we only have to do this for TransportResponseHandler as it exposes numOutstandingFetches
// (there is no extra cleanup that needs to happen)
((TransportResponseHandler) handler).deactivateStream();
}
}

@Override
public boolean handle(ByteBuf buf) throws Exception {
int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead);
Expand All @@ -72,10 +82,10 @@ public boolean handle(ByteBuf buf) throws Exception {
RuntimeException re = new IllegalStateException(String.format(
"Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
callback.onFailure(streamId, re);
handler.deactivateStream();
deactivateStream();
throw re;
} else if (bytesRead == byteCount) {
handler.deactivateStream();
deactivateStream();
callback.onComplete(streamId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.SettableFuture;
import io.netty.channel.Channel;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.OneWayMessage;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.StreamRequest;
import org.apache.spark.network.protocol.*;

import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;

/**
Expand Down Expand Up @@ -141,26 +141,14 @@ public void fetchChunk(
StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
handler.addFetchRequest(streamChunkId, callback);

channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes to these .addListener() calls primarily cleanup / refactoring? Is the intent to reduce the amount of new duplicate code which would otherwise be added to uploadStream in this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly. Marcelo asked for this refactoring in his review -- there was already a ton of copy-paste, and instead of adding more made sense to refactor. Shouldn't be any behavior change (there are minor changes that shouldn't matter ... channel.close() happens before the more specific cleanup operations whereas it was in the middle previously, the try encompasses a bit more than before.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining. I guess the re-ordering of channel.close() and the handler operations is safe because the handler doesn't hold references to the channel / otherwise does not interact with it (and doesn't hold references to objects tied to channel lifecycle (like buffers))?

if (future.isSuccess()) {
long timeTaken = System.currentTimeMillis() - startTime;
if (logger.isTraceEnabled()) {
logger.trace("Sending request {} to {} took {} ms", streamChunkId,
getRemoteAddress(channel), timeTaken);
channel.writeAndFlush(new ChunkFetchRequest(streamChunkId))
.addListener( new StdChannelListener(startTime, streamChunkId) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no space after (

@Override
void handleFailure(String errorMsg, Throwable cause) {
handler.removeFetchRequest(streamChunkId);
callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
}
} else {
String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
getRemoteAddress(channel), future.cause());
logger.error(errorMsg, future.cause());
handler.removeFetchRequest(streamChunkId);
channel.close();
try {
callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
});
});
}

/**
Expand All @@ -180,25 +168,13 @@ public void stream(String streamId, StreamCallback callback) {
// when responses arrive.
synchronized (this) {
handler.addStreamCallback(streamId, callback);
channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> {
if (future.isSuccess()) {
long timeTaken = System.currentTimeMillis() - startTime;
if (logger.isTraceEnabled()) {
logger.trace("Sending request for {} to {} took {} ms", streamId,
getRemoteAddress(channel), timeTaken);
}
} else {
String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId,
getRemoteAddress(channel), future.cause());
logger.error(errorMsg, future.cause());
channel.close();
try {
callback.onFailure(streamId, new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
channel.writeAndFlush(new StreamRequest(streamId))
.addListener(new StdChannelListener(startTime, streamId) {
@Override
void handleFailure(String errorMsg, Throwable cause) throws Exception {
callback.onFailure(streamId, new IOException(errorMsg, cause));
}
}
});
});
}
}

Expand All @@ -216,34 +192,99 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
logger.trace("Sending RPC to {}", getRemoteAddress(channel));
}

long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
long requestId = requestId();
handler.addRpcRequest(requestId, callback);

channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
.addListener(future -> {
if (future.isSuccess()) {
long timeTaken = System.currentTimeMillis() - startTime;
if (logger.isTraceEnabled()) {
logger.trace("Sending request {} to {} took {} ms", requestId,
getRemoteAddress(channel), timeTaken);
}
} else {
String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
getRemoteAddress(channel), future.cause());
logger.error(errorMsg, future.cause());
handler.removeRpcRequest(requestId);
channel.close();
try {
callback.onFailure(new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
});
.addListener(new RpcChannelListener(startTime, requestId, callback));

return requestId;
}

/**
* Send data to the remote end as a stream. This differs from stream() in that this is a request
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you're in the "2 spaces after period camp", but that's 3.

* to *send* data to the remote end, not to receive it from the remote.
*
* @param meta meta data associated with the stream, which will be read completely on the
* receiving end before the stream itself.
* @param data this will be streamed to the remote end to allow for transferring large amounts
* of data without reading into memory.
* @param callback handles the reply -- onSuccess will only be called when both message and data
* are received successfully.
*/
public long uploadStream(
ManagedBuffer meta,
ManagedBuffer data,
RpcResponseCallback callback) {
long startTime = System.currentTimeMillis();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it should be easy to move this to StdChannelListener's constructor. Looks pretty similar in all methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't do that the originally as I figured you wanted the startTime to be before writeAndFlush, but I can work around that too.

if (logger.isTraceEnabled()) {
logger.trace("Sending RPC to {}", getRemoteAddress(channel));
}

long requestId = requestId();
handler.addRpcRequest(requestId, callback);

channel.writeAndFlush(new UploadStream(requestId, meta, data))
.addListener(new RpcChannelListener(startTime, requestId, callback));

return requestId;
}

private class StdChannelListener
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally try to keep nested classes at the bottom of the enclosing class, but up to you.

implements GenericFutureListener<Future<? super Void>> {
final long startTime;
final Object requestId;

StdChannelListener(long startTime, Object requestId) {
this.startTime = startTime;
this.requestId = requestId;
}

@Override
public void operationComplete(Future future) throws Exception {
if (future.isSuccess()) {
if (logger.isTraceEnabled()) {
long timeTaken = System.currentTimeMillis() - startTime;
logger.trace("Sending request {} to {} took {} ms", requestId,
getRemoteAddress(channel), timeTaken);
}
} else {
String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
getRemoteAddress(channel), future.cause());
logger.error(errorMsg, future.cause());
channel.close();
try {
handleFailure(errorMsg, future.cause());
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
}

void handleFailure(String errorMsg, Throwable cause) throws Exception {}
}

private static long requestId() {
return Math.abs(UUID.randomUUID().getLeastSignificantBits());
}

private class RpcChannelListener extends StdChannelListener {
final long rpcRequestId;
final RpcResponseCallback callback;

RpcChannelListener(long startTime, long rpcRequestId, RpcResponseCallback callback) {
super(startTime, "RPC " + rpcRequestId);
this.rpcRequestId = rpcRequestId;
this.callback = callback;
}

@Override
void handleFailure(String errorMsg, Throwable cause) {
handler.removeRpcRequest(rpcRequestId);
callback.onFailure(new IOException(errorMsg, cause));
}
}

/**
* Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
* a specified timeout for a response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.sasl.SaslRpcHandler;
Expand Down Expand Up @@ -149,6 +150,14 @@ public void receive(TransportClient client, ByteBuffer message) {
delegate.receive(client, message);
}

@Override
public StreamCallbackWithID receiveStream(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
return delegate.receiveStream(client, message, callback);
}

@Override
public StreamManager getStreamManager() {
return delegate.getStreamManager();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
RpcRequest(3), RpcResponse(4), RpcFailure(5),
StreamRequest(6), StreamResponse(7), StreamFailure(8),
OneWayMessage(9), User(-1);
OneWayMessage(9), UploadStream(10), User(-1);

private final byte id;

Expand Down Expand Up @@ -65,6 +65,7 @@ public static Type decode(ByteBuf buf) {
case 7: return StreamResponse;
case 8: return StreamFailure;
case 9: return OneWayMessage;
case 10: return UploadStream;
case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
default: throw new IllegalArgumentException("Unknown message type: " + id);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ private Message decode(Message.Type msgType, ByteBuf in) {
case StreamFailure:
return StreamFailure.decode(in);

case UploadStream:
return UploadStream.decode(in);

default:
throw new IllegalArgumentException("Unexpected message type: " + msgType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public static StreamResponse decode(ByteBuf buf) {

@Override
public int hashCode() {
return Objects.hashCode(byteCount, streamId, body());
return Objects.hashCode(byteCount, streamId);
}

@Override
Expand Down
Loading