diff --git a/pom.xml b/pom.xml
index c3f40bb297..33b8869742 100644
--- a/pom.xml
+++ b/pom.xml
@@ -59,6 +59,7 @@
${user.home}/clover.license
32.0.1-jre
3.2.3.3.2.3.3-2
+ 4.3.1
4.1.94.Final
0.13.0
1.19
@@ -848,7 +849,7 @@
org.mockito
mockito-core
- 4.3.1
+ ${mockito-core.version}
org.apache.commons
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandlerImpl.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandlerImpl.java
index bcb7bb58ea..f718aea8d6 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandlerImpl.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandlerImpl.java
@@ -171,8 +171,6 @@ public void logProgress(boolean updateOnClose) {
private void processDataMovementEvent(DataMovementEvent dme, DataMovementEventPayloadProto shufflePayload, BitSet emptyPartitionsBitSet) throws IOException {
int srcIndex = dme.getSourceIndex();
- String hostIdentifier = shufflePayload.getHost() + ":" + shufflePayload.getPort();
-
if (LOG.isDebugEnabled()) {
LOG.debug("DME srcIdx: " + srcIndex + ", targetIndex: " + dme.getTargetIndex()
+ ", attemptNum: " + dme.getVersion() + ", payload: " + ShuffleUtils
@@ -198,20 +196,7 @@ private void processDataMovementEvent(DataMovementEvent dme, DataMovementEventPa
CompositeInputAttemptIdentifier srcAttemptIdentifier = constructInputAttemptIdentifier(dme.getTargetIndex(), 1, dme.getVersion(),
shufflePayload, (useSharedInputs && srcIndex == 0));
- if (shufflePayload.hasData()) {
- DataProto dataProto = shufflePayload.getData();
-
- FetchedInput fetchedInput =
- inputAllocator.allocate(dataProto.getRawLength(),
- dataProto.getCompressedLength(), srcAttemptIdentifier);
- moveDataToFetchedInput(dataProto, fetchedInput, hostIdentifier);
- shuffleManager.addCompletedInputWithData(srcAttemptIdentifier, fetchedInput);
-
- LOG.debug("Payload via DME : " + srcAttemptIdentifier);
- } else {
- shuffleManager.addKnownInput(shufflePayload.getHost(), shufflePayload.getPort(),
- srcAttemptIdentifier, srcIndex);
- }
+ processShufflePayload(shufflePayload, srcAttemptIdentifier, srcIndex);
}
private void moveDataToFetchedInput(DataProto dataProto,
@@ -274,7 +259,25 @@ private void processCompositeRoutedDataMovementEvent(CompositeRoutedDataMovement
CompositeInputAttemptIdentifier srcAttemptIdentifier = constructInputAttemptIdentifier(crdme.getTargetIndex(), crdme.getCount(), crdme.getVersion(),
shufflePayload, (useSharedInputs && partitionId == 0));
- shuffleManager.addKnownInput(shufflePayload.getHost(), shufflePayload.getPort(), srcAttemptIdentifier, partitionId);
+ processShufflePayload(shufflePayload, srcAttemptIdentifier, partitionId);
+ }
+
+ private void processShufflePayload(DataMovementEventPayloadProto shufflePayload,
+ CompositeInputAttemptIdentifier srcAttemptIdentifier, int srcIndex) throws IOException {
+ if (shufflePayload.hasData()) {
+ DataProto dataProto = shufflePayload.getData();
+ String hostIdentifier = shufflePayload.getHost() + ":" + shufflePayload.getPort();
+ FetchedInput fetchedInput =
+ inputAllocator.allocate(dataProto.getRawLength(),
+ dataProto.getCompressedLength(), srcAttemptIdentifier);
+ moveDataToFetchedInput(dataProto, fetchedInput, hostIdentifier);
+ shuffleManager.addCompletedInputWithData(srcAttemptIdentifier, fetchedInput);
+
+ LOG.debug("Payload via DME : " + srcAttemptIdentifier);
+ } else {
+ shuffleManager.addKnownInput(shufflePayload.getHost(), shufflePayload.getPort(),
+ srcAttemptIdentifier, srcIndex);
+ }
}
private void processInputFailedEvent(InputFailedEvent ife) {
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/TestShuffleInputEventHandlerImpl.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/TestShuffleInputEventHandlerImpl.java
index 51dc172df3..229573f70c 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/TestShuffleInputEventHandlerImpl.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/TestShuffleInputEventHandlerImpl.java
@@ -26,12 +26,14 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.IOException;
+import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.BitSet;
import java.util.Collections;
@@ -42,6 +44,7 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.security.token.Token;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezExecutors;
@@ -56,15 +59,20 @@
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.ExecutionContext;
import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.events.CompositeRoutedDataMovementEvent;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.library.common.CompositeInputAttemptIdentifier;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInput;
import org.apache.tez.runtime.library.common.shuffle.FetchedInputAllocator;
+import org.apache.tez.runtime.library.common.shuffle.MemoryFetchedInput;
import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.DataMovementEventPayloadProto;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import org.mockito.MockedStatic;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
@@ -342,6 +350,53 @@ public void testPipelinedShuffleEvents_WithEmptyPartitions() throws IOException
verify(inputContext).killSelf(any(), anyString());
}
+ /**
+ * Verify that data movement events with shuffle data are processed properly.
+ *
+ * @throws IOException
+ */
+ @Test(timeout = 5000)
+ public void testDataMovementEventsWithShuffleData() throws IOException {
+ InputContext inputContext = mock(InputContext.class);
+ ShuffleManager shuffleManager = mock(ShuffleManager.class);
+ ShuffleManager compositeFetchShuffleManager = mock(ShuffleManager.class);
+ FetchedInputAllocator inputAllocator = mock(FetchedInputAllocator.class);
+ MemoryFetchedInput memoryFetchedInput = mock(MemoryFetchedInput.class);
+
+ when(memoryFetchedInput.getType()).thenReturn(FetchedInput.Type.MEMORY);
+ when(memoryFetchedInput.getBytes()).thenReturn("data".getBytes());
+ when(inputAllocator.allocate(anyLong(), anyLong(), any(InputAttemptIdentifier.class)))
+ .thenReturn(memoryFetchedInput);
+
+ ShuffleInputEventHandlerImpl eventHandler = new ShuffleInputEventHandlerImpl(inputContext,
+ shuffleManager, inputAllocator, null, true, 4, false);
+
+ ShuffleInputEventHandlerImpl compositeFetchEventHandler = new ShuffleInputEventHandlerImpl(inputContext,
+ compositeFetchShuffleManager, inputAllocator, null, true, 4, true);
+
+ DataMovementEvent dataMovementEvent = (DataMovementEvent) createDataMovementEventWithShuffleData(false);
+ CompositeRoutedDataMovementEvent compositeRoutedDataMovementEvent =
+ (CompositeRoutedDataMovementEvent) createDataMovementEventWithShuffleData(true);
+
+ List eventListWithDme = new LinkedList<>();
+ eventListWithDme.add(dataMovementEvent);
+ eventListWithDme.add(compositeRoutedDataMovementEvent);
+
+ try (MockedStatic shuffleUtils = mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() -> ShuffleUtils
+ .shuffleToMemory(any(byte[].class), any(InputStream.class), anyInt(), anyInt(), any(CompressionCodec.class),
+ anyBoolean(), anyInt(), any(), any(InputAttemptIdentifier.class)))
+ .thenAnswer((Answer) invocation -> null);
+ eventHandler.handleEvents(eventListWithDme);
+ compositeFetchEventHandler.handleEvents(eventListWithDme);
+
+ verify(shuffleManager, times(2))
+ .addCompletedInputWithData(any(InputAttemptIdentifier.class), any(FetchedInput.class));
+ verify(compositeFetchShuffleManager, times(2))
+ .addCompletedInputWithData(any(InputAttemptIdentifier.class), any(FetchedInput.class));
+ }
+ }
+
private Event createDataMovementEvent(boolean addSpillDetails, int srcIdx, int targetIdx,
int spillId, boolean isLastSpill, BitSet emptyPartitions, int numPartitions, int attemptNum)
throws IOException {
@@ -397,4 +452,19 @@ private ByteString createEmptyPartitionByteString(int... emptyPartitions) throws
return emptyPartitionsBytesString;
}
+ private Event createDataMovementEventWithShuffleData(boolean isComposite) {
+ DataMovementEventPayloadProto.Builder builder = DataMovementEventPayloadProto.newBuilder();
+ builder.setHost(HOST);
+ builder.setPort(PORT);
+ builder.setPathComponent(PATH_COMPONENT);
+ ShuffleUserPayloads.DataProto.Builder dataProtoBuilder = ShuffleUserPayloads.DataProto.newBuilder()
+ .setData(ByteString.copyFromUtf8("data"));
+ builder.setData(dataProtoBuilder);
+
+ Event dme = isComposite?
+ CompositeRoutedDataMovementEvent.create(0, 1, 1, 0, builder.build().toByteString().asReadOnlyByteBuffer()):
+ DataMovementEvent.create(0, 1, 0, builder.build().toByteString().asReadOnlyByteBuffer());
+ return dme;
+ }
+
}
diff --git a/tez-runtime-library/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/tez-runtime-library/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker
new file mode 100644
index 0000000000..a258d79ad3
--- /dev/null
+++ b/tez-runtime-library/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker
@@ -0,0 +1,13 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+mock-maker-inline