-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-14007] [SQL] Manage the memory used by hash map in shuffled hash join #11826
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,17 +17,18 @@ | |
|
||
package org.apache.spark.sql.execution.joins | ||
|
||
import org.apache.spark.{SparkException, TaskContext} | ||
import org.apache.spark.memory.MemoryMode | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow} | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} | ||
import org.apache.spark.sql.catalyst.plans._ | ||
import org.apache.spark.sql.catalyst.plans.physical._ | ||
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} | ||
import org.apache.spark.sql.execution.metric.SQLMetrics | ||
|
||
/** | ||
* Performs an inner hash join of two child relations by first shuffling the data using the join | ||
* keys. | ||
* Performs a hash join of two child relations by first shuffling the data using the join keys. | ||
*/ | ||
case class ShuffledHashJoin( | ||
leftKeys: Seq[Expression], | ||
|
@@ -55,11 +56,45 @@ case class ShuffledHashJoin( | |
override def requiredChildDistribution: Seq[Distribution] = | ||
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil | ||
|
||
private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { | ||
// try to acquire some memory for the hash table, it could trigger other operator to free some | ||
// memory. The memory acquired here will mostly be used until the end of task. | ||
val context = TaskContext.get() | ||
val memoryManager = context.taskMemoryManager() | ||
var acquired = 0L | ||
var used = 0L | ||
context.addTaskCompletionListener((t: TaskContext) => | ||
memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null) | ||
) | ||
|
||
val copiedIter = iter.map { row => | ||
// It's hard to guess what's exactly memory will be used, we have a rough guess here. | ||
// TODO: use BytesToBytesMap instead of HashMap for memory efficiency | ||
// Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers | ||
val needed = 150 + row.getSizeInBytes | ||
if (needed > acquired - used) { | ||
val got = memoryManager.acquireExecutionMemory( | ||
Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null) | ||
if (got < needed) { | ||
throw new SparkException("Can't acquire enough memory to build hash map in shuffled" + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: missing space between shuffled and hash |
||
"hash join, please use sort merge join by setting " + | ||
"spark.sql.join.preferSortMergeJoin=true") | ||
} | ||
acquired += got | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that we release the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The acquired memory will be accounted in task memory manager itself, the |
||
} | ||
used += needed | ||
// HashedRelation requires that the UnsafeRow should be separate objects. | ||
row.copy() | ||
} | ||
|
||
HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator) | ||
} | ||
|
||
protected override def doExecute(): RDD[InternalRow] = { | ||
val numOutputRows = longMetric("numOutputRows") | ||
|
||
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => | ||
val hashed = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator) | ||
val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) | ||
val joinedRow = new JoinedRow | ||
joinType match { | ||
case Inner => | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please explain (and possibly comment in code) the reason behind choosing 150?