Skip to content

Commit

Permalink
TEZ-4450: Shuffle data fetch fails when shuffle data is transferred v…
Browse files Browse the repository at this point in the history
…ia CompositeRoutedDataMovementEvent (apache#243) (Ganesha Shreedhara reviewed by Laszlo Bodor)

(cherry picked from commit 8ebc4b0)
  • Loading branch information
ganeshashree authored and prabhjyotsingh committed Nov 12, 2024
1 parent f9cbbf4 commit 94da1f2
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 18 deletions.
3 changes: 2 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
<clover.license>${user.home}/clover.license</clover.license>
<guava.version>32.0.1-jre</guava.version>
<hadoop.version>3.2.3.3.2.3.2-3</hadoop.version>
<mockito-core.version>4.3.1</mockito-core.version>
<netty.version>4.1.94.Final</netty.version>
<pig.version>0.13.0</pig.version>
<jersey.version>1.19</jersey.version>
Expand Down Expand Up @@ -799,7 +800,7 @@
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<version>1.10.8</version>
<version>${mockito-core.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ public void logProgress(boolean updateOnClose) {
private void processDataMovementEvent(DataMovementEvent dme, DataMovementEventPayloadProto shufflePayload, BitSet emptyPartitionsBitSet) throws IOException {
int srcIndex = dme.getSourceIndex();

String hostIdentifier = shufflePayload.getHost() + ":" + shufflePayload.getPort();

if (LOG.isDebugEnabled()) {
LOG.debug("DME srcIdx: " + srcIndex + ", targetIndex: " + dme.getTargetIndex()
+ ", attemptNum: " + dme.getVersion() + ", payload: " + ShuffleUtils
Expand All @@ -198,20 +196,7 @@ private void processDataMovementEvent(DataMovementEvent dme, DataMovementEventPa
CompositeInputAttemptIdentifier srcAttemptIdentifier = constructInputAttemptIdentifier(dme.getTargetIndex(), 1, dme.getVersion(),
shufflePayload, (useSharedInputs && srcIndex == 0));

if (shufflePayload.hasData()) {
DataProto dataProto = shufflePayload.getData();

FetchedInput fetchedInput =
inputAllocator.allocate(dataProto.getRawLength(),
dataProto.getCompressedLength(), srcAttemptIdentifier);
moveDataToFetchedInput(dataProto, fetchedInput, hostIdentifier);
shuffleManager.addCompletedInputWithData(srcAttemptIdentifier, fetchedInput);

LOG.debug("Payload via DME : " + srcAttemptIdentifier);
} else {
shuffleManager.addKnownInput(shufflePayload.getHost(), shufflePayload.getPort(),
srcAttemptIdentifier, srcIndex);
}
processShufflePayload(shufflePayload, srcAttemptIdentifier, srcIndex);
}

private void moveDataToFetchedInput(DataProto dataProto,
Expand Down Expand Up @@ -274,7 +259,25 @@ private void processCompositeRoutedDataMovementEvent(CompositeRoutedDataMovement
CompositeInputAttemptIdentifier srcAttemptIdentifier = constructInputAttemptIdentifier(crdme.getTargetIndex(), crdme.getCount(), crdme.getVersion(),
shufflePayload, (useSharedInputs && partitionId == 0));

shuffleManager.addKnownInput(shufflePayload.getHost(), shufflePayload.getPort(), srcAttemptIdentifier, partitionId);
processShufflePayload(shufflePayload, srcAttemptIdentifier, partitionId);
}

private void processShufflePayload(DataMovementEventPayloadProto shufflePayload,
CompositeInputAttemptIdentifier srcAttemptIdentifier, int srcIndex) throws IOException {
if (shufflePayload.hasData()) {
DataProto dataProto = shufflePayload.getData();
String hostIdentifier = shufflePayload.getHost() + ":" + shufflePayload.getPort();
FetchedInput fetchedInput =
inputAllocator.allocate(dataProto.getRawLength(),
dataProto.getCompressedLength(), srcAttemptIdentifier);
moveDataToFetchedInput(dataProto, fetchedInput, hostIdentifier);
shuffleManager.addCompletedInputWithData(srcAttemptIdentifier, fetchedInput);

LOG.debug("Payload via DME : " + srcAttemptIdentifier);
} else {
shuffleManager.addKnownInput(shufflePayload.getHost(), shufflePayload.getPort(),
srcAttemptIdentifier, srcIndex);
}
}

private void processInputFailedEvent(InputFailedEvent ife) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,21 @@
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.BitSet;
import java.util.Collections;
Expand All @@ -40,6 +47,7 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.security.token.Token;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezExecutors;
Expand All @@ -54,15 +62,20 @@
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.ExecutionContext;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.events.CompositeRoutedDataMovementEvent;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.library.common.CompositeInputAttemptIdentifier;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.shuffle.FetchedInput;
import org.apache.tez.runtime.library.common.shuffle.FetchedInputAllocator;
import org.apache.tez.runtime.library.common.shuffle.MemoryFetchedInput;
import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.DataMovementEventPayloadProto;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.MockedStatic;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

Expand Down Expand Up @@ -340,6 +353,53 @@ public void testPipelinedShuffleEvents_WithEmptyPartitions() throws IOException
verify(inputContext).killSelf(any(Throwable.class), anyString());
}

/**
* Verify that data movement events with shuffle data are processed properly.
*
* @throws IOException
*/
@Test(timeout = 5000)
public void testDataMovementEventsWithShuffleData() throws IOException {
InputContext inputContext = mock(InputContext.class);
ShuffleManager shuffleManager = mock(ShuffleManager.class);
ShuffleManager compositeFetchShuffleManager = mock(ShuffleManager.class);
FetchedInputAllocator inputAllocator = mock(FetchedInputAllocator.class);
MemoryFetchedInput memoryFetchedInput = mock(MemoryFetchedInput.class);

when(memoryFetchedInput.getType()).thenReturn(FetchedInput.Type.MEMORY);
when(memoryFetchedInput.getBytes()).thenReturn("data".getBytes());
when(inputAllocator.allocate(anyLong(), anyLong(), any(InputAttemptIdentifier.class)))
.thenReturn(memoryFetchedInput);

ShuffleInputEventHandlerImpl eventHandler = new ShuffleInputEventHandlerImpl(inputContext,
shuffleManager, inputAllocator, null, true, 4, false);

ShuffleInputEventHandlerImpl compositeFetchEventHandler = new ShuffleInputEventHandlerImpl(inputContext,
compositeFetchShuffleManager, inputAllocator, null, true, 4, true);

DataMovementEvent dataMovementEvent = (DataMovementEvent) createDataMovementEventWithShuffleData(false);
CompositeRoutedDataMovementEvent compositeRoutedDataMovementEvent =
(CompositeRoutedDataMovementEvent) createDataMovementEventWithShuffleData(true);

List<Event> eventListWithDme = new LinkedList<>();
eventListWithDme.add(dataMovementEvent);
eventListWithDme.add(compositeRoutedDataMovementEvent);

try (MockedStatic<ShuffleUtils> shuffleUtils = mockStatic(ShuffleUtils.class)) {
shuffleUtils.when(() -> ShuffleUtils
.shuffleToMemory(any(byte[].class), any(InputStream.class), anyInt(), anyInt(), any(CompressionCodec.class),
anyBoolean(), anyInt(), any(), any(InputAttemptIdentifier.class)))
.thenAnswer((Answer<Void>) invocation -> null);
eventHandler.handleEvents(eventListWithDme);
compositeFetchEventHandler.handleEvents(eventListWithDme);

verify(shuffleManager, times(2))
.addCompletedInputWithData(any(InputAttemptIdentifier.class), any(FetchedInput.class));
verify(compositeFetchShuffleManager, times(2))
.addCompletedInputWithData(any(InputAttemptIdentifier.class), any(FetchedInput.class));
}
}

private Event createDataMovementEvent(boolean addSpillDetails, int srcIdx, int targetIdx,
int spillId, boolean isLastSpill, BitSet emptyPartitions, int numPartitions, int attemptNum)
throws IOException {
Expand Down Expand Up @@ -395,4 +455,19 @@ private ByteString createEmptyPartitionByteString(int... emptyPartitions) throws
return emptyPartitionsBytesString;
}

private Event createDataMovementEventWithShuffleData(boolean isComposite) {
DataMovementEventPayloadProto.Builder builder = DataMovementEventPayloadProto.newBuilder();
builder.setHost(HOST);
builder.setPort(PORT);
builder.setPathComponent(PATH_COMPONENT);
ShuffleUserPayloads.DataProto.Builder dataProtoBuilder = ShuffleUserPayloads.DataProto.newBuilder()
.setData(ByteString.copyFromUtf8("data"));
builder.setData(dataProtoBuilder);

Event dme = isComposite?
CompositeRoutedDataMovementEvent.create(0, 1, 1, 0, builder.build().toByteString().asReadOnlyByteBuffer()):
DataMovementEvent.create(0, 1, 0, builder.build().toByteString().asReadOnlyByteBuffer());
return dme;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

mock-maker-inline

0 comments on commit 94da1f2

Please sign in to comment.