Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14007] [SQL] Manage the memory used by hash map in shuffled hash join #11826

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
logger.warn("Failed to allocate a page ({} bytes), try again.", acquired);
// there is no enough memory actually, it means the actual free memory is smaller than
// MemoryManager thought, we should keep the acquired memory.
acquiredButNotUsed += acquired;
synchronized (this) {
acquiredButNotUsed += acquired;
allocatedPages.clear(pageNumber);
}
// this could trigger spilling to free some pages.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
* of the join will be broadcasted and the other side will be streamed, with no shuffling
* performed. If both sides of the join are eligible to be broadcasted then the
* - Shuffle hash join: if single partition is small enough to build a hash table.
* - Shuffle hash join: if the average size of a single partition is small enough to build a hash
* table.
* - Sort merge: if the matching join keys are sortable.
*/
object EquiJoinSelection extends Strategy with PredicateHelper {

/**
* Matches a plan whose single partition should be small enough to build a hash table.
*
* Note: this assume that the number of partition is fixed, requires addtional work if it's
* dynamic.
*/
def canBuildHashMap(plan: LogicalPlan): Boolean = {
plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,51 +109,6 @@ private[execution] trait UniqueHashedRelation extends HashedRelation {
}
}

/**
* A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
*/
private[joins] class GeneralHashedRelation(
private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
extends HashedRelation with Externalizable {

// Needed for serialization (it is public to make Java serialization work)
def this() = this(null)

override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)

override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
}

override def readExternal(in: ObjectInput): Unit = {
hashTable = SparkSqlSerializer.deserialize(readBytes(in))
}
}


/**
* A specialized [[HashedRelation]] that maps key into a single value. This implementation
* assumes the key is unique.
*/
private[joins] class UniqueKeyHashedRelation(
private var hashTable: JavaHashMap[InternalRow, InternalRow])
extends UniqueHashedRelation with Externalizable {

// Needed for serialization (it is public to make Java serialization work)
def this() = this(null)

override def getValue(key: InternalRow): InternalRow = hashTable.get(key)

override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
}

override def readExternal(in: ObjectInput): Unit = {
hashTable = SparkSqlSerializer.deserialize(readBytes(in))
}
}


private[execution] object HashedRelation {

/**
Expand All @@ -162,51 +117,16 @@ private[execution] object HashedRelation {
* Note: The caller should make sure that these InternalRow are different objects.
*/
def apply(
canJoinKeyFitWithinLong: Boolean,
input: Iterator[InternalRow],
keyGenerator: Projection,
sizeEstimate: Int = 64): HashedRelation = {

if (keyGenerator.isInstanceOf[UnsafeProjection]) {
return UnsafeHashedRelation(
input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
}

// TODO: Use Spark's HashMap implementation.
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
var currentRow: InternalRow = null

// Whether the join key is unique. If the key is unique, we can convert the underlying
// hash map into one specialized for this.
var keyIsUnique = true

// Create a mapping of buildKeys -> rows
while (input.hasNext) {
currentRow = input.next()
val rowKey = keyGenerator(currentRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new CompactBuffer[InternalRow]()
hashTable.put(rowKey.copy(), newMatchList)
newMatchList
} else {
keyIsUnique = false
existingMatchList
}
matchList += currentRow
}
}

if (keyIsUnique) {
val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size)
val iter = hashTable.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
uniqHashTable.put(entry.getKey, entry.getValue()(0))
}
new UniqueKeyHashedRelation(uniqHashTable)
if (canJoinKeyFitWithinLong) {
LongHashedRelation(input, keyGenerator, sizeEstimate)
} else {
new GeneralHashedRelation(hashTable)
UnsafeHashedRelation(
input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
}
}
}
Expand Down Expand Up @@ -428,6 +348,7 @@ private[joins] object UnsafeHashedRelation {
sizeEstimate: Int): HashedRelation = {

// Use a Java hash table here because unsafe maps expect fixed size records
// TODO: Use BytesToBytesMap for memory efficiency
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)

// Create a mapping of buildKeys -> rows
Expand Down Expand Up @@ -683,11 +604,7 @@ private[execution] case class HashedRelationBroadcastMode(

override def transform(rows: Array[InternalRow]): HashedRelation = {
val generator = UnsafeProjection.create(keys, attributes)
if (canJoinKeyFitWithinLong) {
LongHashedRelation(rows.iterator, generator, rows.length)
} else {
HashedRelation(rows.iterator, generator, rows.length)
}
HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length)
}

private lazy val canonicalizedKeys: Seq[Expression] = {
Expand All @@ -703,4 +620,3 @@ private[execution] case class HashedRelationBroadcastMode(
case _ => false
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@

package org.apache.spark.sql.execution.joins

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.memory.MemoryMode
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow}
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
* Performs an inner hash join of two child relations by first shuffling the data using the join
* keys.
* Performs a hash join of two child relations by first shuffling the data using the join keys.
*/
case class ShuffledHashJoin(
leftKeys: Seq[Expression],
Expand Down Expand Up @@ -55,11 +56,45 @@ case class ShuffledHashJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
// try to acquire some memory for the hash table, it could trigger other operator to free some
// memory. The memory acquired here will mostly be used until the end of task.
val context = TaskContext.get()
val memoryManager = context.taskMemoryManager()
var acquired = 0L
var used = 0L
context.addTaskCompletionListener((t: TaskContext) =>
memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null)
)

val copiedIter = iter.map { row =>
// It's hard to guess what's exactly memory will be used, we have a rough guess here.
// TODO: use BytesToBytesMap instead of HashMap for memory efficiency
// Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers
val needed = 150 + row.getSizeInBytes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain (and possibly comment in code) the reason behind choosing 150?

if (needed > acquired - used) {
val got = memoryManager.acquireExecutionMemory(
Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
if (got < needed) {
throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: missing space between shuffled and hash

"hash join, please use sort merge join by setting " +
"spark.sql.join.preferSortMergeJoin=true")
}
acquired += got
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we release the acquired memory on task failure/completion in L67, shouldn't we update the value of acquired memory before throwing the Spark exception above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The acquired memory will be accounted in task memory manager itself, the required here will be used anywhere.

}
used += needed
// HashedRelation requires that the UnsafeRow should be separate objects.
row.copy()
}

HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator)
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
val joinedRow = new JoinedRow
joinType match {
case Inner =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,6 @@ import org.apache.spark.util.collection.CompactBuffer

class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {

// Key is simply the record itself
private val keyProjection = new Projection {
override def apply(row: InternalRow): InternalRow = row
}

test("GeneralHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])

assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
assert(hashed.get(InternalRow(10)) === null)

val data2 = CompactBuffer[InternalRow](data(2))
data2 += data(2)
assert(hashed.get(data(2)) === data2)
}

test("UniqueKeyHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])

assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2)))
assert(hashed.get(InternalRow(10)) === null)

val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
assert(uniqHashed.getValue(data(0)) === data(0))
assert(uniqHashed.getValue(data(1)) === data(1))
assert(uniqHashed.getValue(data(2)) === data(2))
assert(uniqHashed.getValue(InternalRow(10)) === null)
}

test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
Expand Down