Skip to content

Commit

Permalink
[CELEBORN-1700][FOLLOWUP] Support ShuffleFallbackCount metric for fal…
Browse files Browse the repository at this point in the history
…lback to vanilla Flink built-in shuffle implementation
  • Loading branch information
SteNicholas committed Dec 19, 2024
1 parent cec88b2 commit 64a123d
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -54,6 +55,7 @@
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.plugin.flink.fallback.ShuffleFallbackPolicy;
import org.apache.celeborn.plugin.flink.fallback.ShuffleFallbackPolicyRunner;
import org.apache.celeborn.plugin.flink.utils.FlinkUtils;

Expand All @@ -63,8 +65,8 @@ public class RemoteShuffleMaster implements ShuffleMaster<ShuffleDescriptor> {
private final ShuffleMasterContext shuffleMasterContext;
// Flink JobId -> Celeborn register shuffleIds
private final Map<JobID, Set<Integer>> jobShuffleIds = JavaUtils.newConcurrentHashMap();
private final ConcurrentHashMap.KeySetView<JobID, Boolean> nettyJobIds =
ConcurrentHashMap.newKeySet();
private final ConcurrentHashMap<JobID, String> jobFallbackPolicies =
JavaUtils.newConcurrentHashMap();
private String celebornAppId;
private volatile LifecycleManager lifecycleManager;
private final ShuffleTaskInfo shuffleTaskInfo = new ShuffleTaskInfo();
Expand Down Expand Up @@ -106,18 +108,21 @@ public void registerJob(JobShuffleContext context) {
}

try {
if (nettyShuffleServiceFactory != null
&& ShuffleFallbackPolicyRunner.applyFallbackPolicies(context, conf, lifecycleManager)) {
LOG.warn("Fallback to vanilla Flink NettyShuffleMaster for job: {}.", jobID);
nettyJobIds.add(jobID);
nettyShuffleMaster().registerJob(context);
} else {
Set<Integer> previousShuffleIds = jobShuffleIds.putIfAbsent(jobID, new HashSet<>());
if (previousShuffleIds != null) {
throw new RuntimeException("Duplicated registration job: " + jobID);
if (nettyShuffleServiceFactory != null) {
Optional<ShuffleFallbackPolicy> shuffleFallbackPolicy =
ShuffleFallbackPolicyRunner.applyFallbackPolicies(context, conf, lifecycleManager);
if (shuffleFallbackPolicy.isPresent()) {
LOG.warn("Fallback to vanilla Flink NettyShuffleMaster for job: {}.", jobID);
jobFallbackPolicies.put(jobID, shuffleFallbackPolicy.get().getClass().getName());
nettyShuffleMaster().registerJob(context);
return;
}
shuffleResourceTracker.registerJob(context);
}
Set<Integer> previousShuffleIds = jobShuffleIds.putIfAbsent(jobID, new HashSet<>());
if (previousShuffleIds != null) {
throw new RuntimeException("Duplicated registration job: " + jobID);
}
shuffleResourceTracker.registerJob(context);
} catch (CelebornIOException e) {
throw new RuntimeException(e);
}
Expand All @@ -126,7 +131,7 @@ public void registerJob(JobShuffleContext context) {
@Override
public void unregisterJob(JobID jobID) {
LOG.info("Unregister job: {}.", jobID);
if (nettyJobIds.remove(jobID)) {
if (jobFallbackPolicies.remove(jobID) != null) {
nettyShuffleMaster().unregisterJob(jobID);
return;
}
Expand All @@ -152,8 +157,13 @@ public CompletableFuture<ShuffleDescriptor> registerPartitionWithProducer(
JobID jobID, PartitionDescriptor partitionDescriptor, ProducerDescriptor producerDescriptor) {
return CompletableFuture.supplyAsync(
() -> {
if (nettyJobIds.contains(jobID)) {
lifecycleManager.shuffleCount().increment();
String jobFallbackPolicy = jobFallbackPolicies.get(jobID);
if (jobFallbackPolicy != null) {
try {
lifecycleManager
.shuffleFallbackCounts()
.compute(jobFallbackPolicy, (key, value) -> value == null ? 1L : value + 1L);
return nettyShuffleMaster()
.registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
.get();
Expand Down Expand Up @@ -270,7 +280,7 @@ public MemorySize computeShuffleMemorySizeForTask(
@Override
public void close() throws Exception {
try {
nettyJobIds.clear();
jobFallbackPolicies.clear();
jobShuffleIds.clear();
LifecycleManager manager = lifecycleManager;
if (null != manager) {
Expand Down Expand Up @@ -318,7 +328,12 @@ private NettyShuffleMaster nettyShuffleMaster() {
}

@VisibleForTesting
public ConcurrentHashMap.KeySetView<JobID, Boolean> nettyJobIds() {
return nettyJobIds;
public LifecycleManager lifecycleManager() {
return lifecycleManager;
}

@VisibleForTesting
public ConcurrentHashMap<JobID, String> jobFallbackPolicies() {
return jobFallbackPolicies;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class ShuffleFallbackPolicyRunner {
private static final List<ShuffleFallbackPolicy> FALLBACK_POLICIES =
ShuffleFallbackPolicyFactory.getShuffleFallbackPolicies();

public static boolean applyFallbackPolicies(
public static Optional<ShuffleFallbackPolicy> applyFallbackPolicies(
JobShuffleContext shuffleContext,
CelebornConf celebornConf,
LifecycleManager lifecycleManager)
Expand All @@ -44,11 +44,11 @@ public static boolean applyFallbackPolicies(
shuffleFallbackPolicy.needFallback(
shuffleContext, celebornConf, lifecycleManager))
.findFirst();
boolean needFallback = fallbackPolicy.isPresent();
if (needFallback && FallbackPolicy.NEVER.equals(celebornConf.flinkShuffleFallbackPolicy())) {
if (fallbackPolicy.isPresent()
&& FallbackPolicy.NEVER.equals(celebornConf.flinkShuffleFallbackPolicy())) {
throw new CelebornIOException(
"Fallback to flink built-in shuffle implementation is prohibited.");
}
return needFallback;
return fallbackPolicy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.FallbackPolicy;
import org.apache.celeborn.common.util.Utils$;
import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy;
import org.apache.celeborn.plugin.flink.utils.FlinkUtils;

public class RemoteShuffleMasterSuiteJ {
Expand Down Expand Up @@ -91,9 +92,9 @@ public void testRegisterJobWithForceFallbackPolicy() {
JobID jobID = JobID.generate();
JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID);
remoteShuffleMaster.registerJob(jobShuffleContext);
Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID));
Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID));
remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId());
Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty());
Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty());
}

@Test
Expand Down Expand Up @@ -128,6 +129,7 @@ public void testRegisterPartitionWithProducer()
remoteShuffleMaster
.registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
.get();
Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum());
mapPartitionShuffleDescriptor =
remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
Expand All @@ -147,6 +149,32 @@ public void testRegisterPartitionWithProducer()
Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId());
}

@Test
public void testRegisterPartitionWithProducerForForceFallbackPolicy()
throws UnknownHostException, ExecutionException, InterruptedException {
configuration.setString(
CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name());
remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory());
JobID jobID = JobID.generate();
JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID);
remoteShuffleMaster.registerJob(jobShuffleContext);

IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0);
ProducerDescriptor producerDescriptor = createProducerDescriptor();
ShuffleDescriptor shuffleDescriptor =
remoteShuffleMaster
.registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
.get();
Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor);
Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum());
Map<String, Long> shuffleFallbackCounts =
remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts();
Assert.assertEquals(1, shuffleFallbackCounts.size());
Assert.assertEquals(
1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue());
}

@Test
public void testRegisterMultipleJobs()
throws UnknownHostException, ExecutionException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.FallbackPolicy;
import org.apache.celeborn.common.util.Utils$;
import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy;
import org.apache.celeborn.plugin.flink.utils.FlinkUtils;

public class RemoteShuffleMasterSuiteJ {
Expand Down Expand Up @@ -91,9 +92,9 @@ public void testRegisterJobWithForceFallbackPolicy() {
JobID jobID = JobID.generate();
JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID);
remoteShuffleMaster.registerJob(jobShuffleContext);
Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID));
Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID));
remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId());
Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty());
Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty());
}

