diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index ee82d9fa7874b..182abacc475ae 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -25,9 +25,11 @@ import scala.concurrent.Await
import akka.actor._
import akka.pattern.ask
+
+import org.apache.spark.util._
import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util._
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
@@ -168,8 +170,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
} else {
- throw new FetchFailedException(null, shuffleId, -1, reduceId,
- new Exception("Missing all output locations for shuffle " + shuffleId))
+ throw new MetadataFetchFailedException(
+ shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
} else {
statuses.synchronized {
@@ -371,8 +373,8 @@ private[spark] object MapOutputTracker {
statuses.map {
status =>
if (status == null) {
- throw new FetchFailedException(null, shuffleId, -1, reduceId,
- new Exception("Missing an output location for shuffle " + shuffleId))
+ throw new MetadataFetchFailedException(
+ shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
(status.location, decompressSize(status.compressedSizes(reduceId)))
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 5e8bd8c8e533a..df42d679b4699 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -65,7 +65,7 @@ case object Resubmitted extends TaskFailedReason {
*/
@DeveloperApi
case class FetchFailed(
- bmAddress: BlockManagerId,
+ bmAddress: BlockManagerId, // Note that bmAddress can be null
shuffleId: Int,
mapId: Int,
reduceId: Int)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 1f0785d4056a7..92bf2bca5ea45 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -26,8 +26,8 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import org.apache.spark._
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.{AkkaUtils, Utils}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 378cf1aaebe7b..82163eadd56e9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -75,9 +75,11 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId)
@DeveloperApi
case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent
+@DeveloperApi
case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String)
extends SparkListenerEvent
+@DeveloperApi
case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent
/** An event used in the listener to shutdown the listener daemon thread. */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 9a4be43ee219f..8ec482a6f6d9c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -106,6 +106,8 @@ private[spark] class Stage(
id
}
+ def attemptId: Int = nextAttemptId
+
val name = callSite.short
val details = callSite.long
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 1481d70db42e1..4c96b9e5fef60 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -21,6 +21,10 @@ import java.nio.ByteBuffer
import org.apache.spark.util.SerializableBuffer
+/**
+ * Description of a task that gets passed onto executors to be executed, usually created by
+ * [[TaskSetManager.resourceOffer]].
+ */
private[spark] class TaskDescription(
val taskId: Long,
val executorId: String,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 4c62e4dc0bac8..6aecdfe8e6656 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -27,10 +27,12 @@ import org.apache.spark.annotation.DeveloperApi
class TaskInfo(
val taskId: Long,
val index: Int,
+ val attempt: Int,
val launchTime: Long,
val executorId: String,
val host: String,
- val taskLocality: TaskLocality.TaskLocality) {
+ val taskLocality: TaskLocality.TaskLocality,
+ val speculative: Boolean) {
/**
* The time when the task started remotely getting the result. Will not be set if the
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index c0898f64fb0c9..83ff6b8550b4f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -335,17 +335,19 @@ private[spark] class TaskSetManager(
/**
* Dequeue a pending task for a given node and return its index and locality level.
* Only search for tasks matching the given locality constraint.
+ *
+ * @return An option containing (task index within the task set, locality, is speculative?)
*/
private def findTask(execId: String, host: String, locality: TaskLocality.Value)
- : Option[(Int, TaskLocality.Value)] =
+ : Option[(Int, TaskLocality.Value, Boolean)] =
{
for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) {
- return Some((index, TaskLocality.PROCESS_LOCAL))
+ return Some((index, TaskLocality.PROCESS_LOCAL, false))
}
if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
for (index <- findTaskFromList(execId, getPendingTasksForHost(host))) {
- return Some((index, TaskLocality.NODE_LOCAL))
+ return Some((index, TaskLocality.NODE_LOCAL, false))
}
}
@@ -354,23 +356,25 @@ private[spark] class TaskSetManager(
rack <- sched.getRackForHost(host)
index <- findTaskFromList(execId, getPendingTasksForRack(rack))
} {
- return Some((index, TaskLocality.RACK_LOCAL))
+ return Some((index, TaskLocality.RACK_LOCAL, false))
}
}
// Look for no-pref tasks after rack-local tasks since they can run anywhere.
for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) {
- return Some((index, TaskLocality.PROCESS_LOCAL))
+ return Some((index, TaskLocality.PROCESS_LOCAL, false))
}
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
for (index <- findTaskFromList(execId, allPendingTasks)) {
- return Some((index, TaskLocality.ANY))
+ return Some((index, TaskLocality.ANY, false))
}
}
// Finally, if all else has failed, find a speculative task
- findSpeculativeTask(execId, host, locality)
+ findSpeculativeTask(execId, host, locality).map { case (taskIndex, allowedLocality) =>
+ (taskIndex, allowedLocality, true)
+ }
}
/**
@@ -391,7 +395,7 @@ private[spark] class TaskSetManager(
}
findTask(execId, host, allowedLocality) match {
- case Some((index, taskLocality)) => {
+ case Some((index, taskLocality, speculative)) => {
// Found a task; do some bookkeeping and return a task description
val task = tasks(index)
val taskId = sched.newTaskId()
@@ -400,7 +404,9 @@ private[spark] class TaskSetManager(
taskSet.id, index, taskId, execId, host, taskLocality))
// Do various bookkeeping
copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
+ val attemptNum = taskAttempts(index).size
+ val info = new TaskInfo(
+ taskId, index, attemptNum + 1, curTime, execId, host, taskLocality, speculative)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
// Update our locality level for delay scheduling
diff --git a/core/src/main/scala/org/apache/spark/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
similarity index 50%
rename from core/src/main/scala/org/apache/spark/FetchFailedException.scala
rename to core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 8eaa26bdb1b5b..71c08e9d5a8c3 100644
--- a/core/src/main/scala/org/apache/spark/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -15,31 +15,38 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.shuffle
import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.{FetchFailed, TaskEndReason}
+/**
+ * Failed to fetch a shuffle block. The executor catches this exception and propagates it
+ * back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage.
+ *
+ * Note that bmAddress can be null.
+ */
private[spark] class FetchFailedException(
- taskEndReason: TaskEndReason,
- message: String,
- cause: Throwable)
+ bmAddress: BlockManagerId,
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int)
extends Exception {
- def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int,
- cause: Throwable) =
- this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
- cause)
-
- def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
- this(FetchFailed(null, shuffleId, -1, reduceId),
- "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
-
- override def getMessage(): String = message
+ override def getMessage: String =
+ "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+ def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+}
- override def getCause(): Throwable = cause
-
- def toTaskEndReason: TaskEndReason = taskEndReason
+/**
+ * Failed to get shuffle metadata from [[org.apache.spark.MapOutputTracker]].
+ */
+private[spark] class MetadataFetchFailedException(
+ shuffleId: Int,
+ reduceId: Int,
+ message: String)
+ extends FetchFailedException(null, shuffleId, -1, reduceId) {
+ override def getMessage: String = message
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index b05b6ea345df3..a932455776e34 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -20,11 +20,12 @@ package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import org.apache.spark._
import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
-import org.apache.spark._
private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
@@ -63,7 +64,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
index 6cb43c02b8f08..2d8c3b949c1ac 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
@@ -79,6 +79,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
val maximumMemory = values("Maximum Memory")
val memoryUsed = values("Memory Used")
val diskUsed = values("Disk Used")
+ // scalastyle:off
{values("Executor ID")} |
{values("Address")} |
@@ -94,10 +95,11 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
{values("Failed Tasks")} |
{values("Complete Tasks")} |
{values("Total Tasks")} |
- {Utils.msDurationToString(values("Task Time").toLong)} |
- {Utils.bytesToString(values("Shuffle Read").toLong)} |
- {Utils.bytesToString(values("Shuffle Write").toLong)} |
+ {Utils.msDurationToString(values("Task Time").toLong)} |
+ {Utils.bytesToString(values("Shuffle Read").toLong)} |
+ {Utils.bytesToString(values("Shuffle Write").toLong)} |
+ // scalastyle:on
}
/** Represent an executor's info as a map given a storage status index */
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
index c83e196c9c156..add0e9878a546 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -67,18 +67,20 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) {
executorIdToSummary match {
case Some(x) =>
x.toSeq.sortBy(_._1).map { case (k, v) => {
+ // scalastyle:off
{k} |
{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} |
- {UIUtils.formatDuration(v.taskTime)} |
+ {UIUtils.formatDuration(v.taskTime)} |
{v.failedTasks + v.succeededTasks} |
{v.failedTasks} |
{v.succeededTasks} |
- {Utils.bytesToString(v.shuffleRead)} |
- {Utils.bytesToString(v.shuffleWrite)} |
- {Utils.bytesToString(v.memoryBytesSpilled)} |
- {Utils.bytesToString(v.diskBytesSpilled)} |
+ {Utils.bytesToString(v.shuffleRead)} |
+ {Utils.bytesToString(v.shuffleWrite)} |
+ {Utils.bytesToString(v.memoryBytesSpilled)} |
+ {Utils.bytesToString(v.diskBytesSpilled)} |
+ // scalastyle:on
}
}
case _ => Seq[Node]()
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 8b65f0671bdb9..8e3d5d1cd4c6b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -95,8 +95,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
// scalastyle:on
val taskHeaders: Seq[String] =
- Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++
- Seq("Duration", "GC Time", "Result Ser Time") ++
+ Seq(
+ "Index", "ID", "Attempt", "Status", "Locality Level", "Executor",
+ "Launch Time", "Duration", "GC Time") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
{if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
{if (hasBytesSpilled) Seq("Shuffle Spill (Memory)", "Shuffle Spill (Disk)") else Nil} ++
@@ -245,6 +246,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
{info.index} |
{info.taskId} |
+ {
+ if (info.speculative) s"${info.attempt} (speculative)" else info.attempt.toString
+ } |
{info.status} |
{info.taskLocality} |
{info.host} |
@@ -255,9 +259,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
{if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""}
|
+
{if (shuffleRead) {
{shuffleReadReadable}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
index b66edd91f56c0..9813d9330ac7f 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
@@ -49,6 +49,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
/** Render an HTML row representing an RDD */
private def rddRow(rdd: RDDInfo): Seq[Node] = {
+ // scalastyle:off
|
@@ -59,9 +60,10 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
|
{rdd.numCachedPartitions} |
{"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} |
- {Utils.bytesToString(rdd.memSize)} |
- {Utils.bytesToString(rdd.tachyonSize)} |
- {Utils.bytesToString(rdd.diskSize)} |
+ {Utils.bytesToString(rdd.memSize)} |
+ {Utils.bytesToString(rdd.tachyonSize)} |
+ {Utils.bytesToString(rdd.diskSize)} |
+ // scalastyle:on
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 7cecbfe62a382..6245b4b8023c2 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -32,6 +32,8 @@ import org.apache.spark.storage._
import org.apache.spark._
private[spark] object JsonProtocol {
+ // TODO: Remove this file and put JSON serialization into each individual class.
+
private implicit val format = DefaultFormats
/** ------------------------------------------------- *
@@ -194,10 +196,12 @@ private[spark] object JsonProtocol {
def taskInfoToJson(taskInfo: TaskInfo): JValue = {
("Task ID" -> taskInfo.taskId) ~
("Index" -> taskInfo.index) ~
+ ("Attempt" -> taskInfo.attempt) ~
("Launch Time" -> taskInfo.launchTime) ~
("Executor ID" -> taskInfo.executorId) ~
("Host" -> taskInfo.host) ~
("Locality" -> taskInfo.taskLocality.toString) ~
+ ("Speculative" -> taskInfo.speculative) ~
("Getting Result Time" -> taskInfo.gettingResultTime) ~
("Finish Time" -> taskInfo.finishTime) ~
("Failed" -> taskInfo.failed) ~
@@ -487,16 +491,19 @@ private[spark] object JsonProtocol {
def taskInfoFromJson(json: JValue): TaskInfo = {
val taskId = (json \ "Task ID").extract[Long]
val index = (json \ "Index").extract[Int]
+ val attempt = (json \ "Attempt").extractOpt[Int].getOrElse(1)
val launchTime = (json \ "Launch Time").extract[Long]
val executorId = (json \ "Executor ID").extract[String]
val host = (json \ "Host").extract[String]
val taskLocality = TaskLocality.withName((json \ "Locality").extract[String])
+ val speculative = (json \ "Speculative").extractOpt[Boolean].getOrElse(false)
val gettingResultTime = (json \ "Getting Result Time").extract[Long]
val finishTime = (json \ "Finish Time").extract[Long]
val failed = (json \ "Failed").extract[Boolean]
val serializedSize = (json \ "Serialized Size").extract[Int]
- val taskInfo = new TaskInfo(taskId, index, launchTime, executorId, host, taskLocality)
+ val taskInfo =
+ new TaskInfo(taskId, index, attempt, launchTime, executorId, host, taskLocality, speculative)
taskInfo.gettingResultTime = gettingResultTime
taskInfo.finishTime = finishTime
taskInfo.failed = failed
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 247f10173f1e9..32c5fdad75e58 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -54,17 +54,17 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
*/
@DeveloperApi
class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
- (implicit random: Random = new XORShiftRandom)
extends RandomSampler[T, T] {
- def this(ratio: Double)(implicit random: Random = new XORShiftRandom)
- = this(0.0d, ratio)(random)
+ private[random] var rng: Random = new XORShiftRandom
- override def setSeed(seed: Long) = random.setSeed(seed)
+ def this(ratio: Double) = this(0.0d, ratio)
+
+ override def setSeed(seed: Long) = rng.setSeed(seed)
override def sample(items: Iterator[T]): Iterator[T] = {
items.filter { item =>
- val x = random.nextDouble()
+ val x = rng.nextDouble()
(x >= lb && x < ub) ^ complement
}
}
@@ -72,7 +72,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
/**
* Return a sampler that is the complement of the range specified of the current sampler.
*/
- def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
+ def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
override def clone = new BernoulliSampler[T](lb, ub, complement)
}
@@ -81,21 +81,21 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
* :: DeveloperApi ::
* A sampler based on values drawn from Poisson distribution.
*
- * @param poisson a Poisson random number generator
+ * @param mean Poisson mean
* @tparam T item type
*/
@DeveloperApi
-class PoissonSampler[T](mean: Double)
- (implicit var poisson: Poisson = new Poisson(mean, new DRand))
- extends RandomSampler[T, T] {
+class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] {
+
+ private[random] var rng = new Poisson(mean, new DRand)
override def setSeed(seed: Long) {
- poisson = new Poisson(mean, new DRand(seed.toInt))
+ rng = new Poisson(mean, new DRand(seed.toInt))
}
override def sample(items: Iterator[T]): Iterator[T] = {
items.flatMap { item =>
- val count = poisson.nextInt()
+ val count = rng.nextInt()
if (count == 0) {
Iterator.empty
} else {
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 95ba273f16a71..9702838085627 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -24,6 +24,7 @@ import akka.testkit.TestActorRef
import org.scalatest.FunSuite
import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.AkkaUtils
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index 00c273df63b29..5dd8de319a654 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.rdd
import org.scalatest.FunSuite
import org.apache.spark.SharedSparkContext
-import org.apache.spark.util.random.RandomSampler
+import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler}
/** a sampler that outputs its seed */
class MockSampler extends RandomSampler[Long, Long] {
@@ -32,7 +32,7 @@ class MockSampler extends RandomSampler[Long, Long] {
}
override def sample(items: Iterator[Long]): Iterator[Long] = {
- return Iterator(s)
+ Iterator(s)
}
override def clone = new MockSampler
@@ -40,11 +40,21 @@ class MockSampler extends RandomSampler[Long, Long] {
class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
- test("seedDistribution") {
+ test("seed distribution") {
val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
val sampler = new MockSampler
val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
- assert(sample.distinct.count == 2, "Seeds must be different.")
+ assert(sample.distinct().count == 2, "Seeds must be different.")
+ }
+
+ test("concurrency") {
+ // SPARK-2251: zip with self computes each partition twice.
+ // We want to make sure there are no concurrency issues.
+ val rdd = sc.parallelize(0 until 111, 10)
+ for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
+ val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
+ sampled.zip(sampled).count()
+ }
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index e0fec6a068bd1..fa43b66c6cb5a 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -66,7 +66,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
// finish this task, should get updated shuffleRead
shuffleReadMetrics.remoteBytesRead = 1000
taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
- var taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+ var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
var task = new ShuffleMapTask(0, null, null, 0, null)
val taskType = Utils.getFormattedClassName(task)
@@ -75,7 +75,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
.shuffleRead == 1000)
// finish a task with unknown executor-id, nothing should happen
- taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo =
+ new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true)
taskInfo.finishTime = 1
task = new ShuffleMapTask(0, null, null, 0, null)
listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
@@ -84,7 +85,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
// finish this task, should get updated duration
shuffleReadMetrics.remoteBytesRead = 1000
taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
- taskInfo = new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
task = new ShuffleMapTask(0, null, null, 0, null)
listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
@@ -94,7 +95,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
// finish this task, should get updated duration
shuffleReadMetrics.remoteBytesRead = 1000
taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
- taskInfo = new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
task = new ShuffleMapTask(0, null, null, 0, null)
listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
@@ -106,7 +107,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
val conf = new SparkConf()
val listener = new JobProgressListener(conf)
val metrics = new TaskMetrics()
- val taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+ val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
val task = new ShuffleMapTask(0, null, null, 0, null)
val taskType = Utils.getFormattedClassName(task)
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 495e1b7a0a214..6c49870455873 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -35,10 +35,11 @@ class JsonProtocolSuite extends FunSuite {
val stageSubmitted =
SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties)
val stageCompleted = SparkListenerStageCompleted(makeStageInfo(101, 201, 301, 401L, 501L))
- val taskStart = SparkListenerTaskStart(111, makeTaskInfo(222L, 333, 444L))
- val taskGettingResult = SparkListenerTaskGettingResult(makeTaskInfo(1000L, 2000, 3000L))
+ val taskStart = SparkListenerTaskStart(111, makeTaskInfo(222L, 333, 1, 444L, false))
+ val taskGettingResult =
+ SparkListenerTaskGettingResult(makeTaskInfo(1000L, 2000, 5, 3000L, true))
val taskEnd = SparkListenerTaskEnd(1, "ShuffleMapTask", Success,
- makeTaskInfo(123L, 234, 345L), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800))
+ makeTaskInfo(123L, 234, 67, 345L, false), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800))
val jobStart = SparkListenerJobStart(10, Seq[Int](1, 2, 3, 4), properties)
val jobEnd = SparkListenerJobEnd(20, JobSucceeded)
val environmentUpdate = SparkListenerEnvironmentUpdate(Map[String, Seq[(String, String)]](
@@ -73,7 +74,7 @@ class JsonProtocolSuite extends FunSuite {
test("Dependent Classes") {
testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L))
testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L))
- testTaskInfo(makeTaskInfo(999L, 888, 777L))
+ testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false))
testTaskMetrics(makeTaskMetrics(33333L, 44444L, 55555L, 66666L, 7, 8))
testBlockManagerId(BlockManagerId("Hong", "Kong", 500, 1000))
@@ -269,10 +270,12 @@ class JsonProtocolSuite extends FunSuite {
private def assertEquals(info1: TaskInfo, info2: TaskInfo) {
assert(info1.taskId === info2.taskId)
assert(info1.index === info2.index)
+ assert(info1.attempt === info2.attempt)
assert(info1.launchTime === info2.launchTime)
assert(info1.executorId === info2.executorId)
assert(info1.host === info2.host)
assert(info1.taskLocality === info2.taskLocality)
+ assert(info1.speculative === info2.speculative)
assert(info1.gettingResultTime === info2.gettingResultTime)
assert(info1.finishTime === info2.finishTime)
assert(info1.failed === info2.failed)
@@ -453,8 +456,8 @@ class JsonProtocolSuite extends FunSuite {
new StageInfo(a, "greetings", b, rddInfos, "details")
}
- private def makeTaskInfo(a: Long, b: Int, c: Long) = {
- new TaskInfo(a, b, c, "executor", "your kind sir", TaskLocality.NODE_LOCAL)
+ private def makeTaskInfo(a: Long, b: Int, c: Int, d: Long, speculative: Boolean) = {
+ new TaskInfo(a, b, c, d, "executor", "your kind sir", TaskLocality.NODE_LOCAL, speculative)
}
private def makeTaskMetrics(a: Long, b: Long, c: Long, d: Long, e: Int, f: Int) = {
@@ -510,37 +513,60 @@ class JsonProtocolSuite extends FunSuite {
private val taskStartJsonString =
"""
- {"Event":"SparkListenerTaskStart","Stage ID":111,"Task Info":{"Task ID":222,
- "Index":333,"Launch Time":444,"Executor ID":"executor","Host":"your kind sir",
- "Locality":"NODE_LOCAL","Getting Result Time":0,"Finish Time":0,"Failed":false,
- "Serialized Size":0}}
- """
+ |{"Event":"SparkListenerTaskStart","Stage ID":111,"Task Info":{"Task ID":222,
+ |"Index":333,"Attempt":1,"Launch Time":444,"Executor ID":"executor","Host":"your kind sir",
+ |"Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,
+ |"Failed":false,"Serialized Size":0}}
+ """.stripMargin
private val taskGettingResultJsonString =
"""
- {"Event":"SparkListenerTaskGettingResult","Task Info":{"Task ID":1000,"Index":
- 2000,"Launch Time":3000,"Executor ID":"executor","Host":"your kind sir",
- "Locality":"NODE_LOCAL","Getting Result Time":0,"Finish Time":0,"Failed":false,
- "Serialized Size":0}}
- """
+ |{"Event":"SparkListenerTaskGettingResult","Task Info":
+ | {"Task ID":1000,"Index":2000,"Attempt":5,"Launch Time":3000,"Executor ID":"executor",
+ | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":true,"Getting Result Time":0,
+ | "Finish Time":0,"Failed":false,"Serialized Size":0
+ | }
+ |}
+ """.stripMargin
private val taskEndJsonString =
"""
- {"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask",
- "Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":123,"Index":
- 234,"Launch Time":345,"Executor ID":"executor","Host":"your kind sir",
- "Locality":"NODE_LOCAL","Getting Result Time":0,"Finish Time":0,"Failed":
- false,"Serialized Size":0},"Task Metrics":{"Host Name":"localhost",
- "Executor Deserialize Time":300,"Executor Run Time":400,"Result Size":500,
- "JVM GC Time":600,"Result Serialization Time":700,"Memory Bytes Spilled":
- 800,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Shuffle Finish Time":
- 900,"Total Blocks Fetched":1500,"Remote Blocks Fetched":800,"Local Blocks Fetched":
- 700,"Fetch Wait Time":900,"Remote Bytes Read":1000},"Shuffle Write Metrics":
- {"Shuffle Bytes Written":1200,"Shuffle Write Time":1500},"Updated Blocks":
- [{"Block ID":"rdd_0_0","Status":{"Storage Level":{"Use Disk":true,"Use Memory":true,
- "Use Tachyon":false,"Deserialized":false,"Replication":2},"Memory Size":0,"Tachyon Size":0,
- "Disk Size":0}}]}}
- """
+ |{"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask",
+ |"Task End Reason":{"Reason":"Success"},
+ |"Task Info":{
+ | "Task ID":123,"Index":234,"Attempt":67,"Launch Time":345,"Executor ID":"executor",
+ | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":false,
+ | "Getting Result Time":0,"Finish Time":0,"Failed":false,"Serialized Size":0
+ |},
+ |"Task Metrics":{
+ | "Host Name":"localhost","Executor Deserialize Time":300,"Executor Run Time":400,
+ | "Result Size":500,"JVM GC Time":600,"Result Serialization Time":700,
+ | "Memory Bytes Spilled":800,"Disk Bytes Spilled":0,
+ | "Shuffle Read Metrics":{
+ | "Shuffle Finish Time":900,
+ | "Total Blocks Fetched":1500,
+ | "Remote Blocks Fetched":800,
+ | "Local Blocks Fetched":700,
+ | "Fetch Wait Time":900,
+ | "Remote Bytes Read":1000
+ | },
+ | "Shuffle Write Metrics":{
+ | "Shuffle Bytes Written":1200,
+ | "Shuffle Write Time":1500},
+ | "Updated Blocks":[
+ | {"Block ID":"rdd_0_0",
+ | "Status":{
+ | "Storage Level":{
+ | "Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false,
+ | "Replication":2
+ | },
+ | "Memory Size":0,"Tachyon Size":0,"Disk Size":0
+ | }
+ | }
+ | ]
+ | }
+ |}
+ """.stripMargin
private val jobStartJsonString =
"""
diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
index e166787f17544..36877476e708e 100644
--- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
@@ -42,7 +42,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.25, 0.55)(random)
+ val sampler = new BernoulliSampler[Int](0.25, 0.55)
+ sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
}
}
@@ -54,7 +55,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
+ val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
+ sampler.rng = random
assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
}
}
@@ -66,7 +68,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.35)(random)
+ val sampler = new BernoulliSampler[Int](0.35)
+ sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
}
}
@@ -78,7 +81,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
+ val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
+ sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
}
}
@@ -88,7 +92,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
random.setSeed(10L)
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.2)(random)
+ val sampler = new BernoulliSampler[Int](0.2)
+ sampler.rng = random
sampler.setSeed(10L)
}
}
@@ -100,7 +105,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(poisson) {
- val sampler = new PoissonSampler[Int](0.2)(poisson)
+ val sampler = new PoissonSampler[Int](0.2)
+ sampler.rng = poisson
assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b20b5de8c46eb..fb517e40677ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -257,8 +257,11 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
* Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
*/
object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
- // split the condition expression into 3 parts,
- // (canEvaluateInLeftSide, canEvaluateInRightSide, haveToEvaluateWithBothSide)
+ /**
+ * Splits join condition expressions into three categories based on the attributes required
+ * to evaluate them.
+ * @returns (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth)
+ */
private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
val (leftEvaluateCondition, rest) =
condition.partition(_.references subsetOf left.outputSet)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index a43bef389c4bf..026692abe067d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -105,57 +105,39 @@ object PhysicalOperation extends PredicateHelper {
}
/**
- * A pattern that finds joins with equality conditions that can be evaluated using hashing
- * techniques. For inner joins, any filters on top of the join operator are also matched.
+ * A pattern that finds joins with equality conditions that can be evaluated using equi-join.
*/
-object HashFilteredJoin extends Logging with PredicateHelper {
+object ExtractEquiJoinKeys extends Logging with PredicateHelper {
/** (joinType, rightKeys, leftKeys, condition, leftChild, rightChild) */
type ReturnType =
(JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan)
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
- // All predicates can be evaluated for inner join (i.e., those that are in the ON
- // clause and WHERE clause.)
- case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
- logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
- splitPredicates(predicates ++ condition, join)
- // All predicates can be evaluated for left semi join (those that are in the WHERE
- // clause can only from left table, so they can all be pushed down.)
- case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
- logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
- splitPredicates(predicates ++ condition, join)
case join @ Join(left, right, joinType, condition) =>
- logger.debug(s"Considering hash join on: $condition")
- splitPredicates(condition.toSeq, join)
- case _ => None
- }
-
- // Find equi-join predicates that can be evaluated before the join, and thus can be used
- // as join keys.
- def splitPredicates(allPredicates: Seq[Expression], join: Join): Option[ReturnType] = {
- val Join(left, right, joinType, _) = join
- val (joinPredicates, otherPredicates) =
- allPredicates.flatMap(splitConjunctivePredicates).partition {
- case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
- (canEvaluate(l, right) && canEvaluate(r, left)) => true
- case _ => false
+ logger.debug(s"Considering join on: $condition")
+ // Find equi-join predicates that can be evaluated before the join, and thus can be used
+ // as join keys.
+ val (joinPredicates, otherPredicates) =
+ condition.map(splitConjunctivePredicates).getOrElse(Nil).partition {
+ case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
+ (canEvaluate(l, right) && canEvaluate(r, left)) => true
+ case _ => false
+ }
+
+ val joinKeys = joinPredicates.map {
+ case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r)
+ case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l)
}
-
- val joinKeys = joinPredicates.map {
- case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r)
- case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l)
- }
-
- // Do not consider this strategy if there are no join keys.
- if (joinKeys.nonEmpty) {
val leftKeys = joinKeys.map(_._1)
val rightKeys = joinKeys.map(_._2)
- Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right))
- } else {
- logger.debug(s"Avoiding hash join with no join keys.")
- None
- }
+ if (joinKeys.nonEmpty) {
+ logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}")
+ Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right))
+ } else {
+ None
+ }
+ case _ => None
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3cd29967d1cd5..0925605b7c4d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -31,9 +31,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- // Find left semi joins where at least some predicates can be evaluated by matching hash
- // keys using the HashFilteredJoin pattern.
- case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
+ // Find left semi joins where at least some predicates can be evaluated by matching join keys
+ case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = execution.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
@@ -46,7 +45,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
- * Uses the HashFilteredJoin pattern to find joins where at least some of the predicates can be
+ * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be
* evaluated by matching hash keys.
*/
object HashJoin extends Strategy with PredicateHelper {
@@ -65,7 +64,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def broadcastTables: Seq[String] = sqlContext.joinBroadcastTables.split(",").toBuffer
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case HashFilteredJoin(
+ case ExtractEquiJoinKeys(
Inner,
leftKeys,
rightKeys,
@@ -75,7 +74,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
if broadcastTables.contains(b.tableName) =>
broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight)
- case HashFilteredJoin(
+ case ExtractEquiJoinKeys(
Inner,
leftKeys,
rightKeys,
@@ -85,7 +84,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
if broadcastTables.contains(b.tableName) =>
broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft)
- case HashFilteredJoin(Inner, leftKeys, rightKeys, condition, left, right) =>
+ case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val hashJoin =
execution.ShuffledHashJoin(
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))