Skip to content

Commit

Permalink
[BUG] ClusterLoadFallbackPolicy is not strictness when a shuffle with…
Browse files Browse the repository at this point in the history
… big partitions to register (#30)
  • Loading branch information
TonyDoen authored Jan 26, 2022
1 parent 040ce00 commit 302891a
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 47 deletions.
1 change: 0 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,3 @@ There are already some further improvements on the schedule and welcome to conta
1. Spark AE Support.
2. Metrics Enhancement.
3. Multiple-Engine Support.

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle.rss

import org.apache.spark.{ShuffleDependency, SparkConf}
import org.apache.spark.SparkConf
import org.apache.spark.sql.internal.SQLConf

import com.aliyun.emr.rss.client.write.LifecycleManager
Expand All @@ -28,10 +28,9 @@ class RssShuffleFallbackPolicyRunner(sparkConf: SparkConf) extends Logging {

private lazy val essConf = RssShuffleManager.fromSparkConf(sparkConf)

def applyAllFallbackPolicy(dependency: ShuffleDependency[_, _, _],
lifecycleManager: LifecycleManager): Boolean = {
applyForceFallbackPolicy() || applyShufflePartitionsFallbackPolicy(dependency) ||
applyAQEFallbackPolicy() || applyClusterLoadFallbackPolicy(lifecycleManager)
def applyAllFallbackPolicy(lifecycleManager: LifecycleManager, numPartitions: Int): Boolean = {
applyForceFallbackPolicy() || applyShufflePartitionsFallbackPolicy(numPartitions) ||
applyAQEFallbackPolicy() || applyClusterLoadFallbackPolicy(lifecycleManager, numPartitions)
}

/**
Expand All @@ -42,15 +41,15 @@ class RssShuffleFallbackPolicyRunner(sparkConf: SparkConf) extends Logging {

/**
* if shuffle partitions > rss.max.partition.number, fallback to external shuffle
* @param dependency shuffle dependency
* @param numPartitions shuffle partitions
* @return return if shuffle partitions bigger than limit
*/
def applyShufflePartitionsFallbackPolicy(dependency: ShuffleDependency[_, _, _]): Boolean = {
val needFallback = dependency.partitioner.numPartitions >=
RssConf.maxPartitionNumSupported(essConf)
def applyShufflePartitionsFallbackPolicy(numPartitions: Int): Boolean = {
val confNumPartitions = RssConf.maxPartitionNumSupported(essConf)
val needFallback = numPartitions >= confNumPartitions
if (needFallback) {
logInfo(s"Shuffle num of partitions: ${dependency.partitioner.numPartitions}" +
s" is bigger than the limit: ${RssConf.maxPartitionNumSupported(essConf)}," +
logInfo(s"Shuffle num of partitions: $numPartitions" +
s" is bigger than the limit: $confNumPartitions," +
s" need fallback to spark shuffle")
}
needFallback
Expand All @@ -73,8 +72,13 @@ class RssShuffleFallbackPolicyRunner(sparkConf: SparkConf) extends Logging {
* if rss cluster is under high load, fallback to external shuffle
* @return if rss cluster's slots used percent is overhead the limit
*/
def applyClusterLoadFallbackPolicy(lifecycleManager: LifecycleManager): Boolean = {
val needFallback = lifecycleManager.isClusterOverload()
def applyClusterLoadFallbackPolicy(lifecycleManager: LifecycleManager, numPartitions: Int):
Boolean = {
if (!RssConf.clusterLoadFallbackEnabled(essConf)) {
return false
}

val needFallback = lifecycleManager.isClusterOverload(numPartitions)
if (needFallback) {
logWarning(s"Cluster is overload: $needFallback")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
newAppId = Some(RssShuffleManager.genNewAppId(dependency.rdd.context))
newAppId.foreach(initializeLifecycleManager)

if (fallbackPolicyRunner.applyAllFallbackPolicy(dependency, lifecycleManager.get)) {
if (fallbackPolicyRunner.applyAllFallbackPolicy(lifecycleManager.get,
dependency.partitioner.numPartitions)) {
logWarning("Fallback to SortShuffleManager!")
sortShuffleIds.add(shuffleId)
sortShuffleManager.registerShuffle(shuffleId, numMaps, dependency)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ class RssShuffleFallbackPolicyRunner(sparkConf: SparkConf) extends Logging {

private lazy val essConf = RssShuffleManager.fromSparkConf(sparkConf)

def applyAllFallbackPolicy(dependency: ShuffleDependency[_, _, _],
lifecycleManager: LifecycleManager): Boolean = {
applyForceFallbackPolicy() || applyShufflePartitionsFallbackPolicy(dependency) ||
applyAQEFallbackPolicy() || applyClusterLoadFallbackPolicy(lifecycleManager)
def applyAllFallbackPolicy(lifecycleManager: LifecycleManager, numPartitions: Int): Boolean = {
applyForceFallbackPolicy() || applyShufflePartitionsFallbackPolicy(numPartitions) ||
applyAQEFallbackPolicy() || applyClusterLoadFallbackPolicy(lifecycleManager, numPartitions)
}

/**
Expand All @@ -42,16 +41,16 @@ class RssShuffleFallbackPolicyRunner(sparkConf: SparkConf) extends Logging {

/**
* if shuffle partitions > rss.max.partition.number, fallback to external shuffle
* @param dependency shuffle dependency
* @param numPartitions shuffle partitions
* @return return if shuffle partitions bigger than limit
*/
def applyShufflePartitionsFallbackPolicy(dependency: ShuffleDependency[_, _, _]): Boolean = {
val needFallback = dependency.partitioner.numPartitions >=
RssConf.maxPartitionNumSupported(essConf)
def applyShufflePartitionsFallbackPolicy(numPartitions: Int): Boolean = {
val confNumPartitions = RssConf.maxPartitionNumSupported(essConf)
val needFallback = numPartitions >= confNumPartitions
if (needFallback) {
logInfo(s"Shuffle num of partitions: ${dependency.partitioner.numPartitions}" +
s" is bigger than the limit: ${RssConf.maxPartitionNumSupported(essConf)}," +
s" need fallback to spark shuffle")
logInfo(s"Shuffle num of partitions: $numPartitions" +
s" is bigger than the limit: $confNumPartitions," +
s" need fallback to spark shuffle")
}
needFallback
}
Expand All @@ -75,8 +74,13 @@ class RssShuffleFallbackPolicyRunner(sparkConf: SparkConf) extends Logging {
* if rss cluster is under high load, fallback to external shuffle
* @return if rss cluster's slots used percent is overhead the limit
*/
def applyClusterLoadFallbackPolicy(lifeCycleManager: LifecycleManager): Boolean = {
val needFallback = lifeCycleManager.isClusterOverload()
def applyClusterLoadFallbackPolicy(lifeCycleManager: LifecycleManager, numPartitions: Int):
Boolean = {
if (!RssConf.clusterLoadFallbackEnabled(essConf)) {
return false
}

val needFallback = lifeCycleManager.isClusterOverload(numPartitions)
if (needFallback) {
logWarning(s"Cluster is overload: $needFallback")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
newAppId = Some(RssShuffleManager.genNewAppId(dependency.rdd.context))
newAppId.foreach(initializeLifecycleManager)

if (fallbackPolicyRunner.applyAllFallbackPolicy(dependency, lifecycleManager.get)) {
if (fallbackPolicyRunner.applyAllFallbackPolicy(lifecycleManager.get,
dependency.partitioner.numPartitions)) {
logWarning("Fallback to SortShuffleManager!")
sortShuffleIds.add(shuffleId)
sortShuffleManager.registerShuffle(shuffleId, dependency)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1132,10 +1132,10 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
blacklist.addAll(failedWorker)
}

def isClusterOverload(): Boolean = {
def isClusterOverload(numPartitions: Int = 0): Boolean = {
logInfo(s"Ask Sync Cluster Load Status")
try {
rssHARetryClient.askSync[GetClusterLoadStatusResponse](GetClusterLoadStatus,
rssHARetryClient.askSync[GetClusterLoadStatusResponse](GetClusterLoadStatus(numPartitions),
classOf[GetClusterLoadStatusResponse]).isOverload
} catch {
case e: Exception =>
Expand Down
4 changes: 4 additions & 0 deletions common/src/main/scala/com/aliyun/emr/rss/common/RssConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ object RssConf extends Logging {
conf.getInt("rss.worker.prometheus.metric.port", 9096)
}

def clusterLoadFallbackEnabled(conf: RssConf): Boolean = {
conf.getBoolean("rss.clusterLoad.fallback.enabled", defaultValue = true)
}

def offerSlotsExtraSize(conf: RssConf): Int = {
conf.getInt("rss.offer.slots.extra.size", 2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ object ControlMessages {
case class GetBlacklistResponse(statusCode: StatusCode,
blacklist: util.List[WorkerInfo], unknownWorkers: util.List[WorkerInfo])

case object GetClusterLoadStatus extends Message
case class GetClusterLoadStatus(numPartitions: Int) extends Message

case class GetClusterLoadStatusResponse(isOverload: Boolean)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util
import java.util.concurrent.{ScheduledFuture, TimeUnit}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

import com.aliyun.emr.rss.common.RssConf
Expand Down Expand Up @@ -81,7 +82,7 @@ private[deploy] class Master(
// worker count
source.addGauge(MasterSource.WorkerCount,
_ => statusSystem.workers.size())
val (totalSlots, usedSlots, overloadWorkerCount, _) = getClusterLoad
val (totalSlots, usedSlots, overloadWorkerCount) = getClusterLoad
// worker slots count
source.addGauge(MasterSource.WorkerSlotsCount, _ => totalSlots)
// worker slots used count
Expand Down Expand Up @@ -189,9 +190,9 @@ private[deploy] class Master(
executeWithLeaderChecker(context,
handleReportNodeFailure(context, failedWorkers, requestId))

case GetClusterLoadStatus =>
case GetClusterLoadStatus(numPartitions: Int) =>
logInfo(s"Received GetClusterLoad request")
executeWithLeaderChecker(context, handleGetClusterLoadStatus(context))
executeWithLeaderChecker(context, handleGetClusterLoadStatus(context, numPartitions))
}

private def timeoutDeadWorkers() {
Expand Down Expand Up @@ -421,32 +422,35 @@ private[deploy] class Master(
context.reply(OneWayMessageResponse)
}

private def handleGetClusterLoadStatus(context: RpcCallContext): Unit = {
val (_, _, _, result) = getClusterLoad
private def handleGetClusterLoadStatus(context: RpcCallContext, numPartitions: Int): Unit = {
val clusterSlotsUsageLimit: Double = RssConf.clusterSlotsUsageLimitPercent(conf)
val (totalSlots, usedSlots, _) = getClusterLoad

val totalUsedRatio: Double = (usedSlots + numPartitions) / totalSlots.toDouble
val result = totalUsedRatio >= clusterSlotsUsageLimit
logInfo(s"Current cluster slots usage:$totalUsedRatio, conf:$clusterSlotsUsageLimit, " +
s"overload:$result")
context.reply(GetClusterLoadStatusResponse(result))
}

private def getClusterLoad: (Int, Int, Int, Boolean) = {
if (workersSnapShot.isEmpty) {
return (0, 0, 0, false)
private def getClusterLoad: (Int, Int, Int) = {
val workers: mutable.Buffer[WorkerInfo] = workersSnapShot.asScala
if (workers.isEmpty) {
return (0, 0, 0)
}

val clusterSlotsUsageLimit: Double = RssConf.clusterSlotsUsageLimitPercent(conf)

val (totalSlots, usedSlots, overloadWorkers) = workersSnapShot.asScala.map(workerInfo => {
val (totalSlots, usedSlots, overloadWorkers) = workers.map(workerInfo => {
val allSlots: Int = workerInfo.numSlots
val usedSlots: Int = workerInfo.usedSlots()
val flag: Int = if (usedSlots/allSlots.toDouble >= clusterSlotsUsageLimit) 1 else 0
val flag: Int = if (usedSlots / allSlots.toDouble >= clusterSlotsUsageLimit) 1 else 0
(allSlots, usedSlots, flag)
}).reduce((pair1, pair2) => {
(pair1._1 + pair2._1, pair1._2 + pair2._2, pair1._3 + pair2._3)
})

val totalUsedRatio: Double = usedSlots / totalSlots.toDouble
val result = totalUsedRatio >= clusterSlotsUsageLimit
logInfo(s"Current cluster slots usage:$totalUsedRatio, conf:$clusterSlotsUsageLimit, " +
s"overload:$result")
(totalSlots, usedSlots, overloadWorkers, result)
(totalSlots, usedSlots, overloadWorkers)
}

private def workersNotBlacklisted(
Expand Down

0 comments on commit 302891a

Please sign in to comment.