@Test
Expand Down Expand Up @@ -128,6 +129,7 @@ public void testRegisterPartitionWithProducer()
remoteShuffleMaster
.registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
.get();
Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum());
mapPartitionShuffleDescriptor =
remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
Expand All @@ -147,6 +149,32 @@ public void testRegisterPartitionWithProducer()
Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId());
}

@Test
public void testRegisterPartitionWithProducerForForceFallbackPolicy()
throws UnknownHostException, ExecutionException, InterruptedException {
configuration.setString(
CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name());
remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory());
JobID jobID = JobID.generate();
JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID);
remoteShuffleMaster.registerJob(jobShuffleContext);

IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0);
ProducerDescriptor producerDescriptor = createProducerDescriptor();
ShuffleDescriptor shuffleDescriptor =
remoteShuffleMaster
.registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
.get();
Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor);
Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum());
Map<String, Long> shuffleFallbackCounts =
remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts();
Assert.assertEquals(1, shuffleFallbackCounts.size());
Assert.assertEquals(
1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue());
}

@Test
public void testRegisterMultipleJobs()
throws UnknownHostException, ExecutionException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.shuffle.JobShuffleContext;
import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor;
import org.apache.flink.runtime.shuffle.PartitionDescriptor;
import org.apache.flink.runtime.shuffle.ProducerDescriptor;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.runtime.shuffle.ShuffleMasterContext;
import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor;
import org.junit.After;
Expand All @@ -56,6 +58,7 @@
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.FallbackPolicy;
import org.apache.celeborn.common.util.Utils$;
import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy;
import org.apache.celeborn.plugin.flink.utils.FlinkUtils;

