From 3b374994cbf1900412b68748fc2eb37a90df77d3 Mon Sep 17 00:00:00 2001 From: "haiming@lccomputing.com" Date: Mon, 27 Jun 2022 13:54:52 +0800 Subject: [PATCH 1/4] fix worker pushserverport replicateserverport conflict --- common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala | 4 ++++ .../com/aliyun/emr/rss/service/deploy/worker/Worker.scala | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala b/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala index 663b96dcf06..d6afecd815d 100644 --- a/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala +++ b/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala @@ -574,6 +574,10 @@ object RssConf extends Logging { conf.getInt("rss.fetchserver.port", 0) } + def replicateServerPort(conf: RssConf): Int = { + conf.getInt("rss.replicateserver.port", 0) + } + def registerWorkerTimeoutMs(conf: RssConf): Long = { conf.getTimeAsMs("rss.register.worker.timeout", "180s") } diff --git a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala index c0a58eafb9f..b74a05fe0c1 100644 --- a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala +++ b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala @@ -112,7 +112,7 @@ private[deploy] class Worker( new TransportContext(transportConf, rpcHandler, closeIdleConnections, workerSource, replicateLimiter) val serverBootstraps = new jArrayList[TransportServerBootstrap]() - transportContext.createServer(RssConf.pushServerPort(conf), serverBootstraps) + transportContext.createServer(RssConf.replicateServerPort(conf), serverBootstraps) } private val fetchServer = { From d268a8577d1092f021148fbffcd985cb423d3555 Mon Sep 17 00:00:00 2001 From: "haiming@lccomputing.com" Date: Thu, 30 Jun 2022 18:03:22 +0800 Subject: [PATCH 2/4] fix worker rss.master.address configration item overwrited by worker default master address issue --- .../com/aliyun/emr/rss/service/deploy/worker/Worker.scala | 4 +++- .../emr/rss/service/deploy/worker/WorkerArguments.scala | 5 +---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala index 88c8a61a6a1..c86296729d8 100644 --- a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala +++ b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala @@ -979,7 +979,9 @@ private[deploy] object Worker extends Logging { // much as possible. Therefore, if the user manually specifies the address of the Master when // starting the Worker, we should set it in the parameters and automatically calculate what the // address of the Master should be used in the end. - conf.set("rss.master.address", RpcAddress.fromRssURL(workerArgs.master).toString) + if (workerArgs.master != null) { + conf.set("rss.master.address", RpcAddress.fromRssURL(workerArgs.master).toString) + } val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, WorkerSource.ServletPath) diff --git a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/WorkerArguments.scala b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/WorkerArguments.scala index 188d1301d09..5a1a71f739a 100644 --- a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/WorkerArguments.scala +++ b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/WorkerArguments.scala @@ -28,7 +28,7 @@ class WorkerArguments(args: Array[String], conf: RssConf) { var port = 0 // var master: String = null // for local testing. - var master: String = s"rss://$host:9097" + var master: String = null var propertiesFile: String = null parse(args.toList) @@ -58,9 +58,6 @@ class WorkerArguments(args: Array[String], conf: RssConf) { parse(tail) case Nil => - if (master == null) { // No positional argument was given - printUsageAndExit(1) - } case _ => printUsageAndExit(1) From 20175a4d06d9f450a105600a5c6c89090a163343 Mon Sep 17 00:00:00 2001 From: "haiming@lccomputing.com" Date: Tue, 12 Jul 2022 13:33:25 +0800 Subject: [PATCH 3/4] fetch upstream --- .github/ISSUE_TEMPLATE/refactor_request.md | 16 + README.md | 2 +- .../emr/rss/client/ShuffleClientImpl.java | 32 +- .../rss/client/read/RetryingChunkClient.java | 34 +- .../emr/rss/client/read/RssInputStream.java | 4 +- .../rss/client/write/LifecycleManager.scala | 317 +++--- .../read/RetryingChunkClientSuiteJ.java | 8 +- .../rss/common/network/TransportContext.java | 31 +- .../common/network/client/StreamCallback.java | 40 - .../network/client/StreamCallbackWithID.java | 22 - .../network/client/StreamInterceptor.java | 95 -- .../network/client/TransportClient.java | 126 +-- .../client/TransportClientFactory.java | 4 +- .../client/TransportResponseHandler.java | 101 +- .../network/protocol/AbstractMessage.java | 24 + .../network/protocol/ChunkFetchFailure.java | 20 +- .../network/protocol/ChunkFetchRequest.java | 18 +- .../network/protocol/ChunkFetchSuccess.java | 22 +- .../rss/common/network/protocol/Encoders.java | 19 - .../rss/common/network/protocol/Message.java | 11 +- .../network/protocol/MessageDecoder.java | 12 - .../network/protocol/OneWayMessage.java | 4 +- .../common/network/protocol/OpenStream.java | 103 ++ .../common/network/protocol/RpcRequest.java | 4 +- ...reamChunkId.java => StreamChunkSlice.java} | 40 +- .../network/protocol/StreamFailure.java | 77 -- .../{StreamRequest.java => StreamHandle.java} | 36 +- .../network/protocol/StreamResponse.java | 91 -- .../common/network/protocol/UploadStream.java | 107 -- ...pcHandler.java => BaseMessageHandler.java} | 36 +- ...rIterator.java => FileManagedBuffers.java} | 46 +- .../server/OneForOneStreamManager.java | 49 +- .../rss/common/network/server/RpcHandler.java | 153 --- .../common/network/server/StreamManager.java | 4 +- .../server/TransportRequestHandler.java | 310 +----- .../network/server/TransportServer.java | 22 +- .../server/TransportServerBootstrap.java | 4 +- .../util/ByteArrayReadableChannel.java | 64 -- .../util/ByteArrayWritableChannel.java | 69 -- .../rss/common/network/util/CryptoUtils.java | 49 - .../common/network/util/LevelDBProvider.java | 152 --- .../common/network/util/TransportConf.java | 100 -- .../common/protocol/PartitionLocation.java | 4 + .../common/protocol/message/StatusCode.java | 5 +- .../com/aliyun/emr/rss/common/RssConf.scala | 4 + .../protocol/message/ControlMessages.scala | 94 +- .../emr/rss/common/rpc/RpcEndpointRef.scala | 2 - .../aliyun/emr/rss/common/rpc/RpcEnv.scala | 15 - .../emr/rss/common/rpc/netty/Dispatcher.scala | 2 +- .../rss/common/rpc/netty/NettyRpcEnv.scala | 191 +--- .../common/rpc/netty/NettyStreamManager.scala | 2 +- .../emr/rss/common/rpc/netty/Outbox.scala | 2 +- .../emr/rss/common/util/MemoryParam.scala | 32 - .../network/ChunkFetchIntegrationSuiteJ.java | 25 +- .../rss/common/network/ProtocolSuiteJ.java | 111 -- .../RequestTimeoutIntegrationSuiteJ.java | 53 +- .../common/network/RpcIntegrationSuiteJ.java | 268 +---- .../emr/rss/common/network/StreamSuiteJ.java | 306 ------ .../network/TransportClientFactorySuiteJ.java | 13 +- .../TransportRequestHandlerSuiteJ.java | 154 --- .../TransportResponseHandlerSuiteJ.java | 72 +- .../protocol/MessageWithHeaderSuiteJ.java | 184 ---- .../server/OneForOneStreamManagerSuiteJ.java | 56 +- .../network/util/CryptoUtilsSuiteJ.java | 51 - .../emr/rss/common/meta/WorkerInfoSuite.scala | 6 +- .../rss/service/deploy/master/Master.scala | 14 +- .../deploy/master/MasterArguments.scala | 2 +- .../deploy/worker/ChunkFetchRpcHandler.java | 125 --- .../service/deploy/worker/FlushBuffer.java | 105 -- .../service/deploy/worker/MinimalByteBuf.java | 950 ------------------ .../deploy/worker/OpenStreamHandler.java | 25 - .../deploy/worker/PushDataHandler.java | 27 - .../deploy/worker/PushDataRpcHandler.java | 84 -- .../service/deploy/worker/Registerable.java | 22 - .../service/deploy/worker/Controller.scala | 394 ++++++++ .../service/deploy/worker/FetchHandler.scala | 164 +++ .../deploy/worker/PushDataHandler.scala | 406 ++++++++ .../rss/service/deploy/worker/Worker.scala | 900 +++-------------- .../deploy/worker/WorkerArguments.scala | 2 +- .../deploy/worker/FileWriterSuiteJ.java | 70 +- .../service/deploy/MiniClusterFeature.scala | 17 +- .../rss/service/deploy/SparkTestBase.scala | 6 +- 82 files changed, 1924 insertions(+), 5519 deletions(-) create mode 100644 .github/ISSUE_TEMPLATE/refactor_request.md delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamCallback.java delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamCallbackWithID.java delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamInterceptor.java create mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/protocol/OpenStream.java rename common/src/main/java/com/aliyun/emr/rss/common/network/protocol/{StreamChunkId.java => StreamChunkSlice.java} (62%) delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamFailure.java rename common/src/main/java/com/aliyun/emr/rss/common/network/protocol/{StreamRequest.java => StreamHandle.java} (62%) delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamResponse.java delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/protocol/UploadStream.java rename common/src/main/java/com/aliyun/emr/rss/common/network/server/{NoOpRpcHandler.java => BaseMessageHandler.java} (53%) rename common/src/main/java/com/aliyun/emr/rss/common/network/server/{ManagedBufferIterator.java => FileManagedBuffers.java} (68%) delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/server/RpcHandler.java delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/util/ByteArrayReadableChannel.java delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/util/ByteArrayWritableChannel.java delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/util/CryptoUtils.java delete mode 100644 common/src/main/java/com/aliyun/emr/rss/common/network/util/LevelDBProvider.java delete mode 100644 common/src/main/scala/com/aliyun/emr/rss/common/util/MemoryParam.scala delete mode 100644 common/src/test/java/com/aliyun/emr/rss/common/network/ProtocolSuiteJ.java delete mode 100644 common/src/test/java/com/aliyun/emr/rss/common/network/StreamSuiteJ.java delete mode 100644 common/src/test/java/com/aliyun/emr/rss/common/network/TransportRequestHandlerSuiteJ.java delete mode 100644 common/src/test/java/com/aliyun/emr/rss/common/network/protocol/MessageWithHeaderSuiteJ.java delete mode 100644 common/src/test/java/com/aliyun/emr/rss/common/network/util/CryptoUtilsSuiteJ.java delete mode 100644 server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/ChunkFetchRpcHandler.java delete mode 100644 server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/FlushBuffer.java delete mode 100644 server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/MinimalByteBuf.java delete mode 100644 server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/OpenStreamHandler.java delete mode 100644 server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/PushDataHandler.java delete mode 100644 server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/PushDataRpcHandler.java delete mode 100644 server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/Registerable.java create mode 100644 server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Controller.scala create mode 100644 server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/FetchHandler.scala create mode 100644 server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/PushDataHandler.scala diff --git a/.github/ISSUE_TEMPLATE/refactor_request.md b/.github/ISSUE_TEMPLATE/refactor_request.md new file mode 100644 index 00000000000..9cc01cc616a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/refactor_request.md @@ -0,0 +1,16 @@ +--- +name: Refactor request +about: Suggest an idea for this project +title: '[REFACTOR] title' +labels: feature +assignees: '' + +--- + +### Describe the intention you'd like +A clear and concise description of what you want to do. + + +/cc @who-need-to-know + +/assign @who-can-help-you diff --git a/README.md b/README.md index 2badec14612..b08bc617be1 100644 --- a/README.md +++ b/README.md @@ -181,7 +181,7 @@ spark.shuffle.manager org.apache.spark.shuffle.rss.RssShuffleManager spark.serializer org.apache.spark.serializer.KryoSerializer # if you are running HA cluster ,set spark.rss.master.address to any RSS master -spark.rss.master.address rss-master-host:9097 +spark.rss.master.address rss-master-host:rss-master-port spark.shuffle.service.enabled false # optional:hash,sort diff --git a/client/src/main/java/com/aliyun/emr/rss/client/ShuffleClientImpl.java b/client/src/main/java/com/aliyun/emr/rss/client/ShuffleClientImpl.java index 92c2368fedb..e1afa6af904 100644 --- a/client/src/main/java/com/aliyun/emr/rss/client/ShuffleClientImpl.java +++ b/client/src/main/java/com/aliyun/emr/rss/client/ShuffleClientImpl.java @@ -50,7 +50,7 @@ import com.aliyun.emr.rss.common.network.client.TransportClientFactory; import com.aliyun.emr.rss.common.network.protocol.PushData; import com.aliyun.emr.rss.common.network.protocol.PushMergedData; -import com.aliyun.emr.rss.common.network.server.NoOpRpcHandler; +import com.aliyun.emr.rss.common.network.server.BaseMessageHandler; import com.aliyun.emr.rss.common.network.util.TransportConf; import com.aliyun.emr.rss.common.protocol.PartitionLocation; import com.aliyun.emr.rss.common.protocol.RpcNameConstants; @@ -140,7 +140,7 @@ public ShuffleClientImpl(RssConf conf) { TransportModuleConstants.DATA_MODULE, conf.getInt("rss.data.io.threads", 8)); TransportContext context = - new TransportContext(dataTransportConf, new NoOpRpcHandler(), true); + new TransportContext(dataTransportConf, new BaseMessageHandler(), true); List bootstraps = Lists.newArrayList(); dataClientFactory = context.createClientFactory(bootstraps); @@ -257,6 +257,12 @@ private ConcurrentHashMap registerShuffle( result.put(partitionLoc.getReduceId(), partitionLoc); } return result; + } else if (response.status().equals(StatusCode.SlotNotAvailable)) { + logger.warn("LifecycleManager request slots return {}, retry again," + + " remain retry times {}", StatusCode.SlotNotAvailable.toString(), numRetries - 1); + } else { + logger.error("LifecycleManager request slots return {}, retry again," + + " remain retry times {}", StatusCode.Failed.toString(), numRetries - 1); } } catch (Exception e) { logger.error("Exception raised while registering shuffle {} with {} mapper and" + @@ -832,6 +838,7 @@ public RssInputStream readPartition(String applicationId, int shuffleId, int red @Override public RssInputStream readPartition(String applicationId, int shuffleId, int reduceId, int attemptNumber, int startMapIndex, int endMapIndex) throws IOException { + String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); ReduceFileGroups fileGroups = reduceFileGroupsMap.computeIfAbsent(shuffleId, (id) -> { try { if (driverRssMetaService == null) { @@ -846,11 +853,17 @@ public RssInputStream readPartition(String applicationId, int shuffleId, int red GetReducerFileGroupResponse response = driverRssMetaService.askSync(getReducerFileGroup, classTag); - if (response != null && response.status() == StatusCode.Success) { + if (response.status() == StatusCode.Success) { return new ReduceFileGroups(response.fileGroup(), response.attempts()); + } else if (response.status() == StatusCode.StageEndTimeOut) { + logger.warn("Request {} return {} for {}", + getReducerFileGroup, StatusCode.StageEndTimeOut.toString(), shuffleKey); + } else if (response.status() == StatusCode.ShuffleDataLost) { + logger.warn("Request {} return {} for {}", + getReducerFileGroup, StatusCode.ShuffleDataLost.toString(), shuffleKey); } } catch (Exception e) { - logger.warn("Exception raised while getting reduce file groups.", e); + logger.error("Exception raised while call GetReducerFileGroup for " + shuffleKey + ".", e); } return null; }); @@ -859,15 +872,14 @@ public RssInputStream readPartition(String applicationId, int shuffleId, int red String msg = "Shuffle data lost for shuffle " + shuffleId + " reduce " + reduceId + "!"; logger.error(msg); throw new IOException(msg); - } - if (fileGroups.partitionGroups == null) { + } else if (fileGroups.partitionGroups.length == 0) { logger.warn("Shuffle data is empty for shuffle {} reduce {}.", shuffleId, reduceId); return RssInputStream.empty(); + } else { + return RssInputStream.create(conf, dataClientFactory, shuffleKey, + fileGroups.partitionGroups[reduceId], fileGroups.mapAttempts, attemptNumber, + startMapIndex, endMapIndex); } - String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); - return RssInputStream.create(conf, dataClientFactory, shuffleKey, - fileGroups.partitionGroups[reduceId], fileGroups.mapAttempts, attemptNumber, - startMapIndex, endMapIndex); } @Override diff --git a/client/src/main/java/com/aliyun/emr/rss/client/read/RetryingChunkClient.java b/client/src/main/java/com/aliyun/emr/rss/client/read/RetryingChunkClient.java index 4f6728609a5..f1f137f0c88 100644 --- a/client/src/main/java/com/aliyun/emr/rss/client/read/RetryingChunkClient.java +++ b/client/src/main/java/com/aliyun/emr/rss/client/read/RetryingChunkClient.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ExecutorService; @@ -37,6 +36,9 @@ import com.aliyun.emr.rss.common.network.client.ChunkReceivedCallback; import com.aliyun.emr.rss.common.network.client.TransportClient; import com.aliyun.emr.rss.common.network.client.TransportClientFactory; +import com.aliyun.emr.rss.common.network.protocol.AbstractMessage; +import com.aliyun.emr.rss.common.network.protocol.OpenStream; +import com.aliyun.emr.rss.common.network.protocol.StreamHandle; import com.aliyun.emr.rss.common.network.util.NettyUtils; import com.aliyun.emr.rss.common.network.util.TransportConf; import com.aliyun.emr.rss.common.protocol.PartitionLocation; @@ -238,8 +240,7 @@ class Replica { private final PartitionLocation location; private final TransportClientFactory clientFactory; - private long streamId; - private int numChunks; + private StreamHandle streamHandle; private TransportClient client; private int startMapIndex; private int endMapIndex; @@ -272,20 +273,20 @@ public synchronized TransportClient getOrOpenStream() if (client == null || !client.isActive()) { client = clientFactory.createClient(location.getHost(), location.getFetchPort()); - ByteBuffer openMessage = createOpenMessage(); - ByteBuffer response = client.sendRpcSync(openMessage, timeoutMs); - streamId = response.getLong(); - numChunks = response.getInt(); + OpenStream openBlocks = new OpenStream(shuffleKey, location.getFileName(), + startMapIndex, endMapIndex); + ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), timeoutMs); + streamHandle = (StreamHandle) AbstractMessage.fromByteBuffer(response); } return client; } public long getStreamId() { - return streamId; + return streamHandle.streamId; } public int getNumChunks() { - return numChunks; + return streamHandle.numChunks; } @Override @@ -293,21 +294,6 @@ public String toString() { return location.getHost() + ":" + location.getFetchPort(); } - private ByteBuffer createOpenMessage() { - byte[] shuffleKeyBytes = shuffleKey.getBytes(StandardCharsets.UTF_8); - byte[] fileNameBytes = location.getFileName().getBytes(StandardCharsets.UTF_8); - ByteBuffer openMessage = ByteBuffer.allocate( - 4 + shuffleKeyBytes.length + 4 + fileNameBytes.length + 4 + 4); - openMessage.putInt(shuffleKeyBytes.length); - openMessage.put(shuffleKeyBytes); - openMessage.putInt(fileNameBytes.length); - openMessage.put(fileNameBytes); - openMessage.putInt(startMapIndex); - openMessage.putInt(endMapIndex); - openMessage.flip(); - return openMessage; - } - @VisibleForTesting PartitionLocation getLocation() { return location; diff --git a/client/src/main/java/com/aliyun/emr/rss/client/read/RssInputStream.java b/client/src/main/java/com/aliyun/emr/rss/client/read/RssInputStream.java index 46a316e3661..612630ac381 100644 --- a/client/src/main/java/com/aliyun/emr/rss/client/read/RssInputStream.java +++ b/client/src/main/java/com/aliyun/emr/rss/client/read/RssInputStream.java @@ -159,14 +159,14 @@ private void moveToNextReader() throws IOException { } currentReader = createReader(locations[fileIndex]); - logger.info("Moved to next partition {},startMapIndex {} endMapIndex {} , {}/{} read , " + + logger.debug("Moved to next partition {},startMapIndex {} endMapIndex {} , {}/{} read , " + "get chunks size {}", locations[fileIndex], startMapIndex, endMapIndex, fileIndex, locations.length, currentReader.numChunks); while (currentReader.numChunks < 1 && fileIndex < locations.length - 1) { fileIndex++; currentReader.close(); currentReader = createReader(locations[fileIndex]); - logger.info("Moved to next partition {},startMapIndex {} endMapIndex {} , {}/{} read , " + + logger.debug("Moved to next partition {},startMapIndex {} endMapIndex {} , {}/{} read , " + "get chunks size {}", locations[fileIndex], startMapIndex, endMapIndex, fileIndex, locations.length, currentReader.numChunks); } diff --git a/client/src/main/scala/com/aliyun/emr/rss/client/write/LifecycleManager.scala b/client/src/main/scala/com/aliyun/emr/rss/client/write/LifecycleManager.scala index fb6008f5e8b..abc18c09485 100644 --- a/client/src/main/scala/com/aliyun/emr/rss/client/write/LifecycleManager.scala +++ b/client/src/main/scala/com/aliyun/emr/rss/client/write/LifecycleManager.scala @@ -18,6 +18,7 @@ package com.aliyun.emr.rss.client.write import java.util +import java.util.{List => JList} import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} import scala.collection.JavaConverters._ @@ -36,7 +37,6 @@ import com.aliyun.emr.rss.common.protocol.RpcNameConstants.WORKER_EP import com.aliyun.emr.rss.common.protocol.message.ControlMessages._ import com.aliyun.emr.rss.common.protocol.message.StatusCode import com.aliyun.emr.rss.common.rpc._ -import com.aliyun.emr.rss.common.rpc.netty.{NettyRpcEndpointRef, NettyRpcEnv} import com.aliyun.emr.rss.common.util.{ThreadUtils, Utils} class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint with Logging { @@ -114,7 +114,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit require(rssHARetryClient != null, "When sending a heartbeat, client shouldn't be null.") val appHeartbeat = HeartBeatFromApplication(appId, ZERO_UUID) rssHARetryClient.send(appHeartbeat) - logInfo("Successfully send app heartbeat.") + logDebug("Successfully send app heartbeat.") } catch { case it: InterruptedException => logWarning("Interrupted while sending app heartbeat.") @@ -174,10 +174,10 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit case msg: GetBlacklist => handleGetBlacklist(msg) case StageEnd(applicationId, shuffleId) => - logInfo(s"Received StageEnd request, ${Utils.makeShuffleKey(applicationId, shuffleId)}.") + logDebug(s"Received StageEnd request, ${Utils.makeShuffleKey(applicationId, shuffleId)}.") handleStageEnd(null, applicationId, shuffleId) case UnregisterShuffle(applicationId, shuffleId, _) => - logInfo(s"Received UnregisterShuffle request," + + logDebug(s"Received UnregisterShuffle request," + s"${Utils.makeShuffleKey(applicationId, shuffleId)}.") handleUnregisterShuffle(null, applicationId, shuffleId) } @@ -190,19 +190,19 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit numPartitions) case Revive(applicationId, shuffleId, mapId, attemptId, reduceId, epoch, oldPartition, cause) => - logDebug(s"Received Revive request, " + + logTrace(s"Received Revive request, " + s"$applicationId, $shuffleId, $mapId, $attemptId, ,$reduceId," + s" $epoch, $oldPartition, $cause.") handleRevive(context, applicationId, shuffleId, mapId, attemptId, reduceId, epoch, oldPartition, cause) case PartitionSplit(applicationId, shuffleId, reduceId, epoch, oldPartition) => - logDebug(s"Received split request, " + + logTrace(s"Received split request, " + s"$applicationId, $shuffleId, $reduceId, $epoch, $oldPartition") handlePartitionSplitRequest(context, applicationId, shuffleId, reduceId, epoch, oldPartition) case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers) => - logDebug(s"Received MapperEnd request, " + + logTrace(s"Received MapperEnd request, " + s"${Utils.makeMapKey(applicationId, shuffleId, mapId, attemptId)}.") handleMapperEnd(context, applicationId, shuffleId, mapId, attemptId, numMappers) @@ -212,7 +212,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit handleGetReducerFileGroup(context, shuffleId) case StageEnd(applicationId, shuffleId) => - logInfo(s"Received StageEnd request, ${Utils.makeShuffleKey(applicationId, shuffleId)}.") + logDebug(s"Received StageEnd request, ${Utils.makeShuffleKey(applicationId, shuffleId)}.") handleStageEnd(context, applicationId, shuffleId) } @@ -230,7 +230,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit // If do, just register and return registerShuffleRequest.synchronized { if (registerShuffleRequest.containsKey(shuffleId)) { - logInfo("[handleRegisterShuffle] request for same shuffleKey exists, just register") + logDebug("[handleRegisterShuffle] request for same shuffleKey exists, just register") registerShuffleRequest.get(shuffleId).add(context) return } else { @@ -243,7 +243,6 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit .filter(_.getEpoch == 0) .toList .asJava - logDebug(s"Shuffle $shuffleId already registered, just return.") if (initialLocs.size != numPartitions) { logWarning(s"Shuffle $shuffleId location size ${initialLocs.size} not equal to " + s"numPartitions: $numPartitions!") @@ -262,22 +261,32 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit val reduceIdList = new util.ArrayList[Integer] (0 until numPartitions).foreach(x => reduceIdList.add(new Integer(x))) val res = requestSlotsWithRetry(applicationId, shuffleId, reduceIdList) - if (res.status != StatusCode.Success) { - logError(s"OfferSlots for $shuffleId failed!") + + def reply(response: RegisterShuffleResponse): Unit = { registerShuffleRequest.synchronized { val set = registerShuffleRequest.get(shuffleId) - set.asScala.foreach { context => - context.reply(RegisterShuffleResponse(StatusCode.SlotNotAvailable, null)) - } + set.asScala.foreach(_.reply(response)) registerShuffleRequest.remove(shuffleId) } - return - } else { - logInfo(s"OfferSlots for ${Utils.makeShuffleKey(applicationId, shuffleId)} Success!") - logDebug(s" Slots Info: ${res.workerResource}") } - // reserve buffers + res.status match { + case StatusCode.Failed => + logError(s"OfferSlots RPC request failed for $shuffleId!") + reply(RegisterShuffleResponse(StatusCode.Failed, List.empty.asJava)) + return + case StatusCode.SlotNotAvailable => + logError(s"OfferSlots for $shuffleId failed!") + reply(RegisterShuffleResponse(StatusCode.SlotNotAvailable, List.empty.asJava)) + return + case StatusCode.Success => + logInfo(s"OfferSlots for ${Utils.makeShuffleKey(applicationId, shuffleId)} Success!") + logDebug(s" Slots Info: ${res.workerResource}") + case _ => // won't happen + } + + // Reserve slots for each PartitionLocation. When response status is SUCCESS, WorkerResource + // won't be empty since master will reply SlotNotAvailable status when reserved slots is empty. val slots = res.workerResource val candidatesWorkers = new util.HashSet(slots.keySet()) val connectFailedWorkers = new util.ArrayList[WorkerInfo]() @@ -286,9 +295,6 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit try { workerInfo.endpoint = rpcEnv.setupEndpointRef( RpcAddress.apply(workerInfo.host, workerInfo.rpcPort), WORKER_EP) - workerInfo.endpoint.asInstanceOf[NettyRpcEndpointRef].client = - rpcEnv.asInstanceOf[NettyRpcEnv].clientFactory.createClient(workerInfo.host, - workerInfo.rpcPort) } catch { case t: Throwable => logError(s"Init rpc client for $workerInfo failed", t) @@ -306,13 +312,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit // reserve buffers failed, clear allocated resources if (!reserveSlotsSuccess) { logError(s"reserve buffer for $shuffleId failed, reply to all.") - registerShuffleRequest.synchronized { - val set = registerShuffleRequest.get(shuffleId) - set.asScala.foreach { context => - context.reply(RegisterShuffleResponse(StatusCode.ReserveSlotFailed, null)) - } - registerShuffleRequest.remove(shuffleId) - } + reply(RegisterShuffleResponse(StatusCode.ReserveSlotFailed, List.empty.asJava)) // tell Master to release slots requestReleaseSlots(rssHARetryClient, ReleaseSlots(applicationId, shuffleId, new util.ArrayList[String](), new util.ArrayList[Integer]())) @@ -352,19 +352,13 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit reducerFileGroupsMap.put(shuffleId, new Array[Array[PartitionLocation]](numPartitions)) logInfo(s"Handle RegisterShuffle Success for $shuffleId.") - registerShuffleRequest.synchronized { - val set = registerShuffleRequest.get(shuffleId) - set.asScala.foreach { context => - context.reply(RegisterShuffleResponse(StatusCode.Success, locations.asJava)) - } - registerShuffleRequest.remove(shuffleId) - } + reply(RegisterShuffleResponse(StatusCode.Success, locations.asJava)) } def blacklistPartition(oldPartition: PartitionLocation, cause: StatusCode): Unit = { // only blacklist if cause is PushDataFailMain val failedWorker = new util.ArrayList[WorkerInfo]() - if (cause == StatusCode.PushDataFailMain) { + if (cause == StatusCode.PushDataFailMain && oldPartition != null) { failedWorker.add(oldPartition.getWorker) } if (!failedWorker.isEmpty) { @@ -401,7 +395,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit shuffleReviving.synchronized { if (shuffleReviving.containsKey(reduceId)) { shuffleReviving.get(reduceId).add(context) - logInfo(s"For $shuffleId, same partition $reduceId-$oldEpoch is reviving," + + logTrace(s"For $shuffleId, same partition $reduceId-$oldEpoch is reviving," + s"register context.") return } else { @@ -409,7 +403,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit val latestLoc = getLatestPartition(shuffleId, reduceId, oldEpoch) if (latestLoc != null) { context.reply(ChangeLocationResponse(StatusCode.Success, latestLoc)) - logInfo(s"New partition found, old partition $reduceId-$oldEpoch return it." + + logDebug(s"New partition found, old partition $reduceId-$oldEpoch return it." + s" shuffleId: $shuffleId $latestLoc") return } @@ -423,7 +417,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit logWarning(s"Do Revive for shuffle ${ Utils.makeShuffleKey(applicationId, shuffleId)}, oldPartition: $oldPartition, cause: $cause") blacklistPartition(oldPartition, cause) - handleChangePartitionLocation(shuffleReviving, applicationId, shuffleId, reduceId, + handleChangePartitionLocation(shuffleReviving, applicationId, shuffleId, reduceId, oldEpoch, oldPartition) } @@ -434,11 +428,18 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } private def handleChangePartitionLocation( - contexts: ConcurrentHashMap[Integer, util.Set[RpcCallContext]], - applicationId: String, shuffleId: Int, reduceId: Int, oldPartition: PartitionLocation): Unit = { + contexts: ConcurrentHashMap[Integer, util.Set[RpcCallContext]], + applicationId: String, + shuffleId: Int, + reduceId: Int, + oldEpochId: Int, + oldPartition: PartitionLocation): Unit = { val candidates = workersNotBlacklisted(shuffleId) - val slots = reallocateSlotsFromCandidates( - List(oldPartition), candidates) + val slots = if (oldPartition != null) { + reallocateSlotsFromCandidates(List(oldPartition), candidates) + } else { + reallocateForNonExistPartitionLocation(reduceId, oldEpochId, candidates) + } if (slots == null) { logError("[Update partition] failed for slot not available.") contexts.synchronized { @@ -468,7 +469,6 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit slaves.get(0).getPeer } - logDebug(s"[Update partition] success for $shuffleId $location.") contexts.synchronized { contexts.remove(reduceId) }.asScala.foreach(_.reply(ChangeLocationResponse(StatusCode.Success, location))) @@ -499,14 +499,11 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit shuffleSplitting.synchronized { if (shuffleSplitting.containsKey(reduceId)) { shuffleSplitting.get(reduceId).add(context) - logDebug(s"For $shuffleId, same $reduceId-$oldEpoch is splitting, register context") return } else { val latestLoc = getLatestPartition(shuffleId, reduceId, oldEpoch) if (latestLoc != null) { context.reply(ChangeLocationResponse(StatusCode.Success, latestLoc)) - logDebug(s"Split request found new partition, old partition $reduceId-$oldEpoch" + - s" return it. shuffleId: $shuffleId $latestLoc") return } val set = new util.HashSet[RpcCallContext]() @@ -515,10 +512,10 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } } - logDebug(s"Relocate partition for shuffle split ${Utils.makeShuffleKey(applicationId, + logDebug(s"Relocate partition for shuffle split ${Utils.makeShuffleKey(applicationId, shuffleId)}, oldPartition: $oldPartition") - handleChangePartitionLocation(shuffleSplitting, applicationId, shuffleId, reduceId, + handleChangePartitionLocation(shuffleSplitting, applicationId, shuffleId, reduceId, oldEpoch, oldPartition) } @@ -535,7 +532,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit var attempts = shuffleMapperAttempts.get(shuffleId) // it would happen when task with no shuffle data called MapperEnd first if (attempts == null) { - logInfo(s"[handleMapperEnd] $shuffleId not registered, create one.") + logDebug(s"[handleMapperEnd] $shuffleId not registered, create one.") attempts = new Array[Int](numMappers) 0 until numMappers foreach (ind => attempts(ind) = -1) shuffleMapperAttempts.put(shuffleId, attempts) @@ -566,8 +563,8 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } private def handleGetReducerFileGroup( - context: RpcCallContext, - shuffleId: Int): Unit = { + context: RpcCallContext, + shuffleId: Int): Unit = { logDebug(s"Wait for StageEnd, $shuffleId.") var timeout = RssConf.stageEndTimeout(conf) val delta = 50 @@ -575,29 +572,29 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit Thread.sleep(50) if (timeout <= 0) { logError(s"StageEnd Timeout! $shuffleId.") - context.reply(GetReducerFileGroupResponse(StatusCode.Failed, null, null)) + context.reply( + GetReducerFileGroupResponse(StatusCode.StageEndTimeOut, Array.empty, Array.empty)) return } timeout = timeout - delta } - logDebug(s"Start getting reduce file group, $shuffleId.") if (dataLostShuffleSet.contains(shuffleId)) { - context.reply(GetReducerFileGroupResponse(StatusCode.Failed, null, null)) + context.reply( + GetReducerFileGroupResponse(StatusCode.ShuffleDataLost, Array.empty, Array.empty)) } else { - val shuffleFileGroup = reducerFileGroupsMap.get(shuffleId) context.reply(GetReducerFileGroupResponse( StatusCode.Success, - shuffleFileGroup, - shuffleMapperAttempts.get(shuffleId) + reducerFileGroupsMap.getOrDefault(shuffleId, Array.empty), + shuffleMapperAttempts.getOrDefault(shuffleId, Array.empty) )) } } private def handleStageEnd( - context: RpcCallContext, - applicationId: String, - shuffleId: Int): Unit = { + context: RpcCallContext, + applicationId: String, + shuffleId: Int): Unit = { // check whether shuffle has registered if (!registeredShuffle.contains(shuffleId)) { logInfo(s"[handleStageEnd]" + @@ -646,29 +643,26 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit val masterIds = masterParts.asScala.map(_.getUniqueId).asJava val slaveIds = slaveParts.asScala.map(_.getUniqueId).asJava - val res = requestCommitFiles(worker.endpoint, - CommitFiles(applicationId, shuffleId, masterIds, slaveIds, - shuffleMapperAttempts.get(shuffleId))) - - if (res.status != StatusCode.Success) { - commitFilesFailedWorkers.add(worker) + val commitFiles = CommitFiles(applicationId, shuffleId, masterIds, + slaveIds, shuffleMapperAttempts.get(shuffleId)) + val res = requestCommitFiles(worker.endpoint, commitFiles) + + res.status match { + case StatusCode.Success => // do nothing + case StatusCode.PartialSuccess | StatusCode.ShuffleNotRegistered | StatusCode.Failed => + logDebug(s"Request $commitFiles return ${res.status} for " + + s"${Utils.makeShuffleKey(applicationId, shuffleId)}") + commitFilesFailedWorkers.add(worker) + case _ => // won't happen } // record committed partitionIds - if (res.committedMasterIds != null) { - committedMasterIds.addAll(res.committedMasterIds) - } - if (res.committedSlaveIds != null) { - committedSlaveIds.addAll(res.committedSlaveIds) - } + committedMasterIds.addAll(res.committedMasterIds) + committedSlaveIds.addAll(res.committedSlaveIds) // record failed partitions - if (res.failedMasterIds != null) { - failedMasterIds.addAll(res.failedMasterIds) - } - if (res.failedSlaveIds != null) { - failedSlaveIds.addAll(res.failedSlaveIds) - } + failedMasterIds.addAll(res.failedMasterIds) + failedSlaveIds.addAll(res.failedSlaveIds) } } @@ -836,13 +830,6 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } }) } - // remove failed slot from total slots, close transport client - if (workerInfo.endpoint != null) { - val transportClient = workerInfo.endpoint.asInstanceOf[NettyRpcEndpointRef].client - if (null != transportClient && transportClient.isActive) { - transportClient.close() - } - } }) val newMapFunc = new util.function.Function[WorkerInfo, @@ -969,80 +956,104 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit retryReserveSlotsSuccess } - def reallocateSlotsFromCandidates(oldPartitions: List[PartitionLocation], - candidates: List[WorkerInfo]): WorkerResource = { + private def newMapFunc = + new util.function.Function[WorkerInfo, (JList[PartitionLocation], JList[PartitionLocation])] { + override def apply(w: WorkerInfo): (JList[PartitionLocation], JList[PartitionLocation]) = + (new util.LinkedList[PartitionLocation](), new util.LinkedList[PartitionLocation]()) + } + + private def allocateFromCandidates( + reduceId: Int, + oldEpochId: Int, + candidates: List[WorkerInfo], + slots: WorkerResource): Unit = { + val masterIndex = Random.nextInt(candidates.size) + val masterLocation = new PartitionLocation( + reduceId, + oldEpochId + 1, + candidates(masterIndex).host, + candidates(masterIndex).rpcPort, + candidates(masterIndex).pushPort, + candidates(masterIndex).fetchPort, + candidates(masterIndex).replicatePort, + PartitionLocation.Mode.Master) + + if (ShouldReplicate) { + val slaveIndex = (masterIndex + 1) % candidates.size + val slaveLocation = new PartitionLocation( + reduceId, + oldEpochId + 1, + candidates(slaveIndex).host, + candidates(slaveIndex).rpcPort, + candidates(slaveIndex).pushPort, + candidates(slaveIndex).fetchPort, + candidates(slaveIndex).replicatePort, + PartitionLocation.Mode.Slave, + masterLocation + ) + masterLocation.setPeer(slaveLocation) + val masterAndSlavePairs = slots.computeIfAbsent(candidates(slaveIndex), newMapFunc) + masterAndSlavePairs._2.add(slaveLocation) + } + + val masterAndSlavePairs = slots.computeIfAbsent(candidates(masterIndex), newMapFunc) + masterAndSlavePairs._1.add(masterLocation) + } + + def reallocateForNonExistPartitionLocation( + reduceId: Int, + oldEpochId: Int, + candidates: List[WorkerInfo]): WorkerResource = { if (candidates.size < 1 || (ShouldReplicate && candidates.size < 2)) { logError("Not enough candidates for revive") return null } + val slots = new WorkerResource() + allocateFromCandidates(reduceId, oldEpochId, candidates, slots) + slots + } - val newMapFunc = - new util.function.Function[WorkerInfo, - (util.List[PartitionLocation], util.List[PartitionLocation])] { - override def apply(w: WorkerInfo): - (util.List[PartitionLocation], util.List[PartitionLocation]) = - (new util.LinkedList[PartitionLocation](), new util.LinkedList[PartitionLocation]()) - } + def reallocateSlotsFromCandidates( + oldPartitions: List[PartitionLocation], + candidates: List[WorkerInfo]): WorkerResource = { + if (candidates.size < 1 || (ShouldReplicate && candidates.size < 2)) { + logError("Not enough candidates for revive") + return null + } val slots = new WorkerResource() - oldPartitions.foreach(partitionLocation => { - val masterIndex = Random.nextInt(candidates.size) - val masterLocation = new PartitionLocation( - partitionLocation.getReduceId, - partitionLocation.getEpoch + 1, - candidates(masterIndex).host, - candidates(masterIndex).rpcPort, - candidates(masterIndex).pushPort, - candidates(masterIndex).fetchPort, - candidates(masterIndex).replicatePort, - PartitionLocation.Mode.Master) - - if (ShouldReplicate) { - val slaveIndex = (masterIndex + 1) % candidates.size - val slaveLocation = new PartitionLocation( - partitionLocation.getReduceId, - partitionLocation.getEpoch + 1, - candidates(slaveIndex).host, - candidates(slaveIndex).rpcPort, - candidates(slaveIndex).pushPort, - candidates(slaveIndex).fetchPort, - candidates(slaveIndex).replicatePort, - PartitionLocation.Mode.Slave, - masterLocation - ) - masterLocation.setPeer(slaveLocation) - val masterAndSlavePairs = slots.computeIfAbsent(candidates(slaveIndex), newMapFunc) - masterAndSlavePairs._2.add(slaveLocation) - } - - val masterAndSlavePairs = slots.computeIfAbsent(candidates(masterIndex), newMapFunc) - masterAndSlavePairs._1.add(masterLocation) - }) + oldPartitions.foreach { partition => + allocateFromCandidates(partition.getReduceId, partition.getEpoch, candidates, slots) + } slots } - def destroyBuffersWithRetry(applicationId: String, shuffleId: Int, - worker: WorkerResource): (util.List[String], util.List[String]) = { + def destroyBuffersWithRetry( + applicationId: String, + shuffleId: Int, + worker: WorkerResource): (util.List[String], util.List[String]) = { val failedMasters = new util.LinkedList[String]() val failedSlaves = new util.LinkedList[String]() val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) - worker.asScala.foreach(entry => { - var res = requestDestroy(entry._1.endpoint, - Destroy(shuffleKey, entry._2._1.asScala.map(_.getUniqueId).asJava, - entry._2._2.asScala.map(_.getUniqueId).asJava)) + worker.asScala.foreach { case (workerInfo, (masterLocation, slaveLocation)) => + val destroy = Destroy(shuffleKey, + masterLocation.asScala.map(_.getUniqueId).asJava, + slaveLocation.asScala.map(_.getUniqueId).asJava) + var res = requestDestroy(workerInfo.endpoint, destroy) if (res.status != StatusCode.Success) { - res = requestDestroy(entry._1.endpoint, + logDebug(s"Request $destroy return ${res.status} for " + + s"${Utils.makeShuffleKey(applicationId, shuffleId)}") + res = requestDestroy(workerInfo.endpoint, Destroy(shuffleKey, res.failedMasters, res.failedSlaves)) } - if (null != res.failedMasters) failedMasters.addAll(res.failedMasters) - if (null != res.failedSlaves) failedSlaves.addAll(res.failedSlaves) - }) + failedMasters.addAll(res.failedMasters) + failedSlaves.addAll(res.failedSlaves) + } (failedMasters, failedSlaves) } private def removeExpiredShuffle(): Unit = { - logInfo("Check for expired shuffle.") val currentTime = System.currentTimeMillis() val keys = unregisterShuffleTime.keys().asScala.toList keys.foreach { key => @@ -1071,12 +1082,8 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit logInfo(s"Received Blacklist from Master, blacklist: ${res.blacklist} " + s"unkown workers: ${res.unknownWorkers}") blacklist.clear() - if (res.blacklist != null) { - blacklist.addAll(res.blacklist) - } - if (res.unknownWorkers != null) { - blacklist.addAll(res.unknownWorkers) - } + blacklist.addAll(res.blacklist) + blacklist.addAll(res.unknownWorkers) } } @@ -1102,7 +1109,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } catch { case e: Exception => logError(s"AskSync RegisterShuffle for $shuffleKey failed.", e) - RequestSlotsResponse(StatusCode.Failed, null) + RequestSlotsResponse(StatusCode.Failed, new WorkerResource()) } } @@ -1130,13 +1137,15 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } private def requestCommitFiles( - endpoint: RpcEndpointRef, message: CommitFiles): CommitFilesResponse = { + endpoint: RpcEndpointRef, + message: CommitFiles): CommitFilesResponse = { try { endpoint.askSync[CommitFilesResponse](message) } catch { case e: Exception => logError(s"AskSync CommitFiles for ${message.shuffleId} failed.", e) - CommitFilesResponse(StatusCode.Failed, null, null, message.masterIds, message.slaveIds) + CommitFilesResponse(StatusCode.Failed, List.empty.asJava, List.empty.asJava, + message.masterIds, message.slaveIds) } } @@ -1163,14 +1172,15 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } } - private def requestGetBlacklist(rssHARetryClient: RssHARetryClient, - msg: GetBlacklist): GetBlacklistResponse = { + private def requestGetBlacklist( + rssHARetryClient: RssHARetryClient, + msg: GetBlacklist): GetBlacklistResponse = { try { rssHARetryClient.askSync[GetBlacklistResponse](msg, classOf[GetBlacklistResponse]) } catch { case e: Exception => logError(s"AskSync GetBlacklist failed.", e) - GetBlacklistResponse(StatusCode.Failed, null, null) + GetBlacklistResponse(StatusCode.Failed, List.empty.asJava, List.empty.asJava) } } @@ -1185,7 +1195,6 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } def isClusterOverload(numPartitions: Int = 0): Boolean = { - logInfo(s"Ask Sync Cluster Load Status") try { rssHARetryClient.askSync[GetClusterLoadStatusResponse](GetClusterLoadStatus(numPartitions), classOf[GetClusterLoadStatusResponse]).isOverload diff --git a/client/src/test/java/com/aliyun/emr/rss/client/read/RetryingChunkClientSuiteJ.java b/client/src/test/java/com/aliyun/emr/rss/client/read/RetryingChunkClientSuiteJ.java index b0ec092463b..1a41e6e8ee7 100644 --- a/client/src/test/java/com/aliyun/emr/rss/client/read/RetryingChunkClientSuiteJ.java +++ b/client/src/test/java/com/aliyun/emr/rss/client/read/RetryingChunkClientSuiteJ.java @@ -56,6 +56,7 @@ import com.aliyun.emr.rss.common.network.client.TransportClient; import com.aliyun.emr.rss.common.network.client.TransportClientFactory; import com.aliyun.emr.rss.common.network.client.TransportResponseHandler; +import com.aliyun.emr.rss.common.network.protocol.StreamHandle; import com.aliyun.emr.rss.common.protocol.PartitionLocation; import com.aliyun.emr.rss.common.util.ThreadUtils; @@ -380,11 +381,8 @@ public void fetchChunk(long streamId, int chunkId, ChunkReceivedCallback callbac @Override public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { - ByteBuffer buffer = ByteBuffer.allocate(8 + 4); - buffer.putLong(streamId); - buffer.putInt(numChunks); - buffer.flip(); - return buffer; + StreamHandle handle = new StreamHandle(streamId, numChunks); + return handle.toByteBuffer(); } @Override diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/TransportContext.java b/common/src/main/java/com/aliyun/emr/rss/common/network/TransportContext.java index 98a93f14562..b7d6297952a 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/TransportContext.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/TransportContext.java @@ -57,7 +57,7 @@ public class TransportContext { private static final Logger logger = LoggerFactory.getLogger(TransportContext.class); private final TransportConf conf; - private final RpcHandler rpcHandler; + private final BaseMessageHandler handler; private final boolean closeIdleConnections; /** @@ -80,12 +80,12 @@ public class TransportContext { public TransportContext( TransportConf conf, - RpcHandler rpcHandler, + BaseMessageHandler handler, boolean closeIdleConnections, AbstractSource source, ChannelHandler channelHandler) { this.conf = conf; - this.rpcHandler = rpcHandler; + this.handler = handler; this.closeIdleConnections = closeIdleConnections; this.source = source; this.channelHandler = channelHandler; @@ -93,21 +93,21 @@ public TransportContext( public TransportContext( TransportConf conf, - RpcHandler rpcHandler, + BaseMessageHandler handler, boolean closeIdleConnections, AbstractSource source) { - this(conf, rpcHandler, closeIdleConnections, source, null); + this(conf, handler, closeIdleConnections, source, null); } - public TransportContext(TransportConf conf, RpcHandler rpcHandler) { - this(conf, rpcHandler, false, null, null); + public TransportContext(TransportConf conf, BaseMessageHandler handler) { + this(conf, handler, false, null, null); } public TransportContext( TransportConf conf, - RpcHandler rpcHandler, + BaseMessageHandler handler, boolean closeIdleConnections) { - this(conf, rpcHandler, closeIdleConnections, null, null); + this(conf, handler, closeIdleConnections, null, null); } /** @@ -125,13 +125,13 @@ public TransportClientFactory createClientFactory() { /** Create a server which will attempt to bind to a specific port. */ public TransportServer createServer(int port, List bootstraps) { - return new TransportServer(this, null, port, rpcHandler, bootstraps, source); + return new TransportServer(this, null, port, handler, bootstraps, source); } /** Create a server which will attempt to bind to a specific host and port. */ public TransportServer createServer( String host, int port, List bootstraps) { - return new TransportServer(this, host, port, rpcHandler, bootstraps); + return new TransportServer(this, host, port, handler, bootstraps); } public TransportServer createServer() { @@ -139,7 +139,7 @@ public TransportServer createServer() { } public TransportChannelHandler initializePipeline(SocketChannel channel) { - return initializePipeline(channel, rpcHandler); + return initializePipeline(channel, handler); } /** @@ -156,7 +156,7 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { */ public TransportChannelHandler initializePipeline( SocketChannel channel, - RpcHandler channelRpcHandler) { + BaseMessageHandler channelRpcHandler) { try { if (channelHandler != null) { channel.pipeline() @@ -183,11 +183,12 @@ public TransportChannelHandler initializePipeline( * ResponseMessages. The channel is expected to have been successfully created, though certain * properties (such as the remoteAddress()) may not be available yet. */ - private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) { + private TransportChannelHandler createChannelHandler( + Channel channel, BaseMessageHandler handler) { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, - rpcHandler, conf.maxChunksBeingTransferred(), source); + handler); return new TransportChannelHandler(client, responseHandler, requestHandler, conf.connectionTimeoutMs(), closeIdleConnections); } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamCallback.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamCallback.java deleted file mode 100644 index d9197e80a01..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamCallback.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.client; - -import java.io.IOException; -import java.nio.ByteBuffer; - -/** - * Callback for streaming data. Stream data will be offered to the - * {@link #onData(String, ByteBuffer)} method as it arrives. Once all the stream data is received, - * {@link #onComplete(String)} will be called. - *

- * The network library guarantees that a single thread will call these methods at a time, but - * different call may be made by different threads. - */ -public interface StreamCallback { - /** Called upon receipt of stream data. */ - void onData(String streamId, ByteBuffer buf) throws IOException; - - /** Called when all data from the stream has been received. */ - void onComplete(String streamId) throws IOException; - - /** Called if there's an error reading data from the stream. */ - void onFailure(String streamId, Throwable cause) throws IOException; -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamCallbackWithID.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamCallbackWithID.java deleted file mode 100644 index 523c0207003..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamCallbackWithID.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.client; - -public interface StreamCallbackWithID extends StreamCallback { - String getID(); -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamInterceptor.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamInterceptor.java deleted file mode 100644 index d0641e35360..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/StreamInterceptor.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.client; - -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; - -import io.netty.buffer.ByteBuf; - -import com.aliyun.emr.rss.common.network.protocol.Message; -import com.aliyun.emr.rss.common.network.server.MessageHandler; -import com.aliyun.emr.rss.common.network.util.TransportFrameDecoder; - -/** - * An interceptor that is registered with the frame decoder to feed stream data to a - * callback. - */ -public class StreamInterceptor implements TransportFrameDecoder.Interceptor { - - private final MessageHandler handler; - private final String streamId; - private final long byteCount; - private final StreamCallback callback; - private long bytesRead; - - public StreamInterceptor( - MessageHandler handler, - String streamId, - long byteCount, - StreamCallback callback) { - this.handler = handler; - this.streamId = streamId; - this.byteCount = byteCount; - this.callback = callback; - this.bytesRead = 0; - } - - @Override - public void exceptionCaught(Throwable cause) throws Exception { - deactivateStream(); - callback.onFailure(streamId, cause); - } - - @Override - public void channelInactive() throws Exception { - deactivateStream(); - callback.onFailure(streamId, new ClosedChannelException()); - } - - private void deactivateStream() { - if (handler instanceof TransportResponseHandler) { - // we only have to do this for TransportResponseHandler as it exposes numOutstandingFetches - // (there is no extra cleanup that needs to happen) - ((TransportResponseHandler) handler).deactivateStream(); - } - } - - @Override - public boolean handle(ByteBuf buf) throws Exception { - int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); - ByteBuffer nioBuffer = buf.readSlice(toRead).nioBuffer(); - - int available = nioBuffer.remaining(); - callback.onData(streamId, nioBuffer); - bytesRead += available; - if (bytesRead > byteCount) { - RuntimeException re = new IllegalStateException(String.format( - "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); - callback.onFailure(streamId, re); - deactivateStream(); - throw re; - } else if (bytesRead == byteCount) { - deactivateStream(); - callback.onComplete(streamId); - } - - return bytesRead != byteCount; - } - -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClient.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClient.java index d951665b2db..87beea110da 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClient.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClient.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.net.SocketAddress; import java.nio.ByteBuffer; -import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -39,7 +38,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; import com.aliyun.emr.rss.common.network.buffer.NioManagedBuffer; import com.aliyun.emr.rss.common.network.protocol.*; import com.aliyun.emr.rss.common.network.util.NettyUtils; @@ -115,6 +113,13 @@ public void setClientId(String id) { this.clientId = id; } + public void fetchChunk( + long streamId, + int chunkIndex, + ChunkReceivedCallback callback) { + fetchChunk(streamId, chunkIndex, 0, Integer.MAX_VALUE, callback); + } + /** * Requests a single chunk from the remote side, from the pre-negotiated streamId. * @@ -128,55 +133,32 @@ public void setClientId(String id) { * @param streamId Identifier that refers to a stream in the remote StreamManager. This should * be agreed upon by client and server beforehand. * @param chunkIndex 0-based index of the chunk to fetch + * @param offset offset from the beginning of the chunk to fetch + * @param len size to fetch * @param callback Callback invoked upon successful receipt of chunk, or upon any failure. */ public void fetchChunk( long streamId, int chunkIndex, + int offset, + int len, ChunkReceivedCallback callback) { if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}.", chunkIndex, NettyUtils.getRemoteAddress(channel)); } - StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - StdChannelListener listener = new StdChannelListener(streamChunkId) { + StreamChunkSlice streamChunkSlice = new StreamChunkSlice(streamId, chunkIndex, offset, len); + StdChannelListener listener = new StdChannelListener(streamChunkSlice) { @Override protected void handleFailure(String errorMsg, Throwable cause) { - handler.removeFetchRequest(streamChunkId); + handler.removeFetchRequest(streamChunkSlice); callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); } }; - handler.addFetchRequest(streamChunkId, callback); + handler.addFetchRequest(streamChunkSlice, callback); - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener); - } - - /** - * Request to stream the data with the given stream ID from the remote end. - * - * @param streamId The stream to fetch. - * @param callback Object to call with the stream data. - */ - public void stream(String streamId, StreamCallback callback) { - StdChannelListener listener = new StdChannelListener(streamId) { - @Override - protected void handleFailure(String errorMsg, Throwable cause) throws Exception { - callback.onFailure(streamId, new IOException(errorMsg, cause)); - } - }; - if (logger.isDebugEnabled()) { - logger.debug("Sending stream request for {} to {}.", - streamId, NettyUtils.getRemoteAddress(channel)); - } - - // Need to synchronize here so that the callback is added to the queue and the RPC is - // written to the socket atomically, so that callbacks are called in the right order - // when responses arrive. - synchronized (this) { - handler.addStreamCallback(streamId, callback); - channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener); - } + channel.writeAndFlush(new ChunkFetchRequest(streamChunkSlice)).addListener(listener); } /** @@ -192,7 +174,7 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { logger.trace("Sending RPC to {}", NettyUtils.getRemoteAddress(channel)); } - long requestId = dataRequestId(); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); RpcChannelListener listener = new RpcChannelListener(requestId, callback); @@ -207,7 +189,7 @@ public ChannelFuture pushData(PushData pushData, RpcResponseCallback callback) { logger.trace("Pushing data to {}", NettyUtils.getRemoteAddress(channel)); } - long requestId = dataRequestId(); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); pushData.requestId = requestId; @@ -221,7 +203,7 @@ public ChannelFuture pushMergedData(PushMergedData pushMergedData, RpcResponseCa logger.trace("Pushing merged data to {}", NettyUtils.getRemoteAddress(channel)); } - long requestId = dataRequestId(); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); pushMergedData.requestId = requestId; @@ -230,61 +212,6 @@ public ChannelFuture pushMergedData(PushMergedData pushMergedData, RpcResponseCa return channel.writeAndFlush(pushMergedData).addListener(listener); } - public ByteBuffer pushMergedDataSync(PushMergedData pushMergedData, long timeoutMs) { - final SettableFuture result = SettableFuture.create(); - - pushMergedData(pushMergedData, new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - ByteBuffer copy = ByteBuffer.allocate(response.remaining()); - copy.put(response); - copy.flip(); - result.set(copy); - } - - @Override - public void onFailure(Throwable e) { - result.setException(e); - } - }); - - try { - return result.get(timeoutMs, TimeUnit.MILLISECONDS); - } catch (ExecutionException e) { - throw Throwables.propagate(e.getCause()); - } catch (Exception e) { - throw Throwables.propagate(e); - } - } - - /** - * Send data to the remote end as a stream. This differs from stream() in that this is a request - * to *send* data to the remote end, not to receive it from the remote. - * - * @param meta meta data associated with the stream, which will be read completely on the - * receiving end before the stream itself. - * @param data this will be streamed to the remote end to allow for transferring large amounts - * of data without reading into memory. - * @param callback handles the reply -- onSuccess will only be called when both message and data - * are received successfully. - */ - public long uploadStream( - ManagedBuffer meta, - ManagedBuffer data, - RpcResponseCallback callback) { - if (logger.isTraceEnabled()) { - logger.trace("Sending RPC to {}", NettyUtils.getRemoteAddress(channel)); - } - - long requestId = requestId(); - handler.addRpcRequest(requestId, callback); - - RpcChannelListener listener = new RpcChannelListener(requestId, callback); - channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener); - - return requestId; - } - /** * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. @@ -361,12 +288,8 @@ public String toString() { .toString(); } - public static long requestId() { - return Math.abs(UUID.randomUUID().getLeastSignificantBits()); - } - private static final AtomicLong counter = new AtomicLong(); - public static long dataRequestId() { + public static long requestId() { return counter.getAndIncrement(); } @@ -394,8 +317,8 @@ public void operationComplete(Future future) throws Exception { NettyUtils.getRemoteAddress(channel), timeTaken); } } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - NettyUtils.getRemoteAddress(channel), future.cause()); + String errorMsg = String.format("Failed to send RPC %s to %s: %s, channel will be closed", + requestId, NettyUtils.getRemoteAddress(channel), future.cause()); logger.warn(errorMsg); channel.close(); try { @@ -406,7 +329,9 @@ public void operationComplete(Future future) throws Exception { } } - protected void handleFailure(String errorMsg, Throwable cause) throws Exception {} + protected void handleFailure(String errorMsg, Throwable cause) { + logger.error("Error encountered " + errorMsg, cause); + } } private class RpcChannelListener extends StdChannelListener { @@ -425,5 +350,4 @@ protected void handleFailure(String errorMsg, Throwable cause) { callback.onFailure(new IOException(errorMsg, cause)); } } - } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClientFactory.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClientFactory.java index 6b06f8c7584..575d9e85744 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClientFactory.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportClientFactory.java @@ -119,7 +119,7 @@ public TransportClientFactory( * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort, int reduceId) + public TransportClient createClient(String remoteHost, int remotePort, int partitionId) throws IOException, InterruptedException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. @@ -135,7 +135,7 @@ public TransportClient createClient(String remoteHost, int remotePort, int reduc } int clientIndex = - reduceId < 0 ? rand.nextInt(numConnectionsPerPeer) : reduceId % numConnectionsPerPeer; + partitionId < 0 ? rand.nextInt(numConnectionsPerPeer) : partitionId % numConnectionsPerPeer; TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java index 328d8c79666..7931e4014f1 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java @@ -19,22 +19,16 @@ import java.io.IOException; import java.util.Map; -import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; -import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.aliyun.emr.rss.common.network.protocol.*; import com.aliyun.emr.rss.common.network.server.MessageHandler; import com.aliyun.emr.rss.common.network.util.NettyUtils; -import com.aliyun.emr.rss.common.network.util.TransportFrameDecoder; /** * Handler that processes server responses, in response to requests issued from a @@ -47,13 +41,10 @@ public class TransportResponseHandler extends MessageHandler { private final Channel channel; - private final Map outstandingFetches; + private final Map outstandingFetches; private final Map outstandingRpcs; - private final Queue> streamCallbacks; - private volatile boolean streamActive; - /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ private final AtomicLong timeOfLastRequestNs; @@ -61,17 +52,16 @@ public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap<>(); this.outstandingRpcs = new ConcurrentHashMap<>(); - this.streamCallbacks = new ConcurrentLinkedQueue<>(); this.timeOfLastRequestNs = new AtomicLong(0); } - public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + public void addFetchRequest(StreamChunkSlice streamChunkSlice, ChunkReceivedCallback callback) { updateTimeOfLastRequest(); - outstandingFetches.put(streamChunkId, callback); + outstandingFetches.put(streamChunkSlice, callback); } - public void removeFetchRequest(StreamChunkId streamChunkId) { - outstandingFetches.remove(streamChunkId); + public void removeFetchRequest(StreamChunkSlice streamChunkSlice) { + outstandingFetches.remove(streamChunkSlice); } public void addRpcRequest(long requestId, RpcResponseCallback callback) { @@ -83,22 +73,12 @@ public void removeRpcRequest(long requestId) { outstandingRpcs.remove(requestId); } - public void addStreamCallback(String streamId, StreamCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); - streamCallbacks.offer(ImmutablePair.of(streamId, callback)); - } - - @VisibleForTesting - public void deactivateStream() { - streamActive = false; - } - /** * Fire the failure callback for all outstanding requests. This is called when we have an * uncaught exception or pre-mature connection termination. */ private void failOutstandingRequests(Throwable cause) { - for (Map.Entry entry : outstandingFetches.entrySet()) { + for (Map.Entry entry : outstandingFetches.entrySet()) { try { entry.getValue().onFailure(entry.getKey().chunkIndex, cause); } catch (Exception e) { @@ -112,18 +92,10 @@ private void failOutstandingRequests(Throwable cause) { logger.warn("RpcResponseCallback.onFailure throws exception", e); } } - for (Pair entry : streamCallbacks) { - try { - entry.getValue().onFailure(entry.getKey(), cause); - } catch (Exception e) { - logger.warn("StreamCallback.onFailure throws exception", e); - } - } // It's OK if new fetches appear, as they will fail immediately. outstandingFetches.clear(); outstandingRpcs.clear(); - streamCallbacks.clear(); } @Override @@ -154,26 +126,26 @@ public void exceptionCaught(Throwable cause) { public void handle(ResponseMessage message) throws Exception { if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; - ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkSlice); if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", - resp.streamChunkId, NettyUtils.getRemoteAddress(channel)); + resp.streamChunkSlice, NettyUtils.getRemoteAddress(channel)); resp.body().release(); } else { - outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + outstandingFetches.remove(resp.streamChunkSlice); + listener.onSuccess(resp.streamChunkSlice.chunkIndex, resp.body()); resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { ChunkFetchFailure resp = (ChunkFetchFailure) message; - ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkSlice); if (listener == null) { logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", - resp.streamChunkId, NettyUtils.getRemoteAddress(channel), resp.errorString); + resp.streamChunkSlice, NettyUtils.getRemoteAddress(channel), resp.errorString); } else { - outstandingFetches.remove(resp.streamChunkId); - listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException( - "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString)); + outstandingFetches.remove(resp.streamChunkSlice); + listener.onFailure(resp.streamChunkSlice.chunkIndex, new ChunkFetchFailureException( + "Failure while fetching " + resp.streamChunkSlice + ": " + resp.errorString)); } } else if (message instanceof RpcResponse) { RpcResponse resp = (RpcResponse) message; @@ -199,46 +171,6 @@ public void handle(ResponseMessage message) throws Exception { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } - } else if (message instanceof StreamResponse) { - StreamResponse resp = (StreamResponse) message; - Pair entry = streamCallbacks.poll(); - if (entry != null) { - StreamCallback callback = entry.getValue(); - if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor<>( - this, resp.streamId, resp.byteCount, callback); - try { - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - frameDecoder.setInterceptor(interceptor); - streamActive = true; - } catch (Exception e) { - logger.error("Error installing stream handler.", e); - deactivateStream(); - } - } else { - try { - callback.onComplete(resp.streamId); - } catch (Exception e) { - logger.warn("Error in stream handler onComplete().", e); - } - } - } else { - logger.error("Could not find callback for StreamResponse."); - } - } else if (message instanceof StreamFailure) { - StreamFailure resp = (StreamFailure) message; - Pair entry = streamCallbacks.poll(); - if (entry != null) { - StreamCallback callback = entry.getValue(); - try { - callback.onFailure(resp.streamId, new RuntimeException(resp.error)); - } catch (IOException ioe) { - logger.warn("Error in stream failure handler.", ioe); - } - } else { - logger.warn("Stream failure with unknown callback: {}", resp.error); - } } else { throw new IllegalStateException("Unknown response type: " + message.type()); } @@ -246,8 +178,7 @@ public void handle(ResponseMessage message) throws Exception { /** Returns total number of outstanding requests (fetch requests + rpcs) */ public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + - (streamActive ? 1 : 0); + return outstandingFetches.size() + outstandingRpcs.size(); } /** Returns the time in nanoseconds of when the last request was sent out. */ diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/AbstractMessage.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/AbstractMessage.java index f429cd307f6..5b936287472 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/AbstractMessage.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/AbstractMessage.java @@ -17,7 +17,11 @@ package com.aliyun.emr.rss.common.network.protocol; +import java.nio.ByteBuffer; + import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; @@ -51,4 +55,24 @@ protected boolean equals(AbstractMessage other) { return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body); } + public ByteBuffer toByteBuffer() { + // Allow room for encoded message, plus the type byte + ByteBuf buf = Unpooled.buffer(encodedLength() + 1); + buf.writeByte(type().id()); + encode(buf); + assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); + return buf.nioBuffer(); + } + + public static AbstractMessage fromByteBuffer(ByteBuffer msg) { + ByteBuf buf = Unpooled.wrappedBuffer(msg); + Type type = Type.decode(buf); + switch (type) { + case OpenStream: + return OpenStream.decode(buf); + case StreamHandle: + return StreamHandle.decode(buf); + default: throw new IllegalArgumentException("Unknown message type: " + type); + } + } } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchFailure.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchFailure.java index 270829feb8f..cea8ac93ed4 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchFailure.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchFailure.java @@ -24,11 +24,11 @@ * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage { - public final StreamChunkId streamChunkId; + public final StreamChunkSlice streamChunkSlice; public final String errorString; - public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { - this.streamChunkId = streamChunkId; + public ChunkFetchFailure(StreamChunkSlice streamChunkSlice, String errorString) { + this.streamChunkSlice = streamChunkSlice; this.errorString = errorString; } @@ -37,31 +37,31 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { @Override public int encodedLength() { - return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString); + return streamChunkSlice.encodedLength() + Encoders.Strings.encodedLength(errorString); } @Override public void encode(ByteBuf buf) { - streamChunkId.encode(buf); + streamChunkSlice.encode(buf); Encoders.Strings.encode(buf, errorString); } public static ChunkFetchFailure decode(ByteBuf buf) { - StreamChunkId streamChunkId = StreamChunkId.decode(buf); + StreamChunkSlice streamChunkSlice = StreamChunkSlice.decode(buf); String errorString = Encoders.Strings.decode(buf); - return new ChunkFetchFailure(streamChunkId, errorString); + return new ChunkFetchFailure(streamChunkSlice, errorString); } @Override public int hashCode() { - return Objects.hashCode(streamChunkId, errorString); + return Objects.hashCode(streamChunkSlice, errorString); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchFailure) { ChunkFetchFailure o = (ChunkFetchFailure) other; - return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString); + return streamChunkSlice.equals(o.streamChunkSlice) && errorString.equals(o.errorString); } return false; } @@ -69,7 +69,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("streamChunkId", streamChunkId) + .add("streamChunkId", streamChunkSlice) .add("errorString", errorString) .toString(); } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchRequest.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchRequest.java index 91024301f51..7ad3f8c8044 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchRequest.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchRequest.java @@ -25,10 +25,10 @@ * {@link ResponseMessage} (either success or failure). */ public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage { - public final StreamChunkId streamChunkId; + public final StreamChunkSlice streamChunkSlice; - public ChunkFetchRequest(StreamChunkId streamChunkId) { - this.streamChunkId = streamChunkId; + public ChunkFetchRequest(StreamChunkSlice streamChunkSlice) { + this.streamChunkSlice = streamChunkSlice; } @Override @@ -36,28 +36,28 @@ public ChunkFetchRequest(StreamChunkId streamChunkId) { @Override public int encodedLength() { - return streamChunkId.encodedLength(); + return streamChunkSlice.encodedLength(); } @Override public void encode(ByteBuf buf) { - streamChunkId.encode(buf); + streamChunkSlice.encode(buf); } public static ChunkFetchRequest decode(ByteBuf buf) { - return new ChunkFetchRequest(StreamChunkId.decode(buf)); + return new ChunkFetchRequest(StreamChunkSlice.decode(buf)); } @Override public int hashCode() { - return streamChunkId.hashCode(); + return streamChunkSlice.hashCode(); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchRequest) { ChunkFetchRequest o = (ChunkFetchRequest) other; - return streamChunkId.equals(o.streamChunkId); + return streamChunkSlice.equals(o.streamChunkSlice); } return false; } @@ -65,7 +65,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("streamChunkId", streamChunkId) + .add("streamChunkId", streamChunkSlice) .toString(); } } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchSuccess.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchSuccess.java index 9ac06edc204..bd0b0b6f260 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchSuccess.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/ChunkFetchSuccess.java @@ -31,11 +31,11 @@ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ public final class ChunkFetchSuccess extends AbstractResponseMessage { - public final StreamChunkId streamChunkId; + public final StreamChunkSlice streamChunkSlice; - public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { + public ChunkFetchSuccess(StreamChunkSlice streamChunkSlice, ManagedBuffer buffer) { super(buffer, true); - this.streamChunkId = streamChunkId; + this.streamChunkSlice = streamChunkSlice; } @Override @@ -43,38 +43,38 @@ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { @Override public int encodedLength() { - return streamChunkId.encodedLength(); + return streamChunkSlice.encodedLength(); } /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ @Override public void encode(ByteBuf buf) { - streamChunkId.encode(buf); + streamChunkSlice.encode(buf); } @Override public ResponseMessage createFailureResponse(String error) { - return new ChunkFetchFailure(streamChunkId, error); + return new ChunkFetchFailure(streamChunkSlice, error); } /** Decoding uses the given ByteBuf as our data, and will retain() it. */ public static ChunkFetchSuccess decode(ByteBuf buf) { - StreamChunkId streamChunkId = StreamChunkId.decode(buf); + StreamChunkSlice streamChunkSlice = StreamChunkSlice.decode(buf); buf.retain(); NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); - return new ChunkFetchSuccess(streamChunkId, managedBuf); + return new ChunkFetchSuccess(streamChunkSlice, managedBuf); } @Override public int hashCode() { - return Objects.hashCode(streamChunkId, body()); + return Objects.hashCode(streamChunkSlice, body()); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchSuccess) { ChunkFetchSuccess o = (ChunkFetchSuccess) other; - return streamChunkId.equals(o.streamChunkId) && super.equals(o); + return streamChunkSlice.equals(o.streamChunkSlice) && super.equals(o); } return false; } @@ -82,7 +82,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("streamChunkId", streamChunkId) + .add("streamChunkId", streamChunkSlice) .add("buffer", body()) .toString(); } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/Encoders.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/Encoders.java index ca8e29e7cc1..fde272bc536 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/Encoders.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/Encoders.java @@ -44,25 +44,6 @@ public static String decode(ByteBuf buf) { } } - /** Byte arrays are encoded with their length followed by bytes. */ - public static class ByteArrays { - public static int encodedLength(byte[] arr) { - return 4 + arr.length; - } - - public static void encode(ByteBuf buf, byte[] arr) { - buf.writeInt(arr.length); - buf.writeBytes(arr); - } - - public static byte[] decode(ByteBuf buf) { - int length = buf.readInt(); - byte[] bytes = new byte[length]; - buf.readBytes(bytes); - return bytes; - } - } - /** Int arrays are encoded with their length followed by ints. */ public static class IntArrays { public static int encodedLength(int[] ints) { diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/Message.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/Message.java index b69a309fd8b..d391d9dacfb 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/Message.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/Message.java @@ -35,9 +35,8 @@ public interface Message extends Encodable { /** Preceding every serialized Message is its type, which allows us to deserialize it. */ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), - RpcRequest(3), RpcResponse(4), RpcFailure(5), - StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), UploadStream(10), PushData(11), PushMergedData(12), User(-1); + RpcRequest(3), RpcResponse(4), RpcFailure(5), OpenStream(6), StreamHandle(7), + OneWayMessage(9), PushData(11), PushMergedData(12); private final byte id; @@ -61,11 +60,9 @@ public static Type decode(ByteBuf buf) { case 3: return RpcRequest; case 4: return RpcResponse; case 5: return RpcFailure; - case 6: return StreamRequest; - case 7: return StreamResponse; - case 8: return StreamFailure; + case 6: return OpenStream; + case 7: return StreamHandle; case 9: return OneWayMessage; - case 10: return UploadStream; case 11: return PushData; case 12: return PushMergedData; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/MessageDecoder.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/MessageDecoder.java index 31b5643365e..53acd7b45d7 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/MessageDecoder.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/MessageDecoder.java @@ -71,18 +71,6 @@ private Message decode(Message.Type msgType, ByteBuf in) { case OneWayMessage: return OneWayMessage.decode(in); - case StreamRequest: - return StreamRequest.decode(in); - - case StreamResponse: - return StreamResponse.decode(in); - - case StreamFailure: - return StreamFailure.decode(in); - - case UploadStream: - return UploadStream.decode(in); - case PushData: return PushData.decode(in); diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/OneWayMessage.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/OneWayMessage.java index d04d19672d5..ff4bcff689e 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/OneWayMessage.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/OneWayMessage.java @@ -22,11 +22,11 @@ import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; import com.aliyun.emr.rss.common.network.buffer.NettyManagedBuffer; -import com.aliyun.emr.rss.common.network.server.RpcHandler; +import com.aliyun.emr.rss.common.network.server.BaseMessageHandler; /** * A RPC that does not expect a reply, which is handled by a remote - * {@link RpcHandler}. + * {@link BaseMessageHandler}. */ public final class OneWayMessage extends AbstractMessage implements RequestMessage { diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/OpenStream.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/OpenStream.java new file mode 100644 index 00000000000..08f8f318236 --- /dev/null +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/OpenStream.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.aliyun.emr.rss.common.network.protocol; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Request to read a set of blocks. Returns {@link StreamHandle}. */ +public final class OpenStream extends AbstractMessage { + public byte[] shuffleKey; + public byte[] fileName; + public int startMapIndex; + public int endMapIndex; + + public OpenStream(String shuffleKey, String fileName, int startMapIndex, int endMapIndex) { + this(shuffleKey.getBytes(StandardCharsets.UTF_8), + fileName.getBytes(StandardCharsets.UTF_8), + startMapIndex, endMapIndex); + } + + public OpenStream(byte[] shuffleKey, byte[] fileName, int startMapIndex, int endMapIndex) { + this.shuffleKey = shuffleKey; + this.fileName = fileName; + this.startMapIndex = startMapIndex; + this.endMapIndex = endMapIndex; + } + + @Override + public Type type() { return Type.OpenStream; } + + @Override + public int encodedLength() { + return 4 + shuffleKey.length + + 4 + fileName.length + + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeInt(shuffleKey.length); + buf.writeBytes(shuffleKey); + buf.writeInt(fileName.length); + buf.writeBytes(fileName); + buf.writeInt(startMapIndex); + buf.writeInt(endMapIndex); + } + + public static OpenStream decode(ByteBuf buf) { + int shuffleKeySize = buf.readInt(); + byte[] shuffleKey = new byte[shuffleKeySize]; + buf.readBytes(shuffleKey); + int fileNameSize = buf.readInt(); + byte[] fileName = new byte[fileNameSize]; + buf.readBytes(fileName); + return new OpenStream(shuffleKey, fileName, buf.readInt(), buf.readInt()); + } + + @Override + public int hashCode() { + return Objects.hashCode(shuffleKey, fileName, startMapIndex, endMapIndex); + } + + @Override + public boolean equals(Object other) { + if (other instanceof OpenStream) { + OpenStream o = (OpenStream) other; + return startMapIndex == o.startMapIndex && + endMapIndex == o.endMapIndex && + Arrays.equals(shuffleKey, o.shuffleKey) && + Arrays.equals(fileName, o.fileName); + + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("shuffleKey", new String(shuffleKey, StandardCharsets.UTF_8)) + .add("fileName", new String(fileName, StandardCharsets.UTF_8)) + .add("startMapIndex", startMapIndex) + .add("endMapIndex", endMapIndex) + .toString(); + } +} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/RpcRequest.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/RpcRequest.java index 14fbc54774d..3b36ecd5a2a 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/RpcRequest.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/RpcRequest.java @@ -22,10 +22,10 @@ import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; import com.aliyun.emr.rss.common.network.buffer.NettyManagedBuffer; -import com.aliyun.emr.rss.common.network.server.RpcHandler; +import com.aliyun.emr.rss.common.network.server.BaseMessageHandler; /** - * A generic RPC which is handled by a remote {@link RpcHandler}. + * A generic RPC which is handled by a remote {@link BaseMessageHandler}. * This will correspond to a single * {@link ResponseMessage} (either success or failure). */ diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamChunkId.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamChunkSlice.java similarity index 62% rename from common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamChunkId.java rename to common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamChunkSlice.java index 9299951856a..3c4b4e1a002 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamChunkId.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamChunkSlice.java @@ -23,42 +23,60 @@ /** * Encapsulates a request for a particular chunk of a stream. */ -public final class StreamChunkId implements Encodable { +public final class StreamChunkSlice implements Encodable { public final long streamId; public final int chunkIndex; + /** offset from the beginning of the chunk */ + public final int offset; + /** size to read */ + public final int len; - public StreamChunkId(long streamId, int chunkIndex) { + public StreamChunkSlice(long streamId, int chunkIndex) { this.streamId = streamId; this.chunkIndex = chunkIndex; + this.offset = 0; + this.len = Integer.MAX_VALUE; + } + + public StreamChunkSlice(long streamId, int chunkIndex, int offset, int len) { + this.streamId = streamId; + this.chunkIndex = chunkIndex; + this.offset = offset; + this.len = len; } @Override public int encodedLength() { - return 8 + 4; + return 20; } public void encode(ByteBuf buffer) { buffer.writeLong(streamId); buffer.writeInt(chunkIndex); + buffer.writeInt(offset); + buffer.writeInt(len); } - public static StreamChunkId decode(ByteBuf buffer) { - assert buffer.readableBytes() >= 8 + 4; + public static StreamChunkSlice decode(ByteBuf buffer) { + assert buffer.readableBytes() >= 20; long streamId = buffer.readLong(); int chunkIndex = buffer.readInt(); - return new StreamChunkId(streamId, chunkIndex); + int offset = buffer.readInt(); + int len = buffer.readInt(); + return new StreamChunkSlice(streamId, chunkIndex, offset, len); } @Override public int hashCode() { - return Objects.hashCode(streamId, chunkIndex); + return Objects.hashCode(streamId, chunkIndex, offset, len); } @Override public boolean equals(Object other) { - if (other instanceof StreamChunkId) { - StreamChunkId o = (StreamChunkId) other; - return streamId == o.streamId && chunkIndex == o.chunkIndex; + if (other instanceof StreamChunkSlice) { + StreamChunkSlice o = (StreamChunkSlice) other; + return streamId == o.streamId && chunkIndex == o.chunkIndex && + offset == o.offset && len == o.len; } return false; } @@ -68,6 +86,8 @@ public String toString() { return Objects.toStringHelper(this) .add("streamId", streamId) .add("chunkIndex", chunkIndex) + .add("offset", offset) + .add("len", len) .toString(); } } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamFailure.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamFailure.java deleted file mode 100644 index e54e05158bf..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamFailure.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -/** - * Message indicating an error when transferring a stream. - */ -public final class StreamFailure extends AbstractMessage implements ResponseMessage { - public final String streamId; - public final String error; - - public StreamFailure(String streamId, String error) { - this.streamId = streamId; - this.error = error; - } - - @Override - public Type type() { return Type.StreamFailure; } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(streamId) + Encoders.Strings.encodedLength(error); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, streamId); - Encoders.Strings.encode(buf, error); - } - - public static StreamFailure decode(ByteBuf buf) { - String streamId = Encoders.Strings.decode(buf); - String error = Encoders.Strings.decode(buf); - return new StreamFailure(streamId, error); - } - - @Override - public int hashCode() { - return Objects.hashCode(streamId, error); - } - - @Override - public boolean equals(Object other) { - if (other instanceof StreamFailure) { - StreamFailure o = (StreamFailure) other; - return streamId.equals(o.streamId) && error.equals(o.error); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamId", streamId) - .add("error", error) - .toString(); - } - -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamRequest.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamHandle.java similarity index 62% rename from common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamRequest.java rename to common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamHandle.java index 240c26d0bd8..2431b436c23 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamRequest.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamHandle.java @@ -21,46 +21,46 @@ import io.netty.buffer.ByteBuf; /** - * Request to stream data from the remote end. - *

- * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before - * the data can be streamed. + * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" + * message. */ -public final class StreamRequest extends AbstractMessage implements RequestMessage { - public final String streamId; +public final class StreamHandle extends AbstractMessage { + public final long streamId; + public final int numChunks; - public StreamRequest(String streamId) { + public StreamHandle(long streamId, int numChunks) { this.streamId = streamId; + this.numChunks = numChunks; } @Override - public Type type() { return Type.StreamRequest; } + public Type type() { return Type.StreamHandle; } @Override public int encodedLength() { - return Encoders.Strings.encodedLength(streamId); + return 8 + 4; } @Override public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, streamId); + buf.writeLong(streamId); + buf.writeInt(numChunks); } - public static StreamRequest decode(ByteBuf buf) { - String streamId = Encoders.Strings.decode(buf); - return new StreamRequest(streamId); + public static StreamHandle decode(ByteBuf buf) { + return new StreamHandle(buf.readLong(), buf.readInt()); } @Override public int hashCode() { - return Objects.hashCode(streamId); + return Objects.hashCode(streamId, numChunks); } @Override public boolean equals(Object other) { - if (other instanceof StreamRequest) { - StreamRequest o = (StreamRequest) other; - return streamId.equals(o.streamId); + if (other instanceof StreamHandle) { + StreamHandle o = (StreamHandle) other; + return streamId == o.streamId && numChunks == o.numChunks; } return false; } @@ -69,7 +69,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("streamId", streamId) + .add("numChunks", numChunks) .toString(); } - } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamResponse.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamResponse.java deleted file mode 100644 index 25c9f8612c8..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/StreamResponse.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; - -/** - * Response to {@link StreamRequest} when the stream has been successfully opened. - *

- * Note the message itself does not contain the stream data. That is written separately by the - * sender. The receiver is expected to set a temporary channel handler that will consume the - * number of bytes this message says the stream has. - */ -public final class StreamResponse extends AbstractResponseMessage { - public final String streamId; - public final long byteCount; - - public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { - super(buffer, false); - this.streamId = streamId; - this.byteCount = byteCount; - } - - @Override - public Type type() { return Type.StreamResponse; } - - @Override - public int encodedLength() { - return 8 + Encoders.Strings.encodedLength(streamId); - } - - /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, streamId); - buf.writeLong(byteCount); - } - - @Override - public ResponseMessage createFailureResponse(String error) { - return new StreamFailure(streamId, error); - } - - public static StreamResponse decode(ByteBuf buf) { - String streamId = Encoders.Strings.decode(buf); - long byteCount = buf.readLong(); - return new StreamResponse(streamId, byteCount, null); - } - - @Override - public int hashCode() { - return Objects.hashCode(byteCount, streamId); - } - - @Override - public boolean equals(Object other) { - if (other instanceof StreamResponse) { - StreamResponse o = (StreamResponse) other; - return byteCount == o.byteCount && streamId.equals(o.streamId); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamId", streamId) - .add("byteCount", byteCount) - .add("body", body()) - .toString(); - } - -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/UploadStream.java b/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/UploadStream.java deleted file mode 100644 index 5cb8e1daec8..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/protocol/UploadStream.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.protocol; - -import java.io.IOException; -import java.nio.ByteBuffer; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; -import com.aliyun.emr.rss.common.network.buffer.NettyManagedBuffer; - -/** - * An RPC with data that is sent outside of the frame, so it can be read as a stream. - */ -public final class UploadStream extends AbstractMessage implements RequestMessage { - /** Used to link an RPC request with its response. */ - public final long requestId; - public final ManagedBuffer meta; - public final long bodyByteCount; - - public UploadStream(long requestId, ManagedBuffer meta, ManagedBuffer body) { - super(body, false); // body is *not* included in the frame - this.requestId = requestId; - this.meta = meta; - bodyByteCount = body.size(); - } - - // this version is called when decoding the bytes on the receiving end. The body is handled - // separately. - private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { - super(null, false); - this.requestId = requestId; - this.meta = meta; - this.bodyByteCount = bodyByteCount; - } - - @Override - public Type type() { return Type.UploadStream; } - - @Override - public int encodedLength() { - // the requestId, meta size, meta and bodyByteCount (body is not included) - return 8 + 4 + ((int) meta.size()) + 8; - } - - @Override - public void encode(ByteBuf buf) { - buf.writeLong(requestId); - try { - ByteBuffer metaBuf = meta.nioByteBuffer(); - buf.writeInt(metaBuf.remaining()); - buf.writeBytes(metaBuf); - } catch (IOException io) { - throw new RuntimeException(io); - } - buf.writeLong(bodyByteCount); - } - - public static UploadStream decode(ByteBuf buf) { - long requestId = buf.readLong(); - int metaSize = buf.readInt(); - ManagedBuffer meta = new NettyManagedBuffer(buf.readRetainedSlice(metaSize)); - long bodyByteCount = buf.readLong(); - // This is called by the frame decoder, so the data is still null. We need a StreamInterceptor - // to read the data. - return new UploadStream(requestId, meta, bodyByteCount); - } - - @Override - public int hashCode() { - return Long.hashCode(requestId); - } - - @Override - public boolean equals(Object other) { - if (other instanceof UploadStream) { - UploadStream o = (UploadStream) other; - return requestId == o.requestId && super.equals(o); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("requestId", requestId) - .add("body", body()) - .toString(); - } -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/server/NoOpRpcHandler.java b/common/src/main/java/com/aliyun/emr/rss/common/network/server/BaseMessageHandler.java similarity index 53% rename from common/src/main/java/com/aliyun/emr/rss/common/network/server/NoOpRpcHandler.java rename to common/src/main/java/com/aliyun/emr/rss/common/network/server/BaseMessageHandler.java index d35508f16ea..8ad95a00680 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/server/NoOpRpcHandler.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/server/BaseMessageHandler.java @@ -17,24 +17,34 @@ package com.aliyun.emr.rss.common.network.server; -import java.nio.ByteBuffer; - -import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; import com.aliyun.emr.rss.common.network.client.TransportClient; +import com.aliyun.emr.rss.common.network.protocol.RequestMessage; -/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ -public class NoOpRpcHandler extends RpcHandler { - private final StreamManager streamManager; +/** + * Handler for sendRPC() messages sent by {@link TransportClient}s. + */ +public class BaseMessageHandler { - public NoOpRpcHandler() { - streamManager = new OneForOneStreamManager(); + public void receive( + TransportClient client, + RequestMessage msg) { + throw new UnsupportedOperationException(); } - @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - throw new UnsupportedOperationException("Cannot handle messages"); + public boolean checkRegistered() { + throw new UnsupportedOperationException(); } - @Override - public StreamManager getStreamManager() { return streamManager; } + /** + * Invoked when the channel associated with the given client is active. + */ + public void channelActive(TransportClient client) { } + + /** + * Invoked when the channel associated with the given client is inactive. + * No further requests will come from this client. + */ + public void channelInactive(TransportClient client) { } + + public void exceptionCaught(Throwable cause, TransportClient client) { } } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/server/ManagedBufferIterator.java b/common/src/main/java/com/aliyun/emr/rss/common/network/server/FileManagedBuffers.java similarity index 68% rename from common/src/main/java/com/aliyun/emr/rss/common/network/server/ManagedBufferIterator.java rename to common/src/main/java/com/aliyun/emr/rss/common/network/server/FileManagedBuffers.java index 87aa15ee0f5..f6e7f09a3f8 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/server/ManagedBufferIterator.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/server/FileManagedBuffers.java @@ -20,13 +20,12 @@ import java.io.File; import java.io.IOException; import java.util.BitSet; -import java.util.Iterator; import com.aliyun.emr.rss.common.network.buffer.FileSegmentManagedBuffer; import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; import com.aliyun.emr.rss.common.network.util.TransportConf; -public class ManagedBufferIterator implements Iterator { +public class FileManagedBuffers { private final File file; private final long[] offsets; private final int numChunks; @@ -34,9 +33,9 @@ public class ManagedBufferIterator implements Iterator { private final BitSet chunkTracker; private final TransportConf conf; - private int index = 0; + private volatile boolean fullyRead = false; - public ManagedBufferIterator(FileInfo fileInfo, TransportConf conf) throws IOException { + public FileManagedBuffers(FileInfo fileInfo, TransportConf conf) throws IOException { file = fileInfo.file; numChunks = fileInfo.numChunks; if (numChunks > 0) { @@ -53,11 +52,8 @@ public ManagedBufferIterator(FileInfo fileInfo, TransportConf conf) throws IOExc this.conf = conf; } - @Override - public boolean hasNext() { - synchronized (chunkTracker) { - return chunkTracker.cardinality() < numChunks; - } + public int numChunks() { + return numChunks; } public boolean hasAlreadyRead(int chunkIndex) { @@ -66,23 +62,27 @@ public boolean hasAlreadyRead(int chunkIndex) { } } - @Override - public ManagedBuffer next() { - // This method is only used to clear the Managed Buffer when streamManager.connectionTerminated - // is called. + public ManagedBuffer chunk(int chunkIndex, int offset, int len) { synchronized (chunkTracker) { - index = chunkTracker.nextClearBit(index); + chunkTracker.set(chunkIndex, true); } - assert index < numChunks; - return chunk(index); + // offset of the beginning of the chunk in the file + final long chunkOffset = offsets[chunkIndex]; + final long chunkLength = offsets[chunkIndex + 1] - chunkOffset; + assert offset < chunkLength; + long length = Math.min(chunkLength - offset, len); + if (len + offset >= chunkLength) { + synchronized (chunkTracker) { + chunkTracker.set(chunkIndex); + } + if (chunkIndex == numChunks - 1) { + fullyRead = true; + } + } + return new FileSegmentManagedBuffer(conf, file, chunkOffset + offset, length); } - public ManagedBuffer chunk(int chunkIndex) { - synchronized (chunkTracker) { - chunkTracker.set(chunkIndex, true); - } - final long offset = offsets[chunkIndex]; - final long length = offsets[chunkIndex + 1] - offset; - return new FileSegmentManagedBuffer(conf, file, offset, length); + public boolean isFullyRead() { + return fullyRead; } } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/server/OneForOneStreamManager.java b/common/src/main/java/com/aliyun/emr/rss/common/network/server/OneForOneStreamManager.java index 4d61648f60a..4b3d6a46749 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/server/OneForOneStreamManager.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/server/OneForOneStreamManager.java @@ -17,7 +17,6 @@ package com.aliyun.emr.rss.common.network.server; -import java.util.Iterator; import java.util.Map; import java.util.Random; import java.util.concurrent.ConcurrentHashMap; @@ -47,7 +46,7 @@ public class OneForOneStreamManager extends StreamManager { /** State of a single stream. */ protected static class StreamState { final String appId; - final Iterator buffers; + final FileManagedBuffers buffers; // The channel associated to the stream final Channel associatedChannel; @@ -59,7 +58,7 @@ protected static class StreamState { // Used to keep track of the number of chunks being transferred and not finished yet. volatile long chunksBeingTransferred = 0L; - StreamState(String appId, Iterator buffers, Channel channel) { + StreamState(String appId, FileManagedBuffers buffers, Channel channel) { this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); this.associatedChannel = channel; @@ -74,24 +73,24 @@ public OneForOneStreamManager() { } @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { + public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) { StreamState state = streams.get(streamId); if (state == null) { throw new IllegalStateException(String.format( "Stream %s for chunk %s is not registered(Maybe removed).", streamId, chunkIndex)); - } else if (!state.buffers.hasNext()) { + } else if (chunkIndex >= state.buffers.numChunks()) { throw new IllegalStateException(String.format( "Requested chunk index beyond end %s", chunkIndex)); } - ManagedBufferIterator iterator = (ManagedBufferIterator) state.buffers; - if (iterator.hasAlreadyRead(chunkIndex)) { + FileManagedBuffers buffers = state.buffers; + if (buffers.hasAlreadyRead(chunkIndex)) { throw new IllegalStateException(String.format( "Chunk %s for stream %s has already been read.", chunkIndex, streamId)); } - ManagedBuffer nextChunk = iterator.chunk(chunkIndex); + ManagedBuffer nextChunk = buffers.chunk(chunkIndex, offset, len); - if (!state.buffers.hasNext()) { + if (state.buffers.isFullyRead()) { // Normally, when all chunks are returned to the client, the stream should be removed here. // But if there is a switch on the client side, it will not go here at this time, so we need // to remove the stream when the connection is terminated, and release the unused buffer. @@ -102,13 +101,6 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { return nextChunk; } - @Override - public ManagedBuffer openStream(String streamChunkId) { - Pair streamChunkIdPair = parseStreamChunkId(streamChunkId); - logger.debug("StreamManager open stream {}", streamChunkId); - return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight()); - } - public static String genStreamChunkId(long streamId, int chunkId) { return String.format("%d_%d", streamId, chunkId); } @@ -126,36 +118,13 @@ public static Pair parseStreamChunkId(String streamChunkId) { @Override public void connectionTerminated(Channel channel) { - // SPARK-30246 - RuntimeException failedToReleaseBufferException = null; - // Close all streams which have been associated with the channel. for (Map.Entry entry: streams.entrySet()) { StreamState state = entry.getValue(); if (state.associatedChannel == channel) { streams.remove(entry.getKey()); - - try { - // Release all remaining buffers. - while (state.buffers.hasNext()) { - ManagedBuffer buffer = state.buffers.next(); - if (buffer != null) { - buffer.release(); - } - } - } catch (RuntimeException e) { - if (failedToReleaseBufferException == null) { - failedToReleaseBufferException = e; - } else { - logger.error("Exception trying to release remaining StreamState buffers", e); - } - } } } - - if (failedToReleaseBufferException != null) { - throw failedToReleaseBufferException; - } } @Override @@ -222,7 +191,7 @@ public long chunksBeingTransferred() { * to be the only reader of the stream. Once the connection is closed, the stream will never * be used again, enabling cleanup by `connectionTerminated`. */ - public long registerStream(String appId, Iterator buffers, Channel channel) { + public long registerStream(String appId, FileManagedBuffers buffers, Channel channel) { long myStreamId = nextStreamId.getAndIncrement(); streams.put(myStreamId, new StreamState(appId, buffers, channel)); return myStreamId; diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/server/RpcHandler.java b/common/src/main/java/com/aliyun/emr/rss/common/network/server/RpcHandler.java deleted file mode 100644 index 6d46334d374..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/server/RpcHandler.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.server; - -import java.nio.ByteBuffer; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; -import com.aliyun.emr.rss.common.network.client.StreamCallback; -import com.aliyun.emr.rss.common.network.client.StreamCallbackWithID; -import com.aliyun.emr.rss.common.network.client.TransportClient; -import com.aliyun.emr.rss.common.network.protocol.PushData; -import com.aliyun.emr.rss.common.network.protocol.PushMergedData; - -/** - * Handler for sendRPC() messages sent by {@link TransportClient}s. - */ -public abstract class RpcHandler { - - private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); - - /** - * Receive a single RPC message. Any exception thrown while in this method will be sent back to - * the client in string form as a standard RPC failure. - * - * Neither this method nor #receiveStream will be called in parallel for a single - * TransportClient (i.e., channel). - * - * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. This will always be the exact same object for a particular channel. - * @param message The serialized bytes of the RPC. - * @param callback Callback which should be invoked exactly once upon success or failure of the - * RPC. - */ - public abstract void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback); - - /** - * Receive a single RPC message which includes data that is to be received as a stream. Any - * exception thrown while in this method will be sent back to the client in string form as a - * standard RPC failure. - * - * Neither this method nor #receive will be called in parallel for a single TransportClient - * (i.e., channel). - * - * An error while reading data from the stream - * ({@link StreamCallback#onData(String, ByteBuffer)}) - * will fail the entire channel. A failure in "post-processing" the stream in - * {@link StreamCallback#onComplete(String)} will result in an - * rpcFailure, but the channel will remain active. - * - * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. This will always be the exact same object for a particular channel. - * @param messageHeader The serialized bytes of the header portion of the RPC. This is in meant - * to be relatively small, and will be buffered entirely in memory, to - * facilitate how the streaming portion should be received. - * @param callback Callback which should be invoked exactly once upon success or failure of the - * RPC. - * @return a StreamCallback for handling the accompanying streaming data - */ - public StreamCallbackWithID receiveStream( - TransportClient client, - ByteBuffer messageHeader, - RpcResponseCallback callback) { - throw new UnsupportedOperationException(); - } - - public void receivePushData( - TransportClient client, - PushData pushData, - RpcResponseCallback callback) { - throw new UnsupportedOperationException(); - } - - public void receivePushMergedData( - TransportClient client, - PushMergedData pushMergedData, - RpcResponseCallback callback) { - throw new UnsupportedOperationException(); - } - - public boolean checkRegistered() { - return true; - } - - /** - * Returns the StreamManager which contains the state about which streams are currently being - * fetched by a TransportClient. - */ - public abstract StreamManager getStreamManager(); - - /** - * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if - * any of the callback methods are called. - * - * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. This will always be the exact same object for a particular channel. - * @param message The serialized bytes of the RPC. - */ - public void receive(TransportClient client, ByteBuffer message) { - receive(client, message, ONE_WAY_CALLBACK); - } - - /** - * Invoked when the channel associated with the given client is active. - */ - public void channelActive(TransportClient client) { } - - /** - * Invoked when the channel associated with the given client is inactive. - * No further requests will come from this client. - */ - public void channelInactive(TransportClient client) { } - - public void exceptionCaught(Throwable cause, TransportClient client) { } - - private static class OneWayRpcCallback implements RpcResponseCallback { - - private static final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); - - @Override - public void onSuccess(ByteBuffer response) { - logger.warn("Response provided for one-way RPC."); - } - - @Override - public void onFailure(Throwable e) { - logger.error("Error response provided for one-way RPC.", e); - } - - } - -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/server/StreamManager.java b/common/src/main/java/com/aliyun/emr/rss/common/network/server/StreamManager.java index 078fc3c8171..a765ffa2b41 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/server/StreamManager.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/server/StreamManager.java @@ -44,14 +44,14 @@ public abstract class StreamManager { * @param streamId id of a stream that has been previously registered with the StreamManager. * @param chunkIndex 0-indexed chunk of the stream that's requested */ - public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + public abstract ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len); /** * Called in response to a stream() request. The returned data is streamed to the client * through a single TCP connection. * * Note the streamId argument is not related to the similarly named argument in the - * {@link #getChunk(long, int)} method. + * {@link #getChunk(long, int, int, int)} method. * * @param streamId id of a stream that has been previously registered with the StreamManager. * @return A managed buffer for the stream, or null if the stream was not found. diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportRequestHandler.java b/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportRequestHandler.java index 6f9662bfc7f..b42b08bd345 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportRequestHandler.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportRequestHandler.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.net.SocketAddress; -import java.nio.ByteBuffer; import com.google.common.base.Throwables; import io.netty.channel.Channel; @@ -27,18 +26,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.aliyun.emr.rss.common.metrics.source.AbstractSource; -import com.aliyun.emr.rss.common.metrics.source.NetWorkSource; -import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; -import com.aliyun.emr.rss.common.network.buffer.NioManagedBuffer; -import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; -import com.aliyun.emr.rss.common.network.client.StreamCallbackWithID; -import com.aliyun.emr.rss.common.network.client.StreamInterceptor; import com.aliyun.emr.rss.common.network.client.TransportClient; import com.aliyun.emr.rss.common.network.protocol.*; -import com.aliyun.emr.rss.common.network.util.JavaUtils; -import com.aliyun.emr.rss.common.network.util.NettyUtils; -import com.aliyun.emr.rss.common.network.util.TransportFrameDecoder; /** * A handler that processes requests from clients and writes chunk data back. Each handler is @@ -58,98 +47,46 @@ public class TransportRequestHandler extends MessageHandler { private final TransportClient reverseClient; /** Handles all RPC messages. */ - private final RpcHandler rpcHandler; - - /** Returns each chunk part of a stream. */ - private final StreamManager streamManager; - - /** The max number of chunks being transferred and not finished yet. */ - private final long maxChunksBeingTransferred; - - private AbstractSource source = null; - - public TransportRequestHandler( - Channel channel, - TransportClient reverseClient, - RpcHandler rpcHandler, - Long maxChunksBeingTransferred, - AbstractSource source){ - this(channel, reverseClient, rpcHandler, maxChunksBeingTransferred); - this.source = source; - } + private final BaseMessageHandler msgHandler; public TransportRequestHandler( Channel channel, TransportClient reverseClient, - RpcHandler rpcHandler, - Long maxChunksBeingTransferred) { + BaseMessageHandler msgHandler) { this.channel = channel; this.reverseClient = reverseClient; - this.rpcHandler = rpcHandler; - this.streamManager = rpcHandler.getStreamManager(); - this.maxChunksBeingTransferred = maxChunksBeingTransferred; + this.msgHandler = msgHandler; } @Override public void exceptionCaught(Throwable cause) { - rpcHandler.exceptionCaught(cause, reverseClient); + msgHandler.exceptionCaught(cause, reverseClient); } @Override public void channelActive() { - rpcHandler.channelActive(reverseClient); + msgHandler.channelActive(reverseClient); } @Override public void channelInactive() { - if (streamManager != null) { - try { - streamManager.connectionTerminated(channel); - } catch (RuntimeException e) { - logger.error("StreamManager connectionTerminated() callback failed.", e); - } - } - rpcHandler.channelInactive(reverseClient); + msgHandler.channelInactive(reverseClient); } @Override public void handle(RequestMessage request) { - if (request instanceof ChunkFetchRequest) { - if (checkRegistered(request)) { - processFetchRequest((ChunkFetchRequest) request); - } - } else if (request instanceof RpcRequest) { - if (checkRegistered(request)) { - processRpcRequest((RpcRequest) request); - } - } else if (request instanceof OneWayMessage) { - if (checkRegistered(request)) { - processOneWayMessage((OneWayMessage) request); - } - } else if (request instanceof StreamRequest) { - processStreamRequest((StreamRequest) request); - } else if (request instanceof UploadStream) { - processStreamUpload((UploadStream) request); - } else if (request instanceof PushData) { - if (checkRegistered(request)) { - processPushData((PushData) request); - } - } else if (request instanceof PushMergedData) { - if (checkRegistered(request)) { - processPushMergedData((PushMergedData) request); - } - } else { - throw new IllegalArgumentException("Unknown request type: " + request); + if (checkRegistered(request)) { + msgHandler.receive(reverseClient, request); } } private boolean checkRegistered(RequestMessage req) { - if (!rpcHandler.checkRegistered()) { + if (!msgHandler.checkRegistered()) { IOException e = new IOException("Worker Not Registered!"); if (req instanceof RpcRequest) { respond(new RpcFailure(((RpcRequest)req).requestId, Throwables.getStackTraceAsString(e))); } else if (req instanceof ChunkFetchRequest) { - respond(new ChunkFetchFailure(((ChunkFetchRequest)req).streamChunkId, + respond(new ChunkFetchFailure(((ChunkFetchRequest)req).streamChunkSlice, Throwables.getStackTraceAsString(e))); } else if (req instanceof OneWayMessage) { logger.warn("Ignore OneWayMessage since worker is not registered!"); @@ -159,232 +96,6 @@ private boolean checkRegistered(RequestMessage req) { return true; } - private void processFetchRequest(final ChunkFetchRequest req) { - if (source != null) { - source.startTimer(NetWorkSource.FetchChunkTime(), req.toString()); - } - if (logger.isTraceEnabled()) { - logger.trace("Received req from {} to fetch block {}", NettyUtils.getRemoteAddress(channel), - req.streamChunkId); - } - long chunksBeingTransferred = streamManager.chunksBeingTransferred(); - if (chunksBeingTransferred >= maxChunksBeingTransferred) { - logger.warn("The number of chunks being transferred {} is above {}, close the connection.", - chunksBeingTransferred, maxChunksBeingTransferred); - channel.close(); - if (source != null) { - source.stopTimer(NetWorkSource.FetchChunkTime(), req.toString()); - } - return; - } - ManagedBuffer buf; - try { - streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); - buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); - } catch (Exception e) { - logger.error(String.format("Error opening block %s for request from %s", - req.streamChunkId, NettyUtils.getRemoteAddress(channel)), e); - respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); - if (source != null) { - source.stopTimer(NetWorkSource.FetchChunkTime(), req.toString()); - } - return; - } - - streamManager.chunkBeingSent(req.streamChunkId.streamId); - respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> { - streamManager.chunkSent(req.streamChunkId.streamId); - if (source != null) { - source.stopTimer(NetWorkSource.FetchChunkTime(), req.toString()); - } - }); - } - - private void processStreamRequest(final StreamRequest req) { - if (logger.isTraceEnabled()) { - logger.trace("Received req from {} to fetch stream {}", NettyUtils.getRemoteAddress(channel), - req.streamId); - } - - long chunksBeingTransferred = streamManager.chunksBeingTransferred(); - if (chunksBeingTransferred >= maxChunksBeingTransferred) { - logger.warn("The number of chunks being transferred {} is above {}, close the connection.", - chunksBeingTransferred, maxChunksBeingTransferred); - channel.close(); - return; - } - ManagedBuffer buf; - try { - buf = streamManager.openStream(req.streamId); - } catch (Exception e) { - logger.error(String.format("Error opening stream %s for request from %s", - req.streamId, NettyUtils.getRemoteAddress(channel)), e); - respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e))); - return; - } - - if (buf != null) { - streamManager.streamBeingSent(req.streamId); - respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> { - streamManager.streamSent(req.streamId); - }); - } else { - respond(new StreamFailure(req.streamId, String.format( - "Stream '%s' was not found.", req.streamId))); - } - } - - private void processRpcRequest(final RpcRequest req) { - try { - rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); - } - - @Override - public void onFailure(Throwable e) { - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - } - }); - } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - } finally { - req.body().release(); - } - } - - /** - * Handle a request from the client to upload a stream of data. - */ - private void processStreamUpload(final UploadStream req) { - assert (req.body() == null); - try { - RpcResponseCallback callback = new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); - } - - @Override - public void onFailure(Throwable e) { - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - } - }; - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - ByteBuffer meta = req.meta.nioByteBuffer(); - StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); - if (streamHandler == null) { - throw new NullPointerException("rpcHandler returned a null streamHandler"); - } - StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() { - @Override - public void onData(String streamId, ByteBuffer buf) throws IOException { - streamHandler.onData(streamId, buf); - } - - @Override - public void onComplete(String streamId) throws IOException { - try { - streamHandler.onComplete(streamId); - callback.onSuccess(ByteBuffer.allocate(0)); - } catch (Exception ex) { - IOException ioExc = new IOException("Failure post-processing complete stream;" + - " failing this rpc and leaving channel active", ex); - callback.onFailure(ioExc); - streamHandler.onFailure(streamId, ioExc); - } - } - - @Override - public void onFailure(String streamId, Throwable cause) throws IOException { - callback.onFailure(new IOException("Destination failed while reading stream", cause)); - streamHandler.onFailure(streamId, cause); - } - - @Override - public String getID() { - return streamHandler.getID(); - } - }; - if (req.bodyByteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor<>( - this, wrappedCallback.getID(), req.bodyByteCount, wrappedCallback); - frameDecoder.setInterceptor(interceptor); - } else { - wrappedCallback.onComplete(wrappedCallback.getID()); - } - } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - // We choose to totally fail the channel, rather than trying to recover as we do in other - // cases. We don't know how many bytes of the stream the client has already sent for the - // stream, it's not worth trying to recover. - channel.pipeline().fireExceptionCaught(e); - } finally { - req.meta.release(); - } - } - - private void processPushData(PushData req) { - try { - rpcHandler.receivePushData(reverseClient, req, new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); - } - - @Override - public void onFailure(Throwable e) { - logger.error("[processPushData] Process pushData onFailure! ShuffleKey: " - + req.shuffleKey + ", partitionUniqueId: " + req.partitionUniqueId, e); - respond(new RpcFailure(req.requestId, e.getMessage())); - } - }); - } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() on PushData " + req, e); - channel.writeAndFlush(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - } finally { - req.body().release(); - } - } - - private void processPushMergedData(PushMergedData req) { - try { - rpcHandler.receivePushMergedData(reverseClient, req, new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); - } - - @Override - public void onFailure(Throwable e) { - logger.error("[processPushMergedData] Process PushMergedData onFailure! ShuffleKey: " + - req.shuffleKey + - ", partitionUniqueId: " + JavaUtils.mkString(req.partitionUniqueIds, ","), e); - respond(new RpcFailure(req.requestId, e.getMessage())); - } - }); - } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() on PushMergedData " + req, e); - channel.writeAndFlush(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - } finally { - req.body().release(); - } - } - - private void processOneWayMessage(OneWayMessage req) { - try { - rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); - } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); - } finally { - req.body().release(); - } - } - /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. @@ -397,7 +108,6 @@ private ChannelFuture respond(Encodable result) { } else { logger.warn(String.format("Fail to sending result %s to %s; closing connection", result, remoteAddress), future.cause()); - channel.close(); } }); } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportServer.java b/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportServer.java index 0ba8c6b74d3..031055ee24e 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportServer.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportServer.java @@ -47,27 +47,24 @@ public class TransportServer implements Closeable { private final TransportContext context; private final TransportConf conf; - private final RpcHandler appRpcHandler; + private final BaseMessageHandler appRpcHandler; private final List bootstraps; private ServerBootstrap bootstrap; private ChannelFuture channelFuture; private int port = -1; - private NettyMemoryMetrics nettyMetric; - private AbstractSource source; public TransportServer( TransportContext context, String hostToBind, int portToBind, - RpcHandler appRpcHandler, + BaseMessageHandler appRpcHandler, List bootstraps, AbstractSource source) { this.context = context; this.conf = context.getConf(); this.appRpcHandler = appRpcHandler; this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); - this.source = source; boolean shouldClose = true; try { @@ -88,7 +85,7 @@ public TransportServer( TransportContext context, String hostToBind, int portToBind, - RpcHandler appRpcHandler, + BaseMessageHandler appRpcHandler, List bootstraps) { this(context, hostToBind, portToBind, appRpcHandler, bootstraps, null); } @@ -120,9 +117,6 @@ private void init(String hostToBind, int portToBind) { .childOption(ChannelOption.SO_KEEPALIVE, true) .childOption(ChannelOption.ALLOCATOR, allocator); - this.nettyMetric = new NettyMemoryMetrics( - allocator, conf.getModuleName() + "-server", conf, source); - if (conf.backLog() > 0) { bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); } @@ -150,19 +144,15 @@ protected void initializeChannel(ServerBootstrap bootstrap) { bootstrap.childHandler(new ChannelInitializer() { @Override protected void initChannel(SocketChannel ch) { - RpcHandler rpcHandler = appRpcHandler; + BaseMessageHandler handler = appRpcHandler; for (TransportServerBootstrap bootstrap : bootstraps) { - rpcHandler = bootstrap.doBootstrap(ch, rpcHandler); + handler = bootstrap.doBootstrap(ch, handler); } - context.initializePipeline(ch, rpcHandler); + context.initializePipeline(ch, handler); } }); } - public NettyMemoryMetrics getNettyMetric() { - return nettyMetric; - } - @Override public void close() { if (channelFuture != null) { diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportServerBootstrap.java b/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportServerBootstrap.java index f0aa7ad05b7..fb746a37e76 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportServerBootstrap.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/server/TransportServerBootstrap.java @@ -29,8 +29,8 @@ public interface TransportServerBootstrap { * Customizes the channel to include new features, if needed. * * @param channel The connected channel opened by the client. - * @param rpcHandler The RPC handler for the server. + * @param handler The RPC handler for the server. * @return The RPC handler to use for the channel. */ - RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler); + BaseMessageHandler doBootstrap(Channel channel, BaseMessageHandler handler); } diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/util/ByteArrayReadableChannel.java b/common/src/main/java/com/aliyun/emr/rss/common/network/util/ByteArrayReadableChannel.java deleted file mode 100644 index f96f588b139..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/util/ByteArrayReadableChannel.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.util; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.ReadableByteChannel; - -import io.netty.buffer.ByteBuf; - -public class ByteArrayReadableChannel implements ReadableByteChannel { - private ByteBuf data; - private boolean closed; - - public int readableBytes() { - return data.readableBytes(); - } - - public void feedData(ByteBuf buf) throws ClosedChannelException { - if (closed) { - throw new ClosedChannelException(); - } - data = buf; - } - - @Override - public int read(ByteBuffer dst) throws IOException { - int totalRead = 0; - while (data.readableBytes() > 0 && dst.remaining() > 0) { - int bytesToRead = Math.min(data.readableBytes(), dst.remaining()); - dst.put(data.readSlice(bytesToRead).nioBuffer()); - totalRead += bytesToRead; - } - - return totalRead; - } - - @Override - public void close() throws IOException { - closed = true; - } - - @Override - public boolean isOpen() { - return !closed; - } - -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/util/ByteArrayWritableChannel.java b/common/src/main/java/com/aliyun/emr/rss/common/network/util/ByteArrayWritableChannel.java deleted file mode 100644 index 018a403f6a8..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/util/ByteArrayWritableChannel.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.util; - -import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; - -/** - * A writable channel that stores the written data in a byte array in memory. - */ -public class ByteArrayWritableChannel implements WritableByteChannel { - - private final byte[] data; - private int offset; - - public ByteArrayWritableChannel(int size) { - this.data = new byte[size]; - } - - public byte[] getData() { - return data; - } - - public int length() { - return offset; - } - - /** Resets the channel so that writing to it will overwrite the existing buffer. */ - public void reset() { - offset = 0; - } - - /** - * Reads from the given buffer into the internal byte array. - */ - @Override - public int write(ByteBuffer src) { - int toTransfer = Math.min(src.remaining(), data.length - offset); - src.get(data, offset, toTransfer); - offset += toTransfer; - return toTransfer; - } - - @Override - public void close() { - - } - - @Override - public boolean isOpen() { - return true; - } - -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/util/CryptoUtils.java b/common/src/main/java/com/aliyun/emr/rss/common/network/util/CryptoUtils.java deleted file mode 100644 index 3b325ea656c..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/util/CryptoUtils.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.util; - -import java.util.Map; -import java.util.Properties; - -/** - * Utility methods related to the commons-crypto library. - */ -public class CryptoUtils { - - // The prefix for the configurations passing to Apache Commons Crypto library. - public static final String COMMONS_CRYPTO_CONFIG_PREFIX = "commons.crypto."; - - /** - * Extract the commons-crypto configuration embedded in a list of config values. - * - * @param prefix Prefix in the given configuration that identifies the commons-crypto configs. - * @param conf List of configuration values. - */ - public static Properties toCryptoConf(String prefix, Iterable> conf) { - Properties props = new Properties(); - for (Map.Entry e : conf) { - String key = e.getKey(); - if (key.startsWith(prefix)) { - props.setProperty(COMMONS_CRYPTO_CONFIG_PREFIX + key.substring(prefix.length()), - e.getValue()); - } - } - return props; - } - -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/util/LevelDBProvider.java b/common/src/main/java/com/aliyun/emr/rss/common/network/util/LevelDBProvider.java deleted file mode 100644 index 10be8502f5e..00000000000 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/util/LevelDBProvider.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.util; - -import java.io.File; -import java.io.IOException; -import java.nio.charset.StandardCharsets; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.fusesource.leveldbjni.JniDBFactory; -import org.fusesource.leveldbjni.internal.NativeDB; -import org.iq80.leveldb.DB; -import org.iq80.leveldb.Options; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * LevelDB utility class available in the com.aliyun.emr.rss.common.network package. - */ -public class LevelDBProvider { - private static final Logger logger = LoggerFactory.getLogger(LevelDBProvider.class); - - public static DB initLevelDB(File dbFile, StoreVersion version, ObjectMapper mapper) throws - IOException { - DB tmpDb = null; - if (dbFile != null) { - Options options = new Options(); - options.createIfMissing(false); - options.logger(new LevelDBLogger()); - try { - tmpDb = JniDBFactory.factory.open(dbFile, options); - } catch (NativeDB.DBException e) { - if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { - logger.info("Creating state database at " + dbFile); - options.createIfMissing(true); - try { - tmpDb = JniDBFactory.factory.open(dbFile, options); - } catch (NativeDB.DBException dbExc) { - throw new IOException("Unable to create state store", dbExc); - } - } else { - // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new - // one, so we can keep processing new apps - logger.error("error opening leveldb file {}. Creating new file, will not be able to " + - "recover state for existing applications", dbFile, e); - if (dbFile.isDirectory()) { - for (File f : dbFile.listFiles()) { - if (!f.delete()) { - logger.warn("error deleting {}", f.getPath()); - } - } - } - if (!dbFile.delete()) { - logger.warn("error deleting {}", dbFile.getPath()); - } - options.createIfMissing(true); - try { - tmpDb = JniDBFactory.factory.open(dbFile, options); - } catch (NativeDB.DBException dbExc) { - throw new IOException("Unable to create state store", dbExc); - } - - } - } - // if there is a version mismatch, we throw an exception, which means the service is unusable - checkVersion(tmpDb, version, mapper); - } - return tmpDb; - } - - private static class LevelDBLogger implements org.iq80.leveldb.Logger { - private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); - - @Override - public void log(String message) { - LOG.info(message); - } - } - - /** - * Simple major.minor versioning scheme. Any incompatible changes should be across major - * versions. Minor version differences are allowed -- meaning we should be able to read - * dbs that are either earlier *or* later on the minor version. - */ - public static void checkVersion(DB db, StoreVersion newversion, ObjectMapper mapper) throws - IOException { - byte[] bytes = db.get(StoreVersion.KEY); - if (bytes == null) { - storeVersion(db, newversion, mapper); - } else { - StoreVersion version = mapper.readValue(bytes, StoreVersion.class); - if (version.major != newversion.major) { - throw new IOException("cannot read state DB with version " + version + ", incompatible " + - "with current version " + newversion); - } - storeVersion(db, newversion, mapper); - } - } - - public static void storeVersion(DB db, StoreVersion version, ObjectMapper mapper) - throws IOException { - db.put(StoreVersion.KEY, mapper.writeValueAsBytes(version)); - } - - public static class StoreVersion { - - static final byte[] KEY = "StoreVersion".getBytes(StandardCharsets.UTF_8); - - public final int major; - public final int minor; - - @JsonCreator - public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { - this.major = major; - this.minor = minor; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - StoreVersion that = (StoreVersion) o; - - return major == that.major && minor == that.minor; - } - - @Override - public int hashCode() { - int result = major; - result = 31 * result + minor; - return result; - } - } -} diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/util/TransportConf.java b/common/src/main/java/com/aliyun/emr/rss/common/network/util/TransportConf.java index 7e7adf64eb6..d049faede9e 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/util/TransportConf.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/util/TransportConf.java @@ -18,7 +18,6 @@ package com.aliyun.emr.rss.common.network.util; import java.util.Locale; -import java.util.Properties; import com.google.common.primitives.Ints; @@ -183,105 +182,6 @@ public boolean verboseMetrics() { return conf.getBoolean(RSS_NETWORK_VERBOSE_METRICS, false); } - /** - * Maximum number of retries when binding to a port before giving up. - */ - public int portMaxRetries() { - return conf.getInt("rss.port.maxRetries", 16); - } - - /** - * Enables strong encryption. Also enables the new auth protocol, used to negotiate keys. - */ - public boolean encryptionEnabled() { - return conf.getBoolean("rss.network.crypto.enabled", false); - } - - /** - * The cipher transformation to use for encrypting session data. - */ - public String cipherTransformation() { - return conf.get("rss.network.crypto.cipher", "AES/CTR/NoPadding"); - } - - /** - * The key generation algorithm. This should be an algorithm that accepts a "PBEKeySpec" - * as input. The default value (PBKDF2WithHmacSHA1) is available in Java 7. - */ - public String keyFactoryAlgorithm() { - return conf.get("rss.network.crypto.keyFactoryAlgorithm", "PBKDF2WithHmacSHA1"); - } - - /** - * How many iterations to run when generating keys. - * - * See some discussion about this at: http://security.stackexchange.com/q/3959 - * The default value was picked for speed, since it assumes that the secret has good entropy - * (128 bits by default), which is not generally the case with user passwords. - */ - public int keyFactoryIterations() { - return conf.getInt("rss.network.crypto.keyFactoryIterations", 1024); - } - - /** - * Encryption key length, in bits. - */ - public int encryptionKeyLength() { - return conf.getInt("rss.network.crypto.keyLength", 128); - } - - /** - * Initial vector length, in bytes. - */ - public int ivLength() { - return conf.getInt("rss.network.crypto.ivLength", 16); - } - - /** - * The algorithm for generated secret keys. Nobody should really need to change this, - * but configurable just in case. - */ - public String keyAlgorithm() { - return conf.get("rss.network.crypto.keyAlgorithm", "AES"); - } - - /** - * Whether to fall back to SASL if the new auth protocol fails. Enabled by default for - * backwards compatibility. - */ - public boolean saslFallback() { - return conf.getBoolean("rss.network.crypto.saslFallback", true); - } - - /** - * Whether to enable SASL-based encryption when authenticating using SASL. - */ - public boolean saslEncryption() { - return conf.getBoolean("rss.authenticate.enableSaslEncryption", false); - } - - /** - * Maximum number of bytes to be encrypted at a time when SASL encryption is used. - */ - public int maxSaslEncryptedBlockSize() { - return Ints.checkedCast(JavaUtils.byteStringAsBytes( - conf.get("rss.network.sasl.maxEncryptedBlockSize", "64k"))); - } - - /** - * Whether the server should enforce encryption on SASL-authenticated connections. - */ - public boolean saslServerAlwaysEncrypt() { - return conf.getBoolean("rss.network.sasl.serverAlwaysEncrypt", false); - } - - /** - * The commons-crypto configuration for the module. - */ - public Properties cryptoConf() { - return CryptoUtils.toCryptoConf("rss.network.crypto.config.", conf.getAll()); - } - /** * The max number of chunks allowed to be transferred at the same time on shuffle service. * Note that new incoming connections will be closed when the max number is hit. The client will diff --git a/common/src/main/java/com/aliyun/emr/rss/common/protocol/PartitionLocation.java b/common/src/main/java/com/aliyun/emr/rss/common/protocol/PartitionLocation.java index 62a976a0f84..e7a4570cfeb 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/protocol/PartitionLocation.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/protocol/PartitionLocation.java @@ -39,6 +39,10 @@ public enum StorageHint { NON_EXISTS, MEMORY, HDD, SDD, HDFS, OSS } + public enum Type { + REDUCE_PARTITION, MAP_PARTITION, MAPGROUP_REDUCE_PARTITION + } + public static PartitionLocation.Mode getMode(byte mode) { if (mode == 0) { return Mode.Master; diff --git a/common/src/main/java/com/aliyun/emr/rss/common/protocol/message/StatusCode.java b/common/src/main/java/com/aliyun/emr/rss/common/protocol/message/StatusCode.java index 2c2b575a448..fcdebee6b14 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/protocol/message/StatusCode.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/protocol/message/StatusCode.java @@ -46,7 +46,10 @@ public enum StatusCode { PushDataFailPartitionNotFound(20), HardSplit(21), - SoftSplit(22); + SoftSplit(22), + + StageEndTimeOut(23), + ShuffleDataLost(24); private final byte value; diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala b/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala index d1f1adb4b11..a3cb0f2123c 100644 --- a/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala +++ b/common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala @@ -632,6 +632,10 @@ object RssConf extends Logging { conf.getInt("rss.worker.prometheus.metric.port", 9096) } + def workerRPCPort(conf: RssConf): Int = { + conf.getInt("rss.worker.rpc.port", 0) + } + def clusterLoadFallbackEnabled(conf: RssConf): Boolean = { conf.getBoolean("rss.clusterLoad.fallback.enabled", defaultValue = true) } diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/protocol/message/ControlMessages.scala b/common/src/main/scala/com/aliyun/emr/rss/common/protocol/message/ControlMessages.scala index 015d40bb6ac..a64f051554b 100644 --- a/common/src/main/scala/com/aliyun/emr/rss/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/com/aliyun/emr/rss/common/protocol/message/ControlMessages.scala @@ -88,9 +88,9 @@ sealed trait Message extends Serializable{ case RegisterShuffleResponse(status, partitionLocations) => val builder = TransportMessages.PbRegisterShuffleResponse.newBuilder() .setStatus(status.getValue) - if (partitionLocations != null) { + if (!partitionLocations.isEmpty) { builder.addAllPartitionLocations(partitionLocations.iterator().asScala - .map(PartitionLocation.toPbPartitionLocation(_)).toList.asJava) + .map(PartitionLocation.toPbPartitionLocation).toList.asJava) } val payload = builder.build().toByteArray new TransportMessage(TransportMessages.MessageType.REGISTER_SHUFFLE_RESPONSE, payload) @@ -125,7 +125,7 @@ sealed trait Message extends Serializable{ case RequestSlotsResponse(status, workerResource) => val builder = TransportMessages.PbRequestSlotsResponse.newBuilder() .setStatus(status.getValue) - if (workerResource != null) { + if (!workerResource.isEmpty) { builder.putAllWorkerResource( Utils.convertWorkerResourceToPbWorkerResource(workerResource)) } @@ -182,14 +182,12 @@ sealed trait Message extends Serializable{ case GetReducerFileGroupResponse(status, fileGroup, attempts) => val builder = TransportMessages.PbGetReducerFileGroupResponse.newBuilder() .setStatus(status.getValue) - if (fileGroup != null) { - builder.addAllFileGroup(fileGroup.map(arr => PbFileGroup.newBuilder() - .addAllLocaltions(arr.map(PartitionLocation.toPbPartitionLocation(_)).toIterable.asJava) - .build()).toIterable.asJava) - } - if (attempts != null) { - builder.addAllAttempts(attempts.map(new Integer(_)).toIterable.asJava) - } + builder.addAllFileGroup(fileGroup.map { arr => + PbFileGroup.newBuilder() + .addAllLocaltions(arr.map(PartitionLocation.toPbPartitionLocation).toIterable.asJava) + .build() + }.toIterable.asJava) + builder.addAllAttempts(attempts.map(new Integer(_)).toIterable.asJava) val payload = builder.build().toByteArray new TransportMessage(TransportMessages.MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, payload) @@ -255,7 +253,7 @@ sealed trait Message extends Serializable{ case GetBlacklist(localBlacklist) => val payload = TransportMessages.PbGetBlacklist.newBuilder() - .addAllLocalBlackList(localBlacklist.asScala.map(WorkerInfo.toPbWorkerInfo(_)) + .addAllLocalBlackList(localBlacklist.asScala.map(WorkerInfo.toPbWorkerInfo) .toList.asJava) .build().toByteArray new TransportMessage(TransportMessages.MessageType.GET_BLACKLIST, payload) @@ -263,13 +261,9 @@ sealed trait Message extends Serializable{ case GetBlacklistResponse(statusCode, blacklist, unknownWorkers) => val builder = TransportMessages.PbGetBlacklistResponse.newBuilder() .setStatus(statusCode.getValue) - if (blacklist != null) { - builder.addAllBlacklist(blacklist.asScala.map(WorkerInfo.toPbWorkerInfo(_)).toList.asJava) - } - if (unknownWorkers != null) { - builder.addAllUnknownWorkers(unknownWorkers.asScala - .map(WorkerInfo.toPbWorkerInfo(_)).toList.asJava) - } + builder.addAllBlacklist(blacklist.asScala.map(WorkerInfo.toPbWorkerInfo).toList.asJava) + builder.addAllUnknownWorkers( + unknownWorkers.asScala.map(WorkerInfo.toPbWorkerInfo).toList.asJava) val payload = builder.build().toByteArray new TransportMessage(TransportMessages.MessageType.GET_BLACKLIST_RESPONSE, payload) @@ -338,18 +332,10 @@ sealed trait Message extends Serializable{ failedMasterIds, failedSlaveIds) => val builder = TransportMessages.PbCommitFilesResponse.newBuilder() .setStatus(status.getValue) - if (committedMasterIds != null) { - builder.addAllCommittedMasterIds(committedMasterIds) - } - if (committedSlaveIds != null) { - builder.addAllCommittedSlaveIds(committedSlaveIds) - } - if (failedMasterIds != null) { - builder.addAllFailedMasterIds(failedMasterIds) - } - if (failedSlaveIds != null) { - builder.addAllFailedSlaveIds(failedSlaveIds) - } + builder.addAllCommittedMasterIds(committedMasterIds) + builder.addAllCommittedSlaveIds(committedSlaveIds) + builder.addAllFailedMasterIds(failedMasterIds) + builder.addAllFailedSlaveIds(failedSlaveIds) val payload = builder.build().toByteArray new TransportMessage(TransportMessages.MessageType.COMMIT_FILES_RESPONSE, payload) @@ -364,12 +350,8 @@ sealed trait Message extends Serializable{ case DestroyResponse(status, failedMasters, failedSlaves) => val builder = TransportMessages.PbDestroyResponse.newBuilder() .setStatus(status.getValue) - if (failedMasters != null) { - builder.addAllFailedMasters(failedMasters) - } - if (failedSlaves != null) { - builder.addAllFailedSlaves(failedSlaves) - } + builder.addAllFailedMasters(failedMasters) + builder.addAllFailedSlaves(failedSlaves) val payload = builder.build().toByteArray new TransportMessage(TransportMessages.MessageType.DESTROY_RESPONSE, payload) @@ -695,7 +677,7 @@ object ControlMessages extends Logging{ val partitionLocations = new util.ArrayList[PartitionLocation]() if (pbRegisterShuffleResponse.getPartitionLocationsCount > 0) { partitionLocations.addAll(pbRegisterShuffleResponse.getPartitionLocationsList - .asScala.map(PartitionLocation.fromPbPartitionLocation(_)).toList.asJava) + .asScala.map(PartitionLocation.fromPbPartitionLocation).toList.asJava) } RegisterShuffleResponse(Utils.toStatusCode(pbRegisterShuffleResponse.getStatus), partitionLocations) @@ -719,12 +701,9 @@ object ControlMessages extends Logging{ case REQUEST_SLOTS_RESPONSE => val pbRequestSlotsResponse = PbRequestSlotsResponse.parseFrom(message.getPayload) - val workerResource = if (pbRequestSlotsResponse.getWorkerResourceCount > 0) { - Utils.convertPbWorkerResourceToWorkerResource(pbRequestSlotsResponse.getWorkerResourceMap) - } else { - null - } - RequestSlotsResponse(Utils.toStatusCode(pbRequestSlotsResponse.getStatus), workerResource) + RequestSlotsResponse(Utils.toStatusCode(pbRequestSlotsResponse.getStatus), + Utils.convertPbWorkerResourceToWorkerResource( + pbRequestSlotsResponse.getWorkerResourceMap)) case REVIVE => val pbRevive = PbRevive.parseFrom(message.getPayload) @@ -761,14 +740,10 @@ object ControlMessages extends Logging{ case GET_REDUCER_FILE_GROUP_RESPONSE => val pbGetReducerFileGroupResponse = PbGetReducerFileGroupResponse .parseFrom(message.getPayload) - val fileGroup = if (pbGetReducerFileGroupResponse.getFileGroupCount > 0) { - pbGetReducerFileGroupResponse.getFileGroupList.asScala - .map(fg => fg.getLocaltionsList.asScala - .map(PartitionLocation.fromPbPartitionLocation(_)).toArray).toArray - } else null - val attempts = if (pbGetReducerFileGroupResponse.getAttemptsCount > 0) { - pbGetReducerFileGroupResponse.getAttemptsList().asScala.map(Int.unbox(_)).toArray - } else null + val fileGroup = pbGetReducerFileGroupResponse.getFileGroupList.asScala.map { fg => + fg.getLocaltionsList.asScala.map(PartitionLocation.fromPbPartitionLocation).toArray + }.toArray + val attempts = pbGetReducerFileGroupResponse.getAttemptsList.asScala.map(_.toInt).toArray GetReducerFileGroupResponse(Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus), fileGroup, attempts) @@ -797,20 +772,15 @@ object ControlMessages extends Logging{ case GET_BLACKLIST => val pbGetBlacklist = PbGetBlacklist.parseFrom(message.getPayload) GetBlacklist(new util.ArrayList[WorkerInfo](pbGetBlacklist.getLocalBlackListList.asScala - .map(WorkerInfo.fromPbWorkerInfo(_)).toList.asJava)) + .map(WorkerInfo.fromPbWorkerInfo).toList.asJava)) case GET_BLACKLIST_RESPONSE => val pbGetBlacklistResponse = PbGetBlacklistResponse.parseFrom(message.getPayload) - val blacklist = if (pbGetBlacklistResponse.getBlacklistCount > 0) { + GetBlacklistResponse(Utils.toStatusCode(pbGetBlacklistResponse.getStatus), pbGetBlacklistResponse.getBlacklistList.asScala - .map(WorkerInfo.fromPbWorkerInfo(_)).toList.asJava - } else null - val unkownList = if (pbGetBlacklistResponse.getUnknownWorkersCount > 0) { + .map(WorkerInfo.fromPbWorkerInfo).toList.asJava, pbGetBlacklistResponse.getUnknownWorkersList.asScala - .map(WorkerInfo.fromPbWorkerInfo(_)).toList.asJava - } else null - GetBlacklistResponse(Utils.toStatusCode(pbGetBlacklistResponse.getStatus), - blacklist, unkownList) + .map(WorkerInfo.fromPbWorkerInfo).toList.asJava) case GET_CLUSTER_LOAD_STATUS => val pbGetClusterLoadStats = PbGetClusterLoadStatus.parseFrom(message.getPayload) @@ -856,7 +826,7 @@ object ControlMessages extends Logging{ val pbCommitFiles = PbCommitFiles.parseFrom(message.getPayload) CommitFiles(pbCommitFiles.getApplicationId, pbCommitFiles.getShuffleId, pbCommitFiles.getMasterIdsList, pbCommitFiles.getSlaveIdsList, - pbCommitFiles.getMapAttemptsList.asScala.map(Int.unbox(_)).toArray) + pbCommitFiles.getMapAttemptsList.asScala.map(_.toInt).toArray) case COMMIT_FILES_RESPONSE => val pbCommitFilesResponse = PbCommitFilesResponse.parseFrom(message.getPayload) diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/RpcEndpointRef.scala b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/RpcEndpointRef.scala index 2ebcf21d1a5..6a10dc3e6e9 100644 --- a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/RpcEndpointRef.scala +++ b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/RpcEndpointRef.scala @@ -30,8 +30,6 @@ import com.aliyun.emr.rss.common.util.RpcUtils abstract class RpcEndpointRef(conf: RssConf) extends Serializable with Logging { - private[this] val maxRetries = RpcUtils.numRetries(conf) - private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) /** diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/RpcEnv.scala b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/RpcEnv.scala index 831530fac1e..3bf0a770bf7 100644 --- a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/RpcEnv.scala +++ b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/RpcEnv.scala @@ -124,21 +124,6 @@ abstract class RpcEnv(conf: RssConf) { * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. */ def deserialize[T](deserializationAction: () => T): T - - /** - * Return the instance of the file server used to serve files. This may be `null` if the - * RpcEnv is not operating in server mode. - */ - def fileServer: RpcEnvFileServer - - /** - * Open a channel to download a file from the given URI. If the URIs returned by the - * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to - * retrieve the files. - * - * @param uri URI with location of the file. - */ - def openChannel(uri: String): ReadableByteChannel } /** diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/Dispatcher.scala b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/Dispatcher.scala index da9269f8e77..f0811aec680 100644 --- a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/Dispatcher.scala +++ b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/Dispatcher.scala @@ -196,7 +196,7 @@ private[rss] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extend /** Thread pool used for dispatching messages. */ private val threadpool: ThreadPoolExecutor = { val availableCores = - if (numUsableCores > 0) numUsableCores else Math.max(4, + if (numUsableCores > 0) numUsableCores else Math.max(16, Runtime.getRuntime.availableProcessors()) val numThreads = nettyEnv.conf.getInt("rss.rpc.dispatcher.numThreads", availableCores) val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/NettyRpcEnv.scala b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/NettyRpcEnv.scala index 876ed98abb2..2636dd5288c 100644 --- a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/NettyRpcEnv.scala +++ b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/NettyRpcEnv.scala @@ -18,25 +18,27 @@ package com.aliyun.emr.rss.common.rpc.netty import java.io._ -import java.net.{InetSocketAddress, URI} +import java.net.InetSocketAddress import java.nio.ByteBuffer -import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.Nullable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import scala.util.{DynamicVariable, Failure, Success, Try} +import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal +import com.google.common.base.Throwables + import com.aliyun.emr.rss.common.RssConf import com.aliyun.emr.rss.common.internal.Logging import com.aliyun.emr.rss.common.network.TransportContext +import com.aliyun.emr.rss.common.network.buffer.NioManagedBuffer import com.aliyun.emr.rss.common.network.client._ +import com.aliyun.emr.rss.common.network.protocol.{OneWayMessage => NOneWayMessage, RequestMessage => NRequestMessage, RpcFailure => NRpcFailure, RpcRequest, RpcResponse} import com.aliyun.emr.rss.common.network.server._ import com.aliyun.emr.rss.common.protocol.{RpcNameConstants, TransportModuleConstants} -import com.aliyun.emr.rss.common.protocol.message.Message import com.aliyun.emr.rss.common.rpc._ import com.aliyun.emr.rss.common.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream} import com.aliyun.emr.rss.common.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils} @@ -56,10 +58,9 @@ class NettyRpcEnv( private var worker: RpcEndpoint = null - private val streamManager = new NettyStreamManager(this) private val transportContext = new TransportContext(transportConf, - new NettyRpcHandler(dispatcher, this, streamManager)) + new NettyRpcHandler(dispatcher, this)) private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { java.util.Collections.emptyList[TransportClientBootstrap] @@ -67,16 +68,6 @@ class NettyRpcEnv( val clientFactory = transportContext.createClientFactory(createClientBootstraps()) - /** - * A separate client factory for file downloads. This avoids using the same RPC handler as - * the main RPC context, so that events caused by these clients are kept isolated from the - * main RPC traffic. - * - * It also allows for different configuration of certain properties, such as the number of - * connections per peer. - */ - @volatile private var fileDownloadFactory: TransportClientFactory = _ - private val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") @@ -154,7 +145,7 @@ class NettyRpcEnv( } private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = { - if (receiver.client != null) { + if (receiver.client != null && receiver.client.isActive) { message.sendWith(receiver.client) } else { require(receiver.address != null, @@ -315,9 +306,6 @@ class NettyRpcEnv( if (clientConnectionExecutor != null) { clientConnectionExecutor.shutdownNow() } - if (fileDownloadFactory != null) { - fileDownloadFactory.close() - } } override def deserialize[T](deserializationAction: () => T): T = { @@ -325,111 +313,6 @@ class NettyRpcEnv( deserializationAction() } } - - override def fileServer: RpcEnvFileServer = streamManager - - override def openChannel(uri: String): ReadableByteChannel = { - val parsedUri = new URI(uri) - require(parsedUri.getHost() != null, "Host name must be defined.") - require(parsedUri.getPort() > 0, "Port must be defined.") - require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") - - val pipe = Pipe.open() - val source = new FileDownloadChannel(pipe.source()) - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) - val callback = new FileDownloadCallback(pipe.sink(), source, client) - client.stream(parsedUri.getPath(), callback) - })(catchBlock = { - pipe.sink().close() - source.close() - }) - - source - } - - private def downloadClient(host: String, port: Int): TransportClient = { - if (fileDownloadFactory == null) synchronized { - if (fileDownloadFactory == null) { - val module = TransportModuleConstants.FILE_MODULE - val prefix = "spark.rpc.io." - val clone = conf.clone() - - // Copy any RPC configuration that is not overridden in the spark.files namespace. - conf.getAll.foreach { case (key, value) => - if (key.startsWith(prefix)) { - val opt = key.substring(prefix.length()) - clone.setIfMissing(s"spark.$module.io.$opt", value) - } - } - - val ioThreads = clone.getInt("spark.files.io.threads", 1) - val downloadConf = Utils.fromRssConf(clone, module, ioThreads) - val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) - fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) - } - } - fileDownloadFactory.createClient(host, port) - } - - private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel { - - @volatile private var error: Throwable = _ - - def setError(e: Throwable): Unit = { - // This setError callback is invoked by internal RPC threads in order to propagate remote - // exceptions to application-level threads which are reading from this channel. When an - // RPC error occurs, the RPC system will call setError() and then will close the - // Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe - // sink will cause `source.read()` operations to return EOF, unblocking the application-level - // reading thread. Thus there is no need to actually call `source.close()` here in the - // onError() callback and, in fact, calling it here would be dangerous because the close() - // would be asynchronous with respect to the read() call and could trigger race-conditions - // that lead to data corruption. See the PR for SPARK-22982 for more details on this topic. - error = e - } - - override def read(dst: ByteBuffer): Int = { - Try(source.read(dst)) match { - // See the documentation above in setError(): if an RPC error has occurred then setError() - // will be called to propagate the RPC error and then `source`'s corresponding - // Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate - // the remote RPC exception (and not any exceptions triggered by the pipe close, such as - // ChannelClosedException), hence this `error != null` check: - case _ if error != null => throw error - case Success(bytesRead) => bytesRead - case Failure(readErr) => throw readErr - } - } - - override def close(): Unit = source.close() - - override def isOpen(): Boolean = source.isOpen() - - } - - private class FileDownloadCallback( - sink: WritableByteChannel, - source: FileDownloadChannel, - client: TransportClient) extends StreamCallback { - - override def onData(streamId: String, buf: ByteBuffer): Unit = { - while (buf.remaining() > 0) { - sink.write(buf) - } - } - - override def onComplete(streamId: String): Unit = { - sink.close() - } - - override def onFailure(streamId: String, cause: Throwable): Unit = { - logDebug(s"Error downloading stream $streamId.", cause) - source.setError(cause) - sink.close() - } - - } } private[rss] object NettyRpcEnv extends Logging { @@ -637,25 +520,59 @@ private[rss] case class RpcFailure(e: Throwable) */ private[rss] class NettyRpcHandler( dispatcher: Dispatcher, - nettyEnv: NettyRpcEnv, - streamManager: StreamManager) extends RpcHandler with Logging { + nettyEnv: NettyRpcEnv) extends BaseMessageHandler with Logging { // A variable to track the remote RpcEnv addresses of all clients private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() override def receive( client: TransportClient, - message: ByteBuffer, - callback: RpcResponseCallback): Unit = { - val messageToDispatch = internalReceive(client, message) - dispatcher.postRemoteMessage(messageToDispatch, callback) + requestMessage: NRequestMessage): Unit = { + requestMessage match { + case r: RpcRequest => + processRpc(client, r) + case r: NOneWayMessage => + processOnewayMessage(client, r) + } } - override def receive( - client: TransportClient, - message: ByteBuffer): Unit = { - val messageToDispatch = internalReceive(client, message) - dispatcher.postOneWayMessage(messageToDispatch) + private def processRpc(client: TransportClient, r: RpcRequest): Unit = { + val callback = new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + client.getChannel.writeAndFlush(new RpcResponse(r.requestId, + new NioManagedBuffer(response))) + } + + override def onFailure(e: Throwable): Unit = { + client.getChannel.writeAndFlush(new NRpcFailure(r.requestId, + Throwables.getStackTraceAsString(e))) + } + } + try { + val message = r.body().nioByteBuffer() + val messageToDispatch = internalReceive(client, message) + dispatcher.postRemoteMessage(messageToDispatch, callback) + } catch { + case e: Exception => + logError("Error while invoking RpcHandler#receive() on RPC id " + r.requestId, e) + client.getChannel.writeAndFlush(new NRpcFailure(r.requestId, + Throwables.getStackTraceAsString(e))) + } finally { + r.body().release() + } + } + + private def processOnewayMessage(client: TransportClient, r: NOneWayMessage): Unit = { + try { + val message = r.body().nioByteBuffer() + val messageToDispatch = internalReceive(client, message) + dispatcher.postOneWayMessage(messageToDispatch) + } catch { + case e: Exception => + logError("Error while invoking RpcHandler#receive() for one-way message.", e) + } finally { + r.body().release() + } } private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { @@ -681,8 +598,6 @@ private[rss] class NettyRpcHandler( nettyEnv.checkRegistered() } - override def getStreamManager: StreamManager = streamManager - override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/NettyStreamManager.scala b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/NettyStreamManager.scala index 3c2c8e5cf3a..2a2013ee050 100644 --- a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/NettyStreamManager.scala +++ b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/NettyStreamManager.scala @@ -32,7 +32,7 @@ private[rss] class NettyStreamManager(rpcEnv: NettyRpcEnv) private val jars = new ConcurrentHashMap[String, File]() private val dirs = new ConcurrentHashMap[String, File]() - override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + override def getChunk(streamId: Long, chunkIndex: Int, offset: Int, len: Int): ManagedBuffer = { throw new UnsupportedOperationException() } diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/Outbox.scala b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/Outbox.scala index 9f62c824e49..f3c0a569b4d 100644 --- a/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/Outbox.scala +++ b/common/src/main/scala/com/aliyun/emr/rss/common/rpc/netty/Outbox.scala @@ -145,7 +145,7 @@ private[rss] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // We are connecting to the remote address, so just exit return } - if (client == null) { + if (client == null || !client.isActive) { // There is no connect task but client is null, so we need to launch the connect task. launchConnectTask() return diff --git a/common/src/main/scala/com/aliyun/emr/rss/common/util/MemoryParam.scala b/common/src/main/scala/com/aliyun/emr/rss/common/util/MemoryParam.scala deleted file mode 100644 index aaeacf50fc5..00000000000 --- a/common/src/main/scala/com/aliyun/emr/rss/common/util/MemoryParam.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.util - -/** - * An extractor object for parsing JVM memory strings, such as "10g", into an Int representing - * the number of megabytes. Supports the same formats as Utils.memoryStringToMb. - */ -private[rss] object MemoryParam { - def unapply(str: String): Option[Long] = { - try { - Some(Utils.byteStringAsBytes(str)) - } catch { - case e: NumberFormatException => None - } - } -} diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/ChunkFetchIntegrationSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/ChunkFetchIntegrationSuiteJ.java index 1dec024601e..24994a3a046 100644 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/ChunkFetchIntegrationSuiteJ.java +++ b/common/src/test/java/com/aliyun/emr/rss/common/network/ChunkFetchIntegrationSuiteJ.java @@ -35,10 +35,13 @@ import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; import com.aliyun.emr.rss.common.network.buffer.NioManagedBuffer; import com.aliyun.emr.rss.common.network.client.ChunkReceivedCallback; -import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; import com.aliyun.emr.rss.common.network.client.TransportClient; import com.aliyun.emr.rss.common.network.client.TransportClientFactory; -import com.aliyun.emr.rss.common.network.server.RpcHandler; +import com.aliyun.emr.rss.common.network.protocol.ChunkFetchRequest; +import com.aliyun.emr.rss.common.network.protocol.ChunkFetchSuccess; +import com.aliyun.emr.rss.common.network.protocol.RequestMessage; +import com.aliyun.emr.rss.common.network.protocol.StreamChunkSlice; +import com.aliyun.emr.rss.common.network.server.BaseMessageHandler; import com.aliyun.emr.rss.common.network.server.StreamManager; import com.aliyun.emr.rss.common.network.server.TransportServer; import com.aliyun.emr.rss.common.network.util.MapConfigProvider; @@ -85,7 +88,7 @@ public static void setUp() throws Exception { streamManager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { + public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) { assertEquals(STREAM_ID, streamId); if (chunkIndex == BUFFER_CHUNK_INDEX) { return new NioManagedBuffer(buf); @@ -96,18 +99,20 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { } } }; - RpcHandler handler = new RpcHandler() { + BaseMessageHandler handler = new BaseMessageHandler() { @Override public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - throw new UnsupportedOperationException(); + TransportClient client, + RequestMessage msg) { + StreamChunkSlice slice = ((ChunkFetchRequest) msg).streamChunkSlice; + ManagedBuffer buf = streamManager.getChunk(slice.streamId, slice.chunkIndex, + slice.offset, slice.len); + client.getChannel().writeAndFlush(new ChunkFetchSuccess(slice, buf)); } @Override - public StreamManager getStreamManager() { - return streamManager; + public boolean checkRegistered() { + return true; } }; TransportContext context = new TransportContext(conf, handler); diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/ProtocolSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/ProtocolSuiteJ.java deleted file mode 100644 index db5fcca42bc..00000000000 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/ProtocolSuiteJ.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network; - -import java.util.List; - -import com.google.common.primitives.Ints; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.FileRegion; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.MessageToMessageEncoder; -import org.junit.Test; -import static org.junit.Assert.assertEquals; - -import com.aliyun.emr.rss.common.network.protocol.*; -import com.aliyun.emr.rss.common.network.util.ByteArrayWritableChannel; -import com.aliyun.emr.rss.common.network.util.NettyUtils; - -public class ProtocolSuiteJ { - private void testServerToClient(Message msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), - MessageEncoder.INSTANCE); - serverChannel.writeOutbound(msg); - - EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); - - while (!serverChannel.outboundMessages().isEmpty()) { - clientChannel.writeOneInbound(serverChannel.readOutbound()); - } - - assertEquals(1, clientChannel.inboundMessages().size()); - assertEquals(msg, clientChannel.readInbound()); - } - - private void testClientToServer(Message msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), - MessageEncoder.INSTANCE); - clientChannel.writeOutbound(msg); - - EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); - - while (!clientChannel.outboundMessages().isEmpty()) { - serverChannel.writeOneInbound(clientChannel.readOutbound()); - } - - assertEquals(1, serverChannel.inboundMessages().size()); - assertEquals(msg, serverChannel.readInbound()); - } - - @Test - public void requests() { - testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); - testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0))); - testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10))); - testClientToServer(new StreamRequest("abcde")); - testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); - } - - @Test - public void responses() { - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); - testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); - testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); - testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); - testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100))); - testServerToClient(new RpcFailure(0, "this is an error")); - testServerToClient(new RpcFailure(0, "")); - // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the - // channel and cannot be tested like this. - testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0))); - testServerToClient(new StreamFailure("anId", "this is an error")); - } - - /** - * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer - * bytes, but messages, so this is needed so that the frame decoder on the receiving side can - * understand what MessageWithHeader actually contains. - */ - private static class FileRegionEncoder extends MessageToMessageEncoder { - - @Override - public void encode(ChannelHandlerContext ctx, FileRegion in, List out) - throws Exception { - - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); - while (in.transferred() < in.count()) { - in.transferTo(channel, in.transferred()); - } - out.add(Unpooled.wrappedBuffer(channel.getData())); - } - } -} diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/RequestTimeoutIntegrationSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/RequestTimeoutIntegrationSuiteJ.java index 4c87ad8231a..3e51deb42d4 100644 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/RequestTimeoutIntegrationSuiteJ.java +++ b/common/src/test/java/com/aliyun/emr/rss/common/network/RequestTimeoutIntegrationSuiteJ.java @@ -37,7 +37,8 @@ import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; import com.aliyun.emr.rss.common.network.client.TransportClient; import com.aliyun.emr.rss.common.network.client.TransportClientFactory; -import com.aliyun.emr.rss.common.network.server.RpcHandler; +import com.aliyun.emr.rss.common.network.protocol.*; +import com.aliyun.emr.rss.common.network.server.BaseMessageHandler; import com.aliyun.emr.rss.common.network.server.StreamManager; import com.aliyun.emr.rss.common.network.server.TransportServer; import com.aliyun.emr.rss.common.network.util.MapConfigProvider; @@ -64,12 +65,12 @@ public class RequestTimeoutIntegrationSuiteJ { @Before public void setUp() throws Exception { Map configMap = new HashMap<>(); - configMap.put("rss.shuffle.io.connectionTimeout", "10s"); + configMap.put("rss.shuffle.io.connectionTimeout", "2s"); conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); defaultManager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { + public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) { throw new UnsupportedOperationException(); } }; @@ -90,23 +91,24 @@ public void tearDown() { public void timeoutInactiveRequests() throws Exception { final Semaphore semaphore = new Semaphore(1); final int responseSize = 16; - RpcHandler handler = new RpcHandler() { + BaseMessageHandler handler = new BaseMessageHandler() { @Override public void receive( TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { + RequestMessage message) { try { semaphore.acquire(); - callback.onSuccess(ByteBuffer.allocate(responseSize)); + client.getChannel().writeAndFlush(new RpcResponse( + ((RpcRequest) message).requestId, + new NioManagedBuffer(ByteBuffer.allocate(responseSize)))); } catch (InterruptedException e) { // do nothing } } @Override - public StreamManager getStreamManager() { - return defaultManager; + public boolean checkRegistered() { + return true; } }; @@ -137,23 +139,24 @@ public StreamManager getStreamManager() { public void timeoutCleanlyClosesClient() throws Exception { final Semaphore semaphore = new Semaphore(0); final int responseSize = 16; - RpcHandler handler = new RpcHandler() { + BaseMessageHandler handler = new BaseMessageHandler() { @Override public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { + TransportClient client, + RequestMessage message) { try { semaphore.acquire(); - callback.onSuccess(ByteBuffer.allocate(responseSize)); + client.getChannel().writeAndFlush(new RpcResponse( + ((RpcRequest) message).requestId, + new NioManagedBuffer(ByteBuffer.allocate(responseSize)))); } catch (InterruptedException e) { // do nothing } } @Override - public StreamManager getStreamManager() { - return defaultManager; + public boolean checkRegistered() { + return true; } }; @@ -188,23 +191,25 @@ public void furtherRequestsDelay() throws Exception { final byte[] response = new byte[16]; final StreamManager manager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { + public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) { Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); return new NioManagedBuffer(ByteBuffer.wrap(response)); } }; - RpcHandler handler = new RpcHandler() { + BaseMessageHandler handler = new BaseMessageHandler() { @Override public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - throw new UnsupportedOperationException(); + TransportClient client, + RequestMessage msg) { + StreamChunkSlice slice = ((ChunkFetchRequest) msg).streamChunkSlice; + ManagedBuffer buf = manager.getChunk(slice.streamId, slice.chunkIndex, + slice.offset, slice.len); + client.getChannel().writeAndFlush(new ChunkFetchSuccess(slice, buf)); } @Override - public StreamManager getStreamManager() { - return manager; + public boolean checkRegistered() { + return true; } }; diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/RpcIntegrationSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/RpcIntegrationSuiteJ.java index 3b9650a4f0b..1be178119c3 100644 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/RpcIntegrationSuiteJ.java +++ b/common/src/test/java/com/aliyun/emr/rss/common/network/RpcIntegrationSuiteJ.java @@ -17,15 +17,13 @@ package com.aliyun.emr.rss.common.network; -import java.io.*; import java.nio.ByteBuffer; import java.util.*; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import com.google.common.base.Throwables; import com.google.common.collect.Sets; -import com.google.common.io.Files; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.junit.AfterClass; @@ -33,15 +31,12 @@ import org.junit.Test; import static org.junit.Assert.*; -import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; import com.aliyun.emr.rss.common.network.buffer.NioManagedBuffer; import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; -import com.aliyun.emr.rss.common.network.client.StreamCallbackWithID; import com.aliyun.emr.rss.common.network.client.TransportClient; import com.aliyun.emr.rss.common.network.client.TransportClientFactory; -import com.aliyun.emr.rss.common.network.server.OneForOneStreamManager; -import com.aliyun.emr.rss.common.network.server.RpcHandler; -import com.aliyun.emr.rss.common.network.server.StreamManager; +import com.aliyun.emr.rss.common.network.protocol.*; +import com.aliyun.emr.rss.common.network.server.BaseMessageHandler; import com.aliyun.emr.rss.common.network.server.TransportServer; import com.aliyun.emr.rss.common.network.util.JavaUtils; import com.aliyun.emr.rss.common.network.util.MapConfigProvider; @@ -51,116 +46,70 @@ public class RpcIntegrationSuiteJ { static TransportConf conf; static TransportServer server; static TransportClientFactory clientFactory; - static RpcHandler rpcHandler; + static BaseMessageHandler handler; static List oneWayMsgs; static StreamTestHelper testData; - static ConcurrentHashMap streamCallbacks = - new ConcurrentHashMap<>(); - @BeforeClass public static void setUp() throws Exception { conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); testData = new StreamTestHelper(); - rpcHandler = new RpcHandler() { + handler = new BaseMessageHandler() { @Override public void receive( TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - String msg = JavaUtils.bytesToString(message); - String[] parts = msg.split("/"); - if (parts[0].equals("hello")) { - callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); - } else if (parts[0].equals("return error")) { - callback.onFailure(new RuntimeException("Returned: " + parts[1])); - } else if (parts[0].equals("throw error")) { - throw new RuntimeException("Thrown: " + parts[1]); + RequestMessage message) { + if (message instanceof RpcRequest) { + String msg; + RpcRequest r = (RpcRequest) message; + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + client.getChannel().writeAndFlush(new RpcResponse(r.requestId, + new NioManagedBuffer(response))); + } + + @Override + public void onFailure(Throwable e) { + client.getChannel().writeAndFlush(new RpcFailure(r.requestId, + Throwables.getStackTraceAsString(e))); + } + }; + try { + msg = JavaUtils.bytesToString(message.body().nioByteBuffer()); + } catch (Exception e) { + throw new RuntimeException(e); + } + String[] parts = msg.split("/"); + if (parts[0].equals("hello")) { + callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); + } else if (parts[0].equals("return error")) { + callback.onFailure(new RuntimeException("Returned: " + parts[1])); + } else if (parts[0].equals("throw error")) { + callback.onFailure(new RuntimeException("Thrown: " + parts[1])); + } + } else if (message instanceof OneWayMessage) { + String msg; + try { + msg = JavaUtils.bytesToString(message.body().nioByteBuffer()); + } catch (Exception e) { + throw new RuntimeException(e); + } + oneWayMsgs.add(msg); } } @Override - public StreamCallbackWithID receiveStream( - TransportClient client, - ByteBuffer messageHeader, - RpcResponseCallback callback) { - return receiveStreamHelper(JavaUtils.bytesToString(messageHeader)); - } - - @Override - public void receive(TransportClient client, ByteBuffer message) { - oneWayMsgs.add(JavaUtils.bytesToString(message)); + public boolean checkRegistered() { + return true; } - - @Override - public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; - TransportContext context = new TransportContext(conf, rpcHandler); + TransportContext context = new TransportContext(conf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); oneWayMsgs = new ArrayList<>(); } - private static StreamCallbackWithID receiveStreamHelper(String msg) { - try { - if (msg.startsWith("fail/")) { - String[] parts = msg.split("/"); - switch (parts[1]) { - case "exception-ondata": - return new StreamCallbackWithID() { - @Override - public void onData(String streamId, ByteBuffer buf) throws IOException { - throw new IOException("failed to read stream data!"); - } - - @Override - public void onComplete(String streamId) throws IOException { - } - - @Override - public void onFailure(String streamId, Throwable cause) throws IOException { - } - - @Override - public String getID() { - return msg; - } - }; - case "exception-oncomplete": - return new StreamCallbackWithID() { - @Override - public void onData(String streamId, ByteBuffer buf) throws IOException { - } - - @Override - public void onComplete(String streamId) throws IOException { - throw new IOException("exception in onComplete"); - } - - @Override - public void onFailure(String streamId, Throwable cause) throws IOException { - } - - @Override - public String getID() { - return msg; - } - }; - case "null": - return null; - default: - throw new IllegalArgumentException("unexpected msg: " + msg); - } - } else { - VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); - streamCallbacks.put(msg, streamCallback); - return streamCallback; - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - @AfterClass public static void tearDown() { server.close(); @@ -207,35 +156,6 @@ public void onFailure(Throwable e) { return res; } - private RpcResult sendRpcWithStream(String... streams) throws Exception { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - final Semaphore sem = new Semaphore(0); - RpcResult res = new RpcResult(); - res.successMessages = Collections.synchronizedSet(new HashSet()); - res.errorMessages = Collections.synchronizedSet(new HashSet()); - - for (String stream : streams) { - int idx = stream.lastIndexOf('/'); - ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); - String streamName = (idx == -1) ? stream : stream.substring(idx + 1); - ManagedBuffer data = testData.openStream(conf, streamName); - client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem)); - } - - if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); - } - streamCallbacks.values().forEach(streamCallback -> { - try { - streamCallback.verify(); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - client.close(); - return res; - } - private static class RpcStreamCallback implements RpcResponseCallback { final String streamId; final RpcResult res; @@ -323,46 +243,6 @@ public void sendOneWayMessage() throws Exception { } } - @Test - public void sendRpcWithStreamOneAtATime() throws Exception { - for (String stream : StreamTestHelper.STREAMS) { - RpcResult res = sendRpcWithStream(stream); - assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty()); - assertEquals(Sets.newHashSet(stream), res.successMessages); - } - } - - @Test - public void sendRpcWithStreamConcurrently() throws Exception { - String[] streams = new String[10]; - for (int i = 0; i < 10; i++) { - streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length]; - } - RpcResult res = sendRpcWithStream(streams); - assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), res.successMessages); - assertTrue(res.errorMessages.isEmpty()); - } - - @Test - public void sendRpcWithStreamFailures() throws Exception { - // when there is a failure reading stream data, we don't try to keep the channel usable, - // just send back a decent error msg. - RpcResult exceptionInCallbackResult = - sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer"); - assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); - - RpcResult nullStreamHandler = - sendRpcWithStream("fail/null/smallBuffer", "smallBuffer"); - assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); - - // OTOH, if there is a failure during onComplete, the channel should still be fine - RpcResult exceptionInOnComplete = - sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer"); - assertErrorsContain(exceptionInOnComplete.errorMessages, - Sets.newHashSet("Failure post-processing")); - assertEquals(Sets.newHashSet("smallBuffer"), exceptionInOnComplete.successMessages); - } - private void assertErrorsContain(Set errors, Set contains) { assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " + errors, contains.size(), errors.size()); @@ -426,60 +306,4 @@ private Pair, Set> checkErrorsContain( } return new ImmutablePair<>(remainingErrors, notFound); } - - private static class VerifyingStreamCallback implements StreamCallbackWithID { - final String streamId; - final StreamSuiteJ.TestCallback helper; - final OutputStream out; - final File outFile; - - VerifyingStreamCallback(String streamId) throws IOException { - if (streamId.equals("file")) { - outFile = File.createTempFile("data", ".tmp", testData.tempDir); - out = new FileOutputStream(outFile); - } else { - out = new ByteArrayOutputStream(); - outFile = null; - } - this.streamId = streamId; - helper = new StreamSuiteJ.TestCallback(out); - } - - void verify() throws IOException { - if (streamId.equals("file")) { - assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); - } else { - byte[] result = ((ByteArrayOutputStream)out).toByteArray(); - ByteBuffer srcBuffer = testData.srcBuffer(streamId); - ByteBuffer base; - synchronized (srcBuffer) { - base = srcBuffer.duplicate(); - } - byte[] expected = new byte[base.remaining()]; - base.get(expected); - assertEquals(expected.length, result.length); - assertTrue("buffers don't match", Arrays.equals(expected, result)); - } - } - - @Override - public void onData(String streamId, ByteBuffer buf) throws IOException { - helper.onData(streamId, buf); - } - - @Override - public void onComplete(String streamId) throws IOException { - helper.onComplete(streamId); - } - - @Override - public void onFailure(String streamId, Throwable cause) throws IOException { - helper.onFailure(streamId, cause); - } - - @Override - public String getID() { - return streamId; - } - } } diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/StreamSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/StreamSuiteJ.java deleted file mode 100644 index 801dabc4e01..00000000000 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/StreamSuiteJ.java +++ /dev/null @@ -1,306 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network; - -import java.io.*; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; - -import com.google.common.io.Files; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import static org.junit.Assert.*; - -import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; -import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; -import com.aliyun.emr.rss.common.network.client.StreamCallback; -import com.aliyun.emr.rss.common.network.client.TransportClient; -import com.aliyun.emr.rss.common.network.client.TransportClientFactory; -import com.aliyun.emr.rss.common.network.server.RpcHandler; -import com.aliyun.emr.rss.common.network.server.StreamManager; -import com.aliyun.emr.rss.common.network.server.TransportServer; -import com.aliyun.emr.rss.common.network.util.MapConfigProvider; -import com.aliyun.emr.rss.common.network.util.TransportConf; - -public class StreamSuiteJ { - private static final String[] STREAMS = StreamTestHelper.STREAMS; - private static StreamTestHelper testData; - - private static TransportServer server; - private static TransportClientFactory clientFactory; - - private static ByteBuffer createBuffer(int bufSize) { - ByteBuffer buf = ByteBuffer.allocate(bufSize); - for (int i = 0; i < bufSize; i ++) { - buf.put((byte) i); - } - buf.flip(); - return buf; - } - - @BeforeClass - public static void setUp() throws Exception { - testData = new StreamTestHelper(); - - final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); - final StreamManager streamManager = new StreamManager() { - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public ManagedBuffer openStream(String streamId) { - return testData.openStream(conf, streamId); - } - }; - RpcHandler handler = new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - throw new UnsupportedOperationException(); - } - - @Override - public StreamManager getStreamManager() { - return streamManager; - } - }; - TransportContext context = new TransportContext(conf, handler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - } - - @AfterClass - public static void tearDown() { - server.close(); - clientFactory.close(); - testData.cleanup(); - } - - @Test - public void testZeroLengthStream() throws Throwable { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - try { - StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5)); - task.run(); - task.check(); - } finally { - client.close(); - } - } - - @Test - public void testSingleStream() throws Throwable { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - try { - StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5)); - task.run(); - task.check(); - } finally { - client.close(); - } - } - - @Test - public void testMultipleStreams() throws Throwable { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - try { - for (int i = 0; i < 20; i++) { - StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length], - TimeUnit.SECONDS.toMillis(5)); - task.run(); - task.check(); - } - } finally { - client.close(); - } - } - - @Test - public void testConcurrentStreams() throws Throwable { - ExecutorService executor = Executors.newFixedThreadPool(20); - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - - try { - List tasks = new ArrayList<>(); - for (int i = 0; i < 20; i++) { - StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length], - TimeUnit.SECONDS.toMillis(20)); - tasks.add(task); - executor.submit(task); - } - - executor.shutdown(); - assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS)); - for (StreamTask task : tasks) { - task.check(); - } - } finally { - executor.shutdownNow(); - client.close(); - } - } - - private static class StreamTask implements Runnable { - - private final TransportClient client; - private final String streamId; - private final long timeoutMs; - private Throwable error; - - StreamTask(TransportClient client, String streamId, long timeoutMs) { - this.client = client; - this.streamId = streamId; - this.timeoutMs = timeoutMs; - } - - @Override - public void run() { - ByteBuffer srcBuffer = null; - OutputStream out = null; - File outFile = null; - try { - ByteArrayOutputStream baos = null; - - switch (streamId) { - case "largeBuffer": - baos = new ByteArrayOutputStream(); - out = baos; - srcBuffer = testData.largeBuffer; - break; - case "smallBuffer": - baos = new ByteArrayOutputStream(); - out = baos; - srcBuffer = testData.smallBuffer; - break; - case "file": - outFile = File.createTempFile("data", ".tmp", testData.tempDir); - out = new FileOutputStream(outFile); - break; - case "emptyBuffer": - baos = new ByteArrayOutputStream(); - out = baos; - srcBuffer = testData.emptyBuffer; - break; - default: - throw new IllegalArgumentException(streamId); - } - - TestCallback callback = new TestCallback(out); - client.stream(streamId, callback); - callback.waitForCompletion(timeoutMs); - - if (srcBuffer == null) { - assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); - } else { - ByteBuffer base; - synchronized (srcBuffer) { - base = srcBuffer.duplicate(); - } - byte[] result = baos.toByteArray(); - byte[] expected = new byte[base.remaining()]; - base.get(expected); - assertEquals(expected.length, result.length); - assertTrue("buffers don't match", Arrays.equals(expected, result)); - } - } catch (Throwable t) { - error = t; - } finally { - if (out != null) { - try { - out.close(); - } catch (Exception e) { - // ignore. - } - } - if (outFile != null) { - outFile.delete(); - } - } - } - - public void check() throws Throwable { - if (error != null) { - throw error; - } - } - } - - static class TestCallback implements StreamCallback { - - private final OutputStream out; - public volatile boolean completed; - public volatile Throwable error; - - TestCallback(OutputStream out) { - this.out = out; - this.completed = false; - } - - @Override - public void onData(String streamId, ByteBuffer buf) throws IOException { - byte[] tmp = new byte[buf.remaining()]; - buf.get(tmp); - out.write(tmp); - } - - @Override - public void onComplete(String streamId) throws IOException { - out.close(); - synchronized (this) { - completed = true; - notifyAll(); - } - } - - @Override - public void onFailure(String streamId, Throwable cause) { - error = cause; - synchronized (this) { - completed = true; - notifyAll(); - } - } - - void waitForCompletion(long timeoutMs) { - long now = System.currentTimeMillis(); - long deadline = now + timeoutMs; - synchronized (this) { - while (!completed && now < deadline) { - try { - wait(deadline - now); - } catch (InterruptedException ie) { - throw new RuntimeException(ie); - } - now = System.currentTimeMillis(); - } - } - assertTrue("Timed out waiting for stream.", completed); - assertNull(error); - } - } -} diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/TransportClientFactorySuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/TransportClientFactorySuiteJ.java index 5e7bf0c5681..02eebd42858 100644 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/TransportClientFactorySuiteJ.java +++ b/common/src/test/java/com/aliyun/emr/rss/common/network/TransportClientFactorySuiteJ.java @@ -28,8 +28,7 @@ import com.aliyun.emr.rss.common.network.client.TransportClient; import com.aliyun.emr.rss.common.network.client.TransportClientFactory; -import com.aliyun.emr.rss.common.network.server.NoOpRpcHandler; -import com.aliyun.emr.rss.common.network.server.RpcHandler; +import com.aliyun.emr.rss.common.network.server.BaseMessageHandler; import com.aliyun.emr.rss.common.network.server.TransportServer; import com.aliyun.emr.rss.common.network.util.ConfigProvider; import com.aliyun.emr.rss.common.network.util.JavaUtils; @@ -44,8 +43,8 @@ public class TransportClientFactorySuiteJ { @Before public void setUp() { TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); - RpcHandler rpcHandler = new NoOpRpcHandler(); - context = new TransportContext(conf, rpcHandler); + BaseMessageHandler handler = new BaseMessageHandler(); + context = new TransportContext(conf, handler); server1 = context.createServer(); server2 = context.createServer(); } @@ -69,8 +68,8 @@ private void testClientReuse(int maxConnections, boolean concurrent) configMap.put("rss.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); - RpcHandler rpcHandler = new NoOpRpcHandler(); - TransportContext context = new TransportContext(conf, rpcHandler); + BaseMessageHandler handler = new BaseMessageHandler(); + TransportContext context = new TransportContext(conf, handler); TransportClientFactory factory = context.createClientFactory(); Set clients = Collections.synchronizedSet( new HashSet()); @@ -194,7 +193,7 @@ public Iterable> getAll() { throw new UnsupportedOperationException(); } }); - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportContext context = new TransportContext(conf, new BaseMessageHandler(), true); try (TransportClientFactory factory = context.createClientFactory()) { TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); assertTrue(c1.isActive()); diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/TransportRequestHandlerSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/TransportRequestHandlerSuiteJ.java deleted file mode 100644 index f3ca503b8de..00000000000 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/TransportRequestHandlerSuiteJ.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network; - -import java.util.ArrayList; -import java.util.List; - -import io.netty.channel.Channel; -import io.netty.channel.ChannelPromise; -import io.netty.channel.DefaultChannelPromise; -import io.netty.util.concurrent.Future; -import io.netty.util.concurrent.GenericFutureListener; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.Test; -import org.mockito.Mockito; - -import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; -import com.aliyun.emr.rss.common.network.client.TransportClient; -import com.aliyun.emr.rss.common.network.protocol.*; -import com.aliyun.emr.rss.common.network.server.*; - -public class TransportRequestHandlerSuiteJ { - - @Test - public void handleFetchRequestAndStreamRequest() throws Exception { - RpcHandler rpcHandler = new NoOpRpcHandler(); - OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager()); - Channel channel = Mockito.mock(Channel.class); - List> responseAndPromisePairs = - new ArrayList<>(); - Mockito.when(channel.writeAndFlush(Mockito.any())) - .thenAnswer(invocationOnMock0 -> { - Object response = invocationOnMock0.getArguments()[0]; - ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); - responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); - return channelFuture; - }); - - // Prepare the stream. - List managedBuffers = new ArrayList<>(); - managedBuffers.add(new TestManagedBuffer(10)); - managedBuffers.add(new TestManagedBuffer(20)); - managedBuffers.add(new TestManagedBuffer(30)); - managedBuffers.add(new TestManagedBuffer(40)); - - ManagedBufferIterator iterator = Mockito.mock(ManagedBufferIterator.class); - Mockito.when(iterator.chunk(Mockito.anyInt())) - .thenReturn(managedBuffers.get(0)) - .thenReturn(managedBuffers.get(1)) - .thenReturn(managedBuffers.get(2)) - .thenReturn(managedBuffers.get(3)); - Mockito.when(iterator.hasNext()) - .thenReturn(true).thenReturn(true) - .thenReturn(true).thenReturn(true) - .thenReturn(true).thenReturn(true) - .thenReturn(true).thenReturn(true) - .thenReturn(false); - - long streamId = streamManager.registerStream("test-app", iterator, channel); - - assert streamManager.numStreamStates() == 1; - - TransportClient reverseClient = Mockito.mock(TransportClient.class); - TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, - rpcHandler, 2L); - - RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0)); - requestHandler.handle(request0); - assert responseAndPromisePairs.size() == 1; - assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess; - assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() == - managedBuffers.get(0); - - RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1)); - requestHandler.handle(request1); - assert responseAndPromisePairs.size() == 2; - assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess; - assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() == - managedBuffers.get(1); - - // Finish flushing the response for request0. - responseAndPromisePairs.get(0).getRight().finish(true); - - RequestMessage request2 = new StreamRequest(String.format("%d_%d", streamId, 2)); - requestHandler.handle(request2); - assert responseAndPromisePairs.size() == 3; - assert responseAndPromisePairs.get(2).getLeft() instanceof StreamResponse; - assert ((StreamResponse) (responseAndPromisePairs.get(2).getLeft())).body() == - managedBuffers.get(2); - - // Request3 will trigger the close of channel, because the number of max chunks being - // transferred is 2; - RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3)); - requestHandler.handle(request3); - Mockito.verify(channel, Mockito.times(1)).close(); - assert responseAndPromisePairs.size() == 3; - - streamManager.connectionTerminated(channel); - assert streamManager.numStreamStates() == 0; - } - - private class ExtendedChannelPromise extends DefaultChannelPromise { - - private List>> listeners = new ArrayList<>(); - private boolean success; - - ExtendedChannelPromise(Channel channel) { - super(channel); - success = false; - } - - @Override - public ChannelPromise addListener( - GenericFutureListener> listener) { - @SuppressWarnings("unchecked") - GenericFutureListener> gfListener = - (GenericFutureListener>) listener; - listeners.add(gfListener); - return super.addListener(listener); - } - - @Override - public boolean isSuccess() { - return success; - } - - public void finish(boolean success) { - this.success = success; - listeners.forEach(listener -> { - try { - listener.operationComplete(this); - } catch (Exception e) { - // do nothing - } - }); - } - } -} diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/TransportResponseHandlerSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/TransportResponseHandlerSuiteJ.java index 33a4f119e39..e8c696f7c1c 100644 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/TransportResponseHandlerSuiteJ.java +++ b/common/src/test/java/com/aliyun/emr/rss/common/network/TransportResponseHandlerSuiteJ.java @@ -17,10 +17,8 @@ package com.aliyun.emr.rss.common.network; -import java.io.IOException; import java.nio.ByteBuffer; -import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -29,35 +27,33 @@ import com.aliyun.emr.rss.common.network.buffer.NioManagedBuffer; import com.aliyun.emr.rss.common.network.client.ChunkReceivedCallback; import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; -import com.aliyun.emr.rss.common.network.client.StreamCallback; import com.aliyun.emr.rss.common.network.client.TransportResponseHandler; import com.aliyun.emr.rss.common.network.protocol.*; -import com.aliyun.emr.rss.common.network.util.TransportFrameDecoder; public class TransportResponseHandlerSuiteJ { @Test public void handleSuccessfulFetch() throws Exception { - StreamChunkId streamChunkId = new StreamChunkId(1, 0); + StreamChunkSlice streamChunkSlice = new StreamChunkSlice(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); + handler.addFetchRequest(streamChunkSlice, callback); assertEquals(1, handler.numOutstandingRequests()); - handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + handler.handle(new ChunkFetchSuccess(streamChunkSlice, new TestManagedBuffer(123))); verify(callback, times(1)).onSuccess(eq(0), any()); assertEquals(0, handler.numOutstandingRequests()); } @Test public void handleFailedFetch() throws Exception { - StreamChunkId streamChunkId = new StreamChunkId(1, 0); + StreamChunkSlice streamChunkSlice = new StreamChunkSlice(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); + handler.addFetchRequest(streamChunkSlice, callback); assertEquals(1, handler.numOutstandingRequests()); - handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); + handler.handle(new ChunkFetchFailure(streamChunkSlice, "some error msg")); verify(callback, times(1)).onFailure(eq(0), any()); assertEquals(0, handler.numOutstandingRequests()); } @@ -66,12 +62,12 @@ public void handleFailedFetch() throws Exception { public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(new StreamChunkId(1, 0), callback); - handler.addFetchRequest(new StreamChunkId(1, 1), callback); - handler.addFetchRequest(new StreamChunkId(1, 2), callback); + handler.addFetchRequest(new StreamChunkSlice(1, 0), callback); + handler.addFetchRequest(new StreamChunkSlice(1, 1), callback); + handler.addFetchRequest(new StreamChunkSlice(1, 2), callback); assertEquals(3, handler.numOutstandingRequests()); - handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + handler.handle(new ChunkFetchSuccess(new StreamChunkSlice(1, 0), new TestManagedBuffer(12))); handler.exceptionCaught(new Exception("duh duh duhhhh")); // should fail both b2 and b3 @@ -112,52 +108,4 @@ public void handleFailedRPC() throws Exception { verify(callback, times(1)).onFailure(any()); assertEquals(0, handler.numOutstandingRequests()); } - - @Test - public void testActiveStreams() throws Exception { - Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); - TransportResponseHandler handler = new TransportResponseHandler(c); - - StreamResponse response = new StreamResponse("stream", 1234L, null); - StreamCallback cb = mock(StreamCallback.class); - handler.addStreamCallback("stream", cb); - assertEquals(1, handler.numOutstandingRequests()); - handler.handle(response); - assertEquals(1, handler.numOutstandingRequests()); - handler.deactivateStream(); - assertEquals(0, handler.numOutstandingRequests()); - - StreamFailure failure = new StreamFailure("stream", "uh-oh"); - handler.addStreamCallback("stream", cb); - assertEquals(1, handler.numOutstandingRequests()); - handler.handle(failure); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void failOutstandingStreamCallbackOnClose() throws Exception { - Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); - TransportResponseHandler handler = new TransportResponseHandler(c); - - StreamCallback cb = mock(StreamCallback.class); - handler.addStreamCallback("stream-1", cb); - handler.channelInactive(); - - verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); - } - - @Test - public void failOutstandingStreamCallbackOnException() throws Exception { - Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); - TransportResponseHandler handler = new TransportResponseHandler(c); - - StreamCallback cb = mock(StreamCallback.class); - handler.addStreamCallback("stream-1", cb); - handler.exceptionCaught(new IOException("Oops!")); - - verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); - } } diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/protocol/MessageWithHeaderSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/protocol/MessageWithHeaderSuiteJ.java deleted file mode 100644 index 890d9da8a6b..00000000000 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/protocol/MessageWithHeaderSuiteJ.java +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.protocol; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.Unpooled; -import org.junit.Test; -import org.mockito.Mockito; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import com.aliyun.emr.rss.common.network.TestManagedBuffer; -import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; -import com.aliyun.emr.rss.common.network.buffer.NettyManagedBuffer; -import com.aliyun.emr.rss.common.network.util.AbstractFileRegion; -import com.aliyun.emr.rss.common.network.util.ByteArrayWritableChannel; - -public class MessageWithHeaderSuiteJ { - - @Test - public void testSingleWrite() throws Exception { - testFileRegionBody(8, 8); - } - - @Test - public void testShortWrite() throws Exception { - testFileRegionBody(8, 1); - } - - @Test - public void testByteBufBody() throws Exception { - testByteBufBody(Unpooled.copyLong(42)); - } - - @Test - public void testCompositeByteBufBodySingleBuffer() throws Exception { - ByteBuf header = Unpooled.copyLong(42); - CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); - compositeByteBuf.addComponent(true, header); - assertEquals(1, compositeByteBuf.nioBufferCount()); - testByteBufBody(compositeByteBuf); - } - - @Test - public void testCompositeByteBufBodyMultipleBuffers() throws Exception { - ByteBuf header = Unpooled.copyLong(42); - CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); - compositeByteBuf.addComponent(true, header.retainedSlice(0, 4)); - compositeByteBuf.addComponent(true, header.slice(4, 4)); - assertEquals(2, compositeByteBuf.nioBufferCount()); - testByteBufBody(compositeByteBuf); - } - - /** - * Test writing a {@link MessageWithHeader} using the given {@link ByteBuf} as header. - * - * @param header the header to use. - * @throws Exception thrown on error. - */ - private void testByteBufBody(ByteBuf header) throws Exception { - long expectedHeaderValue = header.getLong(header.readerIndex()); - ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); - assertEquals(1, header.refCnt()); - assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); - ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer); - - Object body = managedBuf.convertToNetty(); - assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt()); - assertEquals(1, header.refCnt()); - - MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); - ByteBuf result = doWrite(msg, 1); - assertEquals(msg.count(), result.readableBytes()); - assertEquals(expectedHeaderValue, result.readLong()); - assertEquals(84, result.readLong()); - - assertTrue(msg.release()); - assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt()); - assertEquals(0, header.refCnt()); - } - - @Test - public void testDeallocateReleasesManagedBuffer() throws Exception { - ByteBuf header = Unpooled.copyLong(42); - ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84)); - ByteBuf body = (ByteBuf) managedBuf.convertToNetty(); - assertEquals(2, body.refCnt()); - MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); - assertTrue(msg.release()); - Mockito.verify(managedBuf, Mockito.times(1)).release(); - assertEquals(0, body.refCnt()); - } - - private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { - ByteBuf header = Unpooled.copyLong(42); - int headerLength = header.readableBytes(); - TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); - MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count()); - - ByteBuf result = doWrite(msg, totalWrites / writesPerCall); - assertEquals(headerLength + region.count(), result.readableBytes()); - assertEquals(42, result.readLong()); - for (long i = 0; i < 8; i++) { - assertEquals(i, result.readLong()); - } - assertTrue(msg.release()); - } - - private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { - int writes = 0; - ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); - while (msg.transfered() < msg.count()) { - msg.transferTo(channel, msg.transfered()); - writes++; - } - assertTrue("Not enough writes!", minExpectedWrites <= writes); - return Unpooled.wrappedBuffer(channel.getData()); - } - - private static class TestFileRegion extends AbstractFileRegion { - - private final int writeCount; - private final int writesPerCall; - private int written; - - TestFileRegion(int totalWrites, int writesPerCall) { - this.writeCount = totalWrites; - this.writesPerCall = writesPerCall; - } - - @Override - public long count() { - return 8 * writeCount; - } - - @Override - public long position() { - return 0; - } - - @Override - public long transferred() { - return 8 * written; - } - - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - for (int i = 0; i < writesPerCall; i++) { - ByteBuf buf = Unpooled.copyLong((position / 8) + i); - ByteBuffer nio = buf.nioBuffer(); - while (nio.remaining() > 0) { - target.write(nio); - } - buf.release(); - written++; - } - return 8 * writesPerCall; - } - - @Override - protected void deallocate() { - } - } -} diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/server/OneForOneStreamManagerSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/server/OneForOneStreamManagerSuiteJ.java index b5ae59bf06e..6e1641369e8 100644 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/server/OneForOneStreamManagerSuiteJ.java +++ b/common/src/test/java/com/aliyun/emr/rss/common/network/server/OneForOneStreamManagerSuiteJ.java @@ -17,55 +17,21 @@ package com.aliyun.emr.rss.common.network.server; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - import io.netty.channel.Channel; import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; -import com.aliyun.emr.rss.common.network.TestManagedBuffer; -import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; - public class OneForOneStreamManagerSuiteJ { - - @Test - public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { - OneForOneStreamManager manager = new OneForOneStreamManager(); - List buffers = new ArrayList<>(); - TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10)); - TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); - buffers.add(buffer1); - buffers.add(buffer2); - - Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); - manager.registerStream("appId", buffers.iterator(), dummyChannel); - assert manager.numStreamStates() == 1; - - manager.connectionTerminated(dummyChannel); - - Mockito.verify(buffer1, Mockito.times(1)).release(); - Mockito.verify(buffer2, Mockito.times(1)).release(); - assert manager.numStreamStates() == 0; - } - @Test public void streamStatesAreFreedWhenConnectionIsClosedEvenIfBufferIteratorThrowsException() { OneForOneStreamManager manager = new OneForOneStreamManager(); @SuppressWarnings("unchecked") - Iterator buffers = Mockito.mock(Iterator.class); - Mockito.when(buffers.hasNext()).thenReturn(true); - Mockito.when(buffers.next()).thenThrow(RuntimeException.class); - - ManagedBuffer mockManagedBuffer = Mockito.mock(ManagedBuffer.class); + FileManagedBuffers buffers = Mockito.mock(FileManagedBuffers.class); @SuppressWarnings("unchecked") - Iterator buffers2 = Mockito.mock(Iterator.class); - Mockito.when(buffers2.hasNext()).thenReturn(true).thenReturn(true); - Mockito.when(buffers2.next()).thenReturn(mockManagedBuffer).thenThrow(RuntimeException.class); + FileManagedBuffers buffers2 = Mockito.mock(FileManagedBuffers.class); Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); manager.registerStream("appId", buffers, dummyChannel); @@ -73,21 +39,7 @@ public void streamStatesAreFreedWhenConnectionIsClosedEvenIfBufferIteratorThrows Assert.assertEquals(2, manager.numStreamStates()); - try { - manager.connectionTerminated(dummyChannel); - Assert.fail("connectionTerminated should throw exception when fails to release all buffers"); - - } catch (RuntimeException e) { - - Mockito.verify(buffers, Mockito.times(1)).hasNext(); - Mockito.verify(buffers, Mockito.times(1)).next(); - - Mockito.verify(buffers2, Mockito.times(2)).hasNext(); - Mockito.verify(buffers2, Mockito.times(2)).next(); - - Mockito.verify(mockManagedBuffer, Mockito.times(1)).release(); - - Assert.assertEquals(0, manager.numStreamStates()); - } + manager.connectionTerminated(dummyChannel); + assert manager.streams.isEmpty(); } } diff --git a/common/src/test/java/com/aliyun/emr/rss/common/network/util/CryptoUtilsSuiteJ.java b/common/src/test/java/com/aliyun/emr/rss/common/network/util/CryptoUtilsSuiteJ.java deleted file mode 100644 index 53280479d77..00000000000 --- a/common/src/test/java/com/aliyun/emr/rss/common/network/util/CryptoUtilsSuiteJ.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.common.network.util; - -import java.util.Map; -import java.util.Properties; - -import com.google.common.collect.ImmutableMap; -import org.junit.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; - -public class CryptoUtilsSuiteJ { - - @Test - public void testConfConversion() { - String prefix = "my.prefix.commons.config."; - - String confKey1 = prefix + "a.b.c"; - String confVal1 = "val1"; - String cryptoKey1 = CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX + "a.b.c"; - - String confKey2 = prefix.substring(0, prefix.length() - 1) + "A.b.c"; - String confVal2 = "val2"; - String cryptoKey2 = CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX + "A.b.c"; - - Map conf = ImmutableMap.of( - confKey1, confVal1, - confKey2, confVal2); - - Properties cryptoConf = CryptoUtils.toCryptoConf(prefix, conf.entrySet()); - - assertEquals(confVal1, cryptoConf.getProperty(cryptoKey1)); - assertFalse(cryptoConf.containsKey(cryptoKey2)); - } -} diff --git a/common/src/test/scala/com/aliyun/emr/rss/common/meta/WorkerInfoSuite.scala b/common/src/test/scala/com/aliyun/emr/rss/common/meta/WorkerInfoSuite.scala index a4a92350577..8f5e9ec9069 100644 --- a/common/src/test/scala/com/aliyun/emr/rss/common/meta/WorkerInfoSuite.scala +++ b/common/src/test/scala/com/aliyun/emr/rss/common/meta/WorkerInfoSuite.scala @@ -56,9 +56,9 @@ class WorkerInfoSuite extends RssFunSuite { assertEquals( "The number of WorkerInfo decoded from string is wrong.", pbList.size(), workerInfos.size()) - check(h1, p1, p2, p3, a1, workerInfos) - check(h2, p4, p5, p6, a2, workerInfos) - check(h3, p7, p8, p9, a3, workerInfos) + check(h1, p1, p2, p3, b1, workerInfos) + check(h2, p4, p5, p6, b2, workerInfos) + check(h3, p7, p8, p9, b3, workerInfos) } private def check( diff --git a/server-master/src/main/scala/com/aliyun/emr/rss/service/deploy/master/Master.scala b/server-master/src/main/scala/com/aliyun/emr/rss/service/deploy/master/Master.scala index b811f74e14b..c4bc0035a9d 100644 --- a/server-master/src/main/scala/com/aliyun/emr/rss/service/deploy/master/Master.scala +++ b/server-master/src/main/scala/com/aliyun/emr/rss/service/deploy/master/Master.scala @@ -136,7 +136,7 @@ private[deploy] class Master( override def onDisconnected(address: RpcAddress): Unit = { // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"Client $address got disassociated.") + logDebug(s"Client $address got disassociated.") } def executeWithLeaderChecker[T](context: RpcCallContext, f: => T): Unit = @@ -144,10 +144,8 @@ private[deploy] class Master( override def receive: PartialFunction[Any, Unit] = { case CheckForWorkerTimeOut => - logDebug("Received CheckForWorkerTimeOut request.") executeWithLeaderChecker(null, timeoutDeadWorkers()) case CheckForApplicationTimeOut => - logDebug("Received CheckForApplicationTimeOut request.") executeWithLeaderChecker(null, timeoutDeadApplications()) case WorkerLost(host, rpcPort, pushPort, fetchPort, replicatePort, requestId) => logDebug(s"Received worker lost $host:$rpcPort:$pushPort:$fetchPort.") @@ -167,11 +165,11 @@ private[deploy] class Master( fetchPort, replicatePort, numSlots, requestId)) case requestSlots @ RequestSlots(_, _, _, _, _, _) => - logDebug(s"Received RequestSlots request $requestSlots.") + logTrace(s"Received RequestSlots request $requestSlots.") executeWithLeaderChecker(context, handleRequestSlots(context, requestSlots)) case ReleaseSlots(applicationId, shuffleId, workerIds, slots, requestId) => - logDebug(s"Received ReleaseSlots request $requestId, $applicationId, $shuffleId," + + logTrace(s"Received ReleaseSlots request $requestId, $applicationId, $shuffleId," + s"workers ${workerIds.asScala.mkString(",")}, slots ${slots.asScala.mkString(",")}") executeWithLeaderChecker(context, handleReleaseSlots(context, applicationId, shuffleId, workerIds, slots, requestId)) @@ -182,7 +180,6 @@ private[deploy] class Master( handleUnregisterShuffle(context, applicationId, shuffleId, requestId)) case msg: GetBlacklist => - logDebug(s"Received Blacklist request") executeWithLeaderChecker(context, handleGetBlacklist(context, msg)) case ApplicationLost(appId, requestId) => @@ -196,16 +193,13 @@ private[deploy] class Master( fetchPort, replicatePort, numSlots, shuffleKeys, requestId)) case GetWorkerInfos => - logDebug("Received GetWorkerInfos request") executeWithLeaderChecker(context, handleGetWorkerInfos(context)) case ReportWorkerFailure(failedWorkers: util.List[WorkerInfo], requestId: String) => - logDebug("Received ReportNodeFailure request ") executeWithLeaderChecker(context, handleReportNodeFailure(context, failedWorkers, requestId)) case GetClusterLoadStatus(numPartitions: Int) => - logInfo(s"Received GetClusterLoad request") executeWithLeaderChecker(context, handleGetClusterLoadStatus(context, numPartitions)) } @@ -345,7 +339,7 @@ private[deploy] class Master( // reply false if offer slots failed if (slots == null || slots.isEmpty) { logError(s"Offer slots for $numReducers reducers of $shuffleKey failed!") - context.reply(RequestSlotsResponse(StatusCode.SlotNotAvailable, null)) + context.reply(RequestSlotsResponse(StatusCode.SlotNotAvailable, new WorkerResource())) return } diff --git a/server-master/src/main/scala/com/aliyun/emr/rss/service/deploy/master/MasterArguments.scala b/server-master/src/main/scala/com/aliyun/emr/rss/service/deploy/master/MasterArguments.scala index f1e9a04440a..17742f4bdd5 100644 --- a/server-master/src/main/scala/com/aliyun/emr/rss/service/deploy/master/MasterArguments.scala +++ b/server-master/src/main/scala/com/aliyun/emr/rss/service/deploy/master/MasterArguments.scala @@ -25,7 +25,7 @@ import com.aliyun.emr.rss.common.util.{IntParam, Utils} class MasterArguments(args: Array[String], conf: RssConf) { var host = Utils.localHostName() - var port = 9097 + var port = RssConf.masterPort(conf) var propertiesFile: String = null if (System.getenv("RSS_MASTER_HOST") != null) { diff --git a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/ChunkFetchRpcHandler.java b/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/ChunkFetchRpcHandler.java deleted file mode 100644 index 7f4d22e396d..00000000000 --- a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/ChunkFetchRpcHandler.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.service.deploy.worker; - -import java.io.FileNotFoundException; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.aliyun.emr.rss.common.exception.RssException; -import com.aliyun.emr.rss.common.metrics.source.AbstractSource; -import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; -import com.aliyun.emr.rss.common.network.client.TransportClient; -import com.aliyun.emr.rss.common.network.server.FileInfo; -import com.aliyun.emr.rss.common.network.server.ManagedBufferIterator; -import com.aliyun.emr.rss.common.network.server.OneForOneStreamManager; -import com.aliyun.emr.rss.common.network.server.RpcHandler; -import com.aliyun.emr.rss.common.network.server.StreamManager; -import com.aliyun.emr.rss.common.network.util.TransportConf; - -public final class ChunkFetchRpcHandler extends RpcHandler { - - private static final Logger logger = LoggerFactory.getLogger(ChunkFetchRpcHandler.class); - - private final TransportConf conf; - private final OpenStreamHandler handler; - private final OneForOneStreamManager streamManager; - private final AbstractSource source; // metrics - - public ChunkFetchRpcHandler(TransportConf conf, AbstractSource source, OpenStreamHandler handler - ) { - this.conf = conf; - this.handler = handler; - this.streamManager = new OneForOneStreamManager(); - this.source = source; - } - - private String readString(ByteBuffer buffer) { - int length = buffer.getInt(); - byte[] bytes = new byte[length]; - buffer.get(bytes); - return new String(bytes, StandardCharsets.UTF_8); - } - - @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - String shuffleKey = readString(message); - String fileName = readString(message); - int startMapIndex = message.getInt(); - int endMapIndex = message.getInt(); - - // metrics start - source.startTimer(WorkerSource.OpenStreamTime(), shuffleKey); - FileInfo fileInfo = handler.handleOpenStream(shuffleKey, fileName, startMapIndex, endMapIndex); - - if (fileInfo != null) { - logger.debug("Received chunk fetch request {} {} {} {} get file info {}", shuffleKey, - fileName, startMapIndex, endMapIndex, fileInfo); - try { - ManagedBufferIterator iterator = new ManagedBufferIterator(fileInfo, conf); - long streamId = streamManager.registerStream( - client.getClientId(), iterator, client.getChannel()); - - ByteBuffer response = ByteBuffer.allocate(8 + 4); - response.putLong(streamId); - response.putInt(fileInfo.numChunks); - if (fileInfo.numChunks == 0) { - logger.debug("StreamId {} fileName {} startMapIndex {} endMapIndex {} is empty.", - streamId, fileName, startMapIndex, endMapIndex); - } - response.flip(); - callback.onSuccess(response); - } catch (IOException e) { - callback.onFailure( - new RssException("Chunk offsets meta exception ", e)); - } finally { - // metrics end - source.stopTimer(WorkerSource.OpenStreamTime(), shuffleKey); - } - } else { - // metrics end - source.stopTimer(WorkerSource.OpenStreamTime(), shuffleKey); - - callback.onFailure(new FileNotFoundException()); - } - } - - @Override - public boolean checkRegistered() { - return ((Registerable) handler).isRegistered(); - } - - @Override - public void channelInactive(TransportClient client) { - logger.debug("channel Inactive " + client.getSocketAddress()); - } - - @Override - public void exceptionCaught(Throwable cause, TransportClient client) { - logger.debug("exception caught " + cause + " " + client.getSocketAddress()); - } - - @Override - public StreamManager getStreamManager() { - return streamManager; - } -} diff --git a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/FlushBuffer.java b/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/FlushBuffer.java deleted file mode 100644 index 99f99dc738b..00000000000 --- a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/FlushBuffer.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.service.deploy.worker; - -import io.netty.buffer.ByteBuf; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.aliyun.emr.rss.common.unsafe.Platform; - -public final class FlushBuffer extends MinimalByteBuf { - private static final Logger logger = LoggerFactory.getLogger(FlushBuffer.class); - - private final int id; - private final long startAddress; - private final long endAddress; - private long currentAddress; - - public FlushBuffer(int id, long startAddress, long endAddress) { - this.id = id; - this.startAddress = startAddress; - this.endAddress = endAddress; - this.currentAddress = startAddress; - } - - public int remaining() { - return (int)(endAddress - currentAddress); - } - - public void append(ByteBuf data) { - final int length = data.readableBytes(); - final int dstIndex = (int) (currentAddress - startAddress); - data.getBytes(data.readerIndex(), this, dstIndex, length); - currentAddress += length; - } - - public void reset() { - currentAddress = startAddress; - } - - public boolean hasData() { - return (currentAddress > startAddress); - } - - public int getId() { - return id; - } - - public long getStartAddress() { - return startAddress; - } - - public long getEndAddress() { - return endAddress; - } - - public long getCurrentAddress() { - return currentAddress; - } - - @Override - public int capacity() { - return (int) (endAddress - startAddress); - } - - @Override - public boolean hasMemoryAddress() { - return true; - } - - @Override - public long memoryAddress() { - return startAddress; - } - - @Override - public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { - if (src.hasMemoryAddress()) { - Platform.copyMemory(null, src.memoryAddress() + srcIndex, - null, startAddress + index, length); - } if (src.hasArray()) { - Platform.copyMemory( - src.array(), Platform.BYTE_ARRAY_OFFSET + src.arrayOffset() + srcIndex, - null, startAddress + index, length); - } else { - src.getBytes(srcIndex, this, index, length); - } - return this; - } -} diff --git a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/MinimalByteBuf.java b/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/MinimalByteBuf.java deleted file mode 100644 index feb39ee70a3..00000000000 --- a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/MinimalByteBuf.java +++ /dev/null @@ -1,950 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.service.deploy.worker; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.channels.FileChannel; -import java.nio.channels.GatheringByteChannel; -import java.nio.channels.ScatteringByteChannel; -import java.nio.charset.Charset; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.util.ByteProcessor; - -public class MinimalByteBuf extends ByteBuf { - - @Override - public int capacity() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf capacity(int newCapacity) { - throw new UnsupportedOperationException(); - } - - @Override - public int maxCapacity() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBufAllocator alloc() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteOrder order() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf order(ByteOrder endianness) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf unwrap() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isDirect() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isReadOnly() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf asReadOnly() { - throw new UnsupportedOperationException(); - } - - @Override - public int readerIndex() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readerIndex(int readerIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public int writerIndex() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writerIndex(int writerIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setIndex(int readerIndex, int writerIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public int readableBytes() { - throw new UnsupportedOperationException(); - } - - @Override - public int writableBytes() { - throw new UnsupportedOperationException(); - } - - @Override - public int maxWritableBytes() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isReadable() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isReadable(int size) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isWritable() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isWritable(int size) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf clear() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf markReaderIndex() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf resetReaderIndex() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf markWriterIndex() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf resetWriterIndex() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf discardReadBytes() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf discardSomeReadBytes() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf ensureWritable(int minWritableBytes) { - throw new UnsupportedOperationException(); - } - - @Override - public int ensureWritable(int minWritableBytes, boolean force) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public short getUnsignedByte(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public short getShortLE(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public int getUnsignedShort(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public int getUnsignedShortLE(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public int getMedium(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public int getMediumLE(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public int getUnsignedMedium(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public int getUnsignedMediumLE(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public int getInt(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public int getIntLE(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public long getUnsignedInt(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public long getUnsignedIntLE(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public long getLongLE(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public char getChar(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf getBytes(int index, ByteBuf dst) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf getBytes(int index, ByteBuf dst, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf getBytes(int index, byte[] dst) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf getBytes(int index, ByteBuffer dst) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int getBytes(int index, FileChannel out, long position, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public CharSequence getCharSequence(int index, int length, Charset charset) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBoolean(int index, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setByte(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setShort(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setShortLE(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setMedium(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setMediumLE(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setInt(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setIntLE(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setLong(int index, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setLongLE(int index, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setChar(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setFloat(int index, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setDouble(int index, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, ByteBuf src) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, ByteBuf src, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, byte[] src) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, ByteBuffer src) { - throw new UnsupportedOperationException(); - } - - @Override - public int setBytes(int index, InputStream in, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int setBytes(int index, FileChannel in, long position, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setZero(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public int setCharSequence(int index, CharSequence sequence, Charset charset) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean readBoolean() { - throw new UnsupportedOperationException(); - } - - @Override - public byte readByte() { - throw new UnsupportedOperationException(); - } - - @Override - public short readUnsignedByte() { - throw new UnsupportedOperationException(); - } - - @Override - public short readShort() { - throw new UnsupportedOperationException(); - } - - @Override - public short readShortLE() { - throw new UnsupportedOperationException(); - } - - @Override - public int readUnsignedShort() { - throw new UnsupportedOperationException(); - } - - @Override - public int readUnsignedShortLE() { - throw new UnsupportedOperationException(); - } - - @Override - public int readMedium() { - throw new UnsupportedOperationException(); - } - - @Override - public int readMediumLE() { - throw new UnsupportedOperationException(); - } - - @Override - public int readUnsignedMedium() { - throw new UnsupportedOperationException(); - } - - @Override - public int readUnsignedMediumLE() { - throw new UnsupportedOperationException(); - } - - @Override - public int readInt() { - throw new UnsupportedOperationException(); - } - - @Override - public int readIntLE() { - throw new UnsupportedOperationException(); - } - - @Override - public long readUnsignedInt() { - throw new UnsupportedOperationException(); - } - - @Override - public long readUnsignedIntLE() { - throw new UnsupportedOperationException(); - } - - @Override - public long readLong() { - throw new UnsupportedOperationException(); - } - - @Override - public long readLongLE() { - throw new UnsupportedOperationException(); - } - - @Override - public char readChar() { - throw new UnsupportedOperationException(); - } - - @Override - public float readFloat() { - throw new UnsupportedOperationException(); - } - - @Override - public double readDouble() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readBytes(int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readSlice(int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readRetainedSlice(int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readBytes(ByteBuf dst) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readBytes(ByteBuf dst, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readBytes(byte[] dst) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readBytes(ByteBuffer dst) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf readBytes(OutputStream out, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int readBytes(GatheringByteChannel out, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public CharSequence readCharSequence(int length, Charset charset) { - throw new UnsupportedOperationException(); - } - - @Override - public int readBytes(FileChannel out, long position, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf skipBytes(int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBoolean(boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeByte(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeShort(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeShortLE(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeMedium(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeMediumLE(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeInt(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeIntLE(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeLong(long value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeLongLE(long value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeChar(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeFloat(float value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeDouble(double value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(ByteBuf src) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(ByteBuf src, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(byte[] src) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(byte[] src, int srcIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(ByteBuffer src) { - throw new UnsupportedOperationException(); - } - - @Override - public int writeBytes(InputStream in, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int writeBytes(ScatteringByteChannel in, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int writeBytes(FileChannel in, long position, int length) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeZero(int length) { - throw new UnsupportedOperationException(); - } - - @Override - public int writeCharSequence(CharSequence sequence, Charset charset) { - throw new UnsupportedOperationException(); - } - - @Override - public int indexOf(int fromIndex, int toIndex, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public int bytesBefore(byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public int bytesBefore(int length, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public int bytesBefore(int index, int length, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public int forEachByte(ByteProcessor processor) { - throw new UnsupportedOperationException(); - } - - @Override - public int forEachByte(int index, int length, ByteProcessor processor) { - throw new UnsupportedOperationException(); - } - - @Override - public int forEachByteDesc(ByteProcessor processor) { - throw new UnsupportedOperationException(); - } - - @Override - public int forEachByteDesc(int index, int length, ByteProcessor processor) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf copy() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf copy(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf slice() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf retainedSlice() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf slice(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf retainedSlice(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf duplicate() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf retainedDuplicate() { - throw new UnsupportedOperationException(); - } - - @Override - public int nioBufferCount() { - return 0; - } - - @Override - public ByteBuffer nioBuffer() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuffer nioBuffer(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuffer internalNioBuffer(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuffer[] nioBuffers() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuffer[] nioBuffers(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean hasArray() { - return false; - } - - @Override - public byte[] array() { - throw new UnsupportedOperationException(); - } - - @Override - public int arrayOffset() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean hasMemoryAddress() { - return false; - } - - @Override - public long memoryAddress() { - throw new UnsupportedOperationException(); - } - - @Override - public String toString(Charset charset) { - throw new UnsupportedOperationException(); - } - - @Override - public String toString(int index, int length, Charset charset) { - throw new UnsupportedOperationException(); - } - - @Override - public int hashCode() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean equals(Object obj) { - return false; - } - - @Override - public int compareTo(ByteBuf buffer) { - throw new UnsupportedOperationException(); - } - - @Override - public String toString() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf retain(int increment) { - throw new UnsupportedOperationException(); - } - - @Override - public int refCnt() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf retain() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf touch() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf touch(Object hint) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean release() { - return false; - } - - @Override - public boolean release(int decrement) { - return false; - } -} diff --git a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/OpenStreamHandler.java b/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/OpenStreamHandler.java deleted file mode 100644 index a9573ea0dac..00000000000 --- a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/OpenStreamHandler.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.service.deploy.worker; - -import com.aliyun.emr.rss.common.network.server.FileInfo; - -public interface OpenStreamHandler { - FileInfo handleOpenStream(String shuffleKey, String partitionId, int startMapIndex, - int endMapIndex); -} diff --git a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/PushDataHandler.java b/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/PushDataHandler.java deleted file mode 100644 index 7be7533f07a..00000000000 --- a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/PushDataHandler.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.service.deploy.worker; - -import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; -import com.aliyun.emr.rss.common.network.protocol.PushData; -import com.aliyun.emr.rss.common.network.protocol.PushMergedData; - -public interface PushDataHandler { - void handlePushData(PushData pushData, RpcResponseCallback callback); - void handlePushMergedData(PushMergedData pushMergedData, RpcResponseCallback callback); -} diff --git a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/PushDataRpcHandler.java b/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/PushDataRpcHandler.java deleted file mode 100644 index 65c4926b541..00000000000 --- a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/PushDataRpcHandler.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.service.deploy.worker; - -import java.nio.ByteBuffer; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.aliyun.emr.rss.common.network.client.RpcResponseCallback; -import com.aliyun.emr.rss.common.network.client.TransportClient; -import com.aliyun.emr.rss.common.network.protocol.PushData; -import com.aliyun.emr.rss.common.network.protocol.PushMergedData; -import com.aliyun.emr.rss.common.network.server.OneForOneStreamManager; -import com.aliyun.emr.rss.common.network.server.RpcHandler; -import com.aliyun.emr.rss.common.network.server.StreamManager; -import com.aliyun.emr.rss.common.network.util.TransportConf; - -public final class PushDataRpcHandler extends RpcHandler { - - private static final Logger logger = LoggerFactory.getLogger(PushDataRpcHandler.class); - - private final TransportConf conf; - private final PushDataHandler handler; - private final OneForOneStreamManager streamManager; - - public PushDataRpcHandler(TransportConf conf, PushDataHandler handler) { - this.conf = conf; - this.handler = handler; - streamManager = new OneForOneStreamManager(); - } - - @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - throw new UnsupportedOperationException("PushDataRpcHandler"); - } - - @Override - public void receivePushData( - TransportClient client, PushData pushData, RpcResponseCallback callback) { - handler.handlePushData(pushData, callback); - } - - @Override - public void receivePushMergedData( - TransportClient client, PushMergedData pushMergedData, RpcResponseCallback callback) { - handler.handlePushMergedData(pushMergedData, callback); - } - - @Override - public boolean checkRegistered() { - return ((Worker) handler).isRegistered(); - } - - @Override - public void channelInactive(TransportClient client) { - logger.debug("channel Inactive " + client.getSocketAddress()); - } - - @Override - public void exceptionCaught(Throwable cause, TransportClient client) { - logger.debug("exception caught " + cause + " " + client.getSocketAddress()); - } - - @Override - public StreamManager getStreamManager() { - return streamManager; - } -} diff --git a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/Registerable.java b/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/Registerable.java deleted file mode 100644 index c8472fa8a10..00000000000 --- a/server-worker/src/main/java/com/aliyun/emr/rss/service/deploy/worker/Registerable.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.aliyun.emr.rss.service.deploy.worker; - -public interface Registerable { - boolean isRegistered(); -} diff --git a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Controller.scala b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Controller.scala new file mode 100644 index 00000000000..03eae9652b4 --- /dev/null +++ b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Controller.scala @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.aliyun.emr.rss.service.deploy.worker + +import java.io.IOException +import java.util.{ArrayList => jArrayList, List => jList} +import java.util.concurrent._ +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +import java.util.function.BiFunction + +import scala.collection.JavaConverters._ + +import io.netty.util.{HashedWheelTimer, Timeout, TimerTask} +import io.netty.util.internal.ConcurrentSet + +import com.aliyun.emr.rss.common.RssConf +import com.aliyun.emr.rss.common.internal.Logging +import com.aliyun.emr.rss.common.meta.{PartitionLocationInfo, WorkerInfo} +import com.aliyun.emr.rss.common.metrics.MetricsSystem +import com.aliyun.emr.rss.common.protocol.{PartitionLocation, PartitionSplitMode} +import com.aliyun.emr.rss.common.protocol.PartitionLocation.StorageHint +import com.aliyun.emr.rss.common.protocol.message.ControlMessages._ +import com.aliyun.emr.rss.common.protocol.message.StatusCode +import com.aliyun.emr.rss.common.rpc._ +import com.aliyun.emr.rss.common.util.Utils + +private[deploy] class Controller( + override val rpcEnv: RpcEnv, + val conf: RssConf, + val metricsSystem: MetricsSystem) + extends RpcEndpoint with Logging { + + var workerSource: WorkerSource = _ + var localStorageManager: LocalStorageManager = _ + var registered: AtomicBoolean = _ + var shuffleMapperAttempts: ConcurrentHashMap[String, Array[Int]] = _ + var workerInfo: WorkerInfo = _ + var partitionLocationInfo: PartitionLocationInfo = _ + var timer: HashedWheelTimer = _ + var commitThreadPool: ThreadPoolExecutor = _ + var asyncReplyPool: ScheduledExecutorService = _ + + def init(worker: Worker): Unit = { + workerSource = worker.workerSource + localStorageManager = worker.localStorageManager + registered = worker.registered + shuffleMapperAttempts = worker.shuffleMapperAttempts + workerInfo = worker.workerInfo + partitionLocationInfo = worker.partitionLocationInfo + timer = worker.timer + commitThreadPool = worker.commitThreadPool + asyncReplyPool = worker.asyncReplyPool + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case ReserveSlots(applicationId, shuffleId, masterLocations, slaveLocations, splitThreashold, + splitMode, storageHint) => + val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) + workerSource.sample(WorkerSource.ReserveSlotsTime, shuffleKey) { + logDebug(s"Received ReserveSlots request, $shuffleKey, " + + s"master partitions: ${masterLocations.asScala.map(_.getUniqueId).mkString(",")}; " + + s"slave partitions: ${slaveLocations.asScala.map(_.getUniqueId).mkString(",")}.") + handleReserveSlots(context, applicationId, shuffleId, masterLocations, + slaveLocations, splitThreashold, splitMode, storageHint) + logDebug(s"ReserveSlots for $shuffleKey succeed.") + } + + case CommitFiles(applicationId, shuffleId, masterIds, slaveIds, mapAttempts) => + val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) + workerSource.sample(WorkerSource.CommitFilesTime, shuffleKey) { + logDebug(s"Received CommitFiles request, $shuffleKey, master files" + + s" ${masterIds.asScala.mkString(",")}; slave files ${slaveIds.asScala.mkString(",")}.") + val commitFilesTimeMs = Utils.timeIt({ + handleCommitFiles(context, shuffleKey, masterIds, slaveIds, mapAttempts) + }) + logDebug(s"Done processed CommitFiles request with shuffleKey $shuffleKey, in " + + s"$commitFilesTimeMs ms.") + } + + case GetWorkerInfos => + handleGetWorkerInfos(context) + + case ThreadDump => + handleThreadDump(context) + + case Destroy(shuffleKey, masterLocations, slaveLocations) => + handleDestroy(context, shuffleKey, masterLocations, slaveLocations) + } + + private def handleReserveSlots( + context: RpcCallContext, + applicationId: String, + shuffleId: Int, + masterLocations: jList[PartitionLocation], + slaveLocations: jList[PartitionLocation], + splitThreshold: Long, + splitMode: PartitionSplitMode, + storageHint: StorageHint): Unit = { + val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) + if (!localStorageManager.hasAvailableWorkingDirs) { + val msg = "Local storage has no available dirs!" + logError(s"[handleReserveSlots] $msg") + context.reply(ReserveSlotsResponse(StatusCode.ReserveSlotFailed, msg)) + return + } + val masterPartitions = new jArrayList[PartitionLocation]() + try { + for (ind <- 0 until masterLocations.size()) { + val location = masterLocations.get(ind) + val writer = localStorageManager.createWriter(applicationId, shuffleId, location, + splitThreshold, splitMode) + masterPartitions.add(new WorkingPartition(location, writer)) + } + } catch { + case e: Exception => + logError(s"CreateWriter for $shuffleKey failed.", e) + } + if (masterPartitions.size() < masterLocations.size()) { + val msg = s"Not all master partition satisfied for $shuffleKey" + logWarning(s"[handleReserveSlots] $msg, will destroy writers.") + masterPartitions.asScala.foreach(_.asInstanceOf[WorkingPartition].getFileWriter.destroy()) + context.reply(ReserveSlotsResponse(StatusCode.ReserveSlotFailed, msg)) + return + } + + val slavePartitions = new jArrayList[PartitionLocation]() + try { + for (ind <- 0 until slaveLocations.size()) { + val location = slaveLocations.get(ind) + val writer = localStorageManager.createWriter(applicationId, shuffleId, + location, splitThreshold, splitMode) + slavePartitions.add(new WorkingPartition(location, writer)) + } + } catch { + case e: Exception => + logError(s"CreateWriter for $shuffleKey failed.", e) + } + if (slavePartitions.size() < slaveLocations.size()) { + val msg = s"Not all slave partition satisfied for $shuffleKey" + logWarning(s"[handleReserveSlots] $msg, destroy writers.") + masterPartitions.asScala.foreach(_.asInstanceOf[WorkingPartition].getFileWriter.destroy()) + slavePartitions.asScala.foreach(_.asInstanceOf[WorkingPartition].getFileWriter.destroy()) + context.reply(ReserveSlotsResponse(StatusCode.ReserveSlotFailed, msg)) + return + } + + // reserve success, update status + partitionLocationInfo.addMasterPartitions(shuffleKey, masterPartitions) + partitionLocationInfo.addSlavePartitions(shuffleKey, slavePartitions) + workerInfo.allocateSlots(shuffleKey, masterPartitions.size() + slavePartitions.size()) + logInfo(s"Reserved ${masterPartitions.size()} master location and ${slavePartitions.size()}" + + s" slave location for $shuffleKey master: ${masterPartitions}\nslave: ${slavePartitions}.") + context.reply(ReserveSlotsResponse(StatusCode.Success)) + } + + private def commitFiles( + shuffleKey: String, + uniqueIds: jList[String], + committedIds: ConcurrentSet[String], + failedIds: ConcurrentSet[String], + master: Boolean = true): CompletableFuture[Void] = { + var future: CompletableFuture[Void] = null + + if (uniqueIds != null) { + uniqueIds.asScala.foreach { uniqueId => + val task = CompletableFuture.runAsync(new Runnable { + override def run(): Unit = { + try { + val location = if (master) { + partitionLocationInfo.getMasterLocation(shuffleKey, uniqueId) + } else { + partitionLocationInfo.getSlaveLocation(shuffleKey, uniqueId) + } + + if (location == null) { + logWarning(s"Get Partition Location for $shuffleKey $uniqueId but didn't exist.") + return + } + + val fileWriter = location.asInstanceOf[WorkingPartition].getFileWriter + val bytes = fileWriter.close() + if (bytes > 0L) { + committedIds.add(uniqueId) + } + } catch { + case e: IOException => + logError(s"Commit file for $shuffleKey $uniqueId failed.", e) + failedIds.add(uniqueId) + } + } + }, commitThreadPool) + + if (future == null) { + future = task + } else { + future = CompletableFuture.allOf(future, task) + } + } + } + + future + } + + private def handleCommitFiles( + context: RpcCallContext, + shuffleKey: String, + masterIds: jList[String], + slaveIds: jList[String], + mapAttempts: Array[Int]): Unit = { + // return null if shuffleKey does not exist + if (!partitionLocationInfo.containsShuffle(shuffleKey)) { + logError(s"Shuffle $shuffleKey doesn't exist!") + context.reply(CommitFilesResponse( + StatusCode.ShuffleNotRegistered, List.empty.asJava, List.empty.asJava, + masterIds, slaveIds)) + return + } + + // close and flush files. + shuffleMapperAttempts.putIfAbsent(shuffleKey, mapAttempts) + + // Use ConcurrentSet to avoid excessive lock contention. + val committedMasterIds = new ConcurrentSet[String]() + val committedSlaveIds = new ConcurrentSet[String]() + val failedMasterIds = new ConcurrentSet[String]() + val failedSlaveIds = new ConcurrentSet[String]() + + val masterFuture = commitFiles(shuffleKey, masterIds, committedMasterIds, failedMasterIds) + val slaveFuture = commitFiles(shuffleKey, slaveIds, committedSlaveIds, failedSlaveIds, false) + + val future = if (masterFuture != null && slaveFuture != null) { + CompletableFuture.allOf(masterFuture, slaveFuture) + } else if (masterFuture != null) { + masterFuture + } else if (slaveFuture != null) { + slaveFuture + } else { + null + } + + def reply(): Unit = { + // release slots before reply. + val numSlotsReleased = + partitionLocationInfo.removeMasterPartitions(shuffleKey, masterIds) + + partitionLocationInfo.removeSlavePartitions(shuffleKey, slaveIds) + workerInfo.releaseSlots(shuffleKey, numSlotsReleased) + + val committedMasterIdList = new jArrayList[String](committedMasterIds) + val committedSlaveIdList = new jArrayList[String](committedSlaveIds) + val failedMasterIdList = new jArrayList[String](failedMasterIds) + val failedSlaveIdList = new jArrayList[String](failedSlaveIds) + // reply + if (failedMasterIds.isEmpty && failedSlaveIds.isEmpty) { + logInfo(s"CommitFiles for $shuffleKey success with ${committedMasterIds.size()}" + + s" master partitions and ${committedSlaveIds.size()} slave partitions!") + context.reply(CommitFilesResponse( + StatusCode.Success, committedMasterIdList, committedSlaveIdList, + List.empty.asJava, List.empty.asJava)) + } else { + logWarning(s"CommitFiles for $shuffleKey failed with ${failedMasterIds.size()} master" + + s" partitions and ${failedSlaveIds.size()} slave partitions!") + context.reply(CommitFilesResponse(StatusCode.PartialSuccess, committedMasterIdList, + committedSlaveIdList, failedMasterIdList, failedSlaveIdList)) + } + } + + if (future != null) { + val result = new AtomicReference[CompletableFuture[Unit]]() + val flushTimeout = RssConf.flushTimeout(conf) + + val timeout = timer.newTimeout(new TimerTask { + override def run(timeout: Timeout): Unit = { + if (result.get() != null) { + result.get().cancel(true) + logWarning(s"After waiting $flushTimeout s, cancel all commit file jobs.") + } + } + }, flushTimeout, TimeUnit.SECONDS) + + result.set(future.handleAsync(new BiFunction[Void, Throwable, Unit] { + override def apply(v: Void, t: Throwable): Unit = { + if (null != t) { + t match { + case _: CancellationException => + logWarning("While handling commitFiles, canceled.") + case ee: ExecutionException => + logError("While handling commitFiles, ExecutionException raised.", ee) + case ie: InterruptedException => + logWarning("While handling commitFiles, interrupted.") + Thread.currentThread().interrupt() + throw ie + case _: TimeoutException => + logWarning(s"While handling commitFiles, timeout after $flushTimeout s.") + case throwable: Throwable => + logError("While handling commitFiles, exception occurs.", throwable) + } + } else { + // finish, cancel timeout job first. + timeout.cancel() + reply() + } + } + }, asyncReplyPool)) // should not use commitThreadPool in case of block by commit files. + } else { + // If both of two futures are null, then reply directly. + reply() + } + } + + private def handleDestroy( + context: RpcCallContext, + shuffleKey: String, + masterLocations: jList[String], + slaveLocations: jList[String]): Unit = { + // check whether shuffleKey has registered + if (!partitionLocationInfo.containsShuffle(shuffleKey)) { + logWarning(s"Shuffle $shuffleKey not registered!") + context.reply(DestroyResponse( + StatusCode.ShuffleNotRegistered, masterLocations, slaveLocations)) + return + } + + val failedMasters = new jArrayList[String]() + val failedSlaves = new jArrayList[String]() + + // destroy master locations + if (masterLocations != null && !masterLocations.isEmpty) { + masterLocations.asScala.foreach { loc => + val allocatedLoc = partitionLocationInfo.getMasterLocation(shuffleKey, loc) + if (allocatedLoc == null) { + failedMasters.add(loc) + } else { + allocatedLoc.asInstanceOf[WorkingPartition].getFileWriter.destroy() + } + } + // remove master locations from WorkerInfo + partitionLocationInfo.removeMasterPartitions(shuffleKey, masterLocations) + } + // destroy slave locations + if (slaveLocations != null && !slaveLocations.isEmpty) { + slaveLocations.asScala.foreach { loc => + val allocatedLoc = partitionLocationInfo.getSlaveLocation(shuffleKey, loc) + if (allocatedLoc == null) { + failedSlaves.add(loc) + } else { + allocatedLoc.asInstanceOf[WorkingPartition].getFileWriter.destroy() + } + } + // remove slave locations from worker info + partitionLocationInfo.removeSlavePartitions(shuffleKey, slaveLocations) + } + // reply + if (failedMasters.isEmpty && failedSlaves.isEmpty) { + logInfo(s"Destroy ${masterLocations.size()} master location and ${slaveLocations.size()}" + + s" slave locations for $shuffleKey successfully.") + context.reply(DestroyResponse(StatusCode.Success, null, null)) + } else { + logInfo(s"Destroy ${failedMasters.size()}/${masterLocations.size()} master location and" + + s"${failedSlaves.size()}/${slaveLocations.size()} slave location for" + + s" $shuffleKey PartialSuccess.") + context.reply(DestroyResponse(StatusCode.PartialSuccess, failedMasters, failedSlaves)) + } + } + + private def handleGetWorkerInfos(context: RpcCallContext): Unit = { + val list = new jArrayList[WorkerInfo]() + list.add(workerInfo) + context.reply(GetWorkerInfosResponse(StatusCode.Success, list.asScala.toList: _*)) + } + + private def handleThreadDump(context: RpcCallContext): Unit = { + val threadDump = Utils.getThreadDump() + context.reply(ThreadDumpResponse(threadDump)) + } + + def isRegistered(): Boolean = { + registered.get() + } +} diff --git a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/FetchHandler.scala b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/FetchHandler.scala new file mode 100644 index 00000000000..5ca139c7fdd --- /dev/null +++ b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/FetchHandler.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.aliyun.emr.rss.service.deploy.worker + +import java.io.FileNotFoundException +import java.io.IOException +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean + +import com.google.common.base.Throwables +import io.netty.util.concurrent.{Future, GenericFutureListener} + +import com.aliyun.emr.rss.common.exception.RssException +import com.aliyun.emr.rss.common.internal.Logging +import com.aliyun.emr.rss.common.metrics.source.NetWorkSource +import com.aliyun.emr.rss.common.network.buffer.NioManagedBuffer +import com.aliyun.emr.rss.common.network.client.RpcResponseCallback +import com.aliyun.emr.rss.common.network.client.TransportClient +import com.aliyun.emr.rss.common.network.protocol._ +import com.aliyun.emr.rss.common.network.server.{BaseMessageHandler, FileInfo} +import com.aliyun.emr.rss.common.network.server.FileManagedBuffers +import com.aliyun.emr.rss.common.network.server.OneForOneStreamManager +import com.aliyun.emr.rss.common.network.util.NettyUtils +import com.aliyun.emr.rss.common.network.util.TransportConf + +class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logging { + var streamManager = new OneForOneStreamManager() + var source: WorkerSource = _ + var localStorageManager: LocalStorageManager = _ + var partitionsSorter: PartitionFilesSorter = _ + var registered: AtomicBoolean = _ + + def init(worker: Worker): Unit = { + this.source = worker.workerSource + this.localStorageManager = worker.localStorageManager + this.partitionsSorter = worker.partitionsSorter + this.registered = worker.registered + } + + def openStream( + shuffleKey: String, + fileName: String, + startMapIndex: Int, + endMapIndex: Int): FileInfo = { + // find FileWriter responsible for the data + val fileWriter = localStorageManager.getWriter(shuffleKey, fileName) + if (fileWriter == null) { + logWarning("File $fileName for $shuffleKey was not found!") + null + } else { + partitionsSorter.openStream(shuffleKey, fileName, fileWriter, startMapIndex, endMapIndex) + } + } + + override def receive(client: TransportClient, msg: RequestMessage): Unit = { + msg match { + case r: ChunkFetchRequest => + handleChunkFetchRequest(client, r) + case r: RpcRequest => + handleOpenStream(client, r) + } + } + + def handleOpenStream(client: TransportClient, request: RpcRequest): Unit = { + val msg = AbstractMessage.fromByteBuffer(request.body().nioByteBuffer()) + val openBlocks = msg.asInstanceOf[OpenStream] + val shuffleKey = new String(openBlocks.shuffleKey, StandardCharsets.UTF_8) + val fileName = new String(openBlocks.fileName, StandardCharsets.UTF_8) + val startMapIndex = openBlocks.startMapIndex + val endMapIndex = openBlocks.endMapIndex + // metrics start + source.startTimer(WorkerSource.OpenStreamTime, shuffleKey) + val fileInfo = openStream(shuffleKey, fileName, startMapIndex, endMapIndex) + + if (fileInfo != null) { + logDebug(s"Received chunk fetch request $shuffleKey $fileName" + + s"$startMapIndex $endMapIndex get file info $fileInfo") + try { + val buffers = new FileManagedBuffers(fileInfo, conf) + val streamId = streamManager.registerStream(client.getClientId, buffers, client.getChannel) + val streamHandle = new StreamHandle(streamId, fileInfo.numChunks) + if (fileInfo.numChunks == 0) { + logDebug(s"StreamId $streamId fileName $fileName startMapIndex" + + s" $startMapIndex endMapIndex $endMapIndex is empty.") + } + client.getChannel.writeAndFlush(new RpcResponse(request.requestId, + new NioManagedBuffer(streamHandle.toByteBuffer))) + } catch { + case e: IOException => + client.getChannel.writeAndFlush(new RpcFailure(request.requestId, + Throwables.getStackTraceAsString(new RssException("Chunk offsets meta exception ", e)))) + } finally { + // metrics end + source.stopTimer(WorkerSource.OpenStreamTime, shuffleKey) + request.body().release() + } + } else { + source.stopTimer(WorkerSource.OpenStreamTime, shuffleKey) + client.getChannel.writeAndFlush(new RpcFailure(request.requestId, + Throwables.getStackTraceAsString(new FileNotFoundException))) + } + } + + def handleChunkFetchRequest(client: TransportClient, req: ChunkFetchRequest): Unit = { + source.startTimer(NetWorkSource.FetchChunkTime, req.toString) + logTrace(s"Received req from ${NettyUtils.getRemoteAddress(client.getChannel)}" + + s" to fetch block ${req.streamChunkSlice}") + + val chunksBeingTransferred = streamManager.chunksBeingTransferred + if (chunksBeingTransferred >= conf.maxChunksBeingTransferred) { + logError(s"The number of chunks being transferred $chunksBeingTransferred" + + s"is above ${conf.maxChunksBeingTransferred()}.") + source.stopTimer(NetWorkSource.FetchChunkTime, req.toString) + } else { + try { + streamManager.checkAuthorization(client, req.streamChunkSlice.streamId) + + val buf = streamManager.getChunk(req.streamChunkSlice.streamId, + req.streamChunkSlice.chunkIndex, req.streamChunkSlice.offset, req.streamChunkSlice.len) + streamManager.chunkBeingSent(req.streamChunkSlice.streamId) + client.getChannel.writeAndFlush(new ChunkFetchSuccess(req.streamChunkSlice, buf)) + .addListener(new GenericFutureListener[Future[_ >: Void]] { + override def operationComplete(future: Future[_ >: Void]): Unit = { + streamManager.chunkSent(req.streamChunkSlice.streamId) + source.stopTimer(NetWorkSource.FetchChunkTime, req.toString) + } + }) + } catch { + case e: Exception => + logError(String.format(s"Error opening block ${req.streamChunkSlice} for request from" + + s" ${NettyUtils.getRemoteAddress(client.getChannel)}", e)) + client.getChannel.writeAndFlush(new ChunkFetchFailure(req.streamChunkSlice, + Throwables.getStackTraceAsString(e))) + source.stopTimer(NetWorkSource.FetchChunkTime, req.toString) + } + } + } + + override def checkRegistered: Boolean = registered.get + + override def channelInactive(client: TransportClient): Unit = { + logDebug("channel Inactive " + client.getSocketAddress) + } + + override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { + logDebug("exception caught " + cause + " " + client.getSocketAddress) + } +} diff --git a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/PushDataHandler.scala b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/PushDataHandler.scala new file mode 100644 index 00000000000..f278b7ef59d --- /dev/null +++ b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/PushDataHandler.scala @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.aliyun.emr.rss.service.deploy.worker + +import java.nio.ByteBuffer +import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor} +import java.util.concurrent.atomic.AtomicBoolean + +import com.google.common.base.Throwables +import io.netty.buffer.ByteBuf + +import com.aliyun.emr.rss.common.exception.AlreadyClosedException +import com.aliyun.emr.rss.common.internal.Logging +import com.aliyun.emr.rss.common.meta.{PartitionLocationInfo, WorkerInfo} +import com.aliyun.emr.rss.common.network.buffer.{NettyManagedBuffer, NioManagedBuffer} +import com.aliyun.emr.rss.common.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory} +import com.aliyun.emr.rss.common.network.protocol.{PushData, PushMergedData, RequestMessage, RpcFailure, RpcResponse} +import com.aliyun.emr.rss.common.network.server.BaseMessageHandler +import com.aliyun.emr.rss.common.protocol.{PartitionLocation, PartitionSplitMode} +import com.aliyun.emr.rss.common.protocol.message.StatusCode +import com.aliyun.emr.rss.common.unsafe.Platform + +class PushDataHandler extends BaseMessageHandler with Logging { + + var workerSource: WorkerSource = _ + var partitionLocationInfo: PartitionLocationInfo = _ + var shuffleMapperAttempts: ConcurrentHashMap[String, Array[Int]] = _ + var replicateThreadPool: ThreadPoolExecutor = _ + var unavailablePeers: ConcurrentHashMap[WorkerInfo, Long] = _ + var pushClientFactory: TransportClientFactory = _ + var registered: AtomicBoolean = _ + + def init(worker: Worker): Unit = { + workerSource = worker.workerSource + partitionLocationInfo = worker.partitionLocationInfo + shuffleMapperAttempts = worker.shuffleMapperAttempts + replicateThreadPool = worker.replicateThreadPool + unavailablePeers = worker.unavailablePeers + pushClientFactory = worker.pushClientFactory + registered = worker.registered + } + + override def receive(client: TransportClient, msg: RequestMessage): Unit = + msg match { + case pushData: PushData => + try { + handlePushData(pushData, new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + client.getChannel.writeAndFlush(new RpcResponse( + pushData.requestId, new NioManagedBuffer(response))) + } + + override def onFailure(e: Throwable): Unit = { + logError("[processPushData] Process pushData onFailure! ShuffleKey: " + + pushData.shuffleKey + ", partitionUniqueId: " + pushData.partitionUniqueId, e) + client.getChannel.writeAndFlush(new RpcFailure(pushData.requestId, e.getMessage)) + } + }) + } catch { + case e: Exception => + logError(s"Error while handlePushData $pushData", e) + client.getChannel.writeAndFlush(new RpcFailure(pushData.requestId, + Throwables.getStackTraceAsString(e))) + } finally { + pushData.body().release() + } + case pushMergedData: PushMergedData => + try { + handlePushMergedData(pushMergedData, new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + client.getChannel.writeAndFlush(new RpcResponse( + pushMergedData.requestId, new NioManagedBuffer(response))) + } + + override def onFailure(e: Throwable): Unit = { + logError("[processPushMergedData] Process PushMergedData onFailure! ShuffleKey: " + + pushMergedData.shuffleKey + + ", partitionUniqueId: " + pushMergedData.partitionUniqueIds.mkString(","), e) + client.getChannel.writeAndFlush( + new RpcFailure(pushMergedData.requestId, e.getMessage())) + } + }) + } catch { + case e: Exception => + logError(s"Error while handlePushMergedData $pushMergedData", e); + client.getChannel.writeAndFlush(new RpcFailure(pushMergedData.requestId, + Throwables.getStackTraceAsString(e))); + } finally { + pushMergedData.body().release() + } + } + + def handlePushData(pushData: PushData, callback: RpcResponseCallback): Unit = { + val shuffleKey = pushData.shuffleKey + val mode = PartitionLocation.getMode(pushData.mode) + val body = pushData.body.asInstanceOf[NettyManagedBuffer].getBuf + val isMaster = mode == PartitionLocation.Mode.Master + + val key = s"${pushData.requestId}" + if (isMaster) { + workerSource.startTimer(WorkerSource.MasterPushDataTime, key) + } else { + workerSource.startTimer(WorkerSource.SlavePushDataTime, key) + } + + // find FileWriter responsible for the data + val location = if (isMaster) { + partitionLocationInfo.getMasterLocation(shuffleKey, pushData.partitionUniqueId) + } else { + partitionLocationInfo.getSlaveLocation(shuffleKey, pushData.partitionUniqueId) + } + + val wrappedCallback = new RpcResponseCallback() { + override def onSuccess(response: ByteBuffer): Unit = { + if (isMaster) { + workerSource.stopTimer(WorkerSource.MasterPushDataTime, key) + if (response.remaining() > 0) { + val resp = ByteBuffer.allocate(response.remaining()) + resp.put(response) + resp.flip() + callback.onSuccess(resp) + } else { + callback.onSuccess(response) + } + } else { + workerSource.stopTimer(WorkerSource.SlavePushDataTime, key) + callback.onSuccess(response) + } + } + + override def onFailure(e: Throwable): Unit = { + logError(s"[handlePushData.onFailure] partitionLocation: $location") + workerSource.incCounter(WorkerSource.PushDataFailCount) + callback.onFailure(new Exception(StatusCode.PushDataFailSlave.getMessage(), e)) + } + } + + if (location == null) { + val (mapId, attemptId) = getMapAttempt(body) + if (shuffleMapperAttempts.containsKey(shuffleKey) && + -1 != shuffleMapperAttempts.get(shuffleKey)(mapId)) { + // partition data has already been committed + logInfo(s"Receive push data from speculative task(shuffle $shuffleKey, map $mapId, " + + s" attempt $attemptId), but this mapper has already been ended.") + wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.StageEnded.getValue))) + } else { + val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, map $mapId, " + + s"attempt $attemptId, uniqueId ${pushData.partitionUniqueId})." + logWarning(s"[handlePushData] $msg") + callback.onFailure(new Exception(StatusCode.PushDataFailPartitionNotFound.getMessage())) + } + return + } + val fileWriter = location.asInstanceOf[WorkingPartition].getFileWriter + val exception = fileWriter.getException + if (exception != null) { + logWarning(s"[handlePushData] fileWriter $fileWriter has Exception $exception") + val message = if (isMaster) { + StatusCode.PushDataFailMain.getMessage() + } else { + StatusCode.PushDataFailSlave.getMessage() + } + callback.onFailure(new Exception(message, exception)) + return + } + if (isMaster && fileWriter.getFileLength > fileWriter.getSplitThreshold()) { + fileWriter.setSplitFlag() + if (fileWriter.getSplitMode == PartitionSplitMode.soft) { + callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.SoftSplit.getValue))) + } else { + callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HardSplit.getValue))) + return + } + } + fileWriter.incrementPendingWrites() + + // for master, send data to slave + if (location.getPeer != null && isMaster) { + pushData.body().retain() + replicateThreadPool.submit(new Runnable { + override def run(): Unit = { + val peer = location.getPeer + val peerWorker = new WorkerInfo(peer.getHost, peer.getRpcPort, peer.getPushPort, + peer.getFetchPort, peer.getReplicatePort, -1, null) + if (unavailablePeers.containsKey(peerWorker)) { + pushData.body().release() + wrappedCallback.onFailure(new Exception(s"Peer $peerWorker unavailable!")) + return + } + try { + val client = pushClientFactory.createClient(peer.getHost, peer.getReplicatePort, + location.getReduceId) + val newPushData = new PushData( + PartitionLocation.Mode.Slave.mode(), + shuffleKey, + pushData.partitionUniqueId, + pushData.body) + client.pushData(newPushData, wrappedCallback) + } catch { + case e: Exception => + pushData.body().release() + unavailablePeers.put(peerWorker, System.currentTimeMillis()) + wrappedCallback.onFailure(e) + } + } + }) + } else { + wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]())) + } + + try { + fileWriter.write(body) + } catch { + case e: AlreadyClosedException => + fileWriter.decrementPendingWrites() + val (mapId, attemptId) = getMapAttempt(body) + val endedAttempt = if (shuffleMapperAttempts.containsKey(shuffleKey)) { + shuffleMapperAttempts.get(shuffleKey)(mapId) + } else -1 + logWarning(s"Append data failed for task(shuffle $shuffleKey, map $mapId, attempt" + + s" $attemptId), caused by ${e.getMessage}") + case e: Exception => + logError("Exception encountered when write.", e) + } + } + + def handlePushMergedData( + pushMergedData: PushMergedData, + callback: RpcResponseCallback): Unit = { + val shuffleKey = pushMergedData.shuffleKey + val mode = PartitionLocation.getMode(pushMergedData.mode) + val batchOffsets = pushMergedData.batchOffsets + val body = pushMergedData.body.asInstanceOf[NettyManagedBuffer].getBuf + val isMaster = mode == PartitionLocation.Mode.Master + + val key = s"${pushMergedData.requestId}" + if (isMaster) { + workerSource.startTimer(WorkerSource.MasterPushDataTime, key) + } else { + workerSource.startTimer(WorkerSource.SlavePushDataTime, key) + } + + val wrappedCallback = new RpcResponseCallback() { + override def onSuccess(response: ByteBuffer): Unit = { + if (isMaster) { + workerSource.stopTimer(WorkerSource.MasterPushDataTime, key) + if (response.remaining() > 0) { + val resp = ByteBuffer.allocate(response.remaining()) + resp.put(response) + resp.flip() + callback.onSuccess(resp) + } else { + callback.onSuccess(response) + } + } else { + workerSource.stopTimer(WorkerSource.SlavePushDataTime, key) + callback.onSuccess(response) + } + } + + override def onFailure(e: Throwable): Unit = { + workerSource.incCounter(WorkerSource.PushDataFailCount) + callback.onFailure(new Exception(StatusCode.PushDataFailSlave.getMessage, e)) + } + } + + // find FileWriters responsible for the data + val locations = pushMergedData.partitionUniqueIds.map { id => + val loc = if (isMaster) { + partitionLocationInfo.getMasterLocation(shuffleKey, id) + } else { + partitionLocationInfo.getSlaveLocation(shuffleKey, id) + } + if (loc == null) { + val (mapId, attemptId) = getMapAttempt(body) + if (shuffleMapperAttempts.containsKey(shuffleKey) + && -1 != shuffleMapperAttempts.get(shuffleKey)(mapId)) { + val msg = s"Receive push data from speculative task(shuffle $shuffleKey, map $mapId," + + s" attempt $attemptId), but this mapper has already been ended." + logInfo(msg) + wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.StageEnded.getValue))) + } else { + val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, map $mapId," + + s" attempt $attemptId, uniqueId $id)." + logWarning(s"[handlePushMergedData] $msg") + wrappedCallback.onFailure(new Exception(msg)) + } + return + } + loc + } + + val fileWriters = locations.map(_.asInstanceOf[WorkingPartition].getFileWriter) + val fileWriterWithException = fileWriters.find(_.getException != null) + if (fileWriterWithException.nonEmpty) { + val exception = fileWriterWithException.get.getException + logDebug(s"[handlePushMergedData] fileWriter ${fileWriterWithException}" + + s" has Exception $exception") + val message = if (isMaster) { + StatusCode.PushDataFailMain.getMessage() + } else { + StatusCode.PushDataFailSlave.getMessage() + } + callback.onFailure(new Exception(message, exception)) + return + } + fileWriters.foreach(_.incrementPendingWrites()) + + // for master, send data to slave + if (locations.head.getPeer != null && isMaster) { + pushMergedData.body().retain() + replicateThreadPool.submit(new Runnable { + override def run(): Unit = { + val location = locations.head + val peer = location.getPeer + val peerWorker = new WorkerInfo(peer.getHost, + peer.getRpcPort, peer.getPushPort, peer.getFetchPort, peer.getReplicatePort, -1, null) + if (unavailablePeers.containsKey(peerWorker)) { + pushMergedData.body().release() + wrappedCallback.onFailure(new Exception(s"Peer $peerWorker unavailable!")) + return + } + try { + val client = pushClientFactory.createClient( + peer.getHost, peer.getReplicatePort, location.getReduceId) + val newPushMergedData = new PushMergedData( + PartitionLocation.Mode.Slave.mode(), + shuffleKey, + pushMergedData.partitionUniqueIds, + batchOffsets, + pushMergedData.body) + client.pushMergedData(newPushMergedData, wrappedCallback) + } catch { + case e: Exception => + pushMergedData.body().release() + unavailablePeers.put(peerWorker, System.currentTimeMillis()) + wrappedCallback.onFailure(e) + } + } + }) + } else { + wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]())) + } + + var index = 0 + var fileWriter: FileWriter = null + var alreadyClosed = false + while (index < fileWriters.length) { + fileWriter = fileWriters(index) + val offset = body.readerIndex() + batchOffsets(index) + val length = if (index == fileWriters.length - 1) { + body.readableBytes() - batchOffsets(index) + } else { + batchOffsets(index + 1) - batchOffsets(index) + } + val batchBody = body.slice(offset, length) + + try { + if (!alreadyClosed) { + fileWriter.write(batchBody) + } else { + fileWriter.decrementPendingWrites() + } + } catch { + case e: AlreadyClosedException => + fileWriter.decrementPendingWrites() + alreadyClosed = true + val (mapId, attemptId) = getMapAttempt(body) + val endedAttempt = if (shuffleMapperAttempts.containsKey(shuffleKey)) { + shuffleMapperAttempts.get(shuffleKey)(mapId) + } else -1 + logWarning(s"Append data failed for task(shuffle $shuffleKey, map $mapId, attempt" + + s" $attemptId), caused by ${e.getMessage}") + case e: Exception => + logError("Exception encountered when write.", e) + } + index += 1 + } + } + + private def getMapAttempt(body: ByteBuf): (Int, Int) = { + // header: mapId attemptId batchId compressedTotalSize + val header = new Array[Byte](8) + body.getBytes(body.readerIndex(), header) + val mapId = Platform.getInt(header, Platform.BYTE_ARRAY_OFFSET) + val attemptId = Platform.getInt(header, Platform.BYTE_ARRAY_OFFSET + 4) + (mapId, attemptId) + } + + override def checkRegistered(): Boolean = registered.get() +} diff --git a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala index c86296729d8..537335cff51 100644 --- a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala +++ b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/Worker.scala @@ -17,49 +17,50 @@ package com.aliyun.emr.rss.service.deploy.worker -import java.io.IOException -import java.nio.ByteBuffer -import java.util.{ArrayList => jArrayList, HashSet => jHashSet, List => jList} -import java.util.concurrent.{CancellationException, CompletableFuture, ConcurrentHashMap, ExecutionException, LinkedBlockingQueue, ScheduledFuture, TimeoutException, TimeUnit} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} -import java.util.function.BiFunction +import java.util.{ArrayList => jArrayList, HashSet => jHashSet} +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ -import io.netty.buffer.ByteBuf -import io.netty.util.{HashedWheelTimer, Timeout, TimerTask} -import io.netty.util.internal.ConcurrentSet +import io.netty.util.HashedWheelTimer import com.aliyun.emr.rss.common.RssConf -import com.aliyun.emr.rss.common.RssConf.{memoryTrimActionThreshold, partitionSortMaxMemoryRatio, partitionSortTimeout, workerDirectMemoryPressureCheckIntervalMs, workerDirectMemoryReportIntervalSecond, workerPausePushDataRatio, workerPauseRepcaliteRatio, workerResumeRatio} -import com.aliyun.emr.rss.common.exception.{AlreadyClosedException, RssException} +import com.aliyun.emr.rss.common.RssConf._ +import com.aliyun.emr.rss.common.exception.RssException import com.aliyun.emr.rss.common.haclient.RssHARetryClient import com.aliyun.emr.rss.common.internal.Logging import com.aliyun.emr.rss.common.meta.{PartitionLocationInfo, WorkerInfo} import com.aliyun.emr.rss.common.metrics.MetricsSystem import com.aliyun.emr.rss.common.metrics.source.{JVMCPUSource, JVMSource, NetWorkSource} import com.aliyun.emr.rss.common.network.TransportContext -import com.aliyun.emr.rss.common.network.buffer.NettyManagedBuffer -import com.aliyun.emr.rss.common.network.client.{RpcResponseCallback, TransportClientBootstrap} -import com.aliyun.emr.rss.common.network.protocol.{PushData, PushMergedData} -import com.aliyun.emr.rss.common.network.server.{ChannelsLimiter, FileInfo, MemoryTracker, TransportServerBootstrap} -import com.aliyun.emr.rss.common.protocol.{PartitionLocation, PartitionSplitMode, RpcNameConstants, TransportModuleConstants} -import com.aliyun.emr.rss.common.protocol.PartitionLocation.StorageHint +import com.aliyun.emr.rss.common.network.client.TransportClientBootstrap +import com.aliyun.emr.rss.common.network.server.{ChannelsLimiter, MemoryTracker, TransportServerBootstrap} +import com.aliyun.emr.rss.common.protocol.{RpcNameConstants, TransportModuleConstants} import com.aliyun.emr.rss.common.protocol.message.ControlMessages._ -import com.aliyun.emr.rss.common.protocol.message.StatusCode import com.aliyun.emr.rss.common.rpc._ -import com.aliyun.emr.rss.common.unsafe.Platform import com.aliyun.emr.rss.common.util.{ThreadUtils, Utils} import com.aliyun.emr.rss.server.common.http.{HttpServer, HttpServerInitializer} import com.aliyun.emr.rss.service.deploy.worker.http.HttpRequestHandler private[deploy] class Worker( - override val rpcEnv: RpcEnv, val conf: RssConf, - val metricsSystem: MetricsSystem) - extends RpcEndpoint with PushDataHandler with OpenStreamHandler with Registerable with Logging { + val workerArgs: WorkerArguments) extends Logging { - private val workerSource = { + val rpcEnv = RpcEnv.create( + RpcNameConstants.WORKER_SYS, + workerArgs.host, + workerArgs.host, + workerArgs.port, + conf, + Math.max(64, Runtime.getRuntime.availableProcessors())) + + private val host = rpcEnv.address.host + private val rpcPort = rpcEnv.address.port + Utils.checkHost(host) + + val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, WorkerSource.ServletPath) + val workerSource = { val source = new WorkerSource(conf) metricsSystem.registerSource(source) metricsSystem.registerSource(new NetWorkSource(conf, MetricsSystem.ROLE_WOKRER)) @@ -68,6 +69,8 @@ private[deploy] class Worker( source } + val localStorageManager = new LocalStorageManager(conf, workerSource, this) + val memoryTracker = MemoryTracker.initialize( workerPausePushDataRatio(conf), workerPauseRepcaliteRatio(conf), @@ -76,24 +79,25 @@ private[deploy] class Worker( workerDirectMemoryPressureCheckIntervalMs(conf), workerDirectMemoryReportIntervalSecond(conf), memoryTrimActionThreshold(conf)) - - private val localStorageManager = new LocalStorageManager(conf, workerSource, this) memoryTracker.registerMemoryListener(localStorageManager) - private val partitionsSorter = new PartitionFilesSorter(memoryTracker, + val partitionsSorter = new PartitionFilesSorter(memoryTracker, partitionSortTimeout(conf), RssConf.workerFetchChunkSize(conf), RssConf.memoryReservedForSingleSort(conf), workerSource) - private val (pushServer, pushClientFactory) = { + var controller = new Controller(rpcEnv, conf, metricsSystem) + rpcEnv.setupEndpoint(RpcNameConstants.WORKER_EP, controller) + + val pushDataHandler = new PushDataHandler() + val (pushServer, pushClientFactory) = { val closeIdleConnections = RssConf.closeIdleConnections(conf) val numThreads = conf.getInt("rss.push.io.threads", localStorageManager.numDisks * 2) val transportConf = Utils.fromRssConf(conf, TransportModuleConstants.PUSH_MODULE, numThreads) - val rpcHandler = new PushDataRpcHandler(transportConf, this) val pushServerLimiter = new ChannelsLimiter(TransportModuleConstants.PUSH_MODULE) val transportContext: TransportContext = - new TransportContext(transportConf, rpcHandler, closeIdleConnections, workerSource, + new TransportContext(transportConf, pushDataHandler, closeIdleConnections, workerSource, pushServerLimiter) val serverBootstraps = new jArrayList[TransportServerBootstrap]() val clientBootstraps = new jArrayList[TransportClientBootstrap]() @@ -101,58 +105,74 @@ private[deploy] class Worker( transportContext.createClientFactory(clientBootstraps)) } + val replicateHandler = new PushDataHandler() private val replicateServer = { val closeIdleConnections = RssConf.closeIdleConnections(conf) val numThreads = conf.getInt("rss.replicate.io.threads", localStorageManager.numDisks * 2) val transportConf = Utils.fromRssConf(conf, TransportModuleConstants.REPLICATE_MODULE, numThreads) - val rpcHandler = new PushDataRpcHandler(transportConf, this) val replicateLimiter = new ChannelsLimiter(TransportModuleConstants.REPLICATE_MODULE) val transportContext: TransportContext = - new TransportContext(transportConf, rpcHandler, closeIdleConnections, workerSource, + new TransportContext(transportConf, replicateHandler, closeIdleConnections, workerSource, replicateLimiter) val serverBootstraps = new jArrayList[TransportServerBootstrap]() transportContext.createServer(RssConf.replicateServerPort(conf), serverBootstraps) } + var fetchHandler: FetchHandler = _ private val fetchServer = { val closeIdleConnections = RssConf.closeIdleConnections(conf) val numThreads = conf.getInt("rss.fetch.io.threads", localStorageManager.numDisks * 2) val transportConf = Utils.fromRssConf(conf, TransportModuleConstants.FETCH_MODULE, numThreads) - val rpcHandler = new ChunkFetchRpcHandler(transportConf, workerSource, this) + fetchHandler = new FetchHandler(transportConf) val transportContext: TransportContext = - new TransportContext(transportConf, rpcHandler, closeIdleConnections, workerSource) + new TransportContext(transportConf, fetchHandler, closeIdleConnections, workerSource) val serverBootstraps = new jArrayList[TransportServerBootstrap]() transportContext.createServer(RssConf.fetchServerPort(conf), serverBootstraps) } - private val host = rpcEnv.address.host - private val rpcPort = rpcEnv.address.port private val pushPort = pushServer.getPort private val fetchPort = fetchServer.getPort private val replicatePort = replicateServer.getPort - Utils.checkHost(host) assert(pushPort > 0) assert(fetchPort > 0) assert(replicatePort > 0) + // worker info + val workerInfo = new WorkerInfo(host, rpcPort, pushPort, fetchPort, replicatePort, + RssConf.workerNumSlots(conf, localStorageManager.numDisks), controller.self) + // whether this Worker registered to Master succesfully - private val registered = new AtomicBoolean(false) + val registered = new AtomicBoolean(false) - private val shuffleMapperAttempts = new ConcurrentHashMap[String, Array[Int]]() + val shuffleMapperAttempts = new ConcurrentHashMap[String, Array[Int]]() + val partitionLocationInfo = new PartitionLocationInfo private val rssHARetryClient = new RssHARetryClient(rpcEnv, conf) - // worker info - private val workerInfo = new WorkerInfo(host, rpcPort, pushPort, fetchPort, replicatePort, - RssConf.workerNumSlots(conf, localStorageManager.numDisks), self) + // (workerInfo -> last connect timeout timestamp) + val unavailablePeers = new ConcurrentHashMap[WorkerInfo, Long]() - private val partitionLocationInfo = new PartitionLocationInfo + // Threads + private val forwardMessageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") + private var logAvailableFlushBuffersTask: ScheduledFuture[_] = _ + private var sendHeartbeatTask: ScheduledFuture[_] = _ + private var checkFastfailTask: ScheduledFuture[_] = _ + val replicateThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "worker-replicate-data", RssConf.workerReplicateNumThreads(conf)) + val commitThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "Worker-CommitFiles", RssConf.workerAsyncCommitFileThreads(conf)) + val asyncReplyPool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("async-reply") + val timer = new HashedWheelTimer() - private val replicateFastfailDuration = RssConf.replicateFastFailDurationMs(conf) - // (workerInfo -> last connect timeout timestamp) - private val unavailablePeers = new ConcurrentHashMap[WorkerInfo, Long]() + // Configs + private val HEARTBEAT_MILLIS = RssConf.workerTimeoutMs(conf) / 4 + private val REPLICATE_FAST_FAIL_DURATION = RssConf.replicateFastFailDurationMs(conf) + + private val cleanTaskQueue = new LinkedBlockingQueue[jHashSet[String]] + var cleaner: Thread = _ workerSource.addGauge( WorkerSource.RegisteredShuffleCount, _ => partitionLocationInfo.shuffleKeySet.size()) @@ -167,25 +187,6 @@ private[deploy] class Worker( workerSource.addGauge(WorkerSource.PausePushDataAndReplicateCount, _ => memoryTracker.getPausePushDataAndReplicateCounter) - // Threads - private val forwardMessageScheduler = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") - private val replicateThreadPool = ThreadUtils.newDaemonCachedThreadPool( - "worker-replicate-data", RssConf.workerReplicateNumThreads(conf)) - private val timer = new HashedWheelTimer() - - // Configs - private val HEARTBEAT_MILLIS = RssConf.workerTimeoutMs(conf) / 4 - - // shared ExecutorService for flush - private val commitThreadPool = ThreadUtils.newDaemonCachedThreadPool( - "Worker-CommitFiles", RssConf.workerAsyncCommitFileThreads(conf)) - private val asyncReplyPool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("async-reply") - - private var logAvailableFlushBuffersTask: ScheduledFuture[_] = _ - private var sendHeartbeatTask: ScheduledFuture[_] = _ - private var checkFastfailTask: ScheduledFuture[_] = _ - def updateNumSlots(numSlots: Int): Unit = { workerInfo.setNumSlots(numSlots) heartBeatToMaster() @@ -199,13 +200,24 @@ private[deploy] class Worker( HeartbeatFromWorker(host, rpcPort, pushPort, fetchPort, replicatePort, workerInfo.numSlots, shuffleKeys) , classOf[HeartbeatResponse]) - cleanTaskQueue.put(response.expiredShuffleKeys) - if (!response.registered) { - logError("Current worker not registered in master") + if (response.registered) { + cleanTaskQueue.put(response.expiredShuffleKeys) + } else { + logError("Worker not registered in master, clean all shuffle data and register again.") + // Clean all shuffle related metadata and data + cleanup(shuffleKeys) + try { + registerWithMaster() + } catch { + case e: Throwable => + logError("Re-register worker failed after worker lost.", e) + // Register failed then stop server + controller.stop() + } } } - override def onStart(): Unit = { + def init(): Unit = { logInfo(s"Starting Worker $host:$pushPort:$fetchPort:$replicatePort" + s" with ${workerInfo.numSlots} slots.") registerWithMaster() @@ -221,15 +233,62 @@ private[deploy] class Worker( override def run(): Unit = Utils.tryLogNonFatalError { unavailablePeers.entrySet().asScala.foreach(entry => { if (System.currentTimeMillis() - entry.getValue > - replicateFastfailDuration) { + REPLICATE_FAST_FAIL_DURATION) { unavailablePeers.remove(entry.getKey) } }) } - }, 0, replicateFastfailDuration, TimeUnit.MILLISECONDS) + }, 0, REPLICATE_FAST_FAIL_DURATION, TimeUnit.MILLISECONDS) + + if (RssConf.metricsSystemEnable(conf)) { + logInfo(s"Metrics system enabled!") + metricsSystem.start() + + var port = RssConf.workerPrometheusMetricPort(conf) + var initialized = false + while (!initialized) { + try { + val httpServer = new HttpServer( + new HttpServerInitializer( + new HttpRequestHandler(metricsSystem.getPrometheusHandler)), port) + httpServer.start() + initialized = true + } catch { + case e: Exception => + logWarning(s"HttpServer pushPort $port may already exist, try pushPort ${port + 1}.", e) + port += 1 + Thread.sleep(1000) + } + } + } + + cleaner = new Thread("Cleaner") { + override def run(): Unit = { + while (true) { + val expiredShuffleKeys = cleanTaskQueue.take() + try { + cleanup(expiredShuffleKeys) + } catch { + case e: Throwable => + logError("Cleanup failed", e) + } + } + } + } + + pushDataHandler.init(this) + replicateHandler.init(this) + fetchHandler.init(this) + controller.init(this) + + cleaner.setDaemon(true) + cleaner.start() + + rpcEnv.awaitTermination() + stop() } - override def onStop(): Unit = { + def stop(): Unit = { logInfo("Stopping RSS Worker.") if (sendHeartbeatTask != null) { @@ -261,648 +320,7 @@ private[deploy] class Worker( logInfo("RSS Worker is stopped.") } - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case ReserveSlots(applicationId, shuffleId, masterLocations, slaveLocations, splitThreashold, - splitMode, storageHint) => - val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) - workerSource.sample(WorkerSource.ReserveSlotsTime, shuffleKey) { - logInfo(s"Received ReserveSlots request, $shuffleKey," + - s" master number: ${masterLocations.size()}, slave number: ${slaveLocations.size()}") - logDebug(s"Received ReserveSlots request, $shuffleKey, " + - s"master partitions: ${masterLocations.asScala.map(_.getUniqueId).mkString(",")}; " + - s"slave partitions: ${slaveLocations.asScala.map(_.getUniqueId).mkString(",")}.") - handleReserveSlots(context, applicationId, shuffleId, masterLocations, - slaveLocations, splitThreashold, splitMode, storageHint) - logDebug(s"ReserveSlots for $shuffleKey succeed.") - } - - case CommitFiles(applicationId, shuffleId, masterIds, slaveIds, mapAttempts) => - val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) - workerSource.sample(WorkerSource.CommitFilesTime, shuffleKey) { - logDebug(s"Received CommitFiles request, $shuffleKey, master files" + - s" ${masterIds.asScala.mkString(",")}; slave files ${slaveIds.asScala.mkString(",")}.") - val commitFilesTimeMs = Utils.timeIt({ - handleCommitFiles(context, shuffleKey, masterIds, slaveIds, mapAttempts) - }) - logDebug(s"Done processed CommitFiles request with shuffleKey $shuffleKey, in " + - s"${commitFilesTimeMs}ms.") - } - - case GetWorkerInfos => - logDebug("Received GetWorkerInfos request.") - handleGetWorkerInfos(context) - - case ThreadDump => - logDebug("Receive ThreadDump request.") - handleThreadDump(context) - - case Destroy(shuffleKey, masterLocations, slaveLocations) => - logDebug(s"Receive Destroy request, $shuffleKey.") - handleDestroy(context, shuffleKey, masterLocations, slaveLocations) - } - - private def handleReserveSlots( - context: RpcCallContext, - applicationId: String, - shuffleId: Int, - masterLocations: jList[PartitionLocation], - slaveLocations: jList[PartitionLocation], - splitThreshold: Long, - splitMode: PartitionSplitMode, - storageHint: StorageHint): Unit = { - val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) - if (!localStorageManager.hasAvailableWorkingDirs) { - val msg = "Local storage has no available dirs!" - logError(s"[handleReserveSlots] $msg") - context.reply(ReserveSlotsResponse(StatusCode.ReserveSlotFailed, msg)) - return - } - val masterPartitions = new jArrayList[PartitionLocation]() - try { - for (ind <- 0 until masterLocations.size()) { - val location = masterLocations.get(ind) - val writer = localStorageManager.createWriter(applicationId, shuffleId, location, - splitThreshold, splitMode) - masterPartitions.add(new WorkingPartition(location, writer)) - } - } catch { - case e: Exception => - logError(s"CreateWriter for $shuffleKey failed.", e) - } - if (masterPartitions.size() < masterLocations.size()) { - val msg = s"Not all master partition satisfied for $shuffleKey" - logWarning(s"[handleReserveSlots] $msg, will destroy writers.") - masterPartitions.asScala.foreach(_.asInstanceOf[WorkingPartition].getFileWriter.destroy()) - context.reply(ReserveSlotsResponse(StatusCode.ReserveSlotFailed, msg)) - return - } - - val slavePartitions = new jArrayList[PartitionLocation]() - try { - for (ind <- 0 until slaveLocations.size()) { - val location = slaveLocations.get(ind) - val writer = localStorageManager.createWriter(applicationId, shuffleId, - location, splitThreshold, splitMode) - slavePartitions.add(new WorkingPartition(location, writer)) - } - } catch { - case e: Exception => - logError(s"CreateWriter for $shuffleKey failed.", e) - } - if (slavePartitions.size() < slaveLocations.size()) { - val msg = s"Not all slave partition satisfied for $shuffleKey" - logWarning(s"[handleReserveSlots] $msg, destroy writers.") - masterPartitions.asScala.foreach(_.asInstanceOf[WorkingPartition].getFileWriter.destroy()) - slavePartitions.asScala.foreach(_.asInstanceOf[WorkingPartition].getFileWriter.destroy()) - context.reply(ReserveSlotsResponse(StatusCode.ReserveSlotFailed, msg)) - return - } - - // reserve success, update status - partitionLocationInfo.addMasterPartitions(shuffleKey, masterPartitions) - partitionLocationInfo.addSlavePartitions(shuffleKey, slavePartitions) - workerInfo.allocateSlots(shuffleKey, masterPartitions.size() + slavePartitions.size()) - logInfo(s"Reserved ${masterPartitions.size()} master location and ${slavePartitions.size()}" + - s" slave location for $shuffleKey master: ${masterPartitions}\nslave: ${slavePartitions}.") - context.reply(ReserveSlotsResponse(StatusCode.Success)) - } - - private def commitFiles( - shuffleKey: String, - uniqueIds: jList[String], - committedIds: ConcurrentSet[String], - failedIds: ConcurrentSet[String], - master: Boolean = true): CompletableFuture[Void] = { - var future: CompletableFuture[Void] = null - - if (uniqueIds != null) { - uniqueIds.asScala.foreach { uniqueId => - val task = CompletableFuture.runAsync(new Runnable { - override def run(): Unit = { - logDebug(s"Committing $shuffleKey $uniqueId") - try { - val location = if (master) { - partitionLocationInfo.getMasterLocation(shuffleKey, uniqueId) - } else { - partitionLocationInfo.getSlaveLocation(shuffleKey, uniqueId) - } - - if (location == null) { - logWarning(s"Get Partition Location for $shuffleKey $uniqueId but didn't exist.") - return - } - - val fileWriter = location.asInstanceOf[WorkingPartition].getFileWriter - val bytes = fileWriter.close() - if (bytes > 0L) { - logDebug(s"FileName ${fileWriter.getFile.getAbsoluteFile}, size $bytes") - committedIds.add(uniqueId) - } - } catch { - case e: IOException => - logError(s"Commit file for $shuffleKey $uniqueId failed.", e) - failedIds.add(uniqueId) - } - } - }, commitThreadPool) - - if (future == null) { - future = task - } else { - future = CompletableFuture.allOf(future, task) - } - } - } - - future - } - - private def handleCommitFiles( - context: RpcCallContext, - shuffleKey: String, - masterIds: jList[String], - slaveIds: jList[String], - mapAttempts: Array[Int]): Unit = { - // return null if shuffleKey does not exist - if (!partitionLocationInfo.containsShuffle(shuffleKey)) { - logError(s"Shuffle $shuffleKey doesn't exist!") - context.reply(CommitFilesResponse( - StatusCode.ShuffleNotRegistered, new jArrayList[String](), new jArrayList[String](), - masterIds, slaveIds)) - return - } - - logDebug(s"[handleCommitFiles] ${shuffleKey} -> ${mapAttempts.mkString(",")}") - // close and flush files. - shuffleMapperAttempts.putIfAbsent(shuffleKey, mapAttempts) - - // Use ConcurrentSet to avoid excessive lock contention. - val committedMasterIds = new ConcurrentSet[String]() - val committedSlaveIds = new ConcurrentSet[String]() - val failedMasterIds = new ConcurrentSet[String]() - val failedSlaveIds = new ConcurrentSet[String]() - - val masterFuture = commitFiles(shuffleKey, masterIds, committedMasterIds, failedMasterIds) - val slaveFuture = commitFiles(shuffleKey, slaveIds, committedSlaveIds, failedSlaveIds, false) - - val future = if (masterFuture != null && slaveFuture != null) { - CompletableFuture.allOf(masterFuture, slaveFuture) - } else if (masterFuture != null) { - masterFuture - } else if (slaveFuture != null) { - slaveFuture - } else { - null - } - - def reply(): Unit = { - // release slots before reply. - val numSlotsReleased = - partitionLocationInfo.removeMasterPartitions(shuffleKey, masterIds) + - partitionLocationInfo.removeSlavePartitions(shuffleKey, slaveIds) - workerInfo.releaseSlots(shuffleKey, numSlotsReleased) - - val committedMasterIdList = new jArrayList[String](committedMasterIds) - val committedSlaveIdList = new jArrayList[String](committedSlaveIds) - val failedMasterIdList = new jArrayList[String](failedMasterIds) - val failedSlaveIdList = new jArrayList[String](failedSlaveIds) - // reply - if (failedMasterIds.isEmpty && failedSlaveIds.isEmpty) { - logInfo(s"CommitFiles for $shuffleKey success with ${committedMasterIds.size()}" + - s" master partitions and ${committedSlaveIds.size()} slave partitions!") - context.reply(CommitFilesResponse( - StatusCode.Success, committedMasterIdList, committedSlaveIdList, - new jArrayList[String](), new jArrayList[String]())) - } else { - logWarning(s"CommitFiles for $shuffleKey failed with ${failedMasterIds.size()} master" + - s" partitions and ${failedSlaveIds.size()} slave partitions!") - context.reply(CommitFilesResponse(StatusCode.PartialSuccess, committedMasterIdList, - committedSlaveIdList, failedMasterIdList, failedSlaveIdList)) - } - } - - if (future != null) { - val result = new AtomicReference[CompletableFuture[Unit]]() - val flushTimeout = RssConf.flushTimeout(conf) - - val timeout = timer.newTimeout(new TimerTask { - override def run(timeout: Timeout): Unit = { - if (result.get() != null) { - result.get().cancel(true) - logWarning(s"After waiting $flushTimeout s, cancel all commit file jobs.") - } - } - }, flushTimeout, TimeUnit.SECONDS) - - result.set(future.handleAsync(new BiFunction[Void, Throwable, Unit] { - override def apply(v: Void, t: Throwable): Unit = { - if (null != t) { - t match { - case _: CancellationException => - logWarning("While handling commitFiles, canceled.") - case ee: ExecutionException => - logError("While handling commitFiles, ExecutionException raised.", ee) - case ie: InterruptedException => - logWarning("While handling commitFiles, interrupted.") - Thread.currentThread().interrupt() - throw ie - case _: TimeoutException => - logWarning(s"While handling commitFiles, timeout after $flushTimeout s.") - case throwable: Throwable => - logError("While handling commitFiles, exception occurs.", throwable) - } - } else { - // finish, cancel timeout job first. - timeout.cancel() - logDebug(s"Handle commitFiles successfully $shuffleKey, reply message.") - reply() - } - } - }, asyncReplyPool)) // should not use commitThreadPool in case of block by commit files. - } else { - logDebug(s"All future is null, reply directly for $shuffleKey.") - // If both of two futures are null, then reply directly. - reply() - } - } - - private def handleDestroy( - context: RpcCallContext, - shuffleKey: String, - masterLocations: jList[String], - slaveLocations: jList[String]): Unit = { - // check whether shuffleKey has registered - if (!partitionLocationInfo.containsShuffle(shuffleKey)) { - logWarning(s"Shuffle $shuffleKey not registered!") - context.reply(DestroyResponse( - StatusCode.ShuffleNotRegistered, masterLocations, slaveLocations)) - return - } - - val failedMasters = new jArrayList[String]() - val failedSlaves = new jArrayList[String]() - - // destroy master locations - if (masterLocations != null && !masterLocations.isEmpty) { - masterLocations.asScala.foreach { loc => - val allocatedLoc = partitionLocationInfo.getMasterLocation(shuffleKey, loc) - if (allocatedLoc == null) { - failedMasters.add(loc) - } else { - allocatedLoc.asInstanceOf[WorkingPartition].getFileWriter.destroy() - } - } - // remove master locations from WorkerInfo - partitionLocationInfo.removeMasterPartitions(shuffleKey, masterLocations) - } - // destroy slave locations - if (slaveLocations != null && !slaveLocations.isEmpty) { - slaveLocations.asScala.foreach { loc => - val allocatedLoc = partitionLocationInfo.getSlaveLocation(shuffleKey, loc) - if (allocatedLoc == null) { - failedSlaves.add(loc) - } else { - allocatedLoc.asInstanceOf[WorkingPartition].getFileWriter.destroy() - } - } - // remove slave locations from worker info - partitionLocationInfo.removeSlavePartitions(shuffleKey, slaveLocations) - } - // reply - if (failedMasters.isEmpty && failedSlaves.isEmpty) { - logInfo(s"Destroy ${masterLocations.size()} master location and ${slaveLocations.size()}" + - s" slave locations for $shuffleKey successfully.") - context.reply(DestroyResponse(StatusCode.Success, null, null)) - } else { - logInfo(s"Destroy ${failedMasters.size()}/${masterLocations.size()} master location and" + - s"${failedSlaves.size()}/${slaveLocations.size()} slave location for" + - s" $shuffleKey PartialSuccess.") - context.reply(DestroyResponse(StatusCode.PartialSuccess, failedMasters, failedSlaves)) - } - } - - private def handleGetWorkerInfos(context: RpcCallContext): Unit = { - val list = new jArrayList[WorkerInfo]() - list.add(workerInfo) - context.reply(GetWorkerInfosResponse(StatusCode.Success, list.asScala.toList: _*)) - } - - private def handleThreadDump(context: RpcCallContext): Unit = { - val threadDump = Utils.getThreadDump() - context.reply(ThreadDumpResponse(threadDump)) - } - - private def getMapAttempt(body: ByteBuf): (Int, Int) = { - // header: mapId attemptId batchId compressedTotalSize - val header = new Array[Byte](8) - body.getBytes(body.readerIndex(), header) - val mapId = Platform.getInt(header, Platform.BYTE_ARRAY_OFFSET) - val attemptId = Platform.getInt(header, Platform.BYTE_ARRAY_OFFSET + 4) - (mapId, attemptId) - } - - override def handlePushData(pushData: PushData, callback: RpcResponseCallback): Unit = { - val shuffleKey = pushData.shuffleKey - val mode = PartitionLocation.getMode(pushData.mode) - val body = pushData.body.asInstanceOf[NettyManagedBuffer].getBuf - val isMaster = mode == PartitionLocation.Mode.Master - val bodySize = pushData.body().size() - - val key = s"${pushData.requestId}" - if (isMaster) { - workerSource.startTimer(WorkerSource.MasterPushDataTime, key) - } else { - workerSource.startTimer(WorkerSource.SlavePushDataTime, key) - } - - // find FileWriter responsible for the data - val location = if (isMaster) { - partitionLocationInfo.getMasterLocation(shuffleKey, pushData.partitionUniqueId) - } else { - partitionLocationInfo.getSlaveLocation(shuffleKey, pushData.partitionUniqueId) - } - - val wrappedCallback = new RpcResponseCallback() { - override def onSuccess(response: ByteBuffer): Unit = { - if (isMaster) { - workerSource.stopTimer(WorkerSource.MasterPushDataTime, key) - if (response.remaining() > 0) { - val resp = ByteBuffer.allocate(response.remaining()) - resp.put(response) - resp.flip() - callback.onSuccess(resp) - } else { - callback.onSuccess(response) - } - } else { - workerSource.stopTimer(WorkerSource.SlavePushDataTime, key) - callback.onSuccess(response) - } - } - - override def onFailure(e: Throwable): Unit = { - logError(s"[handlePushData.onFailure] partitionLocation: $location") - workerSource.incCounter(WorkerSource.PushDataFailCount) - callback.onFailure(new Exception(StatusCode.PushDataFailSlave.getMessage(), e)) - } - } - - if (location == null) { - val (mapId, attemptId) = getMapAttempt(body) - if (shuffleMapperAttempts.containsKey(shuffleKey) && - -1 != shuffleMapperAttempts.get(shuffleKey)(mapId)) { - // partition data has already been committed - logInfo(s"Receive push data from speculative task(shuffle $shuffleKey, map $mapId, " + - s" attempt $attemptId), but this mapper has already been ended.") - wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.StageEnded.getValue))) - } else { - val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, map $mapId, " + - s"attempt $attemptId, uniqueId ${pushData.partitionUniqueId})." - logWarning(s"[handlePushData] $msg") - callback.onFailure(new Exception(StatusCode.PushDataFailPartitionNotFound.getMessage())) - } - return - } - val fileWriter = location.asInstanceOf[WorkingPartition].getFileWriter - val exception = fileWriter.getException - if (exception != null) { - logWarning(s"[handlePushData] fileWriter $fileWriter has Exception $exception") - val message = if (isMaster) { - StatusCode.PushDataFailMain.getMessage() - } else { - StatusCode.PushDataFailSlave.getMessage() - } - callback.onFailure(new Exception(message, exception)) - return - } - if (isMaster && fileWriter.getFileLength > fileWriter.getSplitThreshold()) { - fileWriter.setSplitFlag() - if (fileWriter.getSplitMode == PartitionSplitMode.soft) { - callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.SoftSplit.getValue))) - } else { - callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HardSplit.getValue))) - return - } - } - fileWriter.incrementPendingWrites() - - // for master, send data to slave - if (location.getPeer != null && isMaster) { - pushData.body().retain() - replicateThreadPool.submit(new Runnable { - override def run(): Unit = { - val peer = location.getPeer - val peerWorker = new WorkerInfo(peer.getHost, peer.getRpcPort, peer.getPushPort, - peer.getFetchPort, peer.getReplicatePort, -1, null) - if (unavailablePeers.containsKey(peerWorker)) { - pushData.body().release() - wrappedCallback.onFailure(new Exception(s"Peer $peerWorker unavailable!")) - return - } - try { - val client = pushClientFactory.createClient(peer.getHost, peer.getReplicatePort, - location.getReduceId) - val newPushData = new PushData( - PartitionLocation.Mode.Slave.mode(), - shuffleKey, - pushData.partitionUniqueId, - pushData.body) - client.pushData(newPushData, wrappedCallback) - } catch { - case e: Exception => - pushData.body().release() - unavailablePeers.put(peerWorker, System.currentTimeMillis()) - wrappedCallback.onFailure(e) - } - } - }) - } else { - wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]())) - } - - try { - fileWriter.write(body) - } catch { - case e: AlreadyClosedException => - fileWriter.decrementPendingWrites() - val (mapId, attemptId) = getMapAttempt(body) - val endedAttempt = if (shuffleMapperAttempts.containsKey(shuffleKey)) { - shuffleMapperAttempts.get(shuffleKey)(mapId) - } else -1 - logWarning(s"Append data failed for task(shuffle $shuffleKey, map $mapId, attempt" + - s" $attemptId), caused by ${e.getMessage}") - case e: Exception => - logError("Exception encountered when write.", e) - } - } - - override def handlePushMergedData( - pushMergedData: PushMergedData, callback: RpcResponseCallback): Unit = { - val shuffleKey = pushMergedData.shuffleKey - val mode = PartitionLocation.getMode(pushMergedData.mode) - val batchOffsets = pushMergedData.batchOffsets - val body = pushMergedData.body.asInstanceOf[NettyManagedBuffer].getBuf - val isMaster = mode == PartitionLocation.Mode.Master - val bodySize = pushMergedData.body().size() - - val key = s"${pushMergedData.requestId}" - if (isMaster) { - workerSource.startTimer(WorkerSource.MasterPushDataTime, key) - } else { - workerSource.startTimer(WorkerSource.SlavePushDataTime, key) - } - - val wrappedCallback = new RpcResponseCallback() { - override def onSuccess(response: ByteBuffer): Unit = { - if (isMaster) { - workerSource.stopTimer(WorkerSource.MasterPushDataTime, key) - if (response.remaining() > 0) { - val resp = ByteBuffer.allocate(response.remaining()) - resp.put(response) - resp.flip() - callback.onSuccess(resp) - } else { - callback.onSuccess(response) - } - } else { - workerSource.stopTimer(WorkerSource.SlavePushDataTime, key) - callback.onSuccess(response) - } - } - - override def onFailure(e: Throwable): Unit = { - workerSource.incCounter(WorkerSource.PushDataFailCount) - callback.onFailure(new Exception(StatusCode.PushDataFailSlave.getMessage, e)) - } - } - - // find FileWriters responsible for the data - val locations = pushMergedData.partitionUniqueIds.map { id => - val loc = if (isMaster) { - partitionLocationInfo.getMasterLocation(shuffleKey, id) - } else { - partitionLocationInfo.getSlaveLocation(shuffleKey, id) - } - if (loc == null) { - val (mapId, attemptId) = getMapAttempt(body) - if (shuffleMapperAttempts.containsKey(shuffleKey) - && -1 != shuffleMapperAttempts.get(shuffleKey)(mapId)) { - val msg = s"Receive push data from speculative task(shuffle $shuffleKey, map $mapId," + - s" attempt $attemptId), but this mapper has already been ended." - logInfo(msg) - wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.StageEnded.getValue))) - } else { - val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, map $mapId," + - s" attempt $attemptId, uniqueId $id)." - logWarning(s"[handlePushMergedData] $msg") - wrappedCallback.onFailure(new Exception(msg)) - } - return - } - loc - } - - val fileWriters = locations.map(_.asInstanceOf[WorkingPartition].getFileWriter) - val fileWriterWithException = fileWriters.find(_.getException != null) - if (fileWriterWithException.nonEmpty) { - val exception = fileWriterWithException.get.getException - logDebug(s"[handlePushMergedData] fileWriter ${fileWriterWithException}" + - s" has Exception $exception") - val message = if (isMaster) { - StatusCode.PushDataFailMain.getMessage() - } else { - StatusCode.PushDataFailSlave.getMessage() - } - callback.onFailure(new Exception(message, exception)) - return - } - fileWriters.foreach(_.incrementPendingWrites()) - - // for master, send data to slave - if (locations.head.getPeer != null && isMaster) { - pushMergedData.body().retain() - replicateThreadPool.submit(new Runnable { - override def run(): Unit = { - val location = locations.head - val peer = location.getPeer - val peerWorker = new WorkerInfo(peer.getHost, - peer.getRpcPort, peer.getPushPort, peer.getFetchPort, peer.getReplicatePort, -1, null) - if (unavailablePeers.containsKey(peerWorker)) { - pushMergedData.body().release() - wrappedCallback.onFailure(new Exception(s"Peer $peerWorker unavailable!")) - return - } - try { - val client = pushClientFactory.createClient( - peer.getHost, peer.getReplicatePort, location.getReduceId) - val newPushMergedData = new PushMergedData( - PartitionLocation.Mode.Slave.mode(), - shuffleKey, - pushMergedData.partitionUniqueIds, - batchOffsets, - pushMergedData.body) - client.pushMergedData(newPushMergedData, wrappedCallback) - } catch { - case e: Exception => - pushMergedData.body().release() - unavailablePeers.put(peerWorker, System.currentTimeMillis()) - wrappedCallback.onFailure(e) - } - } - }) - } else { - wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]())) - } - - var index = 0 - var fileWriter: FileWriter = null - var alreadyClosed = false - while (index < fileWriters.length) { - fileWriter = fileWriters(index) - val offset = body.readerIndex() + batchOffsets(index) - val length = if (index == fileWriters.length - 1) { - body.readableBytes() - batchOffsets(index) - } else { - batchOffsets(index + 1) - batchOffsets(index) - } - val batchBody = body.slice(offset, length) - - try { - if (!alreadyClosed) { - fileWriter.write(batchBody) - } else { - fileWriter.decrementPendingWrites() - } - } catch { - case e: AlreadyClosedException => - fileWriter.decrementPendingWrites() - alreadyClosed = true - val (mapId, attemptId) = getMapAttempt(body) - val endedAttempt = if (shuffleMapperAttempts.containsKey(shuffleKey)) { - shuffleMapperAttempts.get(shuffleKey)(mapId) - } else -1 - logWarning(s"Append data failed for task(shuffle $shuffleKey, map $mapId, attempt" + - s" $attemptId), caused by ${e.getMessage}") - case e: Exception => - logError("Exception encountered when write.", e) - } - index += 1 - } - } - - override def handleOpenStream(shuffleKey: String, fileName: String, startMapIndex: Int, - endMapIndex: Int): FileInfo = { - // find FileWriter responsible for the data - val fileWriter = localStorageManager.getWriter(shuffleKey, fileName) - if (fileWriter eq null) { - logWarning(s"File $fileName for $shuffleKey was not found!") - return null - } - partitionsSorter.openStream(shuffleKey, fileName, fileWriter, startMapIndex, endMapIndex); - } - private def registerWithMaster() { - logDebug("Trying to register with master.") var registerTimeout = RssConf.registerWorkerTimeoutMs(conf) val delta = 2000 while (registerTimeout > 0) { @@ -931,24 +349,7 @@ private[deploy] class Worker( throw new RssException("Register worker failed.") } - private val cleanTaskQueue = new LinkedBlockingQueue[jHashSet[String]] - private val cleaner = new Thread("Cleaner") { - override def run(): Unit = { - while (true) { - val expiredShuffleKeys = cleanTaskQueue.take() - try { - cleanup(expiredShuffleKeys) - } catch { - case e: Throwable => - logError("Cleanup failed", e) - } - } - } - } - cleaner.setDaemon(true) - cleaner.start() - - private def cleanup(expiredShuffleKeys: jHashSet[String]): Unit = { + private def cleanup(expiredShuffleKeys: jHashSet[String]): Unit = synchronized { expiredShuffleKeys.asScala.foreach { shuffleKey => partitionLocationInfo.getAllMasterLocations(shuffleKey).asScala.foreach { partition => partition.asInstanceOf[WorkingPartition].getFileWriter.destroy() @@ -959,10 +360,9 @@ private[deploy] class Worker( partitionLocationInfo.removeMasterPartitions(shuffleKey) partitionLocationInfo.removeSlavePartitions(shuffleKey) shuffleMapperAttempts.remove(shuffleKey) - partitionsSorter.cleanup(expiredShuffleKeys) logInfo(s"Cleaned up expired shuffle $shuffleKey") } - + partitionsSorter.cleanup(expiredShuffleKeys) localStorageManager.cleanupExpiredShuffleKey(expiredShuffleKeys) } @@ -983,39 +383,7 @@ private[deploy] object Worker extends Logging { conf.set("rss.master.address", RpcAddress.fromRssURL(workerArgs.master).toString) } - val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, WorkerSource.ServletPath) - - val rpcEnv = RpcEnv.create( - RpcNameConstants.WORKER_SYS, - workerArgs.host, - workerArgs.host, - workerArgs.port, - conf, - Math.max(64, Runtime.getRuntime.availableProcessors())) - rpcEnv.setupEndpoint(RpcNameConstants.WORKER_EP, new Worker(rpcEnv, conf, metricsSystem)) - - if (RssConf.metricsSystemEnable(conf)) { - logInfo(s"Metrics system enabled!") - metricsSystem.start() - - var port = RssConf.workerPrometheusMetricPort(conf) - var initialized = false - while (!initialized) { - try { - val httpServer = new HttpServer( - new HttpServerInitializer( - new HttpRequestHandler(metricsSystem.getPrometheusHandler)), port) - httpServer.start() - initialized = true - } catch { - case e: Exception => - logWarning(s"HttpServer pushPort $port may already exist, try pushPort ${port + 1}.", e) - port += 1 - Thread.sleep(1000) - } - } - } - - rpcEnv.awaitTermination() + val worker = new Worker(conf, workerArgs) + worker.init() } } diff --git a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/WorkerArguments.scala b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/WorkerArguments.scala index 5a1a71f739a..23ae90b1328 100644 --- a/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/WorkerArguments.scala +++ b/server-worker/src/main/scala/com/aliyun/emr/rss/service/deploy/worker/WorkerArguments.scala @@ -25,7 +25,7 @@ import com.aliyun.emr.rss.common.util.{IntParam, Utils} class WorkerArguments(args: Array[String], conf: RssConf) { var host = Utils.localHostName() - var port = 0 + var port = RssConf.workerRPCPort(conf) // var master: String = null // for local testing. var master: String = null diff --git a/server-worker/src/test/java/com/aliyun/emr/rss/service/deploy/worker/FileWriterSuiteJ.java b/server-worker/src/test/java/com/aliyun/emr/rss/service/deploy/worker/FileWriterSuiteJ.java index 7e1d2e897c0..0e525baec54 100644 --- a/server-worker/src/test/java/com/aliyun/emr/rss/service/deploy/worker/FileWriterSuiteJ.java +++ b/server-worker/src/test/java/com/aliyun/emr/rss/service/deploy/worker/FileWriterSuiteJ.java @@ -42,12 +42,14 @@ import static org.junit.Assert.fail; import com.aliyun.emr.rss.common.RssConf; -import com.aliyun.emr.rss.common.metrics.source.AbstractSource; import com.aliyun.emr.rss.common.network.TransportContext; import com.aliyun.emr.rss.common.network.buffer.ManagedBuffer; import com.aliyun.emr.rss.common.network.client.ChunkReceivedCallback; import com.aliyun.emr.rss.common.network.client.TransportClient; import com.aliyun.emr.rss.common.network.client.TransportClientFactory; +import com.aliyun.emr.rss.common.network.protocol.AbstractMessage; +import com.aliyun.emr.rss.common.network.protocol.OpenStream; +import com.aliyun.emr.rss.common.network.protocol.StreamHandle; import com.aliyun.emr.rss.common.network.server.FileInfo; import com.aliyun.emr.rss.common.network.server.MemoryTracker; import com.aliyun.emr.rss.common.network.server.TransportServer; @@ -70,7 +72,7 @@ public class FileWriterSuiteJ { private static File tempDir = null; private static DiskFlusher flusher = null; - private static AbstractSource source = null; + private static WorkerSource source = null; private static TransportServer server; private static TransportClientFactory clientFactory; @@ -84,7 +86,7 @@ public class FileWriterSuiteJ { public static void beforeAll() { tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "rss"); - source = Mockito.mock(AbstractSource.class); + source = Mockito.mock(WorkerSource.class); Mockito.doAnswer(invocationOnMock -> { Function0 function = (Function0) invocationOnMock.getArguments()[2]; return function.apply(); @@ -96,8 +98,26 @@ public static void beforeAll() { } public static void setupChunkServer(FileInfo info) throws Exception { - ChunkFetchRpcHandler handler = new ChunkFetchRpcHandler(transConf, source, - new OpenStreamer(info)); + FetchHandler handler = new FetchHandler(transConf) { + @Override + public FileInfo openStream( + String shuffleKey, + String fileName, + int startMapIndex, + int endMapIndex) { + return info; + } + + @Override + public WorkerSource source() { + return source; + } + + @Override + public boolean checkRegistered() { + return true; + } + }; TransportContext context = new TransportContext(transConf, handler); server = context.createServer(); @@ -133,45 +153,21 @@ public void releaseBuffers() { } } - static class OpenStreamer implements OpenStreamHandler, Registerable { - - private FileInfo fileInfo = null; - - OpenStreamer(FileInfo info) { - this.fileInfo = info; - } - - @Override - public boolean isRegistered() { - return true; - } - - @Override - public FileInfo handleOpenStream(String shuffleKey, String partitionId, int startMapIndex, - int endMapIndex) { - return fileInfo; - } - } - public ByteBuffer createOpenMessage() { byte[] shuffleKeyBytes = "shuffleKey".getBytes(StandardCharsets.UTF_8); byte[] fileNameBytes = "location".getBytes(StandardCharsets.UTF_8); - ByteBuffer openMessage = ByteBuffer.allocate( - 4 + shuffleKeyBytes.length + 4 + fileNameBytes.length + 8 + 8); - openMessage.putInt(shuffleKeyBytes.length); - openMessage.put(shuffleKeyBytes); - openMessage.putInt(fileNameBytes.length); - openMessage.put(fileNameBytes); - openMessage.putInt(0); - openMessage.putInt(Integer.MAX_VALUE); - openMessage.flip(); - return openMessage; + + OpenStream openBlocks = new OpenStream(shuffleKeyBytes, fileNameBytes, + 0, Integer.MAX_VALUE); + + return openBlocks.toByteBuffer(); } private void setUpConn(TransportClient client) { ByteBuffer resp = client.sendRpcSync(createOpenMessage(), 10000); - streamId = resp.getLong(); - numChunks = resp.getInt(); + StreamHandle streamHandle = (StreamHandle) AbstractMessage.fromByteBuffer(resp); + streamId = streamHandle.streamId; + numChunks = streamHandle.numChunks; } private FetchResult fetchChunks(TransportClient client, diff --git a/server-worker/src/test/scala/com/aliyun/emr/rss/service/deploy/MiniClusterFeature.scala b/server-worker/src/test/scala/com/aliyun/emr/rss/service/deploy/MiniClusterFeature.scala index 9d8a4be4d03..053dcd65a44 100644 --- a/server-worker/src/test/scala/com/aliyun/emr/rss/service/deploy/MiniClusterFeature.scala +++ b/server-worker/src/test/scala/com/aliyun/emr/rss/service/deploy/MiniClusterFeature.scala @@ -18,9 +18,7 @@ package com.aliyun.emr.rss.service.deploy import java.util.concurrent.atomic.AtomicInteger - import io.netty.channel.ChannelFuture - import com.aliyun.emr.rss.common.RssConf import com.aliyun.emr.rss.common.internal.Logging import com.aliyun.emr.rss.common.metrics.MetricsSystem @@ -28,7 +26,7 @@ import com.aliyun.emr.rss.common.rpc.RpcEnv import com.aliyun.emr.rss.common.protocol.RpcNameConstants import com.aliyun.emr.rss.server.common.http.{HttpServer, HttpServerInitializer} import com.aliyun.emr.rss.service.deploy.master.{Master, MasterArguments, MasterSource} -import com.aliyun.emr.rss.service.deploy.worker.{Worker, WorkerArguments, WorkerSource} +import com.aliyun.emr.rss.service.deploy.worker.{Controller, Worker, WorkerArguments, WorkerSource} trait MiniClusterFeature extends Logging { val workerPromethusPort = new AtomicInteger(12378) @@ -114,8 +112,7 @@ trait MiniClusterFeature extends Logging { workerArguments.port, conf, 4) - val worker = new Worker(rpcEnv, conf, metricsSystem) - rpcEnv.setupEndpoint(RpcNameConstants.WORKER_EP, worker) + val worker = new Worker(conf, workerArguments) var channelFuture: ChannelFuture = null if (RssConf.metricsSystemEnable(conf)) { @@ -151,23 +148,23 @@ trait MiniClusterFeature extends Logging { Thread.sleep(5000L) val (worker1, workerRpcEnv1, workerMetric1) = createWorker(workerConfs) - val workerThread1 = runnerWrap(workerRpcEnv1.awaitTermination()) + val workerThread1 = runnerWrap(worker1.init()) workerThread1.start() val (worker2, workerRpcEnv2, workerMetric2) = createWorker(workerConfs) - val workerThread2 = runnerWrap(workerRpcEnv2.awaitTermination()) + val workerThread2 = runnerWrap(worker2.init()) workerThread2.start() val (worker3, workerRpcEnv3, workerMetric3) = createWorker(workerConfs) - val workerThread3 = runnerWrap(workerRpcEnv3.awaitTermination()) + val workerThread3 = runnerWrap(worker3.init()) workerThread3.start() val (worker4, workerRpcEnv4, workerMetric4) = createWorker(workerConfs) - val workerThread4 = runnerWrap(workerRpcEnv4.awaitTermination()) + val workerThread4 = runnerWrap(worker4.init()) workerThread4.start() val (worker5, workerRpcEnv5, workerMetric5) = createWorker(workerConfs) - val workerThread5 = runnerWrap(workerRpcEnv5.awaitTermination()) + val workerThread5 = runnerWrap(worker5.init()) workerThread5.start() Thread.sleep(5000L) diff --git a/server-worker/src/test/scala/com/aliyun/emr/rss/service/deploy/SparkTestBase.scala b/server-worker/src/test/scala/com/aliyun/emr/rss/service/deploy/SparkTestBase.scala index f87842b3fcb..0c9b053f72b 100644 --- a/server-worker/src/test/scala/com/aliyun/emr/rss/service/deploy/SparkTestBase.scala +++ b/server-worker/src/test/scala/com/aliyun/emr/rss/service/deploy/SparkTestBase.scala @@ -75,9 +75,9 @@ trait SparkTestBase extends Logging with MiniClusterFeature { val (worker2, workerRpcEnv2, workerMetric2) = createWorker() val (worker3, workerRpcEnv3, workerMetric3) = createWorker() val masterThread = runnerWrap(masterRpcEnv.awaitTermination()) - val workerThread1 = runnerWrap(workerRpcEnv1.awaitTermination()) - val workerThread2 = runnerWrap(workerRpcEnv2.awaitTermination()) - val workerThread3 = runnerWrap(workerRpcEnv3.awaitTermination()) + val workerThread1 = runnerWrap(worker1.init()) + val workerThread2 = runnerWrap(worker2.init()) + val workerThread3 = runnerWrap(worker3.init()) masterThread.start() Thread.sleep(5000L) From 42a0870be0491022f35bb1ce800409bd8a5462eb Mon Sep 17 00:00:00 2001 From: "haiming@lccomputing.com" Date: Tue, 12 Jul 2022 14:13:38 +0800 Subject: [PATCH 4/4] fix memory leak --- .../emr/rss/common/network/client/TransportResponseHandler.java | 1 + 1 file changed, 1 insertion(+) diff --git a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java index 7931e4014f1..5771ca84a62 100644 --- a/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java +++ b/common/src/main/java/com/aliyun/emr/rss/common/network/client/TransportResponseHandler.java @@ -153,6 +153,7 @@ public void handle(ResponseMessage message) throws Exception { if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", resp.requestId, NettyUtils.getRemoteAddress(channel), resp.body().size()); + resp.body().release(); } else { outstandingRpcs.remove(resp.requestId); try {