From b20f05d3a56b2abc340c425052cd5b383aad2227 Mon Sep 17 00:00:00 2001 From: Prabhjyot Singh Date: Wed, 20 Nov 2024 12:30:52 +0530 Subject: [PATCH] ODP-2647: TEZ-4450: Shuffle data fetch fails when shuffle data is transferred via CompositeRoutedDataMovementEvent (#243) (Ganesha Shreedhara reviewed by Laszlo Bodor) (#20) (cherry picked from commit 8ebc4b00f1d66ee88475b1e96691f658dab967be) (cherry picked from commit 94da1f226ffb0c8cb89b827506a9921526a848b8) (cherry picked from commit 238df2382f9705e59fac827d117c0e0ec712132a) Co-authored-by: Ganesha Shreedhara --- pom.xml | 3 +- .../impl/ShuffleInputEventHandlerImpl.java | 37 +++++----- .../TestShuffleInputEventHandlerImpl.java | 70 +++++++++++++++++++ .../org.mockito.plugins.MockMaker | 13 ++++ 4 files changed, 105 insertions(+), 18 deletions(-) create mode 100644 tez-runtime-library/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker diff --git a/pom.xml b/pom.xml index d3b9982134..8fc472920f 100644 --- a/pom.xml +++ b/pom.xml @@ -59,6 +59,7 @@ ${user.home}/clover.license 32.0.1-jre 3.2.3.3.2.3.2-204 + 4.3.1 4.1.94.Final 0.13.0 1.19 @@ -848,7 +849,7 @@ org.mockito mockito-core - 4.3.1 + ${mockito-core.version} org.apache.commons diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandlerImpl.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandlerImpl.java index bcb7bb58ea..f718aea8d6 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandlerImpl.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandlerImpl.java @@ -171,8 +171,6 @@ public void logProgress(boolean updateOnClose) { private void processDataMovementEvent(DataMovementEvent dme, DataMovementEventPayloadProto shufflePayload, BitSet emptyPartitionsBitSet) throws IOException { int srcIndex = dme.getSourceIndex(); - String hostIdentifier = shufflePayload.getHost() + ":" + shufflePayload.getPort(); - if (LOG.isDebugEnabled()) { LOG.debug("DME srcIdx: " + srcIndex + ", targetIndex: " + dme.getTargetIndex() + ", attemptNum: " + dme.getVersion() + ", payload: " + ShuffleUtils @@ -198,20 +196,7 @@ private void processDataMovementEvent(DataMovementEvent dme, DataMovementEventPa CompositeInputAttemptIdentifier srcAttemptIdentifier = constructInputAttemptIdentifier(dme.getTargetIndex(), 1, dme.getVersion(), shufflePayload, (useSharedInputs && srcIndex == 0)); - if (shufflePayload.hasData()) { - DataProto dataProto = shufflePayload.getData(); - - FetchedInput fetchedInput = - inputAllocator.allocate(dataProto.getRawLength(), - dataProto.getCompressedLength(), srcAttemptIdentifier); - moveDataToFetchedInput(dataProto, fetchedInput, hostIdentifier); - shuffleManager.addCompletedInputWithData(srcAttemptIdentifier, fetchedInput); - - LOG.debug("Payload via DME : " + srcAttemptIdentifier); - } else { - shuffleManager.addKnownInput(shufflePayload.getHost(), shufflePayload.getPort(), - srcAttemptIdentifier, srcIndex); - } + processShufflePayload(shufflePayload, srcAttemptIdentifier, srcIndex); } private void moveDataToFetchedInput(DataProto dataProto, @@ -274,7 +259,25 @@ private void processCompositeRoutedDataMovementEvent(CompositeRoutedDataMovement CompositeInputAttemptIdentifier srcAttemptIdentifier = constructInputAttemptIdentifier(crdme.getTargetIndex(), crdme.getCount(), crdme.getVersion(), shufflePayload, (useSharedInputs && partitionId == 0)); - shuffleManager.addKnownInput(shufflePayload.getHost(), shufflePayload.getPort(), srcAttemptIdentifier, partitionId); + processShufflePayload(shufflePayload, srcAttemptIdentifier, partitionId); + } + + private void processShufflePayload(DataMovementEventPayloadProto shufflePayload, + CompositeInputAttemptIdentifier srcAttemptIdentifier, int srcIndex) throws IOException { + if (shufflePayload.hasData()) { + DataProto dataProto = shufflePayload.getData(); + String hostIdentifier = shufflePayload.getHost() + ":" + shufflePayload.getPort(); + FetchedInput fetchedInput = + inputAllocator.allocate(dataProto.getRawLength(), + dataProto.getCompressedLength(), srcAttemptIdentifier); + moveDataToFetchedInput(dataProto, fetchedInput, hostIdentifier); + shuffleManager.addCompletedInputWithData(srcAttemptIdentifier, fetchedInput); + + LOG.debug("Payload via DME : " + srcAttemptIdentifier); + } else { + shuffleManager.addKnownInput(shufflePayload.getHost(), shufflePayload.getPort(), + srcAttemptIdentifier, srcIndex); + } } private void processInputFailedEvent(InputFailedEvent ife) { diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/TestShuffleInputEventHandlerImpl.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/TestShuffleInputEventHandlerImpl.java index 51dc172df3..229573f70c 100644 --- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/TestShuffleInputEventHandlerImpl.java +++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/TestShuffleInputEventHandlerImpl.java @@ -26,12 +26,14 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; +import java.io.InputStream; import java.nio.ByteBuffer; import java.util.BitSet; import java.util.Collections; @@ -42,6 +44,7 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.security.token.Token; import org.apache.tez.common.TezCommonUtils; import org.apache.tez.common.TezExecutors; @@ -56,15 +59,20 @@ import org.apache.tez.runtime.api.Event; import org.apache.tez.runtime.api.ExecutionContext; import org.apache.tez.runtime.api.InputContext; +import org.apache.tez.runtime.api.events.CompositeRoutedDataMovementEvent; import org.apache.tez.runtime.api.events.DataMovementEvent; import org.apache.tez.runtime.library.common.CompositeInputAttemptIdentifier; import org.apache.tez.runtime.library.common.InputAttemptIdentifier; +import org.apache.tez.runtime.library.common.shuffle.FetchedInput; import org.apache.tez.runtime.library.common.shuffle.FetchedInputAllocator; +import org.apache.tez.runtime.library.common.shuffle.MemoryFetchedInput; import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils; +import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads; import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.DataMovementEventPayloadProto; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.mockito.MockedStatic; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -342,6 +350,53 @@ public void testPipelinedShuffleEvents_WithEmptyPartitions() throws IOException verify(inputContext).killSelf(any(), anyString()); } + /** + * Verify that data movement events with shuffle data are processed properly. + * + * @throws IOException + */ + @Test(timeout = 5000) + public void testDataMovementEventsWithShuffleData() throws IOException { + InputContext inputContext = mock(InputContext.class); + ShuffleManager shuffleManager = mock(ShuffleManager.class); + ShuffleManager compositeFetchShuffleManager = mock(ShuffleManager.class); + FetchedInputAllocator inputAllocator = mock(FetchedInputAllocator.class); + MemoryFetchedInput memoryFetchedInput = mock(MemoryFetchedInput.class); + + when(memoryFetchedInput.getType()).thenReturn(FetchedInput.Type.MEMORY); + when(memoryFetchedInput.getBytes()).thenReturn("data".getBytes()); + when(inputAllocator.allocate(anyLong(), anyLong(), any(InputAttemptIdentifier.class))) + .thenReturn(memoryFetchedInput); + + ShuffleInputEventHandlerImpl eventHandler = new ShuffleInputEventHandlerImpl(inputContext, + shuffleManager, inputAllocator, null, true, 4, false); + + ShuffleInputEventHandlerImpl compositeFetchEventHandler = new ShuffleInputEventHandlerImpl(inputContext, + compositeFetchShuffleManager, inputAllocator, null, true, 4, true); + + DataMovementEvent dataMovementEvent = (DataMovementEvent) createDataMovementEventWithShuffleData(false); + CompositeRoutedDataMovementEvent compositeRoutedDataMovementEvent = + (CompositeRoutedDataMovementEvent) createDataMovementEventWithShuffleData(true); + + List eventListWithDme = new LinkedList<>(); + eventListWithDme.add(dataMovementEvent); + eventListWithDme.add(compositeRoutedDataMovementEvent); + + try (MockedStatic shuffleUtils = mockStatic(ShuffleUtils.class)) { + shuffleUtils.when(() -> ShuffleUtils + .shuffleToMemory(any(byte[].class), any(InputStream.class), anyInt(), anyInt(), any(CompressionCodec.class), + anyBoolean(), anyInt(), any(), any(InputAttemptIdentifier.class))) + .thenAnswer((Answer) invocation -> null); + eventHandler.handleEvents(eventListWithDme); + compositeFetchEventHandler.handleEvents(eventListWithDme); + + verify(shuffleManager, times(2)) + .addCompletedInputWithData(any(InputAttemptIdentifier.class), any(FetchedInput.class)); + verify(compositeFetchShuffleManager, times(2)) + .addCompletedInputWithData(any(InputAttemptIdentifier.class), any(FetchedInput.class)); + } + } + private Event createDataMovementEvent(boolean addSpillDetails, int srcIdx, int targetIdx, int spillId, boolean isLastSpill, BitSet emptyPartitions, int numPartitions, int attemptNum) throws IOException { @@ -397,4 +452,19 @@ private ByteString createEmptyPartitionByteString(int... emptyPartitions) throws return emptyPartitionsBytesString; } + private Event createDataMovementEventWithShuffleData(boolean isComposite) { + DataMovementEventPayloadProto.Builder builder = DataMovementEventPayloadProto.newBuilder(); + builder.setHost(HOST); + builder.setPort(PORT); + builder.setPathComponent(PATH_COMPONENT); + ShuffleUserPayloads.DataProto.Builder dataProtoBuilder = ShuffleUserPayloads.DataProto.newBuilder() + .setData(ByteString.copyFromUtf8("data")); + builder.setData(dataProtoBuilder); + + Event dme = isComposite? + CompositeRoutedDataMovementEvent.create(0, 1, 1, 0, builder.build().toByteString().asReadOnlyByteBuffer()): + DataMovementEvent.create(0, 1, 0, builder.build().toByteString().asReadOnlyByteBuffer()); + return dme; + } + } diff --git a/tez-runtime-library/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/tez-runtime-library/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 0000000000..a258d79ad3 --- /dev/null +++ b/tez-runtime-library/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mock-maker-inline