public class RemoteShuffleMasterSuiteJ {
Expand Down Expand Up @@ -98,9 +101,9 @@ public void testRegisterJobWithForceFallbackPolicy() {
JobID jobID = JobID.generate();
JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID);
remoteShuffleMaster.registerJob(jobShuffleContext);
Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID));
Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID));
remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId());
Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty());
Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty());
}

@Test
Expand Down Expand Up @@ -135,6 +138,7 @@ public void testRegisterPartitionWithProducer()
remoteShuffleMaster
.registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
.get();
Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum());
mapPartitionShuffleDescriptor =
remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
Expand All @@ -154,6 +158,32 @@ public void testRegisterPartitionWithProducer()
Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId());
}

@Test
public void testRegisterPartitionWithProducerForForceFallbackPolicy()
throws UnknownHostException, ExecutionException, InterruptedException {
configuration.setString(
CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name());
remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory());
JobID jobID = JobID.generate();
JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID);
remoteShuffleMaster.registerJob(jobShuffleContext);

IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0);
ProducerDescriptor producerDescriptor = createProducerDescriptor();
ShuffleDescriptor shuffleDescriptor =
remoteShuffleMaster
.registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
.get();
Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor);
Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum());
Map<String, Long> shuffleFallbackCounts =
remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts();
Assert.assertEquals(1, shuffleFallbackCounts.size());
Assert.assertEquals(
1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue());
}

@Test
public void testRegisterMultipleJobs()
throws UnknownHostException, ExecutionException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.shuffle.JobShuffleContext;
import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor;
import org.apache.flink.runtime.shuffle.PartitionDescriptor;
import org.apache.flink.runtime.shuffle.ProducerDescriptor;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.runtime.shuffle.ShuffleMasterContext;
import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor;
import org.junit.After;
Expand All @@ -56,6 +58,7 @@
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.FallbackPolicy;
import org.apache.celeborn.common.util.Utils$;
import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy;
import org.apache.celeborn.plugin.flink.utils.FlinkUtils;

public class RemoteShuffleMasterSuiteJ {
Expand Down Expand Up @@ -98,9 +101,9 @@ public void testRegisterJobWithForceFallbackPolicy() {
JobID jobID = JobID.generate();
JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID);
remoteShuffleMaster.registerJob(jobShuffleContext);
Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID));
Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID));
remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId());
Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty());
Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty());
}

@Test
Expand All @@ -118,6 +121,7 @@ public void testRegisterPartitionWithProducer()
remoteShuffleMaster
.registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
.get();
Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum());
ShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource();
ShuffleResourceDescriptor mapPartitionShuffleDescriptor =
shuffleResource.getMapPartitionShuffleDescriptor();
Expand Down Expand Up @@ -154,6 +158,32 @@ public void testRegisterPartitionWithProducer()
Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId());
}

@Test
public void testRegisterPartitionWithProducerForForceFallbackPolicy()
throws UnknownHostException, ExecutionException, InterruptedException {
configuration.setString(
CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name());
remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory());
JobID jobID = JobID.generate();
JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID);
remoteShuffleMaster.registerJob(jobShuffleContext);

IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0);
ProducerDescriptor producerDescriptor = createProducerDescriptor();
ShuffleDescriptor shuffleDescriptor =
remoteShuffleMaster
.registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
.get();
Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor);
Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum());
Map<String, Long> shuffleFallbackCounts =
remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts();
Assert.assertEquals(1, shuffleFallbackCounts.size());
Assert.assertEquals(
1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue());
}

@Test
public void testRegisterMultipleJobs()
throws UnknownHostException, ExecutionException, InterruptedException {
Expand Down
Loading

0 comments on commit 64a123d

Please sign in to comment.